def __init__(s, nbits=8, mbits=1): # Interface s.req_msg_data = InPort(nbits) s.resp_msg_data = OutPort(1) s.result_reg_en = InPort(1) s.reference_reg_en = InPort(1) s.reference_data = Wire(Bits(nbits)) s.cmp_result = Wire(Bits(1)) # Register for reference data s.Reg_reference_data = m = RegEnRst(nbits) s.connect_dict({ m.en: s.reference_reg_en, m.in_: s.req_msg_data, m.out: s.reference_data, }) # Zero comparator s.EqComparator_1 = m = EqComparator(nbits) s.connect_dict({ m.in0: s.req_msg_data, m.in1: s.reference_data, m.out: s.cmp_result, }) # Register for compare result s.Reg_cmp_result = m = RegEnRst(1) s.connect_dict({ m.en: s.result_reg_en, m.in_: s.cmp_result, m.out: s.resp_msg_data, })
def __init__(s, nbits=6, k=3): # Interface s.req_msg_data = InPort(nbits) s.resp_msg_data = OutPort(nbits) s.resp_msg_idx = OutPort(int(math.ceil(math.log(k, 2)))) # dpath->ctrl s.isLarger = OutPort(1) # ctrl->dapth s.max_reg_en = InPort(1) s.idx_reg_en = InPort(1) s.knn_mux_sel = InPort(1) s.knn_counter = InPort(int(math.ceil(math.log(k, 2)))) # max 3 # Internal Signals s.knn_data0 = Wire(Bits(nbits)) s.connect(s.req_msg_data, s.knn_data0) # knn Mux s.knn_data1 = Wire(Bits(nbits)) s.max_reg_out = Wire(Bits(nbits)) s.knn_mux = m = Mux(nbits, 2) s.connect_dict({ m.sel: s.knn_mux_sel, m.in_[0]: s.req_msg_data, m.in_[1]: s.max_reg_out, m.out: s.knn_data1 }) # Greater than comparator s.knn_GtComparator = m = GtComparator(nbits) s.connect_dict({ m.in0: s.knn_data0, m.in1: s.knn_data1, m.out: s.isLarger }) # Max Reg s.max_reg = m = RegEnRst(nbits) s.connect_dict({ m.en: s.max_reg_en, m.in_: s.knn_data0, m.out: s.max_reg_out }) # Idx Reg s.idx_reg = m = RegEnRst(int(math.ceil(math.log(k, 2)))) # max 2 s.connect_dict({ m.en: s.idx_reg_en, m.in_: s.knn_counter, m.out: s.resp_msg_idx }) s.connect(s.max_reg_out, s.resp_msg_data)
def __init__( s, nreqs ): nreqsX2 = nreqs * 2 s.reqs = InPort ( nreqs ) s.grants = OutPort( nreqs ) # priority enable s.priority_en = Wire( 1 ) # priority register s.priority_reg = m = RegEnRst( nreqs, reset_value = 1 ) s.connect_dict({ m.en : s.priority_en, m.in_[1:nreqs] : s.grants[0:nreqs-1], m.in_[0] : s.grants[nreqs-1] }) s.kills = Wire( 2*nreqs + 1 ) s.priority_int = Wire( 2*nreqs ) s.reqs_int = Wire( 2*nreqs ) s.grants_int = Wire( 2*nreqs ) #------------------------------------------------------------------- # comb #------------------------------------------------------------------- @s.combinational def comb(): s.kills[0].value = 1 s.priority_int[ 0:nreqs ].value = s.priority_reg.out s.priority_int[nreqs:nreqsX2].value = 0 s.reqs_int [ 0:nreqs ].value = s.reqs s.reqs_int [nreqs:nreqsX2].value = s.reqs # Calculate the kill chain for i in range( nreqsX2 ): # Set internal grants if s.priority_int[i].value: s.grants_int[i].value = s.reqs_int[i] else: s.grants_int[i].value = ~s.kills[i] & s.reqs_int[i] # Set kill signals if s.priority_int[i].value: s.kills[i+1].value = s.grants_int[i] else: s.kills[i+1].value = s.kills[i] | s.grants_int[i] # Assign the output ports for i in range( nreqs ): s.grants[i].value = s.grants_int[i] | s.grants_int[nreqs+i] # Set the priority enable s.priority_en.value = ( s.grants != 0 )
def __init__(s, num_cores=1): #--------------------------------------------------------------------- # Interface #--------------------------------------------------------------------- # Parameters s.core_id = InPort(32) # imem ports s.imemreq_msg = OutPort(MemReqMsg4B) s.imemresp_msg_data = InPort(32) # dmem ports s.dmemreq_msg_addr = OutPort(32) s.dmemreq_msg_data = OutPort(32) s.dmemresp_msg_data = InPort(32) # mngr ports s.mngr2proc_data = InPort(32) s.proc2mngr_data = OutPort(32) # Control signals (ctrl->dpath) s.reg_en_F = InPort(1) s.pc_sel_F = InPort(2) s.reg_en_D = InPort(1) s.op1_sel_D = InPort(1) s.op2_sel_D = InPort(2) s.csrr_sel_D = InPort(2) s.imm_type_D = InPort(3) s.imul_req_val_D = InPort(1) s.reg_en_X = InPort(1) s.alu_fn_X = InPort(4) s.ex_result_sel_X = InPort(2) s.imul_resp_rdy_X = InPort(1) s.reg_en_M = InPort(1) s.wb_result_sel_M = InPort(1) s.reg_en_W = InPort(1) s.rf_waddr_W = InPort(5) s.rf_wen_W = InPort(1) s.stats_en_wen_W = InPort(1) # Status signals (dpath->Ctrl) s.inst_D = OutPort(32) s.imul_req_rdy_D = OutPort(1) s.br_cond_eq_X = OutPort(1) s.br_cond_ltu_X = OutPort(1) s.br_cond_lt_X = OutPort(1) s.imul_resp_val_X = OutPort(1) # stats_en output s.stats_en = OutPort(1) #--------------------------------------------------------------------- # F stage #--------------------------------------------------------------------- s.pc_F = Wire(32) s.pc_plus4_F = Wire(32) # PC+4 incrementer s.pc_incr_F = m = Incrementer(nbits=32, increment_amount=4) s.connect_pairs(m.in_, s.pc_F, m.out, s.pc_plus4_F) # forward delaration for branch target and jal target s.br_target_X = Wire(32) s.jal_target_D = Wire(32) s.jalr_target_X = Wire(32) # PC sel mux s.pc_sel_mux_F = m = Mux(dtype=32, nports=4) s.connect_pairs(m.in_[0], s.pc_plus4_F, m.in_[1], s.br_target_X, m.in_[2], s.jal_target_D, m.in_[3], s.jalr_target_X, m.sel, s.pc_sel_F) @s.combinational def imem_req_F(): s.imemreq_msg.addr.value = s.pc_sel_mux_F.out # PC register s.pc_reg_F = m = RegEnRst(dtype=32, reset_value=c_reset_vector - 4) s.connect_pairs(m.en, s.reg_en_F, m.in_, s.pc_sel_mux_F.out, m.out, s.pc_F) #--------------------------------------------------------------------- # D stage #--------------------------------------------------------------------- # PC reg in D stage # This value is basically passed from F stage for the corresponding # instruction to use, e.g. branch to (PC+imm) s.pc_reg_D = m = RegEnRst(dtype=32) s.connect_pairs( m.en, s.reg_en_D, m.in_, s.pc_F, ) # Instruction reg s.inst_D_reg = m = RegEnRst(dtype=32, reset_value=c_reset_inst) s.connect_pairs( m.en, s.reg_en_D, m.in_, s.imemresp_msg_data, m.out, s.inst_D # to ctrl ) # Register File # The rf_rdata_D wires, albeit redundant in some sense, are used to # remind people these data are from D stage. s.rf_rdata0_D = Wire(32) s.rf_rdata1_D = Wire(32) s.rf_wdata_W = Wire(32) s.rf = m = RegisterFile(dtype=32, nregs=32, rd_ports=2, const_zero=True) s.connect_pairs(m.rd_addr[0], s.inst_D[RS1], m.rd_addr[1], s.inst_D[RS2], m.rd_data[0], s.rf_rdata0_D, m.rd_data[1], s.rf_rdata1_D, m.wr_en, s.rf_wen_W, m.wr_addr, s.rf_waddr_W, m.wr_data, s.rf_wdata_W) # Immediate generator s.imm_gen_D = m = ImmGenPRTL() s.connect_pairs(m.imm_type, s.imm_type_D, m.inst, s.inst_D) # csrr sel mux s.csrr_sel_mux_D = m = Mux(dtype=32, nports=3) s.connect_pairs( m.in_[0], s.mngr2proc_data, m.in_[1], num_cores, m.in_[2], s.core_id, m.sel, s.csrr_sel_D, ) # op1 sel mux s.op1_sel_mux_D = m = Mux(dtype=32, nports=2) s.connect_pairs( m.in_[0], s.rf_rdata0_D, m.in_[1], s.pc_reg_D.out, m.sel, s.op1_sel_D, ) # op2 sel mux # This mux chooses among RS2, imm, and the output of the above csrr # sel mux. Basically we are using two muxes here for pedagogy. s.op2_sel_mux_D = m = Mux(dtype=32, nports=3) s.connect_pairs( m.in_[0], s.rf_rdata1_D, m.in_[1], s.imm_gen_D.imm, m.in_[2], s.csrr_sel_mux_D.out, m.sel, s.op2_sel_D, ) # Risc-V always calcs branch/jal target by adding imm(generated above) to PC s.pc_plus_imm_D = m = Adder(32) s.connect_pairs( m.in0, s.pc_reg_D.out, m.in1, s.imm_gen_D.imm, m.out, s.jal_target_D, ) #--------------------------------------------------------------------- # X stage #--------------------------------------------------------------------- # br_target_reg_X # Since branches are resolved in X stage, we register the target, # which is already calculated in D stage, to X stage. s.br_target_reg_X = m = RegEnRst(dtype=32, reset_value=0) s.connect_pairs(m.en, s.reg_en_X, m.in_, s.pc_plus_imm_D.out, m.out, s.br_target_X) # op1 reg s.op1_reg_X = m = RegEnRst(dtype=32, reset_value=0) s.connect_pairs( m.en, s.reg_en_X, m.in_, s.op1_sel_mux_D.out, ) # op2 reg s.op2_reg_X = m = RegEnRst(dtype=32, reset_value=0) s.connect_pairs( m.en, s.reg_en_X, m.in_, s.op2_sel_mux_D.out, ) # dmemreq data reg s.dmem_write_data_reg_X = m = RegEnRst(dtype=32, reset_value=0) s.connect_pairs( m.en, s.reg_en_X, m.in_, s.rf_rdata1_D, ) # pc reg s.pc_reg_X = m = RegEnRst(dtype=32, reset_value=0) s.connect_pairs( m.en, s.reg_en_X, m.in_, s.pc_reg_D.out, ) # ALU s.alu_X = m = AluPRTL() s.connect_pairs( m.in0, s.op1_reg_X.out, m.in1, s.op2_reg_X.out, m.fn, s.alu_fn_X, m.ops_eq, s.br_cond_eq_X, m.ops_ltu, s.br_cond_ltu_X, m.ops_lt, s.br_cond_lt_X, m.out, s.jalr_target_X, ) # Multiplier s.imul_X = m = IntMulAltRTL() s.connect_pairs( m.req.msg[0:32], s.op1_sel_mux_D.out, m.req.msg[32:64], s.op2_sel_mux_D.out, m.req.val, s.imul_req_val_D, m.req.rdy, s.imul_req_rdy_D, m.resp.val, s.imul_resp_val_X, m.resp.rdy, s.imul_resp_rdy_X, ) # PC+4 Incrementer s.pc_incr_X = m = Incrementer(nbits=32, increment_amount=4) s.connect_pairs( m.in_, s.pc_reg_X.out, ) # ex result Mux s.ex_result_sel_mux_X = m = Mux(dtype=32, nports=3) s.connect_pairs( m.in_[0], s.pc_incr_X.out, m.in_[1], s.alu_X.out, m.in_[2], s.imul_X.resp.msg, m.sel, s.ex_result_sel_X, ) # dmemreq address s.connect(s.dmemreq_msg_addr, s.alu_X.out) s.connect(s.dmemreq_msg_data, s.dmem_write_data_reg_X.out) #--------------------------------------------------------------------- # M stage #--------------------------------------------------------------------- # Alu execution result reg s.ex_result_reg_M = m = RegEnRst(dtype=32, reset_value=0) s.connect_pairs( m.en, s.reg_en_M, m.in_, s.ex_result_sel_mux_X.out, ) # Writeback result selection mux s.wb_result_sel_mux_M = m = Mux(dtype=32, nports=2) s.connect_pairs(m.in_[0], s.ex_result_reg_M.out, m.in_[1], s.dmemresp_msg_data, m.sel, s.wb_result_sel_M) #--------------------------------------------------------------------- # W stage #--------------------------------------------------------------------- # Writeback result reg s.wb_result_reg_W = m = RegEnRst(dtype=32, reset_value=0) s.connect_pairs( m.en, s.reg_en_W, m.in_, s.wb_result_sel_mux_M.out, ) s.connect(s.proc2mngr_data, s.wb_result_reg_W.out) s.connect(s.rf_wdata_W, s.wb_result_reg_W.out) s.stats_en_reg_W = m = RegEnRst(dtype=32, reset_value=0) # stats_en logic s.connect_pairs( m.en, s.stats_en_wen_W, m.in_, s.wb_result_reg_W.out, ) @s.combinational def stats_en_logic_W(): s.stats_en.value = any( s.stats_en_reg_W.out) # reduction with bitwise OR
def __init__(s): #------------------------------------------------------------------- # Input Ports #------------------------------------------------------------------- s.pvalid = InPort(1) s.nstall = InPort(1) s.nsquash = InPort(1) s.ostall = InPort(1) s.osquash = InPort(1) #------------------------------------------------------------------- # Output Ports #------------------------------------------------------------------- s.nvalid = OutPort(1) s.pstall = OutPort(1) s.psquash = OutPort(1) s.pipereg_en = OutPort(1) s.pipereg_val = OutPort(1) s.pipe_go = OutPort(1) #------------------------------------------------------------------- # Static Elaboration #------------------------------------------------------------------- # current pipeline stage valid bit register s.val_reg = RegEnRst(1, reset_value=0) s.connect(s.val_reg.in_, s.pvalid) # combinationally read out the valid bit of the current state and # assign it to pipereg_val s.connect(s.val_reg.out, s.pipereg_val) #--------------------------------------------------------------------- # Combinational Logic #--------------------------------------------------------------------- @s.combinational def comb(): # Insert microarchitectural 'nop' value when the current stage is # squashed due to nsquash or when the current stage is stalled due # to ostall or when the current stage is stalled due to nstall. # Otherwise pipeline the valid bit if s.nsquash.value or s.nstall.value or s.ostall.value: s.nvalid.value = 0 else: s.nvalid.value = s.val_reg.out.value # Enable the pipeline registers when the current stage is squashed # due to nsquash or when the current stage is not stalling due to # the ostall or nstall. Otherwise do not set the enable signal if s.nsquash.value or not (s.nstall.value or s.ostall.value): s.pipereg_en.value = 1 s.val_reg.en.value = 1 else: s.pipereg_en.value = 0 s.val_reg.en.value = 0 # Set pipego when the current stage is not squashed and not # stalled and if the valid bit is set. Else a pipeline transaction # will not occur. if (s.val_reg.out.value and not s.nsquash.value and not s.nstall.value and not s.ostall.value): s.pipe_go.value = 1 else: s.pipe_go.value = 0 # Accumulate stall signals s.pstall.value = s.nstall.value | s.ostall.value # Accumulate squash signals s.psquash.value = s.nsquash.value | s.osquash.value
def __init__(s, mbits=1): # Interface s.req_val = InPort(1) s.req_rdy = OutPort(1) s.resp_val = OutPort(1) s.resp_rdy = InPort(1) s.req_msg_type = InPort(mbits) s.resp_msg_type = OutPort(mbits) s.result_reg_en = OutPort(1) s.reference_reg_en = OutPort(1) s.msg_type_reg_en = Wire(Bits(1)) s.STATE_IDLE = 0 s.STATE_CMP = 1 # do compare or store reference data s.state = RegRst(1, reset_value=s.STATE_IDLE) #------------------------------------------------------ # state transtion logic #------------------------------------------------------ @s.combinational def state_transitions(): curr_state = s.state.out next_state = s.state.out if (curr_state == s.STATE_IDLE): if (s.req_val and s.req_rdy): next_state = s.STATE_CMP if (curr_state == s.STATE_CMP): if (s.resp_val and s.resp_rdy): next_state = s.STATE_IDLE s.state.in_.value = next_state #------------------------------------------------------ # state output logic #------------------------------------------------------ @s.combinational def state_outputs(): current_state = s.state.out if (current_state == s.STATE_IDLE): s.req_rdy.value = 1 s.resp_val.value = 0 s.result_reg_en.value = 1 s.reference_reg_en.value = s.req_msg_type s.msg_type_reg_en.value = 1 elif (current_state == s.STATE_CMP): s.req_rdy.value = 0 s.resp_val.value = 1 s.result_reg_en.value = 0 s.reference_reg_en.value = 0 s.msg_type_reg_en.value = 0 # Register for resp msg type s.Reg_msg_type = m = RegEnRst(1) s.connect_dict({ m.en: s.msg_type_reg_en, m.in_: s.req_msg_type, m.out: s.resp_msg_type, })
def __init__(s, mapper_num=3, nbits=6, k=3, rst_value=50): addr_nbits = int(math.ceil(math.log(k, 2))) sum_nbits = int(math.ceil(math.log((2**nbits - 1) * k, 2))) # interface s.in_ = [InPort(nbits) for _ in range(mapper_num)] s.out = OutPort(sum_nbits) s.rst = InPort(1) # internal wires s.min_dist = Wire(Bits(nbits)) s.max_dist = Wire(Bits(nbits)) s.max_idx = Wire(Bits(addr_nbits)) s.isLess = Wire(Bits(1)) s.rd_data = Wire[k](Bits(nbits)) s.wr_data = Wire(Bits(nbits)) s.wr_addr = Wire[k](Bits(addr_nbits)) s.wr_en = Wire[k](Bits(1)) # find min of inputs s.findmin = FindMin(nbits, mapper_num) for i in xrange(mapper_num): s.connect_pairs(s.findmin.in_[i], s.in_[i]) s.connect_pairs(s.findmin.out, s.min_dist) # find max in knn_table s.findmax = FindMaxIdx(nbits, k) for i in xrange(k): s.connect_pairs(s.findmax.in_[i], s.rd_data[i]) s.connect_pairs(s.findmax.out, s.max_dist) s.connect_pairs(s.findmax.idx, s.max_idx) # compare min_dist and max_dist s.cmp = LtComparator(nbits) s.connect_pairs(s.cmp.in0, s.min_dist, s.cmp.in1, s.max_dist, s.cmp.out, s.isLess) # choose the smaller one write back @s.combinational def comb_logic(): if (s.rst == 1): for i in xrange(k): s.wr_en[i].value = 1 else: for i in xrange(k): if (i == s.max_idx): s.wr_en[i].value = s.isLess else: s.wr_en[i].value = 0 # mux for wr_data s.wr_data_mux = m = Mux(nbits, 2) s.connect_dict({ m.in_[0]: s.min_dist, m.in_[1]: rst_value, m.sel: s.rst, m.out: s.wr_data }) # Registers s.regs = [RegEnRst(nbits, rst_value) for i in xrange(k)] for i in xrange(k): s.connect_pairs(s.regs[i].in_, s.wr_data, s.regs[i].out, s.rd_data[i], s.regs[i].en, s.wr_en[i]) # sum of knn_table s.addertree = AdderTree(nbits, k) for i in xrange(k): s.connect_pairs(s.addertree.in_[i], s.rd_data[i]) s.connect(s.addertree.out, s.out)
def __init__(s, k=3): SUM_DATA_SIZE = int(math.ceil(math.log(50 * k, 2))) s.req_msg_data = InPort(RE_DATA_SIZE) s.resp_msg_digit = OutPort(4) # ctrl->dpath s.knn_wr_data_mux_sel = InPort(1) s.knn_wr_addr = InPort(int(math.ceil(math.log(k * DIGIT, 2)))) # max 30 s.knn_rd_addr = InPort(int(math.ceil(math.log(k * DIGIT, 2)))) # max 30 s.knn_wr_en = InPort(1) s.vote_wr_data_mux_sel = InPort(1) s.vote_wr_addr = InPort(int(math.ceil(math.log(DIGIT, 2)))) # max 10 s.vote_rd_addr = InPort(int(math.ceil(math.log(DIGIT, 2)))) # max 10 s.vote_wr_en = InPort(1) s.FindMax_req_val = InPort(1) s.FindMax_resp_rdy = InPort(1) s.FindMin_req_val = InPort(1) s.FindMin_resp_rdy = InPort(1) s.msg_data_reg_en = InPort(1) s.msg_idx_reg_en = InPort(1) # dpath->ctrl s.FindMax_req_rdy = OutPort(1) s.FindMax_resp_val = OutPort(1) s.FindMax_resp_idx = OutPort(int(math.ceil(math.log(k, 2)))) # max 3 s.FindMin_req_rdy = OutPort(1) s.FindMin_resp_val = OutPort(1) s.isSmaller = OutPort(1) # internal wires s.knn_rd_data = Wire(Bits(RE_DATA_SIZE)) s.knn_wr_data = Wire(Bits(RE_DATA_SIZE)) s.subtractor_out = Wire(Bits(SUM_DATA_SIZE)) s.adder_out = Wire(Bits(SUM_DATA_SIZE)) s.vote_rd_data = Wire(Bits(SUM_DATA_SIZE)) s.vote_wr_data = Wire(Bits(SUM_DATA_SIZE)) s.FindMax_req_data = Wire(Bits(RE_DATA_SIZE)) s.FindMax_resp_data = Wire(Bits(RE_DATA_SIZE)) s.FindMin_req_data = Wire(Bits(SUM_DATA_SIZE)) s.FindMin_resp_data = Wire(Bits(SUM_DATA_SIZE)) s.FindMin_resp_idx = Wire(Bits(int(math.ceil(math.log(DIGIT, 2))))) # max 10 # Req msg data Register s.req_msg_data_q = Wire(Bits(RE_DATA_SIZE)) s.req_msg_data_reg = m = RegEnRst(RE_DATA_SIZE) s.connect_dict({ m.en: s.msg_data_reg_en, m.in_: s.req_msg_data, m.out: s.req_msg_data_q }) # knn_wr_data Mux s.knn_wr_data_mux = m = Mux(RE_DATA_SIZE, 2) s.connect_dict({ m.sel: s.knn_wr_data_mux_sel, m.in_[0]: 50, m.in_[1]: s.req_msg_data_q, m.out: s.knn_wr_data }) # register file knn_table s.knn_table = m = RegisterFile(dtype=Bits(RE_DATA_SIZE), nregs=k * DIGIT, rd_ports=1, wr_ports=1, const_zero=False) s.connect_dict({ m.rd_addr[0]: s.knn_rd_addr, m.rd_data[0]: s.knn_rd_data, m.wr_addr: s.knn_wr_addr, m.wr_data: s.knn_wr_data, m.wr_en: s.knn_wr_en }) # vote_wr_data Mux s.vote_wr_data_mux = m = Mux(SUM_DATA_SIZE, 2) s.connect_dict({ m.sel: s.vote_wr_data_mux_sel, m.in_[0]: 50 * k, m.in_[1]: s.adder_out, m.out: s.vote_wr_data }) # register file knn_vote s.knn_vote = m = RegisterFile(dtype=Bits(SUM_DATA_SIZE), nregs=DIGIT, rd_ports=1, wr_ports=1, const_zero=False) s.connect_dict({ m.rd_addr[0]: s.vote_rd_addr, m.rd_data[0]: s.vote_rd_data, m.wr_addr: s.vote_wr_addr, m.wr_data: s.vote_wr_data, m.wr_en: s.vote_wr_en }) # Find max value of knn_table for a given digit s.connect_wire(s.knn_rd_data, s.FindMax_req_data) s.findmax = m = FindMaxPRTL(RE_DATA_SIZE, k) s.connect_dict({ m.req.val: s.FindMax_req_val, m.req.rdy: s.FindMax_req_rdy, m.req.msg.data: s.FindMax_req_data, m.resp.val: s.FindMax_resp_val, m.resp.rdy: s.FindMax_resp_rdy, m.resp.msg.data: s.FindMax_resp_data, m.resp.msg.idx: s.FindMax_resp_idx }) # Less than comparator s.knn_LtComparator = m = LtComparator(RE_DATA_SIZE) s.connect_dict({ m.in0: s.req_msg_data_q, m.in1: s.FindMax_resp_data, m.out: s.isSmaller }) # Zero extender s.FindMax_resp_data_zext = Wire(Bits(SUM_DATA_SIZE)) s.FindMax_resp_data_zexter = m = ZeroExtender(RE_DATA_SIZE, SUM_DATA_SIZE) s.connect_dict({ m.in_: s.FindMax_resp_data, m.out: s.FindMax_resp_data_zext, }) # Subtractor s.subtractor = m = Subtractor(SUM_DATA_SIZE) s.connect_dict({ m.in0: s.vote_rd_data, m.in1: s.FindMax_resp_data_zext, m.out: s.subtractor_out }) # Zero extender s.req_msg_data_zext = Wire(Bits(SUM_DATA_SIZE)) s.req_msg_data_zexter = m = ZeroExtender(RE_DATA_SIZE, SUM_DATA_SIZE) s.connect_dict({ m.in_: s.req_msg_data_q, m.out: s.req_msg_data_zext, }) # Adder s.adder = m = Adder(SUM_DATA_SIZE) s.connect_dict({ m.in0: s.subtractor_out, m.in1: s.req_msg_data_zext, m.cin: 0, m.out: s.adder_out }) # Find min value of knn_vote, return digit s.connect_wire(s.vote_rd_data, s.FindMin_req_data) s.findmin = m = FindMinPRTL(SUM_DATA_SIZE, DIGIT) s.connect_dict({ m.req.val: s.FindMin_req_val, m.req.rdy: s.FindMin_req_rdy, m.req.msg.data: s.FindMin_req_data, m.resp.val: s.FindMin_resp_val, m.resp.rdy: s.FindMin_resp_rdy, m.resp.msg.data: s.FindMin_resp_data, m.resp.msg.digit: s.FindMin_resp_idx }) # Resp idx Register s.resp_msg_idx_q = Wire(Bits(int(math.ceil(math.log(DIGIT, 2))))) s.req_msg_idx_reg = m = RegEnRst(int(math.ceil(math.log(DIGIT, 2)))) s.connect_dict({ m.en: s.msg_idx_reg_en, m.in_: s.FindMin_resp_idx, m.out: s.resp_msg_idx_q }) # connect output idx s.connect(s.resp_msg_idx_q, s.resp_msg_digit)
def __init__(s, k=3): # ctrl to toplevel s.req_val = InPort(1) s.req_rdy = OutPort(1) s.resp_val = OutPort(1) s.resp_rdy = InPort(1) s.req_msg_digit = InPort(4) s.req_msg_type = InPort(2) s.resp_msg_type = OutPort(2) # ctrl->dpath s.knn_wr_data_mux_sel = OutPort(1) s.knn_wr_addr = OutPort(int(math.ceil(math.log(k * DIGIT, 2)))) # max 30 s.knn_rd_addr = OutPort(int(math.ceil(math.log(k * DIGIT, 2)))) # max 30 s.knn_wr_en = OutPort(1) s.vote_wr_data_mux_sel = OutPort(1) s.vote_wr_addr = OutPort(int(math.ceil(math.log(DIGIT, 2)))) # max 10 s.vote_rd_addr = OutPort(int(math.ceil(math.log(DIGIT, 2)))) # max 10 s.vote_wr_en = OutPort(1) s.FindMax_req_val = OutPort(1) s.FindMax_resp_rdy = OutPort(1) s.FindMin_req_val = OutPort(1) s.FindMin_resp_rdy = OutPort(1) s.msg_data_reg_en = OutPort(1) s.msg_idx_reg_en = OutPort(1) # dpath->ctrl s.FindMax_req_rdy = InPort(1) s.FindMax_resp_val = InPort(1) s.FindMax_resp_idx = InPort(int(math.ceil(math.log(k, 2)))) # max 10 s.FindMin_req_rdy = InPort(1) s.FindMin_resp_val = InPort(1) s.isSmaller = InPort(1) s.msg_type_reg_en = Wire(Bits(1)) s.msg_digit_reg_en = Wire(Bits(1)) s.req_msg_digit_q = Wire(Bits(4)) s.init_go = Wire(Bits(1)) s.max_go = Wire(Bits(1)) s.min_go = Wire(Bits(1)) # State element s.STATE_IDLE = 0 s.STATE_INIT = 1 s.STATE_MAX = 2 s.STATE_MIN = 3 s.STATE_DONE = 4 s.state = RegRst(3, reset_value=s.STATE_IDLE) # Counters s.init_count = Wire(int(math.ceil(math.log(k * DIGIT, 2)))) # max 30 s.knn_count = Wire(int(math.ceil(math.log(k * DIGIT, 2)))) # max 30 s.vote_count = Wire(int(math.ceil(math.log(DIGIT, 2)))) # max 10 @s.tick def counter(): if (s.init_go == 1): s.init_count.next = s.init_count + 1 else: s.init_count.next = 0 if (s.max_go == 1): s.knn_count.next = s.knn_count + 1 else: s.knn_count.next = s.req_msg_digit * k if (s.min_go == 1): s.vote_count.next = s.vote_count + 1 else: s.vote_count.next = 0 # State Transition Logic @s.combinational def state_transitions(): curr_state = s.state.out next_state = s.state.out # Transition out of IDLE state if (curr_state == s.STATE_IDLE): if ((s.req_val and s.req_rdy) and (s.req_msg_type == 1)): next_state = s.STATE_INIT elif ((s.req_val and s.req_rdy) and (s.req_msg_type == 0) and (s.FindMax_req_val and s.FindMax_req_rdy)): next_state = s.STATE_MAX elif ((s.req_val and s.req_rdy) and (s.req_msg_type == 2) and (s.FindMin_req_val and s.FindMin_req_rdy)): next_state = s.STATE_MIN # Transition out of INIT state if (curr_state == s.STATE_INIT): if (s.init_count == k * DIGIT - 1): next_state = s.STATE_DONE # Transition out of MAX state if (curr_state == s.STATE_MAX): if (s.FindMax_resp_val and s.FindMax_resp_rdy): next_state = s.STATE_DONE # Transition out of MIN state if (curr_state == s.STATE_MIN): if (s.FindMin_resp_val and s.FindMin_resp_rdy): next_state = s.STATE_DONE # Transition out of DONE state if (curr_state == s.STATE_DONE): if (s.resp_val and s.resp_rdy): next_state = s.STATE_IDLE s.state.in_.value = next_state # State Output Logic @s.combinational def state_outputs(): current_state = s.state.out # IDLE state if current_state == s.STATE_IDLE: s.req_rdy.value = 1 s.resp_val.value = 0 s.msg_data_reg_en.value = 1 s.msg_digit_reg_en.value = 1 s.msg_type_reg_en.value = 1 s.msg_idx_reg_envalue = 0 s.init_go.value = 0 s.knn_wr_data_mux_sel.value = 0 s.knn_wr_en.value = 0 s.knn_wr_addr.value = 0 if ((s.req_val and s.req_rdy) and (s.req_msg_type == 0)): s.max_go.value = 1 s.FindMax_req_val.value = 1 s.FindMax_resp_rdy.value = 0 s.knn_rd_addr.value = s.knn_count s.min_go.value = 0 s.FindMin_req_val.value = 0 s.FindMin_resp_rdy.value = 0 s.vote_rd_addr.value = s.req_msg_digit_q elif ((s.req_val and s.req_rdy) and (s.req_msg_type == 2)): s.max_go.value = 0 s.FindMax_req_val.value = 0 s.FindMax_resp_rdy.value = 0 s.knn_rd_addr.value = 0 s.min_go.value = 1 s.FindMin_req_val.value = 1 s.FindMin_resp_rdy.value = 0 s.vote_rd_addr.value = s.vote_count else: s.max_go.value = 0 s.FindMax_req_val.value = 0 s.FindMax_resp_rdy.value = 0 s.knn_rd_addr.value = 0 s.min_go.value = 0 s.FindMin_req_val.value = 0 s.FindMin_resp_rdy.value = 0 s.vote_rd_addr.value = s.req_msg_digit_q s.vote_wr_data_mux_sel.value = 1 s.vote_wr_en.value = 0 s.vote_wr_addr.value = s.req_msg_digit_q # INI state elif current_state == s.STATE_INIT: s.req_rdy.value = 0 s.resp_val.value = 0 s.msg_data_reg_en.value = 0 s.msg_digit_reg_en.value = 0 s.msg_type_reg_en.value = 0 s.msg_idx_reg_en.value = 0 s.init_go.value = 1 s.max_go.value = 0 s.min_go.value = 0 s.FindMin_req_val.value = 0 s.FindMin_resp_rdy.value = 0 s.knn_wr_data_mux_sel.value = 0 s.knn_wr_en.value = 1 s.knn_wr_addr.value = s.init_count s.FindMax_req_val.value = 0 s.FindMax_resp_rdy.value = 0 # knn_rd for debugging if s.init_count == 0: s.knn_rd_addr.value = 0 else: s.knn_rd_addr.value = s.init_count - 1 if (s.init_count < DIGIT): s.vote_wr_data_mux_sel.value = 0 s.vote_wr_en.value = 1 s.vote_wr_addr.value = s.init_count # vote_rd for debugging if s.init_count == 0: s.vote_rd_addr.value = 0 else: s.vote_rd_addr.value = s.init_count - 1 else: s.vote_wr_data_mux_sel.value = 1 s.vote_wr_en.value = 0 s.vote_wr_addr.value = 0 s.vote_rd_addr.value = 0 # MAX state elif current_state == s.STATE_MAX: s.req_rdy.value = 0 s.resp_val.value = 0 s.msg_data_reg_en.value = 0 s.msg_digit_reg_en.value = 0 s.msg_type_reg_en.value = 0 s.msg_idx_reg_en.value = 0 s.init_go.value = 0 s.max_go.value = 1 s.min_go.value = 0 s.FindMin_req_val.value = 0 s.FindMin_resp_rdy.value = 0 s.FindMax_req_val.value = 1 s.FindMax_resp_rdy.value = 1 if (s.knn_count > s.req_msg_digit_q * k + 2): s.knn_rd_addr.value = 0 s.knn_wr_addr.value = s.req_msg_digit_q * k + s.FindMax_resp_idx if (s.isSmaller == 1): s.knn_wr_data_mux_sel.value = 1 s.knn_wr_en.value = 1 s.vote_wr_en.value = 1 else: s.knn_wr_data_mux_sel.value = 0 s.knn_wr_en.value = 0 else: s.knn_wr_data_mux_sel.value = 0 s.knn_rd_addr.value = s.knn_count s.knn_wr_en.value = 0 s.knn_wr_addr.value = 0 s.vote_wr_en.value = 0 s.vote_wr_data_mux_sel.value = 1 s.vote_wr_addr.value = s.req_msg_digit_q s.vote_rd_addr.value = s.req_msg_digit_q # MIN state elif current_state == s.STATE_MIN: s.req_rdy.value = 0 s.resp_val.value = 0 s.msg_data_reg_en.value = 0 s.msg_digit_reg_en.value = 0 s.msg_type_reg_en.value = 0 s.init_go.value = 0 s.max_go.value = 0 s.FindMax_req_val.value = 0 s.FindMax_resp_rdy.value = 0 s.min_go.value = 1 s.FindMin_req_val.value = 1 s.FindMin_resp_rdy.value = 1 s.vote_wr_data_mux_sel.value = 1 s.vote_wr_en.value = 0 s.vote_wr_addr.value = 0 if (s.vote_count > DIGIT - 1): s.msg_idx_reg_en.value = 1 s.vote_rd_addr.value = 0 else: s.msg_idx_reg_en.value = 0 s.vote_rd_addr.value = s.vote_count s.knn_wr_en.value = 0 s.knn_wr_data_mux_sel.value = 1 s.knn_wr_addr.value = 0 s.knn_rd_addr.value = 0 # DONE state elif current_state == s.STATE_DONE: s.req_rdy.value = 0 s.resp_val.value = 1 s.msg_data_reg_en.value = 0 s.msg_digit_reg_en.value = 0 s.msg_type_reg_en.value = 0 s.msg_idx_reg_en.value = 0 s.init_go.value = 0 s.max_go.value = 0 s.min_go.value = 0 s.FindMin_req_val.value = 0 s.FindMin_resp_rdy.value = 0 s.knn_wr_data_mux_sel.value = 0 s.knn_wr_en.value = 0 s.knn_wr_addr.value = 0 s.FindMax_req_val.value = 0 s.FindMax_resp_rdy.value = 0 s.knn_rd_addr.value = 0x1b s.vote_wr_data_mux_sel.value = 1 s.vote_wr_en.value = 0 s.vote_wr_addr.value = s.req_msg_digit_q s.vote_rd_addr.value = s.req_msg_digit_q # Register for resp msg type s.Reg_msg_type = m = RegEnRst(2) s.connect_dict({ m.en: s.msg_type_reg_en, m.in_: s.req_msg_type, m.out: s.resp_msg_type, }) # Register for req msg digit s.Reg_msg_digit = m = RegEnRst(4) s.connect_dict({ m.en: s.msg_digit_reg_en, m.in_: s.req_msg_digit, m.out: s.req_msg_digit_q, })