def __init__(self, x_len): super().__init__(x_len) inst = self.io.inst Iimm = m.sext_to(m.sint(inst[20:32]), x_len) Simm = m.sext_to(m.sint(m.concat(inst[7:12], inst[25:32])), x_len) Bimm = m.sext_to( m.sint( m.concat(m.bits(0, 1), inst[8:12], inst[25:31], inst[7], inst[31])), x_len) Uimm = m.concat(m.bits(0, 12), inst[12:32]) Jimm = m.sext_to( m.sint( m.concat(m.bits(0, 1), inst[21:25], inst[25:31], inst[20], inst[12:20], inst[31])), x_len) Zimm = m.sint(m.zext_to(inst[15:20], x_len)) self.io.O @= m.uint( m.dict_lookup( { IMM_I: Iimm, IMM_S: Simm, IMM_B: Bimm, IMM_U: Uimm, IMM_J: Jimm, IMM_Z: Zimm }, self.io.sel, Iimm & -2))
def bytes_to_x_size(self, bytes_): return m.dict_lookup({ 1: m.bits(0, 3), 2: m.bits(1, 3), 4: m.bits(2, 3), 8: m.bits(3, 3), 16: m.bits(4, 3), 32: m.bits(5, 3), 64: m.bits(6, 3), 128: m.bits(7, 3), }, bytes_, m.bits(0b111, 3))
def __init__(self, x_len): self.io = io = make_ControlIO(x_len) ctrl_signals = m.dict_lookup(inst_map, io.inst, default=default) # Control signals for Fetch io.pc_sel @= ctrl_signals[0] io.inst_kill @= ctrl_signals[6] # Control signals for Execute io.A_sel @= ctrl_signals[1] io.B_sel @= ctrl_signals[2] io.imm_sel @= ctrl_signals[3] io.alu_op @= ctrl_signals[4] io.br_type @= ctrl_signals[5] io.st_type @= ctrl_signals[7] # Control signals for Write Back io.ld_type @= ctrl_signals[8] io.wb_sel @= ctrl_signals[9] io.wb_en @= ctrl_signals[10] io.csr_cmd @= ctrl_signals[11] io.illegal @= ctrl_signals[12]
class CSR_DUT(m.Circuit): io = m.IO(done=m.Out(m.Bit), check=m.Out(m.Bit), rdata=m.Out(m.UInt[x_len]), expected_rdata=m.Out(m.UInt[x_len]), epc=m.Out(m.UInt[x_len]), expected_epc=m.Out(m.UInt[x_len]), evec=m.Out(m.UInt[x_len]), expected_evec=m.Out(m.UInt[x_len]), expt=m.Out(m.Bit), expected_expt=m.Out(m.Bit)) io += m.ClockIO(has_reset=True) regs = {} for reg in CSR.regs: if reg == CSR.mcpuid: init = (1 << (ord('I') - ord('A')) | 1 << (ord('U') - ord('A'))) elif reg == CSR.mstatus: init = (CSR.PRV_M.ext(30) << 4) | (CSR.PRV_M.ext(30) << 1) elif reg == CSR.mtvec: init = Const.PC_EVEC else: init = 0 regs[reg] = m.Register(init=BV[32](init), reset_type=m.Reset)() csr = CSRGen(x_len)() ctrl = Control.Control(x_len)() counter = CounterModM(n, n.bit_length()) inst = m.mux(insts, counter.O) ctrl.inst @= inst csr.inst @= inst csr_cmd = ctrl.csr_cmd csr.cmd @= csr_cmd csr.illegal @= ctrl.illegal csr.st_type @= ctrl.st_type csr.ld_type @= ctrl.ld_type csr.pc_check @= ctrl.pc_sel == Control.PC_ALU csr.pc @= m.mux(pc, counter.O) csr.addr @= m.mux(addr, counter.O) csr.I @= m.mux(data, counter.O) csr.stall @= False csr.host.fromhost.valid @= False csr.host.fromhost.data @= 0 # values known statically _csr_addr = [csr(inst) for inst in insts] _rs1_addr = [rs1(inst) for inst in insts] _csr_ro = [((((x >> 11) & 0x1) > 0x0) & (((x >> 10) & 0x1) > 0x0)) | (x == CSR.mtvec) | (x == CSR.mtdeleg) for x in _csr_addr] _csr_valid = [x in CSR.regs for x in _csr_addr] # should be <= prv in runtime _prv_level = [(x >> 8) & 0x3 for x in _csr_addr] # should consider prv in runtime _is_ecall = [((x & 0x1) == 0x0) & (((x >> 8) & 0x1) == 0x0) for x in _csr_addr] _is_ebreak = [((x & 0x1) > 0x0) & (((x >> 8) & 0x1) == 0x0) for x in _csr_addr] _is_eret = [((x & 0x1) == 0x0) & (((x >> 8) & 0x1) > 0x0) for x in _csr_addr] # should consider pc_check in runtime _iaddr_invalid = [((x >> 1) & 0x1) > 0 for x in addr] # should consider ld_type & sd_type _waddr_invalid = [(((x >> 1) & 0x1) > 0) | ((x & 0x1) > 0) for x in addr] _haddr_invalid = [(x & 0x1) > 0 for x in addr] # values known at runtime csr_addr = m.mux(_csr_addr, counter.O) rs1_addr = m.mux(_rs1_addr, counter.O) csr_ro = m.mux(_csr_ro, counter.O) csr_valid = m.mux(_csr_valid, counter.O) wen = (csr_cmd == CSR.W) | (csr_cmd[1] & (rs1_addr != 0)) prv1 = (regs[CSR.mstatus].O >> 4) & 0x3 ie1 = (regs[CSR.mstatus].O >> 3) & 0x1 prv = (regs[CSR.mstatus].O >> 1) & 0x3 ie = regs[CSR.mstatus].O & 0x1 prv_inst = csr_cmd == CSR.P prv_valid = (m.uint(m.zext_to(m.mux(_prv_level, counter.O), 32)) <= m.uint(prv)) iaddr_invalid = m.mux(_iaddr_invalid, counter.O) & csr.pc_check.value() laddr_invalid = (m.mux(_haddr_invalid, counter.O) & ((ctrl.ld_type == Control.LD_LH) | (ctrl.ld_type == Control.LD_LHU)) | m.mux(_waddr_invalid, counter.O) & (ctrl.ld_type == Control.LD_LW)) saddr_invalid = (m.mux(_haddr_invalid, counter.O) & (ctrl.st_type == Control.ST_SH) | m.mux(_waddr_invalid, counter.O) & (ctrl.st_type == Control.ST_SW)) is_ecall = prv_inst & m.mux(_is_ecall, counter.O) is_ebreak = prv_inst & m.mux(_is_ebreak, counter.O) is_eret = prv_inst & m.mux(_is_eret, counter.O) exception = (ctrl.illegal | iaddr_invalid | laddr_invalid | saddr_invalid | (((csr_cmd & 0x3) > 0) & (~csr_valid | ~prv_valid)) | (csr_ro & wen) | (prv_inst & ~prv_valid) | is_ecall | is_ebreak) instret = (inst != nop) & (~exception | is_ecall | is_ebreak) rdata = m.dict_lookup({key: value.O for key, value in regs.items()}, csr_addr) wdata = m.dict_lookup( { CSR.W: csr.I.value(), CSR.S: (csr.I.value() | rdata), CSR.C: (~csr.I.value() & rdata) }, csr_cmd) # compute state regs[CSR.time].I @= regs[CSR.time].O + 1 regs[CSR.timew].I @= regs[CSR.timew].O + 1 regs[CSR.mtime].I @= regs[CSR.mtime].O + 1 regs[CSR.cycle].I @= regs[CSR.cycle].O + 1 regs[CSR.cyclew].I @= regs[CSR.cyclew].O + 1 time_max = regs[CSR.time].O.reduce_and() # TODO: mtime has same default value as this case (from chisel code) # https://github.com/ucb-bar/riscv-mini/blob/release/src/test/scala/CSRTests.scala#L140 # mtime_reg = regs[CSR.mtime] # mtime_reg.I @= m.mux([mtime_reg.O, mtime_reg.O + 1], time_max) incr_when(regs[CSR.timeh], time_max) incr_when(regs[CSR.timehw], time_max) cycle_max = regs[CSR.cycle].O.reduce_and() incr_when(regs[CSR.cycleh], cycle_max) incr_when(regs[CSR.cyclehw], cycle_max) incr_when(regs[CSR.instret], instret) incr_when(regs[CSR.instretw], instret) instret_max = regs[CSR.instret].O.reduce_and() incr_when(regs[CSR.instreth], instret & instret_max) incr_when(regs[CSR.instrethw], instret & instret_max) cond = ~exception & ~is_eret & wen # Assuming these are mutually exclusive, so we don't need chained # elsewhen update_when(regs[CSR.mstatus], m.zext_to(wdata[0:6], 32), cond & (csr_addr == CSR.mstatus)) update_when(regs[CSR.mip], (m.bits(wdata[7], 32) << 7) | (m.bits(wdata[3], 32) << 3), cond & (csr_addr == CSR.mip)) update_when(regs[CSR.mie], (m.bits(wdata[7], 32) << 7) | (m.bits(wdata[3], 32) << 3), cond & (csr_addr == CSR.mie)) update_when(regs[CSR.mepc], (wdata >> 2) << 2, cond & (csr_addr == CSR.mepc)) update_when(regs[CSR.mcause], wdata & (1 << 31 | 0xf), cond & (csr_addr == CSR.mcause)) update_when(regs[CSR.time], wdata, cond & ((csr_addr == CSR.timew) | (csr_addr == CSR.mtime))) update_when(regs[CSR.timew], wdata, cond & ((csr_addr == CSR.timew) | (csr_addr == CSR.mtime))) update_when(regs[CSR.mtime], wdata, cond & ((csr_addr == CSR.timew) | (csr_addr == CSR.mtime))) update_when( regs[CSR.timeh], wdata, cond & ((csr_addr == CSR.timehw) | (csr_addr == CSR.mtimeh))) update_when( regs[CSR.timehw], wdata, cond & ((csr_addr == CSR.timehw) | (csr_addr == CSR.mtimeh))) update_when( regs[CSR.mtimeh], wdata, cond & ((csr_addr == CSR.timehw) | (csr_addr == CSR.mtimeh))) update_when(regs[CSR.cycle], wdata, cond & (csr_addr == CSR.cyclew)) update_when(regs[CSR.cyclew], wdata, cond & (csr_addr == CSR.cyclew)) update_when(regs[CSR.cycleh], wdata, cond & (csr_addr == CSR.cyclehw)) update_when(regs[CSR.cyclehw], wdata, cond & (csr_addr == CSR.cyclehw)) update_when(regs[CSR.instret], wdata, cond & (csr_addr == CSR.instretw)) update_when(regs[CSR.instretw], wdata, cond & (csr_addr == CSR.instretw)) update_when(regs[CSR.instreth], wdata, cond & (csr_addr == CSR.instrethw)) update_when(regs[CSR.instrethw], wdata, cond & (csr_addr == CSR.instrethw)) update_when(regs[CSR.mtimecmp], wdata, cond & (csr_addr == CSR.mtimecmp)) update_when(regs[CSR.mscratch], wdata, cond & (csr_addr == CSR.mscratch)) update_when(regs[CSR.mbadaddr], wdata, cond & (csr_addr == CSR.mbadaddr)) update_when(regs[CSR.mtohost], wdata, cond & (csr_addr == CSR.mtohost)) update_when(regs[CSR.mfromhost], wdata, cond & (csr_addr == CSR.mfromhost)) # eret update_when(regs[CSR.mstatus], (CSR.PRV_U.zext(30) << 4) | (1 << 3) | (prv1 << 1) | ie1, ~exception & is_eret) # TODO: exception logic comes after since it has priority Cause = make_Cause(x_len) mcause = m.mux([ m.mux([ m.mux([ m.mux([ m.mux([Cause.IllegalInst, Cause.Breakpoint], is_ebreak), Cause.Ecall + prv, ], is_ecall), Cause.StoreAddrMisaligned, ], saddr_invalid), Cause.LoadAddrMisaligned, ], laddr_invalid), Cause.InstAddrMisaligned, ], iaddr_invalid) update_when(regs[CSR.mcause], mcause, exception) update_when(regs[CSR.mepc], (csr.pc.value() >> 2) << 2, exception) update_when(regs[CSR.mstatus], (prv << 4) | (ie << 3) | (CSR.PRV_M.zext(30) << 1), exception) update_when( regs[CSR.mbadaddr], csr.addr.value(), exception & (iaddr_invalid | laddr_invalid | saddr_invalid)) epc = regs[CSR.mepc].O evec = regs[CSR.mtvec].O + (prv << 6) m.display("*** Counter: %d ***", counter.O) m.display("[in] inst: 0x%x, pc: 0x%x, addr: 0x%x, in: 0x%x", csr.inst, csr.pc, csr.addr, csr.I) m.display( " cmd: 0x%x, st_type: 0x%x, ld_type: 0x%x, illegal: %d, " "pc_check: %d", csr.cmd, csr.st_type, csr.ld_type, csr.illegal, csr.pc_check) m.display("[state] csr addr: %x", csr_addr) for reg_addr, reg in regs.items(): m.display(f" {hex(int(reg_addr))} -> 0x%x", reg.O) m.display( "[out] read: 0x%x =? 0x%x, epc: 0x%x =? 0x%x, evec: 0x%x ?= " "0x%x, expt: %d ?= %d", csr.O, rdata, csr.epc, epc, csr.evec, evec, csr.expt, exception) io.check @= counter.O.reduce_or() io.rdata @= csr.O io.expected_rdata @= rdata io.epc @= csr.epc io.expected_epc @= epc io.evec @= csr.evec io.expected_evec @= evec io.expt @= csr.expt io.expected_expt @= exception # io.failed @= counter.O.reduce_or() & ( # (csr.O != rdata) | # (csr.epc != epc) | # (csr.evec != evec) | # (csr.expt != exception) # ) io.done @= counter.COUT for key, reg in regs.items(): if not reg.I.driven(): reg.I @= reg.O
def __init__(self, x_len): Cause = make_Cause(x_len) self.io = io = m.IO( stall=m.In(m.Bit), cmd=m.In(m.UInt[3]), I=m.In(m.UInt[x_len]), O=m.Out(m.UInt[x_len]), # Excpetion pc=m.In(m.UInt[x_len]), addr=m.In(m.UInt[x_len]), inst=m.In(m.UInt[x_len]), illegal=m.In(m.Bit), st_type=m.In(m.UInt[2]), ld_type=m.In(m.UInt[3]), pc_check=m.In(m.Bit), expt=m.Out(m.Bit), evec=m.Out(m.UInt[x_len]), epc=m.Out( m.UInt[x_len])) + HostIO(x_len) + m.ClockIO(has_reset=True) csr_addr = io.inst[20:32] rs1_addr = io.inst[15:20] # user counters time = m.Register(m.UInt[x_len], reset_type=m.Reset)() timeh = m.Register(m.UInt[x_len], reset_type=m.Reset)() cycle = m.Register(m.UInt[x_len], reset_type=m.Reset)() cycleh = m.Register(m.UInt[x_len], reset_type=m.Reset)() instret = m.Register(m.UInt[x_len], reset_type=m.Reset)() instreth = m.Register(m.UInt[x_len], reset_type=m.Reset)() mcpuid = m.concat( BV[26]( 1 << (ord('I') - ord('A')) | # Base ISA 1 << (ord('U') - ord('A'))), # User Mode BV[x_len - 28](0), BV[2](0), # RV32I ) mimpid = BV[x_len](0) mhartid = BV[x_len](0) # interrupt enable stack PRV = m.Register(m.UInt[len(CSR.PRV_M)], init=CSR.PRV_M, reset_type=m.Reset)() PRV1 = m.Register(m.UInt[len(CSR.PRV_M)], init=CSR.PRV_M, reset_type=m.Reset)() PRV2 = BV[2](0) PRV3 = BV[2](0) IE = m.Register(m.Bit, init=False, reset_type=m.Reset)() IE1 = m.Register(m.Bit, init=False, reset_type=m.Reset)() IE2 = False IE3 = False # virtualization management field VM = BV[5](0) # memory privilege MPRV = False # Extension context status XS = BV[2](0) FS = BV[2](0) SD = BV[1](0) mstatus = m.concat(IE.O, PRV.O, IE1.O, PRV1.O, IE2, PRV2, IE3, PRV3, FS, XS, MPRV, VM, BV[x_len - 23](0), SD) mtvec = BV[x_len](Const.PC_EVEC) mtdeleg = BV[x_len](0) # interrupt registers MTIP = m.Register(m.Bit, init=False, reset_type=m.Reset)() HTIP = False STIP = False MTIE = m.Register(m.Bit, init=False, reset_type=m.Reset)() HTIE = False STIE = False MSIP = m.Register(m.Bit, init=False, reset_type=m.Reset)() HSIP = False SSIP = False MSIE = m.Register(m.Bit, init=False, reset_type=m.Reset)() HSIE = False SSIE = False mip = m.concat(Bit(False), SSIP, HSIP, MSIP.O, Bit(False), STIP, HTIP, MTIP.O, BV[x_len - 8](0)) mie = m.concat(Bit(False), SSIE, HSIE, MSIE.O, Bit(False), STIE, HTIE, MTIE.O, BV[x_len - 8](0)) mtimecmp = m.Register(m.UInt[x_len], reset_type=m.Reset)() mscratch = m.Register(m.UInt[x_len], reset_type=m.Reset)() mepc = m.Register(m.UInt[x_len], reset_type=m.Reset)() mcause = m.Register(m.UInt[x_len], reset_type=m.Reset)() mbadaddr = m.Register(m.UInt[x_len], reset_type=m.Reset)() mtohost = m.Register(m.UInt[x_len], reset_type=m.Reset)() mfromhost = m.Register(m.UInt[x_len], reset_type=m.Reset)() io.host.tohost @= mtohost.O csr_file = { CSR.cycle: cycle.O, CSR.time: time.O, CSR.instret: instret.O, CSR.cycleh: cycleh.O, CSR.timeh: timeh.O, CSR.instreth: instreth.O, CSR.cyclew: cycle.O, CSR.timew: time.O, CSR.instretw: instret.O, CSR.cyclehw: cycleh.O, CSR.timehw: timeh.O, CSR.instrethw: instreth.O, CSR.mcpuid: mcpuid, CSR.mimpid: mimpid, CSR.mhartid: mhartid, CSR.mtvec: mtvec, CSR.mtdeleg: mtdeleg, CSR.mie: mie, CSR.mtimecmp: mtimecmp.O, CSR.mtime: time.O, CSR.mtimeh: timeh.O, CSR.mscratch: mscratch.O, CSR.mepc: mepc.O, CSR.mcause: mcause.O, CSR.mbadaddr: mbadaddr.O, CSR.mip: mip, CSR.mtohost: mtohost.O, CSR.mfromhost: mfromhost.O, CSR.mstatus: mstatus, } out = m.dict_lookup(csr_file, csr_addr) io.O @= out priv_valid = csr_addr[8:10] <= PRV.O priv_inst = io.cmd == CSR.P is_E_call = priv_inst & ~csr_addr[0] & ~csr_addr[8] is_E_break = priv_inst & csr_addr[0] & ~csr_addr[8] is_E_ret = priv_inst & ~csr_addr[0] & csr_addr[8] csr_valid = m.reduce(operator.or_, m.bits([csr_addr == key for key in csr_file])) csr_RO = (csr_addr[10:12].reduce_and() | (csr_addr == CSR.mtvec) | (csr_addr == CSR.mtdeleg)) wen = (io.cmd == CSR.W) | io.cmd[1] & rs1_addr.reduce_or() wdata = m.dict_lookup( { CSR.W: io.I, CSR.S: out | io.I, CSR.C: out & ~io.I }, io.cmd) iaddr_invalid = io.pc_check & io.addr[1] laddr_invalid = m.dict_lookup( { Control.LD_LW: io.addr[0:2].reduce_or(), Control.LD_LH: io.addr[0], Control.LD_LHU: io.addr[0] }, io.ld_type) saddr_invalid = m.dict_lookup( { Control.ST_SW: io.addr[0:2].reduce_or(), Control.ST_SH: io.addr[0] }, io.st_type) expt = (io.illegal | iaddr_invalid | laddr_invalid | saddr_invalid | io.cmd[0:2].reduce_or() & (~csr_valid | ~priv_valid) | wen & csr_RO | (priv_inst & ~priv_valid) | is_E_call | is_E_break) io.expt @= expt io.evec @= mtvec + (m.zext_to(PRV.O, x_len) << 6) io.epc @= mepc.O @m.inline_combinational() def logic(): # Counters time.I @= time.O + 1 timeh.I @= timeh.O if time.O.reduce_and(): timeh.I @= timeh.O + 1 cycle.I @= cycle.O + 1 cycleh.I @= cycleh.O if cycle.O.reduce_and(): cycleh.I @= cycleh.O + 1 instret.I @= instret.O is_inst_ret = ((io.inst != Instructions.NOP) & (~expt | is_E_call | is_E_break) & ~io.stall) if is_inst_ret: instret.I @= instret.O + 1 instreth.I @= instreth.O if is_inst_ret & instret.O.reduce_and(): instreth.I @= instreth.O + 1 mbadaddr.I @= mbadaddr.O mepc.I @= mepc.O mcause.I @= mcause.O PRV.I @= PRV.O IE.I @= IE.O IE1.I @= IE1.O PRV1.I @= PRV1.O MTIP.I @= MTIP.O MSIP.I @= MSIP.O MTIE.I @= MTIE.O MSIE.I @= MSIE.O mtimecmp.I @= mtimecmp.O mscratch.I @= mscratch.O mtohost.I @= mtohost.O mfromhost.I @= mfromhost.O if io.host.fromhost.valid: mfromhost.I @= io.host.fromhost.data if ~io.stall: if expt: mepc.I @= io.pc >> 2 << 2 if iaddr_invalid: mcause.I @= Cause.InstAddrMisaligned elif laddr_invalid: mcause.I @= Cause.LoadAddrMisaligned elif saddr_invalid: mcause.I @= Cause.StoreAddrMisaligned elif is_E_call: mcause.I @= Cause.Ecall + m.zext_to(PRV.O, x_len) elif is_E_break: mcause.I @= Cause.Breakpoint else: mcause.I @= Cause.IllegalInst PRV.I @= CSR.PRV_M IE.I @= False PRV1.I @= PRV.O IE1.I @= IE.O if iaddr_invalid | laddr_invalid | saddr_invalid: mbadaddr.I @= io.addr elif is_E_ret: PRV.I @= PRV1.O IE.I @= IE1.O PRV1.I @= CSR.PRV_U IE1.I @= True elif wen: if csr_addr == CSR.mstatus: PRV1.I @= wdata[4:6] IE1.I @= wdata[3] PRV.I @= wdata[1:3] IE.I @= wdata[0] elif csr_addr == CSR.mip: MTIP.I @= wdata[7] MSIP.I @= wdata[3] elif csr_addr == CSR.mie: MTIE.I @= wdata[7] MSIE.I @= wdata[3] elif csr_addr == CSR.mtime: time.I @= wdata elif csr_addr == CSR.mtimeh: timeh.I @= wdata elif csr_addr == CSR.mtimecmp: mtimecmp.I @= wdata elif csr_addr == CSR.mscratch: mscratch.I @= wdata elif csr_addr == CSR.mepc: mepc.I @= wdata >> 2 << 2 elif csr_addr == CSR.mcause: mcause.I @= wdata & (1 << (x_len - 1) | 0xf) elif csr_addr == CSR.mbadaddr: mbadaddr.I @= wdata elif csr_addr == CSR.mtohost: mtohost.I @= wdata elif csr_addr == CSR.mfromhost: mfromhost.I @= wdata elif csr_addr == CSR.cyclew: cycle.I @= wdata elif csr_addr == CSR.timew: time.I @= wdata elif csr_addr == CSR.instretw: instret.I @= wdata elif csr_addr == CSR.cyclehw: cycleh.I @= wdata elif csr_addr == CSR.timehw: timeh.I @= wdata elif csr_addr == CSR.instrethw: instreth.I @= wdata
def __init__(self, x_len, ALU=ALUArea, ImmGen=ImmGenWire, BrCond=BrCondArea): self.io = make_DatapathIO(x_len) + m.ClockIO(has_reset=True) csr = CSRGen(x_len)() reg_file = RegFile(x_len)() alu = ALU(x_len)() imm_gen = ImmGen(x_len)() br_cond = BrCondArea(x_len)() # Fetch / Execute Registers fe_inst = m.Register(init=Instructions.NOP, has_enable=True)() fe_pc = m.Register(m.UInt[x_len], has_enable=True)() # Execute / Write Back Registers ew_inst = m.Register(init=Instructions.NOP)() ew_pc = m.Register(m.UInt[x_len])() ew_alu = m.Register(m.UInt[x_len])() csr_in = m.Register(m.UInt[x_len])() # Control signals st_type = m.Register(type(self.io.ctrl.st_type).undirected_t)() ld_type = m.Register(type(self.io.ctrl.ld_type).undirected_t)() wb_sel = m.Register(type(self.io.ctrl.wb_sel).undirected_t)() wb_en = m.Register(m.Bit)() csr_cmd = m.Register(type(self.io.ctrl.csr_cmd).undirected_t)() illegal = m.Register(m.Bit)() pc_check = m.Register(m.Bit)() # Fetch started = m.Register(m.Bit)()(m.bit(self.io.RESET)) stall = ~self.io.icache.resp.valid | ~self.io.dcache.resp.valid pc = m.Register(init=UIntVector[x_len](Const.PC_START) - UIntVector[x_len](4))() npc = m.mux([ m.mux([ m.mux([ m.mux([ m.mux([pc.O + m.uint(4, x_len), pc.O], self.io.ctrl.pc_sel == PC_0), alu.sum_ >> 1 << 1 ], (self.io.ctrl.pc_sel == PC_ALU) | br_cond.taken), csr.epc ], self.io.ctrl.pc_sel == PC_EPC), csr.evec ], csr.expt), pc.O ], stall) inst = m.mux([self.io.icache.resp.data.data, Instructions.NOP], started | self.io.ctrl.inst_kill | br_cond.taken | csr.expt) pc.I @= npc self.io.icache.req.data.addr @= npc self.io.icache.req.data.data @= 0 self.io.icache.req.data.mask @= 0 self.io.icache.req.valid @= ~stall self.io.icache.abort @= False fe_pc.I @= pc.O fe_pc.CE @= m.enable(~stall) fe_inst.I @= inst fe_inst.CE @= m.enable(~stall) # Execute # Decode self.io.ctrl.inst @= fe_inst.O # reg_file read rs1_addr = fe_inst.O[15:20] rs2_addr = fe_inst.O[20:25] reg_file.raddr1 @= rs1_addr reg_file.raddr2 @= rs2_addr # gen immediates imm_gen.inst @= fe_inst.O imm_gen.sel @= self.io.ctrl.imm_sel # bypass wb_rd_addr = ew_inst.O[7:12] rs1_hazard = wb_en.O & rs1_addr.reduce_or() & (rs1_addr == wb_rd_addr) rs2_hazard = wb_en.O & rs2_addr.reduce_or() & (rs2_addr == wb_rd_addr) rs1 = m.mux([reg_file.rdata1, ew_alu.O], (wb_sel.O == WB_ALU) & rs1_hazard) rs2 = m.mux([reg_file.rdata2, ew_alu.O], (wb_sel.O == WB_ALU) & rs2_hazard) # ALU operations alu.A @= m.mux([fe_pc.O, rs1], self.io.ctrl.A_sel == A_RS1) alu.B @= m.mux([imm_gen.O, rs2], self.io.ctrl.B_sel == B_RS2) alu.op @= self.io.ctrl.alu_op # Branch condition calc br_cond.rs1 @= rs1 br_cond.rs2 @= rs2 br_cond.br_type @= self.io.ctrl.br_type # D$ access daddr = m.mux([alu.sum_, ew_alu.O], stall) >> 2 << 2 w_offset = ((m.bits(alu.sum_[1], x_len) << 4) | (m.bits(alu.sum_[0], x_len) << 3)) self.io.dcache.req.valid @= ~stall & (self.io.ctrl.st_type.reduce_or() | self.io.ctrl.ld_type.reduce_or()) self.io.dcache.req.data.addr @= daddr self.io.dcache.req.data.data @= rs2 << w_offset self.io.dcache.req.data.mask @= m.dict_lookup( { ST_SW: m.bits(0b1111, 4), ST_SH: m.bits(0b11, 4) << m.zext(alu.sum_[0:2], 2), ST_SB: m.bits(0b1, 4) << m.zext(alu.sum_[0:2], 2), }, m.mux([self.io.ctrl.st_type, st_type.O], stall), m.bits(0, 4)) # Pipelining @m.inline_combinational() def pipeline_logic(): ew_pc.I @= ew_pc.O ew_inst.I @= ew_inst.O ew_alu.I @= ew_alu.O csr_in.I @= csr_in.O st_type.I @= st_type.O ld_type.I @= ld_type.O wb_sel.I @= wb_sel.O wb_en.I @= wb_en.O csr_cmd.I @= csr_cmd.O illegal.I @= illegal.O pc_check.I @= pc_check.O if m.bit(self.io.RESET) | ~stall & csr.expt: st_type.I @= 0 ld_type.I @= 0 wb_en.I @= 0 csr_cmd.I @= 0 illegal.I @= False pc_check.I @= False elif ~stall & ~csr.expt: ew_pc.I @= fe_pc.O ew_inst.I @= fe_inst.O ew_alu.I @= alu.O csr_in.I @= m.mux([rs1, imm_gen.O], self.io.ctrl.imm_sel == IMM_Z) st_type.I @= self.io.ctrl.st_type ld_type.I @= self.io.ctrl.ld_type wb_sel.I @= self.io.ctrl.wb_sel wb_en.I @= self.io.ctrl.wb_en csr_cmd.I @= self.io.ctrl.csr_cmd illegal.I @= self.io.ctrl.illegal pc_check.I @= self.io.ctrl.pc_sel == PC_ALU # Load l_offset = ((m.uint(ew_alu.O[1], x_len) << 4) | (m.uint(ew_alu.O[0], x_len) << 3)) l_shift = self.io.dcache.resp.data.data >> l_offset load = m.dict_lookup( { LD_LH: m.sext_to(m.sint(l_shift[0:16]), x_len), LD_LHU: m.sint(m.zext_to(l_shift[0:16], x_len)), LD_LB: m.sext_to(m.sint(l_shift[0:8]), x_len), LD_LBU: m.sint(m.zext_to(l_shift[0:8], x_len)) }, ld_type.O, m.sint(self.io.dcache.resp.data.data)) # CSR access csr.stall @= stall csr.I @= csr_in.O csr.cmd @= csr_cmd.O csr.inst @= ew_inst.O csr.pc @= ew_pc.O csr.addr @= ew_alu.O csr.illegal @= illegal.O csr.pc_check @= pc_check.O csr.ld_type @= ld_type.O csr.st_type @= st_type.O self.io.host @= csr.host # Regfile write reg_write = m.dict_lookup( { WB_MEM: m.uint(load), WB_PC4: (ew_pc.O + 4), WB_CSR: csr.O }, wb_sel.O, ew_alu.O) reg_file.wen @= m.enable(wb_en.O & ~stall & ~csr.expt) reg_file.waddr @= wb_rd_addr reg_file.wdata @= reg_write # Abort store when there's an exception self.io.dcache.abort @= csr.expt