def apply(self, transform_data: TransformData): # Greedy strategy to merge multi-stages merged_blocks = [transform_data.blocks[0]] for candidate in transform_data.blocks[1:]: merged = merged_blocks[-1] if self._are_compatible_multi_stages( merged, candidate, transform_data.has_sequential_axis): merged.id = transform_data.id_generator.new self._merge_domain_blocks(merged, candidate) else: merged_blocks.append(candidate) # Greedy strategy to merge stages # assert transform_data.has_sequential_axis for block in merged_blocks: merged_ijs = [block.ij_blocks[0]] for ij_candidate in block.ij_blocks[1:]: merged = merged_ijs[-1] if self._are_compatible_stages( merged, ij_candidate, transform_data.min_k_interval_sizes): merged.id = transform_data.id_generator.new self._merge_ij_blocks(merged, ij_candidate, transform_data) else: merged_ijs.append(ij_candidate) block.ij_blocks = merged_ijs transform_data.blocks = merged_blocks return transform_data
def apply(self, transform_data: TransformData) -> TransformData: merged_blocks = greedy_merging_with_wrapper( transform_data.blocks, MultiStageMergingWrapper, parent=transform_data ) for block in merged_blocks: block.ij_blocks = greedy_merging_with_wrapper( block.ij_blocks, StageMergingWrapper, parent=transform_data, parent_block=block ) transform_data.blocks = merged_blocks return transform_data
def apply(self, transform_data: TransformData): # Greedy strategy to merge multi-stages merged_blocks = [transform_data.blocks[0]] for candidate in transform_data.blocks[1:]: merged = merged_blocks[-1] if self._are_compatible_multi_stages( merged, candidate, transform_data.has_sequential_axis ): merged.id = transform_data.id_generator.new self._merge_blocks(merged, candidate, "ij") else: merged_blocks.append(candidate) # Greedy strategy to merge stages # assert transform_data.has_sequential_axis for block in merged_blocks: merged_ijs = [block.ij_blocks[0]] for ij_candidate in block.ij_blocks[1:]: merged = merged_ijs[-1] if self._are_compatible_stages( merged, ij_candidate, transform_data.min_k_interval_sizes ): merged.id = transform_data.id_generator.new self._merge_blocks(merged, ij_candidate, "interval") else: merged_ijs.append(ij_candidate) block.ij_blocks = merged_ijs # Greedy strategy to merge apply methods for block in merged_blocks: for ij_block in block.ij_blocks: merged_ints = [ij_block.interval_blocks[0]] for int_candidate in ij_block.interval_blocks[1:]: merged_int = merged_ints[-1] if int_candidate.interval == merged_int.interval: merged_int.id = transform_data.id_generator.new merged_int.stmts.append(int_candidate.stmts) for name, extent in int_candidate.inputs.items(): if name in merged_int.inputs: merged_int.inputs[name] |= extent else: merged_int.inputs[name] = extent merged_int.outputs |= int_candidate.outputs else: merged_ints.append(int_candidate) ij_block.interval_blocks = merged_ints transform_data.blocks = merged_blocks return transform_data
def apply(self, transform_data: TransformData): zero_extent = Extent.zeros(transform_data.ndims) blocks = [] for block in transform_data.blocks: if block.iteration_order == gt_ir.IterationOrder.PARALLEL: # Put every statement in a single stage for ij_block in block.ij_blocks: for interval_block in ij_block.interval_blocks: for stmt_info in interval_block.stmts: interval = interval_block.interval new_interval_block = IntervalBlockInfo( transform_data.id_generator.new, interval, [stmt_info], stmt_info.inputs, stmt_info.outputs, ) new_ij_block = IJBlockInfo( transform_data.id_generator.new, {interval}, [new_interval_block], {**new_interval_block.inputs}, set(new_interval_block.outputs), compute_extent=zero_extent, ) new_block = DomainBlockInfo( transform_data.id_generator.new, block.iteration_order, set(new_ij_block.intervals), [new_ij_block], {**new_ij_block.inputs}, set(new_ij_block.outputs), ) blocks.append(new_block) else: blocks.append(block) transform_data.blocks = blocks return transform_data
def apply(self, transform_data: TransformData): transform_data.blocks = self.SplitBlocksVisitor().visit(transform_data) return transform_data