def visit_FieldIfStmt(self, node: gtir.FieldIfStmt, *, mask: oir.Expr = None, ctx: Context, **kwargs: Any) -> List[oir.Stmt]: mask_field_decl = oir.Temporary(name=f"mask_{id(node)}", dtype=DataType.BOOL, dimensions=(True, True, True)) ctx.temp_fields.append(mask_field_decl) stmts = [ oir.AssignStmt( left=oir.FieldAccess( name=mask_field_decl.name, offset=CartesianOffset.zero(), dtype=DataType.BOOL, loc=node.loc, ), right=self.visit(node.cond), ) ] current_mask = oir.FieldAccess( name=mask_field_decl.name, offset=CartesianOffset.zero(), dtype=mask_field_decl.dtype, loc=node.loc, ) combined_mask = current_mask if mask: combined_mask = oir.BinaryOp(op=LogicalOperator.AND, left=mask, right=combined_mask, loc=node.loc) stmts.extend( self.visit(node.true_branch.body, mask=combined_mask, ctx=ctx, **kwargs)) if node.false_branch: combined_mask = oir.UnaryOp(op=UnaryOperator.NOT, expr=current_mask) if mask: combined_mask = oir.BinaryOp(op=LogicalOperator.AND, left=mask, right=combined_mask, loc=node.loc) stmts.extend( self.visit(node.false_branch.body, mask=combined_mask, ctx=ctx, **kwargs)) return stmts
def _make_scalar_accessor(name: str) -> gtcpp.AccessorRef: return gtcpp.AccessorRef( name=name, offset=CartesianOffset.zero(), kind=ExprKind.SCALAR, dtype=common.DataType.INT32, )
def visit_FieldIfStmt(self, node: gtir.FieldIfStmt, *, mask: oir.Expr = None, ctx: Context, **kwargs: Any) -> None: mask_field_decl = _create_mask(ctx, f"mask_{node.id_}", self.visit(node.cond)) current_mask = oir.FieldAccess(name=mask_field_decl.name, offset=CartesianOffset.zero(), dtype=mask_field_decl.dtype) combined_mask = current_mask if mask: combined_mask = oir.BinaryOp(op=LogicalOperator.AND, left=mask, right=combined_mask) self.visit(node.true_branch.body, mask=combined_mask, ctx=ctx) if node.false_branch: combined_mask = oir.UnaryOp(op=UnaryOperator.NOT, expr=current_mask) if mask: combined_mask = oir.BinaryOp(op=LogicalOperator.AND, left=mask, right=combined_mask) self.visit( node.false_branch.body, mask=combined_mask, ctx=ctx, )
def visit_ScalarAccess( self, node: oir.ScalarAccess, **kwargs: Any) -> Union[gtcpp.AccessorRef, gtcpp.ScalarAccess]: assert "stencil_symtable" in kwargs if node.name in kwargs["stencil_symtable"]: symbol = kwargs["stencil_symtable"][node.name] if isinstance(symbol, oir.ScalarDecl): return gtcpp.AccessorRef(name=symbol.name, offset=CartesianOffset.zero(), dtype=symbol.dtype) assert isinstance(symbol, oir.LocalScalar) return gtcpp.ScalarAccess(name=node.name, dtype=node.dtype)
def _create_mask(ctx: "GTIRToOIR.Context", name: str, cond: oir.Expr) -> oir.Temporary: mask_field_decl = oir.Temporary(name=name, dtype=DataType.BOOL, dimensions=(True, True, True)) ctx.add_decl(mask_field_decl) fill_mask_field = oir.HorizontalExecution( body=[ oir.AssignStmt( left=oir.FieldAccess( name=mask_field_decl.name, offset=CartesianOffset.zero(), dtype=mask_field_decl.dtype, ), right=cond, ) ], declarations=[], ) ctx.add_horizontal_execution(fill_mask_field) return mask_field_decl
def __init__(self, name) -> None: self._name = name self._offset = CartesianOffset.zero() self._kind = ExprKind.FIELD self._dtype = DataType.FLOAT32