Zh3r0 CTF V2 | Twist and Shout

#zh3ro_CTF_2021

from secret import flag
import os
import random

state_len = 624*4
right_pad = random.randint(0,state_len-len(flag))
left_pad = state_len-len(flag)-right_pad
state_bytes = os.urandom(left_pad)+flag+os.urandom(right_pad)
state = tuple( int.from_bytes(state_bytes[i:i+4],'big') for i in range(0,state_len,4) )
random.setstate((3,state+(624,),None))
outputs = [random.getrandbits(32) for i in range(624)]
print(*outputs,sep='\n')
  • Mersenne Twister で624個の値出力が与えられているので、内部状態を構成して次の出力を予測することができる。

  • 問題は624個の出力の前の状態を復元するところにある。つまりメルセンヌツイスターを逆に回せるか、という問題

メルセンヌツイスターの実装では状態の注目している部分の最下位bitが0か1かによって分岐が発生するが、その部分をうまく落とし込んでやればz3による復元が可能

from Crypto.Util.number import long_to_bytes
from pwn import *
from z3 import *


# http://inaz2.hatenablog.com/entry/2016/03/07/194147
def untemper(x):
    x = unBitshiftRightXor(x, 18)
    x = unBitshiftLeftXor(x, 15, 0xEFC60000)
    x = unBitshiftLeftXor(x, 7, 0x9D2C5680)
    x = unBitshiftRightXor(x, 11)
    return x


def unBitshiftRightXor(x, shift):
    i = 1
    y = x
    while i * shift < 32:
        z = y >> shift
        y = x ^ z
        i += 1
    return y


def unBitshiftLeftXor(x, shift, mask):
    i = 1
    y = x
    while i * shift < 32:
        z = y << shift
        y = x ^ (z & mask)
        i += 1
    return y


_r = remote("crypto.zh3r0.cf", 5555)
outputs = []
for _ in range(624):
    outputs.append(int(_r.recvline()))
state = tuple([untemper(x) for x in outputs] + [624])
_r.close()

# https://hg.python.org/cpython/file/2.7/Modules/_randommodule.c#l95
N = 624
M = 397
MATRIX_A = 0x9908B0DF
UPPER_MASK = 0x80000000
LOWER_MASK = 0x7FFFFFFF
xs = [BitVec(f"x_{i}", 33) for i in range(624)]
mt = xs.copy()
for kk in range(N - M):
    y = (mt[kk] & UPPER_MASK) | (mt[kk + 1] & LOWER_MASK)
    mt[kk] = mt[kk + M] ^ (y >> 1) ^ (y % 2) * MATRIX_A
for kk in range(N - M, N - 1):
    y = (mt[kk] & UPPER_MASK) | (mt[kk + 1] & LOWER_MASK)
    mt[kk] = mt[kk + (M - N)] ^ (y >> 1) ^ (y % 2) * MATRIX_A
y = (mt[N - 1] & UPPER_MASK) | (mt[0] & LOWER_MASK)
mt[N - 1] = mt[M - 1] ^ (y >> 1) ^ (y % 2) * MATRIX_A

s = Solver()
for i in range(624):
    s.add(mt[i] == state[i])
assert s.check() == sat
m = s.model()
last_state = [m[xs[i]].as_long() for i in range(624)]
state_bytes = b"".join([long_to_bytes(s) for s in last_state])
print(state_bytes[state_bytes.find(b"zh3r0"):])