Truncated LCG

#yoshicamp2020winter

通常のLCGのパラメータ a, b, m

出力系列 x

  •  x_1 = a*x_0 + b \mod m

  •  x_2 = a * x_1 + b \mod m

  •  \dots

Truncated LCGでは(例えば) x_iの下位半分がわからないときに、 x_{i+1}を求めたい

 x_i = 2^k y_i + z_iと書いて、 z_iが未知

まずb=0の場合

 x_1 = ax_0 \mod m

 x_2 = ax_1 = a^2x_0 \mod m

 x_3 = ax_2 = a^3x_0 \mod m

 \vdots

次のような等式を考えてみる

 \begin{pmatrix} x_0 & x_1 & \dots & x_n \end{pmatrix} \begin{pmatrix} m & a & a^2 & \dots & a ^n \  & -1 \  & & -1 \  & & & \ddots \  & & & & -1\end{pmatrix}^t = \begin{pmatrix} 0 & 0 & \dots & 0 \end{pmatrix} \mod m

ベクトルを \vec{x}と置き、行列を Lと置いて:  \vec{x}L = \vec{0} \mod m

 Lを基底簡約して Bにして:  \vec{x}B = \vec{0} \mod m

 \vec{x} = 2^k\vec{y} + \vec{z}なので (2^k\vec{y} + \vec{z})B = \vec{l}m

 \vec{z}B = \vec{l}m - 2^k\vec{y}B

 \vec{z} Bも小さいので \vec{z}B \simeq 0と近似して \vec{l} = \lceil 2^k\vec{y}B / m \rfloor

これで \vec{z}以外がわかったので、 \vec{z} = (\vec{l}m - 2^k\vec{y}B)*B^{-1}

[*  b\ne 0のとき]

 x_2 = ax_1 + b = a(ax_0 + b) + b = a^2x_0 + ab + b

 x_3 = a^3x_0 + a^2b + ab + b

 x_4 = a^4x_0 + a^3b + a^2b + ab + b

 x_i = a^ix_0 + f_iとおいて

 \begin{pmatrix} x_0 - b & x_1 - f_1 & \dots & x_n - f_n \end{pmatrix} \begin{pmatrix} m & a & a^2 & \dots & a ^n \  & -1 \  & & -1 \  & & & \ddots \  & & & & -1\end{pmatrix}^t = 0 \mod m

例題と回答

#LLL

mod = random_prime(1<<64)
a = randint(1, mod-1)
c = randint(1, mod-1)
seed = int.from_bytes(open("flag.txt","rb").read().strip(), "big")

class LCG():
  def __init__(self, a, c, mod, seed):
    self.a = a
    self.c = c
    self.mod = mod
    self.seed = seed

  def next(self):
    self.seed = (self.seed * self.a + self.c) % self.mod
    return self.seed

lcg = LCG(a, c, mod, seed)
outputs = []
for _ in range(10):
  r = lcg.next()
  half = r >> 32
  outputs.append(half)

print("a={}".format(a))
print("c={}".format(c))
print("mod={}".format(mod))
print("outputs={}".format(outputs))
def guess_state(xs, a, b, m):
    # s1 = a*s0 + b
    # s2 = a^2*s0 + ab + b
    # s3 = a^3*s0 + aab + ab + b
    # s4 = a^4*s0 + aaab + aab + ab + b

    n = len(xs)
    ks = [b]
    for i in range(1, n):
        ks.append(ks[i-1] + a^i * b)

    xxs = [(xs[i] << 32) - ks[i]  for i in range(n)]

    L = matrix(ZZ, n, 1, [m] + [a^i for i in range(1, n)])
    L = L.augment(matrix.identity(n)[:,1:] * -1)
    B = L.LLL()
    print(list(B))

    xB = B * vector(xxs)
    k = [round(x/m) for x in xB]
    yys = B.solve_right(vector(k) * m - xB)
    print(yys)

    s1 = (xs[0] << 32) + yys[0]
    s0 = (s1 - b) * inverse_mod(a, m) % m
    return s0

a=4984204341965293659
c=812992498535670640
mod=12007620771585163889
outputs=[1603731004, 1981439462, 843976945, 95071464, 1700919759, 296270126, 952176344, 1205445753, 1258468465, 143557791]

seed = guess_state(outputs, a, c, mod)
print(bytes.fromhex(hex(seed)[2:]))