ASIS CTF 2020 | hadamard

アダマール行列(全要素が \lbrace-1, 1\rbrace で、 HH^t = nI_nとなる H)の一部分が? に置き換えられた行列が与えられるので、?を適切に埋める。

アダマール行列は簡単な構成法がある(Sylvester's constructionなど)のだが、行列式の値を変えないような変形、すなわち、行と行のswap, 列同士のswap、任意の行または列に-1を掛けるという操作を行っても、それもまたアダマール行列である。というので、無数の例を作ることができ、もともとの形を類推してそこから穴埋めを行う、というのは難しい

アダマール行列の性質に#### 各行は直交するというのがあって、すなわち内積が0になる。これをつかってz3などで制約をつくって解く

from ptrlib import *
from hashlib import md5
from z3 import *

def parse_matrix(matrixstr):
    M = "[" + matrixstr.replace("'?'", "0").replace("\n", ",") + "]"
    return eval(M)

def encode_matrix(M):
    return md5(str(M).encode()).hexdigest()

def solve_matrix(M):
    solver = Solver()

    poses = []
    for y in range(len(M)):
        for x in range(len(M)):
            if M[y][x] == 0:
                M[y][x] = Int("{}-{}".format(y, x))
                solver.add(Or(M[y][x] == 1, M[y][x] == -1))
                poses.append((y, x))

    for l1 in range(len(M)):
        for l2 in range(len(M)):
            if l1 == l2:
                continue

            inner_prod = 0
            for i in range(len(M)):
                inner_prod += M[l1][i] * M[l2][i]
            solver.add(inner_prod == 0)

    if solver.check() == sat:
        m = solver.model()
        for p in poses:
            y, x = p[0],p[1]
            M[y][x] = m[M[y][x]]
    else:
        raise Exception("unsat")

sock = Socket("76.74.178.201", 8001)


for i in range(10):
    sock.recvuntil("M =\n ")
    matrixstr = sock.recvuntil("\n|").decode()[:-2]
    M = parse_matrix(matrixstr)
    print("stage{}, size={}".format(i + 1, len(M)))
    solve_matrix(M)
    sock.sendline(encode_matrix(M))
sock.interactive()

↓は別解

Combinational Matrix Theoryというのがあって、これを使っているんだと思う

Hadamard matrix, a square matrix of 1 and –1 coefficients with each pair of rows having matching coefficients in exactly half of their columns

import pwn as p
from hashlib import md5
from copy import deepcopy
from z3 import *
import itertools

def parse_matrix(matrixstr):
    c = matrixstr.count("'?'")
    M = "[" + matrixstr.replace("'?'", "0").replace("\n", ",") + "]"
    return eval(M), c

def check_matrix(M):
    m = Matrix(ZZ, M)
    nI = m * m.transpose()
    checkM = matrix.identity(len(M)) * nI[0][0]
    if nI != checkM:
        return False
    return True

def encode_matrix(M):
    return md5(str(M).encode()).hexdigest()
def bruteforce_solve(M, zeropos):
    if len(zeropos) == 0:
        if check_matrix(M):
            return M
        else:
            return None

    zeropos2 = deepcopy(zeropos)
    p, zeropos2 = zeropos2[0], zeropos2[1:]
    M1 = deepcopy(M)
    M2 = deepcopy(M)
    M1[p[0]][p[1]] = 1
    MR = bruteforce_solve(M1, zeropos2)
    if MR:
        return MR

    M2[p[0]][p[1]] = -1
    MR = bruteforce_solve(M2, zeropos2)
    if MR:
        return MR
    return None

def z3_to_dict(model):
    d = {}
    for c in model:
        d[str(c)] = model[c]
    return d

def z3_solve_m(M, c):
    ys = [Bool(f'y{i}') for i in range(c)]
    s = Solver()
    oM = deepcopy(M)
    k = 0
    for i, r in enumerate(M):
        for j, e in enumerate(r):
            if e == 1:
                M[i][j] = True
            if e == -1:
                M[i][j] = False
            if e == 0:
                M[i][j] = ys[k]
                k += 1
    for r1, r2 in itertools.combinations(M, 2):
        assert len(r1) == len(r2)
        a = sum([IntSort().cast((r1[i] == r2[i])) for i in range(len(r1))])
        b = sum([IntSort().cast((r1[i] != r2[i])) for i in range(len(r1))])
        s.add(a == b)
    print(s.check())
    if not s.check():
        assert False
    m = s.model()
    d = z3_to_dict(m)
    k = 0
    for i, r in enumerate(oM):
        for j, e in enumerate(r):
            if e == 0:
                a = d[f"y{k}"].sexpr()
                k += 1
                print(a, a == "true")
                if a == "true":
                    oM[i][j] = 1
                else:
                    oM[i][j] = -1
                    
    print("check", check_matrix(oM))
    return oM
    
    
sock = p.remote("76.74.178.201", 8001)


for i in range(20):
    sock.recvuntil("M =\n ")
    matrixstr = sock.recvuntil("\n|").decode()[:-2]
    M, c = parse_matrix(matrixstr)
    
    
    print("stage{}, size={}, variables={}".format(i + 1, len(M), c))
    rM = z3_solve_m(M, c)
    print(rM)

#     zeropos = []
#     for y in range(len(M)):
#         for x in range(len(M)):
#             if M[y][x] == 0:
#                 zeropos.append((y, x))
#     rM = bruteforce_solve(M, zeropos)
    
    sock.sendline(encode_matrix(rM))
sock.interactive()