SECCON CTF 2022 Quals | Witche's symmetric exam

#SECCON_CTF_2022_Quals

#kurenaif

from Crypto.Cipher import AES
from Crypto.Random import get_random_bytes
from Crypto.Util.Padding import pad, unpad
from flag import flag, secret_spell

key = get_random_bytes(16)
nonce = get_random_bytes(16)


def encrypt():
    data = secret_spell
    gcm_cipher = AES.new(key, AES.MODE_GCM, nonce=nonce)
    gcm_ciphertext, gcm_tag = gcm_cipher.encrypt_and_digest(data)

    ofb_input = pad(gcm_tag + gcm_cipher.nonce + gcm_ciphertext, 16)

    ofb_iv = get_random_bytes(16)
    ofb_cipher = AES.new(key, AES.MODE_OFB, iv=ofb_iv)
    ciphertext = ofb_cipher.encrypt(ofb_input)
    return ofb_iv + ciphertext


def decrypt(data):
    ofb_iv = data[:16]
    ofb_ciphertext = data[16:]
    ofb_cipher = AES.new(key, AES.MODE_OFB, iv=ofb_iv)

    try:
        m = ofb_cipher.decrypt(ofb_ciphertext)
        temp = unpad(m, 16)
    except:
        return b"ofb error"

    try:
        gcm_tag = temp[:16]
        gcm_nonce = temp[16:32]
        gcm_ciphertext = temp[32:]
        gcm_cipher = AES.new(key, AES.MODE_GCM, nonce=gcm_nonce)

        plaintext = gcm_cipher.decrypt_and_verify(gcm_ciphertext, gcm_tag)
    except:
        return b"gcm error"

    if b"give me key" == plaintext:
        your_spell = input("ok, please say secret spell:").encode()
        if your_spell == secret_spell:
            return flag
        else:
            return b"Try Harder"

    return b"ok"


print(f"ciphertext: {encrypt().hex()}")
while True:
    c = input("ciphertext: ")
    print(decrypt(bytes.fromhex(c)))

暗号化は iv + OFB(pad(tag + nonce + GCM(nonce, m)))

復号はpaddingが正しいか / GCMの検証ができるかの2種類のオラクルが得られる

step1. OFBのkeystreamを求める

  • 本来、OFBは任意長の入力を受け付けることができるのでpaddingが必要ないが、この問題では入力をpad/unpadしているので padding oracle attack ができる

  • 次の原理で、任意の  E^{k}(iv) が求められる。OFBによる暗号化は  C_{k} = E^{k}(iv) \oplus P_{k} だから、これがわかればOFBへの入力を復元できる

  • OFBによる先頭ブロックの復号は  E(iv) \oplus C_{1} として行われる

  • 1ブロックだけの暗号文  C' を自由に選択して復号する。  C' の最後尾の1バイトを全探索しながら復号していき、復号に成功したとき復号された平文の最後尾の1バイトは  0x01 であることが期待できる。したがって、  E(iv)_{-1} = 0x01 \oplus C'_{-1} が成り立つ

  • この性質を使って後ろから  E(iv) を求められる

  • 同様に、kブロックの暗号文のkブロック目を自由に変化させながら復号オラクルを得ることで、任意の  E^{k}(iv) が求まる

  • あるいは同様に1ブロックだけで考えて、ivとして既知の  E(iv) を指定することでも  E^2(iv) を得ることができる。これを繰り返しても任意の  E^{k}(iv)

   def encrypt(target, oracle):
       B = [0 for _ in range(16)]
       E = [0 for _ in range(16)]
       for i in range(16):
           j = 15 - i
           for x in range(256):
               B_ = B[:j] + [x] + [k^(i+1) for k in E[j+1:]]
               if oracle(target + bytes(B_)):
                   E[j] = x ^ (i+1)
                   break
           else:
               raise ValueError("bad luck")
       return bytes(E)

step2. GCMのplaintextを求める

  • GCM による暗号化はCTRモードと同じで、ctr_0 = GHASH(nonce || 0 || pad)からカウントアップしていってXORされる。したがってECB(ctr_i)と、GHASHの計算に使われるECB(0) を求める必要がある

  • 都合の良いことに、いまはOFBのdecryption oracleにより、 iv を指定できる状態で、任意の  E(iv) が手に入るから、 H = ECB(0) を求めた後 ctr_0 を計算すれば、平文はかんたんに求められる

step3. GCMのciphertextを作って、tagをforgeryする

  • これもencryption oracleを持っているのでやるだけ
from ptrlib import Socket, xor
from Crypto.Util.Padding import pad, unpad
from Crypto.Cipher._mode_gcm import _GHASH, _ghash_clmul
from progress.spinner import Spinner
from struct import pack


def encrypt(target, oracle):
    B = [0 for _ in range(16)]
    E = [0 for _ in range(16)]

    for i in range(16):
        j = 15 - i
        for x in range(256):
            B_ = B[:j] + [x] + [k^(i+1) for k in E[j+1:]]
            if oracle(target + bytes(B_)):
                E[j] = x ^ (i+1)
                break
        else:
            raise ValueError("bad luck")
    return bytes(E)


spinner = Spinner()
sock = Socket("localhost", 9999)
c = sock.recvlineafter(r"ciphertext: ").strip().decode()
c = bytes.fromhex(c)


def oracle(c: bytes) -> bool:
    spinner.next()
    sock.sendlineafter("ciphertext: ", c.hex())
    line = sock.recvline()
    if b"ofb error" in line:
        return False
    return True

# step 1. decrypt OFB

iv, c = c[:16], c[16:]
last_iv = iv
ks = b""
while len(ks) != len(c):
    last_iv = encrypt(last_iv, oracle)
    ks += last_iv
spinner.finish()

m = unpad(xor(ks, c), 16)
tag, nonce, c = m[:16], m[16:32], m[32:]


# step 2. decrypt GCM
H = encrypt(b"\0" * 16, oracle)
ghash = lambda m: _GHASH(H, _ghash_clmul).update(m).digest()


def aes_ctr(target):
    rep = (16 - len(nonce) % 16) % 16 + 8
    j0 = ghash(nonce + b"\0" * rep + (8 * len(nonce)).to_bytes(8, "big"))
    iv = (int.from_bytes(j0, "big") + 1) & 0xFFFFFFFF
    ctr = int.from_bytes(j0[:12] + iv.to_bytes(4, "big"), "big")

    ks = b""
    while len(ks) < len(target):
        ks += encrypt(ctr.to_bytes(16, "big"), oracle)
        ctr += 1
    return xor(target, ks)


spell = aes_ctr(c)
spinner.finish()
print(f"{spell=}")

# step 3. encrypt GCM
c = aes_ctr(b"give me key")
S = aes_ctr(b"\0" * 16)

# theoremoon: なんかこのあたりで 間違ってる!!!!
c_pad = b"\0" * ((16 - len(c)) % 16)
tag = xor(ghash(c + c_pad + pack('>QQ', 0, len(c) * 8)), S)

payload = pad(tag + nonce + c, 16)
iv = b"\0" * 16
last_iv = iv
ks = b""
while len(ks)  < len(payload):
    last_iv = encrypt(last_iv, oracle)
    ks += last_iv
spinner.finish()
sock.sendlineafter("ciphertext: ", (iv + xor(payload, ks)).hex())

sock.interactive()