def test_axis_interval(defir_to_gtir): axis_interval = AxisInterval( start=AxisBound(level=LevelMarker.START, offset=0), end=AxisBound(level=LevelMarker.END, offset=1), ) axis_start, axis_end = defir_to_gtir.visit_AxisInterval(axis_interval) assert isinstance(axis_start, gtir.AxisBound) assert isinstance(axis_end, gtir.AxisBound)
def build(self) -> ComputationBlock: self.loc.scope = self.parent.child_scope return ComputationBlock( interval=AxisInterval( start=AxisBound(level=LevelMarker.START, offset=self.start), end=AxisBound(level=LevelMarker.END, offset=self.end), ), iteration_order=self.order, body=BlockStmt(stmts=[stmt.build() for stmt in self.children], ), loc=self.loc, )
def build(self) -> ComputationBlock: self.loc.scope = self.parent.child_scope if self.parent else self.scope temp_fields = self.fields.difference( self.parent.fields) if self.parent else set() temp_decls = [ FieldDecl(name=n, data_type=DataType.AUTO, axes=self.parent.domain.axes_names, is_api=False) for n in temp_fields ] return ComputationBlock( interval=AxisInterval( start=AxisBound(level=LevelMarker.START, offset=self.start), end=AxisBound(level=LevelMarker.END, offset=self.end), ), iteration_order=self.order, body=BlockStmt(stmts=temp_decls + [stmt.build() for stmt in self.children], ), loc=self.loc, )
def test_intervalinfo_overlap(): overlap = overlap_with_extent( AxisInterval( start=AxisBound(level=LevelMarker.START, offset=2), end=AxisBound(level=LevelMarker.START, offset=4), ), (0, 2), ) assert overlap[0] == -2 and overlap[1] > 100 overlap = overlap_with_extent( AxisInterval( start=AxisBound(level=LevelMarker.START, offset=-1), end=AxisBound(level=LevelMarker.END, offset=0), ), (0, 2), ) assert overlap == (0, 2) overlap = overlap_with_extent( AxisInterval( start=AxisBound(level=LevelMarker.START, offset=-3), end=AxisBound(level=LevelMarker.START, offset=-1), ), (0, 0), ) assert overlap is None """
def make_definition(name: str, fields: List[str], domain: Domain, body: BodyType, iteration_order: IterationOrder) -> StencilDefinition: api_signature = [ArgumentInfo(name=n, is_keyword=False) for n in fields] tmp_fields = {i[0] for i in body}.union({i[1] for i in body}).difference(fields) api_fields = [ FieldDecl(name=n, data_type=DataType.AUTO, axes=domain.axes_names, is_api=True) for n in fields ] + [ FieldDecl(name=n, data_type=DataType.AUTO, axes=domain.axes_names, is_api=False) for n in tmp_fields ] return StencilDefinition( name=name, domain=domain, api_signature=api_signature, api_fields=api_fields, parameters=[], computations=[ ComputationBlock( interval=AxisInterval(start=AxisBound(level=LevelMarker.START), end=AxisBound(level=LevelMarker.END)), iteration_order=iteration_order, body=BlockStmt(stmts=[ make_assign(*assign, loc_scope=name, loc_line=i) for i, assign in enumerate(body) ]), ) ], docstring="", )
def test_axis_bound(defir_to_gtir): axis_bound = AxisBound(level=LevelMarker.START, offset=-51) gtir_bound = defir_to_gtir.visit_AxisBound(axis_bound) assert isinstance(gtir_bound, gtir.AxisBound) assert gtir_bound.level == common.LevelMarker.START assert gtir_bound.offset == -51