zer0pts CTF 2022 | OK

#zer0ptsCTF2022

from Crypto.Util.number import isPrime, getPrime, getRandomRange, inverse
import os
import signal

signal.alarm(300)

flag = os.environ.get("FLAG", "0nepoint{GOLDEN SMILE & SILVER TEARS}")
flag = int(flag.encode().hex(), 16)

P = 2 ** 1000 - 1
while not isPrime(P): P -= 2

p = getPrime(512)
q = getPrime(512)
e = 65537
phi = (p-1)*(q-1)
d = inverse(e, phi)
n = p*q

key = getRandomRange(0, n)
ciphertext = pow(flag, e, P) ^ key

x1 = getRandomRange(0, n)
x2 = getRandomRange(0, n)

print("P = {}".format(P))
print("n = {}".format(n))
print("e = {}".format(e))
print("x1 = {}".format(x1))
print("x2 = {}".format(x2))

# pick a random number k and compute v = k**e + (x1|x2)
# if you add x1, you can get key = c1 - k mod n
# elif you add x2, you can get ciphertext = c2 - k mod n
v = int(input("v: "))

k1 = pow(v - x1, d, n)
k2 = pow(v - x2, d, n)

print("c1 = {}".format((k1 + key) % n))
print("c2 = {}".format((k2 + ciphertext) % n))

RSA n, e, x_1, x_2が与えられる Oblivious Transfer

receiver(player)は kという値を作って v = x_1 + k^e \mod nまたは v = x_2 + k^e \mod nを計算して送り、sender (server)は k_1 = (v - x_1)^d \mod n, k_2 = (v - x_2)^d \mod nとして、 c_1 = m_1 + k_1 c_2 = m_2 + k_2を送る

receiverは m_1 = c_1 - kまたは m_2 = c_2 - kを復号可能、というのが本来のプロトコル

今回は平文をランダムに暗号化した ct = (m^e \mod P) \oplus keyか、その暗号化に用いられている keyのいずれかが手にはいる

この両方をなんとかして手に入れたい、というのが今回の目的

ct + key mod nの算出

 vとして  v - x_1 \equiv -(v - x_2) \mod nが成立するような値を送ってみる。すると dが奇数なので k_1 \equiv (v - x1)^d \equiv (-(v - x_2))^d \equiv -(v - x_2)^d \equiv -k_2 \mod nとなる。

このとき c_1 \equiv k_1 + ct \mod n,  c_2 \equiv -k_1 + key \mod nとなるので、 c_1 + c_2 \equiv ct + key \mod nを計算することができる。

ここから ct \oplus keyを求めたい。 ct + key \mod nは何回でも求められるとき、 ct \oplus keyを復元できるだろうか。

ct + keyの復元

とりあえず ct + key \mod nだと扱いにくいので、これを ct + keyにする。 ct \oplus key = flag^eはいつでも固定なので、 ct \oplus keyの最下位bitも当然いつでも同じになる。最下位bitについては加算と論理和は区別がないので、 ct + keyの結果の偶奇も常に固定される。

偶奇は2通りしかないので両方のパターンを試してみればいい。例えば ct + keyが偶数だろう、ということにすると ct + key \mod nが奇数ならその値は ct + key - nということになるから、 ct + keyを復元するには (ct + key \mod n) + nを計算すれば良い( nは奇数)

ct ^ keyの復元

 ct_i \oplus key_i = Xであるような ct_i + key_i = Y_iが十分に集まれば Xを復元できる。

 X jbit目が1のとき、繰り上がりがなければ ct_i + key_i = 1となるはずである。

 Y_{i,j} = 0ならば必ず次のbitに繰り上がりがあるので、 Y_{i,j+1}は必ず ct_{i,j} + key_{i,j} + 1になっていつも同じ

 Y_{i,j} = 1ならば繰り上がりは必ずないので Y_{i,j+1}は必ず ct_{i,j} + key_{i,j}になっていつも同じ

 X jbit目が0のときは ct_i + key_i = 0となるが、この時繰り上がりがあるかどうかは不定なので、 Y_{i,j} iによって変わる。

したがって Y_iを十分集めて、すべての Y_{i,j}が一致するような jについて、 X j-1ビット目は1である、一致しない場合は0である、と言えそう

from Crypto.Util.number import inverse, isPrime
from ptrlib import Socket

P = 2 ** 1000 - 1
while not isPrime(P): P -= 2
e = 65537

oracles_0 = []
oracles_1 = []
for i in range(15):
    sock = Socket("localhost", 9999)
    assert P == int(sock.recvlineafter("P = "))
    n = int(sock.recvlineafter("n = "))
    assert e == int(sock.recvlineafter("e = "))
    x1 = int(sock.recvlineafter("x1 = "))
    x2 = int(sock.recvlineafter("x2 = "))

    v = (x1 + x2) * inverse(2, n) % n
    sock.sendlineafter("v: ", str(v))

    c1 = int(sock.recvlineafter("c1 = "))
    c2 = int(sock.recvlineafter("c2 = "))
    x = (c1 + c2) % n
    oracles_0.append(x if x % 2 == 0 else x + n)
    oracles_1.append(x if x % 2 == 1 else x + n)

assert len(set([o % 2 for o in oracles_0])) == 1
assert len(set([o % 2 for o in oracles_1])) == 1

def solve(oracles):
    bit_len = max([o.bit_length() for o in oracles])

    r = 0
    for i in range(bit_len - 1):
        zero_cases = set()
        one_cases = set()

        for o in oracles:
            if (o >> i) & 1 == 0:
                zero_cases.add((o >> (i+1)) & 1)
            else:
                one_cases.add((o >> (i+1)) & 1)

        if len(zero_cases) <= 1 and len(one_cases) <= 1:
            r |= 1 << i
    
    return r

        
if (r := solve(oracles_0)) != 0:
    flag = pow(r, inverse(e, P - 1), P)
    print(flag.to_bytes((flag.bit_length() + 7) // 8, "big"))

if (r := solve(oracles_1)) != 0:
    flag = pow(r, inverse(e, P - 1), P)
    print(flag.to_bytes((flag.bit_length() + 7) // 8, "big"))