def mutate_dram_ref(obj: ir.Node, kwargs: Dict[str, int]) -> ir.Node: if isinstance(obj, ir.DRAMRef): dram_throughput = len(obj.dram) * burst_width // obj.width_in_bits if node.dram_reads: output_count = len({x[0].var for x in node.dram_reads}) fifo_throughput = len(node.output_fifos) // output_count else: input_count = len({x[0].var for x in node.dram_writes}) fifo_throughput = len(node.input_fifos) // input_count if dram_throughput != fifo_throughput: raise NotImplementedError( f'memory throughput {dram_throughput} != ' f'processing throughput {fifo_throughput}') coalescing_idx = kwargs['coalescing_idx'] unroll_factor = kwargs['unroll_factor'] elem_idx = coalescing_idx * unroll_factor + obj.offset return ir.Var( name='{buf}[{idx}]'.format( buf=obj.dram_buf_name(obj.dram[elem_idx % len(obj.dram)]), idx=elem_idx // len(obj.dram), ), idx=(), ) if isinstance(obj, ir.Let) and isinstance(obj.name, ir.DRAMRef): return ir.Var(name='{} = {};'.format( obj.name.visit(mutate_dram_ref, kwargs), obj.expr.cl_expr), idx=()) return obj
def mutate_dram_ref_for_writes(obj, kwargs): if isinstance(obj, ir.DRAMRef): coalescing_idx = kwargs.pop('coalescing_idx') unroll_factor = kwargs.pop('unroll_factor') type_width = util.get_width_in_bits(obj.haoda_type) elem_idx = coalescing_idx * unroll_factor + obj.offset num_banks = num_bank_map[obj.var] bank = obj.dram[elem_idx % num_banks] lsb = (elem_idx // num_banks) * type_width msb = lsb + type_width - 1 return ir.Var(name='{}({msb}, {lsb})'.format( obj.dram_buf_name(bank), msb=msb, lsb=lsb), idx=()) return obj
def _mutate_dram_ref_for_writes(obj: ir.Node, kwargs: Dict[str, Any]) -> None: if isinstance(obj, ir.DRAMRef): coalescing_idx = kwargs.pop('coalescing_idx') unroll_factor = kwargs.pop('unroll_factor') interface = kwargs.pop('interface') num_bank_map = kwargs.pop('num_bank_map') type_width = obj.haoda_type.width_in_bits elem_idx = coalescing_idx * unroll_factor + obj.offset num_banks = num_bank_map[obj.var] bank = obj.dram[elem_idx % num_banks] if interface in {'m_axi', 'axis'}: lsb = (elem_idx // num_banks) * type_width msb = lsb + type_width - 1 return ir.Var(name=f'{obj.dram_buf_name(bank)}({msb}, {lsb})', idx=()) return obj
def mutate_dram_ref_for_reads(obj, kwargs): if isinstance(obj, ir.DRAMRef): coalescing_idx = kwargs.pop('coalescing_idx') unroll_factor = kwargs.pop('unroll_factor') type_width = util.get_width_in_bits(obj.haoda_type) elem_idx = coalescing_idx * unroll_factor + obj.offset num_banks = num_bank_map[obj.var] bank = expr.dram[elem_idx % num_banks] lsb = (elem_idx // num_banks) * type_width msb = lsb + type_width - 1 return ir.Var( name='Reinterpret<{c_type}>(static_cast<ap_uint<{width} > >(' '{dram_buf_name}({msb}, {lsb})))'.format( c_type=obj.c_type, dram_buf_name=obj.dram_buf_name(bank), msb=msb, lsb=lsb, width=msb - lsb + 1), idx=()) return obj
def create_dataflow_graph(stencil): chronological_tensors = stencil.chronological_tensors super_source = SuperSourceNode( fwd_nodes={}, cpt_nodes={}, super_sink=SuperSinkNode(), ) load_nodes = { stmt.name: tuple(LoadNode(var=stmt.name, bank=bank) for bank in stmt.dram) for stmt in stencil.input_stmts } store_nodes = { stmt.name: tuple(StoreNode(var=stmt.name, bank=bank) for bank in stmt.dram) for stmt in stencil.output_stmts } for mem_node in itertools.chain(*load_nodes.values()): super_source.add_child(mem_node) for mem_node in itertools.chain(*store_nodes.values()): mem_node.add_child(super_source.super_sink) def color_id(node): if isinstance(node, LoadNode): return f'\033[33mload {node.var}[bank{node.bank}]\033[0m' if isinstance(node, StoreNode): return f'\033[36mstore {node.var}[bank{node.bank}]\033[0m' if isinstance(node, ForwardNode): return f'\033[32mforward {node.tensor.name} @{node.offset}\033[0m' if isinstance(node, ComputeNode): return f'\033[31mcompute {node.tensor.name} #{node.pe_id}\033[0m' return 'unknown node' def color_attr(node): result = [] for k, v in node.__dict__.items(): if (node.__class__, k) in ((SuperSourceNode, 'parents'), (SuperSinkNode, 'children')): continue if k in ('parents', 'children'): result.append('%s: [%s]' % (k, ', '.join(map(color_id, v)))) else: result.append('%s: %s' % (k, repr(v))) return '{%s}' % ', '.join(result) def color_print(node): return '%s: %s' % (color_id(node), color_attr(node)) print_node = color_id if stencil.replication_factor > 1: replicated_next_fifo = stencil.get_replicated_next_fifo() replicated_all_points = stencil.get_replicated_all_points() replicated_reuse_buffers = stencil.get_replicated_reuse_buffers() def add_fwd_nodes(src_name): dsts = replicated_all_points[src_name] reuse_buffer = replicated_reuse_buffers[src_name][1:] nodes_to_add = [] for dst_point_dicts in dsts.values(): for offset in dst_point_dicts: if (src_name, offset) in super_source.fwd_nodes: continue fwd_node = ForwardNode( tensor=stencil.tensors[src_name], offset=offset, depth=stencil.get_replicated_reuse_buffer_length( src_name, offset)) _logger.debug('create %s', print_node(fwd_node)) init_offsets = [ start for start, end in reuse_buffer if start == end ] if offset in init_offsets: if src_name in [stencil.input.name]: load_node_count = len(load_nodes[src_name]) load_nodes[src_name][ load_node_count - 1 - offset % load_node_count].add_child(fwd_node) else: (super_source.cpt_nodes[(src_name, 0)].add_child(fwd_node)) super_source.fwd_nodes[(src_name, offset)] = fwd_node if offset in replicated_next_fifo[src_name]: nodes_to_add.append( (fwd_node, (src_name, replicated_next_fifo[src_name][offset]))) for src_node, key in nodes_to_add: src_node.add_child(super_source.fwd_nodes[key]) add_fwd_nodes(stencil.input.name) for stage in stencil.get_stages_chronologically(): cpt_node = ComputeNode(stage=stage, pe_id=0) _logger.debug('create %s', print_node(cpt_node)) super_source.cpt_nodes[(stage.name, 0)] = cpt_node for input_name, input_window in stage.window.items(): for i in range(len(input_window)): offset = next(offset for offset, points in ( replicated_all_points[input_name][stage.name].items()) if points == i) fwd_node = super_source.fwd_nodes[(input_name, offset)] _logger.debug(' access %s', print_node(fwd_node)) fwd_node.add_child(cpt_node) if stage.is_output(): super_source.cpt_nodes[stage.name, 0].add_child(store_nodes[stage.name][0]) else: add_fwd_nodes(stage.name) else: next_fifo = stencil.next_fifo all_points = stencil.all_points reuse_buffers = stencil.reuse_buffers def add_fwd_nodes(src_name): dsts = all_points[src_name] reuse_buffer = reuse_buffers[src_name][1:] nodes_to_add = [] for dst_point_dicts in dsts.values(): for offset in dst_point_dicts: if (src_name, offset) in super_source.fwd_nodes: continue fwd_node = ForwardNode(tensor=stencil.tensors[src_name], offset=offset) #depth=stencil.get_reuse_buffer_length(src_name, offset)) _logger.debug('create %s', print_node(fwd_node)) # init_offsets is the start of each reuse chain init_offsets = [ next(end for start, end in reuse_buffer if start == unroll_idx) for unroll_idx in reversed(range(stencil.unroll_factor)) ] _logger.debug('reuse buffer: %s', reuse_buffer) _logger.debug('init offsets: %s', init_offsets) if offset in init_offsets: if src_name in stencil.input_names: # fwd from external input load_node_count = len(load_nodes[src_name]) load_nodes[src_name][ load_node_count - 1 - offset % load_node_count].add_child(fwd_node) else: # fwd from output of last stage # tensor name and offset are used to find the cpt node cpt_offset = next( unroll_idx for unroll_idx in range(stencil.unroll_factor) if init_offsets[unroll_idx] == offset) cpt_node = super_source.cpt_nodes[(src_name, cpt_offset)] cpt_node.add_child(fwd_node) super_source.fwd_nodes[(src_name, offset)] = fwd_node if offset in next_fifo[src_name]: nodes_to_add.append( (fwd_node, (src_name, next_fifo[src_name][offset]))) for src_node, key in nodes_to_add: # fwd from another fwd node src_node.add_child(super_source.fwd_nodes[key]) for input_name in stencil.input_names: add_fwd_nodes(input_name) for tensor in chronological_tensors: if tensor.is_input(): continue for unroll_index in range(stencil.unroll_factor): pe_id = stencil.unroll_factor - 1 - unroll_index cpt_node = ComputeNode(tensor=tensor, pe_id=pe_id) _logger.debug('create %s', print_node(cpt_node)) super_source.cpt_nodes[(tensor.name, pe_id)] = cpt_node for input_name, input_window in tensor.ld_indices.items(): for i in range(len(input_window)): offset = next( offset for offset, points in all_points[input_name][ tensor.name].items() if pe_id in points and points[pe_id] == i) fwd_node = super_source.fwd_nodes[(input_name, offset)] _logger.debug(' access %s', print_node(fwd_node)) fwd_node.add_child(cpt_node) if tensor.is_output(): for pe_id in range(stencil.unroll_factor): super_source.cpt_nodes[tensor.name, pe_id].add_child( store_nodes[tensor.name][pe_id % len( store_nodes[tensor.name])]) else: add_fwd_nodes(tensor.name) # pylint: disable=too-many-nested-blocks for src_node in super_source.tpo_valid_node_gen(): for dst_node in filter(is_valid_node, src_node.children): # 5 possible edge types: # 1. load => fwd # 2. fwd => fwd # 3. fwd => cpt # 4. cpt => fwd # 5. cpt => store if isinstance(src_node, LoadNode): write_lat = 0 elif isinstance(src_node, ForwardNode): write_lat = 2 elif isinstance(src_node, ComputeNode): write_lat = src_node.tensor.st_ref.lat else: raise util.InternalError('unexpected source node: %s' % repr(src_node)) fifo = ir.FIFO(src_node, dst_node, depth=0, write_lat=write_lat) lets: List[ir.Let] = [] if isinstance(src_node, LoadNode): expr = ir.DRAMRef( haoda_type=dst_node.tensor.haoda_type, dram=(src_node.bank, ), var=dst_node.tensor.name, offset=(stencil.unroll_factor - 1 - dst_node.offset) // len(stencil.stmt_table[dst_node.tensor.name].dram), ) elif isinstance(src_node, ForwardNode): if isinstance(dst_node, ComputeNode): dst = src_node.tensor.children[dst_node.tensor.name] src_name = src_node.tensor.name unroll_idx = dst_node.pe_id point = all_points[src_name][dst.name][ src_node.offset][unroll_idx] idx = list(dst.ld_indices[src_name].values())[point].idx _logger.debug( '%s%s referenced by <%s> @ unroll_idx=%d is %s', src_name, util.idx2str(idx), dst.name, unroll_idx, print_node(src_node)) dst_node.fifo_map[src_name][idx] = fifo delay = stencil.reuse_buffer_lengths[src_node.tensor.name]\ [src_node.offset] offset = src_node.offset - delay for parent in src_node.parents: # fwd node has only 1 parent for fifo_r in parent.fifos: if fifo_r.edge == (parent, src_node): break if delay > 0: # TODO: build an index somewhere for let in src_node.lets: # pylint: disable=undefined-loop-variable if isinstance( let.expr, ir.DelayedRef) and let.expr.ref == fifo_r: var_name = let.name var_type = let.haoda_type break else: var_name = 'let_%d' % len(src_node.lets) # pylint: disable=undefined-loop-variable var_type = fifo_r.haoda_type lets.append( ir.Let(haoda_type=var_type, name=var_name, expr=ir.DelayedRef(delay=delay, ref=fifo_r))) expr = ir.Var(name=var_name, idx=[]) expr.haoda_type = var_type else: expr = fifo_r # pylint: disable=undefined-loop-variable elif isinstance(src_node, ComputeNode): def replace_refs_callback(obj, args): if isinstance(obj, ir.Ref): _logger.debug( 'replace %s with %s', obj, # pylint: disable=cell-var-from-loop,undefined-loop-variable src_node.fifo_map[obj.name][obj.idx]) # pylint: disable=cell-var-from-loop,undefined-loop-variable return src_node.fifo_map[obj.name][obj.idx] return obj _logger.debug('lets: %s', src_node.tensor.lets) lets = [ _.visit(replace_refs_callback) for _ in src_node.tensor.lets ] _logger.debug('replaced lets: %s', lets) _logger.debug('expr: %s', src_node.tensor.expr) expr = src_node.tensor.expr.visit(replace_refs_callback) _logger.debug('replaced expr: %s', expr) if isinstance(dst_node, StoreNode): dram_ref = ir.DRAMRef( haoda_type=src_node.tensor.haoda_type, dram=(dst_node.bank, ), var=src_node.tensor.name, offset=(src_node.pe_id) // len(stencil.stmt_table[src_node.tensor.name].dram), ) dst_node.lets.append( ir.Let(haoda_type=None, name=dram_ref, expr=fifo)) else: raise util.InternalError('unexpected node of type %s' % type(src_node)) src_node.exprs[fifo] = expr src_node.lets.extend(_ for _ in lets if _ not in src_node.lets) _logger.debug( 'fifo [%d]: %s%s => %s', fifo.depth, color_id(src_node), '' if fifo.write_lat is None else ' ~%d' % fifo.write_lat, color_id(dst_node)) super_source.update_module_depths({}) return super_source