X-MAS CTF | Hashed Presents v1

from hash import secureHash
import SocketServer
import string
import random
from text import *
import os
from hashlib import sha256

PORT = 2000
ROUNDS = 10
TIMEOUT = 120
sigma = string.ascii_letters + string.digits + "!@#$%^&*()-_=+[{]}<>.,?;:"


def isPrintable(x):
    global sigma

    alpha = set(sigma)
    beta = set(x)

    return alpha.intersection(beta) == beta


def get_random_string(l, s):
    return "".join([random.choice(s) for i in range(l)])


class ThreadedTCPRequestHandler(SocketServer.BaseRequestHandler):
    def PoW(self):
        s = os.urandom(10)
        h = sha256(s).hexdigest()
        self.request.sendall(
            "Provide a hex string X such that sha256(X)[-6:] = {}\n".format(h[-6:])
        )
        inp = self.request.recv(2048).strip().lower()
        is_hex = 1
        for c in inp:
            if not c in "0123456789abcdef":
                is_hex = 0

        if is_hex and sha256(inp.decode("hex")).hexdigest()[-6:] == h[-6:]:
            self.request.sendall("Good, you can continue!\n")
            return True
        else:
            self.request.sendall("Oops, your string didn't respect the criterion.\n")
            return False

    def challenge(self, n):
        s = get_random_string(random.randint(30, 35), sigma)
        H = secureHash()
        H.update(s)
        h = H.hexdigest()

        self.request.sendall(chall_intro.format(n, s, h))
        inp = self.request.recv(2048).strip()
        H_inp = secureHash()
        H_inp.update(inp)
        h_inp = H_inp.hexdigest()

        if inp == s or not isPrintable(inp):
            self.request.sendall(chall_wrong)
            return False

        if h_inp != h:
            self.request.sendall(chall_wrong)
            return False

        self.request.sendall(chall_ok)
        return True

    def handle(self):
        self.request.settimeout(TIMEOUT)
        if not self.PoW():
            return
        self.request.sendall(intro.format(ROUNDS))

        for i in range(ROUNDS):
            if not self.challenge(i + 1):
                self.request.sendall(losing_outro)
                return

        self.request.sendall(winning_outro.format(FLAG))

        def finish(self):
            logger.info("%s client disconnected" % self.client_address[0])


class ThreadedTCPServer(SocketServer.ThreadingMixIn, SocketServer.TCPServer):
    pass


if __name__ == "__main__":
    server = ThreadedTCPServer(("0.0.0.0", PORT), ThreadedTCPRequestHandler)
    server.allow_reuse_address = True
    server.serve_forever()
class secureHash(object):
    def __init__(self):
        self.bits = 128
        self.mod = 2**128
        self.mask = 2**128 - 1
        self.step = 23643483844282862943960719738L
        self.hash = 9144491976215488621715609182563L

    def update(self, inp):
        for ch in inp:
            self.hash = ((self.hash + ord(ch)) * self.step) & self.mask

    def hexdigest(self):
        x = self.hash
        out = ''
        for i in range(self.bits/8):
            out=hex(x & 0xff)[2:].replace('L','').zfill(2)+out
            x >>= 8
        return out

ソースコードがめちゃくちゃ読みにくいが(大体SocketServerがわるい)、とにかく、このSecureHashというハッシュを10回衝突させれば良い。SecureHashとかいう名前でありながら実装はRollingHashなので、うまくやれば衝突はさほど難しくない、はず。特に2冪modのRollingHashの衝突は容易だったはず。

とはいえ2冪modじゃなくても衝突はできる。やっていきます。

さて、今 a_n ... a_1という文字列が与えられたとする。このときのhash値は次のようになる。

 \left(Hs^n + a_ns^n + a_{n-1}s^{n-1} + \cdots + a_1*s \right) \mod N = \left( Hs^n + \sum_{i=1}^n a_is^i \right) \mod N

ここで次のような値列 e_k ... e_1 (k \le n)を考えてみる。

 \sum_{i=1}^{k}e_is^i \equiv 0 \mod N

modと加算の性質から当然以下が導けるので、そういう文字列 a_n...a_{n-k+1}(a_k + e_k)...(a_1+e_1) a_n...a_1と衝突する

 Hs^n + \sum_{i=k+1}^n a_is^i + \sum_{i=1}^k(a_i+e_i)s^i \equiv  Hs^n + \sum_{i=1}^n a_is^i \mod N

じゃあどうやってそういう e_k...e_1を作るんですか?

そこで登場するのがLLLです。今回考えるのは次のような部分和問題になる

 t*N = e_ks^k + e_{k-1}s^{k-1} + \cdots + e_0

ナップザック暗号と同じ要領で、

[[1, ..., 0, s^k]
,[0, 1, ..., s^{k-1}]
,...
,[0, ..., 1, s^0]
,[0, ..., 0, -t*N]]

という行列に対してLLLをすると、どこかの行で、そういう答えが見つかるはず

このとき、答えのベクトルを \vec{u}としてLLLで求めたいベクトルは (\vec{u}, 0)なので(LLL and Markle-Hellman Knapsack cryptosystemでいう W)、末尾要素が0になっているベクトルを探す。

in outも書くと面倒なのでできていることをチェックするだけのスクリプトになった 。kの範囲を適当に定めている。あと体感で100回まわすのも10000回回すのも同じ(みつからないときはみつからない)

import random
import string
from hash import secureHash

ROUNDS = 10
sigma = string.ascii_letters + string.digits + "!@#$%^&*()-_=+[{]}<>.,?;:"

def get_random_string(l, s):
    return "".join([random.choice(s) for i in range(l)])

def getHash(s):
    H = secureHash()
    H.update(s)
    return H.hexdigest()

def makeChallenge():
    s = get_random_string(random.randint(30, 35), sigma)
    return (s, getHash(s))

def checkChallenge(s1, s2):
    if s1 == s2 or not all([c in sigma for c in s2]):
        return False

    h1 = getHash(s1)
    h2 = getHash(s2)

    if h1 != h2:
        return False
    return True

def getCollision(s, h):
    mod = 2**128
    step = 23643483844282862943960719738L
    hash = 9144491976215488621715609182563L

    try:
        for k in [len(s)] + list(range(30, 36)):
            for t in range(1, 100):
                A = [[1 if i == j else 0 for i in range(k + 1)] for j in range(k + 1)]
                for i in range(len(A) - 1):
                    A[i][-1] = power_mod(step, k-i-1, mod)
                A[-1][-1] = t * mod

                A = matrix(A)
                A = A.LLL()

                # check rows
                for r in A:
                    if r[-1] == 0 and not all([v == 0 for v in r]):
                        s2 = [ord(c) + e for c, e in zip(s,r)]
                        if all([(c in range(256)) and (chr(c) in sigma) for c in s2]):
                            assert(False)
    except AssertionError:
        r = [0] * (len(s) - len(r)) + list(r)
        return "".join([chr(ord(c) + e) for c, e in zip(s, r)])
    return None


ROUND = 10
for i in range(1, ROUND + 1):
    s, h = makeChallenge()
    s2 = getCollision(s, h)
    ok = checkChallenge(s, s2) if s2 else False
    print("ROUND {}/{}: {}".format(i, ROUND, ok))
    if not ok:
        print("s1: {}, {}".format(repr(s), getHash(s)))
        print("s2: {}, {}".format(repr(s2), getHash(s2)))
        break