async def main(): await mpc.start() print( f"LowMC with blocksize {blocksize}, keysize {keysize}, {numofboxes} s-boxes, {rounds} rounds\n" ) print("generating input ...".ljust(32), end='', flush=True) t1 = time.perf_counter() linmatrices, invlinmatrices, roundconstants, keymatrices = initialize() key = mpc.random.getrandbits(sec, keysize, bits=True) inp = mpc.random.getrandbits(sec, blocksize, bits=True) inp0 = await mpc.output(mpc.from_bits(inp)) t2 = time.perf_counter() print(f"{(t2 - t1):.5} seconds") print("reading in constants ...".ljust(32), end='', flush=True) t1 = time.perf_counter() for r in range(rounds): t = await mpc.output(mpc.input(sec(roundconstants[r]))[0]) roundconstants[r] = to_bits(t.value.value, blocksize) for x in range(blocksize): t = await mpc.output(mpc.input(sec(linmatrices[r][x]))[0]) linmatrices[r][x] = to_bits(t.value.value, blocksize) t = await mpc.output(mpc.input(sec(invlinmatrices[r][x]))[0]) invlinmatrices[r][x] = to_bits(t.value.value, blocksize) for r in range(rounds + 1): for x in range(keysize): t = await mpc.output(mpc.input(sec(keymatrices[r][x]))[0]) keymatrices[r][x] = to_bits(t.value.value, keysize) t2 = time.perf_counter() print(f"{(t2 - t1):.5} seconds") print("scheduling keys ...".ljust(32), end='', flush=True) t1 = time.perf_counter() roundkeys = keyschedule(key, keymatrices) for r in range(rounds + 1): rnd = await mpc.gather(roundkeys[r]) roundkeys[r] = [sec(0) for _ in range(blocksize - keysize)] + rnd t2 = time.perf_counter() print(f"{(t2 - t1):.5} seconds") print("encrypting ...".ljust(32), end='', flush=True) t1 = time.perf_counter() enc = encrypt(inp, roundkeys, linmatrices, roundconstants) encc = await mpc.output(enc) t2 = time.perf_counter() print(f"{(t2 - t1):.5} seconds") print("decrypting ...".ljust(32), end='', flush=True) t1 = time.perf_counter() dec = decrypt(enc, roundkeys, invlinmatrices, roundconstants) dec0 = await mpc.output(mpc.from_bits(dec)) t2 = time.perf_counter() print(f"{(t2 - t1):.5} seconds") print("\nchecking ...".ljust(32), end='', flush=True) assert inp0 == dec0, f"{inp0.value.value} != {dec0.value.value}" print("ok!") await mpc.shutdown()
def quarterround(buffer, a, b, c, d): for rot1, rot2 in [[16, 20], [24, 25]]: buffer[a] += buffer[b] buffer[d] = mpc.from_bits(rotl(to_bits_xor(buffer[d], buffer[a]), rot1)) buffer[c] += buffer[d] buffer[b] = mpc.from_bits(rotl(to_bits_xor(buffer[b], buffer[c]), rot2))
def sbox1(v): """AES inverse S-Box.""" w = mpc.to_bits(v) z = mpc.vector_add(w, B) y = mpc.matrix_prod([z], A1, True)[0] x = mpc.from_bits(y)**254 return x
def sbox(x): """AES S-Box.""" y = mpc.to_bits(x**254) z = mpc.matrix_prod([y], A, True)[0] w = mpc.vector_add(z, B) v = mpc.from_bits(w) return v
def test_empty_input(self): secint = mpc.SecInt() self.assertEqual(mpc.run(mpc.gather([])), []) self.assertEqual(mpc.run(mpc.output([])), []) self.assertEqual(mpc._reshare([]), []) self.assertEqual(mpc.convert([], None), []) self.assertEqual(mpc.sum([]), 0) self.assertEqual(mpc.prod([]), 1) self.assertEqual(mpc.in_prod([], []), 0) self.assertEqual(mpc.vector_add([], []), []) self.assertEqual(mpc.vector_sub([], []), []) self.assertEqual(mpc.scalar_mul(secint(0), []), []) self.assertEqual(mpc.schur_prod([], []), []) self.assertEqual(mpc.from_bits([]), 0)
def code(key, nonce, inp): assert (len(key) == 8) assert (len(nonce) == 2) length = len(inp) for b in range((length // 16) + (length % 16 != 0)): size = min(16, length - (b * 16)) block = matrix(key, nonce, [secint.field(b >> 32), secint.field(b % bound)]) oblock = matrix(key, nonce, [secint.field(b >> 32), secint.field(b % bound)]) for _ in range(10): doubleround(block) for i in range(size): block[i] += oblock[i] inp[(16 * b) + i] = mpc.from_bits( to_bits_xor(inp[(16 * b) + i], block[i]))
def test_secint(self): secint = mpc.SecInt() a = secint(12) b = secint(13) self.assertEqual(mpc.run(mpc.output(mpc.input(a, 0))), 12) self.assertEqual(mpc.run(mpc.output(mpc.input([a, b], 0))), [12, 13]) self.assertEqual(mpc.run(mpc.output(-a)), -12) self.assertEqual(mpc.run(mpc.output(+a)), 12) self.assertNotEqual(id(a), id(+a)) # NB: +a creates a copy self.assertEqual(mpc.run(mpc.output(a * b + b)), 12 * 13 + 13) self.assertEqual(mpc.run(mpc.output((a * b) / b)), 12) self.assertEqual(mpc.run(mpc.output((a * b) / 12)), 13) self.assertEqual(mpc.run(mpc.output(a**11 * a**-6 * a**-5)), 1) self.assertEqual(mpc.run(mpc.output(a**(secint.field.modulus - 1))), 1) c = mpc.to_bits( mpc.SecInt(0)(0)) # mpc.output() only works for nonempty lists self.assertEqual(c, []) c = mpc.run(mpc.output(mpc.to_bits(mpc.SecInt(1)(0)))) self.assertEqual(c, [0]) c = mpc.run(mpc.output(mpc.to_bits(mpc.SecInt(1)(1)))) self.assertEqual(c, [1]) c = mpc.to_bits(secint(0), 0) # mpc.output() only works for nonempty lists self.assertEqual(c, []) c = mpc.run(mpc.output(mpc.to_bits(secint(0)))) self.assertEqual(c, [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]) c = mpc.run(mpc.output(mpc.to_bits(secint(1)))) self.assertEqual(c, [ 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]) c = mpc.run(mpc.output(mpc.to_bits(secint(8113)))) self.assertEqual(c, [ 1, 0, 0, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]) c = mpc.run(mpc.output(mpc.to_bits(secint(2**31 - 1)))) self.assertEqual(c, [ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0 ]) c = mpc.run(mpc.output(mpc.to_bits(secint(2**31 - 1), 16))) self.assertEqual(c, [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]) c = mpc.run(mpc.output(mpc.to_bits(secint(-1), 8))) self.assertEqual(c, [1, 1, 1, 1, 1, 1, 1, 1]) c = mpc.run(mpc.output(mpc.to_bits(secint(-2**31)))) self.assertEqual(c, [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1 ]) c = mpc.run(mpc.output(mpc.to_bits(secint(-2**31), 16))) self.assertEqual(c, [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) c = mpc.run(mpc.output(mpc.from_bits(mpc.to_bits(secint(8113))))) self.assertEqual(c, 8113) c = mpc.run(mpc.output(mpc.from_bits(mpc.to_bits(secint(2**31 - 1))))) self.assertEqual(c, 2**31 - 1) # TODO: from_bits for negative numbers # c = mpc.run(mpc.output(mpc.from_bits(mpc.to_bits(secint(-2**31))))) # self.assertEqual(c, -2**31) self.assertFalse(mpc.run(mpc.eq_public(secint(4), secint(2)))) self.assertTrue(mpc.run(mpc.eq_public(secint(42), secint(42)))) self.assertEqual(mpc.run(mpc.output(abs(secint(1)))), 1) self.assertEqual(mpc.run(mpc.output(secint(-2**31) % 2)), 0) self.assertEqual(mpc.run(mpc.output(secint(-2**31 + 1) % 2)), 1) self.assertEqual(mpc.run(mpc.output(secint(-1) % 2)), 1) self.assertEqual(mpc.run(mpc.output(secint(0) % 2)), 0) self.assertEqual(mpc.run(mpc.output(secint(1) % 2)), 1) self.assertEqual(mpc.run(mpc.output(secint(2**31 - 1) % 2)), 1) self.assertEqual(mpc.run(mpc.output(secint(5) % 2)), 1) self.assertEqual(mpc.run(mpc.output(secint(-5) % 2)), 1) self.assertEqual(mpc.run(mpc.output(secint(50) % 2)), 0) self.assertEqual(mpc.run(mpc.output(secint(50) % 4)), 2) self.assertEqual(mpc.run(mpc.output(secint(50) % 32)), 18) self.assertEqual(mpc.run(mpc.output(secint(-50) % 2)), 0) self.assertEqual(mpc.run(mpc.output(secint(-50) % 32)), 14) self.assertEqual(mpc.run(mpc.output(secint(5) // 2)), 2) self.assertEqual(mpc.run(mpc.output(secint(50) // 2)), 25) self.assertEqual(mpc.run(mpc.output(secint(50) // 4)), 12) self.assertEqual(mpc.run(mpc.output(secint(11) << 3)), 88) self.assertEqual(mpc.run(mpc.output(secint(-11) << 3)), -88) self.assertEqual(mpc.run(mpc.output(secint(70) >> 2)), 17) self.assertEqual(mpc.run(mpc.output(secint(-70) >> 2)), -18) self.assertEqual(mpc.run(mpc.output(secint(50) % 17)), 16) self.assertEqual(mpc.run(mpc.output(secint(177) % 17)), 7) self.assertEqual(mpc.run(mpc.output(secint(-50) % 17)), 1) self.assertEqual(mpc.run(mpc.output(secint(-177) % 17)), 10) self.assertEqual(mpc.run(mpc.output(secint(3)**0)), 1) self.assertEqual(mpc.run(mpc.output(secint(3)**18)), 3**18) self.assertIn(mpc.run(mpc.output(mpc.random_bit(secint))), [0, 1]) self.assertIn(mpc.run(mpc.output(mpc.random_bit(secint, signed=True))), [-1, 1])
def test_secint(self): secint = mpc.SecInt() a = secint(12) b = secint(13) c = mpc.run(mpc.output(mpc.input(a, 0))) self.assertEqual(c, 12) c = mpc.run(mpc.output(mpc.input([a, b], 0))) self.assertEqual(c, [12, 13]) c = mpc.run(mpc.output(a * b + b)) self.assertEqual(c, 12 * 13 + 13) c = mpc.run(mpc.output((a * b) / b)) self.assertEqual(c, 12) c = mpc.run(mpc.output((a * b) / 12)) self.assertEqual(c, 13) c = mpc.run(mpc.output(a**11 * a**(-6) * a**(-5))) self.assertEqual(c, 1) c = mpc.run(mpc.output(a**(secint.field.modulus - 1))) self.assertEqual(c, 1) self.assertEqual(mpc.run(mpc.output(secint(12)**73)), 12**73) c = mpc.to_bits( mpc.SecInt(0)(0)) # mpc.output() only works for non-empty lists self.assertEqual(c, []) c = mpc.run(mpc.output(mpc.to_bits(mpc.SecInt(1)(0)))) self.assertEqual(c, [0]) c = mpc.run(mpc.output(mpc.to_bits(mpc.SecInt(1)(1)))) self.assertEqual(c, [1]) c = mpc.to_bits(secint(0), 0) # mpc.output() only works for non-empty lists self.assertEqual(c, []) c = mpc.run(mpc.output(mpc.to_bits(secint(0)))) self.assertEqual(c, [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]) c = mpc.run(mpc.output(mpc.to_bits(secint(1)))) self.assertEqual(c, [ 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]) c = mpc.run(mpc.output(mpc.to_bits(secint(8113)))) self.assertEqual(c, [ 1, 0, 0, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]) c = mpc.run(mpc.output(mpc.to_bits(secint(2**31 - 1)))) self.assertEqual(c, [ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0 ]) c = mpc.run(mpc.output(mpc.to_bits(secint(2**31 - 1), 16))) self.assertEqual(c, [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]) c = mpc.run(mpc.output(mpc.to_bits(secint(-1), 8))) self.assertEqual(c, [1, 1, 1, 1, 1, 1, 1, 1]) c = mpc.run(mpc.output(mpc.to_bits(secint(-2**31)))) self.assertEqual(c, [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1 ]) c = mpc.run(mpc.output(mpc.to_bits(secint(-2**31), 16))) self.assertEqual(c, [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) c = mpc.run(mpc.output(mpc.from_bits(mpc.to_bits(secint(8113))))) self.assertEqual(c, 8113) c = mpc.run(mpc.output(mpc.from_bits(mpc.to_bits(secint(2**31 - 1))))) self.assertEqual(c, 2**31 - 1) # TODO: from_bits for negative numbers # c = mpc.run(mpc.output(mpc.from_bits(mpc.to_bits(secint(-2**31))))) # self.assertEqual(c, -2**31) self.assertEqual(mpc.run(mpc.output(secint(-2**31) % 2)), 0) self.assertEqual(mpc.run(mpc.output(secint(-2**31 + 1) % 2)), 1) self.assertEqual(mpc.run(mpc.output(secint(-1) % 2)), 1) self.assertEqual(mpc.run(mpc.output(secint(0) % 2)), 0) self.assertEqual(mpc.run(mpc.output(secint(1) % 2)), 1) self.assertEqual(mpc.run(mpc.output(secint(2**31 - 1) % 2)), 1) self.assertEqual(mpc.run(mpc.output(secint(5) % 2)), 1) self.assertEqual(mpc.run(mpc.output(secint(-5) % 2)), 1) self.assertEqual(mpc.run(mpc.output(secint(50) % 2)), 0) self.assertEqual(mpc.run(mpc.output(secint(50) % 4)), 2) self.assertEqual(mpc.run(mpc.output(secint(50) % 32)), 18) self.assertEqual(mpc.run(mpc.output(secint(-50) % 2)), 0) self.assertEqual(mpc.run(mpc.output(secint(-50) % 32)), 14) self.assertEqual(mpc.run(mpc.output(secint(5) // 2)), 2) self.assertEqual(mpc.run(mpc.output(secint(50) // 2)), 25) self.assertEqual(mpc.run(mpc.output(secint(50) // 4)), 12) self.assertEqual(mpc.run(mpc.output(secint(11) << 3)), 88) self.assertEqual(mpc.run(mpc.output(secint(-11) << 3)), -88) self.assertEqual(mpc.run(mpc.output(secint(70) >> 2)), 17) self.assertEqual(mpc.run(mpc.output(secint(-70) >> 2)), -18) self.assertEqual(mpc.run(mpc.output(secint(50) % 17)), 16) self.assertEqual(mpc.run(mpc.output(secint(177) % 17)), 7) self.assertEqual(mpc.run(mpc.output(secint(-50) % 17)), 1) self.assertEqual(mpc.run(mpc.output(secint(-177) % 17)), 10) self.assertEqual(mpc.run(mpc.output(secint(3)**73)), 3**73) b = mpc.random_bit(secint) self.assertIn(mpc.run(mpc.output(b)), [0, 1]) b = mpc.random_bit(secint, signed=True) self.assertIn(mpc.run(mpc.output(b)), [-1, 1])