from Crypto.Util.number import isPrime, getPrime, getRandomRange, inverse import os import signal signal.alarm(300) flag = os.environ.get("FLAG", "0nepoint{GOLDEN SMILE & SILVER TEARS}") flag = int(flag.encode().hex(), 16) P = 2 ** 1000 - 1 while not isPrime(P): P -= 2 p = getPrime(512) q = getPrime(512) e = 65537 phi = (p-1)*(q-1) d = inverse(e, phi) n = p*q key = getRandomRange(0, n) ciphertext = pow(flag, e, P) ^ key x1 = getRandomRange(0, n) x2 = getRandomRange(0, n) print("P = {}".format(P)) print("n = {}".format(n)) print("e = {}".format(e)) print("x1 = {}".format(x1)) print("x2 = {}".format(x2)) # pick a random number k and compute v = k**e + (x1|x2) # if you add x1, you can get key = c1 - k mod n # elif you add x2, you can get ciphertext = c2 - k mod n v = int(input("v: ")) k1 = pow(v - x1, d, n) k2 = pow(v - x2, d, n) print("c1 = {}".format((k1 + key) % n)) print("c2 = {}".format((k2 + ciphertext) % n))
RSAでが与えられる Oblivious Transfer
receiver(player)はという値を作ってまたはを計算して送り、sender (server)はとして、とを送る
receiverはまたはを復号可能、というのが本来のプロトコル
今回は平文をランダムに暗号化したか、その暗号化に用いられているのいずれかが手にはいる
この両方をなんとかして手に入れたい、というのが今回の目的
ct + key mod nの算出
として が成立するような値を送ってみる。するとが奇数なのでとなる。
このとき, となるので、を計算することができる。
ここからを求めたい。は何回でも求められるとき、を復元できるだろうか。
ct + keyの復元
とりあえずだと扱いにくいので、これをにする。はいつでも固定なので、の最下位bitも当然いつでも同じになる。最下位bitについては加算と論理和は区別がないので、の結果の偶奇も常に固定される。
偶奇は2通りしかないので両方のパターンを試してみればいい。例えばが偶数だろう、ということにするとが奇数ならその値はということになるから、を復元するにはを計算すれば良い(は奇数)
ct ^ keyの復元
であるようなが十分に集まればを復元できる。
のbit目が1のとき、繰り上がりがなければとなるはずである。
ならば必ず次のbitに繰り上がりがあるので、は必ずになっていつも同じ
ならば繰り上がりは必ずないのでは必ずになっていつも同じ
のbit目が0のときはとなるが、この時繰り上がりがあるかどうかは不定なので、はによって変わる。
したがってを十分集めて、すべてのが一致するようなについて、のビット目は1である、一致しない場合は0である、と言えそう
from Crypto.Util.number import inverse, isPrime from ptrlib import Socket P = 2 ** 1000 - 1 while not isPrime(P): P -= 2 e = 65537 oracles_0 = [] oracles_1 = [] for i in range(15): sock = Socket("localhost", 9999) assert P == int(sock.recvlineafter("P = ")) n = int(sock.recvlineafter("n = ")) assert e == int(sock.recvlineafter("e = ")) x1 = int(sock.recvlineafter("x1 = ")) x2 = int(sock.recvlineafter("x2 = ")) v = (x1 + x2) * inverse(2, n) % n sock.sendlineafter("v: ", str(v)) c1 = int(sock.recvlineafter("c1 = ")) c2 = int(sock.recvlineafter("c2 = ")) x = (c1 + c2) % n oracles_0.append(x if x % 2 == 0 else x + n) oracles_1.append(x if x % 2 == 1 else x + n) assert len(set([o % 2 for o in oracles_0])) == 1 assert len(set([o % 2 for o in oracles_1])) == 1 def solve(oracles): bit_len = max([o.bit_length() for o in oracles]) r = 0 for i in range(bit_len - 1): zero_cases = set() one_cases = set() for o in oracles: if (o >> i) & 1 == 0: zero_cases.add((o >> (i+1)) & 1) else: one_cases.add((o >> (i+1)) & 1) if len(zero_cases) <= 1 and len(one_cases) <= 1: r |= 1 << i return r if (r := solve(oracles_0)) != 0: flag = pow(r, inverse(e, P - 1), P) print(flag.to_bytes((flag.bit_length() + 7) // 8, "big")) if (r := solve(oracles_1)) != 0: flag = pow(r, inverse(e, P - 1), P) print(flag.to_bytes((flag.bit_length() + 7) // 8, "big"))