Google CTF 2022 | Maybe Someday

#googlectf2022

#!/usr/bin/python3

# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from Crypto.Util.number import getPrime as get_prime
import math
import random
import os
import hashlib

# Suppose gcd(p, q) = 1. Find x such that
#   1. 0 <= x < p * q, and
#   2. x = a (mod p), and
#   3. x = b (mod q).
def crt(a, b, p, q):
    return (a*pow(q, -1, p)*q + b*pow(p, -1, q)*p) % (p*q)

def L(x, n):
    return (x-1) // n

class Paillier:
    def __init__(self):
        p = get_prime(1024)
        q = get_prime(1024)

        n = p * q
        λ = (p-1) * (q-1) // math.gcd(p-1, q-1) # lcm(p-1, q-1)
        g = random.randint(0, n-1)
        µ = pow(L(pow(g, λ, n**2), n), -1, n)

        self.n = n
        self.λ = λ
        self.g = g
        self.µ = µ

        self.p = p
        self.q = q

    # https://www.rfc-editor.org/rfc/rfc3447#section-7.2.1
    def pad(self, m):
        padding_size = 2048//8 - 3 - len(m)
        
        if padding_size < 8:
            raise Exception('message too long')

        random_padding = b'\0' * padding_size
        while b'\0' in random_padding:
            random_padding = os.urandom(padding_size)

        return b'\x00\x02' + random_padding + b'\x00' + m

    def unpad(self, m):
        if m[:2] != b'\x00\x02':
            raise Exception('decryption error')

        random_padding, m = m[2:].split(b'\x00', 1)

        if len(random_padding) < 8:
            raise Exception('decryption error')

        return m

    def public_key(self):
        return (self.n, self.g)

    def secret_key(self):
        return (self.λ, self.µ)

    def encrypt(self, m):
        g = self.g
        n = self.n

        m = self.pad(m)
        m = int.from_bytes(m, 'big')

        r = random.randint(0, n-1)
        c = pow(g, m, n**2) * pow(r, n, n**2) % n**2

        return c

    def decrypt(self, c):
        λ = self.λ
        µ = self.µ
        n = self.n

        m = L(pow(c, λ, n**2), n) * µ % n
        m = m.to_bytes(2048//8, 'big')

        return self.unpad(m)

    def fast_decrypt(self, c):
        λ = self.λ
        µ = self.µ
        n = self.n
        p = self.p
        q = self.q

        rp = pow(c, λ, p**2)
        rq = pow(c, λ, q**2)
        r = crt(rp, rq, p**2, q**2)
        m = L(r, n) * µ % n
        m = m.to_bytes(2048//8, 'big')

        return self.unpad(m)

def challenge(p):
    secret = os.urandom(2)
    secret = hashlib.sha512(secret).hexdigest().encode()

    c0 = p.encrypt(secret)
    print(f'{c0 = }')

    # # The secret has 16 bits of entropy.
    # # Hence 16 oracle calls should be sufficient, isn't it?
    # for _ in range(16):
    #     c = int(input())
    #     try:
    #         p.decrypt(c)
    #         print('😀')
    #     except:
    #         print('😡')

    # I decided to make it non-interactive to make this harder.
    # Good news: I'll give you 25% more oracle calls to compensate, anyways.
    cs = [int(input()) for _ in range(20)]
    for c in cs:
        try:
            p.fast_decrypt(c)
            print('😀')
        except:
            print('😡')

    guess = input().encode()

    if guess != secret: raise Exception('incorrect guess!')

def main():
    with open('/flag.txt', 'r') as f:
      flag = f.read()

    p = Paillier()
    n, g = p.public_key()
    print(f'{n = }')
    print(f'{g = }')

    try:
        # Once is happenstance. Twice is coincidence...
        # Sixteen times is a recovery of the pseudorandom number generator.
        for _ in range(16):
            challenge(p)
            print('💡')
        print(f'🏁 {flag}')
    except:
        print('👋')

if __name__ == '__main__':
    main()
  • Paillier暗号 の公開鍵とある暗号文  c_0 、decryption padding oracle が20回分、non-interactiveで与えられる

  •  c_0 は2byteの値をsha512してhexdigestしたもの

  • 目的は  c_0 を当てること

  • これを16回連続でできればクリア

  • 16bit の値を20回のクエリで当てたいので、1クエリでおよそ1bit 程度の情報が得られると良い

  • 得られる情報は、与えた暗号文が復号できるか否か

    • 復号できない場合 = 不正なpaddingになっているとき
  • paddingは PKCS#1 v1.5 によるもので、 00 02 random 00 plaintext となるようにpaddingされる

    • 復号できない場合というのは、 先頭が 00 02 から始まっていないか、平文全体で 00 が一度も含まれないとき
  • paillier暗号には 加法準同型性 があり、  m に対する暗号文  c が与えられたとき、任意の  x をもってきて  x + m を暗号化した暗号文  c' を作成することができる

  • 今回はplaintextはhexdigestなので平文1byteは '0' から 'f' までの16種類しかなく、 0 は必ず含まれない

  •  c_0 をもらってきて、適当に加算することで 00 02 random ff (plaintext - alpha) とすると……

    • plaintext - alpha に 00 が含まれているかどうかで復号に成功するかどうかが変わる

    • plaintext は '0' から 'f' までの16種類のアスキーコードからなる

    • たとえば '1234512345' という平文だったときに '1010101010' を引くと 00 31 02 33 04 30 01 32 03 34 となって 00 が含まれるので復号に成功する

    • 一方、 'abcdeabcdef' という平文に対して '1010101010' を引くと 30 62 32 63 34 60 31 62 33 64 となって 00 が含まれないので復号には失敗する

  • これを考えると暗号文の任意の部分列に狙った数字が含まれるかどうか、ということがわかるオラクルになる

  • 平文の候補は  2^{16} 通りしかないので、これでうまく条件分岐できればオラクルを元に候補を絞り込めそう

  • 今回は先頭10バイトに '0' を含む|含まない、 '1' を含む|含まない……ということをオラクルにしてやった

from ptrlib import Socket
from Crypto.Util.number import inverse
from itertools import product
import hashlib


table = []
for b in product(list(range(256)), repeat=2):
    table.append(hashlib.sha512(bytes(b)).hexdigest())

size = 10
basenum = int("3000" * size, 16)
basealp = int("6100" * size, 16)
one = int("0100" * size, 16)

sock = Socket("nc maybe-someday.2022.ctfcompetition.com 1337")

n = int(sock.recvlineafter("n = "))
g = int(sock.recvlineafter("g = "))


for stage in range(16):
    print("stage: {}".format(stage + 1))
    c0 = int(sock.recvlineafter("c0 = "))


    # fill 0
    c1 = c0 * pow(g, 0xff << 1024, n**2) % (n**2)


    cs = []
    for i in range(20):
        if i <= 9:
            cs.append(c1 * inverse(pow(g, (basenum + one * i) << (1024 - size*8*2), n**2), n**2) % (n**2))
        elif i <= 0xf:
            cs.append(c1 * inverse(pow(g, (basealp + one * (i - 10)) << (1024 - size*8*2), n**2), n**2) % (n**2))
        else:
            cs.append(c1 * inverse(pow(g, (basenum + one * (i - 16)) << (1024 - size*8*2 - 8), n**2), n**2) % (n**2))

    for c in cs:
        sock.sendline(str(c))

    isin = [0 for _ in range(20)]
    for i, c in enumerate(cs):
        if '😀' in sock.recvline().decode():
            isin[i] = 1

    for h in table:
        head = h[:2*size][::2]
        head2 = h[:2*size+1][1::2]
        match = True
        for i in range(16):
            if isin[i] == 1:
                if hex(i)[2:] not in head:
                    match = False
                    break
            else:
                if hex(i)[2:] in head:
                    match = False
                    break

        for i in range(4):
            if isin[i+0x10] == 1:
                if hex(i)[2:] not in head2:
                    match = False
                    break
            else:
                if hex(i)[2:] in head2:
                    match = False
                    break

        if match:
            sock.sendline(h)
            line = sock.recvline().decode()
            if '💡' in line:
                break
            else:
                print(line)
                raise ValueError("bad luck")
    else:
        raise ValueError("not found")

sock.interactive()