RaR CTF 2021 | rotoRSA

#rarctf2021

from sympy import poly, symbols
from collections import deque
import Crypto.Random.random as random
from Crypto.Util.number import getPrime, bytes_to_long, long_to_bytes
import sys

def build_poly(coeffs):
    x = symbols('x')
    return poly(sum(coeff * x ** i for i, coeff in enumerate(coeffs)))

def encrypt_msg(msg, poly, e, N):
    print(poly(msg), file=sys.stderr)
    return long_to_bytes(pow(poly(msg), e, N)).hex()

p = getPrime(256)
q = getPrime(256)
N = p * q
e = 11

flag = bytes_to_long(open("/challenge/flag.txt", "rb").read())

coeffs = deque([random.randint(0, 128) for _ in range(16)])


welcome_message = f"""
Welcome to RotorSA!
With our state of the art encryption system, you have two options:
1. Encrypt a message
2. Get the encrypted flag
The current public key is
n = {N}
e = {e}
"""

print(welcome_message)

while True:
    padding = build_poly(coeffs)
    choice = int(input('What is your choice? '))
    if choice == 1:
        message = int(input('What is your message? '), 16)
        print(coeffs, file=sys.stderr)
        print(message, file=sys.stderr)
        encrypted = encrypt_msg(message, padding, e, N)
        print(encrypted, file=sys.stderr)
        print(f'The encrypted message is {encrypted}')
    elif choice == 2:
        encrypted_flag = encrypt_msg(flag, padding, e, N)
        print(f'The encrypted flag is {encrypted_flag}')
    coeffs.rotate(1)

多項式paddingをつくっていて、16回暗号化すると一周する RSA

 0を暗号化してもらえば定数項を復元できるので16回やって係数全部復元してからpolynomialgcdとって復号する

from sage.all import *
from ptrlib import Socket

sock = Socket("193.57.159.27", 27407)
# sock = Socket("localhost", 9999)

n = int(sock.recvlineafter("n = "))
e = int(sock.recvlineafter("e = "))

coeffs = []
for _ in range(16):
    sock.sendlineafter("choice? ", "1")
    sock.sendlineafter("message? ", "0")
    c = int(sock.recvlineafter("message is ").decode(), 16)
    k = Integer(c).nth_root(e)
    coeffs.append(int(k))

coeff1 = [ coeffs[0]] + coeffs[1:][::-1]
# coeff2 = [ coeff2[-1] ] + coeff2[0:-1]
coeff2 = [ coeff1[-1]] + coeff1[0:-1]
print(coeff1)
print(coeff2)

input("[X]")


# check
x = 200
sock.sendlineafter("choice? ", "1")
sock.sendlineafter("message? ", hex(x))
c1 = int(sock.recvlineafter("message is ").decode(), 16)
print(pow(sum([c*200**i for i, c in enumerate(coeff1)]), e, n) == c1)

sock.sendlineafter("choice? ", "1")
sock.sendlineafter("message? ", hex(x))
c2 = int(sock.recvlineafter("message is ").decode(), 16)
print(pow(sum([c*200**i for i, c in enumerate(coeff2)]), e, n) == c2)


for i in range(14):
    sock.sendlineafter("choice? ", "1")
    sock.sendlineafter("message? ", "200")
    # c = int(sock.recvlineafter("message is ").decode(), 16)
    # k = Integer(c).nth_root(e)
    # print(coeffs[i] == k)


sock.sendlineafter("choice? ", "2")
c1 = int(sock.recvlineafter("flag is "), 16)

sock.sendlineafter("choice? ", "2")
c2 = int(sock.recvlineafter("flag is "), 16)

PR = PolynomialRing(Zmod(n), name="x")
x = PR.gen()

def gcd(a, b):
    while b != 0:
        a, b = b, a % b
    return a

f1 = c1 - sum([c*x**i for i, c in enumerate(coeff1)])**e
f2 = c2 - sum([c*x**i for i, c in enumerate(coeff2)])**e

r = gcd(f1, f2) 
print(-r[0] * inverse_mod(r[1], n) % n)