InCTF | waRSAw

#inctf2019

https://github.com/ashutosh1206/Crypton/tree/master/RSA-encryption/Attack-LSBit-Oracle-variant

Just another conventional RSA attack, isn't it?

#!/usr/bin/env python
from Crypto.PublicKey import *
from Crypto.Util.number import *
import os, sys

class Unbuffered(object):
   def __init__(self, stream):
       self.stream = stream
   def write(self, data):
       self.stream.write(data)
       self.stream.flush()
   def writelines(self, datas):
       self.stream.writelines(datas)
       self.stream.flush()
   def __getattr__(self, attr):
       return getattr(self.stream, attr)

sys.stdout = Unbuffered(sys.stdout)

def _encrypt(message, e, n):
    m = bytes_to_long(message)
    return long_to_bytes(pow(m, e, n))

def _decrypt(ciphertext, d, n):
    ct = bytes_to_long(ciphertext)
    return long_to_bytes(pow(ct, d, n) % 2)

def genkey(size):
    p = getPrime(size/2)
    q = getPrime(size/2)
    e = 65537
    phin = (p-1)*(q-1)
    assert GCD(e, phin) == 1
    d = inverse(e, phin)
    n = p*q
    return (p, q, e, d, phin, n)

if __name__ == "__main__":
    p, q, e, d, phin, n = genkey(1024)
    flag = open("flag").read().strip()
    print "Welcome to RSA encryption oracle!"
    print "Here take your flag (in hex): ", _encrypt(flag, e, n).encode("hex")
    print "Here take modulus: ", n
    print "RSA service"
    print "[1] Encrypt"
    print "[2] Decrypt"
    option = int(raw_input("Enter your choice: "))
    if option == 1:
        try:
            message = raw_input("Enter the message you want to encrypt (in hex): ").decode("hex")
        except:
            print "Enter proper hex chars"
            exit(0)
        ct = _encrypt(message, e, n)
        print "Here take your ciphertext (in hex): ", ct.encode("hex")
        print "\n\n"
    elif option == 2:
        try:
            ciphertext = raw_input("Enter the ciphertext you want to decrypt (in hex): ").decode("hex")
        except:
            print "Enter proper hex chars"
            exit(0)
        msg = _decrypt(ciphertext, d, n)
        print "Here take your plaintext (in hex): ", msg.encode("hex")
        print "\n\n"
    else:
        print "Enter a valid option!"
    print "Exiting..."

RSAの問題。毎回異なるnで暗号化されたcが渡され、EncryptかDecryptができるが、Decryptは最下位bitしかもらえない。

問題としてはLSBLeakAttackの様に見えるが、nが毎回違うのでうまく動かない。

あとは素朴に 2^{-xe} \mod Nを計算しておいて、 c*2^{-xe}を復号して (c*2^{-xe})^d \equiv m * 2^{-x} \mod Nする方法が思いつく。これについて#### ちゃんと考える

まず、kbitのmを m = 2^{k-1}a_{k-1} + \dots + 2^2a_2 + 2^1a_1 + a_0と書く。単にこれのmod2をとってみると m \equiv a_0 \mod 2となることは明らか。

では次に m*2^{-1} \mod Nのmod2をとってみる。

 m*2^{-1} \equiv {2*(2^{k-2}a_{k-1}+\cdots+2^1a_2+a_1)+a_0} * 2^{-1} \equiv 2^{k-2}a_{k-1}+\cdots+2^1a_2+a_1 + 2^{-1}a_0 \equiv a_1 + 2^{-1}a_0 \mod 2

(ここで 2^{-1}はmod Nでの話なので消えない)

このように a_0の項が残るが、 a_0は既知なので、 (m*2^{-1} \mod 2) - a_0 * 2^{-1} \equiv a_1 \mod N a_1が計算できる

同様に、

 m*2^{-2} \equiv a_2 + (2^a_1 + a_0)*2^{-2} \mod 2となるので、 (m*2^{-2} \mod 2) - (2a_1 + a_0)*2^{-2} \equiv a_2 \mod N

 a_2が求まる。

以下サンプルコード

#!/usr/bin/env python
from Crypto.PublicKey import *
from Crypto.Util.number import *


def genkey(size):
    p = getPrime(size // 2)
    q = getPrime(size // 2)
    e = 202001
    phin = (p - 1) * (q - 1)
    assert GCD(e, phin) == 1
    d = inverse(e, phin)
    n = p * q
    return (e, n), d


def encrypt(m, e, n):
    c = pow(m, e, n)
    return c


def decrypt(c, d, n):
    m = pow(c, d, n)
    return m % 2


# initial parameters
m = getRandomNBitInteger(100)
(e, n), d = genkey(1024)
c = encrypt(m, e, n)

# least significant bit
m2 = decrypt(c, d, n)
bits = str(m2)
i = 1

while i < m.bit_length():
    if GCD(2 ** i, n) == 1:
        # get i-th bit
        inv2 = inverse(2 ** i, n)
        inv = pow(inv2, e, n)
        m2 = decrypt((inv * c) % n, d, n)
        m2 = ((m2 - int(bits, 2) * inv2) % n) % 2
        bits = str(1 - m2) + bits

        i += 1

    # update parameters
    (e, n), d = genkey(1024)
    c = encrypt(m, e, n)

assert m == int(bits, 2)
from ptrlib.pwn.sock import Socket
from Crypto.Util.number import *
from binascii import hexlify, unhexlify
from logging import getLogger, WARN

getLogger("ptrlib.pwn").setLevel(WARN + 1)


def getparam():
    sock = Socket("localhost", 8888)
    sock.recvuntil(":  ")
    c = int.from_bytes(unhexlify(sock.recvline()), "big")
    sock.recvuntil(":  ")
    n = int(sock.recvline())

    return sock, c, n


def decrypt(sock, c):
    sock.recvuntil(": ")
    sock.sendline("2")
    sock.recvuntil(": ")
    sock.sendline(hexlify(long_to_bytes(c)))
    sock.recvuntil(":  ")
    m = bytes_to_long(unhexlify(sock.recvline()))
    sock.close()

    return m


e = 65537
sock, c, n = getparam()
m = decrypt(sock, c)
bits = str(m)
print(bits)

i = 1
while True:
    sock, c, n = getparam()
    if GCD(2 ** i, n) == 1:
        # get i-th bit
        inv2 = inverse(2 ** i, n)
        inv = pow(inv2, e, n)
        m = decrypt(sock, (inv * c) % n)
        m = (m - (int(bits, 2) * inv2) % n) % 2
        bits = str(m) + bits
        if len(bits) % 8 == 0:
            print(long_to_bytes(int(bits, 2)))
        i += 1
    else:
        sock.close()