branch and prune

いわゆる分枝限定法で、 p, qの半分程度のbit(実験によると大体53%〜が精度良く機能する)が分かっているときに素因数分解できる方法

↓のコードでは分かっていると行ったbitのうち下位何bitかがエラーのケースに対応している。あと aeroctfbranch_and_prune は動作が微妙に違って、前者のほうが性能が良い

前者は Hensel's Liftを使った方法で、 f(r) = N - r \equiv 0 \mod 2^kという式が成立するとき、Hensel's Liftより f(s) = N - s \equiv 0 \mod 2^{k+m}を求めることができる。ここで r = p'q'というわけ。詳しくは……わからないですね

from random import getrandbits
from itertools import product
from gmpy2 import mpz, next_prime
import time

BITS = 512
THRESHOLD = 0.53

while True:
    X = getrandbits(BITS)
    if bin(X).count("1") >= BITS * THRESHOLD:
        break

p = next_prime(X | getrandbits(BITS))
q = next_prime(X | getrandbits(BITS))

N = p * q

# https://github.com/AeroCTF/aero-ctf-2021/tree/main/ideas/crypto/boggart
def aeroctf(bits, N, rp, rq, error_size):
    candidates = [(mpz(1), mpz(1))]
    rp, rq = mpz(rp), mpz(rq)
    rp, rq = (rp >> error_size) << error_size, (rq >> error_size) << error_size

    for i in range(1, bits):
        next_candidates = set()

        for p, q in candidates:
            if (N - p * q).bit_test(i):
                next_candidates.add((p.bit_set(i), q))
                next_candidates.add((p, q.bit_set(i)))
            else:
                next_candidates.add((p, q))
                next_candidates.add((p.bit_set(i), q.bit_set(i)))

        candidates = set()

        for p, q in next_candidates:
            if (q, p) in candidates:
                continue
            if i < error_size or ((p & rp).bit_test(i) == rp.bit_test(i) and (q & rq).bit_test(i) == rq.bit_test(i)):
                if p * q == N:
                    return (p, q)
                candidates.add((p, q))

        # print(i, len(candidates))
    return candidates

def branch_and_prune(N, pknown, qknown, n, error_size):
    pknown = (pknown >> error_size) << error_size
    qknown = (qknown >> error_size) << error_size
    cands = set([(0, 0)])
    for i in range(n):
        mod = 1 << (i + 1)

        next_cands = set()
        for cand in cands:
            cur_p, cur_q = cand
            for k in product([0, 1], repeat=2):
                pbit, qbit = k
                if (pknown >> i) & 1:
                    pbit = 1
                if (qknown >> i) & 1:
                    qbit = 1
                px = cur_p | (pbit << i)
                qx = cur_q | (qbit << i)
                p_ = X | px
                q_ = X | qx
                if p_ * q_ == N:
                    return p_, q_
                if i < error_size or ((p_ * q_) % mod == N % mod and p_ * q_ <= N):
                    if (qx, px) not in next_cands:
                        next_cands.add((px, qx))
        cands = next_cands

# print(branch_and_prune(N, X, BITS))

t1 = time.time()
print(aeroctf(BITS, N, X, X, 10))
t2 = time.time()
print(branch_and_prune(N, X, X, BITS, 10))
t3 = time.time()

print(t2 - t1)
print(t3 - t2)