from hash import secureHash import SocketServer import string import random from text import * import os from hashlib import sha256 PORT = 2000 ROUNDS = 10 TIMEOUT = 120 sigma = string.ascii_letters + string.digits + "!@#$%^&*()-_=+[{]}<>.,?;:" def isPrintable(x): global sigma alpha = set(sigma) beta = set(x) return alpha.intersection(beta) == beta def get_random_string(l, s): return "".join([random.choice(s) for i in range(l)]) class ThreadedTCPRequestHandler(SocketServer.BaseRequestHandler): def PoW(self): s = os.urandom(10) h = sha256(s).hexdigest() self.request.sendall( "Provide a hex string X such that sha256(X)[-6:] = {}\n".format(h[-6:]) ) inp = self.request.recv(2048).strip().lower() is_hex = 1 for c in inp: if not c in "0123456789abcdef": is_hex = 0 if is_hex and sha256(inp.decode("hex")).hexdigest()[-6:] == h[-6:]: self.request.sendall("Good, you can continue!\n") return True else: self.request.sendall("Oops, your string didn't respect the criterion.\n") return False def challenge(self, n): s = get_random_string(random.randint(30, 35), sigma) H = secureHash() H.update(s) h = H.hexdigest() self.request.sendall(chall_intro.format(n, s, h)) inp = self.request.recv(2048).strip() H_inp = secureHash() H_inp.update(inp) h_inp = H_inp.hexdigest() if inp == s or not isPrintable(inp): self.request.sendall(chall_wrong) return False if h_inp != h: self.request.sendall(chall_wrong) return False self.request.sendall(chall_ok) return True def handle(self): self.request.settimeout(TIMEOUT) if not self.PoW(): return self.request.sendall(intro.format(ROUNDS)) for i in range(ROUNDS): if not self.challenge(i + 1): self.request.sendall(losing_outro) return self.request.sendall(winning_outro.format(FLAG)) def finish(self): logger.info("%s client disconnected" % self.client_address[0]) class ThreadedTCPServer(SocketServer.ThreadingMixIn, SocketServer.TCPServer): pass if __name__ == "__main__": server = ThreadedTCPServer(("0.0.0.0", PORT), ThreadedTCPRequestHandler) server.allow_reuse_address = True server.serve_forever()
class secureHash(object): def __init__(self): self.bits = 128 self.mod = 2**128 self.mask = 2**128 - 1 self.step = 23643483844282862943960719738L self.hash = 9144491976215488621715609182563L def update(self, inp): for ch in inp: self.hash = ((self.hash + ord(ch)) * self.step) & self.mask def hexdigest(self): x = self.hash out = '' for i in range(self.bits/8): out=hex(x & 0xff)[2:].replace('L','').zfill(2)+out x >>= 8 return out
ソースコードがめちゃくちゃ読みにくいが(大体SocketServerがわるい)、とにかく、このSecureHashというハッシュを10回衝突させれば良い。SecureHashとかいう名前でありながら実装はRollingHashなので、うまくやれば衝突はさほど難しくない、はず。特に2冪modのRollingHashの衝突は容易だったはず。
とはいえ2冪modじゃなくても衝突はできる。やっていきます。
さて、今という文字列が与えられたとする。このときのhash値は次のようになる。
ここで次のような値列を考えてみる。
modと加算の性質から当然以下が導けるので、そういう文字列はと衝突する
じゃあどうやってそういうを作るんですか?
そこで登場するのがLLLです。今回考えるのは次のような部分和問題になる
ナップザック暗号と同じ要領で、
[[1, ..., 0, s^k] ,[0, 1, ..., s^{k-1}] ,... ,[0, ..., 1, s^0] ,[0, ..., 0, -t*N]]
という行列に対してLLLをすると、どこかの行で、そういう答えが見つかるはず
このとき、答えのベクトルをとしてLLLで求めたいベクトルはなので(LLL and Markle-Hellman Knapsack cryptosystemでいう)、末尾要素が0になっているベクトルを探す。
in outも書くと面倒なのでできていることをチェックするだけのスクリプトになった 。kの範囲を適当に定めている。あと体感で100回まわすのも10000回回すのも同じ(みつからないときはみつからない)
import random import string from hash import secureHash ROUNDS = 10 sigma = string.ascii_letters + string.digits + "!@#$%^&*()-_=+[{]}<>.,?;:" def get_random_string(l, s): return "".join([random.choice(s) for i in range(l)]) def getHash(s): H = secureHash() H.update(s) return H.hexdigest() def makeChallenge(): s = get_random_string(random.randint(30, 35), sigma) return (s, getHash(s)) def checkChallenge(s1, s2): if s1 == s2 or not all([c in sigma for c in s2]): return False h1 = getHash(s1) h2 = getHash(s2) if h1 != h2: return False return True def getCollision(s, h): mod = 2**128 step = 23643483844282862943960719738L hash = 9144491976215488621715609182563L try: for k in [len(s)] + list(range(30, 36)): for t in range(1, 100): A = [[1 if i == j else 0 for i in range(k + 1)] for j in range(k + 1)] for i in range(len(A) - 1): A[i][-1] = power_mod(step, k-i-1, mod) A[-1][-1] = t * mod A = matrix(A) A = A.LLL() # check rows for r in A: if r[-1] == 0 and not all([v == 0 for v in r]): s2 = [ord(c) + e for c, e in zip(s,r)] if all([(c in range(256)) and (chr(c) in sigma) for c in s2]): assert(False) except AssertionError: r = [0] * (len(s) - len(r)) + list(r) return "".join([chr(ord(c) + e) for c, e in zip(s, r)]) return None ROUND = 10 for i in range(1, ROUND + 1): s, h = makeChallenge() s2 = getCollision(s, h) ok = checkChallenge(s, s2) if s2 else False print("ROUND {}/{}: {}".format(i, ROUND, ok)) if not ok: print("s1: {}, {}".format(repr(s), getHash(s))) print("s2: {}, {}".format(repr(s2), getHash(s2))) break