Ejemplo n.º 1
0
def prioritized_mux(selects, vals):
    """
    Returns the value in the first wire for which its select bit is 1

    :param [WireVector] selects: a list of WireVectors signaling whether
        a wire should be chosen
    :param [WireVector] vals: values to return when the corresponding select
        value is 1
    :return: WireVector

    If none of the items are high, the last val is returned
    """
    if len(selects) != len(vals):
        raise pyrtl.PyrtlError("Number of select and val signals must match")
    if len(vals) == 0:
        raise pyrtl.PyrtlError("Must have a signal to mux")
    if len(vals) == 1:
        return vals[0]
    else:
        half = len(vals) // 2
        return pyrtl.select(pyrtl.rtl_any(*selects[:half]),
                            truecase=prioritized_mux(selects[:half],
                                                     vals[:half]),
                            falsecase=prioritized_mux(selects[half:],
                                                      vals[half:]))
Ejemplo n.º 2
0
def sparse_mux(sel, vals):
    """ Mux that avoids instantiating unnecessary mux_2s when possible.

    :param WireVector sel: Select wire, determines what is selected on a given cycle
    :param dictionary vals: dictionary of values at mux inputs (of type `{int:WireVector}`)
    :return: WireVector that signifies the change

    This mux supports not having a full specification. Indices that are not
    specified are treated as don't-cares

    It also supports a specified default value, SparseDefault
    """
    import numbers

    max_val = 2**len(sel) - 1
    if SparseDefault in vals:
        default_val = vals[SparseDefault]
        del vals[SparseDefault]
        for i in range(max_val + 1):
            if i not in vals:
                vals[i] = default_val

    for key in vals.keys():
        if not isinstance(key, numbers.Integral):
            raise pyrtl.PyrtlError("value %s nust be either an integer or 'default'" % str(key))
        if key < 0 or key > max_val:
            raise pyrtl.PyrtlError("value %s is out of range of the sel wire" % str(key))

    return _sparse_mux(sel, vals)
Ejemplo n.º 3
0
def dada_reducer(wire_array_2, result_bitwidth, final_adder=kogge_stone):
    """
    The reduction and final adding part of a dada tree. Useful for adding many numbers together
    The use of single bitwidth wires is to allow for additional flexibility

    :param [[Wirevector]] wire_array_2: An array of arrays of single bitwidth
        wirevectors
    :param int result_bitwidth: The bitwidth you want for the resulting wire.
        Used to eliminate unnessary wires.
    :param final_adder: The adder used for the final addition
    :return: wirevector of length result_wirevector
    """
    import math
    # verification that the wires are actually wirevectors of length 1
    for wire_set in wire_array_2:
        for a_wire in wire_set:
            if not isinstance(a_wire, pyrtl.WireVector) or len(a_wire) != 1:
                raise pyrtl.PyrtlError(
                    "The item {} is not a valid element for the wire_array_2. "
                    "It must be a WireVector of bitwidth 1".format(a_wire))

    max_width = max(len(i) for i in wire_array_2)
    reduction_schedule = [2]
    while reduction_schedule[-1] <= max_width:
        reduction_schedule.append(int(reduction_schedule[-1] * 3 / 2))

    for reduction_target in reversed(reduction_schedule[:-1]):
        deferred = [[] for weight in range(result_bitwidth + 1)]
        last_round = (max(len(i) for i in wire_array_2) == 3)
        for i, w_array in enumerate(
                wire_array_2):  # Start with low weights and start reducing
            while len(w_array) + len(deferred[i]) > reduction_target:
                if len(w_array) + len(deferred[i]) - reduction_target >= 2:
                    cout, sum = _one_bit_add_no_concat(*(w_array.pop(0)
                                                         for j in range(3)))
                    deferred[i].append(sum)
                    deferred[i + 1].append(cout)
                else:
                    # if (last_round and len(deferred[i]) % 3 == 1) or (len(deferred[i]) % 3 == 2):
                    # if not(last_round and len(wire_array_2[i + 1]) < 3):
                    cout, sum = half_adder(*(w_array.pop(0) for j in range(2)))
                    deferred[i].append(sum)
                    deferred[i + 1].append(cout)
            deferred[i].extend(w_array)
            if len(deferred[i]) > reduction_target:
                raise pyrtl.PyrtlError(
                    "Expected that the code would be able to reduce more wires"
                )
        wire_array_2 = deferred[:result_bitwidth]

    # At this stage in the multiplication we have only 2 wire vectors left.
    # now we need to add them up
    result = _sparse_adder(wire_array_2, final_adder)
    if len(result) > result_bitwidth:
        return result[:result_bitwidth]
    else:
        return result
Ejemplo n.º 4
0
def _get_output(out_wire, block):
    if out_wire is None:
        outs = block.wirevector_subset(pyrtl.Output)
        if len(outs) != 1:
            raise pyrtl.PyrtlError(
                "If you don't have exactly one Outout wire, you must "
                "specify the Output wire to use")
        return outs.pop()
    elif isinstance(out_wire, pyrtl.WireVector):
        return out_wire
    else:
        raise pyrtl.PyrtlError("Invalid out_wire, %s" % str(out_wire))
Ejemplo n.º 5
0
 def option(self, select_val, *data_signals):
     self._check_finalized()
     instr, ib = pyrtl.infer_val_and_bitwidth(select_val, self.signal_wire.bitwidth)
     if instr in self.instructions:
         raise pyrtl.PyrtlError("instruction %s already exists" % str(select_val))
     self.instructions.append(instr)
     self._add_signal(data_signals)
Ejemplo n.º 6
0
 def __getattr__(self, name):
     try:
         return self._pipeline_register_map[self._current_stage_num][name]
     except KeyError:
         raise pyrtl.PyrtlError(
             'error, no pipeline register "%s" defined for stage %d' %
             (name, self._current_stage_num))
Ejemplo n.º 7
0
def calcuate_max_and_min_bitwidths(max_bitwidth=None, exact_bitwidth=None):
    if max_bitwidth is not None:
        min_bitwidth = 1
    elif exact_bitwidth is not None:
        min_bitwidth = max_bitwidth = exact_bitwidth
    else:
        raise pyrtl.PyrtlError("A max or exact bitwidth must be specified")
    return min_bitwidth, max_bitwidth
Ejemplo n.º 8
0
    def decryption_statem(self, ciphertext_in, key_in, reset):
        """
        Builds a multiple cycle AES Decryption state machine circuit

        :param reset: a one bit signal telling the state machine
          to reset and accept the current plaintext and key
        :return ready, plain_text: ready is a one bit signal showing
          that the decryption result (plain_text) has been calculated.

        """
        if len(key_in) != len(ciphertext_in):
            raise pyrtl.PyrtlError(
                "AES key and ciphertext should be the same length")

        cipher_text, key = (pyrtl.Register(len(ciphertext_in))
                            for i in range(2))
        key_exp_in, add_round_in = (pyrtl.WireVector(len(ciphertext_in))
                                    for i in range(2))

        # this is not part of the state machine as we need the keys in
        # reverse order...
        reversed_key_list = reversed(self._key_gen(key_exp_in))

        counter = pyrtl.Register(4, 'counter')
        round = pyrtl.WireVector(4)
        counter.next <<= round

        inv_shift = self._inv_shift_rows(cipher_text)
        inv_sub = self._sub_bytes(inv_shift, True)
        key_out = pyrtl.mux(round, *reversed_key_list, default=0)
        add_round_out = self._add_round_key(add_round_in, key_out)
        inv_mix_out = self._mix_columns(add_round_out, True)

        with pyrtl.conditional_assignment:
            with reset == 1:
                round |= 0
                key.next |= key_in
                key_exp_in |= key_in  # to lower the number of cycles needed
                cipher_text.next |= add_round_out
                add_round_in |= ciphertext_in

            with counter == 10:  # keep everything the same
                round |= counter
                cipher_text.next |= cipher_text

            with pyrtl.otherwise:  # running through AES
                round |= counter + 1

                key.next |= key
                key_exp_in |= key
                add_round_in |= inv_sub
                with counter == 9:
                    cipher_text.next |= add_round_out
                with pyrtl.otherwise:
                    cipher_text.next |= inv_mix_out

        ready = (counter == 10)
        return ready, cipher_text
Ejemplo n.º 9
0
def partition_wire(wire, partition_size):
    if len(wire) % partition_size != 0:
        raise pyrtl.PyrtlError(
            "Wire {} cannot be evenly partitioned into items of size {}".
            format(wire, partition_size))
    return [
        wire[offset - partition_size:offset]
        for offset in range(len(wire), 0, -partition_size)
    ]
Ejemplo n.º 10
0
    def _add_signal(self, data_signals):
        self._check_finalized()
        if len(data_signals) != len(self.dest_wires):
            raise pyrtl.PyrtlError("Incorrect number of data_signals for "
                                   "instruction received {} , expected {}"
                                   .format(len(data_signals), len(self.dest_wires)))

        for dw, sig in zip(self.dest_wires, data_signals):
            data_signal = pyrtl.as_wires(sig, dw.bitwidth)
            self.dest_instrs_info[dw].append(data_signal)
Ejemplo n.º 11
0
def _sparse_mux(sel, vals):
    """
    Mux that avoids instantiating unnecessary mux_2s when possible.

    :param WireVector sel: Select wire, determines what is selected on a given cycle
    :param {int: WireVector} vals: dictionary to store the values that are
    :return: Wirevector that signifies the change

    This mux supports not having a full specification. indices that are not
    specified are treated as Don't Cares
    """
    items = list(vals.values())
    if len(vals) <= 1:
        if len(vals) == 0:
            raise pyrtl.PyrtlError("Needs at least one parameter for val")
        return items[0]

    if len(sel) == 1:
        try:
            false_result = vals[0]
            true_result = vals[1]
        except KeyError:
            raise pyrtl.PyrtlError("Failed to retrieve values for smartmux. "
                                   "The length of sel might be wrong")
    else:
        half = 2**(len(sel) - 1)

        first_dict = {indx: wire for indx, wire in vals.items() if indx < half}
        second_dict = {
            indx - half: wire
            for indx, wire in vals.items() if indx >= half
        }
        if not len(first_dict):
            return sparse_mux(sel[:-1], second_dict)
        if not len(second_dict):
            return sparse_mux(sel[:-1], first_dict)

        false_result = sparse_mux(sel[:-1], first_dict)
        true_result = sparse_mux(sel[:-1], second_dict)
    if _is_equivelent(false_result, true_result):
        return true_result
    return pyrtl.select(sel[-1], falsecase=false_result, truecase=true_result)
Ejemplo n.º 12
0
def twos_comp_repr(val, bitwidth):
    """
    Converts a value to it's two's-complement (positive) integer representation using a
    given bitwidth (only converts the value if it is negative).
    For use with Simulation.step() etc. in passing negative numbers, which it does not accept
    """
    correctbw = abs(val).bit_length() + 1
    if bitwidth < correctbw:
        raise pyrtl.PyrtlError("please choose a larger target bitwidth")
    if val >= 0:
        return val
    else:
        return (~abs(val) & (2**bitwidth - 1)) + 1  # flip the bits and add one
Ejemplo n.º 13
0
    def decryption(self, ciphertext, key):
        """
        Builds a single cycle AES Decryption circuit

        :param WireVector ciphertext: data to decrypt
        :param WireVector key: AES key to use to encrypt (AES is symmetric)
        :return: a WireVector containing the plaintext
        """
        if len(ciphertext) != self._key_len:
            raise pyrtl.PyrtlError("Ciphertext length is invalid")
        if len(key) != self._key_len:
            raise pyrtl.PyrtlError("key length is invalid")
        key_list = self._key_gen(key)
        t = self._add_round_key(ciphertext, key_list[10])

        for round in range(1, 11):
            t = self._inv_shift_rows(t)
            t = self._sub_bytes(t, True)
            t = self._add_round_key(t, key_list[10 - round])
            if round != 10:
                t = self._mix_columns(t, True)

        return t
Ejemplo n.º 14
0
    def encrypt_state_m(self, plaintext_in, key_in, reset):
        """
        Builds a multiple cycle AES Encryption state machine circuit

        :param reset: a one bit signal telling the state machine
          to reset and accept the current plaintext and key
        :return ready, cipher_text: ready is a one bit signal showing
          that the encryption result (cipher_text) has been calculated.

        """
        if len(key_in) != len(plaintext_in):
            raise pyrtl.PyrtlError(
                "AES key and plaintext should be the same length")

        plain_text, key = (pyrtl.Register(len(plaintext_in)) for i in range(2))
        key_exp_in, add_round_in = (pyrtl.WireVector(len(plaintext_in))
                                    for i in range(2))

        counter = pyrtl.Register(4, 'counter')
        round = pyrtl.WireVector(4, 'round')
        counter.next <<= round
        sub_out = self._sub_bytes(plain_text)
        shift_out = self._shift_rows(sub_out)
        mix_out = self._mix_columns(shift_out)
        key_out = self._key_expansion(key, counter)
        add_round_out = self._add_round_key(add_round_in, key_exp_in)
        with pyrtl.conditional_assignment:
            with reset == 1:
                round |= 0
                key_exp_in |= key_in  # to lower the number of cycles
                plain_text.next |= add_round_out
                key.next |= key_in
                add_round_in |= plaintext_in

            with counter == 10:  # keep everything the same
                round |= counter
                plain_text.next |= plain_text

            with pyrtl.otherwise:  # running through AES
                round |= counter + 1
                key_exp_in |= key_out
                plain_text.next |= add_round_out
                key.next |= key_out
                with counter == 9:
                    add_round_in |= shift_out
                with pyrtl.otherwise:
                    add_round_in |= mix_out

        ready = (counter == 10)
        return ready, plain_text
Ejemplo n.º 15
0
def rev_twos_comp_repr(val, bitwidth):
    """
    Takes a two's-complement represented value and
    converts it to a signed integer based on the provided bitwidth.
    For use with Simulation.inspect() etc. when expecting negative numbers,
    which it does not recognize
    """
    valbl = val.bit_length()
    if bitwidth < val.bit_length() or val == 2**(bitwidth - 1):
        raise pyrtl.PyrtlError("please choose a larger target bitwidth")
    if bitwidth == valbl:  # MSB is a 1, value is negative
        return -((~val & (2**bitwidth - 1)) + 1
                 )  # flip the bits, add one, and make negative
    else:
        return val
Ejemplo n.º 16
0
def signed_tree_multiplier(A,
                           B,
                           reducer=adders.wallace_reducer,
                           adder_func=adders.kogge_stone):
    """Same as tree_multiplier, but uses two's-complement signed integers"""
    if len(A) == 1 or len(B) == 1:
        raise pyrtl.PyrtlError(
            "sign bit required, one or both wires too small")

    aneg, bneg = A[-1], B[-1]
    a = _twos_comp_conditional(A, aneg)
    b = _twos_comp_conditional(B, bneg)

    res = tree_multiplier(a[:-1], b[:-1]).zero_extended(len(A) + len(B))
    return _twos_comp_conditional(res, aneg ^ bneg)
Ejemplo n.º 17
0
    def encryption(self, plaintext, key):
        """
        Builds a single cycle AES Encryption circuit

        :param WireVector plaintext: text to encrypt
        :param WireVector key: AES key to use to encrypt
        :return: a WireVector containing the ciphertext

        """
        if len(plaintext) != self._key_len:
            raise pyrtl.PyrtlError("Ciphertext length is invalid")
        if len(key) != self._key_len:
            raise pyrtl.PyrtlError("key length is invalid")

        key_list = self._key_gen(key)
        t = self._add_round_key(plaintext, key_list[0])

        for round in range(1, 11):
            t = self._sub_bytes(t)
            t = self._shift_rows(t)
            if round != 10:
                t = self._mix_columns(t)
            t = self._add_round_key(t, key_list[round])
        return t
Ejemplo n.º 18
0
def partition_wire(wire, partition_size):
    """ Partitions a wire into a list of N wires of size 'partition_size'.

    :param wire: Wire to partition
    :param partition_size: Integer representing size of each partition

    The wire's bitwidth must be evenly divisible by 'parition_size'.
    """
    if len(wire) % partition_size != 0:
        raise pyrtl.PyrtlError(
            "Wire {} cannot be evenly partitioned into items of size {}".
            format(wire, partition_size))
    return [
        wire[offset:offset + partition_size]
        for offset in range(0, len(wire), partition_size)
    ]
Ejemplo n.º 19
0
def _general_adder_reducer(wire_array_2, result_bitwidth, reduce_2s,
                           final_adder):
    """
    Does the reduction and final adding for bot dada and wallace recucers

    :param [[Wirevector]] wire_array_2: An array of arrays of single bitwidth
    wirevectors
    :param int result_bitwidth: The bitwidth you want for the resulting wire
    Used to eliminate unnessary wires
    :param Bool reduce_2s: True=Wallace Reducer, False=Dada Reducer
    :param final_adder: The adder used for the final addition
    :return: wirevector of length result_wirevector
    """
    # verification that the wires are actually wirevectors of length 1
    for wire_set in wire_array_2:
        for a_wire in wire_set:
            if not isinstance(a_wire, pyrtl.WireVector) or len(a_wire) != 1:
                raise pyrtl.PyrtlError(
                    "The item {} is not a valid element for the wire_array_2. "
                    "It must be a WireVector of bitwidth 1".format(a_wire))

    while not all(len(i) <= 2 for i in wire_array_2):
        deferred = [[] for weight in range(result_bitwidth + 1)]
        for i, w_array in enumerate(
                wire_array_2):  # Start with low weights and start reducing
            while len(w_array) >= 3:
                cout, sum = _one_bit_add_no_concat(*(w_array.pop(0)
                                                     for j in range(3)))
                deferred[i].append(sum)
                deferred[i + 1].append(cout)

            if len(w_array) == 2 and reduce_2s:
                cout, sum = half_adder(*w_array)
                deferred[i].append(sum)
                deferred[i + 1].append(cout)
            else:
                deferred[i].extend(w_array)
        wire_array_2 = deferred[:result_bitwidth]

    # At this stage in the multiplication we have only 2 wire vectors left.
    # now we need to add them up
    result = _sparse_adder(wire_array_2, final_adder)
    if len(result) > result_bitwidth:
        return result[:result_bitwidth]
    else:
        return result
Ejemplo n.º 20
0
    def and_inv_op(net):
        if net.op in '~|&rwcsm@':
            return True

        def arg(num):
            return net.args[num]

        dest = net.dests[0]
        if net.op == '^':
            all_1 = arg(0) & arg(1)
            all_0 = ~arg(0) & ~arg(1)
            dest <<= all_0 & ~all_1
        elif net.op == 'n':
            dest <<= ~(arg(0) & arg(1))
        else:
            raise pyrtl.PyrtlError(
                "Op, '{}' is not supported in and_inv_synth".format(net.op))
Ejemplo n.º 21
0
def circuit_equivalence(equv_func,
                        in_wires=None,
                        out_wire=None,
                        block=None,
                        print_invalid=True):
    """
    Checks whether a circuit is equivalent to a python function

    :param equv_func: function to test circuit equivelence of. int args are passed
      in the order of the in_wires
    :param [Input] in_wires: wires for input (in order of args for equiv_func).
      default: all input wires, in their name's alphabetical order.
    :param Output out_wire: wire to use for output, default: find the only one
    :param Bool print_invalid:
    :return: bool
    """
    block = pyrtl.working_block(block)
    in_wires = _get_inputs(in_wires, block)
    out_wire = _get_output(out_wire, block)

    # now we get into the algorithm
    bits_to_test = sum(w.bitwidth for w in in_wires)

    sim = pyrtl.Simulation()
    for test_val in range(2**bits_to_test):
        vals = _create_seq_list(in_wires, test_val)
        sim.step({w: v for w, v in zip(in_wires, vals)})
        out_val = sim.inspect(out_wire)
        expected_val = equv_func(*vals)
        if not isinstance(expected_val, numbers.Integral):
            raise pyrtl.PyrtlError(
                "Equv_func return %s, which is not an integer" %
                repr(expected_val))
        if out_val != expected_val:
            if print_invalid:
                situation_str = ', '.join(
                    str(w) + ' = ' + str(v) for w, v in zip(in_wires, vals))
                print("in situation {}, got: {} expected: {}".format(
                    situation_str, out_val, expected_val))
            return False
    return True
Ejemplo n.º 22
0
def complex_mult(A, B, shifts, start):
    """ Generate shift-and-add multiplier that can shift and add multiple bits per clock cycle.
    Uses substantially more space than `simple_mult()` but is much faster.

    :param WireVector A, B: two input wires for the multiplication
    :param int shifts: number of spaces Register is to be shifted per clk cycle
        (cannot be greater than the length of `A` or `B`)
    :param bool start: start signal
    :returns: Register containing the product; the "done" signal
    """

    alen = len(A)
    blen = len(B)
    areg = pyrtl.Register(alen)
    breg = pyrtl.Register(alen + blen)
    accum = pyrtl.Register(alen + blen)
    done = (areg == 0)  # Multiplication is finished when a becomes 0
    if (shifts > alen) or (shifts > blen):
        raise pyrtl.PyrtlError(
            "shift is larger than one or both of the parameters A or B,"
            "please choose smaller shift")

    # During multiplication, shift a right every cycle 'shift' times,
    # shift b left every cycle 'shift' times
    with pyrtl.conditional_assignment:
        with start:  # initialization
            areg.next |= A
            breg.next |= B
            accum.next |= 0

        with ~done:  # don't run when there's no work to do
            # "Multiply" shifted breg by LSB of areg by cond. adding
            areg.next |= libutils._shifted_reg_next(areg, 'r',
                                                    shifts)  # right shift
            breg.next |= libutils._shifted_reg_next(breg, 'l',
                                                    shifts)  # left shift
            accum.next |= accum + _one_cycle_mult(areg, breg, shifts)

    return accum, done
Ejemplo n.º 23
0
def _shifted_reg_next(reg, direct, num=1):
    """
    Creates a shifted 'next' property for shifted (left or right) register.\n
    Use: `myReg.next = shifted_reg_next(myReg, 'l', 4)`

    :param string direct: direction of shift, either 'l' or 'r'
    :param int num: number of shifts
    :return: Register containing reg's (shifted) next state
    """
    if direct == 'l':
        if num >= len(reg):
            return 0
        else:
            return pyrtl.concat(reg, pyrtl.Const(0, num))
    elif direct == 'r':
        if num >= len(reg):
            return 0
        else:
            return reg[num:]
    else:
        raise pyrtl.PyrtlError("direction must be specified with 'direct'"
                               "parameter as either 'l' or 'r'")
Ejemplo n.º 24
0
def _general_adder_reducer(wire_array_2, result_bitwidth, reduce_2s, final_adder):
    """
    Does the reduction and final adding for bot dada and wallace recucers

    :param [[Wirevector]] wire_array_2: An array of arrays of single bitwidth
    wirevectors
    :param int result_bitwidth: The bitwidth you want for the resulting wire
    Used to eliminate unnessary wires
    :param Bool reduce_2s: True=Wallace Reducer, False=Dada Reducer
    :param final_adder: The adder used for the final addition
    :return: wirevector of length result_wirevector
    """
    # verification that the wires are actually wirevectors of length 1
    """

    These reductions take place in the form of full-adders (3 inputs),
    half-adders (2 inputs), or just passing a wire along (1 input).

    These reductions take place as long as there are more than 2 wire
    vectors left. When there are 2 wire vectors left, you simply run the
    2 wire vectors through a Kogge-Stone adder.
    """
    for wire_set in wire_array_2:
        for a_wire in wire_set:
            if not isinstance(a_wire, pyrtl.WireVector) or len(a_wire) != 1:
                raise pyrtl.PyrtlError(
                    "The item %s is not a valid element for the wire_array_2. "
                    "It must be a WireVector of bitwidth 1")

    deferred = [[] for weight in range(result_bitwidth)]
    while not all([len(i) <= 2 for i in wire_array_2]):
        # While there's more than 2 wire vectors left
        for i in range(len(wire_array_2)):  # Start with low weights and start reducing
            while len(wire_array_2[i]) >= 3:  # Reduce with Full Adders until < 3 wires
                a, b, cin = (wire_array_2[i].pop(0) for j in range(3))
                deferred[i].append(a ^ b ^ cin)  # deferred bit keeps this sum
                if i + 1 < result_bitwidth:  # watch out for index bounds
                    deferred[i + 1].append((a & b) | (b & cin) | (a & cin))  # cout goes up by one

            if len(wire_array_2[i]) == 2:
                if reduce_2s:  # Reduce with a Half Adder if exactly 2 wires
                    a, b = wire_array_2[i].pop(0), wire_array_2[i].pop(0)
                    deferred[i].append(a ^ b)  # deferred bit keeps this sum
                    if i + 1 < result_bitwidth:
                        deferred[i + 1].append(a & b)  # cout goes up one weight
                else:
                    deferred[i].extend(wire_array_2[i])

            elif len(wire_array_2[i]) == 1:  # Remaining wire is passed along the reductions
                deferred[i].append(wire_array_2[i][0])  # deferred bit keeps this value

        wire_array_2 = deferred  # Set bits equal to the deferred values
        deferred = [[] for weight in range(result_bitwidth)]  # Reset deferred to empty

    # At this stage in the multiplication we have only 2 wire vectors left.

    num1 = []
    num2 = []
    # This humorous variable tells us when we have seen the start of the overlap
    # of the two wire vectors
    weve_seen_a_two = False
    result = None

    for i in range(result_bitwidth):

        if len(wire_array_2[i]) == 2:  # Check if the two wire vectors overlap yet
            weve_seen_a_two = True

        if not weve_seen_a_two:  # If they have not overlapped, add the 1's to result
            if result is None:
                result = wire_array_2[i][0]
            else:
                result = pyrtl.concat(wire_array_2[i][0], result)
        else:
            # For overlapping bits, create num1 and num2
            if weve_seen_a_two and len(wire_array_2[i]) == 2:
                num1.insert(0, wire_array_2[i][0])  # because we need to prepend to the list
                num2.insert(0, wire_array_2[i][1])

            # If there's 1 left it's part of num2
            if weve_seen_a_two and len(wire_array_2[i]) == 1 and i < result_bitwidth:
                num1.insert(0, pyrtl.Const(0))
                num2.insert(0, wire_array_2[i][0])

    adder_result = final_adder(pyrtl.concat(*num1), pyrtl.concat(*num2))

    # Concatenate the results, and then return them.
    # Perhaps here we should slice off the overflow bit, if it exceeds bit_length?
    # result = result[:-1]
    if result is None:
        result = adder_result
    else:
        result = pyrtl.concat(adder_result, result)
    if len(result) > result_bitwidth:
        return result[:result_bitwidth]
    else:
        return result
Ejemplo n.º 25
0
def csprng_trivium(bitwidth, load, req, seed=None, bits_per_cycle=64):
    """ Builds a cyptographically secure PRNG using the Trivium stream cipher.

    :param bitwidth: the desired bitwidth of the random number
    :param load: one bit signal to load the seed into the prng
    :param req: one bit signal to request a random number
    :param seed: 160 bits WireVector (80 bits key + 80 bits IV), defaults to None (self-seeding),
      refrain from self-seeding if reseeding at run time is needed
    :param bits_per_cycle: the number of output bits to generate in parallel each cycle,
      up to 64 bits, must be a power of two: either 1, 2, 4, 8, 16, 32, or 64
    :return ready, rand: ready is a one bit signal showing either the random number has
      been produced or the seed has been initialized, rand is a register containing the
      random number with the given bitwidth

    This prng uses Trivium's key stream as its random bits output.
    Both seed and key stream are MSB first (the earliest bit is stored at the MSB).
    Trivium has a seed initialization stage that discards the first weak 1152 output bits
    after each loading. Generation stage can take multiple cycles as well depending on the
    given bitwidth and bits_per_cycle.
    Has smaller gate area and faster speed than AES-CTR and any other stream cipher.
    Passes all known statistical tests. Can be used to generate encryption keys or IVs.
    Designed to securely generate up to 2**64 bits. If more than 2**64 bits is needed,
    must reseed after each generation of 2**64 bits.

    Trivium specifications:
    http://www.ecrypt.eu.org/stream/ciphers/trivium/trivium.pdf
    See also the eSTREAM portfolio page:
    http://www.ecrypt.eu.org/stream/e2-trivium.html
    """
    from math import ceil, log
    if (64 // bits_per_cycle) * bits_per_cycle != 64:
        raise pyrtl.PyrtlError('bits_per_cycle is invalid')
    if seed is None:
        import random
        cryptogen = random.SystemRandom()
        seed = cryptogen.randrange(
            2**160)  # seed itself if no seed signal is given
    seed = pyrtl.as_wires(seed, 160)
    key = seed[80:]
    iv = seed[:80]

    a = pyrtl.Register(93)
    b = pyrtl.Register(84)
    c = pyrtl.Register(111)
    feedback_a, feedback_b, feedback_c, output = ([] for i in range(4))
    for i in range(bits_per_cycle):
        t1 = a[65 - i] ^ a[92 - i]
        t2 = b[68 - i] ^ b[83 - i]
        t3 = c[65 - i] ^ c[110 - i]
        feedback_a.append(t3 ^ c[108 - i] & c[109 - i] ^ a[68 - i])
        feedback_b.append(t1 ^ a[90 - i] & a[91 - i] ^ b[77 - i])
        feedback_c.append(t2 ^ b[81 - i] & b[82 - i] ^ c[86 - i])
        output.append(t1 ^ t2 ^ t3)
    # update internal states by shifting bits_per_cycle times
    a_next = pyrtl.concat(a, *feedback_a)
    b_next = pyrtl.concat(b, *feedback_b)
    c_next = pyrtl.concat(c, *feedback_c)

    init_cycles = 1152 // bits_per_cycle
    gen_cycles = int(ceil(bitwidth / bits_per_cycle))
    counter_bitwidth = int(ceil(log(max(init_cycles + 1, gen_cycles), 2)))
    rand = pyrtl.Register(bitwidth)
    counter = pyrtl.Register(counter_bitwidth, 'counter')
    init_done = counter == init_cycles
    gen_done = counter == gen_cycles - 1
    state = pyrtl.Register(2)
    WAIT, INIT, GEN = (pyrtl.Const(x) for x in range(3))
    with pyrtl.conditional_assignment:
        with load:
            counter.next |= 0
            a.next |= key
            b.next |= iv
            c.next |= pyrtl.concat(pyrtl.Const("3'b111"), pyrtl.Const(0, 108))
            state.next |= INIT
        with req:
            counter.next |= 0
            a.next |= a_next
            b.next |= b_next
            c.next |= c_next
            rand.next |= pyrtl.concat(rand, *output)
            state.next |= GEN
        with state == INIT:
            with ~init_done:
                counter.next |= counter + 1
                a.next |= a_next
                b.next |= b_next
                c.next |= c_next
        with state == GEN:
            with ~gen_done:
                counter.next |= counter + 1
                a.next |= a_next
                b.next |= b_next
                c.next |= c_next
                rand.next |= pyrtl.concat(rand, *output)

    ready = ~load & ~req & ((state == INIT) & init_done |
                            (state == GEN) & gen_done)
    return ready, rand
Ejemplo n.º 26
0
 def _check_finalized(self):
     if self._final:
         raise pyrtl.PyrtlError("Cannot change InstrConnector, already finalized")