InCTF 2020 | FaultyLFSR

#inctf2020 #LFSR

from random import Random
from flag import flag ,seeds
from Crypto.Cipher import AES
from hashlib import sha256


SECRET = 14810031
#assert seeds[3]%seeds[1] == 0 CLUE-1
def generate() :
    masks = [43, 578, 22079, 142962]
    for i in range(4) :
        assert (masks[i].bit_length() == seeds[i].bit_length()) == True
    

    l = [lfsr(seeds[i],masks[i],masks[i].bit_length()) for i in range(4)]
    data = ''
    for i in range(10000) :
        data += str(combine(l[0].next(),l[1].next(),l[2].next(),l[3].next()))
    f = open('rem_data','w')
    f.write(data[160:])
    f.close()
    return data[:160]

class lfsr():
    def __init__(self, init, mask,masklength):
        self.rand = Random()
        self.rand.seed(SECRET + mask)
        self.init = init
        self.mask = mask
        self.masklength = masklength


    def next(self):
        r  = self.rand.getrandbits(20)
        nextdata = ((self.init << 1)&0xffffff) ^ (self.mask & r)
        output = 0
        l = nextdata.bit_length()
        for i in range(0,l//2,2) :
            output += int(bin(nextdata)[2:][i])
            self.init = nextdata ^ output
        return output%2

def combine(a,b,c,d) :
    return (a^b)^(a|c)^(b|c)^(c|d)

def enc() :
    f = flag.ljust(16*((len(flag)/16)+1),'0')
    key = sha256(generate()).digest()
    return AES.new(key,AES.MODE_ECB).encrypt(f)

if __name__ == "__main__" :
    f = open('flag.enc','w')
    f.write(enc())
    f.close()

overview

  • flag をパディングしている

  • 鍵が generate() で生成されたものの sha256

  • generate()

    • mask と同じbit数の seed がある( seed はわからない)

      • seed[3] % seed[1] == 0 らしい
    • mask , seed が4つあり、LFSRのインスタンスを4つ作る

    • 鍵となるbit列は、4つのLFSRの出力をbit毎に combine することによって得られる

      • combine(a^b)^(a|c)^(b|c)^(c|d) という操作

        • 先頭だけ a^b になっていていかにも怪しい
    • 出力の先頭160bitが与えられており、残りの 9840ibt が generate() の返り値として使われる

  • lfsr は 通常の

    #LFSR とは少し様子が違う

  • r のシードは SECRET + maskSECRET は与えられている)

  • rMersenne Twister の出力20bit

  • output が 1bit ではない(返り値は1bit)

  •  def next(self):
         r = self.rand.getrandbits(20)
         nextdata = ((self.init << 1)&0xffffff) ^ (self.mask & r)
         output = 0
         l = nextdata.bit_length()
         for i in range(0, l//2, 2) :
             output += int(bin(nextdata)[2:][i])
             self.init = nextdata ^ output
         return output % 2
    

考察

  • そもそも LFSR の出力系列が与えられたときに、次を予測する、seedを求めることができる?

    • next の実装なんでこんなことになってるんだろう

      • 毎度内部状態が1bit以上かわる
    • Mersenne Twisterr を生成している理由がある?

  • 4つのLFSRのインスタンスで、ビット数が全然違うのは意味があるのか

  • combine によって何かしらの出力の偏りがあるかどうか

  • たとえば a が1なら bcd の値によって、 75% 出力が1とかそういうやつ

combineの偏りがあるか調査してみる

  • (a|c)^(b|c)^(c|d) は75% の割合で1になるので

  • 出力のたとえば半分を適当にサンプリングして(これによって、 サンプルしたケースにおいてはすべて (a|c)^(b|c)^(c|d) == 1 を仮定する)、 出力が1なら a == b , 0 なら a != b となるとしても良さそう

    • これはうそ で、そうなる確率は  0.75^{80} とかになる
  • そうすると a,b だけを気にすればよく、この場合 a, b のmask, seedは小さいので全探索などしても良さそう?

公式writeup

combineの偏りはもっとあった

  • けどよくわからん
from Crypto.Cipher import AES
from hashlib import sha256
from random import Random
from tqdm import tqdm
from sympy.ntheory import factorint

remdata = list(open("rem_data").read())
remdata = list(map(int, remdata))
masks = [43, 578, 22079, 142962]
SECRET = 14810031

class lfsr():
    def __init__(self, init, mask, masklength):
        self.rand = Random()
        self.rand.seed(SECRET + mask)
        self.init = init
        self.mask = mask
        self.masklength = masklength

    def next(self):
        r = self.rand.getrandbits(20)
        nextdata = ((self.init << 1)&0xffffff) ^ (self.mask & r)
        # print bin(nextdata)[2:]
        output = 0
        l = nextdata.bit_length()  # St-1
        for i in range(0, l//2, 2) :
            output += int(bin(nextdata)[2:][i])  # St-2
            # output+=(nextdata>>(l-i-1))&1
        self.init = nextdata ^ output
        return output %2

# Tried finding an alternative for St-1 and St-2 but couldnt. If this is true the challenge can be solved only through correlation attack

def combine(a, b, c, d) :
    return (a^b)^(a|c)^(b|c)^(c|d)
    # a==b 75%
    # c^d == out 75%
    # if d!=out => c==1 always

def solve_d() :
    poss = []
    for i in tqdm(range(2**17 , 2**17 + 2**14)) :  # given in the secription the reason for range
        d = lfsr(i, masks[3], masks[3].bit_length())
        ct = 0
        for k in range(160) :
            d.next()
        for j in remdata[:2000] :
            if d.next() == j :
                ct+=1.0
        # if i == 136757 :
        #    print " for the actual seed ct is ",ct
        if ct/2000 >= 0.74 :
            poss.append(i)
            print((i, ct/2000))
    bd = []
    for i in poss :
        for j in factorint(i).keys() :
            if j.bit_length() == 10 :
                bd.append((j, i))
    return bd


def solve_c() :
    pair = []
    poss_d = solve_d()
    for b, d in tqdm(poss_d) :
        #poss_c = []
        for i in tqdm(range(2**14, 2**15)) :
            dt = lfsr(d, masks[3], masks[3].bit_length())
            ct = lfsr(i, masks[2], masks[2].bit_length())
            ct1 =ct2 = 0
            for j in range(160) :
                dt.next()
                ct.next()
            for j in remdata[:2000] :
                dtmp = dt.next()
                ctmp = ct.next()
                if dtmp!=j and ctmp!=1 :
                    break
                if ctmp == j :
                    ct1+=1.0
                if ctmp^dtmp == j :
                    ct2+=1.0
            if i == 17523 :
                print("org", (i, ct1, ct2))

            if ct1/2000 > 0.45 and ct1/2000<0.6 and ct2/2000>0.74 :
                # poss_c.append(i)
                pair.append((b, i, d))
                print(i, ct1/2000, ct2/2000)

    return pair

def solve_rest() :

    res = []
    pair = solve_c()
    for b, c, d in tqdm(pair) :
        for a in range(2**5, 2**6) :
            l = [lfsr(i, j, j.bit_length()) for i, j in zip((a, b, c, d), masks)]
            for i in range(160) :
                l[0].next()
                l[1].next()
                l[2].next()
                l[3].next()
            ct = 0
            for i in remdata[:2000] :
                if combine(l[0].next(), l[1].next(), l[2].next(), l[3].next()) == i :

                    ct+=1
                    continue
                else :
                    break
            if ct == 2000 :
                decrypt([a, b, c, d])
    return res

def generate(seeds) :
    masks = [43, 578, 22079, 142962]
    l = [lfsr(seeds[i], masks[i], masks[i].bit_length()) for i in range(4)]
    data = ''
    for i in range(160) :
        data += str(combine(l[0].next(), l[1].next(), l[2].next(), l[3].next()))
    return data


def decrypt(seeds) :

    f = open('flag.enc', "rb").read()

    print(seeds)
    key = sha256(generate(seeds).encode()).digest()
    flag = AES.new(key, AES.MODE_ECB).decrypt(f)
    print(flag)


if __name__ == "__main__" :
    solve_rest()