0CTF Finals 2021 | ezrsa

#0ctf2021finals

from Crypto.Util.number import *
from secret import flag
from os import urandom

def gen(n_size, m_size):
    alpha = 0.5
    delta = 0.03
    d_size = int(delta * n_size)
    k_size = int((alpha + delta - 0.5) * n_size)
    c_size = int(n_size * (1 - alpha - 2 * delta))
    while True:
        while True:
            d_p = getRandomNBitInteger(d_size)
            d_q = getRandomNBitInteger(d_size)
            k = getRandomNBitInteger(k_size)
            l = getRandomNBitInteger(k_size)
            if GCD(k ,l) == 1 and GCD(d_p, k) == 1 and GCD(d_q, l) == 1:
                break
        e = inverse(d_p, k) * inverse(l, k) * l + inverse(d_q, l) * inverse(k, l) * k
        c = getRandomNBitInteger(c_size)
        e += c * k * l
        assert e * d_q % l == 1
        assert e * d_p % k == 1
        p = (e * d_p - 1) // k + 1
        q = (e * d_q - 1) // l + 1
        if isPrime(p) and isPrime(q):
            magic = 1337 * k ** 4 + 7331 * l ** 3 + 73331 * k ** 2 + 13337 * l ** 2 + 7 * k * l + 2 * k + l
            mask = 2 ** m_size - 1
            return (p * q, e), (d_p, d_q, p, q), (magic, d_p & mask, d_q & mask) 

def encrypt(m, pk):
    n, e = pk
    return pow(m, e, n)

n_size = 2000
m_size = 10
pk, sk, hint = gen(n_size, m_size)
flag = urandom(n_size // 8 - len(flag) - 1) + flag
enc = encrypt(int(flag.hex(), 16), pk)

print(pk)
print(hint)
print(enc)
'''
(13144833961692953638155744717380612667335058302310815242506755676885208234342620331186804951145894484501542968789132832800279633590988848298405521677820600481054741175400784558190943019903268095468121342412114428860754522164657102624139527993254089574309927288457799155130004731846999722554981630609692264462023821778810225493633789543259034893395115658330417361250466876018981150507377427664192443342394808337473089411393262018525828475108149889915075872592673448211565529063972264324533136645650169687118301014325354524932405270872098633633071371124551496573869700120350489760340226474892703585296623, 4976865541630914024304930292600669330017247151290783019063407119314069119952298933566289617702551408322779629557316539138884407655160925920670189379289389411163083468782698396121446186733546486790309424372952321446384824084362527492399667929050403530173432700957192011119967010196844119305465574740437)
(154118536863381755324327990994045278493514334577571515646858907141541837890, 431, 217)
12075538182684677737023332074837542797880423774993595442794806087281173669267997104408555839686283996516133283992342507757326913240132429242004071236464149863112788729225204797295863969020348408992315952963166814392745345811848977394200562308125908479180595553832800151118160338048296786712765863667672764499042391263351628529676289293121487926074423104988380291130127694041802572569416584214743544288441507782008422389394379332477148914009173609753877263990429988651290402630935296993764147874437465394433756515223371180032964253037946818633821940103044535390973722964105390263537722948112571112911062
'''

RSA-CRTインスタンス ed_p = 1 + k(p-1), ed_q = 1 + l(q-1)

 k, lからなるmagic と、 d_p, d_qの下位bitが与えられている

RSA-CRTのパラメータ同士の関係を考えると次の式が成り立つから \mod eで2式立って解けそうに見える

  •  magic = f(k, l)

  •  (k-1)*(l-1) \equiv klN \mod e

手で解くのは大変なのでresultantで変数を消去してunivariate coppersmith methodで求める

 k, lが求まれば、これを元に d_p, d_qunivariate coppersmith methodで計算して勝ち

load("./defund.sage")

n, e = (13144833961692953638155744717380612667335058302310815242506755676885208234342620331186804951145894484501542968789132832800279633590988848298405521677820600481054741175400784558190943019903268095468121342412114428860754522164657102624139527993254089574309927288457799155130004731846999722554981630609692264462023821778810225493633789543259034893395115658330417361250466876018981150507377427664192443342394808337473089411393262018525828475108149889915075872592673448211565529063972264324533136645650169687118301014325354524932405270872098633633071371124551496573869700120350489760340226474892703585296623, 4976865541630914024304930292600669330017247151290783019063407119314069119952298933566289617702551408322779629557316539138884407655160925920670189379289389411163083468782698396121446186733546486790309424372952321446384824084362527492399667929050403530173432700957192011119967010196844119305465574740437)
magic, mdp, mdq = (154118536863381755324327990994045278493514334577571515646858907141541837890, 431, 217)
c = 12075538182684677737023332074837542797880423774993595442794806087281173669267997104408555839686283996516133283992342507757326913240132429242004071236464149863112788729225204797295863969020348408992315952963166814392745345811848977394200562308125908479180595553832800151118160338048296786712765863667672764499042391263351628529676289293121487926074423104988380291130127694041802572569416584214743544288441507782008422389394379332477148914009173609753877263990429988651290402630935296993764147874437465394433756515223371180032964253037946818633821940103044535390973722964105390263537722948112571112911062


Zn = Zmod(e)
PR.<k,l> = PolynomialRing(Zn)

f1 = 1337 * k ** 4 + 7331 * l ** 3 + 73331 * k ** 2 + 13337 * l ** 2 + 7 * k * l + 2 * k + l - magic
f2 = (k-1)*(l-1) - k*l*n


PRz.<kz, lz> = PolynomialRing(Zn)
q1 = f1.change_ring(PRz)
q2 = f2.change_ring(PRz)

print("resultant...")
PRn.<ln> = PolynomialRing(Zn)
f = q1.resultant(q2)
f = f.univariate_polynomial().change_ring(PRn).subs(l=ln)
f = f.monic()

print("small roots...")
lvalue = int(f.small_roots(X=2^60)[0])
kvalue = int(f2.subs(l=lvalue).univariate_polynomial().monic().small_roots(X=2^60)[0])

print(lvalue, kvalue)

Pdp.<dp> = PolynomialRing(Zmod(kvalue))
fdp = e*(dp*2^10 + mdp) - 1

Pdq.<dq> = PolynomialRing(Zmod(lvalue))
fdq = e*(dq*2^10 + mdq) - 1

print("small roots...")
dp = int(fdp.monic().small_roots(X=2^50)[0]) * 2^10 + mdp
dq = int(fdq.monic().small_roots(X=2^50)[0]) * 2^10 + mdq
print(dp, dq)

p = (e*dp - 1) // kvalue + 1
q = n // p
assert n % p == 0
assert n % q == 0

phi = (p-1)*(q-1)
d = inverse_mod(e, phi)
m = pow(c, d, n)

print(m)
print(bytes.fromhex(hex(m)[2:]))