def print_kernel(name: str, printer: util.Printer, node: ir.Module, module_trait: ir.ModuleTrait, module_trait_id: int, burst_width: int = 256) -> None: if node.dram_reads and node.dram_writes: raise ValueError('cannot read and write DRAM in the same module') println = printer.println # print I/O info for port, arg in zip(module_trait.loads, node.input_fifos): println('// input <{0.cl_type}> <- {1}'.format(port, arg)) for expr, arg in zip(module_trait.exprs, node.output_fifos): println('// output <{}> -> {}'.format(expr.cl_type, arg)) # print kernel function printer.println('__kernel') kernel_attrs = ['reqd_work_group_size(1, 1, 1)', 'max_global_work_dim(0)'] def print_kernel_attrs(): for attr in kernel_attrs: println('__attribute(({}))'.format(attr)) if node.dram_reads or node.dram_writes: params = [ '__global {} {}* restrict {}'.format( '__attribute((buffer_location("HBM{}")))'.format(bank), dram_ref.haoda_type.get_cl_vec_type(burst_width), util.get_port_name(dram_ref.var, bank)) for dram_ref, bank in node.dram_reads or node.dram_writes ] params.append('ulong coalesced_data_num') kernel_attrs.append('uses_global_work_offset(0)') print_kernel_attrs() printer.print_func(name='void {}'.format(name), params=params, align=0) else: kernel_attrs.append('autorun') print_kernel_attrs() println('void {}()'.format(name)) printer.do_scope(name) # prepare for any DelayedRef def get_delays(obj: ir.Node, delays: List[ir.DelayedRef]) -> ir.Node: if isinstance(obj, ir.DelayedRef): delays.append(obj) return obj delays = [] # type: List[ir.DelayedRef] for let in module_trait.lets: let.visit(get_delays, delays) for expr in module_trait.exprs: expr.visit(get_delays, delays) _logger.debug('delays: %s', delays) # inter-iteration declarations for delay in delays: println(delay.cl_buf_decl) println(delay.cl_ptr_decl) def mutate_dram_ref(obj: ir.Node, kwargs: Dict[str, int]) -> ir.Node: if isinstance(obj, ir.DRAMRef): dram_throughput = len(obj.dram) * burst_width // obj.width_in_bits if node.dram_reads: output_count = len({x[0].var for x in node.dram_reads}) fifo_throughput = len(node.output_fifos) // output_count else: input_count = len({x[0].var for x in node.dram_writes}) fifo_throughput = len(node.input_fifos) // input_count if dram_throughput != fifo_throughput: raise NotImplementedError( f'memory throughput {dram_throughput} != ' f'processing throughput {fifo_throughput}') coalescing_idx = kwargs['coalescing_idx'] unroll_factor = kwargs['unroll_factor'] elem_idx = coalescing_idx * unroll_factor + obj.offset return ir.Var( name='{buf}[{idx}]'.format( buf=obj.dram_buf_name(obj.dram[elem_idx % len(obj.dram)]), idx=elem_idx // len(obj.dram), ), idx=(), ) if isinstance(obj, ir.Let) and isinstance(obj.name, ir.DRAMRef): return ir.Var(name='{} = {};'.format( obj.name.visit(mutate_dram_ref, kwargs), obj.expr.cl_expr), idx=()) return obj if node.dram_reads or node.dram_writes: loop_args = 'ulong i = 0', 'i < coalesced_data_num', '++i' else: loop_args = '', '', '' if delays: println('#pragma ivdep', indent=0) with printer.for_(*loop_args): # print DelayedRef (if any) for delay in delays: println('const {} {};'.format(delay.cl_type, delay.c_buf_load)) # print load from DRAM (if any) for dram_ref, bank in node.dram_reads: println('{gentype} {buf}[{n}];'.format( gentype=dram_ref.cl_type, n=burst_width // dram_ref.width_in_bits, buf=dram_ref.dram_buf_name(bank), )) for dram_ref, bank in node.dram_reads: println('vstore{n}({vec_ptr}[i], 0, {buf});'.format( vec_ptr=util.get_port_name(dram_ref.var, bank), n=burst_width // dram_ref.width_in_bits, buf=dram_ref.dram_buf_name(bank), )) # read from FIFOs for port, arg in zip(module_trait.loads, node.input_fifos): println( '{0.cl_type} {0.ref_name} = read_channel_intel({1});'.format( port, arg)) # declare buffer for DRAM writes for dram_ref, bank in node.dram_writes: n = burst_width // dram_ref.width_in_bits buf = dram_ref.dram_buf_name(bank) println('{gentype} {buf}[{n}];'.format( gentype=dram_ref.cl_type, n=n, buf=buf, )) mutate_kwargs = { 'coalescing_idx': 0, 'unroll_factor': len(node.input_fifos) or len(node.output_fifos), } # print Let (if any) for let in module_trait.lets: println(let.visit(mutate_dram_ref, mutate_kwargs).cl_expr) for expr, arg in zip(module_trait.exprs, node.output_fifos): println('write_channel_intel({}, {});'.format( arg, expr.visit(mutate_dram_ref, mutate_kwargs).cl_expr)) # update DelayedRef (if any) for delay in delays: println(delay.c_buf_store) println('{} = {};'.format(delay.ptr, delay.cl_next_ptr_expr)) # print store to DRAM (if any) for dram_ref, bank in node.dram_writes: println('{vec_ptr}[i] = vload{n}(0, {buf});'.format( vec_ptr=util.get_port_name(dram_ref.var, bank), n=burst_width // dram_ref.width_in_bits, buf=dram_ref.dram_buf_name(bank), )) printer.un_scope() # end of kernel function _logger.debug('printing: %s', module_trait)
def print_top_module( printer: backend.VerilogPrinter, super_source: dataflow.SuperSourceNode, inputs: Sequence[Tuple[str, str, int, str]], outputs: Sequence[Tuple[str, str, int, str]], module_name: str = 'Dataflow', interface: str = 'm_axi', ) -> None: """Generate kernel.xml file. Args: printer: printer to print to super_source: SuperSourceNode carrying the IR tree. inputs: sequence of (port_name, bundle_name, width, _) of input ports outputs: sequence of (port_name, bundle_name, width, _) of output ports module_name: name of the module interface: interface type, supported values are 'm_axi' and 'axis' """ printer.printlns('`timescale 1 ns / 1 ps', '`default_nettype none') ports = *inputs, *outputs # unpack suffixes data_in = FIFO_PORT_SUFFIXES['data_in'] not_full = FIFO_PORT_SUFFIXES['not_full'] write_enable = FIFO_PORT_SUFFIXES['write_enable'] data_out = FIFO_PORT_SUFFIXES['data_out'] not_empty = FIFO_PORT_SUFFIXES['not_empty'] read_enable = FIFO_PORT_SUFFIXES['read_enable'] not_block = FIFO_PORT_SUFFIXES['not_block'] data = AXIS_PORT_SUFFIXES['data'] valid = AXIS_PORT_SUFFIXES['valid'] ready = AXIS_PORT_SUFFIXES['ready'] # prepare arguments input_args = ['ap_clk'] output_args: List[str] = [] if interface == 'm_axi': input_args += 'ap_rst', 'ap_start', 'ap_continue' output_args += 'ap_done', 'ap_idle', 'ap_ready' elif interface == 'axis': input_args.append('ap_rst_n') args = list(input_args + output_args) if interface == 'm_axi': for port_name, _, _, _ in outputs: args.append(f'{port_name}_V_V{data_in}') args.append(f'{port_name}_V_V{not_full}') args.append(f'{port_name}_V_V{write_enable}') for port_name, _, _, _ in inputs: args.append(f'{port_name}_V_V{data_out}') args.append(f'{port_name}_V_V{not_empty}') args.append(f'{port_name}_V_V{read_enable}') elif interface == 'axis': for port_name, _, _, _ in ports: args.extend(port_name + suffix for suffix in AXIS_PORT_SUFFIXES.values()) # print module interface printer.module(module_name, args) printer.println() # print signals for modules printer.printlns( *(f'input wire {arg};' for arg in input_args), *(f'output wire {arg};' for arg in output_args), ) for port_name, _, width, _ in outputs: if interface == 'm_axi': printer.printlns( f'output wire [{width - 1}:0] {port_name}_V_V{data_in};', f'input wire {port_name}_V_V{not_full};', f'output wire {port_name}_V_V{write_enable};', ) elif interface == 'axis': printer.printlns( f'output wire [{width - 1}:0] {port_name}{data};', f'output wire {port_name}{valid};', f'input wire {port_name}{ready};', f'wire [{width - 1}:0] {port_name}_V_V{data_in};', f'wire {port_name}_V_V{not_full};', f'wire {port_name}_V_V{write_enable};', ) for port_name, _, width, _ in inputs: if interface == 'm_axi': printer.printlns( f'input wire [{width - 1}:0] {port_name}_V_V{data_out};', f'input wire {port_name}_V_V{not_empty};', f'output wire {port_name}_V_V{read_enable};', ) elif interface == 'axis': printer.printlns( f'input wire [{width - 1}:0] {port_name}{data};', f'input wire {port_name}{valid};', f'output wire {port_name}{ready};', f'wire [{width - 1}:0] {port_name}_V_V{data_out};', f'wire {port_name}_V_V{not_empty};', f'wire {port_name}_V_V{read_enable};', ) printer.println() # not used printer.printlns( "reg ap_done = 1'b0;", "reg ap_idle = 1'b1;", "reg ap_ready = 1'b0;", ) if interface == 'axis': printer.println("wire ap_start = 1'b1;") # print signals for FIFOs if interface == 'm_axi': for port_name, _, width, _ in outputs: printer.printlns( f'reg [{width - 1}:0] {port_name}{data_in};', f'wire {port_name}_V_V{write_enable};', ) for port_name, _, _, _ in inputs: printer.println(f'wire {port_name}_V_V{read_enable};') printer.println() # register reset signal ap_rst_reg_level = 8 rst = 'ap_rst' if interface == 'axis': rst = '~ap_rst_n' printer.printlns( f'wire ap_rst_reg_0 = {rst};', *(f'(* shreg_extract = "no", max_fanout = {8 ** i} *) reg ap_rst_reg_{i};' for i in range(1, ap_rst_reg_level)), f'(* shreg_extract = "no" *) reg ap_rst_reg_{ap_rst_reg_level};', f'wire ap_rst_reg = ap_rst_reg_{ap_rst_reg_level};', ) if ap_rst_reg_level > 0: with printer.always('posedge ap_clk'): printer.printlns(f'ap_rst_reg_{i + 1} <= ap_rst_reg_{i};' for i in range(ap_rst_reg_level)) with printer.always('posedge ap_clk'): with printer.if_('ap_rst_reg'): printer.printlns( "ap_done <= 1'b0;", "ap_idle <= 1'b1;", "ap_ready <= 1'b0;", ) printer.else_() printer.println('ap_idle <= ~ap_start;') printer.println() if interface == 'm_axi': # used by cosim for deadlock detection printer.printlns(f'reg {port_name}_V_V{not_block};' for port_name, _, _, _ in ports) with printer.always('*'): printer.printlns( *(f'{port_name}_V_V{not_block} = {port_name}_V_V{not_full};' for port_name, _, _, _ in outputs), *(f'{port_name}_V_V{not_block} = {port_name}_V_V{not_empty};' for port_name, _, _, _ in inputs), ) printer.println() fifos: Set[Tuple[int, int]] = set() # used for printing FIFO modules if interface == 'axis': for port_name, _, width, _ in inputs: printer.module_instance( 'fifo_w{width}_d{depth}_A'.format(width=width, depth=2), port_name + '_fifo', args={ 'clk': 'ap_clk', 'reset': 'ap_rst_reg', 'if_read_ce': "1'b1", 'if_write_ce': "1'b1", f'if{data_in}': f'{port_name}{data}', f'if{not_full}': f'{port_name}{ready}', f'if{write_enable}': f'{port_name}{valid}', f'if{data_out}': f'{port_name}_V_V{data_out}', f'if{not_empty}': f'{port_name}_V_V{not_empty}', f'if{read_enable}': f'{port_name}_V_V{read_enable}', }, ) fifos.add((width, 2)) printer.println() for port_name, _, width, _ in outputs: printer.module_instance( f'fifo_w{width}_d2_A', port_name + '_fifo', args={ 'clk': 'ap_clk', 'reset': 'ap_rst_reg', 'if_read_ce': "1'b1", 'if_write_ce': "1'b1", f'if{data_in}': f'{port_name}_V_V{data_in}', f'if{not_full}': f'{port_name}_V_V{not_full}', f'if{write_enable}': f'{port_name}_V_V{write_enable}', f'if{data_out}': f'{port_name}{data}', f'if{not_empty}': f'{port_name}{valid}', f'if{read_enable}': f'{port_name}{ready}', }, ) fifos.add((width, 2)) printer.println() # print FIFO instances for module in super_source.tpo_valid_node_gen(): for fifo in module.fifos: name = fifo.c_expr msb = fifo.width_in_bits - 1 # fifo.depth is the "extra" capacity of a FIFO; the base depth is 3, 1 for # registering the input, 1 for keeping II=1 when FIFO is (almost) full, 1 # for keeping II=1 when FIFO is relaxed from back pressure (necessary # because the optimal FIFO depths may require back pressure) depth = fifo.depth + 3 printer.printlns( f'wire [{msb}:0] {name}{data_in};', f'wire {name}{not_full};', f'wire {name}{write_enable};', f'wire [{msb}:0] {name}{data_out};', f'wire {name}{not_empty};', f'wire {name}{read_enable};', '', ) printer.module_instance( f'fifo_w{fifo.width_in_bits}_d{depth}_A', name, args={ 'clk': 'ap_clk', 'reset': 'ap_rst_reg', 'if_read_ce': "1'b1", 'if_write_ce': "1'b1", f'if{data_in}': f'{name}{data_in}', f'if{not_full}': f'{name}{not_full}', f'if{write_enable}': f'{name}{write_enable}', f'if{data_out}': f'{name}{data_out}', f'if{not_empty}': f'{name}{not_empty}', f'if{read_enable}': f'{name}{read_enable}', }, ) fifos.add((fifo.width_in_bits, depth)) printer.println() # print module instances for module in super_source.tpo_valid_node_gen(): module_trait, module_trait_id = super_source.module_table[module] arg_dict = { 'ap_clk': 'ap_clk', 'ap_rst': 'ap_rst_reg', 'ap_start': "1'b1", } for dram_ref, bank in module.dram_writes: port = dram_ref.dram_fifo_name(bank) fifo = util.get_port_name(dram_ref.var, bank) arg_dict.update({ f'{port}_V{data_in}': f'{fifo}_V_V{data_in}', f'{port}_V{not_full}': f'{fifo}_V_V{not_full}', f'{port}_V{write_enable}': f'{fifo}_V_V{write_enable}', }) for port, fifo in zip(module_trait.output_fifos, module.output_fifos): arg_dict.update({ f'{port}_V{data_in}': f'{fifo}{data_in}', f'{port}_V{not_full}': f'{fifo}{not_full}', f'{port}_V{write_enable}': f'{fifo}{write_enable}', }) for port, fifo in zip(module_trait.input_fifos, module.input_fifos): arg_dict.update({ f'{port}_V{data_out}': f"{{1'b1, {fifo}{data_out}}}", f'{port}_V{not_empty}': f'{fifo}{not_empty}', f'{port}_V{read_enable}': f'{fifo}{read_enable}', }) for dram_ref, bank in module.dram_reads: port = dram_ref.dram_fifo_name(bank) fifo = util.get_port_name(dram_ref.var, bank) arg_dict.update({ f'{port}_V{data_out}': f"{{1'b1, {fifo}_V_V{data_out}}}", f'{port}_V{not_empty}': f'{fifo}_V_V{not_empty}', f'{port}_V{read_enable}': f'{fifo}_V_V{read_enable}', }) printer.module_instance(util.get_func_name(module_trait_id), module.name, arg_dict) printer.println() printer.endmodule() printer.println('`default_nettype wire') # print FIFO modules for fifo in fifos: printer.fifo_module(*fifo)
def print_code(stencil, xo_file, platform=None, jobs=os.cpu_count()): """Generate hardware object file for the given Stencil. Working `vivado` and `vivado_hls` is required in the PATH. Args: stencil: Stencil object to generate from. xo_file: file object to write to. platform: path to the SDAccel platform directory. jobs: maximum number of jobs running in parallel. """ m_axi_names = [] m_axi_bundles = [] inputs = [] outputs = [] for stmt in stencil.output_stmts + stencil.input_stmts: for bank in stmt.dram: haoda_type = 'uint%d' % stencil.burst_width bundle_name = util.get_bundle_name(stmt.name, bank) m_axi_names.append(bundle_name) m_axi_bundles.append((bundle_name, haoda_type)) for stmt in stencil.output_stmts: for bank in stmt.dram: haoda_type = 'uint%d' % stencil.burst_width bundle_name = util.get_bundle_name(stmt.name, bank) outputs.append((util.get_port_name(stmt.name, bank), bundle_name, haoda_type, util.get_port_buf_name(stmt.name, bank))) for stmt in stencil.input_stmts: for bank in stmt.dram: haoda_type = 'uint%d' % stencil.burst_width bundle_name = util.get_bundle_name(stmt.name, bank) inputs.append((util.get_port_name(stmt.name, bank), bundle_name, haoda_type, util.get_port_buf_name(stmt.name, bank))) top_name = stencil.app_name + '_kernel' if 'XDEVICE' in os.environ: xdevice = os.environ['XDEVICE'].replace(':', '_').replace('.', '_') if platform is None or not os.path.exists(platform): platform = os.path.join('/opt/xilinx/platforms', xdevice) if platform is None or not os.path.exists(platform): if 'XILINX_SDX' in os.environ: platform = os.path.join(os.environ['XILINX_SDX'], 'platforms', xdevice) if platform is None or not os.path.exists(platform): raise ValueError('Cannot determine platform from environment.') device_info = backend.get_device_info(platform) with tempfile.TemporaryDirectory(prefix='sodac-xrtl-') as tmpdir: dataflow_kernel = os.path.join(tmpdir, 'dataflow_kernel.cpp') with open(dataflow_kernel, 'w') as dataflow_kernel_obj: print_dataflow_hls_interface(util.Printer(dataflow_kernel_obj), top_name, inputs, outputs) kernel_xml = os.path.join(tmpdir, 'kernel.xml') with open(kernel_xml, 'w') as kernel_xml_obj: backend.print_kernel_xml(top_name, outputs + inputs, kernel_xml_obj) kernel_file = os.path.join(tmpdir, 'kernel.cpp') with open(kernel_file, 'w') as kernel_fileobj: hls_kernel.print_code(stencil, kernel_fileobj) super_source = stencil.dataflow_super_source with concurrent.futures.ThreadPoolExecutor( max_workers=jobs) as executor: threads = [] for module_id in range(len(super_source.module_traits)): threads.append( executor.submit(synthesis_module, tmpdir, [kernel_file], util.get_func_name(module_id), device_info)) threads.append( executor.submit(synthesis_module, tmpdir, [dataflow_kernel], top_name, device_info)) for future in concurrent.futures.as_completed(threads): returncode, stdout, stderr = future.result() log_func = _logger.error if returncode != 0 else _logger.debug if stdout: log_func(stdout.decode()) if stderr: log_func(stderr.decode()) if returncode != 0: util.pause_for_debugging() sys.exit(returncode) hdl_dir = os.path.join(tmpdir, 'hdl') with open(os.path.join(hdl_dir, 'Dataflow.v'), mode='w') as dataflow_v: print_top_module(backend.VerilogPrinter(dataflow_v), stencil.dataflow_super_source, inputs, outputs) util.pause_for_debugging() xo_filename = os.path.join(tmpdir, stencil.app_name + '.xo') with backend.PackageXo(xo_filename, top_name, kernel_xml, hdl_dir, m_axi_names, [dataflow_kernel]) as proc: stdout, stderr = proc.communicate() log_func = _logger.error if proc.returncode != 0 else _logger.debug log_func(stdout.decode()) log_func(stderr.decode()) with open(xo_filename, mode='rb') as xo_fileobj: shutil.copyfileobj(xo_fileobj, xo_file)
def print_code( stencil: core.Stencil, xo_file: IO[bytes], device_info: Dict[str, str], jobs: Optional[int] = os.cpu_count(), rpt_file: Optional[str] = None, interface: str = 'm_axi', ) -> None: """Generate hardware object file for the given Stencil. Working `vivado` and `vivado_hls` is required in the PATH. Args: stencil: Stencil object to generate from. xo_file: file object to write to. device_info: dict of 'part_num' and 'clock_period'. jobs: maximum number of jobs running in parallel. rpt_file: path of the generated report; None disables report generation. interface: interface type, supported values are 'm_axi' and 'axis'. """ iface_names = [] # for axis m_axi_names = [] # for m_axi inputs = [] outputs = [] for stmt in stencil.output_stmts: for bank in stmt.dram: port_name = util.get_port_name(stmt.name, bank) bundle_name = util.get_bundle_name(stmt.name, bank) iface_names.append(port_name) m_axi_names.append(bundle_name) outputs.append((port_name, bundle_name, stencil.burst_width, util.get_port_buf_name(stmt.name, bank))) for stmt in stencil.input_stmts: for bank in stmt.dram: port_name = util.get_port_name(stmt.name, bank) bundle_name = util.get_bundle_name(stmt.name, bank) iface_names.append(port_name) m_axi_names.append(bundle_name) inputs.append((port_name, bundle_name, stencil.burst_width, util.get_port_buf_name(stmt.name, bank))) top_name = stencil.kernel_name with tempfile.TemporaryDirectory(prefix='sodac-xrtl-') as tmpdir: kernel_xml = os.path.join(tmpdir, 'kernel.xml') with open(kernel_xml, 'w') as kernel_xml_obj: print_kernel_xml(top_name, inputs, outputs, kernel_xml_obj, interface) kernel_file = os.path.join(tmpdir, 'kernel.cpp') with open(kernel_file, 'w') as kernel_fileobj: hls_kernel.print_code(stencil, kernel_fileobj) args = [] for module_trait_id, module_trait in enumerate(stencil.module_traits): sio = io.StringIO() hls_kernel.print_module_definition(util.CppPrinter(sio), module_trait, module_trait_id, burst_width=stencil.burst_width) args.append( (len(sio.getvalue()), synthesis_module, tmpdir, [kernel_file], util.get_func_name(module_trait_id), device_info)) if interface == 'm_axi': sio = io.StringIO() print_dataflow_hls_interface(util.CppPrinter(sio), top_name, inputs, outputs) dataflow_kernel = os.path.join(tmpdir, 'dataflow_kernel.cpp') with open(dataflow_kernel, 'w') as dataflow_kernel_obj: dataflow_kernel_obj.write(sio.getvalue()) args.append((len(sio.getvalue()), synthesis_module, tmpdir, [dataflow_kernel], top_name, device_info)) args.sort(key=lambda x: x[0], reverse=True) super_source = stencil.dataflow_super_source job_server = util.release_job_slot() with concurrent.futures.ThreadPoolExecutor( max_workers=jobs) as executor: threads = [executor.submit(*x[1:]) for x in args] for future in concurrent.futures.as_completed(threads): returncode, stdout, stderr = future.result() log_func = _logger.error if returncode != 0 else _logger.debug if stdout: log_func(stdout.decode()) if stderr: log_func(stderr.decode()) if returncode != 0: util.pause_for_debugging() sys.exit(returncode) util.acquire_job_slot(job_server) # generate HLS report depths: Dict[int, int] = {} hls_resources = hls_report.HlsResources() if interface == 'm_axi': hls_resources = hls_report.resources( os.path.join(tmpdir, 'report', top_name + '_csynth.xml')) hls_resources -= hls_report.resources( os.path.join(tmpdir, 'report', 'Dataflow_csynth.xml')) _logger.info(hls_resources) for module_id, nodes in enumerate( super_source.module_trait_table.values()): module_name = util.get_func_name(module_id) report_file = os.path.join(tmpdir, 'report', module_name + '_csynth.xml') hls_resource = hls_report.resources(report_file) use_count = len(nodes) try: perf = hls_report.performance(report_file) _logger.info('%s, usage: %5d times, II: %3d, Depth: %3d', hls_resource, use_count, perf.ii, perf.depth) depths[module_id] = perf.depth except hls_report.BadReport as e: _logger.warn('%s in %s report (%s)', e, module_name, report_file) _logger.info('%s, usage: %5d times', hls_resource, use_count) raise e hls_resources += hls_resource * use_count _logger.info('total usage:') _logger.info(hls_resources) if rpt_file: rpt_json = collections.OrderedDict([('name', top_name)] + list(hls_resources)) with open(rpt_file, mode='w') as rpt_fileobj: json.dump(rpt_json, rpt_fileobj, indent=2) # update the module pipeline depths stencil.dataflow_super_source.update_module_depths(depths) hdl_dir = os.path.join(tmpdir, 'hdl') module_name = 'Dataflow' if interface == 'axis': module_name = top_name with open(os.path.join(hdl_dir, f'{module_name}.v'), mode='w') as fileobj: print_top_module( backend.VerilogPrinter(fileobj), stencil.dataflow_super_source, inputs, outputs, module_name, interface, ) util.pause_for_debugging() xo_filename = os.path.join(tmpdir, stencil.app_name + '.xo') kwargs = {} if interface == 'm_axi': kwargs['m_axi_names'] = m_axi_names elif interface == 'axis': kwargs['iface_names'] = iface_names with backend.PackageXo( xo_filename, top_name, kernel_xml, hdl_dir, **kwargs, ) as proc: stdout, stderr = proc.communicate() log_func = _logger.error if proc.returncode != 0 else _logger.debug log_func(stdout.decode()) log_func(stderr.decode()) with open(xo_filename, mode='rb') as xo_fileobj: shutil.copyfileobj(xo_fileobj, xo_file)
def print_top_module(printer, super_source, inputs, outputs): println = printer.println println('`timescale 1 ns / 1 ps') args = [ 'ap_clk', 'ap_rst', 'ap_start', 'ap_done', 'ap_continue', 'ap_idle', 'ap_ready' ] for port_name, _, _, _ in outputs: args.append('{}_V_V{data_in}'.format(port_name, **FIFO_PORT_SUFFIXES)) args.append('{}_V_V{not_full}'.format(port_name, **FIFO_PORT_SUFFIXES)) args.append('{}_V_V{write_enable}'.format(port_name, **FIFO_PORT_SUFFIXES)) for port_name, _, _, _ in inputs: args.append('{}_V_V{data_out}'.format(port_name, **FIFO_PORT_SUFFIXES)) args.append('{}_V_V{not_empty}'.format(port_name, **FIFO_PORT_SUFFIXES)) args.append('{}_V_V{read_enable}'.format(port_name, **FIFO_PORT_SUFFIXES)) printer.module('Dataflow', args) println() input_args = 'ap_clk', 'ap_rst', 'ap_start', 'ap_continue' output_args = 'ap_done', 'ap_idle', 'ap_ready' for arg in input_args: println('input %s;' % arg) for arg in output_args: println('output %s;' % arg) for port_name, _, haoda_type, _ in outputs: kwargs = dict(port_name=port_name, **FIFO_PORT_SUFFIXES) println('output [{}:0] {port_name}_V_V{data_in};'.format( util.get_width_in_bits(haoda_type) - 1, **kwargs)) println('input {port_name}_V_V{not_full};'.format(**kwargs)) println('output {port_name}_V_V{write_enable};'.format(**kwargs)) for port_name, _, haoda_type, _ in inputs: kwargs = dict(port_name=port_name, **FIFO_PORT_SUFFIXES) println('input [{}:0] {port_name}_V_V{data_out};'.format( util.get_width_in_bits(haoda_type) - 1, **kwargs)) println('input {port_name}_V_V{not_empty};'.format(**kwargs)) println('output {port_name}_V_V{read_enable};'.format(**kwargs)) println() println("reg ap_done = 1'b0;") println("reg ap_idle = 1'b1;") println("reg ap_ready = 1'b0;") for port_name, _, haoda_type, _ in outputs: kwargs = dict(port_name=port_name, **FIFO_PORT_SUFFIXES) println('reg [{}:0] {port_name}{data_in};'.format( util.get_width_in_bits(haoda_type) - 1, **kwargs)) println('wire {port_name}_V_V{write_enable};'.format(**kwargs)) for port_name, _, haoda_type, _ in inputs: println('wire {}_V_V{read_enable};'.format(port_name, **FIFO_PORT_SUFFIXES)) println('reg ap_rst_n_inv;') with printer.always('*'): println('ap_rst_n_inv = ap_rst;') println() with printer.always('posedge ap_clk'): with printer.if_('ap_rst'): println("ap_done <= 1'b0;") println("ap_idle <= 1'b1;") println("ap_ready <= 1'b0;") printer.else_() println('ap_idle <= ~ap_start;') for port_name, _, _, _ in outputs: println('reg {}_V_V{not_block};'.format(port_name, **FIFO_PORT_SUFFIXES)) for port_name, _, _, _ in inputs: println('reg {}_V_V{not_block};'.format(port_name, **FIFO_PORT_SUFFIXES)) with printer.always('*'): for port_name, _, _, _ in outputs: println('{port_name}_V_V{not_block} = {port_name}_V_V{not_full};'. format(port_name=port_name, **FIFO_PORT_SUFFIXES)) for port_name, _, _, _ in inputs: println('{port_name}_V_V{not_block} = {port_name}_V_V{not_empty};'. format(port_name=port_name, **FIFO_PORT_SUFFIXES)) println() for module in super_source.tpo_node_gen(): for fifo in module.fifos: kwargs = { 'name': fifo.c_expr, 'msb': fifo.width_in_bits - 1, **FIFO_PORT_SUFFIXES } println('wire [{msb}:0] {name}{data_in};'.format(**kwargs)) println('wire {name}{not_full};'.format(**kwargs)) println('wire {name}{write_enable};'.format(**kwargs)) println('wire [{msb}:0] {name}{data_out};'.format(**kwargs)) println('wire {name}{not_empty};'.format(**kwargs)) println('wire {name}{read_enable};'.format(**kwargs)) println() args = collections.OrderedDict( (('clk', 'ap_clk'), ('reset', 'ap_rst_n_inv'), ('if_read_ce', "1'b1"), ('if_write_ce', "1'b1"), ('if{data_in}'.format(**kwargs), '{name}{data_in}'.format(**kwargs)), ('if{not_full}'.format(**kwargs), '{name}{not_full}'.format(**kwargs)), ('if{write_enable}'.format(**kwargs), '{name}{write_enable}'.format(**kwargs)), ('if{data_out}'.format(**kwargs), '{name}{data_out}'.format(**kwargs)), ('if{not_empty}'.format(**kwargs), '{name}{not_empty}'.format(**kwargs)), ('if{read_enable}'.format(**kwargs), '{name}{read_enable}'.format(**kwargs)))) printer.module_instance( 'fifo_w{width}_d{depth}_A'.format(width=fifo.width_in_bits, depth=fifo.depth + 2), fifo.c_expr, args) println() for module in super_source.tpo_node_gen(): module_trait, module_trait_id = super_source.module_table[module] args = collections.OrderedDict( (('ap_clk', 'ap_clk'), ('ap_rst', 'ap_rst_n_inv'), ('ap_start', "1'b1"))) for dram_ref, bank in module.dram_writes: kwargs = dict(port=dram_ref.dram_fifo_name(bank), fifo=util.get_port_name(dram_ref.var, bank), **FIFO_PORT_SUFFIXES) args['{port}_V{data_in}'.format(**kwargs)] = \ '{fifo}_V_V{data_in}'.format(**kwargs) args['{port}_V{not_full}'.format(**kwargs)] = \ '{fifo}_V_V{not_full}'.format(**kwargs) args['{port}_V{write_enable}'.format(**kwargs)] = \ '{fifo}_V_V{write_enable}'.format(**kwargs) for port, fifo in zip(module_trait.output_fifos, module.output_fifos): kwargs = dict(port=port, fifo=fifo, **FIFO_PORT_SUFFIXES) args['{port}_V{data_in}'.format(**kwargs)] = \ '{fifo}{data_in}'.format(**kwargs) args['{port}_V{not_full}'.format(**kwargs)] = \ '{fifo}{not_full}'.format(**kwargs) args['{port}_V{write_enable}'.format(**kwargs)] = \ '{fifo}{write_enable}'.format(**kwargs) for port, fifo in zip(module_trait.input_fifos, module.input_fifos): kwargs = dict(port=port, fifo=fifo, **FIFO_PORT_SUFFIXES) args['{port}_V{data_out}'.format(**kwargs)] = \ "{{1'b1, {fifo}{data_out}}}".format(**kwargs) args['{port}_V{not_empty}'.format(**kwargs)] = \ '{fifo}{not_empty}'.format(**kwargs) args['{port}_V{read_enable}'.format(**kwargs)] = \ '{fifo}{read_enable}'.format(**kwargs) for dram_ref, bank in module.dram_reads: kwargs = dict(port=dram_ref.dram_fifo_name(bank), fifo=util.get_port_name(dram_ref.var, bank), **FIFO_PORT_SUFFIXES) args['{port}_V{data_out}'.format(**kwargs)] = \ "{{1'b1, {fifo}_V_V{data_out}}}".format(**kwargs) args['{port}_V{not_empty}'.format(**kwargs)] = \ '{fifo}_V_V{not_empty}'.format(**kwargs) args['{port}_V{read_enable}'.format(**kwargs)] = \ '{fifo}_V_V{read_enable}'.format(**kwargs) printer.module_instance(util.get_func_name(module_trait_id), module.name, args) println() printer.endmodule() fifos = set() for module in super_source.tpo_node_gen(): for fifo in module.fifos: fifos.add((fifo.width_in_bits, fifo.depth + 2)) for fifo in fifos: printer.fifo_module(*fifo)