def propagate_memlet(dfg_state, memlet: Memlet, scope_node: nodes.EntryNode, union_inner_edges: bool, arr=None): """ Tries to propagate a memlet through a scope (computes the image of the memlet function applied on an integer set of, e.g., a map range) and returns a new memlet object. :param dfg_state: An SDFGState object representing the graph. :param memlet: The memlet adjacent to the scope node from the inside. :param scope_node: A scope entry or exit node. :param union_inner_edges: True if the propagation should take other neighboring internal memlets within the same scope into account. """ if isinstance(scope_node, nodes.EntryNode): entry_node = scope_node neighboring_edges = dfg_state.out_edges(scope_node) elif isinstance(scope_node, nodes.ExitNode): entry_node = dfg_state.scope_dict()[scope_node] neighboring_edges = dfg_state.in_edges(scope_node) else: raise TypeError('Trying to propagate through a non-scope node') if memlet.is_empty(): return Memlet() sdfg = dfg_state.parent scope_node_symbols = set(conn for conn in entry_node.in_connectors if not conn.startswith('IN_')) defined_vars = [ symbolic.pystr_to_symbolic(s) for s in dfg_state.symbols_defined_at(entry_node).keys() if s not in scope_node_symbols ] # Find other adjacent edges within the connected to the scope node # and union their subsets if union_inner_edges: aggdata = [ e.data for e in neighboring_edges if e.data.data == memlet.data and e.data != memlet ] else: aggdata = [] aggdata.append(memlet) if arr is None: if memlet.data not in sdfg.arrays: raise KeyError('Data descriptor (Array, Stream) "%s" not defined ' 'in SDFG.' % memlet.data) arr = sdfg.arrays[memlet.data] # Propagate subset if isinstance(entry_node, nodes.MapEntry): mapnode = entry_node.map return propagate_subset(aggdata, arr, mapnode.params, mapnode.range, defined_vars) elif isinstance(entry_node, nodes.ConsumeEntry): # Nothing to analyze/propagate in consume new_memlet = copy.copy(memlet) new_memlet.subset = subsets.Range.from_array(arr) new_memlet.other_subset = None new_memlet.volume = 0 new_memlet.dynamic = True return new_memlet else: raise NotImplementedError('Unimplemented primitive: %s' % type(entry_node))
def propagate_memlet(dfg_state, memlet: Memlet, scope_node: nodes.EntryNode, union_inner_edges: bool, arr=None): """ Tries to propagate a memlet through a scope (computes the image of the memlet function applied on an integer set of, e.g., a map range) and returns a new memlet object. :param dfg_state: An SDFGState object representing the graph. :param memlet: The memlet adjacent to the scope node from the inside. :param scope_node: A scope entry or exit node. :param union_inner_edges: True if the propagation should take other neighboring internal memlets within the same scope into account. """ if isinstance(scope_node, nodes.EntryNode): entry_node = scope_node neighboring_edges = dfg_state.out_edges(scope_node) elif isinstance(scope_node, nodes.ExitNode): entry_node = dfg_state.scope_dict()[scope_node] neighboring_edges = dfg_state.in_edges(scope_node) else: raise TypeError('Trying to propagate through a non-scope node') if memlet.is_empty(): return Memlet() sdfg = dfg_state.parent scope_node_symbols = set(conn for conn in entry_node.in_connectors if not conn.startswith('IN_')) defined_vars = [ symbolic.pystr_to_symbolic(s) for s in dfg_state.symbols_defined_at(entry_node).keys() if s not in scope_node_symbols ] # Find other adjacent edges within the connected to the scope node # and union their subsets if union_inner_edges: aggdata = [ e.data for e in neighboring_edges if e.data.data == memlet.data and e.data != memlet ] else: aggdata = [] aggdata.append(memlet) if arr is None: if memlet.data not in sdfg.arrays: raise KeyError('Data descriptor (Array, Stream) "%s" not defined ' 'in SDFG.' % memlet.data) arr = sdfg.arrays[memlet.data] # Propagate subset if isinstance(entry_node, nodes.MapEntry): mapnode = entry_node.map variable_context = [ defined_vars, [symbolic.pystr_to_symbolic(p) for p in mapnode.params] ] new_subset = None for md in aggdata: tmp_subset = None for pclass in MemletPattern.extensions(): pattern = pclass() if pattern.match([md.subset], variable_context, mapnode.range, [md]): tmp_subset = pattern.propagate(arr, [md.subset], mapnode.range) break else: # No patterns found. Emit a warning and propagate the entire # array warnings.warn('Cannot find appropriate memlet pattern to ' 'propagate %s through %s' % (str(md.subset), str(mapnode.range))) tmp_subset = subsets.Range.from_array(arr) # Union edges as necessary if new_subset is None: new_subset = tmp_subset else: old_subset = new_subset new_subset = subsets.union(new_subset, tmp_subset) if new_subset is None: warnings.warn('Subset union failed between %s and %s ' % (old_subset, tmp_subset)) break # Some unions failed if new_subset is None: new_subset = subsets.Range.from_array(arr) assert new_subset is not None elif isinstance(entry_node, nodes.ConsumeEntry): # Nothing to analyze/propagate in consume new_subset = subsets.Range.from_array(arr) else: raise NotImplementedError('Unimplemented primitive: %s' % type(entry_node)) ### End of subset propagation new_memlet = copy.copy(memlet) new_memlet.subset = new_subset new_memlet.other_subset = None # Number of accesses in the propagated memlet is the sum of the internal # number of accesses times the size of the map range set (unbounded dynamic) new_memlet.num_accesses = ( sum(m.num_accesses for m in aggdata) * functools.reduce(lambda a, b: a * b, scope_node.map.range.size(), 1)) if any(m.dynamic for m in aggdata): new_memlet.dynamic = True elif symbolic.issymbolic(new_memlet.num_accesses) and any( s not in defined_vars for s in new_memlet.num_accesses.free_symbols): new_memlet.dynamic = True new_memlet.num_accesses = 0 return new_memlet