Example #1
0
    def visit_Stencil(self, node: oir.Stencil,
                      **kwargs: Any) -> npir.Computation:
        field_extents, block_extents = compute_extents(node)

        arguments = [decl.name for decl in node.params]
        param_decls = [
            self.visit(decl, **kwargs) for decl in node.params
            if isinstance(decl, oir.ScalarDecl)
        ]
        api_field_decls = [
            self.visit(decl, field_extents=field_extents)
            for decl in node.params if isinstance(decl, oir.FieldDecl)
        ]
        temp_decls = [
            self.visit(decl, field_extents=field_extents, **kwargs)
            for decl in node.declarations
        ]

        vertical_passes = utils.flatten_list(
            self.visit(
                node.vertical_loops,
                block_extents=block_extents,
                **kwargs,
            ))

        return npir.Computation(
            arguments=arguments,
            api_field_decls=api_field_decls,
            param_decls=param_decls,
            temp_decls=temp_decls,
            vertical_passes=vertical_passes,
        )
Example #2
0
    def visit_VerticalLoop(self, node: gtir.VerticalLoop, *,
                           ctx: Context) -> oir.VerticalLoop:
        horiz_execs: List[oir.HorizontalExecution] = []
        for stmt in node.body:
            ctx.reset_local_scalars()
            ret = self.visit(stmt, ctx=ctx)
            stmts = utils.flatten_list(
                [ret] if isinstance(ret, oir.Stmt) else ret)
            horiz_execs.append(
                oir.HorizontalExecution(body=stmts,
                                        declarations=ctx.local_scalars))

        ctx.temp_fields += [
            oir.Temporary(name=temp.name,
                          dtype=temp.dtype,
                          dimensions=temp.dimensions)
            for temp in node.temporaries
        ]

        return oir.VerticalLoop(
            loop_order=node.loop_order,
            sections=[
                oir.VerticalLoopSection(
                    interval=self.visit(node.interval),
                    horizontal_executions=horiz_execs,
                    loc=node.loc,
                )
            ],
        )
Example #3
0
    def visit_HorizontalRestriction(self, node: oir.HorizontalRestriction, *,
                                    extent: Extent, **kwargs: Any) -> Any:
        horizontal_mask = compute_relative_mask(extent, node.mask)
        if horizontal_mask is None:
            return NOTHING

        return utils.flatten_list(
            self.visit(node.body, horizontal_mask=horizontal_mask, **kwargs))
Example #4
0
    def visit_HorizontalRestriction(
            self, node: gtir.HorizontalRestriction,
            **kwargs: Any) -> oir.HorizontalRestriction:
        body_stmts = []
        for stmt in node.body:
            stmt_or_stmts = self.visit(stmt, **kwargs)
            stmts = utils.flatten_list([stmt_or_stmts] if isinstance(
                stmt_or_stmts, oir.Stmt) else stmt_or_stmts)
            body_stmts.extend(stmts)

        return oir.HorizontalRestriction(mask=node.mask, body=body_stmts)
Example #5
0
def _all_local_scalars_are_unique_type(stencil: npir.Computation) -> bool:
    all_declarations = utils.flatten_list(stencil.iter_tree().if_isinstance(
        npir.HorizontalBlock).getattr("declarations").to_list())

    name_to_dtype: Dict[str, common.DataType] = {}
    for decl in all_declarations:
        if decl.name in name_to_dtype:
            if decl.dtype != name_to_dtype[decl.name]:
                return False
        else:
            name_to_dtype[decl.name] = decl.dtype

    return True
Example #6
0
 def visit_While(self,
                 node: oir.While,
                 *,
                 mask: Optional[npir.Expr] = None,
                 **kwargs: Any) -> npir.While:
     cond = self.visit(node.cond, mask=mask, **kwargs)
     if mask:
         mask = npir.VectorLogic(op=common.LogicalOperator.AND,
                                 left=mask,
                                 right=cond)
     else:
         mask = cond
     return npir.While(cond=cond,
                       body=utils.flatten_list(
                           self.visit(node.body, mask=mask, **kwargs)))
Example #7
0
    def visit_MaskStmt(
        self,
        node: oir.MaskStmt,
        *,
        mask: Optional[npir.Expr] = None,
        **kwargs: Any,
    ) -> List[npir.Stmt]:
        mask_expr = self.visit(node.mask, **kwargs)
        if mask:
            mask_expr = npir.VectorLogic(op=common.LogicalOperator.AND,
                                         left=mask,
                                         right=mask_expr)

        return utils.flatten_list(
            self.visit(node.body, mask=mask_expr, **kwargs))
Example #8
0
    def _impl(cls: Type[pydantic.BaseModel],
              values: RootValidatorValuesType) -> RootValidatorValuesType:
        dtype_nodes: List[Node] = []
        for v in flatten_list(values.values()):
            if isinstance(v, Node):
                dtype_nodes.extend(v.iter_tree().if_hasattr("dtype"))

        nodes_without_dtype = []
        for node in dtype_nodes:
            if not node.dtype:
                nodes_without_dtype.append(node)

        if len(nodes_without_dtype) > 0:
            raise ValueError(
                "Nodes without dtype detected {}".format(nodes_without_dtype))
        return values
Example #9
0
 def visit_HorizontalExecution(
     self,
     node: oir.HorizontalExecution,
     *,
     block_extents: Optional[Dict[int, Extent]] = None,
     **kwargs: Any,
 ) -> npir.HorizontalBlock:
     stmts = utils.flatten_list(self.visit(node.body, **kwargs))
     if block_extents:
         extent = block_extents[id(node)]
     else:
         extent = ((0, 0), (0, 0))
     return npir.HorizontalBlock(declarations=self.visit(
         node.declarations, **kwargs),
                                 body=stmts,
                                 extent=extent)
Example #10
0
    def visit_While(self,
                    node: gtir.While,
                    *,
                    mask: oir.Expr = None,
                    **kwargs: Any):
        body_stmts = []
        for stmt in node.body:
            stmt_or_stmts = self.visit(stmt, **kwargs)
            stmts = utils.flatten_list([stmt_or_stmts] if isinstance(
                stmt_or_stmts, oir.Stmt) else stmt_or_stmts)
            body_stmts.extend(stmts)

        cond = self.visit(node.cond)
        if mask:
            cond = oir.BinaryOp(op=common.LogicalOperator.AND,
                                left=mask,
                                right=cond)
        stmt = oir.While(cond=cond, body=body_stmts, loc=node.loc)
        if mask is not None:
            stmt = oir.MaskStmt(body=[stmt], mask=mask, loc=node.loc)
        return stmt