def _make_axis_interval(self, interval: IntervalInfo): axis_bounds = [] for bound in (interval.start, interval.end): if bound[0] == 0: axis_bounds.append(gt_ir.AxisBound(level=gt_ir.LevelMarker.START, offset=bound[1])) elif bound[0] == self.data.nk_intervals: axis_bounds.append(gt_ir.AxisBound(level=gt_ir.LevelMarker.END, offset=bound[1])) else: axis_bounds.append( gt_ir.AxisBound( level=gt_ir.VarRef(name=self.data.splitters_var, index=bound[0] - 1), offset=bound[1], ) ) result = gt_ir.AxisInterval(start=axis_bounds[0], end=axis_bounds[1]) return result
def visit_FieldRef(self, node: gt_ir.FieldRef, **kwargs) -> str: intervals = kwargs.get("intervals", None) assert node.name in self.block_info.accessors is_parallel = self.block_info.iteration_order == gt_ir.IterationOrder.PARALLEL parallel_axes_names = [ axis for axis in self.impl_node.fields[node.name].axes if axis != self.domain.sequential_axis.name ] parallel_axes_dims = [ self.impl_node.domain.index(axis) for axis in parallel_axes_names ] lower_indices = self.block_info.extent.lower_indices upper_indices = self.block_info.extent.upper_indices index = [] for fd, d in enumerate(parallel_axes_dims): ax = self.domain.axes_names[d] ax_offset = node.offset.get(ax, 0) if intervals: restricted_interval = intervals[ax] start_offset = ( max(lower_indices[d], restricted_interval.start.offset) if restricted_interval.start.level == gt_ir.LevelMarker.START else restricted_interval.start.offset) end_offset = ( min(upper_indices[d], restricted_interval.end.offset) if restricted_interval.end.level == gt_ir.LevelMarker.END else restricted_interval.end.offset) axis_interval = gt_ir.AxisInterval( start=gt_ir.AxisBound( level=restricted_interval.start.level, offset=start_offset), end=gt_ir.AxisBound(level=restricted_interval.end.level, offset=end_offset), ) else: axis_interval = gt_ir.AxisInterval( start=gt_ir.AxisBound(level=gt_ir.LevelMarker.START, offset=lower_indices[d]), end=gt_ir.AxisBound(level=gt_ir.LevelMarker.END, offset=upper_indices[d]), ) origin_expr = f"{node.name}{self.origin_marker}[{fd}]" level_to_expr = { gt_ir.LevelMarker.START: origin_expr, gt_ir.LevelMarker.END: f"{origin_expr} + {self.domain_arg_name}[{fd}]", } indices = [] for bound in (axis_interval.start, axis_interval.end): total_offset = bound.offset + ax_offset total_offset_expr = " {:+d}".format( total_offset) if total_offset != 0 else "" indices.append( f"{level_to_expr[bound.level]}{total_offset_expr}") index.append(f"{indices[0]} : {indices[1]}") k_ax = self.domain.sequential_axis.name k_offset = node.offset.get(k_ax, 0) if isinstance(k_offset, gt_ir.Expr): variable_koffset = True is_parallel = False k_offset = self.visit(k_offset) else: variable_koffset = False is_parallel = (self.block_info.iteration_order == gt_ir.IterationOrder.PARALLEL and not self.block_info.variable_koffsets) if k_ax in self.impl_node.fields[node.name].axes: fd = self.impl_node.fields[node.name].axes.index(k_ax) if is_parallel: start_expr = self.interval_k_start_name start_expr += " {:+d}".format(k_offset) if k_offset else "" end_expr = self.interval_k_end_name end_expr += " {:+d}".format(k_offset) if k_offset else "" index.append( "{name}{marker}[{fd}] + {start}:{name}{marker}[{fd}] + {stop}" .format( name=node.name, start=start_expr, marker=self.origin_marker, stop=end_expr, fd=fd, )) elif not variable_koffset: idx = "{:+d}".format(k_offset) if k_offset else "" index.append("{name}{marker}[{fd}] + {ax}{idx}".format( name=node.name, marker=self.origin_marker, fd=fd, ax=k_ax, idx=idx, )) data_idx = f", {','.join(self.visit(i) for i in node.data_index)}" if not variable_koffset: source = f"{node.name}[{', '.join(index)}{data_idx}]" else: source = (f"{node.name}[" + ", ".join(f"{axis_name.upper()}_{node.name}" for axis_name in parallel_axes_names) + f", {k_ax} + {k_offset}" + "]") if not parallel_axes_dims and not is_parallel: source = f"np.asarray([{source}])" return source