Beispiel #1
0
    def _make_apply_block(self, interval_block):
        # Body
        stmts = []
        for stmt_info in interval_block.stmts:
            if not isinstance(stmt_info.stmt, gt_ir.Decl):
                stmts.append(stmt_info.stmt)
        body = gt_ir.BlockStmt(stmts=stmts)
        result = gt_ir.ApplyBlock(
            interval=self._make_axis_interval(interval_block.interval), body=body
        )

        return result
Beispiel #2
0
    def _make_stage(self, ij_block):
        # Apply blocks and decls
        apply_blocks = []
        decls = []
        for int_block in ij_block.interval_blocks:
            # Make apply block
            stmts = []
            local_symbols = {}
            for stmt_info in int_block.stmts:
                if isinstance(stmt_info.stmt, gt_ir.Decl):
                    decl = stmt_info.stmt
                    if decl.name in self.data.symbols:
                        decls.append(stmt_info.stmt)
                    else:
                        assert isinstance(decl, gt_ir.VarDecl)
                        local_symbols[decl.name] = decl
                else:
                    stmts.append(stmt_info.stmt)

            apply_block = gt_ir.ApplyBlock(
                interval=self._make_axis_interval(int_block.interval),
                local_symbols=local_symbols,
                body=gt_ir.BlockStmt(stmts=stmts),
            )
            apply_blocks.append(apply_block)

        # Accessors
        accessors = []
        remaining_outputs = set(ij_block.outputs)
        for name, extent in ij_block.inputs.items():
            if name in remaining_outputs:
                read_write = True
                remaining_outputs.remove(name)
                extent |= Extent.zeros()
            else:
                read_write = False
            accessors.append(self._make_accessor(name, extent, read_write))
        zero_extent = Extent.zeros(self.data.ndims)
        for name in remaining_outputs:
            accessors.append(self._make_accessor(name, zero_extent, True))

        stage = gt_ir.Stage(
            name="stage__{}".format(ij_block.id),
            accessors=accessors,
            apply_blocks=apply_blocks,
            compute_extent=ij_block.compute_extent,
        )

        return stage