示例#1
0
def _g(word, key_expand_round):
    # One-byte left circular rotation, substitution of each byte
    a = libutils.partition_wire(word, 8)
    sub = pyrtl.concat(sbox[a[2]], sbox[a[1]], sbox[a[0]], sbox[a[3]])
    # xor substituted bytes with round constant.
    round_const = pyrtl.concat(rcon[key_expand_round + 1], pyrtl.Const(0, bitwidth=24))
    return round_const ^ sub
示例#2
0
def barrel_shifter(shift_in, bit_in, direction, shift_dist, wrap_around=0):
    """
    Create a barrel shifter that operates on data based on the wire width
    :param shift_in:the input wire;
    :param bit_in: the 1-bit wire giving the value to shift in.
    :param direction: direction is a one bit wirevector representing shift direction
        0 = shift down, 1 = shift up.
    :param shift_dist: wirevector representing offset to shift
    :param wrap_around: ****currently not implemented*****
    :return: shifted wirevector
    """
    # Implement with logN stages pyrtl.muxing between shifted and un-shifted values

    val = shift_in
    append_val = bit_in
    log_length = int(math.log(len(shift_in) - 1, 2))  # note the one offset

    if len(shift_dist) > log_length:
        print('Warning: for barrel shifter, the shift distance wirevector '
              'has bits that are not used in the barrel shifter')

    for i in range(min(len(shift_dist), log_length)):
        shift_amt = pow(2, i)  # stages shift 1,2,4,8,...
        newval = pyrtl.mux(direction,
                           truecase=val[:-shift_amt],
                           falsecase=val[shift_amt:])
        newval = pyrtl.mux(direction,
                           truecase=pyrtl.concat(newval, append_val),
                           falsecase=pyrtl.concat(
                               append_val, newval))  # Build shifted value
        # pyrtl.mux shifted vs. unshifted by using i-th bit of shift amount signal
        val = pyrtl.mux(shift_dist[i - 1], truecase=newval, falsecase=val)
        append_val = pyrtl.concat(append_val, append_val)

    return val
示例#3
0
def barrel_shifter(shift_in, bit_in, direction, shift_dist, wrap_around=0):
    """
    Create a barrel shifter that operates on data based on the wire width
    :param shift_in:the input wire;
    :param bit_in: the 1-bit wire giving the value to shift in.
    :param direction: direction is a one bit wirevector representing shift direction
        0 = shift down, 1 = shift up.
    :param shift_dist: wirevector representing offset to shift
    :param wrap_around: ****currently not implemented*****
    :return: shifted wirevector
    """
    # Implement with logN stages pyrtl.muxing between shifted and un-shifted values

    val = shift_in
    append_val = bit_in
    log_length = int(math.log(len(shift_in)-1, 2))  # note the one offset

    if len(shift_dist) > log_length:
        print('Warning: for barrel shifter, the shift distance wirevector '
              'has bits that are not used in the barrel shifter')

    for i in range(min(len(shift_dist), log_length)):
        shift_amt = pow(2, i)  # stages shift 1,2,4,8,...
        newval = pyrtl.mux(direction, truecase=val[:-shift_amt], falsecase=val[shift_amt:])
        newval = pyrtl.mux(direction, truecase=pyrtl.concat(newval, append_val),
                           falsecase=pyrtl.concat(append_val, newval))  # Build shifted value
        # pyrtl.mux shifted vs. unshifted by using i-th bit of shift amount signal
        val = pyrtl.mux(shift_dist[i-1], truecase=newval, falsecase=val)
        append_val = pyrtl.concat(append_val, append_val)

    return val
示例#4
0
def ripple_half_add(a, cin=0):
    cin = pyrtl.as_wires(cin)
    if len(a) == 1:
        return pyrtl.concat(*half_adder(a, cin))
    else:
        ripplecarry = half_adder(a[0], cin)
        msbits = ripple_half_add(a[1:], ripplecarry[0])
        return pyrtl.concat(msbits, ripplecarry[1])
示例#5
0
文件: aes.py 项目: LinChai/PyRTL
def _g(word, key_expand_round):
    # One-byte left circular rotation, substitution of each byte
    a = libutils.partition_wire(word, 8)
    sub = pyrtl.concat(sbox[a[1]], sbox[a[2]], sbox[a[3]], sbox[a[0]])
    # xor substituted bytes with round constant.
    round_const = pyrtl.concat(rcon[key_expand_round + 1],
                               pyrtl.Const(0, bitwidth=24))
    return round_const ^ sub
示例#6
0
def ripple_half_add(a, cin=0):
    cin = pyrtl.as_wires(cin)
    if len(a) == 1:
        return pyrtl.concat(*half_adder(a, cin))
    else:
        ripplecarry = half_adder(a[0], cin)
        msbits = ripple_half_add(a[1:], ripplecarry[0])
        return pyrtl.concat(msbits, ripplecarry[1])
示例#7
0
    def test_concat(self):
        # concat's args are order dependent, therefore we need to check
        # that we aren't mangling them
        ins = [pyrtl.Input(5) for i in range(2)]
        outs = [pyrtl.Output(10) for i in range(2)]
        outs[0] <<= pyrtl.concat(ins[1], ins[0])
        outs[1] <<= pyrtl.concat(ins[0], ins[1])

        pyrtl.common_subexp_elimination()
        self.num_net_of_type('c', 2)
        self.num_net_of_type('w', 2)
        self.assert_num_net(4)
        self.assert_num_wires(6)
        pyrtl.working_block().sanity_check()
示例#8
0
    def test_concat(self):
        # concat's args are order dependent, therefore we need to check
        # that we aren't mangling them
        ins = [pyrtl.Input(5) for i in range(2)]
        outs = [pyrtl.Output(10) for i in range(2)]
        outs[0] <<= pyrtl.concat(ins[1], ins[0])
        outs[1] <<= pyrtl.concat(ins[0], ins[1])

        pyrtl.common_subexp_elimination()
        self.num_net_of_type('c', 2)
        self.num_net_of_type('w', 2)
        self.assert_num_net(4)
        self.assert_num_wires(6)
        pyrtl.working_block().sanity_check()
示例#9
0
def prng_lfsr(bitwidth, load, req, seed=None):
    """ Builds a single-cycle PRNG using a 127 bits Fibonacci LFSR.

    :param bitwidth: the desired bitwidth of the random number
    :param load: one bit signal to load the seed into the prng
    :param req: one bit signal to request a random number
    :param seed: 127 bits WireVector, defaults to None (self-seeding),
      refrain from self-seeding if reseeding at run time is required
    :return: register containing the random number with the given bitwidth

    A very fast and compact PRNG that generates a random number using only one clock cycle.
    Has a period of 2**127 - 1. Its linearity makes it a bit statistically weak, but should be
    good enough for any noncryptographic purpose like test pattern generation.
    """
    # 127 bits is chosen because 127 is a mersenne prime, which makes the period of the
    # LFSR maximized at 2**127 - 1 for any requested bitwidth
    if seed is None:
        import random
        cryptogen = random.SystemRandom()
        seed = cryptogen.randrange(
            1, 2**127)  # seed itself if no seed signal is given

    lfsr = pyrtl.Register(127 if bitwidth < 127 else bitwidth)
    # leap ahead by shifting the LFSR bitwidth times
    leap_ahead = lfsr
    for i in range(bitwidth):
        leap_ahead = pyrtl.concat(leap_ahead,
                                  leap_ahead[125] ^ leap_ahead[126])

    with pyrtl.conditional_assignment:
        with load:
            lfsr.next |= seed
        with req:
            lfsr.next |= leap_ahead
    return lfsr[:bitwidth]
示例#10
0
文件: aes.py 项目: LinChai/PyRTL
def inv_shift_rows(in_vector):
    a = libutils.partition_wire(in_vector, 8)
    out_vector = pyrtl.concat(a[0], a[7], a[10], a[13],
                              a[1], a[4], a[11], a[14],
                              a[2], a[5], a[8],  a[15],
                              a[3], a[6], a[9],  a[12])
    return out_vector
示例#11
0
def simple_mult(A, B, start):
    """ Generate simple shift-and-add multiplier.

    Builds a slow, small multiplier using the simple shift-and-add algorithm.
    Requires very small area (it uses only a single adder), but has long delay
    (worst case is len(a) cycles). a and b are arbitrary-length inputs; start
    is a one-bit input to indicate inputs are ready.done is a one-bit signal
    output raised when the multiplication is finished, at which point the
    product will be on the result line (returned by the function).
    """
    alen = len(A)
    blen = len(B)
    areg = pyrtl.Register(alen)
    breg = pyrtl.Register(blen + alen)
    accum = pyrtl.Register(blen + alen)
    done = areg == 0  # Multiplication is finished when a becomes 0

    # During multiplication, shift a right every cycle, b left every cycle
    with pyrtl.conditional_assignment:
        with start:  # initialization
            areg.next |= A
            breg.next |= B
            accum.next |= 0
        with ~done:  # don't run when there's no work to do
            areg.next |= areg[1:]  # right shift
            breg.next |= pyrtl.concat(breg, "1'b0")  # left shift

            # "Multply" shifted breg by LSB of areg by conditionally adding
            with areg[0]:
                accum.next |= accum + breg  # adds to accum only when LSB of areg is 1

    return accum, done
示例#12
0
def shift_reg(din, n):
    """
        Use a shift register to create delay-coded thresholds.
    """
    sr = pyrtl.Register(bitwidth=n)
    sr.next <<= pyrtl.concat(sr[:-1], din)
    return sr
示例#13
0
def simple_mult(A, B, start):
    """ Generate simple shift-and-add multiplier.

    Builds a slow, small multiplier using the simple shift-and-add algorithm.
    Requires very small area (it uses only a single adder), but has long delay
    (worst case is len(a) cycles). a and b are arbitrary-length inputs; start
    is a one-bit input to indicate inputs are ready.done is a one-bit signal
    output raised when the multiplication is finished, at which point the
    product will be on the result line (returned by the function).
    """
    alen = len(A)
    blen = len(B)
    areg = pyrtl.Register(alen)
    breg = pyrtl.Register(blen + alen)
    accum = pyrtl.Register(blen + alen)
    done = areg == 0  # Multiplication is finished when a becomes 0

    # During multiplication, shift a right every cycle, b left every cycle
    with pyrtl.conditional_assignment:
        with start:  # initialization
            areg.next |= A
            breg.next |= B
            accum.next |= 0
        with ~done:  # don't run when there's no work to do
            areg.next |= areg[1:]  # right shift
            breg.next |= pyrtl.concat(breg, "1'b0")  # left shift

            # "Multply" shifted breg by LSB of areg by conditionally adding
            with areg[0]:
                accum.next |= accum + breg  # adds to accum only when LSB of areg is 1

    return accum, done
示例#14
0
    def test_two_way_concat(self):
        i = pyrtl.Const(0b1100)
        j = pyrtl.Const(0b011, bitwidth=3)
        k = pyrtl.Const(0b100110)
        o = pyrtl.Output(13, 'o')
        o <<= pyrtl.concat(i, j, k)

        block = pyrtl.working_block()
        concat_nets = list(block.logic_subset(op='c'))
        self.assertEqual(len(concat_nets), 1)
        self.assertEqual(concat_nets[0].args, (i, j, k))

        pyrtl.two_way_concat()

        concat_nets = list(block.logic_subset(op='c'))
        self.assertEqual(len(concat_nets), 2)
        upper_concat = next(n for n in concat_nets if i is n.args[0])
        lower_concat = next(n for n in concat_nets if k is n.args[1])
        self.assertNotEqual(upper_concat, lower_concat)
        self.assertEqual(upper_concat.args, (i, j))
        self.assertEqual(lower_concat.args, (upper_concat.dests[0], k))

        sim = pyrtl.Simulation()
        sim.step({})
        self.assertEqual(sim.inspect('o'), 0b1100011100110)
示例#15
0
 def test_area_est_unchanged(self):
     a = pyrtl.Const(2, 8)
     b = pyrtl.Const(85, 8)
     zero = pyrtl.Const(0, 1)
     reg = pyrtl.Register(8)
     mem = pyrtl.MemBlock(8, 8)
     out = pyrtl.Output(8)
     nota, aLSB, athenb, aORb, aANDb, aNANDb, \
     aXORb, aequalsb, altb, agtb, aselectb, \
     aplusb, bminusa, atimesb, memread = [pyrtl.Output() for i in range(15)]
     out <<= zero
     nota <<= ~a
     aLSB <<= a[0]
     athenb <<= pyrtl.concat(a, b)
     aORb <<= a | b
     aANDb <<= a & b
     aNANDb <<= a.nand(b)
     aXORb <<= a ^ b
     aequalsb <<= a==b
     altb <<= a < b
     agtb <<= a > b
     aselectb <<= pyrtl.select(zero, a, b)
     reg.next <<= a
     aplusb <<= a + b
     bminusa <<= a - b
     atimesb <<= a*b
     memread <<= mem[0]
     mem[1] <<= a
     self.assertEquals(estimate.area_estimation(), (0.00734386752, 0.01879779717361501))
示例#16
0
 def test_time_est_unchanged(self):
     a = pyrtl.Const(2, 8)
     b = pyrtl.Const(85, 8)
     zero = pyrtl.Const(0, 1)
     reg = pyrtl.Register(8)
     mem = pyrtl.MemBlock(8, 8)
     out = pyrtl.Output(8)
     nota, aLSB, athenb, aORb, aANDb, aNANDb, \
     aXORb, aequalsb, altb, agtb, aselectb, \
     aplusb, bminusa, atimesb, memread = [pyrtl.Output() for i in range(15)]
     out <<= zero
     nota <<= ~a
     aLSB <<= a[0]
     athenb <<= pyrtl.concat(a, b)
     aORb <<= a | b
     aANDb <<= a & b
     aNANDb <<= a.nand(b)
     aXORb <<= a ^ b
     aequalsb <<= a == b
     altb <<= a < b
     agtb <<= a > b
     aselectb <<= pyrtl.select(zero, a, b)
     reg.next <<= a
     aplusb <<= a + b
     bminusa <<= a - b
     atimesb <<= a * b
     memread <<= mem[0]
     mem[1] <<= a
     timing = estimate.TimingAnalysis()
     self.assertEqual(timing.max_freq(), 610.2770657878676)
     self.assertEquals(timing.max_length(), 1255.6000000000001)
示例#17
0
def kogge_stone(a, b, cin=0):
    """
    Creates a Kogge-Stone adder given two inputs

    :param a, b: The two Wirevectors to add up (bitwidths don't need to match)
    :param cin: An optimal carry Wirevector or value
    :return: a Wirevector representing the output of the adder

    The Kogge-Stone adder is a fast tree-based adder with O(log(n))
    propagation delay, useful for performance critical designs. However,
    it has O(n log(n)) area usage, and large fan out.
    """
    a, b = libutils.match_bitwidth(a, b)

    prop_orig = a ^ b
    prop_bits = [i for i in prop_orig]
    gen_bits = [i for i in a & b]
    prop_dist = 1

    # creation of the carry calculation
    while prop_dist < len(a):
        for i in reversed(range(prop_dist, len(a))):
            prop_old = prop_bits[i]
            gen_bits[i] = gen_bits[i] | (prop_old & gen_bits[i - prop_dist])
            if i >= prop_dist * 2:  # to prevent creating unnecessary nets and wires
                prop_bits[i] = prop_old & prop_bits[i - prop_dist]
        prop_dist *= 2

    # assembling the result of the addition
    # preparing the cin (and conveniently shifting the gen bits)
    gen_bits.insert(0, pyrtl.as_wires(cin))
    return pyrtl.concat(*reversed(gen_bits)) ^ prop_orig
示例#18
0
文件: aes.py 项目: jolting/PyRTL
def inv_shift_rows(in_vector):
    # a = libutils.partition_wire(in_vector, 8)
    a = [in_vector[offset - 8:offset] for offset in range(128, 0, -8)]
    out_vector = pyrtl.concat(a[0], a[13], a[10], a[7], a[4], a[1], a[14],
                              a[11], a[8], a[5], a[2], a[15], a[12], a[9],
                              a[6], a[3])
    return out_vector
示例#19
0
文件: adders.py 项目: jolting/PyRTL
def cla_adder(a, b, cin=0, la_unit_len=4):
    """
    Carry Lookahead Adder
    :param int la_unit_len: the length of input that every unit processes

    A Carry LookAhead Adder is an adder that is faster than
    a ripple carry adder, as it calculates the carry bits faster.
    It is not as fast as a Kogge-Stone adder, but uses less area.
    """
    a, b = pyrtl.match_bitwidth(a, b)
    if len(a) <= la_unit_len:
        sum, cout = _cla_adder_unit(a, b, cin)
        return pyrtl.concat(cout, sum)
    else:
        sum, cout = _cla_adder_unit(a[0:la_unit_len], b[0:la_unit_len], cin)
        msbits = cla_adder(a[la_unit_len:], b[la_unit_len:], cout, la_unit_len)
        return pyrtl.concat(msbits, sum)
示例#20
0
def barrel_shifter(bits_to_shift,
                   bit_in,
                   direction,
                   shift_dist,
                   wrap_around=0):
    """ Create a barrel shifter that operates on data based on the wire width.

    :param bits_to_shift: the input wire
    :param bit_in: the 1-bit wire giving the value to shift in
    :param direction: a one bit WireVector representing shift direction
        (0 = shift down, 1 = shift up)
    :param shift_dist: WireVector representing offset to shift
    :param wrap_around: ****currently not implemented****
    :return: shifted WireVector
    """
    from pyrtl import concat, select  # just for readability

    if wrap_around != 0:
        raise NotImplementedError

    # Implement with logN stages pyrtl.muxing between shifted and un-shifted values
    final_width = len(bits_to_shift)
    val = bits_to_shift
    append_val = bit_in

    for i in range(len(shift_dist)):
        shift_amt = pow(2, i)  # stages shift 1,2,4,8,...
        if shift_amt < final_width:
            newval = select(
                direction,
                concat(val[:-shift_amt], append_val),  # shift up
                concat(append_val, val[shift_amt:]))  # shift down
            val = select(
                shift_dist[i],
                truecase=newval,  # if bit of shift is 1, do the shift
                falsecase=val)  # otherwise, don't
            # the value to append grows exponentially, but is capped at full width
            append_val = concat(append_val, append_val)[:final_width]
        else:
            # if we are shifting this much, all the data is gone
            val = select(
                shift_dist[i],
                truecase=append_val,  # if bit of shift is 1, do the shift
                falsecase=val)  # otherwise, don't

    return val
示例#21
0
def cla_adder(a, b, cin=0, la_unit_len=4):
    """
    Carry Lookahead Adder
    :param int la_unit_len: the length of input that every unit processes

    A Carry LookAhead Adder is an adder that is faster than
    a ripple carry adder, as it calculates the carry bits faster.
    It is not as fast as a Kogge-Stone adder, but uses less area.
    """
    a, b = pyrtl.match_bitwidth(a, b)
    if len(a) <= la_unit_len:
        sum, cout = _cla_adder_unit(a, b, cin)
        return pyrtl.concat(cout, sum)
    else:
        sum, cout = _cla_adder_unit(a[0:la_unit_len], b[0:la_unit_len], cin)
        msbits = cla_adder(a[la_unit_len:], b[la_unit_len:], cout, la_unit_len)
        return pyrtl.concat(msbits, sum)
示例#22
0
def inv_shift_rows(in_vector):
    # a = libutils.partition_wire(in_vector, 8)
    a = [in_vector[offset - 8:offset] for offset in range(128, 0, -8)]
    out_vector = pyrtl.concat(a[0], a[13], a[10], a[7],
                              a[4], a[1], a[14], a[11],
                              a[8], a[5], a[2], a[15],
                              a[12], a[9], a[6], a[3])
    return out_vector
示例#23
0
def ripple_add(a, b, cin=0):
    a, b = libutils.match_bitwidth(a, b)
    cin = pyrtl.as_wires(cin)
    if len(a) == 1:
        return one_bit_add(a, b, cin)
    else:
        ripplecarry = one_bit_add(a[0], b[0], cin)
        msbits = ripple_add(a[1:], b[1:], ripplecarry[1])
        return pyrtl.concat(msbits, ripplecarry[0])
示例#24
0
def key_expansion(key):
    w = list(libutils.partition_wire(key, 32))
    for key_expand_round in range(10):
        last = key_expand_round * 4
        w.append(w[last] ^ _g(w[last + 3], key_expand_round))
        w.append(w[-1] ^ w[last + 1])
        w.append(w[-1] ^ w[last + 2])
        w.append(w[-1] ^ w[last + 3])
    return pyrtl.concat(*w)
示例#25
0
文件: aes.py 项目: LinChai/PyRTL
def inv_mix_columns(in_vector):
    def _inv_mix_single(index):
        mult_items = [inv_galois_mult(a[_mod_add(index, loc, 4)], mult_table)
                      for loc, mult_table in enumerate(_igm_divisor)]
        return mult_items[0] ^ mult_items[1] ^ mult_items[2] ^ mult_items[3]

    a = libutils.partition_wire(in_vector, 8)
    inverted = [_inv_mix_single(index) for index in range(len(a))]
    return pyrtl.concat(*inverted)
示例#26
0
 def test_async_check_should_pass_with_cat(self):
     memory = pyrtl.MemBlock(
                 bitwidth=self.bitwidth, 
                 addrwidth=self.addrwidth,
                 name='memory')
     addr = pyrtl.concat(self.mem_read_address1[0], self.mem_read_address2[0:-1])
     self.output1 <<= memory[addr]
     memory[self.mem_write_address] <<= self.mem_write_data
     pyrtl.working_block().sanity_check()
示例#27
0
def ripple_add(a, b, cin=0):
    a, b = libutils.match_bitwidth(a, b)
    cin = pyrtl.as_wires(cin)
    if len(a) == 1:
        return one_bit_add(a, b, cin)
    else:
        ripplecarry = one_bit_add(a[0], b[0], cin)
        msbits = ripple_add(a[1:], b[1:], ripplecarry[1])
        return pyrtl.concat(msbits, ripplecarry[0])
示例#28
0
 def test_netgraph_same_wire_multiple_edges_to_same_net(self):
     c = pyrtl.Const(1, 1)
     w = pyrtl.concat(c, c, c)
     g = pyrtl.net_graph()
     self.assertEqual(len(g[c]), 1)
     edges = list(g[c].values())[0]
     self.assertEqual(len(edges), 3)
     for w in edges:
         self.assertIs(w, c)
示例#29
0
文件: aes.py 项目: LinChai/PyRTL
def key_expansion(key):
    w = list(reversed(libutils.partition_wire(key, 32)))
    for key_expand_round in range(10):
        last = key_expand_round * 4
        w.append(w[last] ^ _g(w[last + 3], key_expand_round))
        w.append(w[-1] ^ w[last + 1])
        w.append(w[-1] ^ w[last + 2])
        w.append(w[-1] ^ w[last + 3])
    return pyrtl.concat(*w)
示例#30
0
def barrel_shifter_v2(bits_to_shift,
                      bit_in,
                      direction,
                      shift_dist,
                      wrap_around=0):
    """
    Create a barrel shifter that operates on data based on the wire width

    :param bits_to_shift: the input wire
    :param bit_in: the 1-bit wire giving the value to shift in
    :param direction: a one bit WireVector representing shift direction
        (0 = shift down, 1 = shift up)
    :param shift_dist: WireVector representing offset to shift
    :param wrap_around: ****currently not implemented****
    :return: shifted WireVector
    """
    # Implement with logN stages pyrtl.muxing between shifted and un-shifted values

    val = bits_to_shift
    append_val = bit_in
    log_length = int(math.log(len(bits_to_shift) - 1,
                              2))  # note the one offset

    if wrap_around != 0:
        raise NotImplementedError

    # if len(shift_dist) > log_length:
    #     raise pyrtl.PyrtlError('the shift distance wirevector '
    #                            'has bits that are not used in the barrel shifter')

    for i in range(min(len(shift_dist), log_length)):
        shift_amt = pow(2, i)  # stages shift 1,2,4,8,...
        newval = pyrtl.select(direction,
                              truecase=val[:-shift_amt],
                              falsecase=val[shift_amt:])
        newval = pyrtl.select(direction,
                              truecase=pyrtl.concat(newval, append_val),
                              falsecase=pyrtl.concat(
                                  append_val, newval))  # Build shifted value
        # pyrtl.mux shifted vs. unshifted by using i-th bit of shift amount signal
        val = pyrtl.select(shift_dist[i], truecase=newval, falsecase=val)
        append_val = pyrtl.concat(append_val, bit_in)

    return val
示例#31
0
def ripple_add(a, b, cin=0):
    a, b = pyrtl.match_bitwidth(a, b)
    # this function is a function that allows us to match the bitwidth of multiple
    # different wires. By default, it zero extends the shorter bits
    if len(a) == 1:
        sumbits, cout = one_bit_add(a, b, cin)
    else:
        lsbit, ripplecarry = one_bit_add(a[0], b[0], cin)
        msbits, cout = ripple_add(a[1:], b[1:], ripplecarry)
        sumbits = pyrtl.concat(msbits, lsbit)
    return sumbits, cout
示例#32
0
def carrysave_adder(a, b, c, final_adder=ripple_add):
    """
    Adds three wirevectors up in an efficient manner
    :param WireVector a, b, c : the three wires to add up
    :param function final_adder : The adder to use to do the final addition
    :return: a wirevector with length 2 longer than the largest input
    """
    a, b, c = libutils.match_bitwidth(a, b, c)
    partial_sum = a ^ b ^ c
    shift_carry = (a | b) & (a | c) & (b | c)
    return pyrtl.concat(final_adder(partial_sum[1:], shift_carry), partial_sum[0])
def ripple_add(a, b, carry_in=0):
    a, b = pyrtl.match_bitwidth(a, b)
    # function that allows us to match the bitwidth of multiple different wires
    # By default, it zero extends the shorter bits
    if len(a) == 1:
        sumbits, carry_out = one_bit_add(a, b, carry_in)
    else:
        lsbit, ripplecarry = one_bit_add(a[0], b[0], carry_in)
        msbits, carry_out = ripple_add(a[1:], b[1:], ripplecarry)
        sumbits = pyrtl.concat(msbits, lsbit)
    return sumbits, carry_out
示例#34
0
def carrysave_adder(a, b, c):
    """
    Adds three wirevectors up in an efficient manner
    :param a, b, c: the three wirevectors to add up
    :return: a wirevector with length 2 longer than the largest input
    """
    libutils.match_bitwidth(a, b, c)
    partial_sum = a ^ b ^ c
    shift_carry = (a | b) & (a | c) & (b | c)
    shift_carry_1 = pyrtl.concat(shift_carry, 0)
    return ripple_add(partial_sum, shift_carry_1)
示例#35
0
def generate_full_mux(a, b, sel):
    """Generates a multiplexor
       b is the one selected when sel is high"""
    assert len(a) == len(b)
    if len(a) == 1:
        out = generate_one_bit_mux(a, b, sel)
    else:
        lsbit = generate_one_bit_mux(a[0], b[0], sel)
        msbits = generate_full_mux(a[1:], b[1:], sel)
        out = pyrtl.concat(msbits, lsbit)
    return out
示例#36
0
def generate_full_mux(a, b, sel):
    """Generates a multiplexor
       b is the one selected when sel is high"""
    assert len(a) == len(b)
    if len(a) == 1:
        out = generate_one_bit_mux(a, b, sel)
    else:
        lsbit = generate_one_bit_mux(a[0], b[0], sel)
        msbits = generate_full_mux(a[1:], b[1:], sel)
        out = pyrtl.concat(msbits, lsbit)
    return out
示例#37
0
文件: aes.py 项目: LinChai/PyRTL
def inv_mix_columns(in_vector):
    def _inv_mix_single(index):
        mult_items = [
            inv_galois_mult(a[_mod_add(index, loc, 4)], mult_table)
            for loc, mult_table in enumerate(_igm_divisor)
        ]
        return mult_items[0] ^ mult_items[1] ^ mult_items[2] ^ mult_items[3]

    a = libutils.partition_wire(in_vector, 8)
    inverted = [_inv_mix_single(index) for index in range(len(a))]
    return pyrtl.concat(*inverted)
示例#38
0
def generate_full_adder(a, b, cin=None):
    """ Generates a arbitrary bitwidth ripple-carry adder """
    assert len(a) == len(b)
    if cin is None:
        cin = pyrtl.Const(0, bitwidth=1)
    if len(a) == 1:
        sumbits, cout = generate_one_bit_adder(a, b, cin)
    else:
        lsbit, ripplecarry = generate_one_bit_adder(a[0], b[0], cin)
        msbits, cout = generate_full_adder(a[1:], b[1:], ripplecarry)
        sumbits = pyrtl.concat(msbits, lsbit)
    return sumbits, cout
示例#39
0
def _sparse_adder(wire_array_2, adder):
    result = []
    for single_w_index in range(len(wire_array_2)):
        if len(wire_array_2[single_w_index]) == 2:  # Check if the two wire vectors overlap yet
            break
        result.append(wire_array_2[single_w_index][0])

    import six
    wires_to_zip = wire_array_2[single_w_index:]
    add_wires = tuple(six.moves.zip_longest(*wires_to_zip, fillvalue=pyrtl.Const(0)))
    adder_result = adder(pyrtl.concat_list(add_wires[0]), pyrtl.concat_list(add_wires[1]))
    return pyrtl.concat(adder_result, *reversed(result))
示例#40
0
def generate_full_adder(a, b, cin=None):
    """ Generates a arbitrary bitwidth ripple-carry adder """
    assert len(a) == len(b)
    if cin is None:
        cin = pyrtl.Const(0, bitwidth=1)
    if len(a) == 1:
        sumbits, cout = generate_one_bit_adder(a, b, cin)
    else:
        lsbit, ripplecarry = generate_one_bit_adder(a[0], b[0], cin)
        msbits, cout = generate_full_adder(a[1:], b[1:], ripplecarry)
        sumbits = pyrtl.concat(msbits, lsbit)
    return sumbits, cout
示例#41
0
文件: adders.py 项目: jolting/PyRTL
def carrysave_adder(a, b, c, final_adder=ripple_add):
    """
    Adds three wirevectors up in an efficient manner
    :param WireVector a, b, c : the three wires to add up
    :param function final_adder : The adder to use to do the final addition
    :return: a wirevector with length 2 longer than the largest input
    """
    a, b, c = libutils.match_bitwidth(a, b, c)
    partial_sum = a ^ b ^ c
    shift_carry = (a | b) & (a | c) & (b | c)
    shift_carry_1 = pyrtl.concat(shift_carry, 0)
    return final_adder(partial_sum, shift_carry_1)
示例#42
0
def ripple_add(a, b, cin=0):
    if len(a) < len(b):  # make sure that b is the shorter wire
        b, a = a, b
    cin = pyrtl.as_wires(cin)
    if len(a) == 1:
        return one_bit_add(a, b, cin)
    else:
        ripplecarry = one_bit_add(a[0], b[0], cin)
        if len(b) == 1:
            msbits = ripple_half_add(a[1:], ripplecarry[1])
        else:
            msbits = ripple_add(a[1:], b[1:], ripplecarry[1])
        return pyrtl.concat(msbits, ripplecarry[0])
示例#43
0
def ripple_add(a, b, cin=0):
    if len(a) < len(b):  # make sure that b is the shorter wire
        b, a = a, b
    cin = pyrtl.as_wires(cin)
    if len(a) == 1:
        return one_bit_add(a, b, cin)
    else:
        ripplecarry = one_bit_add(a[0], b[0], cin)
        if len(b) == 1:
            msbits = ripple_half_add(a[1:], ripplecarry[1])
        else:
            msbits = ripple_add(a[1:], b[1:], ripplecarry[1])
        return pyrtl.concat(msbits, ripplecarry[0])
示例#44
0
    def test_loop_2(self):
        in_1 = pyrtl.Input(10)
        in_2 = pyrtl.Input(9)
        fake_loop_wire = pyrtl.WireVector(1)
        # Note the slight difference from the last test case on the next line
        comp_wire = pyrtl.concat(in_2[0:6], fake_loop_wire, in_2[6:9])
        r_wire = in_1 & comp_wire
        fake_loop_wire <<= r_wire[3]
        out = pyrtl.Output(10)
        out <<= fake_loop_wire

        # It causes there to be a real loop
        self.assert_has_loop()
示例#45
0
    def test_loop_2(self):
        in_1 = pyrtl.Input(10)
        in_2 = pyrtl.Input(9)
        fake_loop_wire = pyrtl.WireVector(1)
        # Note the slight difference from the last test case on the next line
        comp_wire = pyrtl.concat(in_2[0:6], fake_loop_wire, in_2[6:9])
        r_wire = in_1 & comp_wire
        fake_loop_wire <<= r_wire[3]
        out = pyrtl.Output(10)
        out <<= fake_loop_wire

        # It causes there to be a real loop
        self.assert_has_loop()
示例#46
0
def _sparse_adder(wire_array_2, adder):
    result = []
    for single_w_index in range(len(wire_array_2)):
        if len(wire_array_2[single_w_index]
               ) == 2:  # Check if the two wire vectors overlap yet
            break
        result.append(wire_array_2[single_w_index][0])

    import six
    wires_to_zip = wire_array_2[single_w_index:]
    add_wires = tuple(
        six.moves.zip_longest(*wires_to_zip, fillvalue=pyrtl.Const(0)))
    adder_result = adder(pyrtl.concat_list(add_wires[0]),
                         pyrtl.concat_list(add_wires[1]))
    return pyrtl.concat(adder_result, *reversed(result))
示例#47
0
def _one_cycle_mult(areg, breg, rem_bits, sum_sf=0, curr_bit=0):
    """ returns a WireVector sum of rem_bits multiplies (in one clock cycle)
    note: this method requires a lot of area because of the indexing in the else statement """
    if rem_bits == 0:
        return sum_sf
    else:
        a_curr_val = areg[curr_bit].sign_extended(len(breg))
        if curr_bit == 0:  # if no shift
            return(_one_cycle_mult(areg, breg, rem_bits-1,  # areg, breg, rem_bits
                                   sum_sf + (a_curr_val & breg),  # sum_sf
                                   curr_bit+1))  # curr_bit
        else:
            return _one_cycle_mult(
                areg, breg, rem_bits-1,  # areg, breg, rem_bits
                sum_sf + (a_curr_val &
                          pyrtl.concat(breg, pyrtl.Const(0, curr_bit))),  # sum_sf
                curr_bit+1  # curr_bit
            )
示例#48
0
def _sparse_adder(wire_array_2, adder):
    bitwidth = len(wire_array_2)
    add_wires = [], []
    result = []
    for single_w_index in range(bitwidth):
        if len(wire_array_2[single_w_index]) == 2:  # Check if the two wire vectors overlap yet
            break
        result.append(wire_array_2[single_w_index][0])

    for w_loc in range(single_w_index, bitwidth):
        for i in range(2):
            if len(wire_array_2[w_loc]) >= i + 1:
                add_wires[i].append(wire_array_2[w_loc][i])
            else:
                add_wires[i].append(pyrtl.Const(0))

    adder_result = adder(pyrtl.concat_list(add_wires[0]), pyrtl.concat_list(add_wires[1]))
    return pyrtl.concat(adder_result, *reversed(result))
示例#49
0
    def test_as_graph_duplicate_args(self):
        a = pyrtl.Input(3)
        x = pyrtl.Input(1)
        d = pyrtl.Output()
        b = a & a
        c = pyrtl.concat(a, a)
        m = pyrtl.MemBlock(addrwidth=3, bitwidth=3, name='m')
        m2 = pyrtl.MemBlock(addrwidth=1, bitwidth=1, name='m')
        d <<= m[a]
        m[a] <<= a
        m2[x] <<= pyrtl.MemBlock.EnabledWrite(x, x)

        b = pyrtl.working_block()
        src_g, dst_g = b.net_connections(False)
        self.check_graph_correctness(src_g, dst_g)

        src_g, dst_g = b.net_connections(True)
        self.check_graph_correctness(src_g, dst_g, True)
示例#50
0
    def test_edge_case_1(self):
        in_1 = pyrtl.Input(10)
        in_2 = pyrtl.Input(9)
        fake_loop_wire = pyrtl.WireVector(1)
        comp_wire = pyrtl.concat(in_2[0:4], fake_loop_wire, in_2[4:9])
        r_wire = in_1 & comp_wire
        fake_loop_wire <<= r_wire[3]
        out = pyrtl.Output(10)
        out <<= fake_loop_wire

        # Yes, because we only check loops on a net level, this will still be
        # a loop pre synth
        self.assertNotEqual(pyrtl.find_loop(), None)
        pyrtl.synthesize()

        # Because synth separates the individual wires, it also resolves the loop
        self.assertEqual(pyrtl.find_loop(), None)
        pyrtl.optimize()
        self.assertEqual(pyrtl.find_loop(), None)
示例#51
0
    def test_edge_case_1(self):
        in_1 = pyrtl.Input(10)
        in_2 = pyrtl.Input(9)
        fake_loop_wire = pyrtl.WireVector(1)
        comp_wire = pyrtl.concat(in_2[0:4], fake_loop_wire, in_2[4:9])
        r_wire = in_1 & comp_wire
        fake_loop_wire <<= r_wire[3]
        out = pyrtl.Output(10)
        out <<= fake_loop_wire

        # Yes, because we only check loops on a net level, this will still be
        # a loop pre synth
        self.assertNotEqual(pyrtl.find_loop(), None)
        pyrtl.synthesize()

        # Because synth separates the individual wires, it also resolves the loop
        self.assertEqual(pyrtl.find_loop(), None)
        pyrtl.optimize()
        self.assertEqual(pyrtl.find_loop(), None)
示例#52
0
def _shifted_reg_next(reg, direct, num=1):
    """
    use: myReg.next = shifted_reg_next(myReg, 'l', 4)
    :param string direct: direction of shift, either 'l' or 'r'
    :param int num: number of shifts
    :return Register reg_next: a 'next' property for shifted (left or right) register
    """
    if direct == 'l':
        if num >= len(reg):
            return 0
        else:
            return pyrtl.concat(reg, pyrtl.Const(0, num))
    elif direct == 'r':
        if num >= len(reg):
            return 0
        else:
            return reg[num:]
    else:
        raise pyrtl.PyrtlError("direction must be specified with 'direct'"
                               "parameter as either 'l' or 'r'")
示例#53
0
def _cla_adder_unit(a, b, cin):
    """
    Carry generation and propogation signals will be calculated only using
    the inputs; their values don't rely on the sum.  Every unit generates
    a cout signal which is used as cin for the next unit.
    """
    gen = a & b
    prop = a ^ b
    assert(len(prop) == len(gen))

    carry = [gen[0] | prop[0] & cin]
    sum_bit = prop[0] ^ cin

    cur_gen = gen[0]
    cur_prop = prop[0]
    for i in range(1, len(prop)):
        cur_gen = gen[i] | (prop[i] & cur_gen)
        cur_prop = cur_prop & prop[i]
        sum_bit = pyrtl.concat(prop[i] ^ carry[i-1], sum_bit)
        carry.append(gen[i] | (prop[i] & carry[i-1]))
    cout = cur_gen | (cur_prop & cin)
    return sum_bit, cout
示例#54
0
def simple_mult(A, B, start):
    """ Generate simple shift-and-add multiplier.

    :param

    Builds a slow, small multiplier using the simple shift-and-add algorithm.
    Requires very small area (it uses only a single adder), but has long delay
    (worst case is len(a) cycles). a and b are arbitrary-length inputs; start
    is a one-bit input to indicate inputs are ready.done is a one-bit signal
    output raised when the multiplication is finished, at which point the
    product will be on the result line (returned by the function).
    """
    triv_result = _trivial_mult(A, B)
    if triv_result is not None:
        return triv_result, pyrtl.Const(1, 1)

    alen = len(A)
    blen = len(B)
    areg = pyrtl.Register(alen)
    breg = pyrtl.Register(blen + alen)
    accum = pyrtl.Register(blen + alen)
    done = (areg == 0)  # Multiplication is finished when a becomes 0

    # During multiplication, shift a right every cycle, b left every cycle
    with pyrtl.conditional_assignment:
        with start:  # initialization
            areg.next |= A
            breg.next |= B
            accum.next |= 0
        with ~done:  # don't run when there's no work to do
            areg.next |= areg[1:]  # right shift
            breg.next |= pyrtl.concat(breg, pyrtl.Const(0, 1))  # left shift
            a_0_val = areg[0].sign_extended(len(accum))

            # adds to accum only when LSB of areg is 1
            accum.next |= accum + (a_0_val & breg)

    return accum, done