from secret import flag import os import random state_len = 624*4 right_pad = random.randint(0,state_len-len(flag)) left_pad = state_len-len(flag)-right_pad state_bytes = os.urandom(left_pad)+flag+os.urandom(right_pad) state = tuple( int.from_bytes(state_bytes[i:i+4],'big') for i in range(0,state_len,4) ) random.setstate((3,state+(624,),None)) outputs = [random.getrandbits(32) for i in range(624)] print(*outputs,sep='\n')
Mersenne Twister で624個の値出力が与えられているので、内部状態を構成して次の出力を予測することができる。
問題は624個の出力の前の状態を復元するところにある。つまりメルセンヌツイスターを逆に回せるか、という問題
メルセンヌツイスターの実装では状態の注目している部分の最下位bitが0か1かによって分岐が発生するが、その部分をうまく落とし込んでやればz3による復元が可能
from Crypto.Util.number import long_to_bytes from pwn import * from z3 import * # http://inaz2.hatenablog.com/entry/2016/03/07/194147 def untemper(x): x = unBitshiftRightXor(x, 18) x = unBitshiftLeftXor(x, 15, 0xEFC60000) x = unBitshiftLeftXor(x, 7, 0x9D2C5680) x = unBitshiftRightXor(x, 11) return x def unBitshiftRightXor(x, shift): i = 1 y = x while i * shift < 32: z = y >> shift y = x ^ z i += 1 return y def unBitshiftLeftXor(x, shift, mask): i = 1 y = x while i * shift < 32: z = y << shift y = x ^ (z & mask) i += 1 return y _r = remote("crypto.zh3r0.cf", 5555) outputs = [] for _ in range(624): outputs.append(int(_r.recvline())) state = tuple([untemper(x) for x in outputs] + [624]) _r.close() # https://hg.python.org/cpython/file/2.7/Modules/_randommodule.c#l95 N = 624 M = 397 MATRIX_A = 0x9908B0DF UPPER_MASK = 0x80000000 LOWER_MASK = 0x7FFFFFFF xs = [BitVec(f"x_{i}", 33) for i in range(624)] mt = xs.copy() for kk in range(N - M): y = (mt[kk] & UPPER_MASK) | (mt[kk + 1] & LOWER_MASK) mt[kk] = mt[kk + M] ^ (y >> 1) ^ (y % 2) * MATRIX_A for kk in range(N - M, N - 1): y = (mt[kk] & UPPER_MASK) | (mt[kk + 1] & LOWER_MASK) mt[kk] = mt[kk + (M - N)] ^ (y >> 1) ^ (y % 2) * MATRIX_A y = (mt[N - 1] & UPPER_MASK) | (mt[0] & LOWER_MASK) mt[N - 1] = mt[M - 1] ^ (y >> 1) ^ (y % 2) * MATRIX_A s = Solver() for i in range(624): s.add(mt[i] == state[i]) assert s.check() == sat m = s.model() last_state = [m[xs[i]].as_long() for i in range(624)] state_bytes = b"".join([long_to_bytes(s) for s in last_state]) print(state_bytes[state_bytes.find(b"zh3r0"):])