corCTF 2021 | lcg_k

#corctf2021

from Crypto.Util.number import bytes_to_long, inverse
from hashlib import sha256
from secrets import randbelow
from private import flag
from fastecdsa.curve import P256

G = P256.G
N = P256.q

class RNG:
    def __init__(self, seed, A, b, p):
        self.seed = seed
        self.A = A
        self.b = b
        self.p = p

    def gen(self):
        out = self.seed
        while True:
            out = (self.A*out + self.b) % self.p
            yield out

def H(m):
    h = sha256()
    h.update(m)
    return bytes_to_long(h.digest())

def sign(m):
    k = next(gen)
    r = int((k*G).x) % N
    s = ((H(m) + d*r)*inverse(k, N)) % N
    return r, s

def verify(r, s, m):
    v1 = H(m)*inverse(s, N) % N
    v2 = r*inverse(s, N) % N
    V = v1*G + v2*pub
    return int(V.x) % N == r

seed, A, b = randbelow(N), randbelow(N), randbelow(N)
lcg = RNG(seed, A, b, N)
gen = lcg.gen()
d = randbelow(N)
pub = d*G
mymsg = b'i wish to know the ways of the world'

print('public key:', pub)
signed_hashes = []

for _ in range(4):
    m = bytes.fromhex(input('give me something to sign, in hex>'))
    h = H(m)
    if m == mymsg or h in signed_hashes:
        print("i won't sign that.")
        exit()
    signed_hashes.append(h)
    r, s = sign(m)
    print('r:', str(r))
    print('s:', str(s))
print('now, i want you to sign my message.')
r = int(input('give me r>'))
s = int(input('give me s>'))
if verify(r, s, mymsg):
    print("nice. i'll give you the flag.")
    print(flag)
else:
    print("no, that's wrong.")

#LCG#ECDSA

フラグを得るためには mymsgに署名する必要があるが、4回与えられる署名のチャンスではメッセージそのものに署名することはできない。また同一のハッシュを生成するような(すなわち同一の)平文にも署名してもらえない。

4つの署名とLCGから式立てて解くんだろう

適当に  z_1, \dots, z_4に署名してもらって、それを (r_1, s_1), \dots, (r_4, s_4)と置き、またLCGのパラメータを A, b, p、出力される値を k_1, \dots, k_4とする

当然  s_i = (z_i + dr_i)k_i^{-1} \mod N が成り立つ。 r_1, s_1について k_1 A, b, pを用いて表してみると

 s_1(Ak_0 + b) = z_1 + dr_1 \mod N

 s_2(Ak_1 + b) = z_2 + dr_2 \mod N = s_2(A(Ak_0 + b) + b) = z_2 + dr_2 \mod N

という感じで4つ式が立つ。対して未知変数は A, b, k_0, dの4つなので解けそうな感じがしてきた

groebner basisを使っても解けるしresultantを使っても解けると思うが、今回は前者を用いた

from ptrlib import Socket
from hashlib import sha256

def H(m):
    h = sha256()
    h.update(m)
    return int(h.hexdigest(), 16)

p = 0xffffffff00000001000000000000000000000000ffffffffffffffffffffffff
K = GF(p)
a = K(0xffffffff00000001000000000000000000000000fffffffffffffffffffffffc)
b = K(0x5ac635d8aa3a93e7b3ebbd55769886bc651d06b0cc53b0f63bce3c3e27d2604b)
E = EllipticCurve(K, (a, b))
G = E(0x6b17d1f2e12c4247f8bce6e563a440f277037d812deb33a0f4a13945d898c296, 0x4fe342e2fe1a7f9b8ee7eb4a7c0f9e162bce33576b315ececbb6406837bf51f5)
E.set_order(0xffffffff00000000ffffffffffffffffbce6faada7179e84f3b9cac2fc632551 * 0x1)
N = E.order()

def sign(m, k):
    r = int((k*G)[0]) % N
    s = ((H(m) + d*r)*inverse_mod(k, N)) % N
    return int(r), int(s)

def verify(P, r, s, m):
    v1 = H(m)*inverse_mod(s, N) % N
    v2 = r*inverse_mod(s, N) % N
    V = v1*G + v2*P
    return int(V[0]) % N == r


sock = Socket("localhost", 9999)
X = int(sock.recvlineafter("X: "), 16)
Y = int(sock.recvlineafter("Y: "), 16)

zs = []
rs = []
ss = []

for i in range(4):
    m = bytes([i])
    sock.sendlineafter("hex>", m.hex())

    r = int(sock.recvlineafter("r: "))
    s = int(sock.recvlineafter("s: "))

    zs.append(H(m))
    rs.append(r)
    ss.append(s)

PR.<A, b, k_0, d> = PolynomialRing(GF(E.order()))

def nthk(n):
    _k = k_0
    for _ in range(n):
        _k = A*_k + b
    return _k


I = Ideal([
    ss[0]*nthk(1) - (zs[0] + d*rs[0]),
    ss[1]*nthk(2) - (zs[1] + d*rs[1]),
    ss[2]*nthk(3) - (zs[2] + d*rs[2]),
    ss[3]*nthk(4) - (zs[3] + d*rs[3]),
])
# print(I)
B = I.groebner_basis()
roots = B[0].univariate_polynomial().roots()
for root in roots:
    d, _ = root
    P = int(d)*G
    if P.xy() != (X, Y):
        print("[-] bad luck")
        continue
    
    A = int(B[1].subs(d=d).univariate_polynomial().roots()[0][0])
    b = int(B[2].subs(d=d).univariate_polynomial().roots()[0][0])
    k_0 = int(B[3].subs(d=d).univariate_polynomial().roots()[0][0])
    print(A)
    print(b)
    print(k_0)

    mymsg = b'i wish to know the ways of the world'

    k = nthk(5) % E.order()
    r, s = sign(mymsg, k)
    print(verify(P, r, s, mymsg))

    sock.sendlineafter("r>", str(r))
    sock.sendlineafter("s>", str(s))

    sock.interactive()

resultantを使うとこういう感じになるらしい

from Crypto.Util.number import bytes_to_long, inverse
from hashlib import sha256
from secrets import randbelow
from pwn import remote

def H(m):
    h = sha256()
    h.update(m)
    return bytes_to_long(h.digest())

def get_sig(m):
    io.recvuntil(b"hex>")
    io.sendline(m.hex())
    r = io.recvline().strip().decode().strip("r: ")
    s = io.recvline().strip().decode().strip("s: ")
    print(r, s)
    r, s = map(int, (r, s))
    return H(m), r, s

host, port = "crypto.be.ax", 6002
io = remote(host, port)
io.recvuntil(b"X: 0x")

X = int(io.recvline().strip().decode(), 16)
print(X)
Y = int(io.recvline().strip().decode().strip("Y: 0x"), 16)

m1, r1, s1 = get_sig(b"a")
m2, r2, s2 = get_sig(b"b")
m3, r3, s3 = get_sig(b"c")
m4, r4, s4 = get_sig(b"d")

n = 115792089210356248762697446949407573529996955224135760342422259061068512044369

P.<a, b> = PolynomialRing(Zmod(n))

res1 = -a^2*m2*r1^2*s3 + a^2*m1*r1*r2*s3 + a*b*r1*r2*s1*s3 - a*b*r1^2*s2*s3 + a*m3*r1^2*s2 - a*m1*r1*r3*s2 - b*r1*r3*s1*s2 + b*r1*r2*s1*s3 - m3*r1*r2*s1 + m2*r1*r3*s1
res2 = -a^3*m2*r1^2*s4 + a^3*m1*r1*r2*s4 + a^2*b*r1*r2*s1*s4 - a^2*b*r1^2*s2*s4 + a*b*r1*r2*s1*s4 - a*b*r1^2*s2*s4 + a*m4*r1^2*s2 - a*m1*r1*r4*s2 - b*r1*r4*s1*s2 + b*r1*r2*s1*s4 - m4*r1*r2*s1 + m2*r1*r4*s1

PZ.<xz, yz> = PolynomialRing(Zmod(n))
q1 = res1.change_ring(PZ)
q2 = res2.change_ring(PZ)
h = q2.resultant(q1).univariate_polynomial()

AB = []
for (b_r, _) in h.roots():
    a_roots = res1.subs({b: int(b_r)}).univariate_polynomial().roots()
    for (a_r, _) in a_roots:
        AB.append((int(a_r), int(b_r)))

P256 = EllipticCurve(GF(115792089210356248762697446949407573530086143415290314195533631308867097853951), 
    [-3, 41058363725152142129326129780047268409114441015993725554835256314039467401291])

G = P256(48439561293906451759052585252797914202762949526041747995844080717082404635286, 36134250956749795798585127919587881956611106672985015071877198253568414405109)
P = P256(X, Y)

for (a, b) in AB:
    k = (m2*r1 - m1*r2 - b*r1*s2) * inverse_mod(a*r1*s2 - r2*s1, n) % n
    d = (s1*k - m1)*inverse_mod(r1, n) % n
    if d*G == P:
        print("Found private key", d)
        break

k = 193922
mymsg = b'i wish to know the ways of the world'
mi = H(mymsg)

r = int((k*G)[0]) % n
s = (mi + d*r)*inverse_mod(k, n) % n
io.recvuntil(b"r>")
io.sendline(str(r).encode())
io.recvuntil(b"s>")
io.sendline(str(s).encode())
io.interactive()