Google CTF 2021 | Pythia

#googlectf2021

#!/usr/bin/python -u
import random
import string
import time

from base64 import b64encode, b64decode
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
from cryptography.hazmat.primitives.kdf.scrypt import Scrypt

max_queries = 150
query_delay = 10

passwords = [bytes(''.join(random.choice(string.ascii_lowercase) for _ in range(3)), 'UTF-8') for _ in range(3)]
flag = open("flag.txt", "rb").read()

def menu():
    print("What you wanna do?")
    print("1- Set key")
    print("2- Read flag")
    print("3- Decrypt text")
    print("4- Exit")
    try:
        return int(input(">>> "))
    except:
        return -1

print("Welcome!\n")

key_used = 0

for query in range(max_queries):
    option = menu()

    if option == 1:
        print("Which key you want to use [0-2]?")
        try:
            i = int(input(">>> "))
        except:
            i = -1
        if i >= 0 and i <= 2:
          key_used = i
        else:
          print("Please select a valid key.")
    elif option == 2:
        print("Password?")
        passwd = bytes(input(">>> "), 'UTF-8')

        print("Checking...")
        # Prevent bruteforce attacks...
        time.sleep(query_delay)
        if passwd == (passwords[0] + passwords[1] + passwords[2]):
            print("ACCESS GRANTED: " + flag.decode('UTF-8'))
        else:
            print("ACCESS DENIED!")
    elif option == 3:
        print("Send your ciphertext ")

        ct = input(">>> ")
        print("Decrypting...")
        # Prevent bruteforce attacks...
        time.sleep(query_delay)
        try:
            nonce, ciphertext = ct.split(",")
            nonce = b64decode(nonce)
            ciphertext = b64decode(ciphertext)
        except:
            print("ERROR: Ciphertext has invalid format. Must be of the form \"nonce,ciphertext\", where nonce and ciphertext are base64 strings.")
            continue

        kdf = Scrypt(salt=b'', length=16, n=2**4, r=8, p=1, backend=default_backend())
        key = kdf.derive(passwords[key_used])
        try:
            cipher = AESGCM(key)
            plaintext = cipher.decrypt(nonce, ciphertext, associated_data=None)
        except:
            print("ERROR: Decryption failed. Key was not correct.")
            continue

        print("Decryption successful")
    elif option == 4:
        print("Bye!")
        break
    else:
        print("Invalid option!")
    print("You have " + str(max_queries - query) + " trials left...\n")

AES - GCM

なんもしらん

from base64 import b64encode, b64decode
from cryptography.hazmat.primitives.ciphers import (
        Cipher, algorithms, modes
    )
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
from cryptography.exceptions import InvalidTag
from Crypto.Cipher import AES
from Crypto.Random import get_random_bytes
from Crypto.Util.number import long_to_bytes, bytes_to_long
from bitstring import BitArray, Bits
import binascii
import sys


ALL_ZEROS = b'\x00'*16
GCM_BITS_PER_BLOCK = 128


def check_correctness(keyset, nonce, ct):
    flag = True

    for i in range(len(keyset)):
        aesgcm = AESGCM(key)
        try:
            aesgcm.decrypt(nonce, ct, None)
        except InvalidTag:
            print('key %s failed' % i)
            flag = False

    if flag:
        print("All %s keys decrypted the ciphertext" % len(keyset))

def pad(a):
    if len(a) < GCM_BITS_PER_BLOCK:
        diff = GCM_BITS_PER_BLOCK - len(a)
        zeros = ['0'] * diff
        a = a + zeros
    return a


def bytes_to_element(val, field, a):
    bits = BitArray(val)
    result = field.fetch_int(0)
    for i in range(len(bits)):
        if bits[i]:
            result += a^i
    return result

def multi_collide_gcm(keyset, nonce, tag, first_block=None, use_magma=False):

    # initialize matrix and vector spaces
    P.<x> = PolynomialRing(GF(2))
    p = x^128 + x^7 + x^2 + x + 1
    GFghash.<a> = GF(2^128,'x',modulus=p)
    if use_magma:
        t = "p:=IrreducibleLowTermGF2Polynomial(128); GFghash<a> := ext<GF(2) | p>;"
        magma.eval(t)
    else:
        R = PolynomialRing(GFghash, 'x')

    # encode length as lens
    if first_block is not None:
        ctbitlen = (len(keyset) + 1) * GCM_BITS_PER_BLOCK
    else:
        ctbitlen = len(keyset) * GCM_BITS_PER_BLOCK
    adbitlen = 0
    lens = (adbitlen << 64) | ctbitlen
    lens_byte = int(lens).to_bytes(16,byteorder='big')
    lens_bf = bytes_to_element(lens_byte, GFghash, a)

    # increment nonce
    nonce_plus = int((int.from_bytes(nonce,'big') << 32) | 1).to_bytes(16,'big')

    # encode fixed ciphertext block and tag
    if first_block is not None:
        block_bf = bytes_to_element(first_block, GFghash, a)
    tag_bf = bytes_to_element(tag, GFghash, a)
    keyset_len = len(keyset)

    if use_magma:
        I = []
        V = []
    else:
        pairs = []

    for k in keyset:
        # compute H
        aes = AES.new(k, AES.MODE_ECB)
        H = aes.encrypt(ALL_ZEROS)
        h_bf = bytes_to_element(H, GFghash, a)

        # compute P
        P = aes.encrypt(nonce_plus)
        p_bf = bytes_to_element(P, GFghash, a)

        if first_block is not None:
            # assign (lens * H) + P + T + (C1 * H^{k+2}) to b
            b = (lens_bf * h_bf) + p_bf + tag_bf + (block_bf * h_bf^(keyset_len+2))
        else:
            # assign (lens * H) + P + T to b
            b = (lens_bf * h_bf) + p_bf + tag_bf

        # get pair (H, b*(H^-2))
        y =  b * h_bf^-2
        if use_magma:
            I.append(h_bf)
            V.append(y)
        else:
            pairs.append((h_bf, y))

    # compute Lagrange interpolation
    if use_magma:
        f = magma("Interpolation(%s,%s)" % (I,V)).sage()
    else:
        f = R.lagrange_polynomial(pairs)
    coeffs = f.list()
    coeffs.reverse()

    # get ciphertext
    if first_block is not None:
        ct = list(map(str, block_bf.polynomial().list()))
        ct_pad = pad(ct)
        ct = Bits(bin=''.join(ct_pad))
    else:
        ct = ''
    
    for i in range(len(coeffs)):
        ct_i = list(map(str, coeffs[i].polynomial().list()))
        ct_pad = pad(ct_i)
        ct_i = Bits(bin=''.join(ct_pad))
        ct += ct_i
    ct = ct.bytes
    
    return ct+tag

from pwn import remote
import ast

def recv_menu(io, option):
    io.recvuntil(b">>> ")
    io.sendline(str(option).encode())

def set_key(io, key_id):
    recv_menu(io, 1)
    io.recvuntil(b">>> ")
    io.sendline(str(key_id).encode())

def decrypt_text(io, nonce, ciphertext):
    recv_menu(io, 3)
    print(io.recvuntil(b">>> "))
    ct = b64encode(nonce) + b"," + b64encode(ciphertext)
    io.sendline(ct)
    print(io.recvline())
    s = io.recvline()
    print("s =", s.decode())
    if b"Decryption successful" in s:
        return True
    return False

def get_password(io, key_id, poss_keys):
    set_key(io, key_id)
    print("key id =", key_id)
    #intial filtering
    increment = len(poss_keys) // 60
    for i in range(0, len(poss_keys), increment):
        keyset = poss_keys[i: i+increment]
        print("keyset len =", len(keyset))
        first_block = b'\x01'
        nonce = b'\x00'*12
        tag = b'\x01'*16
        print("generating collision")
        ct = multi_collide_gcm(keyset, nonce, tag, first_block=first_block)
        print("done collision")
        if decrypt_text(io, nonce, ct):
            poss_keys = keyset
            break             

    print("intial filtering done")
    print("len(poss_keys) =", len(poss_keys))

    while len(poss_keys) > 1:
        print("len(poss_keys) =", len(poss_keys))
        if len(poss_keys) < 10:
            print(poss_keys)
        mid = len(poss_keys) // 2
        set_1 = poss_keys[:mid]
        set_2 = poss_keys[mid:]
        first_block = b'\x01'
        nonce = b'\x00'*12
        tag = b'\x01'*16
        print("generating collision")
        ct = multi_collide_gcm(set_1, nonce, tag, first_block=first_block)
        print("done collision")
        if decrypt_text(io, nonce, ct):
            poss_keys = set_1
        else:
            poss_keys = set_2
    return poss_keys[0]

host, port = "pythia.2021.ctfcompetition.com", 1337
io = remote(host, port)

possible_keys = ast.literal_eval(open("possible_keys").read())

password = possible_keys[get_password(io, 0, list(possible_keys.keys()))]
password += possible_keys[get_password(io, 1, list(possible_keys.keys()))]
password += possible_keys[get_password(io, 2, list(possible_keys.keys()))]

print('password = ', password)
io.interactive()