Example #1
0
    def visit_FieldIfStmt(self,
                          node: gtir.FieldIfStmt,
                          *,
                          mask: oir.Expr = None,
                          ctx: Context,
                          **kwargs: Any) -> List[oir.Stmt]:
        mask_field_decl = oir.Temporary(name=f"mask_{id(node)}",
                                        dtype=DataType.BOOL,
                                        dimensions=(True, True, True))
        ctx.temp_fields.append(mask_field_decl)
        stmts = [
            oir.AssignStmt(
                left=oir.FieldAccess(
                    name=mask_field_decl.name,
                    offset=CartesianOffset.zero(),
                    dtype=DataType.BOOL,
                    loc=node.loc,
                ),
                right=self.visit(node.cond),
            )
        ]

        current_mask = oir.FieldAccess(
            name=mask_field_decl.name,
            offset=CartesianOffset.zero(),
            dtype=mask_field_decl.dtype,
            loc=node.loc,
        )
        combined_mask = current_mask
        if mask:
            combined_mask = oir.BinaryOp(op=LogicalOperator.AND,
                                         left=mask,
                                         right=combined_mask,
                                         loc=node.loc)
        stmts.extend(
            self.visit(node.true_branch.body,
                       mask=combined_mask,
                       ctx=ctx,
                       **kwargs))

        if node.false_branch:
            combined_mask = oir.UnaryOp(op=UnaryOperator.NOT,
                                        expr=current_mask)
            if mask:
                combined_mask = oir.BinaryOp(op=LogicalOperator.AND,
                                             left=mask,
                                             right=combined_mask,
                                             loc=node.loc)
            stmts.extend(
                self.visit(node.false_branch.body,
                           mask=combined_mask,
                           ctx=ctx,
                           **kwargs))

        return stmts
Example #2
0
 def _make_scalar_accessor(name: str) -> gtcpp.AccessorRef:
     return gtcpp.AccessorRef(
         name=name,
         offset=CartesianOffset.zero(),
         kind=ExprKind.SCALAR,
         dtype=common.DataType.INT32,
     )
Example #3
0
    def visit_FieldIfStmt(self,
                          node: gtir.FieldIfStmt,
                          *,
                          mask: oir.Expr = None,
                          ctx: Context,
                          **kwargs: Any) -> None:
        mask_field_decl = _create_mask(ctx, f"mask_{node.id_}",
                                       self.visit(node.cond))
        current_mask = oir.FieldAccess(name=mask_field_decl.name,
                                       offset=CartesianOffset.zero(),
                                       dtype=mask_field_decl.dtype)
        combined_mask = current_mask
        if mask:
            combined_mask = oir.BinaryOp(op=LogicalOperator.AND,
                                         left=mask,
                                         right=combined_mask)
        self.visit(node.true_branch.body, mask=combined_mask, ctx=ctx)

        if node.false_branch:
            combined_mask = oir.UnaryOp(op=UnaryOperator.NOT,
                                        expr=current_mask)
            if mask:
                combined_mask = oir.BinaryOp(op=LogicalOperator.AND,
                                             left=mask,
                                             right=combined_mask)
            self.visit(
                node.false_branch.body,
                mask=combined_mask,
                ctx=ctx,
            )
Example #4
0
 def visit_ScalarAccess(
         self, node: oir.ScalarAccess,
         **kwargs: Any) -> Union[gtcpp.AccessorRef, gtcpp.ScalarAccess]:
     assert "stencil_symtable" in kwargs
     if node.name in kwargs["stencil_symtable"]:
         symbol = kwargs["stencil_symtable"][node.name]
         if isinstance(symbol, oir.ScalarDecl):
             return gtcpp.AccessorRef(name=symbol.name,
                                      offset=CartesianOffset.zero(),
                                      dtype=symbol.dtype)
         assert isinstance(symbol, oir.LocalScalar)
     return gtcpp.ScalarAccess(name=node.name, dtype=node.dtype)
Example #5
0
def _create_mask(ctx: "GTIRToOIR.Context", name: str, cond: oir.Expr) -> oir.Temporary:
    mask_field_decl = oir.Temporary(name=name, dtype=DataType.BOOL, dimensions=(True, True, True))
    ctx.add_decl(mask_field_decl)

    fill_mask_field = oir.HorizontalExecution(
        body=[
            oir.AssignStmt(
                left=oir.FieldAccess(
                    name=mask_field_decl.name,
                    offset=CartesianOffset.zero(),
                    dtype=mask_field_decl.dtype,
                ),
                right=cond,
            )
        ],
        declarations=[],
    )
    ctx.add_horizontal_execution(fill_mask_field)
    return mask_field_decl
Example #6
0
 def __init__(self, name) -> None:
     self._name = name
     self._offset = CartesianOffset.zero()
     self._kind = ExprKind.FIELD
     self._dtype = DataType.FLOAT32