예제 #1
0
def print_kernel(name: str,
                 printer: util.Printer,
                 node: ir.Module,
                 module_trait: ir.ModuleTrait,
                 module_trait_id: int,
                 burst_width: int = 256) -> None:
    if node.dram_reads and node.dram_writes:
        raise ValueError('cannot read and write DRAM in the same module')
    println = printer.println

    # print I/O info
    for port, arg in zip(module_trait.loads, node.input_fifos):
        println('//  input <{0.cl_type}> <- {1}'.format(port, arg))
    for expr, arg in zip(module_trait.exprs, node.output_fifos):
        println('// output <{}> -> {}'.format(expr.cl_type, arg))

    # print kernel function
    printer.println('__kernel')
    kernel_attrs = ['reqd_work_group_size(1, 1, 1)', 'max_global_work_dim(0)']

    def print_kernel_attrs():
        for attr in kernel_attrs:
            println('__attribute(({}))'.format(attr))

    if node.dram_reads or node.dram_writes:
        params = [
            '__global {} {}* restrict {}'.format(
                '__attribute((buffer_location("HBM{}")))'.format(bank),
                dram_ref.haoda_type.get_cl_vec_type(burst_width),
                util.get_port_name(dram_ref.var, bank))
            for dram_ref, bank in node.dram_reads or node.dram_writes
        ]
        params.append('ulong coalesced_data_num')
        kernel_attrs.append('uses_global_work_offset(0)')
        print_kernel_attrs()
        printer.print_func(name='void {}'.format(name), params=params, align=0)
    else:
        kernel_attrs.append('autorun')
        print_kernel_attrs()
        println('void {}()'.format(name))
    printer.do_scope(name)

    # prepare for any DelayedRef
    def get_delays(obj: ir.Node, delays: List[ir.DelayedRef]) -> ir.Node:
        if isinstance(obj, ir.DelayedRef):
            delays.append(obj)
        return obj

    delays = []  # type: List[ir.DelayedRef]
    for let in module_trait.lets:
        let.visit(get_delays, delays)
    for expr in module_trait.exprs:
        expr.visit(get_delays, delays)
    _logger.debug('delays: %s', delays)

    # inter-iteration declarations
    for delay in delays:
        println(delay.cl_buf_decl)
        println(delay.cl_ptr_decl)

    def mutate_dram_ref(obj: ir.Node, kwargs: Dict[str, int]) -> ir.Node:
        if isinstance(obj, ir.DRAMRef):
            dram_throughput = len(obj.dram) * burst_width // obj.width_in_bits
            if node.dram_reads:
                output_count = len({x[0].var for x in node.dram_reads})
                fifo_throughput = len(node.output_fifos) // output_count
            else:
                input_count = len({x[0].var for x in node.dram_writes})
                fifo_throughput = len(node.input_fifos) // input_count
            if dram_throughput != fifo_throughput:
                raise NotImplementedError(
                    f'memory throughput {dram_throughput} != '
                    f'processing throughput {fifo_throughput}')

            coalescing_idx = kwargs['coalescing_idx']
            unroll_factor = kwargs['unroll_factor']
            elem_idx = coalescing_idx * unroll_factor + obj.offset
            return ir.Var(
                name='{buf}[{idx}]'.format(
                    buf=obj.dram_buf_name(obj.dram[elem_idx % len(obj.dram)]),
                    idx=elem_idx // len(obj.dram),
                ),
                idx=(),
            )
        if isinstance(obj, ir.Let) and isinstance(obj.name, ir.DRAMRef):
            return ir.Var(name='{} = {};'.format(
                obj.name.visit(mutate_dram_ref, kwargs), obj.expr.cl_expr),
                          idx=())
        return obj

    if node.dram_reads or node.dram_writes:
        loop_args = 'ulong i = 0', 'i < coalesced_data_num', '++i'
    else:
        loop_args = '', '', ''

    if delays:
        println('#pragma ivdep', indent=0)
    with printer.for_(*loop_args):
        # print DelayedRef (if any)
        for delay in delays:
            println('const {} {};'.format(delay.cl_type, delay.c_buf_load))

        # print load from DRAM (if any)
        for dram_ref, bank in node.dram_reads:
            println('{gentype} {buf}[{n}];'.format(
                gentype=dram_ref.cl_type,
                n=burst_width // dram_ref.width_in_bits,
                buf=dram_ref.dram_buf_name(bank),
            ))
        for dram_ref, bank in node.dram_reads:
            println('vstore{n}({vec_ptr}[i], 0, {buf});'.format(
                vec_ptr=util.get_port_name(dram_ref.var, bank),
                n=burst_width // dram_ref.width_in_bits,
                buf=dram_ref.dram_buf_name(bank),
            ))

        # read from FIFOs
        for port, arg in zip(module_trait.loads, node.input_fifos):
            println(
                '{0.cl_type} {0.ref_name} = read_channel_intel({1});'.format(
                    port, arg))

        # declare buffer for DRAM writes
        for dram_ref, bank in node.dram_writes:
            n = burst_width // dram_ref.width_in_bits
            buf = dram_ref.dram_buf_name(bank)
            println('{gentype} {buf}[{n}];'.format(
                gentype=dram_ref.cl_type,
                n=n,
                buf=buf,
            ))

        mutate_kwargs = {
            'coalescing_idx': 0,
            'unroll_factor': len(node.input_fifos) or len(node.output_fifos),
        }

        # print Let (if any)
        for let in module_trait.lets:
            println(let.visit(mutate_dram_ref, mutate_kwargs).cl_expr)

        for expr, arg in zip(module_trait.exprs, node.output_fifos):
            println('write_channel_intel({}, {});'.format(
                arg,
                expr.visit(mutate_dram_ref, mutate_kwargs).cl_expr))

        # update DelayedRef (if any)
        for delay in delays:
            println(delay.c_buf_store)
            println('{} = {};'.format(delay.ptr, delay.cl_next_ptr_expr))

        # print store to DRAM (if any)
        for dram_ref, bank in node.dram_writes:
            println('{vec_ptr}[i] = vload{n}(0, {buf});'.format(
                vec_ptr=util.get_port_name(dram_ref.var, bank),
                n=burst_width // dram_ref.width_in_bits,
                buf=dram_ref.dram_buf_name(bank),
            ))

    printer.un_scope()  # end of kernel function
    _logger.debug('printing: %s', module_trait)
예제 #2
0
def print_top_module(
    printer: backend.VerilogPrinter,
    super_source: dataflow.SuperSourceNode,
    inputs: Sequence[Tuple[str, str, int, str]],
    outputs: Sequence[Tuple[str, str, int, str]],
    module_name: str = 'Dataflow',
    interface: str = 'm_axi',
) -> None:
    """Generate kernel.xml file.

  Args:
    printer: printer to print to
    super_source: SuperSourceNode carrying the IR tree.
    inputs: sequence of (port_name, bundle_name, width, _) of input ports
    outputs: sequence of (port_name, bundle_name, width, _) of output ports
    module_name: name of the module
    interface: interface type, supported values are 'm_axi' and 'axis'
  """
    printer.printlns('`timescale 1 ns / 1 ps', '`default_nettype none')

    ports = *inputs, *outputs

    # unpack suffixes
    data_in = FIFO_PORT_SUFFIXES['data_in']
    not_full = FIFO_PORT_SUFFIXES['not_full']
    write_enable = FIFO_PORT_SUFFIXES['write_enable']
    data_out = FIFO_PORT_SUFFIXES['data_out']
    not_empty = FIFO_PORT_SUFFIXES['not_empty']
    read_enable = FIFO_PORT_SUFFIXES['read_enable']
    not_block = FIFO_PORT_SUFFIXES['not_block']
    data = AXIS_PORT_SUFFIXES['data']
    valid = AXIS_PORT_SUFFIXES['valid']
    ready = AXIS_PORT_SUFFIXES['ready']

    # prepare arguments
    input_args = ['ap_clk']
    output_args: List[str] = []
    if interface == 'm_axi':
        input_args += 'ap_rst', 'ap_start', 'ap_continue'
        output_args += 'ap_done', 'ap_idle', 'ap_ready'
    elif interface == 'axis':
        input_args.append('ap_rst_n')

    args = list(input_args + output_args)
    if interface == 'm_axi':
        for port_name, _, _, _ in outputs:
            args.append(f'{port_name}_V_V{data_in}')
            args.append(f'{port_name}_V_V{not_full}')
            args.append(f'{port_name}_V_V{write_enable}')
        for port_name, _, _, _ in inputs:
            args.append(f'{port_name}_V_V{data_out}')
            args.append(f'{port_name}_V_V{not_empty}')
            args.append(f'{port_name}_V_V{read_enable}')
    elif interface == 'axis':
        for port_name, _, _, _ in ports:
            args.extend(port_name + suffix
                        for suffix in AXIS_PORT_SUFFIXES.values())

    # print module interface
    printer.module(module_name, args)
    printer.println()

    # print signals for modules
    printer.printlns(
        *(f'input  wire {arg};' for arg in input_args),
        *(f'output wire {arg};' for arg in output_args),
    )
    for port_name, _, width, _ in outputs:
        if interface == 'm_axi':
            printer.printlns(
                f'output wire [{width - 1}:0] {port_name}_V_V{data_in};',
                f'input  wire {port_name}_V_V{not_full};',
                f'output wire {port_name}_V_V{write_enable};',
            )
        elif interface == 'axis':
            printer.printlns(
                f'output wire [{width - 1}:0] {port_name}{data};',
                f'output wire {port_name}{valid};',
                f'input  wire {port_name}{ready};',
                f'wire [{width - 1}:0] {port_name}_V_V{data_in};',
                f'wire {port_name}_V_V{not_full};',
                f'wire {port_name}_V_V{write_enable};',
            )
    for port_name, _, width, _ in inputs:
        if interface == 'm_axi':
            printer.printlns(
                f'input  wire [{width - 1}:0] {port_name}_V_V{data_out};',
                f'input  wire {port_name}_V_V{not_empty};',
                f'output wire {port_name}_V_V{read_enable};',
            )
        elif interface == 'axis':
            printer.printlns(
                f'input  wire [{width - 1}:0] {port_name}{data};',
                f'input  wire {port_name}{valid};',
                f'output wire {port_name}{ready};',
                f'wire [{width - 1}:0] {port_name}_V_V{data_out};',
                f'wire {port_name}_V_V{not_empty};',
                f'wire {port_name}_V_V{read_enable};',
            )
    printer.println()

    # not used
    printer.printlns(
        "reg ap_done = 1'b0;",
        "reg ap_idle = 1'b1;",
        "reg ap_ready = 1'b0;",
    )
    if interface == 'axis':
        printer.println("wire ap_start = 1'b1;")

    # print signals for FIFOs
    if interface == 'm_axi':
        for port_name, _, width, _ in outputs:
            printer.printlns(
                f'reg [{width - 1}:0] {port_name}{data_in};',
                f'wire {port_name}_V_V{write_enable};',
            )
        for port_name, _, _, _ in inputs:
            printer.println(f'wire {port_name}_V_V{read_enable};')
    printer.println()

    # register reset signal
    ap_rst_reg_level = 8
    rst = 'ap_rst'
    if interface == 'axis':
        rst = '~ap_rst_n'
    printer.printlns(
        f'wire ap_rst_reg_0 = {rst};',
        *(f'(* shreg_extract = "no", max_fanout = {8 ** i} *) reg ap_rst_reg_{i};'
          for i in range(1, ap_rst_reg_level)),
        f'(* shreg_extract = "no" *) reg ap_rst_reg_{ap_rst_reg_level};',
        f'wire ap_rst_reg = ap_rst_reg_{ap_rst_reg_level};',
    )
    if ap_rst_reg_level > 0:
        with printer.always('posedge ap_clk'):
            printer.printlns(f'ap_rst_reg_{i + 1} <= ap_rst_reg_{i};'
                             for i in range(ap_rst_reg_level))

    with printer.always('posedge ap_clk'):
        with printer.if_('ap_rst_reg'):
            printer.printlns(
                "ap_done <= 1'b0;",
                "ap_idle <= 1'b1;",
                "ap_ready <= 1'b0;",
            )
            printer.else_()
            printer.println('ap_idle <= ~ap_start;')
    printer.println()

    if interface == 'm_axi':
        # used by cosim for deadlock detection
        printer.printlns(f'reg {port_name}_V_V{not_block};'
                         for port_name, _, _, _ in ports)
        with printer.always('*'):
            printer.printlns(
                *(f'{port_name}_V_V{not_block} = {port_name}_V_V{not_full};'
                  for port_name, _, _, _ in outputs),
                *(f'{port_name}_V_V{not_block} = {port_name}_V_V{not_empty};'
                  for port_name, _, _, _ in inputs),
            )
    printer.println()

    fifos: Set[Tuple[int, int]] = set()  # used for printing FIFO modules

    if interface == 'axis':
        for port_name, _, width, _ in inputs:
            printer.module_instance(
                'fifo_w{width}_d{depth}_A'.format(width=width, depth=2),
                port_name + '_fifo',
                args={
                    'clk': 'ap_clk',
                    'reset': 'ap_rst_reg',
                    'if_read_ce': "1'b1",
                    'if_write_ce': "1'b1",
                    f'if{data_in}': f'{port_name}{data}',
                    f'if{not_full}': f'{port_name}{ready}',
                    f'if{write_enable}': f'{port_name}{valid}',
                    f'if{data_out}': f'{port_name}_V_V{data_out}',
                    f'if{not_empty}': f'{port_name}_V_V{not_empty}',
                    f'if{read_enable}': f'{port_name}_V_V{read_enable}',
                },
            )
            fifos.add((width, 2))
            printer.println()

        for port_name, _, width, _ in outputs:
            printer.module_instance(
                f'fifo_w{width}_d2_A',
                port_name + '_fifo',
                args={
                    'clk': 'ap_clk',
                    'reset': 'ap_rst_reg',
                    'if_read_ce': "1'b1",
                    'if_write_ce': "1'b1",
                    f'if{data_in}': f'{port_name}_V_V{data_in}',
                    f'if{not_full}': f'{port_name}_V_V{not_full}',
                    f'if{write_enable}': f'{port_name}_V_V{write_enable}',
                    f'if{data_out}': f'{port_name}{data}',
                    f'if{not_empty}': f'{port_name}{valid}',
                    f'if{read_enable}': f'{port_name}{ready}',
                },
            )
            fifos.add((width, 2))
            printer.println()

    # print FIFO instances
    for module in super_source.tpo_valid_node_gen():
        for fifo in module.fifos:
            name = fifo.c_expr
            msb = fifo.width_in_bits - 1
            # fifo.depth is the "extra" capacity of a FIFO; the base depth is 3, 1 for
            # registering the input, 1 for keeping II=1 when FIFO is (almost) full, 1
            # for keeping II=1 when FIFO is relaxed from back pressure (necessary
            # because the optimal FIFO depths may require back pressure)
            depth = fifo.depth + 3
            printer.printlns(
                f'wire [{msb}:0] {name}{data_in};',
                f'wire {name}{not_full};',
                f'wire {name}{write_enable};',
                f'wire [{msb}:0] {name}{data_out};',
                f'wire {name}{not_empty};',
                f'wire {name}{read_enable};',
                '',
            )
            printer.module_instance(
                f'fifo_w{fifo.width_in_bits}_d{depth}_A',
                name,
                args={
                    'clk': 'ap_clk',
                    'reset': 'ap_rst_reg',
                    'if_read_ce': "1'b1",
                    'if_write_ce': "1'b1",
                    f'if{data_in}': f'{name}{data_in}',
                    f'if{not_full}': f'{name}{not_full}',
                    f'if{write_enable}': f'{name}{write_enable}',
                    f'if{data_out}': f'{name}{data_out}',
                    f'if{not_empty}': f'{name}{not_empty}',
                    f'if{read_enable}': f'{name}{read_enable}',
                },
            )
            fifos.add((fifo.width_in_bits, depth))
            printer.println()

    # print module instances
    for module in super_source.tpo_valid_node_gen():
        module_trait, module_trait_id = super_source.module_table[module]
        arg_dict = {
            'ap_clk': 'ap_clk',
            'ap_rst': 'ap_rst_reg',
            'ap_start': "1'b1",
        }
        for dram_ref, bank in module.dram_writes:
            port = dram_ref.dram_fifo_name(bank)
            fifo = util.get_port_name(dram_ref.var, bank)
            arg_dict.update({
                f'{port}_V{data_in}':
                f'{fifo}_V_V{data_in}',
                f'{port}_V{not_full}':
                f'{fifo}_V_V{not_full}',
                f'{port}_V{write_enable}':
                f'{fifo}_V_V{write_enable}',
            })
        for port, fifo in zip(module_trait.output_fifos, module.output_fifos):
            arg_dict.update({
                f'{port}_V{data_in}': f'{fifo}{data_in}',
                f'{port}_V{not_full}': f'{fifo}{not_full}',
                f'{port}_V{write_enable}': f'{fifo}{write_enable}',
            })
        for port, fifo in zip(module_trait.input_fifos, module.input_fifos):
            arg_dict.update({
                f'{port}_V{data_out}': f"{{1'b1, {fifo}{data_out}}}",
                f'{port}_V{not_empty}': f'{fifo}{not_empty}',
                f'{port}_V{read_enable}': f'{fifo}{read_enable}',
            })
        for dram_ref, bank in module.dram_reads:
            port = dram_ref.dram_fifo_name(bank)
            fifo = util.get_port_name(dram_ref.var, bank)
            arg_dict.update({
                f'{port}_V{data_out}':
                f"{{1'b1, {fifo}_V_V{data_out}}}",
                f'{port}_V{not_empty}':
                f'{fifo}_V_V{not_empty}',
                f'{port}_V{read_enable}':
                f'{fifo}_V_V{read_enable}',
            })
        printer.module_instance(util.get_func_name(module_trait_id),
                                module.name, arg_dict)
        printer.println()
    printer.endmodule()
    printer.println('`default_nettype wire')

    # print FIFO modules
    for fifo in fifos:
        printer.fifo_module(*fifo)
예제 #3
0
def print_code(stencil, xo_file, platform=None, jobs=os.cpu_count()):
    """Generate hardware object file for the given Stencil.

  Working `vivado` and `vivado_hls` is required in the PATH.

  Args:
    stencil: Stencil object to generate from.
    xo_file: file object to write to.
    platform: path to the SDAccel platform directory.
    jobs: maximum number of jobs running in parallel.
  """

    m_axi_names = []
    m_axi_bundles = []
    inputs = []
    outputs = []
    for stmt in stencil.output_stmts + stencil.input_stmts:
        for bank in stmt.dram:
            haoda_type = 'uint%d' % stencil.burst_width
            bundle_name = util.get_bundle_name(stmt.name, bank)
            m_axi_names.append(bundle_name)
            m_axi_bundles.append((bundle_name, haoda_type))

    for stmt in stencil.output_stmts:
        for bank in stmt.dram:
            haoda_type = 'uint%d' % stencil.burst_width
            bundle_name = util.get_bundle_name(stmt.name, bank)
            outputs.append((util.get_port_name(stmt.name,
                                               bank), bundle_name, haoda_type,
                            util.get_port_buf_name(stmt.name, bank)))
    for stmt in stencil.input_stmts:
        for bank in stmt.dram:
            haoda_type = 'uint%d' % stencil.burst_width
            bundle_name = util.get_bundle_name(stmt.name, bank)
            inputs.append((util.get_port_name(stmt.name,
                                              bank), bundle_name, haoda_type,
                           util.get_port_buf_name(stmt.name, bank)))

    top_name = stencil.app_name + '_kernel'

    if 'XDEVICE' in os.environ:
        xdevice = os.environ['XDEVICE'].replace(':', '_').replace('.', '_')
        if platform is None or not os.path.exists(platform):
            platform = os.path.join('/opt/xilinx/platforms', xdevice)
        if platform is None or not os.path.exists(platform):
            if 'XILINX_SDX' in os.environ:
                platform = os.path.join(os.environ['XILINX_SDX'], 'platforms',
                                        xdevice)
    if platform is None or not os.path.exists(platform):
        raise ValueError('Cannot determine platform from environment.')
    device_info = backend.get_device_info(platform)

    with tempfile.TemporaryDirectory(prefix='sodac-xrtl-') as tmpdir:
        dataflow_kernel = os.path.join(tmpdir, 'dataflow_kernel.cpp')
        with open(dataflow_kernel, 'w') as dataflow_kernel_obj:
            print_dataflow_hls_interface(util.Printer(dataflow_kernel_obj),
                                         top_name, inputs, outputs)

        kernel_xml = os.path.join(tmpdir, 'kernel.xml')
        with open(kernel_xml, 'w') as kernel_xml_obj:
            backend.print_kernel_xml(top_name, outputs + inputs,
                                     kernel_xml_obj)

        kernel_file = os.path.join(tmpdir, 'kernel.cpp')
        with open(kernel_file, 'w') as kernel_fileobj:
            hls_kernel.print_code(stencil, kernel_fileobj)

        super_source = stencil.dataflow_super_source
        with concurrent.futures.ThreadPoolExecutor(
                max_workers=jobs) as executor:
            threads = []
            for module_id in range(len(super_source.module_traits)):
                threads.append(
                    executor.submit(synthesis_module, tmpdir, [kernel_file],
                                    util.get_func_name(module_id),
                                    device_info))
            threads.append(
                executor.submit(synthesis_module, tmpdir, [dataflow_kernel],
                                top_name, device_info))
            for future in concurrent.futures.as_completed(threads):
                returncode, stdout, stderr = future.result()
                log_func = _logger.error if returncode != 0 else _logger.debug
                if stdout:
                    log_func(stdout.decode())
                if stderr:
                    log_func(stderr.decode())
                if returncode != 0:
                    util.pause_for_debugging()
                    sys.exit(returncode)

        hdl_dir = os.path.join(tmpdir, 'hdl')
        with open(os.path.join(hdl_dir, 'Dataflow.v'), mode='w') as dataflow_v:
            print_top_module(backend.VerilogPrinter(dataflow_v),
                             stencil.dataflow_super_source, inputs, outputs)

        util.pause_for_debugging()

        xo_filename = os.path.join(tmpdir, stencil.app_name + '.xo')
        with backend.PackageXo(xo_filename, top_name, kernel_xml, hdl_dir,
                               m_axi_names, [dataflow_kernel]) as proc:
            stdout, stderr = proc.communicate()
        log_func = _logger.error if proc.returncode != 0 else _logger.debug
        log_func(stdout.decode())
        log_func(stderr.decode())
        with open(xo_filename, mode='rb') as xo_fileobj:
            shutil.copyfileobj(xo_fileobj, xo_file)
예제 #4
0
def print_code(
    stencil: core.Stencil,
    xo_file: IO[bytes],
    device_info: Dict[str, str],
    jobs: Optional[int] = os.cpu_count(),
    rpt_file: Optional[str] = None,
    interface: str = 'm_axi',
) -> None:
    """Generate hardware object file for the given Stencil.

  Working `vivado` and `vivado_hls` is required in the PATH.

  Args:
    stencil: Stencil object to generate from.
    xo_file: file object to write to.
    device_info: dict of 'part_num' and 'clock_period'.
    jobs: maximum number of jobs running in parallel.
    rpt_file: path of the generated report; None disables report generation.
    interface: interface type, supported values are 'm_axi' and 'axis'.
  """

    iface_names = []  # for axis
    m_axi_names = []  # for m_axi
    inputs = []
    outputs = []

    for stmt in stencil.output_stmts:
        for bank in stmt.dram:
            port_name = util.get_port_name(stmt.name, bank)
            bundle_name = util.get_bundle_name(stmt.name, bank)
            iface_names.append(port_name)
            m_axi_names.append(bundle_name)
            outputs.append((port_name, bundle_name, stencil.burst_width,
                            util.get_port_buf_name(stmt.name, bank)))
    for stmt in stencil.input_stmts:
        for bank in stmt.dram:
            port_name = util.get_port_name(stmt.name, bank)
            bundle_name = util.get_bundle_name(stmt.name, bank)
            iface_names.append(port_name)
            m_axi_names.append(bundle_name)
            inputs.append((port_name, bundle_name, stencil.burst_width,
                           util.get_port_buf_name(stmt.name, bank)))

    top_name = stencil.kernel_name
    with tempfile.TemporaryDirectory(prefix='sodac-xrtl-') as tmpdir:
        kernel_xml = os.path.join(tmpdir, 'kernel.xml')
        with open(kernel_xml, 'w') as kernel_xml_obj:
            print_kernel_xml(top_name, inputs, outputs, kernel_xml_obj,
                             interface)

        kernel_file = os.path.join(tmpdir, 'kernel.cpp')
        with open(kernel_file, 'w') as kernel_fileobj:
            hls_kernel.print_code(stencil, kernel_fileobj)

        args = []
        for module_trait_id, module_trait in enumerate(stencil.module_traits):
            sio = io.StringIO()
            hls_kernel.print_module_definition(util.CppPrinter(sio),
                                               module_trait,
                                               module_trait_id,
                                               burst_width=stencil.burst_width)
            args.append(
                (len(sio.getvalue()), synthesis_module, tmpdir, [kernel_file],
                 util.get_func_name(module_trait_id), device_info))

        if interface == 'm_axi':
            sio = io.StringIO()
            print_dataflow_hls_interface(util.CppPrinter(sio), top_name,
                                         inputs, outputs)
            dataflow_kernel = os.path.join(tmpdir, 'dataflow_kernel.cpp')
            with open(dataflow_kernel, 'w') as dataflow_kernel_obj:
                dataflow_kernel_obj.write(sio.getvalue())
            args.append((len(sio.getvalue()), synthesis_module, tmpdir,
                         [dataflow_kernel], top_name, device_info))
        args.sort(key=lambda x: x[0], reverse=True)

        super_source = stencil.dataflow_super_source
        job_server = util.release_job_slot()
        with concurrent.futures.ThreadPoolExecutor(
                max_workers=jobs) as executor:
            threads = [executor.submit(*x[1:]) for x in args]
            for future in concurrent.futures.as_completed(threads):
                returncode, stdout, stderr = future.result()
                log_func = _logger.error if returncode != 0 else _logger.debug
                if stdout:
                    log_func(stdout.decode())
                if stderr:
                    log_func(stderr.decode())
                if returncode != 0:
                    util.pause_for_debugging()
                    sys.exit(returncode)
        util.acquire_job_slot(job_server)

        # generate HLS report
        depths: Dict[int, int] = {}
        hls_resources = hls_report.HlsResources()
        if interface == 'm_axi':
            hls_resources = hls_report.resources(
                os.path.join(tmpdir, 'report', top_name + '_csynth.xml'))
            hls_resources -= hls_report.resources(
                os.path.join(tmpdir, 'report', 'Dataflow_csynth.xml'))

        _logger.info(hls_resources)
        for module_id, nodes in enumerate(
                super_source.module_trait_table.values()):
            module_name = util.get_func_name(module_id)
            report_file = os.path.join(tmpdir, 'report',
                                       module_name + '_csynth.xml')
            hls_resource = hls_report.resources(report_file)
            use_count = len(nodes)
            try:
                perf = hls_report.performance(report_file)
                _logger.info('%s, usage: %5d times, II: %3d, Depth: %3d',
                             hls_resource, use_count, perf.ii, perf.depth)
                depths[module_id] = perf.depth
            except hls_report.BadReport as e:
                _logger.warn('%s in %s report (%s)', e, module_name,
                             report_file)
                _logger.info('%s, usage: %5d times', hls_resource, use_count)
                raise e
            hls_resources += hls_resource * use_count
        _logger.info('total usage:')
        _logger.info(hls_resources)
        if rpt_file:
            rpt_json = collections.OrderedDict([('name', top_name)] +
                                               list(hls_resources))
            with open(rpt_file, mode='w') as rpt_fileobj:
                json.dump(rpt_json, rpt_fileobj, indent=2)

        # update the module pipeline depths
        stencil.dataflow_super_source.update_module_depths(depths)

        hdl_dir = os.path.join(tmpdir, 'hdl')
        module_name = 'Dataflow'
        if interface == 'axis':
            module_name = top_name
        with open(os.path.join(hdl_dir, f'{module_name}.v'),
                  mode='w') as fileobj:
            print_top_module(
                backend.VerilogPrinter(fileobj),
                stencil.dataflow_super_source,
                inputs,
                outputs,
                module_name,
                interface,
            )

        util.pause_for_debugging()

        xo_filename = os.path.join(tmpdir, stencil.app_name + '.xo')
        kwargs = {}
        if interface == 'm_axi':
            kwargs['m_axi_names'] = m_axi_names
        elif interface == 'axis':
            kwargs['iface_names'] = iface_names
        with backend.PackageXo(
                xo_filename,
                top_name,
                kernel_xml,
                hdl_dir,
                **kwargs,
        ) as proc:
            stdout, stderr = proc.communicate()
        log_func = _logger.error if proc.returncode != 0 else _logger.debug
        log_func(stdout.decode())
        log_func(stderr.decode())
        with open(xo_filename, mode='rb') as xo_fileobj:
            shutil.copyfileobj(xo_fileobj, xo_file)
예제 #5
0
def print_top_module(printer, super_source, inputs, outputs):
    println = printer.println
    println('`timescale 1 ns / 1 ps')
    args = [
        'ap_clk', 'ap_rst', 'ap_start', 'ap_done', 'ap_continue', 'ap_idle',
        'ap_ready'
    ]
    for port_name, _, _, _ in outputs:
        args.append('{}_V_V{data_in}'.format(port_name, **FIFO_PORT_SUFFIXES))
        args.append('{}_V_V{not_full}'.format(port_name, **FIFO_PORT_SUFFIXES))
        args.append('{}_V_V{write_enable}'.format(port_name,
                                                  **FIFO_PORT_SUFFIXES))
    for port_name, _, _, _ in inputs:
        args.append('{}_V_V{data_out}'.format(port_name, **FIFO_PORT_SUFFIXES))
        args.append('{}_V_V{not_empty}'.format(port_name,
                                               **FIFO_PORT_SUFFIXES))
        args.append('{}_V_V{read_enable}'.format(port_name,
                                                 **FIFO_PORT_SUFFIXES))
    printer.module('Dataflow', args)
    println()

    input_args = 'ap_clk', 'ap_rst', 'ap_start', 'ap_continue'
    output_args = 'ap_done', 'ap_idle', 'ap_ready'

    for arg in input_args:
        println('input  %s;' % arg)
    for arg in output_args:
        println('output %s;' % arg)
    for port_name, _, haoda_type, _ in outputs:
        kwargs = dict(port_name=port_name, **FIFO_PORT_SUFFIXES)
        println('output [{}:0] {port_name}_V_V{data_in};'.format(
            util.get_width_in_bits(haoda_type) - 1, **kwargs))
        println('input  {port_name}_V_V{not_full};'.format(**kwargs))
        println('output {port_name}_V_V{write_enable};'.format(**kwargs))
    for port_name, _, haoda_type, _ in inputs:
        kwargs = dict(port_name=port_name, **FIFO_PORT_SUFFIXES)
        println('input  [{}:0] {port_name}_V_V{data_out};'.format(
            util.get_width_in_bits(haoda_type) - 1, **kwargs))
        println('input  {port_name}_V_V{not_empty};'.format(**kwargs))
        println('output {port_name}_V_V{read_enable};'.format(**kwargs))
    println()

    println("reg ap_done = 1'b0;")
    println("reg ap_idle = 1'b1;")
    println("reg ap_ready = 1'b0;")

    for port_name, _, haoda_type, _ in outputs:
        kwargs = dict(port_name=port_name, **FIFO_PORT_SUFFIXES)
        println('reg [{}:0] {port_name}{data_in};'.format(
            util.get_width_in_bits(haoda_type) - 1, **kwargs))
        println('wire {port_name}_V_V{write_enable};'.format(**kwargs))
    for port_name, _, haoda_type, _ in inputs:
        println('wire {}_V_V{read_enable};'.format(port_name,
                                                   **FIFO_PORT_SUFFIXES))
    println('reg ap_rst_n_inv;')
    with printer.always('*'):
        println('ap_rst_n_inv = ap_rst;')
    println()

    with printer.always('posedge ap_clk'):
        with printer.if_('ap_rst'):
            println("ap_done <= 1'b0;")
            println("ap_idle <= 1'b1;")
            println("ap_ready <= 1'b0;")
            printer.else_()
            println('ap_idle <= ~ap_start;')

    for port_name, _, _, _ in outputs:
        println('reg {}_V_V{not_block};'.format(port_name,
                                                **FIFO_PORT_SUFFIXES))
    for port_name, _, _, _ in inputs:
        println('reg {}_V_V{not_block};'.format(port_name,
                                                **FIFO_PORT_SUFFIXES))

    with printer.always('*'):
        for port_name, _, _, _ in outputs:
            println('{port_name}_V_V{not_block} = {port_name}_V_V{not_full};'.
                    format(port_name=port_name, **FIFO_PORT_SUFFIXES))
        for port_name, _, _, _ in inputs:
            println('{port_name}_V_V{not_block} = {port_name}_V_V{not_empty};'.
                    format(port_name=port_name, **FIFO_PORT_SUFFIXES))
    println()

    for module in super_source.tpo_node_gen():
        for fifo in module.fifos:
            kwargs = {
                'name': fifo.c_expr,
                'msb': fifo.width_in_bits - 1,
                **FIFO_PORT_SUFFIXES
            }
            println('wire [{msb}:0] {name}{data_in};'.format(**kwargs))
            println('wire {name}{not_full};'.format(**kwargs))
            println('wire {name}{write_enable};'.format(**kwargs))
            println('wire [{msb}:0] {name}{data_out};'.format(**kwargs))
            println('wire {name}{not_empty};'.format(**kwargs))
            println('wire {name}{read_enable};'.format(**kwargs))
            println()

            args = collections.OrderedDict(
                (('clk', 'ap_clk'), ('reset', 'ap_rst_n_inv'), ('if_read_ce',
                                                                "1'b1"),
                 ('if_write_ce', "1'b1"), ('if{data_in}'.format(**kwargs),
                                           '{name}{data_in}'.format(**kwargs)),
                 ('if{not_full}'.format(**kwargs),
                  '{name}{not_full}'.format(**kwargs)),
                 ('if{write_enable}'.format(**kwargs),
                  '{name}{write_enable}'.format(**kwargs)),
                 ('if{data_out}'.format(**kwargs),
                  '{name}{data_out}'.format(**kwargs)),
                 ('if{not_empty}'.format(**kwargs),
                  '{name}{not_empty}'.format(**kwargs)),
                 ('if{read_enable}'.format(**kwargs),
                  '{name}{read_enable}'.format(**kwargs))))
            printer.module_instance(
                'fifo_w{width}_d{depth}_A'.format(width=fifo.width_in_bits,
                                                  depth=fifo.depth + 2),
                fifo.c_expr, args)
            println()

    for module in super_source.tpo_node_gen():
        module_trait, module_trait_id = super_source.module_table[module]
        args = collections.OrderedDict(
            (('ap_clk', 'ap_clk'), ('ap_rst', 'ap_rst_n_inv'), ('ap_start',
                                                                "1'b1")))
        for dram_ref, bank in module.dram_writes:
            kwargs = dict(port=dram_ref.dram_fifo_name(bank),
                          fifo=util.get_port_name(dram_ref.var, bank),
                          **FIFO_PORT_SUFFIXES)
            args['{port}_V{data_in}'.format(**kwargs)] = \
                             '{fifo}_V_V{data_in}'.format(**kwargs)
            args['{port}_V{not_full}'.format(**kwargs)] = \
                             '{fifo}_V_V{not_full}'.format(**kwargs)
            args['{port}_V{write_enable}'.format(**kwargs)] = \
                             '{fifo}_V_V{write_enable}'.format(**kwargs)
        for port, fifo in zip(module_trait.output_fifos, module.output_fifos):
            kwargs = dict(port=port, fifo=fifo, **FIFO_PORT_SUFFIXES)
            args['{port}_V{data_in}'.format(**kwargs)] = \
                             '{fifo}{data_in}'.format(**kwargs)
            args['{port}_V{not_full}'.format(**kwargs)] = \
                             '{fifo}{not_full}'.format(**kwargs)
            args['{port}_V{write_enable}'.format(**kwargs)] = \
                             '{fifo}{write_enable}'.format(**kwargs)
        for port, fifo in zip(module_trait.input_fifos, module.input_fifos):
            kwargs = dict(port=port, fifo=fifo, **FIFO_PORT_SUFFIXES)
            args['{port}_V{data_out}'.format(**kwargs)] = \
                             "{{1'b1, {fifo}{data_out}}}".format(**kwargs)
            args['{port}_V{not_empty}'.format(**kwargs)] = \
                             '{fifo}{not_empty}'.format(**kwargs)
            args['{port}_V{read_enable}'.format(**kwargs)] = \
                             '{fifo}{read_enable}'.format(**kwargs)
        for dram_ref, bank in module.dram_reads:
            kwargs = dict(port=dram_ref.dram_fifo_name(bank),
                          fifo=util.get_port_name(dram_ref.var, bank),
                          **FIFO_PORT_SUFFIXES)
            args['{port}_V{data_out}'.format(**kwargs)] = \
                             "{{1'b1, {fifo}_V_V{data_out}}}".format(**kwargs)
            args['{port}_V{not_empty}'.format(**kwargs)] = \
                             '{fifo}_V_V{not_empty}'.format(**kwargs)
            args['{port}_V{read_enable}'.format(**kwargs)] = \
                             '{fifo}_V_V{read_enable}'.format(**kwargs)
        printer.module_instance(util.get_func_name(module_trait_id),
                                module.name, args)
        println()
    printer.endmodule()

    fifos = set()
    for module in super_source.tpo_node_gen():
        for fifo in module.fifos:
            fifos.add((fifo.width_in_bits, fifo.depth + 2))
    for fifo in fifos:
        printer.fifo_module(*fifo)