Ejemplo n.º 1
0
    class _ROM(Circuit):
        name = 'LUT_{}t_{}n'.format(cleanName(str(t)), n)
        addr_width = getRAMAddrWidth(n)
        IO = [
            'addr',
            In(Bits[addr_width]),
            'data',
            Out(t),
        ] + ClockInterface()

        @classmethod
        def definition(cls):
            type_size_in_bits = GetCoreIRBackend().get_type(t).size
            bit_luts = []
            for i in range(type_size_in_bits):
                bit_luts += [
                    DefineLUT(seq2int([el[i] for el in init]),
                              getRAMAddrWidth(n))()
                ]

            bits_to_type = Hydrate(t)
            for i in range(type_size_in_bits):
                wire(bit_luts[i].O, bits_to_type.I[i])
            wire(bits_to_type.out, cls.data)

            for i in range(type_size_in_bits):
                for j in range(cls.addr_width):
                    wire(cls.addr[j], getattr(bit_luts[i], "I" + str(j)))
Ejemplo n.º 2
0
    class _SIPO(Circuit):
        name = 'SIPO_{}t_{}n_{}init_{}CE_RESET'.format(cleanName(str(t)),
                                                       str(n), str(init),
                                                       str(has_ce),
                                                       str(has_reset))
        IO = ['I', In(t), 'O', Out(Array[n, t])] + \
                ClockInterface(has_ce,has_reset)

        @classmethod
        def definition(cls):
            type_size_in_bits = GetCoreIRBackend().get_type(t).size
            type_to_bits = Dehydrate(t)
            sipos = MapParallel(type_size_in_bits,
                                DefineSIPO(n, init, has_ce, has_reset))
            bits_to_type = MapParallel(n, DefineHydrate(t))

            for bit_in_type in range(type_size_in_bits):
                wire(type_to_bits.out[bit_in_type], sipos.I[bit_in_type])
                for sipo_output in range(n):
                    wire(sipos.O[bit_in_type][sipo_output],
                         bits_to_type.I[sipo_output][bit_in_type])

            wire(cls.I, type_to_bits.I)
            wire(bits_to_type.out, cls.O)

            for bit_in_type in range(type_size_in_bits):
                if has_ce:
                    wire(cls.CE, sipos.CE[bit_in_type])
                if has_reset:
                    wire(cls.RESET, sipos.RESET[bit_in_type])
Ejemplo n.º 3
0
    class _Mux(Circuit):
        name = 'Mux_{}t_{}n'.format(cleanName(str(t)), n)
        addr_width = getRAMAddrWidth(n)
        IO = [
            'data',
            In(Array[n, t]), 'sel',
            In(Bits[addr_width]), 'out',
            Out(t)
        ]

        @classmethod
        def definition(cls):
            if n > 1:
                type_size_in_bits = GetCoreIRBackend().get_type(t).size
                mux = CommonlibMuxN(n, type_size_in_bits)

                type_to_bits = DefineNativeMapParallel(n, DefineDehydrate(t))()
                wire(cls.data, type_to_bits.I)
                wire(type_to_bits.out, mux.I.data)

                bits_to_type = Hydrate(t)
                wire(mux.out, bits_to_type.I)
                wire(bits_to_type.out, cls.out)

                wire(cls.sel, mux.I.sel)
            else:
                wire(cls.data[0], cls.out)
                sel_term = DefineTermAnyType(Bits[cls.addr_width])()
                wire(cls.sel, sel_term.I)
Ejemplo n.º 4
0
    class _BitonicSort(Circuit):
        name = "BitonicSort_t{}_n{}".format(cleanName(str(T)), str(n))
        IO = ['I', In(Array[n, T]), 'O', Out(Array[n, T])]

        @classmethod
        def definition(BitonicSort):
            # generate the max value (all 1's) and feed it to all inputs to
            # power 2 bitonic sorting network not used by inputs
            t_size = T.size()
            n_raised_to_nearest_pow2 = pow(2, ceil(log2(n)))
            if n_raised_to_nearest_pow2 > n:
                max_const_flat = DefineCoreirConst(t_size,
                                                   pow(2, t_size) - 1)()
                max_const = Hydrate(T)
                wire(max_const_flat.O, max_const.I)

            pow2_sort = DefineBitonicSortPow2(T, n_raised_to_nearest_pow2,
                                              cmp_component)()
            for i in range(n_raised_to_nearest_pow2):
                if i < n:
                    wire(BitonicSort.I[i], pow2_sort.I[i])
                    wire(BitonicSort.O[i], pow2_sort.O[i])
                else:
                    wire(max_const.out, pow2_sort.I[i])
                    term = TermAnyType(T)
                    wire(term.I, pow2_sort.O[i])
Ejemplo n.º 5
0
    class _RAM(Circuit):
        name = 'RAM_{}t_{}n'.format(cleanName(str(t)), n)
        addr_width = getRAMAddrWidth(n)
        IO = [
            'RADDR',
            In(Bits[addr_width]), 'RDATA',
            Out(t), 'WADDR',
            In(Bits[addr_width]), 'WDATA',
            In(t), 'WE',
            In(Bit)
        ] + ClockInterface()

        @classmethod
        def definition(cls):
            type_size_in_bits = GetCoreIRBackend().get_type(t).size
            ram = DefineRAM(n, type_size_in_bits, read_latency=read_latency)()

            type_to_bits = Dehydrate(t)
            wire(cls.WDATA, type_to_bits.I)
            wire(type_to_bits.out, ram.WDATA)

            bits_to_type = Hydrate(t)
            wire(ram.RDATA, bits_to_type.I)
            wire(bits_to_type.out, cls.RDATA)

            wire(cls.RADDR, ram.RADDR)
            wire(ram.WADDR, cls.WADDR)

            wire(cls.WE, ram.WE)
Ejemplo n.º 6
0
    class _Sort2Elements(Circuit):
        name = "Sort2Elements_T{}".format(cleanName(str(T)))
        IO = ['I0', In(T), 'I1',
              In(T), 'O0',
              Out(T), 'O1',
              Out(T)] + ClockInterface()

        @classmethod
        def definition(Sort2Elements):
            sel = DefineMuxAnyType(Array[2, T], 2)()

            cmp0 = cmp_component(Sort2Elements.I0)
            cmp1 = cmp_component(Sort2Elements.I1)

            cmp0_bits = Dehydrate(type(cmp0))
            cmp1_bits = Dehydrate(type(cmp1))

            wire(cmp0_bits.I, cmp0)
            wire(cmp1_bits.I, cmp1)

            lt = DefineCoreirUlt(cmp0_bits.out.N)()

            wire(lt.I0, cmp0_bits.out)
            wire(lt.I1, cmp1_bits.out)

            # lt will emit 1 if I0 is less than
            wire(Sort2Elements.I0, sel.data[1][0])
            wire(Sort2Elements.I1, sel.data[1][1])
            wire(Sort2Elements.I0, sel.data[0][1])
            wire(Sort2Elements.I1, sel.data[0][0])
            wire(lt.O, sel.sel[0])

            wire(sel.out[0], Sort2Elements.O0)
            wire(sel.out[1], Sort2Elements.O1)
Ejemplo n.º 7
0
    class _BitonicSortPow2(Circuit):
        name = "BitonicSortPow2_t{}_n{}".format(cleanName(str(T)), str(n))
        IO = ['I', In(Array[n, T]), 'O', Out(Array[n, T])]

        @classmethod
        def definition(BitonicSortPow2):
            # input ports and all other intermediate merge ports
            # starting with just input ports
            ports = [[BitonicSortPow2.I[i] for i in range(n)]]
            # Sort ranges starting with 2^1 until 2^m == n
            for i in range(int(log2(n))):
                elements_per_merge = pow(2, i + 1)
                cur_prior_stage_port = 0
                cur_stage_ports = []
                for j in range(n // pow(2, i + 1)):
                    merger = DefineBitonicMergePow2(T, elements_per_merge,
                                                    cmp_component, j)()
                    for k in range(elements_per_merge):
                        wire(ports[i][cur_prior_stage_port], merger.I[k])
                        cur_prior_stage_port += 1
                        cur_stage_ports += [merger.O[k]]
                ports += [cur_stage_ports]
            last_ports = ports[-1]
            for i in range(len(last_ports)):
                wire(last_ports[i], BitonicSortPow2.O[i])
Ejemplo n.º 8
0
    class _Down_S(Circuit):
        name = "Down_S_n{}_sel{}_tEl{}_v{}".format(str(n), str(idx),
                                                   cleanName(str(elem_t)),
                                                   str(has_valid))
        binary_op = False
        st_in_t = [ST_SSeq(n, elem_t)]
        st_out_t = ST_SSeq(1, elem_t)
        IO = ['I', In(st_in_t[0].magma_repr()), 'O', Out(st_out_t.magma_repr())] + \
             ClockInterface(False, False)
        if has_valid:
            IO += valid_ports

        @classmethod
        def definition(cls):
            if n > 1:
                inputs_term = TermAnyType(ST_SSeq(n - 1, elem_t).magma_repr())
            num_wired_to_output = 0
            for i in range(len(cls.I)):
                if i == idx:
                    wire(cls.I[i], cls.O[0])
                    num_wired_to_output += 1
                else:
                    wire(cls.I[i], inputs_term.I[i - num_wired_to_output])
            if has_valid:
                wire(cls.valid_up, cls.valid_down)
Ejemplo n.º 9
0
    class _Const(Circuit):
        name = "Const_t{}_hasCE{}_hasReset{}_hasValid{}".format(cleanName(str(t)), str(has_ce),
                                                                str(has_reset), str(has_valid))
        IO = ['O', Out(t.magma_repr())] + ClockInterface(has_ce, has_reset)
        binary_op = False
        st_in_t = []
        st_out_t = t
        if has_valid:
            IO += valid_ports
        @classmethod
        def definition(cls):
            one_const = DefineCoreirConst(1, 1)().O[0]
            if delay == 0:
                enabled = one_const
            else:
                delay_counter = InitialDelayCounter(delay)
                wire(delay_counter.CE, one_const)
                enabled = delay_counter.valid
            if has_ce:
                enabled = bit(cls.CE) & enabled

            luts = DefineLUTAnyType(t.magma_repr(), t.time(), ts_arrays_to_bits(ts_values))()
            lut_position_counter = AESizedCounterModM(t.time(), has_ce=True, has_reset=has_reset)

            wire(lut_position_counter.O, luts.addr)
            wire(cls.O, luts.data)
            wire(enabled, lut_position_counter.CE)

            if has_reset:
                wire(cls.RESET, lut_position_counter.RESET)
            if has_valid:
                valid_up_term = TermAnyType(Bit)
                wire(cls.valid_up, valid_up_term.I)
                wire(enabled, cls.valid_down)
Ejemplo n.º 10
0
    class _ShiftT(Circuit):
        name = "Shift_t_n{}_i{}_amt{}_tEl{}__hasCE{}_hasReset{}_hasValid{}".format(
            str(n), str(i), str(shift_amount), cleanName(str(elem_t)),
            str(has_ce), str(has_reset), str(has_valid))
        binary_op = False
        st_in_t = [ST_TSeq(n, i, elem_t)]
        st_out_t = ST_TSeq(n, i, elem_t)
        IO = ['I', In(st_in_t[0].magma_repr()), 'O', Out(st_out_t.magma_repr())] + \
             ClockInterface(has_ce, has_reset)
        if has_valid:
            IO += valid_ports

        @classmethod
        def definition(cls):
            enabled = DefineCoreirConst(1, 1)().O[0]
            if has_valid:
                enabled = cls.valid_up & enabled
                wire(cls.valid_up, cls.valid_down)
            if has_ce:
                enabled = bit(cls.CE) & enabled

            value_store = DefineRAM_ST(elem_t,
                                       shift_amount,
                                       has_reset=has_reset)()

            # write and read from same location
            # will write on first iteration through element, write and read on later iterations
            # output for first iteration is undefined, so ok to read anything
            next_ram_addr = DefineNestedCounters(elem_t,
                                                 has_ce=True,
                                                 has_reset=has_reset)()
            # its fine that this doesn't account for the invalid clocks.
            # after the invalid clocks, the next iteration will start from
            # an index that is possibly not 0. That doesn't matter
            # as will just loop around
            ram_addr = AESizedCounterModM(shift_amount,
                                          has_ce=True,
                                          has_reset=has_reset)

            wire(ram_addr.O, value_store.WADDR)
            wire(ram_addr.O, value_store.RADDR)

            wire(enabled, value_store.WE)
            wire(enabled, value_store.RE)
            wire(enabled & next_ram_addr.last, ram_addr.CE)
            wire(enabled, next_ram_addr.CE)

            next_ram_addr_term = TermAnyType(Bit)
            wire(next_ram_addr.valid, next_ram_addr_term.I)

            wire(cls.I, value_store.WDATA)
            wire(value_store.RDATA, cls.O)
            if has_reset:
                wire(value_store.RESET, cls.RESET)
                wire(ram_addr.RESET, cls.RESET)
                wire(next_ram_addr.RESET, cls.RESET)
Ejemplo n.º 11
0
    class _SIPO(Circuit):
        name = 'Term_{}t'.format(cleanName(str(t)))
        IO = ['I', In(t)]

        @classmethod
        def definition(cls):
            type_size_in_bits = GetCoreIRBackend().get_type(t).size
            type_to_bits = Dehydrate(t)
            term = Term(type_size_in_bits)
            wire(cls.I, type_to_bits.I)
            wire(type_to_bits.out, term.I)
Ejemplo n.º 12
0
    class _Up_T(Circuit):
        name = "Up_T_n{}_i{}_tEl{}_hasCE{}_hasReset{}_hasValid{}".format(str(n), str(i),
                                                                          cleanName(str(elem_t)),
                                                                          str(has_ce), str(has_reset),
                                                                          str(has_valid))
        binary_op = False
        st_in_t = [ST_TSeq(1, n+i-1, elem_t)]
        st_out_t = ST_TSeq(n, i, elem_t)
        IO = ['I', In(st_in_t[0].magma_repr()), 'O', Out(st_out_t.magma_repr())] + \
             ClockInterface(has_ce, has_reset)
        if has_valid:
            IO += valid_ports
        @classmethod
        def definition(cls):
            enabled = DefineCoreirConst(1, 1)().O[0]
            if has_valid:
                enabled = cls.valid_up & enabled
                wire(cls.valid_up, cls.valid_down)
            if has_ce:
                enabled = bit(cls.CE) & enabled

            value_store = DefineRAM_ST(elem_t, 1, has_reset=has_reset)()

            # write to value_store for first element, read for next
            element_time_counter = DefineNestedCounters(elem_t, has_ce=True, has_reset=has_reset)()
            element_idx_counter = AESizedCounterModM(n + i, has_ce=True, has_reset=has_reset)
            is_first_element = Decode(0, element_idx_counter.O.N)(element_idx_counter.O)

            zero_addr = DefineCoreirConst(1, 0)().O
            wire(zero_addr, value_store.WADDR)
            wire(zero_addr, value_store.RADDR)

            wire(enabled & is_first_element, value_store.WE)
            wire(enabled, value_store.RE)
            wire(enabled, element_time_counter.CE)
            wire(enabled & element_time_counter.last, element_idx_counter.CE)

            element_time_counter_term = TermAnyType(Bit)
            wire(element_time_counter.valid, element_time_counter_term.I)

            wire(cls.I, value_store.WDATA)

            output_selector = DefineMuxAnyType(elem_t.magma_repr(), 2)()

            # on first element, send the input directly out. otherwise, use the register
            wire(is_first_element, output_selector.sel[0])
            wire(value_store.RDATA, output_selector.data[0])
            wire(cls.I, output_selector.data[1])
            wire(output_selector.out, cls.O)

            if has_reset:
                wire(value_store.RESET, cls.RESET)
                wire(element_time_counter.RESET, cls.RESET)
                wire(element_idx_counter.RESET, cls.RESET)
Ejemplo n.º 13
0
    class _Passthrough(Circuit):
        name = "Passthrough_tIn{}_tOut{}".format(cleanName(str(t_in)),
                                                 str(cleanName(str(t_out))))
        IO = ['I', In(t_in.magma_repr()), 'O', Out(t_out.magma_repr())]
        binary_op = False
        st_in_t = [t_in]
        st_out_t = t_out
        if has_valid:
            IO += valid_ports

        @classmethod
        def definition(cls):
            cls.output_delay = 0
            flat_in_ports = get_nested_ports(cls.I,
                                             num_nested_space_layers(t_in), [])
            flat_out_ports = get_nested_ports(cls.O,
                                              num_nested_space_layers(t_out),
                                              [])
            for i_port, o_port in zip(flat_in_ports, flat_out_ports):
                wire(i_port, o_port)
            if has_valid:
                wire(cls.valid_up, cls.valid_down)
Ejemplo n.º 14
0
    class _AtomTupleCreator(Circuit):
        name = "atomTupleCreator_t0{}_t1{}".format(cleanName(str(t0)),
                                                   cleanName(str(t1)))

        binary_op = True
        st_in_t = [t0, t1]
        st_out_t = ST_Atom_Tuple(t0, t1)
        IO = [
            "I0",
            In(st_in_t[0].magma_repr()), "I1",
            In(st_in_t[1].magma_repr()), "O",
            Out(st_out_t.magma_repr())
        ]
        if has_valid:
            IO += valid_ports

        @classmethod
        def definition(atom_tuple_creator):
            wire(atom_tuple_creator.I0, atom_tuple_creator.O[0])
            wire(atom_tuple_creator.I1, atom_tuple_creator.O[1])
            if has_valid:
                wire(atom_tuple_creator.valid_up,
                     atom_tuple_creator.valid_down)
Ejemplo n.º 15
0
    class _Register(Circuit):
        name = 'Register_{}t_{}init_{}CE_{}RESET'.format(
            cleanName(str(t)), str(init), str(has_ce), str(has_reset))
        IO = ['I', In(t), 'O', Out(t)] + \
                ClockInterface(has_ce,has_reset)

        @classmethod
        def definition(cls):
            # if using a layer of nesting
            nested = False
            if issubclass(type(t), ArrayKind):
                if type(t.T) == BitKind:
                    regs = [
                        DefineRegister(t.N, has_ce=has_ce,
                                       has_reset=has_reset)()
                    ]
                    nested = True
                else:
                    regs = [
                        DefineRegisterAnyType(t.T, init, has_ce, has_reset)()
                        for _ in range(t.N)
                    ]
            elif issubclass(type(t), TupleKind):
                regs = [
                    DefineRegisterAnyType(t_inner, init, has_ce, has_reset)()
                    for t_inner in t.Ts
                ]
            else:
                regs = [DefineRegister(1, init, has_ce, has_reset)()]
            if nested:
                for i in range(t.N):
                    wire(cls.I[i], regs[0].I[i])
                    wire(cls.O[i], regs[0].O[i])
                if has_ce:
                    wire(cls.CE, regs[0].CE)
                if has_reset:
                    wire(cls.RESET, regs[0].RESET)
            else:
                for i, reg in enumerate(regs):
                    if type(t) == BitKind:
                        wire(cls.I, reg.I[0])
                        wire(cls.O, reg.O[0])
                    else:
                        wire(cls.I[i], reg.I)
                        wire(cls.O[i], reg.O)
                    if has_ce:
                        wire(cls.CE, reg.CE)
                    if has_reset:
                        wire(cls.RESET, reg.RESET)
Ejemplo n.º 16
0
 class _Up_S(Circuit):
     name = "Up_S_n{}_tEl{}_v{}".format(str(n), cleanName(str(elem_t)), str(has_valid))
     binary_op = False
     st_in_t = [ST_SSeq(1, elem_t)]
     st_out_t = ST_SSeq(n, elem_t)
     IO = ['I', In(st_in_t[0].magma_repr()), 'O', Out(st_out_t.magma_repr())] + \
          ClockInterface(False, False)
     if has_valid:
         IO += valid_ports
     @classmethod
     def definition(cls):
         for i in range(n):
             cls.wire = wire(cls.I[0], cls.O[i])
         if has_valid:
             wire(cls.valid_up, cls.valid_down)
Ejemplo n.º 17
0
    class _RAM_ST(Circuit):
        name = 'RAM_ST_{}_hasReset{}'.format(cleanName(str(t)), str(has_reset))
        addr_width = getRAMAddrWidth(n)
        IO = ['RADDR', In(Bits[addr_width]),
              'RDATA', Out(t.magma_repr()),
              'WADDR', In(Bits[addr_width]),
              'WDATA', In(t.magma_repr()),
              'WE', In(Bit),
              'RE', In(Bit)
              ] + ClockInterface(has_ce=False, has_reset=has_reset)
        @classmethod
        def definition(cls):
            # each valid clock, going to get a magma_repr in
            # read or write each one of those to a location
            rams = [DefineRAMAnyType(t.magma_repr(), t.valid_clocks(), read_latency=read_latency)() for _ in range(n)]
            read_time_position_counter = DefineNestedCounters(t, has_cur_valid=True, has_ce=True, has_reset=has_reset)()
            read_valid_term = TermAnyType(Bit)
            read_last_term = TermAnyType(Bit)
            write_time_position_counter = DefineNestedCounters(t, has_cur_valid=True, has_ce=True, has_reset=has_reset)()
            write_valid_term = TermAnyType(Bit)
            write_last_term = TermAnyType(Bit)
            read_selector = DefineMuxAnyType(t.magma_repr(), n)()

            for i in range(n):
                wire(cls.WDATA, rams[i].WDATA)
                wire(write_time_position_counter.cur_valid, rams[i].WADDR)
                wire(read_selector.data[i], rams[i].RDATA)
                wire(read_time_position_counter.cur_valid, rams[i].RADDR)
                write_cur_ram = Decode(i, cls.WADDR.N)(cls.WADDR)
                wire(write_cur_ram & write_time_position_counter.valid, rams[i].WE)

            wire(cls.RADDR, read_selector.sel)
            wire(cls.RDATA, read_selector.out)

            wire(cls.WE, write_time_position_counter.CE)
            wire(cls.RE, read_time_position_counter.CE)

            wire(read_time_position_counter.valid, read_valid_term.I)
            wire(read_time_position_counter.last, read_last_term.I)
            wire(write_time_position_counter.valid, write_valid_term.I)
            wire(write_time_position_counter.last, write_last_term.I)

            if has_reset:
                wire(cls.RESET, write_time_position_counter.RESET)
                wire(cls.RESET, read_time_position_counter.RESET)
Ejemplo n.º 18
0
    class _Partition_S(Circuit):
        name = "Partition_S_no{}_ni{}_tEl{}_v{}".format(
            str(no), str(ni), cleanName(str(elem_t)), str(has_valid))
        binary_op = False
        st_in_t = [ST_SSeq(no * ni, elem_t)]
        st_out_t = ST_SSeq(no, ST_SSeq(ni, elem_t))
        IO = ['I', In(st_in_t[0].magma_repr()), 'O', Out(st_out_t.magma_repr())] + \
             ClockInterface(False, False)
        if has_valid:
            IO += valid_ports

        @classmethod
        def definition(cls):
            for i in range(no):
                for j in range(ni):
                    wire(cls.I[i * ni + j], cls.O[i][j])
            if has_valid:
                wire(cls.valid_up, cls.valid_down)
Ejemplo n.º 19
0
    class _ShiftS(Circuit):
        name = "Shift_S_n{}_amt{}_tEl{}_hasValid{}".format(
            str(n), str(shift_amount), cleanName(str(elem_t)), str(has_valid))
        binary_op = False
        st_in_t = [ST_SSeq(n, elem_t)]
        st_out_t = ST_SSeq(n, elem_t)
        IO = ['I', In(st_in_t[0].magma_repr()), 'O', Out(st_out_t.magma_repr())] + \
             ClockInterface(False, False)
        if has_valid:
            IO += valid_ports

        @classmethod
        def definition(cls):
            for i in range(n):
                # wrap around. first shift_amount outputs undefined, so anything can go out there
                wire(cls.I[i], cls.O[(i + shift_amount) % n])
            if has_valid:
                wire(cls.valid_up, cls.valid_down)
Ejemplo n.º 20
0
    class _ROM(Circuit):
        name = 'ROM_{}t_{}n'.format(cleanName(str(t)), n)
        addr_width = getRAMAddrWidth(n)
        IO = ['RADDR', In(Bits[addr_width]),
              'RDATA', Out(t),
              'REN', In(Bit)
              ] + ClockInterface()
        @classmethod
        def definition(cls):
            type_size_in_bits = GetCoreIRBackend().get_type(t).size
            rom = DefineROM(n, type_size_in_bits)(coreir_configargs={"init": [list(el) for el in init]})

            bits_to_type = Hydrate(t)
            wire(rom.rdata, bits_to_type.I)
            wire(bits_to_type.out, cls.RDATA)

            wire(cls.RADDR, rom.raddr)

            wire(cls.REN, rom.ren)
Ejemplo n.º 21
0
    class _TupleToSSeq(Circuit):
        name = "stupleToSSeq_t{}_n{}".format(cleanName(str(t)), str(n))

        binary_op = False
        st_in_t = [ST_SSeq_Tuple(n, t)]
        st_out_t = ST_SSeq(n, t)
        IO = [
            "I",
            In(st_in_t[0].magma_repr()), "O",
            Out(st_out_t.magma_repr())
        ]
        if has_valid:
            IO += valid_ports

        @classmethod
        def definition(cls):
            wire(cls.I, cls.O)
            if has_valid:
                wire(cls.valid_up, cls.valid_down)
Ejemplo n.º 22
0
    class _ShiftTS(Circuit):
        name = "Shift_ts_no{}_io{}_ni{}_amt{}_tEl{}__hasCE{}_hasReset{}_hasValid{}".format(
            str(no), str(io), str(ni),
            str(shift_amount), cleanName(str(elem_t)), str(has_ce),
            str(has_reset), str(has_valid))
        binary_op = False
        st_in_t = [ST_TSeq(no, io, ST_SSeq(ni, elem_t))]
        st_out_t = ST_TSeq(no, io, ST_SSeq(ni, elem_t))
        IO = ['I', In(st_in_t[0].magma_repr()), 'O', Out(st_out_t.magma_repr())] + \
             ClockInterface(has_ce, has_reset)
        if has_valid:
            IO += valid_ports

        @classmethod
        def definition(cls):
            enabled = DefineCoreirConst(1, 1)().O[0]
            if has_valid:
                enabled = cls.valid_up & enabled
                wire(cls.valid_up, cls.valid_down)
            if has_ce:
                enabled = bit(cls.CE) & enabled

            # don't need valid on these shift_t as they'll be getting it from the enable signal
            shift_t_xs = []
            for i in range(ni):
                shift_amount_t = (ni - i + shift_amount - 1) // ni
                if shift_amount_t == 0:
                    shift_t_xs.append(None)
                else:
                    shift_t_xs.append(
                        DefineShift_T(no, io, shift_amount_t, elem_t, True,
                                      has_reset, False)())

            for i in range(ni):
                if shift_t_xs[i] is None:
                    wire(cls.I[(i - shift_amount) % ni], cls.O[i])
                else:
                    wire(cls.I[(i - shift_amount) % ni], shift_t_xs[i].I)
                    wire(shift_t_xs[i].O, cls.O[i])
                    wire(enabled, shift_t_xs[i].CE)
                    if has_reset:
                        wire(cls.RESET, shift_t_xs[i].RESET)
Ejemplo n.º 23
0
    class _Snd(Circuit):
        name = "snd_t{}".format(cleanName(str(t)))

        binary_op = False
        st_in_t = [t]
        st_out_t = t.t1
        IO = [
            "I",
            In(st_in_t[0].magma_repr()), "O",
            Out(st_out_t.magma_repr())
        ]
        if has_valid:
            IO += valid_ports

        @classmethod
        def definition(cls):
            fst_term = TermAnyType(t.t0.magma_repr())
            wire(cls.I[0], fst_term.I)
            wire(cls.I[1], cls.O)
            if has_valid:
                wire(cls.valid_up, cls.valid_down)
Ejemplo n.º 24
0
    class _Eq(Circuit):
        name = "Eq_Atom_{}t".format(cleanName(str(t)))
        IO = [
            'I',
            In(ST_Atom_Tuple(t, t).magma_repr()), 'O',
            Out(ST_Bit().magma_repr())
        ]
        if has_valid:
            IO += valid_ports
        binary_op = False
        st_in_t = [ST_Atom_Tuple(ST_Int(), ST_Int())]
        st_out_t = ST_Bit()

        @classmethod
        def definition(cls):
            op = DefineEQ(t.magma_repr().size())()
            wire(cls.I[0], op.I0)
            wire(cls.I[1], op.I1)
            wire(op.O, cls.O[0])
            if has_valid:
                wire(cls.valid_up, cls.valid_down)
Ejemplo n.º 25
0
    class _SeqTupleCreator(Circuit):
        name = "sseqTupleCreator_t{}".format(cleanName(str(t)))
        binary_op = True
        st_in_t = [t, t]
        st_out_t = ST_SSeq_Tuple(2, t)

        IO = [
            "I0",
            In(t.magma_repr()), "I1",
            In(t.magma_repr()), "O",
            Out(st_out_t.magma_repr())
        ]
        if has_valid:
            IO += valid_ports

        @classmethod
        def definition(sseq_tuple_creator):
            wire(sseq_tuple_creator.I0, sseq_tuple_creator.O[0])
            wire(sseq_tuple_creator.I1, sseq_tuple_creator.O[1])
            if has_valid:
                wire(sseq_tuple_creator.valid_up,
                     sseq_tuple_creator.valid_down)
Ejemplo n.º 26
0
    class _If(Circuit):
        name = "If_Atom_{}t".format(cleanName(str(t)))
        binary_op = False
        st_in_t = [ST_Atom_Tuple(ST_Bit(), ST_Atom_Tuple(t, t))]
        st_out_t = t
        IO = [
            'I',
            In(st_in_t[0].magma_repr()), 'O',
            Out(st_out_t.magma_repr())
        ]
        if has_valid:
            IO += valid_ports

        @classmethod
        def definition(cls):
            op = DefineMuxAnyType(t.magma_repr(), 2)()
            wire(cls.I[0], op.sel)
            wire(cls.I[1][0], op.data[1])
            wire(cls.I[1][1], op.data[0])
            wire(op.out, cls.O)
            if has_valid:
                wire(cls.valid_up, cls.valid_down)
Ejemplo n.º 27
0
    class _SeqTupleAppender(Circuit):
        name = "sseqTupleAppender_t{}_n{}".format(cleanName(str(t)), str(n))

        binary_op = True
        st_in_t = [ST_SSeq_Tuple(n, t), t]
        st_out_t = ST_SSeq_Tuple(n + 1, t)
        IO = [
            "I0",
            In(st_in_t[0].magma_repr()), "I1",
            In(st_in_t[1].magma_repr()), "O",
            Out(st_out_t.magma_repr())
        ]
        if has_valid:
            IO += valid_ports

        @classmethod
        def definition(sseq_tuple_appender):
            for i in range(n):
                wire(sseq_tuple_appender.I0[i], sseq_tuple_appender.O[i])
            wire(sseq_tuple_appender.I1, sseq_tuple_appender.O[n])
            if has_valid:
                wire(sseq_tuple_appender.valid_up,
                     sseq_tuple_appender.valid_down)
Ejemplo n.º 28
0
    class _Down_T(Circuit):
        name = "Down_S_n{}_i{}_sel{}_tEl{}_v{}".format(str(n), str(i),
                                                       str(idx),
                                                       cleanName(str(elem_t)),
                                                       str(has_valid))
        binary_op = False
        st_in_t = [ST_TSeq(n, i, elem_t)]
        st_out_t = ST_TSeq(1, n + i - 1, elem_t)
        IO = ['I', In(st_in_t[0].magma_repr()), 'O', Out(st_out_t.magma_repr())] + \
             ClockInterface(False, False)
        if has_valid:
            IO += valid_ports

        @classmethod
        def definition(cls):
            wire(cls.I, cls.O)
            if has_valid:
                if elem_t.time() * idx == 0:
                    wire(cls.valid_up, cls.valid_down)
                else:
                    delay_counter = InitialDelayCounter(elem_t.time() * idx)
                    wire(cls.valid_up, delay_counter.CE)
                    wire(delay_counter.valid, cls.valid_down)
Ejemplo n.º 29
0
    class _BitonicMergePow2(Circuit):
        name = "BitonicMergePow2_t{}_n{}_ithMerge{}".format(
            cleanName(str(T)), str(n), str(ith_merge))
        IO = ['I', In(Array[n, T]), 'O', Out(Array[n, T])]

        @classmethod
        def definition(BitonicMergePow2):
            # first sort the inputs once
            first_sorts_ports = [0] * n
            for i in range(n // 2):
                pair_sort = DefineSort2Elements(T, cmp_component)()
                # reverse if ith_merge is not divisible by 2
                wire(BitonicMergePow2.I[i], pair_sort.I0)
                wire(BitonicMergePow2.I[n // 2 + i], pair_sort.I1)
                if ith_merge % 2 == 0:
                    first_sorts_ports[i] = pair_sort.O0
                    first_sorts_ports[n // 2 + i] = pair_sort.O1
                else:
                    first_sorts_ports[i] = pair_sort.O1
                    first_sorts_ports[n // 2 + i] = pair_sort.O0

            if n == 2:
                wire(first_sorts_ports[0], BitonicMergePow2.O[0])
                wire(first_sorts_ports[1], BitonicMergePow2.O[1])
            else:
                # next merge each of the halfs
                mergers = [
                    DefineBitonicMergePow2(T, n // 2, cmp_component,
                                           ith_merge)(),
                    DefineBitonicMergePow2(T, n // 2, cmp_component,
                                           ith_merge)()
                ]
                for i in range(n // 2):
                    wire(first_sorts_ports[i], mergers[0].I[i])
                    wire(first_sorts_ports[n // 2 + i], mergers[1].I[i])
                    wire(mergers[0].O[i], BitonicMergePow2.O[i])
                    wire(mergers[1].O[i], BitonicMergePow2.O[n // 2 + i])
Ejemplo n.º 30
0
    class _LB(Circuit):
        if image_cols % pixels_per_row_per_clock != 0:
            reason = """
            this is necessary so that input a complete row before getting 
            the input for the next row. This means don't have a clock cycle
            where input pixels are from two different rows in the input image.
            """
            raise Exception("Aetherling's Native LineBuffer has invalid "
                            "parameters: image_cols {} not divisible by"
                            "pixels_per_row_per_clock {}. \n Reason {}".format(
                                image_cols, pixels_per_row_per_clock, reason))

        if image_rows % rows_of_pixels_per_clock != 0:
            reason = """
            this is necessary so that input a complete image with the same number 
            of input pixels every clock, don't have a weird ending
            with only 1 pixel from the image left to input
            """
            raise Exception("Aetherling's Native LineBuffer has invalid "
                            "parameters: image_rows {} not divisible by"
                            "pixels_per_row_per_clock {}".format(
                                image_rows, rows_of_pixels_per_clock, reason))

        if image_cols % stride_cols != 0:
            reason = "stride_cols is downsample factor for number of columns." \
                     "This requirement ensures that the column downsample factor" \
                     "cleanly divides the image's number of columns"
            raise Exception("Aetherling's Native LineBuffer has invalid "
                            "parameters: image_cols {} not divisible by"
                            "stride_cols {}. \n Reason: {}".format(
                                image_cols, stride_cols, reason))

        if image_rows % stride_rows != 0:
            reason = "stride_rows is downsample factor for number of rows." \
                     "This requirement ensures that the row downsample factor" \
                     "cleanly divides the image's number of rows"
            raise Exception("Aetherling's Native LineBuffer has invalid "
                            "parameters: image_rows {} not divisible by"
                            "stride_rows {}. \n Reason: {}".format(
                                image_rows, stride_rows, reason))

        if (stride_cols % pixels_per_row_per_clock != 0
                and pixels_per_row_per_clock % stride_cols != 0):
            reason = """
            the average number of output windows per row per clock =
                pixels per row per clock / stride cols per clock 
            Number of output windows per row per clock must be integer or 
            reciprocal of one so that the position of output windows relative
            to the new pixels each clock is constant. 

            Otherwise there will be different numbers of new pixels in
            each window in each clock. This is challenging to implement.
            """
            raise Exception(
                "Aetherling's Native LineBuffer has invalid "
                "parameters: stride_cols {} not divisible by"
                "pixels_per_row_per_clock {} nor vice-verse. One of them must"
                "be divisible by the other. \n Reason: {}".format(
                    stride_cols, pixels_per_row_per_clock, reason))

        if (stride_rows % rows_of_pixels_per_clock != 0
                and rows_of_pixels_per_clock % stride_rows != 0):
            reason = """
            the average number of rows of output windows per clock =
                rows of pixels per clock / stride rows per clock 
            Number of rows of output windows per clock must be integer or 
            reciprocal of one so that the position of output windows relative
            to the new pixels each clock is constant. 

            Otherwise there will be different numbers of new pixels in
            each window in each clock. This is challenging to implement.
            """
            raise Exception(
                "Aetherling's Native LineBuffer has invalid "
                "parameters: stride_rows {} not divisible by"
                "rows_of_pixels_per_clock {} nor vice-verse. One of them must"
                "be divisible by the other. \n Reason: {}".format(
                    stride_rows, rows_of_pixels_per_clock, reason))

        if ((stride_cols * stride_rows) %
            (pixels_per_row_per_clock * rows_of_pixels_per_clock) != 0) and \
                ((pixels_per_row_per_clock * rows_of_pixels_per_clock) %
                 (stride_cols * stride_rows) != 0):
            reason = """
            the average number of output windows per clock =
                (pixels per row per clock * rows of pixels per clock) /
                (stride cols per clock * stride rows per clock)
            Number of output windows per clock must be integer or 
            reciprocal of one so that throughput is an easier factor to manipulate 
            with map/underutil.
            
            Otherwise throughput is a weird fraction and the downstream system is
            either only partially used on some clocks or the sequence length
            is multiplied by a weird factor that makes the rational number
            throughput become an integer.
            """
            raise Exception(
                "Aetherling's Native LineBuffer has invalid "
                "parameters: stride_cols {} * stride_rows{} not divisible by"
                "pixels_per_row_per_clock {} * rows_of_pixels_per_clock nor vice-verse. One of them must"
                "be divisible by the other. \n Reason: {}".format(
                    stride_cols, stride_rows, pixels_per_row_per_clock,
                    rows_of_pixels_per_clock, reason))

        if abs(origin_cols) >= window_cols:
            reason = """
            origin must be less than window. If abs(origin_cols) was greater
            than window_cols, then entire first window would be garbage         
            """
            raise Exception("Aetherling's Native LineBuffer has invalid "
                            "parameters: |origin| {} greater than or equal to"
                            "window width {}".format(abs(origin_cols),
                                                     window_cols))

        if abs(origin_rows) >= window_rows:
            reason = """
            origin must be less than window. If abs(origin_rows) was greater
            than window_rows, then entire first window would be garbage         
            """
            raise Exception("Aetherling's Native LineBuffer has invalid "
                            "parameters: |origin| {} greater than or equal to"
                            "window width {}".format(abs(origin_rows),
                                                     window_rows))
        if origin_cols > 0:
            reason = """
            origin_cols can't go into image. That would be cropping the first cols of the image
            and linebuffer doesn't do cropping.
            """
            raise Exception(
                "Aetherling's Native LineBuffer has invalid "
                "parameters: origin_cols {} greater than 0. \n Reason: {}".
                format(origin_cols, reason))

        if origin_rows > 0:
            reason = """
            origin_rows can't go into image. That would be cropping the first rows of the image
            and linebuffer doesn't do cropping.
            """
            raise Exception(
                "Aetherling's Native LineBuffer has invalid "
                "parameters: origin_rows {} greater than 0. \n Reason: {}".
                format(origin_rows, reason))

        if window_cols - origin_cols >= image_cols:
            reason = """
            need window_cols plus abs(origin_cols) outputs to do wiring.
            If the image_cols is smaller than this, will have issues with
            internal wiring. Additionally, the linebuffer isn't
            used for images that are small enough to be processed
            in one or a few clock cycles. This is a weird edge 
            case that I don't want to deal with and shouldn't occur
            in the real world.
            """
            raise Exception(
                "Aetherling's Native LineBuffer has invalid "
                "parameters: window_cols {} - origin_cols {} "
                "greater than or equal to image_cols {}. \n Reason: {}".format(
                    window_cols, origin_cols, image_cols, reason))

        if window_rows - origin_rows >= image_rows:
            reason = """
            need window_rows plus abs(origin_rows) outputs to do wiring.
            If the image_rows is smaller than this, will have issues with
            internal wiring. Additionally, the linebuffer isn't
            used for images that are small enough to be processed
            in one or a few clock cycles. This is a weird edge 
            case that I don't want to deal with and shouldn't occur
            in the real world.
            """
            raise Exception(
                "Aetherling's Native LineBuffer has invalid "
                "parameters: window_rows {} - origin_rows {} "
                "greater than or equal to image_rows {}. \n Reason: {}".format(
                    window_rows, origin_rows, image_rows, reason))

        name = "TwoDimensionalLineBuffer_{}type_{}x{}pxPerClock_{}x{}window" \
               "_{}x{}img_{}x{}stride_{}x{}origin".format(
            cleanName(str(pixel_type)),
            pixels_per_row_per_clock,
            rows_of_pixels_per_clock,
            window_cols,
            window_rows,
            image_cols,
            image_rows,
            stride_cols,
            stride_rows,
            origin_cols,
            origin_rows
        )

        # if pixel_per_clock greater than stride, emitting that many new windows per clock
        # else just emit one per clock when have enough pixels to do so
        # A buffer makes sure that windows come out at a constant rate, not more than
        # one per clock even if the overall rate is not greater than 1
        windows_per_active_clock = max(
            (rows_of_pixels_per_clock * pixels_per_row_per_clock) //
            (stride_rows * stride_cols), 1)

        # buffered cycle is length of time to collect and emit windows to get an
        # even rate
        windows_per_row_per_clock = max(
            pixels_per_row_per_clock // stride_cols, 1)
        rows_of_windows_per_clock = max(
            rows_of_pixels_per_clock // stride_rows, 1)
        time_per_buffered_cycle = (
            (image_cols * stride_rows) //
            (rows_of_pixels_per_clock * pixels_per_row_per_clock))

        if add_debug_interface:
            debug_interface = [
                'undelayedO',
                Out(Array[windows_per_active_clock,
                          Array[window_rows, Array[window_cols,
                                                   Out(pixel_type)]]]), 'dbCE',
                Out(Bit), 'dbWE',
                Out(Bit)
            ] + GetDBDebugInterface(
                Array[window_rows, Array[window_cols, pixel_type]],
                image_cols // stride_cols,
                max(pixels_per_row_per_clock // stride_cols, 1),
            )
        else:
            debug_interface = []
        IO = [
            'I',
            In(Array[rows_of_pixels_per_clock, Array[pixels_per_row_per_clock,
                                                     In(pixel_type)]]), 'O',
            Out(Array[windows_per_active_clock,
                      Array[window_rows, Array[window_cols,
                                               Out(pixel_type)]]]), 'valid',
            Out(Bit), 'ready',
            Out(Bit)
        ] + debug_interface + ClockInterface(has_ce=True)

        @classmethod
        def definition(cls):
            lb = AnyDimensionalLineBuffer(
                pixel_type,
                [rows_of_pixels_per_clock, pixels_per_row_per_clock],
                [window_rows, window_cols], [image_rows, image_cols],
                [stride_rows, stride_cols], [origin_rows, origin_cols])
            wire(cls.I, lb.I)
            if stride_rows <= rows_of_pixels_per_clock:
                for row_of_windows in range(cls.rows_of_windows_per_clock):
                    for window_per_row in range(cls.windows_per_row_per_clock):
                        wire(
                            cls.O[row_of_windows *
                                  cls.windows_per_row_per_clock +
                                  window_per_row],
                            lb.O[row_of_windows][window_per_row])
                wire(cls.valid, lb.valid)

            else:
                if add_debug_interface:
                    for row_of_windows in range(cls.rows_of_windows_per_clock):
                        for window_per_row in range(
                                cls.windows_per_row_per_clock):
                            wire(
                                cls.undelayedO[row_of_windows *
                                               cls.windows_per_row_per_clock +
                                               window_per_row],
                                lb.O[row_of_windows][window_per_row])
                db = DelayedBuffer(Array[window_rows, Array[window_cols,
                                                            pixel_type]],
                                   image_cols // stride_cols,
                                   max(pixels_per_row_per_clock // stride_cols,
                                       1),
                                   cls.time_per_buffered_cycle,
                                   add_debug_interface=add_debug_interface)
                for row_of_windows in range(cls.rows_of_windows_per_clock):
                    for window_per_row in range(cls.windows_per_row_per_clock):
                        wire(
                            db.I[row_of_windows * cls.windows_per_row_per_clock
                                 + window_per_row],
                            lb.O[row_of_windows][window_per_row])
                wire(lb.valid, db.WE)
                wire(db.valid, cls.valid)
                wire(db.O, cls.O)

                # first time lb is valid, delayed buffer becomes
                # valid permanently
                first_valid_counter = SizedCounterModM(2, has_ce=True)
                zero_const = DefineCoreirConst(1, 0)()
                wire(lb.valid & (zero_const.O == first_valid_counter.O),
                     first_valid_counter.CE)
                # delay the CE of the delayed buffer as the LB output will hit the
                # DB one later, so give the DB that CE
                # this ensure sthat when using CE for a ready-valid chain, don't have to wait until
                delayed_ce_for_db_valid = DefineRegisterAnyType(Bit)()
                wire(bit(cls.CE), delayed_ce_for_db_valid.I)
                #ce_or_last_valid = bit(cls.CE) | (lb.valid & ~last_clock_lb_valid.O)
                # need lb.valid or counter as lb.valid will be 1 on first clock where valid
                # while counter will still be 0
                wire((lb.valid | first_valid_counter.O[0])
                     & delayed_ce_for_db_valid.O, db.CE)
                if add_debug_interface:
                    wire((lb.valid | first_valid_counter.O[0])
                         & delayed_ce_for_db_valid.O, cls.dbCE)
                    wire(lb.valid, cls.dbWE)
                    wire(db.WDATA, cls.WDATA)
                    wire(db.RDATA, cls.RDATA)
                    wire(db.WADDR, cls.WADDR)
                    wire(db.RADDR, cls.RADDR)
                    wire(db.RAMWE, cls.RAMWE)

            wire(cls.CE, lb.CE)
            wire(cls.ready, 1)