#!/usr/local/bin/python # # Polymero # # Imports from Crypto.Util.number import isPrime, getPrime, inverse import hashlib, time, os # Local import FLAG = os.environ.get('FLAG').encode() class URSA: # Upgraded RSA (faster and with cheap key cycling) def __init__(self, pbit, lbit): p, q = self.prime_gen(pbit, lbit) self.public = {'n': p * q, 'e': 0x10001} self.private = {'p': p, 'q': q, 'f': (p - 1)*(q - 1), 'd': inverse(self.public['e'], (p - 1)*(q - 1))} def prime_gen(self, pbit, lbit): # Smooth primes are FAST primes ~ ! while True: qlst = [getPrime(lbit) for _ in range(pbit // lbit)] if len(qlst) - len(set(qlst)) <= 1: continue q = 1 for ql in qlst: q *= ql Q = 2 * q + 1 if isPrime(Q): break while True: plst = [getPrime(lbit) for _ in range(pbit // lbit)] if len(plst) - len(set(plst)) <= 1: continue p = 1 for pl in plst: p *= pl P = 2 * p + 1 if isPrime(P): break return P, Q def update_key(self): # Prime generation is expensive, so we'll just update d and e instead ^w^ self.private['d'] ^= int.from_bytes(hashlib.sha512((str(self.private['d']) + str(time.time())).encode()).digest(), 'big') self.private['d'] %= self.private['f'] self.public['e'] = inverse(self.private['d'], self.private['f']) def encrypt(self, m_int): c_lst = [] while m_int: c_lst += [pow(m_int, self.public['e'], self.public['n'])] m_int //= self.public['n'] return c_lst def decrypt(self, c_int): m_lst = [] while c_int: m_lst += [pow(c_int, self.private['d'], self.public['n'])] c_int //= self.public['n'] return m_lst # Challenge setup print("""| | ~ Welcome to URSA decryption services | Press enter to start key generation...""") input("|") print("""| | Please hold on while we generate your primes... |\n|""") oracle = URSA(256, 12) print("| ~ You are connected to an URSA-256-12 service, public key ::") print("| id = {}".format(hashlib.sha256(str(oracle.public['n']).encode()).hexdigest())) print("| e = {}".format(oracle.public['e'])) print("|\n| ~ Here is a free flag sample, enjoy ::") for i in oracle.encrypt(int.from_bytes(FLAG, 'big')): print("| {}".format(i)) MENU = """| | ~ Menu (key updated after {} requests):: | [E]ncrypt | [D]ecrypt | [U]pdate key | [Q]uit |""" # Server loop CYCLE = 0 while True: try: if CYCLE % 4: print(MENU.format(4 - CYCLE)) choice = input("| > ") else: choice = 'u' if choice.lower() == 'e': msg = int(input("|\n| > (int) ")) print("|\n| ~ Encryption ::") for i in oracle.encrypt(msg): print("| {}".format(i)) elif choice.lower() == 'd': cip = int(input("|\n| > (int) ")) print("|\n| ~ Decryption ::") for i in oracle.decrypt(cip): print("| {}".format(i)) elif choice.lower() == 'u': oracle.update_key() print("|\n| ~ Key updated succesfully ::") print("| id = {}".format(hashlib.sha256(str(oracle.public['n']).encode()).hexdigest())) print("| e = {}".format(oracle.public['e'])) CYCLE = 0 elif choice.lower() == 'q': print("|\n| ~ Closing services...\n|") break else: print("|\n| ~ ERROR - Unknown command") CYCLE += 1 except KeyboardInterrupt: print("\n| ~ Closing services...\n|") break except: print("|\n| ~ Please do NOT abuse our services.\n|")
RSA P, Qがsmoothで、Nは不明
flagを暗号化したものがもらえる
その後暗号化 & 復号をいくらでもしてもらえるが、途中でe, d がupdateされてしまう
が不明ならとりあえず-1いれるかと思ったけど、-1
入れると無限ループするので、より大きい値を平文として入れると配列の要素数が2倍になることを利用して二分探索する
が分かれば、, がB-smoothなのでp-1法で素因数分解できる
from ptrlib import Socket from tqdm import tqdm import hashlib import re import math def pollard(n): a = 2 b = 2 while True: a = pow(a, b, n) d = math.gcd(a - 1, n) if 1 < d < n: return d b += 1 sock = Socket("nc localhost 9999") sock.sendline("a") nid = sock.recvlineafter("id = ").decode() e = 65537 c = int(sock.recvlineafter("enjoy ::\n|").strip()) # binsearch to get n ok = 1 ng = 2**513 pbar = tqdm() while abs(ok - ng) > 1: m = (ok + ng) // 2 sock.sendlineafter("> ", "E") sock.sendlineafter("> (int) ", str(m)) cnt = 0 while True: line = sock.recvline().decode().strip() if re.match(r"\|\s+\d+", line): cnt += 1 if "~ Menu" in line: break if cnt <= 1: ok = m else: ng = m pbar.update(1) print("-------") print(ok) print(ng) if hashlib.sha256(str(ok).encode()).hexdigest() == nid: n = ok elif hashlib.sha256(str(ng).encode()).hexdigest() == nid: n = ng else: assert False p = pollard(n) q = n // p d = pow(e, -1, (p-1)*(q-1)) m = pow(c, d, n) print(bytes.fromhex(hex(m)[2:]))