def visit_FieldIfStmt(self, node: gtir.FieldIfStmt, *, mask: oir.Expr = None, ctx: Context, **kwargs: Any) -> None: mask_field_decl = _create_mask(ctx, f"mask_{node.id_}", self.visit(node.cond)) current_mask = oir.FieldAccess(name=mask_field_decl.name, offset=CartesianOffset.zero(), dtype=mask_field_decl.dtype) combined_mask = current_mask if mask: combined_mask = oir.BinaryOp(op=LogicalOperator.AND, left=mask, right=combined_mask) self.visit(node.true_branch.body, mask=combined_mask, ctx=ctx) if node.false_branch: combined_mask = oir.UnaryOp(op=UnaryOperator.NOT, expr=current_mask) if mask: combined_mask = oir.BinaryOp(op=LogicalOperator.AND, left=mask, right=combined_mask) self.visit( node.false_branch.body, mask=combined_mask, ctx=ctx, )
def visit_ScalarIfStmt(self, node: gtir.ScalarIfStmt, *, mask: oir.Expr = None, ctx: Context, **kwargs: Any) -> None: current_mask = self.visit(node.cond) combined_mask = current_mask if mask: combined_mask = oir.BinaryOp(op=LogicalOperator.AND, left=mask, right=current_mask) self.visit(node.true_branch.body, mask=combined_mask, ctx=ctx) if node.false_branch: combined_mask = oir.UnaryOp(op=UnaryOperator.NOT, expr=current_mask) if mask: combined_mask = oir.BinaryOp(op=LogicalOperator.AND, left=mask, right=combined_mask) self.visit( node.false_branch.body, mask=combined_mask, ctx=ctx, )
def visit_FieldIfStmt(self, node: gtir.FieldIfStmt, *, mask: oir.Expr = None, ctx: Context, **kwargs: Any) -> List[oir.Stmt]: mask_field_decl = oir.Temporary(name=f"mask_{id(node)}", dtype=DataType.BOOL, dimensions=(True, True, True)) ctx.temp_fields.append(mask_field_decl) stmts = [ oir.AssignStmt( left=oir.FieldAccess( name=mask_field_decl.name, offset=CartesianOffset.zero(), dtype=DataType.BOOL, loc=node.loc, ), right=self.visit(node.cond), ) ] current_mask = oir.FieldAccess( name=mask_field_decl.name, offset=CartesianOffset.zero(), dtype=mask_field_decl.dtype, loc=node.loc, ) combined_mask = current_mask if mask: combined_mask = oir.BinaryOp(op=LogicalOperator.AND, left=mask, right=combined_mask, loc=node.loc) stmts.extend( self.visit(node.true_branch.body, mask=combined_mask, ctx=ctx, **kwargs)) if node.false_branch: combined_mask = oir.UnaryOp(op=UnaryOperator.NOT, expr=current_mask) if mask: combined_mask = oir.BinaryOp(op=LogicalOperator.AND, left=mask, right=combined_mask, loc=node.loc) stmts.extend( self.visit(node.false_branch.body, mask=combined_mask, ctx=ctx, **kwargs)) return stmts
def visit_UnaryOp(self, node: gtir.UnaryOp) -> oir.UnaryOp: return oir.UnaryOp(op=node.op, expr=self.visit(node.expr), loc=node.loc)
def visit_UnaryOp(self, node: gtir.UnaryOp, **kwargs: Any) -> oir.UnaryOp: return oir.UnaryOp(op=node.op, expr=self.visit(node.expr))