Aero CTF 2021 | boggart

#aeroctf2021 #good_challenges_2021

#!/usr/bin/env python3.8

from gmpy import next_prime
from random import getrandbits


def bytes_to_long(data):
    return int.from_bytes(data, 'big')


class Wardrobe:
    @staticmethod
    def create_boggarts(fear, danger):
        # for each level of danger we're increasing fear
        while danger > 0:
            fear = next_prime(fear)

            if getrandbits(1):
                yield fear
                danger -= 1


class Wizard:
    def __init__(self, magic, year, experience):
        self.magic = magic
        self.knowledge = year - 1  # the wizard is currently studying the current year
        self.experience = experience

    def cast_riddikulus(self, boggart):
        # increasing the wizard's experience by casting the riddikulus charm
        knowledge, experience = self.knowledge, self.experience

        while boggart > 1:
            knowledge, experience = experience, (experience * self.experience - knowledge) % self.magic
            boggart -= 1

        self.experience = experience


def main():
    year = 3
    bits = 512
    boggart_fear = 31337
    boggart_danger = 16

    neutral_magic, light_magic, dark_magic = [getrandbits(bits) for _ in range(3)]
    magic = next_prime(neutral_magic | light_magic) * next_prime(neutral_magic | dark_magic)

    print('Hello. I am Professor Remus Lupin. Today I\'m going to show you how to deal with the boggart.')

    print(neutral_magic)
    print(magic)

    with open('flag.txt', 'rb') as file:
        flag = file.read().strip()

    # some young wizards without knowledge of the riddikulus charm
    harry_potter = Wizard(magic, year, bytes_to_long(b'the boy who lived'))
    you = Wizard(magic, year, bytes_to_long(flag))

    for boggart in Wardrobe.create_boggarts(boggart_fear, boggart_danger):
        # wizards should train to learn the riddikulus charm
        harry_potter.cast_riddikulus(boggart)
        you.cast_riddikulus(boggart)

    # wizard's experience should be increased
    print(harry_potter.experience)
    print(you.experience)


if __name__ == '__main__':
    main()

概要

  • neutral , light , dark はそれぞれ 512bitの乱数

  • [tex: magic = next\prime(neutral|light) * next\prime(neutral|dark)]

    • neutralmagic は与えられる
  • create_boggarts(a, b) は次のように振る舞う

    • a = next_prime(a) としながら b 回、途中の a を返す

      • a を返すかどうかは1/2 の乱数で決まる

      • (のでおよそ 2b 回のループが回る?)

  • Wizard(magic, knowledge, experience) は次のように振る舞う

  • magic : ↑で与えられた合成数 magic (いわゆる n

  • knowledge : 2

  • experience : 平文

  • cast_riddikulus(b)

  •  k_0, e_0 = knowledge, experience として  e_b を計算する

  •  k_i = e_{i-1}

  •  e_i = e_{i-1} * e_{0} - k_{i-1} \mod n

考察

  • 求めたいのは you の初期の experience

  • ヒントは harry の初期の experience

    • 同じ引数で cast_riddikulus を呼んでいるので  e_i の値が異なる部分を直せばなんとかなる、かも
  • cast_riddikulus 自体は 16 回呼ばれていて、引数はそれぞれ素数で単調増加する

  • 実は cast_riddikulus の式は次のように書き換えられる。多項間漸化式、 リュカ数列 の形になっていることがわかる

    •  e_0 = 2

    •  e_1 = experience

    •  e_{i} = e_{i-1} * e_1 - e_{i-2} \mod N

  •  e_0, e_i から  e_1 を求められるか

    •  e_i = e_{i-1} e_1 - e_{i-2}

    •  e_{i-1} = e_{i-2}e_1 - e_{i-3}

    • 行列で書くと

      •  \begin{pmatrix} e_{i+2} \ e_{i+1} \end{pmatrix} = \begin{pmatrix} e_1 & -1 \ 1 & 0 \end{pmatrix} \begin{pmatrix} e_{i+1} \ e_i \end{pmatrix} =  \begin{pmatrix} e_1 & -1 \ 1 & 0 \end{pmatrix}^{i-1} \begin{pmatrix} e_1 \ e_0 \end{pmatrix}

        • こうかけるので、この暗号は暗号鍵についての可換性を持つことがわかる
    • また、  e_nリュカ数列

       V_n(e_1, 1) として表されるので、その一般項から↓と書ける

      •  e_n = \left(\frac{e_1 + \sqrt{e_1^2 - 4}}{2} \right)^n + \left(\frac{e_1 - \sqrt{e_1^2 - 4}}{2} \right)^n
  • harry e_1 が分かっていると何が嬉しいのか

    • cast_riddiculus の引数がわかるとか?
  • magic の値は謎の生成方法をとっているが、たとえばこれは素因数分解できるということなのか?

  • 素因数分解branch and prune でできることがわかった

    • 素因数分解できるかどうかは neutral がどれだけわかっているかによるので、よい neutral を引き当てるまで頑張る
  • 素因数分解できると何が嬉しい?

writeupを読んでいるとわかること

  • 素因数分解 にはbranch and pruneとHensel's liftを組み合わせた方法を使っている?

  • これは LUC Cryptosystem というらしい

  •  e_n をリュカ数列  V_n の一般項として表したときの

    •  e_n = \left(\frac{e_1 + \sqrt{e_1^2 - 4}}{2} \right)^n + \left(\frac{e_1 - \sqrt{e_1^2 - 4}}{2} \right)^n で、  t = \frac{e_1 + \sqrt{e_1^2 - 4}}{2} とおくと、  t^{-1} = \frac{e_1 - \sqrt{e_1^2 - 4}}{2} になっている

      • ※これはたまたまだと思う
    • かつ、リュカ数列  V の一般項  V_n = \alpha^n + \beta^n と表されるから、  e_n = t^n + t^{-n} である

  • ということは、暗号文  c = t^n + t^{-n} \mod N で、↓の手順で平文  m が復元できそう

    •  c から  t^n をもとめる

    •  t^n から  t を求める

    •  t から  m = t + t^{-1} を求める

  • この復元ステップを実行するためには

    •  \mod N素因数分解したい(  t^n から  t を求めるために必要そう)

    •  c^2-4=(t^n+t^{-n})^2-4 = (t^n-t^{-n})^2 なので  c^2 -4平方剰余 をとると  t^n - t^{-n} がでてくるので、  c = t^n + t^{-n} と足し引きして  t^n を取り出せる

  • harryの  e_1 e_n がわかっているので、これをもとに meet-in-the-middle attack をすることで cast_riddikulus の引数に使われた16個の素数を当てることができる

  • 大体32個〜の素数が全体で使われていると見て、16個を前半パート、16個を後半パートにして、前半パートのうちの8個を選んで暗号化したやつと、後半パートのうちの8個を選んで復号したやつの状態が一致すれば良さそう。これは16個から8個を選ぶ組み合わせの通りなので  \binom{16}{8} = 12870 通りの2倍で、だいたいfeasible

  • ちなみに鍵の可換性があるので前半パート後半パートは適当に選んで良くて、8個の鍵があたってさえいれば良い

かなりの運が必要だが、これで解ける

from ptrlib import Socket
from gmpy2 import mpz
from Crypto.Util.number import bytes_to_long

def legendre_symbol(a, p):
    ls = pow(a, (p - 1)//2, p)
    if ls == p - 1:
        return -1
    return ls

# Source: https://codereview.stackexchange.com/q/43210
def prime_mod_sqrt(a, p):
    """
    Square root modulo prime number
    Solve the equation
        x^2 = a mod p
    and return list of x solution
    http://en.wikipedia.org/wiki/Tonelli-Shanks_algorithm
    """
    a %= p

    # Simple case
    if a == 0:
        return [0]
    if p == 2:
        return [a]

    # Check solution existence on odd prime
    if legendre_symbol(a, p) != 1:
        return []

    # Simple case
    if p % 4 == 3:
        x = pow(a, (p + 1)//4, p)
        return [x, p-x]

    # Factor p-1 on the form q * 2^s (with Q odd)
    q, s = p - 1, 0
    while q % 2 == 0:
        s += 1
        q //= 2

    # Select a z which is a quadratic non resudue modulo p
    z = 1
    while legendre_symbol(z, p) != -1:
        z += 1
    c = pow(z, q, p)

    # Search for a solution
    x = pow(a, (q + 1)//2, p)
    t = pow(a, q, p)
    m = s
    while t != 1:
        # Find the lowest i such that t^(2^i) = 1
        i, e = 0, 2
        for i in range(1, m):
            if pow(t, e, p) == 1:
                break
            e *= 2

        # Update next value to iterate
        b = pow(c, 2**(m - i - 1), p)
        x = (x * b) % p
        t = (t * b * b) % p
        c = (b * b) % p
        m = i
    return [x, p-x]

def factor(bits, N, rp, rq, error_size):
    candidates = [(mpz(1), mpz(1))]
    rp, rq = mpz(rp), mpz(rq)
    rp, rq = (rp >> error_size) << error_size, (rq >> error_size) << error_size

    for i in range(1, bits):
        next_candidates = set()

        for p, q in candidates:
            if (N - p * q).bit_test(i):
                next_candidates.add((p.bit_set(i), q))
                next_candidates.add((p, q.bit_set(i)))
            else:
                next_candidates.add((p, q))
                next_candidates.add((p.bit_set(i), q.bit_set(i)))

        candidates = set()

        for p, q in next_candidates:
            if (q, p) in candidates:
                continue
            if i < error_size or ((p & rp).bit_test(i) == rp.bit_test(i) and (q & rq).bit_test(i) == rq.bit_test(i)):
                if p * q == N:
                    return int(p), int(q)
                candidates.add((p, q))

    return None, None


def modsqrt(x, p, q):
    from itertools import product

    yp = prime_mod_sqrt(x, p)
    yq = prime_mod_sqrt(x, q)
    ms = []
    for pat in product(yp, yq): 
        ms.append(CRT_list([int(pat[0]), int(pat[1])], [p, q]))
    return ms

def encrypt(m, x, k, n):
    """
    m: plaintext (experience)
    x: knowledge
    k: boggart
    n: modulo (magic)
    """
    F = Zmod(n)
    M = matrix(F, [[m, -1], [1, 0]])
    C = M^(k-1) * matrix(F, [[m], [x]])
    return C[0,0]

def decrypt(c, x, k, p, q):
    """
    c: ciphertext (assert legendre_symbol(c^2 - 4, p) == 1 and legendre_symbol(c^2 - 4, q) == 1
    x: knowledge
    k: boggart
    p, q: n = p * q
    """
    tpt = c
    n = p * q
    d = pow(k, -1, (p-1)*(q-1))

    for tmt in modsqrt(int(c)^2-4, p, q):
        try:
            tn = (tpt + tmt) * inverse_mod(2, n) % n
            t = pow(tn, d, n)
            tinv = pow(t, -1, n)
            return (t + tinv) % n
        except ZeroDivisionError:
            continue

trial = 0

def solve():
    global trial

    THRESHOLD = 0.53
    BITS = 512

    while True:
        trial += 1
        print("[+] trial: {}".format(trial))
        # wait for good parameters
        sock = Socket("localhost", 19999)
        sock.recvline()

        X = int(sock.recvline())
        N = int(sock.recvline())
        Y = int(sock.recvline())
        C = int(sock.recvline())
        sock.close()

        print("[+] X = {} / {}".format(bin(X).count("1"), BITS))
        if not bin(X).count("1") >= THRESHOLD * BITS:
            continue

        try:
            inverse_mod(2, N)
        except ZeroDivisionError:
            continue

        # part 1. factoring
        print("[+] facotoring...")
        p, q = factor(BITS, N, X, X, 14)
        if p is None or q is None:
            continue
        print("[+] done")


        # check if ciphertext is good
        if legendre_symbol(Y^2 - 4, p) == 1 and legendre_symbol(Y^2 - 4, q) == 1:
            if legendre_symbol(C^2 - 4, p) == 1 and legendre_symbol(C^2 - 4, q) == 1:
                break
        print("[-] bad ciphertexts")

    # part 2. find primes
    prime_list = []
    prime = 31337
    for _ in range(36):
        prime = next_prime(prime)
        prime_list.append(prime)
    
    for i in range(14, 18):
        prime_left = prime_list[:i]
        prime_right = prime_list[i:]

        from itertools import combinations
        m = bytes_to_long(b'the boy who lived')
        table = {}
        print("[+] Meet in the middle: {}...".format(i))
        for ps in combinations(prime_left, 8):
            key = encrypt(m, 2, product(ps), N)
            key = int(key)
            table[key] = ps

        print("[+] Turning point...")

        for ps in combinations(prime_right, 8):
            key = decrypt(Y, 2, product(ps), p, q)
            if key is None:
                print("NONE")
                continue
            key = int(key)
            if key in table:
                primes = list(table[key]) + list(ps)
                break
        else:
            print("[-] Bad parameter")
            continue

        print("[+] Found boggarts")

        # part 3. decrypt
        plaintext = decrypt(C, 2, product(primes), p, q)
        return plaintext


if __name__ == "__main__":
    while True:
        p = solve()
        if p is not None:
            print(p)
            quit()