#!/usr/bin/env python3 import random from flag import flag def keygen(ln): # Generate a linearly independent key arr = [ 1 << i for i in range(ln) ] for i in range(ln): for j in range(i): if random.getrandbits(1): arr[j] ^= arr[i] for i in range(ln): for j in range(i): if random.getrandbits(1): arr[ln - 1 - j] ^= arr[ln - 1 - i] return arr def gen_keystream(key): ln = len(key) # Generate some fake values based on the given key... fake = [0] * ln for i in range(ln): for j in range(ln // 3): if i + j + 1 >= ln: break fake[i] ^= key[i + j + 1] # Generate the keystream res = [] for i in range(ln): t = random.getrandbits(1) if t: res.append((t, [fake[i], key[i]])) else: res.append((t, [key[i], fake[i]])) # Shuffle! random.shuffle(res) keystream = [v[0] for v in res] public = [v[1] for v in res] return keystream, public def xor(a, b): return [x ^ y for x, y in zip(a, b)] def recover_keystream(key, public): st = set(key) keystream = [] for v0, v1 in public: if v0 in st: keystream.append(0) elif v1 in st: keystream.append(1) else: assert False, "Failed to recover the keystream" return keystream def bytes_to_bits(inp): res = [] for v in inp: res.extend(list(map(int, format(v, '08b')))) return res def bits_to_bytes(inp): res = [] for i in range(0, len(inp), 8): res.append(int(''.join(map(str, inp[i:i+8])), 2)) return bytes(res) flag = bytes_to_bits(flag) key = keygen(len(flag)) keystream, public = gen_keystream(key) assert keystream == recover_keystream(key, public) enc = bits_to_bytes(xor(flag, keystream)) print(enc.hex()) print(public)
鍵系列 key
に対して、公開鍵は[(v[0], v[1])]
となっている。どちらかはfake
系列になっていて、
というXor Sum。さらに公開鍵はシャッフルされている。
ここでを構成するの要素は固定であることに気が付き、どういう関係になっているのかダンプしてみると
となっていることに気がつく。
また、fake
系列の末尾の要素はその次のkey
が存在しないため、かならず0になっている。
以上の観察より
値が0になっている=末尾の鍵であることがわかり、0であるほうがfake、0でないほうがkeyであることがわかる
以上で手に入れた
key[-1]
と同じ値がどこかにあり、それがfake[-2]
であるがわかっているので
fake[-3]
を探せる
……というように後ろから鍵系列を求めることができる
def X(xs): x = 0 for i in range(len(xs)): x ^= xs[i] return x def Z(key, ind): xs = [key[i] for i in ind] return X(xs) def xor(a, b): return [x ^ y for x, y in zip(a, b)] def bytes_to_bits(inp): res = [] for v in inp: res.extend(list(map(int, format(v, '08b')))) return res def bits_to_bytes(inp): res = [] for i in range(0, len(inp), 8): res.append(int(''.join(map(str, inp[i:i+8])), 2)) return bytes(res) def make_pattern(ln): fake = [[] for _ in range(ln)] for i in range(ln): for j in range(ln // 3): if i + j + 1 >= ln: break fake[i].append( i + j + 1 ) return fake def recover_keystream(key, public): st = set(key) keystream = [] for v0, v1 in public: if v0 in st: keystream.append(0) elif v1 in st: keystream.append(1) else: # assert False, "Failed to recover the keystream" keystream.append(None) return keystream with open("output.txt") as f: enc = bytes.fromhex(f.readline().strip()) pubkey = eval(f.readline().strip()) patterns = make_pattern(len(pubkey)) # -- recover key from last used = set() key = [None for _ in range(len(pubkey))] for i in range(len(pubkey)): j = -(i+1) curKey = Z(key, patterns[j]) for k in range(len(pubkey)): if k in used: continue s, t = pubkey[k] if s == curKey: used.add(k) key[j] = t break elif t == curKey: used.add(k) key[j] = s break else: raise ValueError("not found") # -- recover keystream keyst = recover_keystream(key, pubkey) # print(keyst) enc = bytes_to_bits(enc) plain = xor(enc, keyst) print(plain) print(bits_to_bytes(plain))