예제 #1
0
    def _make_axis_interval(self, interval: IntervalInfo):
        axis_bounds = []
        for bound in (interval.start, interval.end):
            if bound[0] == 0:
                axis_bounds.append(gt_ir.AxisBound(level=gt_ir.LevelMarker.START, offset=bound[1]))
            elif bound[0] == self.data.nk_intervals:
                axis_bounds.append(gt_ir.AxisBound(level=gt_ir.LevelMarker.END, offset=bound[1]))
            else:
                axis_bounds.append(
                    gt_ir.AxisBound(
                        level=gt_ir.VarRef(name=self.data.splitters_var, index=bound[0] - 1),
                        offset=bound[1],
                    )
                )

        result = gt_ir.AxisInterval(start=axis_bounds[0], end=axis_bounds[1])

        return result
예제 #2
0
    def visit_FieldRef(self, node: gt_ir.FieldRef, **kwargs) -> str:
        intervals = kwargs.get("intervals", None)
        assert node.name in self.block_info.accessors

        is_parallel = self.block_info.iteration_order == gt_ir.IterationOrder.PARALLEL
        parallel_axes_names = [
            axis for axis in self.impl_node.fields[node.name].axes
            if axis != self.domain.sequential_axis.name
        ]
        parallel_axes_dims = [
            self.impl_node.domain.index(axis) for axis in parallel_axes_names
        ]

        lower_indices = self.block_info.extent.lower_indices
        upper_indices = self.block_info.extent.upper_indices

        index = []
        for fd, d in enumerate(parallel_axes_dims):
            ax = self.domain.axes_names[d]
            ax_offset = node.offset.get(ax, 0)

            if intervals:
                restricted_interval = intervals[ax]
                start_offset = (
                    max(lower_indices[d], restricted_interval.start.offset) if
                    restricted_interval.start.level == gt_ir.LevelMarker.START
                    else restricted_interval.start.offset)
                end_offset = (
                    min(upper_indices[d], restricted_interval.end.offset)
                    if restricted_interval.end.level == gt_ir.LevelMarker.END
                    else restricted_interval.end.offset)
                axis_interval = gt_ir.AxisInterval(
                    start=gt_ir.AxisBound(
                        level=restricted_interval.start.level,
                        offset=start_offset),
                    end=gt_ir.AxisBound(level=restricted_interval.end.level,
                                        offset=end_offset),
                )
            else:
                axis_interval = gt_ir.AxisInterval(
                    start=gt_ir.AxisBound(level=gt_ir.LevelMarker.START,
                                          offset=lower_indices[d]),
                    end=gt_ir.AxisBound(level=gt_ir.LevelMarker.END,
                                        offset=upper_indices[d]),
                )

            origin_expr = f"{node.name}{self.origin_marker}[{fd}]"
            level_to_expr = {
                gt_ir.LevelMarker.START:
                origin_expr,
                gt_ir.LevelMarker.END:
                f"{origin_expr} + {self.domain_arg_name}[{fd}]",
            }

            indices = []
            for bound in (axis_interval.start, axis_interval.end):
                total_offset = bound.offset + ax_offset
                total_offset_expr = " {:+d}".format(
                    total_offset) if total_offset != 0 else ""
                indices.append(
                    f"{level_to_expr[bound.level]}{total_offset_expr}")

            index.append(f"{indices[0]} : {indices[1]}")

        k_ax = self.domain.sequential_axis.name
        k_offset = node.offset.get(k_ax, 0)
        if isinstance(k_offset, gt_ir.Expr):
            variable_koffset = True
            is_parallel = False
            k_offset = self.visit(k_offset)
        else:
            variable_koffset = False
            is_parallel = (self.block_info.iteration_order
                           == gt_ir.IterationOrder.PARALLEL
                           and not self.block_info.variable_koffsets)

        if k_ax in self.impl_node.fields[node.name].axes:
            fd = self.impl_node.fields[node.name].axes.index(k_ax)
            if is_parallel:
                start_expr = self.interval_k_start_name
                start_expr += " {:+d}".format(k_offset) if k_offset else ""
                end_expr = self.interval_k_end_name
                end_expr += " {:+d}".format(k_offset) if k_offset else ""
                index.append(
                    "{name}{marker}[{fd}] + {start}:{name}{marker}[{fd}] + {stop}"
                    .format(
                        name=node.name,
                        start=start_expr,
                        marker=self.origin_marker,
                        stop=end_expr,
                        fd=fd,
                    ))
            elif not variable_koffset:
                idx = "{:+d}".format(k_offset) if k_offset else ""
                index.append("{name}{marker}[{fd}] + {ax}{idx}".format(
                    name=node.name,
                    marker=self.origin_marker,
                    fd=fd,
                    ax=k_ax,
                    idx=idx,
                ))

        data_idx = f", {','.join(self.visit(i) for i in node.data_index)}"
        if not variable_koffset:
            source = f"{node.name}[{', '.join(index)}{data_idx}]"
        else:
            source = (f"{node.name}[" +
                      ", ".join(f"{axis_name.upper()}_{node.name}"
                                for axis_name in parallel_axes_names) +
                      f", {k_ax} + {k_offset}" + "]")
        if not parallel_axes_dims and not is_parallel:
            source = f"np.asarray([{source}])"

        return source