zer0pts CTF 2021 | 3-AES

#zer0ptsCTF2021

from Crypto.Cipher import AES
from Crypto.Random import get_random_bytes
from binascii import hexlify, unhexlify
from hashlib import md5
import os
import signal
from flag import flag

keys = [md5(os.urandom(3)).digest() for _ in range(3)]


def get_ciphers(iv1, iv2):
    return [
        AES.new(keys[0], mode=AES.MODE_ECB),
        AES.new(keys[1], mode=AES.MODE_CBC, iv=iv1),
        AES.new(keys[2], mode=AES.MODE_CFB, iv=iv2, segment_size=8*16),
    ]

def encrypt(m: bytes, iv1: bytes, iv2: bytes) -> bytes:
    assert len(m) % 16 == 0
    ciphers = get_ciphers(iv1, iv2)
    c = m
    for cipher in ciphers:
        c = cipher.encrypt(c)
    return c

def decrypt(c: bytes, iv1: bytes, iv2: bytes) -> bytes:
    assert len(c) % 16 == 0
    ciphers = get_ciphers(iv1, iv2)
    m = c
    for cipher in ciphers[::-1]:
        m = cipher.decrypt(m)
    return m

signal.alarm(3600)
while True:
    print("==== MENU ====")
    print("1. Encrypt your plaintext")
    print("2. Decrypt your ciphertext")
    print("3. Get encrypted flag")
    choice = int(input("> "))

    if choice == 1:
        plaintext = unhexlify(input("your plaintext(hex): "))
        iv1, iv2 = get_random_bytes(16), get_random_bytes(16)
        ciphertext = encrypt(plaintext, iv1, iv2)
        ciphertext = b":".join([hexlify(x) for x in [iv1, iv2, ciphertext]]).decode()
        print("here's the ciphertext: {}".format(ciphertext))

    elif choice == 2:
        ciphertext = input("your ciphertext: ")
        iv1, iv2, ciphertext = [unhexlify(x) for x in ciphertext.strip().split(":")]
        plaintext = decrypt(ciphertext, iv1, iv2)
        print("here's the plaintext(hex): {}".format(hexlify(plaintext).decode()))

    elif choice == 3:
        plaintext = flag
        iv1, iv2 = get_random_bytes(16), get_random_bytes(16)
        ciphertext = encrypt(plaintext, iv1, iv2)
        ciphertext = b":".join([hexlify(x) for x in [iv1, iv2, ciphertext]]).decode()
        print("here's the encrypted flag: {}".format(ciphertext))
        exit()

    else:
        exit()

AESの様々な暗号利用モードを組み合わせている様子。暗号化は次のように行われる

 c_1 = E_2(E_1(m_1) \oplus IV_2) \oplus E_3(IV_3)

 c_2 = E_2(E_1(m_2) \oplus E_2(E_1(m_1) \oplus IV_2)) \oplus E_3(c_1)

ここで c_1 = c_2 = IV_3となるように値を決めてやると D_2(c_2 \oplus E_3(c_1)) \oplus E_1(m_2) = D_2(c_1 \oplus E_3(IV_3)) \oplus E_1(m_2)

 = D_2(E_2(E_1(m_1) \oplus IV_2))  \oplus E_1(m_2) = E_1(m_1) \oplus IV_2 \oplus E_1(m_2)

したがって

 c_1 \oplus E_3(IV_3) = E_1(m_1) \oplus IV_2 \oplus E_1(m_2)  \cdots \triangle

となる。この式は E_1 E_3しか使っていないのでmeet-in-the-middle attackができる

別解

適当に c, IV_2, IV_3を決めて

 c_1 = E_2(E_1(m_1) \oplus IV_2) \oplus E_3(IV_3)

を得たあとに、 IV2だけ適当に変えて IV_2'として入力すると、 m'に関連する次の式が得られる

 c_1 = E_2(E_1(m_1') \oplus IV_2') \oplus E_3(IV_3)

この式から

 E_1(m_1) \oplus IV_2 = E_1(m_1') \oplus IV_2'

が得られるので、 k_1を探索できる

from ptrlib import Socket, Process
from subprocess import run, PIPE
from binascii import hexlify, unhexlify
from Crypto.Cipher import AES


sock = Socket("localhost", 9999)
sock.sendlineafter("> ", "2")

iv2 = "A" * 32
iv3 = "A" * 32
a = "A" * 32
b = "A" * 32
c = "A" * 32

sock.sendlineafter("ciphertext: ", "{}:{}:{}".format(iv2, iv3, a+b+c))
plaintext = unhexlify(sock.recvlineafter(": "))

A = hexlify(plaintext[:16]).decode()
B = hexlify(plaintext[16:32]).decode()

sock.sendlineafter("> ", "3")
flag = sock.recvlineafter("flag: ")

sock.close()

# ---

r = run(["./k1k3", A, B, iv2, iv3], stdout=PIPE)
k1, k3 = r.stdout.decode().strip().split("\n")
print("k1={}".format(k1))
print("k3={}".format(k3))

r = run(["./k2", A, B, iv2, iv3, k1, k3], stdout=PIPE)
k2 = r.stdout.decode().strip()
print("k2={}".format(k2))

# ---

keys = [
    unhexlify(k1),
    unhexlify(k2),
    unhexlify(k3),
]

def get_ciphers(iv1, iv2):
    return [
        AES.new(keys[0], mode=AES.MODE_ECB),
        AES.new(keys[1], mode=AES.MODE_CBC, iv=iv1),
        AES.new(keys[2], mode=AES.MODE_CFB, iv=iv2, segment_size=8*16),
    ]

def decrypt(c: bytes, iv1: bytes, iv2: bytes) -> bytes:
    assert len(c) % 16 == 0
    ciphers = get_ciphers(iv1, iv2)
    m = c
    for cipher in ciphers[::-1]:
        m = cipher.decrypt(m)
    return m

iv1, iv2, ciphertext = [unhexlify(x) for x in flag.decode().strip().split(":")]
plaintext = decrypt(ciphertext, iv1, iv2)
print(plaintext)
#include <openssl/aes.h>
#include <openssl/evp.h>
#include <stdlib.h>
#include <string.h>
#include <string>
#include <unordered_map>

void encrypt(const char* key, const char *data, int len, unsigned char* dest) {
    EVP_CIPHER_CTX *ctx = EVP_CIPHER_CTX_new();
    memset(dest, 0, len);
    int x;

    EVP_CIPHER_CTX_init(ctx);
    EVP_EncryptInit_ex(ctx, EVP_aes_128_ecb(), NULL, (unsigned char*)key, NULL);
    EVP_EncryptUpdate(ctx, dest, &x, (const unsigned char*)data, len);
    EVP_CIPHER_CTX_free(ctx);
}

char fromHexChar(char c) {
    if ('0' <= c && c <= '9') {
        return c - '0';
    }
    if ('a' <= c && c <= 'f') {
        return c - 'a' + 10;
    }
    if ('A' <= c && c <= 'F') {
        return c - 'A' + 10;
    }
    exit(EXIT_FAILURE);
}

char toHexChar(unsigned char c) {
    const static char table[] = {'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f'};
    return table[c];
}

std::string fromhex(const char* source) {
    std::string s;
    while(*source) {
        char v = fromHexChar(*source) << 4;
        source++;
        v = v | fromHexChar(*source);
        source++;

        s.push_back(v);
    }
    return s;
}

std::string tohex(const char *src, int len) {
    const unsigned char *source = (const unsigned char*) src;
    std::string s;
    for (int i = 0; i < len; i++) {
        s.push_back(toHexChar((*source) >> 4));
        s.push_back(toHexChar((*source) & 0xf));
        source++;
    }
    return s;
}

std::string x(std::string a, std::string b) {
    std::string c;
    for (int i = 0; i < a.size(); i++) {
        c.push_back( a[i] ^ b[i] );
    }
    return c;
}

int main(int argc, char **argv) {
    std::string a = fromhex(argv[1]);
    std::string b = fromhex(argv[2]);
    std::string iv2 = fromhex(argv[3]);
    std::string iv3 = fromhex(argv[4]);

    std::unordered_map<std::string, std::string> map;
    auto left = x(iv2, iv3);

    unsigned char* m1 = (unsigned char*)malloc(sizeof(unsigned char) * a.size()) ;
    unsigned char* m2 = (unsigned char*)malloc(sizeof(unsigned char) * a.size()) ;

    char key[16] = {0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0};
    for (int i = 0; i < 256; i++) {
        key[13] = i;
        for (int j = 0; j < 256; j++) {
            key[14] = j;
            for (int k = 0; k < 256; k++) {
                key[15] = k;

                encrypt(key, iv3.c_str(), iv3.size(), m1);
                auto t = x(left, std::string((char*) m1));
                map[t] = tohex(key, 16);
            }
        }
    }

    std::string k1, k3;

    for (int i = 0; i < 256; i++) {
        key[13] = i;
        for (int j = 0; j < 256; j++) {
            key[14] = j;
            for (int k = 0; k < 256; k++) {
                key[15] = k;

                encrypt(key, a.c_str(), a.size(), m1);
                encrypt(key, b.c_str(), b.size(), m2);
                auto t = x(std::string((char*) m2), std::string((char*) m1));

                if (map.find(t) != map.end()) {
                    printf("%s\n", tohex(key, 16).c_str());
                    printf("%s\n", map.at(t).c_str());

                    k1 = std::string((char*)key);
                    k3 = fromhex(map.at(t).c_str());

                    return 0;
                }
            }
        }
    }
    return 1;
}
#include <openssl/aes.h>
#include <openssl/evp.h>
#include <stdlib.h>
#include <string.h>
#include <string>
#include <unordered_map>

void encrypt(const char* key, const char *data, int len, unsigned char* dest) {
    EVP_CIPHER_CTX *ctx = EVP_CIPHER_CTX_new();
    memset(dest, 0, len);
    int x;

    EVP_CIPHER_CTX_init(ctx);
    EVP_EncryptInit_ex(ctx, EVP_aes_128_ecb(), NULL, (unsigned char*)key, NULL);
    EVP_EncryptUpdate(ctx, dest, &x, (const unsigned char*)data, len);
    EVP_CIPHER_CTX_free(ctx);
}

char fromHexChar(char c) {
    if ('0' <= c && c <= '9') {
        return c - '0';
    }
    if ('a' <= c && c <= 'f') {
        return c - 'a' + 10;
    }
    if ('A' <= c && c <= 'F') {
        return c - 'A' + 10;
    }
    exit(EXIT_FAILURE);
}

char toHexChar(unsigned char c) {
    const static char table[] = {'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f'};
    return table[c];
}

std::string fromhex(const char* source) {
    std::string s;
    while(*source) {
        char v = fromHexChar(*source) << 4;
        source++;
        v = v | fromHexChar(*source);
        source++;

        s.push_back(v);
    }
    return s;
}

std::string tohex(const char *src, int len) {
    const unsigned char *source = (const unsigned char*) src;
    std::string s;
    for (int i = 0; i < len; i++) {
        s.push_back(toHexChar((*source) >> 4));
        s.push_back(toHexChar((*source) & 0xf));
        source++;
    }
    return s;
}

std::string x(std::string a, std::string b) {
    std::string c;
    for (int i = 0; i < a.size(); i++) {
        c.push_back( a[i] ^ b[i] );
    }
    return c;
}

int main(int argc, char **argv) {
    std::string a = fromhex(argv[1]);
    std::string b = fromhex(argv[2]);
    std::string iv2 = fromhex(argv[3]);
    std::string iv3 = fromhex(argv[4]);
    std::string k1 = fromhex(argv[5]);
    std::string k3 = fromhex(argv[6]);

    unsigned char* m1 = (unsigned char*)malloc(sizeof(unsigned char) * a.size()) ;
    unsigned char* m2 = (unsigned char*)malloc(sizeof(unsigned char) * a.size()) ;

    encrypt(k1.c_str(), a.c_str(), a.size(), m1);
    std::string target = x(std::string((char*)m1), iv2);

    encrypt(k3.c_str(), iv3.c_str(), iv3.size(), m1);
    std::string cmp_to = x(std::string((char*)m1), iv3);

    char key[16] = {0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0};
    for (int i = 0; i < 256; i++) {
        key[13] = i;
        for (int j = 0; j < 256; j++) {
            key[14] = j;
            for (int k = 0; k < 256; k++) {
                key[15] = k;

                encrypt(key, target.c_str(), target.size(), m1);
                std::string result((char*) m1);
                if (result == cmp_to) {
                    printf("%s\n", tohex(key, 16).c_str());
                    return 0;
                }
            }
        }
    }
}