def count_arithmetic_ops(sdfg: dace.SDFG, symbols: Dict[str, Any] = None) -> int: result = 0 symbols = symbols or {} for state in sdfg.nodes(): result += count_arithmetic_ops_state(state, symbols) return result
def apply_pass(self, sdfg: SDFG, _) -> Dict[SDFGState, Set[SDFGState]]: """ :return: A dictionary mapping each state to its other reachable states. """ reachable: Dict[SDFGState, Set[SDFGState]] = {} tc: nx.DiGraph = nx.transitive_closure(sdfg.nx) for state in sdfg.nodes(): reachable[state] = set(tc.successors(state)) return reachable
def optimize_for_gpu(sdfg: dace.SDFG, m: int, n: int, k: int): """ Optimize the matrix multiplication example for GPUs. """ # Ensure integers are 32-bit by default dace.Config.set('compiler', 'default_data_types', value='C') # Fuse the map and reduce nodes sdfg.apply_transformations(MapReduceFusion) # Apply GPU transformation sdfg.apply_gpu_transformations() # Find multiplication map entry = find_map_by_param(sdfg, 'k') # Create a tiling strategy divides_evenly = (m % 64 == 0) and (n % 64 == 0) and (k % 8 == 0) xfutil.tile(sdfg, entry, divides_evenly, True, i=64, j=64, k=8) xfutil.tile(sdfg, entry, divides_evenly, True, i=8, j=4) # Create kernel schedule by collapsing and reordering maps gtile_i = find_map_by_param(sdfg, 'tile_i') gtile_j = find_map_by_param(sdfg, 'tile_j') btile_i = find_map_by_param(sdfg, 'tile1_i') btile_j = find_map_by_param(sdfg, 'tile1_j') MapCollapse.apply_to(sdfg, outer_map_entry=gtile_i, inner_map_entry=gtile_j, permissive=True) MapCollapse.apply_to(sdfg, outer_map_entry=btile_i, inner_map_entry=btile_j, permissive=True) btile = find_map_by_param(sdfg, 'tile1_i') btile.map.schedule = dace.ScheduleType.GPU_ThreadBlock # Add local storage (shared memory) for A and B on GPU ktile = find_map_by_param(sdfg, 'tile_k') smem_a = InLocalStorage.apply_to(sdfg, dict(array='A'), node_a=ktile, node_b=btile) smem_b = InLocalStorage.apply_to(sdfg, dict(array='B'), node_a=ktile, node_b=btile) sdfg.arrays[smem_a.data].storage = dace.StorageType.GPU_Shared sdfg.arrays[smem_b.data].storage = dace.StorageType.GPU_Shared # Add local storage (registers) for A and B ttile = find_map_by_param(sdfg, 'k') warptile, ttile = xfutil.extract_map_dims(sdfg, ttile, [2]) InLocalStorage.apply_to(sdfg, dict(array='trans_gpu_A'), node_a=warptile, node_b=ttile) InLocalStorage.apply_to(sdfg, dict(array='trans_gpu_B'), node_a=warptile, node_b=ttile) # Add local storage (registers) for C state = next(s for s in sdfg.nodes() if warptile in s.nodes()) warptile_exit = state.exit_node(warptile) btile_exit = state.exit_node(btile) AccumulateTransient.apply_to(sdfg, map_exit=warptile_exit, outer_map_exit=btile_exit) # Set C tile to zero on allocation c_access = next(n for n in state.data_nodes() if n.data == 'trans_gpu_C') c_access.setzero = True # Unroll microkernel maps ttile.map.unroll = True # Apply double-buffering on shared memory DoubleBuffering.apply_to(sdfg, map_entry=ktile, transient=smem_a)
def _sdfg_freeze_domain_and_origin(self, inner_sdfg: dace.SDFG, domain: Tuple[int, ...], origin: Dict[str, Tuple[int, ...]]): wrapper_sdfg = dace.SDFG("frozen_" + inner_sdfg.name) state = wrapper_sdfg.add_state("frozen_" + inner_sdfg.name + "_state") inputs = set() outputs = set() for inner_state in inner_sdfg.nodes(): for node in inner_state.nodes(): if (not isinstance(node, dace.nodes.AccessNode) or inner_sdfg.arrays[node.data].transient): continue if node.has_reads(inner_state): inputs.add(node.data) if node.has_writes(inner_state): outputs.add(node.data) nsdfg = state.add_nested_sdfg(inner_sdfg, None, inputs, outputs) self._sdfg_add_arrays_and_edges(wrapper_sdfg, state, inner_sdfg, nsdfg, inputs, outputs, origins=origin) # in special case of empty domain, remove entire SDFG. if any(d == 0 for d in domain): states = wrapper_sdfg.states() assert len(states) == 1 for node in states[0].nodes(): state.remove_node(node) # make sure that symbols are passed throught o inner sdfg for symbol in nsdfg.sdfg.free_symbols: if symbol not in wrapper_sdfg.symbols: wrapper_sdfg.add_symbol(symbol, nsdfg.sdfg.symbols[symbol]) # Try to inline wrapped SDFG before symbols are specialized to avoid extra views inline_sdfgs(wrapper_sdfg) self._sdfg_specialize_symbols(wrapper_sdfg, domain) for _, _, array in wrapper_sdfg.arrays_recursive(): if array.transient: array.lifetime = dace.dtypes.AllocationLifetime.SDFG signature = self.__sdfg_signature__() wrapper_sdfg.arg_names = [ a for a in signature[0] if a not in signature[1] ] return wrapper_sdfg
def _permute_array(array: Array, perm: Tuple[int, int, int], sdfg: dace.SDFG, array_name: str): array.shape = [array.shape[i] for i in perm] array.strides = [array.strides[i] for i in perm] array.offset = [array.offset[i] for i in perm] # Modify all edges coming in/out of the array for state in sdfg.nodes(): for e in state.edges(): if e.data.data == array_name: e.data.subset = type( e.data.subset)([e.data.subset[i] for i in perm])
def _specialize_transient_strides(sdfg: dace.SDFG, layout_map): repldict = replace_strides( [array for array in sdfg.arrays.values() if array.transient], layout_map, ) sdfg.replace_dict(repldict) for state in sdfg.nodes(): for node in state.nodes(): if isinstance(node, dace.nodes.NestedSDFG): for k, v in repldict.items(): if k in node.symbol_mapping: node.symbol_mapping[k] = v for k in repldict.keys(): if k in sdfg.symbols: sdfg.remove_symbol(k)
def _remove_transients(sdfg: dace.SDFG, transients_to_remove: Dict[str, float], replacer: ast.NodeTransformer = ASTFindReplace): """ Replaces transients with constants, removing associated access nodes. """ # Remove transients for dname, val in transients_to_remove.items(): # Add constant, remove data descriptor del sdfg.arrays[dname] sdfg.add_constant(dname, val) for state in sdfg.nodes(): for node in state.nodes(): if (isinstance(node, dace.nodes.AccessNode) and node.data == dname): # For all access node instances, remove # outgoing edge connectors from subsequent nodes, # then remove access nodes for edge in state.out_edges(node): for e in state.memlet_tree(edge): # Do not break scopes if there are no other edges if len(state.edges_between(e.src, e.dst)) == 1: state.add_edge(e.src, None, e.dst, None, dace.Memlet()) state.remove_edge_and_connectors(e) # If tasklet, replace connector name with constant if isinstance(e.dst, dace.nodes.Tasklet): replacer({ e.dst_conn: dname }).visit(e.dst.code.code) # If stencil, handle similarly elif isinstance(e.dst, stencil.Stencil): del e.dst.accesses[e.dst_conn] for i, stmt in enumerate(e.dst.code.code): e.dst.code.code[i] = replacer({ e.dst_conn: dname }).visit(stmt) # If dst is a NestedSDFG, add the dst_connector as # a constant and remove internal nodes elif isinstance(e.dst, dace.nodes.NestedSDFG): nsdfg: dace.SDFG = e.dst.sdfg _remove_transients(nsdfg, {dname: val}) # Lastly, remove the node itself state.remove_node(node)
def can_be_applied(graph: dace.SDFGState, candidate: Dict[Any, int], expr_index: int, sdfg: dace.SDFG, strict=False): stencil_a: Stencil = graph.node(candidate[StencilFusion._stencil_a]) stencil_b: Stencil = graph.node(candidate[StencilFusion._stencil_b]) array: nodes.AccessNode = graph.node( candidate[StencilFusion._tmp_array]) # Ensure the stencil shapes match if len(stencil_a.shape) != len(stencil_b.shape): return False if any(sa != sb for sa, sb in zip(stencil_a.shape, stencil_b.shape)): return False # Ensure that the transient is not used anywhere else and can be # removed if len(graph.all_edges(array)) != 2: return False if not sdfg.arrays[array.data].transient: return False if (len([ n for state in sdfg.nodes() for n in state.nodes() if isinstance(n, nodes.AccessNode) and n.data == array.data ]) > 1): return False # Ensure that second stencil only has one input access of the # candidate transient to remove edge = graph.out_edges(array)[0] if len(stencil_b.accesses[edge.dst_conn][1]) > 1: return False # TODO: Remove check once stencils can be offset if any(a != 0 for a in stencil_b.accesses[edge.dst_conn][1][0]): return False # Code languages must match if stencil_a.code.language != stencil_b.code.language: return False # TODO: Boundary condition matching checks return True
def add_gpu_location(sdfg: dace.SDFG, mapEntry, gpu): graph = sdfg.nodes()[sdfg.sdfg_id] mapEntry.location = {'gpu': gpu} exit_edges = [ e for e in graph.out_edges(mapEntry) if isinstance(e.dst, nodes.Tasklet) ] for e in exit_edges: tasklet = e.dst tasklet.location = {'gpu': gpu} entry_edges = [ e for e in graph.in_edges(mapEntry) if isinstance(e.src, nodes.AccessNode) and not isinstance(e.src.desc(sdfg), Scalar) ] for e in entry_edges: data_node = e.src data_node.desc(sdfg).location = {'gpu': gpu}
def unify_symbols(sdfg: dace.SDFG): """ Uses one set of symbols across all nested SDFGs. """ for state in sdfg.nodes(): for node in state.nodes(): if isinstance(node, dace.nodes.NestedSDFG): # First, get nested symbols and replace them if they match # the names of the outer symbols usedsyms: Set[str] = set() for symvalue in node.symbol_mapping.values(): usedsyms |= set( map( str, dace.symbolic.pystr_to_symbolic( symvalue).free_symbols)) # Replace clashing names clashing = usedsyms & (node.sdfg.symbols.keys() | node.sdfg.arrays.keys()) for clash in clashing: new_name = find_new_name(node.sdfg, clash) node.sdfg.replace(clash, new_name) if clash in node.symbol_mapping: node.symbol_mapping[ new_name] = node.symbol_mapping[clash] del node.symbol_mapping[clash] # Two-step replacement (N -> __dacesym_N --> map[N]) to avoid clashes for symname, symvalue in node.symbol_mapping.items(): if str(symname) != str(symvalue): node.sdfg.replace(symname, '__dacesym_' + symname) for symname, symvalue in node.symbol_mapping.items(): if str(symname) != str(symvalue): if str(symvalue) in node.sdfg.symbols: del node.sdfg.symbols[str(symvalue)] node.sdfg.replace('__dacesym_' + symname, str(symvalue)) # Replace symbol mapping node.symbol_mapping = {k: k for k in usedsyms} # Recursively descend unify_symbols(node.sdfg)
def get_parents(outermost_node: nodes.Node, node: nodes.Node, sdfg: dace.SDFG, state_id: int) -> List[nodes.Node]: parent = None # Because dfg is only a subgraph view, it does not contain the entry # node for a given entry. This O(n) solution is suboptimal for state in sdfg.nodes(): s_d = state.scope_dict() try: scope = s_d[node] except KeyError: continue if scope is not None: parent = scope break if parent is None: return [] if parent == outermost_node: return [parent] return PAPIUtils.get_parents(outermost_node, parent, sdfg, state_id) + [parent]
def remove_node_and_computation(sdfg: dace.SDFG, state: dace.SDFGState, node: nd.Node): """ Remove a node and the parent nodes that compute this node, if the outputs are not used elsewhere. :param sdfg: the sdfg containing the node. :param state: the state containing the node. :param node: the node to remove """ queue = deque([node]) while len(queue) > 0: current_node = queue.popleft() edges = state.in_edges(current_node) state.remove_node(current_node) for e in edges: next_node = e.src data_used_in_other_states = isinstance(next_node, nd.AccessNode) and \ any(n.data == next_node.data for s in sdfg.nodes() for n in s.nodes() if s is not state) if len(state.out_edges( next_node)) == 0 and not data_used_in_other_states: queue.append(next_node)
def apply_pass(self, sdfg: SDFG, _) -> Dict[SDFGState, Tuple[Set[str], Set[str]]]: """ :return: A dictionary mapping each state to its other reachable states. """ result: Dict[SDFGState, Tuple[Set[str], Set[str]]] = {} for state in sdfg.nodes(): readset, writeset = set(), set() for anode in state.data_nodes(): if state.in_degree(anode) > 0: writeset.add(anode.data) if state.out_degree(anode) > 0: readset.add(anode.data) result[state] = (readset, writeset) # Edges that read from arrays add to both ends' access sets anames = sdfg.arrays.keys() for e in sdfg.edges(): fsyms = e.data.free_symbols & anames if fsyms: result[e.src][0].update(fsyms) result[e.dst][0].update(fsyms) return result
def apply(self, sdfg: dace.SDFG): graph = sdfg.nodes()[self.state_id] t1 = graph.nodes()[self.subgraph[self.t1]] t2 = graph.nodes()[self.subgraph[self.t2]] def rename_conn(conn: str, names: Set[str]) -> str: """ Renames connector so that it doesn't clash with names. """ match = re.match('(.*?)([0-9]+)$', conn) if match: pre = match.group(1) else: pre = f'{conn}_' i = 0 while f'{pre}{i}' in names: i += 1 return f'{pre}{i}' def replace(tasklet, repl_dict): """ Renames connectors based on the input replacement dictionary. """ if tasklet.language is dtypes.Language.Python: repl = ConnectorRenamer(repl_dict) for stmt in tasklet.code.code: repl.visit(stmt) elif tasklet.language is dtypes.Language.CPP: for old, new in repl_dict.items(): tasklet.code.code = re.sub(r'\b%s\b' % re.escape(old), new, tasklet.code.as_string) def replace_lhs(tasklet, repl_dict): """ Replaces assignments' LHS based on the input replacement dictionary. This is used only on CPP tasklets. """ if tasklet.language is dtypes.Language.Python: raise ValueError( "This method should only be used with CPP Tasklets") elif tasklet.language is dtypes.Language.CPP: for old, new in repl_dict.items(): tasklet.code.code = re.sub( r'(?<!auto\s)%s[\s\t]*=' % re.escape(old), new, tasklet.code.as_string) def extract_lhs(tasklet) -> Set[str]: """ Returns the LHS of assignments in Tasklet code. """ if tasklet.language is dtypes.Language.Python: extr = PythonLHSExtractor() for stmt in tasklet.code.code: extr.visit(stmt) return extr.assignments elif tasklet.language is dtypes.Language.CPP: rhs = set() for match in re.findall('[\s\t\n\r]*([\w]*)[\s\t]*=', tasklet.code.code): rhs.add(match) return rhs rdict = dict() rdict_inout = dict() # Find names of current and former connectors # (assignments' LHS that are not connectors). t1_names = t1.in_connectors.keys() | t1.out_connectors.keys() t1_rhs = extract_lhs(t1) if t1_rhs: t1_names |= t1_rhs t2_names = t2.in_connectors.keys() | t2.out_connectors.keys() t2_rhs = extract_lhs(t2) if t2_rhs: t2_names |= t2_rhs # Change t2 connector names. nlist = list(t2_names) for name in nlist: if name in t1_names: newname = rename_conn(name, t1_names | t2_names) rdict[name] = newname t2_names.remove(name) t2_names.add(newname) if rdict: replace(t2, rdict) # Handle input edges. inconn = {} for e in graph.in_edges(t1): inconn[e.dst_conn] = t1.in_connectors[e.dst_conn] for e in graph.in_edges(t2): graph.remove_edge(e) conn = e.dst_conn if conn in rdict.keys(): conn = rdict[conn] if e.src is t1: rdict_inout[conn] = e.src_conn else: inconn[conn] = t2.in_connectors[e.dst_conn] graph.add_edge(e.src, e.src_conn, t1, conn, e.data) # Handle output edges. outconn = {} for e in graph.out_edges(t1): outconn[e.src_conn] = t1.out_connectors[e.src_conn] for e in graph.out_edges(t2): graph.remove_edge(e) conn = e.src_conn if conn in rdict: conn = rdict[conn] outconn[conn] = t2.out_connectors[e.src_conn] graph.add_edge(t1, conn, e.dst, e.dst_conn, e.data) # Rename in-out connectors. if rdict_inout: replace(t2, rdict_inout) # Update t1 connectors and code. t1.in_connectors = inconn t1.out_connectors = outconn if t1.language is dtypes.Language.Python: t1.code.code.extend(t2.code.code) elif t1.language is dtypes.Language.CPP: t1.code.code += f'\n{t2.code.code}' graph.remove_node(t2) # Fix CPP assignemnt LHS that are not connectors. if t1.language is dtypes.Language.CPP: rhs = extract_lhs(t1) repl_dict = dict() for name in rhs: if name not in inconn and name not in outconn: repl_dict[name] = f'auto {name} =' if repl_dict: replace_lhs(t1, repl_dict)
def count_moved_data(sdfg: dace.SDFG, symbols: Dict[str, Any] = None) -> int: result = 0 symbols = symbols or {} for state in sdfg.nodes(): result += count_moved_data_state(state, symbols) return result
def _add_ort_init_code(sdfg: SDFG): """ Add onnxruntime initialization code to the SDFG if required """ if "OrtKernelSession" not in sdfg.global_code['frame'].as_string: sdfg.append_global_code(""" // Start global ORT setup const OrtApi* __ort_api = OrtGetApiBase()->GetApi(ORT_API_VERSION); // helper function to check for status void __ort_check_status(OrtStatus* status) { if (status != NULL) { const char* msg = __ort_api->GetErrorMessage(status); fprintf(stderr, "%s\\n", msg); __ort_api->ReleaseStatus(status); exit(1); } } OrtEnv* __ort_env; OrtKernelSession* __ort_session; OrtSessionOptions* __ort_session_options; OrtMemoryInfo* __ort_cpu_mem_info; """) sdfg.append_init_code(""" __ort_check_status(__ort_api->CreateCpuMemoryInfo(OrtDeviceAllocator, OrtMemTypeDefault, &__ort_cpu_mem_info)); __ort_check_status(__ort_api->CreateEnv(ORT_LOGGING_LEVEL_WARNING, "dace_graph", &__ort_env)); __ort_check_status(__ort_api->CreateSessionOptions(&__ort_session_options)); __ort_check_status(OrtSessionOptionsAppendExecutionProvider_CPU(__ort_session_options, /*use_arena=*/0)); """) session_cleanup_code = """ __ort_api->ReleaseMemoryInfo(__ort_cpu_mem_info); __ort_api->ReleaseKernelSession(__ort_session); __ort_api->ReleaseSessionOptions(__ort_session_options); __ort_api->ReleaseEnv(__ort_env); """ if any( hasattr(node, "schedule") and node.schedule == ScheduleType.GPU_Device for state in sdfg.nodes() for node in state.nodes()): # if the SDFG contains a GPU node, add the CUDA provider and the memory_info sdfg.append_global_code("OrtMemoryInfo* __ort_cuda_mem_info;\n") sdfg.append_global_code( "OrtMemoryInfo* __ort_cuda_pinned_mem_info;\n") sdfg.append_init_code(""" __ort_check_status(__ort_api->CreateMemoryInfo("Cuda", /*allocator_type=*/OrtDeviceAllocator, /*device=*/0, /*mem_type=*/OrtMemTypeDefault, &__ort_cuda_mem_info)); __ort_check_status(__ort_api->CreateMemoryInfo("CudaPinned", /*allocator_type=*/OrtDeviceAllocator, /*device=*/0, /*mem_type=*/OrtMemTypeCPU, &__ort_cuda_pinned_mem_info)); __ort_check_status(OrtSessionOptionsAppendExecutionProvider_CUDA(__ort_session_options, /*device=*/0)); """) session_cleanup_code = (""" __ort_api->ReleaseMemoryInfo(__ort_cuda_mem_info); __ort_api->ReleaseMemoryInfo(__ort_cuda_pinned_mem_info); """ + session_cleanup_code) sdfg.append_global_code("// End global ORT setup\n") sdfg.prepend_exit_code(session_cleanup_code) sdfg.append_init_code(""" __ort_check_status(__ort_api->CreateKernelSession(__ort_session_options, &__ort_session, 12)); """)
def apply(self, sdfg: dace.SDFG): # Extract the map and its entry and exit nodes. graph = sdfg.nodes()[self.state_id] map_entry = graph.nodes()[self.subgraph[MapExpansion._map_entry]] map_exit = graph.exit_node(map_entry) current_map = map_entry.map # Create new maps new_maps = [ nodes.Map(current_map.label + '_' + str(param), [param], subsets.Range([param_range]), schedule=dtypes.ScheduleType.Sequential) for param, param_range in zip(current_map.params[1:], current_map.range[1:]) ] current_map.params = [current_map.params[0]] current_map.range = subsets.Range([current_map.range[0]]) # Create new map entries and exits entries = [nodes.MapEntry(new_map) for new_map in new_maps] exits = [nodes.MapExit(new_map) for new_map in new_maps] # Create edges, abiding by the following rules: # 1. If there are no edges coming from the outside, use empty memlets # 2. Edges with IN_* connectors replicate along the maps # 3. Edges for dynamic map ranges replicate until reaching range(s) for edge in graph.out_edges(map_entry): graph.remove_edge(edge) graph.add_memlet_path(map_entry, *entries, edge.dst, src_conn=edge.src_conn, memlet=edge.data, dst_conn=edge.dst_conn) # Modify dynamic map ranges dynamic_edges = dace.sdfg.dynamic_map_inputs(graph, map_entry) for edge in dynamic_edges: # Remove old edge and connector graph.remove_edge(edge) edge.dst.remove_in_connector(edge.dst_conn) # Propagate to each range it belongs to path = [] for mapnode in [map_entry] + entries: path.append(mapnode) if any(edge.dst_conn in map(str, symbolic.symlist(r)) for r in mapnode.map.range): graph.add_memlet_path(edge.src, *path, memlet=edge.data, src_conn=edge.src_conn, dst_conn=edge.dst_conn) # Create new map exits for edge in graph.in_edges(map_exit): graph.remove_edge(edge) graph.add_memlet_path(edge.src, *exits[::-1], map_exit, memlet=edge.data, src_conn=edge.src_conn, dst_conn=edge.dst_conn)
def apply(self, sdfg: dace.SDFG): graph = sdfg.nodes()[self.state_id] map_entry = graph.nodes()[self.subgraph[self._map_entry]] map_exit = graph.exit_node(map_entry) sz = dace.symbol('commsize', dtype=dace.int32, integer=True, positive=True) Px = dace.symbol('Px', dtype=dace.int32, integer=True, positive=True) Py = dace.symbol('Py', dtype=dace.int32, integer=True, positive=True) def _prod(sequence): return reduce(lambda a, b: a * b, sequence, 1) # NOTE: Maps with step in their ranges are currently not supported if len(map_entry.map.params) == 2: params = map_entry.map.params ranges = [None] * 2 b, e, _ = map_entry.map.range[0] ranges[0] = (0, (e - b + 1) / Px - 1, 1) b, e, _ = map_entry.map.range[1] ranges[1] = (0, (e - b + 1) / Py - 1, 1) strides = [1] else: params = ['__iflat'] sizes = map_entry.map.range.size_exact() total_size = _prod(sizes) ranges = [(0, (total_size) / sz - 1, 1)] strides = [_prod(sizes[i + 1:]) for i in range(len(sizes))] root_name = sdfg.temp_data_name() sdfg.add_scalar(root_name, dace.int32, transient=True) root_node = graph.add_access(root_name) root_tasklet = graph.add_tasklet('_set_root_', {}, {'__out'}, '__out = 0') graph.add_edge(root_tasklet, '__out', root_node, None, dace.Memlet.simple(root_name, '0')) from dace.libraries.mpi import Bcast from dace.libraries.pblas import BlockCyclicScatter, BlockCyclicGather inputs = set() for src, _, _, _, m in graph.in_edges(map_entry): if not isinstance(src, nodes.AccessNode): raise NotImplementedError desc = src.desc(sdfg) if not isinstance(desc, (data.Scalar, data.Array)): raise NotImplementedError if list(desc.shape) != m.src_subset.size_exact(): # Second attempt # TODO: We need a solution for symbols not matching if str(list(desc.shape)) != str(m.src_subset.size_exact()): raise NotImplementedError inputs.add(src) for inp in inputs: desc = inp.desc(sdfg) if isinstance(desc, data.Scalar): local_access = graph.add_access(inp.data) bcast_node = Bcast('_Bcast_') graph.add_edge(inp, None, bcast_node, '_inbuffer', dace.Memlet.from_array(inp.data, desc)) graph.add_edge(root_node, None, bcast_node, '_root', dace.Memlet.simple(root_name, '0')) graph.add_edge(bcast_node, '_outbuffer', local_access, None, dace.Memlet.from_array(inp.data, desc)) for e in graph.edges_between(inp, map_entry): graph.add_edge(local_access, None, map_entry, e.dst_conn, dace.Memlet.from_array(inp.data, desc)) graph.remove_edge(e) elif isinstance(desc, data.Array): local_name, local_arr = sdfg.add_temp_transient( [(desc.shape[0]) // Px, (desc.shape[1]) // Py], dtype=desc.dtype, storage=desc.storage) local_access = graph.add_access(local_name) bsizes_name, bsizes_arr = sdfg.add_temp_transient( (2, ), dtype=dace.int32) bsizes_access = graph.add_access(bsizes_name) bsizes_tasklet = nodes.Tasklet( '_set_bsizes_', {}, {'__out'}, "__out[0] = {x}; __out[1] = {y}".format( x=(desc.shape[0]) // Px, y=(desc.shape[1]) // Py)) graph.add_edge(bsizes_tasklet, '__out', bsizes_access, None, dace.Memlet.from_array(bsizes_name, bsizes_arr)) gdesc_name, gdesc_arr = sdfg.add_temp_transient( (9, ), dtype=dace.int32) gdesc_access = graph.add_access(gdesc_name) ldesc_name, ldesc_arr = sdfg.add_temp_transient( (9, ), dtype=dace.int32) ldesc_access = graph.add_access(ldesc_name) scatter_node = BlockCyclicScatter('_Scatter_') graph.add_edge(inp, None, scatter_node, '_inbuffer', dace.Memlet.from_array(inp.data, desc)) graph.add_edge(bsizes_access, None, scatter_node, '_block_sizes', dace.Memlet.from_array(bsizes_name, bsizes_arr)) graph.add_edge(scatter_node, '_outbuffer', local_access, None, dace.Memlet.from_array(local_name, local_arr)) graph.add_edge(scatter_node, '_gdescriptor', gdesc_access, None, dace.Memlet.from_array(gdesc_name, gdesc_arr)) graph.add_edge(scatter_node, '_ldescriptor', ldesc_access, None, dace.Memlet.from_array(ldesc_name, ldesc_arr)) for e in graph.edges_between(inp, map_entry): graph.add_edge( local_access, None, map_entry, e.dst_conn, dace.Memlet.from_array(local_name, local_arr)) graph.remove_edge(e) for e in graph.out_edges(map_entry): if e.data.data == inp.data: e.data.data = local_name else: raise NotImplementedError outputs = set() for _, _, dst, _, m in graph.out_edges(map_exit): if not isinstance(dst, nodes.AccessNode): raise NotImplementedError desc = dst.desc(sdfg) if not isinstance(desc, data.Array): raise NotImplementedError try: if list(desc.shape) != m.dst_subset.size_exact(): # Second attempt # TODO: We need a solution for symbols not matching if str(list(desc.shape)) != str(m.dst_subset.size_exact()): raise NotImplementedError except AttributeError: if list(desc.shape) != m.subset.size_exact(): # Second attempt # TODO: We need a solution for symbols not matching if str(list(desc.shape)) != str(m.subset.size_exact()): raise NotImplementedError outputs.add(dst) for out in outputs: desc = out.desc(sdfg) if isinstance(desc, data.Scalar): raise NotImplementedError elif isinstance(desc, data.Array): local_name, local_arr = sdfg.add_temp_transient( [(desc.shape[0]) // Px, (desc.shape[1]) // Py], dtype=desc.dtype, storage=desc.storage) local_access = graph.add_access(local_name) bsizes_name, bsizes_arr = sdfg.add_temp_transient( (2, ), dtype=dace.int32) bsizes_access = graph.add_access(bsizes_name) bsizes_tasklet = nodes.Tasklet( '_set_bsizes_', {}, {'__out'}, "__out[0] = {x}; __out[1] = {y}".format( x=(desc.shape[0]) // Px, y=(desc.shape[1]) // Py)) graph.add_edge(bsizes_tasklet, '__out', bsizes_access, None, dace.Memlet.from_array(bsizes_name, bsizes_arr)) scatter_node = BlockCyclicGather('_Gather_') graph.add_edge(local_access, None, scatter_node, '_inbuffer', dace.Memlet.from_array(local_name, local_arr)) graph.add_edge(bsizes_access, None, scatter_node, '_block_sizes', dace.Memlet.from_array(bsizes_name, bsizes_arr)) graph.add_edge(scatter_node, '_outbuffer', out, None, dace.Memlet.from_array(out.data, desc)) for e in graph.edges_between(map_exit, out): graph.add_edge( map_exit, e.src_conn, local_access, None, dace.Memlet.from_array(local_name, local_arr)) graph.remove_edge(e) for e in graph.in_edges(map_exit): if e.data.data == out.data: e.data.data = local_name else: raise NotImplementedError map_entry.map.params = params map_entry.map.range = subsets.Range(ranges)
def apply(self, sdfg: dace.SDFG): # Extract the subgraph, execute it and insert an AccessNode to the result parent: ONNXModel = sdfg._parent_onnx_model state = sdfg.nodes()[self.state_id] node = state.nodes()[self.subgraph[ConstantFolding._onnx_node]] if isinstance(node, donnx.ONNXShape): # if we have a shape node, replace it with a constant assert len(state.in_edges(node)) == 1 shape_in_edge = state.in_edges(node)[0] assert shape_in_edge.dst_conn == "data" shape_desc = sdfg.arrays[shape_in_edge.src.data] constant_name = sdfg.temp_data_name() clean_constant_name = clean_onnx_name(constant_name) sdfg.add_array(clean_constant_name, (len(shape_desc.shape), ), dace.int64) assert constant_name not in parent.clean_weights parent.weights[constant_name] = np.array(shape_desc.shape, np.int64) assert len(state.out_edges(node)) == 1 output_edge = state.out_edges(node)[0] access_shape = state.add_access(clean_constant_name) state.add_edge(access_shape, None, output_edge.dst, output_edge.dst_conn, sdfg.make_array_memlet(clean_constant_name)) else: # otherwise compute the result of the op sub_sdfg = dace.SDFG("sub_sdfg") sub_state = sub_sdfg.add_state() node_copy = copy.deepcopy(node) sub_state.add_node(node_copy) inputs = {} for edge in state.in_edges(node): # we know from can_be_applied that all in edges are from AccessNodes assert (isinstance(edge.src, nd.AccessNode) and hasattr(sdfg, "_parent_onnx_model") and edge.src.data in sdfg._parent_onnx_model.clean_weights) desc = copy.deepcopy(sdfg.arrays[edge.data.data]) desc.transient = False sub_sdfg.add_datadesc('array_' + edge.dst_conn, desc) input_value = sdfg._parent_onnx_model.clean_weights[ edge.src.data] if len(input_value.shape) == 0: inputs['array_' + edge.dst_conn] = input_value[()] else: inputs['array_' + edge.dst_conn] = input_value.copy() access = sub_state.add_access('array_' + edge.dst_conn) sub_state.add_edge( access, None, node_copy, edge.dst_conn, sub_sdfg.make_array_memlet('array_' + edge.dst_conn)) outputs = {} for edge in state.out_edges(node): desc = copy.deepcopy(sdfg.arrays[edge.data.data]) if isinstance(desc, dt.Scalar): # we need to copy to an array of size [1] so that we can "return" the output from the sdfg desc.transient = True sub_sdfg.add_datadesc('scalar_array_' + edge.src_conn, desc) sub_sdfg.add_array('array_' + edge.src_conn, [1], desc.dtype, transient=False) access_scalar = sub_state.add_access('scalar_array_' + edge.src_conn) access = sub_state.add_access('array_' + edge.src_conn) sub_state.add_edge( node_copy, edge.src_conn, access_scalar, None, sub_sdfg.make_array_memlet('scalar_array_' + edge.src_conn)) sub_state.add_edge( access_scalar, None, access, None, sub_sdfg.make_array_memlet('array_' + edge.src_conn)) else: desc.transient = False sub_sdfg.add_datadesc('array_' + edge.src_conn, desc) access = sub_state.add_access('array_' + edge.src_conn) sub_state.add_edge( node_copy, edge.src_conn, access, None, sub_sdfg.make_array_memlet('array_' + edge.src_conn)) if len(desc.shape) == 0: outputs['array_' + edge.src_conn] = np.empty( (1, ), desc.dtype.as_numpy_dtype()) else: outputs['array_' + edge.src_conn] = np.empty( tuple(desc.shape), desc.dtype.as_numpy_dtype()) sub_sdfg(**outputs, **inputs) for edge in state.out_edges(node): desc = copy.deepcopy(sdfg.arrays[edge.data.data]) desc.transient = False output_value = outputs['array_' + edge.src_conn] constant_name = sdfg.temp_data_name() clean_constant_name = clean_onnx_name(constant_name) sdfg.add_datadesc(clean_constant_name, desc) assert constant_name not in parent.weights if isinstance(desc, dt.Scalar): parent.weights[constant_name] = output_value.reshape(()) else: parent.weights[constant_name] = output_value access_constant = state.add_access(clean_constant_name) state.add_edge(access_constant, None, edge.dst, edge.dst_conn, sdfg.make_array_memlet(clean_constant_name)) # remove all now useless nodes with a reverse BFS queue = deque([node]) while len(queue) > 0: current_node = queue.popleft() edges = state.in_edges(current_node) state.remove_node(current_node) for e in edges: next_node = e.src if len(state.out_edges(next_node)) == 0: queue.append(next_node)
def apply(self, sdfg: SDFG) -> Union[Any, None]: # Load/parse infos from the SDFG graph = sdfg.nodes()[self.state_id] src = graph.nodes()[self.subgraph[BankSplit._src_node]] dst = graph.nodes()[self.subgraph[BankSplit._dst_node]] src_array = sdfg.arrays[src.data] dst_array = sdfg.arrays[dst.data] collect_src = len(src_array.shape) - 1 == len( dst_array.shape ) # If this is not true we have to distribute to dst (checked in can_apply) if collect_src: bank_count = int(src_array.shape[0]) true_size = dst_array.shape else: bank_count = int(dst_array.shape[0]) true_size = src_array.shape ndim = len(true_size) # Move Default storage if sdfg.arrays[src.data].storage == dtypes.StorageType.Default: sdfg.arrays[src.data].storage = self.default_to_storage if sdfg.arrays[dst.data].storage == dtypes.StorageType.Default: sdfg.arrays[dst.data].storage = self.default_to_storage # Figure out how to split if self.split_array_info is None: split_info = [1] * ndim split_info[0] = bank_count else: split_info = self.split_array_info if len(split_info) != ndim: raise RuntimeError( "Length of split_array_info must match number of " "dimensions") if functools.reduce(lambda a, b: a * b, split_info) != bank_count: raise RuntimeError( "Splitting is not possible with the selected splits" "and this number of HBM-banks (required number of banks " "!= actual number of banks)") # create the copy-subgraph ndrange = dict() usable_params = [] for i in range(ndim): usable_params.append(f"i{i}") for i in range(ndim): ndrange[usable_params[i]] = f"0:{split_info[i]}" graph.remove_edge_and_connectors(graph.edges_between(src, dst)[0]) copy_map_enter, copy_map_exit = graph.add_map( "hbm_bank_split", ndrange, dtypes.ScheduleType.Unrolled) graph.add_edge(copy_map_enter, None, src, None, memlet.Memlet()) graph.add_edge(dst, None, copy_map_exit, None, memlet.Memlet()) target_size = [ str(x) for x in self._get_split_size(true_size, split_info) ] target_hbm_bank = [] for i in range(ndim): target_hbm_bank.append(usable_params[i]) for j in range(i): target_hbm_bank[j] = f"{split_info[i]}*{target_hbm_bank[j]}" target_offset = [] for i in range(ndim): target_offset.append(f"{usable_params[i]}*{target_size[i]}") target_size_str = ", ".join( [f"{x}:{y}" for x, y in zip([0] * ndim, target_size)]) target_hbm_bank_str = "+ ".join(target_hbm_bank) target_offset_str = ", ".join( [f"({x}):({x}+{y})" for x, y in zip(target_offset, target_size)]) if collect_src: copy_memlet = memlet.Memlet( f"{src.data}[{target_hbm_bank_str}, {target_size_str}]->" f"{target_offset_str}") else: copy_memlet = memlet.Memlet( f"{src.data}[{target_offset_str}]->{target_hbm_bank_str}, " f"{target_size_str}") graph.add_edge(src, None, dst, None, copy_memlet)
def apply(self, sdfg: dace.SDFG): # Extract the map and its entry and exit nodes. graph = sdfg.nodes()[self.state_id] map_entry = graph.nodes()[self.subgraph[MapExpansion._map_entry]] map_exit = graph.exit_nodes(map_entry)[0] current_map = map_entry.map # Create new maps maps = [ nodes.Map(current_map.label + '_' + str(param), [param], subsets.Range([param_range]), schedule=dtypes.ScheduleType.Sequential) for param, param_range in zip(current_map.params, current_map.range) ] maps[0]._schedule = dtypes.ScheduleType.Default # Create new map entries entries = [nodes.MapEntry(new_map) for new_map in maps] entries[0].in_connectors = map_entry.in_connectors entries[0].out_connectors = map_entry.out_connectors num_entry_out_edges = len(graph.out_edges(map_entry)) for i in range(1, len(entries)): entries[i].in_connectors = set('IN_' + str(i + 1) for i in range(num_entry_out_edges)) entries[i].out_connectors = set( 'OUT_' + str(i + 1) for i in range(num_entry_out_edges)) # Create new map exits exits = [nodes.MapExit(new_map) for new_map in maps] exits.reverse() exits[-1].in_connectors = map_exit.in_connectors exits[-1].out_connectors = map_exit.out_connectors num_entry_out_edges = len(graph.out_edges(map_exit)) for i in range(0, len(exits) - 1): exits[i].in_connectors = set('IN_' + str(i + 1) for i in range(num_entry_out_edges)) exits[i].out_connectors = set('OUT_' + str(i + 1) for i in range(num_entry_out_edges)) # Add new nodes to state graph.add_nodes_from(entries) graph.add_nodes_from(exits) # Redirect edges to new nodes dace.graph.nxutil.change_edge_dest(graph, map_entry, entries[0]) dace.graph.nxutil.change_edge_src(graph, map_exit, exits[-1]) for i, e in enumerate(graph.out_edges(map_entry)): graph.remove_edge(e) graph.add_edge(entries[0], e.src_conn, entries[1], 'IN_' + str(i + 1), copy.deepcopy(e.data)) graph.add_edge(entries[-1], 'OUT_' + str(i + 1), e.dst, e.dst_conn, copy.deepcopy(e.data)) for j in range(1, len(entries) - 1): graph.add_edge(entries[j], 'OUT_' + str(i + 1), entries[j + 1], 'IN_' + str(i + 1), copy.deepcopy(e.data)) for i, e in enumerate(graph.in_edges(map_exit)): graph.remove_edge(e) graph.add_edge(e.src, e.src_conn, exits[0], 'IN_' + str(i + 1), copy.deepcopy(e.data)) graph.add_edge(exits[-2], 'OUT_' + str(i + 1), exits[-1], e.dst_conn, copy.deepcopy(e.data)) for j in range(0, len(exits) - 2): graph.add_edge(exits[j], 'OUT_' + str(i + 1), exits[j + 1], 'IN_' + str(i + 1), copy.deepcopy(e.data)) # Remove old nodes graph.remove_node(map_entry) graph.remove_node(map_exit)
def apply_pass(self, sdfg: SDFG, _, initial_symbols: Optional[Dict[str, Any]] = None ) -> Optional[Set[str]]: """ Propagates constants throughout the SDFG. :param sdfg: The SDFG to modify. :param pipeline_results: If in the context of a ``Pipeline``, a dictionary that is populated with prior Pass results as ``{Pass subclass name: returned object from pass}``. If not run in a pipeline, an empty dictionary is expected. :param initial_symbols: If not None, sets values of initial symbols. :return: A set of propagated constants, or None if nothing was changed. """ initial_symbols = initial_symbols or {} # Early exit if no constants can be propagated if not initial_symbols and not self.should_apply(sdfg): result = {} else: # Trace all constants and symbols through states per_state_constants: Dict[SDFGState, Dict[str, Any]] = self.collect_constants( sdfg, initial_symbols) # Keep track of replaced and ambiguous symbols symbols_replaced: Dict[str, Any] = {} remaining_unknowns: Set[str] = set() # Collect symbols from symbol-dependent data descriptors # If there can be multiple values over the SDFG, the symbols are not propagated desc_symbols, multivalue_desc_symbols = self._find_desc_symbols( sdfg, per_state_constants) # Replace constants per state for state, mapping in per_state_constants.items(): remaining_unknowns.update({ k for k, v in mapping.items() if v is _UnknownValue or k in multivalue_desc_symbols }) mapping = { k: v for k, v in mapping.items() if v is not _UnknownValue and k not in multivalue_desc_symbols } # Update replaced symbols for later replacements symbols_replaced.update(mapping) # Replace in state contents state.replace_dict(mapping) # Replace in outgoing edges as well for e in sdfg.out_edges(state): e.data.replace_dict(mapping, replace_keys=False) # If symbols are never unknown any longer, remove from SDFG result = { k: v for k, v in symbols_replaced.items() if k not in remaining_unknowns } # Remove from symbol repository for sym in result: if sym in sdfg.symbols: sdfg.remove_symbol(sym) # Remove single-valued symbols from data descriptors (e.g., symbolic array size) sdfg.replace_dict( {k: v for k, v in result.items() if k in desc_symbols}, replace_in_graph=False, replace_keys=False) # Remove constant symbol assignments in interstate edges for edge in sdfg.edges(): intersection = result & edge.data.assignments.keys() for sym in intersection: del edge.data.assignments[sym] result = set(result.keys()) if self.recursive: # Change result to set of tuples sid = sdfg.sdfg_id result = set((sid, sym) for sym in result) for state in sdfg.nodes(): for node in state.nodes(): if isinstance(node, nodes.NestedSDFG): nested_id = node.sdfg.sdfg_id const_syms = { k: v for k, v in node.symbol_mapping.items() if not symbolic.issymbolic(v) } internal = self.apply_pass(node.sdfg, _, const_syms) if internal: for nid, removed in internal: result.add((nid, removed)) # Remove symbol mapping if constant was completely propagated if nid == nested_id and removed in node.symbol_mapping: del node.symbol_mapping[removed] # Return result if not result: return None return result
def find_dead_states( self, sdfg: SDFG, set_unconditional_edges: bool = True) -> Set[SDFGState]: ''' Finds "dead" (unreachable) states in an SDFG. A state is deemed unreachable if it is: * Unreachable from the starting state * Conditions leading to it will always evaluate to False * There is another unconditional (always True) inter-state edge that leads to another state :param sdfg: The SDFG to traverse. :param set_unconditional_edges: If True, conditions of edges evaluated as unconditional are removed. :return: A set of unreachable states. ''' visited: Set[SDFGState] = set() # Run a modified BFS where definitely False edges are not traversed, or if there is an # unconditional edge the rest are not. The inverse of the visited states is the dead set. queue = collections.deque([sdfg.start_state]) while len(queue) > 0: node = queue.popleft() if node in visited: continue visited.add(node) # First, check for unconditional edges unconditional = None for e in sdfg.out_edges(node): # If an unconditional edge is found, ignore all other outgoing edges if self.is_definitely_taken(e.data): # If more than one unconditional outgoing edge exist, fail with Invalid SDFG if unconditional is not None: raise InvalidSDFGInterstateEdgeError( 'Multiple unconditional edges leave the same state', sdfg, sdfg.edge_id(e)) unconditional = e if set_unconditional_edges and not e.data.is_unconditional( ): # Annotate edge as unconditional e.data.condition = CodeBlock('1') # Continue traversal through edge if e.dst not in visited: queue.append(e.dst) continue if unconditional is not None: # Unconditional edge exists, skip traversal continue # End of unconditional check # Check outgoing edges normally for e in sdfg.out_edges(node): next_node = e.dst # Test for edges that definitely evaluate to False if self.is_definitely_not_taken(e.data): continue # Continue traversal through edge if next_node not in visited: queue.append(next_node) # Dead states are states that are not live (i.e., visited) return set(sdfg.nodes()) - visited
def apply(self, sdfg: dace.SDFG): # Extract the subgraph, execute it and insert an AccessNode to the result # this method of execution is slow but simple. A better option would be to call the ORT # C API from a python object (like the OpChecker). parent: ONNXModel = sdfg._parent_onnx_model state = sdfg.nodes()[self.state_id] node = state.nodes()[self.subgraph[ConstantFolding._onnx_node]] log.debug(f"Applying constant folding: {node} in {state}") if isinstance(node, donnx.ONNXShape): # if we have a shape node, replace it with a constant assert len(state.in_edges(node)) == 1 shape_in_edge = state.in_edges(node)[0] assert shape_in_edge.dst_conn == "data" shape_desc = sdfg.arrays[shape_in_edge.src.data] constant_name = sdfg.temp_data_name() clean_constant_name = clean_onnx_name(constant_name) sdfg.add_array(clean_constant_name, (len(shape_desc.shape), ), dace.int64) assert constant_name not in parent.clean_weights parent.weights[constant_name] = torch.from_numpy( np.array(shape_desc.shape, np.int64)) assert len(state.out_edges(node)) == 1 output_edge = state.out_edges(node)[0] access_shape = state.add_access(clean_constant_name) state.add_edge(access_shape, None, output_edge.dst, output_edge.dst_conn, sdfg.make_array_memlet(clean_constant_name)) else: # otherwise compute the result of the op global UNIQUE_ID UNIQUE_ID += 1 sub_sdfg = dace.SDFG("sub_sdfg_" + str(UNIQUE_ID)) sub_state = sub_sdfg.add_state() node_copy = copy.deepcopy(node) sub_state.add_node(node_copy) inputs = {} for edge in state.in_edges(node): # we know from can_be_applied that all in edges are from AccessNodes assert (isinstance(edge.src, nd.AccessNode) and hasattr(sdfg, "_parent_onnx_model") and edge.src.data in sdfg._parent_onnx_model.clean_weights) desc = copy.deepcopy(sdfg.arrays[edge.data.data]) desc.transient = False sub_sdfg.add_datadesc('array_' + edge.dst_conn, desc) input_value = sdfg._parent_onnx_model.clean_weights[ edge.src.data] if len(input_value.shape) == 0: inputs['array_' + edge.dst_conn] = input_value.cpu().numpy()[()] else: inputs['array_' + edge.dst_conn] = input_value.clone() access = sub_state.add_access('array_' + edge.dst_conn) sub_state.add_edge( access, None, node_copy, edge.dst_conn, sub_sdfg.make_array_memlet('array_' + edge.dst_conn)) outputs = {} for edge in state.out_edges(node): desc = copy.deepcopy(sdfg.arrays[edge.data.data]) if isinstance(desc, dt.Scalar): # we need to copy to an array of size [1] so that we can "return" the output from the sdfg desc.transient = True sub_sdfg.add_datadesc('scalar_array_' + edge.src_conn, desc) sub_sdfg.add_array('array_' + edge.src_conn, [1], desc.dtype, transient=False) access_scalar = sub_state.add_access('scalar_array_' + edge.src_conn) access = sub_state.add_access('array_' + edge.src_conn) sub_state.add_edge( node_copy, edge.src_conn, access_scalar, None, sub_sdfg.make_array_memlet('scalar_array_' + edge.src_conn)) sub_state.add_edge( access_scalar, None, access, None, sub_sdfg.make_array_memlet('array_' + edge.src_conn)) else: desc.transient = False sub_sdfg.add_datadesc('array_' + edge.src_conn, desc) access = sub_state.add_access('array_' + edge.src_conn) sub_state.add_edge( node_copy, edge.src_conn, access, None, sub_sdfg.make_array_memlet('array_' + edge.src_conn)) if len(desc.shape) == 0: empty_array = np.empty((1, ), desc.dtype.as_numpy_dtype()) else: empty_array = np.empty(tuple(desc.shape), desc.dtype.as_numpy_dtype()) empty_array = torch.from_numpy(empty_array) if desc.storage is dtypes.StorageType.GPU_Global: empty_array = empty_array.cuda() outputs['array_' + edge.src_conn] = empty_array sub_sdfg(**outputs, **inputs) for edge in state.out_edges(node): desc = copy.deepcopy(sdfg.arrays[edge.data.data]) desc.transient = False output_value = outputs['array_' + edge.src_conn] constant_name = sdfg.temp_data_name() clean_constant_name = clean_onnx_name(constant_name) sdfg.add_datadesc(clean_constant_name, desc) assert constant_name not in parent.weights assert type(output_value) is torch.Tensor if not dtypes.can_access(dtypes.ScheduleType.CPU_Multicore, desc.storage): cpu_desc = copy.deepcopy(desc) cpu_desc.storage = dtypes.StorageType.CPU_Heap cpu_desc.transient = False desc.transient = True copy_in_name = sdfg.temp_data_name() clean_copy_in_name = clean_onnx_name(copy_in_name) sdfg.add_datadesc(clean_copy_in_name, cpu_desc) access_constant = state.add_access(clean_constant_name) state.add_edge(state.add_read(clean_copy_in_name), None, access_constant, None, sdfg.make_array_memlet(clean_copy_in_name)) name_to_add = copy_in_name else: access_constant = state.add_read(clean_constant_name) name_to_add = constant_name if isinstance(desc, dt.Scalar): parent.weights[name_to_add] = output_value.reshape(()) else: parent.weights[name_to_add] = output_value state.add_edge(access_constant, None, edge.dst, edge.dst_conn, sdfg.make_array_memlet(clean_constant_name)) # remove all now useless nodes with a reverse BFS remove_node_and_computation(sdfg, state, node)