ALLES! CTF 2021 | Secure Flag Service

#alles_ctf_2021

#!/usr/bin/env python3
import base64
from Crypto.Cipher import AES
from Crypto.Hash import SHA3_256, HMAC, BLAKE2s
from Crypto.Random import urandom, random
# from secret import FLAG, PASSWORD
# from secret import FLAG, PASSWORD
FLAG = b"FLAG{UOUOUO_YOU_GOT_WHOWHWO}"
PASSWORD = b"year the password so cool not or notnot"

encryption_key = BLAKE2s.new(data=PASSWORD + b'encryption_key').digest()
mac_key = BLAKE2s.new(data=PASSWORD + b'mac_key').digest()

def int_to_bytes(i):
    return i.to_bytes((i.bit_length() + 7) // 8, byteorder='big')

def encode(s):
    bits = bin(int.from_bytes(s, byteorder='big'))[2:]
    ret = ''
    
    for bit in bits:
        if bit == '1':
            if random.randrange(0,2):
                ret += '01'
            else:
                ret += '10'
        else:
            ret += '00'
    
    return int_to_bytes(int(ret, base=2))

def decode(s):
    bits = bin(int.from_bytes(s, byteorder='big'))[2:]
    if len(bits) % 2:
        bits = '0' + bits

    ret = ''

    for i in range(0,len(bits)-1,2):
        if int(bits[i] + bits[i+1],base=2):
            ret += '1'
        else:
            ret += '0'

    return int_to_bytes(int(ret, base=2))

def encrypt(m):
    nonce = urandom(8)

    aes = AES.new(key=encryption_key, mode=AES.MODE_CTR,nonce=nonce)
    tag = HMAC.new(key=mac_key, msg=m).digest()

    return nonce + aes.encrypt(encode(m) + tag)

def decrypt(c):
    try:
        aes = AES.new(key=encryption_key, mode=AES.MODE_CTR,nonce=c[:8])
        
        decrypted = aes.decrypt(c[8:])
        message, tag = decode(decrypted[:-16]), decrypted[-16:]

        HMAC.new(key=mac_key, msg=message).verify(mac_tag=tag)
        return message
    except ValueError:
        print("Get off my lawn or I call the police!!!")
        exit(1)

def main():
    try:
        encrypted_password = base64.b64decode(input('Encrypted password>'))
        password = decrypt(encrypted_password)
        
        if password == PASSWORD:
            print(str(base64.b64encode(encrypt(FLAG)), 'utf-8'))
        else:
            print("Wrong Password!!!")
    except:
        exit(1)

if __name__ == '__main__':
    main()

これに加えて PASSWORDencrypt したものが事前に与えられている。

encodedecode が特殊なので 101|10 000 となる。 11XOR したときに 1 は引き続き1になるけど 00 ではなくなる

CTRモードだからBit Flipping Attackはできる。ここで平文が変わるとHMACによる認証が失敗するのでオラクルが得られる。勝ち

from ptrlib import Process
from base64 import b64encode, b64decode
from tqdm import tqdm
from logging import getLogger, WARN

getLogger("ptrlib").setLevel(WARN)

CONN = "".split(" ")


encrypted_password = b64decode(b'kgsekWGeAwPhz6tbMyLd34Bg5pwhy2TkQJF7NRYC987Ibuiu/dmNHqyYXHV0kXlksThSRi83Qu2owAiUdT9pfqlY')
nonce = encrypted_password[:8]
ciphertext = encrypted_password[8:]

password = 0
for i in tqdm(range(len(ciphertext))):
    for j in range(4):
        flipped = bytearray(ciphertext)
        key = 0b11 << (3 - j)*2
        flipped[i] = flipped[i] ^ key

        sock = Process(CONN)
        sock.sendlineafter("password>", b64encode(nonce + bytes(flipped)))

        if "Get off" in sock.recvline().decode():
            password = password*2
        else:
            password = password*2 + 1
print("[+] password = {}".format(password))

from Crypto.Cipher import AES
from Crypto.Util.number import long_to_bytes
from Crypto.Hash import SHA3_256, HMAC, BLAKE2s


sock = Process(CONN)
sock.sendlineafter("password>", b64encode(encrypted_password))
encrypted_flag = b64decode(sock.recvline())

password = long_to_bytes(password)[:-8] # remove tag
encryption_key = BLAKE2s.new(data=password + b'encryption_key').digest()

def decode(s):
    bits = bin(int.from_bytes(s, byteorder='big'))[2:]
    if len(bits) % 2:
        bits = '0' + bits

    ret = ''

    for i in range(0,len(bits)-1,2):
        if int(bits[i] + bits[i+1],base=2):
            ret += '1'
        else:
            ret += '0'

    return long_to_bytes(int(ret, base=2))

def decrypt(c):
    aes = AES.new(key=encryption_key, mode=AES.MODE_CTR,nonce=c[:8])
    
    decrypted = aes.decrypt(c[8:])
    message, _ = decode(decrypted[:-16]), decrypted[-16:]

    return message

print(decrypt(encrypted_flag))