def alu(a, b, op): """ Implementation of the desired simplified ALU: if op == 0: return a and b else if op == 1: return a xnor b else if op == 2" return a + b """ # Operation 0: a and b op0 = a & b # Operation 1: a xnor b op1 = ~(a ^ b) # Operation 2: a + b op2_s, op2_c = half_adder(a, b) # Based on the given "op", return the proper signals as outputs alu_r = pyrtl.WireVector(bitwidth=1, name='alu_r') alu_cout = pyrtl.WireVector(bitwidth=1, name='alu_cout') with pyrtl.conditional_assignment: with op == 0b00: alu_r |= op0 alu_cout |= 0 with op == 0b01: alu_r |= op1 alu_cout |= 0 with op == 0b10: alu_r |= op2_s alu_cout |= op2_c return alu_r, alu_cout
def alu(a, b, op): """ Implementation of the desired simplified ALU: if op == 0: return a and b else if op == 1: return a xnor b else if op == 2" return a + b """ # Operation 0: a and b op0 = a & b # Operation 1: a xnor b op1 = ~(a ^ b) # Operation 2: a + b op2_c, op2_s = half_adder(a, b) # Based on the given "op", return the proper signals as outputs alu_r = pyrtl.WireVector(bitwidth=1) alu_cout = pyrtl.WireVector(bitwidth=1) # < add your code here > with pyrtl.conditional_assignment: with op == 0: alu_r |= op0 with op == 1: alu_r |= op1 with op == 2: alu_r |= op2_c alu_cout |= op2_s return alu_r, alu_cout
def test_one_bit_adder_matches_expected(self): temp1 = pyrtl.WireVector(bitwidth=1, name='temp1') temp2 = pyrtl.WireVector() a, b, c = pyrtl.Input(1, 'a'), pyrtl.Input(1, 'b'), pyrtl.Input(1, 'c') sum, carry_out = pyrtl.Output(1, 'sum'), pyrtl.Output(1, 'carry_out') sum <<= a ^ b ^ c temp1 <<= a & b # connect the result of a & b to the pre-allocated wirevector temp2 <<= a & c temp3 = b & c # temp3 IS the result of b & c (this is the first mention of temp3) carry_out <<= temp1 | temp2 | temp3 sim_trace = pyrtl.SimulationTrace() sim = pyrtl.Simulation(tracer=sim_trace) for cycle in range(15): sim.step({ 'a': random.choice([0, 1]), 'b': random.choice([0, 1]), 'c': random.choice([0, 1]) }) htmlstring = inputoutput.trace_to_html(sim_trace) # tests if it compiles or not
def test_no_dup(self): sel = pyrtl.WireVector(3) a = pyrtl.WireVector(3) b = pyrtl.WireVector(3) res = muxes.sparse_mux(sel, {6: a, 2: b}) self.assertIsNot(res, a) self.assertIsNot(res, b)
def test_mux_enough_inputs_with_default(self): a = pyrtl.WireVector(name='a', bitwidth=3) b = pyrtl.WireVector(name='b', bitwidth=1) c = pyrtl.WireVector(name='c', bitwidth=1) d = pyrtl.WireVector(name='d', bitwidth=1) s = pyrtl.WireVector(name='s', bitwidth=2) r = pyrtl.corecircuits.mux(s, a, b, c, d, default=0)
def test_equivelence_of_different_nets(self): a = pyrtl.WireVector() b = pyrtl.WireVector() c = pyrtl.WireVector() n = pyrtl.LogicNet('-', 'John', (a, b), (c, )) net = pyrtl.LogicNet('+', 'John', (a, b), (c, )) net2 = pyrtl.LogicNet('+', 'xx', (a, b), (c, )) net3 = pyrtl.LogicNet('+', 'xx', (b, a), (c, )) net4 = pyrtl.LogicNet('+', 'xx', (b, a, c), (c, )) net5 = pyrtl.LogicNet('+', 'xx', (b, a, c), (c, a)) net6 = pyrtl.LogicNet('+', 'xx', (b, a, c), (a, )) self.assertDifferentNets(n, net) self.assertDifferentNets(net, net2) self.assertDifferentNets(net2, net3) self.assertDifferentNets(net3, net4) self.assertDifferentNets(net4, net5) self.assertDifferentNets(net4, net6) self.assertDifferentNets(net5, net6) # some extra edge cases to check netx_1 = pyrtl.LogicNet('+', 'John', (a, a), (c, )) netx_2 = pyrtl.LogicNet('+', 'John', (a, ), (c, )) self.assertDifferentNets(netx_1, netx_2)
def test_mux_too_many_inputs(self): a = pyrtl.WireVector(name='a', bitwidth=3) b = pyrtl.WireVector(name='b', bitwidth=1) c = pyrtl.WireVector(name='c', bitwidth=1) s = pyrtl.WireVector(name='s', bitwidth=1) with self.assertRaises(pyrtl.PyrtlError): r = pyrtl.mux(s, a, b, c)
def test_mux_not_enough_inputs(self): a = pyrtl.WireVector(name='a', bitwidth=3) b = pyrtl.WireVector(name='b', bitwidth=1) c = pyrtl.WireVector(name='c', bitwidth=1) s = pyrtl.WireVector(name='s', bitwidth=2) with self.assertRaises(pyrtl.PyrtlError): r = pyrtl.corecircuits.mux(s, a, b, c)
def test_equivelence_of_same_nets(self): a = pyrtl.WireVector(1) b = pyrtl.WireVector(1) c = pyrtl.WireVector(1) net = pyrtl.LogicNet('+', 'xx', (a, b), (c,)) net2 = pyrtl.LogicNet('+', 'xx', (a, b), (c,)) self.assertIsNot(net, net2) self.assertEqual(net, net2)
def decryption_statem(self, ciphertext_in, key_in, reset): """ Builds a multiple cycle AES Decryption state machine circuit :param reset: a one bit signal telling the state machine to reset and accept the current plaintext and key :return ready, plain_text: ready is a one bit signal showing that the decryption result (plain_text) has been calculated. """ if len(key_in) != len(ciphertext_in): raise pyrtl.PyrtlError( "AES key and ciphertext should be the same length") cipher_text, key = (pyrtl.Register(len(ciphertext_in)) for i in range(2)) key_exp_in, add_round_in = (pyrtl.WireVector(len(ciphertext_in)) for i in range(2)) # this is not part of the state machine as we need the keys in # reverse order... reversed_key_list = reversed(self._key_gen(key_exp_in)) counter = pyrtl.Register(4, 'counter') round = pyrtl.WireVector(4) counter.next <<= round inv_shift = self._inv_shift_rows(cipher_text) inv_sub = self._sub_bytes(inv_shift, True) key_out = pyrtl.mux(round, *reversed_key_list, default=0) add_round_out = self._add_round_key(add_round_in, key_out) inv_mix_out = self._mix_columns(add_round_out, True) with pyrtl.conditional_assignment: with reset == 1: round |= 0 key.next |= key_in key_exp_in |= key_in # to lower the number of cycles needed cipher_text.next |= add_round_out add_round_in |= ciphertext_in with counter == 10: # keep everything the same round |= counter cipher_text.next |= cipher_text with pyrtl.otherwise: # running through AES round |= counter + 1 key.next |= key key_exp_in |= key add_round_in |= inv_sub with counter == 9: cipher_text.next |= add_round_out with pyrtl.otherwise: cipher_text.next |= inv_mix_out ready = (counter == 10) return ready, cipher_text
def test_bad_bitwidth(self): with self.assertRaises(pyrtl.PyrtlError): x = pyrtl.WireVector(bitwidth='happy') with self.assertRaises(pyrtl.PyrtlError): x = pyrtl.WireVector(bitwidth=-1) with self.assertRaises(pyrtl.PyrtlError): x = pyrtl.WireVector(bitwidth=0) y = pyrtl.WireVector(1) with self.assertRaises(pyrtl.PyrtlError): x = pyrtl.WireVector(y)
def test_timing_basic_2(self): inwire, inwire2 = pyrtl.Input(bitwidth=1), pyrtl.Input(bitwidth=1) inwire3 = pyrtl.Input(bitwidth=1) tempwire, tempwire2 = pyrtl.WireVector(), pyrtl.WireVector() outwire = pyrtl.Output() tempwire <<= inwire | inwire2 tempwire2 <<= ~tempwire outwire <<= tempwire2 & inwire3 self.everything_t_procedure(3, 3)
def test_combo_1(self): inwire, inwire2 = pyrtl.Input(bitwidth=1), pyrtl.Input(bitwidth=1) tempwire, tempwire2 = pyrtl.WireVector(), pyrtl.WireVector() inwire3 = pyrtl.Input(bitwidth=1) outwire = pyrtl.Output() tempwire <<= inwire | inwire2 tempwire2 <<= ~tempwire outwire <<= tempwire2 & inwire3 self.everything_t_procedure(252.3, 252.3)
def test_wrong_input_types_fail(self): instr = pyrtl.WireVector(name='instr', bitwidth=32) x = pyrtl.WireVector(name='x', bitwidth=10) with self.assertRaises(pyrtl.PyrtlError): o = pyrtl.chop(instr, x, 10, 12) with self.assertRaises(pyrtl.PyrtlError): o = pyrtl.chop(instr, 10, x, 12) with self.assertRaises(pyrtl.PyrtlError): o = pyrtl.chop(instr, x) with self.assertRaises(pyrtl.PyrtlError): o = pyrtl.chop(10, 5, 5)
def test_net_odd_wires(self): wire = pyrtl.WireVector(2, 'wire') net = self.new_net(args=(wire, wire)) other_block = pyrtl.Block() wire._block = other_block self.invalid_net("net references different block", net) pyrtl.reset_working_block() wire = pyrtl.WireVector(2, 'wire') net = self.new_net(args=(wire,)) pyrtl.working_block().remove_wirevector(wire) self.invalid_net("net with unknown source", net)
def test_randomly_replace(self): a, b = pyrtl.WireVector(3), pyrtl.WireVector(3) o = a & b insert_random_inversions(1) block = pyrtl.working_block() self.num_net_of_type('~', 3, block) self.num_net_of_type('&', 1, block) new_and_net = block.logic_subset('&').pop() for arg in new_and_net.args: self.assertIsNot(arg, a) self.assertIsNot(arg, b) self.assertIsNot(new_and_net.dests[0], o)
def test_wirevector_1(self): inwire = pyrtl.Input(bitwidth=1) tempwire0, tempwire1 = pyrtl.WireVector(bitwidth=1), pyrtl.WireVector(bitwidth=1) tempwire2 = pyrtl.WireVector(bitwidth=1) outwire = pyrtl.Output() tempwire0 <<= inwire tempwire1 <<= tempwire0 tempwire2 <<= tempwire1 outwire <<= ~tempwire2 self.everything_t_procedure(48.5, 48.5) block = pyrtl.working_block() self.assert_num_net(3, block)
def __init__(self, depth_width=2, data_width=32): aw = depth_width dw = data_width self.wr_data_i = pyrtl.Input(dw, 'wr_data_i') self.wr_en_i = pyrtl.Input(1, 'wr_en_i') self.rd_data_o = pyrtl.Output(dw, 'rd_data_o') self.rd_en_i = pyrtl.Input(1, 'rd_en_i') self.full_o = pyrtl.Output(1, 'full_o') self.empty_o = pyrtl.Output(1, 'empty_o') self.one_left = pyrtl.Output(1, 'one_left') self.reset = pyrtl.Input(1, 'reset') self.write_pointer = pyrtl.Register(aw + 1, 'write_pointer') self.read_pointer = pyrtl.Register(aw + 1, 'read_pointer') self.read_plus_1 = pyrtl.Const(1, 1) + self.read_pointer self.read_pointer.next <<= pyrtl.select(self.rd_en_i, truecase=self.read_plus_1, falsecase=self.read_pointer) self.write_pointer.next <<= pyrtl.select( self.reset, truecase=pyrtl.Const(0, aw + 1), falsecase=pyrtl.select(self.wr_en_i, truecase=self.write_pointer + 1, falsecase=self.write_pointer)) self.empty_int = pyrtl.WireVector(1, 'empty_int') self.full_or_empty = pyrtl.WireVector(1, 'full_or_empty') self.empty_int <<= self.write_pointer[aw] == self.read_pointer[aw] self.full_or_empty <<= self.write_pointer[0:aw] == self.read_pointer[ 0:aw] self.full_o <<= self.full_or_empty & ~self.empty_int self.empty_o <<= self.full_or_empty & self.empty_int self.one_left <<= (self.read_pointer + 1) == self.write_pointer self.mem = m = _RAM(num_entries=1 << aw, data_nbits=dw, name='FIFOStorage') m.wen <<= self.wr_en_i m.ren <<= self.rd_en_i m.raddr <<= self.read_pointer[0:aw] m.waddr <<= self.write_pointer[0:aw] m.wdata <<= self.wr_data_i self.rd_data_o <<= pyrtl.select(self.rd_en_i, truecase=m.rdata, falsecase=pyrtl.Const(0, dw))
def test_wire_net_removal_2(self): inwire = pyrtl.Input(bitwidth=3) tempwire = pyrtl.WireVector() tempwire2 = pyrtl.WireVector() outwire = pyrtl.Output() tempwire <<= inwire tempwire2 <<= tempwire outwire <<= tempwire pyrtl.synthesize() pyrtl.optimize() # should remove the middle wires but keep the input block = pyrtl.working_block(None) self.assertEqual(len(block.logic), 5) self.assertEqual(len(block.wirevector_set), 6)
def encrypt_state_m(self, plaintext_in, key_in, reset): """ Builds a multiple cycle AES Encryption state machine circuit :param reset: a one bit signal telling the state machine to reset and accept the current plaintext and key :return ready, cipher_text: ready is a one bit signal showing that the encryption result (cipher_text) has been calculated. """ if len(key_in) != len(plaintext_in): raise pyrtl.PyrtlError( "AES key and plaintext should be the same length") plain_text, key = (pyrtl.Register(len(plaintext_in)) for i in range(2)) key_exp_in, add_round_in = (pyrtl.WireVector(len(plaintext_in)) for i in range(2)) counter = pyrtl.Register(4, 'counter') round = pyrtl.WireVector(4, 'round') counter.next <<= round sub_out = self._sub_bytes(plain_text) shift_out = self._shift_rows(sub_out) mix_out = self._mix_columns(shift_out) key_out = self._key_expansion(key, counter) add_round_out = self._add_round_key(add_round_in, key_exp_in) with pyrtl.conditional_assignment: with reset == 1: round |= 0 key_exp_in |= key_in # to lower the number of cycles plain_text.next |= add_round_out key.next |= key_in add_round_in |= plaintext_in with counter == 10: # keep everything the same round |= counter plain_text.next |= plain_text with pyrtl.otherwise: # running through AES round |= counter + 1 key_exp_in |= key_out plain_text.next |= add_round_out key.next |= key_out with counter == 9: add_round_in |= shift_out with pyrtl.otherwise: add_round_in |= mix_out ready = (counter == 10) return ready, plain_text
def decode_instruction(instr): # output data op = pyrtl.WireVector(bitwidth=6, name='op') rs = pyrtl.WireVector(bitwidth=5, name='rs') rt = pyrtl.WireVector(bitwidth=5, name='rt') rd = pyrtl.WireVector(bitwidth=5, name='rd') sh = pyrtl.WireVector(bitwidth=5, name='sh') func = pyrtl.WireVector(bitwidth=6, name='func') imm = pyrtl.WireVector(bitwidth=16, name='imm') addr = pyrtl.WireVector(bitwidth=26, name='addr') #decode the instruction into its parts ''' R-TYPE: op, rs, rt, rd, sh, funct I-TYPE: op, rs, rt, imm J-TYPE: op, addr ''' op <<= instr[26:32] rs <<= instr[21:26] rt <<= instr[16:21] rd <<= instr[11:16] sh <<= instr[6:11] func <<= instr[0:6] imm <<= instr[0:16] addr <<= instr[0:26] return op, rs, rt, rd, sh, func, imm, addr
def test_wirevector_1(self): inwire = pyrtl.Input(bitwidth=1) tempwire0, tempwire1 = pyrtl.WireVector(bitwidth=1), pyrtl.WireVector( bitwidth=1) tempwire2 = pyrtl.WireVector(bitwidth=1) outwire = pyrtl.Output() tempwire0 <<= inwire tempwire1 <<= tempwire0 tempwire2 <<= tempwire1 outwire <<= ~tempwire2 self.everything_t_procedure(1, 1) block = pyrtl.working_block() self.assertEqual(len(block.logic), 3)
def test_wire_net_removal_2(self): inwire = pyrtl.Input(bitwidth=3) tempwire = pyrtl.WireVector() tempwire2 = pyrtl.WireVector() outwire = pyrtl.Output() tempwire <<= inwire tempwire2 <<= tempwire outwire <<= tempwire pyrtl.synthesize() pyrtl.optimize() # should remove the middle wires but keep the input block = pyrtl.working_block() self.assert_num_net(5, block) self.assert_num_wires(6, block)
def test_no_logic_net_comparisons(self): a = pyrtl.WireVector(bitwidth=3) b = pyrtl.WireVector(bitwidth=3) select = pyrtl.WireVector(bitwidth=3) outwire = pyrtl.WireVector(bitwidth=3) net1 = pyrtl.LogicNet(op='x', op_param=None, args=(select, a, b), dests=(outwire,)) net2 = pyrtl.LogicNet(op='x', op_param=None, args=(select, b, a), dests=(outwire,)) with self.assertRaises(pyrtl.PyrtlError): foo = net1 < net2 with self.assertRaises(pyrtl.PyrtlError): foo = net1 <= net2 with self.assertRaises(pyrtl.PyrtlError): foo = net1 > net2 with self.assertRaises(pyrtl.PyrtlError): foo = net1 >= net2
def test_timing_error(self): inwire, inwire2 = pyrtl.Input(bitwidth=1), pyrtl.Input(bitwidth=1) tempwire, tempwire2 = pyrtl.WireVector(1), pyrtl.WireVector(1) outwire = pyrtl.Output() tempwire <<= ~(inwire & tempwire2) tempwire2 <<= ~(inwire2 & tempwire) outwire <<= tempwire with self.assertRaises(pyrtl.PyrtlError): pyrtl.synthesize() pyrtl.optimize() block = pyrtl.working_block() timing = estimate.TimingAnalysis(block) timing_max_length = timing.max_length()
def test_weird_wire_names(self): """ Some simulations need to be careful when handling special names (eg Fastsim June 2016) """ i = pyrtl.Input(8, '"182&!!!\n') o = pyrtl.Output(8, '*^*)#*$\'*') o2 = pyrtl.Output(8, 'test@+') w = pyrtl.WireVector(8, '[][[-=--09888') r = pyrtl.Register(8, '&@#)^#@^&(asdfkhafkjh') w <<= i r.next <<= i o <<= w o2 <<= r trace = pyrtl.SimulationTrace() sim = self.sim(tracer=trace) sim.step({i: 28}) self.assertEqual(sim.inspect(o), 28) self.assertEqual(sim.inspect(o.name), 28) self.assertEqual(trace.trace[o.name], [28]) sim.step({i: 233}) self.assertEqual(sim.inspect(o), 233) self.assertEqual(sim.inspect(o2), 28) self.assertEqual(sim.inspect(o2.name), 28) self.assertEqual(trace.trace[o2.name], [0, 28])
def test_no_immed_operators(self): with self.assertRaises(pyrtl.PyrtlError): x = pyrtl.WireVector(bitwidth=3) x &= 2 with self.assertRaises(pyrtl.PyrtlError): x = pyrtl.WireVector(bitwidth=3) x ^= 2 with self.assertRaises(pyrtl.PyrtlError): x = pyrtl.WireVector(bitwidth=3) x += 2 with self.assertRaises(pyrtl.PyrtlError): x = pyrtl.WireVector(bitwidth=3) x -= 2 with self.assertRaises(pyrtl.PyrtlError): x = pyrtl.WireVector(bitwidth=3) x *= 2
def test_dup_consts2(self): sel = pyrtl.WireVector(3) c1 = pyrtl.Const(4) c2 = pyrtl.Const(4) res = muxes.sparse_mux(sel, {6: c1, 2: c2}) self.assertIsInstance(res, pyrtl.Const) self.assertEqual(res.val, 4)
def test_error_condition_connect_const(self): i = pyrtl.Const(3, 2) o = pyrtl.WireVector(bitwidth=2, name='o') with pyrtl.conditional_assignment: with i <= 2: with self.assertRaises(pyrtl.PyrtlError): i |= o
def test_condition_nice_error_message(self): with self.assertRaises(pyrtl.PyrtlError): i = pyrtl.Register(bitwidth=2, name='i') o = pyrtl.WireVector(bitwidth=2, name='o') i.next <<= i + 1 with i <= 2: o |= 1