class DUT(m.Circuit): io = m.IO(I=m.In(m.Bit), O=m.Out(m.Bit)) + m.ClockIO()
class MyAdder(m.Circuit): IO = ["a", m.In(m.UInt[4]), "b", m.Out(m.UInt[4])] @classmethod def definition(io): io.b <= io.a + 1
class ConfigReg(m.Circuit): io = m.IO(D=m.In(m.Bits[2]), Q=m.Out(m.Bits[2])) + \ m.ClockIO(has_ce=True) reg = mantle.Register(2, has_ce=True, name="conf_reg") io.Q @= reg(io.D, CE=io.CE)
def _get_primitive_drivers( self, primitive: m.DefineCircuitKind, inst_bit: m.Out(m.Bit)) -> Iterable[m.Bit]: assert inst_bit.is_output() defn_bit = inst_port_to_defn_port(inst_bit) return get_primitive_drivers(defn_bit, allow_default=False)
def make_HandshakeData(data_type): in_type = m.Tuple(data=m.In(data_type), valid=m.In(m.Bit), ready=m.Out(m.Bit)) out_type = m.Flip(in_type) return in_type, out_type
class SimpleInverter(m.Circuit): IO = ["a", m.In(m.Bit), "nota", m.Out(m.Bit)] @classmethod def definition(io): io.nota <= ~ io.a
class Rasterizer(m.Circuit): IO = [ "CLK", m.In(m.Clock), "RESET", m.In(m.Reset), "valid_in", m.In(m.Bits(1)), "poly", m.In(Polygon(vertices, axes, bits)), "color_in", m.In(Colors(color_channels, bits)), "is_quad", m.In(m.Bits(1)), "screen_max", m.In(Point(2, bits)), "sample_size", m.In(SampleSize), "halt", m.Out(m.Bits(1)), "valid_hit", m.Out(m.Bits(1)), "hit", m.Out(Point(axes, bits)), "color_out", m.Out(Colors(color_channels, bits)) ] @classmethod def definition(io): bbox_inst = bbox.define_compute_bounding_box( integer_bits, fractional_bits, vertices, axes, color_channels, pipe_stages_box)() iterator_inst = iterator.define_iterator(integer_bits, fractional_bits, vertices, axes, color_channels, modified_fsm)() hash_jtree_inst = hash_jtree.define_hash_jtree( integer_bits, fractional_bits, vertices, axes, color_channels, pipe_stages_hash)() sampletest_inst = sampletest.define_sampletest( integer_bits, fractional_bits, vertices, axes, color_channels, pipe_stages_samp)() m.wire(io.CLK, bbox_inst.CLK) m.wire(bbox_inst.RESET, io.RESET) m.wire(bbox_inst.valid_in, io.valid_in) m.wire(bbox_inst.poly_in, io.poly) m.wire(bbox_inst.color_in, io.color_in) m.wire(bbox_inst.is_quad_in, io.is_quad) m.wire(bbox_inst.screen_max, io.screen_max) m.wire(bbox_inst.sample_size, io.sample_size) m.wire(bbox_inst.halt, iterator_inst.halt) m.wire(iterator_inst.CLK, io.CLK) m.wire(iterator_inst.RESET, io.RESET) m.wire(iterator_inst.poly_in, bbox_inst.poly_out) m.wire(iterator_inst.color_in, bbox_inst.color_out) m.wire(iterator_inst.valid_in, bbox_inst.valid_out) m.wire(iterator_inst.is_quad_in, bbox_inst.is_quad_out) m.wire(iterator_inst.sample_size, io.sample_size) m.wire(iterator_inst.halt, io.halt) m.wire(iterator_inst.box, bbox_inst.box) m.wire(hash_jtree_inst.CLK, io.CLK) m.wire(hash_jtree_inst.RESET, io.RESET) m.wire(hash_jtree_inst.poly_in, iterator_inst.poly_out) m.wire(hash_jtree_inst.color_in, iterator_inst.color_out) m.wire(hash_jtree_inst.is_quad_in, iterator_inst.is_quad_out) m.wire(hash_jtree_inst.sample_in, iterator_inst.sample) m.wire(hash_jtree_inst.valid_sample_in, iterator_inst.valid_sample) m.wire(hash_jtree_inst.sample_size, io.sample_size) m.wire(sampletest_inst.CLK, io.CLK) m.wire(sampletest_inst.RESET, io.RESET) m.wire(sampletest_inst.poly, hash_jtree_inst.poly_out) m.wire(sampletest_inst.color_in, hash_jtree_inst.color_out) m.wire(sampletest_inst.sample, hash_jtree_inst.sample_out) m.wire(sampletest_inst.valid_sample, hash_jtree_inst.valid_sample_out) m.wire(sampletest_inst.is_quad_in, hash_jtree_inst.is_quad_out) m.wire(sampletest_inst.hit, io.hit) m.wire(sampletest_inst.color_out, io.color_out) m.wire(sampletest_inst.valid_hit, io.valid_hit)
def definition(io): load = io.LOAD baud = rising(io.SCK) | falling(io.SCK) valid_counter = mantle.CounterModM(buf_size, 13, has_ce=True) m.wire(load & baud, valid_counter.CE) valid_list = [wi * (b - 1) + i * a - 1 for i in range(1, wo + 1)] valid = m.GND for i in valid_list: valid = valid | mantle.Decode(i, 13)(valid_counter.O) # register on input st_in = mantle.Register(16, has_ce=True) st_in(io.DATA) m.wire(load, st_in.CE) # --------------------------DOWNSCALING----------------------------- # # downscale the image from 320x240 to 16x16 Downscale = m.DeclareCircuit( 'Downscale', "I_0_0", m.In(m.Array(1, m.Array(1, m.Array(16, m.Bit)))), "WE", m.In(m.Bit), "CLK", m.In(m.Clock), "O", m.Out(m.Array(16, m.Bit)), "V", m.Out(m.Bit)) dscale = Downscale() m.wire(st_in.O, dscale.I_0_0[0][0]) m.wire(1, dscale.WE) m.wire(load, dscale.CLK) add16 = mantle.Add(16) # needed for Add16 definition # --------------------------FILL IMG RAM--------------------------- # # each valid output of dscale represents a pixel in 16x16 binary image # accumulate each group of 16 pixels into a 16-bit value representing # a row in the image col = mantle.CounterModM(16, 5, has_ce=True) col_ce = rising(valid) m.wire(col_ce, col.CE) # shift each bit in one at a time until we get an entire row px_bit = mantle.ULE(16)(dscale.O, m.uint(THRESH, 16)) & valid row_reg = mantle.SIPO(16, has_ce=True) row_reg(px_bit) m.wire(col_ce, row_reg.CE) # reverse the row bits since the image is flipped row = reverse(row_reg.O) rowaddr = mantle.Counter(5, has_ce=True) img_full = mantle.SRFF(has_ce=True) img_full(mantle.EQ(5)(rowaddr.O, m.bits(16, 5)), 0) m.wire(falling(col.COUT), img_full.CE) row_ce = rising(col.COUT) & ~img_full.O m.wire(row_ce, rowaddr.CE) waddr = rowaddr.O[:4] # we_counter = mantle.CounterModM(16, 5, has_ce=True) # m.wire(rising(valid), we_counter.CE) rdy = col.COUT & ~img_full.O pulse_count = mantle.Counter(5, has_ce=True) we = mantle.UGE(5)(pulse_count.O, m.uint(1, 5)) pulse_count(CE=(we | rdy)) # ---------------------------UART OUTPUT----------------------------- # row_load = row_ce row_baud = mantle.FF()(baud) uart_row = UART(16) uart_row(CLK=io.CLK, BAUD=row_baud, DATA=row, LOAD=row_load) uart_addr = UART(4) uart_addr(CLK=io.CLK, BAUD=row_baud, DATA=waddr, LOAD=row_load) # split 16-bit row data into 8-bit packets so it can be parsed low_byte = row & LOW_MASK high_byte = row & HIGH_MASK uart_counter = mantle.CounterModM(8, 4, has_ce=True) m.wire(rising(valid), uart_counter.CE) m.wire(waddr, io.WADDR) m.wire(img_full, io.DONE) m.wire(uart_row, io.UART) m.wire(row, io.O) m.wire(we, io.VALID)
def __init__(self, inputs): super().__init__() self.all_inputs = inputs self.inputs = self.__organize_inputs(inputs) self.add_ports( north=SideType(5, (1, 16)), west=SideType(5, (1, 16)), south=SideType(5, (1, 16)), east=SideType(5, (1, 16)), clk=magma.In(magma.Clock), reset=magma.In(magma.AsyncReset), config=magma.In(ConfigurationType(8, 32)), read_config_data=magma.Out(magma.Bits(32)), ) # TODO(rsetaluri): Clean up this logic. for i, input_ in enumerate(self.all_inputs): assert input_.type().isoutput() port_name = f"{input_._name}" self.add_port(port_name, magma.In(input_.type())) sides = (self.ports.north, self.ports.west, self.ports.south, self.ports.east) self.muxs = self.__make_muxs(sides) for (side, layer, track), mux in self.muxs.items(): idx = 0 for side_in in sides: if side_in == side: continue mux_in = getattr(side_in.I, f"layer{layer}")[track] self.wire(mux_in, mux.ports.I[idx]) idx += 1 for input_ in self.inputs[layer]: port_name = input_._name self.wire(self.ports[port_name], mux.ports.I[idx]) idx += 1 buffered_mux = self.__make_register_buffer(mux) mux_out = getattr(side.O, f"layer{layer}")[track] self.wire(buffered_mux.ports.O, mux_out) # Add corresponding config register. config_name = f"mux_{side._name}_{layer}_{track}" config_name_mux = config_name + '_sel' config_name_buffer = config_name + '_buffer_sel' self.add_config(config_name_mux, mux.sel_bits) self.wire(self.registers[config_name_mux].ports.O, mux.ports.S) self.add_config(config_name_buffer, buffered_mux.sel_bits) self.wire(self.registers[config_name_buffer].ports.O, buffered_mux.ports.S) # NOTE(rsetaluri): We set the config register addresses explicitly and # in a well-defined order. This ordering can be considered a part of # the functional spec of this module. idx = 0 for side in sides: for layer in (1, 16): for track in range(5): reg_name = f"mux_{side._name}_{layer}_{track}" reg_name_mux = reg_name + '_sel' reg_name_buffer = reg_name + '_buffer_sel' self.registers[reg_name_mux].set_addr(idx) idx += 1 self.registers[reg_name_buffer].set_addr(idx) idx += 1 for idx, reg in enumerate(self.registers.values()): reg.set_addr_width(8) reg.set_data_width(32) self.wire(self.ports.config.config_addr, reg.ports.config_addr) self.wire(self.ports.config.config_data, reg.ports.config_data) self.wire(self.ports.config.write[0], reg.ports.config_en) self.wire(self.ports.reset, reg.ports.reset) # read_config_data output num_config_reg = len(self.registers) if(num_config_reg > 1): self.read_config_data_mux = MuxWrapper(num_config_reg, 32) sel_bits = self.read_config_data_mux.sel_bits # Wire up config_addr to select input of read_data MUX # TODO(rsetaluri): Make this a mux with default. self.wire(self.ports.config.config_addr[:sel_bits], self.read_config_data_mux.ports.S) self.wire(self.read_config_data_mux.ports.O, self.ports.read_config_data) for idx, reg in enumerate(self.registers.values()): zext = ZextWrapper(reg.width, 32) self.wire(reg.ports.O, zext.ports.I) zext_out = zext.ports.O self.wire(zext_out, self.read_config_data_mux.ports.I[idx]) # If we only have 1 config register, we don't need a mux # Wire sole config register directly to read_config_data_output else: self.wire(self.registers[0].ports.O, self.ports.read_config_data)
def __init__(self, addr_width, data_width): self.addr_width = addr_width self.data_width = data_width self.slave = magma.Product.from_fields( "AXI4SlaveType", dict(awaddr=magma.In(magma.Bits[addr_width]), awvalid=magma.In(magma.Bit), awready=magma.Out(magma.Bit), wdata=magma.In(magma.Bits[data_width]), wvalid=magma.In(magma.Bit), wready=magma.Out(magma.Bit), bready=magma.In(magma.Bit), bresp=magma.Out(magma.Bits[2]), bvalid=magma.Out(magma.Bit), araddr=magma.In(magma.Bits[addr_width]), arvalid=magma.In(magma.Bit), arready=magma.Out(magma.Bit), rdata=magma.Out(magma.Bits[data_width]), rresp=magma.Out(magma.Bits[2]), rvalid=magma.Out(magma.Bit), rready=magma.In(magma.Bit))) self.master = magma.Product.from_fields( "AXI4MasterType", dict(awaddr=magma.Out(magma.Bits[addr_width]), awvalid=magma.Out(magma.Bit), awready=magma.In(magma.Bit), wdata=magma.Out(magma.Bits[data_width]), wvalid=magma.Out(magma.Bit), wready=magma.In(magma.Bit), bready=magma.Out(magma.Bit), bresp=magma.In(magma.Bits[2]), bvalid=magma.In(magma.Bit), araddr=magma.Out(magma.Bits[addr_width]), arvalid=magma.Out(magma.Bit), arready=magma.In(magma.Bit), rdata=magma.In(magma.Bits[data_width]), rresp=magma.In(magma.Bits[2]), rvalid=magma.In(magma.Bit), rready=magma.Out(magma.Bit)))
class dut(m.Circuit): name = 'test_meas_width' io = m.IO( in_=m.In(m.Bits[8]), out = m.Out(m.Bits[8]) )
import magma as m from magma.bitutils import int2seq import mantle from rom import ROM8, ROM16 from loam.boards.hx8kboard import HX8KBoard from uart import UART a = 2 b = 2 width = 16 TIN = m.Array(width, m.BitIn) TOUT = m.Array(width, m.Out(m.Bit)) hx8kboard = HX8KBoard() hx8kboard.Clock.on() hx8kboard.D1.on() hx8kboard.J2[9].output().on() hx8kboard.J2[10].output().on() hx8kboard.J2[11].output().on() hx8kboard.J2[12].output().on() main = hx8kboard.main() # baud for uart output clock = mantle.CounterModM(103, 8) baud = clock.COUT bit_counter = mantle.Counter(5, has_ce=True) m.wire(baud, bit_counter.CE)
def __init__(self, num_banks, num_io, num_cfg, bank_addr, cfg_addr=32, cfg_data=32): super().__init__() self.num_banks = num_banks self.bank_addr = bank_addr self.glb_addr = math.ceil(math.log2(self.num_banks)) + self.bank_addr self.num_io = num_io self.num_cfg = num_cfg self.bank_data = 64 self.cgra_data = 16 self.cfg_addr = cfg_addr self.cfg_data = cfg_data self.cgra_config_type = ConfigurationType(self.cfg_addr, self.cfg_data) self.glb_config_type = ConfigurationType(self.cfg_addr, self.cfg_data) self.add_ports( clk=magma.In(magma.Clock), reset=magma.In(magma.AsyncReset), soc_data=MMIOType(self.glb_addr, self.bank_data), cgra_to_io_wr_en=magma.In(magma.Array[self.num_io, magma.Bit]), cgra_to_io_rd_en=magma.In(magma.Array[self.num_io, magma.Bit]), io_to_cgra_rd_data_valid=magma.Out( magma.Array[self.num_io, magma.Bit]), cgra_to_io_wr_data=magma.In( magma.Array[self.num_io, magma.Bits[self.cgra_data]]), io_to_cgra_rd_data=magma.Out( magma.Array[self.num_io, magma.Bits[self.cgra_data]]), cgra_to_io_addr_high=magma.In( magma.Array[self.num_io, magma.Bits[self.cgra_data]]), cgra_to_io_addr_low=magma.In( magma.Array[self.num_io, magma.Bits[self.cgra_data]]), glc_to_io_stall=magma.In(magma.Bit), cgra_start_pulse=magma.In(magma.Bit), config_start_pulse=magma.In(magma.Bit), config_done_pulse=magma.Out(magma.Bit), cgra_config=magma.In(self.cgra_config_type), glb_to_cgra_config=magma.Out( magma.Array[self.num_cfg, self.cgra_config_type]), glb_config=magma.In(self.glb_config_type), glb_config_rd_data=magma.Out(magma.Bits[self.cfg_data]), glb_sram_config_wr=magma.In(magma.Bit), glb_sram_config_rd=magma.In(magma.Bit) ) wrapper = global_buffer_genesis2.glb_wrapper param_mapping = global_buffer_genesis2.param_mapping generator = wrapper.generator(param_mapping, mode="declare") circ = generator(num_banks=self.num_banks, num_io=self.num_io, num_cfg=self.num_cfg, bank_addr=self.bank_addr, cfg_addr=self.cfg_addr, cfg_data=self.cfg_data) self.underlying = FromMagma(circ) self.wire(self.ports.clk, self.underlying.ports.clk) self.wire(self.ports.reset, self.underlying.ports.reset) self.wire(self.ports.soc_data.wr_en, self.underlying.ports.host_wr_en) self.wire(self.ports.soc_data.wr_addr, self.underlying.ports.host_wr_addr) self.wire(self.ports.soc_data.wr_data, self.underlying.ports.host_wr_data) self.wire(self.ports.soc_data.rd_en, self.underlying.ports.host_rd_en) self.wire(self.ports.soc_data.rd_addr, self.underlying.ports.host_rd_addr) self.wire(self.ports.soc_data.rd_data, self.underlying.ports.host_rd_data) for i in range(self.num_io): self.wire(self.ports.cgra_to_io_wr_en[i], self.underlying.ports.cgra_to_io_wr_en[i]) self.wire(self.ports.cgra_to_io_rd_en[i], self.underlying.ports.cgra_to_io_rd_en[i]) self.wire(self.ports.io_to_cgra_rd_data_valid[i], self.underlying.ports.io_to_cgra_rd_data_valid[i]) self.wire(self.ports.cgra_to_io_wr_data[i], self.underlying.ports.cgra_to_io_wr_data[ i * self.cgra_data:(i + 1) * self.cgra_data]) self.wire(self.ports.io_to_cgra_rd_data[i], self.underlying.ports.io_to_cgra_rd_data[ i * self.cgra_data:(i + 1) * self.cgra_data]) self.wire(self.ports.cgra_to_io_addr_high[i], self.underlying.ports.cgra_to_io_addr_high[ i * self.cgra_data:(i + 1) * self.cgra_data]) self.wire(self.ports.cgra_to_io_addr_low[i], self.underlying.ports.cgra_to_io_addr_low[ i * self.cgra_data:(i + 1) * self.cgra_data]) for i in range(self.num_cfg): self.wire(self.ports.glb_to_cgra_config[i].write[0], self.underlying.ports.glb_to_cgra_cfg_wr[i]) self.wire(self.ports.glb_to_cgra_config[i].read[0], self.underlying.ports.glb_to_cgra_cfg_rd[i]) self.wire(self.ports.glb_to_cgra_config[i].config_addr, self.underlying.ports.glb_to_cgra_cfg_addr[ i * self.cfg_addr:(i + 1) * self.cfg_addr]) self.wire(self.ports.glb_to_cgra_config[i].config_data, self.underlying.ports.glb_to_cgra_cfg_data[ i * self.cfg_data:(i + 1) * self.cfg_data]) self.wire(self.ports.glc_to_io_stall, self.underlying.ports.glc_to_io_stall) self.wire(self.ports.cgra_config.write[0], self.underlying.ports.glc_to_cgra_cfg_wr) self.wire(self.ports.cgra_config.read[0], self.underlying.ports.glc_to_cgra_cfg_rd) self.wire(self.ports.cgra_config.config_addr, self.underlying.ports.glc_to_cgra_cfg_addr) self.wire(self.ports.cgra_config.config_data, self.underlying.ports.glc_to_cgra_cfg_data) self.wire(self.ports.cgra_start_pulse, self.underlying.ports.cgra_start_pulse) self.wire(self.ports.config_start_pulse, self.underlying.ports.config_start_pulse) self.wire(self.ports.config_done_pulse, self.underlying.ports.config_done_pulse) self.wire(self.ports.glb_config.write[0], self.underlying.ports.glb_config_wr) self.wire(self.ports.glb_config.read[0], self.underlying.ports.glb_config_rd) self.wire(self.ports.glb_config.config_data, self.underlying.ports.glb_config_wr_data) self.wire(self.ports.glb_config.config_addr, self.underlying.ports.glb_config_addr) self.wire(self.ports.glb_config_rd_data, self.underlying.ports.glb_config_rd_data) self.wire(self.ports.glb_sram_config_wr, self.underlying.ports.glb_sram_config_wr) self.wire(self.ports.glb_sram_config_rd, self.underlying.ports.glb_sram_config_rd)
class paramadd(m.Circuit): io = m.IO( a_val=m.In(m.Bits[n_bits]), c_val=m.Out(m.Bits[n_bits]) )
class dut(m.Circuit): name = 'test_conv' io = m.IO(r2i_i=fault.RealIn, r2i_o=m.Out(m.SInt[8]), i2r_i=m.In(m.SInt[8]), i2r_o=fault.RealOut)
def __init__(self, width, height, add_pd, interconnect_only: bool = False, use_sram_stub: bool = True): super().__init__() # configuration parameters config_addr_width = 32 config_data_width = 32 axi_addr_width = 12 tile_id_width = 16 config_addr_reg_width = 8 num_tracks = 5 # size self.width = width self.height = height # only north side has IO io_side = IOSide.North # global buffer parameters num_banks = 32 bank_addr_width = 17 bank_data_width = 64 glb_addr_width = 32 # parallel configuration parameter num_parallel_cfg = math.ceil(width / 4) # number of input/output channels parameter num_io = math.ceil(width / 4) if not interconnect_only: wiring = GlobalSignalWiring.ParallelMeso self.global_controller = GlobalController(config_addr_width, config_data_width, axi_addr_width) self.global_buffer = GlobalBuffer(num_banks=num_banks, num_io=num_io, num_cfg=num_parallel_cfg, bank_addr_width=bank_addr_width, glb_addr_width=glb_addr_width, cfg_addr_width=config_addr_width, cfg_data_width=config_data_width, axi_addr_width=axi_addr_width) else: wiring = GlobalSignalWiring.Meso interconnect = create_cgra(width, height, io_side, reg_addr_width=config_addr_reg_width, config_data_width=config_data_width, tile_id_width=tile_id_width, num_tracks=num_tracks, add_pd=add_pd, use_sram_stub=use_sram_stub, global_signal_wiring=wiring, num_parallel_config=num_parallel_cfg, mem_ratio=(1, 4)) self.interconnect = interconnect if not interconnect_only: self.add_ports( jtag=JTAGType, clk_in=magma.In(magma.Clock), reset_in=magma.In(magma.AsyncReset), soc_data=SoCDataType(glb_addr_width, bank_data_width), axi4_ctrl=AXI4SlaveType(axi_addr_width, config_data_width), cgra_running_clk_out=magma.Out(magma.Clock), ) # top <-> global controller ports connection self.wire(self.ports.clk_in, self.global_controller.ports.clk_in) self.wire(self.ports.reset_in, self.global_controller.ports.reset_in) self.wire(self.ports.jtag, self.global_controller.ports.jtag) self.wire(self.ports.axi4_ctrl, self.global_controller.ports.axi4_ctrl) self.wire(self.ports.cgra_running_clk_out, self.global_controller.ports.clk_out) # top <-> global buffer ports connection self.wire(self.ports.soc_data, self.global_buffer.ports.soc_data) glc_interconnect_wiring(self) glb_glc_wiring(self) glb_interconnect_wiring(self, width, num_parallel_cfg) else: # lift all the interconnect ports up self._lift_interconnect_ports(config_data_width) self.mapper_initalized = False self.__rewrite_rules = None
def test_extension_no_error(op): try: a = m.Out(m.SInt[2])() op(a, 2) except Exception as e: assert False, "This should work"
class ClocksT(m.Product): clk0 = m.In(m.Clock) clk1 = m.Out(m.Clock)
class mybuf_inc_test(m.Circuit): io = m.IO(in_=m.In(m.Bit), out=m.Out(m.Bit))
def __init__(self, tiles: Dict[int, Tile], config_addr_width: int, config_data_width: int, tile_id_width: int = 16, full_config_addr_width: int = 32): super().__init__() self.tiles = tiles self.config_addr_width = config_addr_width self.config_data_width = config_data_width self.tile_id_width = tile_id_width # compute config addr sizes # (16, 24) full_width = full_config_addr_width self.feature_addr_slice = slice(full_width - self.tile_id_width, full_width - self.config_addr_width) # (0, 16) self.tile_id_slice = slice(0, self.tile_id_width) # (24, 32) self.feature_config_slice = slice(full_width - self.config_addr_width, full_width) # sanity check x = -1 y = -1 core = None for bit_width, tile in self.tiles.items(): assert bit_width == tile.track_width if x == -1: x = tile.x y = tile.y core = tile.core else: assert x == tile.x assert y == tile.y # the restriction is that all the tiles in the same coordinate # have to have the same core, otherwise it's physically # impossible assert core == tile.core assert x != -1 and y != -1 self.x = x self.y = y self.core = core.core # create cb and switchbox self.cbs: Dict[str, CB] = {} self.sbs: Dict[int, SB] = {} # we only create cb if it's an input port, which doesn't have # graph neighbors for bit_width, tile in self.tiles.items(): core = tile.core # connection box time for port_name, port_node in tile.ports.items(): # input ports if len(port_node) == 0: # make sure that it has at least one connection assert len(port_node.get_conn_in()) > 0 assert bit_width == port_node.width # create a CB port_ref = core.get_port_ref(port_node.name) cb = CB(port_node, config_addr_width, config_data_width) self.wire(cb.ports.O, port_ref) self.cbs[port_name] = cb else: # output ports assert len(port_node.get_conn_in()) == 0 assert bit_width == port_node.width # switch box time sb = SB(tile.switchbox, config_addr_width, config_data_width) self.sbs[sb.switchbox.width] = sb # lift all the sb ports up for _, switchbox in self.sbs.items(): sbs = switchbox.switchbox.get_all_sbs() assert switchbox.switchbox.x == self.x assert switchbox.switchbox.y == self.y for sb in sbs: sb_name = create_name(str(sb)) node, mux = switchbox.sb_muxs[str(sb)] assert node == sb assert sb.x == self.x assert sb.y == self.y port = switchbox.ports[sb_name] if node.io == SwitchBoxIO.SB_IN: self.add_port(sb_name, magma.In(port.base_type())) # FIXME: # it seems like I need this hack to by-pass coreIR's # checking, even though it's connected below self.wire(self.ports[sb_name], mux.ports.I) else: self.add_port(sb_name, magma.Out(port.base_type())) assert port.owner() == switchbox self.wire(self.ports[sb_name], port) # connect ports from cb to switch box and back for _, cb in self.cbs.items(): conn_ins = cb.node.get_conn_in() for idx, node in enumerate(conn_ins): assert isinstance(node, SwitchBoxNode) assert node.x == self.x assert node.y == self.y bit_width = node.width sb_circuit = self.sbs[bit_width] if node.io == SwitchBoxIO.SB_IN: # get the internal wire n, sb_mux = sb_circuit.sb_muxs[str(node)] assert n == node self.wire(sb_mux.ports.O, cb.ports.I[idx]) else: sb_name = create_name(str(node)) self.wire(sb_circuit.ports[sb_name], cb.ports.I[idx]) # connect ports from core to switch box for bit_width, tile in self.tiles.items(): for _, port_node in tile.ports.items(): if len(port_node) > 0: assert len(port_node.get_conn_in()) == 0 port_name = port_node.name for sb_node in port_node: assert isinstance(sb_node, SwitchBoxNode) assert sb_node.x == self.x assert sb_node.y == self.y idx = sb_node.get_conn_in().index(port_node) sb_circuit = self.sbs[port_node.width] # we need to find the actual mux n, mux = sb_circuit.sb_muxs[str(sb_node)] assert n == sb_node # the generator doesn't allow circular reference # we have to be very creative here if port_name not in sb_circuit.ports: sb_circuit.add_port( port_name, magma.In(magma.Bits(bit_width))) self.wire(self.core.ports[port_name], sb_circuit.ports[port_name]) sb_circuit.wire(sb_circuit.ports[port_name], mux.ports.I[idx]) # add configuration space # we can't use the InterconnectConfigurable because the tile class # doesn't have any mux self.__add_config()
def definition(io): load = io.LOAD baud = io.BAUD valid_counter = mantle.CounterModM(buf_size, 13, has_ce=True) m.wire(load & baud, valid_counter.CE) valid_list = [wi * (b - 1) + i * a - 1 for i in range(1, wo + 1)] valid = m.GND for i in valid_list: valid = valid | mantle.Decode(i, 13)(valid_counter.O) # register on input st_in = mantle.Register(16, has_ce=True) st_in(io.DATA) m.wire(load, st_in.CE) # --------------------------DOWNSCALING----------------------------- # # downscale the image from 320x240 to 16x16 Downscale = m.DeclareCircuit( 'Downscale', "I_0_0", m.In(m.Array(1, m.Array(1, m.Array(16, m.Bit)))), "WE", m.In(m.Bit), "CLK", m.In(m.Clock), "O", m.Out(m.Array(16, m.Bit)), "V", m.Out(m.Bit)) dscale = Downscale() m.wire(st_in.O, dscale.I_0_0[0][0]) m.wire(1, dscale.WE) m.wire(load, dscale.CLK) add16 = mantle.Add(16) # needed for Add16 definition # --------------------------FILL IMG RAM--------------------------- # # each valid output of dscale represents an entry of 16x16 binary image # accumulate each group of 16 entries into a 16-bit value representing # a row of the image col = mantle.Counter(4, has_ce=True) row_full = mantle.SRFF(has_ce=True) row_full(mantle.EQ(4)(col.O, m.bits(15, 4)), 0) m.wire(falling(dscale.V), row_full.CE) col_ce = rising(dscale.V) & ~row_full.O m.wire(col_ce, col.CE) row = mantle.Counter(4, has_ce=True) img_full = mantle.SRFF(has_ce=True) img_full(mantle.EQ(4)(row.O, m.bits(15, 4)), 0) m.wire(falling(col.COUT), img_full.CE) row_ce = rising(col.COUT) & ~img_full.O m.wire(row_ce, row.CE) # ---------------------------UART OUTPUT----------------------------- # uart_st = UART(16) uart_st(CLK=io.CLK, BAUD=baud, DATA=dscale.O, LOAD=load) m.wire(row.O, io.ROW) m.wire(img_full.O, io.DONE) m.wire(uart_st.O, io.UART)
def __add_config(self): self.add_ports(config=magma.In( ConfigurationType(self.config_data_width, self.config_data_width)), tile_id=magma.In(magma.Bits(self.tile_id_width)), clk=magma.In(magma.Clock), reset=magma.In(magma.AsyncReset), read_config_data=magma.Out( magma.Bits(self.config_data_width))) features = self.features() num_features = len(features) self.read_data_mux = MuxWithDefaultWrapper(num_features, self.config_data_width, self.config_addr_width, 0) # most of the logic copied from tile_magma.py # remove all hardcoded values for feature in self.features(): self.wire(self.ports.config.config_addr[self.feature_config_slice], feature.ports.config.config_addr) self.wire(self.ports.config.config_data, feature.ports.config.config_data) self.wire(self.ports.config.read, feature.ports.config.read) # Connect S input to config_addr.feature. self.wire(self.ports.config.config_addr[self.feature_addr_slice], self.read_data_mux.ports.S) self.wire(self.read_data_mux.ports.O, self.ports.read_config_data) # Logic to generate EN input for read_data_mux self.read_and_tile = FromMagma(mantle.DefineAnd(2)) self.eq_tile = FromMagma(mantle.DefineEQ(self.tile_id_width)) # config_addr.tile_id == self.tile_id? self.wire(self.ports.tile_id, self.eq_tile.ports.I0) self.wire(self.ports.config.config_addr[self.tile_id_slice], self.eq_tile.ports.I1) # (config_addr.tile_id == self.tile_id) & READ self.wire(self.read_and_tile.ports.I0, self.eq_tile.ports.O) self.wire(self.read_and_tile.ports.I1, self.ports.config.read[0]) # read_data_mux.EN = (config_addr.tile_id == self.tile_id) & READ self.wire(self.read_and_tile.ports.O, self.read_data_mux.ports.EN[0]) # Logic for writing to config registers # Config_en_tile = (config_addr.tile_id == self.tile_id & WRITE) self.write_and_tile = FromMagma(mantle.DefineAnd(2)) self.wire(self.write_and_tile.ports.I0, self.eq_tile.ports.O) self.wire(self.write_and_tile.ports.I1, self.ports.config.write[0]) self.decode_feat = [] self.feat_and_config_en_tile = [] for i, feat in enumerate(self.features()): # wire each feature's read_data output to # read_data_mux inputs self.wire(feat.ports.read_config_data, self.read_data_mux.ports.I[i]) # for each feature, # config_en = (config_addr.feature == feature_num) & config_en_tile self.decode_feat.append( FromMagma(mantle.DefineDecode(i, self.config_addr_width))) self.feat_and_config_en_tile.append(FromMagma(mantle.DefineAnd(2))) self.wire(self.decode_feat[i].ports.I, self.ports.config.config_addr[self.feature_addr_slice]) self.wire(self.decode_feat[i].ports.O, self.feat_and_config_en_tile[i].ports.I0) self.wire(self.write_and_tile.ports.O, self.feat_and_config_en_tile[i].ports.I1) self.wire(self.feat_and_config_en_tile[i].ports.O, feat.ports.config.write[0])
class dut(m.Circuit): name = 'test_sub' io = m.IO(a=m.In(m.SInt[63]), b=m.In(m.SInt[63]), c=m.Out(m.SInt[64]))
class CSR_DUT(m.Circuit): io = m.IO(done=m.Out(m.Bit), check=m.Out(m.Bit), rdata=m.Out(m.UInt[x_len]), expected_rdata=m.Out(m.UInt[x_len]), epc=m.Out(m.UInt[x_len]), expected_epc=m.Out(m.UInt[x_len]), evec=m.Out(m.UInt[x_len]), expected_evec=m.Out(m.UInt[x_len]), expt=m.Out(m.Bit), expected_expt=m.Out(m.Bit)) io += m.ClockIO(has_reset=True) regs = {} for reg in CSR.regs: if reg == CSR.mcpuid: init = (1 << (ord('I') - ord('A')) | 1 << (ord('U') - ord('A'))) elif reg == CSR.mstatus: init = (CSR.PRV_M.ext(30) << 4) | (CSR.PRV_M.ext(30) << 1) elif reg == CSR.mtvec: init = Const.PC_EVEC else: init = 0 regs[reg] = m.Register(init=BV[32](init), reset_type=m.Reset)() csr = CSRGen(x_len)() ctrl = Control.Control(x_len)() counter = CounterModM(n, n.bit_length()) inst = m.mux(insts, counter.O) ctrl.inst @= inst csr.inst @= inst csr_cmd = ctrl.csr_cmd csr.cmd @= csr_cmd csr.illegal @= ctrl.illegal csr.st_type @= ctrl.st_type csr.ld_type @= ctrl.ld_type csr.pc_check @= ctrl.pc_sel == Control.PC_ALU csr.pc @= m.mux(pc, counter.O) csr.addr @= m.mux(addr, counter.O) csr.I @= m.mux(data, counter.O) csr.stall @= False csr.host.fromhost.valid @= False csr.host.fromhost.data @= 0 # values known statically _csr_addr = [csr(inst) for inst in insts] _rs1_addr = [rs1(inst) for inst in insts] _csr_ro = [((((x >> 11) & 0x1) > 0x0) & (((x >> 10) & 0x1) > 0x0)) | (x == CSR.mtvec) | (x == CSR.mtdeleg) for x in _csr_addr] _csr_valid = [x in CSR.regs for x in _csr_addr] # should be <= prv in runtime _prv_level = [(x >> 8) & 0x3 for x in _csr_addr] # should consider prv in runtime _is_ecall = [((x & 0x1) == 0x0) & (((x >> 8) & 0x1) == 0x0) for x in _csr_addr] _is_ebreak = [((x & 0x1) > 0x0) & (((x >> 8) & 0x1) == 0x0) for x in _csr_addr] _is_eret = [((x & 0x1) == 0x0) & (((x >> 8) & 0x1) > 0x0) for x in _csr_addr] # should consider pc_check in runtime _iaddr_invalid = [((x >> 1) & 0x1) > 0 for x in addr] # should consider ld_type & sd_type _waddr_invalid = [(((x >> 1) & 0x1) > 0) | ((x & 0x1) > 0) for x in addr] _haddr_invalid = [(x & 0x1) > 0 for x in addr] # values known at runtime csr_addr = m.mux(_csr_addr, counter.O) rs1_addr = m.mux(_rs1_addr, counter.O) csr_ro = m.mux(_csr_ro, counter.O) csr_valid = m.mux(_csr_valid, counter.O) wen = (csr_cmd == CSR.W) | (csr_cmd[1] & (rs1_addr != 0)) prv1 = (regs[CSR.mstatus].O >> 4) & 0x3 ie1 = (regs[CSR.mstatus].O >> 3) & 0x1 prv = (regs[CSR.mstatus].O >> 1) & 0x3 ie = regs[CSR.mstatus].O & 0x1 prv_inst = csr_cmd == CSR.P prv_valid = (m.uint(m.zext_to(m.mux(_prv_level, counter.O), 32)) <= m.uint(prv)) iaddr_invalid = m.mux(_iaddr_invalid, counter.O) & csr.pc_check.value() laddr_invalid = (m.mux(_haddr_invalid, counter.O) & ((ctrl.ld_type == Control.LD_LH) | (ctrl.ld_type == Control.LD_LHU)) | m.mux(_waddr_invalid, counter.O) & (ctrl.ld_type == Control.LD_LW)) saddr_invalid = (m.mux(_haddr_invalid, counter.O) & (ctrl.st_type == Control.ST_SH) | m.mux(_waddr_invalid, counter.O) & (ctrl.st_type == Control.ST_SW)) is_ecall = prv_inst & m.mux(_is_ecall, counter.O) is_ebreak = prv_inst & m.mux(_is_ebreak, counter.O) is_eret = prv_inst & m.mux(_is_eret, counter.O) exception = (ctrl.illegal | iaddr_invalid | laddr_invalid | saddr_invalid | (((csr_cmd & 0x3) > 0) & (~csr_valid | ~prv_valid)) | (csr_ro & wen) | (prv_inst & ~prv_valid) | is_ecall | is_ebreak) instret = (inst != nop) & (~exception | is_ecall | is_ebreak) rdata = m.dict_lookup({key: value.O for key, value in regs.items()}, csr_addr) wdata = m.dict_lookup( { CSR.W: csr.I.value(), CSR.S: (csr.I.value() | rdata), CSR.C: (~csr.I.value() & rdata) }, csr_cmd) # compute state regs[CSR.time].I @= regs[CSR.time].O + 1 regs[CSR.timew].I @= regs[CSR.timew].O + 1 regs[CSR.mtime].I @= regs[CSR.mtime].O + 1 regs[CSR.cycle].I @= regs[CSR.cycle].O + 1 regs[CSR.cyclew].I @= regs[CSR.cyclew].O + 1 time_max = regs[CSR.time].O.reduce_and() # TODO: mtime has same default value as this case (from chisel code) # https://github.com/ucb-bar/riscv-mini/blob/release/src/test/scala/CSRTests.scala#L140 # mtime_reg = regs[CSR.mtime] # mtime_reg.I @= m.mux([mtime_reg.O, mtime_reg.O + 1], time_max) incr_when(regs[CSR.timeh], time_max) incr_when(regs[CSR.timehw], time_max) cycle_max = regs[CSR.cycle].O.reduce_and() incr_when(regs[CSR.cycleh], cycle_max) incr_when(regs[CSR.cyclehw], cycle_max) incr_when(regs[CSR.instret], instret) incr_when(regs[CSR.instretw], instret) instret_max = regs[CSR.instret].O.reduce_and() incr_when(regs[CSR.instreth], instret & instret_max) incr_when(regs[CSR.instrethw], instret & instret_max) cond = ~exception & ~is_eret & wen # Assuming these are mutually exclusive, so we don't need chained # elsewhen update_when(regs[CSR.mstatus], m.zext_to(wdata[0:6], 32), cond & (csr_addr == CSR.mstatus)) update_when(regs[CSR.mip], (m.bits(wdata[7], 32) << 7) | (m.bits(wdata[3], 32) << 3), cond & (csr_addr == CSR.mip)) update_when(regs[CSR.mie], (m.bits(wdata[7], 32) << 7) | (m.bits(wdata[3], 32) << 3), cond & (csr_addr == CSR.mie)) update_when(regs[CSR.mepc], (wdata >> 2) << 2, cond & (csr_addr == CSR.mepc)) update_when(regs[CSR.mcause], wdata & (1 << 31 | 0xf), cond & (csr_addr == CSR.mcause)) update_when(regs[CSR.time], wdata, cond & ((csr_addr == CSR.timew) | (csr_addr == CSR.mtime))) update_when(regs[CSR.timew], wdata, cond & ((csr_addr == CSR.timew) | (csr_addr == CSR.mtime))) update_when(regs[CSR.mtime], wdata, cond & ((csr_addr == CSR.timew) | (csr_addr == CSR.mtime))) update_when( regs[CSR.timeh], wdata, cond & ((csr_addr == CSR.timehw) | (csr_addr == CSR.mtimeh))) update_when( regs[CSR.timehw], wdata, cond & ((csr_addr == CSR.timehw) | (csr_addr == CSR.mtimeh))) update_when( regs[CSR.mtimeh], wdata, cond & ((csr_addr == CSR.timehw) | (csr_addr == CSR.mtimeh))) update_when(regs[CSR.cycle], wdata, cond & (csr_addr == CSR.cyclew)) update_when(regs[CSR.cyclew], wdata, cond & (csr_addr == CSR.cyclew)) update_when(regs[CSR.cycleh], wdata, cond & (csr_addr == CSR.cyclehw)) update_when(regs[CSR.cyclehw], wdata, cond & (csr_addr == CSR.cyclehw)) update_when(regs[CSR.instret], wdata, cond & (csr_addr == CSR.instretw)) update_when(regs[CSR.instretw], wdata, cond & (csr_addr == CSR.instretw)) update_when(regs[CSR.instreth], wdata, cond & (csr_addr == CSR.instrethw)) update_when(regs[CSR.instrethw], wdata, cond & (csr_addr == CSR.instrethw)) update_when(regs[CSR.mtimecmp], wdata, cond & (csr_addr == CSR.mtimecmp)) update_when(regs[CSR.mscratch], wdata, cond & (csr_addr == CSR.mscratch)) update_when(regs[CSR.mbadaddr], wdata, cond & (csr_addr == CSR.mbadaddr)) update_when(regs[CSR.mtohost], wdata, cond & (csr_addr == CSR.mtohost)) update_when(regs[CSR.mfromhost], wdata, cond & (csr_addr == CSR.mfromhost)) # eret update_when(regs[CSR.mstatus], (CSR.PRV_U.zext(30) << 4) | (1 << 3) | (prv1 << 1) | ie1, ~exception & is_eret) # TODO: exception logic comes after since it has priority Cause = make_Cause(x_len) mcause = m.mux([ m.mux([ m.mux([ m.mux([ m.mux([Cause.IllegalInst, Cause.Breakpoint], is_ebreak), Cause.Ecall + prv, ], is_ecall), Cause.StoreAddrMisaligned, ], saddr_invalid), Cause.LoadAddrMisaligned, ], laddr_invalid), Cause.InstAddrMisaligned, ], iaddr_invalid) update_when(regs[CSR.mcause], mcause, exception) update_when(regs[CSR.mepc], (csr.pc.value() >> 2) << 2, exception) update_when(regs[CSR.mstatus], (prv << 4) | (ie << 3) | (CSR.PRV_M.zext(30) << 1), exception) update_when( regs[CSR.mbadaddr], csr.addr.value(), exception & (iaddr_invalid | laddr_invalid | saddr_invalid)) epc = regs[CSR.mepc].O evec = regs[CSR.mtvec].O + (prv << 6) m.display("*** Counter: %d ***", counter.O) m.display("[in] inst: 0x%x, pc: 0x%x, addr: 0x%x, in: 0x%x", csr.inst, csr.pc, csr.addr, csr.I) m.display( " cmd: 0x%x, st_type: 0x%x, ld_type: 0x%x, illegal: %d, " "pc_check: %d", csr.cmd, csr.st_type, csr.ld_type, csr.illegal, csr.pc_check) m.display("[state] csr addr: %x", csr_addr) for reg_addr, reg in regs.items(): m.display(f" {hex(int(reg_addr))} -> 0x%x", reg.O) m.display( "[out] read: 0x%x =? 0x%x, epc: 0x%x =? 0x%x, evec: 0x%x ?= " "0x%x, expt: %d ?= %d", csr.O, rdata, csr.epc, epc, csr.evec, evec, csr.expt, exception) io.check @= counter.O.reduce_or() io.rdata @= csr.O io.expected_rdata @= rdata io.epc @= csr.epc io.expected_epc @= epc io.evec @= csr.evec io.expected_evec @= evec io.expt @= csr.expt io.expected_expt @= exception # io.failed @= counter.O.reduce_or() & ( # (csr.O != rdata) | # (csr.epc != epc) | # (csr.evec != evec) | # (csr.expt != exception) # ) io.done @= counter.COUT for key, reg in regs.items(): if not reg.I.driven(): reg.I @= reg.O
class Process(m.Circuit): name = "Process" IO = ['CLK', m.In(m.Clock), 'SCK', m.In(m.Bit), 'DATA', m.In(m.Bits(8)), 'VALID', m.In(m.Bit), 'PXV', m.Out(m.Bits(16)), 'UART', m.Out(m.Bit), 'LOAD', m.Out(m.Bit)] @classmethod def definition(io): edge_r = rising(io.SCK) edge_f = falling(io.SCK) # pixels come 16 bits (high and low byte) at a time bit_counter = mantle.Counter(4, has_ce=True, has_reset=True) m.wire(edge_r, bit_counter.CE) # find when the high and low byte are valid low = mantle.Decode(15, 4)(bit_counter.O) high = mantle.Decode(7, 4)(bit_counter.O) # shift registers to store high and low byte low_byte = mantle.PIPO(8, has_ce=True) high_byte = mantle.PIPO(8, has_ce=True) low_byte(0, io.DATA, low) high_byte(0, io.DATA, high) m.wire(low, low_byte.CE) m.wire(high, high_byte.CE) # assemble the 16-bit RGB565 value px_bits = (m.uint(mantle.LSL(16)((m.uint(m.concat(high_byte.O, zeros))), m.bits(8, 4))) + m.uint(m.concat(low_byte.O, zeros))) # extract the values for each color r_val = m.uint(mantle.LSR(16)((px_bits & RMASK), m.bits(11, 4))) g_val = m.uint(mantle.LSR(16)((px_bits & GMASK), m.bits(5, 4))) b_val = m.uint(px_bits & BMASK) # sum them to get grayscale (0 to 125) px_val = (r_val + g_val + b_val) # --------------------------UART OUTPUT---------------------------- # # run 16-bit UART at 2x speed baud = edge_r | edge_f # reset at start of pixel transfer ff1 = mantle.FF(has_ce=True) m.wire(baud, ff1.CE) u_reset = mantle.LUT2(I0 & ~I1)(io.VALID, ff1(io.VALID)) m.wire(u_reset, bit_counter.RESET) # generate load signal ff2 = mantle.FF(has_ce=True) m.wire(baud, ff2.CE) load = mantle.LUT3(I0 & I1 & ~I2)(io.VALID, high, ff2(high)) uart = UART(16) uart(CLK=io.CLK, BAUD=baud, DATA=px_val, LOAD=load) m.wire(px_val, io.PXV) m.wire(uart, io.UART) m.wire(load, io.LOAD)
m.wire(load & baud, printf.CE) px_val = rom.O # register on input st_in = mantle.Register(16, has_ce=True) st_in(px_val) m.wire(load, st_in.CE) # ---------------------------STENCILING----------------------------- # Downscale = m.DeclareCircuit('Downscale', "I_0_0", m.In(m.Array(1, m.Array(1, m.Array(16, m.Bit)))), "WE", m.In(m.Bit), "CLK", m.In(m.Clock), "O", m.Out(m.Array(16, m.Bit)), "V", m.Out(m.Bit)) dscale = Downscale() m.wire(st_in.O, dscale.I_0_0[0][0]) m.wire(1, dscale.WE) m.wire(load, dscale.CLK) add16 = mantle.Add(16) # needed for Add16 definition # ---------------------------UART OUTPUT----------------------------- # uart_px = UART(16) uart_px(CLK=main.CLKIN, BAUD=baud, DATA=px_val, LOAD=load) uart_st = UART(16)
a = 20 b = 15 samples = 16 # image dimensions (height and width) im_w = 320 im_h = 240 c = coreir.Context() cirb = CoreIRBackend(c) scope = Scope() # 8-bit values but extend to 16-bit to avoid carryover in addition width = 16 TIN = m.Array(width, m.BitIn) TOUT = m.Array(width, m.Out(Bit)) # Line Buffer interface inType = m.Array(1, m.Array(1, TIN)) # one pixel in per clock outType = m.Array(width, m.Array(a, TOUT)) # downscale window imgType = m.Array(im_h, m.Array(im_w, TIN)) # image dimensions # Reduce interface inType2 = m.In(m.Array(a*b, TIN)) outType2 = TOUT # Top level module: line buffer input, reduce output args = ['I', inType, 'O', outType2, 'WE', m.BitIn, 'V', m.Out(m.Bit)] + \ m.ClockInterface(False, False) top = m.DefineCircuit('Downscale', *args)
class TestPLL(m.Circuit): io = m.IO(CLKIN=m.In(m.Clock), CLKOUT=m.Out(m.Clock)) clk = SB_PLL(32000000, 16000000)(I=io.CLKIN) m.wire(clk, io.CLKOUT)
def DefineDecode(i, n): circ = m.DefineCircuit(f"Decode{i}{n}", "I", m.In(m.Bits(n)), "O", m.Out(m.Bit)) m.wire(circ.O, EQ(n)(circ.I, m.bits(i, n))) m.EndDefine() return circ
def __init__(self, x_len): Cause = make_Cause(x_len) self.io = io = m.IO( stall=m.In(m.Bit), cmd=m.In(m.UInt[3]), I=m.In(m.UInt[x_len]), O=m.Out(m.UInt[x_len]), # Excpetion pc=m.In(m.UInt[x_len]), addr=m.In(m.UInt[x_len]), inst=m.In(m.UInt[x_len]), illegal=m.In(m.Bit), st_type=m.In(m.UInt[2]), ld_type=m.In(m.UInt[3]), pc_check=m.In(m.Bit), expt=m.Out(m.Bit), evec=m.Out(m.UInt[x_len]), epc=m.Out( m.UInt[x_len])) + HostIO(x_len) + m.ClockIO(has_reset=True) csr_addr = io.inst[20:32] rs1_addr = io.inst[15:20] # user counters time = m.Register(m.UInt[x_len], reset_type=m.Reset)() timeh = m.Register(m.UInt[x_len], reset_type=m.Reset)() cycle = m.Register(m.UInt[x_len], reset_type=m.Reset)() cycleh = m.Register(m.UInt[x_len], reset_type=m.Reset)() instret = m.Register(m.UInt[x_len], reset_type=m.Reset)() instreth = m.Register(m.UInt[x_len], reset_type=m.Reset)() mcpuid = m.concat( BV[26]( 1 << (ord('I') - ord('A')) | # Base ISA 1 << (ord('U') - ord('A'))), # User Mode BV[x_len - 28](0), BV[2](0), # RV32I ) mimpid = BV[x_len](0) mhartid = BV[x_len](0) # interrupt enable stack PRV = m.Register(m.UInt[len(CSR.PRV_M)], init=CSR.PRV_M, reset_type=m.Reset)() PRV1 = m.Register(m.UInt[len(CSR.PRV_M)], init=CSR.PRV_M, reset_type=m.Reset)() PRV2 = BV[2](0) PRV3 = BV[2](0) IE = m.Register(m.Bit, init=False, reset_type=m.Reset)() IE1 = m.Register(m.Bit, init=False, reset_type=m.Reset)() IE2 = False IE3 = False # virtualization management field VM = BV[5](0) # memory privilege MPRV = False # Extension context status XS = BV[2](0) FS = BV[2](0) SD = BV[1](0) mstatus = m.concat(IE.O, PRV.O, IE1.O, PRV1.O, IE2, PRV2, IE3, PRV3, FS, XS, MPRV, VM, BV[x_len - 23](0), SD) mtvec = BV[x_len](Const.PC_EVEC) mtdeleg = BV[x_len](0) # interrupt registers MTIP = m.Register(m.Bit, init=False, reset_type=m.Reset)() HTIP = False STIP = False MTIE = m.Register(m.Bit, init=False, reset_type=m.Reset)() HTIE = False STIE = False MSIP = m.Register(m.Bit, init=False, reset_type=m.Reset)() HSIP = False SSIP = False MSIE = m.Register(m.Bit, init=False, reset_type=m.Reset)() HSIE = False SSIE = False mip = m.concat(Bit(False), SSIP, HSIP, MSIP.O, Bit(False), STIP, HTIP, MTIP.O, BV[x_len - 8](0)) mie = m.concat(Bit(False), SSIE, HSIE, MSIE.O, Bit(False), STIE, HTIE, MTIE.O, BV[x_len - 8](0)) mtimecmp = m.Register(m.UInt[x_len], reset_type=m.Reset)() mscratch = m.Register(m.UInt[x_len], reset_type=m.Reset)() mepc = m.Register(m.UInt[x_len], reset_type=m.Reset)() mcause = m.Register(m.UInt[x_len], reset_type=m.Reset)() mbadaddr = m.Register(m.UInt[x_len], reset_type=m.Reset)() mtohost = m.Register(m.UInt[x_len], reset_type=m.Reset)() mfromhost = m.Register(m.UInt[x_len], reset_type=m.Reset)() io.host.tohost @= mtohost.O csr_file = { CSR.cycle: cycle.O, CSR.time: time.O, CSR.instret: instret.O, CSR.cycleh: cycleh.O, CSR.timeh: timeh.O, CSR.instreth: instreth.O, CSR.cyclew: cycle.O, CSR.timew: time.O, CSR.instretw: instret.O, CSR.cyclehw: cycleh.O, CSR.timehw: timeh.O, CSR.instrethw: instreth.O, CSR.mcpuid: mcpuid, CSR.mimpid: mimpid, CSR.mhartid: mhartid, CSR.mtvec: mtvec, CSR.mtdeleg: mtdeleg, CSR.mie: mie, CSR.mtimecmp: mtimecmp.O, CSR.mtime: time.O, CSR.mtimeh: timeh.O, CSR.mscratch: mscratch.O, CSR.mepc: mepc.O, CSR.mcause: mcause.O, CSR.mbadaddr: mbadaddr.O, CSR.mip: mip, CSR.mtohost: mtohost.O, CSR.mfromhost: mfromhost.O, CSR.mstatus: mstatus, } out = m.dict_lookup(csr_file, csr_addr) io.O @= out priv_valid = csr_addr[8:10] <= PRV.O priv_inst = io.cmd == CSR.P is_E_call = priv_inst & ~csr_addr[0] & ~csr_addr[8] is_E_break = priv_inst & csr_addr[0] & ~csr_addr[8] is_E_ret = priv_inst & ~csr_addr[0] & csr_addr[8] csr_valid = m.reduce(operator.or_, m.bits([csr_addr == key for key in csr_file])) csr_RO = (csr_addr[10:12].reduce_and() | (csr_addr == CSR.mtvec) | (csr_addr == CSR.mtdeleg)) wen = (io.cmd == CSR.W) | io.cmd[1] & rs1_addr.reduce_or() wdata = m.dict_lookup( { CSR.W: io.I, CSR.S: out | io.I, CSR.C: out & ~io.I }, io.cmd) iaddr_invalid = io.pc_check & io.addr[1] laddr_invalid = m.dict_lookup( { Control.LD_LW: io.addr[0:2].reduce_or(), Control.LD_LH: io.addr[0], Control.LD_LHU: io.addr[0] }, io.ld_type) saddr_invalid = m.dict_lookup( { Control.ST_SW: io.addr[0:2].reduce_or(), Control.ST_SH: io.addr[0] }, io.st_type) expt = (io.illegal | iaddr_invalid | laddr_invalid | saddr_invalid | io.cmd[0:2].reduce_or() & (~csr_valid | ~priv_valid) | wen & csr_RO | (priv_inst & ~priv_valid) | is_E_call | is_E_break) io.expt @= expt io.evec @= mtvec + (m.zext_to(PRV.O, x_len) << 6) io.epc @= mepc.O @m.inline_combinational() def logic(): # Counters time.I @= time.O + 1 timeh.I @= timeh.O if time.O.reduce_and(): timeh.I @= timeh.O + 1 cycle.I @= cycle.O + 1 cycleh.I @= cycleh.O if cycle.O.reduce_and(): cycleh.I @= cycleh.O + 1 instret.I @= instret.O is_inst_ret = ((io.inst != Instructions.NOP) & (~expt | is_E_call | is_E_break) & ~io.stall) if is_inst_ret: instret.I @= instret.O + 1 instreth.I @= instreth.O if is_inst_ret & instret.O.reduce_and(): instreth.I @= instreth.O + 1 mbadaddr.I @= mbadaddr.O mepc.I @= mepc.O mcause.I @= mcause.O PRV.I @= PRV.O IE.I @= IE.O IE1.I @= IE1.O PRV1.I @= PRV1.O MTIP.I @= MTIP.O MSIP.I @= MSIP.O MTIE.I @= MTIE.O MSIE.I @= MSIE.O mtimecmp.I @= mtimecmp.O mscratch.I @= mscratch.O mtohost.I @= mtohost.O mfromhost.I @= mfromhost.O if io.host.fromhost.valid: mfromhost.I @= io.host.fromhost.data if ~io.stall: if expt: mepc.I @= io.pc >> 2 << 2 if iaddr_invalid: mcause.I @= Cause.InstAddrMisaligned elif laddr_invalid: mcause.I @= Cause.LoadAddrMisaligned elif saddr_invalid: mcause.I @= Cause.StoreAddrMisaligned elif is_E_call: mcause.I @= Cause.Ecall + m.zext_to(PRV.O, x_len) elif is_E_break: mcause.I @= Cause.Breakpoint else: mcause.I @= Cause.IllegalInst PRV.I @= CSR.PRV_M IE.I @= False PRV1.I @= PRV.O IE1.I @= IE.O if iaddr_invalid | laddr_invalid | saddr_invalid: mbadaddr.I @= io.addr elif is_E_ret: PRV.I @= PRV1.O IE.I @= IE1.O PRV1.I @= CSR.PRV_U IE1.I @= True elif wen: if csr_addr == CSR.mstatus: PRV1.I @= wdata[4:6] IE1.I @= wdata[3] PRV.I @= wdata[1:3] IE.I @= wdata[0] elif csr_addr == CSR.mip: MTIP.I @= wdata[7] MSIP.I @= wdata[3] elif csr_addr == CSR.mie: MTIE.I @= wdata[7] MSIE.I @= wdata[3] elif csr_addr == CSR.mtime: time.I @= wdata elif csr_addr == CSR.mtimeh: timeh.I @= wdata elif csr_addr == CSR.mtimecmp: mtimecmp.I @= wdata elif csr_addr == CSR.mscratch: mscratch.I @= wdata elif csr_addr == CSR.mepc: mepc.I @= wdata >> 2 << 2 elif csr_addr == CSR.mcause: mcause.I @= wdata & (1 << (x_len - 1) | 0xf) elif csr_addr == CSR.mbadaddr: mbadaddr.I @= wdata elif csr_addr == CSR.mtohost: mtohost.I @= wdata elif csr_addr == CSR.mfromhost: mfromhost.I @= wdata elif csr_addr == CSR.cyclew: cycle.I @= wdata elif csr_addr == CSR.timew: time.I @= wdata elif csr_addr == CSR.instretw: instret.I @= wdata elif csr_addr == CSR.cyclehw: cycleh.I @= wdata elif csr_addr == CSR.timehw: timeh.I @= wdata elif csr_addr == CSR.instrethw: instreth.I @= wdata