def __init__( s, dtype ): s.enq_bits = InPort ( dtype ) s.deq_bits = OutPort ( dtype ) # Control signal (ctrl -> dpath) s.wen = InPort ( 1 ) s.bypass_mux_sel = InPort ( 1 ) # Queue storage s.queue = RegEn( dtype ) s.connect( s.queue.en, s.wen ) s.connect( s.queue.in_, s.enq_bits ) # Bypass mux s.bypass_mux = Mux( dtype, 2 ) s.connect( s.bypass_mux.in_[0], s.queue.out ) s.connect( s.bypass_mux.in_[1], s.enq_bits ) s.connect( s.bypass_mux.sel, s.bypass_mux_sel ) s.connect( s.bypass_mux.out, s.deq_bits )
def __init__(s, Type): s.enq = EnqIfcRTL(Type) s.deq = DeqIfcRTL(Type) s.buf = RegEn(Type)(in_=s.enq.msg) s.full = Reg(Bits1) s.byp_mux = Mux(Type, 2)( out=s.deq.msg, in_={ 0: s.enq.msg, 1: s.buf.out, }, sel=s.full.out, # full -- buf.out, empty -- bypass ) @s.update def up_bypq_set_enq_rdy(): s.enq.rdy = ~s.full.out @s.update def up_bypq_set_deq_rdy(): s.deq.rdy = s.full.out | s.enq.en # if enq is enabled deq must be rdy @s.update def up_bypq_full(): # enable buf <==> receiver marks deq.en=0 even if it sees deq.rdy=1 s.buf.en = ~s.deq.en & s.enq.en s.full.in_ = ~s.deq.en & (s.enq.en | s.full.out)
def __init__(s): #--------------------------------------------------------------------- # Interface #--------------------------------------------------------------------- s.in_val = InPort(1) s.in_a = InPort(16) s.in_b = InPort(16) s.in_result = InPort(16) s.out_val = OutPort(1) s.out_a = OutPort(16) s.out_b = OutPort(16) s.out_result = OutPort(16) #--------------------------------------------------------------------- # Structural composition #--------------------------------------------------------------------- # Right shift s.rshifter = m = RightLogicalShifter(16) s.connect_dict({ m.in_: s.in_b, m.shamt: 1, m.out: s.out_b, }) # Left shifter s.lshifter = m = LeftLogicalShifter(16) s.connect_dict({ m.in_: s.in_a, m.shamt: 1, m.out: s.out_a, }) # Adder s.add = m = Adder(16) s.connect_dict({ m.in0: s.in_a, m.in1: s.in_result, }) # Result mux s.result_mux = m = Mux(16, 2) s.connect_dict({ m.sel: s.in_b[0], m.in_[0]: s.in_result, m.in_[1]: s.add.out, m.out: s.out_result }) # Connect the valid bits s.connect(s.in_val, s.out_val)
def __init__(s, nbits=6, k=3): # Interface s.req_msg_data = InPort(nbits) s.resp_msg_data = OutPort(nbits) s.resp_msg_idx = OutPort(int(math.ceil(math.log(k, 2)))) # dpath->ctrl s.isLarger = OutPort(1) # ctrl->dapth s.max_reg_en = InPort(1) s.idx_reg_en = InPort(1) s.knn_mux_sel = InPort(1) s.knn_counter = InPort(int(math.ceil(math.log(k, 2)))) # max 3 # Internal Signals s.knn_data0 = Wire(Bits(nbits)) s.connect(s.req_msg_data, s.knn_data0) # knn Mux s.knn_data1 = Wire(Bits(nbits)) s.max_reg_out = Wire(Bits(nbits)) s.knn_mux = m = Mux(nbits, 2) s.connect_dict({ m.sel: s.knn_mux_sel, m.in_[0]: s.req_msg_data, m.in_[1]: s.max_reg_out, m.out: s.knn_data1 }) # Greater than comparator s.knn_GtComparator = m = GtComparator(nbits) s.connect_dict({ m.in0: s.knn_data0, m.in1: s.knn_data1, m.out: s.isLarger }) # Max Reg s.max_reg = m = RegEnRst(nbits) s.connect_dict({ m.en: s.max_reg_en, m.in_: s.knn_data0, m.out: s.max_reg_out }) # Idx Reg s.idx_reg = m = RegEnRst(int(math.ceil(math.log(k, 2)))) # max 2 s.connect_dict({ m.en: s.idx_reg_en, m.in_: s.knn_counter, m.out: s.resp_msg_idx }) s.connect(s.max_reg_out, s.resp_msg_data)
def __init__(s, Type): s.enq = InValRdyIfc(Type) s.deq = OutValRdyIfc(Type) s.buf = RegEn(Type)(in_=s.enq.msg) s.next_full = Wire(int if Type is int else Bits1) s.full = Wire(int if Type is int else Bits1) s.byp_mux = Mux(Type, 2)( out=s.deq.msg, in_={ 0: s.enq.msg, 1: s.buf.out, }, sel=s.full, # full -- buf.out, empty -- bypass ) @s.update_on_edge def up_full(): s.full = s.next_full if Type is int: @s.update def up_bypq_set_enq_rdy(): s.enq.rdy = not s.full @s.update def up_bypq_internal(): s.buf.en = (not s.deq.rdy) & (s.enq.val & s.enq.rdy) s.next_full = (not s.deq.rdy) & s.deq.val else: @s.update def up_bypq_set_enq_rdy(): s.enq.rdy = ~s.full @s.update def up_bypq_internal(): s.buf.en = (~s.deq.rdy) & (s.enq.val & s.enq.rdy) s.next_full = (~s.deq.rdy) & s.deq.val # this enables the sender to make enq.val depend on enq.rdy @s.update def up_bypq_set_deq_val(): s.deq.val = s.full | s.enq.val
def __init__ ( s ): s.req_msg_data = InPort (1) s.resp_msg_data = OutPort(32) s.sel = InPort (1) s.en = InPort (1) # Input Mux s.reg_out = Wire(32) s.mux = m = Mux( 32, 2) s.connect_dict({ m.sel : s.sel, m.in_[0] : 0, m.in_[1] : s.reg_out }) # Output Register s.adder_out = Wire(32) s.reg = m = RegEn( 32 ) s.connect_dict({ m.en : s.en, m.in_ : s.adder_out, m.out : s.reg_out }) # Zero Extender s.zext = m = ZeroExtender( 1, 32 ) s.connect_dict({ m.in_ : s.req_msg_data }) # Adder s.add = m = Adder( 32 ) s.connect_dict({ m.in0 : s.zext.out, m.in1 : s.mux.out, m.cin : 0, m.out : s.adder_out }) # Connect to output port s.connect( s.reg_out, s.resp_msg_data )
def __init__(s, nreqs): s.en = InPort(1) s.reqs = InPort(nreqs) s.grants = OutPort(nreqs) ARB = 1 NO_ARB = 0 # Request Mux s.reqs_mux = m = Mux(nreqs, nports=2) s.connect_dict({m.in_[NO_ARB]: 0, m.in_[ARB]: s.reqs, m.sel: s.en}) # round robin arbiter s.rr_arbiter = RoundRobinArbiter(nreqs) s.connect(s.rr_arbiter.reqs, s.reqs_mux.out) s.connect(s.rr_arbiter.grants, s.grants)
def __init__(s, nbits=6, k=3): addr_nbits = int(math.ceil(math.log(k, 2))) # interface s.in_ = [InPort(nbits) for _ in range(k)] s.out = OutPort(nbits) s.idx = OutPort(addr_nbits) # muxs and cmps s.muxs = [Mux(nbits, 2) for i in xrange(k - 1)] s.cmps = [GtComparator(nbits) for i in xrange(k - 1)] s.connect_wire(s.muxs[0].in_[0], s.in_[0]) s.connect_wire(s.muxs[0].in_[1], s.in_[1]) s.connect_wire(s.cmps[0].in0, s.in_[1]) s.connect_wire(s.cmps[0].in1, s.in_[0]) s.connect_wire(s.cmps[0].out, s.muxs[0].sel) if k > 2: for i in xrange(1, k - 1): s.connect_pairs(s.muxs[i].in_[0], s.muxs[i - 1].out, s.muxs[i].in_[1], s.in_[i + 1], s.muxs[i].sel, s.cmps[i].out, s.cmps[i].in0, s.in_[i + 1], s.cmps[i].in1, s.muxs[i - 1].out) @s.combinational def comb_logic(): s.idx.value = 0 for i in range(k - 1): if (s.muxs[i].sel == 1): if (i == 0): s.idx.value = s.muxs[i].sel else: s.idx.value = i + 1 s.connect(s.muxs[k - 2].out, s.out)
def __init__(s, nbits=6, k=3): # interface s.in_ = [InPort(nbits) for _ in range(k)] s.out = OutPort(nbits) # muxs and cmps s.muxs = [Mux(nbits, 2) for i in xrange(k - 1)] s.cmps = [LtComparator(nbits) for i in xrange(k - 1)] s.connect_wire(s.muxs[0].in_[0], s.in_[0]) s.connect_wire(s.muxs[0].in_[1], s.in_[1]) s.connect_wire(s.cmps[0].in0, s.in_[1]) s.connect_wire(s.cmps[0].in1, s.in_[0]) s.connect_wire(s.cmps[0].out, s.muxs[0].sel) if k > 2: for i in xrange(1, k - 1): s.connect_pairs(s.muxs[i].in_[0], s.muxs[i - 1].out, s.muxs[i].in_[1], s.in_[i + 1], s.muxs[i].sel, s.cmps[i].out, s.cmps[i].in0, s.in_[i + 1], s.cmps[i].in1, s.muxs[i - 1].out) s.connect(s.muxs[k - 2].out, s.out)
def __init__(s): # Interface s.req_msg_data = InPort(DATA_NBITS) s.resp_msg_data = OutPort(DISTANCE_NBITS) s.req_msg_type = InPort(TYPE_NBITS) s.resp_msg_type = OutPort(TYPE_NBITS) s.req_msg_digit = InPort(DIGIT_NBITS) s.resp_msg_digit = OutPort(DIGIT_NBITS) # Control signals (ctrl -> dpath) s.en_test = InPort(1) s.en_train = InPort(1) s.en_out = InPort(1) s.sel_out = InPort(1) s.sel = InPort(6) # Input Mux for Test Data s.in_test = Wire(DATA_NBITS) s.mux_in_test = m = Mux(DATA_NBITS, 2) s.connect_dict({ m.sel: s.req_msg_type, m.in_[0]: s.in_test, m.in_[1]: s.req_msg_data, m.out: s.in_test }) # Input Mux for Train Data s.in_train = Wire(DATA_NBITS) s.mux_in_train = m = Mux(DATA_NBITS, 2) s.connect_dict({ m.sel: s.req_msg_type, m.in_[0]: s.req_msg_data, m.in_[1]: s.in_train, m.out: s.in_train }) # Register for Test Data s.out_test = Wire(DATA_NBITS) s.reg_test = m = RegEn(DATA_NBITS) s.connect_dict({m.en: s.en_test, m.in_: s.in_test, m.out: s.out_test}) # Register for Train Data s.out_train = Wire(DATA_NBITS) s.reg_train = m = RegEn(DATA_NBITS) s.connect_dict({ m.en: s.en_train, m.in_: s.in_train, m.out: s.out_train }) # 49-1 Mux for Test Data s.data_test = Wire(1) s.mux_test = m = Mux(1, 49) for i in range(49): s.connect(m.in_[i], s.out_test[i]) s.connect(m.sel, s.sel) s.connect(m.out, s.data_test) # 49-1 Mux for Train Data s.data_train = Wire(1) s.mux_train = m = Mux(1, 49) for i in range(49): s.connect(m.in_[i], s.out_train[i]) s.connect(m.sel, s.sel) s.connect(m.out, s.data_train) # Comparator s.is_not_equal = Wire(1) s.is_equal = Wire(1) s.comp = m = EqComparator(1) s.connect_dict({ m.in0: s.data_test, m.in1: s.data_train, m.out: s.is_equal }) @s.combinational def not_value(): s.is_not_equal.value = ~s.is_equal # Zero Extender s.zext = m = ZeroExtender(1, DISTANCE_NBITS) s.connect_dict({m.in_: s.is_not_equal}) # Input Mux for Adder s.reg_out = Wire(DISTANCE_NBITS) s.mux_add = m = Mux(DISTANCE_NBITS, 2) s.connect_dict({m.sel: s.sel_out, m.in_[0]: 0, m.in_[1]: s.reg_out}) # Adder s.add = m = Adder(DISTANCE_NBITS) s.connect_dict({m.in0: s.zext.out, m.in1: s.mux_add.out, m.cin: 0}) # Output Register s.reg = m = RegEn(DISTANCE_NBITS) s.connect_dict({m.en: s.en_out, m.in_: s.add.out, m.out: s.reg_out}) # Connect to output port s.connect(s.reg_out, s.resp_msg_data) s.connect(0, s.resp_msg_type) s.connect(s.req_msg_digit, s.resp_msg_digit)
def __init__(s, k=3): SUM_DATA_SIZE = int(math.ceil(math.log(50 * k, 2))) s.req_msg_data = InPort(RE_DATA_SIZE) s.resp_msg_digit = OutPort(4) # ctrl->dpath s.knn_wr_data_mux_sel = InPort(1) s.knn_wr_addr = InPort(int(math.ceil(math.log(k * DIGIT, 2)))) # max 30 s.knn_rd_addr = InPort(int(math.ceil(math.log(k * DIGIT, 2)))) # max 30 s.knn_wr_en = InPort(1) s.vote_wr_data_mux_sel = InPort(1) s.vote_wr_addr = InPort(int(math.ceil(math.log(DIGIT, 2)))) # max 10 s.vote_rd_addr = InPort(int(math.ceil(math.log(DIGIT, 2)))) # max 10 s.vote_wr_en = InPort(1) s.FindMax_req_val = InPort(1) s.FindMax_resp_rdy = InPort(1) s.FindMin_req_val = InPort(1) s.FindMin_resp_rdy = InPort(1) s.msg_data_reg_en = InPort(1) s.msg_idx_reg_en = InPort(1) # dpath->ctrl s.FindMax_req_rdy = OutPort(1) s.FindMax_resp_val = OutPort(1) s.FindMax_resp_idx = OutPort(int(math.ceil(math.log(k, 2)))) # max 3 s.FindMin_req_rdy = OutPort(1) s.FindMin_resp_val = OutPort(1) s.isSmaller = OutPort(1) # internal wires s.knn_rd_data = Wire(Bits(RE_DATA_SIZE)) s.knn_wr_data = Wire(Bits(RE_DATA_SIZE)) s.subtractor_out = Wire(Bits(SUM_DATA_SIZE)) s.adder_out = Wire(Bits(SUM_DATA_SIZE)) s.vote_rd_data = Wire(Bits(SUM_DATA_SIZE)) s.vote_wr_data = Wire(Bits(SUM_DATA_SIZE)) s.FindMax_req_data = Wire(Bits(RE_DATA_SIZE)) s.FindMax_resp_data = Wire(Bits(RE_DATA_SIZE)) s.FindMin_req_data = Wire(Bits(SUM_DATA_SIZE)) s.FindMin_resp_data = Wire(Bits(SUM_DATA_SIZE)) s.FindMin_resp_idx = Wire(Bits(int(math.ceil(math.log(DIGIT, 2))))) # max 10 # Req msg data Register s.req_msg_data_q = Wire(Bits(RE_DATA_SIZE)) s.req_msg_data_reg = m = RegEnRst(RE_DATA_SIZE) s.connect_dict({ m.en: s.msg_data_reg_en, m.in_: s.req_msg_data, m.out: s.req_msg_data_q }) # knn_wr_data Mux s.knn_wr_data_mux = m = Mux(RE_DATA_SIZE, 2) s.connect_dict({ m.sel: s.knn_wr_data_mux_sel, m.in_[0]: 50, m.in_[1]: s.req_msg_data_q, m.out: s.knn_wr_data }) # register file knn_table s.knn_table = m = RegisterFile(dtype=Bits(RE_DATA_SIZE), nregs=k * DIGIT, rd_ports=1, wr_ports=1, const_zero=False) s.connect_dict({ m.rd_addr[0]: s.knn_rd_addr, m.rd_data[0]: s.knn_rd_data, m.wr_addr: s.knn_wr_addr, m.wr_data: s.knn_wr_data, m.wr_en: s.knn_wr_en }) # vote_wr_data Mux s.vote_wr_data_mux = m = Mux(SUM_DATA_SIZE, 2) s.connect_dict({ m.sel: s.vote_wr_data_mux_sel, m.in_[0]: 50 * k, m.in_[1]: s.adder_out, m.out: s.vote_wr_data }) # register file knn_vote s.knn_vote = m = RegisterFile(dtype=Bits(SUM_DATA_SIZE), nregs=DIGIT, rd_ports=1, wr_ports=1, const_zero=False) s.connect_dict({ m.rd_addr[0]: s.vote_rd_addr, m.rd_data[0]: s.vote_rd_data, m.wr_addr: s.vote_wr_addr, m.wr_data: s.vote_wr_data, m.wr_en: s.vote_wr_en }) # Find max value of knn_table for a given digit s.connect_wire(s.knn_rd_data, s.FindMax_req_data) s.findmax = m = FindMaxPRTL(RE_DATA_SIZE, k) s.connect_dict({ m.req.val: s.FindMax_req_val, m.req.rdy: s.FindMax_req_rdy, m.req.msg.data: s.FindMax_req_data, m.resp.val: s.FindMax_resp_val, m.resp.rdy: s.FindMax_resp_rdy, m.resp.msg.data: s.FindMax_resp_data, m.resp.msg.idx: s.FindMax_resp_idx }) # Less than comparator s.knn_LtComparator = m = LtComparator(RE_DATA_SIZE) s.connect_dict({ m.in0: s.req_msg_data_q, m.in1: s.FindMax_resp_data, m.out: s.isSmaller }) # Zero extender s.FindMax_resp_data_zext = Wire(Bits(SUM_DATA_SIZE)) s.FindMax_resp_data_zexter = m = ZeroExtender(RE_DATA_SIZE, SUM_DATA_SIZE) s.connect_dict({ m.in_: s.FindMax_resp_data, m.out: s.FindMax_resp_data_zext, }) # Subtractor s.subtractor = m = Subtractor(SUM_DATA_SIZE) s.connect_dict({ m.in0: s.vote_rd_data, m.in1: s.FindMax_resp_data_zext, m.out: s.subtractor_out }) # Zero extender s.req_msg_data_zext = Wire(Bits(SUM_DATA_SIZE)) s.req_msg_data_zexter = m = ZeroExtender(RE_DATA_SIZE, SUM_DATA_SIZE) s.connect_dict({ m.in_: s.req_msg_data_q, m.out: s.req_msg_data_zext, }) # Adder s.adder = m = Adder(SUM_DATA_SIZE) s.connect_dict({ m.in0: s.subtractor_out, m.in1: s.req_msg_data_zext, m.cin: 0, m.out: s.adder_out }) # Find min value of knn_vote, return digit s.connect_wire(s.vote_rd_data, s.FindMin_req_data) s.findmin = m = FindMinPRTL(SUM_DATA_SIZE, DIGIT) s.connect_dict({ m.req.val: s.FindMin_req_val, m.req.rdy: s.FindMin_req_rdy, m.req.msg.data: s.FindMin_req_data, m.resp.val: s.FindMin_resp_val, m.resp.rdy: s.FindMin_resp_rdy, m.resp.msg.data: s.FindMin_resp_data, m.resp.msg.digit: s.FindMin_resp_idx }) # Resp idx Register s.resp_msg_idx_q = Wire(Bits(int(math.ceil(math.log(DIGIT, 2))))) s.req_msg_idx_reg = m = RegEnRst(int(math.ceil(math.log(DIGIT, 2)))) s.connect_dict({ m.en: s.msg_idx_reg_en, m.in_: s.FindMin_resp_idx, m.out: s.resp_msg_idx_q }) # connect output idx s.connect(s.resp_msg_idx_q, s.resp_msg_digit)
def __init__(s, mapper_num=3, nbits=6, k=3, rst_value=50): addr_nbits = int(math.ceil(math.log(k, 2))) sum_nbits = int(math.ceil(math.log((2**nbits - 1) * k, 2))) # interface s.in_ = [InPort(nbits) for _ in range(mapper_num)] s.out = OutPort(sum_nbits) s.rst = InPort(1) # internal wires s.min_dist = Wire(Bits(nbits)) s.max_dist = Wire(Bits(nbits)) s.max_idx = Wire(Bits(addr_nbits)) s.isLess = Wire(Bits(1)) s.rd_data = Wire[k](Bits(nbits)) s.wr_data = Wire(Bits(nbits)) s.wr_addr = Wire[k](Bits(addr_nbits)) s.wr_en = Wire[k](Bits(1)) # find min of inputs s.findmin = FindMin(nbits, mapper_num) for i in xrange(mapper_num): s.connect_pairs(s.findmin.in_[i], s.in_[i]) s.connect_pairs(s.findmin.out, s.min_dist) # find max in knn_table s.findmax = FindMaxIdx(nbits, k) for i in xrange(k): s.connect_pairs(s.findmax.in_[i], s.rd_data[i]) s.connect_pairs(s.findmax.out, s.max_dist) s.connect_pairs(s.findmax.idx, s.max_idx) # compare min_dist and max_dist s.cmp = LtComparator(nbits) s.connect_pairs(s.cmp.in0, s.min_dist, s.cmp.in1, s.max_dist, s.cmp.out, s.isLess) # choose the smaller one write back @s.combinational def comb_logic(): if (s.rst == 1): for i in xrange(k): s.wr_en[i].value = 1 else: for i in xrange(k): if (i == s.max_idx): s.wr_en[i].value = s.isLess else: s.wr_en[i].value = 0 # mux for wr_data s.wr_data_mux = m = Mux(nbits, 2) s.connect_dict({ m.in_[0]: s.min_dist, m.in_[1]: rst_value, m.sel: s.rst, m.out: s.wr_data }) # Registers s.regs = [RegEnRst(nbits, rst_value) for i in xrange(k)] for i in xrange(k): s.connect_pairs(s.regs[i].in_, s.wr_data, s.regs[i].out, s.rd_data[i], s.regs[i].en, s.wr_en[i]) # sum of knn_table s.addertree = AdderTree(nbits, k) for i in xrange(k): s.connect_pairs(s.addertree.in_[i], s.rd_data[i]) s.connect(s.addertree.out, s.out)
def __init__(s): #================================================================== # Interfaces #================================================================== s.a = InPort(32) # s.look_ahead_cnt = InPort ( 6 ) s.shamt = OutPort(6) #================================================================== # Structure #================================================================== # Substractor (s.a - 1) s.sub = m = Subtractor(32) s.sub_out = Wire(32) s.connect_pairs( m.in0, s.a, m.in1, 1, m.out, s.sub_out, ) # Right shifter s.rshift = m = RightLogicalShifter(32) s.xor_out = Wire(32) s.xor_not_out = Wire(32) s.rshift_out = Wire(32) s.and_out = Wire(32) s.connect_pairs( m.in_, s.xor_not_out, m.shamt, 1, m.out, s.rshift_out, ) # Encoder s.encoder = m = Encoder() s.encoder_out = Wire(6) s.connect_pairs( m.in_, s.and_out, m.out, s.encoder_out, ) # MUX s.mux = m = Mux(6, 2) s.mux_sel = Wire(MUX_SEL_NBITS) s.connect_pairs( m.sel, s.mux_sel, m.in_[MUX_SEL_ENCODER], s.encoder_out, m.in_[MUX_SEL_SKIP], 32, # m.in_[ MUX_SEL_SKIP ], s.look_ahead_cnt, m.out, s.shamt, ) #================================================================== # Combinational Logic #================================================================== @s.combinational def xor_block(): s.xor_out.value = s.a ^ s.sub_out s.xor_not_out.value = ~(s.a ^ s.sub_out) @s.combinational def and_block(): s.and_out.value = s.rshift_out & s.xor_out @s.combinational def mux_sel_block(): s.mux_sel.value = s.encoder_out[0] | s.encoder_out[1] | \ s.encoder_out[2] | s.encoder_out[3] | \ s.encoder_out[4] | s.encoder_out[5]
def __init__(s): #--------------------------------------------------------------------- # Interface #--------------------------------------------------------------------- s.req_msg_a = InPort(16) s.req_msg_b = InPort(16) s.resp_msg = OutPort(16) # Control signals (ctrl -> dpath) s.a_mux_sel = InPort(A_MUX_SEL_NBITS) s.a_reg_en = InPort(1) s.b_mux_sel = InPort(B_MUX_SEL_NBITS) s.b_reg_en = InPort(1) # Status signals (dpath -> ctrl) s.is_b_zero = OutPort(1) s.is_a_lt_b = OutPort(1) #--------------------------------------------------------------------- # Structural composition #--------------------------------------------------------------------- # A mux s.sub_out = Wire(16) s.b_reg_out = Wire(16) s.a_mux = m = Mux(16, 3) s.connect_dict({ m.sel: s.a_mux_sel, m.in_[A_MUX_SEL_IN]: s.req_msg_a, m.in_[A_MUX_SEL_SUB]: s.sub_out, m.in_[A_MUX_SEL_B]: s.b_reg_out, }) # A register s.a_reg = m = RegEn(16) s.connect_dict({ m.en: s.a_reg_en, m.in_: s.a_mux.out, }) # B mux s.b_mux = m = Mux(16, 2) s.connect_dict({ m.sel: s.b_mux_sel, m.in_[B_MUX_SEL_A]: s.a_reg.out, m.in_[B_MUX_SEL_IN]: s.req_msg_b, }) # B register s.b_reg = m = RegEn(16) s.connect_dict({ m.en: s.b_reg_en, m.in_: s.b_mux.out, m.out: s.b_reg_out, }) # Zero compare s.b_zero = m = ZeroComparator(16) s.connect_dict({ m.in_: s.b_reg.out, m.out: s.is_b_zero, }) # Less-than comparator s.a_lt_b = m = LtComparator(16) s.connect_dict({ m.in0: s.a_reg.out, m.in1: s.b_reg.out, m.out: s.is_a_lt_b }) # Subtractor s.sub = m = Subtractor(16) s.connect_dict({ m.in0: s.a_reg.out, m.in1: s.b_reg.out, m.out: s.sub_out, }) # connect to output port s.connect(s.sub.out, s.resp_msg)
def __init__(s): #================================================================== # Interfaces #================================================================== s.req_msg_a = InPort(32) s.req_msg_b = InPort(32) s.resp_msg = OutPort(32) # Control signals s.a_mux_sel = InPort(A_MUX_SEL_NBITS) s.b_mux_sel = InPort(B_MUX_SEL_NBITS) s.result_mux_sel = InPort(RES_MUX_SEL_NBITS) s.add_mux_sel = InPort(ADD_MUX_SEL_NBITS) s.result_en = InPort(1) s.result_sign = InPort(OUT_MUX_SEL_NBITS) # Status signals s.b_lsb = OutPort(1) s.a_msb = OutPort(1) s.b_msb = OutPort(1) s.to_ctrl_shamt = OutPort(6) # shamt input to shifters, calculated by ShamtGen s.shamt = Wire(6) # Binary representation of the multiplier s.bit_string = Wire(32) #================================================================== # Structure #================================================================== # A Mux s.in_a = Wire(32) # Take the absolute value of the input @s.combinational def sign_handling_a(): s.in_a.value = s.req_msg_a if ~s.req_msg_a[31] \ else (~s.req_msg_a) + 1 s.l_shift_out = Wire(32) s.a_mux = m = Mux(32, 2) s.connect_pairs( m.sel, s.a_mux_sel, m.in_[A_MUX_SEL_IN], s.in_a, m.in_[A_MUX_SEL_SHIFT], s.l_shift_out, ) # A Register s.a_reg = m = Reg(32) s.connect(m.in_, s.a_mux.out) # Left Shifter s.l_shift = m = LeftLogicalShifter(32, 6) s.connect_pairs( m.in_, s.a_reg.out, m.shamt, s.shamt, m.out, s.l_shift_out, ) # B Mux s.in_b = Wire(32) # Take the absolute value of the input @s.combinational def sign_handling_b(): s.in_b.value = s.req_msg_b if ~s.req_msg_b[31] \ else (~s.req_msg_b) + 1 s.r_shift_out = Wire(32) s.b_mux = m = Mux(32, 2) s.connect_pairs( m.sel, s.b_mux_sel, m.in_[B_MUX_SEL_IN], s.in_b, m.in_[B_MUX_SEL_SHIFT], s.r_shift_out, ) # B Register s.b_reg = m = Reg(32) s.connect(m.in_, s.b_mux.out) # Take the higher 31 bits and add 0 in the high order # The ShamtGen module will generate the appropriate shamt # according to the number of consecutive zeros in lower s.bit_string @s.combinational def bit_string_block(): s.bit_string.value = concat(Bits(1, 0), s.b_reg.out[1:32]) # Right Shifter s.r_shift = m = RightLogicalShifter(32, 6) s.connect_pairs( m.in_, s.b_reg.out, m.shamt, s.shamt, m.out, s.r_shift_out, ) # Result Mux s.add_mux_out = Wire(32) s.result_mux = m = Mux(32, 2) s.connect_pairs( m.sel, s.result_mux_sel, m.in_[RES_MUX_SEL_ZERO], 0, m.in_[RES_MUX_SEL_ADD], s.add_mux_out, ) # Result Register s.res_reg = m = RegEn(32) s.connect_pairs(m.in_, s.result_mux.out, m.en, s.result_en) # Adder s.adder = m = Adder(32) s.connect_pairs( m.in0, s.a_reg.out, m.in1, s.res_reg.out, m.cin, 0, ) # Add Mux s.add_mux = m = Mux(32, 2) s.connect_pairs( m.sel, s.add_mux_sel, m.in_[ADD_MUX_SEL_ADD], s.adder.out, m.in_[ADD_MUX_SEL_RES], s.res_reg.out, m.out, s.add_mux_out, ) # ShamtGen s.shamt_gen = m = ShamtGenPRTL() s.connect_pairs( m.a, s.bit_string, m.shamt, s.shamt, ) # Forward shamt to control unit so the counter can update # accordingly @s.combinational def to_ctrl_shamt_block(): s.to_ctrl_shamt.value = s.shamt # Output MUX s.res_neg = Wire(32) # Generate -res in case the result is negative @s.combinational def twos_compl_block(): s.res_neg.value = (~s.res_reg.out) + 1 s.out_mux = m = Mux(32, 2) s.connect_pairs( m.sel, s.result_sign, m.in_[OUT_MUX_SEL_POS], s.res_reg.out, m.in_[OUT_MUX_SEL_NEG], s.res_neg, m.out, s.resp_msg, ) # Connect status signals s.connect(s.b_reg.out[0], s.b_lsb) s.connect(s.req_msg_a[31], s.a_msb) s.connect(s.req_msg_b[31], s.b_msb)
def __init__(s, num_cores=1): #--------------------------------------------------------------------- # Interface #--------------------------------------------------------------------- # Parameters s.core_id = InPort(32) # imem ports s.imemreq_msg = OutPort(MemReqMsg4B) s.imemresp_msg_data = InPort(32) # dmem ports s.dmemreq_msg_addr = OutPort(32) s.dmemreq_msg_data = OutPort(32) s.dmemresp_msg_data = InPort(32) # mngr ports s.mngr2proc_data = InPort(32) s.proc2mngr_data = OutPort(32) # Control signals (ctrl->dpath) s.reg_en_F = InPort(1) s.pc_sel_F = InPort(2) s.reg_en_D = InPort(1) s.op1_sel_D = InPort(1) s.op2_sel_D = InPort(2) s.csrr_sel_D = InPort(2) s.imm_type_D = InPort(3) s.imul_req_val_D = InPort(1) s.reg_en_X = InPort(1) s.alu_fn_X = InPort(4) s.ex_result_sel_X = InPort(2) s.imul_resp_rdy_X = InPort(1) s.reg_en_M = InPort(1) s.wb_result_sel_M = InPort(1) s.reg_en_W = InPort(1) s.rf_waddr_W = InPort(5) s.rf_wen_W = InPort(1) s.stats_en_wen_W = InPort(1) # Status signals (dpath->Ctrl) s.inst_D = OutPort(32) s.imul_req_rdy_D = OutPort(1) s.br_cond_eq_X = OutPort(1) s.br_cond_ltu_X = OutPort(1) s.br_cond_lt_X = OutPort(1) s.imul_resp_val_X = OutPort(1) # stats_en output s.stats_en = OutPort(1) #--------------------------------------------------------------------- # F stage #--------------------------------------------------------------------- s.pc_F = Wire(32) s.pc_plus4_F = Wire(32) # PC+4 incrementer s.pc_incr_F = m = Incrementer(nbits=32, increment_amount=4) s.connect_pairs(m.in_, s.pc_F, m.out, s.pc_plus4_F) # forward delaration for branch target and jal target s.br_target_X = Wire(32) s.jal_target_D = Wire(32) s.jalr_target_X = Wire(32) # PC sel mux s.pc_sel_mux_F = m = Mux(dtype=32, nports=4) s.connect_pairs(m.in_[0], s.pc_plus4_F, m.in_[1], s.br_target_X, m.in_[2], s.jal_target_D, m.in_[3], s.jalr_target_X, m.sel, s.pc_sel_F) @s.combinational def imem_req_F(): s.imemreq_msg.addr.value = s.pc_sel_mux_F.out # PC register s.pc_reg_F = m = RegEnRst(dtype=32, reset_value=c_reset_vector - 4) s.connect_pairs(m.en, s.reg_en_F, m.in_, s.pc_sel_mux_F.out, m.out, s.pc_F) #--------------------------------------------------------------------- # D stage #--------------------------------------------------------------------- # PC reg in D stage # This value is basically passed from F stage for the corresponding # instruction to use, e.g. branch to (PC+imm) s.pc_reg_D = m = RegEnRst(dtype=32) s.connect_pairs( m.en, s.reg_en_D, m.in_, s.pc_F, ) # Instruction reg s.inst_D_reg = m = RegEnRst(dtype=32, reset_value=c_reset_inst) s.connect_pairs( m.en, s.reg_en_D, m.in_, s.imemresp_msg_data, m.out, s.inst_D # to ctrl ) # Register File # The rf_rdata_D wires, albeit redundant in some sense, are used to # remind people these data are from D stage. s.rf_rdata0_D = Wire(32) s.rf_rdata1_D = Wire(32) s.rf_wdata_W = Wire(32) s.rf = m = RegisterFile(dtype=32, nregs=32, rd_ports=2, const_zero=True) s.connect_pairs(m.rd_addr[0], s.inst_D[RS1], m.rd_addr[1], s.inst_D[RS2], m.rd_data[0], s.rf_rdata0_D, m.rd_data[1], s.rf_rdata1_D, m.wr_en, s.rf_wen_W, m.wr_addr, s.rf_waddr_W, m.wr_data, s.rf_wdata_W) # Immediate generator s.imm_gen_D = m = ImmGenPRTL() s.connect_pairs(m.imm_type, s.imm_type_D, m.inst, s.inst_D) # csrr sel mux s.csrr_sel_mux_D = m = Mux(dtype=32, nports=3) s.connect_pairs( m.in_[0], s.mngr2proc_data, m.in_[1], num_cores, m.in_[2], s.core_id, m.sel, s.csrr_sel_D, ) # op1 sel mux s.op1_sel_mux_D = m = Mux(dtype=32, nports=2) s.connect_pairs( m.in_[0], s.rf_rdata0_D, m.in_[1], s.pc_reg_D.out, m.sel, s.op1_sel_D, ) # op2 sel mux # This mux chooses among RS2, imm, and the output of the above csrr # sel mux. Basically we are using two muxes here for pedagogy. s.op2_sel_mux_D = m = Mux(dtype=32, nports=3) s.connect_pairs( m.in_[0], s.rf_rdata1_D, m.in_[1], s.imm_gen_D.imm, m.in_[2], s.csrr_sel_mux_D.out, m.sel, s.op2_sel_D, ) # Risc-V always calcs branch/jal target by adding imm(generated above) to PC s.pc_plus_imm_D = m = Adder(32) s.connect_pairs( m.in0, s.pc_reg_D.out, m.in1, s.imm_gen_D.imm, m.out, s.jal_target_D, ) #--------------------------------------------------------------------- # X stage #--------------------------------------------------------------------- # br_target_reg_X # Since branches are resolved in X stage, we register the target, # which is already calculated in D stage, to X stage. s.br_target_reg_X = m = RegEnRst(dtype=32, reset_value=0) s.connect_pairs(m.en, s.reg_en_X, m.in_, s.pc_plus_imm_D.out, m.out, s.br_target_X) # op1 reg s.op1_reg_X = m = RegEnRst(dtype=32, reset_value=0) s.connect_pairs( m.en, s.reg_en_X, m.in_, s.op1_sel_mux_D.out, ) # op2 reg s.op2_reg_X = m = RegEnRst(dtype=32, reset_value=0) s.connect_pairs( m.en, s.reg_en_X, m.in_, s.op2_sel_mux_D.out, ) # dmemreq data reg s.dmem_write_data_reg_X = m = RegEnRst(dtype=32, reset_value=0) s.connect_pairs( m.en, s.reg_en_X, m.in_, s.rf_rdata1_D, ) # pc reg s.pc_reg_X = m = RegEnRst(dtype=32, reset_value=0) s.connect_pairs( m.en, s.reg_en_X, m.in_, s.pc_reg_D.out, ) # ALU s.alu_X = m = AluPRTL() s.connect_pairs( m.in0, s.op1_reg_X.out, m.in1, s.op2_reg_X.out, m.fn, s.alu_fn_X, m.ops_eq, s.br_cond_eq_X, m.ops_ltu, s.br_cond_ltu_X, m.ops_lt, s.br_cond_lt_X, m.out, s.jalr_target_X, ) # Multiplier s.imul_X = m = IntMulAltRTL() s.connect_pairs( m.req.msg[0:32], s.op1_sel_mux_D.out, m.req.msg[32:64], s.op2_sel_mux_D.out, m.req.val, s.imul_req_val_D, m.req.rdy, s.imul_req_rdy_D, m.resp.val, s.imul_resp_val_X, m.resp.rdy, s.imul_resp_rdy_X, ) # PC+4 Incrementer s.pc_incr_X = m = Incrementer(nbits=32, increment_amount=4) s.connect_pairs( m.in_, s.pc_reg_X.out, ) # ex result Mux s.ex_result_sel_mux_X = m = Mux(dtype=32, nports=3) s.connect_pairs( m.in_[0], s.pc_incr_X.out, m.in_[1], s.alu_X.out, m.in_[2], s.imul_X.resp.msg, m.sel, s.ex_result_sel_X, ) # dmemreq address s.connect(s.dmemreq_msg_addr, s.alu_X.out) s.connect(s.dmemreq_msg_data, s.dmem_write_data_reg_X.out) #--------------------------------------------------------------------- # M stage #--------------------------------------------------------------------- # Alu execution result reg s.ex_result_reg_M = m = RegEnRst(dtype=32, reset_value=0) s.connect_pairs( m.en, s.reg_en_M, m.in_, s.ex_result_sel_mux_X.out, ) # Writeback result selection mux s.wb_result_sel_mux_M = m = Mux(dtype=32, nports=2) s.connect_pairs(m.in_[0], s.ex_result_reg_M.out, m.in_[1], s.dmemresp_msg_data, m.sel, s.wb_result_sel_M) #--------------------------------------------------------------------- # W stage #--------------------------------------------------------------------- # Writeback result reg s.wb_result_reg_W = m = RegEnRst(dtype=32, reset_value=0) s.connect_pairs( m.en, s.reg_en_W, m.in_, s.wb_result_sel_mux_M.out, ) s.connect(s.proc2mngr_data, s.wb_result_reg_W.out) s.connect(s.rf_wdata_W, s.wb_result_reg_W.out) s.stats_en_reg_W = m = RegEnRst(dtype=32, reset_value=0) # stats_en logic s.connect_pairs( m.en, s.stats_en_wen_W, m.in_, s.wb_result_reg_W.out, ) @s.combinational def stats_en_logic_W(): s.stats_en.value = any( s.stats_en_reg_W.out) # reduction with bitwise OR
def __init__(s): #================================================================== # Interfaces #================================================================== s.req_msg_a = InPort(32) s.req_msg_b = InPort(32) s.resp_msg = OutPort(32) # Control signals s.a_mux_sel = InPort(A_MUX_SEL_NBITS) s.b_mux_sel = InPort(B_MUX_SEL_NBITS) s.result_mux_sel = InPort(RES_MUX_SEL_NBITS) s.add_mux_sel = InPort(ADD_MUX_SEL_NBITS) s.result_en = InPort(1) s.result_sign = InPort(OUT_MUX_SEL_NBITS) # Status signals s.b_lsb = OutPort(1) s.a_msb = OutPort(1) s.b_msb = OutPort(1) #================================================================== # Structure #================================================================== # A Mux s.in_a = Wire(32) # Take the abs value of the input @s.combinational def sign_handling_a(): s.in_a.value = s.req_msg_a if ~s.req_msg_a[31] \ else (~s.req_msg_a) + 1 s.l_shift_out = Wire(32) s.a_mux = m = Mux(32, 2) s.connect_pairs( m.sel, s.a_mux_sel, m.in_[A_MUX_SEL_IN], s.in_a, m.in_[A_MUX_SEL_SHIFT], s.l_shift_out, ) # A Register s.a_reg = m = Reg(32) s.connect(m.in_, s.a_mux.out) # Left Shifter s.l_shift = m = LeftLogicalShifter(32) s.connect_pairs( m.in_, s.a_reg.out, m.shamt, 1, m.out, s.l_shift_out, ) # B Mux s.in_b = Wire(32) # Take the abs value of the input @s.combinational def sign_handling_b(): s.in_b.value = s.req_msg_b if ~s.req_msg_b[31] \ else (~s.req_msg_b) + 1 s.r_shift_out = Wire(32) s.b_mux = m = Mux(32, 2) s.connect_pairs( m.sel, s.b_mux_sel, m.in_[B_MUX_SEL_IN], s.in_b, m.in_[B_MUX_SEL_SHIFT], s.r_shift_out, ) # B Register s.b_reg = m = Reg(32) s.connect(m.in_, s.b_mux.out) # Right Shifter s.r_shift = m = RightLogicalShifter(32) s.connect_pairs( m.in_, s.b_reg.out, m.shamt, 1, m.out, s.r_shift_out, ) # Result Mux s.add_mux_out = Wire(32) s.result_mux = m = Mux(32, 2) s.connect_pairs( m.sel, s.result_mux_sel, m.in_[RES_MUX_SEL_ZERO], 0, m.in_[RES_MUX_SEL_ADD], s.add_mux_out, ) # Result Register s.res_reg = m = RegEn(32) s.connect_pairs(m.in_, s.result_mux.out, m.en, s.result_en) # Adder s.adder = m = Adder(32) s.connect_pairs( m.in0, s.a_reg.out, m.in1, s.res_reg.out, m.cin, 0, ) # Add Mux s.add_mux = m = Mux(32, 2) s.connect_pairs( m.sel, s.add_mux_sel, m.in_[ADD_MUX_SEL_ADD], s.adder.out, m.in_[ADD_MUX_SEL_RES], s.res_reg.out, m.out, s.add_mux_out, ) # Output MUX s.res_neg = Wire(32) # Generate -res in case the output is negative @s.combinational def twos_compl_block(): s.res_neg.value = (~s.res_reg.out) + 1 s.out_mux = m = Mux(32, 2) s.connect_pairs( m.sel, s.result_sign, m.in_[OUT_MUX_SEL_POS], s.res_reg.out, m.in_[OUT_MUX_SEL_NEG], s.res_neg, m.out, s.resp_msg, ) # Connect status signals s.connect(s.b_reg.out[0], s.b_lsb) s.connect(s.req_msg_a[31], s.a_msb) s.connect(s.req_msg_b[31], s.b_msb)
def __init__(s): # interface from source and sink s.in_mem = InValRdyBundle(nWid) s.out_mem = OutValRdyBundle(nWid) s.cgra = CgraRTL() s.fsm = fsm() #Memory s.ocm = TestMemoryFuture(MemMsg( 0, 32, nWid)) # opaque field, address, data width # Input Mux s.input_mux = m = Mux(dtype=MemMsg(0, 32, nWid), nports=3) s.in_mem_wire = Wire(MemMsg(0, 32, nWid)) @s.combinational def logic(): s.in_mem_wire.msg[0:16].value = s.in_mem.msg[0:16] # data s.in_mem_wire.msg[16:18].value = s.in_mem.msg[ 16:18] # length in bytes of read or write s.in_mem_wire.msg[18:50].value = s.in_mem.msg[18: 50] # address field s.in_mem_wire.msg[50:51].value = s.in_mem.msg[ 50:51] # type: read or write s.connect_pairs( m.in_[0], s.in_mem_wire, m.in_[1], s.cgra.ocmreqs[0], m.in_[2], s.cgra.ocmreqs[1], m.sel, s.fsm.in_mux_sel, # Add later in Cpath ) # Memory Connections s.connect(s.ocm.reqs[0], s.input_mux.out) # Demux logic @s.combinational def logic(): if s.fsm.out_mux_sel.value == 0: s.out_mem.msg[0:16] = s.ocm.resps[0].msg[0:16] s.out_mem.msg[16:18] = s.ocm.resps[0].msg[16:18] s.out_mem.msg[18:50] = s.ocm.resps[0].msg[18:50] s.out_mem.msg[50:51] = s.ocm.resps[0].msg[50:51] elif s.fsm.out_mux_sel.value == 1: s.cgra.ocmresps[0].msg.value = s.ocm.resps[0].msg elif s.fsm.out_mux_sel.value == 2: s.cgra.ocmresps[1].msg.value = s.ocm.resps[0].msg # queue for control word #s.ctr_q = SingleElementBypassQueue[nPE](inst_msg()) #s.ctr_q = SingleElementNormalQueue[nPE](inst_msg()) s.ctr_q = NormalQueue[nPE](2, inst_msg()) # queue for cgra-to-fsm response s.resp_q = SingleElementBypassQueue[nPE](1) #s.resp_q = SingleElementNormalQueue[nPE](1) for x in range(nPE): s.connect(s.cgra.out_fsm[x], s.resp_q[x].enq) s.connect(s.resp_q[x].deq, s.fsm.in_[x]) s.connect(s.fsm.out[x], s.ctr_q[x].enq) s.connect(s.ctr_q[x].deq, s.cgra.in_control[x]) s.connect(s.in_mem, s.cgra.in_mem) s.connect(s.out_mem, s.cgra.out_mem)
def __init__(s, cpu_ifc_types): s.cpu_ifc_req = InValRdyBundle(cpu_ifc_types.req) s.cpu_ifc_resp = OutValRdyBundle(cpu_ifc_types.resp) size = cpu_ifc_types.req.data print(size, type(size)) s.cs = InPort(CtrlSignals()) s.ss = OutPort(StatusSignals()) # Interface wires s.in_msg_a = Wire(size) s.in_msg_b = Wire(size) s.out_msg = Wire(size) #----------------------------------------------------------------------- # Connectivity and Logic #----------------------------------------------------------------------- print(s.cpu_ifc_req.msg.data.nbits) print(s.in_msg_a.nbits) s.connect(s.cpu_ifc_req.msg.data, s.in_msg_a) s.connect(s.cpu_ifc_req.msg.data, s.in_msg_b) s.connect(s.cpu_ifc_resp.msg.data, s.out_msg) #--------------------------------------------------------------------- # Datapath Structural Composition #--------------------------------------------------------------------- s.sub_out = Wire(size) s.b_reg_out = Wire(size) # A mux s.a_mux = m = Mux(size, 4) s.connect_dict({ m.sel: s.cs.a_mux_sel, m.in_[A_MUX_SEL_IN]: s.in_msg_a, m.in_[A_MUX_SEL_SUB]: s.sub_out, m.in_[A_MUX_SEL_B]: s.b_reg_out, m.in_[A_MUX_SEL_C]: s.b_reg_out, }) # A register s.a_reg = m = regs.RegEn(size) s.connect_dict({ m.en: s.cs.a_reg_en, m.in_: s.a_mux.out, }) # B mux s.b_mux = m = Mux(size, 2) s.connect_dict({ m.sel: s.cs.b_mux_sel, m.in_[B_MUX_SEL_A]: s.a_reg.out, m.in_[B_MUX_SEL_IN]: s.in_msg_b, }) # B register s.b_reg = m = regs.RegEn(size) s.connect_dict({ m.en: s.cs.b_reg_en, m.in_: s.b_mux.out, m.out: s.b_reg_out, }) # Zero compare s.b_zero = m = arith.ZeroComparator(size) s.connect_dict({ m.in_: s.b_reg.out, m.out: s.ss.is_b_zero, }) # Less-than comparator s.a_lt_b = m = arith.LtComparator(size) s.connect_dict({ m.in0: s.a_reg.out, m.in1: s.b_reg.out, m.out: s.ss.is_a_lt_b }) # Subtractor s.sub = m = arith.Subtractor(size) s.connect_dict({ m.in0: s.a_reg.out, m.in1: s.b_reg.out, m.out: s.sub_out, }) # connect to output port s.connect(s.sub.out, s.out_msg)
def __init__(s, nbits): nbitsx2 = nbits * 2 dtype = mk_bits(nbits) dtypex2 = mk_bits(nbitsx2) s.req_msg = InVPort(dtypex2) s.resp_msg = OutVPort(dtypex2) # Status signals s.sub_negative1 = OutVPort(Bits1) s.sub_negative2 = OutVPort(Bits1) # Control signals s.quotient_mux_sel = InVPort(Bits1) s.quotient_reg_en = InVPort(Bits1) s.remainder_mux_sel = InVPort(Bits2) s.remainder_reg_en = InVPort(Bits1) s.divisor_mux_sel = InVPort(Bits1) # Dpath components s.remainder_mux = Mux(dtypex2, 3)(sel=s.remainder_mux_sel) @s.update def up_remainder_mux_in0(): s.remainder_mux.in_[R_MUX_SEL_IN] = dtypex2() s.remainder_mux.in_[R_MUX_SEL_IN][0:nbits] = s.req_msg[0:nbits] s.remainder_reg = RegEn(dtypex2)( in_=s.remainder_mux.out, en=s.remainder_reg_en, ) # lower bits of resp_msg save the remainder s.connect(s.resp_msg[0:nbits], s.remainder_reg.out[0:nbits]) s.divisor_mux = Mux(dtypex2, 2)(sel=s.divisor_mux_sel) @s.update def up_divisor_mux_in0(): s.divisor_mux.in_[D_MUX_SEL_IN] = dtypex2() s.divisor_mux.in_[D_MUX_SEL_IN][nbits - 1:nbitsx2 - 1] = s.req_msg[nbits:nbitsx2] s.divisor_reg = Reg(dtypex2)(in_=s.divisor_mux.out) s.quotient_mux = Mux(dtype, 2)(sel=s.quotient_mux_sel) s.connect(s.quotient_mux.in_[Q_MUX_SEL_0], 0) s.quotient_reg = RegEn(dtype)( in_=s.quotient_mux.out, en=s.quotient_reg_en, # higher bits of resp_msg save the quotient out=s.resp_msg[nbits:nbitsx2], ) # shamt should be 2 bits! s.quotient_lsh = LShifter(dtype, 2)(in_=s.quotient_reg.out) s.connect(s.quotient_lsh.shamt, 2) s.inc = Wire(Bits2) s.connect(s.sub_negative1, s.inc[1]) s.connect(s.sub_negative2, s.inc[0]) @s.update def up_quotient_inc(): s.quotient_mux.in_[Q_MUX_SEL_LSH] = s.quotient_lsh.out + ~s.inc # stage 1/2 s.sub1 = Subtractor(dtypex2)( in0=s.remainder_reg.out, in1=s.divisor_reg.out, out=s.remainder_mux.in_[R_MUX_SEL_SUB1], ) s.connect(s.sub_negative1, s.sub1.out[nbitsx2 - 1]) s.remainder_mid_mux = Mux(dtypex2, 2)( in_={ 0: s.sub1.out, 1: s.remainder_reg.out, }, sel=s.sub_negative1, ) s.divisor_rsh1 = RShifter(dtypex2, 1)(in_=s.divisor_reg.out, ) s.connect(s.divisor_rsh1.shamt, 1) # stage 2/2 s.sub2 = Subtractor(dtypex2)( in0=s.remainder_mid_mux.out, in1=s.divisor_rsh1.out, out=s.remainder_mux.in_[R_MUX_SEL_SUB2], ) s.connect(s.sub_negative2, s.sub2.out[nbitsx2 - 1]) s.divisor_rsh2 = RShifter(dtypex2, 1)( in_=s.divisor_rsh1.out, out=s.divisor_mux.in_[D_MUX_SEL_RSH], ) s.connect(s.divisor_rsh2.shamt, 1)
def elaborate_logic(s): #--------------------------------------------------------------------- # Datapath Structural Composition #--------------------------------------------------------------------- s.sub_out = Wire(32) s.b_reg_out = Wire(32) # A mux s.a_mux = m = Mux(32, 3) s.connect_dict({ m.sel: s.a_mux_sel, m.in_[A_MUX_SEL_IN]: s.in_msg_a, m.in_[A_MUX_SEL_SUB]: s.sub_out, m.in_[A_MUX_SEL_B]: s.b_reg_out, }) # A register s.a_reg = m = RegEn(32) s.connect_dict({ m.en: s.a_reg_en, m.in_: s.a_mux.out, }) # B mux s.b_mux = m = Mux(32, 2) s.connect_dict({ m.sel: s.b_mux_sel, m.in_[B_MUX_SEL_A]: s.a_reg.out, m.in_[B_MUX_SEL_IN]: s.in_msg_b, }) # B register s.b_reg = m = RegEn(32) s.connect_dict({ m.en: s.b_reg_en, m.in_: s.b_mux.out, m.out: s.b_reg_out, }) # Zero compare s.b_zero = m = arith.ZeroComparator(32) s.connect_dict({ m.in_: s.b_reg.out, m.out: s.is_b_zero, }) # Less-than comparator s.a_lt_b = m = arith.LtComparator(32) s.connect_dict({ m.in0: s.a_reg.out, m.in1: s.b_reg.out, m.out: s.is_a_lt_b }) # Subtractor s.sub = m = arith.Subtractor(32) s.connect_dict({ m.in0: s.a_reg.out, m.in1: s.b_reg.out, m.out: s.sub_out, }) # connect to output port s.connect(s.sub.out, s.out_msg)