Mersenne Twister

pythonrandom.Random とインタフェースでそこそこ互換性があり、読みやすく、状態の復元ができてbackwardの計算ができるMersenne Twister

import os
import random

class MersenneTwister(random.Random):
    N = 624
    M = 397
    A = [0, 0x9908b0df]
    UPPER_MASK = 0x80000000
    LOWER_MASK = 0x7fffffff

    def __init__(self, x=None):
        self.seed(x)

    def seed(self, a=None, version=2):
        r = random.Random()
        r.seed(a, version)
        self.setstate(r.getstate())

    def getstate(self):
        t = self.clone()
        p = t.p
        if p == 0:
            p = self.N
        else:
            for _ in range(self.N - p):
                t.next()
        return (3, tuple(t.state + [p]), None)

    def _getstate(self):
        return tuple(self.state + [self.p])

    def setstate(self, state):
        assert len(state) == 3
        assert len(state[1]) == self.N + 1
        self._setstate(state[1])

        if self.p != 0:
            p = self.p
            self.p = 0
            for _ in range(self.N - p):
                self.prev()
            self.p = p

    def _setstate(self, state):
        self.state = list(state[:self.N])
        self.p = state[self.N] % self.N

    def getrandbits(self, k: int)->int:
        if k < 0:
            raise ValueError("number of bits must be non-negative")
        if k == 0:
            return 0

        if k <= 32:
            return self.next() >> (32 - k)

        b = b''
        while k > 0:
            x = self.next()
            if k < 32:
                x = x >> (32 - k)
            b += x.to_bytes(4, 'little')
            k -= 32
        return int.from_bytes(b, 'little')

    def next_value(self)->int:
        p, q = self.p, (self.p + 1) % self.N
        # update state
        a = self.state[p] & self.UPPER_MASK
        b = self.state[q] & self.LOWER_MASK
        x = a | b

        k = (p + self.M) % self.N
        return self.state[k] ^ (x >> 1) ^ self.A[x & 1]

    def next(self)->int:
        y = self.next_value()
        self.state[self.p] = y
        self.p = (self.p + 1) % self.N
        return self._tempering(y)

    def prev(self)->int:
        p = (self.p - 2) % self.N
        k = (p + self.M) % self.N

        t = self.state[p] ^ self.state[k]  # (x >> 1) ^ self.A[x & 1]
        x_ = t ^ self.A[t >> 31]           # x >> 1 because t>>31 == 1 iff x&1 == 1)

        q, l = (p + 1) % self.N, (k + 1) % self.N
        head = ((self.state[q] ^ self.state[l]) << 1)&self.UPPER_MASK
        body = (x_ << 1)&self.LOWER_MASK
        tail = t >> 31

        self.p = (self.p - 1) % self.N
        self.state[self.p] = head|body|tail
        return self._tempering(self.next_value())

    def random(self):
        a = self.next() >> 5
        b = self.next() >> 6
        return ((a * 2**26 + b) * (1.0 / 2**53))

    def _tempering(self, y):
        y ^= y >> 11
        y ^= (y << 7) & 0x9d2c5680
        y ^= (y << 15) & 0xefc60000
        y ^= y >> 18
        return y

    def _untempering(self, y):
        y ^= y >> 18
        y ^= (y << 15) & 0xefc60000
        y ^= ((y << 7) & 0x9d2c5680) ^ ((y << 14) & 0x94284000) ^ ((y << 21) & 0x14200000) ^ ((y << 28) & 0x10000000)
        y ^= (y >> 11) ^ (y >> 22)
        return y

    def setoutputs(self, outputs):
        assert len(outputs) == self.N
        self._setstate([self._untempering(o) for o in outputs] + [0])

    def clone(self):
        t = MersenneTwister()
        t._setstate(self._getstate())
        return t


if __name__ == '__main__':
    outputs = [random.getrandbits(32) for _ in range(624)]
    mt = MersenneTwister()
    mt.setoutputs(outputs)
    for i in range(1, 1000):
        assert random.getrandbits(i) == mt.getrandbits(i), i
    assert random.getstate() == mt.getstate()

    random.seed(100)
    mt.seed(100)
    for i in range(1000):
        assert random.random() == mt.random(), i
    assert random.getstate() == mt.getstate()

    random.seed(random.getrandbits(32))
    mt.setstate(random.getstate())
    for i in range(1000):
        assert random.randint(0, 2**i) == mt.randint(0, 2**i), i
    assert random.getstate() == mt.getstate()

    mt2 = mt.clone()
    xs = [mt.getrandbits(32) for _ in range(1000)]
    for _ in range(1000):
        mt.prev()
    assert mt.getstate() == mt2.getstate()
    ys = [mt.getrandbits(32) for _ in range(1000)]
    assert xs == ys

z3 Mersenne Twister

あんまりちゃんと作ってない。大抵の場合は https://github.com/icemonster/symbolic_mersenne_cracker/ で良さそう

from claripy import BVS, BVV, Solver, LShR, If, simplify


class SymbolicMersenneTwister:
    N = 624
    M = 397

    UPPER_MASK = BVV(0x80000000, 32)
    LOWER_MASK = BVV(0x7fffffff, 32)
    A = BVV(0x9908b0df, 32)

    def __init__(self):
        self.state = [BVS(f'MT_{i}', 32) for i in range(624)]
        self.index = 0
        self.solver = Solver()

    def _tempering(self, y):
        y = y ^ LShR(y, 11)
        y = y ^ ((y << 7) & BVV(0x9D2C5680, 32))
        y = y ^ ((y << 15) & BVV(0xEFC60000, 32))
        y = y ^ LShR(y, 18)
        return y

    def skip(self):
        self.index = (self.index + 1) % self.N

    def _next_value(self):
        p, q = self.index, (self.index + 1) % self.N
        a = self.state[p] & self.UPPER_MASK
        b = self.state[q] & self.LOWER_MASK
        x = a | b

        k = (p + self.M) % self.N
        return simplify(self.state[k] ^ LShR(x, 1) ^ If(x & 1 == 0, 0, self.A))

    def next(self):
        y = self._next_value()
        self.state[self.index] = y
        self.index = (self.index + 1) % self.N
        return self._tempering(y)

    def setoutput(self, value):
        self.solver.add(self.next() == value)

    def getrand32bits(self):
        y = self.next()
        ans = self.solver.eval(y, 1)
        if len(ans) > 0:
            return ans[0]
        return None



if __name__ == '__main__':
    import random

    r = random.Random()
    r2 = random.Random()

    mt = SymbolicMersenneTwister()

    for _ in range(624):
        v = r.getrandbits(32)
        if r2.random() >= 0.4:
            mt.setoutput(v)
        else:
            mt.skip()

    for _ in range(624):
        print(r.getrandbits(32) == mt.getrand32bits())


メルセンヌツイスターを行列で表す

  • WIP

 wbitの乱数列を作ることを考えて、各乱数を w次元ベクトル \vec{x}として表すことにする

メルセンヌ・ツイスタでは、 \vec{x}_{k+n}を、それ以前の乱数 \vec{x}_{k+m}, \vec{x}_{k}, \vec{x}_{k+1} とある行列 A、撹拌(Tempering)行列 Tを使って次のように求める

 \vec{x}_{k+n} = (\vec{x}_{k+m} \oplus (\vec{x}_k^u | \vec{x}_{k+1}^l)A) T

ここで、 \vec{x}_k^u | \vec{x}_{k+1}^l というのは  \vec{x}_k^uの上位 w-rビットと \vec{x}_{k+1}^lの下位 rbit の連結で、 \oplusはXOR (xorだけどいま \mod 2を考えているので普通の加算でいい)

 A = \begin{pmatrix}0 &amp; 1 &amp; 0 &amp;  &amp; 0 \ 0 &amp; 0 &amp; 1 &amp;  &amp; 0 \ 0 &amp; 0 &amp; 0 &amp; \ddots &amp; 0 \ 0 &amp; 0 &amp; 0 &amp; &amp; 1 \ a_{w-1} &amp; a_{w-2} &amp; a_{w-3} &amp; \cdots &amp; a_{0} \end{pmatrix}