def apply(self, sdfg: SDFG): state: SDFGState = sdfg.nodes()[self.state_id] nsdfg_node = state.nodes()[self.subgraph[InlineSDFG._nested_sdfg]] nsdfg: SDFG = nsdfg_node.sdfg nstate: SDFGState = nsdfg.nodes()[0] if nsdfg_node.schedule is not dtypes.ScheduleType.Default: infer_types.set_default_schedule_and_storage_types( nsdfg, nsdfg_node.schedule) nsdfg_scope_entry = state.entry_node(nsdfg_node) nsdfg_scope_exit = (state.exit_node(nsdfg_scope_entry) if nsdfg_scope_entry is not None else None) ####################################################### # Collect and update top-level SDFG metadata # Global/init/exit code for loc, code in nsdfg.global_code.items(): sdfg.append_global_code(code.code, loc) for loc, code in nsdfg.init_code.items(): sdfg.append_init_code(code.code, loc) for loc, code in nsdfg.exit_code.items(): sdfg.append_exit_code(code.code, loc) # Constants for cstname, cstval in nsdfg.constants.items(): if cstname in sdfg.constants: if cstval != sdfg.constants[cstname]: warnings.warn('Constant value mismatch for "%s" while ' 'inlining SDFG. Inner = %s != %s = outer' % (cstname, cstval, sdfg.constants[cstname])) else: sdfg.add_constant(cstname, cstval) # Find original source/destination edges (there is only one edge per # connector, according to match) inputs: Dict[str, MultiConnectorEdge] = {} outputs: Dict[str, MultiConnectorEdge] = {} input_set: Dict[str, str] = {} output_set: Dict[str, str] = {} for e in state.in_edges(nsdfg_node): inputs[e.dst_conn] = e input_set[e.data.data] = e.dst_conn for e in state.out_edges(nsdfg_node): outputs[e.src_conn] = e output_set[e.data.data] = e.src_conn # Access nodes that need to be reshaped reshapes: Set(str) = set() for aname, array in nsdfg.arrays.items(): if array.transient: continue edge = None if aname in inputs: edge = inputs[aname] if len(array.shape) > len(edge.data.subset): reshapes.add(aname) continue if aname in outputs: edge = outputs[aname] if len(array.shape) > len(edge.data.subset): reshapes.add(aname) continue if edge is not None and not InlineSDFG._check_strides( array.strides, sdfg.arrays[edge.data.data].strides, edge.data, nsdfg_node): reshapes.add(aname) # Replace symbols using invocation symbol mapping # Two-step replacement (N -> __dacesym_N --> map[N]) to avoid clashes for symname, symvalue in nsdfg_node.symbol_mapping.items(): if str(symname) != str(symvalue): nsdfg.replace(symname, '__dacesym_' + symname) for symname, symvalue in nsdfg_node.symbol_mapping.items(): if str(symname) != str(symvalue): nsdfg.replace('__dacesym_' + symname, symvalue) # All transients become transients of the parent (if data already # exists, find new name) # Mapping from nested transient name to top-level name transients: Dict[str, str] = {} for node in nstate.nodes(): if isinstance(node, nodes.AccessNode): datadesc = nsdfg.arrays[node.data] if node.data not in transients and datadesc.transient: name = sdfg.add_datadesc('%s_%s' % (nsdfg.label, node.data), datadesc, find_new_name=True) transients[node.data] = name # All transients of edges between code nodes are also added to parent for edge in nstate.edges(): if (isinstance(edge.src, nodes.CodeNode) and isinstance(edge.dst, nodes.CodeNode)): if edge.data.data is not None: datadesc = nsdfg.arrays[edge.data.data] if edge.data.data not in transients and datadesc.transient: name = sdfg.add_datadesc('%s_%s' % (nsdfg.label, edge.data.data), datadesc, find_new_name=True) transients[edge.data.data] = name # Collect nodes to add to top-level graph new_incoming_edges: Dict[nodes.Node, MultiConnectorEdge] = {} new_outgoing_edges: Dict[nodes.Node, MultiConnectorEdge] = {} source_accesses = set() sink_accesses = set() for node in nstate.source_nodes(): if (isinstance(node, nodes.AccessNode) and node.data not in transients and node.data not in reshapes): new_incoming_edges[node] = inputs[node.data] source_accesses.add(node) for node in nstate.sink_nodes(): if (isinstance(node, nodes.AccessNode) and node.data not in transients and node.data not in reshapes): new_outgoing_edges[node] = outputs[node.data] sink_accesses.add(node) ####################################################### # Replace data on inlined SDFG nodes/edges # Replace data names with their top-level counterparts repldict = {} repldict.update(transients) repldict.update({ k: v.data.data for k, v in itertools.chain(inputs.items(), outputs.items()) }) # Add views whenever reshapes are necessary for dname in reshapes: desc = nsdfg.arrays[dname] # To avoid potential confusion, rename protected __return keyword if dname.startswith('__return'): newname = f'{nsdfg.name}_ret{dname[8:]}' else: newname = dname newname, _ = sdfg.add_view(newname, desc.shape, desc.dtype, storage=desc.storage, strides=desc.strides, offset=desc.offset, debuginfo=desc.debuginfo, allow_conflicts=desc.allow_conflicts, total_size=desc.total_size, alignment=desc.alignment, may_alias=desc.may_alias, find_new_name=True) repldict[dname] = newname for node in nstate.nodes(): if isinstance(node, nodes.AccessNode) and node.data in repldict: node.data = repldict[node.data] for edge in nstate.edges(): if edge.data.data in repldict: edge.data.data = repldict[edge.data.data] # Add extra access nodes for out/in view nodes for node in nstate.nodes(): if isinstance(node, nodes.AccessNode) and node.data in reshapes: if nstate.in_degree(node) > 0 and nstate.out_degree(node) > 0: # Such a node has to be in the output set edge = outputs[node.data] # Redirect outgoing edges through access node out_edges = list(nstate.out_edges(node)) anode = nstate.add_access(edge.data.data) vnode = nstate.add_access(node.data) nstate.add_nedge(node, anode, edge.data) nstate.add_nedge(anode, vnode, edge.data) for e in out_edges: nstate.remove_edge(e) nstate.add_edge(vnode, e.src_conn, e.dst, e.dst_conn, e.data) ####################################################### # Add nested SDFG into top-level SDFG # Add nested nodes into original state subgraph = SubgraphView(nstate, [ n for n in nstate.nodes() if n not in (source_accesses | sink_accesses) ]) state.add_nodes_from(subgraph.nodes()) for edge in subgraph.edges(): state.add_edge(edge.src, edge.src_conn, edge.dst, edge.dst_conn, edge.data) ####################################################### # Reconnect inlined SDFG # If a source/sink node is one of the inputs/outputs, reconnect it, # replacing memlets in outgoing/incoming paths modified_edges = set() modified_edges |= self._modify_memlet_path(new_incoming_edges, nstate, state, True) modified_edges |= self._modify_memlet_path(new_outgoing_edges, nstate, state, False) # Reshape: add connections to viewed data self._modify_reshape_data(reshapes, repldict, inputs, nstate, state, True) self._modify_reshape_data(reshapes, repldict, outputs, nstate, state, False) # Modify all other internal edges pertaining to input/output nodes for node in subgraph.nodes(): if isinstance(node, nodes.AccessNode): if node.data in input_set or node.data in output_set: if node.data in input_set: outer_edge = inputs[input_set[node.data]] else: outer_edge = outputs[output_set[node.data]] for edge in state.all_edges(node): if (edge not in modified_edges and edge.data.data == node.data): for e in state.memlet_tree(edge): if e.data.data == node.data: e._data = helpers.unsqueeze_memlet( e.data, outer_edge.data) # If source/sink node is not connected to a source/destination access # node, and the nested SDFG is in a scope, connect to scope with empty # memlets if nsdfg_scope_entry is not None: for node in subgraph.nodes(): if state.in_degree(node) == 0: state.add_edge(nsdfg_scope_entry, None, node, None, Memlet()) if state.out_degree(node) == 0: state.add_edge(node, None, nsdfg_scope_exit, None, Memlet()) # Replace nested SDFG parents with new SDFG for node in nstate.nodes(): if isinstance(node, nodes.NestedSDFG): node.sdfg.parent = state node.sdfg.parent_sdfg = sdfg node.sdfg.parent_nsdfg_node = node # Remove all unused external inputs/output memlet paths, as well as # resulting isolated nodes removed_in_edges = self._remove_edge_path(state, inputs, set(inputs.keys()) - source_accesses, reverse=True) removed_out_edges = self._remove_edge_path(state, outputs, set(outputs.keys()) - sink_accesses, reverse=False) # Re-add in/out edges to first/last nodes in subgraph order = [ x for x in nx.topological_sort(nstate._nx) if isinstance(x, nodes.AccessNode) ] for edge in removed_in_edges: # Find first access node that refers to this edge node = next(n for n in order if n.data == edge.data.data) state.add_edge(edge.src, edge.src_conn, node, edge.dst_conn, edge.data) for edge in removed_out_edges: # Find last access node that refers to this edge node = next(n for n in reversed(order) if n.data == edge.data.data) state.add_edge(node, edge.src_conn, edge.dst, edge.dst_conn, edge.data) ####################################################### # Remove nested SDFG node state.remove_node(nsdfg_node)
def nest_state_subgraph(sdfg: SDFG, state: SDFGState, subgraph: SubgraphView, name: Optional[str] = None, full_data: bool = False) -> nodes.NestedSDFG: """ Turns a state subgraph into a nested SDFG. Operates in-place. :param sdfg: The SDFG containing the state subgraph. :param state: The state containing the subgraph. :param subgraph: Subgraph to nest. :param name: An optional name for the nested SDFG. :param full_data: If True, nests entire input/output data. :return: The nested SDFG node. :raise KeyError: Some or all nodes in the subgraph are not located in this state, or the state does not belong to the given SDFG. :raise ValueError: The subgraph is contained in more than one scope. """ if state.parent != sdfg: raise KeyError('State does not belong to given SDFG') if subgraph is not state and subgraph.graph is not state: raise KeyError('Subgraph does not belong to given state') # Find the top-level scope scope_tree = state.scope_tree() scope_dict = state.scope_dict() scope_dict_children = state.scope_children() top_scopenode = -1 # Initialized to -1 since "None" already means top-level for node in subgraph.nodes(): if node not in scope_dict: raise KeyError('Node not found in state') # If scope entry/exit, ensure entire scope is in subgraph if isinstance(node, nodes.EntryNode): scope_nodes = scope_dict_children[node] if any(n not in subgraph.nodes() for n in scope_nodes): raise ValueError('Subgraph contains partial scopes (entry)') elif isinstance(node, nodes.ExitNode): entry = state.entry_node(node) scope_nodes = scope_dict_children[entry] + [entry] if any(n not in subgraph.nodes() for n in scope_nodes): raise ValueError('Subgraph contains partial scopes (exit)') scope_node = scope_dict[node] if scope_node not in subgraph.nodes(): if top_scopenode != -1 and top_scopenode != scope_node: raise ValueError('Subgraph is contained in more than one scope') top_scopenode = scope_node scope = scope_tree[top_scopenode] ### # Consolidate edges in top scope utils.consolidate_edges(sdfg, scope) snodes = subgraph.nodes() # Collect inputs and outputs of the nested SDFG inputs: List[MultiConnectorEdge] = [] outputs: List[MultiConnectorEdge] = [] for node in snodes: for edge in state.in_edges(node): if edge.src not in snodes: inputs.append(edge) for edge in state.out_edges(node): if edge.dst not in snodes: outputs.append(edge) # Collect transients not used outside of subgraph (will be removed of # top-level graph) data_in_subgraph = set(n.data for n in subgraph.nodes() if isinstance(n, nodes.AccessNode)) # Find other occurrences in SDFG other_nodes = set(n.data for s in sdfg.nodes() for n in s.nodes() if isinstance(n, nodes.AccessNode) and n not in subgraph.nodes()) subgraph_transients = set() for data in data_in_subgraph: datadesc = sdfg.arrays[data] if datadesc.transient and data not in other_nodes: subgraph_transients.add(data) # All transients of edges between code nodes are also added to nested graph for edge in subgraph.edges(): if (isinstance(edge.src, nodes.CodeNode) and isinstance(edge.dst, nodes.CodeNode)): subgraph_transients.add(edge.data.data) # Collect data used in access nodes within subgraph (will be referenced in # full upon nesting) input_arrays = set() output_arrays = {} for node in subgraph.nodes(): if (isinstance(node, nodes.AccessNode) and node.data not in subgraph_transients): if node.has_reads(state): input_arrays.add(node.data) if node.has_writes(state): output_arrays[node.data] = state.in_edges(node)[0].data.wcr # Create the nested SDFG nsdfg = SDFG(name or 'nested_' + state.label) # Transients are added to the nested graph as-is for name in subgraph_transients: nsdfg.add_datadesc(name, sdfg.arrays[name]) # Input/output data that are not source/sink nodes are added to the graph # as non-transients for name in (input_arrays | output_arrays.keys()): datadesc = copy.deepcopy(sdfg.arrays[name]) datadesc.transient = False nsdfg.add_datadesc(name, datadesc) # Connected source/sink nodes outside subgraph become global data # descriptors in nested SDFG input_names = {} output_names = {} global_subsets: Dict[str, Tuple[str, Subset]] = {} for edge in inputs: if edge.data.data is None: # Skip edges with an empty memlet continue name = edge.data.data if name not in global_subsets: datadesc = copy.deepcopy(sdfg.arrays[edge.data.data]) datadesc.transient = False if not full_data: datadesc.shape = edge.data.subset.size() new_name = nsdfg.add_datadesc(name, datadesc, find_new_name=True) global_subsets[name] = (new_name, edge.data.subset) else: new_name, subset = global_subsets[name] if not full_data: new_subset = union(subset, edge.data.subset) if new_subset is None: new_subset = Range.from_array(sdfg.arrays[name]) global_subsets[name] = (new_name, new_subset) nsdfg.arrays[new_name].shape = new_subset.size() input_names[edge] = new_name for edge in outputs: if edge.data.data is None: # Skip edges with an empty memlet continue name = edge.data.data if name not in global_subsets: datadesc = copy.deepcopy(sdfg.arrays[edge.data.data]) datadesc.transient = False if not full_data: datadesc.shape = edge.data.subset.size() new_name = nsdfg.add_datadesc(name, datadesc, find_new_name=True) global_subsets[name] = (new_name, edge.data.subset) else: new_name, subset = global_subsets[name] if not full_data: new_subset = union(subset, edge.data.subset) if new_subset is None: new_subset = Range.from_array(sdfg.arrays[name]) global_subsets[name] = (new_name, new_subset) nsdfg.arrays[new_name].shape = new_subset.size() output_names[edge] = new_name ################### # Add scope symbols to the nested SDFG defined_vars = set( symbolic.pystr_to_symbolic(s) for s in (state.symbols_defined_at(top_scopenode).keys() | sdfg.symbols)) for v in defined_vars: if v in sdfg.symbols: sym = sdfg.symbols[v] nsdfg.add_symbol(v, sym.dtype) # Add constants to nested SDFG for cstname, cstval in sdfg.constants.items(): nsdfg.add_constant(cstname, cstval) # Create nested state nstate = nsdfg.add_state() # Add subgraph nodes and edges to nested state nstate.add_nodes_from(subgraph.nodes()) for e in subgraph.edges(): nstate.add_edge(e.src, e.src_conn, e.dst, e.dst_conn, copy.deepcopy(e.data)) # Modify nested SDFG parents in subgraph for node in subgraph.nodes(): if isinstance(node, nodes.NestedSDFG): node.sdfg.parent = nstate node.sdfg.parent_sdfg = nsdfg node.sdfg.parent_nsdfg_node = node # Add access nodes and edges as necessary edges_to_offset = [] for edge, name in input_names.items(): node = nstate.add_read(name) new_edge = copy.deepcopy(edge.data) new_edge.data = name edges_to_offset.append((edge, nstate.add_edge(node, None, edge.dst, edge.dst_conn, new_edge))) for edge, name in output_names.items(): node = nstate.add_write(name) new_edge = copy.deepcopy(edge.data) new_edge.data = name edges_to_offset.append((edge, nstate.add_edge(edge.src, edge.src_conn, node, None, new_edge))) # Offset memlet paths inside nested SDFG according to subsets for original_edge, new_edge in edges_to_offset: for edge in nstate.memlet_tree(new_edge): edge.data.data = new_edge.data.data if not full_data: edge.data.subset.offset(global_subsets[original_edge.data.data][1], True) # Add nested SDFG node to the input state nested_sdfg = state.add_nested_sdfg(nsdfg, None, set(input_names.values()) | input_arrays, set(output_names.values()) | output_arrays.keys()) # Reconnect memlets to nested SDFG reconnected_in = set() reconnected_out = set() empty_input = None empty_output = None for edge in inputs: if edge.data.data is None: empty_input = edge continue name = input_names[edge] if name in reconnected_in: continue if full_data: data = Memlet.from_array(edge.data.data, sdfg.arrays[edge.data.data]) else: data = copy.deepcopy(edge.data) data.subset = global_subsets[edge.data.data][1] state.add_edge(edge.src, edge.src_conn, nested_sdfg, name, data) reconnected_in.add(name) for edge in outputs: if edge.data.data is None: empty_output = edge continue name = output_names[edge] if name in reconnected_out: continue if full_data: data = Memlet.from_array(edge.data.data, sdfg.arrays[edge.data.data]) else: data = copy.deepcopy(edge.data) data.subset = global_subsets[edge.data.data][1] data.wcr = edge.data.wcr state.add_edge(nested_sdfg, name, edge.dst, edge.dst_conn, data) reconnected_out.add(name) # Connect access nodes to internal input/output data as necessary entry = scope.entry exit = scope.exit for name in input_arrays: node = state.add_read(name) if entry is not None: state.add_nedge(entry, node, Memlet()) state.add_edge(node, None, nested_sdfg, name, Memlet.from_array(name, sdfg.arrays[name])) for name, wcr in output_arrays.items(): node = state.add_write(name) if exit is not None: state.add_nedge(node, exit, Memlet()) state.add_edge(nested_sdfg, name, node, None, Memlet(data=name, wcr=wcr)) # Graph was not reconnected, but needs to be if state.in_degree(nested_sdfg) == 0 and empty_input is not None: state.add_edge(empty_input.src, empty_input.src_conn, nested_sdfg, None, empty_input.data) if state.out_degree(nested_sdfg) == 0 and empty_output is not None: state.add_edge(nested_sdfg, None, empty_output.dst, empty_output.dst_conn, empty_output.data) # Remove subgraph nodes from graph state.remove_nodes_from(subgraph.nodes()) # Remove subgraph transients from top-level graph for transient in subgraph_transients: del sdfg.arrays[transient] # Remove newly isolated nodes due to memlet consolidation for edge in inputs: if state.in_degree(edge.src) + state.out_degree(edge.src) == 0: state.remove_node(edge.src) for edge in outputs: if state.in_degree(edge.dst) + state.out_degree(edge.dst) == 0: state.remove_node(edge.dst) return nested_sdfg
def generate_reference(name, chain): """Generates a simple, unoptimized SDFG to run on the CPU, for verification purposes.""" sdfg = SDFG(name) for k, v in chain.constants.items(): sdfg.add_constant(k, v["value"], dace.data.Scalar(v["data_type"])) (dimensions_to_skip, shape, vector_length, parameters, iterators, memcopy_indices, memcopy_accesses) = _generate_init(chain) prev_state = sdfg.add_state("init") # Throw vectorization in the bin for the reference code vector_length = 1 shape = tuple(map(int, shape)) input_shapes = {} # Maps inputs to their shape tuple for node in chain.graph.nodes(): if isinstance(node, Input) or isinstance(node, Output): if isinstance(node, Input): for output in node.outputs.values(): pars = tuple( output["input_dims"] ) if "input_dims" in output and output[ "input_dims"] is not None else tuple(parameters) arr_shape = tuple(s for s, p in zip(shape, parameters) if p in pars) input_shapes[node.name] = arr_shape break else: raise ValueError("No outputs found for input node.") else: arr_shape = shape if len(arr_shape) > 0: try: sdfg.add_array(node.name, arr_shape, node.data_type) except NameError: sdfg.data( node.name).access = dace.dtypes.AccessType.ReadWrite else: sdfg.add_symbol(node.name, node.data_type) for link in chain.graph.edges(data=True): name = link[0].name if name not in sdfg.arrays and name not in sdfg.symbols: sdfg.add_array(name, shape, link[0].data_type, transient=True) input_shapes[name] = tuple(shape) input_iterators = { k: tuple("0:{}".format(s) for s in v) for k, v in input_shapes.items() } # Enforce dependencies via topological sort for node in nx.topological_sort(chain.graph): if not isinstance(node, Kernel): continue state = sdfg.add_state(node.name) sdfg.add_edge(prev_state, state, dace.InterstateEdge()) (stencil_node, input_to_connector, output_to_connector) = _generate_stencil(node, chain, shape, dimensions_to_skip) stencil_node.implementation = "CPU" for field, connector in input_to_connector.items(): if len(input_iterators[field]) == 0: continue # Scalar variable # Outer memory read read_node = state.add_read(field) state.add_memlet_path(read_node, stencil_node, dst_conn=connector, memlet=Memlet.simple( field, ", ".join(input_iterators[field]))) for _, connector in output_to_connector.items(): # Outer write write_node = state.add_write(node.name) state.add_memlet_path(stencil_node, write_node, src_conn=connector, memlet=Memlet.simple( node.name, ", ".join("0:{}".format(s) for s in shape))) prev_state = state return sdfg
def apply(self, outer_state: SDFGState, sdfg: SDFG): nsdfg_node = self.nested_sdfg nsdfg: SDFG = nsdfg_node.sdfg if nsdfg_node.schedule is not dtypes.ScheduleType.Default: infer_types.set_default_schedule_and_storage_types( nsdfg, nsdfg_node.schedule) ####################################################### # Collect and update top-level SDFG metadata # Global/init/exit code for loc, code in nsdfg.global_code.items(): sdfg.append_global_code(code.code, loc) for loc, code in nsdfg.init_code.items(): sdfg.append_init_code(code.code, loc) for loc, code in nsdfg.exit_code.items(): sdfg.append_exit_code(code.code, loc) # Environments for nstate in nsdfg.nodes(): for node in nstate.nodes(): if isinstance(node, nodes.CodeNode): node.environments |= nsdfg_node.environments # Constants for cstname, cstval in nsdfg.constants.items(): if cstname in sdfg.constants: if cstval != sdfg.constants[cstname]: warnings.warn('Constant value mismatch for "%s" while ' 'inlining SDFG. Inner = %s != %s = outer' % (cstname, cstval, sdfg.constants[cstname])) else: sdfg.add_constant(cstname, cstval) # Symbols outer_symbols = {str(k): v for k, v in sdfg.symbols.items()} for ise in sdfg.edges(): outer_symbols.update(ise.data.new_symbols(sdfg, outer_symbols)) # Find original source/destination edges (there is only one edge per # connector, according to match) inputs: Dict[str, MultiConnectorEdge] = {} outputs: Dict[str, MultiConnectorEdge] = {} input_set: Dict[str, str] = {} output_set: Dict[str, str] = {} for e in outer_state.in_edges(nsdfg_node): inputs[e.dst_conn] = e input_set[e.data.data] = e.dst_conn for e in outer_state.out_edges(nsdfg_node): outputs[e.src_conn] = e output_set[e.data.data] = e.src_conn # Replace symbols using invocation symbol mapping # Two-step replacement (N -> __dacesym_N --> map[N]) to avoid clashes symbolic.safe_replace(nsdfg_node.symbol_mapping, nsdfg.replace_dict) # Access nodes that need to be reshaped # reshapes: Set(str) = set() # for aname, array in nsdfg.arrays.items(): # if array.transient: # continue # edge = None # if aname in inputs: # edge = inputs[aname] # if len(array.shape) > len(edge.data.subset): # reshapes.add(aname) # continue # if aname in outputs: # edge = outputs[aname] # if len(array.shape) > len(edge.data.subset): # reshapes.add(aname) # continue # if edge is not None and not InlineMultistateSDFG._check_strides( # array.strides, sdfg.arrays[edge.data.data].strides, # edge.data, nsdfg_node): # reshapes.add(aname) # Mapping from nested transient name to top-level name transients: Dict[str, str] = {} # All transients become transients of the parent (if data already # exists, find new name) for nstate in nsdfg.nodes(): for node in nstate.nodes(): if isinstance(node, nodes.AccessNode): datadesc = nsdfg.arrays[node.data] if node.data not in transients and datadesc.transient: new_name = node.data if (new_name in sdfg.arrays or new_name in outer_symbols or new_name in sdfg.constants): new_name = f'{nsdfg.label}_{node.data}' name = sdfg.add_datadesc(new_name, datadesc, find_new_name=True) transients[node.data] = name # All transients of edges between code nodes are also added to parent for edge in nstate.edges(): if (isinstance(edge.src, nodes.CodeNode) and isinstance(edge.dst, nodes.CodeNode)): if edge.data.data is not None: datadesc = nsdfg.arrays[edge.data.data] if edge.data.data not in transients and datadesc.transient: new_name = edge.data.data if (new_name in sdfg.arrays or new_name in outer_symbols or new_name in sdfg.constants): new_name = f'{nsdfg.label}_{edge.data.data}' name = sdfg.add_datadesc(new_name, datadesc, find_new_name=True) transients[edge.data.data] = name ####################################################### # Replace data on inlined SDFG nodes/edges # Replace data names with their top-level counterparts repldict = {} repldict.update(transients) repldict.update({ k: v.data.data for k, v in itertools.chain(inputs.items(), outputs.items()) }) symbolic.safe_replace(repldict, lambda m: replace_datadesc_names(nsdfg, m), value_as_string=True) # Add views whenever reshapes are necessary # for dname in reshapes: # desc = nsdfg.arrays[dname] # # To avoid potential confusion, rename protected __return keyword # if dname.startswith('__return'): # newname = f'{nsdfg.name}_ret{dname[8:]}' # else: # newname = dname # newname, _ = sdfg.add_view(newname, # desc.shape, # desc.dtype, # storage=desc.storage, # strides=desc.strides, # offset=desc.offset, # debuginfo=desc.debuginfo, # allow_conflicts=desc.allow_conflicts, # total_size=desc.total_size, # alignment=desc.alignment, # may_alias=desc.may_alias, # find_new_name=True) # repldict[dname] = newname # Add extra access nodes for out/in view nodes # inv_reshapes = {repldict[r]: r for r in reshapes} # for nstate in nsdfg.nodes(): # for node in nstate.nodes(): # if isinstance(node, # nodes.AccessNode) and node.data in inv_reshapes: # if nstate.in_degree(node) > 0 and nstate.out_degree( # node) > 0: # # Such a node has to be in the output set # edge = outputs[inv_reshapes[node.data]] # # Redirect outgoing edges through access node # out_edges = list(nstate.out_edges(node)) # anode = nstate.add_access(edge.data.data) # vnode = nstate.add_access(node.data) # nstate.add_nedge(node, anode, edge.data) # nstate.add_nedge(anode, vnode, edge.data) # for e in out_edges: # nstate.remove_edge(e) # nstate.add_edge(vnode, e.src_conn, e.dst, # e.dst_conn, e.data) # Make unique names for states statenames = set(s.label for s in sdfg.nodes()) for nstate in nsdfg.nodes(): if nstate.label in statenames: newname = data.find_new_name(nstate.label, statenames) statenames.add(newname) nstate.set_label(newname) ####################################################### # Collect and modify interstate edges as necessary outer_assignments = set() for e in sdfg.edges(): outer_assignments |= e.data.assignments.keys() inner_assignments = set() for e in nsdfg.edges(): inner_assignments |= e.data.assignments.keys() assignments_to_replace = inner_assignments & outer_assignments sym_replacements: Dict[str, str] = {} allnames = set(outer_symbols.keys()) | set(sdfg.arrays.keys()) for assign in assignments_to_replace: newname = data.find_new_name(assign, allnames) allnames.add(newname) sym_replacements[assign] = newname nsdfg.replace_dict(sym_replacements) ####################################################### # Add nested SDFG states into top-level SDFG outer_start_state = sdfg.start_state sdfg.add_nodes_from(nsdfg.nodes()) for ise in nsdfg.edges(): sdfg.add_edge(ise.src, ise.dst, ise.data) ####################################################### # Reconnect inlined SDFG source = nsdfg.start_state sinks = nsdfg.sink_nodes() # Reconnect state machine for e in sdfg.in_edges(outer_state): sdfg.add_edge(e.src, source, e.data) for e in sdfg.out_edges(outer_state): for sink in sinks: sdfg.add_edge(sink, e.dst, e.data) # Modify start state as necessary if outer_start_state is outer_state: sdfg.start_state = sdfg.node_id(source) # TODO: Modify memlets by offsetting # If both source and sink nodes are inputs/outputs, reconnect once # edges_to_ignore = self._modify_access_to_access(new_incoming_edges, # nsdfg, nstate, state, # orig_data) # source_to_outer = {n: e.src for n, e in new_incoming_edges.items()} # sink_to_outer = {n: e.dst for n, e in new_outgoing_edges.items()} # # If a source/sink node is one of the inputs/outputs, reconnect it, # # replacing memlets in outgoing/incoming paths # modified_edges = set() # modified_edges |= self._modify_memlet_path(new_incoming_edges, nstate, # state, sink_to_outer, True, # edges_to_ignore) # modified_edges |= self._modify_memlet_path(new_outgoing_edges, nstate, # state, source_to_outer, # False, edges_to_ignore) # # Reshape: add connections to viewed data # self._modify_reshape_data(reshapes, repldict, inputs, nstate, state, # True) # self._modify_reshape_data(reshapes, repldict, outputs, nstate, state, # False) # Modify all other internal edges pertaining to input/output nodes # for nstate in nsdfg.nodes(): # for node in nstate.nodes(): # if isinstance(node, nodes.AccessNode): # if node.data in input_set or node.data in output_set: # if node.data in input_set: # outer_edge = inputs[input_set[node.data]] # else: # outer_edge = outputs[output_set[node.data]] # for edge in state.all_edges(node): # if (edge not in modified_edges # and edge.data.data == node.data): # for e in state.memlet_tree(edge): # if e.data.data == node.data: # e._data = helpers.unsqueeze_memlet( # e.data, outer_edge.data) # Replace nested SDFG parents with new SDFG for nstate in nsdfg.nodes(): nstate.parent = sdfg for node in nstate.nodes(): if isinstance(node, nodes.NestedSDFG): node.sdfg.parent_sdfg = sdfg node.sdfg.parent_nsdfg_node = node ####################################################### # Remove nested SDFG and state sdfg.remove_node(outer_state) return nsdfg.nodes()
def generate_sdfg(name, chain, synthetic_reads=False, specialize_scalars=False): sdfg = SDFG(name) for k, v in chain.constants.items(): sdfg.add_constant(k, v["value"], dace.data.Scalar(v["data_type"])) if specialize_scalars: for k, v in chain.inputs.items(): if len(v["input_dims"]) == 0: try: val = stencilflow.load_array(v) except FileNotFoundError: continue print(f"Specialized constant {k} to {val}.") sdfg.add_constant(k, val) pre_state = sdfg.add_state("initialize") state = sdfg.add_state("compute") post_state = sdfg.add_state("finalize") sdfg.add_edge(pre_state, state, InterstateEdge()) sdfg.add_edge(state, post_state, InterstateEdge()) (dimensions_to_skip, shape, vector_length, parameters, iterators, memcopy_indices, memcopy_accesses) = _generate_init(chain) vshape = list(shape) # Copy if vector_length > 1: vshape[-1] //= vector_length def add_input(node, bank): # Collapse iterators and shape if input is lower dimensional for output in node.outputs.values(): try: input_pars = output["input_dims"][:] except (KeyError, TypeError): input_pars = list(parameters) # Copy break # Just needed any output to retrieve the dimensions else: raise ValueError("Input {} is not connected to anything.".format( node.name)) # If scalar, just add a symbol if len(input_pars) == 0: sdfg.add_symbol(node.name, node.data_type) return # We're done input_shape = [shape[list(parameters).index(i)] for i in input_pars] input_accesses = str(functools.reduce(operator.mul, input_shape, 1)) # Only vectorize the read if the innermost dimensions is read input_vector_length = (vector_length if input_pars[-1] == parameters[-1] else 1) input_vtype = (dace.dtypes.vector(node.data_type, input_vector_length) if input_vector_length > 1 else node.data_type) input_vshape = list(input_shape) if input_vector_length > 1: input_vshape[-1] //= input_vector_length # Sort to get deterministic output outputs = sorted([e[1].name for e in chain.graph.out_edges(node)]) out_memlets = ["_" + o for o in outputs] entry, exit = state.add_map("read_" + node.name, iterators, schedule=ScheduleType.FPGA_Device) if not synthetic_reads: # Generate synthetic inputs without memory # Host-side array, which will be an input argument sdfg.add_array(node.name + "_host", input_shape, node.data_type) # Device-side copy _, array = sdfg.add_array(node.name, input_vshape, input_vtype, storage=StorageType.FPGA_Global, transient=True) array.location["bank"] = bank access_node = state.add_read(node.name) # Copy data to the FPGA copy_host = pre_state.add_read(node.name + "_host") copy_fpga = pre_state.add_write(node.name) pre_state.add_memlet_path(copy_host, copy_fpga, memlet=Memlet.simple( copy_fpga, ", ".join("0:{}".format(s) for s in input_vshape), num_accesses=input_accesses)) tasklet_code = "\n".join( ["{} = memory".format(o) for o in out_memlets]) tasklet = state.add_tasklet("read_" + node.name, {"memory"}, out_memlets, tasklet_code) vectorized_pars = input_pars # if input_vector_length > 1: # vectorized_pars[-1] = "{}*{}".format(input_vector_length, # vectorized_pars[-1]) # Lower-dimensional arrays should buffer values and send them # multiple times is_lower_dim = len(input_shape) != len(shape) if is_lower_dim: buffer_name = node.name + "_buffer" sdfg.add_array(buffer_name, input_shape, input_vtype, storage=StorageType.FPGA_Local, transient=True) buffer_node = state.add_access(buffer_name) buffer_entry, buffer_exit = state.add_map( "buffer_" + node.name, { k: "0:{}".format(v) for k, v in zip(input_pars, input_shape) }, schedule=dace.ScheduleType.FPGA_Device) buffer_tasklet = state.add_tasklet("buffer_" + node.name, {"memory"}, {"buffer"}, "buffer = memory") state.add_memlet_path(access_node, buffer_entry, buffer_tasklet, dst_conn="memory", memlet=dace.Memlet.simple( access_node.data, ", ".join(vectorized_pars), num_accesses=1)) state.add_memlet_path(buffer_tasklet, buffer_exit, buffer_node, src_conn="buffer", memlet=dace.Memlet.simple( buffer_node.data, ", ".join(input_pars), num_accesses=1)) state.add_memlet_path(buffer_node, entry, tasklet, dst_conn="memory", memlet=dace.Memlet.simple( buffer_node.data, ", ".join(input_pars), num_accesses=1)) else: state.add_memlet_path(access_node, entry, tasklet, dst_conn="memory", memlet=Memlet.simple( node.name, ", ".join(vectorized_pars), num_accesses=1)) else: tasklet_code = "\n".join([ "{} = {}".format(o, float(synthetic_reads)) for o in out_memlets ]) tasklet = state.add_tasklet("read_" + node.name, {}, out_memlets, tasklet_code) state.add_memlet_path(entry, tasklet, memlet=dace.Memlet()) # Add memlets to all FIFOs connecting to compute units for out_name, out_memlet in zip(outputs, out_memlets): stream_name = "read_{}_to_{}".format(node.name, out_name) write_node = state.add_write(stream_name) state.add_memlet_path(tasklet, exit, write_node, src_conn=out_memlet, memlet=Memlet.simple(stream_name, "0", num_accesses=1)) def add_output(node, bank): # Host-side array, which will be an output argument try: sdfg.add_array(node.name + "_host", shape, node.data_type) _, array = sdfg.add_array(node.name, vshape, dace.dtypes.vector( node.data_type, vector_length), storage=StorageType.FPGA_Global, transient=True) array.location["bank"] = bank except NameError: # This array is also read sdfg.data(node.name + "_host").access = dace.AccessType.ReadWrite sdfg.data(node.name).access = dace.AccessType.ReadWrite # Device-side copy write_node = state.add_write(node.name) # Copy data to the host copy_fpga = post_state.add_read(node.name) copy_host = post_state.add_write(node.name + "_host") post_state.add_memlet_path(copy_fpga, copy_host, memlet=Memlet.simple( copy_fpga, ", ".join(memcopy_indices), num_accesses=memcopy_accesses)) entry, exit = state.add_map("write_" + node.name, iterators, schedule=ScheduleType.FPGA_Device) src = chain.graph.in_edges(node) if len(src) > 1: raise RuntimeError("Only one writer per output supported") src = next(iter(src))[0] in_memlet = "_" + src.name tasklet_code = "memory = " + in_memlet tasklet = state.add_tasklet("write_" + node.name, {in_memlet}, {"memory"}, tasklet_code) vectorized_pars = copy.copy(parameters) # if vector_length > 1: # vectorized_pars[-1] = "{}*{}".format(vector_length, # vectorized_pars[-1]) stream_name = "{}_to_write_{}".format(src.name, node.name) read_node = state.add_read(stream_name) state.add_memlet_path(read_node, entry, tasklet, dst_conn=in_memlet, memlet=Memlet.simple(stream_name, "0", num_accesses=1)) state.add_memlet_path(tasklet, exit, write_node, src_conn="memory", memlet=Memlet.simple(node.name, ", ".join(vectorized_pars), num_accesses=1)) def add_kernel(node): (stencil_node, input_to_connector, output_to_connector) = _generate_stencil(node, chain, shape, dimensions_to_skip) if len(stencil_node.output_fields) == 0: if len(input_to_connector) == 0: warnings.warn("Ignoring orphan stencil: {}".format(node.name)) else: raise ValueError("Orphan stencil with inputs: {}".format( node.name)) return vendor_str = dace.config.Config.get("compiler", "fpga_vendor") if vendor_str == "intel_fpga": stencil_node.implementation = "Intel FPGA" elif vendor_str == "xilinx": stencil_node.implementation = "Xilinx" else: raise ValueError(f"Unsupported FPGA backend: {vendor_str}") state.add_node(stencil_node) is_from_memory = { e[0].name: not isinstance(e[0], stencilflow.kernel.Kernel) for e in chain.graph.in_edges(node) } is_to_memory = { e[1].name: not isinstance(e[1], stencilflow.kernel.Kernel) for e in chain.graph.out_edges(node) } # Add read nodes and memlets for field_name, connector in input_to_connector.items(): input_vector_length = vector_length try: # Scalars are symbols rather than data nodes if len(node.inputs[field_name]["input_dims"]) == 0: continue else: # If the innermost dimension of this field is not the # vectorized one, read it as scalars if (node.inputs[field_name]["input_dims"][-1] != parameters[-1]): input_vector_length = 1 except (KeyError, TypeError): pass # input_dim is not defined or is None if is_from_memory[field_name]: stream_name = "read_{}_to_{}".format(field_name, node.name) else: stream_name = "{}_to_{}".format(field_name, node.name) # Outer memory read read_node = state.add_read(stream_name) state.add_memlet_path(read_node, stencil_node, dst_conn=connector, memlet=Memlet.simple( stream_name, "0", num_accesses=memcopy_accesses)) # Add read nodes and memlets for output_name, connector in output_to_connector.items(): # Add write node and memlet if is_to_memory[output_name]: stream_name = "{}_to_write_{}".format(node.name, output_name) else: stream_name = "{}_to_{}".format(node.name, output_name) # Outer write write_node = state.add_write(stream_name) state.add_memlet_path(stencil_node, write_node, src_conn=connector, memlet=Memlet.simple( stream_name, "0", num_accesses=memcopy_accesses)) # First generate all connections between kernels and memories for link in chain.graph.edges(data=True): _add_pipe(sdfg, link, parameters, vector_length) bank = 0 # Now generate all memory access functions so arrays are registered for node in chain.graph.nodes(): if isinstance(node, Input): add_input(node, bank) bank = (bank + 1) % NUM_BANKS elif isinstance(node, Output): add_output(node, bank) bank = (bank + 1) % NUM_BANKS elif isinstance(node, Kernel): # Generate these separately after pass else: raise RuntimeError("Unexpected node type: {}".format( node.node_type)) # Finally generate the compute kernels for node in chain.graph.nodes(): if isinstance(node, Kernel): add_kernel(node) return sdfg