def _resize(image, shape: core.Shape, method: Union[str, ResizeMethod], antialias: bool, precision): if len(shape) != image.ndim: msg = ( 'shape must have length equal to the number of dimensions of x; ' f' {shape} vs {image.shape}') raise ValueError(msg) if isinstance(method, str): method = ResizeMethod.from_string(method) if method == ResizeMethod.NEAREST: return _resize_nearest(image, shape) assert isinstance(method, ResizeMethod) kernel = _kernels[method] if not jnp.issubdtype(image.dtype, jnp.inexact): image = lax.convert_element_type(image, jnp.result_type(image, jnp.float32)) # Skip dimensions that have scale=1 and translation=0, this is only possible # since all of the current resize methods (kernels) are interpolating, so the # output = input under an identity warp. spatial_dims = tuple( i for i in range(len(shape)) if not core.symbolic_equal_dim(image.shape[i], shape[i])) scale = [ 1.0 if core.symbolic_equal_dim( shape[d], 0) else core.dimension_as_value(shape[d]) / core.dimension_as_value(image.shape[d]) for d in spatial_dims ] return _scale_and_translate(image, shape, spatial_dims, scale, [0.] * len(spatial_dims), kernel, antialias, precision)
def test_dim_vars_symbolic_equal(self): da, db = shape_poly.parse_spec("a, b", (2, 3)) self.assertTrue(core.symbolic_equal_dim(da, da)) self.assertFalse(core.symbolic_equal_dim(da, 1)) self.assertFalse(core.symbolic_equal_dim(da, db)) self.assertTrue(core.symbolic_equal_one_of_dim(da, [2, da])) self.assertFalse(core.symbolic_equal_one_of_dim(da, [2, db])) self.assertFalse(core.symbolic_equal_one_of_dim(da, [])) self.assertTrue(core.symbolic_equal_one_of_dim(2, [da, 3, 2])) self.assertFalse(core.symbolic_equal_one_of_dim(1, [2, db])) self.assertFalse(core.symbolic_equal_one_of_dim(3, []))
def _pre_gather_with_batch_dims(args: GatherArgs): """Returns True if this call to gather has non-empty batch dimensions. This is for instance triggered when doing jax.vmap(lax.dynamic_slice). """ # All dimensions in the output array and not in offset_dims are batch_dims. batch_dims = tuple(x for x in range(len(args.out_aval.shape)) if x not in args.dnums.offset_dims) # We assume exactly one batch (and one or more non-batch dimensions). if len(batch_dims) != 1: raise ValueError(f"batch_dims is {len(batch_dims)} but should be 1") # `start_index_map` maps indices in `start_indices` to indices in `operand`. # For simplicity, we currently only consider the case where this mapping is # the identity function, i.e., [2, 3] in `start_indices` maps to # `operand[2, 3]`. if args.dnums.start_index_map != tuple(range( args.start_indices_shape[-1])): raise ValueError("unsupported start_index_map") # The batch dims in `start_indices` and `operand` should agree. if not core.symbolic_equal_dim(args.op_shape[0], args.start_indices_shape[0]): raise ValueError("Batch dimensions in operand and start_indices don't " "agree")
def _broadcast_to(arr, shape): if hasattr(arr, "broadcast_to"): return arr.broadcast_to(shape) _check_arraylike("broadcast_to", arr) arr = arr if isinstance(arr, ndarray) else _asarray(arr) if not isinstance(shape, tuple) and np.ndim(shape) == 0: shape = (shape, ) shape = core.canonicalize_shape(shape) # check that shape is concrete arr_shape = np.shape(arr) if core.symbolic_equal_shape(arr_shape, shape): return arr else: nlead = len(shape) - len(arr_shape) shape_tail = shape[nlead:] compatible = all( core.symbolic_equal_one_of_dim(arr_d, [1, shape_d]) for arr_d, shape_d in safe_zip(arr_shape, shape_tail)) if nlead < 0 or not compatible: msg = "Incompatible shapes for broadcasting: {} and requested shape {}" raise ValueError(msg.format(arr_shape, shape)) diff, = np.where( tuple(not core.symbolic_equal_dim(arr_d, shape_d) for arr_d, shape_d in safe_zip(arr_shape, shape_tail))) new_dims = tuple(range(nlead)) + tuple(nlead + diff) kept_dims = tuple(np.delete(np.arange(len(shape)), new_dims)) return lax.broadcast_in_dim(lax.squeeze(arr, tuple(diff)), shape, kept_dims)
def test_dim_vars_symbolic_equal(self): a, b = shape_poly.parse_spec("a, b", (2, 3)) self.assertTrue(core.symbolic_equal_dim(a, a)) self.assertFalse(core.symbolic_equal_dim(a, 1)) self.assertFalse(core.symbolic_equal_dim(a, b)) self.assertTrue(core.symbolic_equal_one_of_dim(a, [2, a])) self.assertFalse(core.symbolic_equal_one_of_dim(a, [2, b])) self.assertFalse(core.symbolic_equal_one_of_dim(a, [])) self.assertTrue(core.symbolic_equal_one_of_dim(2, [a, 3, 2])) self.assertFalse(core.symbolic_equal_one_of_dim(1, [2, b])) self.assertFalse(core.symbolic_equal_one_of_dim(3, [])) self.assertTrue(core.symbolic_equal_dim(1, jnp.add(0, 1))) # A DeviceArray with self.assertRaisesRegex(TypeError, re.escape("Shapes must be 1D sequences of concrete values of integer type, got (1, 'a').")): self.assertTrue(core.symbolic_equal_dim(1, "a"))
def _average(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, weights=None, returned=False): a = _asarray(a) if weights is None: # Treat all weights as 1 avg = mean(a, axis=axis) if axis is None: weights_sum = lax.full((), core.dimension_as_value(np.size(a)), dtype=avg.dtype) else: weights_sum = lax.full_like(avg, core.dimension_as_value(a.shape[axis]), dtype=avg.dtype) else: weights = _asarray(weights) if dtypes.issubdtype(a.dtype, np.inexact): out_dtype = dtypes.result_type(a.dtype, weights.dtype) else: out_dtype = dtypes.result_type(a.dtype, weights.dtype, dtypes.float_) out_dtype = dtypes.canonicalize_dtype(out_dtype) a_shape = np.shape(a) a_ndim = len(a_shape) weights_shape = np.shape(weights) axis = None if axis is None else _canonicalize_axis(axis, a_ndim) if a_shape != weights_shape: # Make sure the dimensions work out if axis is None: raise ValueError("Axis must be specified when shapes of a and " "weights differ.") if len(weights_shape) != 1: raise ValueError("1D weights expected when shapes of a and " "weights differ.") if not core.symbolic_equal_dim(weights_shape[0], a_shape[axis]): raise ValueError("Length of weights not " "compatible with specified axis.") weights = _broadcast_to(weights, (a_ndim - 1) * (1, ) + weights_shape) weights = _moveaxis(weights, -1, axis) weights_sum = sum(weights, axis=axis, dtype=out_dtype) avg = sum(a * weights, axis=axis, dtype=out_dtype) / weights_sum if returned: if avg.shape != weights_sum.shape: weights_sum = _broadcast_to(weights_sum, avg.shape) return avg, weights_sum return avg
def divide_shape_sizes(self, s1: Shape, s2: Shape) -> DimSize: sz1 = np.prod(s1) sz2 = np.prod(s2) if core.symbolic_equal_dim(sz1, sz2): # Takes care also of sz1 == sz2 == 0 return 1 err_msg = f"Cannot divide evenly the sizes of shapes {tuple(s1)} and {tuple(s2)}" try: q, r = _ensure_poly(sz1).divmod(sz2) except InconclusiveDimensionOperation: raise InconclusiveDimensionOperation(err_msg) if r != 0: raise InconclusiveDimensionOperation(err_msg) return q # type: ignore[return-value]
def _resize_nearest(x, output_shape: core.Shape): input_shape = x.shape assert len(input_shape) == len(output_shape) spatial_dims = tuple(i for i in range(len(input_shape)) if not core.symbolic_equal_dim(input_shape[i], output_shape[i])) for d in spatial_dims: m = input_shape[d] n = output_shape[d] offsets = (jnp.arange(n) + 0.5) * core.dimension_as_value(m) / core.dimension_as_value(n) # TODO(b/206898375): this computation produces the wrong result on # CPU and GPU when using float64. Use float32 until the bug is fixed. offsets = jnp.floor(offsets.astype(np.float32)).astype(np.int32) indices = [slice(None)] * len(input_shape) indices[d] = offsets x = x[tuple(indices)] return x
def _pre_gather_for_multidim_indexing(args: GatherArgs): """Returns True if this call to gather represents multi-dimensional indexing. E.g., jnp.take(op, [[0], [1]], axis=0). Note we currently only support multi-dimensional indexing if the last dimension is 1. """ # Handle only the case when tf.gather argument batch_dims=0. # Find axis to match the tf.gather semantics # Let I = len(start_indices_shape) # let O = len(op_shape) # slice_sizes == op_shape[:axis] + (1,) + op_shape[axis+1:] # collapsed_slice_dims == (axis,) # start_index_map == (axis,) # offset_dims == (0, 1, ..., axis - 1, axis + I, ..., O + I - 1) # We added a trailing dimension of size 1 op_shape = args.op_shape start_index_map = args.dnums.start_index_map collapsed_slice_dims = args.dnums.collapsed_slice_dims offset_dims = args.dnums.offset_dims if not (len(op_shape) >= 1 and len(start_index_map) == 1 and len(collapsed_slice_dims) == 1 and collapsed_slice_dims[0] == start_index_map[0] and len(offset_dims) == len(op_shape) - 1): raise ValueError("unsupported dimension numbers") # We added a trailing dimension of size 1 if not core.symbolic_equal_dim(args.start_indices_shape[-1], 1): raise ValueError("start_indices shape[-1] should be 1") # Guess the axis axis = collapsed_slice_dims[0] index_dims = len(args.start_indices_shape) - 1 expected_offset_dims = tuple( list(range(axis)) + list(range(axis + index_dims, len(op_shape) + index_dims - 1))) if offset_dims != expected_offset_dims: raise ValueError("unsupported offset_dims") expected_slice_sizes = op_shape[:axis] + ( 1, ) + op_shape[axis + 1:] # type: ignore if not core.symbolic_equal_shape(args.slice_sizes, expected_slice_sizes): raise ValueError("unsupported slice_sizes")
def _conv_general_dilated_shape_rule(lhs: core.ShapedArray, rhs: core.ShapedArray, *, window_strides, padding, lhs_dilation, rhs_dilation, dimension_numbers, feature_group_count, batch_group_count, **unused_kwargs) -> Tuple[int, ...]: assert type(dimension_numbers) is ConvDimensionNumbers if len(lhs.shape) != len(rhs.shape): msg = ("conv_general_dilated lhs and rhs must have the same number of " "dimensions, but got {} and {}.") raise ValueError(msg.format(lhs.shape, rhs.shape)) if not feature_group_count > 0: msg = ("conv_general_dilated feature_group_count " "must be a positive integer, got {}.") raise ValueError(msg.format(feature_group_count)) lhs_feature_count = lhs.shape[dimension_numbers.lhs_spec[1]] quot, rem = divmod(lhs_feature_count, feature_group_count) if rem: msg = ( "conv_general_dilated feature_group_count must divide lhs feature " "dimension size, but {} does not divide {}.") raise ValueError(msg.format(feature_group_count, lhs_feature_count)) if not core.symbolic_equal_dim(quot, rhs.shape[dimension_numbers.rhs_spec[1]]): msg = ( "conv_general_dilated lhs feature dimension size divided by " "feature_group_count must equal the rhs input feature dimension " "size, but {} // {} != {}.") raise ValueError( msg.format(lhs_feature_count, feature_group_count, rhs.shape[dimension_numbers.rhs_spec[1]])) if rhs.shape[dimension_numbers.rhs_spec[0]] % feature_group_count: msg = ( "conv_general_dilated rhs output feature dimension size must be a " "multiple of feature_group_count, but {} is not a multiple of {}.") raise ValueError( msg.format(rhs.shape[dimension_numbers.rhs_spec[0]], feature_group_count)) if not batch_group_count > 0: msg = ("conv_general_dilated batch_group_count " "must be a positive integer, got {}.") raise ValueError(msg.format(batch_group_count)) lhs_batch_count = lhs.shape[dimension_numbers.lhs_spec[0]] if batch_group_count > 1 and lhs_batch_count % batch_group_count != 0: msg = ("conv_general_dilated batch_group_count must divide lhs batch " "dimension size, but {} does not divide {}.") raise ValueError(msg.format(batch_group_count, lhs_batch_count)) if rhs.shape[dimension_numbers.rhs_spec[0]] % batch_group_count: msg = ( "conv_general_dilated rhs output feature dimension size must be a " "multiple of batch_group_count, but {} is not a multiple of {}.") raise ValueError( msg.format(rhs.shape[dimension_numbers.rhs_spec[0]], batch_group_count)) if batch_group_count > 1 and feature_group_count > 1: msg = ( "At most one of batch_group_count and feature_group_count may be > " "1, got batch_group_count={} and feature_group_count={}") raise ValueError(msg.format(batch_group_count, feature_group_count)) if len(_conv_sdims(dimension_numbers.rhs_spec)) != len(window_strides): msg = ("conv_general_dilated window and window_strides must have " "the same number of dimensions, but got {} and {}") raise ValueError( msg.format(len(_conv_sdims(dimension_numbers.rhs_spec)), len(window_strides))) lhs_perm, rhs_perm, out_perm = dimension_numbers lhs_trans = lax._dilate_shape(np.take(lhs.shape, lhs_perm), lhs_dilation) rhs_trans = lax._dilate_shape(np.take(rhs.shape, rhs_perm), rhs_dilation) out_trans = conv_shape_tuple(lhs_trans, rhs_trans, window_strides, padding, batch_group_count) return tuple(np.take(out_trans, np.argsort(out_perm))) # type: ignore[arg-type]