def barrel_shifter(shift_in, bit_in, direction, shift_dist, wrap_around=0): """ Create a barrel shifter that operates on data based on the wire width :param shift_in: the input wire :param bit_in: the 1-bit wire giving the value to shift in :param direction: a one bit WireVector representing shift direction (0 = shift down, 1 = shift up) :param shift_dist: WireVector representing offset to shift :param wrap_around: ****currently not implemented**** :return: shifted WireVector """ # Implement with logN stages pyrtl.muxing between shifted and un-shifted values val = shift_in append_val = bit_in log_length = int(math.log(len(shift_in) - 1, 2)) # note the one offset if len(shift_dist) > log_length: print('Warning: for barrel shifter, the shift distance wirevector ' 'has bits that are not used in the barrel shifter') for i in range(min(len(shift_dist), log_length)): shift_amt = pow(2, i) # stages shift 1,2,4,8,... newval = pyrtl.select(direction, truecase=val[:-shift_amt], falsecase=val[shift_amt:]) newval = pyrtl.select(direction, truecase=pyrtl.concat(newval, append_val), falsecase=pyrtl.concat( append_val, newval)) # Build shifted value # pyrtl.mux shifted vs. unshifted by using i-th bit of shift amount signal val = pyrtl.select(shift_dist[i], truecase=newval, falsecase=val) append_val = pyrtl.concat(append_val, bit_in) return val
def barrel_shifter(shift_in, bit_in, direction, shift_dist, wrap_around=0): """ Create a barrel shifter that operates on data based on the wire width :param shift_in:the input wire; :param bit_in: the 1-bit wire giving the value to shift in. :param direction: direction is a one bit wirevector representing shift direction 0 = shift down, 1 = shift up. :param shift_dist: wirevector representing offset to shift :param wrap_around: ****currently not implemented***** :return: shifted wirevector """ # Implement with logN stages pyrtl.muxing between shifted and un-shifted values val = shift_in append_val = bit_in log_length = int(math.log(len(shift_in)-1, 2)) # note the one offset if len(shift_dist) > log_length: print('Warning: for barrel shifter, the shift distance wirevector ' 'has bits that are not used in the barrel shifter') for i in range(min(len(shift_dist), log_length)): shift_amt = pow(2, i) # stages shift 1,2,4,8,... newval = pyrtl.select(direction, truecase=val[:-shift_amt], falsecase=val[shift_amt:]) newval = pyrtl.select(direction, truecase=pyrtl.concat(newval, append_val), falsecase=pyrtl.concat(append_val, newval)) # Build shifted value # pyrtl.mux shifted vs. unshifted by using i-th bit of shift amount signal val = pyrtl.select(shift_dist[i - 1], truecase=newval, falsecase=val) append_val = pyrtl.concat(append_val, append_val) return val
def __init__(self, depth_width=2, data_width=32): aw = depth_width dw = data_width self.wr_data_i = pyrtl.Input(dw, 'wr_data_i') self.wr_en_i = pyrtl.Input(1, 'wr_en_i') self.rd_data_o = pyrtl.Output(dw, 'rd_data_o') self.rd_en_i = pyrtl.Input(1, 'rd_en_i') self.full_o = pyrtl.Output(1, 'full_o') self.empty_o = pyrtl.Output(1, 'empty_o') self.one_left = pyrtl.Output(1, 'one_left') self.reset = pyrtl.Input(1, 'reset') self.write_pointer = pyrtl.Register(aw + 1, 'write_pointer') self.read_pointer = pyrtl.Register(aw + 1, 'read_pointer') self.read_plus_1 = pyrtl.Const(1, 1) + self.read_pointer self.read_pointer.next <<= pyrtl.select(self.rd_en_i, truecase=self.read_plus_1, falsecase=self.read_pointer) self.write_pointer.next <<= pyrtl.select( self.reset, truecase=pyrtl.Const(0, aw + 1), falsecase=pyrtl.select(self.wr_en_i, truecase=self.write_pointer + 1, falsecase=self.write_pointer)) self.empty_int = pyrtl.WireVector(1, 'empty_int') self.full_or_empty = pyrtl.WireVector(1, 'full_or_empty') self.empty_int <<= self.write_pointer[aw] == self.read_pointer[aw] self.full_or_empty <<= self.write_pointer[0:aw] == self.read_pointer[ 0:aw] self.full_o <<= self.full_or_empty & ~self.empty_int self.empty_o <<= self.full_or_empty & self.empty_int self.one_left <<= (self.read_pointer + 1) == self.write_pointer self.mem = m = _RAM(num_entries=1 << aw, data_nbits=dw, name='FIFOStorage') m.wen <<= self.wr_en_i m.ren <<= self.rd_en_i m.raddr <<= self.read_pointer[0:aw] m.waddr <<= self.write_pointer[0:aw] m.wdata <<= self.wr_data_i self.rd_data_o <<= pyrtl.select(self.rd_en_i, truecase=m.rdata, falsecase=pyrtl.Const(0, dw))
def test_time_est_unchanged(self): a = pyrtl.Const(2, 8) b = pyrtl.Const(85, 8) zero = pyrtl.Const(0, 1) reg = pyrtl.Register(8) mem = pyrtl.MemBlock(8, 8) out = pyrtl.Output(8) nota, aLSB, athenb, aORb, aANDb, aNANDb, \ aXORb, aequalsb, altb, agtb, aselectb, \ aplusb, bminusa, atimesb, memread = [pyrtl.Output() for i in range(15)] out <<= zero nota <<= ~a aLSB <<= a[0] athenb <<= pyrtl.concat(a, b) aORb <<= a | b aANDb <<= a & b aNANDb <<= a.nand(b) aXORb <<= a ^ b aequalsb <<= a == b altb <<= a < b agtb <<= a > b aselectb <<= pyrtl.select(zero, a, b) reg.next <<= a aplusb <<= a + b bminusa <<= a - b atimesb <<= a * b memread <<= mem[0] mem[1] <<= a timing = estimate.TimingAnalysis() self.assertEqual(timing.max_freq(), 610.2770657878676) self.assertEquals(timing.max_length(), 1255.6000000000001)
def test_area_est_unchanged(self): a = pyrtl.Const(2, 8) b = pyrtl.Const(85, 8) zero = pyrtl.Const(0, 1) reg = pyrtl.Register(8) mem = pyrtl.MemBlock(8, 8) out = pyrtl.Output(8) nota, aLSB, athenb, aORb, aANDb, aNANDb, \ aXORb, aequalsb, altb, agtb, aselectb, \ aplusb, bminusa, atimesb, memread = [pyrtl.Output() for i in range(15)] out <<= zero nota <<= ~a aLSB <<= a[0] athenb <<= pyrtl.concat(a, b) aORb <<= a | b aANDb <<= a & b aNANDb <<= a.nand(b) aXORb <<= a ^ b aequalsb <<= a==b altb <<= a < b agtb <<= a > b aselectb <<= pyrtl.select(zero, a, b) reg.next <<= a aplusb <<= a + b bminusa <<= a - b atimesb <<= a*b memread <<= mem[0] mem[1] <<= a self.assertEquals(estimate.area_estimation(), (0.00734386752, 0.01879779717361501))
def prioritized_mux(selects, vals): """ Returns the value in the first wire for which its select bit is 1 :param [WireVector] selects: a list of WireVectors signaling whether a wire should be chosen :param [WireVector] vals: values to return when the corresponding select value is 1 :return: WireVector If none of the items are high, the last val is returned """ if len(selects) != len(vals): raise pyrtl.PyrtlError("Number of select and val signals must match") if len(vals) == 0: raise pyrtl.PyrtlError("Must have a signal to mux") if len(vals) == 1: return vals[0] else: half = len(vals) // 2 return pyrtl.select(pyrtl.rtl_any(*selects[:half]), truecase=prioritized_mux(selects[:half], vals[:half]), falsecase=prioritized_mux(selects[half:], vals[half:]))
def barrel_shifter(bits_to_shift, bit_in, direction, shift_dist, wrap_around=0): """ Create a barrel shifter that operates on data based on the wire width. :param bits_to_shift: the input wire :param bit_in: the 1-bit wire giving the value to shift in :param direction: a one bit WireVector representing shift direction (0 = shift down, 1 = shift up) :param shift_dist: WireVector representing offset to shift :param wrap_around: ****currently not implemented**** :return: shifted WireVector """ from pyrtl import concat, select # just for readability if wrap_around != 0: raise NotImplementedError # Implement with logN stages pyrtl.muxing between shifted and un-shifted values final_width = len(bits_to_shift) val = bits_to_shift append_val = bit_in for i in range(len(shift_dist)): shift_amt = pow(2, i) # stages shift 1,2,4,8,... if shift_amt < final_width: newval = select( direction, concat(val[:-shift_amt], append_val), # shift up concat(append_val, val[shift_amt:])) # shift down val = select( shift_dist[i], truecase=newval, # if bit of shift is 1, do the shift falsecase=val) # otherwise, don't # the value to append grows exponentially, but is capped at full width append_val = concat(append_val, append_val)[:final_width] else: # if we are shifting this much, all the data is gone val = select( shift_dist[i], truecase=append_val, # if bit of shift is 1, do the shift falsecase=val) # otherwise, don't return val
def relu(vec): # assert offset <= 24 # d[-1] of 2's complement is the signed bit # if 0 -> falsecase (positive) # if 1 -> truecase (negative) return [ pyrtl.select(d[-1], falsecase=d, truecase=pyrtl.Const(0, len(d))) for d in vec ]
def _sparse_mux(sel, vals): """ Mux that avoids instantiating unnecessary mux_2s when possible. :param WireVector sel: Select wire, determines what is selected on a given cycle :param {int: WireVector} vals: dictionary to store the values that are :return: Wirevector that signifies the change This mux supports not having a full specification. indices that are not specified are treated as Don't Cares """ items = list(vals.values()) if len(vals) <= 1: if len(vals) == 0: raise pyrtl.PyrtlError("Needs at least one parameter for val") return items[0] if len(sel) == 1: try: false_result = vals[0] true_result = vals[1] except KeyError: raise pyrtl.PyrtlError("Failed to retrieve values for smartmux. " "The length of sel might be wrong") else: half = 2**(len(sel) - 1) first_dict = {indx: wire for indx, wire in vals.items() if indx < half} second_dict = { indx - half: wire for indx, wire in vals.items() if indx >= half } if not len(first_dict): return sparse_mux(sel[:-1], second_dict) if not len(second_dict): return sparse_mux(sel[:-1], first_dict) false_result = sparse_mux(sel[:-1], first_dict) true_result = sparse_mux(sel[:-1], second_dict) if _is_equivelent(false_result, true_result): return true_result return pyrtl.select(sel[-1], falsecase=false_result, truecase=true_result)
def _sparse_mux(sel, vals): """ Mux that avoids instantiating unnecessary mux_2s when possible. :param WireVector sel: Select wire, determines what is selected on a given cycle :param {int: WireVector} vals: dictionary to store the values that are :return: Wirevector that signifies the change This mux supports not having a full specification. indices that are not specified are treated as Don't Cares """ items = list(vals.values()) if len(vals) <= 1: if len(vals) == 0: raise pyrtl.PyrtlError("Needs at least one parameter for val") return items[0] if len(sel) == 1: try: false_result = vals[0] true_result = vals[1] except KeyError: raise pyrtl.PyrtlError("Failed to retrieve values for smartmux. " "The length of sel might be wrong") else: half = 2**(len(sel) - 1) first_dict = {indx: wire for indx, wire in vals.items() if indx < half} second_dict = {indx-half: wire for indx, wire in vals.items() if indx >= half} if not len(first_dict): return sparse_mux(sel[:-1], second_dict) if not len(second_dict): return sparse_mux(sel[:-1], first_dict) false_result = sparse_mux(sel[:-1], first_dict) true_result = sparse_mux(sel[:-1], second_dict) if _is_equivelent(false_result, true_result): return true_result return pyrtl.select(sel[-1], falsecase=false_result, truecase=true_result)
def __init__(self, num_entries, data_nbits, reset_value=0, name=u''): self.addr_nbits = clog2(num_entries) self.reset_value = reset_value self.num_entries = num_entries self.data_nbits = data_nbits self.mem = pyrtl.MemBlock(bitwidth=data_nbits, addrwidth=self.addr_nbits, name='_RAM_' + name, asynchronous=False) self.wen = pyrtl.WireVector(1, 'wen') self.ren = pyrtl.WireVector(1, 'ren') self.waddr = pyrtl.WireVector(self.addr_nbits, 'waddr') self.raddr = pyrtl.WireVector(self.addr_nbits, 'raddr') self.wdata = pyrtl.WireVector(self.data_nbits, 'wdata') self.rdata = pyrtl.WireVector(self.data_nbits, 'rdata') self.rdata <<= pyrtl.select(self.ren, truecase=self.mem[self.raddr], falsecase=pyrtl.Const(0, 32)) self.mem[self.waddr] <<= pyrtl.MemBlock.EnabledWrite( self.wdata, self.wen)