Exemple #1
0
    def visit_FieldIfStmt(self,
                          node: gtir.FieldIfStmt,
                          *,
                          mask: oir.Expr = None,
                          ctx: Context,
                          **kwargs: Any) -> None:
        mask_field_decl = _create_mask(ctx, f"mask_{node.id_}",
                                       self.visit(node.cond))
        current_mask = oir.FieldAccess(name=mask_field_decl.name,
                                       offset=CartesianOffset.zero(),
                                       dtype=mask_field_decl.dtype)
        combined_mask = current_mask
        if mask:
            combined_mask = oir.BinaryOp(op=LogicalOperator.AND,
                                         left=mask,
                                         right=combined_mask)
        self.visit(node.true_branch.body, mask=combined_mask, ctx=ctx)

        if node.false_branch:
            combined_mask = oir.UnaryOp(op=UnaryOperator.NOT,
                                        expr=current_mask)
            if mask:
                combined_mask = oir.BinaryOp(op=LogicalOperator.AND,
                                             left=mask,
                                             right=combined_mask)
            self.visit(
                node.false_branch.body,
                mask=combined_mask,
                ctx=ctx,
            )
Exemple #2
0
    def visit_ScalarIfStmt(self,
                           node: gtir.ScalarIfStmt,
                           *,
                           mask: oir.Expr = None,
                           ctx: Context,
                           **kwargs: Any) -> None:
        current_mask = self.visit(node.cond)
        combined_mask = current_mask
        if mask:
            combined_mask = oir.BinaryOp(op=LogicalOperator.AND,
                                         left=mask,
                                         right=current_mask)

        self.visit(node.true_branch.body, mask=combined_mask, ctx=ctx)
        if node.false_branch:
            combined_mask = oir.UnaryOp(op=UnaryOperator.NOT,
                                        expr=current_mask)
            if mask:
                combined_mask = oir.BinaryOp(op=LogicalOperator.AND,
                                             left=mask,
                                             right=combined_mask)
            self.visit(
                node.false_branch.body,
                mask=combined_mask,
                ctx=ctx,
            )
Exemple #3
0
    def visit_FieldIfStmt(self,
                          node: gtir.FieldIfStmt,
                          *,
                          mask: oir.Expr = None,
                          ctx: Context,
                          **kwargs: Any) -> List[oir.Stmt]:
        mask_field_decl = oir.Temporary(name=f"mask_{id(node)}",
                                        dtype=DataType.BOOL,
                                        dimensions=(True, True, True))
        ctx.temp_fields.append(mask_field_decl)
        stmts = [
            oir.AssignStmt(
                left=oir.FieldAccess(
                    name=mask_field_decl.name,
                    offset=CartesianOffset.zero(),
                    dtype=DataType.BOOL,
                    loc=node.loc,
                ),
                right=self.visit(node.cond),
            )
        ]

        current_mask = oir.FieldAccess(
            name=mask_field_decl.name,
            offset=CartesianOffset.zero(),
            dtype=mask_field_decl.dtype,
            loc=node.loc,
        )
        combined_mask = current_mask
        if mask:
            combined_mask = oir.BinaryOp(op=LogicalOperator.AND,
                                         left=mask,
                                         right=combined_mask,
                                         loc=node.loc)
        stmts.extend(
            self.visit(node.true_branch.body,
                       mask=combined_mask,
                       ctx=ctx,
                       **kwargs))

        if node.false_branch:
            combined_mask = oir.UnaryOp(op=UnaryOperator.NOT,
                                        expr=current_mask)
            if mask:
                combined_mask = oir.BinaryOp(op=LogicalOperator.AND,
                                             left=mask,
                                             right=combined_mask,
                                             loc=node.loc)
            stmts.extend(
                self.visit(node.false_branch.body,
                           mask=combined_mask,
                           ctx=ctx,
                           **kwargs))

        return stmts
Exemple #4
0
 def visit_UnaryOp(self, node: gtir.UnaryOp) -> oir.UnaryOp:
     return oir.UnaryOp(op=node.op, expr=self.visit(node.expr), loc=node.loc)
Exemple #5
0
 def visit_UnaryOp(self, node: gtir.UnaryOp, **kwargs: Any) -> oir.UnaryOp:
     return oir.UnaryOp(op=node.op, expr=self.visit(node.expr))