コード例 #1
0
ファイル: gtscript_frontend.py プロジェクト: muellch/gt4py
    def visit_FunctionDef(self, node: ast.FunctionDef) -> list:
        blocks = []
        for stmt in node.body:
            blocks.extend(gt_utils.listify(self.visit(stmt)))

        if not all(
                isinstance(item, gt_ir.ComputationBlock) for item in blocks):
            raise GTScriptSyntaxError("Invalid stencil definition",
                                      loc=gt_ir.Location.from_ast_node(node))

        return blocks
コード例 #2
0
ファイル: gtscript_frontend.py プロジェクト: muellch/gt4py
    def visit_If(self, node: ast.If) -> gt_ir.If:
        main_stmts = []
        for stmt in node.body:
            main_stmts.extend(gt_utils.listify(self.visit(stmt)))
        assert all(isinstance(item, gt_ir.Statement) for item in main_stmts)

        else_stmts = []
        if node.orelse:
            for stmt in node.orelse:
                else_stmts.extend(gt_utils.listify(self.visit(stmt)))
            assert all(
                isinstance(item, gt_ir.Statement) for item in else_stmts)

        result = gt_ir.If(
            condition=gt_ir.utils.make_expr(self.visit(node.test)),
            main_body=gt_ir.BlockStmt(stmts=main_stmts),
            else_body=gt_ir.BlockStmt(
                stmts=else_stmts) if else_stmts else None,
        )

        return result
コード例 #3
0
ファイル: gtscript_frontend.py プロジェクト: muellch/gt4py
    def _visit_interval_node(self, node: ast.With) -> gt_ir.ComputationBlock:
        loc = gt_ir.Location.from_ast_node(node)
        interval_error = GTScriptSyntaxError(
            f"Invalid 'interval' specification at line {loc.line} (column {loc.column})",
            loc=loc)

        interval_node = node.items[0].context_expr
        if ((len(interval_node.args) + len(interval_node.keywords) < 1)
                or (len(interval_node.args) + len(interval_node.keywords) > 2)
                or any(keyword.arg not in ["start", "end"]
                       for keyword in interval_node.keywords)):
            raise interval_error

        loc = gt_ir.Location.from_ast_node(node)
        range_error = GTScriptSyntaxError(
            f"Invalid interval range specification at line {loc.line} (column {loc.column})",
            loc=loc,
        )
        if interval_node.args:
            range_node = interval_node.args
        else:
            range_node = [
                interval_node.keywords[0].value,
                interval_node.keywords[1].value
            ]
        if len(range_node) == 1 and isinstance(range_node[0], ast.Ellipsis):
            interval = gt_ir.AxisInterval.full_interval()
        elif len(range_node) == 2 and all(
                isinstance(elem, (ast.Num, ast.UnaryOp, ast.NameConstant))
                for elem in range_node):
            range_value = tuple(self.visit(elem) for elem in range_node)
            try:
                interval = gt_ir.utils.make_axis_interval(range_value)
            except AssertionError as e:
                raise range_error from e
        else:
            raise range_error

        self.parsing_context = ParsingContext.INTERVAL
        stmts = []
        for stmt in node.body:
            stmts.extend(gt_utils.listify(self.visit(stmt)))
        self.parsing_context = ParsingContext.COMPUTATION

        result = gt_ir.ComputationBlock(
            interval=interval,
            iteration_order=gt_ir.IterationOrder.PARALLEL,
            body=gt_ir.BlockStmt(stmts=stmts),
        )

        return result
コード例 #4
0
ファイル: gtscript_frontend.py プロジェクト: gronerl/gt4py
    def visit_If(self, node: ast.If) -> gt_ir.If:
        condition_value = gt_utils.meta.ast_eval(node.test,
                                                 self.externals,
                                                 default=NOTHING)
        if condition_value is not NOTHING:
            # Compile-time evaluation
            stmts = []
            if condition_value:
                for stmt in node.body:
                    stmts.extend(gt_utils.listify(self.visit(stmt)))
            elif node.orelse:
                for stmt in node.orelse:
                    stmts.extend(gt_utils.listify(self.visit(stmt)))
            result = stmts
        else:
            # run-time evaluation
            main_stmts = []
            for stmt in node.body:
                main_stmts.extend(gt_utils.listify(self.visit(stmt)))
            assert all(
                isinstance(item, gt_ir.Statement) for item in main_stmts)

            else_stmts = []
            if node.orelse:
                for stmt in node.orelse:
                    else_stmts.extend(gt_utils.listify(self.visit(stmt)))
                assert all(
                    isinstance(item, gt_ir.Statement) for item in else_stmts)

            result = gt_ir.If(
                condition=gt_ir.utils.make_expr(self.visit(node.test)),
                main_body=gt_ir.BlockStmt(stmts=main_stmts),
                else_body=gt_ir.BlockStmt(
                    stmts=else_stmts) if else_stmts else None,
            )

        return result
コード例 #5
0
ファイル: nodes.py プロジェクト: eddie-c-davis/gt4py
    def recurse(node: Node) -> Generator[Node, None, None]:
        for key, value in iter_attributes(node):
            if isinstance(node, collections.abc.Iterable):
                if isinstance(node, collections.abc.Mapping):
                    children = node.values()
                else:
                    children = node
            else:
                children = gt_utils.listify(value)

            for value in children:
                if isinstance(value, Node):
                    yield from recurse(value)

            if isinstance(node, node_type):
                yield node
コード例 #6
0
def filter_nodes_dfs(root_node, node_type):
    """Yield an iterator over the nodes of node_type inside root_node in DFS order."""
    stack = [root_node]
    while stack:
        curr = stack.pop()
        assert isinstance(curr, Node)

        for node_class in curr.__class__.__mro__:
            if node_class is node_type:
                yield curr

        for key, value in iter_attributes(curr):
            if isinstance(curr, collections.abc.Iterable):
                if isinstance(curr, collections.abc.Mapping):
                    children = curr.values()
                else:
                    children = curr
            else:
                children = gt_utils.listify(value)

            for value in filter(lambda x: isinstance(x, Node), children):
                stack.append(value)
コード例 #7
0
 def __getitem__(self, item):
     if self._is_device_modified:
         self.device_to_host()
     self._new_index = gt_utils.listify(item)
     return super().__getitem__(item)
コード例 #8
0
 def __getitem__(self, item):
     self._new_index = gt_utils.listify(item)
     return super().__getitem__(item)