from random import getrandbits from Crypto.Util.number import getPrime, long_to_bytes, bytes_to_long def keygen(): # normal rsa key generation primes = [] e = 3 for _ in range(2): while True: p = getPrime(1024) if (p - 1) % 3: break primes.append(p) return e, primes[0] * primes[1] def pad(m, n): # pkcs#1 v1.5 ms = long_to_bytes(m) ns = long_to_bytes(n) if len(ms) >= len(ns) - 11: return -1 padlength = len(ns) - len(ms) - 3 ps = long_to_bytes(getrandbits(padlength * 8)).rjust(padlength, b"\x00") return int.from_bytes(b"\x00\x02" + ps + b"\x00" + ms, "big") def encrypt(m, e, n): # standard rsa res = pad(m, n) if res != -1: print(f"c: {pow(res, e, n)}") else: print("error :(", "message too long") menu = """ [1] enc() [2] enc(flag) [3] quit """[1:] e, n = keygen() print(f"e: {e}") print(f"n: {n}") assert len(open("/challenge/flag.txt", "rb").read()) < 55 while True: try: print(menu) opt = input("opt: ") if opt == "1": encrypt(int(input("msg: ")), e, n) elif opt == "2": encrypt(bytes_to_long(open("/challenge/flag.txt", "rb").read()), e, n) elif opt == "3": print("bye") exit(0) else: print("idk") except Exception as e: print("error :(", e)
平文をめちゃくちゃ小さくすればunivariate coppersmith methodからpaddingが求められる。paddingはpython標準のランダムで生成されているのでMersenne Twisterの状態を復元すれば未来の出力がわかる。
from sage.all import * from random import getrandbits import random from Crypto.Util.number import getPrime, long_to_bytes, bytes_to_long from ptrlib import Socket class MersenneTwister(random.Random): N = 624 M = 397 A = [0, 0x9908b0df] UPPER_MASK = 0x80000000 LOWER_MASK = 0x7fffffff def __init__(self, x=None): self.seed(x) def seed(self, a=None, version=2): r = random.Random() r.seed(a, version) self.setstate(r.getstate()) def getstate(self): t = self.clone() p = t.p if p == 0: p = self.N else: for _ in range(self.N - p): t.next() return (3, tuple(t.state + [p]), None) def _getstate(self): return tuple(self.state + [self.p]) def setstate(self, state): assert len(state) == 3 assert len(state[1]) == self.N + 1 self._setstate(state[1]) if self.p != 0: p = self.p self.p = 0 for _ in range(self.N - p): self.prev() self.p = p def _setstate(self, state): self.state = list(state[:self.N]) self.p = state[self.N] % self.N def getrandbits(self, k: int)->int: if k < 0: raise ValueError("number of bits must be non-negative") if k == 0: return 0 if k <= 32: return self.next() >> (32 - k) b = b'' while k > 0: x = self.next() if k < 32: x = x >> (32 - k) b += x.to_bytes(4, 'little') k -= 32 return int.from_bytes(b, 'little') def next_value(self)->int: p, q = self.p, (self.p + 1) % self.N # update state a = self.state[p] & self.UPPER_MASK b = self.state[q] & self.LOWER_MASK x = a | b k = (p + self.M) % self.N return self.state[k] ^ (x >> 1) ^ self.A[x & 1] def next(self)->int: y = self.next_value() self.state[self.p] = y self.p = (self.p + 1) % self.N return self._tempering(y) def prev(self)->int: p = (self.p - 2) % self.N k = (p + self.M) % self.N t = self.state[p] ^ self.state[k] # (x >> 1) ^ self.A[x & 1] x_ = t ^ self.A[t >> 31] # x >> 1 because t>>31 == 1 iff x&1 == 1) q, l = (p + 1) % self.N, (k + 1) % self.N head = ((self.state[q] ^ self.state[l]) << 1)&self.UPPER_MASK body = (x_ << 1)&self.LOWER_MASK tail = t >> 31 self.p = (self.p - 1) % self.N self.state[self.p] = head|body|tail return self._tempering(self.next_value()) def random(self): a = self.next() >> 5 b = self.next() >> 6 return ((a * 2**26 + b) * (1.0 / 2**53)) def _tempering(self, y): y ^= y >> 11 y ^= (y << 7) & 0x9d2c5680 y ^= (y << 15) & 0xefc60000 y ^= y >> 18 return y def _untempering(self, y): y ^= y >> 18 y ^= (y << 15) & 0xefc60000 y ^= ((y << 7) & 0x9d2c5680) ^ ((y << 14) & 0x94284000) ^ ((y << 21) & 0x14200000) ^ ((y << 28) & 0x10000000) y ^= (y >> 11) ^ (y >> 22) return y def setoutputs(self, outputs): assert len(outputs) == self.N self._setstate([self._untempering(o) for o in outputs] + [0]) def clone(self): t = MersenneTwister() t._setstate(self._getstate()) return t sock = Socket("193.57.159.27", 40341) e = int(sock.recvlineafter("e: ")) n = int(sock.recvlineafter("n: ")) ns = len(long_to_bytes(n)) ms = ns - 15 PR = PolynomialRing(Zmod(n), name="x") x = PR.gen() outputs = [] for _ in range(MersenneTwister.N // 3): m = getrandbits(ms * 8) sock.sendlineafter("opt: ", "1") sock.sendlineafter("msg: ", str(m)) c = int(sock.recvlineafter("c: ")) f = ((x)*(2**(ms*8)) + m)**3 - c f = f.monic() roots = f.small_roots(2**(8 * 16)) root = int(roots[0]) bits = (root >> 8) % (2**96) outputs.append(bits & 0xffffffff) outputs.append((bits >> 32) & 0xffffffff) outputs.append((bits >> 64) & 0xffffffff) print(len(outputs)) # mt = MersenneTwister() # mt.setoutputs(outputs) # # m = getrandbits(8 * 50) # print("m=", m) # m_ = pad(m, n) # c = pow(m_, 3, n) sock.sendlineafter("opt: ", "2") c = int(sock.recvlineafter("c: ")) for mlen in range(10, 55): print("len=", mlen) # 本当はこんなことをしなくてもsetstate/getstateすればいいんだけどバグってそう mt = MersenneTwister() mt.setoutputs(outputs) pad_length = ns - mlen - 3 pad_bytes = b"\x00\x02" + long_to_bytes(mt.getrandbits(pad_length * 8)) + b"\x00" # print(pad_bytes) f = c - (bytes_to_long(pad_bytes) * 2**(8*mlen) + x)**3 f = f.monic() roots = f.small_roots(X=2**(8*mlen), beta=0.02) for r in roots: print(r)