Exemple #1
0
 def api_fields(self) -> List[FieldDecl]:
     tmp_field_names = self.field_names.difference(self.fields)
     tmp_fields = [
         FieldDecl(name=n,
                   data_type=DataType.AUTO,
                   axes=self.domain.axes_names,
                   is_api=False) for n in tmp_field_names
     ]
     return tmp_fields + [
         FieldDecl(name=n,
                   data_type=DataType.AUTO,
                   axes=self.domain.axes_names,
                   is_api=True) for n in self.fields
     ]
Exemple #2
0
def test_field_decl_dims(defir_to_gtir, axes, expected_mask):
    field_decl = FieldDecl(name="a",
                           data_type=DataType.INT64,
                           axes=axes,
                           is_api=True)
    gtir_decl = defir_to_gtir.visit_FieldDecl(field_decl)
    assert gtir_decl.dimensions == expected_mask
Exemple #3
0
 def api_fields(self) -> List[FieldDecl]:
     return [
         FieldDecl(name=n,
                   data_type=DataType.AUTO,
                   axes=self.domain.axes_names,
                   is_api=True) for n in self.fields
     ]
Exemple #4
0
def test_field_decl(defir_to_gtir):
    field_decl = FieldDecl(name="a",
                           data_type=DataType.BOOL,
                           axes=["I", "J", "K"],
                           is_api=True)
    gtir_decl = defir_to_gtir.visit_FieldDecl(field_decl)
    assert isinstance(gtir_decl, gtir.FieldDecl)
    assert gtir_decl.name == "a"
    assert gtir_decl.dtype == common.DataType.BOOL
Exemple #5
0
def make_definition(name: str, fields: List[str], domain: Domain,
                    body: BodyType,
                    iteration_order: IterationOrder) -> StencilDefinition:
    api_signature = [ArgumentInfo(name=n, is_keyword=False) for n in fields]
    tmp_fields = {i[0]
                  for i in body}.union({i[1]
                                        for i in body}).difference(fields)
    api_fields = [
        FieldDecl(name=n,
                  data_type=DataType.AUTO,
                  axes=domain.axes_names,
                  is_api=True) for n in fields
    ] + [
        FieldDecl(name=n,
                  data_type=DataType.AUTO,
                  axes=domain.axes_names,
                  is_api=False) for n in tmp_fields
    ]
    return StencilDefinition(
        name=name,
        domain=domain,
        api_signature=api_signature,
        api_fields=api_fields,
        parameters=[],
        computations=[
            ComputationBlock(
                interval=AxisInterval(start=AxisBound(level=LevelMarker.START),
                                      end=AxisBound(level=LevelMarker.END)),
                iteration_order=iteration_order,
                body=BlockStmt(stmts=[
                    make_assign(*assign, loc_scope=name, loc_line=i)
                    for i, assign in enumerate(body)
                ]),
            )
        ],
        docstring="",
    )
Exemple #6
0
 def build(self) -> ComputationBlock:
     self.loc.scope = self.parent.child_scope if self.parent else self.scope
     temp_fields = self.fields.difference(
         self.parent.fields) if self.parent else set()
     temp_decls = [
         FieldDecl(name=n,
                   data_type=DataType.AUTO,
                   axes=self.parent.domain.axes_names,
                   is_api=False) for n in temp_fields
     ]
     return ComputationBlock(
         interval=AxisInterval(
             start=AxisBound(level=LevelMarker.START, offset=self.start),
             end=AxisBound(level=LevelMarker.END, offset=self.end),
         ),
         iteration_order=self.order,
         body=BlockStmt(stmts=temp_decls +
                        [stmt.build() for stmt in self.children], ),
         loc=self.loc,
     )