redpwn CTF 2021 | retrosign

#redpwnctf2021

TokyoWesterns CTF 6th 2020 | circularと全くおなじ問題設定。つまりOng-Schnorr-Shamir Digital Signature Schemeで、おなじsolverが使える

#!/usr/local/bin/python

from Crypto.Util.number import getPrime, bytes_to_long
from Crypto.Hash import SHA256
from binascii import unhexlify
from secrets import randbelow

with open('flag.txt','r') as f:
    flag = f.read().strip()

def sha256(val):
    h = SHA256.new()
    h.update(val)
    return h.digest()

def execute(cmd):
    if cmd == "sice_deets":
        print(flag)
    elif cmd == "bad_signature":
        print("INTRUSION DETECTED!")
    else:
        print("Command unknown.")

def authorize_command(cmd, sig):
    assert len(sig) == 128*2
    a = bytes_to_long(sig[:128])
    b = bytes_to_long(sig[128:])
    if (a**2 + k*b**2) % n == bytes_to_long(sha256(cmd)):
        execute(cmd.decode())
    else:
        execute("bad_signature")

p = getPrime(512)
q = getPrime(512)
n = p * q
k = randbelow(n)
def interact():
    print("===============================================================================")
    print("This mainframe is protected with state-of-the-art intrusion detection software.")
    print("All commands are passed through a signature-based filter.")
    print("===============================================================================")
    print("The following configuration is in place:")
    print(f"n = {n};\nk = {k};")
    print("Server configured.")
    cmd = input(">>> ").strip().lower().encode()
    sig = unhexlify(input("$$$ "))
    authorize_command(cmd, sig)
    print("Connection closed.")

if __name__ == "__main__":
    try:
        interact()
    except:
        print("An error has occurred.")

つまり n, kが与えられるので a^2 + kb^2 = H(cmd) \mod nを満たす a, bを求めよ、という問題(方程式, Diophantus Equation

↑にも書いたけどTokyoWesterns CTF 6th 2020 | circularのパクりで同じsolverが走る

s3v3ru5のsolverは動かなかったけどrkm0959のsolverは動いた

from Crypto.Util.number import long_to_bytes, bytes_to_long, GCD, inverse, isPrime
import random

def kthp(n, k):
    if n == 0:
        return 0
    lef = 1
    rig = 2
    while rig ** k < n:
        rig = rig << 1
    while lef <= rig:
        mid = (lef + rig) // 2
        if mid ** k <= n:
            best = mid
            lef = mid + 1
        else:
            rig = mid - 1
    return best

def tonelli(x, tar):
    tar %= x
    if pow(tar, (x-1) // 2, x) != 1:
        return -1
    S, tp, cc, sz = 0, x, x-1, 0
    if tar == 0:
        return 0
    if x % 4 == 3:
        return pow(tar, (x+1) // 4, x)
    while cc % 2 == 0:
        cc //= 2
        S += 1
    Q = cc
    z = 2
    while True:
        if pow(z, (x-1)//2, x) != 1:
            break
        z += 1
    M, c, t, R = S, pow(z, Q, x), pow(tar, Q, x), pow(tar, (Q+1) // 2, x)
    while True:
        if t == 0:
            return 0
        if t == 1:
            return R % x
        sx, tem = 0, t
        while True:
            sx += 1
            tem = (tem * tem) % x
            if tem == 1:
                break
        b = pow(c, 1<<(M-sx-1), x)
        M, c = sx, (b * b) % x
        t, R = (t * c) % x, (R * b) % x

def comb(x1, y1, x2, y2, k, n):
    return (x1 * x2 + k * y1 * y2) % n, (x1 * y2 - x2 * y1) % n

def solve(k, m, n):  # solve x^2 + ky^2 == m mod n
    print("solve", k, m, n)
    fu = kthp(m, 2)
    if fu * fu == m:
        return (fu, 0)
    if k < 0:
        se = kthp(-k, 2)
        if se * se == -k:
            retx = (m+1) * inverse(2, n) % n
            rety = (m-1) * inverse(2 * se, n) % n
            return retx, rety
    if m == 1:
        return (1, 0)
    if m == k % n:
        return (0, 1)
    while True:
        u = random.getrandbits(1024)
        v = random.getrandbits(1024)
        m_0 = (m * (u * u + k * v * v)) % n
        if isPrime(m_0):
            if GCD(m_0, n) != 1:
                print("LOL", m_0)
                exit()
            x_0 = tonelli(m_0, (-k) % m_0)
            if (x_0 * x_0 + k) % m_0 == 0:
                break
    ms = [m_0]
    xs = [x_0]
    sz = 1
    while True:
        new_m = (xs[sz-1] * xs[sz-1] + k) // ms[sz-1]
        ms.append(new_m)
        if k > 0 and xs[sz-1] <= ms[sz] <= ms[sz-1]:
            sz = sz + 1
            break
        if k < 0 and abs(ms[sz]) <= kthp(abs(k), 2):
            sz = sz + 1
            break
        xs.append(min(xs[sz-1] % ms[sz], ms[sz] - (xs[sz-1] % ms[sz])))
        sz = sz + 1
    assert sz == len(ms)
    assert sz - 1 == len(xs)
    uu, vv = xs[0], 1
    dv = 1
    for i in range(1, sz-1):
        assert (xs[i] ** 2 + k) % n == (ms[i] * ms[i+1]) % n
        uu, vv = comb(uu, vv, xs[i], 1, k, n)
        dv = (dv * ms[i]) % n
    dv = (dv * ms[sz-1]) % n
    uu = (uu * inverse(dv, n)) % n
    vv = (vv * inverse(dv, n)) % n
    X, Y = solve(-ms[sz-1], (-k) % n, n)
    soly = inverse(Y, n)
    solx = (X * soly) % n
    finx, finy = comb(solx, soly, uu, vv, k, n)
    godx = ((finx * u - k * finy * v) * inverse(u * u + k * v * v, n)) % n
    gody = ((finx * v + finy * u) * inverse(u * u + k * v * v, n)) % n
    return godx, gody

from ptrlib import Socket, Process
from hashlib import sha256


# sock = Socket("nc localhost 9999")

sock = Socket("nc mc.ax 31079")
pow_input = sock.recvlineafter("sh -s ").decode()
pow_solved = Process(["redpwnpow", pow_input]).recv().decode().strip()
sock.sendline(pow_solved)


n = int(sock.recvlineafter("n = ")[:-1])
k = int(sock.recvlineafter("k = ")[:-1])

print(n, k)

cmd = b"sice_deets"

h = bytes_to_long(sha256(cmd).digest())

x, y = solve(k, h, n)

print("[*] x = ", x)
print("[*] y = ", y)

print(cmd)
print(sock.recvuntil(">>").decode())
sock.sendline(cmd.decode())
sock.sendlineafter("$$$ ", (long_to_bytes(x, 128) + long_to_bytes(y, 128)).hex())

sock.interactive()