def api_fields(self) -> List[FieldDecl]: tmp_field_names = self.field_names.difference(self.fields) tmp_fields = [ FieldDecl(name=n, data_type=DataType.AUTO, axes=self.domain.axes_names, is_api=False) for n in tmp_field_names ] return tmp_fields + [ FieldDecl(name=n, data_type=DataType.AUTO, axes=self.domain.axes_names, is_api=True) for n in self.fields ]
def test_field_decl_dims(defir_to_gtir, axes, expected_mask): field_decl = FieldDecl(name="a", data_type=DataType.INT64, axes=axes, is_api=True) gtir_decl = defir_to_gtir.visit_FieldDecl(field_decl) assert gtir_decl.dimensions == expected_mask
def api_fields(self) -> List[FieldDecl]: return [ FieldDecl(name=n, data_type=DataType.AUTO, axes=self.domain.axes_names, is_api=True) for n in self.fields ]
def test_field_decl(defir_to_gtir): field_decl = FieldDecl(name="a", data_type=DataType.BOOL, axes=["I", "J", "K"], is_api=True) gtir_decl = defir_to_gtir.visit_FieldDecl(field_decl) assert isinstance(gtir_decl, gtir.FieldDecl) assert gtir_decl.name == "a" assert gtir_decl.dtype == common.DataType.BOOL
def make_definition(name: str, fields: List[str], domain: Domain, body: BodyType, iteration_order: IterationOrder) -> StencilDefinition: api_signature = [ArgumentInfo(name=n, is_keyword=False) for n in fields] tmp_fields = {i[0] for i in body}.union({i[1] for i in body}).difference(fields) api_fields = [ FieldDecl(name=n, data_type=DataType.AUTO, axes=domain.axes_names, is_api=True) for n in fields ] + [ FieldDecl(name=n, data_type=DataType.AUTO, axes=domain.axes_names, is_api=False) for n in tmp_fields ] return StencilDefinition( name=name, domain=domain, api_signature=api_signature, api_fields=api_fields, parameters=[], computations=[ ComputationBlock( interval=AxisInterval(start=AxisBound(level=LevelMarker.START), end=AxisBound(level=LevelMarker.END)), iteration_order=iteration_order, body=BlockStmt(stmts=[ make_assign(*assign, loc_scope=name, loc_line=i) for i, assign in enumerate(body) ]), ) ], docstring="", )
def build(self) -> ComputationBlock: self.loc.scope = self.parent.child_scope if self.parent else self.scope temp_fields = self.fields.difference( self.parent.fields) if self.parent else set() temp_decls = [ FieldDecl(name=n, data_type=DataType.AUTO, axes=self.parent.domain.axes_names, is_api=False) for n in temp_fields ] return ComputationBlock( interval=AxisInterval( start=AxisBound(level=LevelMarker.START, offset=self.start), end=AxisBound(level=LevelMarker.END, offset=self.end), ), iteration_order=self.order, body=BlockStmt(stmts=temp_decls + [stmt.build() for stmt in self.children], ), loc=self.loc, )