Exemple #1
0
 def intervals_overlap_or_imply_reorder(
     self,
     interval_a: IntervalInfo,
     interval_b: IntervalInfo,
 ) -> bool:
     return interval_a != interval_b and interval_a.overlaps(
         interval_b, self.min_k_interval_sizes)
Exemple #2
0
        def visit_StencilDefinition(self, node: gt_ir.StencilDefinition):
            self.data.splitters_var = None
            self.data.min_k_interval_sizes = [0]

            # First, look for dynamic splitters variable
            for computation in node.computations:
                interval_def = computation.interval
                for axis_bound in [interval_def.start, interval_def.end]:
                    if isinstance(axis_bound.level, gt_ir.VarRef):
                        name = axis_bound.level.name
                        for item in node.parameters:
                            if item.name == name:
                                decl = item
                                break
                        else:
                            decl = None

                        if decl is None or decl.length == 0:
                            raise IntervalSpecificationError(
                                interval_def,
                                "Invalid variable reference in interval specification",
                                loc=axis_bound.loc,
                            )

                        self.data.splitters_var = decl.name
                        self.data.min_k_interval_sizes = [1] * (decl.length + 1)

            # Extract computation intervals
            computation_intervals = []
            for computation in node.computations:
                # Process current interval definition
                interval_def = computation.interval
                bounds = [None, None]

                for i, axis_bound in enumerate([interval_def.start, interval_def.end]):
                    if isinstance(axis_bound.level, gt_ir.VarRef):
                        # Dynamic splitters: check existing reference and extract size info
                        if axis_bound.level.name != self.data.splitters_var:
                            raise IntervalSpecificationError(
                                interval_def,
                                "Non matching variable reference in interval specification",
                                loc=axis_bound.loc,
                            )

                        index = axis_bound.level.index + 1
                        offset = axis_bound.offset
                        if offset < 0:
                            index = index - 1

                    else:
                        # Static splitter: extract size info
                        index = (
                            self.data.nk_intervals
                            if axis_bound.offset < 0 or axis_bound.level == gt_ir.LevelMarker.END
                            else 0
                        )
                        offset = axis_bound.offset

                        if offset < 0 and axis_bound.level != gt_ir.LevelMarker.END:
                            raise IntervalSpecificationError(
                                interval_def,
                                "Invalid offset in interval specification",
                                loc=axis_bound.loc,
                            )

                        elif offset > 0 and axis_bound.level != gt_ir.LevelMarker.START:
                            raise IntervalSpecificationError(
                                interval_def,
                                "Invalid offset in interval specification",
                                loc=axis_bound.loc,
                            )

                    # Update min sizes
                    if not 0 <= index <= self.data.nk_intervals:
                        raise IntervalSpecificationError(
                            interval_def,
                            "Invalid variable reference in interval specification",
                            loc=axis_bound.loc,
                        )

                    bounds[i] = (index, offset)
                    if index < self.data.nk_intervals:
                        self.data.min_k_interval_sizes[index] = max(
                            self.data.min_k_interval_sizes[index], offset
                        )

                if bounds[0][0] == bounds[1][0] - 1:
                    index = bounds[0][0]
                    min_size = 1 + bounds[0][1] - bounds[1][1]
                    self.data.min_k_interval_sizes[index] = max(
                        self.data.min_k_interval_sizes[index], min_size
                    )

                # Create computation intervals
                interval_info = IntervalInfo(*bounds)
                computation_intervals.append(interval_info)

            return computation_intervals