def apply(self, transform_data: TransformData): # Greedy strategy to merge multi-stages merged_blocks = [transform_data.blocks[0]] for candidate in transform_data.blocks[1:]: merged = merged_blocks[-1] if self._are_compatible_multi_stages( merged, candidate, transform_data.has_sequential_axis): merged.id = transform_data.id_generator.new self._merge_domain_blocks(merged, candidate) else: merged_blocks.append(candidate) # Greedy strategy to merge stages # assert transform_data.has_sequential_axis for block in merged_blocks: merged_ijs = [block.ij_blocks[0]] for ij_candidate in block.ij_blocks[1:]: merged = merged_ijs[-1] if self._are_compatible_stages( merged, ij_candidate, transform_data.min_k_interval_sizes): merged.id = transform_data.id_generator.new self._merge_ij_blocks(merged, ij_candidate, transform_data) else: merged_ijs.append(ij_candidate) block.ij_blocks = merged_ijs transform_data.blocks = merged_blocks return transform_data
def build_transform(self): definition = self.build() return TransformData( definition_ir=definition, implementation_ir=init_implementation_from_definition(definition), options=BuildOptions(name=self.name, module=__name__), )
def __call__(self, definition_ir, options): implementation_ir = gt_ir.StencilImplementation( name=definition_ir.name, api_signature=[], domain=definition_ir.domain, fields={}, parameters={}, multi_stages=[], fields_extents={}, unreferenced=[], axis_splitters_var=None, externals=definition_ir.externals, sources=definition_ir.sources, docstring=definition_ir.docstring, ) self.transform_data = TransformData( definition_ir=definition_ir, implementation_ir=implementation_ir, options=options ) # Initialize auxiliary data InitInfoPass.apply(self.transform_data) # Turn compute units into atomic execution units NormalizeBlocksPass.apply(self.transform_data) # Compute stage extents ComputeExtentsPass.apply(self.transform_data) # Remove HorizontalIf statements that do not have an effect RemoveUnreachedStatementsPass.apply(self.transform_data) # Merge compatible blocks MergeBlocksPass.apply(self.transform_data) # Compute used symbols ComputeUsedSymbolsPass.apply(self.transform_data) # Build IIR BuildIIRPass.apply(self.transform_data) # Fill in missing dtypes DataTypePass.apply(self.transform_data) # turn temporary fields that are only written and read within the same function # into local scalars DemoteLocalTemporariesToVariablesPass.apply(self.transform_data) # Replace temporary fields only assigned to scalar literals with the actual values ConstantFoldingPass.apply(self.transform_data) # prune some stages that don't have effect HousekeepingPass.apply(self.transform_data) if options.build_info is not None: options.build_info["def_ir"] = self.transform_data.definition_ir options.build_info["iir"] = self.transform_data.implementation_ir options.build_info["symbol_info"] = self.transform_data.symbols return self.transform_data.implementation_ir
def apply(self, transform_data: TransformData) -> TransformData: collect_demotable_symbols = self.CollectDemotableSymbols() demotables = collect_demotable_symbols(transform_data.implementation_ir) demote_symbols = self.DemoteSymbols(demotables) transform_data.implementation_ir = demote_symbols(transform_data.implementation_ir) return transform_data
def apply(self, transform_data: TransformData) -> TransformData: merged_blocks = greedy_merging_with_wrapper( transform_data.blocks, MultiStageMergingWrapper, parent=transform_data ) for block in merged_blocks: block.ij_blocks = greedy_merging_with_wrapper( block.ij_blocks, StageMergingWrapper, parent=transform_data, parent_block=block ) transform_data.blocks = merged_blocks return transform_data
def apply(self, transform_data: TransformData): # Greedy strategy to merge multi-stages merged_blocks = [transform_data.blocks[0]] for candidate in transform_data.blocks[1:]: merged = merged_blocks[-1] if self._are_compatible_multi_stages( merged, candidate, transform_data.has_sequential_axis ): merged.id = transform_data.id_generator.new self._merge_blocks(merged, candidate, "ij") else: merged_blocks.append(candidate) # Greedy strategy to merge stages # assert transform_data.has_sequential_axis for block in merged_blocks: merged_ijs = [block.ij_blocks[0]] for ij_candidate in block.ij_blocks[1:]: merged = merged_ijs[-1] if self._are_compatible_stages( merged, ij_candidate, transform_data.min_k_interval_sizes ): merged.id = transform_data.id_generator.new self._merge_blocks(merged, ij_candidate, "interval") else: merged_ijs.append(ij_candidate) block.ij_blocks = merged_ijs # Greedy strategy to merge apply methods for block in merged_blocks: for ij_block in block.ij_blocks: merged_ints = [ij_block.interval_blocks[0]] for int_candidate in ij_block.interval_blocks[1:]: merged_int = merged_ints[-1] if int_candidate.interval == merged_int.interval: merged_int.id = transform_data.id_generator.new merged_int.stmts.append(int_candidate.stmts) for name, extent in int_candidate.inputs.items(): if name in merged_int.inputs: merged_int.inputs[name] |= extent else: merged_int.inputs[name] = extent merged_int.outputs |= int_candidate.outputs else: merged_ints.append(int_candidate) ij_block.interval_blocks = merged_ints transform_data.blocks = merged_blocks return transform_data
def __call__(self, definition_ir, options): implementation_ir = gt_ir.StencilImplementation( name=definition_ir.name, api_signature=[], domain=definition_ir.domain, fields={}, parameters={}, multi_stages=[], fields_extents={}, unreferenced=[], axis_splitters_var=None, externals=definition_ir.externals, sources=definition_ir.sources, docstring=definition_ir.docstring, ) self.transform_data = TransformData( definition_ir=definition_ir, implementation_ir=implementation_ir, options=options) # Initialize auxiliary data init_pass = InitInfoPass() init_pass.apply(self.transform_data) # Turn compute units into atomic execution units normalize_blocks_pass = NormalizeBlocksPass() normalize_blocks_pass.apply(self.transform_data) # Compute stage extents compute_extent_pass = ComputeExtentsPass() compute_extent_pass.apply(self.transform_data) # Merge compatible blocks merge_blocks_pass = MergeBlocksPass() merge_blocks_pass.apply(self.transform_data) # Compute used symbols compute_used_symbols_pass = ComputeUsedSymbolsPass() compute_used_symbols_pass.apply(self.transform_data) # Build IIR build_iir_pass = BuildIIRPass() build_iir_pass.apply(self.transform_data) # Fill in missing dtypes data_type_pass = DataTypePass() data_type_pass.apply(self.transform_data) if options.build_info is not None: options.build_info["def_ir"] = self.transform_data.definition_ir options.build_info["iir"] = self.transform_data.implementation_ir options.build_info["symbol_info"] = self.transform_data.symbols return self.transform_data.implementation_ir
def make_transform_data( *, name: str, domain: Domain, fields: List[str], body: BodyType, iteration_order: IterationOrder, ) -> TransformData: definition = make_definition(name, fields, domain, body, iteration_order) return TransformData( definition_ir=definition, implementation_ir=init_implementation_from_definition(definition), options=BuildOptions(name=name, module=__name__), )
def apply(self, transform_data: TransformData): zero_extent = Extent.zeros(transform_data.ndims) blocks = [] for block in transform_data.blocks: if block.iteration_order == gt_ir.IterationOrder.PARALLEL: # Put every statement in a single stage for ij_block in block.ij_blocks: for interval_block in ij_block.interval_blocks: for stmt_info in interval_block.stmts: interval = interval_block.interval new_interval_block = IntervalBlockInfo( transform_data.id_generator.new, interval, [stmt_info], stmt_info.inputs, stmt_info.outputs, ) new_ij_block = IJBlockInfo( transform_data.id_generator.new, {interval}, [new_interval_block], {**new_interval_block.inputs}, set(new_interval_block.outputs), compute_extent=zero_extent, ) new_block = DomainBlockInfo( transform_data.id_generator.new, block.iteration_order, set(new_ij_block.intervals), [new_ij_block], {**new_ij_block.inputs}, set(new_ij_block.outputs), ) blocks.append(new_block) else: blocks.append(block) transform_data.blocks = blocks return transform_data
def __call__(self, definition_ir, options): implementation_ir = gt_ir.StencilImplementation( name=definition_ir.name, api_signature=[], domain=definition_ir.domain, fields={}, parameters={}, multi_stages=[], fields_extents={}, unreferenced=[], axis_splitters_var=None, externals=definition_ir.externals, sources=definition_ir.sources, docstring=definition_ir.docstring, ) self.transform_data = TransformData( definition_ir=definition_ir, implementation_ir=implementation_ir, options=options) # Initialize auxiliary data init_pass = InitInfoPass() init_pass.apply(self.transform_data) # Turn compute units into atomic execution units normalize_blocks_pass = NormalizeBlocksPass() normalize_blocks_pass.apply(self.transform_data) # Compute stage extents compute_extent_pass = ComputeExtentsPass() compute_extent_pass.apply(self.transform_data) # Merge compatible blocks merge_blocks_pass = MergeBlocksPass() merge_blocks_pass.apply(self.transform_data) # Compute used symbols compute_used_symbols_pass = ComputeUsedSymbolsPass() compute_used_symbols_pass.apply(self.transform_data) # Build IIR build_iir_pass = BuildIIRPass() build_iir_pass.apply(self.transform_data) # Fill in missing dtypes data_type_pass = DataTypePass() data_type_pass.apply(self.transform_data) # turn temporary fields that are only written and read within the same function # into local scalars demote_local_temporaries_to_variables_pass = DemoteLocalTemporariesToVariablesPass( ) demote_local_temporaries_to_variables_pass.apply(self.transform_data) # prune some stages that don't have effect housekeeping_pass = HousekeepingPass() housekeeping_pass.apply(self.transform_data) if options.build_info is not None: options.build_info["def_ir"] = self.transform_data.definition_ir options.build_info["iir"] = self.transform_data.implementation_ir options.build_info["symbol_info"] = self.transform_data.symbols return self.transform_data.implementation_ir
def apply(self, transform_data: TransformData): transform_data.blocks = self.SplitBlocksVisitor().visit(transform_data) return transform_data