Пример #1
0
def barrel_shifter(shift_in, bit_in, direction, shift_dist, wrap_around=0):
    """
    Create a barrel shifter that operates on data based on the wire width
    :param shift_in: the input wire
    :param bit_in: the 1-bit wire giving the value to shift in
    :param direction: a one bit WireVector representing shift direction
        (0 = shift down, 1 = shift up)
    :param shift_dist: WireVector representing offset to shift
    :param wrap_around: ****currently not implemented****
    :return: shifted WireVector
    """
    # Implement with logN stages pyrtl.muxing between shifted and un-shifted values

    val = shift_in
    append_val = bit_in
    log_length = int(math.log(len(shift_in) - 1, 2))  # note the one offset

    if len(shift_dist) > log_length:
        print('Warning: for barrel shifter, the shift distance wirevector '
              'has bits that are not used in the barrel shifter')

    for i in range(min(len(shift_dist), log_length)):
        shift_amt = pow(2, i)  # stages shift 1,2,4,8,...
        newval = pyrtl.select(direction,
                              truecase=val[:-shift_amt],
                              falsecase=val[shift_amt:])
        newval = pyrtl.select(direction,
                              truecase=pyrtl.concat(newval, append_val),
                              falsecase=pyrtl.concat(
                                  append_val, newval))  # Build shifted value
        # pyrtl.mux shifted vs. unshifted by using i-th bit of shift amount signal
        val = pyrtl.select(shift_dist[i], truecase=newval, falsecase=val)
        append_val = pyrtl.concat(append_val, bit_in)

    return val
Пример #2
0
def barrel_shifter(shift_in, bit_in, direction, shift_dist, wrap_around=0):
    """
    Create a barrel shifter that operates on data based on the wire width
    :param shift_in:the input wire;
    :param bit_in: the 1-bit wire giving the value to shift in.
    :param direction: direction is a one bit wirevector representing shift direction
        0 = shift down, 1 = shift up.
    :param shift_dist: wirevector representing offset to shift
    :param wrap_around: ****currently not implemented*****
    :return: shifted wirevector
    """
    # Implement with logN stages pyrtl.muxing between shifted and un-shifted values

    val = shift_in
    append_val = bit_in
    log_length = int(math.log(len(shift_in)-1, 2))  # note the one offset

    if len(shift_dist) > log_length:
        print('Warning: for barrel shifter, the shift distance wirevector '
              'has bits that are not used in the barrel shifter')

    for i in range(min(len(shift_dist), log_length)):
        shift_amt = pow(2, i)  # stages shift 1,2,4,8,...
        newval = pyrtl.select(direction, truecase=val[:-shift_amt], falsecase=val[shift_amt:])
        newval = pyrtl.select(direction, truecase=pyrtl.concat(newval, append_val),
                              falsecase=pyrtl.concat(append_val, newval))  # Build shifted value
        # pyrtl.mux shifted vs. unshifted by using i-th bit of shift amount signal
        val = pyrtl.select(shift_dist[i - 1], truecase=newval, falsecase=val)
        append_val = pyrtl.concat(append_val, append_val)

    return val
Пример #3
0
    def __init__(self, depth_width=2, data_width=32):
        aw = depth_width
        dw = data_width

        self.wr_data_i = pyrtl.Input(dw, 'wr_data_i')
        self.wr_en_i = pyrtl.Input(1, 'wr_en_i')
        self.rd_data_o = pyrtl.Output(dw, 'rd_data_o')
        self.rd_en_i = pyrtl.Input(1, 'rd_en_i')
        self.full_o = pyrtl.Output(1, 'full_o')
        self.empty_o = pyrtl.Output(1, 'empty_o')
        self.one_left = pyrtl.Output(1, 'one_left')

        self.reset = pyrtl.Input(1, 'reset')

        self.write_pointer = pyrtl.Register(aw + 1, 'write_pointer')
        self.read_pointer = pyrtl.Register(aw + 1, 'read_pointer')

        self.read_plus_1 = pyrtl.Const(1, 1) + self.read_pointer
        self.read_pointer.next <<= pyrtl.select(self.rd_en_i,
                                                truecase=self.read_plus_1,
                                                falsecase=self.read_pointer)

        self.write_pointer.next <<= pyrtl.select(
            self.reset,
            truecase=pyrtl.Const(0, aw + 1),
            falsecase=pyrtl.select(self.wr_en_i,
                                   truecase=self.write_pointer + 1,
                                   falsecase=self.write_pointer))

        self.empty_int = pyrtl.WireVector(1, 'empty_int')
        self.full_or_empty = pyrtl.WireVector(1, 'full_or_empty')

        self.empty_int <<= self.write_pointer[aw] == self.read_pointer[aw]
        self.full_or_empty <<= self.write_pointer[0:aw] == self.read_pointer[
            0:aw]
        self.full_o <<= self.full_or_empty & ~self.empty_int
        self.empty_o <<= self.full_or_empty & self.empty_int
        self.one_left <<= (self.read_pointer + 1) == self.write_pointer

        self.mem = m = _RAM(num_entries=1 << aw,
                            data_nbits=dw,
                            name='FIFOStorage')
        m.wen <<= self.wr_en_i
        m.ren <<= self.rd_en_i
        m.raddr <<= self.read_pointer[0:aw]
        m.waddr <<= self.write_pointer[0:aw]
        m.wdata <<= self.wr_data_i
        self.rd_data_o <<= pyrtl.select(self.rd_en_i,
                                        truecase=m.rdata,
                                        falsecase=pyrtl.Const(0, dw))
Пример #4
0
 def test_time_est_unchanged(self):
     a = pyrtl.Const(2, 8)
     b = pyrtl.Const(85, 8)
     zero = pyrtl.Const(0, 1)
     reg = pyrtl.Register(8)
     mem = pyrtl.MemBlock(8, 8)
     out = pyrtl.Output(8)
     nota, aLSB, athenb, aORb, aANDb, aNANDb, \
     aXORb, aequalsb, altb, agtb, aselectb, \
     aplusb, bminusa, atimesb, memread = [pyrtl.Output() for i in range(15)]
     out <<= zero
     nota <<= ~a
     aLSB <<= a[0]
     athenb <<= pyrtl.concat(a, b)
     aORb <<= a | b
     aANDb <<= a & b
     aNANDb <<= a.nand(b)
     aXORb <<= a ^ b
     aequalsb <<= a == b
     altb <<= a < b
     agtb <<= a > b
     aselectb <<= pyrtl.select(zero, a, b)
     reg.next <<= a
     aplusb <<= a + b
     bminusa <<= a - b
     atimesb <<= a * b
     memread <<= mem[0]
     mem[1] <<= a
     timing = estimate.TimingAnalysis()
     self.assertEqual(timing.max_freq(), 610.2770657878676)
     self.assertEquals(timing.max_length(), 1255.6000000000001)
Пример #5
0
 def test_area_est_unchanged(self):
     a = pyrtl.Const(2, 8)
     b = pyrtl.Const(85, 8)
     zero = pyrtl.Const(0, 1)
     reg = pyrtl.Register(8)
     mem = pyrtl.MemBlock(8, 8)
     out = pyrtl.Output(8)
     nota, aLSB, athenb, aORb, aANDb, aNANDb, \
     aXORb, aequalsb, altb, agtb, aselectb, \
     aplusb, bminusa, atimesb, memread = [pyrtl.Output() for i in range(15)]
     out <<= zero
     nota <<= ~a
     aLSB <<= a[0]
     athenb <<= pyrtl.concat(a, b)
     aORb <<= a | b
     aANDb <<= a & b
     aNANDb <<= a.nand(b)
     aXORb <<= a ^ b
     aequalsb <<= a==b
     altb <<= a < b
     agtb <<= a > b
     aselectb <<= pyrtl.select(zero, a, b)
     reg.next <<= a
     aplusb <<= a + b
     bminusa <<= a - b
     atimesb <<= a*b
     memread <<= mem[0]
     mem[1] <<= a
     self.assertEquals(estimate.area_estimation(), (0.00734386752, 0.01879779717361501))
Пример #6
0
def prioritized_mux(selects, vals):
    """
    Returns the value in the first wire for which its select bit is 1

    :param [WireVector] selects: a list of WireVectors signaling whether
        a wire should be chosen
    :param [WireVector] vals: values to return when the corresponding select
        value is 1
    :return: WireVector

    If none of the items are high, the last val is returned
    """
    if len(selects) != len(vals):
        raise pyrtl.PyrtlError("Number of select and val signals must match")
    if len(vals) == 0:
        raise pyrtl.PyrtlError("Must have a signal to mux")
    if len(vals) == 1:
        return vals[0]
    else:
        half = len(vals) // 2
        return pyrtl.select(pyrtl.rtl_any(*selects[:half]),
                            truecase=prioritized_mux(selects[:half],
                                                     vals[:half]),
                            falsecase=prioritized_mux(selects[half:],
                                                      vals[half:]))
Пример #7
0
def barrel_shifter(bits_to_shift,
                   bit_in,
                   direction,
                   shift_dist,
                   wrap_around=0):
    """ Create a barrel shifter that operates on data based on the wire width.

    :param bits_to_shift: the input wire
    :param bit_in: the 1-bit wire giving the value to shift in
    :param direction: a one bit WireVector representing shift direction
        (0 = shift down, 1 = shift up)
    :param shift_dist: WireVector representing offset to shift
    :param wrap_around: ****currently not implemented****
    :return: shifted WireVector
    """
    from pyrtl import concat, select  # just for readability

    if wrap_around != 0:
        raise NotImplementedError

    # Implement with logN stages pyrtl.muxing between shifted and un-shifted values
    final_width = len(bits_to_shift)
    val = bits_to_shift
    append_val = bit_in

    for i in range(len(shift_dist)):
        shift_amt = pow(2, i)  # stages shift 1,2,4,8,...
        if shift_amt < final_width:
            newval = select(
                direction,
                concat(val[:-shift_amt], append_val),  # shift up
                concat(append_val, val[shift_amt:]))  # shift down
            val = select(
                shift_dist[i],
                truecase=newval,  # if bit of shift is 1, do the shift
                falsecase=val)  # otherwise, don't
            # the value to append grows exponentially, but is capped at full width
            append_val = concat(append_val, append_val)[:final_width]
        else:
            # if we are shifting this much, all the data is gone
            val = select(
                shift_dist[i],
                truecase=append_val,  # if bit of shift is 1, do the shift
                falsecase=val)  # otherwise, don't

    return val
Пример #8
0
def relu(vec):
    # assert offset <= 24
    # d[-1] of 2's complement is the signed bit
    # if 0 -> falsecase (positive)
    # if 1 -> truecase (negative)
    return [
        pyrtl.select(d[-1], falsecase=d, truecase=pyrtl.Const(0, len(d)))
        for d in vec
    ]
Пример #9
0
def _sparse_mux(sel, vals):
    """
    Mux that avoids instantiating unnecessary mux_2s when possible.

    :param WireVector sel: Select wire, determines what is selected on a given cycle
    :param {int: WireVector} vals: dictionary to store the values that are
    :return: Wirevector that signifies the change

    This mux supports not having a full specification. indices that are not
    specified are treated as Don't Cares
    """
    items = list(vals.values())
    if len(vals) <= 1:
        if len(vals) == 0:
            raise pyrtl.PyrtlError("Needs at least one parameter for val")
        return items[0]

    if len(sel) == 1:
        try:
            false_result = vals[0]
            true_result = vals[1]
        except KeyError:
            raise pyrtl.PyrtlError("Failed to retrieve values for smartmux. "
                                   "The length of sel might be wrong")
    else:
        half = 2**(len(sel) - 1)

        first_dict = {indx: wire for indx, wire in vals.items() if indx < half}
        second_dict = {
            indx - half: wire
            for indx, wire in vals.items() if indx >= half
        }
        if not len(first_dict):
            return sparse_mux(sel[:-1], second_dict)
        if not len(second_dict):
            return sparse_mux(sel[:-1], first_dict)

        false_result = sparse_mux(sel[:-1], first_dict)
        true_result = sparse_mux(sel[:-1], second_dict)
    if _is_equivelent(false_result, true_result):
        return true_result
    return pyrtl.select(sel[-1], falsecase=false_result, truecase=true_result)
Пример #10
0
def _sparse_mux(sel, vals):
    """
    Mux that avoids instantiating unnecessary mux_2s when possible.

    :param WireVector sel: Select wire, determines what is selected on a given cycle
    :param {int: WireVector} vals: dictionary to store the values that are
    :return: Wirevector that signifies the change

    This mux supports not having a full specification. indices that are not
    specified are treated as Don't Cares
    """
    items = list(vals.values())
    if len(vals) <= 1:
        if len(vals) == 0:
            raise pyrtl.PyrtlError("Needs at least one parameter for val")
        return items[0]

    if len(sel) == 1:
        try:
            false_result = vals[0]
            true_result = vals[1]
        except KeyError:
            raise pyrtl.PyrtlError("Failed to retrieve values for smartmux. "
                                   "The length of sel might be wrong")
    else:
        half = 2**(len(sel) - 1)

        first_dict = {indx: wire for indx, wire in vals.items() if indx < half}
        second_dict = {indx-half: wire for indx, wire in vals.items() if indx >= half}
        if not len(first_dict):
            return sparse_mux(sel[:-1], second_dict)
        if not len(second_dict):
            return sparse_mux(sel[:-1], first_dict)

        false_result = sparse_mux(sel[:-1], first_dict)
        true_result = sparse_mux(sel[:-1], second_dict)
    if _is_equivelent(false_result, true_result):
        return true_result
    return pyrtl.select(sel[-1], falsecase=false_result, truecase=true_result)
Пример #11
0
    def __init__(self, num_entries, data_nbits, reset_value=0, name=u''):
        self.addr_nbits = clog2(num_entries)
        self.reset_value = reset_value
        self.num_entries = num_entries
        self.data_nbits = data_nbits

        self.mem = pyrtl.MemBlock(bitwidth=data_nbits,
                                  addrwidth=self.addr_nbits,
                                  name='_RAM_' + name,
                                  asynchronous=False)

        self.wen = pyrtl.WireVector(1, 'wen')
        self.ren = pyrtl.WireVector(1, 'ren')
        self.waddr = pyrtl.WireVector(self.addr_nbits, 'waddr')
        self.raddr = pyrtl.WireVector(self.addr_nbits, 'raddr')
        self.wdata = pyrtl.WireVector(self.data_nbits, 'wdata')
        self.rdata = pyrtl.WireVector(self.data_nbits, 'rdata')

        self.rdata <<= pyrtl.select(self.ren,
                                    truecase=self.mem[self.raddr],
                                    falsecase=pyrtl.Const(0, 32))
        self.mem[self.waddr] <<= pyrtl.MemBlock.EnabledWrite(
            self.wdata, self.wen)
Пример #12
0
def prioritized_mux(selects, vals):
    """
    Returns the value in the first wire for which its select bit is 1

    :param [WireVector] selects: a list of WireVectors signaling whether
        a wire should be chosen
    :param [WireVector] vals: values to return when the corresponding select
        value is 1
    :return: WireVector

    If none of the items are high, the last val is returned
    """
    if len(selects) != len(vals):
        raise pyrtl.PyrtlError("Number of select and val signals must match")
    if len(vals) == 0:
        raise pyrtl.PyrtlError("Must have a signal to mux")
    if len(vals) == 1:
        return vals[0]
    else:
        half = len(vals) // 2
        return pyrtl.select(pyrtl.rtl_any(*selects[:half]),
                            truecase=prioritized_mux(selects[:half], vals[:half]),
                            falsecase=prioritized_mux(selects[half:], vals[half:]))