class _ROM(Circuit): name = 'LUT_{}t_{}n'.format(cleanName(str(t)), n) addr_width = getRAMAddrWidth(n) IO = [ 'addr', In(Bits[addr_width]), 'data', Out(t), ] + ClockInterface() @classmethod def definition(cls): type_size_in_bits = GetCoreIRBackend().get_type(t).size bit_luts = [] for i in range(type_size_in_bits): bit_luts += [ DefineLUT(seq2int([el[i] for el in init]), getRAMAddrWidth(n))() ] bits_to_type = Hydrate(t) for i in range(type_size_in_bits): wire(bit_luts[i].O, bits_to_type.I[i]) wire(bits_to_type.out, cls.data) for i in range(type_size_in_bits): for j in range(cls.addr_width): wire(cls.addr[j], getattr(bit_luts[i], "I" + str(j)))
class _SIPO(Circuit): name = 'SIPO_{}t_{}n_{}init_{}CE_RESET'.format(cleanName(str(t)), str(n), str(init), str(has_ce), str(has_reset)) IO = ['I', In(t), 'O', Out(Array[n, t])] + \ ClockInterface(has_ce,has_reset) @classmethod def definition(cls): type_size_in_bits = GetCoreIRBackend().get_type(t).size type_to_bits = Dehydrate(t) sipos = MapParallel(type_size_in_bits, DefineSIPO(n, init, has_ce, has_reset)) bits_to_type = MapParallel(n, DefineHydrate(t)) for bit_in_type in range(type_size_in_bits): wire(type_to_bits.out[bit_in_type], sipos.I[bit_in_type]) for sipo_output in range(n): wire(sipos.O[bit_in_type][sipo_output], bits_to_type.I[sipo_output][bit_in_type]) wire(cls.I, type_to_bits.I) wire(bits_to_type.out, cls.O) for bit_in_type in range(type_size_in_bits): if has_ce: wire(cls.CE, sipos.CE[bit_in_type]) if has_reset: wire(cls.RESET, sipos.RESET[bit_in_type])
class _Mux(Circuit): name = 'Mux_{}t_{}n'.format(cleanName(str(t)), n) addr_width = getRAMAddrWidth(n) IO = [ 'data', In(Array[n, t]), 'sel', In(Bits[addr_width]), 'out', Out(t) ] @classmethod def definition(cls): if n > 1: type_size_in_bits = GetCoreIRBackend().get_type(t).size mux = CommonlibMuxN(n, type_size_in_bits) type_to_bits = DefineNativeMapParallel(n, DefineDehydrate(t))() wire(cls.data, type_to_bits.I) wire(type_to_bits.out, mux.I.data) bits_to_type = Hydrate(t) wire(mux.out, bits_to_type.I) wire(bits_to_type.out, cls.out) wire(cls.sel, mux.I.sel) else: wire(cls.data[0], cls.out) sel_term = DefineTermAnyType(Bits[cls.addr_width])() wire(cls.sel, sel_term.I)
class _BitonicSort(Circuit): name = "BitonicSort_t{}_n{}".format(cleanName(str(T)), str(n)) IO = ['I', In(Array[n, T]), 'O', Out(Array[n, T])] @classmethod def definition(BitonicSort): # generate the max value (all 1's) and feed it to all inputs to # power 2 bitonic sorting network not used by inputs t_size = T.size() n_raised_to_nearest_pow2 = pow(2, ceil(log2(n))) if n_raised_to_nearest_pow2 > n: max_const_flat = DefineCoreirConst(t_size, pow(2, t_size) - 1)() max_const = Hydrate(T) wire(max_const_flat.O, max_const.I) pow2_sort = DefineBitonicSortPow2(T, n_raised_to_nearest_pow2, cmp_component)() for i in range(n_raised_to_nearest_pow2): if i < n: wire(BitonicSort.I[i], pow2_sort.I[i]) wire(BitonicSort.O[i], pow2_sort.O[i]) else: wire(max_const.out, pow2_sort.I[i]) term = TermAnyType(T) wire(term.I, pow2_sort.O[i])
class _RAM(Circuit): name = 'RAM_{}t_{}n'.format(cleanName(str(t)), n) addr_width = getRAMAddrWidth(n) IO = [ 'RADDR', In(Bits[addr_width]), 'RDATA', Out(t), 'WADDR', In(Bits[addr_width]), 'WDATA', In(t), 'WE', In(Bit) ] + ClockInterface() @classmethod def definition(cls): type_size_in_bits = GetCoreIRBackend().get_type(t).size ram = DefineRAM(n, type_size_in_bits, read_latency=read_latency)() type_to_bits = Dehydrate(t) wire(cls.WDATA, type_to_bits.I) wire(type_to_bits.out, ram.WDATA) bits_to_type = Hydrate(t) wire(ram.RDATA, bits_to_type.I) wire(bits_to_type.out, cls.RDATA) wire(cls.RADDR, ram.RADDR) wire(ram.WADDR, cls.WADDR) wire(cls.WE, ram.WE)
class _Sort2Elements(Circuit): name = "Sort2Elements_T{}".format(cleanName(str(T))) IO = ['I0', In(T), 'I1', In(T), 'O0', Out(T), 'O1', Out(T)] + ClockInterface() @classmethod def definition(Sort2Elements): sel = DefineMuxAnyType(Array[2, T], 2)() cmp0 = cmp_component(Sort2Elements.I0) cmp1 = cmp_component(Sort2Elements.I1) cmp0_bits = Dehydrate(type(cmp0)) cmp1_bits = Dehydrate(type(cmp1)) wire(cmp0_bits.I, cmp0) wire(cmp1_bits.I, cmp1) lt = DefineCoreirUlt(cmp0_bits.out.N)() wire(lt.I0, cmp0_bits.out) wire(lt.I1, cmp1_bits.out) # lt will emit 1 if I0 is less than wire(Sort2Elements.I0, sel.data[1][0]) wire(Sort2Elements.I1, sel.data[1][1]) wire(Sort2Elements.I0, sel.data[0][1]) wire(Sort2Elements.I1, sel.data[0][0]) wire(lt.O, sel.sel[0]) wire(sel.out[0], Sort2Elements.O0) wire(sel.out[1], Sort2Elements.O1)
class _BitonicSortPow2(Circuit): name = "BitonicSortPow2_t{}_n{}".format(cleanName(str(T)), str(n)) IO = ['I', In(Array[n, T]), 'O', Out(Array[n, T])] @classmethod def definition(BitonicSortPow2): # input ports and all other intermediate merge ports # starting with just input ports ports = [[BitonicSortPow2.I[i] for i in range(n)]] # Sort ranges starting with 2^1 until 2^m == n for i in range(int(log2(n))): elements_per_merge = pow(2, i + 1) cur_prior_stage_port = 0 cur_stage_ports = [] for j in range(n // pow(2, i + 1)): merger = DefineBitonicMergePow2(T, elements_per_merge, cmp_component, j)() for k in range(elements_per_merge): wire(ports[i][cur_prior_stage_port], merger.I[k]) cur_prior_stage_port += 1 cur_stage_ports += [merger.O[k]] ports += [cur_stage_ports] last_ports = ports[-1] for i in range(len(last_ports)): wire(last_ports[i], BitonicSortPow2.O[i])
class _Down_S(Circuit): name = "Down_S_n{}_sel{}_tEl{}_v{}".format(str(n), str(idx), cleanName(str(elem_t)), str(has_valid)) binary_op = False st_in_t = [ST_SSeq(n, elem_t)] st_out_t = ST_SSeq(1, elem_t) IO = ['I', In(st_in_t[0].magma_repr()), 'O', Out(st_out_t.magma_repr())] + \ ClockInterface(False, False) if has_valid: IO += valid_ports @classmethod def definition(cls): if n > 1: inputs_term = TermAnyType(ST_SSeq(n - 1, elem_t).magma_repr()) num_wired_to_output = 0 for i in range(len(cls.I)): if i == idx: wire(cls.I[i], cls.O[0]) num_wired_to_output += 1 else: wire(cls.I[i], inputs_term.I[i - num_wired_to_output]) if has_valid: wire(cls.valid_up, cls.valid_down)
class _Const(Circuit): name = "Const_t{}_hasCE{}_hasReset{}_hasValid{}".format(cleanName(str(t)), str(has_ce), str(has_reset), str(has_valid)) IO = ['O', Out(t.magma_repr())] + ClockInterface(has_ce, has_reset) binary_op = False st_in_t = [] st_out_t = t if has_valid: IO += valid_ports @classmethod def definition(cls): one_const = DefineCoreirConst(1, 1)().O[0] if delay == 0: enabled = one_const else: delay_counter = InitialDelayCounter(delay) wire(delay_counter.CE, one_const) enabled = delay_counter.valid if has_ce: enabled = bit(cls.CE) & enabled luts = DefineLUTAnyType(t.magma_repr(), t.time(), ts_arrays_to_bits(ts_values))() lut_position_counter = AESizedCounterModM(t.time(), has_ce=True, has_reset=has_reset) wire(lut_position_counter.O, luts.addr) wire(cls.O, luts.data) wire(enabled, lut_position_counter.CE) if has_reset: wire(cls.RESET, lut_position_counter.RESET) if has_valid: valid_up_term = TermAnyType(Bit) wire(cls.valid_up, valid_up_term.I) wire(enabled, cls.valid_down)
class _ShiftT(Circuit): name = "Shift_t_n{}_i{}_amt{}_tEl{}__hasCE{}_hasReset{}_hasValid{}".format( str(n), str(i), str(shift_amount), cleanName(str(elem_t)), str(has_ce), str(has_reset), str(has_valid)) binary_op = False st_in_t = [ST_TSeq(n, i, elem_t)] st_out_t = ST_TSeq(n, i, elem_t) IO = ['I', In(st_in_t[0].magma_repr()), 'O', Out(st_out_t.magma_repr())] + \ ClockInterface(has_ce, has_reset) if has_valid: IO += valid_ports @classmethod def definition(cls): enabled = DefineCoreirConst(1, 1)().O[0] if has_valid: enabled = cls.valid_up & enabled wire(cls.valid_up, cls.valid_down) if has_ce: enabled = bit(cls.CE) & enabled value_store = DefineRAM_ST(elem_t, shift_amount, has_reset=has_reset)() # write and read from same location # will write on first iteration through element, write and read on later iterations # output for first iteration is undefined, so ok to read anything next_ram_addr = DefineNestedCounters(elem_t, has_ce=True, has_reset=has_reset)() # its fine that this doesn't account for the invalid clocks. # after the invalid clocks, the next iteration will start from # an index that is possibly not 0. That doesn't matter # as will just loop around ram_addr = AESizedCounterModM(shift_amount, has_ce=True, has_reset=has_reset) wire(ram_addr.O, value_store.WADDR) wire(ram_addr.O, value_store.RADDR) wire(enabled, value_store.WE) wire(enabled, value_store.RE) wire(enabled & next_ram_addr.last, ram_addr.CE) wire(enabled, next_ram_addr.CE) next_ram_addr_term = TermAnyType(Bit) wire(next_ram_addr.valid, next_ram_addr_term.I) wire(cls.I, value_store.WDATA) wire(value_store.RDATA, cls.O) if has_reset: wire(value_store.RESET, cls.RESET) wire(ram_addr.RESET, cls.RESET) wire(next_ram_addr.RESET, cls.RESET)
class _SIPO(Circuit): name = 'Term_{}t'.format(cleanName(str(t))) IO = ['I', In(t)] @classmethod def definition(cls): type_size_in_bits = GetCoreIRBackend().get_type(t).size type_to_bits = Dehydrate(t) term = Term(type_size_in_bits) wire(cls.I, type_to_bits.I) wire(type_to_bits.out, term.I)
class _Up_T(Circuit): name = "Up_T_n{}_i{}_tEl{}_hasCE{}_hasReset{}_hasValid{}".format(str(n), str(i), cleanName(str(elem_t)), str(has_ce), str(has_reset), str(has_valid)) binary_op = False st_in_t = [ST_TSeq(1, n+i-1, elem_t)] st_out_t = ST_TSeq(n, i, elem_t) IO = ['I', In(st_in_t[0].magma_repr()), 'O', Out(st_out_t.magma_repr())] + \ ClockInterface(has_ce, has_reset) if has_valid: IO += valid_ports @classmethod def definition(cls): enabled = DefineCoreirConst(1, 1)().O[0] if has_valid: enabled = cls.valid_up & enabled wire(cls.valid_up, cls.valid_down) if has_ce: enabled = bit(cls.CE) & enabled value_store = DefineRAM_ST(elem_t, 1, has_reset=has_reset)() # write to value_store for first element, read for next element_time_counter = DefineNestedCounters(elem_t, has_ce=True, has_reset=has_reset)() element_idx_counter = AESizedCounterModM(n + i, has_ce=True, has_reset=has_reset) is_first_element = Decode(0, element_idx_counter.O.N)(element_idx_counter.O) zero_addr = DefineCoreirConst(1, 0)().O wire(zero_addr, value_store.WADDR) wire(zero_addr, value_store.RADDR) wire(enabled & is_first_element, value_store.WE) wire(enabled, value_store.RE) wire(enabled, element_time_counter.CE) wire(enabled & element_time_counter.last, element_idx_counter.CE) element_time_counter_term = TermAnyType(Bit) wire(element_time_counter.valid, element_time_counter_term.I) wire(cls.I, value_store.WDATA) output_selector = DefineMuxAnyType(elem_t.magma_repr(), 2)() # on first element, send the input directly out. otherwise, use the register wire(is_first_element, output_selector.sel[0]) wire(value_store.RDATA, output_selector.data[0]) wire(cls.I, output_selector.data[1]) wire(output_selector.out, cls.O) if has_reset: wire(value_store.RESET, cls.RESET) wire(element_time_counter.RESET, cls.RESET) wire(element_idx_counter.RESET, cls.RESET)
class _Passthrough(Circuit): name = "Passthrough_tIn{}_tOut{}".format(cleanName(str(t_in)), str(cleanName(str(t_out)))) IO = ['I', In(t_in.magma_repr()), 'O', Out(t_out.magma_repr())] binary_op = False st_in_t = [t_in] st_out_t = t_out if has_valid: IO += valid_ports @classmethod def definition(cls): cls.output_delay = 0 flat_in_ports = get_nested_ports(cls.I, num_nested_space_layers(t_in), []) flat_out_ports = get_nested_ports(cls.O, num_nested_space_layers(t_out), []) for i_port, o_port in zip(flat_in_ports, flat_out_ports): wire(i_port, o_port) if has_valid: wire(cls.valid_up, cls.valid_down)
class _AtomTupleCreator(Circuit): name = "atomTupleCreator_t0{}_t1{}".format(cleanName(str(t0)), cleanName(str(t1))) binary_op = True st_in_t = [t0, t1] st_out_t = ST_Atom_Tuple(t0, t1) IO = [ "I0", In(st_in_t[0].magma_repr()), "I1", In(st_in_t[1].magma_repr()), "O", Out(st_out_t.magma_repr()) ] if has_valid: IO += valid_ports @classmethod def definition(atom_tuple_creator): wire(atom_tuple_creator.I0, atom_tuple_creator.O[0]) wire(atom_tuple_creator.I1, atom_tuple_creator.O[1]) if has_valid: wire(atom_tuple_creator.valid_up, atom_tuple_creator.valid_down)
class _Register(Circuit): name = 'Register_{}t_{}init_{}CE_{}RESET'.format( cleanName(str(t)), str(init), str(has_ce), str(has_reset)) IO = ['I', In(t), 'O', Out(t)] + \ ClockInterface(has_ce,has_reset) @classmethod def definition(cls): # if using a layer of nesting nested = False if issubclass(type(t), ArrayKind): if type(t.T) == BitKind: regs = [ DefineRegister(t.N, has_ce=has_ce, has_reset=has_reset)() ] nested = True else: regs = [ DefineRegisterAnyType(t.T, init, has_ce, has_reset)() for _ in range(t.N) ] elif issubclass(type(t), TupleKind): regs = [ DefineRegisterAnyType(t_inner, init, has_ce, has_reset)() for t_inner in t.Ts ] else: regs = [DefineRegister(1, init, has_ce, has_reset)()] if nested: for i in range(t.N): wire(cls.I[i], regs[0].I[i]) wire(cls.O[i], regs[0].O[i]) if has_ce: wire(cls.CE, regs[0].CE) if has_reset: wire(cls.RESET, regs[0].RESET) else: for i, reg in enumerate(regs): if type(t) == BitKind: wire(cls.I, reg.I[0]) wire(cls.O, reg.O[0]) else: wire(cls.I[i], reg.I) wire(cls.O[i], reg.O) if has_ce: wire(cls.CE, reg.CE) if has_reset: wire(cls.RESET, reg.RESET)
class _Up_S(Circuit): name = "Up_S_n{}_tEl{}_v{}".format(str(n), cleanName(str(elem_t)), str(has_valid)) binary_op = False st_in_t = [ST_SSeq(1, elem_t)] st_out_t = ST_SSeq(n, elem_t) IO = ['I', In(st_in_t[0].magma_repr()), 'O', Out(st_out_t.magma_repr())] + \ ClockInterface(False, False) if has_valid: IO += valid_ports @classmethod def definition(cls): for i in range(n): cls.wire = wire(cls.I[0], cls.O[i]) if has_valid: wire(cls.valid_up, cls.valid_down)
class _RAM_ST(Circuit): name = 'RAM_ST_{}_hasReset{}'.format(cleanName(str(t)), str(has_reset)) addr_width = getRAMAddrWidth(n) IO = ['RADDR', In(Bits[addr_width]), 'RDATA', Out(t.magma_repr()), 'WADDR', In(Bits[addr_width]), 'WDATA', In(t.magma_repr()), 'WE', In(Bit), 'RE', In(Bit) ] + ClockInterface(has_ce=False, has_reset=has_reset) @classmethod def definition(cls): # each valid clock, going to get a magma_repr in # read or write each one of those to a location rams = [DefineRAMAnyType(t.magma_repr(), t.valid_clocks(), read_latency=read_latency)() for _ in range(n)] read_time_position_counter = DefineNestedCounters(t, has_cur_valid=True, has_ce=True, has_reset=has_reset)() read_valid_term = TermAnyType(Bit) read_last_term = TermAnyType(Bit) write_time_position_counter = DefineNestedCounters(t, has_cur_valid=True, has_ce=True, has_reset=has_reset)() write_valid_term = TermAnyType(Bit) write_last_term = TermAnyType(Bit) read_selector = DefineMuxAnyType(t.magma_repr(), n)() for i in range(n): wire(cls.WDATA, rams[i].WDATA) wire(write_time_position_counter.cur_valid, rams[i].WADDR) wire(read_selector.data[i], rams[i].RDATA) wire(read_time_position_counter.cur_valid, rams[i].RADDR) write_cur_ram = Decode(i, cls.WADDR.N)(cls.WADDR) wire(write_cur_ram & write_time_position_counter.valid, rams[i].WE) wire(cls.RADDR, read_selector.sel) wire(cls.RDATA, read_selector.out) wire(cls.WE, write_time_position_counter.CE) wire(cls.RE, read_time_position_counter.CE) wire(read_time_position_counter.valid, read_valid_term.I) wire(read_time_position_counter.last, read_last_term.I) wire(write_time_position_counter.valid, write_valid_term.I) wire(write_time_position_counter.last, write_last_term.I) if has_reset: wire(cls.RESET, write_time_position_counter.RESET) wire(cls.RESET, read_time_position_counter.RESET)
class _Partition_S(Circuit): name = "Partition_S_no{}_ni{}_tEl{}_v{}".format( str(no), str(ni), cleanName(str(elem_t)), str(has_valid)) binary_op = False st_in_t = [ST_SSeq(no * ni, elem_t)] st_out_t = ST_SSeq(no, ST_SSeq(ni, elem_t)) IO = ['I', In(st_in_t[0].magma_repr()), 'O', Out(st_out_t.magma_repr())] + \ ClockInterface(False, False) if has_valid: IO += valid_ports @classmethod def definition(cls): for i in range(no): for j in range(ni): wire(cls.I[i * ni + j], cls.O[i][j]) if has_valid: wire(cls.valid_up, cls.valid_down)
class _ShiftS(Circuit): name = "Shift_S_n{}_amt{}_tEl{}_hasValid{}".format( str(n), str(shift_amount), cleanName(str(elem_t)), str(has_valid)) binary_op = False st_in_t = [ST_SSeq(n, elem_t)] st_out_t = ST_SSeq(n, elem_t) IO = ['I', In(st_in_t[0].magma_repr()), 'O', Out(st_out_t.magma_repr())] + \ ClockInterface(False, False) if has_valid: IO += valid_ports @classmethod def definition(cls): for i in range(n): # wrap around. first shift_amount outputs undefined, so anything can go out there wire(cls.I[i], cls.O[(i + shift_amount) % n]) if has_valid: wire(cls.valid_up, cls.valid_down)
class _ROM(Circuit): name = 'ROM_{}t_{}n'.format(cleanName(str(t)), n) addr_width = getRAMAddrWidth(n) IO = ['RADDR', In(Bits[addr_width]), 'RDATA', Out(t), 'REN', In(Bit) ] + ClockInterface() @classmethod def definition(cls): type_size_in_bits = GetCoreIRBackend().get_type(t).size rom = DefineROM(n, type_size_in_bits)(coreir_configargs={"init": [list(el) for el in init]}) bits_to_type = Hydrate(t) wire(rom.rdata, bits_to_type.I) wire(bits_to_type.out, cls.RDATA) wire(cls.RADDR, rom.raddr) wire(cls.REN, rom.ren)
class _TupleToSSeq(Circuit): name = "stupleToSSeq_t{}_n{}".format(cleanName(str(t)), str(n)) binary_op = False st_in_t = [ST_SSeq_Tuple(n, t)] st_out_t = ST_SSeq(n, t) IO = [ "I", In(st_in_t[0].magma_repr()), "O", Out(st_out_t.magma_repr()) ] if has_valid: IO += valid_ports @classmethod def definition(cls): wire(cls.I, cls.O) if has_valid: wire(cls.valid_up, cls.valid_down)
class _ShiftTS(Circuit): name = "Shift_ts_no{}_io{}_ni{}_amt{}_tEl{}__hasCE{}_hasReset{}_hasValid{}".format( str(no), str(io), str(ni), str(shift_amount), cleanName(str(elem_t)), str(has_ce), str(has_reset), str(has_valid)) binary_op = False st_in_t = [ST_TSeq(no, io, ST_SSeq(ni, elem_t))] st_out_t = ST_TSeq(no, io, ST_SSeq(ni, elem_t)) IO = ['I', In(st_in_t[0].magma_repr()), 'O', Out(st_out_t.magma_repr())] + \ ClockInterface(has_ce, has_reset) if has_valid: IO += valid_ports @classmethod def definition(cls): enabled = DefineCoreirConst(1, 1)().O[0] if has_valid: enabled = cls.valid_up & enabled wire(cls.valid_up, cls.valid_down) if has_ce: enabled = bit(cls.CE) & enabled # don't need valid on these shift_t as they'll be getting it from the enable signal shift_t_xs = [] for i in range(ni): shift_amount_t = (ni - i + shift_amount - 1) // ni if shift_amount_t == 0: shift_t_xs.append(None) else: shift_t_xs.append( DefineShift_T(no, io, shift_amount_t, elem_t, True, has_reset, False)()) for i in range(ni): if shift_t_xs[i] is None: wire(cls.I[(i - shift_amount) % ni], cls.O[i]) else: wire(cls.I[(i - shift_amount) % ni], shift_t_xs[i].I) wire(shift_t_xs[i].O, cls.O[i]) wire(enabled, shift_t_xs[i].CE) if has_reset: wire(cls.RESET, shift_t_xs[i].RESET)
class _Snd(Circuit): name = "snd_t{}".format(cleanName(str(t))) binary_op = False st_in_t = [t] st_out_t = t.t1 IO = [ "I", In(st_in_t[0].magma_repr()), "O", Out(st_out_t.magma_repr()) ] if has_valid: IO += valid_ports @classmethod def definition(cls): fst_term = TermAnyType(t.t0.magma_repr()) wire(cls.I[0], fst_term.I) wire(cls.I[1], cls.O) if has_valid: wire(cls.valid_up, cls.valid_down)
class _Eq(Circuit): name = "Eq_Atom_{}t".format(cleanName(str(t))) IO = [ 'I', In(ST_Atom_Tuple(t, t).magma_repr()), 'O', Out(ST_Bit().magma_repr()) ] if has_valid: IO += valid_ports binary_op = False st_in_t = [ST_Atom_Tuple(ST_Int(), ST_Int())] st_out_t = ST_Bit() @classmethod def definition(cls): op = DefineEQ(t.magma_repr().size())() wire(cls.I[0], op.I0) wire(cls.I[1], op.I1) wire(op.O, cls.O[0]) if has_valid: wire(cls.valid_up, cls.valid_down)
class _SeqTupleCreator(Circuit): name = "sseqTupleCreator_t{}".format(cleanName(str(t))) binary_op = True st_in_t = [t, t] st_out_t = ST_SSeq_Tuple(2, t) IO = [ "I0", In(t.magma_repr()), "I1", In(t.magma_repr()), "O", Out(st_out_t.magma_repr()) ] if has_valid: IO += valid_ports @classmethod def definition(sseq_tuple_creator): wire(sseq_tuple_creator.I0, sseq_tuple_creator.O[0]) wire(sseq_tuple_creator.I1, sseq_tuple_creator.O[1]) if has_valid: wire(sseq_tuple_creator.valid_up, sseq_tuple_creator.valid_down)
class _If(Circuit): name = "If_Atom_{}t".format(cleanName(str(t))) binary_op = False st_in_t = [ST_Atom_Tuple(ST_Bit(), ST_Atom_Tuple(t, t))] st_out_t = t IO = [ 'I', In(st_in_t[0].magma_repr()), 'O', Out(st_out_t.magma_repr()) ] if has_valid: IO += valid_ports @classmethod def definition(cls): op = DefineMuxAnyType(t.magma_repr(), 2)() wire(cls.I[0], op.sel) wire(cls.I[1][0], op.data[1]) wire(cls.I[1][1], op.data[0]) wire(op.out, cls.O) if has_valid: wire(cls.valid_up, cls.valid_down)
class _SeqTupleAppender(Circuit): name = "sseqTupleAppender_t{}_n{}".format(cleanName(str(t)), str(n)) binary_op = True st_in_t = [ST_SSeq_Tuple(n, t), t] st_out_t = ST_SSeq_Tuple(n + 1, t) IO = [ "I0", In(st_in_t[0].magma_repr()), "I1", In(st_in_t[1].magma_repr()), "O", Out(st_out_t.magma_repr()) ] if has_valid: IO += valid_ports @classmethod def definition(sseq_tuple_appender): for i in range(n): wire(sseq_tuple_appender.I0[i], sseq_tuple_appender.O[i]) wire(sseq_tuple_appender.I1, sseq_tuple_appender.O[n]) if has_valid: wire(sseq_tuple_appender.valid_up, sseq_tuple_appender.valid_down)
class _Down_T(Circuit): name = "Down_S_n{}_i{}_sel{}_tEl{}_v{}".format(str(n), str(i), str(idx), cleanName(str(elem_t)), str(has_valid)) binary_op = False st_in_t = [ST_TSeq(n, i, elem_t)] st_out_t = ST_TSeq(1, n + i - 1, elem_t) IO = ['I', In(st_in_t[0].magma_repr()), 'O', Out(st_out_t.magma_repr())] + \ ClockInterface(False, False) if has_valid: IO += valid_ports @classmethod def definition(cls): wire(cls.I, cls.O) if has_valid: if elem_t.time() * idx == 0: wire(cls.valid_up, cls.valid_down) else: delay_counter = InitialDelayCounter(elem_t.time() * idx) wire(cls.valid_up, delay_counter.CE) wire(delay_counter.valid, cls.valid_down)
class _BitonicMergePow2(Circuit): name = "BitonicMergePow2_t{}_n{}_ithMerge{}".format( cleanName(str(T)), str(n), str(ith_merge)) IO = ['I', In(Array[n, T]), 'O', Out(Array[n, T])] @classmethod def definition(BitonicMergePow2): # first sort the inputs once first_sorts_ports = [0] * n for i in range(n // 2): pair_sort = DefineSort2Elements(T, cmp_component)() # reverse if ith_merge is not divisible by 2 wire(BitonicMergePow2.I[i], pair_sort.I0) wire(BitonicMergePow2.I[n // 2 + i], pair_sort.I1) if ith_merge % 2 == 0: first_sorts_ports[i] = pair_sort.O0 first_sorts_ports[n // 2 + i] = pair_sort.O1 else: first_sorts_ports[i] = pair_sort.O1 first_sorts_ports[n // 2 + i] = pair_sort.O0 if n == 2: wire(first_sorts_ports[0], BitonicMergePow2.O[0]) wire(first_sorts_ports[1], BitonicMergePow2.O[1]) else: # next merge each of the halfs mergers = [ DefineBitonicMergePow2(T, n // 2, cmp_component, ith_merge)(), DefineBitonicMergePow2(T, n // 2, cmp_component, ith_merge)() ] for i in range(n // 2): wire(first_sorts_ports[i], mergers[0].I[i]) wire(first_sorts_ports[n // 2 + i], mergers[1].I[i]) wire(mergers[0].O[i], BitonicMergePow2.O[i]) wire(mergers[1].O[i], BitonicMergePow2.O[n // 2 + i])
class _LB(Circuit): if image_cols % pixels_per_row_per_clock != 0: reason = """ this is necessary so that input a complete row before getting the input for the next row. This means don't have a clock cycle where input pixels are from two different rows in the input image. """ raise Exception("Aetherling's Native LineBuffer has invalid " "parameters: image_cols {} not divisible by" "pixels_per_row_per_clock {}. \n Reason {}".format( image_cols, pixels_per_row_per_clock, reason)) if image_rows % rows_of_pixels_per_clock != 0: reason = """ this is necessary so that input a complete image with the same number of input pixels every clock, don't have a weird ending with only 1 pixel from the image left to input """ raise Exception("Aetherling's Native LineBuffer has invalid " "parameters: image_rows {} not divisible by" "pixels_per_row_per_clock {}".format( image_rows, rows_of_pixels_per_clock, reason)) if image_cols % stride_cols != 0: reason = "stride_cols is downsample factor for number of columns." \ "This requirement ensures that the column downsample factor" \ "cleanly divides the image's number of columns" raise Exception("Aetherling's Native LineBuffer has invalid " "parameters: image_cols {} not divisible by" "stride_cols {}. \n Reason: {}".format( image_cols, stride_cols, reason)) if image_rows % stride_rows != 0: reason = "stride_rows is downsample factor for number of rows." \ "This requirement ensures that the row downsample factor" \ "cleanly divides the image's number of rows" raise Exception("Aetherling's Native LineBuffer has invalid " "parameters: image_rows {} not divisible by" "stride_rows {}. \n Reason: {}".format( image_rows, stride_rows, reason)) if (stride_cols % pixels_per_row_per_clock != 0 and pixels_per_row_per_clock % stride_cols != 0): reason = """ the average number of output windows per row per clock = pixels per row per clock / stride cols per clock Number of output windows per row per clock must be integer or reciprocal of one so that the position of output windows relative to the new pixels each clock is constant. Otherwise there will be different numbers of new pixels in each window in each clock. This is challenging to implement. """ raise Exception( "Aetherling's Native LineBuffer has invalid " "parameters: stride_cols {} not divisible by" "pixels_per_row_per_clock {} nor vice-verse. One of them must" "be divisible by the other. \n Reason: {}".format( stride_cols, pixels_per_row_per_clock, reason)) if (stride_rows % rows_of_pixels_per_clock != 0 and rows_of_pixels_per_clock % stride_rows != 0): reason = """ the average number of rows of output windows per clock = rows of pixels per clock / stride rows per clock Number of rows of output windows per clock must be integer or reciprocal of one so that the position of output windows relative to the new pixels each clock is constant. Otherwise there will be different numbers of new pixels in each window in each clock. This is challenging to implement. """ raise Exception( "Aetherling's Native LineBuffer has invalid " "parameters: stride_rows {} not divisible by" "rows_of_pixels_per_clock {} nor vice-verse. One of them must" "be divisible by the other. \n Reason: {}".format( stride_rows, rows_of_pixels_per_clock, reason)) if ((stride_cols * stride_rows) % (pixels_per_row_per_clock * rows_of_pixels_per_clock) != 0) and \ ((pixels_per_row_per_clock * rows_of_pixels_per_clock) % (stride_cols * stride_rows) != 0): reason = """ the average number of output windows per clock = (pixels per row per clock * rows of pixels per clock) / (stride cols per clock * stride rows per clock) Number of output windows per clock must be integer or reciprocal of one so that throughput is an easier factor to manipulate with map/underutil. Otherwise throughput is a weird fraction and the downstream system is either only partially used on some clocks or the sequence length is multiplied by a weird factor that makes the rational number throughput become an integer. """ raise Exception( "Aetherling's Native LineBuffer has invalid " "parameters: stride_cols {} * stride_rows{} not divisible by" "pixels_per_row_per_clock {} * rows_of_pixels_per_clock nor vice-verse. One of them must" "be divisible by the other. \n Reason: {}".format( stride_cols, stride_rows, pixels_per_row_per_clock, rows_of_pixels_per_clock, reason)) if abs(origin_cols) >= window_cols: reason = """ origin must be less than window. If abs(origin_cols) was greater than window_cols, then entire first window would be garbage """ raise Exception("Aetherling's Native LineBuffer has invalid " "parameters: |origin| {} greater than or equal to" "window width {}".format(abs(origin_cols), window_cols)) if abs(origin_rows) >= window_rows: reason = """ origin must be less than window. If abs(origin_rows) was greater than window_rows, then entire first window would be garbage """ raise Exception("Aetherling's Native LineBuffer has invalid " "parameters: |origin| {} greater than or equal to" "window width {}".format(abs(origin_rows), window_rows)) if origin_cols > 0: reason = """ origin_cols can't go into image. That would be cropping the first cols of the image and linebuffer doesn't do cropping. """ raise Exception( "Aetherling's Native LineBuffer has invalid " "parameters: origin_cols {} greater than 0. \n Reason: {}". format(origin_cols, reason)) if origin_rows > 0: reason = """ origin_rows can't go into image. That would be cropping the first rows of the image and linebuffer doesn't do cropping. """ raise Exception( "Aetherling's Native LineBuffer has invalid " "parameters: origin_rows {} greater than 0. \n Reason: {}". format(origin_rows, reason)) if window_cols - origin_cols >= image_cols: reason = """ need window_cols plus abs(origin_cols) outputs to do wiring. If the image_cols is smaller than this, will have issues with internal wiring. Additionally, the linebuffer isn't used for images that are small enough to be processed in one or a few clock cycles. This is a weird edge case that I don't want to deal with and shouldn't occur in the real world. """ raise Exception( "Aetherling's Native LineBuffer has invalid " "parameters: window_cols {} - origin_cols {} " "greater than or equal to image_cols {}. \n Reason: {}".format( window_cols, origin_cols, image_cols, reason)) if window_rows - origin_rows >= image_rows: reason = """ need window_rows plus abs(origin_rows) outputs to do wiring. If the image_rows is smaller than this, will have issues with internal wiring. Additionally, the linebuffer isn't used for images that are small enough to be processed in one or a few clock cycles. This is a weird edge case that I don't want to deal with and shouldn't occur in the real world. """ raise Exception( "Aetherling's Native LineBuffer has invalid " "parameters: window_rows {} - origin_rows {} " "greater than or equal to image_rows {}. \n Reason: {}".format( window_rows, origin_rows, image_rows, reason)) name = "TwoDimensionalLineBuffer_{}type_{}x{}pxPerClock_{}x{}window" \ "_{}x{}img_{}x{}stride_{}x{}origin".format( cleanName(str(pixel_type)), pixels_per_row_per_clock, rows_of_pixels_per_clock, window_cols, window_rows, image_cols, image_rows, stride_cols, stride_rows, origin_cols, origin_rows ) # if pixel_per_clock greater than stride, emitting that many new windows per clock # else just emit one per clock when have enough pixels to do so # A buffer makes sure that windows come out at a constant rate, not more than # one per clock even if the overall rate is not greater than 1 windows_per_active_clock = max( (rows_of_pixels_per_clock * pixels_per_row_per_clock) // (stride_rows * stride_cols), 1) # buffered cycle is length of time to collect and emit windows to get an # even rate windows_per_row_per_clock = max( pixels_per_row_per_clock // stride_cols, 1) rows_of_windows_per_clock = max( rows_of_pixels_per_clock // stride_rows, 1) time_per_buffered_cycle = ( (image_cols * stride_rows) // (rows_of_pixels_per_clock * pixels_per_row_per_clock)) if add_debug_interface: debug_interface = [ 'undelayedO', Out(Array[windows_per_active_clock, Array[window_rows, Array[window_cols, Out(pixel_type)]]]), 'dbCE', Out(Bit), 'dbWE', Out(Bit) ] + GetDBDebugInterface( Array[window_rows, Array[window_cols, pixel_type]], image_cols // stride_cols, max(pixels_per_row_per_clock // stride_cols, 1), ) else: debug_interface = [] IO = [ 'I', In(Array[rows_of_pixels_per_clock, Array[pixels_per_row_per_clock, In(pixel_type)]]), 'O', Out(Array[windows_per_active_clock, Array[window_rows, Array[window_cols, Out(pixel_type)]]]), 'valid', Out(Bit), 'ready', Out(Bit) ] + debug_interface + ClockInterface(has_ce=True) @classmethod def definition(cls): lb = AnyDimensionalLineBuffer( pixel_type, [rows_of_pixels_per_clock, pixels_per_row_per_clock], [window_rows, window_cols], [image_rows, image_cols], [stride_rows, stride_cols], [origin_rows, origin_cols]) wire(cls.I, lb.I) if stride_rows <= rows_of_pixels_per_clock: for row_of_windows in range(cls.rows_of_windows_per_clock): for window_per_row in range(cls.windows_per_row_per_clock): wire( cls.O[row_of_windows * cls.windows_per_row_per_clock + window_per_row], lb.O[row_of_windows][window_per_row]) wire(cls.valid, lb.valid) else: if add_debug_interface: for row_of_windows in range(cls.rows_of_windows_per_clock): for window_per_row in range( cls.windows_per_row_per_clock): wire( cls.undelayedO[row_of_windows * cls.windows_per_row_per_clock + window_per_row], lb.O[row_of_windows][window_per_row]) db = DelayedBuffer(Array[window_rows, Array[window_cols, pixel_type]], image_cols // stride_cols, max(pixels_per_row_per_clock // stride_cols, 1), cls.time_per_buffered_cycle, add_debug_interface=add_debug_interface) for row_of_windows in range(cls.rows_of_windows_per_clock): for window_per_row in range(cls.windows_per_row_per_clock): wire( db.I[row_of_windows * cls.windows_per_row_per_clock + window_per_row], lb.O[row_of_windows][window_per_row]) wire(lb.valid, db.WE) wire(db.valid, cls.valid) wire(db.O, cls.O) # first time lb is valid, delayed buffer becomes # valid permanently first_valid_counter = SizedCounterModM(2, has_ce=True) zero_const = DefineCoreirConst(1, 0)() wire(lb.valid & (zero_const.O == first_valid_counter.O), first_valid_counter.CE) # delay the CE of the delayed buffer as the LB output will hit the # DB one later, so give the DB that CE # this ensure sthat when using CE for a ready-valid chain, don't have to wait until delayed_ce_for_db_valid = DefineRegisterAnyType(Bit)() wire(bit(cls.CE), delayed_ce_for_db_valid.I) #ce_or_last_valid = bit(cls.CE) | (lb.valid & ~last_clock_lb_valid.O) # need lb.valid or counter as lb.valid will be 1 on first clock where valid # while counter will still be 0 wire((lb.valid | first_valid_counter.O[0]) & delayed_ce_for_db_valid.O, db.CE) if add_debug_interface: wire((lb.valid | first_valid_counter.O[0]) & delayed_ce_for_db_valid.O, cls.dbCE) wire(lb.valid, cls.dbWE) wire(db.WDATA, cls.WDATA) wire(db.RDATA, cls.RDATA) wire(db.WADDR, cls.WADDR) wire(db.RADDR, cls.RADDR) wire(db.RAMWE, cls.RAMWE) wire(cls.CE, lb.CE) wire(cls.ready, 1)