Midnightsun CTF 2022 Quals | pelle's rotor supported arithmetic

#MidnightSunCTF_2022_Quals

#!/usr/bin/python3
from sys import stdin, stdout, exit
from flag import FLAG
from secrets import randbelow
from gmpy import next_prime

p = int(next_prime(randbelow(2**512)))
q = int(next_prime(randbelow(2**512)))
n = p * q
e = 65537

phi = (p - 1)*(q - 1)
d = int(pow(e, -1, phi))
d_len = len(str(d))

print("encrypted flag", pow(FLAG, 3331646268016923629, n))
stdout.flush()

ctr = 0
def oracle(c, i):
    global ctr
    if ctr > 10 * d_len // 9:
        print("Come on, that was already way too generous...")
        return
    ctr += 1
    rotor = lambda d, i: int(str(d)[i % d_len:] + str(d)[:i % d_len])
    return int(pow(c, rotor(d, i), n))

banner = lambda: stdout.write("""
Pelle's Rotor Supported Arithmetic Oracle
1) Query the oracle with a ciphertext and rotation value.
2) Exit.
""")

banner()
stdout.flush()

choices = {
    1: oracle,
    2: exit
}

while True:
    try:
        choice = stdin.readline()
        print("c:")
        stdout.flush()
        cipher = stdin.readline()
        print("rot:")
        stdout.flush()
        rotation = stdin.readline()
        print(choices.get(int(choice))(int(cipher), int(rotation)))
        stdout.flush()
    except Exception as e:
        stdout.write("%s\n" % e)
        stdout.flush()
        exit()

RSAdに関するオラクルが得られる問題。類題としてはTetCTF 2022 | faultとか、corCTF 2021 | dividing secretsとか

 c^{rot(d, 1)} を考えると、 dの不明なdigitを x dの桁数を lとして以下が成り立つから xを全探索すればよい

 c^{rot(d, 1)} \equiv c^{10d} * c^{-(10^{l}x)} * c^{x} \mod n

 c^{rot(d, i+1)} \equiv c^{rot(d, i)*10 - 10^lx + x} \mod n

 nは適当に -1を入れれば手に入る

 l c^{rot(d, 0)} = c^{rot(d, l}なる lを探せば良い

 dがわかれはhow to factorize N given dとかでいくらでも

from ptrlib import Socket
from Crypto.Util.number import inverse

def factorize(N, e, d):
    from math import gcd
    import gmpy2

    k = d*e - 1
    t = k
    while t % 2 == 0:
        t //= 2

    g = 3
    while True:
        x = pow(g, t, N)
        if x > 1:
            y = gcd(x - 1, N)
            if y > 1:
                return y, N//y
        g = gmpy2.next_prime(g)

sock = Socket("localhost", 9999)

target = int(sock.recvlineafter("encrypted flag "))
sock.recvuntil("2) Exit.\n")

def oracle(c, r):
    sock.sendline("1")
    sock.sendlineafter("c:\n", str(c))
    sock.sendlineafter("rot:\n", str(r))

    return int(sock.recvline())

# find n
n = oracle(-1, 0) + 1

# find d_len
t = oracle(2, 0)
d_len = len(str(n))
while True:
    if oracle(2, d_len) == t:
        break
    d_len -= 1

# find d digit by digit
m = n // 2
c = pow(m, 65537, n)
oracles = [oracle(c, i) for i in range(d_len)]
oracles.append(oracles[0])

digits = 0
for i in range(len(oracles) - 1):
    
    for x in range(0, 10):
        s = pow(oracles[i], 10, n) * pow(c, x, n) * inverse(pow(c, 10**d_len * x, n), n) % n
        if s == oracles[i+1]:
            digits = 10 * digits + x
            break
    else:
        raise ValueError(f"not found at {i}")
assert pow(c, digits, n) == m

# factoring n
p, q = factorize(n, 65537, digits)
e = 3331646268016923629
d = pow(e, -1, (p-1)*(q-1))

m = pow(target, d, n)
print(int(m).to_bytes(100, "big"))