Exemple #1
0
    def _make_apply_block(self, interval_block):
        # Body
        stmts = []
        for stmt_info in interval_block.stmts:
            if not isinstance(stmt_info.stmt, gt_ir.Decl):
                stmts.append(stmt_info.stmt)
        body = gt_ir.BlockStmt(stmts=stmts)
        result = gt_ir.ApplyBlock(
            interval=self._make_axis_interval(interval_block.interval), body=body
        )

        return result
Exemple #2
0
    def visit_If(self, node: ast.If) -> gt_ir.If:
        main_stmts = []
        for stmt in node.body:
            main_stmts.extend(gt_utils.listify(self.visit(stmt)))
        assert all(isinstance(item, gt_ir.Statement) for item in main_stmts)

        else_stmts = []
        if node.orelse:
            for stmt in node.orelse:
                else_stmts.extend(gt_utils.listify(self.visit(stmt)))
            assert all(
                isinstance(item, gt_ir.Statement) for item in else_stmts)

        result = gt_ir.If(
            condition=gt_ir.utils.make_expr(self.visit(node.test)),
            main_body=gt_ir.BlockStmt(stmts=main_stmts),
            else_body=gt_ir.BlockStmt(
                stmts=else_stmts) if else_stmts else None,
        )

        return result
Exemple #3
0
    def _visit_interval_node(self, node: ast.With) -> gt_ir.ComputationBlock:
        loc = gt_ir.Location.from_ast_node(node)
        interval_error = GTScriptSyntaxError(
            f"Invalid 'interval' specification at line {loc.line} (column {loc.column})",
            loc=loc)

        interval_node = node.items[0].context_expr
        if ((len(interval_node.args) + len(interval_node.keywords) < 1)
                or (len(interval_node.args) + len(interval_node.keywords) > 2)
                or any(keyword.arg not in ["start", "end"]
                       for keyword in interval_node.keywords)):
            raise interval_error

        loc = gt_ir.Location.from_ast_node(node)
        range_error = GTScriptSyntaxError(
            f"Invalid interval range specification at line {loc.line} (column {loc.column})",
            loc=loc,
        )
        if interval_node.args:
            range_node = interval_node.args
        else:
            range_node = [
                interval_node.keywords[0].value,
                interval_node.keywords[1].value
            ]
        if len(range_node) == 1 and isinstance(range_node[0], ast.Ellipsis):
            interval = gt_ir.AxisInterval.full_interval()
        elif len(range_node) == 2 and all(
                isinstance(elem, (ast.Num, ast.UnaryOp, ast.NameConstant))
                for elem in range_node):
            range_value = tuple(self.visit(elem) for elem in range_node)
            try:
                interval = gt_ir.utils.make_axis_interval(range_value)
            except AssertionError as e:
                raise range_error from e
        else:
            raise range_error

        self.parsing_context = ParsingContext.INTERVAL
        stmts = []
        for stmt in node.body:
            stmts.extend(gt_utils.listify(self.visit(stmt)))
        self.parsing_context = ParsingContext.COMPUTATION

        result = gt_ir.ComputationBlock(
            interval=interval,
            iteration_order=gt_ir.IterationOrder.PARALLEL,
            body=gt_ir.BlockStmt(stmts=stmts),
        )

        return result
Exemple #4
0
    def visit_If(self, node: ast.If) -> gt_ir.If:
        condition_value = gt_utils.meta.ast_eval(node.test,
                                                 self.externals,
                                                 default=NOTHING)
        if condition_value is not NOTHING:
            # Compile-time evaluation
            stmts = []
            if condition_value:
                for stmt in node.body:
                    stmts.extend(gt_utils.listify(self.visit(stmt)))
            elif node.orelse:
                for stmt in node.orelse:
                    stmts.extend(gt_utils.listify(self.visit(stmt)))
            result = stmts
        else:
            # run-time evaluation
            main_stmts = []
            for stmt in node.body:
                main_stmts.extend(gt_utils.listify(self.visit(stmt)))
            assert all(
                isinstance(item, gt_ir.Statement) for item in main_stmts)

            else_stmts = []
            if node.orelse:
                for stmt in node.orelse:
                    else_stmts.extend(gt_utils.listify(self.visit(stmt)))
                assert all(
                    isinstance(item, gt_ir.Statement) for item in else_stmts)

            result = gt_ir.If(
                condition=gt_ir.utils.make_expr(self.visit(node.test)),
                main_body=gt_ir.BlockStmt(stmts=main_stmts),
                else_body=gt_ir.BlockStmt(
                    stmts=else_stmts) if else_stmts else None,
            )

        return result
Exemple #5
0
    def _make_stage(self, ij_block):
        # Apply blocks and decls
        apply_blocks = []
        decls = []
        for int_block in ij_block.interval_blocks:
            # Make apply block
            stmts = []
            local_symbols = {}
            for stmt_info in int_block.stmts:
                if isinstance(stmt_info.stmt, gt_ir.Decl):
                    decl = stmt_info.stmt
                    if decl.name in self.data.symbols:
                        decls.append(stmt_info.stmt)
                    else:
                        assert isinstance(decl, gt_ir.VarDecl)
                        local_symbols[decl.name] = decl
                else:
                    stmts.append(stmt_info.stmt)

            apply_block = gt_ir.ApplyBlock(
                interval=self._make_axis_interval(int_block.interval),
                local_symbols=local_symbols,
                body=gt_ir.BlockStmt(stmts=stmts),
            )
            apply_blocks.append(apply_block)

        # Accessors
        accessors = []
        remaining_outputs = set(ij_block.outputs)
        for name, extent in ij_block.inputs.items():
            if name in remaining_outputs:
                read_write = True
                remaining_outputs.remove(name)
                extent |= Extent.zeros()
            else:
                read_write = False
            accessors.append(self._make_accessor(name, extent, read_write))
        zero_extent = Extent.zeros(self.data.ndims)
        for name in remaining_outputs:
            accessors.append(self._make_accessor(name, zero_extent, True))

        stage = gt_ir.Stage(
            name="stage__{}".format(ij_block.id),
            accessors=accessors,
            apply_blocks=apply_blocks,
            compute_extent=ij_block.compute_extent,
        )

        return stage