#LINECTF2022
from present import Present
from Crypto.Util.strxor import strxor
import os, re
class CTRMode():
def __init__(self, key, nonce=None):
self.key = key
self.cipher = DoubleRoundReducedPresent(key)
if None==nonce:
nonce = os.urandom(self.cipher.block_size//2)
self.nonce = nonce
def XorStream(self, data):
output = b""
counter = 0
for i in range(0, len(data), self.cipher.block_size):
keystream = self.cipher.encrypt(self.nonce+counter.to_bytes(self.cipher.block_size//2, 'big'))
if b""==keystream:
exit(1)
if len(data)<i+self.cipher.block_size:
block = data[i:len(data)]
block = data[i:i+self.cipher.block_size]
block = strxor(keystream[:len(block)], block)
output+=block
counter+=1
return output
def encrypt(self, plaintext):
return self.XorStream(plaintext)
def decrypt(self, ciphertext):
return self.XorStream(ciphertext)
class DoubleRoundReducedPresent():
def __init__(self, key):
self.block_size = 8
self.key_length = 160
self.round = 16
self.cipher0 = Present(key[0:10], self.round)
self.cipher1 = Present(key[10:20], self.round)
def encrypt(self, plaintext):
if len(plaintext)>self.block_size:
print("Error: Plaintext must be less than %d bytes per block" % self.block_size)
return b""
return self.cipher1.encrypt(self.cipher0.encrypt(plaintext))
def decrypt(self, ciphertext):
if len(ciphertext)>self.block_size:
print("Error: Ciphertext must be less than %d bytes per block" % self.block_size)
return b""
return self.cipher0.decrypt(self.cipher1.decrypt(ciphertext))
if __name__ == "__main__":
import sys, os
sys.path.append(os.path.join(os.path.dirname(__file__), './secret/'))
from keyfile import key
from flag import flag
if not re.fullmatch(r'[0-3]+', key):
exit(1)
key = key.encode('ascii')
flag = flag.encode('ascii')
plain = flag
cipher = CTRMode(key)
ciphertext = cipher.encrypt(plain)
nonce = cipher.nonce
print(ciphertext.hex())
print(nonce.hex())
'''
Python3 PRESENT implementation
original code (implementation for python2) is here:
http://www.lightweightcrypto.org/downloads/implementations/pypresent.py
'''
'''
key = bytes.fromhex("00000000000000000000")
plain = bytes.fromhex("0000000000000000")
cipher = Present(key)
encrypted = cipher.encrypt(plain)
print(encrypted.hex())
>> '5579c1387b228445'
decrypted = cipher.decrypt(encrypted)
decrypted.hex()
>> '0000000000000000'
'''
class Present:
def __init__(self,key,rounds=32):
"""Create a PRESENT cipher object
key: the key as a 128-bit or 80-bit rawstring
rounds: the number of rounds as an integer, 32 by default
"""
self.rounds = rounds
if len(key) * 8 == 80:
self.roundkeys = generateRoundkeys80(byte2number(key),self.rounds)
elif len(key) * 8 == 128:
self.roundkeys = generateRoundkeys128(byte2number(key),self.rounds)
else:
raise (ValueError, "Key must be a 128-bit or 80-bit rawstring")
def encrypt(self,block):
"""Encrypt 1 block (8 bytes)
Input: plaintext block as raw string
Output: ciphertext block as raw string
"""
state = byte2number(block)
for i in range(self.rounds-1):
state = addRoundKey(state,self.roundkeys[i])
state = sBoxLayer(state)
state = pLayer(state)
cipher = addRoundKey(state,self.roundkeys[-1])
return number2byte_N(cipher,8)
def decrypt(self,block):
"""Decrypt 1 block (8 bytes)
Input: ciphertext block as raw string
Output: plaintext block as raw string
"""
state = byte2number(block)
for i in range(self.rounds-1):
state = addRoundKey(state,self.roundkeys[-i-1])
state = pLayer_dec(state)
state = sBoxLayer_dec(state)
decipher = addRoundKey(state,self.roundkeys[0])
return number2byte_N(decipher,8)
def get_block_size(self):
return 8
Sbox= [0xc,0x5,0x6,0xb,0x9,0x0,0xa,0xd,0x3,0xe,0xf,0x8,0x4,0x7,0x1,0x2]
Sbox_inv = [Sbox.index(x) for x in range(16)]
PBox = [0,16,32,48,1,17,33,49,2,18,34,50,3,19,35,51,
4,20,36,52,5,21,37,53,6,22,38,54,7,23,39,55,
8,24,40,56,9,25,41,57,10,26,42,58,11,27,43,59,
12,28,44,60,13,29,45,61,14,30,46,62,15,31,47,63]
PBox_inv = [PBox.index(x) for x in range(64)]
def generateRoundkeys80(key,rounds):
"""Generate the roundkeys for a 80-bit key
Input:
key: the key as a 80-bit integer
rounds: the number of rounds as an integer
Output: list of 64-bit roundkeys as integers"""
roundkeys = []
for i in range(1,rounds+1):
roundkeys.append(key >>16)
key = ((key & (2**19-1)) << 61) + (key >> 19)
key = (Sbox[key >> 76] << 76)+(key & (2**76-1))
key ^= i << 15
return roundkeys
def generateRoundkeys128(key,rounds):
"""Generate the roundkeys for a 128-bit key
Input:
key: the key as a 128-bit integer
rounds: the number of rounds as an integer
Output: list of 64-bit roundkeys as integers"""
roundkeys = []
for i in range(1,rounds+1):
roundkeys.append(key >>64)
key = ((key & (2**67-1)) << 61) + (key >> 67)
key = (Sbox[key >> 124] << 124)+(Sbox[(key >> 120) & 0xF] << 120)+(key & (2**120-1))
key ^= i << 62
return roundkeys
def addRoundKey(state,roundkey):
return state ^ roundkey
def sBoxLayer(state):
"""SBox function for encryption
Input: 64-bit integer
Output: 64-bit integer"""
output = 0
for i in range(16):
output += Sbox[( state >> (i*4)) & 0xF] << (i*4)
return output
def sBoxLayer_dec(state):
"""Inverse SBox function for decryption
Input: 64-bit integer
Output: 64-bit integer"""
output = 0
for i in range(16):
output += Sbox_inv[( state >> (i*4)) & 0xF] << (i*4)
return output
def pLayer(state):
"""Permutation layer for encryption
Input: 64-bit integer
Output: 64-bit integer"""
output = 0
for i in range(64):
output += ((state >> i) & 0x01) << PBox[i]
return output
def pLayer_dec(state):
"""Permutation layer for decryption
Input: 64-bit integer
Output: 64-bit integer"""
output = 0
for i in range(64):
output += ((state >> i) & 0x01) << PBox_inv[i]
return output
def byte2number(i):
""" Convert a string to a number
Input: byte (big-endian)
Output: long or integer
"""
return int.from_bytes(i, 'big')
def number2byte_N(i, N):
"""Convert a number to a string of fixed size
i: long or integer
N: length of byte
Output: string (big-endian)
"""
return i.to_bytes(N, byteorder='big')
def _test():
import doctest
doctest.testmod()
if __name__ == "__main__":
_test()
meet-in-the-middle_attack
from present import Present
from ptrlib import xor
from itertools import product
from main import CTRMode
from tqdm import tqdm
ciphertext_hex="3201339d0fcffbd152f169ddcb8349647d8bc36a73abc4d981d3206f4b1d98468995b9b1c15dc0f0"
nonce_hex="32e10325"
ciphertext = bytes.fromhex(ciphertext_hex)
nonce = bytes.fromhex(nonce_hex)
key1 = xor(b"LINECTF{", ciphertext)
plain = nonce + (0).to_bytes(4, "big")
table = {}
print("forward...")
for pat in tqdm(product("0123", repeat=10)):
k = "".join(pat).encode()
cipher = Present(k, 16)
table[cipher.encrypt(plain)] = k
print("backward...")
key = None
for pat in tqdm(product("0123", repeat=10)):
k = "".join(pat).encode()
cipher = Present(k, 16)
x = cipher.decrypt(key1)
if x in table:
print("found")
key = table[x] + k
break
if key is None:
quit()
ctr = CTRMode(key, nonce)
print(ctr.decrypt(ciphertext))