def apply(self, graph: SDFGState, sdfg: SDFG): if self.expr_index == 0: map_entry = self.map_entry nsdfg_node = helpers.nest_state_subgraph( sdfg, graph, graph.scope_subgraph(map_entry), full_data=self.fullcopy) else: cnode = self.reduce nsdfg_node = helpers.nest_state_subgraph(sdfg, graph, SubgraphView( graph, [cnode]), full_data=self.fullcopy) # Avoiding import loops from dace.transformation.interstate import GPUTransformSDFG transformation = GPUTransformSDFG(sdfg, 0, -1, {}, 0) transformation.register_trans = self.register_trans transformation.sequential_innermaps = self.sequential_innermaps transformation.toplevel_trans = self.toplevel_trans transformation.apply(nsdfg_node.sdfg, nsdfg_node.sdfg) # Inline back as necessary sdfg.simplify()
def dry_run(sdfg: SDFG, *args, **kwargs) -> Any: # Check existing instrumented data for shape mismatch kwargs.update({aname: a for aname, a in zip(sdfg.arg_names, args)}) dreport = sdfg.get_instrumented_data() if dreport is not None: for data in dreport.keys(): rep_arr = dreport.get_first_version(data) sdfg_arr = sdfg.arrays[data] # Potential shape mismatch if rep_arr.shape != sdfg_arr.shape: # Check given data first if hasattr(kwargs[data], 'shape') and rep_arr.shape != kwargs[data].shape: sdfg.clear_data_reports() dreport = None break # If there is no valid instrumented data available yet, run in data instrumentation mode if dreport is None: for state in sdfg.nodes(): for node in state.nodes(): if isinstance(node, dace.nodes.AccessNode) and not node.desc(sdfg).transient: node.instrument = dace.DataInstrumentationType.Save result = sdfg(**kwargs) # Disable data instrumentation from now on for state in sdfg.nodes(): for node in state.nodes(): if isinstance(node, dace.nodes.AccessNode): node.instrument = dace.DataInstrumentationType.No_Instrumentation else: return None return result
def fuse_states(sdfg: SDFG, strict: bool = True, progress: bool = False) -> int: """ Fuses all possible states of an SDFG (and all sub-SDFGs) using an optimized routine that uses the structure of the StateFusion transformation. :param sdfg: The SDFG to transform. :param strict: If True (default), operates in strict mode. :param progress: If True, prints out a progress bar of fusion (may be inaccurate, requires ``tqdm``) :return: The total number of states fused. """ from dace.transformation.interstate import StateFusion # Avoid import loop counter = 0 if progress: from tqdm import tqdm fusible_states = 0 for sd in sdfg.all_sdfgs_recursive(): fusible_states += sd.number_of_edges() pbar = tqdm(total=fusible_states) for sd in sdfg.all_sdfgs_recursive(): id = sd.sdfg_id while True: edges = list(sd.nx.edges) applied = 0 skip_nodes = set() for u, v in edges: if u in skip_nodes or v in skip_nodes: continue candidate = { StateFusion.first_state: u, StateFusion.second_state: v } sf = StateFusion(id, -1, candidate, 0, override=True) if sf.can_be_applied(sd, candidate, 0, sd, strict=strict): sf.apply(sd) applied += 1 counter += 1 if progress: pbar.update(1) skip_nodes.add(u) skip_nodes.add(v) if applied == 0: break if progress: pbar.close() if config.Config.get_bool('debugprint'): print(f'Applied {counter} State Fusions') return counter
def fuse_states(sdfg: SDFG) -> int: """ Fuses all possible states of an SDFG (and all sub-SDFGs) using an optimized routine that uses the structure of the StateFusion transformation. :param sdfg: The SDFG to transform. :return: The total number of states fused. """ from dace.transformation.interstate import StateFusion # Avoid import loop counter = 0 for sd in sdfg.all_sdfgs_recursive(): id = sd.sdfg_id while True: edges = list(sd.nx.edges) applied = 0 skip_nodes = set() for u, v in edges: if u in skip_nodes or v in skip_nodes: continue candidate = { StateFusion.first_state: u, StateFusion.second_state: v } sf = StateFusion(id, -1, candidate, 0, override=True) if sf.can_be_applied(sd, candidate, 0, sd, strict=True): sf.apply(sd) applied += 1 counter += 1 skip_nodes.add(u) skip_nodes.add(v) if applied == 0: break if config.Config.get_bool('debugprint'): print(f'Applied {counter} State Fusions') return counter
def apply(self, graph: SDFGState, sdfg: SDFG): import dace.libraries.blas as blas transpose_a = self.transpose_a _at = self.at transpose_b = self.transpose_b _bt = self.bt a_times_b = self.a_times_b for src, src_conn, _, _, memlet in graph.in_edges(transpose_a): graph.add_edge(src, src_conn, a_times_b, '_b', memlet) graph.remove_node(transpose_a) for src, src_conn, _, _, memlet in graph.in_edges(transpose_b): graph.add_edge(src, src_conn, a_times_b, '_a', memlet) graph.remove_node(transpose_b) graph.remove_node(_at) graph.remove_node(_bt) for _, _, dst, dst_conn, memlet in graph.out_edges(a_times_b): subset = dcpy(memlet.subset) subset.squeeze() size = subset.size() shape = [size[1], size[0]] break tmp_name, tmp_arr = sdfg.add_temp_transient(shape, a_times_b.dtype) tmp_acc = graph.add_access(tmp_name) transpose_c = blas.Transpose('_Transpose_', a_times_b.dtype) for edge in graph.out_edges(a_times_b): _, _, dst, dst_conn, memlet = edge graph.remove_edge(edge) graph.add_edge(transpose_c, '_out', dst, dst_conn, memlet) graph.add_edge(a_times_b, '_c', tmp_acc, None, dace.Memlet.from_array(tmp_name, tmp_arr)) graph.add_edge(tmp_acc, None, transpose_c, '_inp', dace.Memlet.from_array(tmp_name, tmp_arr))
def structured_control_flow_tree( sdfg: SDFG, dispatch_state: Callable[[SDFGState], str]) -> ControlFlow: """ Returns a structured control-flow tree (i.e., with constructs such as branches and loops) from an SDFG, which can be used to generate its code in a compiler- and human-friendly way. :param sdfg: The SDFG to iterate over. :return: Control-flow block representing the entire SDFG. """ # Avoid import loops from dace.sdfg.analysis import cfg # Get parent states and back-edges ptree = cfg.state_parent_tree(sdfg) back_edges = cfg.back_edges(sdfg) # Annotate branches branch_merges: Dict[SDFGState, SDFGState] = {} adf = cfg.acyclic_dominance_frontier(sdfg) for state in sdfg.nodes(): oedges = sdfg.out_edges(state) # Skip if not branch if len(oedges) <= 1: continue # Skip if natural loop if len(oedges) == 2 and ( (ptree[oedges[0].dst] == state and ptree[oedges[1].dst] != state) or (ptree[oedges[1].dst] == state and ptree[oedges[0].dst] != state)): continue common_frontier = set() for oedge in oedges: frontier = adf[oedge.dst] if not frontier: frontier = {oedge.dst} common_frontier |= frontier if len(common_frontier) == 1: branch_merges[state] = next(iter(common_frontier)) root_block = GeneralBlock(dispatch_state, [], []) _structured_control_flow_traversal(sdfg, sdfg.start_state, ptree, branch_merges, back_edges, dispatch_state, root_block) return root_block
def load_precompiled_sdfg(folder: str): """ Loads a pre-compiled SDFG from an output folder (e.g. ".dacecache/program"). Folder must contain a file called "program.sdfg" and a subfolder called "build" with the shared object. :param folder: Path to SDFG output folder. :return: A callable CompiledSDFG object. """ from dace.codegen import compiled_sdfg as csdfg sdfg = SDFG.from_file(os.path.join(folder, 'program.sdfg')) suffix = config.Config.get('compiler', 'library_extension') return csdfg.CompiledSDFG( sdfg, csdfg.ReloadableDLL( os.path.join(folder, 'build', f'lib{sdfg.name}.{suffix}'), sdfg.name))
def inline_sdfgs(sdfg: SDFG, strict: bool = True, progress: bool = False) -> int: """ Inlines all possible nested SDFGs (or sub-SDFGs) using an optimized routine that uses the structure of the SDFG hierarchy. :param sdfg: The SDFG to transform. :param strict: If True (default), operates in strict mode. :param progress: If True, prints out a progress bar of inlining (may be inaccurate, requires ``tqdm``) :return: The total number of SDFGs inlined. """ from dace.transformation.interstate import InlineSDFG # Avoid import loop counter = 0 sdfgs = list(sdfg.all_sdfgs_recursive()) if progress: from tqdm import tqdm pbar = tqdm(total=len(sdfgs)) for sd in reversed(sdfgs): id = sd.sdfg_id for state_id, state in enumerate(sd.nodes()): for node in state.nodes(): if not isinstance(node, NestedSDFG): continue # We have to reevaluate every time due to changing IDs node_id = state.node_id(node) candidate = { InlineSDFG._nested_sdfg: node_id, } inliner = InlineSDFG(id, state_id, candidate, 0, override=True) if inliner.can_be_applied(state, candidate, 0, sd, strict=strict): inliner.apply(sd) counter += 1 if progress: pbar.update(1) if progress: pbar.close() if config.Config.get_bool('debugprint'): print(f'Inlined {counter} SDFGs') return counter
def consolidate_edges(sdfg: SDFG, starting_scope=None) -> int: """ Union scope-entering memlets relating to the same data node in all states. This effectively reduces the number of connectors and allows more transformations to be performed, at the cost of losing the individual per-tasklet memlets. :param sdfg: The SDFG to consolidate. :return: Number of edges removed. """ from dace.sdfg.propagation import propagate_memlets_sdfg, propagate_memlets_scope consolidated = 0 for state in sdfg.nodes(): # Start bottom-up if starting_scope and starting_scope.entry not in state.nodes(): continue queue = [starting_scope] if starting_scope else state.scope_leaves() next_queue = [] while len(queue) > 0: for scope in queue: consolidated += consolidate_edges_scope(state, scope.entry) consolidated += consolidate_edges_scope(state, scope.exit) if scope.parent is not None: next_queue.append(scope.parent) queue = next_queue next_queue = [] if starting_scope is not None: # Repropagate memlets from this scope outwards propagate_memlets_scope(sdfg, state, starting_scope) # No need to traverse other states break # Repropagate memlets if starting_scope is None: propagate_memlets_sdfg(sdfg) return consolidated
def get_next_nonempty_states(sdfg: SDFG, state: SDFGState) -> Set[SDFGState]: """ From the given state, return the next set of states that are reachable in the SDFG, skipping empty states. Traversal stops at the non-empty state. This function is used to determine whether synchronization should happen at the end of a GPU state. :param sdfg: The SDFG that contains the state. :param state: The state to start from. :return: A set of reachable non-empty states. """ result: Set[SDFGState] = set() # Traverse children until states are not empty for succ in sdfg.successors(state): result |= set( dfs_conditional(sdfg, sources=[succ], condition=lambda parent, _: parent.is_empty())) # Filter out empty states result = {s for s in result if not s.is_empty()} return result
def consolidate_edges(sdfg: SDFG) -> int: """ Union scope-entering memlets relating to the same data node in all states. This effectively reduces the number of connectors and allows more transformations to be performed, at the cost of losing the individual per-tasklet memlets. :param sdfg: The SDFG to consolidate. :return: Number of edges removed. """ consolidated = 0 for state in sdfg.nodes(): # Start bottom-up queue = state.scope_leaves() next_queue = [] while len(queue) > 0: for scope in queue: consolidated += consolidate_edges_scope(state, scope.entry) consolidated += consolidate_edges_scope(state, scope.exit) if scope.parent is not None: next_queue.append(scope.parent) queue = next_queue next_queue = [] return consolidated
def _structured_control_flow_traversal( sdfg: SDFG, start: SDFGState, ptree: Dict[SDFGState, SDFGState], branch_merges: Dict[SDFGState, SDFGState], back_edges: List[Edge[InterstateEdge]], dispatch_state: Callable[[SDFGState], str], parent_block: GeneralBlock, stop: SDFGState = None, generate_children_of: SDFGState = None) -> Set[SDFGState]: """ Helper function for ``structured_control_flow_tree``. :param sdfg: SDFG. :param start: Starting state for traversal. :param ptree: State parent tree (computed from ``state_parent_tree``). :param branch_merges: Dictionary mapping from branch state to its merge state. :param dispatch_state: A function that dispatches code generation for a single state. :param parent_block: The block to append children to. :param stop: Stopping state to not traverse through (merge state of a branch or guard state of a loop). :return: Generator that yields states in state-order from ``start`` to ``stop``. """ # Traverse states in custom order visited = set() if stop is not None: visited.add(stop) stack = [start] while stack: node = stack.pop() if (generate_children_of is not None and not _child_of(node, generate_children_of, ptree)): continue if node in visited: continue visited.add(node) stateblock = SingleState(dispatch_state, node) oe = sdfg.out_edges(node) if len(oe) == 0: # End state # If there are no remaining nodes, this is the last state and it can # be marked as such if len(stack) == 0: stateblock.last_state = True parent_block.elements.append(stateblock) continue elif len(oe) == 1: # No traversal change stack.append(oe[0].dst) parent_block.elements.append(stateblock) continue # Potential branch or loop if node in branch_merges: mergestate = branch_merges[node] # Add branching node and ignore outgoing edges parent_block.elements.append(stateblock) parent_block.edges_to_ignore.extend(oe) stateblock.last_state = True # Parse all outgoing edges recursively first cblocks: Dict[Edge[InterstateEdge], GeneralBlock] = {} for branch in oe: cblocks[branch] = GeneralBlock(dispatch_state, [], []) visited |= _structured_control_flow_traversal( sdfg, branch.dst, ptree, branch_merges, back_edges, dispatch_state, cblocks[branch], stop=mergestate, generate_children_of=node) # Classify branch type: branch_block = None # If there are 2 out edges, one negation of the other: # * if/else in case both branches are not merge state # * if without else in case one branch is merge state if (len(oe) == 2 and oe[0].data.condition_sympy() == sp.Not( oe[1].data.condition_sympy())): # If without else if oe[0].dst is mergestate: branch_block = IfScope(dispatch_state, sdfg, node, oe[1].data.condition, cblocks[oe[1]]) elif oe[1].dst is mergestate: branch_block = IfScope(dispatch_state, sdfg, node, oe[0].data.condition, cblocks[oe[0]]) else: branch_block = IfScope(dispatch_state, sdfg, node, oe[0].data.condition, cblocks[oe[0]], cblocks[oe[1]]) else: # If there are 2 or more edges (one is not the negation of the # other): switch = _cases_from_branches(oe, cblocks) if switch: # If all edges are of form "x == y" for a single x and # integer y, it is a switch/case branch_block = SwitchCaseScope(dispatch_state, sdfg, node, switch[0], switch[1]) else: # Otherwise, create if/else if/.../else goto exit chain branch_block = IfElseChain(dispatch_state, sdfg, node, [(e.data.condition, cblocks[e]) for e in oe]) # End of branch classification parent_block.elements.append(branch_block) if mergestate != stop: stack.append(mergestate) elif len(oe) == 2: # Potential loop # TODO(later): Recognize do/while loops # If loop, traverse body, then exit body_start = None loop_exit = None scope = None if ptree[oe[0].dst] == node and ptree[oe[1].dst] != node: scope = _loop_from_structure(sdfg, node, oe[0], oe[1], back_edges, dispatch_state) body_start = oe[0].dst loop_exit = oe[1].dst elif ptree[oe[1].dst] == node and ptree[oe[0].dst] != node: scope = _loop_from_structure(sdfg, node, oe[1], oe[0], back_edges, dispatch_state) body_start = oe[1].dst loop_exit = oe[0].dst if scope: visited |= _structured_control_flow_traversal( sdfg, body_start, ptree, branch_merges, back_edges, dispatch_state, scope.body, stop=node, generate_children_of=node) # Add branching node and ignore outgoing edges parent_block.elements.append(stateblock) parent_block.edges_to_ignore.extend(oe) parent_block.elements.append(scope) # If for loop, ignore certain edges if isinstance(scope, ForScope): # Mark init edge(s) to ignore in parent_block and all children _ignore_recursive([ e for e in sdfg.in_edges(node) if e not in back_edges ], parent_block) # Mark back edge for ignoring in all children of loop body _ignore_recursive( [e for e in sdfg.in_edges(node) if e in back_edges], scope.body) stack.append(loop_exit) continue # No proper loop detected: Unstructured control flow parent_block.elements.append(stateblock) stack.extend([e.dst for e in oe]) else: # No merge state: Unstructured control flow parent_block.elements.append(stateblock) stack.extend([e.dst for e in oe]) return visited - {stop}
def _loop_from_structure( sdfg: SDFG, guard: SDFGState, enter_edge: Edge[InterstateEdge], leave_edge: Edge[InterstateEdge], back_edges: List[Edge[InterstateEdge]], dispatch_state: Callable[[SDFGState], str]) -> Union[ForScope, WhileScope]: """ Helper method that constructs the correct structured loop construct from a set of states. Can construct for or while loops. """ body = GeneralBlock(dispatch_state, [], []) guard_inedges = sdfg.in_edges(guard) increment_edges = [e for e in guard_inedges if e in back_edges] init_edges = [e for e in guard_inedges if e not in back_edges] # If no back edge found (or more than one, indicating a "continue" # statement), disregard if len(increment_edges) > 1 or len(increment_edges) == 0: return None increment_edge = increment_edges[0] # Mark increment edge to be ignored in body body.edges_to_ignore.append(increment_edge) # Outgoing edges must be a negation of each other if enter_edge.data.condition_sympy() != (sp.Not( leave_edge.data.condition_sympy())): return None # Body of guard state must be empty if not guard.is_empty(): return None if not increment_edge.data.is_unconditional(): return None if len(enter_edge.data.assignments) > 0: return None condition = enter_edge.data.condition # Detect whether this loop is a for loop: # All incoming edges to the guard must set the same variable itvars = None for iedge in guard_inedges: if itvars is None: itvars = set(iedge.data.assignments.keys()) else: itvars &= iedge.data.assignments.keys() if itvars and len(itvars) == 1: itvar = next(iter(itvars)) init = init_edges[0].data.assignments[itvar] # Check that all init edges are the same and that increment edge only # increments if (all(e.data.assignments[itvar] == init for e in init_edges) and len(increment_edge.data.assignments) == 1): update = increment_edge.data.assignments[itvar] return ForScope(dispatch_state, itvar, guard, init, condition, update, body) # Otherwise, it is a while loop return WhileScope(dispatch_state, guard, condition, body)
def apply(self, graph: SDFGState, sdfg: SDFG): node_a = self.node_a node_b = self.node_b prefix = self.prefix # Determine direction of new memlet scope_dict = graph.scope_dict() propagate_forward = sd.scope_contains_scope(scope_dict, node_a, node_b) array = self.array if array is None or len(array) == 0: array = next(e.data.data for e in graph.edges_between(node_a, node_b) if e.data.data is not None and e.data.wcr is None) original_edge = None invariant_memlet = None for edge in graph.edges_between(node_a, node_b): if array == edge.data.data: original_edge = edge invariant_memlet = edge.data break if invariant_memlet is None: for edge in graph.edges_between(node_a, node_b): original_edge = edge invariant_memlet = edge.data warnings.warn('Array %s not found! Using array %s instead.' % (array, invariant_memlet.data)) array = invariant_memlet.data break if invariant_memlet is None: raise NameError('Array %s not found!' % array) if self.create_array: # Add transient array new_data, _ = sdfg.add_transient( name=prefix + invariant_memlet.data, shape=[ symbolic.overapproximate(r).simplify() for r in invariant_memlet.bounding_box_size() ], dtype=sdfg.arrays[invariant_memlet.data].dtype, find_new_name=True) else: new_data = prefix + invariant_memlet.data data_node = nodes.AccessNode(new_data) # Store as fields so that other transformations can use them self._local_name = new_data self._data_node = data_node to_data_mm = copy.deepcopy(invariant_memlet) from_data_mm = copy.deepcopy(invariant_memlet) offset = subsets.Indices([r[0] for r in invariant_memlet.subset]) # Reconnect, assuming one edge to the access node graph.remove_edge(original_edge) if propagate_forward: graph.add_edge(node_a, original_edge.src_conn, data_node, None, to_data_mm) new_edge = graph.add_edge(data_node, None, node_b, original_edge.dst_conn, from_data_mm) else: new_edge = graph.add_edge(node_a, original_edge.src_conn, data_node, None, to_data_mm) graph.add_edge(data_node, None, node_b, original_edge.dst_conn, from_data_mm) # Offset all edges in the memlet tree (including the new edge) for edge in graph.memlet_tree(new_edge): edge.data.subset.offset(offset, True) edge.data.data = new_data return data_node
def inline_sdfgs(sdfg: SDFG, permissive: bool = False, progress: bool = None, multistate: bool = True) -> int: """ Inlines all possible nested SDFGs (or sub-SDFGs) using an optimized routine that uses the structure of the SDFG hierarchy. :param sdfg: The SDFG to transform. :param permissive: If True, operates in permissive mode, which ignores some checks. :param progress: If True, prints out a progress bar of inlining (may be inaccurate, requires ``tqdm``). If None, prints out progress if over 5 seconds have passed. If False, never shows progress bar. :param multistate: Include :return: The total number of SDFGs inlined. """ # Avoid import loops from dace.transformation.interstate import InlineSDFG, InlineMultistateSDFG if progress is True or progress is None: try: from tqdm import tqdm except ImportError: tqdm = None counter = 0 sdfgs = list(sdfg.all_sdfgs_recursive()) if progress is True: pbar = tqdm(total=len(sdfgs), desc='Inlining SDFGs') start = time.time() for sd in reversed(sdfgs): id = sd.sdfg_id for state in sd.nodes(): for node in state.nodes(): if (progress is None and tqdm is not None and (time.time() - start) > 5): progress = True pbar = tqdm(total=len(sdfgs), desc='Inlining SDFG', initial=counter) if not isinstance(node, NestedSDFG): continue # We have to reevaluate every time due to changing IDs node_id = state.node_id(node) state_id = sd.node_id(state) if multistate: candidate = { InlineMultistateSDFG.nested_sdfg: node_id, } inliner = InlineMultistateSDFG(id, state_id, candidate, 0, override=True) if inliner.can_be_applied(state, candidate, 0, sd, permissive=permissive): inliner.apply(sd) counter += 1 if progress: pbar.update(1) continue candidate = { InlineSDFG._nested_sdfg: node_id, } inliner = InlineSDFG(id, state_id, candidate, 0, override=True) if inliner.can_be_applied(state, candidate, 0, sd, permissive=permissive): inliner.apply(sd) counter += 1 if progress: pbar.update(1) if progress: pbar.close() if config.Config.get_bool('debugprint') and counter > 0: print(f'Inlined {counter} SDFGs') return counter
def fuse_states(sdfg: SDFG, permissive: bool = False, progress: bool = None) -> int: """ Fuses all possible states of an SDFG (and all sub-SDFGs) using an optimized routine that uses the structure of the StateFusion transformation. :param sdfg: The SDFG to transform. :param permissive: If True, operates in permissive mode, which ignores some race condition checks. :param progress: If True, prints out a progress bar of fusion (may be inaccurate, requires ``tqdm``). If None, prints out progress if over 5 seconds have passed. If False, never shows progress bar. :return: The total number of states fused. """ from dace.transformation.interstate import StateFusion # Avoid import loop if progress is True or progress is None: try: from tqdm import tqdm except ImportError: tqdm = None counter = 0 if progress is True or progress is None: fusible_states = 0 for sd in sdfg.all_sdfgs_recursive(): fusible_states += sd.number_of_edges() if progress is True: pbar = tqdm(total=fusible_states, desc='Fusing states') start = time.time() for sd in sdfg.all_sdfgs_recursive(): id = sd.sdfg_id while True: edges = list(sd.nx.edges) applied = 0 skip_nodes = set() for u, v in edges: if (progress is None and tqdm is not None and (time.time() - start) > 5): progress = True pbar = tqdm(total=fusible_states, desc='Fusing states', initial=counter) if u in skip_nodes or v in skip_nodes: continue candidate = { StateFusion.first_state: u, StateFusion.second_state: v } sf = StateFusion(id, -1, candidate, 0, override=True) if sf.can_be_applied(sd, candidate, 0, sd, permissive=permissive): sf.apply(sd) applied += 1 counter += 1 if progress: pbar.update(1) skip_nodes.add(u) skip_nodes.add(v) if applied == 0: break if progress: pbar.close() if config.Config.get_bool('debugprint') and counter > 0: print(f'Applied {counter} State Fusions') return counter