Esempio n. 1
0
class IType:
    def __init__(self, prefix=""):
        if prefix:
            prefix += "_"
        self.opcode = Signal(7, name=f"{prefix}opcode")
        self.rd = Signal(5, name=f"{prefix}rd")
        self.funct3 = Signal(3, name=f"{prefix}funct3")
        self.rs1 = Signal(5, name=f"{prefix}rs1")
        self.imm = Signal(32, name=f"{prefix}imm12")

    def elaborate(self, comb: List[Statement], input: Signal):
        comb += self.opcode.eq(input[0:7])
        comb += self.rd.eq(input[7:12])
        comb += self.funct3.eq(input[12:15])
        comb += self.rs1.eq(input[15:20])
        comb += self.imm.eq(Cat(input[20:32], Repl(input[31], 20)))

    def match(self,
              opcode=None,
              rd=None,
              funct3=None,
              rs1=None,
              imm=None) -> Value:
        """ Build boolean expression that matches x against provided parts """
        if type(imm) == int:
            assert imm.bit_length(
            ) <= 32, "imm must be 32 bit long(12 bits+signext)"
            imm = imm & (2**32) - 1
        subexpressions = []
        if opcode is not None:
            subexpressions.append(self.opcode.matches(opcode))
        if rd is not None: subexpressions.append(self.rd.matches(rd))
        if funct3 is not None:
            subexpressions.append(self.funct3.matches(funct3))
        if rs1 is not None: subexpressions.append(self.rs1.matches(rs1))
        if imm is not None: subexpressions.append(self.imm.matches(imm))

        if not subexpressions:
            print("warning: no matches provided for itype.match")
            return Const(1)
        res = subexpressions.pop(0)
        while subexpressions:
            res = res & subexpressions.pop(0)
        return res

    @staticmethod
    def build_i32(opcode: int = 0,
                  rd: int = 0,
                  funct3: int = 0,
                  rs1: int = 0,
                  imm: int = 0,
                  ensure_ints=True) -> int:
        if type(imm) == int:
            assert -2**12 <= imm < 2**12
        imm = imm & ((2**12) - 1)

        word = opcode | (rd << 7) | (funct3 << 12) | (rs1 << 15) | (imm << 20)
        return word
Esempio n. 2
0
class JType:
    def __init__(self, prefix=""):
        if prefix:
            prefix += "_"
        self.opcode = Signal(7, name=f"{prefix}opcode")
        self.rd = Signal(5, name=f"{prefix}rd")
        self.imm = Signal(32, name=f"{prefix}imm")

    def elaborate(self, comb: List[Statement], input: Signal):
        comb += self.opcode.eq(input[0:7])
        comb += self.rd.eq(input[7:12])
        comb += self.imm.eq(
            Cat(Const(0, 1), input[21:31], input[20], input[12:20],
                Repl(input[31], 12)))

    def match(self, opcode=None, rd=None, imm=None) -> Value:
        """ Build boolean expression that matches x against provided parts """
        if type(imm) == int:
            assert imm.bit_length(
            ) <= 32, "imm must be 32 bit long(12 bits+signext)"
            assert imm % 2 == 0, "jtype has 2-byte offset"
            imm = imm & (2**32) - 1
        subexpressions = []
        if opcode is not None:
            subexpressions.append(self.opcode.matches(opcode))
        if rd is not None: subexpressions.append(self.rd.matches(rd))
        if imm is not None: subexpressions.append(self.imm.matches(imm))

        if not subexpressions:
            print("warning: no matches provided for jtype.match")
            return Const(1)
        res = subexpressions.pop(0)
        while subexpressions:
            res = res & subexpressions.pop(0)
        return res

    @staticmethod
    def build_i32(opcode: int = 0, rd: int = 0, imm: int = 0) -> int:
        if type(imm) == int:
            assert -2**21 <= imm < 2**21
            assert imm % 2 == 0
        imm = imm & ((2**21) - 1)

        word = (opcode) | (rd << 7) | (bit_slice(imm, 10, 1) << 21) | (
            bit_slice(imm, 11, 11) << 20) | (bit_slice(imm, 19, 12) << 12)
        word = word | (bit_slice(imm, 20, 20) << 31)
        return word
Esempio n. 3
0
class UType:
    def __init__(self, prefix=""):
        if prefix:
            prefix += "_"
        self.opcode = Signal(7, name=f"{prefix}opcode")
        self.rd = Signal(5, name=f"{prefix}rd")
        self.imm = Signal(32, name=f"{prefix}imm")

    def elaborate(self, comb: List[Statement], input: Signal):
        comb += self.opcode.eq(input[0:7])
        comb += self.rd.eq(input[7:12])
        comb += self.imm.eq(Cat(Const(0, 12), input[12:32]))

    @staticmethod
    def build_i32(opcode: int = 0, rd: int = 0, imm: int = 0) -> int:
        if type(imm) == int:
            assert imm.bit_length(
            ) <= 32, "imm must be 32 bit long(12 bits+signext)"
            assert imm & ((1 << 12) - 1) == 0, "lower 12 bits must be zero"
        imm = imm & 0xFFFFFFFF

        word = (opcode) | (rd << 7) | (bit_slice(imm, 31, 12) << 12)
        return word

    def match(self, opcode=None, rd=None, imm=None) -> Value:
        """ Build boolean expression that matches x against provided parts """
        if type(imm) == int:
            assert imm.bit_length(
            ) <= 32, "imm must be 32 bit long(12 bits+signext)"
            assert imm & ((1 << 12) - 1) == 0, "lower 12 bits must be zero"
            imm = imm & (2**32) - 1
        subexpressions = []
        if opcode is not None:
            subexpressions.append(self.opcode.matches(opcode))
        if rd is not None: subexpressions.append(self.rd.matches(rd))
        if imm is not None: subexpressions.append(self.imm.matches(imm))

        if not subexpressions:
            print("warning: no matches provided for utype.match")
            return Const(1)
        res = subexpressions.pop(0)
        while subexpressions:
            res = res & subexpressions.pop(0)
        return res
Esempio n. 4
0
class SequencerROM(Elaboratable):
    """ROM for the sequencer card state machine."""
    def __init__(self):
        # Control line
        self.enable_sequencer_rom = Signal()

        # Inputs: 9 + 11 decoder
        # Since this is implemented by a ROM, the address lines
        # must be stable in order for the outputs to start becoming
        # stable. This means that if any input address depends on
        # any output data combinatorically, there's a danger of
        # going unstable. Therefore, all address lines must be
        # registered, or come combinatorically from registered data.

        self.memaddr_2_lsb = Signal(2)
        self.branch_cond = Signal()
        self._instr_phase = Signal(2)
        # Only used on instruction phase 1 in BRANCH.
        self.data_z_in_2_lsb0 = Signal()
        self.imm0 = Signal()
        self.rd0 = Signal()
        self.rs1_0 = Signal()

        # Instruction decoding
        self.opcode_select = Signal(OpcodeSelect)  # 4 bits
        self._funct3 = Signal(3)
        self._alu_func = Signal(4)

        ##############
        # Outputs (66 bits total)
        ##############

        # Raised on the last phase of an instruction.
        self.set_instr_complete = Signal()

        # Raised when the exception card should store trap data.
        self.save_trap_csrs = Signal()

        # CSR lines
        self.csr_to_x = Signal()
        self.z_to_csr = Signal()

        # Memory
        self.mem_rd = Signal(reset=1)
        self.mem_wr = Signal()
        # Bytes in memory word to write
        self.mem_wr_mask = Signal(4)

        self._next_instr_phase = Signal(2)

        self._x_reg_select = Signal(InstrReg)  # 2 bits
        self._y_reg_select = Signal(InstrReg)  # 2 bits
        self._z_reg_select = Signal(InstrReg)  # 2 bits

        # -> X
        self.x_mux_select = Signal(SeqMuxSelect)
        self.reg_to_x = Signal()

        # -> Y
        self.y_mux_select = Signal(SeqMuxSelect)
        self.reg_to_y = Signal()

        # -> Z
        self.z_mux_select = Signal(SeqMuxSelect)
        self.alu_op_to_z = Signal(AluOp)  # 4 bits

        # -> PC
        self.pc_mux_select = Signal(SeqMuxSelect)

        # -> tmp
        self.tmp_mux_select = Signal(SeqMuxSelect)

        # -> csr_num
        self._funct12_to_csr_num = Signal()
        self._mepc_num_to_csr_num = Signal()
        self._mcause_to_csr_num = Signal()

        # -> memaddr
        self.memaddr_mux_select = Signal(SeqMuxSelect)

        # -> memdata
        self.memdata_wr_mux_select = Signal(SeqMuxSelect)

        self._const = Signal(ConstSelect)  # select: 4 bits

        self.enter_trap = Signal()
        self.exit_trap = Signal()

        # Signals for next registers
        self.load_trap = Signal()
        self.next_trap = Signal()
        self.load_exception = Signal()
        self.next_exception = Signal()
        self.next_fatal = Signal()

    def elaborate(self, _: Platform) -> Module:
        """Implements the logic of the sequencer card."""
        m = Module()

        # Defaults
        m.d.comb += [
            self._next_instr_phase.eq(0),
            self.reg_to_x.eq(0),
            self.reg_to_y.eq(0),
            self.alu_op_to_z.eq(AluOp.NONE),
            self.mem_rd.eq(0),
            self.mem_wr.eq(0),
            self.mem_wr_mask.eq(0),
            self.csr_to_x.eq(0),
            self.z_to_csr.eq(0),
            self._funct12_to_csr_num.eq(0),
            self._mepc_num_to_csr_num.eq(0),
            self._mcause_to_csr_num.eq(0),
            self._x_reg_select.eq(0),
            self._y_reg_select.eq(0),
            self._z_reg_select.eq(0),
            self.enter_trap.eq(0),
            self.exit_trap.eq(0),
            self.save_trap_csrs.eq(0),
            self.pc_mux_select.eq(SeqMuxSelect.PC),
            self.memaddr_mux_select.eq(SeqMuxSelect.MEMADDR),
            self.memdata_wr_mux_select.eq(SeqMuxSelect.MEMDATA_WR),
            self.tmp_mux_select.eq(SeqMuxSelect.TMP),
            self.x_mux_select.eq(SeqMuxSelect.X),
            self.y_mux_select.eq(SeqMuxSelect.Y),
            self.z_mux_select.eq(SeqMuxSelect.Z),
            self._const.eq(0),
        ]

        m.d.comb += [
            self.load_trap.eq(0),
            self.next_trap.eq(0),
            self.load_exception.eq(0),
            self.next_exception.eq(0),
            self.next_fatal.eq(0),
        ]

        with m.If(self.enable_sequencer_rom):

            # Output control signals
            with m.Switch(self.opcode_select):
                with m.Case(OpcodeSelect.LUI):
                    self.handle_lui(m)

                with m.Case(OpcodeSelect.AUIPC):
                    self.handle_auipc(m)

                with m.Case(OpcodeSelect.OP_IMM):
                    self.handle_op_imm(m)

                with m.Case(OpcodeSelect.OP):
                    self.handle_op(m)

                with m.Case(OpcodeSelect.JAL):
                    self.handle_jal(m)

                with m.Case(OpcodeSelect.JALR):
                    self.handle_jalr(m)

                with m.Case(OpcodeSelect.BRANCH):
                    self.handle_branch(m)

                with m.Case(OpcodeSelect.LOAD):
                    self.handle_load(m)

                with m.Case(OpcodeSelect.STORE):
                    self.handle_store(m)

                with m.Case(OpcodeSelect.CSRS):
                    self.handle_csrs(m)

                with m.Case(OpcodeSelect.MRET):
                    self.handle_MRET(m)

                with m.Case(OpcodeSelect.ECALL):
                    self.handle_ECALL(m)

                with m.Case(OpcodeSelect.EBREAK):
                    self.handle_EBREAK(m)

                with m.Default():
                    self.handle_illegal_instr(m)

        return m

    def next_instr(self, m: Module, next_pc: NextPC = NextPC.PC_PLUS_4):
        """Sets signals to advance to the next instruction.

        next_pc is the signal to load the PC and MEMADDR registers with
        at the end of the instruction cycle.
        """
        m.d.comb += self.set_instr_complete.eq(1)
        if next_pc == NextPC.PC_PLUS_4:
            m.d.comb += self.pc_mux_select.eq(SeqMuxSelect.PC_PLUS_4)
            m.d.comb += self.memaddr_mux_select.eq(SeqMuxSelect.PC_PLUS_4)
        elif next_pc == NextPC.MEMADDR:
            m.d.comb += self.pc_mux_select.eq(SeqMuxSelect.MEMADDR)
        elif next_pc == NextPC.MEMADDR_NO_LSB:
            m.d.comb += self.pc_mux_select.eq(SeqMuxSelect.MEMADDR_LSB_MASKED)
        elif next_pc == NextPC.Z:
            m.d.comb += self.pc_mux_select.eq(SeqMuxSelect.Z)
            m.d.comb += self.memaddr_mux_select.eq(SeqMuxSelect.Z)
        elif next_pc == NextPC.X:
            m.d.comb += self.pc_mux_select.eq(SeqMuxSelect.X)
            m.d.comb += self.memaddr_mux_select.eq(SeqMuxSelect.X)

    def set_exception(self,
                      m: Module,
                      exc: ConstSelect,
                      mtval: SeqMuxSelect,
                      fatal: bool = True):
        m.d.comb += self.load_exception.eq(1)
        m.d.comb += self.next_exception.eq(1)
        m.d.comb += self.next_fatal.eq(1 if fatal else 0)

        m.d.comb += self._const.eq(exc)
        m.d.comb += self.x_mux_select.eq(SeqMuxSelect.CONST)
        m.d.comb += self.z_mux_select.eq(mtval)

        if fatal:
            m.d.comb += self.y_mux_select.eq(SeqMuxSelect.PC)
        else:
            m.d.comb += self.y_mux_select.eq(SeqMuxSelect.PC_PLUS_4)

        # X -> MCAUSE, Y -> MEPC, Z -> MTVAL
        m.d.comb += self.save_trap_csrs.eq(1)
        m.d.comb += self.load_trap.eq(1)
        m.d.comb += self.next_trap.eq(1)
        m.d.comb += self._next_instr_phase.eq(0)

    def handle_illegal_instr(self, m: Module):
        self.set_exception(m,
                           ConstSelect.EXC_ILLEGAL_INSTR,
                           mtval=SeqMuxSelect.INSTR)

    def handle_lui(self, m: Module):
        """Adds the LUI logic to the given module.

        rd <- r0 + imm
        PC <- PC + 4

        r0      -> X
        imm     -> Y
        ALU ADD -> Z
        Z       -> rd
        PC + 4  -> PC
        PC + 4  -> memaddr
        """
        m.d.comb += [
            self.reg_to_x.eq(1),
            self._x_reg_select.eq(InstrReg.ZERO),
            self.y_mux_select.eq(SeqMuxSelect.IMM),
            self.alu_op_to_z.eq(AluOp.ADD),
            self._z_reg_select.eq(InstrReg.RD),
        ]
        self.next_instr(m)

    def handle_auipc(self, m: Module):
        """Adds the AUIPC logic to the given module.

        rd <- PC + imm
        PC <- PC + 4

        PC      -> X
        imm     -> Y
        ALU ADD -> Z
        Z       -> rd
        PC + 4  -> PC
        PC + 4  -> memaddr
        """
        m.d.comb += [
            self.x_mux_select.eq(SeqMuxSelect.PC),
            self.y_mux_select.eq(SeqMuxSelect.IMM),
            self.alu_op_to_z.eq(AluOp.ADD),
            self._z_reg_select.eq(InstrReg.RD),
        ]
        self.next_instr(m)

    def handle_op_imm(self, m: Module):
        """Adds the OP_IMM logic to the given module.

        rd <- rs1 op imm
        PC <- PC + 4

        rs1     -> X
        imm     -> Y
        ALU op  -> Z
        Z       -> rd
        PC + 4  -> PC
        PC + 4  -> memaddr
        """
        with m.If(~self._alu_func.matches(
                AluFunc.ADD, AluFunc.SUB, AluFunc.SLL, AluFunc.SLT,
                AluFunc.SLTU, AluFunc.XOR, AluFunc.SRL, AluFunc.SRA,
                AluFunc.OR, AluFunc.AND)):
            self.handle_illegal_instr(m)
        with m.Else():
            m.d.comb += [
                self.reg_to_x.eq(1),
                self._x_reg_select.eq(InstrReg.RS1),
                self.y_mux_select.eq(SeqMuxSelect.IMM),
                self._z_reg_select.eq(InstrReg.RD),
            ]
            with m.Switch(self._alu_func):
                with m.Case(AluFunc.ADD):
                    m.d.comb += self.alu_op_to_z.eq(AluOp.ADD)
                with m.Case(AluFunc.SUB):
                    m.d.comb += self.alu_op_to_z.eq(AluOp.SUB)
                with m.Case(AluFunc.SLL):
                    m.d.comb += self.alu_op_to_z.eq(AluOp.SLL)
                with m.Case(AluFunc.SLT):
                    m.d.comb += self.alu_op_to_z.eq(AluOp.SLT)
                with m.Case(AluFunc.SLTU):
                    m.d.comb += self.alu_op_to_z.eq(AluOp.SLTU)
                with m.Case(AluFunc.XOR):
                    m.d.comb += self.alu_op_to_z.eq(AluOp.XOR)
                with m.Case(AluFunc.SRL):
                    m.d.comb += self.alu_op_to_z.eq(AluOp.SRL)
                with m.Case(AluFunc.SRA):
                    m.d.comb += self.alu_op_to_z.eq(AluOp.SRA)
                with m.Case(AluFunc.OR):
                    m.d.comb += self.alu_op_to_z.eq(AluOp.OR)
                with m.Case(AluFunc.AND):
                    m.d.comb += self.alu_op_to_z.eq(AluOp.AND)
            self.next_instr(m)

    def handle_op(self, m: Module):
        """Adds the OP logic to the given module.

        rd <- rs1 op rs2
        PC <- PC + 4

        rs1     -> X
        rs2     -> Y
        ALU op  -> Z
        Z       -> rd
        PC + 4  -> PC
        PC + 4  -> memaddr
        """
        with m.If(~self._alu_func.matches(
                AluFunc.ADD, AluFunc.SUB, AluFunc.SLL, AluFunc.SLT,
                AluFunc.SLTU, AluFunc.XOR, AluFunc.SRL, AluFunc.SRA,
                AluFunc.OR, AluFunc.AND)):
            self.handle_illegal_instr(m)
        with m.Else():
            m.d.comb += [
                self.reg_to_x.eq(1),
                self._x_reg_select.eq(InstrReg.RS1),
                self.reg_to_y.eq(1),
                self._y_reg_select.eq(InstrReg.RS2),
                self._z_reg_select.eq(InstrReg.RD),
            ]
            with m.Switch(self._alu_func):
                with m.Case(AluFunc.ADD):
                    m.d.comb += self.alu_op_to_z.eq(AluOp.ADD)
                with m.Case(AluFunc.SUB):
                    m.d.comb += self.alu_op_to_z.eq(AluOp.SUB)
                with m.Case(AluFunc.SLL):
                    m.d.comb += self.alu_op_to_z.eq(AluOp.SLL)
                with m.Case(AluFunc.SLT):
                    m.d.comb += self.alu_op_to_z.eq(AluOp.SLT)
                with m.Case(AluFunc.SLTU):
                    m.d.comb += self.alu_op_to_z.eq(AluOp.SLTU)
                with m.Case(AluFunc.XOR):
                    m.d.comb += self.alu_op_to_z.eq(AluOp.XOR)
                with m.Case(AluFunc.SRL):
                    m.d.comb += self.alu_op_to_z.eq(AluOp.SRL)
                with m.Case(AluFunc.SRA):
                    m.d.comb += self.alu_op_to_z.eq(AluOp.SRA)
                with m.Case(AluFunc.OR):
                    m.d.comb += self.alu_op_to_z.eq(AluOp.OR)
                with m.Case(AluFunc.AND):
                    m.d.comb += self.alu_op_to_z.eq(AluOp.AND)
            self.next_instr(m)

    def handle_jal(self, m: Module):
        """Adds the JAL logic to the given module.

        rd <- PC + 4, PC <- PC + imm

        PC      -> X
        imm     -> Y
        ALU ADD -> Z
        Z       -> memaddr
        ---------------------
        PC + 4  -> Z
        Z       -> rd
        memaddr -> PC   # This will zero the least significant bit

        Note that because the immediate value for JAL has its least
        significant bit set to zero by definition, and the PC is also
        assumed to be aligned, there is no loss in generality to clear
        the least significant bit when transferring memaddr to PC.
        """
        with m.If(self._instr_phase == 0):
            m.d.comb += [
                self.x_mux_select.eq(SeqMuxSelect.PC),
                self.y_mux_select.eq(SeqMuxSelect.IMM),
                self.alu_op_to_z.eq(AluOp.ADD),
                self.memaddr_mux_select.eq(SeqMuxSelect.Z),
                self._next_instr_phase.eq(1),
            ]
        with m.Else():
            with m.If(self.memaddr_2_lsb[1] != 0):
                self.set_exception(m,
                                   ConstSelect.EXC_INSTR_ADDR_MISALIGN,
                                   mtval=SeqMuxSelect.MEMADDR)
            with m.Else():
                m.d.comb += [
                    self.z_mux_select.eq(SeqMuxSelect.PC_PLUS_4),
                    self._z_reg_select.eq(InstrReg.RD),
                ]
                self.next_instr(m, NextPC.MEMADDR_NO_LSB)

    def handle_jalr(self, m: Module):
        """Adds the JALR logic to the given module.

        rd <- PC + 4, PC <- (rs1 + imm) & 0xFFFFFFFE

        rs1     -> X
        imm     -> Y
        ALU ADD -> Z
        Z       -> memaddr
        ---------------------
        PC + 4  -> Z
        Z       -> rd
        memaddr -> PC  # This will zero the least significant bit
        """
        with m.If(self._instr_phase == 0):
            m.d.comb += [
                self.reg_to_x.eq(1),
                self._x_reg_select.eq(InstrReg.RS1),
                self.y_mux_select.eq(SeqMuxSelect.IMM),
                self.alu_op_to_z.eq(AluOp.ADD),
                self.memaddr_mux_select.eq(SeqMuxSelect.Z),
                self._next_instr_phase.eq(1),
            ]
        with m.Else():
            with m.If(self.memaddr_2_lsb[1] != 0):
                self.set_exception(m,
                                   ConstSelect.EXC_INSTR_ADDR_MISALIGN,
                                   mtval=SeqMuxSelect.MEMADDR_LSB_MASKED)
            with m.Else():
                m.d.comb += [
                    self.z_mux_select.eq(SeqMuxSelect.PC_PLUS_4),
                    self._z_reg_select.eq(InstrReg.RD),
                ]
                self.next_instr(m, NextPC.MEMADDR_NO_LSB)

    def handle_branch(self, m: Module):
        """Adds the BRANCH logic to the given module.

        cond <- rs1 - rs2 < 0, rs1 - rs2 == 0
        if f(cond):
            PC <- PC + imm
        else:
            PC <- PC + 4

        rs1     -> X
        rs2     -> Y
        ALU SUB -> Z, cond
        --------------------- cond == 1
        PC      -> X
        imm/4   -> Y (imm for cond == 1, 4 otherwise)
        ALU ADD -> Z
        Z       -> PC
        Z       -> memaddr
        --------------------- cond == 0
        PC + 4  -> PC
        PC + 4  -> memaddr
        """
        with m.If(self._instr_phase == 0):
            m.d.comb += [
                self.reg_to_x.eq(1),
                self._x_reg_select.eq(InstrReg.RS1),
                self.reg_to_y.eq(1),
                self._y_reg_select.eq(InstrReg.RS2),
                self.alu_op_to_z.eq(AluOp.SUB),
                self._next_instr_phase.eq(1),
            ]
        with m.Elif(self._instr_phase == 1):
            with m.If(~self._funct3.matches(BranchCond.EQ, BranchCond.NE,
                                            BranchCond.LT, BranchCond.GE,
                                            BranchCond.LTU, BranchCond.GEU)):
                self.handle_illegal_instr(m)

            with m.Else():
                with m.If(self.branch_cond):
                    m.d.comb += self.y_mux_select.eq(SeqMuxSelect.IMM)
                with m.Else():
                    m.d.comb += self._const.eq(ConstSelect.SHAMT_4)
                    m.d.comb += self.y_mux_select.eq(SeqMuxSelect.CONST)

                m.d.comb += [
                    self.x_mux_select.eq(SeqMuxSelect.PC),
                    self.alu_op_to_z.eq(AluOp.ADD),
                ]

                with m.If(self.data_z_in_2_lsb0):
                    self.next_instr(m, NextPC.Z)

                with m.Else():
                    m.d.comb += self._next_instr_phase.eq(2)
                    m.d.comb += self.tmp_mux_select.eq(SeqMuxSelect.Z)

        with m.Else():
            self.set_exception(m,
                               ConstSelect.EXC_INSTR_ADDR_MISALIGN,
                               mtval=SeqMuxSelect.TMP)

    def handle_load(self, m: Module):
        """Adds the LOAD logic to the given module.

        Note that byte loads are byte-aligned, half-word loads
        are 16-bit aligned, and word loads are 32-bit aligned.
        Attempting to load unaligned will lead to undefined
        behavior.

        Operation is to load 32 bits from a 32-bit aligned
        address, and then perform at most two shifts to get
        the desired behavior: a shift left to get the most
        significant byte into the leftmost position, then a
        shift right to zero or sign extend the value.

        For example, for loading a half-word starting at
        address A where A%4=0, we first load the full 32
        bits at that address, resulting in XYHL, where X and
        Y are unwanted and H and L are the half-word we want
        to load. Then we shift left by 16: HL00. And finally
        we shift right by 16, either signed or unsigned
        depending on whether we are doing an LH or an LHU:
        ssHL / 00HL.

        addr <- rs1 + imm
        rd <- data at addr, possibly sign-extended
        PC <- PC + 4

        If we let N be addr%4, then:

        instr   N   shift1  shift2
        --------------------------
        LB      0   SLL 24  SRA 24
        LB      1   SLL 16  SRA 24
        LB      2   SLL  8  SRA 24
        LB      3   SLL  0  SRA 24
        LBU     0   SLL 24  SRL 24
        LBU     1   SLL 16  SRL 24
        LBU     2   SLL  8  SRL 24
        LBU     3   SLL  0  SRL 24
        LH      0   SLL 16  SRA 16
        LH      2   SLL  0  SRA 16
        LHU     0   SLL 16  SRL 16
        LHU     2   SLL  0  SRL 16
        LW      0   SLL  0  SRA  0
        (all other N are misaligned accesses)

        Where there is an SLL 0, the machine cycle
        could be skipped, but in the interests of
        simpler logic, we will not do that.

        rs1     -> X
        imm     -> Y
        ALU ADD -> Z
        Z       -> memaddr
        ---------------------
        memdata -> X
        shamt1  -> Y
        ALU SLL -> Z
        Z       -> rd
        ---------------------
        rd          -> X
        shamt2      -> Y
        ALU SRA/SRL -> Z
        Z           -> rd
        PC + 4      -> PC
        PC + 4      -> memaddr
        """
        with m.If(self._instr_phase == 0):
            m.d.comb += [
                self.reg_to_x.eq(1),
                self._x_reg_select.eq(InstrReg.RS1),
                self.y_mux_select.eq(SeqMuxSelect.IMM),
                self.alu_op_to_z.eq(AluOp.ADD),
                self.memaddr_mux_select.eq(SeqMuxSelect.Z),
                self._next_instr_phase.eq(1),
            ]

        with m.Elif(self._instr_phase == 1):
            # Check for exception conditions first
            with m.If(
                    self._funct3.matches(MemAccessWidth.H, MemAccessWidth.HU)
                    & self.memaddr_2_lsb[0]):
                self.set_exception(m,
                                   ConstSelect.EXC_LOAD_ADDR_MISALIGN,
                                   mtval=SeqMuxSelect.MEMADDR)

            with m.Elif((self._funct3 == MemAccessWidth.W)
                        & (self.memaddr_2_lsb != 0)):
                self.set_exception(m,
                                   ConstSelect.EXC_LOAD_ADDR_MISALIGN,
                                   mtval=SeqMuxSelect.MEMADDR)

            with m.Elif(~self._funct3.matches(
                    MemAccessWidth.B, MemAccessWidth.BU, MemAccessWidth.H,
                    MemAccessWidth.HU, MemAccessWidth.W)):
                self.handle_illegal_instr(m)

            with m.Else():
                m.d.comb += [
                    self.mem_rd.eq(1),
                    self.x_mux_select.eq(SeqMuxSelect.MEMDATA_RD),
                    self.y_mux_select.eq(SeqMuxSelect.CONST),
                    self.alu_op_to_z.eq(AluOp.SLL),
                    self._z_reg_select.eq(InstrReg.RD),
                    self._next_instr_phase.eq(2),
                ]

                with m.Switch(self._funct3):

                    with m.Case(MemAccessWidth.B, MemAccessWidth.BU):
                        with m.Switch(self.memaddr_2_lsb):
                            with m.Case(0):
                                m.d.comb += self._const.eq(
                                    ConstSelect.SHAMT_24)
                            with m.Case(1):
                                m.d.comb += self._const.eq(
                                    ConstSelect.SHAMT_16)
                            with m.Case(2):
                                m.d.comb += self._const.eq(ConstSelect.SHAMT_8)
                            with m.Case(3):
                                m.d.comb += self._const.eq(ConstSelect.SHAMT_0)

                    with m.Case(MemAccessWidth.H, MemAccessWidth.HU):
                        with m.Switch(self.memaddr_2_lsb):
                            with m.Case(0):
                                m.d.comb += self._const.eq(
                                    ConstSelect.SHAMT_16)
                            with m.Case(2):
                                m.d.comb += self._const.eq(ConstSelect.SHAMT_0)

                    with m.Case(MemAccessWidth.W):
                        m.d.comb += self._const.eq(ConstSelect.SHAMT_0)

        with m.Else():
            m.d.comb += [
                self.reg_to_x.eq(1),
                self._x_reg_select.eq(InstrReg.RD),
                self.y_mux_select.eq(SeqMuxSelect.CONST),
                self._z_reg_select.eq(InstrReg.RD),
            ]

            with m.Switch(self._funct3):
                with m.Case(MemAccessWidth.B):
                    m.d.comb += [
                        self._const.eq(ConstSelect.SHAMT_24),
                        self.alu_op_to_z.eq(AluOp.SRA),
                    ]
                with m.Case(MemAccessWidth.BU):
                    m.d.comb += [
                        self._const.eq(ConstSelect.SHAMT_24),
                        self.alu_op_to_z.eq(AluOp.SRL),
                    ]
                with m.Case(MemAccessWidth.H):
                    m.d.comb += [
                        self._const.eq(ConstSelect.SHAMT_16),
                        self.alu_op_to_z.eq(AluOp.SRA),
                    ]
                with m.Case(MemAccessWidth.HU):
                    m.d.comb += [
                        self._const.eq(ConstSelect.SHAMT_16),
                        self.alu_op_to_z.eq(AluOp.SRL),
                    ]
                with m.Case(MemAccessWidth.W):
                    m.d.comb += [
                        self._const.eq(ConstSelect.SHAMT_0),
                        self.alu_op_to_z.eq(AluOp.SRL),
                    ]

            self.next_instr(m)

    def handle_store(self, m: Module):
        """Adds the STORE logic to the given module.

        Note that byte stores are byte-aligned, half-word stores
        are 16-bit aligned, and word stores are 32-bit aligned.
        Attempting to stores unaligned will lead to undefined
        behavior.

        addr <- rs1 + imm
        data <- rs2
        PC <- PC + 4

        rs1     -> X
        imm     -> Y
        ALU ADD -> Z
        Z       -> memaddr
        ---------------------
        rs2     -> X
        shamt   -> Y
        ALU SLL -> Z
        Z       -> wrdata
                -> wrmask
        ---------------------
        PC + 4  -> PC
        PC + 4  -> memaddr
        """
        with m.If(self._instr_phase == 0):
            m.d.comb += [
                self.reg_to_x.eq(1),
                self._x_reg_select.eq(InstrReg.RS1),
                self.y_mux_select.eq(SeqMuxSelect.IMM),
                self.alu_op_to_z.eq(AluOp.ADD),
                self.memaddr_mux_select.eq(SeqMuxSelect.Z),
                self._next_instr_phase.eq(1),
            ]

        with m.Elif(self._instr_phase == 1):
            # Check for exception conditions first
            with m.If((self._funct3 == MemAccessWidth.H)
                      & self.memaddr_2_lsb[0]):
                self.set_exception(m,
                                   ConstSelect.EXC_STORE_AMO_ADDR_MISALIGN,
                                   mtval=SeqMuxSelect.MEMADDR)

            with m.Elif((self._funct3 == MemAccessWidth.W)
                        & (self.memaddr_2_lsb != 0)):
                self.set_exception(m,
                                   ConstSelect.EXC_STORE_AMO_ADDR_MISALIGN,
                                   mtval=SeqMuxSelect.MEMADDR)

            with m.Elif(~self._funct3.matches(
                    MemAccessWidth.B, MemAccessWidth.H, MemAccessWidth.W)):
                self.handle_illegal_instr(m)

            with m.Else():
                m.d.comb += [
                    self.reg_to_x.eq(1),
                    self._x_reg_select.eq(InstrReg.RS2),
                    self.y_mux_select.eq(SeqMuxSelect.CONST),
                    self.alu_op_to_z.eq(AluOp.SLL),
                    self.memdata_wr_mux_select.eq(SeqMuxSelect.Z),
                    self._next_instr_phase.eq(2),
                ]

                with m.Switch(self._funct3):

                    with m.Case(MemAccessWidth.B):
                        with m.Switch(self.memaddr_2_lsb):
                            with m.Case(0):
                                m.d.comb += self._const.eq(ConstSelect.SHAMT_0)
                            with m.Case(1):
                                m.d.comb += self._const.eq(ConstSelect.SHAMT_8)
                            with m.Case(2):
                                m.d.comb += self._const.eq(
                                    ConstSelect.SHAMT_16)
                            with m.Case(3):
                                m.d.comb += self._const.eq(
                                    ConstSelect.SHAMT_24)

                    with m.Case(MemAccessWidth.H):
                        with m.Switch(self.memaddr_2_lsb):
                            with m.Case(0):
                                m.d.comb += self._const.eq(ConstSelect.SHAMT_0)
                            with m.Case(2):
                                m.d.comb += self._const.eq(
                                    ConstSelect.SHAMT_16)

                    with m.Case(MemAccessWidth.W):
                        m.d.comb += self._const.eq(ConstSelect.SHAMT_0)

        with m.Else():
            with m.Switch(self._funct3):

                with m.Case(MemAccessWidth.B):
                    with m.Switch(self.memaddr_2_lsb):
                        with m.Case(0):
                            m.d.comb += self.mem_wr_mask.eq(0b0001)
                        with m.Case(1):
                            m.d.comb += self.mem_wr_mask.eq(0b0010)
                        with m.Case(2):
                            m.d.comb += self.mem_wr_mask.eq(0b0100)
                        with m.Case(3):
                            m.d.comb += self.mem_wr_mask.eq(0b1000)

                with m.Case(MemAccessWidth.H):
                    with m.Switch(self.memaddr_2_lsb):
                        with m.Case(0):
                            m.d.comb += self.mem_wr_mask.eq(0b0011)
                        with m.Case(2):
                            m.d.comb += self.mem_wr_mask.eq(0b1100)

                with m.Case(MemAccessWidth.W):
                    m.d.comb += self.mem_wr_mask.eq(0b1111)

            m.d.comb += self.mem_wr.eq(1)
            self.next_instr(m)

    def handle_csrs(self, m: Module):
        """Adds the SYSTEM (CSR opcodes) logic to the given module.

        Some points of interest:

        * Attempts to write a read-only register
          result in an illegal instruction exception.
        * Attempts to access a CSR that doesn't exist
          result in an illegal instruction exception.
        * Attempts to write read-only bits to a read/write CSR
          are ignored.

        Because we're building this in hardware, which is
        expensive, we're not implementing any CSRs that aren't
        strictly necessary. The documentation for the misa, mvendorid,
        marchid, and mimpid registers state that they can return zero if
        unimplemented. This implies that unimplemented CSRs still
        exist.

        The mhartid, because we only have one HART, can just return zero.
        """
        with m.Switch(self._funct3):
            with m.Case(SystemFunc.CSRRW):
                self.handle_CSRRW(m)
            with m.Case(SystemFunc.CSRRWI):
                self.handle_CSRRWI(m)
            with m.Case(SystemFunc.CSRRS):
                self.handle_CSRRS(m)
            with m.Case(SystemFunc.CSRRSI):
                self.handle_CSRRSI(m)
            with m.Case(SystemFunc.CSRRC):
                self.handle_CSRRC(m)
            with m.Case(SystemFunc.CSRRCI):
                self.handle_CSRRCI(m)
            with m.Default():
                self.handle_illegal_instr(m)

    def handle_CSRRW(self, m: Module):
        m.d.comb += self._funct12_to_csr_num.eq(1)

        with m.If(self._instr_phase == 0):
            with m.If(self.rd0):
                m.d.comb += [
                    self._x_reg_select.eq(InstrReg.ZERO),
                    self.reg_to_x.eq(1),
                ]
            with m.Else():
                m.d.comb += [self.csr_to_x.eq(1)]
            m.d.comb += [
                self._y_reg_select.eq(InstrReg.RS1),
                self.reg_to_y.eq(1),
                self.alu_op_to_z.eq(AluOp.Y),
                self.z_to_csr.eq(1),
                self.tmp_mux_select.eq(SeqMuxSelect.X),
                self._next_instr_phase.eq(1),
            ]

        with m.Else():
            m.d.comb += [
                self.z_mux_select.eq(SeqMuxSelect.TMP),
                self._z_reg_select.eq(InstrReg.RD),
            ]
            self.next_instr(m)

    def handle_CSRRWI(self, m: Module):
        m.d.comb += self._funct12_to_csr_num.eq(1)

        with m.If(self._instr_phase == 0):
            with m.If(self.rd0):
                m.d.comb += [
                    self._x_reg_select.eq(InstrReg.ZERO),
                    self.reg_to_x.eq(1),
                ]
            with m.Else():
                m.d.comb += [self.csr_to_x.eq(1)]
            m.d.comb += [
                self.y_mux_select.eq(SeqMuxSelect.IMM),
                self.alu_op_to_z.eq(AluOp.Y),
                self.z_to_csr.eq(1),
                self.tmp_mux_select.eq(SeqMuxSelect.X),
                self._next_instr_phase.eq(1),
            ]

        with m.Else():
            m.d.comb += [
                self.z_mux_select.eq(SeqMuxSelect.TMP),
                self._z_reg_select.eq(InstrReg.RD),
            ]
            self.next_instr(m)

    def handle_CSRRS(self, m: Module):
        m.d.comb += self._funct12_to_csr_num.eq(1)

        with m.If(self._instr_phase == 0):
            m.d.comb += [
                self.csr_to_x.eq(1),
                self._y_reg_select.eq(InstrReg.RS1),
                self.reg_to_y.eq(1),
                self.alu_op_to_z.eq(AluOp.OR),
                self.z_to_csr.eq(~self.rs1_0),
                self.tmp_mux_select.eq(SeqMuxSelect.X),
                self._next_instr_phase.eq(1),
            ]

        with m.Else():
            m.d.comb += [
                self.z_mux_select.eq(SeqMuxSelect.TMP),
                self._z_reg_select.eq(InstrReg.RD),
            ]
            self.next_instr(m)

    def handle_CSRRSI(self, m: Module):
        m.d.comb += self._funct12_to_csr_num.eq(1)

        with m.If(self._instr_phase == 0):
            m.d.comb += [
                self.csr_to_x.eq(1),
                self.y_mux_select.eq(SeqMuxSelect.IMM),
                self.alu_op_to_z.eq(AluOp.OR),
                self.z_to_csr.eq(~self.imm0),
                self.tmp_mux_select.eq(SeqMuxSelect.X),
                self._next_instr_phase.eq(1),
            ]

        with m.Else():
            m.d.comb += [
                self.z_mux_select.eq(SeqMuxSelect.TMP),
                self._z_reg_select.eq(InstrReg.RD),
            ]
            self.next_instr(m)

    def handle_CSRRC(self, m: Module):
        m.d.comb += self._funct12_to_csr_num.eq(1)

        with m.If(self._instr_phase == 0):
            m.d.comb += [
                self.csr_to_x.eq(1),
                self._y_reg_select.eq(InstrReg.RS1),
                self.reg_to_y.eq(1),
                self.alu_op_to_z.eq(AluOp.AND_NOT),
                self.z_to_csr.eq(~self.rs1_0),
                self.tmp_mux_select.eq(SeqMuxSelect.X),
                self._next_instr_phase.eq(1),
            ]

        with m.Else():
            m.d.comb += [
                self.z_mux_select.eq(SeqMuxSelect.TMP),
                self._z_reg_select.eq(InstrReg.RD),
            ]
            self.next_instr(m)

    def handle_CSRRCI(self, m: Module):
        m.d.comb += self._funct12_to_csr_num.eq(1)

        with m.If(self._instr_phase == 0):
            m.d.comb += [
                self.csr_to_x.eq(1),
                self.y_mux_select.eq(SeqMuxSelect.IMM),
                self.alu_op_to_z.eq(AluOp.AND_NOT),
                self.z_to_csr.eq(~self.imm0),
                self.tmp_mux_select.eq(SeqMuxSelect.X),
                self._next_instr_phase.eq(1),
            ]

        with m.Else():
            m.d.comb += [
                self.z_mux_select.eq(SeqMuxSelect.TMP),
                self._z_reg_select.eq(InstrReg.RD),
            ]
            self.next_instr(m)

    def handle_MRET(self, m: Module):
        m.d.comb += [
            self._mepc_num_to_csr_num.eq(1),
            self.csr_to_x.eq(1),
            self.exit_trap.eq(1),
        ]
        self.next_instr(m, NextPC.X)

    def handle_ECALL(self, m: Module):
        """Handles the ECALL instruction.

        Note that normally, ECALL is used from a lower privelege mode, which stores
        the PC of the instruction in the appropriate lower EPC CSR (e.g. SEPC or UEPC).
        This allows interrupts to be handled during the call, because we're in a higher
        privelege level. However, in machine mode, there is no higher privelege level,
        so we have no choice but to disable interrupts for an ECALL.
        """
        self.set_exception(m,
                           ConstSelect.EXC_ECALL_FROM_MACH_MODE,
                           mtval=SeqMuxSelect.PC,
                           fatal=False)

    def handle_EBREAK(self, m: Module):
        """Handles the EBREAK instruction.

        Note that normally, EBREAK is used from a lower privelege mode, which stores
        the PC of the instruction in the appropriate lower EPC CSR (e.g. SEPC or UEPC).
        This allows interrupts to be handled during the call, because we're in a higher
        privelege level. However, in machine mode, there is no higher privelege level,
        so we have no choice but to disable interrupts for an EBREAK.
        """
        self.set_exception(m,
                           ConstSelect.EXC_BREAKPOINT,
                           mtval=SeqMuxSelect.PC,
                           fatal=False)
Esempio n. 5
0
class BType:
    def __init__(self, prefix=""):
        if prefix:
            prefix += "_"
        self.opcode = Signal(7, name=f"{prefix}opcode")
        self.funct3 = Signal(3, name=f"{prefix}funct3")
        self.rs1 = Signal(5, name=f"{prefix}rs1")
        self.rs2 = Signal(5, name=f"{prefix}rs2")
        self.imm = Signal(32, name=f"{prefix}imm")

    def elaborate(self, comb: List[Statement], input: Signal):
        comb += self.opcode.eq(input[0:7])
        comb += self.funct3.eq(input[12:15])
        comb += self.rs1.eq(input[15:20])
        comb += self.rs2.eq(input[20:25])
        comb += self.imm.eq(
            Cat(Const(0, 1), input[8:12], input[25:31], input[7],
                Repl(input[31], 20)))

    @staticmethod
    def build_i32(opcode, funct3, rs1, rs2, imm):
        if type(imm) == int:
            assert imm % 2 == 0
            assert imm.bit_length() <= 32, "imm must be 32 bit long"
            imm = imm & (2**32) - 1
        value = opcode
        value = value | (funct3 << 12)
        value = value | (rs1 << 15)
        value = value | (rs2 << 20)
        value = value | ((bit_slice(imm, 4, 1)) << 8)
        value = value | ((bit_slice(imm, 10, 5)) << 25)
        value = value | ((bit_slice(imm, 11, 11)) << 7)
        value = value | ((bit_slice(imm, 12, 12)) << 31)

        return value

    def match(self,
              opcode=None,
              funct3=None,
              rs1=None,
              rs2=None,
              imm=None) -> Value:
        """ Build boolean expression that matches x against provided parts """
        if type(imm) == int:
            # TOOD: other types need it as well
            assert imm.bit_length() <= 32, "imm must be 32 bit long"
            assert imm % 2 == 0, "btype has 2-byte offset"
            # TODO: check ~20 hi bits of imm == sigen ext
            imm = imm & (2**32) - 1
        subexpressions = []
        if opcode is not None:
            subexpressions.append(self.opcode.matches(opcode))
        if funct3 is not None:
            subexpressions.append(self.funct3.matches(funct3))
        if rs1 is not None: subexpressions.append(self.rs1.matches(rs1))
        if rs2 is not None: subexpressions.append(self.rs2.matches(rs2))
        if imm is not None: subexpressions.append(self.imm.matches(imm))

        if not subexpressions:
            print("warning: no matches provided for btype.match")
            return Const(1)
        res = subexpressions.pop(0)
        while subexpressions:
            res = res & subexpressions.pop(0)
        return res
Esempio n. 6
0
class Core(Elaboratable):
    def __init__(self, verification: Instruction = None):
        self.enable = Signal(reset=1)
        self.addr = Signal(16)
        self.din = Signal(8)
        self.dout = Signal(8)
        self.RWB = Signal(reset=1)  # 1 = read, 0 = write

        # registers
        self.reg = Registers()
        self.tmp = Signal(8)  # temp signal when reading 16 bits

        # internal exec state
        self.opcode = Signal(8)
        self.cycle = Signal(4, reset=1)

        # formal verification
        self.verification = verification
        self.snapshot = Snapshot()

    def ports(self) -> List[Signal]:
        return [self.addr, self.din, self.dout, self.RWB]

    def elaborate(self, platform: Platform) -> Module:
        m = Module()

        m.submodules.alu = self.alu = ALU()

        """Fetch the opcode, common for all instr"""
        m.d.sync += self.opcode.eq(Mux(self.cycle == 1, self.dout, self.opcode))

        with m.Switch(Mux(self.cycle == 1, self.dout, self.opcode)):
            for i in implemented.implemented:
                with m.Case(i.opcode):
                    i.synth(self, m)
            with m.Default():
                m.d.comb += core.alu.oper.eq(Operation.NOP)
                m.d.sync += [
                    core.reg.PC.eq(add16(core.reg.PC, 1)),
                    core.enable.eq(1),
                    core.addr.eq(add16(core.reg.PC, 1)),
                    core.RWB.eq(1),
                    core.cycle.eq(1),
                ]

        if self.verification is not None:
            self.verify(m)

            with m.If(Initial()):
                m.d.sync += [
                    self.reg.A.eq(AnyConst(8)),
                    self.reg.X.eq(AnyConst(8)),
                    self.reg.Y.eq(AnyConst(8)),
                    self.reg.SP.eq(AnyConst(16)),
                    self.reg.PC.eq(AnyConst(16)),
                ]
                m.d.sync += [
                    self.reg.PSW.N.eq(AnyConst(1)),
                    self.reg.PSW.V.eq(AnyConst(1)),
                    self.reg.PSW.P.eq(AnyConst(1)),
                    self.reg.PSW.B.eq(AnyConst(1)),
                    self.reg.PSW.H.eq(AnyConst(1)),
                    self.reg.PSW.I.eq(AnyConst(1)),
                    self.reg.PSW.Z.eq(AnyConst(1)),
                    self.reg.PSW.C.eq(AnyConst(1)),
                ]

        return m

    def verify(self, m: Module):
        """Take snapshots of the state and check formally"""
        with m.If(self.cycle == 1):
            with m.If(self.dout.matches(self.verification.opcode)):
                self.snapshot.pre_snapshot(m, self.addr, self.dout, self.reg)
            with m.Else():
                self.snapshot.no_snapshot(m)
        with m.Else():
            with m.If(self.snapshot.taken & self.enable):
                """we keep track of every address during instr exec"""
                with m.If(self.RWB == 1):
                    self.snapshot.read(m, self.addr, self.dout)
                with m.If(self.RWB == 0):
                    self.snapshot.write(m, self.addr, self.din)

        with m.If((self.snapshot.taken) & (self.cycle == 1)):
            """at the start of the next instr, check"""
            self.snapshot.post_snapshot(m, self.reg)
            self.verification.check(m, self.snapshot, self.alu)