def main():
    mod = cmsh.Model()

    k1 = "839c7a4d7a92cb5678a5d5b9eea5a7573c8a74deb366c3dc20a083b69f5d2a3bb3719dc69891e9f95e809fd7e8b23ba6318edd45e51fe39708bf9427e9c3e8b9"
    k2 = "839c7a4d7a92cbd678a5d529eea5a7573c8a74deb366c3dc20a083b69f5d2a3bb3719dc69891e9f95e809fd7e8b23ba6318edc45e51fe39708bf9427e9c3e8b9"

    index = 1
    blocks_value_1 = list(
        map(lambda x: mod.to_vector(x, width=32), split_hex(k1, 8)))
    blocks_value_2 = list(
        map(lambda x: mod.to_vector(x, width=32), split_hex(k2, 8)))

    r_value_1, r_rounds_1 = _md4.md4(mod, blocks_value_1)
    r_value_2, r_rounds_2 = _md4.md4(mod, blocks_value_2)

    assert r_rounds_1[0] == r_rounds_2[0]
    a, b, c, d = (
        mod.vec(32),
        md4.md4.default_state[1],
        md4.md4.default_state[2],
        md4.md4.default_state[3],
    )
    x_a, x_b, x_c, x_d = 0, 0, 0, 0
    new_d_1, new_d_2 = r_rounds_1[1], r_rounds_2[1]
    x_new_d = new_d_1 ^ new_d_2
    assert new_d_1 != new_d_2

    # abcd
    # dabc
    new_x_1, new_x_2 = mod.vec(32), mod.vec(32)
    i_a, i_b, i_c, i_d = mod.vec(32), mod.vec(32), mod.vec(32), mod.vec(32)
    j_a, j_b, j_c, j_d = mod.vec(32), mod.vec(32), mod.vec(32), mod.vec(32)
    l_out_1 = _md4.md4_round([new_x_1], _md4.md4r1, [i_d, i_a, i_b, i_c], 0,
                             7)[1]
    l_out_2 = _md4.md4_round([new_x_2], _md4.md4r1, [j_d, j_a, j_b, j_c], 0,
                             7)[1]

    mod.add_assert(new_x_1 < new_x_2)
    mod.add_assert(i_a == j_a)
    mod.add_assert(i_b == j_b)
    mod.add_assert(i_c == j_c)
    mod.add_assert(i_d == j_d)
    mod.add_assert((l_out_1 ^ l_out_2) == (x_new_d))
    sat = mod.solve()
    assert sat
    while sat:
        print(
            bin(int(i_a) ^ int(j_a)),
            bin(int(new_x_1) ^ int(new_x_2)),
            bin(int(x_new_d)),
        )
        negated = mod.negate_solution(new_x_1 ^ new_x_2)

        mod.add_assert(negated)
        sat = mod.solve()
예제 #2
0
def main():
    model = cmsh.Model()

    k1 = "839c7a4d7a92cb5678a5d5b9eea5a7573c8a74deb366c3dc20a083b69f5d2a3bb3719dc69891e9f95e809fd7e8b23ba6318edd45e51fe39708bf9427e9c3e8b9"
    k2 = "839c7a4d7a92cbd678a5d529eea5a7573c8a74deb366c3dc20a083b69f5d2a3bb3719dc69891e9f95e809fd7e8b23ba6318edc45e51fe39708bf9427e9c3e8b9"

    blocks_value_1 = list(
        map(lambda x: model.to_vector(x, width=32), split_hex(k1, 8)))
    blocks_value_2 = list(
        map(lambda x: model.to_vector(x, width=32), split_hex(k2, 8)))

    bv1 = model.join_vec(blocks_value_1)
    bv2 = model.join_vec(blocks_value_2)

    r_value_1, r_rounds_1 = _md4.md4(model, blocks_value_1)
    r_value_2, r_rounds_2 = _md4.md4(model, blocks_value_2)

    blocks_1 = gen_blocks(model, 16, 32)
    blocks_2 = gen_blocks(model, 16, 32)

    b1 = model.join_vec(blocks_1)
    b2 = model.join_vec(blocks_2)

    iv_1_arr = gen_blocks(model, 4, 32)
    iv_2_arr = gen_blocks(model, 4, 32)
    iv_1 = model.join_vec(iv_1_arr)
    iv_2 = model.join_vec(iv_2_arr)

    # The blocks must differ...
    model.add_assert(b1 < b2)
    model.add_assert((b1 ^ b2) == (bv1 ^ bv2))

    # But their MD4s must be the same
    result_1_arr, rounds_1_arr = _md4.md4(model, blocks_1)
    result_2_arr, rounds_2_arr = _md4.md4(model, blocks_2)

    result_1 = model.join_vec(result_1_arr)
    rounds_1 = model.join_vec(rounds_1_arr)
    result_2 = model.join_vec(result_2_arr)
    rounds_2 = model.join_vec(rounds_2_arr)

    model.add_assert(iv_1 == iv_2)
    model.add_assert(result_1 == result_2)

    rr1 = model.join_vec(r_rounds_1)
    rr2 = model.join_vec(r_rounds_2)

    real_diff_path = rr1 ^ rr2
    differential = rounds_1 ^ rounds_2

    diff_path = differential == real_diff_path

    model.add_assume(diff_path)

    print(model.variables)
    print(len(model.clauses))

    assert model.solve()
    print(model.sat)

    solutions = []

    while len(solutions) < 32 and model.sat:
        left = (int(iv_1), int(b1))
        right = (int(iv_2), int(b2))

        print((left, right))

        real_diff_path = int(differential)
        solutions.append((left, right))

        try:
            model.remove_assume(diff_path)
            model.add_assert(differential != real_diff_path)
        except Exception as e:
            print(e)

        new_diff_path = (differential
                         ^ real_diff_path).bit_sum() <= len(solutions)
        model.add_assume(new_diff_path)
        diff_path = new_diff_path

        model.solve()

    print(solutions)
예제 #3
0
 def compute(self, model, block, iv=None, rounds=None):
     if iv is None:
         iv = self.default_state
     if rounds is None:
         rounds = self.rounds
     return _md4.md4(model, block, iv=iv, rounds=rounds)
예제 #4
0
def main():
    mod = cmsh.Model()

    k1 = "839c7a4d7a92cb5678a5d5b9eea5a7573c8a74deb366c3dc20a083b69f5d2a3bb3719dc69891e9f95e809fd7e8b23ba6318edd45e51fe39708bf9427e9c3e8b9"
    k2 = "839c7a4d7a92cbd678a5d529eea5a7573c8a74deb366c3dc20a083b69f5d2a3bb3719dc69891e9f95e809fd7e8b23ba6318edc45e51fe39708bf9427e9c3e8b9"

    known_shaped_1 = list(
        map(lambda x: mod.to_vector(x, width=32), split_hex(k1, 8)))
    known_shaped_2 = list(
        map(lambda x: mod.to_vector(x, width=32), split_hex(k2, 8)))
    known_result_shaped_1, known_rounds_shaped_1 = _md4.md4(
        mod, known_shaped_1)
    known_result_shaped_2, known_rounds_shaped_2 = _md4.md4(
        mod, known_shaped_2)

    known_1 = mod.join_vec(known_shaped_1)
    known_2 = mod.join_vec(known_shaped_2)
    known_result_1 = mod.join_vec(known_result_shaped_1)
    known_rounds_1 = mod.join_vec(known_rounds_shaped_1)
    known_result_2 = mod.join_vec(known_result_shaped_2)
    known_rounds_2 = mod.join_vec(known_rounds_shaped_2)

    block_1 = mod.vec(512)
    block_2 = mod.vec(512)
    block_shaped_1 = mod.split_vec(block_1, 32)
    block_shaped_2 = mod.split_vec(block_2, 32)

    result_shaped_1, rounds_shaped_1 = _md4.md4(mod, block_shaped_1)
    result_shaped_2, rounds_shaped_2 = _md4.md4(mod, block_shaped_2)

    result_1 = mod.join_vec(result_shaped_1)
    rounds_1 = mod.join_vec(rounds_shaped_1)
    result_2 = mod.join_vec(result_shaped_2)
    rounds_2 = mod.join_vec(rounds_shaped_2)

    mod.add_assert(block_1 != block_2)
    mod.add_assert(result_1 == result_2)
    dpath = (rounds_1 ^ rounds_2) == (known_rounds_1 ^ known_rounds_2)
    mod.add_assert(dpath)

    # for index, var in enumerate(block_1[:32]):
    #     mod.add_assume(var == known_1[index])
    # for index, var in enumerate(block_1[64:], 64):
    #     mod.add_assume(var == known_1[index])
    # for index, var in enumerate(block_2[:32]):
    #     mod.add_assume(var == known_2[index])
    # for index, var in enumerate(block_2[64:], 64):
    #     mod.add_assume(var == known_2[index])

    # mod.solve()

    # print(hex(int(block_1)) == hex(int(known_1)))
    # print(hex(int(block_2)) == hex(int(known_2)))

    # for index, var in enumerate(block_1[:32]):
    #     mod.remove_assume(var == known_1[index])
    # for index, var in enumerate(block_1[64:], 64):
    #     mod.remove_assume(var == known_1[index])
    # for index, var in enumerate(block_2[:32]):
    #     mod.remove_assume(var == known_2[index])
    # for index, var in enumerate(block_2[64:], 64):
    #     mod.remove_assume(var == known_2[index])

    # mod.remove_assume(dpath)
    # mod.add_assert((rounds_1[64:] ^ rounds_2[64:]) == (known_rounds_1[64:] ^ known_rounds_2[64:]))
    mod.add_assert((block_1 != known_1)
                   | (block_2 != known_2)
                   | (block_1 != known_2)
                   | (block_2 != known_1))

    mod.solve()
    print(hex(int(block_1)))
    print(hex(int(block_2)))
예제 #5
0
dps = []

for data in datas:
    left, right = data
    l_iv, l_bi = left
    r_iv, r_bi = right

    l_b = model.to_vec(l_bi, width=512)
    r_b = model.to_vec(r_bi, width=512)

    l_ba = model.split_vec(l_b, 32)
    r_ba = model.split_vec(r_b, 32)

    db = l_b ^ r_b

    l_r, l_rsa = m.md4(model, l_ba)
    r_r, r_rsa = m.md4(model, r_ba)

    l_rs = model.join_vec(l_rsa)
    r_rs = model.join_vec(r_rsa)
    dp = l_rs ^ r_rs

    lv = model.join_vec(l_r)
    rv = model.join_vec(r_r)

    assert int(lv) == int(rv)

    dps.append(int(dp))

print(len(dps))
for ii in range(0, len(dps)):