def __init__(s, interface): UseInterface(s, interface) nwords = s.interface.NumWords num_read_ports = len(s.read_data) num_write_ports = len(s.write_data) # The core ram s.regs = [Wire(s.interface.Data) for _ in range(nwords)] if s.interface.Bypass: @s.tick_rtl def handle_writes(): for i in range(num_write_ports): if s.write_call[i]: s.regs[s.write_addr[i]].v = s.write_data[i] for i in range(num_read_ports): s.read_data[i].v = s.regs[s.read_next_addr[i]] else: @s.tick_rtl def handle_writes(): for i in range(num_write_ports): if s.write_call[i]: s.regs[s.write_addr[i]].n = s.write_data[i] for i in range(num_read_ports): s.read_data[i].n = s.regs[s.read_next_addr[i]]
def __init__(s, dtype, nports): UseInterface(s, MuxInterface(dtype, nports)) @s.combinational def select(): assert s.mux_select < nports s.mux_out.v = s.mux_in_[s.mux_select]
def __init__(s, interface): UseInterface(s, interface) @s.combinational def compute(): # PYMTL_BROKEN unary - translates but does not simulate s.grant_grant.v = s.grant_reqs & (0 - s.grant_reqs)
def __init__(s, interface): UseInterface(s, interface) nreqs = s.interface.nreqs s.mask = Register(RegisterInterface(Bits(nreqs)), reset_value=0) s.masker = ThermometerMask(ThermometerMaskInterface(nreqs)) s.raw_arb = PriorityArbiter(ArbiterInterface(nreqs)) s.masked_arb = PriorityArbiter(ArbiterInterface(nreqs)) s.final_grant = Wire(nreqs) s.connect(s.raw_arb.grant_reqs, s.grant_reqs) s.connect(s.masker.mask_in_, s.mask.read_data) @s.combinational def compute(): s.masked_arb.grant_reqs.v = s.grant_reqs & s.masker.mask_out if s.masked_arb.grant_grant == 0: s.final_grant.v = s.raw_arb.grant_grant else: s.final_grant.v = s.masked_arb.grant_grant @s.combinational def shift_write(): s.mask.write_data.v = s.final_grant << 1 s.connect(s.grant_grant, s.final_grant)
def __init__(s, interface, reset_value=None): UseInterface(s, interface) s.reg_value = Wire(s.interface.Data) s.update = Wire(1) if s.interface.enable: s.connect(s.update, s.write_call) else: s.connect(s.update, 1) if s.interface.write_read_bypass: @s.combinational def read(): s.read_data.v = s.write_data if s.update else s.reg_value else: s.connect(s.read_data, s.reg_value) # Create the sequential update block: if reset_value is not None: @s.tick_rtl def update(): if s.reset: s.reg_value.n = reset_value elif s.update: s.reg_value.n = s.write_data else: @s.tick_rtl def update(): if s.update: s.reg_value.n = s.write_data
def __init__(s, xlen): UseInterface(s, Interface([])) s.require( MethodSpec( 'check_redirect', args={}, rets={ 'redirect': Bits(1), 'target': Bits(xlen), }, call=False, rdy=False, ), MethodSpec( 'kill_notify', args={ 'msg': Bits(1), }, rets=None, call=False, rdy=False, ), ) s.connect(s.kill_notify_msg, s.check_redirect_redirect)
def __init__(s): UseInterface(s, PairTestInterface()) @s.combinational def do(): s.make_pair_pair.first = s.make_pair_a s.make_pair_pair.second = s.make_pair_b
def __init__(s, mem_func): UseInterface(s, PayloadGeneratorInterface(Bits(1), MemMsg())) s.connect(s.gen_payload.func, int(mem_func)) # PYMTL_BROKEN s.funct3 = Wire(3) s.connect(s.gen_inst.funct3, s.funct3) # PYMTL_BROKEN unsigned is a verilog keyword s.unsigned_ = Wire(1) s.width = Wire(2) # PYMTL_BROKEN @s.combinational def handle_funct3(): s.unsigned_.v = s.funct3[2] s.width.v = s.funct3[0:2] s.connect(s.gen_payload.unsigned, s.unsigned_) s.connect(s.gen_payload.width, s.width) # For loads, the only invalid unsigned / width combination # is an unsigned double word (width = 0b11, unsigned = 1) # For stores, all widths are valid, but the unsigned bit must # always be 0 if mem_func == MemFunc.MEM_FUNC_LOAD: @s.combinational def compute_valid(): s.gen_valid.v = not (s.width == 0b11 and s.unsigned_ == 0b1) else: @s.combinational def compute_valid(): s.gen_valid.v = not s.unsigned_
def __init__(s, inwidth): UseInterface(s, PriorityDecoderInterface(inwidth)) s.valid = [Wire(1) for _ in range(inwidth + 1)] s.outs = [Wire(s.interface.Out) for _ in range(inwidth + 1)] # PYMTL_BROKEN @s.combinational def connect_is_broken(): s.valid[0].v = 0 s.outs[0].v = 0 for i in range(inwidth): @s.combinational def handle_decode(n=i + 1, i=i): if s.valid[i]: s.valid[n].v = 1 s.outs[n].v = s.outs[i] elif s.decode_signal[i]: s.valid[n].v = 1 s.outs[n].v = i else: s.valid[n].v = 0 s.outs[n].v = 0 s.connect(s.outs[inwidth], s.decode_decoded) s.connect(s.valid[inwidth], s.decode_valid)
def __init__(s, interface, svalues): UseInterface(s, interface) size = s.interface.nports assert size == len(svalues) s.out_chain = [Wire(s.interface.Data) for _ in range(size + 1)] s.valid_chain = [Wire(1) for _ in range(size + 1)] # PYMTL_BROKEN @s.combinational def connect_is_broken(): s.out_chain[0].v = s.mux_default s.valid_chain[0].v = 0 for i, svalue in enumerate(svalues): @s.combinational def chain(curr=i + 1, last=i, svalue=int(svalue)): if s.mux_select == svalue: s.out_chain[curr].v = s.mux_in_[last] s.valid_chain[curr].v = 1 else: s.out_chain[curr].v = s.out_chain[last] s.valid_chain[curr].v = s.valid_chain[last] s.connect(s.mux_out, s.out_chain[-1]) s.connect(s.mux_matched, s.valid_chain[size])
def __init__(s): UseInterface(s, MultOutputPipelineAdapterInterface()) s.out_temp = Wire(s.interface.Out) s.out_mmsg = Wire(MMsg()) s.out_32 = Wire(32) # PYMTL_BROKEN # Use temporary wire to prevent pymtl bug num_bits = MMsg().nbits # PYMTL_BROKEN # 2D array bug s.fuse_kill_data_result = Wire(XLEN) @s.combinational def connect_wire_workaround(): s.fuse_kill_data_result.v = s.fuse_kill_data.result @s.combinational def compute_out(XLEN_2=2 * XLEN, num_bits=num_bits): s.out_mmsg.v = s.fuse_kill_data_result[:num_bits] # Magic cast s.out_temp.v = s.fuse_kill_data s.out_32.v = s.fuse_internal_out[:32] if s.out_mmsg.op32: s.out_temp.result.v = sext(s.out_32, XLEN) elif s.out_mmsg.variant == MVariant.M_VARIANT_N or s.out_mmsg.variant == MVariant.M_VARIANT_U: s.out_temp.result.v = s.fuse_internal_out[:XLEN] else: s.out_temp.result.v = s.fuse_internal_out[XLEN:XLEN_2] s.connect(s.fuse_out, s.out_temp)
def __init__(s, stage_class, drop_controller_class=None): def gen(): return ForwardingStage(stage_class) gen.__name__ = stage_class.__name__ s.gen_stage = gen_stage(gen, drop_controller_class)() UseInterface(s, s.gen_stage.interface) s.wrap(s.gen_stage, ['forward']) s.require( MethodSpec( 'forward', args={ 'tag': PREG_IDX_NBITS, 'value': Bits(XLEN), }, rets=None, call=True, rdy=False, ), ) s.forwarder = Forwarder() s.connect_m(s.forwarder.in_forward, s.gen_stage.forward) s.connect_m(s.forwarder.in_peek, s.gen_stage.peek) s.connect_m(s.forwarder.in_take, s.gen_stage.take) s.connect_m(s.forward, s.forwarder.forward) s.connect_m(s.peek, s.forwarder.peek) s.connect_m(s.take, s.forwarder.take) if hasattr(s, 'kill_notify'): s.connect_m(s.gen_stage.kill_notify, s.kill_notify)
def __init__(s, alu_interface): UseInterface(s, alu_interface) xlen = s.interface.Xlen # PYMTL BROKEN: XLEN_M1 = xlen - 1 # Input s.s0_ = Wire(xlen) s.s1_ = Wire(xlen) s.func_ = Wire(CMPFunc.bits) # Flags s.eq_ = Wire(1) s.lt_ = Wire(1) # Output s.res_ = Wire(1) # Since single cycle, always ready s.connect(s.exec_rdy, 1) s.connect(s.exec_res, s.res_) s.connect(s.func_, s.exec_func) # All workarorunds due to slicing in concat() issues: s.s0_lower_ = Wire(XLEN_M1) s.s0_up_ = Wire(1) s.s1_lower_ = Wire(XLEN_M1) s.s1_up_ = Wire(1) @s.combinational def set_flags(): s.eq_.v = s.s0_ == s.s1_ s.lt_.v = s.s0_ < s.s1_ @s.combinational def set_signed(): # We flip the upper most bit if signed s.s0_up_.v = s.exec_src0[ XLEN_M1] if s.exec_unsigned else not s.exec_src0[XLEN_M1] s.s1_up_.v = s.exec_src1[ XLEN_M1] if s.exec_unsigned else not s.exec_src1[XLEN_M1] s.s0_lower_.v = s.exec_src0[0:XLEN_M1] s.s1_lower_.v = s.exec_src1[0:XLEN_M1] # Now we can concat and compare s.s0_.v = concat(s.s0_up_, s.s0_lower_) s.s1_.v = concat(s.s1_up_, s.s1_lower_) @s.combinational def eval_comb(): s.res_.v = 0 if s.func_ == CMPFunc.CMP_EQ: s.res_.v = s.eq_ elif s.func_ == CMPFunc.CMP_NE: s.res_.v = not s.eq_ elif s.func_ == CMPFunc.CMP_LT: s.res_.v = s.lt_ elif s.func_ == CMPFunc.CMP_GE: s.res_.v = not s.lt_ or s.eq_
def __init__(s, interface, clients): UseInterface(s, interface) reqs = [] for client in clients: reqs.extend([ MethodSpec( '{}_peek'.format(client), args=None, rets={ 'msg': s.interface.MsgType, }, call=False, rdy=True, ), MethodSpec( '{}_take'.format(client), args=None, rets=None, call=True, rdy=False, ), ]) s.require(*reqs) ninputs = len(clients) s.index_peek_msg = [Wire(s.interface.MsgType) for _ in range(ninputs)] s.index_peek_rdy = [Wire(1) for _ in range(ninputs)] s.index_take_call = [Wire(1) for _ in range(ninputs)] for i, client in enumerate(clients): s.connect(s.index_peek_msg[i], getattr(s, '{}_peek'.format(client)).msg) s.connect(s.index_peek_rdy[i], getattr(s, '{}_peek'.format(client)).rdy) s.connect( getattr(s, '{}_take'.format(client)).call, s.index_take_call[i]) s.arb = PriorityArbiter(ArbiterInterface(ninputs)) s.mux = CaseMux( CaseMuxInterface(s.interface.MsgType, Bits(ninputs), ninputs), [1 << i for i in range(ninputs)]) @s.combinational def compute_ready(): s.peek_rdy.v = (s.arb.grant_grant != 0) for i in range(ninputs): s.connect(s.arb.grant_reqs[i], s.index_peek_rdy[i]) # call an input if granted and we are being called @s.combinational def compute_call(i=i): s.index_take_call[i].v = s.arb.grant_grant[i] & s.take_call s.connect(s.mux.mux_in_[i], s.index_peek_msg[i]) s.connect(s.mux.mux_default, 0) s.connect(s.mux.mux_select, s.arb.grant_grant) s.connect(s.peek_msg, s.mux.mux_out)
def __init__(s, interface): UseInterface(s, interface) s.connect(s.check_out, s.check_in_) @s.combinational def handle_check_keep(): s.check_keep.v = not s.check_msg
def __init__(s, interface): UseInterface(s, interface) @s.combinational def compute(): s.mask_out.v = s.mask_in_ ^ (s.mask_in_ - 1) s.mask_out.v = s.mask_out if s.mask_in_ == 0 else ~s.mask_out s.mask_out.v = s.mask_out | s.mask_in_
def __init__(s): UseInterface(s, StageInterface(Bits(8), Bits(8))) s.connect(s.process_accepted, 1) @s.combinational def compute(): s.process_out.v = s.process_in_ + 2
def __init__(s): UseInterface(s, SubDecoderInterface()) s.composite_decoder = CompositeDecoder(len(classes)) s.decs = [class_() for class_ in classes] for i in range(len(classes)): s.connect_m(s.composite_decoder.decode_child[i], s.decs[i].decode) s.connect_m(s.decode, s.composite_decoder.decode)
def __init__(s, interface): UseInterface(s, interface) s.kill_match = Wire(s.interface.Out.nbits) @s.combinational def handle_check(): s.kill_match.v = s.check_in_ & s.check_msg.kill_mask s.check_keep.v = not (reduce_or(s.kill_match) or s.check_msg.force) s.check_out.v = s.check_in_ & (~s.check_msg.clear_mask)
def __init__(s): UseInterface(s, StructTestInterface()) @s.combinational def do(): s.decompose_w1.v = s.decompose_in_.w1 s.decompose_w2.v = s.decompose_in_.w2 s.decompose_w3.v = s.decompose_in_.w3 s.decompose_w4.v = s.decompose_in_.w4
def __init__(s, interface): UseInterface(s, interface) size = len(s.interface.clients) Pipe = Bits(clog2nz(size)) s.require( MethodSpec( 'in_peek', args=None, rets={ 'msg': s.interface.Data, }, call=False, rdy=True, ), MethodSpec( 'in_take', args=None, rets=None, call=True, rdy=False, ), MethodSpec( 'sort', args={'msg': s.interface.Data}, rets={'pipe': Pipe}, call=False, rdy=False, ), ) s.peek_array = [ getattr(s, '{}_peek'.format(client)) for client in s.interface.clients ] s.take_array = [ getattr(s, '{}_take'.format(client)) for client in s.interface.clients ] s.rdy_array = [Wire(1) for _ in range(size)] for i in range(size): s.connect(s.peek_array[i].rdy, s.rdy_array[i]) s.connect(s.sort_msg, s.in_peek_msg) s.take_mux = Mux(Bits(1), size) s.effective_call = Wire(1) for i in range(size): @s.combinational def handle_rdy(i=i): s.rdy_array[i].v = (s.sort_pipe == i) and s.in_peek_rdy s.connect(s.peek_array[i].msg, s.in_peek_msg) s.connect(s.take_mux.mux_in_[i], s.take_array[i].call) s.connect(s.take_mux.mux_select, s.sort_pipe) s.connect(s.in_take_call, s.take_mux.mux_out)
def __init__(s, mul_interface, use_mul=True): UseInterface(s, mul_interface) assert s.interface.MultiplierLen >= s.interface.MultiplicandLen plen = s.interface.ProductLen res_len = s.interface.MultiplierLen + s.interface.MultiplicandLen s.tmp_res = Wire(res_len) if not use_mul: s.src1_ = Wire(plen) s.tmps_ = [ Wire(res_len) for _ in range(s.interface.MultiplicandLen + 1) ] if plen >= s.interface.MultiplierLen: @s.combinational def src1_zext(): s.src1_.v = s.mult_src1 else: @s.combinational def src1_truncate(): s.src1_.v = s.mult_src1[:plen] s.connect_wire(s.tmp_res, s.tmps_[s.interface.MultiplicandLen]) # PYMTL_BROKEN Direction is inferred wrong: #s.connect_wire(s.tmps_[0], 0) @s.combinational def eval_base(): s.tmps_[0].v = 0 for i in range(1, s.interface.MultiplicandLen + 1): @s.combinational def eval(i=i): s.tmps_[i].v = s.tmps_[i - 1] if s.mult_src2[i - 1]: s.tmps_[i].v = s.tmps_[i - 1] + (s.src1_ << (i - 1)) else: @s.combinational def eval(): s.tmp_res.v = (s.mult_src1 * s.mult_src2) # Now we need to zext or truncate to productlen if plen > res_len: @s.combinational def zext_prod(): s.mult_res.v = zext(s.tmp_res, plen) else: @s.combinational def trunc_prod(): s.mult_res.v = s.tmp_res[:plen]
def __init__(s, dtype, nports): UseInterface(s, UnpackerInterface(dtype, nports)) for i in range(nports): @s.combinational def unpack(i=i, start=i * s.interface.Data.nbits, end=(i + 1) * s.interface.Data.nbits): s.unpack_out[i].v = s.unpack_packed[start:end]
def __init__(s): UseInterface(s, PipelineStageInterface(Bits(8), None)) s.stage_0 = CounterStage() s.stage_1 = Add2Stage() s.connect_m(s.stage_0.peek, s.stage_1.in_peek) s.connect_m(s.stage_0.take, s.stage_1.in_take) s.connect_m(s.stage_1.peek, s.peek) s.connect_m(s.stage_1.take, s.take)
def __init__(s): UseInterface(s, FunctionalFormTestInterface()) s.sum_4 = Wire(4) @s.combinational def add(): s.sum_4.v = s.add_w4_a[:4] + s.add_w4_b[:4] # Sext the result s.call(Sext, s.sum_4, 8, out=s.add_w4_sum)
def __init__(s): UseInterface(s, StageInterface(None, Bits(8))) s.counter = Register(RegisterInterface(Bits(8), enable=True), reset_value=0) s.connect(s.process_accepted, 1) s.connect(s.process_out, s.counter.read_data) s.connect(s.counter.write_call, s.process_call) @s.combinational def count(): s.counter.write_data.v = s.counter.read_data + 1
def __init__(s, interface, In, Intermediate=None): UseInterface(s, interface) # Assume intermediate type is same as output unless specified Intermediate = Intermediate or s.interface.MsgType # Require the methods of an incoming pipeline stage # Note that if In is None, the incoming stage will have no methods # Name the methods in_peek, in_take s.require(*[ m.variant(name='in_{}'.format(m.name)) for m in PipelineStageInterface(In, None).methods.values() ]) s.require(StageInterface(In, interface.MsgType)['process']) # If this pipeline stage outputs, require a drop controller if interface.MsgType is not None: s.require( DropControllerInterface(interface.MsgType, Intermediate, interface.KillArgType)['check']) s.vvm = ValidValueManager( ValidValueManagerInterface(interface.MsgType, Intermediate, interface.KillArgType)) s.input_available = Wire(1) s.output_clear = Wire(1) s.advance = Wire(1) s.taking = Wire(1) @s.combinational def handle_taking(): s.taking.v = s.advance & s.process_accepted s.connect(s.process_call, s.advance) if In is not None: s.connect(s.process_in_, s.in_peek_msg) s.connect(s.in_take_call, s.taking) s.connect(s.input_available, s.in_peek_rdy) else: s.connect(s.input_available, 1) if interface.MsgType is not None: s.connect(s.vvm.add_msg, s.process_out) s.connect(s.vvm.add_call, s.taking) if s.interface.KillArgType is not None: s.connect_m(s.vvm.kill_notify, s.kill_notify) s.connect_m(s.vvm.check, s.check) s.connect_m(s.vvm.peek, s.peek) s.connect_m(s.vvm.take, s.take) s.connect(s.output_clear, s.vvm.add_rdy) else: s.connect(s.output_clear, 1) @s.combinational def handle_advance(): s.advance.v = s.output_clear and s.input_available
def __init__(s, interface): UseInterface(s, interface) s.out_temp = Wire(s.interface.Out) # PYMTL_BROKEN # Use temporary wire to prevent pymtl bug @s.combinational def compute_out(): s.out_temp.v = s.fuse_internal_out s.out_temp.hdr_branch_mask.v = s.fuse_kill_data s.connect(s.fuse_out, s.out_temp)
def __init__(s, interface, key_groups, reset_values=None): """ key_groups is a list of tuples like [(a, b), c]. Each element in the list is a key. If an element is a tuple, all elements of the tuple alias each other. """ UseInterface(s, interface) size = len(key_groups) Key = s.interface.Key Value = s.interface.Value num_read_ports = s.interface.num_read_ports num_write_ports = s.interface.num_write_ports mapping = {} for i, group in enumerate(key_groups): if not isinstance(group, tuple): group = (group, ) for key in group: assert key not in mapping mapping[key] = i s.async_ram = AsynchronousRAM(AsynchronousRAMInterface( Value, size, num_read_ports, num_write_ports), reset_values=reset_values) def make_lut(): return LookupTable( LookupTableInterface(Key, s.async_ram.interface.Addr), mapping) s.read_luts = [make_lut() for _ in range(num_read_ports)] s.write_luts = [make_lut() for _ in range(num_write_ports)] for i in range(num_read_ports): s.connect(s.read_luts[i].lookup_in_, s.read_key[i]) s.connect(s.async_ram.read_addr[i], s.read_luts[i].lookup_out) s.connect(s.read_valid[i], s.read_luts[i].lookup_valid) @s.combinational def handle_invalid_read(i=i): if s.read_luts[i].lookup_valid: s.read_value[i].v = s.async_ram.read_data[i] else: s.read_value[i].v = 0 for i in range(num_write_ports): s.connect(s.write_luts[i].lookup_in_, s.write_key[i]) s.connect(s.async_ram.write_addr[i], s.write_luts[i].lookup_out) s.connect(s.async_ram.write_data[i], s.write_value[i]) s.connect(s.write_valid[i], s.write_luts[i].lookup_valid) @s.combinational def compute_write_call(i=i): s.async_ram.write_call[ i].v = s.write_luts[i].lookup_valid and s.write_call[i]
def __init__(s, num_stages): UseInterface(s, MultInternalInterface()) # Require the methods of an incoming pipeline stage # Name the methods in_peek, in_take s.require(*[ m.variant(name='in_{}'.format(m.name)) for m in PipelineStageInterface(MultIn(), None).methods.values() ]) s.multiplier = MulPipelined(MulPipelinedInterface(XLEN, keep_upper=True), nstages=num_stages, use_mul=True) # PYMTL_BROKEN # double array in combinational block s.in_peek_msg_a = Wire(XLEN) s.in_peek_msg_b = Wire(XLEN) s.connect(s.in_peek_msg_a, s.in_peek_msg.a) s.connect(s.in_peek_msg_b, s.in_peek_msg.b) s.op32_a = Wire(32) s.op32_b = Wire(32) @s.combinational def set_inputs(): s.op32_a.v = s.in_peek_msg_a[:32] s.op32_b.v = s.in_peek_msg_b[:32] s.multiplier.mult_src1_signed.v = not ( s.in_peek_msg.variant == MVariant.M_VARIANT_U or s.in_peek_msg.variant == MVariant.M_VARIANT_HU) s.multiplier.mult_src2_signed.v = ( s.in_peek_msg.variant == MVariant.M_VARIANT_N or s.in_peek_msg.variant == MVariant.M_VARIANT_H) if s.in_peek_msg.op32: s.multiplier.mult_src1.v = sext(s.op32_a, XLEN) s.multiplier.mult_src2.v = sext(s.op32_b, XLEN) else: s.multiplier.mult_src1.v = s.in_peek_msg.a s.multiplier.mult_src2.v = s.in_peek_msg.b s.connect(s.multiplier.mult_call, s.in_take_call) @s.combinational def set_in_take_call(): s.in_take_call.v = s.multiplier.mult_rdy and s.in_peek_rdy # Connect output s.connect(s.multiplier.take_call, s.take_call) s.connect(s.peek_rdy, s.multiplier.peek_rdy) s.connect(s.peek_msg, s.multiplier.peek_res)