示例#1
0
    def mutate_dram_ref(obj: ir.Node, kwargs: Dict[str, int]) -> ir.Node:
        if isinstance(obj, ir.DRAMRef):
            dram_throughput = len(obj.dram) * burst_width // obj.width_in_bits
            if node.dram_reads:
                output_count = len({x[0].var for x in node.dram_reads})
                fifo_throughput = len(node.output_fifos) // output_count
            else:
                input_count = len({x[0].var for x in node.dram_writes})
                fifo_throughput = len(node.input_fifos) // input_count
            if dram_throughput != fifo_throughput:
                raise NotImplementedError(
                    f'memory throughput {dram_throughput} != '
                    f'processing throughput {fifo_throughput}')

            coalescing_idx = kwargs['coalescing_idx']
            unroll_factor = kwargs['unroll_factor']
            elem_idx = coalescing_idx * unroll_factor + obj.offset
            return ir.Var(
                name='{buf}[{idx}]'.format(
                    buf=obj.dram_buf_name(obj.dram[elem_idx % len(obj.dram)]),
                    idx=elem_idx // len(obj.dram),
                ),
                idx=(),
            )
        if isinstance(obj, ir.Let) and isinstance(obj.name, ir.DRAMRef):
            return ir.Var(name='{} = {};'.format(
                obj.name.visit(mutate_dram_ref, kwargs), obj.expr.cl_expr),
                          idx=())
        return obj
示例#2
0
 def mutate_dram_ref_for_writes(obj, kwargs):
     if isinstance(obj, ir.DRAMRef):
         coalescing_idx = kwargs.pop('coalescing_idx')
         unroll_factor = kwargs.pop('unroll_factor')
         type_width = util.get_width_in_bits(obj.haoda_type)
         elem_idx = coalescing_idx * unroll_factor + obj.offset
         num_banks = num_bank_map[obj.var]
         bank = obj.dram[elem_idx % num_banks]
         lsb = (elem_idx // num_banks) * type_width
         msb = lsb + type_width - 1
         return ir.Var(name='{}({msb}, {lsb})'.format(
             obj.dram_buf_name(bank), msb=msb, lsb=lsb),
                       idx=())
     return obj
示例#3
0
def _mutate_dram_ref_for_writes(obj: ir.Node, kwargs: Dict[str, Any]) -> None:
  if isinstance(obj, ir.DRAMRef):
    coalescing_idx = kwargs.pop('coalescing_idx')
    unroll_factor = kwargs.pop('unroll_factor')
    interface = kwargs.pop('interface')
    num_bank_map = kwargs.pop('num_bank_map')
    type_width = obj.haoda_type.width_in_bits
    elem_idx = coalescing_idx * unroll_factor + obj.offset
    num_banks = num_bank_map[obj.var]
    bank = obj.dram[elem_idx % num_banks]
    if interface in {'m_axi', 'axis'}:
      lsb = (elem_idx // num_banks) * type_width
      msb = lsb + type_width - 1
      return ir.Var(name=f'{obj.dram_buf_name(bank)}({msb}, {lsb})', idx=())
  return obj
示例#4
0
 def mutate_dram_ref_for_reads(obj, kwargs):
     if isinstance(obj, ir.DRAMRef):
         coalescing_idx = kwargs.pop('coalescing_idx')
         unroll_factor = kwargs.pop('unroll_factor')
         type_width = util.get_width_in_bits(obj.haoda_type)
         elem_idx = coalescing_idx * unroll_factor + obj.offset
         num_banks = num_bank_map[obj.var]
         bank = expr.dram[elem_idx % num_banks]
         lsb = (elem_idx // num_banks) * type_width
         msb = lsb + type_width - 1
         return ir.Var(
             name='Reinterpret<{c_type}>(static_cast<ap_uint<{width} > >('
             '{dram_buf_name}({msb}, {lsb})))'.format(
                 c_type=obj.c_type,
                 dram_buf_name=obj.dram_buf_name(bank),
                 msb=msb,
                 lsb=lsb,
                 width=msb - lsb + 1),
             idx=())
     return obj
示例#5
0
def create_dataflow_graph(stencil):
    chronological_tensors = stencil.chronological_tensors
    super_source = SuperSourceNode(
        fwd_nodes={},
        cpt_nodes={},
        super_sink=SuperSinkNode(),
    )

    load_nodes = {
        stmt.name:
        tuple(LoadNode(var=stmt.name, bank=bank) for bank in stmt.dram)
        for stmt in stencil.input_stmts
    }
    store_nodes = {
        stmt.name:
        tuple(StoreNode(var=stmt.name, bank=bank) for bank in stmt.dram)
        for stmt in stencil.output_stmts
    }

    for mem_node in itertools.chain(*load_nodes.values()):
        super_source.add_child(mem_node)
    for mem_node in itertools.chain(*store_nodes.values()):
        mem_node.add_child(super_source.super_sink)

    def color_id(node):
        if isinstance(node, LoadNode):
            return f'\033[33mload {node.var}[bank{node.bank}]\033[0m'
        if isinstance(node, StoreNode):
            return f'\033[36mstore {node.var}[bank{node.bank}]\033[0m'
        if isinstance(node, ForwardNode):
            return f'\033[32mforward {node.tensor.name} @{node.offset}\033[0m'
        if isinstance(node, ComputeNode):
            return f'\033[31mcompute {node.tensor.name} #{node.pe_id}\033[0m'
        return 'unknown node'

    def color_attr(node):
        result = []
        for k, v in node.__dict__.items():
            if (node.__class__, k) in ((SuperSourceNode, 'parents'),
                                       (SuperSinkNode, 'children')):
                continue
            if k in ('parents', 'children'):
                result.append('%s: [%s]' % (k, ', '.join(map(color_id, v))))
            else:
                result.append('%s: %s' % (k, repr(v)))
        return '{%s}' % ', '.join(result)

    def color_print(node):
        return '%s: %s' % (color_id(node), color_attr(node))

    print_node = color_id

    if stencil.replication_factor > 1:
        replicated_next_fifo = stencil.get_replicated_next_fifo()
        replicated_all_points = stencil.get_replicated_all_points()
        replicated_reuse_buffers = stencil.get_replicated_reuse_buffers()

        def add_fwd_nodes(src_name):
            dsts = replicated_all_points[src_name]
            reuse_buffer = replicated_reuse_buffers[src_name][1:]
            nodes_to_add = []
            for dst_point_dicts in dsts.values():
                for offset in dst_point_dicts:
                    if (src_name, offset) in super_source.fwd_nodes:
                        continue
                    fwd_node = ForwardNode(
                        tensor=stencil.tensors[src_name],
                        offset=offset,
                        depth=stencil.get_replicated_reuse_buffer_length(
                            src_name, offset))
                    _logger.debug('create %s', print_node(fwd_node))
                    init_offsets = [
                        start for start, end in reuse_buffer if start == end
                    ]
                    if offset in init_offsets:
                        if src_name in [stencil.input.name]:
                            load_node_count = len(load_nodes[src_name])
                            load_nodes[src_name][
                                load_node_count - 1 -
                                offset % load_node_count].add_child(fwd_node)
                        else:
                            (super_source.cpt_nodes[(src_name,
                                                     0)].add_child(fwd_node))
                    super_source.fwd_nodes[(src_name, offset)] = fwd_node
                    if offset in replicated_next_fifo[src_name]:
                        nodes_to_add.append(
                            (fwd_node,
                             (src_name,
                              replicated_next_fifo[src_name][offset])))
            for src_node, key in nodes_to_add:
                src_node.add_child(super_source.fwd_nodes[key])

        add_fwd_nodes(stencil.input.name)

        for stage in stencil.get_stages_chronologically():
            cpt_node = ComputeNode(stage=stage, pe_id=0)
            _logger.debug('create %s', print_node(cpt_node))
            super_source.cpt_nodes[(stage.name, 0)] = cpt_node
            for input_name, input_window in stage.window.items():
                for i in range(len(input_window)):
                    offset = next(offset for offset, points in (
                        replicated_all_points[input_name][stage.name].items())
                                  if points == i)
                    fwd_node = super_source.fwd_nodes[(input_name, offset)]
                    _logger.debug('  access %s', print_node(fwd_node))
                    fwd_node.add_child(cpt_node)
            if stage.is_output():
                super_source.cpt_nodes[stage.name,
                                       0].add_child(store_nodes[stage.name][0])
            else:
                add_fwd_nodes(stage.name)

    else:
        next_fifo = stencil.next_fifo
        all_points = stencil.all_points
        reuse_buffers = stencil.reuse_buffers

        def add_fwd_nodes(src_name):
            dsts = all_points[src_name]
            reuse_buffer = reuse_buffers[src_name][1:]
            nodes_to_add = []
            for dst_point_dicts in dsts.values():
                for offset in dst_point_dicts:
                    if (src_name, offset) in super_source.fwd_nodes:
                        continue
                    fwd_node = ForwardNode(tensor=stencil.tensors[src_name],
                                           offset=offset)
                    #depth=stencil.get_reuse_buffer_length(src_name, offset))
                    _logger.debug('create %s', print_node(fwd_node))
                    # init_offsets is the start of each reuse chain
                    init_offsets = [
                        next(end for start, end in reuse_buffer
                             if start == unroll_idx) for unroll_idx in
                        reversed(range(stencil.unroll_factor))
                    ]
                    _logger.debug('reuse buffer: %s', reuse_buffer)
                    _logger.debug('init offsets: %s', init_offsets)
                    if offset in init_offsets:
                        if src_name in stencil.input_names:
                            # fwd from external input
                            load_node_count = len(load_nodes[src_name])
                            load_nodes[src_name][
                                load_node_count - 1 -
                                offset % load_node_count].add_child(fwd_node)
                        else:
                            # fwd from output of last stage
                            # tensor name and offset are used to find the cpt node
                            cpt_offset = next(
                                unroll_idx
                                for unroll_idx in range(stencil.unroll_factor)
                                if init_offsets[unroll_idx] == offset)
                            cpt_node = super_source.cpt_nodes[(src_name,
                                                               cpt_offset)]
                            cpt_node.add_child(fwd_node)
                    super_source.fwd_nodes[(src_name, offset)] = fwd_node
                    if offset in next_fifo[src_name]:
                        nodes_to_add.append(
                            (fwd_node, (src_name,
                                        next_fifo[src_name][offset])))
            for src_node, key in nodes_to_add:
                # fwd from another fwd node
                src_node.add_child(super_source.fwd_nodes[key])

        for input_name in stencil.input_names:
            add_fwd_nodes(input_name)

        for tensor in chronological_tensors:
            if tensor.is_input():
                continue
            for unroll_index in range(stencil.unroll_factor):
                pe_id = stencil.unroll_factor - 1 - unroll_index
                cpt_node = ComputeNode(tensor=tensor, pe_id=pe_id)
                _logger.debug('create %s', print_node(cpt_node))
                super_source.cpt_nodes[(tensor.name, pe_id)] = cpt_node
                for input_name, input_window in tensor.ld_indices.items():
                    for i in range(len(input_window)):
                        offset = next(
                            offset
                            for offset, points in all_points[input_name][
                                tensor.name].items()
                            if pe_id in points and points[pe_id] == i)
                        fwd_node = super_source.fwd_nodes[(input_name, offset)]
                        _logger.debug('  access %s', print_node(fwd_node))
                        fwd_node.add_child(cpt_node)
            if tensor.is_output():
                for pe_id in range(stencil.unroll_factor):
                    super_source.cpt_nodes[tensor.name, pe_id].add_child(
                        store_nodes[tensor.name][pe_id % len(
                            store_nodes[tensor.name])])
            else:
                add_fwd_nodes(tensor.name)

    # pylint: disable=too-many-nested-blocks
    for src_node in super_source.tpo_valid_node_gen():
        for dst_node in filter(is_valid_node, src_node.children):
            # 5 possible edge types:
            # 1. load => fwd
            # 2. fwd => fwd
            # 3. fwd => cpt
            # 4. cpt => fwd
            # 5. cpt => store
            if isinstance(src_node, LoadNode):
                write_lat = 0
            elif isinstance(src_node, ForwardNode):
                write_lat = 2
            elif isinstance(src_node, ComputeNode):
                write_lat = src_node.tensor.st_ref.lat
            else:
                raise util.InternalError('unexpected source node: %s' %
                                         repr(src_node))

            fifo = ir.FIFO(src_node, dst_node, depth=0, write_lat=write_lat)
            lets: List[ir.Let] = []
            if isinstance(src_node, LoadNode):
                expr = ir.DRAMRef(
                    haoda_type=dst_node.tensor.haoda_type,
                    dram=(src_node.bank, ),
                    var=dst_node.tensor.name,
                    offset=(stencil.unroll_factor - 1 - dst_node.offset) //
                    len(stencil.stmt_table[dst_node.tensor.name].dram),
                )
            elif isinstance(src_node, ForwardNode):
                if isinstance(dst_node, ComputeNode):
                    dst = src_node.tensor.children[dst_node.tensor.name]
                    src_name = src_node.tensor.name
                    unroll_idx = dst_node.pe_id
                    point = all_points[src_name][dst.name][
                        src_node.offset][unroll_idx]
                    idx = list(dst.ld_indices[src_name].values())[point].idx
                    _logger.debug(
                        '%s%s referenced by <%s> @ unroll_idx=%d is %s',
                        src_name, util.idx2str(idx), dst.name, unroll_idx,
                        print_node(src_node))
                    dst_node.fifo_map[src_name][idx] = fifo
                delay = stencil.reuse_buffer_lengths[src_node.tensor.name]\
                                                    [src_node.offset]
                offset = src_node.offset - delay
                for parent in src_node.parents:  # fwd node has only 1 parent
                    for fifo_r in parent.fifos:
                        if fifo_r.edge == (parent, src_node):
                            break
                if delay > 0:
                    # TODO: build an index somewhere
                    for let in src_node.lets:
                        # pylint: disable=undefined-loop-variable
                        if isinstance(
                                let.expr,
                                ir.DelayedRef) and let.expr.ref == fifo_r:
                            var_name = let.name
                            var_type = let.haoda_type
                            break
                    else:
                        var_name = 'let_%d' % len(src_node.lets)
                        # pylint: disable=undefined-loop-variable
                        var_type = fifo_r.haoda_type
                        lets.append(
                            ir.Let(haoda_type=var_type,
                                   name=var_name,
                                   expr=ir.DelayedRef(delay=delay,
                                                      ref=fifo_r)))
                    expr = ir.Var(name=var_name, idx=[])
                    expr.haoda_type = var_type
                else:
                    expr = fifo_r  # pylint: disable=undefined-loop-variable
            elif isinstance(src_node, ComputeNode):

                def replace_refs_callback(obj, args):
                    if isinstance(obj, ir.Ref):
                        _logger.debug(
                            'replace %s with %s',
                            obj,
                            # pylint: disable=cell-var-from-loop,undefined-loop-variable
                            src_node.fifo_map[obj.name][obj.idx])
                        # pylint: disable=cell-var-from-loop,undefined-loop-variable
                        return src_node.fifo_map[obj.name][obj.idx]
                    return obj

                _logger.debug('lets: %s', src_node.tensor.lets)
                lets = [
                    _.visit(replace_refs_callback)
                    for _ in src_node.tensor.lets
                ]
                _logger.debug('replaced lets: %s', lets)
                _logger.debug('expr: %s', src_node.tensor.expr)
                expr = src_node.tensor.expr.visit(replace_refs_callback)
                _logger.debug('replaced expr: %s', expr)
                if isinstance(dst_node, StoreNode):
                    dram_ref = ir.DRAMRef(
                        haoda_type=src_node.tensor.haoda_type,
                        dram=(dst_node.bank, ),
                        var=src_node.tensor.name,
                        offset=(src_node.pe_id) //
                        len(stencil.stmt_table[src_node.tensor.name].dram),
                    )
                    dst_node.lets.append(
                        ir.Let(haoda_type=None, name=dram_ref, expr=fifo))
            else:
                raise util.InternalError('unexpected node of type %s' %
                                         type(src_node))

            src_node.exprs[fifo] = expr
            src_node.lets.extend(_ for _ in lets if _ not in src_node.lets)
            _logger.debug(
                'fifo [%d]: %s%s => %s', fifo.depth, color_id(src_node),
                '' if fifo.write_lat is None else ' ~%d' % fifo.write_lat,
                color_id(dst_node))

    super_source.update_module_depths({})

    return super_source