def _g(word, key_expand_round): # One-byte left circular rotation, substitution of each byte a = libutils.partition_wire(word, 8) sub = pyrtl.concat(sbox[a[2]], sbox[a[1]], sbox[a[0]], sbox[a[3]]) # xor substituted bytes with round constant. round_const = pyrtl.concat(rcon[key_expand_round + 1], pyrtl.Const(0, bitwidth=24)) return round_const ^ sub
def barrel_shifter(shift_in, bit_in, direction, shift_dist, wrap_around=0): """ Create a barrel shifter that operates on data based on the wire width :param shift_in:the input wire; :param bit_in: the 1-bit wire giving the value to shift in. :param direction: direction is a one bit wirevector representing shift direction 0 = shift down, 1 = shift up. :param shift_dist: wirevector representing offset to shift :param wrap_around: ****currently not implemented***** :return: shifted wirevector """ # Implement with logN stages pyrtl.muxing between shifted and un-shifted values val = shift_in append_val = bit_in log_length = int(math.log(len(shift_in) - 1, 2)) # note the one offset if len(shift_dist) > log_length: print('Warning: for barrel shifter, the shift distance wirevector ' 'has bits that are not used in the barrel shifter') for i in range(min(len(shift_dist), log_length)): shift_amt = pow(2, i) # stages shift 1,2,4,8,... newval = pyrtl.mux(direction, truecase=val[:-shift_amt], falsecase=val[shift_amt:]) newval = pyrtl.mux(direction, truecase=pyrtl.concat(newval, append_val), falsecase=pyrtl.concat( append_val, newval)) # Build shifted value # pyrtl.mux shifted vs. unshifted by using i-th bit of shift amount signal val = pyrtl.mux(shift_dist[i - 1], truecase=newval, falsecase=val) append_val = pyrtl.concat(append_val, append_val) return val
def barrel_shifter(shift_in, bit_in, direction, shift_dist, wrap_around=0): """ Create a barrel shifter that operates on data based on the wire width :param shift_in:the input wire; :param bit_in: the 1-bit wire giving the value to shift in. :param direction: direction is a one bit wirevector representing shift direction 0 = shift down, 1 = shift up. :param shift_dist: wirevector representing offset to shift :param wrap_around: ****currently not implemented***** :return: shifted wirevector """ # Implement with logN stages pyrtl.muxing between shifted and un-shifted values val = shift_in append_val = bit_in log_length = int(math.log(len(shift_in)-1, 2)) # note the one offset if len(shift_dist) > log_length: print('Warning: for barrel shifter, the shift distance wirevector ' 'has bits that are not used in the barrel shifter') for i in range(min(len(shift_dist), log_length)): shift_amt = pow(2, i) # stages shift 1,2,4,8,... newval = pyrtl.mux(direction, truecase=val[:-shift_amt], falsecase=val[shift_amt:]) newval = pyrtl.mux(direction, truecase=pyrtl.concat(newval, append_val), falsecase=pyrtl.concat(append_val, newval)) # Build shifted value # pyrtl.mux shifted vs. unshifted by using i-th bit of shift amount signal val = pyrtl.mux(shift_dist[i-1], truecase=newval, falsecase=val) append_val = pyrtl.concat(append_val, append_val) return val
def ripple_half_add(a, cin=0): cin = pyrtl.as_wires(cin) if len(a) == 1: return pyrtl.concat(*half_adder(a, cin)) else: ripplecarry = half_adder(a[0], cin) msbits = ripple_half_add(a[1:], ripplecarry[0]) return pyrtl.concat(msbits, ripplecarry[1])
def _g(word, key_expand_round): # One-byte left circular rotation, substitution of each byte a = libutils.partition_wire(word, 8) sub = pyrtl.concat(sbox[a[1]], sbox[a[2]], sbox[a[3]], sbox[a[0]]) # xor substituted bytes with round constant. round_const = pyrtl.concat(rcon[key_expand_round + 1], pyrtl.Const(0, bitwidth=24)) return round_const ^ sub
def test_concat(self): # concat's args are order dependent, therefore we need to check # that we aren't mangling them ins = [pyrtl.Input(5) for i in range(2)] outs = [pyrtl.Output(10) for i in range(2)] outs[0] <<= pyrtl.concat(ins[1], ins[0]) outs[1] <<= pyrtl.concat(ins[0], ins[1]) pyrtl.common_subexp_elimination() self.num_net_of_type('c', 2) self.num_net_of_type('w', 2) self.assert_num_net(4) self.assert_num_wires(6) pyrtl.working_block().sanity_check()
def prng_lfsr(bitwidth, load, req, seed=None): """ Builds a single-cycle PRNG using a 127 bits Fibonacci LFSR. :param bitwidth: the desired bitwidth of the random number :param load: one bit signal to load the seed into the prng :param req: one bit signal to request a random number :param seed: 127 bits WireVector, defaults to None (self-seeding), refrain from self-seeding if reseeding at run time is required :return: register containing the random number with the given bitwidth A very fast and compact PRNG that generates a random number using only one clock cycle. Has a period of 2**127 - 1. Its linearity makes it a bit statistically weak, but should be good enough for any noncryptographic purpose like test pattern generation. """ # 127 bits is chosen because 127 is a mersenne prime, which makes the period of the # LFSR maximized at 2**127 - 1 for any requested bitwidth if seed is None: import random cryptogen = random.SystemRandom() seed = cryptogen.randrange( 1, 2**127) # seed itself if no seed signal is given lfsr = pyrtl.Register(127 if bitwidth < 127 else bitwidth) # leap ahead by shifting the LFSR bitwidth times leap_ahead = lfsr for i in range(bitwidth): leap_ahead = pyrtl.concat(leap_ahead, leap_ahead[125] ^ leap_ahead[126]) with pyrtl.conditional_assignment: with load: lfsr.next |= seed with req: lfsr.next |= leap_ahead return lfsr[:bitwidth]
def inv_shift_rows(in_vector): a = libutils.partition_wire(in_vector, 8) out_vector = pyrtl.concat(a[0], a[7], a[10], a[13], a[1], a[4], a[11], a[14], a[2], a[5], a[8], a[15], a[3], a[6], a[9], a[12]) return out_vector
def simple_mult(A, B, start): """ Generate simple shift-and-add multiplier. Builds a slow, small multiplier using the simple shift-and-add algorithm. Requires very small area (it uses only a single adder), but has long delay (worst case is len(a) cycles). a and b are arbitrary-length inputs; start is a one-bit input to indicate inputs are ready.done is a one-bit signal output raised when the multiplication is finished, at which point the product will be on the result line (returned by the function). """ alen = len(A) blen = len(B) areg = pyrtl.Register(alen) breg = pyrtl.Register(blen + alen) accum = pyrtl.Register(blen + alen) done = areg == 0 # Multiplication is finished when a becomes 0 # During multiplication, shift a right every cycle, b left every cycle with pyrtl.conditional_assignment: with start: # initialization areg.next |= A breg.next |= B accum.next |= 0 with ~done: # don't run when there's no work to do areg.next |= areg[1:] # right shift breg.next |= pyrtl.concat(breg, "1'b0") # left shift # "Multply" shifted breg by LSB of areg by conditionally adding with areg[0]: accum.next |= accum + breg # adds to accum only when LSB of areg is 1 return accum, done
def shift_reg(din, n): """ Use a shift register to create delay-coded thresholds. """ sr = pyrtl.Register(bitwidth=n) sr.next <<= pyrtl.concat(sr[:-1], din) return sr
def test_two_way_concat(self): i = pyrtl.Const(0b1100) j = pyrtl.Const(0b011, bitwidth=3) k = pyrtl.Const(0b100110) o = pyrtl.Output(13, 'o') o <<= pyrtl.concat(i, j, k) block = pyrtl.working_block() concat_nets = list(block.logic_subset(op='c')) self.assertEqual(len(concat_nets), 1) self.assertEqual(concat_nets[0].args, (i, j, k)) pyrtl.two_way_concat() concat_nets = list(block.logic_subset(op='c')) self.assertEqual(len(concat_nets), 2) upper_concat = next(n for n in concat_nets if i is n.args[0]) lower_concat = next(n for n in concat_nets if k is n.args[1]) self.assertNotEqual(upper_concat, lower_concat) self.assertEqual(upper_concat.args, (i, j)) self.assertEqual(lower_concat.args, (upper_concat.dests[0], k)) sim = pyrtl.Simulation() sim.step({}) self.assertEqual(sim.inspect('o'), 0b1100011100110)
def test_area_est_unchanged(self): a = pyrtl.Const(2, 8) b = pyrtl.Const(85, 8) zero = pyrtl.Const(0, 1) reg = pyrtl.Register(8) mem = pyrtl.MemBlock(8, 8) out = pyrtl.Output(8) nota, aLSB, athenb, aORb, aANDb, aNANDb, \ aXORb, aequalsb, altb, agtb, aselectb, \ aplusb, bminusa, atimesb, memread = [pyrtl.Output() for i in range(15)] out <<= zero nota <<= ~a aLSB <<= a[0] athenb <<= pyrtl.concat(a, b) aORb <<= a | b aANDb <<= a & b aNANDb <<= a.nand(b) aXORb <<= a ^ b aequalsb <<= a==b altb <<= a < b agtb <<= a > b aselectb <<= pyrtl.select(zero, a, b) reg.next <<= a aplusb <<= a + b bminusa <<= a - b atimesb <<= a*b memread <<= mem[0] mem[1] <<= a self.assertEquals(estimate.area_estimation(), (0.00734386752, 0.01879779717361501))
def test_time_est_unchanged(self): a = pyrtl.Const(2, 8) b = pyrtl.Const(85, 8) zero = pyrtl.Const(0, 1) reg = pyrtl.Register(8) mem = pyrtl.MemBlock(8, 8) out = pyrtl.Output(8) nota, aLSB, athenb, aORb, aANDb, aNANDb, \ aXORb, aequalsb, altb, agtb, aselectb, \ aplusb, bminusa, atimesb, memread = [pyrtl.Output() for i in range(15)] out <<= zero nota <<= ~a aLSB <<= a[0] athenb <<= pyrtl.concat(a, b) aORb <<= a | b aANDb <<= a & b aNANDb <<= a.nand(b) aXORb <<= a ^ b aequalsb <<= a == b altb <<= a < b agtb <<= a > b aselectb <<= pyrtl.select(zero, a, b) reg.next <<= a aplusb <<= a + b bminusa <<= a - b atimesb <<= a * b memread <<= mem[0] mem[1] <<= a timing = estimate.TimingAnalysis() self.assertEqual(timing.max_freq(), 610.2770657878676) self.assertEquals(timing.max_length(), 1255.6000000000001)
def kogge_stone(a, b, cin=0): """ Creates a Kogge-Stone adder given two inputs :param a, b: The two Wirevectors to add up (bitwidths don't need to match) :param cin: An optimal carry Wirevector or value :return: a Wirevector representing the output of the adder The Kogge-Stone adder is a fast tree-based adder with O(log(n)) propagation delay, useful for performance critical designs. However, it has O(n log(n)) area usage, and large fan out. """ a, b = libutils.match_bitwidth(a, b) prop_orig = a ^ b prop_bits = [i for i in prop_orig] gen_bits = [i for i in a & b] prop_dist = 1 # creation of the carry calculation while prop_dist < len(a): for i in reversed(range(prop_dist, len(a))): prop_old = prop_bits[i] gen_bits[i] = gen_bits[i] | (prop_old & gen_bits[i - prop_dist]) if i >= prop_dist * 2: # to prevent creating unnecessary nets and wires prop_bits[i] = prop_old & prop_bits[i - prop_dist] prop_dist *= 2 # assembling the result of the addition # preparing the cin (and conveniently shifting the gen bits) gen_bits.insert(0, pyrtl.as_wires(cin)) return pyrtl.concat(*reversed(gen_bits)) ^ prop_orig
def inv_shift_rows(in_vector): # a = libutils.partition_wire(in_vector, 8) a = [in_vector[offset - 8:offset] for offset in range(128, 0, -8)] out_vector = pyrtl.concat(a[0], a[13], a[10], a[7], a[4], a[1], a[14], a[11], a[8], a[5], a[2], a[15], a[12], a[9], a[6], a[3]) return out_vector
def cla_adder(a, b, cin=0, la_unit_len=4): """ Carry Lookahead Adder :param int la_unit_len: the length of input that every unit processes A Carry LookAhead Adder is an adder that is faster than a ripple carry adder, as it calculates the carry bits faster. It is not as fast as a Kogge-Stone adder, but uses less area. """ a, b = pyrtl.match_bitwidth(a, b) if len(a) <= la_unit_len: sum, cout = _cla_adder_unit(a, b, cin) return pyrtl.concat(cout, sum) else: sum, cout = _cla_adder_unit(a[0:la_unit_len], b[0:la_unit_len], cin) msbits = cla_adder(a[la_unit_len:], b[la_unit_len:], cout, la_unit_len) return pyrtl.concat(msbits, sum)
def barrel_shifter(bits_to_shift, bit_in, direction, shift_dist, wrap_around=0): """ Create a barrel shifter that operates on data based on the wire width. :param bits_to_shift: the input wire :param bit_in: the 1-bit wire giving the value to shift in :param direction: a one bit WireVector representing shift direction (0 = shift down, 1 = shift up) :param shift_dist: WireVector representing offset to shift :param wrap_around: ****currently not implemented**** :return: shifted WireVector """ from pyrtl import concat, select # just for readability if wrap_around != 0: raise NotImplementedError # Implement with logN stages pyrtl.muxing between shifted and un-shifted values final_width = len(bits_to_shift) val = bits_to_shift append_val = bit_in for i in range(len(shift_dist)): shift_amt = pow(2, i) # stages shift 1,2,4,8,... if shift_amt < final_width: newval = select( direction, concat(val[:-shift_amt], append_val), # shift up concat(append_val, val[shift_amt:])) # shift down val = select( shift_dist[i], truecase=newval, # if bit of shift is 1, do the shift falsecase=val) # otherwise, don't # the value to append grows exponentially, but is capped at full width append_val = concat(append_val, append_val)[:final_width] else: # if we are shifting this much, all the data is gone val = select( shift_dist[i], truecase=append_val, # if bit of shift is 1, do the shift falsecase=val) # otherwise, don't return val
def ripple_add(a, b, cin=0): a, b = libutils.match_bitwidth(a, b) cin = pyrtl.as_wires(cin) if len(a) == 1: return one_bit_add(a, b, cin) else: ripplecarry = one_bit_add(a[0], b[0], cin) msbits = ripple_add(a[1:], b[1:], ripplecarry[1]) return pyrtl.concat(msbits, ripplecarry[0])
def key_expansion(key): w = list(libutils.partition_wire(key, 32)) for key_expand_round in range(10): last = key_expand_round * 4 w.append(w[last] ^ _g(w[last + 3], key_expand_round)) w.append(w[-1] ^ w[last + 1]) w.append(w[-1] ^ w[last + 2]) w.append(w[-1] ^ w[last + 3]) return pyrtl.concat(*w)
def inv_mix_columns(in_vector): def _inv_mix_single(index): mult_items = [inv_galois_mult(a[_mod_add(index, loc, 4)], mult_table) for loc, mult_table in enumerate(_igm_divisor)] return mult_items[0] ^ mult_items[1] ^ mult_items[2] ^ mult_items[3] a = libutils.partition_wire(in_vector, 8) inverted = [_inv_mix_single(index) for index in range(len(a))] return pyrtl.concat(*inverted)
def test_async_check_should_pass_with_cat(self): memory = pyrtl.MemBlock( bitwidth=self.bitwidth, addrwidth=self.addrwidth, name='memory') addr = pyrtl.concat(self.mem_read_address1[0], self.mem_read_address2[0:-1]) self.output1 <<= memory[addr] memory[self.mem_write_address] <<= self.mem_write_data pyrtl.working_block().sanity_check()
def test_netgraph_same_wire_multiple_edges_to_same_net(self): c = pyrtl.Const(1, 1) w = pyrtl.concat(c, c, c) g = pyrtl.net_graph() self.assertEqual(len(g[c]), 1) edges = list(g[c].values())[0] self.assertEqual(len(edges), 3) for w in edges: self.assertIs(w, c)
def key_expansion(key): w = list(reversed(libutils.partition_wire(key, 32))) for key_expand_round in range(10): last = key_expand_round * 4 w.append(w[last] ^ _g(w[last + 3], key_expand_round)) w.append(w[-1] ^ w[last + 1]) w.append(w[-1] ^ w[last + 2]) w.append(w[-1] ^ w[last + 3]) return pyrtl.concat(*w)
def barrel_shifter_v2(bits_to_shift, bit_in, direction, shift_dist, wrap_around=0): """ Create a barrel shifter that operates on data based on the wire width :param bits_to_shift: the input wire :param bit_in: the 1-bit wire giving the value to shift in :param direction: a one bit WireVector representing shift direction (0 = shift down, 1 = shift up) :param shift_dist: WireVector representing offset to shift :param wrap_around: ****currently not implemented**** :return: shifted WireVector """ # Implement with logN stages pyrtl.muxing between shifted and un-shifted values val = bits_to_shift append_val = bit_in log_length = int(math.log(len(bits_to_shift) - 1, 2)) # note the one offset if wrap_around != 0: raise NotImplementedError # if len(shift_dist) > log_length: # raise pyrtl.PyrtlError('the shift distance wirevector ' # 'has bits that are not used in the barrel shifter') for i in range(min(len(shift_dist), log_length)): shift_amt = pow(2, i) # stages shift 1,2,4,8,... newval = pyrtl.select(direction, truecase=val[:-shift_amt], falsecase=val[shift_amt:]) newval = pyrtl.select(direction, truecase=pyrtl.concat(newval, append_val), falsecase=pyrtl.concat( append_val, newval)) # Build shifted value # pyrtl.mux shifted vs. unshifted by using i-th bit of shift amount signal val = pyrtl.select(shift_dist[i], truecase=newval, falsecase=val) append_val = pyrtl.concat(append_val, bit_in) return val
def ripple_add(a, b, cin=0): a, b = pyrtl.match_bitwidth(a, b) # this function is a function that allows us to match the bitwidth of multiple # different wires. By default, it zero extends the shorter bits if len(a) == 1: sumbits, cout = one_bit_add(a, b, cin) else: lsbit, ripplecarry = one_bit_add(a[0], b[0], cin) msbits, cout = ripple_add(a[1:], b[1:], ripplecarry) sumbits = pyrtl.concat(msbits, lsbit) return sumbits, cout
def carrysave_adder(a, b, c, final_adder=ripple_add): """ Adds three wirevectors up in an efficient manner :param WireVector a, b, c : the three wires to add up :param function final_adder : The adder to use to do the final addition :return: a wirevector with length 2 longer than the largest input """ a, b, c = libutils.match_bitwidth(a, b, c) partial_sum = a ^ b ^ c shift_carry = (a | b) & (a | c) & (b | c) return pyrtl.concat(final_adder(partial_sum[1:], shift_carry), partial_sum[0])
def ripple_add(a, b, carry_in=0): a, b = pyrtl.match_bitwidth(a, b) # function that allows us to match the bitwidth of multiple different wires # By default, it zero extends the shorter bits if len(a) == 1: sumbits, carry_out = one_bit_add(a, b, carry_in) else: lsbit, ripplecarry = one_bit_add(a[0], b[0], carry_in) msbits, carry_out = ripple_add(a[1:], b[1:], ripplecarry) sumbits = pyrtl.concat(msbits, lsbit) return sumbits, carry_out
def carrysave_adder(a, b, c): """ Adds three wirevectors up in an efficient manner :param a, b, c: the three wirevectors to add up :return: a wirevector with length 2 longer than the largest input """ libutils.match_bitwidth(a, b, c) partial_sum = a ^ b ^ c shift_carry = (a | b) & (a | c) & (b | c) shift_carry_1 = pyrtl.concat(shift_carry, 0) return ripple_add(partial_sum, shift_carry_1)
def generate_full_mux(a, b, sel): """Generates a multiplexor b is the one selected when sel is high""" assert len(a) == len(b) if len(a) == 1: out = generate_one_bit_mux(a, b, sel) else: lsbit = generate_one_bit_mux(a[0], b[0], sel) msbits = generate_full_mux(a[1:], b[1:], sel) out = pyrtl.concat(msbits, lsbit) return out
def inv_mix_columns(in_vector): def _inv_mix_single(index): mult_items = [ inv_galois_mult(a[_mod_add(index, loc, 4)], mult_table) for loc, mult_table in enumerate(_igm_divisor) ] return mult_items[0] ^ mult_items[1] ^ mult_items[2] ^ mult_items[3] a = libutils.partition_wire(in_vector, 8) inverted = [_inv_mix_single(index) for index in range(len(a))] return pyrtl.concat(*inverted)
def generate_full_adder(a, b, cin=None): """ Generates a arbitrary bitwidth ripple-carry adder """ assert len(a) == len(b) if cin is None: cin = pyrtl.Const(0, bitwidth=1) if len(a) == 1: sumbits, cout = generate_one_bit_adder(a, b, cin) else: lsbit, ripplecarry = generate_one_bit_adder(a[0], b[0], cin) msbits, cout = generate_full_adder(a[1:], b[1:], ripplecarry) sumbits = pyrtl.concat(msbits, lsbit) return sumbits, cout
def _sparse_adder(wire_array_2, adder): result = [] for single_w_index in range(len(wire_array_2)): if len(wire_array_2[single_w_index]) == 2: # Check if the two wire vectors overlap yet break result.append(wire_array_2[single_w_index][0]) import six wires_to_zip = wire_array_2[single_w_index:] add_wires = tuple(six.moves.zip_longest(*wires_to_zip, fillvalue=pyrtl.Const(0))) adder_result = adder(pyrtl.concat_list(add_wires[0]), pyrtl.concat_list(add_wires[1])) return pyrtl.concat(adder_result, *reversed(result))
def carrysave_adder(a, b, c, final_adder=ripple_add): """ Adds three wirevectors up in an efficient manner :param WireVector a, b, c : the three wires to add up :param function final_adder : The adder to use to do the final addition :return: a wirevector with length 2 longer than the largest input """ a, b, c = libutils.match_bitwidth(a, b, c) partial_sum = a ^ b ^ c shift_carry = (a | b) & (a | c) & (b | c) shift_carry_1 = pyrtl.concat(shift_carry, 0) return final_adder(partial_sum, shift_carry_1)
def ripple_add(a, b, cin=0): if len(a) < len(b): # make sure that b is the shorter wire b, a = a, b cin = pyrtl.as_wires(cin) if len(a) == 1: return one_bit_add(a, b, cin) else: ripplecarry = one_bit_add(a[0], b[0], cin) if len(b) == 1: msbits = ripple_half_add(a[1:], ripplecarry[1]) else: msbits = ripple_add(a[1:], b[1:], ripplecarry[1]) return pyrtl.concat(msbits, ripplecarry[0])
def test_loop_2(self): in_1 = pyrtl.Input(10) in_2 = pyrtl.Input(9) fake_loop_wire = pyrtl.WireVector(1) # Note the slight difference from the last test case on the next line comp_wire = pyrtl.concat(in_2[0:6], fake_loop_wire, in_2[6:9]) r_wire = in_1 & comp_wire fake_loop_wire <<= r_wire[3] out = pyrtl.Output(10) out <<= fake_loop_wire # It causes there to be a real loop self.assert_has_loop()
def _sparse_adder(wire_array_2, adder): result = [] for single_w_index in range(len(wire_array_2)): if len(wire_array_2[single_w_index] ) == 2: # Check if the two wire vectors overlap yet break result.append(wire_array_2[single_w_index][0]) import six wires_to_zip = wire_array_2[single_w_index:] add_wires = tuple( six.moves.zip_longest(*wires_to_zip, fillvalue=pyrtl.Const(0))) adder_result = adder(pyrtl.concat_list(add_wires[0]), pyrtl.concat_list(add_wires[1])) return pyrtl.concat(adder_result, *reversed(result))
def _one_cycle_mult(areg, breg, rem_bits, sum_sf=0, curr_bit=0): """ returns a WireVector sum of rem_bits multiplies (in one clock cycle) note: this method requires a lot of area because of the indexing in the else statement """ if rem_bits == 0: return sum_sf else: a_curr_val = areg[curr_bit].sign_extended(len(breg)) if curr_bit == 0: # if no shift return(_one_cycle_mult(areg, breg, rem_bits-1, # areg, breg, rem_bits sum_sf + (a_curr_val & breg), # sum_sf curr_bit+1)) # curr_bit else: return _one_cycle_mult( areg, breg, rem_bits-1, # areg, breg, rem_bits sum_sf + (a_curr_val & pyrtl.concat(breg, pyrtl.Const(0, curr_bit))), # sum_sf curr_bit+1 # curr_bit )
def _sparse_adder(wire_array_2, adder): bitwidth = len(wire_array_2) add_wires = [], [] result = [] for single_w_index in range(bitwidth): if len(wire_array_2[single_w_index]) == 2: # Check if the two wire vectors overlap yet break result.append(wire_array_2[single_w_index][0]) for w_loc in range(single_w_index, bitwidth): for i in range(2): if len(wire_array_2[w_loc]) >= i + 1: add_wires[i].append(wire_array_2[w_loc][i]) else: add_wires[i].append(pyrtl.Const(0)) adder_result = adder(pyrtl.concat_list(add_wires[0]), pyrtl.concat_list(add_wires[1])) return pyrtl.concat(adder_result, *reversed(result))
def test_as_graph_duplicate_args(self): a = pyrtl.Input(3) x = pyrtl.Input(1) d = pyrtl.Output() b = a & a c = pyrtl.concat(a, a) m = pyrtl.MemBlock(addrwidth=3, bitwidth=3, name='m') m2 = pyrtl.MemBlock(addrwidth=1, bitwidth=1, name='m') d <<= m[a] m[a] <<= a m2[x] <<= pyrtl.MemBlock.EnabledWrite(x, x) b = pyrtl.working_block() src_g, dst_g = b.net_connections(False) self.check_graph_correctness(src_g, dst_g) src_g, dst_g = b.net_connections(True) self.check_graph_correctness(src_g, dst_g, True)
def test_edge_case_1(self): in_1 = pyrtl.Input(10) in_2 = pyrtl.Input(9) fake_loop_wire = pyrtl.WireVector(1) comp_wire = pyrtl.concat(in_2[0:4], fake_loop_wire, in_2[4:9]) r_wire = in_1 & comp_wire fake_loop_wire <<= r_wire[3] out = pyrtl.Output(10) out <<= fake_loop_wire # Yes, because we only check loops on a net level, this will still be # a loop pre synth self.assertNotEqual(pyrtl.find_loop(), None) pyrtl.synthesize() # Because synth separates the individual wires, it also resolves the loop self.assertEqual(pyrtl.find_loop(), None) pyrtl.optimize() self.assertEqual(pyrtl.find_loop(), None)
def _shifted_reg_next(reg, direct, num=1): """ use: myReg.next = shifted_reg_next(myReg, 'l', 4) :param string direct: direction of shift, either 'l' or 'r' :param int num: number of shifts :return Register reg_next: a 'next' property for shifted (left or right) register """ if direct == 'l': if num >= len(reg): return 0 else: return pyrtl.concat(reg, pyrtl.Const(0, num)) elif direct == 'r': if num >= len(reg): return 0 else: return reg[num:] else: raise pyrtl.PyrtlError("direction must be specified with 'direct'" "parameter as either 'l' or 'r'")
def _cla_adder_unit(a, b, cin): """ Carry generation and propogation signals will be calculated only using the inputs; their values don't rely on the sum. Every unit generates a cout signal which is used as cin for the next unit. """ gen = a & b prop = a ^ b assert(len(prop) == len(gen)) carry = [gen[0] | prop[0] & cin] sum_bit = prop[0] ^ cin cur_gen = gen[0] cur_prop = prop[0] for i in range(1, len(prop)): cur_gen = gen[i] | (prop[i] & cur_gen) cur_prop = cur_prop & prop[i] sum_bit = pyrtl.concat(prop[i] ^ carry[i-1], sum_bit) carry.append(gen[i] | (prop[i] & carry[i-1])) cout = cur_gen | (cur_prop & cin) return sum_bit, cout
def simple_mult(A, B, start): """ Generate simple shift-and-add multiplier. :param Builds a slow, small multiplier using the simple shift-and-add algorithm. Requires very small area (it uses only a single adder), but has long delay (worst case is len(a) cycles). a and b are arbitrary-length inputs; start is a one-bit input to indicate inputs are ready.done is a one-bit signal output raised when the multiplication is finished, at which point the product will be on the result line (returned by the function). """ triv_result = _trivial_mult(A, B) if triv_result is not None: return triv_result, pyrtl.Const(1, 1) alen = len(A) blen = len(B) areg = pyrtl.Register(alen) breg = pyrtl.Register(blen + alen) accum = pyrtl.Register(blen + alen) done = (areg == 0) # Multiplication is finished when a becomes 0 # During multiplication, shift a right every cycle, b left every cycle with pyrtl.conditional_assignment: with start: # initialization areg.next |= A breg.next |= B accum.next |= 0 with ~done: # don't run when there's no work to do areg.next |= areg[1:] # right shift breg.next |= pyrtl.concat(breg, pyrtl.Const(0, 1)) # left shift a_0_val = areg[0].sign_extended(len(accum)) # adds to accum only when LSB of areg is 1 accum.next |= accum + (a_0_val & breg) return accum, done