def test_stencil_to_computation() -> None: stencil = StencilFactory( name="stencil", params=[ FieldDeclFactory( name="a", dtype=common.DataType.FLOAT64, ), oir.ScalarDecl( name="b", dtype=common.DataType.INT32, ), ], vertical_loops__0__sections__0__horizontal_executions__0__body=[ AssignStmtFactory( left=FieldAccessFactory(name="a"), right=ScalarAccessFactory(name="b") ) ], ) computation = OirToNpir().visit(stencil) assert set(d.name for d in computation.api_field_decls) == { "a", } assert set(computation.arguments) == {"a", "b"} assert len(computation.vertical_passes) == 1
def sdfg_arrays_to_oir_decls( sdfg: dace.SDFG) -> Tuple[List[oir.Decl], List[oir.Temporary]]: params = list() decls = list() array: dace.data.Data for name, array in sdfg.arrays.items(): dtype = common.typestr_to_data_type(dace_dtype_to_typestr(array.dtype)) if isinstance(array, dace.data.Array): dimensions = array_dimensions(array) if not array.transient: params.append( oir.FieldDecl( name=name, dtype=dtype, dimensions=dimensions, data_dims=array.shape[sum(dimensions):], )) else: decls.append( oir.Temporary( name=name, dtype=dtype, dimensions=dimensions, data_dims=array.shape[sum(dimensions):], )) else: assert isinstance(array, dace.data.Scalar) params.append(oir.ScalarDecl(name=name, dtype=dtype)) reserved_symbols = internal_symbols(sdfg) for sym, stype in sdfg.symbols.items(): if sym not in reserved_symbols: params.append( oir.ScalarDecl(name=sym, dtype=common.typestr_to_data_type( stype.as_numpy_dtype().str))) return params, decls
def test_stencil_to_computation(): stencil = StencilFactory( name="stencil", params=[ FieldDeclFactory( name="a", dtype=common.DataType.FLOAT64, ), oir.ScalarDecl( name="b", dtype=common.DataType.INT32, ), ], vertical_loops__0__sections__0__horizontal_executions__0__body=[ AssignStmtFactory(left=FieldAccessFactory(name="a"), right=ScalarAccessFactory(name="b")) ], ) computation = OirToNpir().visit(stencil) assert computation.field_params == ["a"] assert computation.params == ["a", "b"] assert len(computation.vertical_passes) == 1
def visit_ScalarDecl(self, node: gtir.ScalarDecl) -> oir.ScalarDecl: return oir.ScalarDecl(name=node.name, dtype=node.dtype, loc=node.loc)
def visit_ScalarDecl(self, node: gtir.ScalarDecl, **kwargs: Any) -> oir.ScalarDecl: return oir.ScalarDecl(name=node.name, dtype=node.dtype)
def visit_VerticalLoopSection( self, node: oir.VerticalLoopSection, *, block_extents: Dict[int, Extent], new_symbol_name: Callable[[str], str], **kwargs: Any, ) -> oir.VerticalLoopSection: horizontal_executions = [node.horizontal_executions[0]] new_block_extents = [block_extents[id(horizontal_executions[-1])]] for this_hexec in node.horizontal_executions[1:]: last_extent = new_block_extents[-1] last_writes = AccessCollector.apply( horizontal_executions[-1]).write_fields() this_offset_reads = { name for name, offsets in AccessCollector.apply( this_hexec).read_offsets().items() if any(off[0] != 0 or off[1] != 0 for off in offsets) } reads_with_offset_after_write = last_writes & this_offset_reads this_extent = block_extents[id(this_hexec)] if reads_with_offset_after_write or last_extent != this_extent: # Cannot merge: simply append to list horizontal_executions.append(this_hexec) new_block_extents.append(this_extent) else: # Merge duplicated_locals = { decl.name for decl in horizontal_executions[-1].declarations } & {decl.name for decl in this_hexec.declarations} # Map from old to new scalar names applied to the second horizontal execution scalar_map = { name: new_symbol_name(name) for name in duplicated_locals } locals_symtable = { decl.name: decl for decl in this_hexec.declarations } new_body = self.visit(this_hexec.body, scalar_map=scalar_map, **kwargs) this_not_duplicated = [ decl for decl in this_hexec.declarations if decl.name not in duplicated_locals ] this_mapped = [ oir.ScalarDecl(name=scalar_map[name], dtype=locals_symtable[name].dtype) for name in duplicated_locals ] horizontal_executions[-1] = oir.HorizontalExecution( body=horizontal_executions[-1].body + new_body, declarations=(horizontal_executions[-1].declarations + this_not_duplicated + this_mapped), ) return oir.VerticalLoopSection( interval=node.interval, horizontal_executions=horizontal_executions)
def visit_VerticalLoopSection( self, node: oir.VerticalLoopSection, *, block_extents: Dict[int, Extent], new_symbol_name: Callable[[str], str], **kwargs: Any, ) -> oir.VerticalLoopSection: @dataclass class UncheckedHorizontalExecution: # local replacement without type checking for type-checked oir node # required to reach reasonable run times for large node counts body: List[oir.Stmt] declarations: List[oir.LocalScalar] loc: Optional[SourceLocation] assert set(oir.HorizontalExecution.__fields__) == { "loc", "symtable_", "body", "declarations", }, ("Unexpected field in oir.HorizontalExecution, " "probably UncheckedHorizontalExecution needs an update") @classmethod def from_oir(cls, hexec: oir.HorizontalExecution): return cls(body=hexec.body, declarations=hexec.declarations, loc=hexec.loc) def to_oir(self) -> oir.HorizontalExecution: return oir.HorizontalExecution(body=self.body, declarations=self.declarations, loc=self.loc) horizontal_executions = [ UncheckedHorizontalExecution.from_oir( node.horizontal_executions[0]) ] new_block_extents = [block_extents[id(node.horizontal_executions[0])]] last_writes = AccessCollector.apply( node.horizontal_executions[0]).write_fields() for this_hexec in node.horizontal_executions[1:]: last_extent = new_block_extents[-1] this_offset_reads = { name for name, offsets in AccessCollector.apply( this_hexec).read_offsets().items() if any(off[0] != 0 or off[1] != 0 for off in offsets) } reads_with_offset_after_write = last_writes & this_offset_reads this_extent = block_extents[id(this_hexec)] if reads_with_offset_after_write or last_extent != this_extent: # Cannot merge: simply append to list horizontal_executions.append( UncheckedHorizontalExecution.from_oir(this_hexec)) new_block_extents.append(this_extent) last_writes = AccessCollector.apply(this_hexec).write_fields() else: # Merge duplicated_locals = { decl.name for decl in horizontal_executions[-1].declarations } & {decl.name for decl in this_hexec.declarations} # Map from old to new scalar names applied to the second horizontal execution scalar_map = { name: new_symbol_name(name) for name in duplicated_locals } locals_symtable = { decl.name: decl for decl in this_hexec.declarations } new_body = self.visit(this_hexec.body, scalar_map=scalar_map, **kwargs) this_not_duplicated = [ decl for decl in this_hexec.declarations if decl.name not in duplicated_locals ] this_mapped = [ oir.ScalarDecl(name=scalar_map[name], dtype=locals_symtable[name].dtype) for name in duplicated_locals ] horizontal_executions[-1] = UncheckedHorizontalExecution( body=horizontal_executions[-1].body + new_body, declarations=(horizontal_executions[-1].declarations + this_not_duplicated + this_mapped), loc=horizontal_executions[-1].loc, ) last_writes |= AccessCollector.apply(new_body).write_fields() return oir.VerticalLoopSection( interval=node.interval, horizontal_executions=[ hexec.to_oir() for hexec in horizontal_executions ], loc=node.loc, )