def __init__(s, nbits=8, mbits=1):

        # Interface
        s.req_msg_data = InPort(nbits)
        s.resp_msg_data = OutPort(1)

        s.result_reg_en = InPort(1)
        s.reference_reg_en = InPort(1)

        s.reference_data = Wire(Bits(nbits))
        s.cmp_result = Wire(Bits(1))

        # Register for reference data
        s.Reg_reference_data = m = RegEnRst(nbits)
        s.connect_dict({
            m.en: s.reference_reg_en,
            m.in_: s.req_msg_data,
            m.out: s.reference_data,
        })

        # Zero comparator
        s.EqComparator_1 = m = EqComparator(nbits)
        s.connect_dict({
            m.in0: s.req_msg_data,
            m.in1: s.reference_data,
            m.out: s.cmp_result,
        })

        # Register for compare result
        s.Reg_cmp_result = m = RegEnRst(1)
        s.connect_dict({
            m.en: s.result_reg_en,
            m.in_: s.cmp_result,
            m.out: s.resp_msg_data,
        })
예제 #2
0
    def __init__(s, nbits=6, k=3):

        # Interface
        s.req_msg_data = InPort(nbits)
        s.resp_msg_data = OutPort(nbits)
        s.resp_msg_idx = OutPort(int(math.ceil(math.log(k, 2))))

        # dpath->ctrl
        s.isLarger = OutPort(1)

        # ctrl->dapth
        s.max_reg_en = InPort(1)
        s.idx_reg_en = InPort(1)
        s.knn_mux_sel = InPort(1)
        s.knn_counter = InPort(int(math.ceil(math.log(k, 2))))  # max 3

        # Internal Signals
        s.knn_data0 = Wire(Bits(nbits))

        s.connect(s.req_msg_data, s.knn_data0)

        # knn Mux
        s.knn_data1 = Wire(Bits(nbits))
        s.max_reg_out = Wire(Bits(nbits))

        s.knn_mux = m = Mux(nbits, 2)
        s.connect_dict({
            m.sel: s.knn_mux_sel,
            m.in_[0]: s.req_msg_data,
            m.in_[1]: s.max_reg_out,
            m.out: s.knn_data1
        })

        # Greater than comparator
        s.knn_GtComparator = m = GtComparator(nbits)
        s.connect_dict({
            m.in0: s.knn_data0,
            m.in1: s.knn_data1,
            m.out: s.isLarger
        })

        # Max Reg
        s.max_reg = m = RegEnRst(nbits)
        s.connect_dict({
            m.en: s.max_reg_en,
            m.in_: s.knn_data0,
            m.out: s.max_reg_out
        })

        # Idx Reg
        s.idx_reg = m = RegEnRst(int(math.ceil(math.log(k, 2))))  # max 2
        s.connect_dict({
            m.en: s.idx_reg_en,
            m.in_: s.knn_counter,
            m.out: s.resp_msg_idx
        })

        s.connect(s.max_reg_out, s.resp_msg_data)
예제 #3
0
  def __init__( s, nreqs ):

    nreqsX2  = nreqs * 2

    s.reqs   = InPort ( nreqs )
    s.grants = OutPort( nreqs )

    # priority enable

    s.priority_en = Wire( 1 )

    # priority register

    s.priority_reg = m = RegEnRst( nreqs, reset_value = 1 )

    s.connect_dict({
      m.en           : s.priority_en,
      m.in_[1:nreqs] : s.grants[0:nreqs-1],
      m.in_[0]       : s.grants[nreqs-1]
    })

    s.kills        = Wire( 2*nreqs + 1 )
    s.priority_int = Wire( 2*nreqs )
    s.reqs_int     = Wire( 2*nreqs )
    s.grants_int   = Wire( 2*nreqs )

    #-------------------------------------------------------------------
    # comb
    #-------------------------------------------------------------------
    @s.combinational
    def comb():

      s.kills[0].value = 1

      s.priority_int[    0:nreqs  ].value = s.priority_reg.out
      s.priority_int[nreqs:nreqsX2].value = 0
      s.reqs_int    [    0:nreqs  ].value = s.reqs
      s.reqs_int    [nreqs:nreqsX2].value = s.reqs

      # Calculate the kill chain
      for i in range( nreqsX2 ):

        # Set internal grants
        if s.priority_int[i].value:
          s.grants_int[i].value = s.reqs_int[i]
        else:
          s.grants_int[i].value = ~s.kills[i] & s.reqs_int[i]

        # Set kill signals
        if s.priority_int[i].value:
          s.kills[i+1].value = s.grants_int[i]
        else:
          s.kills[i+1].value = s.kills[i] | s.grants_int[i]

      # Assign the output ports
      for i in range( nreqs ):
        s.grants[i].value = s.grants_int[i] | s.grants_int[nreqs+i]

      # Set the priority enable
      s.priority_en.value = ( s.grants != 0 )
예제 #4
0
    def __init__(s, num_cores=1):

        #---------------------------------------------------------------------
        # Interface
        #---------------------------------------------------------------------

        # Parameters

        s.core_id = InPort(32)

        # imem ports

        s.imemreq_msg = OutPort(MemReqMsg4B)
        s.imemresp_msg_data = InPort(32)

        # dmem ports

        s.dmemreq_msg_addr = OutPort(32)
        s.dmemreq_msg_data = OutPort(32)
        s.dmemresp_msg_data = InPort(32)

        # mngr ports

        s.mngr2proc_data = InPort(32)
        s.proc2mngr_data = OutPort(32)

        # Control signals (ctrl->dpath)

        s.reg_en_F = InPort(1)
        s.pc_sel_F = InPort(2)

        s.reg_en_D = InPort(1)
        s.op1_sel_D = InPort(1)
        s.op2_sel_D = InPort(2)
        s.csrr_sel_D = InPort(2)
        s.imm_type_D = InPort(3)
        s.imul_req_val_D = InPort(1)

        s.reg_en_X = InPort(1)
        s.alu_fn_X = InPort(4)
        s.ex_result_sel_X = InPort(2)
        s.imul_resp_rdy_X = InPort(1)

        s.reg_en_M = InPort(1)
        s.wb_result_sel_M = InPort(1)

        s.reg_en_W = InPort(1)
        s.rf_waddr_W = InPort(5)
        s.rf_wen_W = InPort(1)
        s.stats_en_wen_W = InPort(1)

        # Status signals (dpath->Ctrl)

        s.inst_D = OutPort(32)
        s.imul_req_rdy_D = OutPort(1)
        s.br_cond_eq_X = OutPort(1)
        s.br_cond_ltu_X = OutPort(1)
        s.br_cond_lt_X = OutPort(1)
        s.imul_resp_val_X = OutPort(1)

        # stats_en output

        s.stats_en = OutPort(1)

        #---------------------------------------------------------------------
        # F stage
        #---------------------------------------------------------------------

        s.pc_F = Wire(32)
        s.pc_plus4_F = Wire(32)

        # PC+4 incrementer

        s.pc_incr_F = m = Incrementer(nbits=32, increment_amount=4)
        s.connect_pairs(m.in_, s.pc_F, m.out, s.pc_plus4_F)

        # forward delaration for branch target and jal target

        s.br_target_X = Wire(32)
        s.jal_target_D = Wire(32)
        s.jalr_target_X = Wire(32)
        # PC sel mux

        s.pc_sel_mux_F = m = Mux(dtype=32, nports=4)
        s.connect_pairs(m.in_[0], s.pc_plus4_F, m.in_[1], s.br_target_X,
                        m.in_[2], s.jal_target_D, m.in_[3], s.jalr_target_X,
                        m.sel, s.pc_sel_F)

        @s.combinational
        def imem_req_F():
            s.imemreq_msg.addr.value = s.pc_sel_mux_F.out

        # PC register

        s.pc_reg_F = m = RegEnRst(dtype=32, reset_value=c_reset_vector - 4)
        s.connect_pairs(m.en, s.reg_en_F, m.in_, s.pc_sel_mux_F.out, m.out,
                        s.pc_F)

        #---------------------------------------------------------------------
        # D stage
        #---------------------------------------------------------------------

        # PC reg in D stage
        # This value is basically passed from F stage for the corresponding
        # instruction to use, e.g. branch to (PC+imm)

        s.pc_reg_D = m = RegEnRst(dtype=32)
        s.connect_pairs(
            m.en,
            s.reg_en_D,
            m.in_,
            s.pc_F,
        )

        # Instruction reg

        s.inst_D_reg = m = RegEnRst(dtype=32, reset_value=c_reset_inst)
        s.connect_pairs(
            m.en,
            s.reg_en_D,
            m.in_,
            s.imemresp_msg_data,
            m.out,
            s.inst_D  # to ctrl
        )

        # Register File
        # The rf_rdata_D wires, albeit redundant in some sense, are used to
        # remind people these data are from D stage.

        s.rf_rdata0_D = Wire(32)
        s.rf_rdata1_D = Wire(32)

        s.rf_wdata_W = Wire(32)

        s.rf = m = RegisterFile(dtype=32,
                                nregs=32,
                                rd_ports=2,
                                const_zero=True)
        s.connect_pairs(m.rd_addr[0], s.inst_D[RS1], m.rd_addr[1],
                        s.inst_D[RS2], m.rd_data[0], s.rf_rdata0_D,
                        m.rd_data[1], s.rf_rdata1_D, m.wr_en, s.rf_wen_W,
                        m.wr_addr, s.rf_waddr_W, m.wr_data, s.rf_wdata_W)

        # Immediate generator

        s.imm_gen_D = m = ImmGenPRTL()
        s.connect_pairs(m.imm_type, s.imm_type_D, m.inst, s.inst_D)

        # csrr sel mux

        s.csrr_sel_mux_D = m = Mux(dtype=32, nports=3)
        s.connect_pairs(
            m.in_[0],
            s.mngr2proc_data,
            m.in_[1],
            num_cores,
            m.in_[2],
            s.core_id,
            m.sel,
            s.csrr_sel_D,
        )

        # op1 sel mux
        s.op1_sel_mux_D = m = Mux(dtype=32, nports=2)
        s.connect_pairs(
            m.in_[0],
            s.rf_rdata0_D,
            m.in_[1],
            s.pc_reg_D.out,
            m.sel,
            s.op1_sel_D,
        )

        # op2 sel mux
        # This mux chooses among RS2, imm, and the output of the above csrr
        # sel mux. Basically we are using two muxes here for pedagogy.

        s.op2_sel_mux_D = m = Mux(dtype=32, nports=3)
        s.connect_pairs(
            m.in_[0],
            s.rf_rdata1_D,
            m.in_[1],
            s.imm_gen_D.imm,
            m.in_[2],
            s.csrr_sel_mux_D.out,
            m.sel,
            s.op2_sel_D,
        )

        # Risc-V always calcs branch/jal target by adding imm(generated above) to PC

        s.pc_plus_imm_D = m = Adder(32)
        s.connect_pairs(
            m.in0,
            s.pc_reg_D.out,
            m.in1,
            s.imm_gen_D.imm,
            m.out,
            s.jal_target_D,
        )

        #---------------------------------------------------------------------
        # X stage
        #---------------------------------------------------------------------

        # br_target_reg_X
        # Since branches are resolved in X stage, we register the target,
        # which is already calculated in D stage, to X stage.

        s.br_target_reg_X = m = RegEnRst(dtype=32, reset_value=0)
        s.connect_pairs(m.en, s.reg_en_X, m.in_, s.pc_plus_imm_D.out, m.out,
                        s.br_target_X)

        # op1 reg

        s.op1_reg_X = m = RegEnRst(dtype=32, reset_value=0)
        s.connect_pairs(
            m.en,
            s.reg_en_X,
            m.in_,
            s.op1_sel_mux_D.out,
        )

        # op2 reg

        s.op2_reg_X = m = RegEnRst(dtype=32, reset_value=0)
        s.connect_pairs(
            m.en,
            s.reg_en_X,
            m.in_,
            s.op2_sel_mux_D.out,
        )

        # dmemreq data reg

        s.dmem_write_data_reg_X = m = RegEnRst(dtype=32, reset_value=0)
        s.connect_pairs(
            m.en,
            s.reg_en_X,
            m.in_,
            s.rf_rdata1_D,
        )

        # pc reg

        s.pc_reg_X = m = RegEnRst(dtype=32, reset_value=0)
        s.connect_pairs(
            m.en,
            s.reg_en_X,
            m.in_,
            s.pc_reg_D.out,
        )

        # ALU

        s.alu_X = m = AluPRTL()
        s.connect_pairs(
            m.in0,
            s.op1_reg_X.out,
            m.in1,
            s.op2_reg_X.out,
            m.fn,
            s.alu_fn_X,
            m.ops_eq,
            s.br_cond_eq_X,
            m.ops_ltu,
            s.br_cond_ltu_X,
            m.ops_lt,
            s.br_cond_lt_X,
            m.out,
            s.jalr_target_X,
        )

        # Multiplier

        s.imul_X = m = IntMulAltRTL()
        s.connect_pairs(
            m.req.msg[0:32],
            s.op1_sel_mux_D.out,
            m.req.msg[32:64],
            s.op2_sel_mux_D.out,
            m.req.val,
            s.imul_req_val_D,
            m.req.rdy,
            s.imul_req_rdy_D,
            m.resp.val,
            s.imul_resp_val_X,
            m.resp.rdy,
            s.imul_resp_rdy_X,
        )

        # PC+4 Incrementer

        s.pc_incr_X = m = Incrementer(nbits=32, increment_amount=4)
        s.connect_pairs(
            m.in_,
            s.pc_reg_X.out,
        )

        # ex result Mux

        s.ex_result_sel_mux_X = m = Mux(dtype=32, nports=3)
        s.connect_pairs(
            m.in_[0],
            s.pc_incr_X.out,
            m.in_[1],
            s.alu_X.out,
            m.in_[2],
            s.imul_X.resp.msg,
            m.sel,
            s.ex_result_sel_X,
        )

        # dmemreq address

        s.connect(s.dmemreq_msg_addr, s.alu_X.out)
        s.connect(s.dmemreq_msg_data, s.dmem_write_data_reg_X.out)
        #---------------------------------------------------------------------
        # M stage
        #---------------------------------------------------------------------

        # Alu execution result reg

        s.ex_result_reg_M = m = RegEnRst(dtype=32, reset_value=0)
        s.connect_pairs(
            m.en,
            s.reg_en_M,
            m.in_,
            s.ex_result_sel_mux_X.out,
        )

        # Writeback result selection mux

        s.wb_result_sel_mux_M = m = Mux(dtype=32, nports=2)
        s.connect_pairs(m.in_[0], s.ex_result_reg_M.out, m.in_[1],
                        s.dmemresp_msg_data, m.sel, s.wb_result_sel_M)

        #---------------------------------------------------------------------
        # W stage
        #---------------------------------------------------------------------

        # Writeback result reg

        s.wb_result_reg_W = m = RegEnRst(dtype=32, reset_value=0)
        s.connect_pairs(
            m.en,
            s.reg_en_W,
            m.in_,
            s.wb_result_sel_mux_M.out,
        )

        s.connect(s.proc2mngr_data, s.wb_result_reg_W.out)

        s.connect(s.rf_wdata_W, s.wb_result_reg_W.out)

        s.stats_en_reg_W = m = RegEnRst(dtype=32, reset_value=0)

        # stats_en logic

        s.connect_pairs(
            m.en,
            s.stats_en_wen_W,
            m.in_,
            s.wb_result_reg_W.out,
        )

        @s.combinational
        def stats_en_logic_W():
            s.stats_en.value = any(
                s.stats_en_reg_W.out)  # reduction with bitwise OR
예제 #5
0
    def __init__(s):

        #-------------------------------------------------------------------
        # Input  Ports
        #-------------------------------------------------------------------

        s.pvalid = InPort(1)
        s.nstall = InPort(1)
        s.nsquash = InPort(1)
        s.ostall = InPort(1)
        s.osquash = InPort(1)

        #-------------------------------------------------------------------
        # Output Ports
        #-------------------------------------------------------------------

        s.nvalid = OutPort(1)
        s.pstall = OutPort(1)
        s.psquash = OutPort(1)
        s.pipereg_en = OutPort(1)
        s.pipereg_val = OutPort(1)
        s.pipe_go = OutPort(1)

        #-------------------------------------------------------------------
        # Static Elaboration
        #-------------------------------------------------------------------

        # current pipeline stage valid bit register

        s.val_reg = RegEnRst(1, reset_value=0)

        s.connect(s.val_reg.in_, s.pvalid)

        # combinationally read out the valid bit of the current state and
        # assign it to pipereg_val

        s.connect(s.val_reg.out, s.pipereg_val)

        #---------------------------------------------------------------------
        # Combinational Logic
        #---------------------------------------------------------------------

        @s.combinational
        def comb():

            # Insert microarchitectural 'nop' value when the current stage is
            # squashed due to nsquash or when the current stage is stalled due
            # to ostall or when the current stage is stalled due to nstall.
            # Otherwise pipeline the valid bit

            if s.nsquash.value or s.nstall.value or s.ostall.value:
                s.nvalid.value = 0
            else:
                s.nvalid.value = s.val_reg.out.value

            # Enable the pipeline registers when the current stage is squashed
            # due to nsquash or when the current stage is not stalling due to
            # the ostall or nstall. Otherwise do not set the enable signal

            if s.nsquash.value or not (s.nstall.value or s.ostall.value):
                s.pipereg_en.value = 1
                s.val_reg.en.value = 1
            else:
                s.pipereg_en.value = 0
                s.val_reg.en.value = 0

            # Set pipego when the current stage is not squashed and not
            # stalled and if the valid bit is set. Else a pipeline transaction
            # will not occur.

            if (s.val_reg.out.value and not s.nsquash.value
                    and not s.nstall.value and not s.ostall.value):
                s.pipe_go.value = 1
            else:
                s.pipe_go.value = 0

            # Accumulate stall signals

            s.pstall.value = s.nstall.value | s.ostall.value

            # Accumulate squash signals

            s.psquash.value = s.nsquash.value | s.osquash.value
    def __init__(s, mbits=1):

        # Interface
        s.req_val = InPort(1)
        s.req_rdy = OutPort(1)

        s.resp_val = OutPort(1)
        s.resp_rdy = InPort(1)

        s.req_msg_type = InPort(mbits)
        s.resp_msg_type = OutPort(mbits)

        s.result_reg_en = OutPort(1)
        s.reference_reg_en = OutPort(1)

        s.msg_type_reg_en = Wire(Bits(1))

        s.STATE_IDLE = 0
        s.STATE_CMP = 1  # do compare or store reference data

        s.state = RegRst(1, reset_value=s.STATE_IDLE)

        #------------------------------------------------------
        # state transtion logic
        #------------------------------------------------------

        @s.combinational
        def state_transitions():

            curr_state = s.state.out
            next_state = s.state.out

            if (curr_state == s.STATE_IDLE):
                if (s.req_val and s.req_rdy):
                    next_state = s.STATE_CMP

            if (curr_state == s.STATE_CMP):
                if (s.resp_val and s.resp_rdy):
                    next_state = s.STATE_IDLE

            s.state.in_.value = next_state

        #------------------------------------------------------
        # state output logic
        #------------------------------------------------------

        @s.combinational
        def state_outputs():

            current_state = s.state.out

            if (current_state == s.STATE_IDLE):
                s.req_rdy.value = 1
                s.resp_val.value = 0
                s.result_reg_en.value = 1
                s.reference_reg_en.value = s.req_msg_type
                s.msg_type_reg_en.value = 1

            elif (current_state == s.STATE_CMP):
                s.req_rdy.value = 0
                s.resp_val.value = 1
                s.result_reg_en.value = 0
                s.reference_reg_en.value = 0
                s.msg_type_reg_en.value = 0

        # Register for resp msg type
        s.Reg_msg_type = m = RegEnRst(1)
        s.connect_dict({
            m.en: s.msg_type_reg_en,
            m.in_: s.req_msg_type,
            m.out: s.resp_msg_type,
        })
예제 #7
0
    def __init__(s, mapper_num=3, nbits=6, k=3, rst_value=50):

        addr_nbits = int(math.ceil(math.log(k, 2)))
        sum_nbits = int(math.ceil(math.log((2**nbits - 1) * k, 2)))

        # interface
        s.in_ = [InPort(nbits) for _ in range(mapper_num)]
        s.out = OutPort(sum_nbits)
        s.rst = InPort(1)

        # internal wires
        s.min_dist = Wire(Bits(nbits))
        s.max_dist = Wire(Bits(nbits))
        s.max_idx = Wire(Bits(addr_nbits))
        s.isLess = Wire(Bits(1))

        s.rd_data = Wire[k](Bits(nbits))
        s.wr_data = Wire(Bits(nbits))
        s.wr_addr = Wire[k](Bits(addr_nbits))
        s.wr_en = Wire[k](Bits(1))

        # find min of inputs
        s.findmin = FindMin(nbits, mapper_num)

        for i in xrange(mapper_num):
            s.connect_pairs(s.findmin.in_[i], s.in_[i])

        s.connect_pairs(s.findmin.out, s.min_dist)

        # find max in knn_table
        s.findmax = FindMaxIdx(nbits, k)

        for i in xrange(k):
            s.connect_pairs(s.findmax.in_[i], s.rd_data[i])

        s.connect_pairs(s.findmax.out, s.max_dist)
        s.connect_pairs(s.findmax.idx, s.max_idx)

        # compare min_dist and max_dist
        s.cmp = LtComparator(nbits)
        s.connect_pairs(s.cmp.in0, s.min_dist, s.cmp.in1, s.max_dist,
                        s.cmp.out, s.isLess)

        # choose the smaller one write back
        @s.combinational
        def comb_logic():
            if (s.rst == 1):
                for i in xrange(k):
                    s.wr_en[i].value = 1
            else:
                for i in xrange(k):
                    if (i == s.max_idx):
                        s.wr_en[i].value = s.isLess
                    else:
                        s.wr_en[i].value = 0

        # mux for wr_data
        s.wr_data_mux = m = Mux(nbits, 2)
        s.connect_dict({
            m.in_[0]: s.min_dist,
            m.in_[1]: rst_value,
            m.sel: s.rst,
            m.out: s.wr_data
        })

        # Registers
        s.regs = [RegEnRst(nbits, rst_value) for i in xrange(k)]

        for i in xrange(k):

            s.connect_pairs(s.regs[i].in_, s.wr_data, s.regs[i].out,
                            s.rd_data[i], s.regs[i].en, s.wr_en[i])

        # sum of knn_table
        s.addertree = AdderTree(nbits, k)
        for i in xrange(k):
            s.connect_pairs(s.addertree.in_[i], s.rd_data[i])

        s.connect(s.addertree.out, s.out)
    def __init__(s, k=3):

        SUM_DATA_SIZE = int(math.ceil(math.log(50 * k, 2)))

        s.req_msg_data = InPort(RE_DATA_SIZE)
        s.resp_msg_digit = OutPort(4)

        # ctrl->dpath
        s.knn_wr_data_mux_sel = InPort(1)
        s.knn_wr_addr = InPort(int(math.ceil(math.log(k * DIGIT,
                                                      2))))  # max 30
        s.knn_rd_addr = InPort(int(math.ceil(math.log(k * DIGIT,
                                                      2))))  # max 30
        s.knn_wr_en = InPort(1)

        s.vote_wr_data_mux_sel = InPort(1)
        s.vote_wr_addr = InPort(int(math.ceil(math.log(DIGIT, 2))))  # max 10
        s.vote_rd_addr = InPort(int(math.ceil(math.log(DIGIT, 2))))  # max 10
        s.vote_wr_en = InPort(1)

        s.FindMax_req_val = InPort(1)
        s.FindMax_resp_rdy = InPort(1)
        s.FindMin_req_val = InPort(1)
        s.FindMin_resp_rdy = InPort(1)

        s.msg_data_reg_en = InPort(1)
        s.msg_idx_reg_en = InPort(1)

        # dpath->ctrl
        s.FindMax_req_rdy = OutPort(1)
        s.FindMax_resp_val = OutPort(1)
        s.FindMax_resp_idx = OutPort(int(math.ceil(math.log(k, 2))))  # max 3
        s.FindMin_req_rdy = OutPort(1)
        s.FindMin_resp_val = OutPort(1)
        s.isSmaller = OutPort(1)

        # internal wires
        s.knn_rd_data = Wire(Bits(RE_DATA_SIZE))
        s.knn_wr_data = Wire(Bits(RE_DATA_SIZE))

        s.subtractor_out = Wire(Bits(SUM_DATA_SIZE))
        s.adder_out = Wire(Bits(SUM_DATA_SIZE))

        s.vote_rd_data = Wire(Bits(SUM_DATA_SIZE))
        s.vote_wr_data = Wire(Bits(SUM_DATA_SIZE))

        s.FindMax_req_data = Wire(Bits(RE_DATA_SIZE))
        s.FindMax_resp_data = Wire(Bits(RE_DATA_SIZE))
        s.FindMin_req_data = Wire(Bits(SUM_DATA_SIZE))
        s.FindMin_resp_data = Wire(Bits(SUM_DATA_SIZE))
        s.FindMin_resp_idx = Wire(Bits(int(math.ceil(math.log(DIGIT,
                                                              2)))))  # max 10

        # Req msg data Register
        s.req_msg_data_q = Wire(Bits(RE_DATA_SIZE))

        s.req_msg_data_reg = m = RegEnRst(RE_DATA_SIZE)
        s.connect_dict({
            m.en: s.msg_data_reg_en,
            m.in_: s.req_msg_data,
            m.out: s.req_msg_data_q
        })

        # knn_wr_data Mux
        s.knn_wr_data_mux = m = Mux(RE_DATA_SIZE, 2)
        s.connect_dict({
            m.sel: s.knn_wr_data_mux_sel,
            m.in_[0]: 50,
            m.in_[1]: s.req_msg_data_q,
            m.out: s.knn_wr_data
        })

        # register file knn_table
        s.knn_table = m = RegisterFile(dtype=Bits(RE_DATA_SIZE),
                                       nregs=k * DIGIT,
                                       rd_ports=1,
                                       wr_ports=1,
                                       const_zero=False)
        s.connect_dict({
            m.rd_addr[0]: s.knn_rd_addr,
            m.rd_data[0]: s.knn_rd_data,
            m.wr_addr: s.knn_wr_addr,
            m.wr_data: s.knn_wr_data,
            m.wr_en: s.knn_wr_en
        })

        # vote_wr_data Mux
        s.vote_wr_data_mux = m = Mux(SUM_DATA_SIZE, 2)
        s.connect_dict({
            m.sel: s.vote_wr_data_mux_sel,
            m.in_[0]: 50 * k,
            m.in_[1]: s.adder_out,
            m.out: s.vote_wr_data
        })

        # register file knn_vote
        s.knn_vote = m = RegisterFile(dtype=Bits(SUM_DATA_SIZE),
                                      nregs=DIGIT,
                                      rd_ports=1,
                                      wr_ports=1,
                                      const_zero=False)
        s.connect_dict({
            m.rd_addr[0]: s.vote_rd_addr,
            m.rd_data[0]: s.vote_rd_data,
            m.wr_addr: s.vote_wr_addr,
            m.wr_data: s.vote_wr_data,
            m.wr_en: s.vote_wr_en
        })

        # Find max value of knn_table for a given digit
        s.connect_wire(s.knn_rd_data, s.FindMax_req_data)

        s.findmax = m = FindMaxPRTL(RE_DATA_SIZE, k)

        s.connect_dict({
            m.req.val: s.FindMax_req_val,
            m.req.rdy: s.FindMax_req_rdy,
            m.req.msg.data: s.FindMax_req_data,
            m.resp.val: s.FindMax_resp_val,
            m.resp.rdy: s.FindMax_resp_rdy,
            m.resp.msg.data: s.FindMax_resp_data,
            m.resp.msg.idx: s.FindMax_resp_idx
        })

        # Less than comparator
        s.knn_LtComparator = m = LtComparator(RE_DATA_SIZE)
        s.connect_dict({
            m.in0: s.req_msg_data_q,
            m.in1: s.FindMax_resp_data,
            m.out: s.isSmaller
        })

        # Zero extender
        s.FindMax_resp_data_zext = Wire(Bits(SUM_DATA_SIZE))
        s.FindMax_resp_data_zexter = m = ZeroExtender(RE_DATA_SIZE,
                                                      SUM_DATA_SIZE)
        s.connect_dict({
            m.in_: s.FindMax_resp_data,
            m.out: s.FindMax_resp_data_zext,
        })

        # Subtractor
        s.subtractor = m = Subtractor(SUM_DATA_SIZE)
        s.connect_dict({
            m.in0: s.vote_rd_data,
            m.in1: s.FindMax_resp_data_zext,
            m.out: s.subtractor_out
        })

        # Zero extender
        s.req_msg_data_zext = Wire(Bits(SUM_DATA_SIZE))
        s.req_msg_data_zexter = m = ZeroExtender(RE_DATA_SIZE, SUM_DATA_SIZE)
        s.connect_dict({
            m.in_: s.req_msg_data_q,
            m.out: s.req_msg_data_zext,
        })

        # Adder
        s.adder = m = Adder(SUM_DATA_SIZE)
        s.connect_dict({
            m.in0: s.subtractor_out,
            m.in1: s.req_msg_data_zext,
            m.cin: 0,
            m.out: s.adder_out
        })

        # Find min value of knn_vote, return digit
        s.connect_wire(s.vote_rd_data, s.FindMin_req_data)

        s.findmin = m = FindMinPRTL(SUM_DATA_SIZE, DIGIT)
        s.connect_dict({
            m.req.val: s.FindMin_req_val,
            m.req.rdy: s.FindMin_req_rdy,
            m.req.msg.data: s.FindMin_req_data,
            m.resp.val: s.FindMin_resp_val,
            m.resp.rdy: s.FindMin_resp_rdy,
            m.resp.msg.data: s.FindMin_resp_data,
            m.resp.msg.digit: s.FindMin_resp_idx
        })

        # Resp idx Register
        s.resp_msg_idx_q = Wire(Bits(int(math.ceil(math.log(DIGIT, 2)))))

        s.req_msg_idx_reg = m = RegEnRst(int(math.ceil(math.log(DIGIT, 2))))
        s.connect_dict({
            m.en: s.msg_idx_reg_en,
            m.in_: s.FindMin_resp_idx,
            m.out: s.resp_msg_idx_q
        })

        # connect output idx
        s.connect(s.resp_msg_idx_q, s.resp_msg_digit)
    def __init__(s, k=3):

        # ctrl to toplevel
        s.req_val = InPort(1)
        s.req_rdy = OutPort(1)

        s.resp_val = OutPort(1)
        s.resp_rdy = InPort(1)

        s.req_msg_digit = InPort(4)

        s.req_msg_type = InPort(2)
        s.resp_msg_type = OutPort(2)

        # ctrl->dpath
        s.knn_wr_data_mux_sel = OutPort(1)
        s.knn_wr_addr = OutPort(int(math.ceil(math.log(k * DIGIT,
                                                       2))))  # max 30
        s.knn_rd_addr = OutPort(int(math.ceil(math.log(k * DIGIT,
                                                       2))))  # max 30
        s.knn_wr_en = OutPort(1)

        s.vote_wr_data_mux_sel = OutPort(1)
        s.vote_wr_addr = OutPort(int(math.ceil(math.log(DIGIT, 2))))  # max 10
        s.vote_rd_addr = OutPort(int(math.ceil(math.log(DIGIT, 2))))  # max 10
        s.vote_wr_en = OutPort(1)

        s.FindMax_req_val = OutPort(1)
        s.FindMax_resp_rdy = OutPort(1)
        s.FindMin_req_val = OutPort(1)
        s.FindMin_resp_rdy = OutPort(1)

        s.msg_data_reg_en = OutPort(1)
        s.msg_idx_reg_en = OutPort(1)

        # dpath->ctrl
        s.FindMax_req_rdy = InPort(1)
        s.FindMax_resp_val = InPort(1)
        s.FindMax_resp_idx = InPort(int(math.ceil(math.log(k, 2))))  # max 10
        s.FindMin_req_rdy = InPort(1)
        s.FindMin_resp_val = InPort(1)
        s.isSmaller = InPort(1)

        s.msg_type_reg_en = Wire(Bits(1))
        s.msg_digit_reg_en = Wire(Bits(1))
        s.req_msg_digit_q = Wire(Bits(4))
        s.init_go = Wire(Bits(1))
        s.max_go = Wire(Bits(1))
        s.min_go = Wire(Bits(1))

        # State element
        s.STATE_IDLE = 0
        s.STATE_INIT = 1
        s.STATE_MAX = 2
        s.STATE_MIN = 3
        s.STATE_DONE = 4

        s.state = RegRst(3, reset_value=s.STATE_IDLE)

        # Counters
        s.init_count = Wire(int(math.ceil(math.log(k * DIGIT, 2))))  # max 30
        s.knn_count = Wire(int(math.ceil(math.log(k * DIGIT, 2))))  # max 30
        s.vote_count = Wire(int(math.ceil(math.log(DIGIT, 2))))  # max 10

        @s.tick
        def counter():
            if (s.init_go == 1):
                s.init_count.next = s.init_count + 1
            else:
                s.init_count.next = 0

            if (s.max_go == 1):
                s.knn_count.next = s.knn_count + 1
            else:
                s.knn_count.next = s.req_msg_digit * k

            if (s.min_go == 1):
                s.vote_count.next = s.vote_count + 1
            else:
                s.vote_count.next = 0

        # State Transition Logic
        @s.combinational
        def state_transitions():

            curr_state = s.state.out
            next_state = s.state.out

            # Transition out of IDLE state
            if (curr_state == s.STATE_IDLE):
                if ((s.req_val and s.req_rdy) and (s.req_msg_type == 1)):
                    next_state = s.STATE_INIT
                elif ((s.req_val and s.req_rdy) and (s.req_msg_type == 0)
                      and (s.FindMax_req_val and s.FindMax_req_rdy)):
                    next_state = s.STATE_MAX
                elif ((s.req_val and s.req_rdy) and (s.req_msg_type == 2)
                      and (s.FindMin_req_val and s.FindMin_req_rdy)):
                    next_state = s.STATE_MIN

            # Transition out of INIT state
            if (curr_state == s.STATE_INIT):
                if (s.init_count == k * DIGIT - 1):
                    next_state = s.STATE_DONE

            # Transition out of MAX  state
            if (curr_state == s.STATE_MAX):
                if (s.FindMax_resp_val and s.FindMax_resp_rdy):
                    next_state = s.STATE_DONE

            # Transition out of MIN  state
            if (curr_state == s.STATE_MIN):
                if (s.FindMin_resp_val and s.FindMin_resp_rdy):
                    next_state = s.STATE_DONE

            # Transition out of DONE state
            if (curr_state == s.STATE_DONE):
                if (s.resp_val and s.resp_rdy):
                    next_state = s.STATE_IDLE

            s.state.in_.value = next_state

        # State Output Logic
        @s.combinational
        def state_outputs():

            current_state = s.state.out

            # IDLE state
            if current_state == s.STATE_IDLE:
                s.req_rdy.value = 1
                s.resp_val.value = 0

                s.msg_data_reg_en.value = 1
                s.msg_digit_reg_en.value = 1
                s.msg_type_reg_en.value = 1
                s.msg_idx_reg_envalue = 0

                s.init_go.value = 0

                s.knn_wr_data_mux_sel.value = 0
                s.knn_wr_en.value = 0
                s.knn_wr_addr.value = 0
                if ((s.req_val and s.req_rdy) and (s.req_msg_type == 0)):
                    s.max_go.value = 1
                    s.FindMax_req_val.value = 1
                    s.FindMax_resp_rdy.value = 0
                    s.knn_rd_addr.value = s.knn_count
                    s.min_go.value = 0
                    s.FindMin_req_val.value = 0
                    s.FindMin_resp_rdy.value = 0
                    s.vote_rd_addr.value = s.req_msg_digit_q
                elif ((s.req_val and s.req_rdy) and (s.req_msg_type == 2)):
                    s.max_go.value = 0
                    s.FindMax_req_val.value = 0
                    s.FindMax_resp_rdy.value = 0
                    s.knn_rd_addr.value = 0
                    s.min_go.value = 1
                    s.FindMin_req_val.value = 1
                    s.FindMin_resp_rdy.value = 0
                    s.vote_rd_addr.value = s.vote_count
                else:
                    s.max_go.value = 0
                    s.FindMax_req_val.value = 0
                    s.FindMax_resp_rdy.value = 0
                    s.knn_rd_addr.value = 0
                    s.min_go.value = 0
                    s.FindMin_req_val.value = 0
                    s.FindMin_resp_rdy.value = 0
                    s.vote_rd_addr.value = s.req_msg_digit_q

                s.vote_wr_data_mux_sel.value = 1
                s.vote_wr_en.value = 0
                s.vote_wr_addr.value = s.req_msg_digit_q

            # INI state
            elif current_state == s.STATE_INIT:
                s.req_rdy.value = 0
                s.resp_val.value = 0

                s.msg_data_reg_en.value = 0
                s.msg_digit_reg_en.value = 0
                s.msg_type_reg_en.value = 0
                s.msg_idx_reg_en.value = 0

                s.init_go.value = 1
                s.max_go.value = 0

                s.min_go.value = 0
                s.FindMin_req_val.value = 0
                s.FindMin_resp_rdy.value = 0

                s.knn_wr_data_mux_sel.value = 0
                s.knn_wr_en.value = 1
                s.knn_wr_addr.value = s.init_count
                s.FindMax_req_val.value = 0
                s.FindMax_resp_rdy.value = 0
                # knn_rd for debugging
                if s.init_count == 0:
                    s.knn_rd_addr.value = 0
                else:
                    s.knn_rd_addr.value = s.init_count - 1

                if (s.init_count < DIGIT):
                    s.vote_wr_data_mux_sel.value = 0
                    s.vote_wr_en.value = 1
                    s.vote_wr_addr.value = s.init_count
                    # vote_rd for debugging
                    if s.init_count == 0:
                        s.vote_rd_addr.value = 0
                    else:
                        s.vote_rd_addr.value = s.init_count - 1
                else:
                    s.vote_wr_data_mux_sel.value = 1
                    s.vote_wr_en.value = 0
                    s.vote_wr_addr.value = 0
                    s.vote_rd_addr.value = 0

            # MAX state
            elif current_state == s.STATE_MAX:
                s.req_rdy.value = 0
                s.resp_val.value = 0

                s.msg_data_reg_en.value = 0
                s.msg_digit_reg_en.value = 0
                s.msg_type_reg_en.value = 0
                s.msg_idx_reg_en.value = 0

                s.init_go.value = 0
                s.max_go.value = 1

                s.min_go.value = 0
                s.FindMin_req_val.value = 0
                s.FindMin_resp_rdy.value = 0

                s.FindMax_req_val.value = 1
                s.FindMax_resp_rdy.value = 1
                if (s.knn_count > s.req_msg_digit_q * k + 2):
                    s.knn_rd_addr.value = 0
                    s.knn_wr_addr.value = s.req_msg_digit_q * k + s.FindMax_resp_idx
                    if (s.isSmaller == 1):
                        s.knn_wr_data_mux_sel.value = 1
                        s.knn_wr_en.value = 1
                        s.vote_wr_en.value = 1
                    else:
                        s.knn_wr_data_mux_sel.value = 0
                        s.knn_wr_en.value = 0
                else:
                    s.knn_wr_data_mux_sel.value = 0
                    s.knn_rd_addr.value = s.knn_count
                    s.knn_wr_en.value = 0
                    s.knn_wr_addr.value = 0
                    s.vote_wr_en.value = 0

                s.vote_wr_data_mux_sel.value = 1
                s.vote_wr_addr.value = s.req_msg_digit_q
                s.vote_rd_addr.value = s.req_msg_digit_q

            # MIN state
            elif current_state == s.STATE_MIN:
                s.req_rdy.value = 0
                s.resp_val.value = 0

                s.msg_data_reg_en.value = 0
                s.msg_digit_reg_en.value = 0
                s.msg_type_reg_en.value = 0

                s.init_go.value = 0
                s.max_go.value = 0
                s.FindMax_req_val.value = 0
                s.FindMax_resp_rdy.value = 0

                s.min_go.value = 1
                s.FindMin_req_val.value = 1
                s.FindMin_resp_rdy.value = 1

                s.vote_wr_data_mux_sel.value = 1
                s.vote_wr_en.value = 0
                s.vote_wr_addr.value = 0
                if (s.vote_count > DIGIT - 1):
                    s.msg_idx_reg_en.value = 1
                    s.vote_rd_addr.value = 0
                else:
                    s.msg_idx_reg_en.value = 0
                    s.vote_rd_addr.value = s.vote_count

                s.knn_wr_en.value = 0
                s.knn_wr_data_mux_sel.value = 1
                s.knn_wr_addr.value = 0
                s.knn_rd_addr.value = 0

            # DONE state
            elif current_state == s.STATE_DONE:
                s.req_rdy.value = 0
                s.resp_val.value = 1

                s.msg_data_reg_en.value = 0
                s.msg_digit_reg_en.value = 0
                s.msg_type_reg_en.value = 0
                s.msg_idx_reg_en.value = 0

                s.init_go.value = 0
                s.max_go.value = 0

                s.min_go.value = 0
                s.FindMin_req_val.value = 0
                s.FindMin_resp_rdy.value = 0

                s.knn_wr_data_mux_sel.value = 0
                s.knn_wr_en.value = 0
                s.knn_wr_addr.value = 0
                s.FindMax_req_val.value = 0
                s.FindMax_resp_rdy.value = 0
                s.knn_rd_addr.value = 0x1b

                s.vote_wr_data_mux_sel.value = 1
                s.vote_wr_en.value = 0
                s.vote_wr_addr.value = s.req_msg_digit_q
                s.vote_rd_addr.value = s.req_msg_digit_q

        # Register for resp msg type
        s.Reg_msg_type = m = RegEnRst(2)
        s.connect_dict({
            m.en: s.msg_type_reg_en,
            m.in_: s.req_msg_type,
            m.out: s.resp_msg_type,
        })

        # Register for req msg digit
        s.Reg_msg_digit = m = RegEnRst(4)
        s.connect_dict({
            m.en: s.msg_digit_reg_en,
            m.in_: s.req_msg_digit,
            m.out: s.req_msg_digit_q,
        })