def _print_module_func_call(printer, node, module_trait_id, **kwargs): println = printer.println print_func = printer.print_func func_name = util.get_func_name(module_trait_id) dram_reads = tuple('/* input*/ &' + util.get_port_buf_name(dram_ref.var, bank) for dram_ref, bank in node.dram_reads) dram_writes = tuple('/*output*/ &' + util.get_port_buf_name(dram_ref.var, bank) for dram_ref, bank in node.dram_writes) output_fifos = tuple('/*output*/ &' + _ for _ in node.output_fifos) input_fifos = tuple('/* input*/ &' + _ for _ in node.input_fifos) params = dram_writes + output_fifos + input_fifos + dram_reads print_func(func_name, params, suffix=';', align=0)
def _print_module_func_call(printer: util.CppPrinter, node: ir.Module, module_trait_id: int, interface: str) -> None: func_name = util.get_func_name(module_trait_id) if interface == 'm_axi': get_port_name = util.get_port_buf_name elif interface == 'axis': get_port_name = util.get_port_name dram_reads = tuple(' /* input*/ ' + get_port_name(dram_ref.var, bank) for dram_ref, bank in node.dram_reads) dram_writes = tuple(' /*output*/ ' + get_port_name(dram_ref.var, bank) for dram_ref, bank in node.dram_writes) output_fifos = tuple(' /*output*/ ' + _ for _ in node.output_fifos) input_fifos = tuple(' /* input*/ ' + _ for _ in node.input_fifos) params = dram_writes + output_fifos + input_fifos + dram_reads if interface in {'m_axi', 'axis'}: printer.print_func(func_name, params, suffix=';', align=0)
def print_module_definition( printer: util.CppPrinter, module_trait: ir.ModuleTrait, module_trait_id: int, burst_width: int, interface: str = SUPPORTED_INTERFACES[0], ) -> None: func_name = util.get_func_name(module_trait_id) func_lower_name = util.get_module_name(module_trait_id) delays: ir.DelayedRef = [] for let in module_trait.lets: let.visit(_get_delays, delays) for expr in module_trait.exprs: expr.visit(_get_delays, delays) _logger.debug('delays: %s', delays) ( fifo_loads, fifo_stores, dram_read_map, dram_write_map, num_bank_map, dram_reads, coalescing_factor, ii, ) = _process_accesses( module_trait, burst_width, interface, ) dram_rw_map = {**dram_read_map, **dram_write_map} # print function printer.print_func(f'void {func_name}', fifo_stores + fifo_loads, align=0) printer.do_scope(func_name) if interface == 'm_axi': printer.printlns( *(f'#pragma HLS data_pack variable = {dram_ref.dram_fifo_name(bank)}' for dram_ref, bank in module_trait.dram_writes + module_trait.dram_reads), *(f'#pragma HLS data_pack variable = {arg}' for arg in module_trait.output_fifos + module_trait.input_fifos), ) # print inter-iteration declarations printer.printlns(x.c_buf_decl for x in delays) printer.printlns(x.c_ptr_decl for x in delays) # print loop printer.println(f'{func_lower_name}:', indent=0) if interface in {'m_axi', 'axis'}: printer.println('for (bool enable = true; enable;)') else: printer.println('for (;;)') printer.do_scope(f'for {func_lower_name}') printer.printlns( f'#pragma HLS pipeline II = {ii}', *(f'#pragma HLS dependence variable = {delay.buf_name} inter false' for delay in delays), ) # print emptyness tests printer.println('if (%s)' % (' && '.join( f'!{fifo}.empty()' for fifo in [_.ld_name for _ in module_trait.loads] + [_.dram_fifo_name(bank) for _ in dram_reads for bank in _.dram]))) printer.do_scope('if not empty') # print intra-iteration declarations printer.printlns( f'{fifo_in.c_type} {fifo_in.ref_name};' for fifo_in in module_trait.loads) if interface in {'m_axi', 'axis'}: printer.printlns( f'ap_uint<{burst_width}> {dram.dram_buf_name(bank)};' for var, accesses in dram_rw_map.items() for dram in (next(iter(_.values())) for _ in accesses.values()) for bank in dram.dram) if interface in {'m_axi', 'axis'}: # print enable conditions if not dram_write_map: printer.printlns(f'const bool {fifo_in.ref_name}_enable = ' f'ReadData({fifo_in.ref_name}, {fifo_in.ld_name});' for fifo_in in module_trait.loads) printer.printlns( f'const bool {x.dram_buf_name(bank)}_enable = ' f'ReadData({x.dram_buf_name(bank)}, {x.dram_fifo_name(bank)});' for x in dram_reads for bank in x.dram) if not dram_write_map: printer.println( 'const bool enabled = %s;' % ' && '.join([f'{y.ref_name}_enable' for y in module_trait.loads] + [ f'{x.dram_buf_name(bank)}_enable' for x in dram_reads for bank in x.dram ])) printer.println('enable = enabled;') # print delays (if any) printer.printlns(f'const {x.c_type} {x.c_buf_load};' for x in delays) # mutate dram ref for writes if dram_write_map: for coalescing_idx in range(coalescing_factor): if interface in {'m_axi', 'axis'}: for fifo_in in module_trait.loads: if coalescing_idx == coalescing_factor - 1: prefix = f'const bool {fifo_in.ref_name}_enable = ' else: prefix = '' printer.println(f'{prefix}ReadData({fifo_in.ref_name},' f' {fifo_in.ld_name});') if coalescing_idx == coalescing_factor - 1: printer.printlns( 'const bool enabled = %s;' % ' && '.join([f'{x.ref_name}_enable' for x in module_trait.loads] + [ f'{x.dram_buf_name(bank)}_enable' for x in dram_reads for bank in x.dram ]), 'enable = enabled;', ) for idx, let in enumerate(module_trait.lets): let = let.visit( _mutate_dram_ref_for_writes, dict( coalescing_idx=coalescing_idx, unroll_factor=len(dram_write_map[let.name.var][let.name.dram]), num_bank_map=num_bank_map, interface=interface, )) if interface in {'m_axi', 'axis'}: printer.println( f'{let.name} = ' f'Reinterpret<ap_uint<{let.expr.haoda_type.width_in_bits}>>' f'({let.expr.c_expr});') if interface in {'m_axi', 'axis'}: printer.printlns( f'WriteData({dram.dram_fifo_name(bank)}, ' f'{dram.dram_buf_name(bank)}, enabled);' for var in dram_write_map for dram in ( next(iter(_.values())) for _ in dram_write_map[var].values()) for bank in dram.dram) else: printer.printlns(let.c_expr for let in module_trait.lets) # mutate dram ref for reads if dram_read_map: for coalescing_idx in range(coalescing_factor): for idx, expr in enumerate(module_trait.exprs): c_expr = expr.visit( _mutate_dram_ref_for_reads, dict( coalescing_idx=coalescing_idx, unroll_factor=len(dram_read_map[expr.var][expr.dram]), num_bank_map=num_bank_map, interface=interface, expr=expr, )).c_expr if interface in {'m_axi', 'axis'}: if coalescing_idx < coalescing_factor - 1: enabled = 'true' else: enabled = 'enabled' printer.println( f'WriteData({ir.FIFORef.ST_PREFIX}{idx}, {c_expr}, {enabled});') else: if interface in {'m_axi', 'axis'}: printer.printlns(f'WriteData({ir.FIFORef.ST_PREFIX}{idx}, ' f'{expr.c_type}({expr.c_expr}), enabled);' for idx, expr in enumerate(module_trait.exprs)) for delay in delays: printer.printlns( delay.c_buf_store, f'{delay.ptr} = {delay.c_next_ptr_expr};', ) printer.un_scope() printer.un_scope() printer.un_scope() printer.println() _logger.debug('printing: %s', module_trait)
def print_top_module( printer: backend.VerilogPrinter, super_source: dataflow.SuperSourceNode, inputs: Sequence[Tuple[str, str, int, str]], outputs: Sequence[Tuple[str, str, int, str]], module_name: str = 'Dataflow', interface: str = 'm_axi', ) -> None: """Generate kernel.xml file. Args: printer: printer to print to super_source: SuperSourceNode carrying the IR tree. inputs: sequence of (port_name, bundle_name, width, _) of input ports outputs: sequence of (port_name, bundle_name, width, _) of output ports module_name: name of the module interface: interface type, supported values are 'm_axi' and 'axis' """ printer.printlns('`timescale 1 ns / 1 ps', '`default_nettype none') ports = *inputs, *outputs # unpack suffixes data_in = FIFO_PORT_SUFFIXES['data_in'] not_full = FIFO_PORT_SUFFIXES['not_full'] write_enable = FIFO_PORT_SUFFIXES['write_enable'] data_out = FIFO_PORT_SUFFIXES['data_out'] not_empty = FIFO_PORT_SUFFIXES['not_empty'] read_enable = FIFO_PORT_SUFFIXES['read_enable'] not_block = FIFO_PORT_SUFFIXES['not_block'] data = AXIS_PORT_SUFFIXES['data'] valid = AXIS_PORT_SUFFIXES['valid'] ready = AXIS_PORT_SUFFIXES['ready'] # prepare arguments input_args = ['ap_clk'] output_args: List[str] = [] if interface == 'm_axi': input_args += 'ap_rst', 'ap_start', 'ap_continue' output_args += 'ap_done', 'ap_idle', 'ap_ready' elif interface == 'axis': input_args.append('ap_rst_n') args = list(input_args + output_args) if interface == 'm_axi': for port_name, _, _, _ in outputs: args.append(f'{port_name}_V_V{data_in}') args.append(f'{port_name}_V_V{not_full}') args.append(f'{port_name}_V_V{write_enable}') for port_name, _, _, _ in inputs: args.append(f'{port_name}_V_V{data_out}') args.append(f'{port_name}_V_V{not_empty}') args.append(f'{port_name}_V_V{read_enable}') elif interface == 'axis': for port_name, _, _, _ in ports: args.extend(port_name + suffix for suffix in AXIS_PORT_SUFFIXES.values()) # print module interface printer.module(module_name, args) printer.println() # print signals for modules printer.printlns( *(f'input wire {arg};' for arg in input_args), *(f'output wire {arg};' for arg in output_args), ) for port_name, _, width, _ in outputs: if interface == 'm_axi': printer.printlns( f'output wire [{width - 1}:0] {port_name}_V_V{data_in};', f'input wire {port_name}_V_V{not_full};', f'output wire {port_name}_V_V{write_enable};', ) elif interface == 'axis': printer.printlns( f'output wire [{width - 1}:0] {port_name}{data};', f'output wire {port_name}{valid};', f'input wire {port_name}{ready};', f'wire [{width - 1}:0] {port_name}_V_V{data_in};', f'wire {port_name}_V_V{not_full};', f'wire {port_name}_V_V{write_enable};', ) for port_name, _, width, _ in inputs: if interface == 'm_axi': printer.printlns( f'input wire [{width - 1}:0] {port_name}_V_V{data_out};', f'input wire {port_name}_V_V{not_empty};', f'output wire {port_name}_V_V{read_enable};', ) elif interface == 'axis': printer.printlns( f'input wire [{width - 1}:0] {port_name}{data};', f'input wire {port_name}{valid};', f'output wire {port_name}{ready};', f'wire [{width - 1}:0] {port_name}_V_V{data_out};', f'wire {port_name}_V_V{not_empty};', f'wire {port_name}_V_V{read_enable};', ) printer.println() # not used printer.printlns( "reg ap_done = 1'b0;", "reg ap_idle = 1'b1;", "reg ap_ready = 1'b0;", ) if interface == 'axis': printer.println("wire ap_start = 1'b1;") # print signals for FIFOs if interface == 'm_axi': for port_name, _, width, _ in outputs: printer.printlns( f'reg [{width - 1}:0] {port_name}{data_in};', f'wire {port_name}_V_V{write_enable};', ) for port_name, _, _, _ in inputs: printer.println(f'wire {port_name}_V_V{read_enable};') printer.println() # register reset signal ap_rst_reg_level = 8 rst = 'ap_rst' if interface == 'axis': rst = '~ap_rst_n' printer.printlns( f'wire ap_rst_reg_0 = {rst};', *(f'(* shreg_extract = "no", max_fanout = {8 ** i} *) reg ap_rst_reg_{i};' for i in range(1, ap_rst_reg_level)), f'(* shreg_extract = "no" *) reg ap_rst_reg_{ap_rst_reg_level};', f'wire ap_rst_reg = ap_rst_reg_{ap_rst_reg_level};', ) if ap_rst_reg_level > 0: with printer.always('posedge ap_clk'): printer.printlns(f'ap_rst_reg_{i + 1} <= ap_rst_reg_{i};' for i in range(ap_rst_reg_level)) with printer.always('posedge ap_clk'): with printer.if_('ap_rst_reg'): printer.printlns( "ap_done <= 1'b0;", "ap_idle <= 1'b1;", "ap_ready <= 1'b0;", ) printer.else_() printer.println('ap_idle <= ~ap_start;') printer.println() if interface == 'm_axi': # used by cosim for deadlock detection printer.printlns(f'reg {port_name}_V_V{not_block};' for port_name, _, _, _ in ports) with printer.always('*'): printer.printlns( *(f'{port_name}_V_V{not_block} = {port_name}_V_V{not_full};' for port_name, _, _, _ in outputs), *(f'{port_name}_V_V{not_block} = {port_name}_V_V{not_empty};' for port_name, _, _, _ in inputs), ) printer.println() fifos: Set[Tuple[int, int]] = set() # used for printing FIFO modules if interface == 'axis': for port_name, _, width, _ in inputs: printer.module_instance( 'fifo_w{width}_d{depth}_A'.format(width=width, depth=2), port_name + '_fifo', args={ 'clk': 'ap_clk', 'reset': 'ap_rst_reg', 'if_read_ce': "1'b1", 'if_write_ce': "1'b1", f'if{data_in}': f'{port_name}{data}', f'if{not_full}': f'{port_name}{ready}', f'if{write_enable}': f'{port_name}{valid}', f'if{data_out}': f'{port_name}_V_V{data_out}', f'if{not_empty}': f'{port_name}_V_V{not_empty}', f'if{read_enable}': f'{port_name}_V_V{read_enable}', }, ) fifos.add((width, 2)) printer.println() for port_name, _, width, _ in outputs: printer.module_instance( f'fifo_w{width}_d2_A', port_name + '_fifo', args={ 'clk': 'ap_clk', 'reset': 'ap_rst_reg', 'if_read_ce': "1'b1", 'if_write_ce': "1'b1", f'if{data_in}': f'{port_name}_V_V{data_in}', f'if{not_full}': f'{port_name}_V_V{not_full}', f'if{write_enable}': f'{port_name}_V_V{write_enable}', f'if{data_out}': f'{port_name}{data}', f'if{not_empty}': f'{port_name}{valid}', f'if{read_enable}': f'{port_name}{ready}', }, ) fifos.add((width, 2)) printer.println() # print FIFO instances for module in super_source.tpo_valid_node_gen(): for fifo in module.fifos: name = fifo.c_expr msb = fifo.width_in_bits - 1 # fifo.depth is the "extra" capacity of a FIFO; the base depth is 3, 1 for # registering the input, 1 for keeping II=1 when FIFO is (almost) full, 1 # for keeping II=1 when FIFO is relaxed from back pressure (necessary # because the optimal FIFO depths may require back pressure) depth = fifo.depth + 3 printer.printlns( f'wire [{msb}:0] {name}{data_in};', f'wire {name}{not_full};', f'wire {name}{write_enable};', f'wire [{msb}:0] {name}{data_out};', f'wire {name}{not_empty};', f'wire {name}{read_enable};', '', ) printer.module_instance( f'fifo_w{fifo.width_in_bits}_d{depth}_A', name, args={ 'clk': 'ap_clk', 'reset': 'ap_rst_reg', 'if_read_ce': "1'b1", 'if_write_ce': "1'b1", f'if{data_in}': f'{name}{data_in}', f'if{not_full}': f'{name}{not_full}', f'if{write_enable}': f'{name}{write_enable}', f'if{data_out}': f'{name}{data_out}', f'if{not_empty}': f'{name}{not_empty}', f'if{read_enable}': f'{name}{read_enable}', }, ) fifos.add((fifo.width_in_bits, depth)) printer.println() # print module instances for module in super_source.tpo_valid_node_gen(): module_trait, module_trait_id = super_source.module_table[module] arg_dict = { 'ap_clk': 'ap_clk', 'ap_rst': 'ap_rst_reg', 'ap_start': "1'b1", } for dram_ref, bank in module.dram_writes: port = dram_ref.dram_fifo_name(bank) fifo = util.get_port_name(dram_ref.var, bank) arg_dict.update({ f'{port}_V{data_in}': f'{fifo}_V_V{data_in}', f'{port}_V{not_full}': f'{fifo}_V_V{not_full}', f'{port}_V{write_enable}': f'{fifo}_V_V{write_enable}', }) for port, fifo in zip(module_trait.output_fifos, module.output_fifos): arg_dict.update({ f'{port}_V{data_in}': f'{fifo}{data_in}', f'{port}_V{not_full}': f'{fifo}{not_full}', f'{port}_V{write_enable}': f'{fifo}{write_enable}', }) for port, fifo in zip(module_trait.input_fifos, module.input_fifos): arg_dict.update({ f'{port}_V{data_out}': f"{{1'b1, {fifo}{data_out}}}", f'{port}_V{not_empty}': f'{fifo}{not_empty}', f'{port}_V{read_enable}': f'{fifo}{read_enable}', }) for dram_ref, bank in module.dram_reads: port = dram_ref.dram_fifo_name(bank) fifo = util.get_port_name(dram_ref.var, bank) arg_dict.update({ f'{port}_V{data_out}': f"{{1'b1, {fifo}_V_V{data_out}}}", f'{port}_V{not_empty}': f'{fifo}_V_V{not_empty}', f'{port}_V{read_enable}': f'{fifo}_V_V{read_enable}', }) printer.module_instance(util.get_func_name(module_trait_id), module.name, arg_dict) printer.println() printer.endmodule() printer.println('`default_nettype wire') # print FIFO modules for fifo in fifos: printer.fifo_module(*fifo)
def print_code( stencil: core.Stencil, xo_file: IO[bytes], device_info: Dict[str, str], jobs: Optional[int] = os.cpu_count(), rpt_file: Optional[str] = None, interface: str = 'm_axi', ) -> None: """Generate hardware object file for the given Stencil. Working `vivado` and `vivado_hls` is required in the PATH. Args: stencil: Stencil object to generate from. xo_file: file object to write to. device_info: dict of 'part_num' and 'clock_period'. jobs: maximum number of jobs running in parallel. rpt_file: path of the generated report; None disables report generation. interface: interface type, supported values are 'm_axi' and 'axis'. """ iface_names = [] # for axis m_axi_names = [] # for m_axi inputs = [] outputs = [] for stmt in stencil.output_stmts: for bank in stmt.dram: port_name = util.get_port_name(stmt.name, bank) bundle_name = util.get_bundle_name(stmt.name, bank) iface_names.append(port_name) m_axi_names.append(bundle_name) outputs.append((port_name, bundle_name, stencil.burst_width, util.get_port_buf_name(stmt.name, bank))) for stmt in stencil.input_stmts: for bank in stmt.dram: port_name = util.get_port_name(stmt.name, bank) bundle_name = util.get_bundle_name(stmt.name, bank) iface_names.append(port_name) m_axi_names.append(bundle_name) inputs.append((port_name, bundle_name, stencil.burst_width, util.get_port_buf_name(stmt.name, bank))) top_name = stencil.kernel_name with tempfile.TemporaryDirectory(prefix='sodac-xrtl-') as tmpdir: kernel_xml = os.path.join(tmpdir, 'kernel.xml') with open(kernel_xml, 'w') as kernel_xml_obj: print_kernel_xml(top_name, inputs, outputs, kernel_xml_obj, interface) kernel_file = os.path.join(tmpdir, 'kernel.cpp') with open(kernel_file, 'w') as kernel_fileobj: hls_kernel.print_code(stencil, kernel_fileobj) args = [] for module_trait_id, module_trait in enumerate(stencil.module_traits): sio = io.StringIO() hls_kernel.print_module_definition(util.CppPrinter(sio), module_trait, module_trait_id, burst_width=stencil.burst_width) args.append( (len(sio.getvalue()), synthesis_module, tmpdir, [kernel_file], util.get_func_name(module_trait_id), device_info)) if interface == 'm_axi': sio = io.StringIO() print_dataflow_hls_interface(util.CppPrinter(sio), top_name, inputs, outputs) dataflow_kernel = os.path.join(tmpdir, 'dataflow_kernel.cpp') with open(dataflow_kernel, 'w') as dataflow_kernel_obj: dataflow_kernel_obj.write(sio.getvalue()) args.append((len(sio.getvalue()), synthesis_module, tmpdir, [dataflow_kernel], top_name, device_info)) args.sort(key=lambda x: x[0], reverse=True) super_source = stencil.dataflow_super_source job_server = util.release_job_slot() with concurrent.futures.ThreadPoolExecutor( max_workers=jobs) as executor: threads = [executor.submit(*x[1:]) for x in args] for future in concurrent.futures.as_completed(threads): returncode, stdout, stderr = future.result() log_func = _logger.error if returncode != 0 else _logger.debug if stdout: log_func(stdout.decode()) if stderr: log_func(stderr.decode()) if returncode != 0: util.pause_for_debugging() sys.exit(returncode) util.acquire_job_slot(job_server) # generate HLS report depths: Dict[int, int] = {} hls_resources = hls_report.HlsResources() if interface == 'm_axi': hls_resources = hls_report.resources( os.path.join(tmpdir, 'report', top_name + '_csynth.xml')) hls_resources -= hls_report.resources( os.path.join(tmpdir, 'report', 'Dataflow_csynth.xml')) _logger.info(hls_resources) for module_id, nodes in enumerate( super_source.module_trait_table.values()): module_name = util.get_func_name(module_id) report_file = os.path.join(tmpdir, 'report', module_name + '_csynth.xml') hls_resource = hls_report.resources(report_file) use_count = len(nodes) try: perf = hls_report.performance(report_file) _logger.info('%s, usage: %5d times, II: %3d, Depth: %3d', hls_resource, use_count, perf.ii, perf.depth) depths[module_id] = perf.depth except hls_report.BadReport as e: _logger.warn('%s in %s report (%s)', e, module_name, report_file) _logger.info('%s, usage: %5d times', hls_resource, use_count) raise e hls_resources += hls_resource * use_count _logger.info('total usage:') _logger.info(hls_resources) if rpt_file: rpt_json = collections.OrderedDict([('name', top_name)] + list(hls_resources)) with open(rpt_file, mode='w') as rpt_fileobj: json.dump(rpt_json, rpt_fileobj, indent=2) # update the module pipeline depths stencil.dataflow_super_source.update_module_depths(depths) hdl_dir = os.path.join(tmpdir, 'hdl') module_name = 'Dataflow' if interface == 'axis': module_name = top_name with open(os.path.join(hdl_dir, f'{module_name}.v'), mode='w') as fileobj: print_top_module( backend.VerilogPrinter(fileobj), stencil.dataflow_super_source, inputs, outputs, module_name, interface, ) util.pause_for_debugging() xo_filename = os.path.join(tmpdir, stencil.app_name + '.xo') kwargs = {} if interface == 'm_axi': kwargs['m_axi_names'] = m_axi_names elif interface == 'axis': kwargs['iface_names'] = iface_names with backend.PackageXo( xo_filename, top_name, kernel_xml, hdl_dir, **kwargs, ) as proc: stdout, stderr = proc.communicate() log_func = _logger.error if proc.returncode != 0 else _logger.debug log_func(stdout.decode()) log_func(stderr.decode()) with open(xo_filename, mode='rb') as xo_fileobj: shutil.copyfileobj(xo_fileobj, xo_file)
def _print_module_definition(printer, module_trait, module_trait_id, **kwargs): println = printer.println do_scope = printer.do_scope un_scope = printer.un_scope func_name = util.get_func_name(module_trait_id) func_lower_name = util.get_module_name(module_trait_id) ii = 1 def get_delays(obj, delays): if isinstance(obj, ir.DelayedRef): delays.append(obj) return obj delays = [] for let in module_trait.lets: let.visit(get_delays, delays) for expr in module_trait.exprs: expr.visit(get_delays, delays) _logger.debug('delays: %s', delays) fifo_loads = tuple( '/* input*/ hls::stream<Data<{} > >* {}'.format(_.c_type, _.ld_name) for _ in module_trait.loads) fifo_stores = tuple('/*output*/ hls::stream<Data<{} > >* {}{}'.format( expr.c_type, ir.FIFORef.ST_PREFIX, idx) for idx, expr in enumerate(module_trait.exprs)) # look for DRAM access reads_in_lets = tuple(_.expr for _ in module_trait.lets) writes_in_lets = tuple(_.name for _ in module_trait.lets if not isinstance(_.name, str)) reads_in_exprs = module_trait.exprs dram_reads = visitor.get_dram_refs(reads_in_lets + reads_in_exprs) dram_writes = visitor.get_dram_refs(writes_in_lets) dram_read_map = collections.OrderedDict() dram_write_map = collections.OrderedDict() all_dram_reads = () num_bank_map = {} if dram_reads: # this is an unpacking module assert not dram_writes, 'cannot read and write DRAM in the same module' for dram_read in dram_reads: dram_read_map.setdefault(dram_read.var, collections.OrderedDict()).setdefault( dram_read.dram, []).append(dram_read) _logger.debug('dram read map: %s', dram_read_map) burst_width = kwargs.pop('burst_width') for var in dram_read_map: for dram in dram_read_map[var]: # number of elements per cycle batch_size = len(dram_read_map[var][dram]) dram_read_map[var][dram] = collections.OrderedDict( (_.offset, _) for _ in dram_read_map[var][dram]) dram_reads = dram_read_map[var][dram] num_banks = len(next(iter(dram_reads.values())).dram) if var in num_bank_map: assert num_bank_map[ var] == num_banks, 'inconsistent num banks' else: num_bank_map[var] = num_banks _logger.debug('dram reads: %s', dram_reads) assert tuple(sorted(dram_reads.keys())) == tuple(range(batch_size)), \ 'unexpected DRAM accesses pattern %s' % dram_reads batch_width = sum( util.get_width_in_bits(_.haoda_type) for _ in dram_reads.values()) del dram_reads if burst_width * num_banks >= batch_width: assert burst_width * num_banks % batch_width == 0, \ 'cannot process such a burst' # a single burst consumed in multiple cycles coalescing_factor = burst_width * num_banks // batch_width ii = coalescing_factor else: assert batch_width * num_banks % burst_width == 0, \ 'cannot process such a burst' # multiple bursts consumed in a single cycle # reassemble_factor = batch_width // (burst_width * num_banks) raise util.InternalError('cannot process such a burst yet') dram_reads = tuple( next(iter(_.values())) for _ in dram_read_map[var].values()) all_dram_reads += dram_reads fifo_loads += tuple( '/* input*/ hls::stream<Data<ap_uint<{burst_width} > > >* ' '{bank_name}'.format(burst_width=burst_width, bank_name=_.dram_fifo_name(bank)) for _ in dram_reads for bank in _.dram) elif dram_writes: # this is a packing module for dram_write in dram_writes: dram_write_map.setdefault(dram_write.var, collections.OrderedDict()).setdefault( dram_write.dram, []).append(dram_write) _logger.debug('dram write map: %s', dram_write_map) burst_width = kwargs.pop('burst_width') for var in dram_write_map: for dram in dram_write_map[var]: # number of elements per cycle batch_size = len(dram_write_map[var][dram]) dram_write_map[var][dram] = collections.OrderedDict( (_.offset, _) for _ in dram_write_map[var][dram]) dram_writes = dram_write_map[var][dram] num_banks = len(next(iter(dram_writes.values())).dram) if var in num_bank_map: assert num_bank_map[ var] == num_banks, 'inconsistent num banks' else: num_bank_map[var] = num_banks _logger.debug('dram writes: %s', dram_writes) assert tuple(sorted(dram_writes.keys())) == tuple(range(batch_size)), \ 'unexpected DRAM accesses pattern %s' % dram_writes batch_width = sum( util.get_width_in_bits(_.haoda_type) for _ in dram_writes.values()) del dram_writes if burst_width * num_banks >= batch_width: assert burst_width * num_banks % batch_width == 0, \ 'cannot process such a burst' # a single burst consumed in multiple cycles coalescing_factor = burst_width * num_banks // batch_width ii = coalescing_factor else: assert batch_width * num_banks % burst_width == 0, \ 'cannot process such a burst' # multiple bursts consumed in a single cycle # reassemble_factor = batch_width // (burst_width * num_banks) raise util.InternalError('cannot process such a burst yet') dram_writes = tuple( next(iter(_.values())) for _ in dram_write_map[var].values()) fifo_stores += tuple( '/*output*/ hls::stream<Data<ap_uint<{burst_width} > > >* ' '{bank_name}'.format(burst_width=burst_width, bank_name=_.dram_fifo_name(bank)) for _ in dram_writes for bank in _.dram) # print function printer.print_func('void {func_name}'.format(**locals()), fifo_stores + fifo_loads, align=0) do_scope(func_name) for dram_ref, bank in module_trait.dram_writes: println( '#pragma HLS data_pack variable = {}'.format( dram_ref.dram_fifo_name(bank)), 0) for arg in module_trait.output_fifos: println('#pragma HLS data_pack variable = %s' % arg, 0) for arg in module_trait.input_fifos: println('#pragma HLS data_pack variable = %s' % arg, 0) for dram_ref, bank in module_trait.dram_reads: println( '#pragma HLS data_pack variable = {}'.format( dram_ref.dram_fifo_name(bank)), 0) # print inter-iteration declarations for delay in delays: println(delay.c_buf_decl) println(delay.c_ptr_decl) # print loop println('{}_epoch:'.format(func_lower_name), indent=0) println('for (bool enable = true; enable;)') do_scope('for {}_epoch'.format(func_lower_name)) println('#pragma HLS pipeline II=%d' % ii, 0) for delay in delays: println( '#pragma HLS dependence variable=%s inter false' % delay.buf_name, 0) # print emptyness tests println( 'if (%s)' % (' && '.join('!{fifo}->empty()'.format(fifo=fifo) for fifo in tuple(_.ld_name for _ in module_trait.loads) + tuple( _.dram_fifo_name(bank) for _ in all_dram_reads for bank in _.dram)))) do_scope('if not empty') # print intra-iteration declarations for fifo_in in module_trait.loads: println('{fifo_in.c_type} {fifo_in.ref_name};'.format(**locals())) for var in dram_read_map: for dram in (next(iter(_.values())) for _ in dram_read_map[var].values()): for bank in dram.dram: println('ap_uint<{}> {};'.format(burst_width, dram.dram_buf_name(bank))) for var in dram_write_map: for dram in (next(iter(_.values())) for _ in dram_write_map[var].values()): for bank in dram.dram: println('ap_uint<{}> {};'.format(burst_width, dram.dram_buf_name(bank))) # print enable conditions if not dram_write_map: for fifo_in in module_trait.loads: println('const bool {fifo_in.ref_name}_enable = ' 'ReadData(&{fifo_in.ref_name}, {fifo_in.ld_name});'.format( **locals())) for dram in all_dram_reads: for bank in dram.dram: println('const bool {dram_buf_name}_enable = ' 'ReadData(&{dram_buf_name}, {dram_fifo_name});'.format( dram_buf_name=dram.dram_buf_name(bank), dram_fifo_name=dram.dram_fifo_name(bank))) if not dram_write_map: println('const bool enabled = %s;' % (' && '.join( tuple('{_.ref_name}_enable'.format(_=_) for _ in module_trait.loads) + tuple('{}_enable'.format(_.dram_buf_name(bank)) for _ in all_dram_reads for bank in _.dram)))) println('enable = enabled;') # print delays (if any) for delay in delays: println('const {} {};'.format(delay.c_type, delay.c_buf_load)) # print lets def mutate_dram_ref_for_writes(obj, kwargs): if isinstance(obj, ir.DRAMRef): coalescing_idx = kwargs.pop('coalescing_idx') unroll_factor = kwargs.pop('unroll_factor') type_width = util.get_width_in_bits(obj.haoda_type) elem_idx = coalescing_idx * unroll_factor + obj.offset num_banks = num_bank_map[obj.var] bank = obj.dram[elem_idx % num_banks] lsb = (elem_idx // num_banks) * type_width msb = lsb + type_width - 1 return ir.Var(name='{}({msb}, {lsb})'.format( obj.dram_buf_name(bank), msb=msb, lsb=lsb), idx=()) return obj # mutate dram ref for writes if dram_write_map: for coalescing_idx in range(coalescing_factor): for fifo_in in module_trait.loads: if coalescing_idx == coalescing_factor - 1: prefix = 'const bool {fifo_in.ref_name}_enable = '.format( fifo_in=fifo_in) else: prefix = '' println('{prefix}ReadData(&{fifo_in.ref_name},' ' {fifo_in.ld_name});'.format(fifo_in=fifo_in, prefix=prefix)) if coalescing_idx == coalescing_factor - 1: println('const bool enabled = %s;' % (' && '.join( tuple('{_.ref_name}_enable'.format(_=_) for _ in module_trait.loads) + tuple('{}_enable'.format(_.dram_buf_name(bank)) for _ in dram_reads for bank in _.dram)))) println('enable = enabled;') for idx, let in enumerate(module_trait.lets): let = let.visit( mutate_dram_ref_for_writes, { 'coalescing_idx': coalescing_idx, 'unroll_factor': len(dram_write_map[let.name.var][let.name.dram]) }) println('{} = Reinterpret<ap_uint<{width} > >({});'.format( let.name, let.expr.c_expr, width=util.get_width_in_bits(let.expr.haoda_type))) for var in dram_write_map: for dram in (next(iter(_.values())) for _ in dram_write_map[var].values()): for bank in dram.dram: println('WriteData({}, {}, enabled);'.format( dram.dram_fifo_name(bank), dram.dram_buf_name(bank))) else: for let in module_trait.lets: println(let.c_expr) def mutate_dram_ref_for_reads(obj, kwargs): if isinstance(obj, ir.DRAMRef): coalescing_idx = kwargs.pop('coalescing_idx') unroll_factor = kwargs.pop('unroll_factor') type_width = util.get_width_in_bits(obj.haoda_type) elem_idx = coalescing_idx * unroll_factor + obj.offset num_banks = num_bank_map[obj.var] bank = expr.dram[elem_idx % num_banks] lsb = (elem_idx // num_banks) * type_width msb = lsb + type_width - 1 return ir.Var( name='Reinterpret<{c_type}>(static_cast<ap_uint<{width} > >(' '{dram_buf_name}({msb}, {lsb})))'.format( c_type=obj.c_type, dram_buf_name=obj.dram_buf_name(bank), msb=msb, lsb=lsb, width=msb - lsb + 1), idx=()) return obj # mutate dram ref for reads if dram_read_map: for coalescing_idx in range(coalescing_factor): for idx, expr in enumerate(module_trait.exprs): println('WriteData({}{}, {}, {});'.format( ir.FIFORef.ST_PREFIX, idx, expr.visit( mutate_dram_ref_for_reads, { 'coalescing_idx': coalescing_idx, 'unroll_factor': len(dram_read_map[expr.var][expr.dram]) }).c_expr, 'true' if coalescing_idx < coalescing_factor - 1 else 'enabled')) else: for idx, expr in enumerate(module_trait.exprs): println('WriteData({}{}, {}({}), enabled);'.format( ir.FIFORef.ST_PREFIX, idx, expr.c_type, expr.c_expr)) for delay in delays: println(delay.c_buf_store) println('{} = {};'.format(delay.ptr, delay.c_next_ptr_expr)) un_scope() un_scope() un_scope() _logger.debug('printing: %s', module_trait)
def print_code(stencil, xo_file, platform=None, jobs=os.cpu_count()): """Generate hardware object file for the given Stencil. Working `vivado` and `vivado_hls` is required in the PATH. Args: stencil: Stencil object to generate from. xo_file: file object to write to. platform: path to the SDAccel platform directory. jobs: maximum number of jobs running in parallel. """ m_axi_names = [] m_axi_bundles = [] inputs = [] outputs = [] for stmt in stencil.output_stmts + stencil.input_stmts: for bank in stmt.dram: haoda_type = 'uint%d' % stencil.burst_width bundle_name = util.get_bundle_name(stmt.name, bank) m_axi_names.append(bundle_name) m_axi_bundles.append((bundle_name, haoda_type)) for stmt in stencil.output_stmts: for bank in stmt.dram: haoda_type = 'uint%d' % stencil.burst_width bundle_name = util.get_bundle_name(stmt.name, bank) outputs.append((util.get_port_name(stmt.name, bank), bundle_name, haoda_type, util.get_port_buf_name(stmt.name, bank))) for stmt in stencil.input_stmts: for bank in stmt.dram: haoda_type = 'uint%d' % stencil.burst_width bundle_name = util.get_bundle_name(stmt.name, bank) inputs.append((util.get_port_name(stmt.name, bank), bundle_name, haoda_type, util.get_port_buf_name(stmt.name, bank))) top_name = stencil.app_name + '_kernel' if 'XDEVICE' in os.environ: xdevice = os.environ['XDEVICE'].replace(':', '_').replace('.', '_') if platform is None or not os.path.exists(platform): platform = os.path.join('/opt/xilinx/platforms', xdevice) if platform is None or not os.path.exists(platform): if 'XILINX_SDX' in os.environ: platform = os.path.join(os.environ['XILINX_SDX'], 'platforms', xdevice) if platform is None or not os.path.exists(platform): raise ValueError('Cannot determine platform from environment.') device_info = backend.get_device_info(platform) with tempfile.TemporaryDirectory(prefix='sodac-xrtl-') as tmpdir: dataflow_kernel = os.path.join(tmpdir, 'dataflow_kernel.cpp') with open(dataflow_kernel, 'w') as dataflow_kernel_obj: print_dataflow_hls_interface(util.Printer(dataflow_kernel_obj), top_name, inputs, outputs) kernel_xml = os.path.join(tmpdir, 'kernel.xml') with open(kernel_xml, 'w') as kernel_xml_obj: backend.print_kernel_xml(top_name, outputs + inputs, kernel_xml_obj) kernel_file = os.path.join(tmpdir, 'kernel.cpp') with open(kernel_file, 'w') as kernel_fileobj: hls_kernel.print_code(stencil, kernel_fileobj) super_source = stencil.dataflow_super_source with concurrent.futures.ThreadPoolExecutor( max_workers=jobs) as executor: threads = [] for module_id in range(len(super_source.module_traits)): threads.append( executor.submit(synthesis_module, tmpdir, [kernel_file], util.get_func_name(module_id), device_info)) threads.append( executor.submit(synthesis_module, tmpdir, [dataflow_kernel], top_name, device_info)) for future in concurrent.futures.as_completed(threads): returncode, stdout, stderr = future.result() log_func = _logger.error if returncode != 0 else _logger.debug if stdout: log_func(stdout.decode()) if stderr: log_func(stderr.decode()) if returncode != 0: util.pause_for_debugging() sys.exit(returncode) hdl_dir = os.path.join(tmpdir, 'hdl') with open(os.path.join(hdl_dir, 'Dataflow.v'), mode='w') as dataflow_v: print_top_module(backend.VerilogPrinter(dataflow_v), stencil.dataflow_super_source, inputs, outputs) util.pause_for_debugging() xo_filename = os.path.join(tmpdir, stencil.app_name + '.xo') with backend.PackageXo(xo_filename, top_name, kernel_xml, hdl_dir, m_axi_names, [dataflow_kernel]) as proc: stdout, stderr = proc.communicate() log_func = _logger.error if proc.returncode != 0 else _logger.debug log_func(stdout.decode()) log_func(stderr.decode()) with open(xo_filename, mode='rb') as xo_fileobj: shutil.copyfileobj(xo_fileobj, xo_file)
def print_top_module(printer, super_source, inputs, outputs): println = printer.println println('`timescale 1 ns / 1 ps') args = [ 'ap_clk', 'ap_rst', 'ap_start', 'ap_done', 'ap_continue', 'ap_idle', 'ap_ready' ] for port_name, _, _, _ in outputs: args.append('{}_V_V{data_in}'.format(port_name, **FIFO_PORT_SUFFIXES)) args.append('{}_V_V{not_full}'.format(port_name, **FIFO_PORT_SUFFIXES)) args.append('{}_V_V{write_enable}'.format(port_name, **FIFO_PORT_SUFFIXES)) for port_name, _, _, _ in inputs: args.append('{}_V_V{data_out}'.format(port_name, **FIFO_PORT_SUFFIXES)) args.append('{}_V_V{not_empty}'.format(port_name, **FIFO_PORT_SUFFIXES)) args.append('{}_V_V{read_enable}'.format(port_name, **FIFO_PORT_SUFFIXES)) printer.module('Dataflow', args) println() input_args = 'ap_clk', 'ap_rst', 'ap_start', 'ap_continue' output_args = 'ap_done', 'ap_idle', 'ap_ready' for arg in input_args: println('input %s;' % arg) for arg in output_args: println('output %s;' % arg) for port_name, _, haoda_type, _ in outputs: kwargs = dict(port_name=port_name, **FIFO_PORT_SUFFIXES) println('output [{}:0] {port_name}_V_V{data_in};'.format( util.get_width_in_bits(haoda_type) - 1, **kwargs)) println('input {port_name}_V_V{not_full};'.format(**kwargs)) println('output {port_name}_V_V{write_enable};'.format(**kwargs)) for port_name, _, haoda_type, _ in inputs: kwargs = dict(port_name=port_name, **FIFO_PORT_SUFFIXES) println('input [{}:0] {port_name}_V_V{data_out};'.format( util.get_width_in_bits(haoda_type) - 1, **kwargs)) println('input {port_name}_V_V{not_empty};'.format(**kwargs)) println('output {port_name}_V_V{read_enable};'.format(**kwargs)) println() println("reg ap_done = 1'b0;") println("reg ap_idle = 1'b1;") println("reg ap_ready = 1'b0;") for port_name, _, haoda_type, _ in outputs: kwargs = dict(port_name=port_name, **FIFO_PORT_SUFFIXES) println('reg [{}:0] {port_name}{data_in};'.format( util.get_width_in_bits(haoda_type) - 1, **kwargs)) println('wire {port_name}_V_V{write_enable};'.format(**kwargs)) for port_name, _, haoda_type, _ in inputs: println('wire {}_V_V{read_enable};'.format(port_name, **FIFO_PORT_SUFFIXES)) println('reg ap_rst_n_inv;') with printer.always('*'): println('ap_rst_n_inv = ap_rst;') println() with printer.always('posedge ap_clk'): with printer.if_('ap_rst'): println("ap_done <= 1'b0;") println("ap_idle <= 1'b1;") println("ap_ready <= 1'b0;") printer.else_() println('ap_idle <= ~ap_start;') for port_name, _, _, _ in outputs: println('reg {}_V_V{not_block};'.format(port_name, **FIFO_PORT_SUFFIXES)) for port_name, _, _, _ in inputs: println('reg {}_V_V{not_block};'.format(port_name, **FIFO_PORT_SUFFIXES)) with printer.always('*'): for port_name, _, _, _ in outputs: println('{port_name}_V_V{not_block} = {port_name}_V_V{not_full};'. format(port_name=port_name, **FIFO_PORT_SUFFIXES)) for port_name, _, _, _ in inputs: println('{port_name}_V_V{not_block} = {port_name}_V_V{not_empty};'. format(port_name=port_name, **FIFO_PORT_SUFFIXES)) println() for module in super_source.tpo_node_gen(): for fifo in module.fifos: kwargs = { 'name': fifo.c_expr, 'msb': fifo.width_in_bits - 1, **FIFO_PORT_SUFFIXES } println('wire [{msb}:0] {name}{data_in};'.format(**kwargs)) println('wire {name}{not_full};'.format(**kwargs)) println('wire {name}{write_enable};'.format(**kwargs)) println('wire [{msb}:0] {name}{data_out};'.format(**kwargs)) println('wire {name}{not_empty};'.format(**kwargs)) println('wire {name}{read_enable};'.format(**kwargs)) println() args = collections.OrderedDict( (('clk', 'ap_clk'), ('reset', 'ap_rst_n_inv'), ('if_read_ce', "1'b1"), ('if_write_ce', "1'b1"), ('if{data_in}'.format(**kwargs), '{name}{data_in}'.format(**kwargs)), ('if{not_full}'.format(**kwargs), '{name}{not_full}'.format(**kwargs)), ('if{write_enable}'.format(**kwargs), '{name}{write_enable}'.format(**kwargs)), ('if{data_out}'.format(**kwargs), '{name}{data_out}'.format(**kwargs)), ('if{not_empty}'.format(**kwargs), '{name}{not_empty}'.format(**kwargs)), ('if{read_enable}'.format(**kwargs), '{name}{read_enable}'.format(**kwargs)))) printer.module_instance( 'fifo_w{width}_d{depth}_A'.format(width=fifo.width_in_bits, depth=fifo.depth + 2), fifo.c_expr, args) println() for module in super_source.tpo_node_gen(): module_trait, module_trait_id = super_source.module_table[module] args = collections.OrderedDict( (('ap_clk', 'ap_clk'), ('ap_rst', 'ap_rst_n_inv'), ('ap_start', "1'b1"))) for dram_ref, bank in module.dram_writes: kwargs = dict(port=dram_ref.dram_fifo_name(bank), fifo=util.get_port_name(dram_ref.var, bank), **FIFO_PORT_SUFFIXES) args['{port}_V{data_in}'.format(**kwargs)] = \ '{fifo}_V_V{data_in}'.format(**kwargs) args['{port}_V{not_full}'.format(**kwargs)] = \ '{fifo}_V_V{not_full}'.format(**kwargs) args['{port}_V{write_enable}'.format(**kwargs)] = \ '{fifo}_V_V{write_enable}'.format(**kwargs) for port, fifo in zip(module_trait.output_fifos, module.output_fifos): kwargs = dict(port=port, fifo=fifo, **FIFO_PORT_SUFFIXES) args['{port}_V{data_in}'.format(**kwargs)] = \ '{fifo}{data_in}'.format(**kwargs) args['{port}_V{not_full}'.format(**kwargs)] = \ '{fifo}{not_full}'.format(**kwargs) args['{port}_V{write_enable}'.format(**kwargs)] = \ '{fifo}{write_enable}'.format(**kwargs) for port, fifo in zip(module_trait.input_fifos, module.input_fifos): kwargs = dict(port=port, fifo=fifo, **FIFO_PORT_SUFFIXES) args['{port}_V{data_out}'.format(**kwargs)] = \ "{{1'b1, {fifo}{data_out}}}".format(**kwargs) args['{port}_V{not_empty}'.format(**kwargs)] = \ '{fifo}{not_empty}'.format(**kwargs) args['{port}_V{read_enable}'.format(**kwargs)] = \ '{fifo}{read_enable}'.format(**kwargs) for dram_ref, bank in module.dram_reads: kwargs = dict(port=dram_ref.dram_fifo_name(bank), fifo=util.get_port_name(dram_ref.var, bank), **FIFO_PORT_SUFFIXES) args['{port}_V{data_out}'.format(**kwargs)] = \ "{{1'b1, {fifo}_V_V{data_out}}}".format(**kwargs) args['{port}_V{not_empty}'.format(**kwargs)] = \ '{fifo}_V_V{not_empty}'.format(**kwargs) args['{port}_V{read_enable}'.format(**kwargs)] = \ '{fifo}_V_V{read_enable}'.format(**kwargs) printer.module_instance(util.get_func_name(module_trait_id), module.name, args) println() printer.endmodule() fifos = set() for module in super_source.tpo_node_gen(): for fifo in module.fifos: fifos.add((fifo.width_in_bits, fifo.depth + 2)) for fifo in fifos: printer.fifo_module(*fifo)