ACTF 2022 | retros

#ACTF_2022

普通の問題設定が謎のrestrictionでカス問題に……

RevするとVMでコードを送れることがわかる。ただしAES CBCで暗号化したものを送る必要があり、鍵は不明

  • 幸い復号に成功したかどうかでメッセージが変わるので Padding Oracle Encryption Attack ができる

  • 命令セットはこんな感じ

    • """
      VM:
          regs:
              0: pc
              1: memptr
              2: memptr
              3: value
              4: value
              5: global
              6: flag
          mem[32]: shuffled 32 values
          global: number
      """
      
      pc = 0
      ptr1 = 1
      ptr2 = 2
      val1 = 3
      val2 = 4
      glob = 5
      flag = 6
      
      
      def check_and_halt():
          """
          print flag if mem is sorted
          """
          return [0]
      
      
      def add_g(idx):
          assert idx in [0, 1, 2, 5, 6]
          return [1, idx]
      
      
      def sub_g(idx):
          assert idx in [0, 1, 2, 5, 6]
          return [2, idx]
      
      
      def mv_from_mem(reg, ptr):
          """
          reg = mem[ptr]
          """
          assert reg in [3, 4]  # general register
          assert ptr in [1, 2]  # mem ptr register
          return [3, reg, ptr]
      
      
      def mv_to_mem(ptr, reg):
          """
          mem[ptr] = reg
          """
          assert ptr in [1, 2]
          assert reg in [3, 4]
          return [4, ptr, reg]
      
      
      def set_g(val):
          """
          global = val
          """
          assert 0 <= val < 256
          return [5, val >> 4, val & 0x0f]
      
      
      def set_g_if(val):
          """
          if flag == 1:
              global = val
          """
          assert 0 <= val < 256
          return [6, val >> 4, val & 0x0f]
      
      
      def memcmp():
          """
          if mem[ptr1] >= mem[ptr2]:
              flag = 1
          """
          return [7]
      
      
      def cmp_ge(reg):
          """
          if reg >= global
           flag = 1
          """
          assert reg in [0, 1, 2, 3, 4, 5, 6]
          return [8, reg]
      
      
      def mv_to_g(reg):
          """
          global = reg
          """
          return [9, reg]
      
  • 31byte 以下のコードで、32byteのメモリをソートする必要がある

  • 頑張ってアセンブリを組み立てて、送りつける。readlineの実装がカスなので 0x0a が含まれているとそこで入力が打ち切られてしまうから、そうならないケースを引く必要がある

import random
from ptrlib import Process
from Crypto.Cipher import AES
from Crypto.Util.Padding import pad

"""
VM:
    regs:
        0: pc
        1: memptr
        2: memptr
        3: value
        4: value
        5: global
        6: flag
    mem[32]: shuffled 32 values
    global: number
"""

pc = 0
ptr1 = 1
ptr2 = 2
val1 = 3
val2 = 4
glob = 5
flag = 6


def check_and_halt():
    """
    print flag if mem is sorted
    """
    return [0]


def add_g(idx):
    assert idx in [0, 1, 2, 5, 6]
    return [1, idx]


def sub_g(idx):
    assert idx in [0, 1, 2, 5, 6]
    return [2, idx]


def mv_from_mem(reg, ptr):
    """
    reg = mem[ptr]
    """
    assert reg in [3, 4]  # general register
    assert ptr in [1, 2]  # mem ptr register
    return [3, reg, ptr]


def mv_to_mem(ptr, reg):
    """
    mem[ptr] = reg
    """
    assert ptr in [1, 2]
    assert reg in [3, 4]
    return [4, ptr, reg]


def set_g(val):
    """
    global = val
    """
    assert 0 <= val < 256
    return [5, val >> 4, val & 0x0f]


def set_g_if(val):
    """
    if flag == 1:
        global = val
    """
    assert 0 <= val < 256
    return [6, val >> 4, val & 0x0f]


def memcmp():
    """
    if mem[ptr1] >= mem[ptr2]:
        flag = 1
    """
    return [7]


def cmp_ge(reg):
    """
    if reg >= global
     flag = 1
    """
    assert reg in [0, 1, 2, 3, 4, 5, 6]
    return [8, reg]


def mv_to_g(reg):
    """
    global = reg
    """
    return [9, reg]

instructions = []
jump_marker = {}


def jump1(name1):
    # 3: set_g
    # 2: add_g
    return [name1, None, None, None, None]  # placeholder


def jump(name1, name2):
    # 3: set_g
    # 3: set_g_if
    # 2: add_g/sub_g
    return [name2, None, None, name1, None, None, None, None]  # placeholder


def mark(name):
    jump_marker[name] = (len(instructions)) * 4



mark("outer_loop_begin")
# ptr1 += 1
instructions += set_g(1)
instructions += add_g(ptr1)

# ptr2 = 0
instructions += mv_to_g(ptr2)
instructions += sub_g(ptr2)

# flag = ptr1 >= 32
instructions += set_g(32)
instructions += cmp_ge(ptr1)
instructions += jump("check", "inner_loop_begin")

mark("inner_loop_begin")
# flag = mem[ptr1] >= mem[ptr2]
instructions += memcmp()
instructions += jump("inner_loop_end", "swap")

mark("swap")
instructions += mv_from_mem(val1, ptr1)
instructions += mv_from_mem(val2, ptr2)
instructions += mv_to_mem(ptr1, val2)
instructions += mv_to_mem(ptr2, val1)

mark("inner_loop_end")
# flag = ptr2 >= 31
# ptr2 += 1
# if (flag) break;
# else      continue;
instructions += set_g(1)
instructions += add_g(ptr2)

instructions += mv_to_g(ptr1)
instructions += cmp_ge(ptr2)
instructions += jump("outer_loop_begin", "inner_loop_begin")

mark("check")
instructions += check_and_halt()

g_if_ind = set()
# encode instructions
is_jump_forward = False
for i in range(len(instructions)):
    if isinstance(instructions[i], str):
        name = instructions[i]
        jump_to = jump_marker[name]
        if i+3 < jump_to:
            is_jump_forward = True
        else:
            is_jump_forward = False

        if isinstance(instructions[i+3], str):
            g_if_ind.add(i + 3)

        if i in g_if_ind:
            a, b, c = set_g_if(abs(jump_to - (i+5) * 4))
            instructions[i+0] = a
            instructions[i+1] = b
            instructions[i+2] = c
        else:
            a, b, c = set_g(abs(jump_to - (i+8) * 4))
            instructions[i+0] = a
            instructions[i+1] = b
            instructions[i+2] = c

    elif instructions[i] is None:
        if is_jump_forward:
            a, b = add_g(pc)
            instructions[i+0] = a
            instructions[i+1] = b
        else:
            a, b = sub_g(pc)
            instructions[i+0] = a
            instructions[i+1] = b

print(len(instructions), instructions)

i = 0
remain_insns = list(instructions)
res = ""
insts = [(1, "check"), (2, "add_g"), (2, "sub_g"), (3, "mv_from_mem"), (3, "mv_to_mem"), (3, "set_g"), (3, "set_g_if"), (1, "memcmp"), (2, "cmp_ge"), (2, "mv_to_g")]
while len(remain_insns) != 0:
    l, name = insts[remain_insns[0]]
    ops = remain_insns[:l]
    args = ops[1:]
    remain_insns = remain_insns[l:]
    
    op =   (str(ops[0])).rjust(3)
    arg0 = (str(ops[1]) if 2 <= l else "").rjust(3)
    arg1 = (str(ops[2]) if 3 <= l else "").rjust(3)

    if name in ["set_g", "set_g_if"]:
        disasm = f'{name}({ops[1] * 0x10 + ops[2]})'.ljust(18) + f'# {ops[1]}, {ops[2]}'
    else:
        disasm = f'{name}({", ".join(map(str, args))})'
    res += f'{str(i).rjust(3)}: {op} {arg0} {arg1} | {disasm}\n'
    i += l * 4

print(res)

if len(instructions) % 2 == 1:
    instructions.append(0)
payload = []
for i in range(0, len(instructions), 2):
    payload.append((instructions[i] << 4) | instructions[i+1])

if payload[-1] != 0:
    payload.append(0)
print(len(payload), payload)
assert 16 < len(payload) < 32
payload = pad(bytes(payload), 16)

# on remote, encode payload by using padding oracle encryption attack
iv = b"\0"*16
key = b"A" * 16
with open("./key", "wb") as f:
    f.write(key)

"""
aes = AES.new(mode=AES.MODE_CBC, key=key, iv=iv)
ticket = aes.encrypt(payload)

sock = Process("./retros")

input("WAIT> ")
sock.sendline(ticket)
sock.sendlineafter("fortune: ", b"\0"*0x10)
sock.interactive()
"""

import pwn, subprocess

LOCAL = False

# if LOCAL: io = pwn.process('./retros')
if LOCAL: io = pwn.remote("localhost", 8003)
else: io = pwn.remote('123.60.146.157', 9999)

if not LOCAL:
    print(io.recvline())
    dat = io.recvline().split(b'`')[1].decode().split()[2]
    print(dat)
    dat = subprocess.check_output(['hashcash','-mb26',dat])
    print(dat)
    dat = dat.decode().split(' ')[-1].strip()
    print(dat)
    io.sendline(dat)
    print(io.recvline())
    print(io.recvline())

def _do_oracle(ct2, token, j):
    if j < 0:
        return token, False
    
    candidate = []
    last_byte = None
    xval = ([0] * 16 + [16 - j] * (16 - j))[-16:]
    for i in range(256):
        token[j] = i
        send_token = pwn.xor(token, xval)
        if b'\n' in send_token:
            candidate.append(i)
            continue
        io.sendline(send_token + ct2)
        io.sendline(b'\x00' * 16)
    for i in range(256):
        if i in candidate: continue
        r = io.recvline()
        if b'not complete' in r:
            print(j, i, token)
            last_byte = i
    if last_byte is not None:
        token[j] = last_byte
        res, _ = _do_oracle(ct2, bytearray(token), j - 1)
        return res, True

    assert 1 <= len(candidate)
    print(f'[+] {j=} {token[j + 1]=} {len(candidate)=}')
    for i in candidate:
        token[j] = i
        res, confidence = _do_oracle(ct2, bytearray(token), j - 1)
        print(res)
        if confidence:
            print(f'[+] discoverd! {res=}, {confidence=}')
            return res, True
    return res, confidence

def do_oracle(ct2):
    token = bytearray(b'\x00' * 16)
    j = 15
    res, _ = _do_oracle(ct2, token, j)
    return res

init_ct2 = b'superneko'.ljust(16, b'\x00')

ct2 = do_oracle(init_ct2)
print(f'{ct2=}')
ct2 = pwn.xor(ct2, payload[16:])
ct3 = do_oracle(ct2)
print(f'{ct3=}')
ct3 = pwn.xor(ct3, payload[:16])

io.send(ct2 + init_ct2 + b'\n' + ct3 + b'\n')

io.interactive()