def test_inv_sub_bytes(self): self.out_vector <<= self.aes.inv_sub_bytes(self.in_vector) in_vals = [0x3e175076b61c04678dfc2295f6a8bfc0, 0x2dfb02343f6d12dd09337ec75b36e3f0] true_result = [0xd1876c0f79c4300ab45594add66ff41f, 0xfa636a2825b339c940668a3157244d17] calculated_result = testingutils.sim_and_ret_out(self.out_vector, (self.in_vector,), (in_vals,)) self.assertEqual(calculated_result, true_result)
def test_inv_mix_columns(self): self.out_vector <<= self.aes.inv_mix_columns(self.in_vector) in_vals = [0xe9f74eec023020f61bf2ccf2353c21c7, 0xbaa03de7a1f9b56ed5512cba5f414d23] real_res = [0x54d990a16ba09ab596bbf40ea111702f, 0x3e1c22c0b6fcbf768da85067f6170495] calculated_result = testingutils.sim_and_ret_out(self.out_vector, (self.in_vector,), (in_vals,)) self.assertEqual(calculated_result, real_res)
def test_inv_shift_rows(self): self.out_vector <<= self.aes.inv_shift_rows(self.in_vector) in_vals = [0x3e1c22c0b6fcbf768da85067f6170495, 0x2d6d7ef03f33e334093602dd5bfb12c7] true_result = [0x3e175076b61c04678dfc2295f6a8bfc0, 0x2dfb02343f6d12dd09337ec75b36e3f0] calculated_result = testingutils.sim_and_ret_out(self.out_vector, (self.in_vector,), (in_vals,)) self.assertEqual(calculated_result, true_result)
def bitfield_update_checker(self, input_width, range_start, range_end, update_width, test_amt=20): def ref(i, s, e, u): mask = ((1 << (e)) - 1) - ((1 << s) - 1) return (i & ~mask) | ((u << s) & mask) inp, inp_vals = utils.an_input_and_vals(input_width, test_vals=test_amt, name='inp') upd, upd_vals = utils.an_input_and_vals(update_width, test_vals=test_amt, name='upd') #inp_vals = [1,1,0,0] #upd_vals = [0x7,0x6,0x7,0x6] out = pyrtl.Output(input_width, "out") bfu_out = pyrtl.bitfield_update(inp, range_start, range_end, upd) self.assertEqual(len(out), len(bfu_out)) # output should have width of input out <<= bfu_out true_result = [ ref(i, range_start, range_end, u) for i, u in zip(inp_vals, upd_vals) ] upd_result = utils.sim_and_ret_out(out, [inp, upd], [inp_vals, upd_vals]) self.assertEqual(upd_result, true_result)
def adder_t_base(self, adder_func, **kwargs): wires, vals = utils.make_inputs_and_values(dist=utils.inverse_power_dist, **kwargs) outwire = pyrtl.Output(name="test") outwire <<= adder_func(*wires) out_vals = utils.sim_and_ret_out(outwire, wires, vals) true_result = [sum(cycle_vals) for cycle_vals in zip(*vals)] self.assertEqual(out_vals, true_result)
def test_key_expansion(self): # This is not at all correct. Needs to be completely rewritten self.out_vector <<= pyrtl.corecircuits.concat_list(self.aes.decryption_key_gen(self.in_vector)) in_vals = [0xd1876c0f79c4300ab45594add66ff41f, 0xfa636a2825b339c940668a3157244d17] true_result = [0x3e175076b61c04678dfc2295f6a8bfc0, 0x2dfb02343f6d12dd09337ec75b36e3f0] calculated_result = testingutils.sim_and_ret_out(self.out_vector, (self.in_vector,), (in_vals,)) self.assertEqual(calculated_result, true_result)
def test_mix_columns(self): self.out_vector <<= self.aes_encrypt._mix_columns(self.in_vector) in_vals = [0x6353e08c0960e104cd70b751bacad0e7, 0xa7be1a6997ad739bd8c9ca451f618b61] real_res = [0x5f72641557f5bc92f7be3b291db9f91a, 0xff87968431d86a51645151fa773ad009] calculated_result = testingutils.sim_and_ret_out(self.out_vector, (self.in_vector,), (in_vals,)) self.assertEqual(calculated_result, real_res)
def test_sub_bytes(self): self.out_vector <<= self.aes_encrypt._sub_bytes(self.in_vector) in_vals = [0x4915598f55e5d7a0daca94fa1f0a63f7, 0xc62fe109f75eedc3cc79395d84f9cf5d] true_result = [0x3b59cb73fcd90ee05774222dc067fb68, 0xb415f8016858552e4bb6124c5f998a4c] calculated_result = testingutils.sim_and_ret_out(self.out_vector, (self.in_vector,), (in_vals,)) self.assertEqual(calculated_result, true_result)
def test_shift_rows(self): self.out_vector <<= self.aes_encrypt._shift_rows(self.in_vector) in_vals = [0x3b59cb73fcd90ee05774222dc067fb68, 0xb415f8016858552e4bb6124c5f998a4c] true_result = [0x3bd92268fc74fb735767cbe0c0590e2d, 0xb458124c68b68a014b99f82e5f15554c] calculated_result = testingutils.sim_and_ret_out(self.out_vector, (self.in_vector,), (in_vals,)) self.assertEqual(calculated_result, true_result)
def test_fast_group_adder_1(self): wires, vals = utils.make_inputs_and_values(max_bitwidth=12, num_wires=7, dist=utils.inverse_power_dist) outwire = pyrtl.Output(name="test") outwire <<= adders.fast_group_adder(wires) out_vals = utils.sim_and_ret_out(outwire, wires, vals) true_result = [sum(cycle_vals) for cycle_vals in zip(*vals)] self.assertEqual(out_vals, true_result)
def test_key_expansion(self): # This is not at all correct. Needs to be completely rewritten self.out_vector <<= pyrtl.concat_list(self.aes_encrypt._key_gen(self.in_vector)) in_vals = [0x4c9c1e66f771f0762c3f868e534df256, 0xc57e1c159a9bd286f05f4be098c63439] true_result = [0x3bd92268fc74fb735767cbe0c0590e2d, 0xb458124c68b68a014b99f82e5f15554c] calculated_result = testingutils.sim_and_ret_out(self.out_vector, (self.in_vector,), (in_vals,)) self.assertEqual(calculated_result, true_result)
def test_fast_group_adder_1(self): wires, vals = utils.make_inputs_and_values( max_bitwidth=12, num_wires=7, dist=utils.inverse_power_dist) outwire = pyrtl.Output(name="test") outwire <<= adders.fast_group_adder(wires) out_vals = utils.sim_and_ret_out(outwire, wires, vals) true_result = [sum(cycle_vals) for cycle_vals in zip(*vals)] self.assertEqual(out_vals, true_result)
def adder_t_base(self, adder_func, **kwargs): wires, vals = utils.make_inputs_and_values( dist=utils.inverse_power_dist, **kwargs) outwire = pyrtl.Output(name="test") outwire <<= adder_func(*wires) out_vals = utils.sim_and_ret_out(outwire, wires, vals) true_result = [sum(cycle_vals) for cycle_vals in zip(*vals)] self.assertEqual(out_vals, true_result)
def test_single_output_simulates_correctly(self): i, ivals = utils.an_input_and_vals(4, name='i') j, jvals = utils.an_input_and_vals(4, name='j') o = pyrtl.Output(8, 'o') o <<= i * j pyrtl.direct_connect_outputs() true_result = [x * y for x, y in zip(ivals, jvals)] sim_result = utils.sim_and_ret_out(o, [i, j], [ivals, jvals]) self.assertEqual(true_result, sim_result)
def shift_checker(self, shift_func, ref_func, input_width, shift_width, test_amt=20): inp, inp_vals = utils.an_input_and_vals(input_width, test_vals=test_amt, name='inp') shf, shf_vals = utils.an_input_and_vals(shift_width, test_vals=test_amt, name='shf') out = pyrtl.Output(input_width, "out") shf_out = shift_func(inp, shf) self.assertEqual(len(out), len(shf_out)) # output should have width of input out <<= shf_out true_result = [ref_func(i, s) for i, s in zip(inp_vals, shf_vals)] shift_result = utils.sim_and_ret_out(out, [inp, shf], [inp_vals, shf_vals]) self.assertEqual(shift_result, true_result)
def test_aes_full(self): aes_key = pyrtl.Input(bitwidth=128, name='aes_key') self.out_vector <<= self.aes_encrypt.encryption(self.in_vector, aes_key) plain_text = [0x00112233445566778899aabbccddeeff, 0x0] keys = [0x000102030405060708090a0b0c0d0e0f, 0x0] ciphers = [0x69c4e0d86a7b0430d8cdb78070b4c55a, 0x66e94bd4ef8a2c3b884cfa59ca342b2e] calculated_result = testingutils.sim_and_ret_out(self.out_vector, (self.in_vector, aes_key), (plain_text, keys)) self.assertEqual(calculated_result, ciphers)
def test_aes_full(self): aes_key = pyrtl.Input(bitwidth=128, name='aes_key') self.out_vector <<= self.aes.decryption(self.in_vector, aes_key) ciphers = [0x3ad77bb40d7a3660a89ecaf32466ef97, 0x66e94bd4ef8a2c3b884cfa59ca342b2e] keys = [0x2b7e151628aed2a6abf7158809cf4f3c, 0x0] plain_text = [0x6bc1bee22e409f96e93d7e117393172a, 0x0] calculated_result = testingutils.sim_and_ret_out(self.out_vector, (self.in_vector, aes_key), (ciphers, keys)) self.assertEqual(calculated_result, plain_text)
def mux_t_subprocess(self, addr_width, val_width): mux_ins, vals = utils.make_consts(num_wires=2**addr_width, exact_bitwidth=val_width) control, testctrl = utils.an_input_and_vals(addr_width, 40, "mux_ctrl") out = pyrtl.Output(val_width, "mux_out") out <<= pyrtl.corecircuits.mux(control, *mux_ins) true_result = [vals[i] for i in testctrl] mux_result = utils.sim_and_ret_out(out, (control,), (testctrl,)) self.assertEqual(mux_result, true_result)
def test_select_with_5_wires(self): val_width = 5 sels, sel_vals = utils.make_inputs_and_values(5, exact_bitwidth=1, test_vals=50) mux_ins, vals = utils.make_inputs_and_values(5, exact_bitwidth=val_width, test_vals=50) out = pyrtl.Output(val_width, "out") out <<= muxes.prioritized_mux(sels, mux_ins) actual = utils.sim_and_ret_out(out, sels + mux_ins, sel_vals + vals) expected = [pri_mux_actual(sel, val) for sel, val in zip(zip(*sel_vals), zip(*vals))] self.assertEqual(actual, expected)
def test_select_no_pred(self): vals = 12, 27 mux_ins = [pyrtl.Const(x) for x in vals] control, testctrl = utils.an_input_and_vals(1, 40, "sel_ctrl", utils.uniform_dist) out = pyrtl.Output(5, "mux_out") out <<= pyrtl.corecircuits.select(control, mux_ins[1], mux_ins[0]) true_result = [vals[i] for i in testctrl] mux_result = utils.sim_and_ret_out(out, (control,), (testctrl,)) self.assertEqual(mux_result, true_result)
def test_select(self): vals = 12, 27 mux_ins = [pyrtl.Const(x) for x in vals] control, testctrl = utils.generate_in_wire_and_values(1, 40, "sel_ctrl", utils.uniform_dist) out = pyrtl.Output(5, "mux_out") out <<= pyrtl.corecircuits.select(control, falsecase=mux_ins[0], truecase=mux_ins[1]) true_result = [vals[i] for i in testctrl] mux_result = utils.sim_and_ret_out(out, (control,), (testctrl,)) self.assertEqual(mux_result, true_result)
def test_xor(self): wires, vals = utils.make_inputs_and_values(7, exact_bitwidth=8, dist=utils.uniform_dist) outwire = pyrtl.Output(name="test") import operator from six.moves import reduce outwire <<= pyrtl.tree_reduce(operator.xor, wires) out_vals = utils.sim_and_ret_out(outwire, wires, vals) true_result = [reduce(operator.xor, v) for v in zip(*vals)] self.assertEqual(out_vals, true_result)
def test_two_vals(self): sel, sel_vals = gen_in(1) a1, a1_vals = gen_in(3) a2, a2_vals = gen_in(3) res = pyrtl.Output(name="output") res <<= muxes.sparse_mux(sel, {0: a1, 1: a2}) in_vals = [sel_vals, a1_vals, a2_vals] out_res = utils.sim_and_ret_out(res, [sel, a1, a2], in_vals) expected_out = [e2 if sel else e1 for sel, e1, e2 in zip(*in_vals)] self.assertEqual(out_res, expected_out)
def test_two_big_close(self): sel = pyrtl.Input(3) a1, a1_vals = gen_in(3) a2, a2_vals = gen_in(3) res = pyrtl.Output(name="output") sel_vals = [utils.uniform_dist(1) for i in range(20)] real_sel = [6 if s else 5 for s in sel_vals] res <<= muxes.sparse_mux(sel, {5: a1, 6: a2}) out_res = utils.sim_and_ret_out(res, [sel, a1, a2], [real_sel, a1_vals, a2_vals]) expected_out = [e2 if sel else e1 for sel, e1, e2 in zip(sel_vals, a1_vals, a2_vals)] self.assertEqual(out_res, expected_out)
def test_aes_full(self): aes_key = pyrtl.Input(bitwidth=128, name='aes_key') self.out_vector <<= self.aes_encrypt.encryption( self.in_vector, aes_key) plain_text = [0x00112233445566778899aabbccddeeff, 0x0] keys = [0x000102030405060708090a0b0c0d0e0f, 0x0] ciphers = [ 0x69c4e0d86a7b0430d8cdb78070b4c55a, 0x66e94bd4ef8a2c3b884cfa59ca342b2e ] calculated_result = testingutils.sim_and_ret_out( self.out_vector, (self.in_vector, aes_key), (plain_text, keys)) self.assertEqual(calculated_result, ciphers)
def test_mix_columns(self): self.out_vector <<= self.aes_encrypt._mix_columns(self.in_vector) in_vals = [ 0x6353e08c0960e104cd70b751bacad0e7, 0xa7be1a6997ad739bd8c9ca451f618b61 ] real_res = [ 0x5f72641557f5bc92f7be3b291db9f91a, 0xff87968431d86a51645151fa773ad009 ] calculated_result = testingutils.sim_and_ret_out( self.out_vector, (self.in_vector, ), (in_vals, )) self.assertEqual(calculated_result, real_res)
def test_inv_mix_columns(self): self.out_vector <<= self.aes_decrypt._mix_columns(self.in_vector, True) in_vals = [ 0xe9f74eec023020f61bf2ccf2353c21c7, 0xbaa03de7a1f9b56ed5512cba5f414d23 ] real_res = [ 0x54d990a16ba09ab596bbf40ea111702f, 0x3e1c22c0b6fcbf768da85067f6170495 ] calculated_result = testingutils.sim_and_ret_out( self.out_vector, (self.in_vector, ), (in_vals, )) self.assertEqual(calculated_result, real_res)
def test_sub_bytes(self): self.out_vector <<= self.aes_encrypt._sub_bytes(self.in_vector) in_vals = [ 0x4915598f55e5d7a0daca94fa1f0a63f7, 0xc62fe109f75eedc3cc79395d84f9cf5d ] true_result = [ 0x3b59cb73fcd90ee05774222dc067fb68, 0xb415f8016858552e4bb6124c5f998a4c ] calculated_result = testingutils.sim_and_ret_out( self.out_vector, (self.in_vector, ), (in_vals, )) self.assertEqual(calculated_result, true_result)
def test_inv_shift_rows(self): self.out_vector <<= self.aes_decrypt._inv_shift_rows(self.in_vector) in_vals = [ 0x3e1c22c0b6fcbf768da85067f6170495, 0x2d6d7ef03f33e334093602dd5bfb12c7 ] true_result = [ 0x3e175076b61c04678dfc2295f6a8bfc0, 0x2dfb02343f6d12dd09337ec75b36e3f0 ] calculated_result = testingutils.sim_and_ret_out( self.out_vector, (self.in_vector, ), (in_vals, )) self.assertEqual(calculated_result, true_result)
def test_shift_rows(self): self.out_vector <<= self.aes_encrypt._shift_rows(self.in_vector) in_vals = [ 0x3b59cb73fcd90ee05774222dc067fb68, 0xb415f8016858552e4bb6124c5f998a4c ] true_result = [ 0x3bd92268fc74fb735767cbe0c0590e2d, 0xb458124c68b68a014b99f82e5f15554c ] calculated_result = testingutils.sim_and_ret_out( self.out_vector, (self.in_vector, ), (in_vals, )) self.assertEqual(calculated_result, true_result)
def test_aes_full(self): aes_key = pyrtl.Input(bitwidth=128, name='aes_key') self.out_vector <<= self.aes_decrypt.decryption( self.in_vector, aes_key) ciphers = [ 0x3ad77bb40d7a3660a89ecaf32466ef97, 0x66e94bd4ef8a2c3b884cfa59ca342b2e ] keys = [0x2b7e151628aed2a6abf7158809cf4f3c, 0x0] plain_text = [0x6bc1bee22e409f96e93d7e117393172a, 0x0] calculated_result = testingutils.sim_and_ret_out( self.out_vector, (self.in_vector, aes_key), (ciphers, keys)) self.assertEqual(calculated_result, plain_text)
def test_default(self): sel, sel_vals = gen_in(3) a1, a1_vals = gen_in(3) a2, a2_vals = gen_in(3) default, default_vals = gen_in(3) res = pyrtl.Output(name="output") res <<= muxes.sparse_mux(sel, {5: a1, 6: a2, muxes.SparseDefault: default}) out_res = utils.sim_and_ret_out(res, [sel, a1, a2, default], [sel_vals, a1_vals, a2_vals, default_vals]) expected_out = [e2 if sel == 6 else e1 if sel == 5 else d for sel, e1, e2, d in zip(sel_vals, a1_vals, a2_vals, default_vals)] self.assertEqual(out_res, expected_out)
def test_inv_sub_bytes(self): self.out_vector <<= self.aes_decrypt._sub_bytes(self.in_vector, True) in_vals = [ 0x3e175076b61c04678dfc2295f6a8bfc0, 0x2dfb02343f6d12dd09337ec75b36e3f0 ] true_result = [ 0xd1876c0f79c4300ab45594add66ff41f, 0xfa636a2825b339c940668a3157244d17 ] calculated_result = testingutils.sim_and_ret_out( self.out_vector, (self.in_vector, ), (in_vals, )) self.assertEqual(calculated_result, true_result)
def test_fma_1(self): wires, vals = utils.make_inputs_and_values(exact_bitwidth=10, num_wires=3, dist=utils.inverse_power_dist) test_w = multipliers.fused_multiply_adder(wires[0], wires[1], wires[2], False, reducer=adders.dada_reducer, adder_func=adders.ripple_add) self.assertEqual(len(test_w), 20) outwire = pyrtl.Output(21, "test") outwire <<= test_w out_vals = utils.sim_and_ret_out(outwire, wires, vals) true_result = [vals[0][cycle] * vals[1][cycle] + vals[2][cycle] for cycle in range(len(vals[0]))] self.assertEqual(out_vals, true_result)
def test_gen_fma_1(self): wires, vals = utils.make_inputs_and_values(max_bitwidth=8, num_wires=8, dist=utils.inverse_power_dist) # mixing tuples and lists solely for readability purposes mult_pairs = [(wires[0], wires[1]), (wires[2], wires[3]), (wires[4], wires[5])] add_wires = (wires[6], wires[7]) outwire = pyrtl.Output(name="test") outwire <<= multipliers.generalized_fma(mult_pairs, add_wires, signed=False) out_vals = utils.sim_and_ret_out(outwire, wires, vals) true_result = [vals[0][cycle] * vals[1][cycle] + vals[2][cycle] * vals[3][cycle] + vals[4][cycle] * vals[5][cycle] + vals[6][cycle] + vals[7][cycle] for cycle in range(len(vals[0]))] self.assertEqual(out_vals, true_result)
def test_key_expansion(self): # This is not at all correct. Needs to be completely rewritten self.out_vector <<= pyrtl.corecircuits.concat_list( self.aes_decrypt._key_gen(self.in_vector)) in_vals = [ 0xd1876c0f79c4300ab45594add66ff41f, 0xfa636a2825b339c940668a3157244d17 ] true_result = [ 0x3e175076b61c04678dfc2295f6a8bfc0, 0x2dfb02343f6d12dd09337ec75b36e3f0 ] calculated_result = testingutils.sim_and_ret_out( self.out_vector, (self.in_vector, ), (in_vals, )) self.assertEqual(calculated_result, true_result)
def test_key_expansion(self): # This is not at all correct. Needs to be completely rewritten self.out_vector <<= pyrtl.concat_list( self.aes_encrypt._key_gen(self.in_vector)) in_vals = [ 0x4c9c1e66f771f0762c3f868e534df256, 0xc57e1c159a9bd286f05f4be098c63439 ] true_result = [ 0x3bd92268fc74fb735767cbe0c0590e2d, 0xb458124c68b68a014b99f82e5f15554c ] calculated_result = testingutils.sim_and_ret_out( self.out_vector, (self.in_vector, ), (in_vals, )) self.assertEqual(calculated_result, true_result)
def test_mux_with_default(self): addr_width = 5 val_width = 9 default_val = 170 # arbitrary value under 2**val_width num_defaults = 5 mux_ins, vals = utils.make_consts(num_wires=2**addr_width - num_defaults, exact_bitwidth=val_width, random_dist=utils.uniform_dist) control, testctrl = utils.an_input_and_vals(addr_width, 40, "mux_ctrl", utils.uniform_dist) for i in range(5): vals.append(default_val) out = pyrtl.Output(val_width, "mux_out") out <<= pyrtl.corecircuits.mux(control, *mux_ins, default=pyrtl.Const(default_val)) true_result = [vals[i] for i in testctrl] mux_result = utils.sim_and_ret_out(out, (control,), (testctrl,)) self.assertEqual(mux_result, true_result)
def test_fma_1(self): wires, vals = utils.make_inputs_and_values( exact_bitwidth=10, num_wires=3, dist=utils.inverse_power_dist) test_w = multipliers.fused_multiply_adder(wires[0], wires[1], wires[2], False, reducer=adders.dada_reducer, adder_func=adders.ripple_add) self.assertEqual(len(test_w), 20) outwire = pyrtl.Output(21, "test") outwire <<= test_w out_vals = utils.sim_and_ret_out(outwire, wires, vals) true_result = [ vals[0][cycle] * vals[1][cycle] + vals[2][cycle] for cycle in range(len(vals[0])) ] self.assertEqual(out_vals, true_result)
def test_gen_fma_1(self): wires, vals = utils.make_inputs_and_values( max_bitwidth=8, num_wires=8, dist=utils.inverse_power_dist) # mixing tuples and lists solely for readability purposes mult_pairs = [(wires[0], wires[1]), (wires[2], wires[3]), (wires[4], wires[5])] add_wires = (wires[6], wires[7]) outwire = pyrtl.Output(name="test") outwire <<= multipliers.generalized_fma(mult_pairs, add_wires, signed=False) out_vals = utils.sim_and_ret_out(outwire, wires, vals) true_result = [ vals[0][cycle] * vals[1][cycle] + vals[2][cycle] * vals[3][cycle] + vals[4][cycle] * vals[5][cycle] + vals[6][cycle] + vals[7][cycle] for cycle in range(len(vals[0])) ] self.assertEqual(out_vals, true_result)
def test_list_of_long_wires(self): in_wires, vals = utils.make_inputs_and_values(4, exact_bitwidth=13) out = pyrtl.Output(name='o') out <<= pyrtl.corecircuits.xor_all_bits(in_wires) expected = [v1 ^ v2 ^ v3 ^ v4 for v1, v2, v3, v4 in zip(*vals)] self.assertEqual(expected, utils.sim_and_ret_out(out, in_wires, vals))