Example #1
0
    def apply(self, transform_data: TransformData):
        # Greedy strategy to merge multi-stages
        merged_blocks = [transform_data.blocks[0]]
        for candidate in transform_data.blocks[1:]:
            merged = merged_blocks[-1]
            if self._are_compatible_multi_stages(
                    merged, candidate, transform_data.has_sequential_axis):
                merged.id = transform_data.id_generator.new
                self._merge_domain_blocks(merged, candidate)
            else:
                merged_blocks.append(candidate)

        # Greedy strategy to merge stages
        # assert transform_data.has_sequential_axis
        for block in merged_blocks:
            merged_ijs = [block.ij_blocks[0]]
            for ij_candidate in block.ij_blocks[1:]:
                merged = merged_ijs[-1]
                if self._are_compatible_stages(
                        merged, ij_candidate,
                        transform_data.min_k_interval_sizes):
                    merged.id = transform_data.id_generator.new
                    self._merge_ij_blocks(merged, ij_candidate, transform_data)
                else:
                    merged_ijs.append(ij_candidate)

            block.ij_blocks = merged_ijs

        transform_data.blocks = merged_blocks

        return transform_data
Example #2
0
 def build_transform(self):
     definition = self.build()
     return TransformData(
         definition_ir=definition,
         implementation_ir=init_implementation_from_definition(definition),
         options=BuildOptions(name=self.name, module=__name__),
     )
Example #3
0
    def __call__(self, definition_ir, options):
        implementation_ir = gt_ir.StencilImplementation(
            name=definition_ir.name,
            api_signature=[],
            domain=definition_ir.domain,
            fields={},
            parameters={},
            multi_stages=[],
            fields_extents={},
            unreferenced=[],
            axis_splitters_var=None,
            externals=definition_ir.externals,
            sources=definition_ir.sources,
            docstring=definition_ir.docstring,
        )
        self.transform_data = TransformData(
            definition_ir=definition_ir, implementation_ir=implementation_ir, options=options
        )

        # Initialize auxiliary data
        InitInfoPass.apply(self.transform_data)

        # Turn compute units into atomic execution units
        NormalizeBlocksPass.apply(self.transform_data)

        # Compute stage extents
        ComputeExtentsPass.apply(self.transform_data)

        # Remove HorizontalIf statements that do not have an effect
        RemoveUnreachedStatementsPass.apply(self.transform_data)

        # Merge compatible blocks
        MergeBlocksPass.apply(self.transform_data)

        # Compute used symbols
        ComputeUsedSymbolsPass.apply(self.transform_data)

        # Build IIR
        BuildIIRPass.apply(self.transform_data)

        # Fill in missing dtypes
        DataTypePass.apply(self.transform_data)

        # turn temporary fields that are only written and read within the same function
        # into local scalars
        DemoteLocalTemporariesToVariablesPass.apply(self.transform_data)

        # Replace temporary fields only assigned to scalar literals with the actual values
        ConstantFoldingPass.apply(self.transform_data)

        # prune some stages that don't have effect
        HousekeepingPass.apply(self.transform_data)

        if options.build_info is not None:
            options.build_info["def_ir"] = self.transform_data.definition_ir
            options.build_info["iir"] = self.transform_data.implementation_ir
            options.build_info["symbol_info"] = self.transform_data.symbols

        return self.transform_data.implementation_ir
Example #4
0
    def apply(self, transform_data: TransformData) -> TransformData:
        collect_demotable_symbols = self.CollectDemotableSymbols()
        demotables = collect_demotable_symbols(transform_data.implementation_ir)

        demote_symbols = self.DemoteSymbols(demotables)
        transform_data.implementation_ir = demote_symbols(transform_data.implementation_ir)

        return transform_data
Example #5
0
 def apply(self, transform_data: TransformData) -> TransformData:
     merged_blocks = greedy_merging_with_wrapper(
         transform_data.blocks, MultiStageMergingWrapper, parent=transform_data
     )
     for block in merged_blocks:
         block.ij_blocks = greedy_merging_with_wrapper(
             block.ij_blocks, StageMergingWrapper, parent=transform_data, parent_block=block
         )
     transform_data.blocks = merged_blocks
     return transform_data
Example #6
0
    def apply(self, transform_data: TransformData):
        # Greedy strategy to merge multi-stages
        merged_blocks = [transform_data.blocks[0]]
        for candidate in transform_data.blocks[1:]:
            merged = merged_blocks[-1]
            if self._are_compatible_multi_stages(
                merged, candidate, transform_data.has_sequential_axis
            ):
                merged.id = transform_data.id_generator.new
                self._merge_blocks(merged, candidate, "ij")
            else:
                merged_blocks.append(candidate)

        # Greedy strategy to merge stages
        # assert transform_data.has_sequential_axis
        for block in merged_blocks:
            merged_ijs = [block.ij_blocks[0]]
            for ij_candidate in block.ij_blocks[1:]:
                merged = merged_ijs[-1]
                if self._are_compatible_stages(
                    merged, ij_candidate, transform_data.min_k_interval_sizes
                ):
                    merged.id = transform_data.id_generator.new
                    self._merge_blocks(merged, ij_candidate, "interval")
                else:
                    merged_ijs.append(ij_candidate)

            block.ij_blocks = merged_ijs

        # Greedy strategy to merge apply methods
        for block in merged_blocks:
            for ij_block in block.ij_blocks:
                merged_ints = [ij_block.interval_blocks[0]]
                for int_candidate in ij_block.interval_blocks[1:]:
                    merged_int = merged_ints[-1]
                    if int_candidate.interval == merged_int.interval:
                        merged_int.id = transform_data.id_generator.new
                        merged_int.stmts.append(int_candidate.stmts)
                        for name, extent in int_candidate.inputs.items():
                            if name in merged_int.inputs:
                                merged_int.inputs[name] |= extent
                            else:
                                merged_int.inputs[name] = extent

                        merged_int.outputs |= int_candidate.outputs

                    else:
                        merged_ints.append(int_candidate)

                ij_block.interval_blocks = merged_ints

        transform_data.blocks = merged_blocks

        return transform_data
Example #7
0
    def __call__(self, definition_ir, options):
        implementation_ir = gt_ir.StencilImplementation(
            name=definition_ir.name,
            api_signature=[],
            domain=definition_ir.domain,
            fields={},
            parameters={},
            multi_stages=[],
            fields_extents={},
            unreferenced=[],
            axis_splitters_var=None,
            externals=definition_ir.externals,
            sources=definition_ir.sources,
            docstring=definition_ir.docstring,
        )
        self.transform_data = TransformData(
            definition_ir=definition_ir,
            implementation_ir=implementation_ir,
            options=options)

        # Initialize auxiliary data
        init_pass = InitInfoPass()
        init_pass.apply(self.transform_data)

        # Turn compute units into atomic execution units
        normalize_blocks_pass = NormalizeBlocksPass()
        normalize_blocks_pass.apply(self.transform_data)

        # Compute stage extents
        compute_extent_pass = ComputeExtentsPass()
        compute_extent_pass.apply(self.transform_data)

        # Merge compatible blocks
        merge_blocks_pass = MergeBlocksPass()
        merge_blocks_pass.apply(self.transform_data)

        # Compute used symbols
        compute_used_symbols_pass = ComputeUsedSymbolsPass()
        compute_used_symbols_pass.apply(self.transform_data)

        # Build IIR
        build_iir_pass = BuildIIRPass()
        build_iir_pass.apply(self.transform_data)

        # Fill in missing dtypes
        data_type_pass = DataTypePass()
        data_type_pass.apply(self.transform_data)

        if options.build_info is not None:
            options.build_info["def_ir"] = self.transform_data.definition_ir
            options.build_info["iir"] = self.transform_data.implementation_ir
            options.build_info["symbol_info"] = self.transform_data.symbols

        return self.transform_data.implementation_ir
Example #8
0
def make_transform_data(
    *,
    name: str,
    domain: Domain,
    fields: List[str],
    body: BodyType,
    iteration_order: IterationOrder,
) -> TransformData:
    definition = make_definition(name, fields, domain, body, iteration_order)
    return TransformData(
        definition_ir=definition,
        implementation_ir=init_implementation_from_definition(definition),
        options=BuildOptions(name=name, module=__name__),
    )
Example #9
0
    def apply(self, transform_data: TransformData):
        zero_extent = Extent.zeros(transform_data.ndims)
        blocks = []
        for block in transform_data.blocks:
            if block.iteration_order == gt_ir.IterationOrder.PARALLEL:
                # Put every statement in a single stage
                for ij_block in block.ij_blocks:
                    for interval_block in ij_block.interval_blocks:
                        for stmt_info in interval_block.stmts:
                            interval = interval_block.interval
                            new_interval_block = IntervalBlockInfo(
                                transform_data.id_generator.new,
                                interval,
                                [stmt_info],
                                stmt_info.inputs,
                                stmt_info.outputs,
                            )
                            new_ij_block = IJBlockInfo(
                                transform_data.id_generator.new,
                                {interval},
                                [new_interval_block],
                                {**new_interval_block.inputs},
                                set(new_interval_block.outputs),
                                compute_extent=zero_extent,
                            )
                            new_block = DomainBlockInfo(
                                transform_data.id_generator.new,
                                block.iteration_order,
                                set(new_ij_block.intervals),
                                [new_ij_block],
                                {**new_ij_block.inputs},
                                set(new_ij_block.outputs),
                            )
                            blocks.append(new_block)
            else:
                blocks.append(block)

        transform_data.blocks = blocks

        return transform_data
Example #10
0
    def __call__(self, definition_ir, options):
        implementation_ir = gt_ir.StencilImplementation(
            name=definition_ir.name,
            api_signature=[],
            domain=definition_ir.domain,
            fields={},
            parameters={},
            multi_stages=[],
            fields_extents={},
            unreferenced=[],
            axis_splitters_var=None,
            externals=definition_ir.externals,
            sources=definition_ir.sources,
            docstring=definition_ir.docstring,
        )
        self.transform_data = TransformData(
            definition_ir=definition_ir,
            implementation_ir=implementation_ir,
            options=options)

        # Initialize auxiliary data
        init_pass = InitInfoPass()
        init_pass.apply(self.transform_data)

        # Turn compute units into atomic execution units
        normalize_blocks_pass = NormalizeBlocksPass()
        normalize_blocks_pass.apply(self.transform_data)

        # Compute stage extents
        compute_extent_pass = ComputeExtentsPass()
        compute_extent_pass.apply(self.transform_data)

        # Merge compatible blocks
        merge_blocks_pass = MergeBlocksPass()
        merge_blocks_pass.apply(self.transform_data)

        # Compute used symbols
        compute_used_symbols_pass = ComputeUsedSymbolsPass()
        compute_used_symbols_pass.apply(self.transform_data)

        # Build IIR
        build_iir_pass = BuildIIRPass()
        build_iir_pass.apply(self.transform_data)

        # Fill in missing dtypes
        data_type_pass = DataTypePass()
        data_type_pass.apply(self.transform_data)

        # turn temporary fields that are only written and read within the same function
        # into local scalars
        demote_local_temporaries_to_variables_pass = DemoteLocalTemporariesToVariablesPass(
        )
        demote_local_temporaries_to_variables_pass.apply(self.transform_data)

        # prune some stages that don't have effect
        housekeeping_pass = HousekeepingPass()
        housekeeping_pass.apply(self.transform_data)

        if options.build_info is not None:
            options.build_info["def_ir"] = self.transform_data.definition_ir
            options.build_info["iir"] = self.transform_data.implementation_ir
            options.build_info["symbol_info"] = self.transform_data.symbols

        return self.transform_data.implementation_ir
Example #11
0
    def apply(self, transform_data: TransformData):
        transform_data.blocks = self.SplitBlocksVisitor().visit(transform_data)

        return transform_data