Beispiel #1
0
    def __init__(self, x_len):
        super().__init__(x_len)
        inst = self.io.inst
        Iimm = m.sext_to(m.sint(inst[20:32]), x_len)
        Simm = m.sext_to(m.sint(m.concat(inst[7:12], inst[25:32])), x_len)
        Bimm = m.sext_to(
            m.sint(
                m.concat(m.bits(0, 1), inst[8:12], inst[25:31], inst[7],
                         inst[31])), x_len)
        Uimm = m.concat(m.bits(0, 12), inst[12:32])
        Jimm = m.sext_to(
            m.sint(
                m.concat(m.bits(0, 1), inst[21:25], inst[25:31], inst[20],
                         inst[12:20], inst[31])), x_len)
        Zimm = m.sint(m.zext_to(inst[15:20], x_len))

        self.io.O @= m.uint(
            m.dict_lookup(
                {
                    IMM_I: Iimm,
                    IMM_S: Simm,
                    IMM_B: Bimm,
                    IMM_U: Uimm,
                    IMM_J: Jimm,
                    IMM_Z: Zimm
                }, self.io.sel, Iimm & -2))
Beispiel #2
0
 def bytes_to_x_size(self, bytes_):
     return m.dict_lookup({
         1: m.bits(0, 3),
         2: m.bits(1, 3),
         4: m.bits(2, 3),
         8: m.bits(3, 3),
         16: m.bits(4, 3),
         32: m.bits(5, 3),
         64: m.bits(6, 3),
         128: m.bits(7, 3),
     }, bytes_, m.bits(0b111, 3))
Beispiel #3
0
    def __init__(self, x_len):
        self.io = io = make_ControlIO(x_len)

        ctrl_signals = m.dict_lookup(inst_map, io.inst, default=default)

        # Control signals for Fetch
        io.pc_sel @= ctrl_signals[0]
        io.inst_kill @= ctrl_signals[6]

        # Control signals for Execute
        io.A_sel @= ctrl_signals[1]
        io.B_sel @= ctrl_signals[2]
        io.imm_sel @= ctrl_signals[3]
        io.alu_op @= ctrl_signals[4]
        io.br_type @= ctrl_signals[5]
        io.st_type @= ctrl_signals[7]

        # Control signals for Write Back
        io.ld_type @= ctrl_signals[8]
        io.wb_sel @= ctrl_signals[9]
        io.wb_en @= ctrl_signals[10]
        io.csr_cmd @= ctrl_signals[11]
        io.illegal @= ctrl_signals[12]
Beispiel #4
0
    class CSR_DUT(m.Circuit):
        io = m.IO(done=m.Out(m.Bit),
                  check=m.Out(m.Bit),
                  rdata=m.Out(m.UInt[x_len]),
                  expected_rdata=m.Out(m.UInt[x_len]),
                  epc=m.Out(m.UInt[x_len]),
                  expected_epc=m.Out(m.UInt[x_len]),
                  evec=m.Out(m.UInt[x_len]),
                  expected_evec=m.Out(m.UInt[x_len]),
                  expt=m.Out(m.Bit),
                  expected_expt=m.Out(m.Bit))
        io += m.ClockIO(has_reset=True)

        regs = {}
        for reg in CSR.regs:
            if reg == CSR.mcpuid:
                init = (1 << (ord('I') - ord('A')) | 1 <<
                        (ord('U') - ord('A')))
            elif reg == CSR.mstatus:
                init = (CSR.PRV_M.ext(30) << 4) | (CSR.PRV_M.ext(30) << 1)
            elif reg == CSR.mtvec:
                init = Const.PC_EVEC
            else:
                init = 0
            regs[reg] = m.Register(init=BV[32](init), reset_type=m.Reset)()

        csr = CSRGen(x_len)()
        ctrl = Control.Control(x_len)()

        counter = CounterModM(n, n.bit_length())
        inst = m.mux(insts, counter.O)
        ctrl.inst @= inst
        csr.inst @= inst
        csr_cmd = ctrl.csr_cmd
        csr.cmd @= csr_cmd
        csr.illegal @= ctrl.illegal
        csr.st_type @= ctrl.st_type
        csr.ld_type @= ctrl.ld_type
        csr.pc_check @= ctrl.pc_sel == Control.PC_ALU
        csr.pc @= m.mux(pc, counter.O)
        csr.addr @= m.mux(addr, counter.O)
        csr.I @= m.mux(data, counter.O)
        csr.stall @= False
        csr.host.fromhost.valid @= False
        csr.host.fromhost.data @= 0

        # values known statically
        _csr_addr = [csr(inst) for inst in insts]
        _rs1_addr = [rs1(inst) for inst in insts]
        _csr_ro = [((((x >> 11) & 0x1) > 0x0) & (((x >> 10) & 0x1) > 0x0)) |
                   (x == CSR.mtvec) | (x == CSR.mtdeleg) for x in _csr_addr]
        _csr_valid = [x in CSR.regs for x in _csr_addr]
        # should be <= prv in runtime
        _prv_level = [(x >> 8) & 0x3 for x in _csr_addr]
        # should consider prv in runtime
        _is_ecall = [((x & 0x1) == 0x0) & (((x >> 8) & 0x1) == 0x0)
                     for x in _csr_addr]
        _is_ebreak = [((x & 0x1) > 0x0) & (((x >> 8) & 0x1) == 0x0)
                      for x in _csr_addr]
        _is_eret = [((x & 0x1) == 0x0) & (((x >> 8) & 0x1) > 0x0)
                    for x in _csr_addr]
        # should consider pc_check in runtime
        _iaddr_invalid = [((x >> 1) & 0x1) > 0 for x in addr]
        # should consider ld_type & sd_type
        _waddr_invalid = [(((x >> 1) & 0x1) > 0) | ((x & 0x1) > 0)
                          for x in addr]
        _haddr_invalid = [(x & 0x1) > 0 for x in addr]

        # values known at runtime
        csr_addr = m.mux(_csr_addr, counter.O)
        rs1_addr = m.mux(_rs1_addr, counter.O)
        csr_ro = m.mux(_csr_ro, counter.O)
        csr_valid = m.mux(_csr_valid, counter.O)

        wen = (csr_cmd == CSR.W) | (csr_cmd[1] & (rs1_addr != 0))
        prv1 = (regs[CSR.mstatus].O >> 4) & 0x3
        ie1 = (regs[CSR.mstatus].O >> 3) & 0x1
        prv = (regs[CSR.mstatus].O >> 1) & 0x3
        ie = regs[CSR.mstatus].O & 0x1
        prv_inst = csr_cmd == CSR.P
        prv_valid = (m.uint(m.zext_to(m.mux(_prv_level, counter.O), 32)) <=
                     m.uint(prv))
        iaddr_invalid = m.mux(_iaddr_invalid, counter.O) & csr.pc_check.value()
        laddr_invalid = (m.mux(_haddr_invalid, counter.O) &
                         ((ctrl.ld_type == Control.LD_LH) |
                          (ctrl.ld_type == Control.LD_LHU))
                         | m.mux(_waddr_invalid, counter.O) &
                         (ctrl.ld_type == Control.LD_LW))
        saddr_invalid = (m.mux(_haddr_invalid, counter.O) &
                         (ctrl.st_type == Control.ST_SH)
                         | m.mux(_waddr_invalid, counter.O) &
                         (ctrl.st_type == Control.ST_SW))
        is_ecall = prv_inst & m.mux(_is_ecall, counter.O)
        is_ebreak = prv_inst & m.mux(_is_ebreak, counter.O)
        is_eret = prv_inst & m.mux(_is_eret, counter.O)
        exception = (ctrl.illegal | iaddr_invalid | laddr_invalid
                     | saddr_invalid | (((csr_cmd & 0x3) > 0) &
                                        (~csr_valid | ~prv_valid)) |
                     (csr_ro & wen) | (prv_inst & ~prv_valid) | is_ecall
                     | is_ebreak)
        instret = (inst != nop) & (~exception | is_ecall | is_ebreak)

        rdata = m.dict_lookup({key: value.O
                               for key, value in regs.items()}, csr_addr)
        wdata = m.dict_lookup(
            {
                CSR.W: csr.I.value(),
                CSR.S: (csr.I.value() | rdata),
                CSR.C: (~csr.I.value() & rdata)
            }, csr_cmd)

        # compute state
        regs[CSR.time].I @= regs[CSR.time].O + 1
        regs[CSR.timew].I @= regs[CSR.timew].O + 1
        regs[CSR.mtime].I @= regs[CSR.mtime].O + 1
        regs[CSR.cycle].I @= regs[CSR.cycle].O + 1
        regs[CSR.cyclew].I @= regs[CSR.cyclew].O + 1

        time_max = regs[CSR.time].O.reduce_and()
        # TODO: mtime has same default value as this case (from chisel code)
        # https://github.com/ucb-bar/riscv-mini/blob/release/src/test/scala/CSRTests.scala#L140
        # mtime_reg = regs[CSR.mtime]
        # mtime_reg.I @= m.mux([mtime_reg.O, mtime_reg.O + 1], time_max)

        incr_when(regs[CSR.timeh], time_max)
        incr_when(regs[CSR.timehw], time_max)

        cycle_max = regs[CSR.cycle].O.reduce_and()

        incr_when(regs[CSR.cycleh], cycle_max)
        incr_when(regs[CSR.cyclehw], cycle_max)

        incr_when(regs[CSR.instret], instret)
        incr_when(regs[CSR.instretw], instret)

        instret_max = regs[CSR.instret].O.reduce_and()
        incr_when(regs[CSR.instreth], instret & instret_max)
        incr_when(regs[CSR.instrethw], instret & instret_max)

        cond = ~exception & ~is_eret & wen
        # Assuming these are mutually exclusive, so we don't need chained
        # elsewhen
        update_when(regs[CSR.mstatus], m.zext_to(wdata[0:6], 32),
                    cond & (csr_addr == CSR.mstatus))
        update_when(regs[CSR.mip],
                    (m.bits(wdata[7], 32) << 7) | (m.bits(wdata[3], 32) << 3),
                    cond & (csr_addr == CSR.mip))
        update_when(regs[CSR.mie],
                    (m.bits(wdata[7], 32) << 7) | (m.bits(wdata[3], 32) << 3),
                    cond & (csr_addr == CSR.mie))
        update_when(regs[CSR.mepc], (wdata >> 2) << 2,
                    cond & (csr_addr == CSR.mepc))
        update_when(regs[CSR.mcause], wdata & (1 << 31 | 0xf),
                    cond & (csr_addr == CSR.mcause))
        update_when(regs[CSR.time], wdata,
                    cond & ((csr_addr == CSR.timew) | (csr_addr == CSR.mtime)))
        update_when(regs[CSR.timew], wdata,
                    cond & ((csr_addr == CSR.timew) | (csr_addr == CSR.mtime)))
        update_when(regs[CSR.mtime], wdata,
                    cond & ((csr_addr == CSR.timew) | (csr_addr == CSR.mtime)))
        update_when(
            regs[CSR.timeh], wdata,
            cond & ((csr_addr == CSR.timehw) | (csr_addr == CSR.mtimeh)))
        update_when(
            regs[CSR.timehw], wdata,
            cond & ((csr_addr == CSR.timehw) | (csr_addr == CSR.mtimeh)))
        update_when(
            regs[CSR.mtimeh], wdata,
            cond & ((csr_addr == CSR.timehw) | (csr_addr == CSR.mtimeh)))
        update_when(regs[CSR.cycle], wdata, cond & (csr_addr == CSR.cyclew))
        update_when(regs[CSR.cyclew], wdata, cond & (csr_addr == CSR.cyclew))
        update_when(regs[CSR.cycleh], wdata, cond & (csr_addr == CSR.cyclehw))
        update_when(regs[CSR.cyclehw], wdata, cond & (csr_addr == CSR.cyclehw))
        update_when(regs[CSR.instret], wdata,
                    cond & (csr_addr == CSR.instretw))
        update_when(regs[CSR.instretw], wdata,
                    cond & (csr_addr == CSR.instretw))
        update_when(regs[CSR.instreth], wdata,
                    cond & (csr_addr == CSR.instrethw))
        update_when(regs[CSR.instrethw], wdata,
                    cond & (csr_addr == CSR.instrethw))
        update_when(regs[CSR.mtimecmp], wdata,
                    cond & (csr_addr == CSR.mtimecmp))
        update_when(regs[CSR.mscratch], wdata,
                    cond & (csr_addr == CSR.mscratch))
        update_when(regs[CSR.mbadaddr], wdata,
                    cond & (csr_addr == CSR.mbadaddr))
        update_when(regs[CSR.mtohost], wdata, cond & (csr_addr == CSR.mtohost))
        update_when(regs[CSR.mfromhost], wdata,
                    cond & (csr_addr == CSR.mfromhost))

        # eret
        update_when(regs[CSR.mstatus],
                    (CSR.PRV_U.zext(30) << 4) | (1 << 3) | (prv1 << 1) | ie1,
                    ~exception & is_eret)

        # TODO: exception logic comes after since it has priority
        Cause = make_Cause(x_len)
        mcause = m.mux([
            m.mux([
                m.mux([
                    m.mux([
                        m.mux([Cause.IllegalInst, Cause.Breakpoint],
                              is_ebreak),
                        Cause.Ecall + prv,
                    ], is_ecall),
                    Cause.StoreAddrMisaligned,
                ], saddr_invalid),
                Cause.LoadAddrMisaligned,
            ], laddr_invalid),
            Cause.InstAddrMisaligned,
        ], iaddr_invalid)
        update_when(regs[CSR.mcause], mcause, exception)

        update_when(regs[CSR.mepc], (csr.pc.value() >> 2) << 2, exception)
        update_when(regs[CSR.mstatus],
                    (prv << 4) | (ie << 3) | (CSR.PRV_M.zext(30) << 1),
                    exception)
        update_when(
            regs[CSR.mbadaddr], csr.addr.value(),
            exception & (iaddr_invalid | laddr_invalid | saddr_invalid))

        epc = regs[CSR.mepc].O
        evec = regs[CSR.mtvec].O + (prv << 6)

        m.display("*** Counter: %d ***", counter.O)
        m.display("[in] inst: 0x%x, pc: 0x%x, addr: 0x%x, in: 0x%x", csr.inst,
                  csr.pc, csr.addr, csr.I)

        m.display(
            "     cmd: 0x%x, st_type: 0x%x, ld_type: 0x%x, illegal: %d, "
            "pc_check: %d", csr.cmd, csr.st_type, csr.ld_type, csr.illegal,
            csr.pc_check)

        m.display("[state] csr addr: %x", csr_addr)

        for reg_addr, reg in regs.items():
            m.display(f" {hex(int(reg_addr))} -> 0x%x", reg.O)

        m.display(
            "[out] read: 0x%x =? 0x%x, epc: 0x%x =? 0x%x, evec: 0x%x ?= "
            "0x%x, expt: %d ?= %d", csr.O, rdata, csr.epc, epc, csr.evec, evec,
            csr.expt, exception)
        io.check @= counter.O.reduce_or()

        io.rdata @= csr.O
        io.expected_rdata @= rdata

        io.epc @= csr.epc
        io.expected_epc @= epc

        io.evec @= csr.evec
        io.expected_evec @= evec

        io.expt @= csr.expt
        io.expected_expt @= exception

        # io.failed @= counter.O.reduce_or() & (
        #     (csr.O != rdata) |
        #     (csr.epc != epc) |
        #     (csr.evec != evec) |
        #     (csr.expt != exception)
        # )
        io.done @= counter.COUT
        for key, reg in regs.items():
            if not reg.I.driven():
                reg.I @= reg.O
Beispiel #5
0
    def __init__(self, x_len):
        Cause = make_Cause(x_len)

        self.io = io = m.IO(
            stall=m.In(m.Bit),
            cmd=m.In(m.UInt[3]),
            I=m.In(m.UInt[x_len]),
            O=m.Out(m.UInt[x_len]),
            # Excpetion
            pc=m.In(m.UInt[x_len]),
            addr=m.In(m.UInt[x_len]),
            inst=m.In(m.UInt[x_len]),
            illegal=m.In(m.Bit),
            st_type=m.In(m.UInt[2]),
            ld_type=m.In(m.UInt[3]),
            pc_check=m.In(m.Bit),
            expt=m.Out(m.Bit),
            evec=m.Out(m.UInt[x_len]),
            epc=m.Out(
                m.UInt[x_len])) + HostIO(x_len) + m.ClockIO(has_reset=True)

        csr_addr = io.inst[20:32]
        rs1_addr = io.inst[15:20]

        # user counters
        time = m.Register(m.UInt[x_len], reset_type=m.Reset)()
        timeh = m.Register(m.UInt[x_len], reset_type=m.Reset)()
        cycle = m.Register(m.UInt[x_len], reset_type=m.Reset)()
        cycleh = m.Register(m.UInt[x_len], reset_type=m.Reset)()
        instret = m.Register(m.UInt[x_len], reset_type=m.Reset)()
        instreth = m.Register(m.UInt[x_len], reset_type=m.Reset)()

        mcpuid = m.concat(
            BV[26](
                1 << (ord('I') - ord('A')) |  # Base ISA
                1 << (ord('U') - ord('A'))),  # User Mode
            BV[x_len - 28](0),
            BV[2](0),  # RV32I
        )
        mimpid = BV[x_len](0)
        mhartid = BV[x_len](0)

        # interrupt enable stack
        PRV = m.Register(m.UInt[len(CSR.PRV_M)],
                         init=CSR.PRV_M,
                         reset_type=m.Reset)()
        PRV1 = m.Register(m.UInt[len(CSR.PRV_M)],
                          init=CSR.PRV_M,
                          reset_type=m.Reset)()
        PRV2 = BV[2](0)
        PRV3 = BV[2](0)
        IE = m.Register(m.Bit, init=False, reset_type=m.Reset)()
        IE1 = m.Register(m.Bit, init=False, reset_type=m.Reset)()
        IE2 = False
        IE3 = False

        # virtualization management field
        VM = BV[5](0)

        # memory privilege
        MPRV = False

        # Extension context status
        XS = BV[2](0)
        FS = BV[2](0)
        SD = BV[1](0)
        mstatus = m.concat(IE.O, PRV.O, IE1.O, PRV1.O, IE2, PRV2, IE3, PRV3,
                           FS, XS, MPRV, VM, BV[x_len - 23](0), SD)
        mtvec = BV[x_len](Const.PC_EVEC)
        mtdeleg = BV[x_len](0)

        # interrupt registers
        MTIP = m.Register(m.Bit, init=False, reset_type=m.Reset)()
        HTIP = False
        STIP = False
        MTIE = m.Register(m.Bit, init=False, reset_type=m.Reset)()
        HTIE = False
        STIE = False
        MSIP = m.Register(m.Bit, init=False, reset_type=m.Reset)()
        HSIP = False
        SSIP = False
        MSIE = m.Register(m.Bit, init=False, reset_type=m.Reset)()
        HSIE = False
        SSIE = False

        mip = m.concat(Bit(False), SSIP, HSIP, MSIP.O, Bit(False), STIP, HTIP,
                       MTIP.O, BV[x_len - 8](0))
        mie = m.concat(Bit(False), SSIE, HSIE, MSIE.O, Bit(False), STIE, HTIE,
                       MTIE.O, BV[x_len - 8](0))

        mtimecmp = m.Register(m.UInt[x_len], reset_type=m.Reset)()
        mscratch = m.Register(m.UInt[x_len], reset_type=m.Reset)()

        mepc = m.Register(m.UInt[x_len], reset_type=m.Reset)()
        mcause = m.Register(m.UInt[x_len], reset_type=m.Reset)()
        mbadaddr = m.Register(m.UInt[x_len], reset_type=m.Reset)()

        mtohost = m.Register(m.UInt[x_len], reset_type=m.Reset)()
        mfromhost = m.Register(m.UInt[x_len], reset_type=m.Reset)()

        io.host.tohost @= mtohost.O
        csr_file = {
            CSR.cycle: cycle.O,
            CSR.time: time.O,
            CSR.instret: instret.O,
            CSR.cycleh: cycleh.O,
            CSR.timeh: timeh.O,
            CSR.instreth: instreth.O,
            CSR.cyclew: cycle.O,
            CSR.timew: time.O,
            CSR.instretw: instret.O,
            CSR.cyclehw: cycleh.O,
            CSR.timehw: timeh.O,
            CSR.instrethw: instreth.O,
            CSR.mcpuid: mcpuid,
            CSR.mimpid: mimpid,
            CSR.mhartid: mhartid,
            CSR.mtvec: mtvec,
            CSR.mtdeleg: mtdeleg,
            CSR.mie: mie,
            CSR.mtimecmp: mtimecmp.O,
            CSR.mtime: time.O,
            CSR.mtimeh: timeh.O,
            CSR.mscratch: mscratch.O,
            CSR.mepc: mepc.O,
            CSR.mcause: mcause.O,
            CSR.mbadaddr: mbadaddr.O,
            CSR.mip: mip,
            CSR.mtohost: mtohost.O,
            CSR.mfromhost: mfromhost.O,
            CSR.mstatus: mstatus,
        }
        out = m.dict_lookup(csr_file, csr_addr)
        io.O @= out

        priv_valid = csr_addr[8:10] <= PRV.O
        priv_inst = io.cmd == CSR.P
        is_E_call = priv_inst & ~csr_addr[0] & ~csr_addr[8]
        is_E_break = priv_inst & csr_addr[0] & ~csr_addr[8]
        is_E_ret = priv_inst & ~csr_addr[0] & csr_addr[8]
        csr_valid = m.reduce(operator.or_,
                             m.bits([csr_addr == key for key in csr_file]))
        csr_RO = (csr_addr[10:12].reduce_and() | (csr_addr == CSR.mtvec) |
                  (csr_addr == CSR.mtdeleg))
        wen = (io.cmd == CSR.W) | io.cmd[1] & rs1_addr.reduce_or()
        wdata = m.dict_lookup(
            {
                CSR.W: io.I,
                CSR.S: out | io.I,
                CSR.C: out & ~io.I
            }, io.cmd)

        iaddr_invalid = io.pc_check & io.addr[1]

        laddr_invalid = m.dict_lookup(
            {
                Control.LD_LW: io.addr[0:2].reduce_or(),
                Control.LD_LH: io.addr[0],
                Control.LD_LHU: io.addr[0]
            }, io.ld_type)

        saddr_invalid = m.dict_lookup(
            {
                Control.ST_SW: io.addr[0:2].reduce_or(),
                Control.ST_SH: io.addr[0]
            }, io.st_type)

        expt = (io.illegal | iaddr_invalid | laddr_invalid | saddr_invalid
                | io.cmd[0:2].reduce_or() & (~csr_valid | ~priv_valid)
                | wen & csr_RO | (priv_inst & ~priv_valid) | is_E_call
                | is_E_break)
        io.expt @= expt

        io.evec @= mtvec + (m.zext_to(PRV.O, x_len) << 6)
        io.epc @= mepc.O

        @m.inline_combinational()
        def logic():
            # Counters
            time.I @= time.O + 1
            timeh.I @= timeh.O
            if time.O.reduce_and():
                timeh.I @= timeh.O + 1

            cycle.I @= cycle.O + 1
            cycleh.I @= cycleh.O
            if cycle.O.reduce_and():
                cycleh.I @= cycleh.O + 1
            instret.I @= instret.O
            is_inst_ret = ((io.inst != Instructions.NOP) &
                           (~expt | is_E_call | is_E_break) & ~io.stall)
            if is_inst_ret:
                instret.I @= instret.O + 1
            instreth.I @= instreth.O
            if is_inst_ret & instret.O.reduce_and():
                instreth.I @= instreth.O + 1

            mbadaddr.I @= mbadaddr.O
            mepc.I @= mepc.O
            mcause.I @= mcause.O
            PRV.I @= PRV.O
            IE.I @= IE.O
            IE1.I @= IE1.O
            PRV1.I @= PRV1.O
            MTIP.I @= MTIP.O
            MSIP.I @= MSIP.O
            MTIE.I @= MTIE.O
            MSIE.I @= MSIE.O
            mtimecmp.I @= mtimecmp.O
            mscratch.I @= mscratch.O
            mtohost.I @= mtohost.O
            mfromhost.I @= mfromhost.O
            if io.host.fromhost.valid:
                mfromhost.I @= io.host.fromhost.data

            if ~io.stall:
                if expt:
                    mepc.I @= io.pc >> 2 << 2
                    if iaddr_invalid:
                        mcause.I @= Cause.InstAddrMisaligned
                    elif laddr_invalid:
                        mcause.I @= Cause.LoadAddrMisaligned
                    elif saddr_invalid:
                        mcause.I @= Cause.StoreAddrMisaligned
                    elif is_E_call:
                        mcause.I @= Cause.Ecall + m.zext_to(PRV.O, x_len)
                    elif is_E_break:
                        mcause.I @= Cause.Breakpoint
                    else:
                        mcause.I @= Cause.IllegalInst
                    PRV.I @= CSR.PRV_M
                    IE.I @= False
                    PRV1.I @= PRV.O
                    IE1.I @= IE.O
                    if iaddr_invalid | laddr_invalid | saddr_invalid:
                        mbadaddr.I @= io.addr
                elif is_E_ret:
                    PRV.I @= PRV1.O
                    IE.I @= IE1.O
                    PRV1.I @= CSR.PRV_U
                    IE1.I @= True
                elif wen:
                    if csr_addr == CSR.mstatus:
                        PRV1.I @= wdata[4:6]
                        IE1.I @= wdata[3]
                        PRV.I @= wdata[1:3]
                        IE.I @= wdata[0]
                    elif csr_addr == CSR.mip:
                        MTIP.I @= wdata[7]
                        MSIP.I @= wdata[3]
                    elif csr_addr == CSR.mie:
                        MTIE.I @= wdata[7]
                        MSIE.I @= wdata[3]
                    elif csr_addr == CSR.mtime:
                        time.I @= wdata
                    elif csr_addr == CSR.mtimeh:
                        timeh.I @= wdata
                    elif csr_addr == CSR.mtimecmp:
                        mtimecmp.I @= wdata
                    elif csr_addr == CSR.mscratch:
                        mscratch.I @= wdata
                    elif csr_addr == CSR.mepc:
                        mepc.I @= wdata >> 2 << 2
                    elif csr_addr == CSR.mcause:
                        mcause.I @= wdata & (1 << (x_len - 1) | 0xf)
                    elif csr_addr == CSR.mbadaddr:
                        mbadaddr.I @= wdata
                    elif csr_addr == CSR.mtohost:
                        mtohost.I @= wdata
                    elif csr_addr == CSR.mfromhost:
                        mfromhost.I @= wdata
                    elif csr_addr == CSR.cyclew:
                        cycle.I @= wdata
                    elif csr_addr == CSR.timew:
                        time.I @= wdata
                    elif csr_addr == CSR.instretw:
                        instret.I @= wdata
                    elif csr_addr == CSR.cyclehw:
                        cycleh.I @= wdata
                    elif csr_addr == CSR.timehw:
                        timeh.I @= wdata
                    elif csr_addr == CSR.instrethw:
                        instreth.I @= wdata
Beispiel #6
0
    def __init__(self,
                 x_len,
                 ALU=ALUArea,
                 ImmGen=ImmGenWire,
                 BrCond=BrCondArea):
        self.io = make_DatapathIO(x_len) + m.ClockIO(has_reset=True)
        csr = CSRGen(x_len)()
        reg_file = RegFile(x_len)()
        alu = ALU(x_len)()
        imm_gen = ImmGen(x_len)()
        br_cond = BrCondArea(x_len)()

        # Fetch / Execute Registers
        fe_inst = m.Register(init=Instructions.NOP, has_enable=True)()
        fe_pc = m.Register(m.UInt[x_len], has_enable=True)()

        # Execute / Write Back Registers
        ew_inst = m.Register(init=Instructions.NOP)()
        ew_pc = m.Register(m.UInt[x_len])()
        ew_alu = m.Register(m.UInt[x_len])()
        csr_in = m.Register(m.UInt[x_len])()

        # Control signals
        st_type = m.Register(type(self.io.ctrl.st_type).undirected_t)()
        ld_type = m.Register(type(self.io.ctrl.ld_type).undirected_t)()
        wb_sel = m.Register(type(self.io.ctrl.wb_sel).undirected_t)()
        wb_en = m.Register(m.Bit)()
        csr_cmd = m.Register(type(self.io.ctrl.csr_cmd).undirected_t)()
        illegal = m.Register(m.Bit)()
        pc_check = m.Register(m.Bit)()

        # Fetch
        started = m.Register(m.Bit)()(m.bit(self.io.RESET))
        stall = ~self.io.icache.resp.valid | ~self.io.dcache.resp.valid
        pc = m.Register(init=UIntVector[x_len](Const.PC_START) -
                        UIntVector[x_len](4))()
        npc = m.mux([
            m.mux([
                m.mux([
                    m.mux([
                        m.mux([pc.O + m.uint(4, x_len), pc.O],
                              self.io.ctrl.pc_sel == PC_0), alu.sum_ >> 1 << 1
                    ], (self.io.ctrl.pc_sel == PC_ALU) | br_cond.taken),
                    csr.epc
                ], self.io.ctrl.pc_sel == PC_EPC), csr.evec
            ], csr.expt), pc.O
        ], stall)

        inst = m.mux([self.io.icache.resp.data.data, Instructions.NOP], started
                     | self.io.ctrl.inst_kill | br_cond.taken | csr.expt)

        pc.I @= npc
        self.io.icache.req.data.addr @= npc
        self.io.icache.req.data.data @= 0
        self.io.icache.req.data.mask @= 0
        self.io.icache.req.valid @= ~stall
        self.io.icache.abort @= False

        fe_pc.I @= pc.O
        fe_pc.CE @= m.enable(~stall)
        fe_inst.I @= inst
        fe_inst.CE @= m.enable(~stall)

        # Execute
        # Decode
        self.io.ctrl.inst @= fe_inst.O

        # reg_file read
        rs1_addr = fe_inst.O[15:20]
        rs2_addr = fe_inst.O[20:25]
        reg_file.raddr1 @= rs1_addr
        reg_file.raddr2 @= rs2_addr

        # gen immediates
        imm_gen.inst @= fe_inst.O
        imm_gen.sel @= self.io.ctrl.imm_sel

        # bypass
        wb_rd_addr = ew_inst.O[7:12]
        rs1_hazard = wb_en.O & rs1_addr.reduce_or() & (rs1_addr == wb_rd_addr)
        rs2_hazard = wb_en.O & rs2_addr.reduce_or() & (rs2_addr == wb_rd_addr)
        rs1 = m.mux([reg_file.rdata1, ew_alu.O],
                    (wb_sel.O == WB_ALU) & rs1_hazard)
        rs2 = m.mux([reg_file.rdata2, ew_alu.O],
                    (wb_sel.O == WB_ALU) & rs2_hazard)

        # ALU operations
        alu.A @= m.mux([fe_pc.O, rs1], self.io.ctrl.A_sel == A_RS1)
        alu.B @= m.mux([imm_gen.O, rs2], self.io.ctrl.B_sel == B_RS2)
        alu.op @= self.io.ctrl.alu_op

        # Branch condition calc
        br_cond.rs1 @= rs1
        br_cond.rs2 @= rs2
        br_cond.br_type @= self.io.ctrl.br_type

        # D$ access
        daddr = m.mux([alu.sum_, ew_alu.O], stall) >> 2 << 2
        w_offset = ((m.bits(alu.sum_[1], x_len) << 4) |
                    (m.bits(alu.sum_[0], x_len) << 3))
        self.io.dcache.req.valid @= ~stall & (self.io.ctrl.st_type.reduce_or()
                                              |
                                              self.io.ctrl.ld_type.reduce_or())
        self.io.dcache.req.data.addr @= daddr
        self.io.dcache.req.data.data @= rs2 << w_offset
        self.io.dcache.req.data.mask @= m.dict_lookup(
            {
                ST_SW: m.bits(0b1111, 4),
                ST_SH: m.bits(0b11, 4) << m.zext(alu.sum_[0:2], 2),
                ST_SB: m.bits(0b1, 4) << m.zext(alu.sum_[0:2], 2),
            }, m.mux([self.io.ctrl.st_type, st_type.O], stall), m.bits(0, 4))

        # Pipelining
        @m.inline_combinational()
        def pipeline_logic():
            ew_pc.I @= ew_pc.O
            ew_inst.I @= ew_inst.O
            ew_alu.I @= ew_alu.O
            csr_in.I @= csr_in.O
            st_type.I @= st_type.O
            ld_type.I @= ld_type.O
            wb_sel.I @= wb_sel.O
            wb_en.I @= wb_en.O
            csr_cmd.I @= csr_cmd.O
            illegal.I @= illegal.O
            pc_check.I @= pc_check.O
            if m.bit(self.io.RESET) | ~stall & csr.expt:
                st_type.I @= 0
                ld_type.I @= 0
                wb_en.I @= 0
                csr_cmd.I @= 0
                illegal.I @= False
                pc_check.I @= False
            elif ~stall & ~csr.expt:
                ew_pc.I @= fe_pc.O
                ew_inst.I @= fe_inst.O
                ew_alu.I @= alu.O
                csr_in.I @= m.mux([rs1, imm_gen.O],
                                  self.io.ctrl.imm_sel == IMM_Z)
                st_type.I @= self.io.ctrl.st_type
                ld_type.I @= self.io.ctrl.ld_type
                wb_sel.I @= self.io.ctrl.wb_sel
                wb_en.I @= self.io.ctrl.wb_en
                csr_cmd.I @= self.io.ctrl.csr_cmd
                illegal.I @= self.io.ctrl.illegal
                pc_check.I @= self.io.ctrl.pc_sel == PC_ALU

        # Load
        l_offset = ((m.uint(ew_alu.O[1], x_len) << 4) |
                    (m.uint(ew_alu.O[0], x_len) << 3))
        l_shift = self.io.dcache.resp.data.data >> l_offset
        load = m.dict_lookup(
            {
                LD_LH: m.sext_to(m.sint(l_shift[0:16]), x_len),
                LD_LHU: m.sint(m.zext_to(l_shift[0:16], x_len)),
                LD_LB: m.sext_to(m.sint(l_shift[0:8]), x_len),
                LD_LBU: m.sint(m.zext_to(l_shift[0:8], x_len))
            }, ld_type.O, m.sint(self.io.dcache.resp.data.data))

        # CSR access
        csr.stall @= stall
        csr.I @= csr_in.O
        csr.cmd @= csr_cmd.O
        csr.inst @= ew_inst.O
        csr.pc @= ew_pc.O
        csr.addr @= ew_alu.O
        csr.illegal @= illegal.O
        csr.pc_check @= pc_check.O
        csr.ld_type @= ld_type.O
        csr.st_type @= st_type.O
        self.io.host @= csr.host

        # Regfile write
        reg_write = m.dict_lookup(
            {
                WB_MEM: m.uint(load),
                WB_PC4: (ew_pc.O + 4),
                WB_CSR: csr.O
            }, wb_sel.O, ew_alu.O)

        reg_file.wen @= m.enable(wb_en.O & ~stall & ~csr.expt)
        reg_file.waddr @= wb_rd_addr
        reg_file.wdata @= reg_write

        # Abort store when there's an exception
        self.io.dcache.abort @= csr.expt