Exemplo n.º 1
0
def barrel_shifter(shift_in, bit_in, direction, shift_dist, wrap_around=0):
    """
    Create a barrel shifter that operates on data based on the wire width
    :param shift_in:the input wire;
    :param bit_in: the 1-bit wire giving the value to shift in.
    :param direction: direction is a one bit wirevector representing shift direction
        0 = shift down, 1 = shift up.
    :param shift_dist: wirevector representing offset to shift
    :param wrap_around: ****currently not implemented*****
    :return: shifted wirevector
    """
    # Implement with logN stages pyrtl.muxing between shifted and un-shifted values

    val = shift_in
    append_val = bit_in
    log_length = int(math.log(len(shift_in) - 1, 2))  # note the one offset

    if len(shift_dist) > log_length:
        print('Warning: for barrel shifter, the shift distance wirevector '
              'has bits that are not used in the barrel shifter')

    for i in range(min(len(shift_dist), log_length)):
        shift_amt = pow(2, i)  # stages shift 1,2,4,8,...
        newval = pyrtl.mux(direction,
                           truecase=val[:-shift_amt],
                           falsecase=val[shift_amt:])
        newval = pyrtl.mux(direction,
                           truecase=pyrtl.concat(newval, append_val),
                           falsecase=pyrtl.concat(
                               append_val, newval))  # Build shifted value
        # pyrtl.mux shifted vs. unshifted by using i-th bit of shift amount signal
        val = pyrtl.mux(shift_dist[i - 1], truecase=newval, falsecase=val)
        append_val = pyrtl.concat(append_val, append_val)

    return val
Exemplo n.º 2
0
def barrel_shifter(shift_in, bit_in, direction, shift_dist, wrap_around=0):
    """
    Create a barrel shifter that operates on data based on the wire width
    :param shift_in:the input wire;
    :param bit_in: the 1-bit wire giving the value to shift in.
    :param direction: direction is a one bit wirevector representing shift direction
        0 = shift down, 1 = shift up.
    :param shift_dist: wirevector representing offset to shift
    :param wrap_around: ****currently not implemented*****
    :return: shifted wirevector
    """
    # Implement with logN stages pyrtl.muxing between shifted and un-shifted values

    val = shift_in
    append_val = bit_in
    log_length = int(math.log(len(shift_in)-1, 2))  # note the one offset

    if len(shift_dist) > log_length:
        print('Warning: for barrel shifter, the shift distance wirevector '
              'has bits that are not used in the barrel shifter')

    for i in range(min(len(shift_dist), log_length)):
        shift_amt = pow(2, i)  # stages shift 1,2,4,8,...
        newval = pyrtl.mux(direction, truecase=val[:-shift_amt], falsecase=val[shift_amt:])
        newval = pyrtl.mux(direction, truecase=pyrtl.concat(newval, append_val),
                           falsecase=pyrtl.concat(append_val, newval))  # Build shifted value
        # pyrtl.mux shifted vs. unshifted by using i-th bit of shift amount signal
        val = pyrtl.mux(shift_dist[i-1], truecase=newval, falsecase=val)
        append_val = pyrtl.concat(append_val, append_val)

    return val
Exemplo n.º 3
0
 def test_mux_enough_inputs_with_default(self):
     a = pyrtl.WireVector(name='a', bitwidth=3)
     b = pyrtl.WireVector(name='b', bitwidth=1)
     c = pyrtl.WireVector(name='c', bitwidth=1)
     d = pyrtl.WireVector(name='d', bitwidth=1)
     s = pyrtl.WireVector(name='s', bitwidth=2)
     r = pyrtl.mux(s, a, b, c, d, default=0)
Exemplo n.º 4
0
 def test_mux_not_enough_inputs(self):
     a = pyrtl.WireVector(name='a', bitwidth=3)
     b = pyrtl.WireVector(name='b', bitwidth=1)
     c = pyrtl.WireVector(name='c', bitwidth=1)
     s = pyrtl.WireVector(name='s', bitwidth=2)
     with self.assertRaises(pyrtl.PyrtlError):
         r = pyrtl.mux(s, a, b, c)
Exemplo n.º 5
0
 def test_mux_not_enough_inputs(self):
     a = pyrtl.WireVector(name='a', bitwidth=3)
     b = pyrtl.WireVector(name='b', bitwidth=1)
     c = pyrtl.WireVector(name='c', bitwidth=1)
     s = pyrtl.WireVector(name='s', bitwidth=2)
     with self.assertRaises(pyrtl.PyrtlError):
         r = pyrtl.mux(s, a, b, c)
Exemplo n.º 6
0
 def test_mux_enough_inputs_with_default(self):
     a = pyrtl.WireVector(name='a', bitwidth=3)
     b = pyrtl.WireVector(name='b', bitwidth=1)
     c = pyrtl.WireVector(name='c', bitwidth=1)
     d = pyrtl.WireVector(name='d', bitwidth=1)
     s = pyrtl.WireVector(name='s', bitwidth=2)
     r = pyrtl.mux(s, a, b, c, d, default=0)
Exemplo n.º 7
0
    def decryption_statem(self, ciphertext_in, key_in, reset):
        """
        Builds a multiple cycle AES Decryption state machine circuit

        :param reset: a one bit signal telling the state machine
          to reset and accept the current plaintext and key
        :return ready, plain_text: ready is a one bit signal showing
          that the decryption result (plain_text) has been calculated.

        """
        if len(key_in) != len(ciphertext_in):
            raise pyrtl.PyrtlError(
                "AES key and ciphertext should be the same length")

        cipher_text, key = (pyrtl.Register(len(ciphertext_in))
                            for i in range(2))
        key_exp_in, add_round_in = (pyrtl.WireVector(len(ciphertext_in))
                                    for i in range(2))

        # this is not part of the state machine as we need the keys in
        # reverse order...
        reversed_key_list = reversed(self._key_gen(key_exp_in))

        counter = pyrtl.Register(4, 'counter')
        round = pyrtl.WireVector(4)
        counter.next <<= round

        inv_shift = self._inv_shift_rows(cipher_text)
        inv_sub = self._sub_bytes(inv_shift, True)
        key_out = pyrtl.mux(round, *reversed_key_list, default=0)
        add_round_out = self._add_round_key(add_round_in, key_out)
        inv_mix_out = self._mix_columns(add_round_out, True)

        with pyrtl.conditional_assignment:
            with reset == 1:
                round |= 0
                key.next |= key_in
                key_exp_in |= key_in  # to lower the number of cycles needed
                cipher_text.next |= add_round_out
                add_round_in |= ciphertext_in

            with counter == 10:  # keep everything the same
                round |= counter
                cipher_text.next |= cipher_text

            with pyrtl.otherwise:  # running through AES
                round |= counter + 1

                key.next |= key
                key_exp_in |= key
                add_round_in |= inv_sub
                with counter == 9:
                    cipher_text.next |= add_round_out
                with pyrtl.otherwise:
                    cipher_text.next |= inv_mix_out

        ready = (counter == 10)
        return ready, cipher_text
Exemplo n.º 8
0
 def test_mux_too_many_inputs_with_default(self):
     a = pyrtl.WireVector(name='a', bitwidth=3)
     b = pyrtl.WireVector(name='b', bitwidth=1)
     c = pyrtl.WireVector(name='c', bitwidth=1)
     d = pyrtl.WireVector(name='d', bitwidth=1)
     e = pyrtl.WireVector(name='e', bitwidth=1)
     s = pyrtl.WireVector(name='s', bitwidth=2)
     with self.assertRaises(pyrtl.PyrtlError):
         r = pyrtl.mux(s, a, b, c, d, e, default=0)
Exemplo n.º 9
0
 def test_mux_too_many_inputs_with_default(self):
     a = pyrtl.WireVector(name='a', bitwidth=3)
     b = pyrtl.WireVector(name='b', bitwidth=1)
     c = pyrtl.WireVector(name='c', bitwidth=1)
     d = pyrtl.WireVector(name='d', bitwidth=1)
     e = pyrtl.WireVector(name='e', bitwidth=1)
     s = pyrtl.WireVector(name='s', bitwidth=2)
     with self.assertRaises(pyrtl.PyrtlError):
         r = pyrtl.mux(s, a, b, c, d, e, default=0)
Exemplo n.º 10
0
    def decryption_statem(self, ciphertext_in, key_in, reset):
        """
        Builds a multiple cycle AES Decryption state machine circuit

        :param reset: a one bit signal telling the state machine
          to reset and accept the current plaintext and key
        :return ready, plain_text: ready is a one bit signal showing
          that the decryption result (plain_text) has been calculated.

        """
        if len(key_in) != len(ciphertext_in):
            raise pyrtl.PyrtlError("AES key and ciphertext should be the same length")

        cipher_text, key = (pyrtl.Register(len(ciphertext_in)) for i in range(2))
        key_exp_in, add_round_in = (pyrtl.WireVector(len(ciphertext_in)) for i in range(2))

        # this is not part of the state machine as we need the keys in
        # reverse order...
        reversed_key_list = reversed(self._key_gen(key_exp_in))

        counter = pyrtl.Register(4, 'counter')
        round = pyrtl.WireVector(4)
        counter.next <<= round

        inv_shift = self._inv_shift_rows(cipher_text)
        inv_sub = self._sub_bytes(inv_shift, True)
        key_out = pyrtl.mux(round, *reversed_key_list, default=0)
        add_round_out = self._add_round_key(add_round_in, key_out)
        inv_mix_out = self._mix_columns(add_round_out, True)

        with pyrtl.conditional_assignment:
            with reset == 1:
                round |= 0
                key.next |= key_in
                key_exp_in |= key_in  # to lower the number of cycles needed
                cipher_text.next |= add_round_out
                add_round_in |= ciphertext_in

            with counter == 10:  # keep everything the same
                round |= counter
                cipher_text.next |= cipher_text

            with pyrtl.otherwise:  # running through AES
                round |= counter + 1

                key.next |= key
                key_exp_in |= key
                add_round_in |= inv_sub
                with counter == 9:
                    cipher_text.next |= add_round_out
                with pyrtl.otherwise:
                    cipher_text.next |= inv_mix_out

        ready = (counter == 10)
        return ready, cipher_text
Exemplo n.º 11
0
def basic_n_bit_mux(ctrl, mux_in, default=None):

    default = pyrtl.Const(0) if default is None else default
    for ctrl_i in ctrl:
        next_mux_in = []
        for j in range((len(mux_in) + 1) // 2):
            second = default if 2*j + 1 >= len(mux_in) else mux_in[2*j + 1]
            next_mux_in.append(pyrtl.mux(select=ctrl_i,
                                         falsecase=mux_in[2*j], truecase=second))
        mux_in = next_mux_in
    return mux_in[0]
Exemplo n.º 12
0
    def setUp(self):
        pyrtl.reset_working_block()
        bitwidth = 3
        self.a = pyrtl.Input(bitwidth=bitwidth)
        self.b = pyrtl.Input(bitwidth=bitwidth)
        self.sel = pyrtl.Input(bitwidth=1)
        self.muxout = pyrtl.Output(bitwidth=bitwidth, name='muxout')
        self.muxout <<= pyrtl.mux(self.sel, self.a, self.b)

        # build the actual simulation environment
        self.sim_trace = pyrtl.SimulationTrace()
        self.sim = self.sim(tracer=self.sim_trace)
Exemplo n.º 13
0
 def test_verilog_testbench_does_not_throw_error(self):
     zero = pyrtl.Input(1, 'zero')
     counter_output = pyrtl.Output(3, 'counter_output')
     counter = pyrtl.Register(3, 'counter')
     counter.next <<= pyrtl.mux(zero, counter + 1, 0)
     counter_output <<= counter
     sim_trace = pyrtl.SimulationTrace([counter_output, zero])
     sim = pyrtl.Simulation(tracer=sim_trace)
     for cycle in range(15):
         sim.step({zero: random.choice([0, 0, 0, 1])})
     with io.StringIO() as tbfile:
         pyrtl.output_verilog_testbench(tbfile, sim_trace)
Exemplo n.º 14
0
    def setUp(self):
        pyrtl.reset_working_block()
        bitwidth = 3
        self.a = pyrtl.Input(bitwidth=bitwidth)
        self.b = pyrtl.Input(bitwidth=bitwidth)
        self.sel = pyrtl.Input(bitwidth=1)
        self.muxout = pyrtl.Output(bitwidth=bitwidth, name='muxout')
        self.muxout <<= pyrtl.mux(self.sel, self.a, self.b)

        # build the actual simulation environment
        self.sim_trace = pyrtl.SimulationTrace()
        self.sim = self.sim(tracer=self.sim_trace)
Exemplo n.º 15
0
def basic_n_bit_mux(ctrl, mux_in, default=None):

    default = pyrtl.Const(0) if default is None else default
    for ctrl_i in ctrl:
        next_mux_in = []
        for j in range((len(mux_in) + 1) // 2):
            second = default if 2 * j + 1 >= len(mux_in) else mux_in[2 * j + 1]
            next_mux_in.append(
                pyrtl.mux(select=ctrl_i,
                          falsecase=mux_in[2 * j],
                          truecase=second))
        mux_in = next_mux_in
    return mux_in[0]
Exemplo n.º 16
0
    def test_as_graph_memory(self):
        m = pyrtl.MemBlock(addrwidth=2, bitwidth=2, name='m', max_read_ports=None)
        i = pyrtl.Register(bitwidth=2, name='i')
        o = pyrtl.WireVector(bitwidth=2, name='o')
        i.next <<= i + 1
        m[i] <<= pyrtl.mux((m[i] != 0), 0, m[i])
        o <<= m[i]

        b = pyrtl.working_block()
        src_g, dst_g = b.net_connections(False)
        self.check_graph_correctness(src_g, dst_g)

        src_g, dst_g = b.net_connections(True)
        self.check_graph_correctness(src_g, dst_g, True)
Exemplo n.º 17
0
 def test_mux_simulation(self):
     self.r.next <<= pyrtl.mux(self.r, 4, 3, 1, 7, 2, 6, 0, 5)
     self.check_trace('r 04213756\n')
Exemplo n.º 18
0
    })
sim_trace.render_trace(symbol_len=5, segment_size=5)

# ---- Exporting to Verilog ----

# However, not only do we want to have a method to import from Verilog, we also
# want a way to export it back out to Verilog as well. To demonstrate PyRTL's
# ability to export in Verilog, we will create a sample 3-bit counter. However
# unlike the example in example2, we extend it to be synchronously resetting.

pyrtl.reset_working_block()

zero = pyrtl.Input(1, 'zero')
counter_output = pyrtl.Output(3, 'counter_output')
counter = pyrtl.Register(3, 'counter')
counter.next <<= pyrtl.mux(zero, counter + 1, 0)
counter_output <<= counter

# The counter gets 0 in the next cycle if the "zero" signal goes high, otherwise just
# counter + 1.  Note that both "0" and "1" are bit extended to the proper length and
# here we are making use of that native add operation.  Let's dump this bad boy out
# to a Verilog file and see what is looks like (here we are using StringIO just to
# print it to a string for demo purposes; most likely you will want to pass a normal
# open file).

print("--- PyRTL Representation ---")
print(pyrtl.working_block())
print()

print("--- Verilog for the Counter ---")
with io.StringIO() as vfile:
Exemplo n.º 19
0
 def test_mux_too_many_inputs_with_extra_kwarg(self):
     a = pyrtl.WireVector(name='a', bitwidth=3)
     b = pyrtl.WireVector(name='b', bitwidth=1)
     s = pyrtl.WireVector(name='s', bitwidth=2)
     with self.assertRaises(pyrtl.PyrtlError):
         r = pyrtl.mux(s, a, b, default=0, foo=1)
Exemplo n.º 20
0
 def test_mux_not_enough_inputs_but_default(self):
     a = pyrtl.WireVector(name='a', bitwidth=3)
     b = pyrtl.WireVector(name='b', bitwidth=1)
     s = pyrtl.WireVector(name='s', bitwidth=2)
     r = pyrtl.mux(s, a, b, default=0)
Exemplo n.º 21
0
 def test_mux_not_enough_inputs_but_default(self):
     a = pyrtl.WireVector(name='a', bitwidth=3)
     b = pyrtl.WireVector(name='b', bitwidth=1)
     s = pyrtl.WireVector(name='s', bitwidth=2)
     r = pyrtl.mux(s, a, b, default=0)
Exemplo n.º 22
0
 def test_mux_too_many_inputs_with_extra_kwarg(self):
     a = pyrtl.WireVector(name='a', bitwidth=3)
     b = pyrtl.WireVector(name='b', bitwidth=1)
     s = pyrtl.WireVector(name='s', bitwidth=2)
     with self.assertRaises(pyrtl.PyrtlError):
         r = pyrtl.mux(s, a, b, default=0, foo=1)
Exemplo n.º 23
0
def SNmux(d, n, m, a1, a2, a3, a4):
    outs = sn(in1, in2, in3, in4)
    temp1 = pyrtl.mux(n, outs[0], outs[1], outs[2], outs[3])
    temp2 = rdelta(d, temp1)
    temp3 = pyrtl.mux(m, outs[0], outs[1], outs[2], outs[3])
    return rinhibit(temp2, temp3)
Exemplo n.º 24
0
sim_trace.render_trace(symbol_len=5, segment_size=5)


# ---- Exporting to Verilog ----

# However, not only do we want to have a method to import from Verilog, we also
# want a way to export it back out to Verilog as well. To demonstrate PyRTL's
# ability to export in Verilog, we will create a sample 3-bit counter. However
# unlike the example in example2, we extend it to be synchronously resetting.

pyrtl.reset_working_block()

zero = pyrtl.Input(1, 'zero')
counter_output = pyrtl.Output(3, 'counter_output')
counter = pyrtl.Register(3, 'counter')
counter.next <<= pyrtl.mux(zero, counter + 1, 0)
counter_output <<= counter

# The counter gets 0 in the next cycle if the "zero" signal goes high, otherwise just
# counter + 1.  Note that both "0" and "1" are bit extended to the proper length and
# here we are making use of that native add operation.  Let's dump this bad boy out
# to a verilog file and see what is looks like (here we are using StringIO just to
# print it to a string for demo purposes, most likely you will want to pass a normal
# open file).

print("--- PyRTL Representation ---")
print(pyrtl.working_block())
print()

print("--- Verilog for the Counter ---")
with io.StringIO() as vfile:
Exemplo n.º 25
0
 def test_mux_simulation(self):
     self.r.next <<= pyrtl.mux(self.r, 4, 3, 1, 7, 2, 6, 0, 5)
     self.check_trace('r 04213756\n')
Exemplo n.º 26
0
    else:
        lsbit, ripplecarry = one_bit_add(a[0], b[0], carry_in)
        msbits, carry_out = ripple_add(a[1:], b[1:], ripplecarry)
        sumbits = pyrtl.concat(msbits, lsbit)
    return sumbits, carry_out


# Inputs and outputs of the module
load = pyrtl.Input(1, 'load')
data = pyrtl.Input(8, 'data')
out = pyrtl.Output(8, 'out')

# Simple logic to allow loading of the counter
counter = pyrtl.Register(8, 'counter')
sum, carry_out = ripple_add(counter, pyrtl.Const("1'b1"))
counter.next <<= pyrtl.mux(load, sum, data)
out <<= counter

# Setup the simulation
sim_trace = pyrtl.SimulationTrace([load, data, out])
sim = pyrtl.Simulation(tracer=sim_trace)
# Run until receive a 'q' from the named pipe
while True:
    cmd = os.read(inFifo, 3)
    if cmd != '':
        if cmd == 'q':
            break
        l = int(cmd[0], 2)
        d = int(cmd[1:3], 16)
        sim.step({'load': l, 'data': d})
        os.write(outFifo, hex(sim.inspect(counter)))