#!/usr/bin/python import binascii from random import choice class Cipher: BLOCK_SIZE = 16 ROUNDS = 3 def __init__(self, key): assert(len(key) == self.BLOCK_SIZE * self.ROUNDS) self.key = key def __block_encrypt(self, block): enc = int.from_bytes(block, "big") for i in range(self.ROUNDS): k = int.from_bytes(self.key[i * self.BLOCK_SIZE:(i + 1) * self.BLOCK_SIZE], "big") enc &= k enc ^= k return hex(enc)[2:].rjust(self.BLOCK_SIZE * 2, "0") def __pad(self, msg): if len(msg) % self.BLOCK_SIZE != 0: return msg + (bytes([0]) * (self.BLOCK_SIZE - (len(msg) % self.BLOCK_SIZE))) else: return msg def encrypt(self, msg): m = self.__pad(msg) e = "" for i in range(0, len(m), self.BLOCK_SIZE): e += self.__block_encrypt(m[i:i + self.BLOCK_SIZE]) return e.encode() key = binascii.unhexlify("".join([choice(list("abcdef0123456789")) for a in range(Cipher.BLOCK_SIZE * Cipher.ROUNDS * 2)])) with open("flag", "rb") as f: flag = f.read() cipher = Cipher(key) while True: a = input("Would you like to encrypt [1], or try encrypting [2]? ") if a == "1": p = input("What would you like to encrypt: ") try: print(cipher.encrypt(binascii.unhexlify(p)).decode()) except: print("Invalid input. ") elif a == "2": for i in range(10): p = "".join([choice(list("abcdef0123456789")) for a in range(64)]) print("Encrypt this:", p) e = cipher.encrypt(binascii.unhexlify(p)).decode() c = input() if e != c: print("L") exit() print("W") print(flag.decode()) elif a.lower() == "quit": print("Bye") exit() else: print("Invalid input. ")
何回でも好きな平文を暗号化してくれるので、10回連続で向こうの平文を暗号化できればOK。暗号化できているかどうかは、向こうの暗号文と一致するかどうかで調べられる
暗号化方式は次の通り
があって、はそれぞれ16バイトある
1ブロックは次のように暗号化される(パディングはゼロパディング)
同じように暗号化できれば良いので、同じように暗号化できる鍵を探せば良い
のすべてのビットを1にして暗号化してみる。するとなのでになる。続いて、 となる
これでがもとまるが、これではまだ膨大な数の候補がある
ここでのビット目を0にしたやつを暗号化すると、のビット目が0のときは変わらず、のビット目が1のときはになる。このときのビット目が1なら[tex: c = *1]
from Crypto.Util.number import * from pwn import remote host, port = "crypto.2021.chall.actf.co", 21602 conn = remote(host, port) def get_ciphertext(conn, pt): conn.sendlineafter(b"try encrypting [2]? ", b"1") conn.sendlineafter(b"to encrypt: ", pt.hex().encode()) return bytes.fromhex(conn.recvline().strip().decode()) def encrypt_block(pt, lookup_table): bits = map(int, bin(bytes_to_long(pt))[2:].zfill(128)) ct = "" for i, bit in enumerate(bits): ct += lookup_table[i][bit] return long_to_bytes(int(ct, 2)) pt = b"\x00"*16 + b"\xff"*16 ct = get_ciphertext(conn, pt) zero_map = bin(bytes_to_long(ct[:16]))[2:].zfill(128) one_map = bin(bytes_to_long(ct[16:]))[2:].zfill(128) lookup_table = [] for i in range(128): lookup_table.append((zero_map[i], one_map[i])) conn.sendlineafter(b"try encrypting [2]? ", b"2") for i in range(10): conn.recvuntil(b"Encrypt this:") pt = bytes.fromhex(conn.recvline().strip().decode()) ct = b"" for i in range(0, len(pt), 16): ct += encrypt_block(pt[i:i+16], lookup_table) ct = ct.hex().encode() conn.sendline(ct) conn.interactive()
*1:k_2 \oplus 2^i