corCTF 2021 | babypad

#corctf2021

from Crypto.Cipher import AES
from Crypto.Util import Counter
from Crypto.Util.Padding import pad, unpad
from Crypto.Util.number import bytes_to_long
import os

flag = open("/challenge/flag.txt").read().encode()
key = os.urandom(16)

def encrypt(pt):
    iv = os.urandom(16)
    ctr = Counter.new(128, initial_value=bytes_to_long(iv))
    cipher = AES.new(key, AES.MODE_CTR, counter=ctr)
    return iv + cipher.encrypt(pad(pt, 16))

def decrypt(ct):
    try:
        iv = ct[:16]
        ct = ct[16:]
        ctr = Counter.new(128, initial_value=bytes_to_long(iv))
        cipher = AES.new(key, AES.MODE_CTR, counter=ctr)
        pt = cipher.decrypt(ct)
        unpad(pt, 16)
        return 1
    except Exception as e:
        return 0

def main():
    print(encrypt(flag).hex())
    while True:
        try:
            print(decrypt(bytes.fromhex(input("> "))))
        except Exception as e:
            pass

main()

CTRモードPadding Oracle Attack。そもそもCTRモードでpaddingやる必要ないけど……

from Crypto.Cipher import AES
from Crypto.Util import Counter
from Crypto.Util.Padding import pad, unpad
from Crypto.Util.number import bytes_to_long
import os
import string

from pwn import remote, xor

def unpad_ct(iv, ct):
    io.recvuntil(b"> ")
    io.sendline((iv + ct).hex().encode())
    if b"1" in io.recvline().strip():
        return True
    return False

def decrypt_block(iv, block):
    known_plaintext = b"\x00"*16
    target = b"\x00"*16
    known_plaintext, target = map(bytearray, (known_plaintext, target))
    for ind in range(15, -1, -1):
        pad_byte = 16 - ind
        for i in range(15, ind-1, -1):
            target[i] = pad_byte
        for plaintext_byte in charset:
            known_plaintext[ind] = plaintext_byte
            ct = xor(block, known_plaintext)
            ct = xor(ct, target)
            if pad_byte == 16:
                ct = bytes(16) + ct
            ct = bytearray(ct)
            if unpad_ct(iv, ct):
                if pad_byte == 1:
                    ct[-2] = ct[-2] + 1
                    if unpad_ct(iv, ct):
                        break
                    else:
                        known_plaintext[-1] = 2
                print("known_plaintext =", known_plaintext)
                break
    return known_plaintext

charset = "_" + string.printable
charset = charset.encode() + bytes([i for i in range(1, 17, 1)])

host, port = "babypad.be.ax", 1337
io = remote(host, port)

ciphertext = io.recvline().strip().decode()
ciphertext = bytes.fromhex(ciphertext)

iv, ct = ciphertext[:16], ciphertext[16:]

block_num = 0

print(decrypt_block(iv, ct[block_num*16: (block_num+1)*16]))