def apply(self, sdfg: dace.SDFG) -> None: state = sdfg.node(self.state_id) left = self.left(sdfg) right = self.right(sdfg) # Merge source locations dinfo = self._merge_source_locations(left, right) # merge oir nodes res = HorizontalExecutionLibraryNode( oir_node=oir.HorizontalExecution( body=left.as_oir().body + right.as_oir().body, declarations=left.as_oir().declarations + right.as_oir().declarations, ), iteration_space=left.iteration_space, debuginfo=dinfo, ) state.add_node(res) intermediate_accesses = set( n for path in nx.all_simple_paths(state.nx, left, right) for n in path[1:-1]) # rewire edges and connectors to left and delete right for edge in state.edges_between(left, right): state.remove_edge_and_connectors(edge) for acc in intermediate_accesses: for edge in state.in_edges(acc): if edge.src is not left: rewire_edge(state, edge, dst=res) else: state.remove_edge_and_connectors(edge) for edge in state.out_edges(acc): if edge.dst is not right: rewire_edge(state, edge, src=res) else: state.remove_edge_and_connectors(edge) for edge in state.in_edges(left): rewire_edge(state, edge, dst=res) for edge in state.out_edges(right): rewire_edge(state, edge, src=res) for edge in state.out_edges(left): rewire_edge(state, edge, src=res) for edge in state.in_edges(right): rewire_edge(state, edge, dst=res) state.remove_node(left) state.remove_node(right) for acc in intermediate_accesses: if not state.in_edges(acc): if not state.out_edges(acc): state.remove_node(acc) else: assert (len(state.edges_between(acc, res)) == 1 and len(state.out_edges(acc)) == 1), "Previously written array now read-only." state.remove_node(acc) res.remove_in_connector("IN_" + acc.label) elif not state.out_edges: acc.access = dace.AccessType.WriteOnly
def visit_HorizontalExecution( self, node: oir.HorizontalExecution, tmps_to_replace: Set[str], symtable: Dict[str, Any], new_symbol_name: Callable[[str], str], **kwargs: Any, ) -> oir.HorizontalExecution: local_tmps_to_replace = (node.iter_tree().if_isinstance( oir.FieldAccess).getattr("name").if_in(tmps_to_replace).to_set()) tmps_name_map = { tmp: new_symbol_name(tmp) for tmp in local_tmps_to_replace } return oir.HorizontalExecution( body=self.visit(node.body, tmps_name_map=tmps_name_map, symtable=symtable, **kwargs), declarations=node.declarations + [ oir.LocalScalar(name=tmps_name_map[tmp], dtype=symtable[tmp].dtype, loc=symtable[tmp].loc) for tmp in local_tmps_to_replace ], )
def visit_VerticalLoop(self, node: gtir.VerticalLoop, *, ctx: Context) -> oir.VerticalLoop: horiz_execs: List[oir.HorizontalExecution] = [] for stmt in node.body: ctx.reset_local_scalars() ret = self.visit(stmt, ctx=ctx) stmts = utils.flatten_list( [ret] if isinstance(ret, oir.Stmt) else ret) horiz_execs.append( oir.HorizontalExecution(body=stmts, declarations=ctx.local_scalars)) ctx.temp_fields += [ oir.Temporary(name=temp.name, dtype=temp.dtype, dimensions=temp.dimensions) for temp in node.temporaries ] return oir.VerticalLoop( loop_order=node.loop_order, sections=[ oir.VerticalLoopSection( interval=self.visit(node.interval), horizontal_executions=horiz_execs, loc=node.loc, ) ], )
def visit_ParAssignStmt( self, node: gtir.ParAssignStmt, *, mask: oir.Expr = None, ctx: Context, **kwargs: Any ) -> None: body = [oir.AssignStmt(left=self.visit(node.left), right=self.visit(node.right))] if mask is not None: body = [oir.MaskStmt(body=body, mask=mask)] ctx.add_horizontal_execution( oir.HorizontalExecution( body=body, declarations=[], ), )
def visit_HorizontalExecution( self, node: oir.HorizontalExecution, *, name_map: Dict[str, str], fills: List[oir.Stmt], flushes: List[oir.Stmt], **kwargs: Any, ) -> oir.HorizontalExecution: return oir.HorizontalExecution( body=fills + self.visit(node.body, name_map=name_map, **kwargs) + flushes, declarations=node.declarations, )
def visit_HorizontalExecution( self, node: oir.HorizontalExecution, local_tmps: Set[str], symtable: Dict[str, Any], **kwargs: Any, ) -> oir.HorizontalExecution: declarations = node.declarations + [ oir.LocalScalar( name=tmp, dtype=symtable[tmp].dtype, loc=symtable[tmp].loc) for tmp in node.iter_tree().if_isinstance(oir.FieldAccess).getattr( "name").if_in(local_tmps).to_set() ] return oir.HorizontalExecution( body=self.visit(node.body, local_tmps=local_tmps, **kwargs), mask=self.visit(node.mask, local_tmps=local_tmps, **kwargs), declarations=declarations, )
def _create_mask(ctx: "GTIRToOIR.Context", name: str, cond: oir.Expr) -> oir.Temporary: mask_field_decl = oir.Temporary(name=name, dtype=DataType.BOOL, dimensions=(True, True, True)) ctx.add_decl(mask_field_decl) fill_mask_field = oir.HorizontalExecution( body=[ oir.AssignStmt( left=oir.FieldAccess( name=mask_field_decl.name, offset=CartesianOffset.zero(), dtype=mask_field_decl.dtype, ), right=cond, ) ], declarations=[], ) ctx.add_horizontal_execution(fill_mask_field) return mask_field_decl
def build(self): return oir.HorizontalExecution( body=[], mask=None, )
def _merge( self, horizontal_executions: List[oir.HorizontalExecution], symtable: Dict[str, Any], new_symbol_name: Callable[[str], str], protected_fields: Set[str], ) -> List[oir.HorizontalExecution]: """Recursively merge horizontal executions. Uses the following algorithm: 1. Get output fields of the first horizontal execution. 2. Check in which following h. execs. the outputs are read. 3. Duplicate the body of the first h. exec. for each read access (with corresponding offset) and prepend it to the depending h. execs. 4. Recurse on the resulting h. execs. """ if len(horizontal_executions) <= 1: return horizontal_executions first, *others = horizontal_executions first_accesses = AccessCollector.apply(first) other_accesses = AccessCollector.apply(others) def first_fields_rewritten_later() -> bool: return bool(first_accesses.fields() & other_accesses.write_fields()) def first_has_large_body() -> bool: return len(first.body) > self.max_horizontal_execution_body_size def first_writes_protected() -> bool: return bool(protected_fields & first_accesses.write_fields()) def first_has_expensive_function_call() -> bool: if self.allow_expensive_function_duplication: return False nf = common.NativeFunction expensive_calls = { nf.SIN, nf.COS, nf.TAN, nf.ARCSIN, nf.ARCCOS, nf.ARCTAN, nf.SQRT, nf.EXP, nf.LOG, } calls = first.iter_tree().if_isinstance( oir.NativeFuncCall).getattr("func") return any(call in expensive_calls for call in calls) if (first_fields_rewritten_later() or first_writes_protected() or first_has_large_body() or first_has_expensive_function_call()): return [first] + self._merge(others, symtable, new_symbol_name, protected_fields) writes = first_accesses.write_fields() others_otf = [] for horizontal_execution in others: read_offsets: Set[Tuple[int, int, int]] = set() read_offsets = read_offsets.union( *(offsets for field, offsets in AccessCollector.apply( horizontal_execution).read_offsets().items() if field in writes)) if not read_offsets: others_otf.append(horizontal_execution) continue offset_symbol_map = {(name, o): new_symbol_name(name) for name in writes for o in read_offsets} merged = oir.HorizontalExecution( body=self.visit(horizontal_execution.body, offset_symbol_map=offset_symbol_map), declarations=horizontal_execution.declarations + [ oir.LocalScalar(name=new_name, dtype=symtable[old_name].dtype) for (old_name, _), new_name in offset_symbol_map.items() ] + [ d for d in first.declarations if d not in horizontal_execution.declarations ], ) for offset in read_offsets: merged.body = (self.visit( first.body, shift=offset, offset_symbol_map=offset_symbol_map, symtable=symtable, ) + merged.body) others_otf.append(merged) return self._merge(others_otf, symtable, new_symbol_name, protected_fields)
def visit_HorizontalExecution(self, node: oir.HorizontalExecution) -> oir.HorizontalExecution: return oir.HorizontalExecution( body=self._merge(node.body), declarations=node.declarations, loc=node.loc, )
def visit_VerticalLoopSection( self, node: oir.VerticalLoopSection, *, block_extents: Dict[int, Extent], new_symbol_name: Callable[[str], str], **kwargs: Any, ) -> oir.VerticalLoopSection: horizontal_executions = [node.horizontal_executions[0]] new_block_extents = [block_extents[id(horizontal_executions[-1])]] for this_hexec in node.horizontal_executions[1:]: last_extent = new_block_extents[-1] last_writes = AccessCollector.apply( horizontal_executions[-1]).write_fields() this_offset_reads = { name for name, offsets in AccessCollector.apply( this_hexec).read_offsets().items() if any(off[0] != 0 or off[1] != 0 for off in offsets) } reads_with_offset_after_write = last_writes & this_offset_reads this_extent = block_extents[id(this_hexec)] if reads_with_offset_after_write or last_extent != this_extent: # Cannot merge: simply append to list horizontal_executions.append(this_hexec) new_block_extents.append(this_extent) else: # Merge duplicated_locals = { decl.name for decl in horizontal_executions[-1].declarations } & {decl.name for decl in this_hexec.declarations} # Map from old to new scalar names applied to the second horizontal execution scalar_map = { name: new_symbol_name(name) for name in duplicated_locals } locals_symtable = { decl.name: decl for decl in this_hexec.declarations } new_body = self.visit(this_hexec.body, scalar_map=scalar_map, **kwargs) this_not_duplicated = [ decl for decl in this_hexec.declarations if decl.name not in duplicated_locals ] this_mapped = [ oir.ScalarDecl(name=scalar_map[name], dtype=locals_symtable[name].dtype) for name in duplicated_locals ] horizontal_executions[-1] = oir.HorizontalExecution( body=horizontal_executions[-1].body + new_body, declarations=(horizontal_executions[-1].declarations + this_not_duplicated + this_mapped), ) return oir.VerticalLoopSection( interval=node.interval, horizontal_executions=horizontal_executions)
def to_oir(self) -> oir.HorizontalExecution: return oir.HorizontalExecution(body=self.body, declarations=self.declarations, loc=self.loc)