Пример #1
0
def alu(a, b, op):
    """
        Implementation of the desired simplified ALU:
        if op == 0: return a and b
        else if op == 1: return a xnor b
        else if op == 2" return a + b
    """
    # Operation 0: a and b
    op0 = a & b
    # Operation 1: a xnor b
    op1 = ~(a ^ b)
    # Operation 2: a + b
    op2_s, op2_c = half_adder(a, b)
    # Based on the given "op", return the proper signals as outputs
    alu_r = pyrtl.WireVector(bitwidth=1, name='alu_r')
    alu_cout = pyrtl.WireVector(bitwidth=1, name='alu_cout')

    with pyrtl.conditional_assignment:
        with op == 0b00:
            alu_r |= op0
            alu_cout |= 0
        with op == 0b01:
            alu_r |= op1
            alu_cout |= 0
        with op == 0b10:
            alu_r |= op2_s
            alu_cout |= op2_c

    return alu_r, alu_cout
Пример #2
0
def alu(a, b, op):
    """
        Implementation of the desired simplified ALU:
        if op == 0: return a and b
        else if op == 1: return a xnor b
        else if op == 2" return a + b
    """
    # Operation 0: a and b
    op0 = a & b
    # Operation 1: a xnor b
    op1 = ~(a ^ b)
    # Operation 2: a + b
    op2_c, op2_s = half_adder(a, b)
    # Based on the given "op", return the proper signals as outputs
    alu_r = pyrtl.WireVector(bitwidth=1)
    alu_cout = pyrtl.WireVector(bitwidth=1)
    # < add your code here >
    with pyrtl.conditional_assignment:
        with op == 0:
            alu_r |= op0
        with op == 1:
            alu_r |= op1
        with op == 2:
            alu_r |= op2_c
            alu_cout |= op2_s

    return alu_r, alu_cout
Пример #3
0
    def test_one_bit_adder_matches_expected(self):
        temp1 = pyrtl.WireVector(bitwidth=1, name='temp1')
        temp2 = pyrtl.WireVector()

        a, b, c = pyrtl.Input(1, 'a'), pyrtl.Input(1, 'b'), pyrtl.Input(1, 'c')
        sum, carry_out = pyrtl.Output(1, 'sum'), pyrtl.Output(1, 'carry_out')


        sum <<= a ^ b ^ c

        temp1 <<= a & b  # connect the result of a & b to the pre-allocated wirevector
        temp2 <<= a & c
        temp3 = b & c  # temp3 IS the result of b & c (this is the first mention of temp3)
        carry_out <<= temp1 | temp2 | temp3

        sim_trace = pyrtl.SimulationTrace()
        sim = pyrtl.Simulation(tracer=sim_trace)
        for cycle in range(15):
            sim.step({
                'a': random.choice([0, 1]),
                'b': random.choice([0, 1]),
                'c': random.choice([0, 1])
                })

        htmlstring = inputoutput.trace_to_html(sim_trace) # tests if it compiles or not
Пример #4
0
 def test_no_dup(self):
     sel = pyrtl.WireVector(3)
     a = pyrtl.WireVector(3)
     b = pyrtl.WireVector(3)
     res = muxes.sparse_mux(sel, {6: a, 2: b})
     self.assertIsNot(res, a)
     self.assertIsNot(res, b)
Пример #5
0
 def test_mux_enough_inputs_with_default(self):
     a = pyrtl.WireVector(name='a', bitwidth=3)
     b = pyrtl.WireVector(name='b', bitwidth=1)
     c = pyrtl.WireVector(name='c', bitwidth=1)
     d = pyrtl.WireVector(name='d', bitwidth=1)
     s = pyrtl.WireVector(name='s', bitwidth=2)
     r = pyrtl.corecircuits.mux(s, a, b, c, d, default=0)
Пример #6
0
    def test_equivelence_of_different_nets(self):
        a = pyrtl.WireVector()
        b = pyrtl.WireVector()
        c = pyrtl.WireVector()

        n = pyrtl.LogicNet('-', 'John', (a, b), (c, ))
        net = pyrtl.LogicNet('+', 'John', (a, b), (c, ))
        net2 = pyrtl.LogicNet('+', 'xx', (a, b), (c, ))
        net3 = pyrtl.LogicNet('+', 'xx', (b, a), (c, ))
        net4 = pyrtl.LogicNet('+', 'xx', (b, a, c), (c, ))
        net5 = pyrtl.LogicNet('+', 'xx', (b, a, c), (c, a))
        net6 = pyrtl.LogicNet('+', 'xx', (b, a, c), (a, ))

        self.assertDifferentNets(n, net)
        self.assertDifferentNets(net, net2)
        self.assertDifferentNets(net2, net3)
        self.assertDifferentNets(net3, net4)
        self.assertDifferentNets(net4, net5)
        self.assertDifferentNets(net4, net6)
        self.assertDifferentNets(net5, net6)

        # some extra edge cases to check
        netx_1 = pyrtl.LogicNet('+', 'John', (a, a), (c, ))
        netx_2 = pyrtl.LogicNet('+', 'John', (a, ), (c, ))
        self.assertDifferentNets(netx_1, netx_2)
Пример #7
0
 def test_mux_too_many_inputs(self):
     a = pyrtl.WireVector(name='a', bitwidth=3)
     b = pyrtl.WireVector(name='b', bitwidth=1)
     c = pyrtl.WireVector(name='c', bitwidth=1)
     s = pyrtl.WireVector(name='s', bitwidth=1)
     with self.assertRaises(pyrtl.PyrtlError):
         r = pyrtl.mux(s, a, b, c)
Пример #8
0
 def test_mux_not_enough_inputs(self):
     a = pyrtl.WireVector(name='a', bitwidth=3)
     b = pyrtl.WireVector(name='b', bitwidth=1)
     c = pyrtl.WireVector(name='c', bitwidth=1)
     s = pyrtl.WireVector(name='s', bitwidth=2)
     with self.assertRaises(pyrtl.PyrtlError):
         r = pyrtl.corecircuits.mux(s, a, b, c)
Пример #9
0
 def test_equivelence_of_same_nets(self):
     a = pyrtl.WireVector(1)
     b = pyrtl.WireVector(1)
     c = pyrtl.WireVector(1)
     net = pyrtl.LogicNet('+', 'xx', (a, b), (c,))
     net2 = pyrtl.LogicNet('+', 'xx', (a, b), (c,))
     self.assertIsNot(net, net2)
     self.assertEqual(net, net2)
Пример #10
0
    def decryption_statem(self, ciphertext_in, key_in, reset):
        """
        Builds a multiple cycle AES Decryption state machine circuit

        :param reset: a one bit signal telling the state machine
          to reset and accept the current plaintext and key
        :return ready, plain_text: ready is a one bit signal showing
          that the decryption result (plain_text) has been calculated.

        """
        if len(key_in) != len(ciphertext_in):
            raise pyrtl.PyrtlError(
                "AES key and ciphertext should be the same length")

        cipher_text, key = (pyrtl.Register(len(ciphertext_in))
                            for i in range(2))
        key_exp_in, add_round_in = (pyrtl.WireVector(len(ciphertext_in))
                                    for i in range(2))

        # this is not part of the state machine as we need the keys in
        # reverse order...
        reversed_key_list = reversed(self._key_gen(key_exp_in))

        counter = pyrtl.Register(4, 'counter')
        round = pyrtl.WireVector(4)
        counter.next <<= round

        inv_shift = self._inv_shift_rows(cipher_text)
        inv_sub = self._sub_bytes(inv_shift, True)
        key_out = pyrtl.mux(round, *reversed_key_list, default=0)
        add_round_out = self._add_round_key(add_round_in, key_out)
        inv_mix_out = self._mix_columns(add_round_out, True)

        with pyrtl.conditional_assignment:
            with reset == 1:
                round |= 0
                key.next |= key_in
                key_exp_in |= key_in  # to lower the number of cycles needed
                cipher_text.next |= add_round_out
                add_round_in |= ciphertext_in

            with counter == 10:  # keep everything the same
                round |= counter
                cipher_text.next |= cipher_text

            with pyrtl.otherwise:  # running through AES
                round |= counter + 1

                key.next |= key
                key_exp_in |= key
                add_round_in |= inv_sub
                with counter == 9:
                    cipher_text.next |= add_round_out
                with pyrtl.otherwise:
                    cipher_text.next |= inv_mix_out

        ready = (counter == 10)
        return ready, cipher_text
Пример #11
0
 def test_bad_bitwidth(self):
     with self.assertRaises(pyrtl.PyrtlError):
         x = pyrtl.WireVector(bitwidth='happy')
     with self.assertRaises(pyrtl.PyrtlError):
         x = pyrtl.WireVector(bitwidth=-1)
     with self.assertRaises(pyrtl.PyrtlError):
         x = pyrtl.WireVector(bitwidth=0)
     y = pyrtl.WireVector(1)
     with self.assertRaises(pyrtl.PyrtlError):
         x = pyrtl.WireVector(y)
Пример #12
0
    def test_timing_basic_2(self):
        inwire, inwire2 = pyrtl.Input(bitwidth=1), pyrtl.Input(bitwidth=1)
        inwire3 = pyrtl.Input(bitwidth=1)
        tempwire, tempwire2 = pyrtl.WireVector(), pyrtl.WireVector()
        outwire = pyrtl.Output()

        tempwire <<= inwire | inwire2
        tempwire2 <<= ~tempwire
        outwire <<= tempwire2 & inwire3
        self.everything_t_procedure(3, 3)
Пример #13
0
    def test_combo_1(self):
        inwire, inwire2 = pyrtl.Input(bitwidth=1), pyrtl.Input(bitwidth=1)
        tempwire, tempwire2 = pyrtl.WireVector(), pyrtl.WireVector()
        inwire3 = pyrtl.Input(bitwidth=1)
        outwire = pyrtl.Output()

        tempwire <<= inwire | inwire2
        tempwire2 <<= ~tempwire
        outwire <<= tempwire2 & inwire3
        self.everything_t_procedure(252.3, 252.3)
Пример #14
0
 def test_wrong_input_types_fail(self):
     instr = pyrtl.WireVector(name='instr', bitwidth=32)
     x = pyrtl.WireVector(name='x', bitwidth=10)
     with self.assertRaises(pyrtl.PyrtlError):
         o = pyrtl.chop(instr, x, 10, 12)
     with self.assertRaises(pyrtl.PyrtlError):
         o = pyrtl.chop(instr, 10, x, 12)
     with self.assertRaises(pyrtl.PyrtlError):
         o = pyrtl.chop(instr, x)
     with self.assertRaises(pyrtl.PyrtlError):
         o = pyrtl.chop(10, 5, 5)
Пример #15
0
    def test_net_odd_wires(self):
        wire = pyrtl.WireVector(2, 'wire')
        net = self.new_net(args=(wire, wire))
        other_block = pyrtl.Block()
        wire._block = other_block
        self.invalid_net("net references different block", net)

        pyrtl.reset_working_block()
        wire = pyrtl.WireVector(2, 'wire')
        net = self.new_net(args=(wire,))
        pyrtl.working_block().remove_wirevector(wire)
        self.invalid_net("net with unknown source", net)
Пример #16
0
    def test_randomly_replace(self):
        a, b = pyrtl.WireVector(3), pyrtl.WireVector(3)
        o = a & b
        insert_random_inversions(1)
        block = pyrtl.working_block()
        self.num_net_of_type('~', 3, block)
        self.num_net_of_type('&', 1, block)

        new_and_net = block.logic_subset('&').pop()
        for arg in new_and_net.args:
            self.assertIsNot(arg, a)
            self.assertIsNot(arg, b)
        self.assertIsNot(new_and_net.dests[0], o)
Пример #17
0
    def test_wirevector_1(self):
        inwire = pyrtl.Input(bitwidth=1)
        tempwire0, tempwire1 = pyrtl.WireVector(bitwidth=1), pyrtl.WireVector(bitwidth=1)
        tempwire2 = pyrtl.WireVector(bitwidth=1)
        outwire = pyrtl.Output()

        tempwire0 <<= inwire
        tempwire1 <<= tempwire0
        tempwire2 <<= tempwire1
        outwire <<= ~tempwire2
        self.everything_t_procedure(48.5, 48.5)
        block = pyrtl.working_block()
        self.assert_num_net(3, block)
Пример #18
0
    def __init__(self, depth_width=2, data_width=32):
        aw = depth_width
        dw = data_width

        self.wr_data_i = pyrtl.Input(dw, 'wr_data_i')
        self.wr_en_i = pyrtl.Input(1, 'wr_en_i')
        self.rd_data_o = pyrtl.Output(dw, 'rd_data_o')
        self.rd_en_i = pyrtl.Input(1, 'rd_en_i')
        self.full_o = pyrtl.Output(1, 'full_o')
        self.empty_o = pyrtl.Output(1, 'empty_o')
        self.one_left = pyrtl.Output(1, 'one_left')

        self.reset = pyrtl.Input(1, 'reset')

        self.write_pointer = pyrtl.Register(aw + 1, 'write_pointer')
        self.read_pointer = pyrtl.Register(aw + 1, 'read_pointer')

        self.read_plus_1 = pyrtl.Const(1, 1) + self.read_pointer
        self.read_pointer.next <<= pyrtl.select(self.rd_en_i,
                                                truecase=self.read_plus_1,
                                                falsecase=self.read_pointer)

        self.write_pointer.next <<= pyrtl.select(
            self.reset,
            truecase=pyrtl.Const(0, aw + 1),
            falsecase=pyrtl.select(self.wr_en_i,
                                   truecase=self.write_pointer + 1,
                                   falsecase=self.write_pointer))

        self.empty_int = pyrtl.WireVector(1, 'empty_int')
        self.full_or_empty = pyrtl.WireVector(1, 'full_or_empty')

        self.empty_int <<= self.write_pointer[aw] == self.read_pointer[aw]
        self.full_or_empty <<= self.write_pointer[0:aw] == self.read_pointer[
            0:aw]
        self.full_o <<= self.full_or_empty & ~self.empty_int
        self.empty_o <<= self.full_or_empty & self.empty_int
        self.one_left <<= (self.read_pointer + 1) == self.write_pointer

        self.mem = m = _RAM(num_entries=1 << aw,
                            data_nbits=dw,
                            name='FIFOStorage')
        m.wen <<= self.wr_en_i
        m.ren <<= self.rd_en_i
        m.raddr <<= self.read_pointer[0:aw]
        m.waddr <<= self.write_pointer[0:aw]
        m.wdata <<= self.wr_data_i
        self.rd_data_o <<= pyrtl.select(self.rd_en_i,
                                        truecase=m.rdata,
                                        falsecase=pyrtl.Const(0, dw))
Пример #19
0
 def test_wire_net_removal_2(self):
     inwire = pyrtl.Input(bitwidth=3)
     tempwire = pyrtl.WireVector()
     tempwire2 = pyrtl.WireVector()
     outwire = pyrtl.Output()
     tempwire <<= inwire
     tempwire2 <<= tempwire
     outwire <<= tempwire
     pyrtl.synthesize()
     pyrtl.optimize()
     # should remove the middle wires but keep the input
     block = pyrtl.working_block(None)
     self.assertEqual(len(block.logic), 5)
     self.assertEqual(len(block.wirevector_set), 6)
Пример #20
0
    def encrypt_state_m(self, plaintext_in, key_in, reset):
        """
        Builds a multiple cycle AES Encryption state machine circuit

        :param reset: a one bit signal telling the state machine
          to reset and accept the current plaintext and key
        :return ready, cipher_text: ready is a one bit signal showing
          that the encryption result (cipher_text) has been calculated.

        """
        if len(key_in) != len(plaintext_in):
            raise pyrtl.PyrtlError(
                "AES key and plaintext should be the same length")

        plain_text, key = (pyrtl.Register(len(plaintext_in)) for i in range(2))
        key_exp_in, add_round_in = (pyrtl.WireVector(len(plaintext_in))
                                    for i in range(2))

        counter = pyrtl.Register(4, 'counter')
        round = pyrtl.WireVector(4, 'round')
        counter.next <<= round
        sub_out = self._sub_bytes(plain_text)
        shift_out = self._shift_rows(sub_out)
        mix_out = self._mix_columns(shift_out)
        key_out = self._key_expansion(key, counter)
        add_round_out = self._add_round_key(add_round_in, key_exp_in)
        with pyrtl.conditional_assignment:
            with reset == 1:
                round |= 0
                key_exp_in |= key_in  # to lower the number of cycles
                plain_text.next |= add_round_out
                key.next |= key_in
                add_round_in |= plaintext_in

            with counter == 10:  # keep everything the same
                round |= counter
                plain_text.next |= plain_text

            with pyrtl.otherwise:  # running through AES
                round |= counter + 1
                key_exp_in |= key_out
                plain_text.next |= add_round_out
                key.next |= key_out
                with counter == 9:
                    add_round_in |= shift_out
                with pyrtl.otherwise:
                    add_round_in |= mix_out

        ready = (counter == 10)
        return ready, plain_text
Пример #21
0
def decode_instruction(instr):
    # output data
    op = pyrtl.WireVector(bitwidth=6, name='op')
    rs = pyrtl.WireVector(bitwidth=5, name='rs')
    rt = pyrtl.WireVector(bitwidth=5, name='rt')
    rd = pyrtl.WireVector(bitwidth=5, name='rd')
    sh = pyrtl.WireVector(bitwidth=5, name='sh')
    func = pyrtl.WireVector(bitwidth=6, name='func')
    imm = pyrtl.WireVector(bitwidth=16, name='imm')
    addr = pyrtl.WireVector(bitwidth=26, name='addr')

    #decode the instruction into its parts
    '''
    R-TYPE: op, rs, rt, rd, sh, funct
    I-TYPE: op, rs, rt, imm
    J-TYPE: op, addr
    '''
    op <<= instr[26:32]
    rs <<= instr[21:26]
    rt <<= instr[16:21]
    rd <<= instr[11:16]
    sh <<= instr[6:11]
    func <<= instr[0:6]
    imm <<= instr[0:16]
    addr <<= instr[0:26]
    return op, rs, rt, rd, sh, func, imm, addr
Пример #22
0
    def test_wirevector_1(self):
        inwire = pyrtl.Input(bitwidth=1)
        tempwire0, tempwire1 = pyrtl.WireVector(bitwidth=1), pyrtl.WireVector(
            bitwidth=1)
        tempwire2 = pyrtl.WireVector(bitwidth=1)
        outwire = pyrtl.Output()

        tempwire0 <<= inwire
        tempwire1 <<= tempwire0
        tempwire2 <<= tempwire1
        outwire <<= ~tempwire2
        self.everything_t_procedure(1, 1)
        block = pyrtl.working_block()
        self.assertEqual(len(block.logic), 3)
Пример #23
0
 def test_wire_net_removal_2(self):
     inwire = pyrtl.Input(bitwidth=3)
     tempwire = pyrtl.WireVector()
     tempwire2 = pyrtl.WireVector()
     outwire = pyrtl.Output()
     tempwire <<= inwire
     tempwire2 <<= tempwire
     outwire <<= tempwire
     pyrtl.synthesize()
     pyrtl.optimize()
     # should remove the middle wires but keep the input
     block = pyrtl.working_block()
     self.assert_num_net(5, block)
     self.assert_num_wires(6, block)
Пример #24
0
 def test_no_logic_net_comparisons(self):
     a = pyrtl.WireVector(bitwidth=3)
     b = pyrtl.WireVector(bitwidth=3)
     select = pyrtl.WireVector(bitwidth=3)
     outwire = pyrtl.WireVector(bitwidth=3)
     net1 = pyrtl.LogicNet(op='x', op_param=None, args=(select, a, b), dests=(outwire,))
     net2 = pyrtl.LogicNet(op='x', op_param=None, args=(select, b, a), dests=(outwire,))
     with self.assertRaises(pyrtl.PyrtlError):
         foo = net1 < net2
     with self.assertRaises(pyrtl.PyrtlError):
         foo = net1 <= net2
     with self.assertRaises(pyrtl.PyrtlError):
         foo = net1 > net2
     with self.assertRaises(pyrtl.PyrtlError):
         foo = net1 >= net2
Пример #25
0
    def test_timing_error(self):
        inwire, inwire2 = pyrtl.Input(bitwidth=1), pyrtl.Input(bitwidth=1)
        tempwire, tempwire2 = pyrtl.WireVector(1), pyrtl.WireVector(1)
        outwire = pyrtl.Output()

        tempwire <<= ~(inwire & tempwire2)
        tempwire2 <<= ~(inwire2 & tempwire)
        outwire <<= tempwire

        with self.assertRaises(pyrtl.PyrtlError):
            pyrtl.synthesize()
            pyrtl.optimize()
            block = pyrtl.working_block()
            timing = estimate.TimingAnalysis(block)
            timing_max_length = timing.max_length()
Пример #26
0
    def test_weird_wire_names(self):
        """
        Some simulations need to be careful when handling special names
        (eg Fastsim June 2016)
        """
        i = pyrtl.Input(8, '"182&!!!\n')
        o = pyrtl.Output(8, '*^*)#*$\'*')
        o2 = pyrtl.Output(8, 'test@+')
        w = pyrtl.WireVector(8, '[][[-=--09888')
        r = pyrtl.Register(8, '&@#)^#@^&(asdfkhafkjh')

        w <<= i
        r.next <<= i
        o <<= w
        o2 <<= r

        trace = pyrtl.SimulationTrace()
        sim = self.sim(tracer=trace)

        sim.step({i: 28})
        self.assertEqual(sim.inspect(o), 28)
        self.assertEqual(sim.inspect(o.name), 28)
        self.assertEqual(trace.trace[o.name], [28])

        sim.step({i: 233})
        self.assertEqual(sim.inspect(o), 233)
        self.assertEqual(sim.inspect(o2), 28)
        self.assertEqual(sim.inspect(o2.name), 28)
        self.assertEqual(trace.trace[o2.name], [0, 28])
Пример #27
0
 def test_no_immed_operators(self):
     with self.assertRaises(pyrtl.PyrtlError):
         x = pyrtl.WireVector(bitwidth=3)
         x &= 2
     with self.assertRaises(pyrtl.PyrtlError):
         x = pyrtl.WireVector(bitwidth=3)
         x ^= 2
     with self.assertRaises(pyrtl.PyrtlError):
         x = pyrtl.WireVector(bitwidth=3)
         x += 2
     with self.assertRaises(pyrtl.PyrtlError):
         x = pyrtl.WireVector(bitwidth=3)
         x -= 2
     with self.assertRaises(pyrtl.PyrtlError):
         x = pyrtl.WireVector(bitwidth=3)
         x *= 2
Пример #28
0
 def test_dup_consts2(self):
     sel = pyrtl.WireVector(3)
     c1 = pyrtl.Const(4)
     c2 = pyrtl.Const(4)
     res = muxes.sparse_mux(sel, {6: c1, 2: c2})
     self.assertIsInstance(res, pyrtl.Const)
     self.assertEqual(res.val, 4)
Пример #29
0
 def test_error_condition_connect_const(self):
     i = pyrtl.Const(3, 2)
     o = pyrtl.WireVector(bitwidth=2, name='o')
     with pyrtl.conditional_assignment:
         with i <= 2:
             with self.assertRaises(pyrtl.PyrtlError):
                 i |= o
Пример #30
0
 def test_condition_nice_error_message(self):
     with self.assertRaises(pyrtl.PyrtlError):
         i = pyrtl.Register(bitwidth=2, name='i')
         o = pyrtl.WireVector(bitwidth=2, name='o')
         i.next <<= i + 1
         with i <= 2:
             o |= 1