TetCTF 2022 | shares

#tetctf2022

"""
This is an (incomplete) implement for a new (and experimental) secret/password
sharing scheme.

The idea is simple. Basically, a secret or password is turned into a set of
finite field elements and each share is just a linear combination of these
elements. On the other hand, when enough shares are collected, the finite field
elements are determined, allowing the original secret or password to be
recovered.
"""
from typing import List
from secrets import randbelow
import string

ALLOWED_CHARS = string.ascii_lowercase + string.digits + "_"
P = len(ALLOWED_CHARS)
INT_TO_CHAR = {}
CHAR_TO_INT = {}
for _i, _c in enumerate(ALLOWED_CHARS):
    INT_TO_CHAR[_i] = _c
    CHAR_TO_INT[_c] = _i


def get_shares(password: str, n: int, t: int) -> List[str]:
    """
    Get password shares.

    Args:
        password: the password to be shared.
        n: the number of shares returned.
        t: the minimum number of shares needed to recover the password.

    Returns:
        the shares.
    """
    assert len(password) <= t
    assert n > 0

    ffes = [CHAR_TO_INT[c] for c in password]
    ffes += [randbelow(P) for _ in range(t - len(password))]
    result = []
    for _ in range(n):
        coeffs = [randbelow(P) for _ in range(len(ffes))]
        s = sum([x * y for x, y in zip(coeffs, ffes)]) % P
        coeffs.append(s)
        result.append("".join(INT_TO_CHAR[i] for i in coeffs))

    return result


def combine_shares(shares: List[str]) -> str:
    raise Exception("unimplemented")


def main():
    pw_len = 16
    password = "".join(INT_TO_CHAR[randbelow(P)] for _ in range(pw_len))

    # how about n < t :D
    n = 16
    t = 32

    for _ in range(2022):
        line = input()
        if line == password:
            from secret import FLAG
            print(FLAG)
            return
        else:
            print(get_shares(password, n, t))


if __name__ == '__main__':
    main()

get_shares という関数で、passwordをシェアするための配列を生成してくれるSecret Sharing

パスワードの文字数 l = 16、シェアの数 n = 16、秘密の復元に必要なシェアの数 t = 32

シェアの生成  s_i = \sum a_{i,j}x_j where  a_i = (a_1, a_2, \dots a_{32}) \in GF(P)^t, x = (x_1, x_2, \dots, x_{32}) \in GF(P)^t

シェアは2022回まで生成できて、 xは共通。目的は (x_1, \dots x_l)の復元

行列の形で書けばこう。

 \begin{pmatrix} a_{1,1} &amp; a_{1,2} &amp; \dots &amp;  a_{1,32} \ a_{2,1} &amp; a_{2,2} &amp; \dots &amp; a_{2,32} \ \vdots \ a_{16,1} &amp; a_{16,2} &amp; \dots &amp; a_{16,32} \end{pmatrix} \begin{pmatrix} x_1 \ x_2 \ \vdots \ x_{32} \end{pmatrix}= \begin{pmatrix} s_1 \ s_2 \ \vdots \ s_{16} \end{pmatrix}

さらにこう書いても良い。  Ax + By = s とする

 \begin{pmatrix} a_{1,1} &amp; \dots &amp; a_{1,16} \ a_{2,1} &amp; \dots &amp; a_{2,16} \ \vdots \ a_{16,1} &amp; \dots &amp; a_{16,16} \end{pmatrix} \begin{pmatrix} x_1 \ x_2 \ \vdots \ x_{16} \end{pmatrix} +  \begin{pmatrix} b_{1,1} &amp; \dots &amp; b_{1,16} \ b_{2,1} &amp; \dots &amp; b_{2,16} \ \vdots \ b_{16,1} &amp; \dots &amp; b_{16,16} \end{pmatrix} \begin{pmatrix} y_1 \ y_2 \ \vdots \ y_{16} \end{pmatrix} = \begin{pmatrix} s_1 \ s_2 \ \vdots \ s_{16} \end{pmatrix}

このとき、 Bがnot full-rankedなら \lbrack B \mid A  \rbrackRow Echelon Formを考えると、 (b_{i,1}, \dots, b_{i,16}) = 0 となるような行が存在する(要するに掃き出し法をやった時に下の方の行は1が残らなくて0になる)

このとき Ax = sだけの式となる

 (x_{1}, \dots, x_{16})は2022回の施行の中で共通かつ、 A, sは毎回わかっているから、 b_{i} = 0となるような行について a_i x = s_i という式を16個集めてこれを解けば xが求まる ( #連立方程式を解く問題

sageではleft_kernel という関数で行列 Aに対して wA = 0を満たすようなleft kernel wを求めることができる。これを使えば

 wAx + wBy = wAx + 0 = ws

が構成できるのでこれで解ける

(ある行だけを0にするような係数を探して組み合わせてもいいけど、こっちの方が楽……)

import string

ALLOWED_CHARS = string.ascii_lowercase + string.digits + "_"
P = len(ALLOWED_CHARS)
INT_TO_CHAR = {}
CHAR_TO_INT = {}
for _i, _c in enumerate(ALLOWED_CHARS):
    INT_TO_CHAR[_i] = _c
    CHAR_TO_INT[_c] = _i

F = GF(P)
from ptrlib import Process
sock = Process(["python3", "./shares.py"])

lhs = []
rhs = []

while True:
    sock.sendline("a")
    mat = matrix(F, [[CHAR_TO_INT[c] for c in row] for row in eval(sock.recvline().decode())])

    A = mat[:16,:16]
    B = mat[:16,16:32]
    s = mat[:16,32:]

    # A*x + B*y = s
    
    if B.rank() != 16:
        # ker*(A*x + B*y) = ker*s
        # ker*A*x + 0 = ker*s
        ker = B.left_kernel().matrix()
        assert (ker * B).is_zero()

        lhs.append((ker*A)[0])
        rhs.append((ker*s)[0,0])

    if len(lhs) == 16:
        x = matrix(F, lhs).solve_right(vector(F, rhs))
        password = "".join(INT_TO_CHAR[v] for v in x)
        sock.sendline(password)
        sock.interactive()
        quit()