RaRCTF 2021 | randompad

#rarctf2021

from random import getrandbits
from Crypto.Util.number import getPrime, long_to_bytes, bytes_to_long

def keygen():  # normal rsa key generation
    primes = []
    e = 3
    for _ in range(2):
        while True:
            p = getPrime(1024)
            if (p - 1) % 3:
                break
        primes.append(p)
    return e, primes[0] * primes[1]

def pad(m, n):  # pkcs#1 v1.5
    ms = long_to_bytes(m)
    ns = long_to_bytes(n)
    if len(ms) >= len(ns) - 11:
        return -1
    padlength = len(ns) - len(ms) - 3
    ps = long_to_bytes(getrandbits(padlength * 8)).rjust(padlength, b"\x00")
    return int.from_bytes(b"\x00\x02" + ps + b"\x00" + ms, "big")

def encrypt(m, e, n):  # standard rsa
    res = pad(m, n)
    if res != -1:
        print(f"c: {pow(res, e, n)}")
    else:
        print("error :(", "message too long")

menu = """
[1] enc()
[2] enc(flag)
[3] quit
"""[1:]

e, n = keygen()
print(f"e: {e}")
print(f"n: {n}")
assert len(open("/challenge/flag.txt", "rb").read()) < 55

while True:
    try:
        print(menu)
        opt = input("opt: ")
        if opt == "1":
            encrypt(int(input("msg: ")), e, n)
        elif opt == "2":
            encrypt(bytes_to_long(open("/challenge/flag.txt", "rb").read()), e, n)
        elif opt == "3":
            print("bye")
            exit(0)
        else:
            print("idk")
    except Exception as e:
        print("error :(", e)

RaRCTF 2021 | unrandompadの続き

平文をめちゃくちゃ小さくすればunivariate coppersmith methodからpaddingが求められる。paddingpython標準のランダムで生成されているのでMersenne Twisterの状態を復元すれば未来の出力がわかる。

from sage.all import *
from random import getrandbits
import random
from Crypto.Util.number import getPrime, long_to_bytes, bytes_to_long
from ptrlib import Socket

class MersenneTwister(random.Random):
    N = 624
    M = 397
    A = [0, 0x9908b0df]
    UPPER_MASK = 0x80000000
    LOWER_MASK = 0x7fffffff

    def __init__(self, x=None):
        self.seed(x)

    def seed(self, a=None, version=2):
        r = random.Random()
        r.seed(a, version)
        self.setstate(r.getstate())

    def getstate(self):
        t = self.clone()
        p = t.p
        if p == 0:
            p = self.N
        else:
            for _ in range(self.N - p):
                t.next()
        return (3, tuple(t.state + [p]), None)

    def _getstate(self):
        return tuple(self.state + [self.p])

    def setstate(self, state):
        assert len(state) == 3
        assert len(state[1]) == self.N + 1
        self._setstate(state[1])

        if self.p != 0:
            p = self.p
            self.p = 0
            for _ in range(self.N - p):
                self.prev()
            self.p = p

    def _setstate(self, state):
        self.state = list(state[:self.N])
        self.p = state[self.N] % self.N

    def getrandbits(self, k: int)->int:
        if k < 0:
            raise ValueError("number of bits must be non-negative")
        if k == 0:
            return 0

        if k <= 32:
            return self.next() >> (32 - k)

        b = b''
        while k > 0:
            x = self.next()
            if k < 32:
                x = x >> (32 - k)
            b += x.to_bytes(4, 'little')
            k -= 32
        return int.from_bytes(b, 'little')

    def next_value(self)->int:
        p, q = self.p, (self.p + 1) % self.N
        # update state
        a = self.state[p] & self.UPPER_MASK
        b = self.state[q] & self.LOWER_MASK
        x = a | b

        k = (p + self.M) % self.N
        return self.state[k] ^ (x >> 1) ^ self.A[x & 1]

    def next(self)->int:
        y = self.next_value()
        self.state[self.p] = y
        self.p = (self.p + 1) % self.N
        return self._tempering(y)

    def prev(self)->int:
        p = (self.p - 2) % self.N
        k = (p + self.M) % self.N

        t = self.state[p] ^ self.state[k]  # (x >> 1) ^ self.A[x & 1]
        x_ = t ^ self.A[t >> 31]           # x >> 1 because t>>31 == 1 iff x&1 == 1)

        q, l = (p + 1) % self.N, (k + 1) % self.N
        head = ((self.state[q] ^ self.state[l]) << 1)&self.UPPER_MASK
        body = (x_ << 1)&self.LOWER_MASK
        tail = t >> 31

        self.p = (self.p - 1) % self.N
        self.state[self.p] = head|body|tail
        return self._tempering(self.next_value())

    def random(self):
        a = self.next() >> 5
        b = self.next() >> 6
        return ((a * 2**26 + b) * (1.0 / 2**53))

    def _tempering(self, y):
        y ^= y >> 11
        y ^= (y << 7) & 0x9d2c5680
        y ^= (y << 15) & 0xefc60000
        y ^= y >> 18
        return y

    def _untempering(self, y):
        y ^= y >> 18
        y ^= (y << 15) & 0xefc60000
        y ^= ((y << 7) & 0x9d2c5680) ^ ((y << 14) & 0x94284000) ^ ((y << 21) & 0x14200000) ^ ((y << 28) & 0x10000000)
        y ^= (y >> 11) ^ (y >> 22)
        return y

    def setoutputs(self, outputs):
        assert len(outputs) == self.N
        self._setstate([self._untempering(o) for o in outputs] + [0])

    def clone(self):
        t = MersenneTwister()
        t._setstate(self._getstate())
        return t


sock = Socket("193.57.159.27", 40341)
e = int(sock.recvlineafter("e: "))
n = int(sock.recvlineafter("n: "))

ns = len(long_to_bytes(n))
ms = ns - 15

PR = PolynomialRing(Zmod(n), name="x")
x = PR.gen()
outputs = []

for _ in range(MersenneTwister.N // 3):
    m = getrandbits(ms * 8)
    sock.sendlineafter("opt: ", "1")
    sock.sendlineafter("msg: ", str(m))
    c = int(sock.recvlineafter("c: "))

    f = ((x)*(2**(ms*8)) + m)**3 - c
    f = f.monic()

    roots = f.small_roots(2**(8 * 16))
    root = int(roots[0])
    bits = (root >> 8) % (2**96)

    outputs.append(bits & 0xffffffff)
    outputs.append((bits >> 32) & 0xffffffff)
    outputs.append((bits >> 64) & 0xffffffff)

    print(len(outputs))


# mt = MersenneTwister()
# mt.setoutputs(outputs)
# 
# m = getrandbits(8 * 50)
# print("m=", m)
# m_ = pad(m, n)
# c = pow(m_, 3, n)

sock.sendlineafter("opt: ", "2")
c = int(sock.recvlineafter("c: "))

for mlen in range(10, 55):
    print("len=", mlen)
    # 本当はこんなことをしなくてもsetstate/getstateすればいいんだけどバグってそう
    mt = MersenneTwister()
    mt.setoutputs(outputs)

    pad_length = ns - mlen - 3
    pad_bytes = b"\x00\x02" + long_to_bytes(mt.getrandbits(pad_length * 8)) + b"\x00"
    # print(pad_bytes)
    f = c - (bytes_to_long(pad_bytes) * 2**(8*mlen) + x)**3
    f = f.monic()
    roots = f.small_roots(X=2**(8*mlen), beta=0.02)
    for r in roots:
        print(r)