def intervals_overlap_or_imply_reorder( self, interval_a: IntervalInfo, interval_b: IntervalInfo, ) -> bool: return interval_a != interval_b and interval_a.overlaps( interval_b, self.min_k_interval_sizes)
def visit_StencilDefinition(self, node: gt_ir.StencilDefinition): self.data.splitters_var = None self.data.min_k_interval_sizes = [0] # First, look for dynamic splitters variable for computation in node.computations: interval_def = computation.interval for axis_bound in [interval_def.start, interval_def.end]: if isinstance(axis_bound.level, gt_ir.VarRef): name = axis_bound.level.name for item in node.parameters: if item.name == name: decl = item break else: decl = None if decl is None or decl.length == 0: raise IntervalSpecificationError( interval_def, "Invalid variable reference in interval specification", loc=axis_bound.loc, ) self.data.splitters_var = decl.name self.data.min_k_interval_sizes = [1] * (decl.length + 1) # Extract computation intervals computation_intervals = [] for computation in node.computations: # Process current interval definition interval_def = computation.interval bounds = [None, None] for i, axis_bound in enumerate([interval_def.start, interval_def.end]): if isinstance(axis_bound.level, gt_ir.VarRef): # Dynamic splitters: check existing reference and extract size info if axis_bound.level.name != self.data.splitters_var: raise IntervalSpecificationError( interval_def, "Non matching variable reference in interval specification", loc=axis_bound.loc, ) index = axis_bound.level.index + 1 offset = axis_bound.offset if offset < 0: index = index - 1 else: # Static splitter: extract size info index = ( self.data.nk_intervals if axis_bound.offset < 0 or axis_bound.level == gt_ir.LevelMarker.END else 0 ) offset = axis_bound.offset if offset < 0 and axis_bound.level != gt_ir.LevelMarker.END: raise IntervalSpecificationError( interval_def, "Invalid offset in interval specification", loc=axis_bound.loc, ) elif offset > 0 and axis_bound.level != gt_ir.LevelMarker.START: raise IntervalSpecificationError( interval_def, "Invalid offset in interval specification", loc=axis_bound.loc, ) # Update min sizes if not 0 <= index <= self.data.nk_intervals: raise IntervalSpecificationError( interval_def, "Invalid variable reference in interval specification", loc=axis_bound.loc, ) bounds[i] = (index, offset) if index < self.data.nk_intervals: self.data.min_k_interval_sizes[index] = max( self.data.min_k_interval_sizes[index], offset ) if bounds[0][0] == bounds[1][0] - 1: index = bounds[0][0] min_size = 1 + bounds[0][1] - bounds[1][1] self.data.min_k_interval_sizes[index] = max( self.data.min_k_interval_sizes[index], min_size ) # Create computation intervals interval_info = IntervalInfo(*bounds) computation_intervals.append(interval_info) return computation_intervals