RTACTF | Proth RSA

#rtactf

from Crypto.Util.number import getRandomInteger, getPrime, isPrime
import os

def getProthPrime(n=512):
    # Proth prime: https://en.wikipedia.org/wiki/Proth_prime
    while True:
        k = getRandomInteger(n)
        p = (2*k + 1) * (1<<n) + 1
        if isPrime(p):
            return p, k

if __name__ == '__main__':
    # Plaintext (FLAG)
    m = int.from_bytes(os.getenv("FLAG", "FAKE{sample_flag}").encode(), 'big')

    # Generate key
    p, k1 = getProthPrime()
    q, k2 = getProthPrime()
    n = p * q
    e = 65537
    s = (k1 * k2) % n

    # Encryption
    c = pow(m, e, n)

    # Information disclosure
    print(f"n = 0x{n:x}")
    print(f"e = 0x{e:x}")
    print(f"s = 0x{s:x}")
    print(f"c = 0x{c:x}")

#RSA

with open("output.txt") as f:
    n = int(f.readline().strip().split(" = ")[1], 16)
    e = int(f.readline().strip().split(" = ")[1], 16)
    s = int(f.readline().strip().split(" = ")[1], 16)
    c = int(f.readline().strip().split(" = ")[1], 16)

PR.<u,v> = QQ[]
p = (2*u + 1) * (1<<512) + 1
q = (2*v + 1) * (1<<512) + 1

polys = [
    p*q - n,
    u*v - s,
]
I = Ideal(polys)
ans = I.variety(ring=ZZ)[0]
u, v = ans[u], ans[v]

p = (2*u + 1) * (1<<512) + 1
q = (2*v + 1) * (1<<512) + 1

d = int(pow(e, -1, (p-1)*(q-1)))
m = pow(c, d, n)

print(bytes.fromhex(hex(m)[2:]))