def visit_HorizontalExecution( self, node: oir.HorizontalExecution, tmps_to_replace: Set[str], symtable: Dict[str, Any], new_symbol_name: Callable[[str], str], **kwargs: Any, ) -> oir.HorizontalExecution: local_tmps_to_replace = (node.iter_tree().if_isinstance( oir.FieldAccess).getattr("name").if_in(tmps_to_replace).to_set()) tmps_name_map = { tmp: new_symbol_name(tmp) for tmp in local_tmps_to_replace } return oir.HorizontalExecution( body=self.visit(node.body, tmps_name_map=tmps_name_map, symtable=symtable, **kwargs), declarations=node.declarations + [ oir.LocalScalar(name=tmps_name_map[tmp], dtype=symtable[tmp].dtype, loc=symtable[tmp].loc) for tmp in local_tmps_to_replace ], )
def visit_HorizontalExecution( self, node: oir.HorizontalExecution, local_tmps: Set[str], symtable: Dict[str, Any], **kwargs: Any, ) -> oir.HorizontalExecution: declarations = node.declarations + [ oir.LocalScalar( name=tmp, dtype=symtable[tmp].dtype, loc=symtable[tmp].loc) for tmp in node.iter_tree().if_isinstance(oir.FieldAccess).getattr( "name").if_in(local_tmps).to_set() ] return oir.HorizontalExecution( body=self.visit(node.body, local_tmps=local_tmps, **kwargs), mask=self.visit(node.mask, local_tmps=local_tmps, **kwargs), declarations=declarations, )
def visit_LocalScalar(self, node: oir.LocalScalar) -> oir.LocalScalar: return oir.LocalScalar(name=decls_map[node.name], dtype=node.dtype)
def _merge( self, horizontal_executions: List[oir.HorizontalExecution], symtable: Dict[str, Any], new_symbol_name: Callable[[str], str], protected_fields: Set[str], ) -> List[oir.HorizontalExecution]: """Recursively merge horizontal executions. Uses the following algorithm: 1. Get output fields of the first horizontal execution. 2. Check in which following h. execs. the outputs are read. 3. Duplicate the body of the first h. exec. for each read access (with corresponding offset) and prepend it to the depending h. execs. 4. Recurse on the resulting h. execs. """ if len(horizontal_executions) <= 1: return horizontal_executions first, *others = horizontal_executions first_accesses = AccessCollector.apply(first) other_accesses = AccessCollector.apply(others) def first_fields_rewritten_later() -> bool: return bool(first_accesses.fields() & other_accesses.write_fields()) def first_has_large_body() -> bool: return len(first.body) > self.max_horizontal_execution_body_size def first_writes_protected() -> bool: return bool(protected_fields & first_accesses.write_fields()) def first_has_expensive_function_call() -> bool: if self.allow_expensive_function_duplication: return False nf = common.NativeFunction expensive_calls = { nf.SIN, nf.COS, nf.TAN, nf.ARCSIN, nf.ARCCOS, nf.ARCTAN, nf.SQRT, nf.EXP, nf.LOG, } calls = first.iter_tree().if_isinstance( oir.NativeFuncCall).getattr("func") return any(call in expensive_calls for call in calls) if (first_fields_rewritten_later() or first_writes_protected() or first_has_large_body() or first_has_expensive_function_call()): return [first] + self._merge(others, symtable, new_symbol_name, protected_fields) writes = first_accesses.write_fields() others_otf = [] for horizontal_execution in others: read_offsets: Set[Tuple[int, int, int]] = set() read_offsets = read_offsets.union( *(offsets for field, offsets in AccessCollector.apply( horizontal_execution).read_offsets().items() if field in writes)) if not read_offsets: others_otf.append(horizontal_execution) continue offset_symbol_map = {(name, o): new_symbol_name(name) for name in writes for o in read_offsets} merged = oir.HorizontalExecution( body=self.visit(horizontal_execution.body, offset_symbol_map=offset_symbol_map), declarations=horizontal_execution.declarations + [ oir.LocalScalar(name=new_name, dtype=symtable[old_name].dtype) for (old_name, _), new_name in offset_symbol_map.items() ] + [ d for d in first.declarations if d not in horizontal_execution.declarations ], ) for offset in read_offsets: merged.body = (self.visit( first.body, shift=offset, offset_symbol_map=offset_symbol_map, symtable=symtable, ) + merged.body) others_otf.append(merged) return self._merge(others_otf, symtable, new_symbol_name, protected_fields)