Ejemplo n.º 1
0
    def apply(self, transform_data: TransformData):
        # Greedy strategy to merge multi-stages
        merged_blocks = [transform_data.blocks[0]]
        for candidate in transform_data.blocks[1:]:
            merged = merged_blocks[-1]
            if self._are_compatible_multi_stages(
                    merged, candidate, transform_data.has_sequential_axis):
                merged.id = transform_data.id_generator.new
                self._merge_domain_blocks(merged, candidate)
            else:
                merged_blocks.append(candidate)

        # Greedy strategy to merge stages
        # assert transform_data.has_sequential_axis
        for block in merged_blocks:
            merged_ijs = [block.ij_blocks[0]]
            for ij_candidate in block.ij_blocks[1:]:
                merged = merged_ijs[-1]
                if self._are_compatible_stages(
                        merged, ij_candidate,
                        transform_data.min_k_interval_sizes):
                    merged.id = transform_data.id_generator.new
                    self._merge_ij_blocks(merged, ij_candidate, transform_data)
                else:
                    merged_ijs.append(ij_candidate)

            block.ij_blocks = merged_ijs

        transform_data.blocks = merged_blocks

        return transform_data
Ejemplo n.º 2
0
 def apply(self, transform_data: TransformData) -> TransformData:
     merged_blocks = greedy_merging_with_wrapper(
         transform_data.blocks, MultiStageMergingWrapper, parent=transform_data
     )
     for block in merged_blocks:
         block.ij_blocks = greedy_merging_with_wrapper(
             block.ij_blocks, StageMergingWrapper, parent=transform_data, parent_block=block
         )
     transform_data.blocks = merged_blocks
     return transform_data
Ejemplo n.º 3
0
    def apply(self, transform_data: TransformData):
        # Greedy strategy to merge multi-stages
        merged_blocks = [transform_data.blocks[0]]
        for candidate in transform_data.blocks[1:]:
            merged = merged_blocks[-1]
            if self._are_compatible_multi_stages(
                merged, candidate, transform_data.has_sequential_axis
            ):
                merged.id = transform_data.id_generator.new
                self._merge_blocks(merged, candidate, "ij")
            else:
                merged_blocks.append(candidate)

        # Greedy strategy to merge stages
        # assert transform_data.has_sequential_axis
        for block in merged_blocks:
            merged_ijs = [block.ij_blocks[0]]
            for ij_candidate in block.ij_blocks[1:]:
                merged = merged_ijs[-1]
                if self._are_compatible_stages(
                    merged, ij_candidate, transform_data.min_k_interval_sizes
                ):
                    merged.id = transform_data.id_generator.new
                    self._merge_blocks(merged, ij_candidate, "interval")
                else:
                    merged_ijs.append(ij_candidate)

            block.ij_blocks = merged_ijs

        # Greedy strategy to merge apply methods
        for block in merged_blocks:
            for ij_block in block.ij_blocks:
                merged_ints = [ij_block.interval_blocks[0]]
                for int_candidate in ij_block.interval_blocks[1:]:
                    merged_int = merged_ints[-1]
                    if int_candidate.interval == merged_int.interval:
                        merged_int.id = transform_data.id_generator.new
                        merged_int.stmts.append(int_candidate.stmts)
                        for name, extent in int_candidate.inputs.items():
                            if name in merged_int.inputs:
                                merged_int.inputs[name] |= extent
                            else:
                                merged_int.inputs[name] = extent

                        merged_int.outputs |= int_candidate.outputs

                    else:
                        merged_ints.append(int_candidate)

                ij_block.interval_blocks = merged_ints

        transform_data.blocks = merged_blocks

        return transform_data
Ejemplo n.º 4
0
    def apply(self, transform_data: TransformData):
        zero_extent = Extent.zeros(transform_data.ndims)
        blocks = []
        for block in transform_data.blocks:
            if block.iteration_order == gt_ir.IterationOrder.PARALLEL:
                # Put every statement in a single stage
                for ij_block in block.ij_blocks:
                    for interval_block in ij_block.interval_blocks:
                        for stmt_info in interval_block.stmts:
                            interval = interval_block.interval
                            new_interval_block = IntervalBlockInfo(
                                transform_data.id_generator.new,
                                interval,
                                [stmt_info],
                                stmt_info.inputs,
                                stmt_info.outputs,
                            )
                            new_ij_block = IJBlockInfo(
                                transform_data.id_generator.new,
                                {interval},
                                [new_interval_block],
                                {**new_interval_block.inputs},
                                set(new_interval_block.outputs),
                                compute_extent=zero_extent,
                            )
                            new_block = DomainBlockInfo(
                                transform_data.id_generator.new,
                                block.iteration_order,
                                set(new_ij_block.intervals),
                                [new_ij_block],
                                {**new_ij_block.inputs},
                                set(new_ij_block.outputs),
                            )
                            blocks.append(new_block)
            else:
                blocks.append(block)

        transform_data.blocks = blocks

        return transform_data
Ejemplo n.º 5
0
    def apply(self, transform_data: TransformData):
        transform_data.blocks = self.SplitBlocksVisitor().visit(transform_data)

        return transform_data