示例#1
0
def test_axis_interval(defir_to_gtir):
    axis_interval = AxisInterval(
        start=AxisBound(level=LevelMarker.START, offset=0),
        end=AxisBound(level=LevelMarker.END, offset=1),
    )
    axis_start, axis_end = defir_to_gtir.visit_AxisInterval(axis_interval)
    assert isinstance(axis_start, gtir.AxisBound)
    assert isinstance(axis_end, gtir.AxisBound)
示例#2
0
 def build(self) -> ComputationBlock:
     self.loc.scope = self.parent.child_scope
     return ComputationBlock(
         interval=AxisInterval(
             start=AxisBound(level=LevelMarker.START, offset=self.start),
             end=AxisBound(level=LevelMarker.END, offset=self.end),
         ),
         iteration_order=self.order,
         body=BlockStmt(stmts=[stmt.build() for stmt in self.children], ),
         loc=self.loc,
     )
示例#3
0
 def build(self) -> ComputationBlock:
     self.loc.scope = self.parent.child_scope if self.parent else self.scope
     temp_fields = self.fields.difference(
         self.parent.fields) if self.parent else set()
     temp_decls = [
         FieldDecl(name=n,
                   data_type=DataType.AUTO,
                   axes=self.parent.domain.axes_names,
                   is_api=False) for n in temp_fields
     ]
     return ComputationBlock(
         interval=AxisInterval(
             start=AxisBound(level=LevelMarker.START, offset=self.start),
             end=AxisBound(level=LevelMarker.END, offset=self.end),
         ),
         iteration_order=self.order,
         body=BlockStmt(stmts=temp_decls +
                        [stmt.build() for stmt in self.children], ),
         loc=self.loc,
     )
示例#4
0
def test_intervalinfo_overlap():
    overlap = overlap_with_extent(
        AxisInterval(
            start=AxisBound(level=LevelMarker.START, offset=2),
            end=AxisBound(level=LevelMarker.START, offset=4),
        ),
        (0, 2),
    )
    assert overlap[0] == -2 and overlap[1] > 100

    overlap = overlap_with_extent(
        AxisInterval(
            start=AxisBound(level=LevelMarker.START, offset=-1),
            end=AxisBound(level=LevelMarker.END, offset=0),
        ),
        (0, 2),
    )
    assert overlap == (0, 2)

    overlap = overlap_with_extent(
        AxisInterval(
            start=AxisBound(level=LevelMarker.START, offset=-3),
            end=AxisBound(level=LevelMarker.START, offset=-1),
        ),
        (0, 0),
    )
    assert overlap is None
    """
示例#5
0
def make_definition(name: str, fields: List[str], domain: Domain,
                    body: BodyType,
                    iteration_order: IterationOrder) -> StencilDefinition:
    api_signature = [ArgumentInfo(name=n, is_keyword=False) for n in fields]
    tmp_fields = {i[0]
                  for i in body}.union({i[1]
                                        for i in body}).difference(fields)
    api_fields = [
        FieldDecl(name=n,
                  data_type=DataType.AUTO,
                  axes=domain.axes_names,
                  is_api=True) for n in fields
    ] + [
        FieldDecl(name=n,
                  data_type=DataType.AUTO,
                  axes=domain.axes_names,
                  is_api=False) for n in tmp_fields
    ]
    return StencilDefinition(
        name=name,
        domain=domain,
        api_signature=api_signature,
        api_fields=api_fields,
        parameters=[],
        computations=[
            ComputationBlock(
                interval=AxisInterval(start=AxisBound(level=LevelMarker.START),
                                      end=AxisBound(level=LevelMarker.END)),
                iteration_order=iteration_order,
                body=BlockStmt(stmts=[
                    make_assign(*assign, loc_scope=name, loc_line=i)
                    for i, assign in enumerate(body)
                ]),
            )
        ],
        docstring="",
    )
示例#6
0
def test_axis_bound(defir_to_gtir):
    axis_bound = AxisBound(level=LevelMarker.START, offset=-51)
    gtir_bound = defir_to_gtir.visit_AxisBound(axis_bound)
    assert isinstance(gtir_bound, gtir.AxisBound)
    assert gtir_bound.level == common.LevelMarker.START
    assert gtir_bound.offset == -51