技術めいた何か

社会人になってしまった

SECCON CTF 2021 Writeup(Crypto: cerberus)

チームlazy_oracleで参加しました。(一人チーム)
ブロック暗号問があったら解きたいなと思っていたところ、cerberusが該当問だったので解いてみました。

---追記(2021-12-13): 作問者のkurenaifさんからコメントを頂きました。

---追記終わり

まず、チャレンジャーには以下の情報が与えれます。

ソースコードを読むと、以下のことがわかりました。

  • 復号オラクルの実装である
    • 復号した平文は返さない
    • パディングチェックの結果は返す
  • 暗号アルゴリズムはAES-128-PCBC

詳細を見ると、
c = EncAES-128-PCBC(flag, key, iv)
のように鍵長128bitのAESのPCBCモードでflagが暗号化されているようです。 また、クライアントから受け付けたivと暗号文を使用してサーバーは復号処理を行い、最終ブロックのパディングチェックの結果を返します。 サーバーから与えられるパラメータは

  • 暗号文c
  • iv

また、サーバーのソースコードを見てみると入力に以下の条件があることがわかります。

  1. クライアントから受けとった暗号文がもとの暗号文と前方一致しているかチェック
  2. 復号した平文に対してパッディングチェックをし、成否を出力

つまり、今回自由に操作できるのは以下の2つであることがわかります。

  • iv
  • もとの暗号文c以降に任意のバイナリを連結すること

条件1,2の該当部分のソースコード

while True:
    c = base64.b64decode(input("spell:"))
    iv = c[:16]
    c = c[16:]

    if not c.startswith(ref_c):
        print("Grrrrrrr!!!!")
        continue

    m = decrypt(iv, c)

    try:
        unpad(m, block_size)
    except:
        print("little different :(")
        continue

    print("Great :)")

そして、条件2からPadding Oracle攻撃が狙えそうです。 ただし、PCBCモードは下の図のように2ブロック目以降は暗号文ブロックと平文ブロックが後ろのブロックの復号に影響を与えます。 なので、Padding Oracleをするために任意の入力を作るのは工夫が必要です。

f:id:atofaer:20211212145849p:plain
PCBC modeの復号処理

ここで、1ブロック目に着目します。 1ブロック目であればIVを制御するだけでPadding Oracle攻撃ができそうです。 ただし、条件1の制限があるため、オリジナルの暗号文cのうち1ブロック目だけを素朴に切り出してサーバーに入力することはできません。 なので、2ブロック目以降を打ち消す方法を考えます。

ここで、排他的論理和の性質を思い出します。 排他的論理和の真理値表は次のとおりです。 排他的論理和はA==Bのときに0を出力します。 つまり、同じ値を持つもの同士を打ち消す性質があります。

A B A xor B
0 0 0
0 1 1
1 0 1
1 1 0

つまり、以下の図のように暗号文cに2ブロック以降の暗号文C3, C2, C1を連結すれば2ブロック目以降を打ち消して1ブロック目の出力P0の値のみを最終ブロックにもっていけそうです。

f:id:atofaer:20211212151142p:plain
3ブロック以降を打ち消す

しかし、このままだとC0とC1の影響が残っているので、ivを操作して打ち消します。

f:id:atofaer:20211212162250p:plain
C0, C1を打ち消す

ここまでくれば、ivを操作してPadding Oracle攻撃で1ブロック目のDec(C0)(図赤丸部分)を求めて
P0 = Dec(C0) xor iv とすることで、平文P0を復元できます。

f:id:atofaer:20211212153450p:plain
Padding OracleでDec(C0)を復元する(赤丸部分)

また、2ブロック目は復元したDec(C0)を使い同じ方針で3ブロック目以降を打ち消した後にPadding Oracle攻撃でDec(C1)を復元します。
その後、P1 = Dec(C1) xor P0 xor C0
とすることで平文P1を復元できます。
3ブロック目以降も同様です。

f:id:atofaer:20211212162337p:plain
P1, P2, P3の復元

Solverは以下のとおりです。
SECCON{v._.^v-_-v^._.^_S0und_oF_0rpHeUs_Aha~~}

import socket
import base64
from Crypto.Cipher import AES
from Crypto.Random import get_random_bytes
from Crypto.Util.Padding import pad, unpad
from Crypto.Util.strxor import strxor

block_size = 16

#----------
# Netcat class: ref https://scrapbox.io/progfay-pub/netcat.py
class Netcat:
    """ Python 'netcat like' module """

    def __init__(self, ip, port):
        self.buff = ""
        self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        self.socket.connect((ip, port))

    def read(self, length = 1024):
        """ Read 1024 bytes off the socket """
        return self.socket.recv(length)

    def write(self, data):
        self.socket.sendall(((str(data)+'\n')).encode('utf-8'))

    def close(self):
        self.socket.close()

#----------



def unpack_ct(b64string):
    a = base64.b64decode(b64string.encode('utf-8'))
    iv = a[:16]
    c = a[16:]
    return iv, c

def pack_ct(iv, c):
    return base64.b64encode(iv+c).decode('utf-8')

def recover_plaintext_block(iv, c, mask0=b"\x00"*block_size, mask1=b"\x00"*block_size):
    attack_query = c
    attack_query_iv = iv


    iv0 = b'\x00'*block_size
    iv0 = [iv0[i : i + 1] for i in range(0, len(iv0))]
    state = [None]*block_size
    for b in range(block_size):
        for j in range(b):
            # ターゲットより下位バイトをPadding
            iv0[block_size-1-j] = strxor(state[15-j], bytes([b+1]))
        for i in range(256):
            iv0[block_size-1-b] = bytes([i])
            iv0_t = strxor(b"".join(iv0), attack_query_iv)
            nc.write(pack_ct(iv0_t, attack_query))
            lines = nc.read().decode('utf-8').split('\n')
            if -1!=lines[0].find("Great"):
                print('found!')
                state[15-b] = strxor(bytes([i]), bytes([b+1]))
                break
            if 255==i:
                print('cannot find')
                exit(-1)
    plaintext=strxor(b"".join(state), mask0)
    plaintext=strxor(plaintext, mask1)
    return plaintext, b"".join(state)

nc = Netcat('cerberus.quals.seccon.jp', 8080)

lines = nc.read().decode('utf-8').split('\n')
for line in lines:
    print(line)
iv, c = unpack_ct(lines[-2])

num_of_block = len(c)//block_size
print("num of block: %d" %num_of_block)

print("recover p0...")
attack_query_iv = strxor(c[:16], c[16:32])
attack_query_ciphertext = c+c[48:64]+c[32:48]+c[16:32]
p0, state_c0 = recover_plaintext_block(attack_query_iv, attack_query_ciphertext, mask0=iv)
print("p0 is:")
print(p0)
print("---")

print("recover p1...")
attack_query_ciphertext = c+c[48:64]+c[32:48]
attack_query_iv = strxor(c[:16], c[16:32])
attack_query_iv = strxor(attack_query_iv, c[32:48])
attack_query_iv = strxor(attack_query_iv, state_c0)
p1, state_c1 = recover_plaintext_block(attack_query_iv, attack_query_ciphertext, mask0=c[0:16], mask1=p0)
print("p1 is:")
print(p1)
print("---")

print("recover p2...")
attack_query_ciphertext = c+c[48:64]
attack_query_iv = strxor(c[:16], c[16:32])
attack_query_iv = strxor(attack_query_iv, c[32:48])
attack_query_iv = strxor(attack_query_iv, c[48:64])
attack_query_iv = strxor(attack_query_iv, state_c0)
attack_query_iv = strxor(attack_query_iv, state_c1)
p2, state_c2 = recover_plaintext_block(attack_query_iv, attack_query_ciphertext, mask0=c[16:32], mask1=p1)
print("p2 is:")
print(p2)
print("---")

print("recover p3...")
attack_query_ciphertext = c
attack_query_iv = strxor(c[:16], c[16:32])
attack_query_iv = strxor(attack_query_iv, c[32:48])
attack_query_iv = strxor(attack_query_iv, state_c0)
attack_query_iv = strxor(attack_query_iv, state_c1)
attack_query_iv = strxor(attack_query_iv, state_c2)
p3, _ = recover_plaintext_block(attack_query_iv, attack_query_ciphertext, mask0=c[32:48], mask1=p2)
print("p3 is:")
print(p3)
print("---")

p_all = p0+p1+p2+p3
flag = unpad(p_all, block_size)[16:].decode('utf-8')
print(flag)