pbctf 2021 | Alkaloid Stream

#pbctf_2021

#rbtree

#!/usr/bin/env python3

import random
from flag import flag

def keygen(ln):
    # Generate a linearly independent key
    arr = [ 1 << i for i in range(ln) ]

    for i in range(ln):
        for j in range(i):
            if random.getrandbits(1):
                arr[j] ^= arr[i]
    for i in range(ln):
        for j in range(i):
            if random.getrandbits(1):
                arr[ln - 1 - j] ^= arr[ln - 1 - i]

    return arr

def gen_keystream(key):
    ln = len(key)
    
    # Generate some fake values based on the given key...
    fake = [0] * ln
    for i in range(ln):
        for j in range(ln // 3):
            if i + j + 1 >= ln:
                break
            fake[i] ^= key[i + j + 1]

    # Generate the keystream
    res = []
    for i in range(ln):
        t = random.getrandbits(1)
        if t:
            res.append((t, [fake[i], key[i]]))
        else:
            res.append((t, [key[i], fake[i]]))

    # Shuffle!
    random.shuffle(res)

    keystream = [v[0] for v in res]
    public = [v[1] for v in res]
    return keystream, public

def xor(a, b):
    return [x ^ y for x, y in zip(a, b)]

def recover_keystream(key, public):
    st = set(key)
    keystream = []
    for v0, v1 in public:
        if v0 in st:
            keystream.append(0)
        elif v1 in st:
            keystream.append(1)
        else:
            assert False, "Failed to recover the keystream"
    return keystream

def bytes_to_bits(inp):
    res = []
    for v in inp:
        res.extend(list(map(int, format(v, '08b'))))
    return res

def bits_to_bytes(inp):
    res = []
    for i in range(0, len(inp), 8):
        res.append(int(''.join(map(str, inp[i:i+8])), 2))
    return bytes(res)

flag = bytes_to_bits(flag)

key = keygen(len(flag))
keystream, public = gen_keystream(key)
assert keystream == recover_keystream(key, public)
enc = bits_to_bytes(xor(flag, keystream))

print(enc.hex())
print(public)

#ストリーム暗号

鍵系列 key に対して、公開鍵は[(v[0], v[1])]となっている。どちらかはfake 系列になっていて、 fake_i = \bigoplus key_j

というXor Sum。さらに公開鍵はシャッフルされている。

ここで fake_iを構成する key_iの要素は固定であることに気が付き、どういう関係になっているのかダンプしてみると

 fake_i = key_{i+1} \oplus key_{i+2} \oplus \dotsとなっていることに気がつく。

また、fake系列の末尾の要素はその次のkeyが存在しないため、かならず0になっている。

以上の観察より

  • 値が0になっている=末尾の鍵であることがわかり、0であるほうがfake、0でないほうがkeyであることがわかる

  • 以上で手に入れた key[-1] と同じ値がどこかにあり、それが fake[-2] である

  •  fake_{-3} = key_{-1} \oplus key_{-2} がわかっているので fake[-3] を探せる

……というように後ろから鍵系列を求めることができる

def X(xs):
    x = 0
    for i in range(len(xs)):
        x ^= xs[i]
    return x

def Z(key, ind):
    xs = [key[i] for i in ind]
    return X(xs)

def xor(a, b):
    return [x ^ y for x, y in zip(a, b)]

def bytes_to_bits(inp):
    res = []
    for v in inp:
        res.extend(list(map(int, format(v, '08b'))))
    return res

def bits_to_bytes(inp):
    res = []
    for i in range(0, len(inp), 8):
        res.append(int(''.join(map(str, inp[i:i+8])), 2))
    return bytes(res)

def make_pattern(ln):
    fake = [[] for _ in range(ln)]
    for i in range(ln):
        for j in range(ln // 3):
            if i + j + 1 >= ln:
                break
            fake[i].append( i + j + 1 )
    return fake

def recover_keystream(key, public):
    st = set(key)
    keystream = []
    for v0, v1 in public:
        if v0 in st:
            keystream.append(0)
        elif v1 in st:
            keystream.append(1)
        else:
            # assert False, "Failed to recover the keystream"
            keystream.append(None)
    return keystream


with open("output.txt") as f:
    enc = bytes.fromhex(f.readline().strip())
    pubkey = eval(f.readline().strip())
patterns = make_pattern(len(pubkey))

# -- recover key from last
used = set()
key = [None for _ in range(len(pubkey))]
for i in range(len(pubkey)):
    j = -(i+1)
    curKey = Z(key, patterns[j])

    for k in range(len(pubkey)):
        if k in used:
            continue

        s, t = pubkey[k]
        if s == curKey:
            used.add(k)
            key[j] = t
            break
        elif t == curKey:
            used.add(k)
            key[j] = s
            break
    else:
        raise ValueError("not found")


# -- recover keystream
keyst = recover_keystream(key, pubkey)
# print(keyst)

enc = bytes_to_bits(enc)
plain = xor(enc, keyst)
print(plain)
print(bits_to_bytes(plain))