def elaborate_logic(s): # input register wen, wben, addr, wdata s.wen_reg = Reg(1) s.connect(s.wen_reg.in_, s.wen) s.wben_reg = Reg(s.num_nbytes) s.connect(s.wben_reg.in_, s.wben) s.addr_reg = Reg(s.addr_nbits) s.connect(s.addr_reg.in_, s.addr) s.wdata_reg = Reg(s.data_nbits) s.connect(s.wdata_reg.in_, s.wdata) # instantiate a s.combinational SRAM s.SRAM = m = SRAMBytesComb_rst_1rw(s.num_entries, s.num_nbytes, reset_value=s.reset_value) s.connect_dict({ m.wen: s.wen_reg.out, m.wben: s.wben_reg.out, m.addr: s.addr_reg.out, m.wdata: s.wdata_reg.out, m.rdata: s.rdata })
def __init__(s, Type): s.enq = EnqIfcRTL(Type) s.deq = DeqIfcRTL(Type) s.buf = RegEn(Type)(in_=s.enq.msg) s.full = Reg(Bits1) s.byp_mux = Mux(Type, 2)( out=s.deq.msg, in_={ 0: s.enq.msg, 1: s.buf.out, }, sel=s.full.out, # full -- buf.out, empty -- bypass ) @s.update def up_bypq_set_enq_rdy(): s.enq.rdy = ~s.full.out @s.update def up_bypq_set_deq_rdy(): s.deq.rdy = s.full.out | s.enq.en # if enq is enabled deq must be rdy @s.update def up_bypq_full(): # enable buf <==> receiver marks deq.en=0 even if it sees deq.rdy=1 s.buf.en = ~s.deq.en & s.enq.en s.full.in_ = ~s.deq.en & (s.enq.en | s.full.out)
def __init__(s, type_): s.enq = EnqIfcRTL(type_) s.deq = DeqIfcRTL(type_) s.buf = RegEn(type_)(out=s.deq.msg, in_=enq.msg) s.full = Reg(Bits1) @s.update def up_normq_set_both_rdy(): s.enq.rdy = ~s.full.out s.deq.rdy = s.full.out @s.update def up_normq_full(): s.buf.en = s.enq.en s.full.in_ = ~s.deq.en & (s.enq.en | s.full.out)
def __init__(s, Type): s.enq = EnqIfcRTL(Type) s.deq = DeqIfcRTL(Type) s.buf = RegEn(Type)(out=s.deq.msg, in_=s.enq.msg) s.full = Reg(Bits1) @s.update def up_pipeq_set_deq_rdy(): s.deq.rdy = s.full.out @s.update def up_pipeq_set_enq_rdy(): s.enq.rdy = ~s.full.out | s.deq.en @s.update def up_pipeq_full(): s.buf.en = s.enq.en s.full.in_ = s.enq.en | (s.full.out & ~s.deq.en)
def __init__(s): s.req_val = InPort(1) s.req_rdy = OutPort(1) s.resp_val = OutPort(1) s.resp_rdy = InPort(1) s.req_msg_type = InPort(1) # Control signals (ctrl -> dpath) s.en_test = OutPort(1) s.en_train = OutPort(1) s.en_out = OutPort(1) s.sel_out = OutPort(1) s.sel = OutPort(6) # State element s.STATE_IDLE = 0 s.STATE_INIT = 1 s.STATE_CALC = 2 s.STATE_DONE = 3 s.state = RegRst(2, reset_value=s.STATE_IDLE) s.counter = Reg(DATA_NBITS) # State Transition Logic @s.combinational def state_transitions(): curr_state = s.state.out next_state = s.state.out # Transition out of IDLE state if (curr_state == s.STATE_IDLE): if ((s.req_val and s.req_rdy) and (s.req_msg_type == 1)): next_state = s.STATE_INIT elif ((s.req_val and s.req_rdy) and (s.req_msg_type == 0)): next_state = s.STATE_CALC # Transition out of INIT state if (curr_state == s.STATE_INIT): if ((s.req_val and s.req_rdy) and (s.req_msg_type == 0)): next_state = s.STATE_CALC # Transition out of CALC state if (curr_state == s.STATE_CALC): if ((s.counter.out + 1) == DATA_NBITS): next_state = s.STATE_DONE # Transition out of DONE state if (curr_state == s.STATE_DONE): if (s.resp_val and s.resp_rdy): next_state = s.STATE_IDLE s.state.in_.value = next_state # State Output Logic @s.combinational def state_outputs(): current_state = s.state.out # IDLE state if current_state == s.STATE_IDLE: s.req_rdy.value = 1 s.resp_val.value = 0 s.en_out.value = 0 s.en_test.value = 1 s.en_train.value = 1 s.sel.value = 0 s.sel_out.value = 0 s.counter.in_.value = 0 # INIT state elif current_state == s.STATE_INIT: s.req_rdy.value = 1 s.resp_val.value = 0 s.en_out.value = 0 s.en_test.value = 0 s.en_train.value = 1 s.sel.value = 0 s.sel_out.value = 0 s.counter.in_.value = 0 # CALC state elif current_state == s.STATE_CALC: s.req_rdy.value = 0 s.resp_val.value = 0 s.en_out.value = 1 if (s.counter.out == 0): s.sel_out.value = 0 else: s.sel_out.value = 1 s.en_test.value = 0 s.en_train.value = 0 s.sel.value = s.counter.out s.counter.in_.value = s.counter.out + 1 # DONE state elif current_state == s.STATE_DONE: s.req_rdy.value = 0 s.resp_val.value = 1 s.en_out.value = 0 s.en_test.value = 0 s.en_train.value = 0 s.sel.value = 0 s.sel_out.value = 0 s.counter.in_.value = 0
def __init__(s, mapper_num=10, reducer_num=1, train_size=600): TRAIN_DATA = train_size TRAIN_LOG = int(math.ceil(math.log(TRAIN_DATA, 2))) # import training data and store them into array training_data = [] for i in xrange(DIGIT): count = 0 filename = 'data/training_set_' + str(i) + '.dat' with open(filename, 'r') as f: for L in f: if (count > TRAIN_DATA - 1): break training_data.append(int(L.replace(',\n', ''), 16)) count = count + 1 # Top Level Interface s.in_ = InValRdyBundle(digitrecReqMsg()) s.out = OutValRdyBundle(digitrecRespMsg()) s.base = InPort(32) s.size = InPort(32) # Global Memory Interface s.gmem_req = OutValRdyBundle(MemReqMsg(8, 32, 64)) s.gmem_resp = InValRdyBundle(MemRespMsg(8, 64)) # Register File Interface s.regf_addr = OutPort[DIGIT](TRAIN_LOG) s.regf_data = OutPort[DIGIT](DATA_BITS) s.regf_wren = OutPort[DIGIT](1) s.regf_rdaddr = OutPort[mapper_num / reducer_num](TRAIN_LOG) # Mapper Interface s.map_req = OutPort[mapper_num](DATA_BITS) # Reducer Reset s.red_rst = OutPort(1) # Merger Interface s.merger_resp = InPort(DIGIT_LOG) # States s.STATE_IDLE = 0 # Idle state, scheduler waiting for top level to start s.STATE_SOURCE = 1 # Source state, handling with Test Source, getting base, size, ref info s.STATE_INIT = 2 # Init state, scheduler assigns input info to each Mapper s.STATE_START = 3 # Start state, scheduler gets test data, starts distributing and sorting s.STATE_WRITE = 4 # Write state, scheduler writes merger data to memory s.STATE_END = 5 # End state, shceduler loads all task from global memory and it is done s.state = RegRst(4, reset_value=s.STATE_IDLE) # Counters s.input_count = Wire(TEST_LOG) s.result_count = Wire(TEST_LOG) s.train_count_rd = Wire(TRAIN_LOG) s.train_count_wr = Wire(32) s.train_data_wr = Wire(1) s.train_data_rd = Wire(1) # Logic to Increment Counters @s.tick def counter(): if (s.gmem_req.val and s.gmem_req.msg.type_ == TYPE_READ): s.input_count.next = s.input_count + 1 if (s.gmem_req.val and s.gmem_req.msg.type_ == TYPE_WRITE): s.result_count.next = s.result_count + 1 if s.rst: s.train_count_rd.next = 0 elif s.train_data_rd: s.train_count_rd.next = s.train_count_rd + (mapper_num / DIGIT) if (s.train_data_wr): s.train_count_wr.next = s.train_count_wr + 1 # Signals s.go = Wire(1) # go signal tells scheduler to start scheduling s.done = Wire(1) # done signal indicates everything is done s.rst = Wire(1) # reset train count every test data processed # Reference data s.reference = Reg(dtype=DATA_BITS) # reference stores test data #--------------------------------------------------------------------- # Initialize Register File for Training data #--------------------------------------------------------------------- @s.combinational def traindata(): if s.train_data_wr: for i in xrange(DIGIT): s.regf_addr[i].value = s.train_count_wr s.regf_data[i].value = training_data[i * TRAIN_DATA + s.train_count_wr] s.regf_wren[i].value = 1 else: for i in xrange(DIGIT): s.regf_wren[i].value = 0 #--------------------------------------------------------------------- # Assign Task to Mapper Combinational Logic #--------------------------------------------------------------------- @s.combinational def mapper(): # broadcast train data to mapper for i in xrange(DIGIT): for j in xrange(mapper_num / DIGIT): if (s.train_data_rd): s.map_req[j * 10 + i].value = s.reference.out s.regf_rdaddr[j].value = s.train_count_rd + j #--------------------------------------------------------------------- # Task State Transition Logic #--------------------------------------------------------------------- @s.combinational def state_transitions(): curr_state = s.state.out next_state = s.state.out if (curr_state == s.STATE_IDLE): if (s.in_.val): next_state = s.STATE_SOURCE if (curr_state == s.STATE_SOURCE): if (s.go): next_state = s.STATE_INIT elif (s.done): next_state = s.STATE_IDLE if (curr_state == s.STATE_INIT): if (s.train_count_wr == TRAIN_DATA - 1): next_state = s.STATE_START if (curr_state == s.STATE_START): if (s.train_count_rd == TRAIN_DATA - (mapper_num / DIGIT)): next_state = s.STATE_WRITE if (curr_state == s.STATE_WRITE): if (s.input_count == s.size): next_state = s.STATE_END else: next_state = s.STATE_START if (curr_state == s.STATE_END): if s.gmem_resp.val: next_state = s.STATE_SOURCE s.state.in_.value = next_state #--------------------------------------------------------------------- # Task State Output Logic #--------------------------------------------------------------------- @s.combinational def state_outputs(): current_state = s.state.out s.gmem_req.val.value = 0 s.gmem_resp.rdy.value = 0 s.in_.rdy.value = 0 s.out.val.value = 0 # In IDLE state if (current_state == s.STATE_IDLE): s.input_count.value = 0 s.train_count_rd.value = 0 s.train_count_wr.value = 0 s.reference.value = 0 s.go.value = 0 s.train_data_rd.value = 0 s.train_data_wr.value = 0 s.done.value = 0 s.rst.value = 0 s.red_rst.value = 0 # In SOURCE state if (current_state == s.STATE_SOURCE): if (s.in_.val and s.out.rdy): if (s.in_.msg.type_ == digitrecReqMsg.TYPE_WRITE): if (s.in_.msg.addr == 0): # start computing s.go.value = 1 elif (s.in_.msg.addr == 1): # base address s.base.value = s.in_.msg.data elif (s.in_.msg.addr == 2): # size s.size.value = s.in_.msg.data # Send xcel response message s.in_.rdy.value = 1 s.out.msg.type_.value = digitrecReqMsg.TYPE_WRITE s.out.msg.data.value = 0 s.out.val.value = 1 elif (s.in_.msg.type_ == digitrecReqMsg.TYPE_READ): # the computing is done, send response message if (s.done): s.out.msg.type_.value = digitrecReqMsg.TYPE_READ s.out.msg.data.value = 1 s.in_.rdy.value = 1 s.out.val.value = 1 # In INIT state if (current_state == s.STATE_INIT): s.train_data_wr.value = 1 s.go.value = 0 # at the end of init, send read req to global memory if s.train_count_wr == TRAIN_DATA - 1: if s.gmem_req.rdy: s.gmem_req.msg.addr.value = s.base + (8 * s.input_count) s.gmem_req.msg.type_.value = TYPE_READ s.gmem_req.val.value = 1 s.red_rst.value = 1 # In START state if (current_state == s.STATE_START): s.train_data_wr.value = 0 s.train_data_rd.value = 1 s.rst.value = 0 s.red_rst.value = 0 if s.gmem_resp.val: # if response type is read, stores test data to reference, hold response val # until everything is done, which is set in WRITE state if s.gmem_resp.msg.type_ == TYPE_READ: s.gmem_resp.rdy.value = 1 s.reference.in_.value = s.gmem_resp.msg.data else: # if response tyle is write, set response rdy, send another req to # read test data s.gmem_resp.rdy.value = 1 s.gmem_req.msg.addr.value = s.base + (8 * s.input_count) s.gmem_req.msg.type_.value = TYPE_READ s.gmem_req.val.value = 1 s.red_rst.value = 1 s.train_data_rd.value = 0 # In WRITE state if (current_state == s.STATE_WRITE): s.train_data_rd.value = 0 # one test data done processed, write result from merger to memory if (s.gmem_req.rdy): s.gmem_req.msg.addr.value = 0x2000 + (8 * s.result_count) s.gmem_req.msg.data.value = s.merger_resp s.gmem_req.msg.type_.value = TYPE_WRITE s.gmem_req.val.value = 1 s.rst.value = 1 # In END state if (current_state == s.STATE_END): if s.gmem_resp.val: s.gmem_resp.rdy.value = 1 s.done.value = 1
def test_reg16(): reg_test( Reg(16) )
def test_reg8(): reg_test( Reg(8) )
def __init__(s): #================================================================== # Interfaces #================================================================== s.req_msg_a = InPort(32) s.req_msg_b = InPort(32) s.resp_msg = OutPort(32) # Control signals s.a_mux_sel = InPort(A_MUX_SEL_NBITS) s.b_mux_sel = InPort(B_MUX_SEL_NBITS) s.result_mux_sel = InPort(RES_MUX_SEL_NBITS) s.add_mux_sel = InPort(ADD_MUX_SEL_NBITS) s.result_en = InPort(1) s.result_sign = InPort(OUT_MUX_SEL_NBITS) # Status signals s.b_lsb = OutPort(1) s.a_msb = OutPort(1) s.b_msb = OutPort(1) s.to_ctrl_shamt = OutPort(6) # shamt input to shifters, calculated by ShamtGen s.shamt = Wire(6) # Binary representation of the multiplier s.bit_string = Wire(32) #================================================================== # Structure #================================================================== # A Mux s.in_a = Wire(32) # Take the absolute value of the input @s.combinational def sign_handling_a(): s.in_a.value = s.req_msg_a if ~s.req_msg_a[31] \ else (~s.req_msg_a) + 1 s.l_shift_out = Wire(32) s.a_mux = m = Mux(32, 2) s.connect_pairs( m.sel, s.a_mux_sel, m.in_[A_MUX_SEL_IN], s.in_a, m.in_[A_MUX_SEL_SHIFT], s.l_shift_out, ) # A Register s.a_reg = m = Reg(32) s.connect(m.in_, s.a_mux.out) # Left Shifter s.l_shift = m = LeftLogicalShifter(32, 6) s.connect_pairs( m.in_, s.a_reg.out, m.shamt, s.shamt, m.out, s.l_shift_out, ) # B Mux s.in_b = Wire(32) # Take the absolute value of the input @s.combinational def sign_handling_b(): s.in_b.value = s.req_msg_b if ~s.req_msg_b[31] \ else (~s.req_msg_b) + 1 s.r_shift_out = Wire(32) s.b_mux = m = Mux(32, 2) s.connect_pairs( m.sel, s.b_mux_sel, m.in_[B_MUX_SEL_IN], s.in_b, m.in_[B_MUX_SEL_SHIFT], s.r_shift_out, ) # B Register s.b_reg = m = Reg(32) s.connect(m.in_, s.b_mux.out) # Take the higher 31 bits and add 0 in the high order # The ShamtGen module will generate the appropriate shamt # according to the number of consecutive zeros in lower s.bit_string @s.combinational def bit_string_block(): s.bit_string.value = concat(Bits(1, 0), s.b_reg.out[1:32]) # Right Shifter s.r_shift = m = RightLogicalShifter(32, 6) s.connect_pairs( m.in_, s.b_reg.out, m.shamt, s.shamt, m.out, s.r_shift_out, ) # Result Mux s.add_mux_out = Wire(32) s.result_mux = m = Mux(32, 2) s.connect_pairs( m.sel, s.result_mux_sel, m.in_[RES_MUX_SEL_ZERO], 0, m.in_[RES_MUX_SEL_ADD], s.add_mux_out, ) # Result Register s.res_reg = m = RegEn(32) s.connect_pairs(m.in_, s.result_mux.out, m.en, s.result_en) # Adder s.adder = m = Adder(32) s.connect_pairs( m.in0, s.a_reg.out, m.in1, s.res_reg.out, m.cin, 0, ) # Add Mux s.add_mux = m = Mux(32, 2) s.connect_pairs( m.sel, s.add_mux_sel, m.in_[ADD_MUX_SEL_ADD], s.adder.out, m.in_[ADD_MUX_SEL_RES], s.res_reg.out, m.out, s.add_mux_out, ) # ShamtGen s.shamt_gen = m = ShamtGenPRTL() s.connect_pairs( m.a, s.bit_string, m.shamt, s.shamt, ) # Forward shamt to control unit so the counter can update # accordingly @s.combinational def to_ctrl_shamt_block(): s.to_ctrl_shamt.value = s.shamt # Output MUX s.res_neg = Wire(32) # Generate -res in case the result is negative @s.combinational def twos_compl_block(): s.res_neg.value = (~s.res_reg.out) + 1 s.out_mux = m = Mux(32, 2) s.connect_pairs( m.sel, s.result_sign, m.in_[OUT_MUX_SEL_POS], s.res_reg.out, m.in_[OUT_MUX_SEL_NEG], s.res_neg, m.out, s.resp_msg, ) # Connect status signals s.connect(s.b_reg.out[0], s.b_lsb) s.connect(s.req_msg_a[31], s.a_msb) s.connect(s.req_msg_b[31], s.b_msb)
def __init__(s): #================================================================== # Interfaces #================================================================== s.req_msg_a = InPort(32) s.req_msg_b = InPort(32) s.resp_msg = OutPort(32) # Control signals s.a_mux_sel = InPort(A_MUX_SEL_NBITS) s.b_mux_sel = InPort(B_MUX_SEL_NBITS) s.result_mux_sel = InPort(RES_MUX_SEL_NBITS) s.add_mux_sel = InPort(ADD_MUX_SEL_NBITS) s.result_en = InPort(1) s.result_sign = InPort(OUT_MUX_SEL_NBITS) # Status signals s.b_lsb = OutPort(1) s.a_msb = OutPort(1) s.b_msb = OutPort(1) #================================================================== # Structure #================================================================== # A Mux s.in_a = Wire(32) # Take the abs value of the input @s.combinational def sign_handling_a(): s.in_a.value = s.req_msg_a if ~s.req_msg_a[31] \ else (~s.req_msg_a) + 1 s.l_shift_out = Wire(32) s.a_mux = m = Mux(32, 2) s.connect_pairs( m.sel, s.a_mux_sel, m.in_[A_MUX_SEL_IN], s.in_a, m.in_[A_MUX_SEL_SHIFT], s.l_shift_out, ) # A Register s.a_reg = m = Reg(32) s.connect(m.in_, s.a_mux.out) # Left Shifter s.l_shift = m = LeftLogicalShifter(32) s.connect_pairs( m.in_, s.a_reg.out, m.shamt, 1, m.out, s.l_shift_out, ) # B Mux s.in_b = Wire(32) # Take the abs value of the input @s.combinational def sign_handling_b(): s.in_b.value = s.req_msg_b if ~s.req_msg_b[31] \ else (~s.req_msg_b) + 1 s.r_shift_out = Wire(32) s.b_mux = m = Mux(32, 2) s.connect_pairs( m.sel, s.b_mux_sel, m.in_[B_MUX_SEL_IN], s.in_b, m.in_[B_MUX_SEL_SHIFT], s.r_shift_out, ) # B Register s.b_reg = m = Reg(32) s.connect(m.in_, s.b_mux.out) # Right Shifter s.r_shift = m = RightLogicalShifter(32) s.connect_pairs( m.in_, s.b_reg.out, m.shamt, 1, m.out, s.r_shift_out, ) # Result Mux s.add_mux_out = Wire(32) s.result_mux = m = Mux(32, 2) s.connect_pairs( m.sel, s.result_mux_sel, m.in_[RES_MUX_SEL_ZERO], 0, m.in_[RES_MUX_SEL_ADD], s.add_mux_out, ) # Result Register s.res_reg = m = RegEn(32) s.connect_pairs(m.in_, s.result_mux.out, m.en, s.result_en) # Adder s.adder = m = Adder(32) s.connect_pairs( m.in0, s.a_reg.out, m.in1, s.res_reg.out, m.cin, 0, ) # Add Mux s.add_mux = m = Mux(32, 2) s.connect_pairs( m.sel, s.add_mux_sel, m.in_[ADD_MUX_SEL_ADD], s.adder.out, m.in_[ADD_MUX_SEL_RES], s.res_reg.out, m.out, s.add_mux_out, ) # Output MUX s.res_neg = Wire(32) # Generate -res in case the output is negative @s.combinational def twos_compl_block(): s.res_neg.value = (~s.res_reg.out) + 1 s.out_mux = m = Mux(32, 2) s.connect_pairs( m.sel, s.result_sign, m.in_[OUT_MUX_SEL_POS], s.res_reg.out, m.in_[OUT_MUX_SEL_NEG], s.res_neg, m.out, s.resp_msg, ) # Connect status signals s.connect(s.b_reg.out[0], s.b_lsb) s.connect(s.req_msg_a[31], s.a_msb) s.connect(s.req_msg_b[31], s.b_msb)
def __init__(s, mem_ifc_types=MemMsg(8, 32, 32)): # Interface s.xcelreq = InValRdyBundle(XcelReqMsg()) s.xcelresp = OutValRdyBundle(XcelRespMsg()) s.memreq = OutValRdyBundle(mem_ifc_types.req) s.memresp = InValRdyBundle(mem_ifc_types.resp) # Queues s.xcelreq_q = SingleElementPipelinedQueue(XcelReqMsg()) s.connect(s.xcelreq, s.xcelreq_q.enq) s.memreq_q = SingleElementBypassQueue(MemReqMsg(8, 32, 32)) s.connect(s.memreq, s.memreq_q.deq) s.memresp_q = SingleElementPipelinedQueue(MemRespMsg(8, 32)) s.connect(s.memresp, s.memresp_q.enq) # Internal state s.base_addr = Reg(32) s.size = Reg(32) s.inner_count = Reg(32) s.outer_count = Reg(32) s.a = Reg(32) # Line tracing s.prev_state = 0 s.xcfg_trace = " " # Helpers to make memory read/write requests s.mk_rd = mem_ifc_types.req.mk_rd s.mk_wr = mem_ifc_types.req.mk_wr #===================================================================== # State Update #===================================================================== s.STATE_XCFG = 0 s.STATE_FIRST0 = 1 s.STATE_FIRST1 = 2 s.STATE_BUBBLE0 = 3 s.STATE_BUBBLE1 = 4 s.STATE_LAST = 5 s.state = Wire(8) s.go = Wire(1) @s.tick_rtl def block0(): if s.reset: s.state.next = s.STATE_XCFG else: s.state.next = s.state if s.state == s.STATE_XCFG: if s.go & s.xcelresp.val & s.xcelresp.rdy: s.state.next = s.STATE_FIRST0 elif s.state == s.STATE_FIRST0: if s.memreq_q.enq.rdy: s.state.next = s.STATE_FIRST1 elif s.state == s.STATE_FIRST1: if s.memreq_q.enq.rdy and s.memresp_q.deq.rdy: s.state.next = s.STATE_BUBBLE0 elif s.state == s.STATE_BUBBLE0: if s.memreq_q.enq.rdy and s.memresp_q.deq.rdy: s.state.next = s.STATE_BUBBLE1 elif s.state == s.STATE_BUBBLE1: if s.memreq_q.enq.rdy and s.memresp_q.deq.rdy: if s.inner_count.out + 1 < s.size.out: s.state.next = s.STATE_BUBBLE0 else: s.state.next = s.STATE_LAST elif s.state == s.STATE_LAST: if s.memreq_q.enq.rdy and s.memresp_q.deq.rdy: if s.outer_count.out + 1 < s.size.out: s.state.next = s.STATE_FIRST1 else: s.state.next = s.STATE_XCFG #===================================================================== # State Outputs #===================================================================== @s.combinational def block1(): s.xcelreq_q.deq.rdy.value = 0 s.xcelresp.val.value = 0 s.memreq_q.enq.val.value = 0 s.memresp_q.deq.rdy.value = 0 s.go.value = 0 s.outer_count.in_.value = s.outer_count.out s.inner_count.in_.value = s.inner_count.out #------------------------------------------------------------------- # STATE: XCFG #------------------------------------------------------------------- if s.state == s.STATE_XCFG: s.xcelreq_q.deq.rdy.value = s.xcelresp.rdy s.xcelresp.val.value = s.xcelreq_q.deq.val if s.xcelreq_q.deq.val: if s.xcelreq_q.deq.msg.type_ == XcelReqMsg.TYPE_WRITE: if s.xcelreq_q.deq.msg.raddr == 0: s.outer_count.in_.value = 0 s.go.value = 1 elif s.xcelreq_q.deq.msg.raddr == 1: s.base_addr.in_.value = s.xcelreq_q.deq.msg.data elif s.xcelreq_q.deq.msg.raddr == 2: s.size.in_.value = s.xcelreq_q.deq.msg.data # Send xcel response message s.xcelresp.msg.type_.value = XcelRespMsg.TYPE_WRITE else: # Send xcel response message, obviously you only want to # send the response message when accelerator is done s.xcelresp.msg.type_.value = XcelRespMsg.TYPE_READ s.xcelresp.msg.data.value = 1 #------------------------------------------------------------------- # STATE: FIRST0 #------------------------------------------------------------------- # Send the first memory read request for the very first # element in the array. elif s.state == s.STATE_FIRST0: if s.memreq_q.enq.rdy: s.memreq_q.enq.val.value = 1 s.memreq_q.enq.msg.type_.value = MemReqMsg.TYPE_READ s.memreq_q.enq.msg.opaque.value = 0 s.memreq_q.enq.msg.addr.value = s.base_addr.out + 4 * s.inner_count.out s.memreq_q.enq.msg.len.value = 0 s.inner_count.in_.value = 1 #------------------------------------------------------------------- # STATE: FIRST1 #------------------------------------------------------------------- # Wait for the memory response for the first element in the array, # and once it arrives store this element in a, and send the memory # read request for the second element. elif s.state == s.STATE_FIRST1: if s.memreq_q.enq.rdy and s.memresp_q.deq.val: s.memresp_q.deq.rdy.value = 1 s.a.in_.value = s.memresp_q.deq.msg.data s.memreq_q.enq.val.value = 1 s.memreq_q.enq.msg.type_.value = MemReqMsg.TYPE_READ s.memreq_q.enq.msg.opaque.value = 0 s.memreq_q.enq.msg.addr.value = s.base_addr.out + 4 * s.inner_count.out s.memreq_q.enq.msg.len.value = 0 #------------------------------------------------------------------- # STATE: BUBBLE0 #------------------------------------------------------------------- # Wait for the memory read response to get the next element, # compare the new value to the previous max value, update b with # the new max value, and send a memory request to store the new min # value. Notice how we decrement the write address by four since we # want to store to the new min value _previous_ element. elif s.state == s.STATE_BUBBLE0: if s.memreq_q.enq.rdy and s.memresp_q.deq.val: s.memresp_q.deq.rdy.value = 1 if s.a.out > s.memresp_q.deq.msg: s.a.in_.value = s.a.out s.memreq_q.enq.msg.data.value = s.memresp_q.deq.msg else: s.a.in_.value = s.memresp_q.deq.msg s.memreq_q.enq.msg.data.value = s.a.out s.memreq_q.enq.val.value = 1 s.memreq_q.enq.msg.type_.value = MemReqMsg.TYPE_WRITE s.memreq_q.enq.msg.opaque.value = 0 s.memreq_q.enq.msg.addr.value = (s.base_addr.out + 4 * (s.inner_count.out - 1)) s.memreq_q.enq.msg.len.value = 0 #------------------------------------------------------------------- # STATE: BUBBLE1 #------------------------------------------------------------------- # Wait for the memory write response, and then check to see if we # have reached the end of the array. If we have not reached the end # of the array, then make a new memory read request for the next # element; if we have reached the end of the array, then make a # final write request (with value from a) to update the final # element in the array. elif s.state == s.STATE_BUBBLE1: if s.memreq_q.enq.rdy and s.memresp_q.deq.val: s.memresp_q.deq.rdy.value = 1 s.inner_count.in_.value = s.inner_count.out + 1 if s.inner_count.out + 1 < s.size.out: s.memreq_q.enq.val.value = 1 s.memreq_q.enq.msg.type_.value = MemReqMsg.TYPE_READ s.memreq_q.enq.msg.opaque.value = 0 s.memreq_q.enq.msg.addr.value = s.base_addr.out + 4 * ( s.inner_count.out + 1) s.memreq_q.enq.msg.len.value = 0 else: s.memreq_q.enq.val.value = 1 s.memreq_q.enq.msg.type_.value = MemReqMsg.TYPE_WRITE s.memreq_q.enq.msg.opaque.value = 0 s.memreq_q.enq.msg.addr.value = (s.base_addr.out + 4 * (s.inner_count.out)) s.memreq_q.enq.msg.len.value = 0 s.memreq_q.enq.msg.data.value = s.a.out #------------------------------------------------------------------- # STATE: LAST #------------------------------------------------------------------- # Wait for the last response, and then check to see if we need to # go through the array again. If we do need to go through array # again, then make a new memory read request for the very first # element in the array; if we do not need to go through the array # again, then we are all done and we can go back to accelerator # configuration. elif s.state == s.STATE_LAST: if s.memreq_q.enq.rdy and s.memresp_q.deq.val: s.memresp_q.deq.rdy.value = 1 s.outer_count.in_.value = s.outer_count.out + 1 if s.outer_count.out + 1 < s.size.out: s.memreq_q.enq.val.value = 1 s.memreq_q.enq.msg.type_.value = MemReqMsg.TYPE_READ s.memreq_q.enq.msg.opaque.value = 0 s.memreq_q.enq.msg.addr.value = s.base_addr.out s.memreq_q.enq.msg.len.value = 0 s.inner_count.in_.value = 1
def elaborate_logic(s): #--------------------------------------------------------------------- # Stage A->B pipeline registers #--------------------------------------------------------------------- s.reg_AB = [Reg(16) for x in range(4)] s.connect(s.reg_AB[0].in_, s.in_[0]) s.connect(s.reg_AB[1].in_, s.in_[1]) s.connect(s.reg_AB[2].in_, s.in_[2]) s.connect(s.reg_AB[3].in_, s.in_[3]) #--------------------------------------------------------------------- # Stage B combinational logic #--------------------------------------------------------------------- s.cmp_B0 = m = MinMax() s.connect_dict({ m.in0: s.reg_AB[0].out, m.in1: s.reg_AB[1].out, }) s.cmp_B1 = m = MinMax() s.connect_dict({ m.in0: s.reg_AB[2].out, m.in1: s.reg_AB[3].out, }) #--------------------------------------------------------------------- # Stage B->C pipeline registers #--------------------------------------------------------------------- s.reg_BC = [Reg(16) for x in range(4)] s.connect(s.reg_BC[0].in_, s.cmp_B0.min) s.connect(s.reg_BC[1].in_, s.cmp_B0.max) s.connect(s.reg_BC[2].in_, s.cmp_B1.min) s.connect(s.reg_BC[3].in_, s.cmp_B1.max) #--------------------------------------------------------------------- # Stage C combinational logic #--------------------------------------------------------------------- s.cmp_C0 = m = MinMax() s.connect_dict({ m.in0: s.reg_BC[0].out, m.in1: s.reg_BC[2].out, }) s.cmp_C1 = m = MinMax() s.connect_dict({ m.in0: s.reg_BC[1].out, m.in1: s.reg_BC[3].out, }) s.cmp_C2 = m = MinMax() s.connect_dict({ m.in0: s.cmp_C0.max, m.in1: s.cmp_C1.min, }) # Connect to output ports s.connect(s.cmp_C0.min, s.out[0]) s.connect(s.cmp_C2.min, s.out[1]) s.connect(s.cmp_C2.max, s.out[2]) s.connect(s.cmp_C1.max, s.out[3])
def __init__(s, nbits): nbitsx2 = nbits * 2 dtype = mk_bits(nbits) dtypex2 = mk_bits(nbitsx2) s.req_msg = InVPort(dtypex2) s.resp_msg = OutVPort(dtypex2) # Status signals s.sub_negative1 = OutVPort(Bits1) s.sub_negative2 = OutVPort(Bits1) # Control signals s.quotient_mux_sel = InVPort(Bits1) s.quotient_reg_en = InVPort(Bits1) s.remainder_mux_sel = InVPort(Bits2) s.remainder_reg_en = InVPort(Bits1) s.divisor_mux_sel = InVPort(Bits1) # Dpath components s.remainder_mux = Mux(dtypex2, 3)(sel=s.remainder_mux_sel) @s.update def up_remainder_mux_in0(): s.remainder_mux.in_[R_MUX_SEL_IN] = dtypex2() s.remainder_mux.in_[R_MUX_SEL_IN][0:nbits] = s.req_msg[0:nbits] s.remainder_reg = RegEn(dtypex2)( in_=s.remainder_mux.out, en=s.remainder_reg_en, ) # lower bits of resp_msg save the remainder s.connect(s.resp_msg[0:nbits], s.remainder_reg.out[0:nbits]) s.divisor_mux = Mux(dtypex2, 2)(sel=s.divisor_mux_sel) @s.update def up_divisor_mux_in0(): s.divisor_mux.in_[D_MUX_SEL_IN] = dtypex2() s.divisor_mux.in_[D_MUX_SEL_IN][nbits - 1:nbitsx2 - 1] = s.req_msg[nbits:nbitsx2] s.divisor_reg = Reg(dtypex2)(in_=s.divisor_mux.out) s.quotient_mux = Mux(dtype, 2)(sel=s.quotient_mux_sel) s.connect(s.quotient_mux.in_[Q_MUX_SEL_0], 0) s.quotient_reg = RegEn(dtype)( in_=s.quotient_mux.out, en=s.quotient_reg_en, # higher bits of resp_msg save the quotient out=s.resp_msg[nbits:nbitsx2], ) # shamt should be 2 bits! s.quotient_lsh = LShifter(dtype, 2)(in_=s.quotient_reg.out) s.connect(s.quotient_lsh.shamt, 2) s.inc = Wire(Bits2) s.connect(s.sub_negative1, s.inc[1]) s.connect(s.sub_negative2, s.inc[0]) @s.update def up_quotient_inc(): s.quotient_mux.in_[Q_MUX_SEL_LSH] = s.quotient_lsh.out + ~s.inc # stage 1/2 s.sub1 = Subtractor(dtypex2)( in0=s.remainder_reg.out, in1=s.divisor_reg.out, out=s.remainder_mux.in_[R_MUX_SEL_SUB1], ) s.connect(s.sub_negative1, s.sub1.out[nbitsx2 - 1]) s.remainder_mid_mux = Mux(dtypex2, 2)( in_={ 0: s.sub1.out, 1: s.remainder_reg.out, }, sel=s.sub_negative1, ) s.divisor_rsh1 = RShifter(dtypex2, 1)(in_=s.divisor_reg.out, ) s.connect(s.divisor_rsh1.shamt, 1) # stage 2/2 s.sub2 = Subtractor(dtypex2)( in0=s.remainder_mid_mux.out, in1=s.divisor_rsh1.out, out=s.remainder_mux.in_[R_MUX_SEL_SUB2], ) s.connect(s.sub_negative2, s.sub2.out[nbitsx2 - 1]) s.divisor_rsh2 = RShifter(dtypex2, 1)( in_=s.divisor_rsh1.out, out=s.divisor_mux.in_[D_MUX_SEL_RSH], ) s.connect(s.divisor_rsh2.shamt, 1)
def __init__(s, nbits): s.req_val = InVPort(Bits1) s.req_rdy = OutVPort(Bits1) s.resp_val = OutVPort(Bits1) s.resp_rdy = InVPort(Bits1) # Status signals s.sub_negative1 = InVPort(Bits1) s.sub_negative2 = InVPort(Bits1) # Control signals s.quotient_mux_sel = OutVPort(Bits1) s.quotient_reg_en = OutVPort(Bits1) s.remainder_mux_sel = OutVPort(Bits2) s.remainder_reg_en = OutVPort(Bits1) s.divisor_mux_sel = OutVPort(Bits1) state_dtype = mk_bits(1 + clog2(nbits)) s.state = Reg(state_dtype) s.STATE_IDLE = state_dtype(0) s.STATE_DONE = state_dtype(1) s.STATE_CALC = state_dtype(1 + nbits / 2) @s.update def state_transitions(): curr_state = s.state.out if curr_state == s.STATE_IDLE: if s.req_val and s.req_rdy: s.state.in_ = s.STATE_CALC elif curr_state == s.STATE_DONE: if s.resp_val and s.resp_rdy: s.state.in_ = s.STATE_IDLE else: s.state.in_ = curr_state - 1 @s.update def state_outputs(): curr_state = s.state.out if curr_state == s.STATE_IDLE: s.req_rdy = Bits1(1) s.resp_val = Bits1(0) s.remainder_mux_sel = Bits2(R_MUX_SEL_IN) s.remainder_reg_en = Bits1(1) s.quotient_mux_sel = Bits2(Q_MUX_SEL_0) s.quotient_reg_en = Bits1(1) s.divisor_mux_sel = Bits1(D_MUX_SEL_IN) elif curr_state == s.STATE_DONE: s.req_rdy = Bits1(0) s.resp_val = Bits1(1) s.quotient_mux_sel = Bits2(Q_MUX_SEL_0) s.quotient_reg_en = Bits1(0) s.remainder_mux_sel = Bits2(R_MUX_SEL_IN) s.remainder_reg_en = Bits1(0) s.divisor_mux_sel = Bits1(D_MUX_SEL_IN) else: # calculating s.req_rdy = Bits1(0) s.resp_val = Bits1(0) s.remainder_reg_en = ~(s.sub_negative1 & s.sub_negative2) if s.sub_negative2: s.remainder_mux_sel = Bits2(R_MUX_SEL_SUB1) else: s.remainder_mux_sel = Bits2(R_MUX_SEL_SUB2) s.quotient_reg_en = Bits1(1) s.quotient_mux_sel = Bits1(Q_MUX_SEL_LSH) s.divisor_mux_sel = Bits1(D_MUX_SEL_RSH)