コード例 #1
0
ファイル: scale.py プロジェクト: matthewfeickert/jax
def _resize(image, shape: core.Shape, method: Union[str, ResizeMethod],
            antialias: bool, precision):
    if len(shape) != image.ndim:
        msg = (
            'shape must have length equal to the number of dimensions of x; '
            f' {shape} vs {image.shape}')
        raise ValueError(msg)
    if isinstance(method, str):
        method = ResizeMethod.from_string(method)
    if method == ResizeMethod.NEAREST:
        return _resize_nearest(image, shape)
    assert isinstance(method, ResizeMethod)
    kernel = _kernels[method]

    if not jnp.issubdtype(image.dtype, jnp.inexact):
        image = lax.convert_element_type(image,
                                         jnp.result_type(image, jnp.float32))
    # Skip dimensions that have scale=1 and translation=0, this is only possible
    # since all of the current resize methods (kernels) are interpolating, so the
    # output = input under an identity warp.
    spatial_dims = tuple(
        i for i in range(len(shape))
        if not core.symbolic_equal_dim(image.shape[i], shape[i]))
    scale = [
        1.0 if core.symbolic_equal_dim(
            shape[d], 0) else core.dimension_as_value(shape[d]) /
        core.dimension_as_value(image.shape[d]) for d in spatial_dims
    ]
    return _scale_and_translate(image, shape, spatial_dims, scale,
                                [0.] * len(spatial_dims), kernel, antialias,
                                precision)
コード例 #2
0
ファイル: shape_poly_test.py プロジェクト: memari-majid/jax
    def test_dim_vars_symbolic_equal(self):
        da, db = shape_poly.parse_spec("a, b", (2, 3))
        self.assertTrue(core.symbolic_equal_dim(da, da))
        self.assertFalse(core.symbolic_equal_dim(da, 1))
        self.assertFalse(core.symbolic_equal_dim(da, db))

        self.assertTrue(core.symbolic_equal_one_of_dim(da, [2, da]))
        self.assertFalse(core.symbolic_equal_one_of_dim(da, [2, db]))
        self.assertFalse(core.symbolic_equal_one_of_dim(da, []))

        self.assertTrue(core.symbolic_equal_one_of_dim(2, [da, 3, 2]))
        self.assertFalse(core.symbolic_equal_one_of_dim(1, [2, db]))
        self.assertFalse(core.symbolic_equal_one_of_dim(3, []))
コード例 #3
0
def _pre_gather_with_batch_dims(args: GatherArgs):
    """Returns True if this call to gather has non-empty batch dimensions.

  This is for instance triggered when doing jax.vmap(lax.dynamic_slice).
  """
    # All dimensions in the output array and not in offset_dims are batch_dims.
    batch_dims = tuple(x for x in range(len(args.out_aval.shape))
                       if x not in args.dnums.offset_dims)

    # We assume exactly one batch (and one or more non-batch dimensions).
    if len(batch_dims) != 1:
        raise ValueError(f"batch_dims is {len(batch_dims)} but should be 1")

    # `start_index_map` maps indices in `start_indices` to indices in `operand`.
    # For simplicity, we currently only consider the case where this mapping is
    # the identity function, i.e., [2, 3] in `start_indices` maps to
    # `operand[2, 3]`.
    if args.dnums.start_index_map != tuple(range(
            args.start_indices_shape[-1])):
        raise ValueError("unsupported start_index_map")

    # The batch dims in `start_indices` and `operand` should agree.
    if not core.symbolic_equal_dim(args.op_shape[0],
                                   args.start_indices_shape[0]):
        raise ValueError("Batch dimensions in operand and start_indices don't "
                         "agree")
コード例 #4
0
def _broadcast_to(arr, shape):
    if hasattr(arr, "broadcast_to"):
        return arr.broadcast_to(shape)
    _check_arraylike("broadcast_to", arr)
    arr = arr if isinstance(arr, ndarray) else _asarray(arr)
    if not isinstance(shape, tuple) and np.ndim(shape) == 0:
        shape = (shape, )
    shape = core.canonicalize_shape(shape)  # check that shape is concrete
    arr_shape = np.shape(arr)
    if core.symbolic_equal_shape(arr_shape, shape):
        return arr
    else:
        nlead = len(shape) - len(arr_shape)
        shape_tail = shape[nlead:]
        compatible = all(
            core.symbolic_equal_one_of_dim(arr_d, [1, shape_d])
            for arr_d, shape_d in safe_zip(arr_shape, shape_tail))
        if nlead < 0 or not compatible:
            msg = "Incompatible shapes for broadcasting: {} and requested shape {}"
            raise ValueError(msg.format(arr_shape, shape))
        diff, = np.where(
            tuple(not core.symbolic_equal_dim(arr_d, shape_d)
                  for arr_d, shape_d in safe_zip(arr_shape, shape_tail)))
        new_dims = tuple(range(nlead)) + tuple(nlead + diff)
        kept_dims = tuple(np.delete(np.arange(len(shape)), new_dims))
        return lax.broadcast_in_dim(lax.squeeze(arr, tuple(diff)), shape,
                                    kept_dims)
コード例 #5
0
  def test_dim_vars_symbolic_equal(self):
    a, b = shape_poly.parse_spec("a, b", (2, 3))
    self.assertTrue(core.symbolic_equal_dim(a, a))
    self.assertFalse(core.symbolic_equal_dim(a, 1))
    self.assertFalse(core.symbolic_equal_dim(a, b))

    self.assertTrue(core.symbolic_equal_one_of_dim(a, [2, a]))
    self.assertFalse(core.symbolic_equal_one_of_dim(a, [2, b]))
    self.assertFalse(core.symbolic_equal_one_of_dim(a, []))

    self.assertTrue(core.symbolic_equal_one_of_dim(2, [a, 3, 2]))
    self.assertFalse(core.symbolic_equal_one_of_dim(1, [2, b]))
    self.assertFalse(core.symbolic_equal_one_of_dim(3, []))

    self.assertTrue(core.symbolic_equal_dim(1, jnp.add(0, 1)))  # A DeviceArray
    with self.assertRaisesRegex(TypeError,
                                re.escape("Shapes must be 1D sequences of concrete values of integer type, got (1, 'a').")):
      self.assertTrue(core.symbolic_equal_dim(1, "a"))
コード例 #6
0
ファイル: reductions.py プロジェクト: cloudhan/jax
def _average(a,
             axis: Optional[Union[int, Tuple[int, ...]]] = None,
             weights=None,
             returned=False):
    a = _asarray(a)

    if weights is None:  # Treat all weights as 1
        avg = mean(a, axis=axis)
        if axis is None:
            weights_sum = lax.full((),
                                   core.dimension_as_value(np.size(a)),
                                   dtype=avg.dtype)
        else:
            weights_sum = lax.full_like(avg,
                                        core.dimension_as_value(a.shape[axis]),
                                        dtype=avg.dtype)
    else:
        weights = _asarray(weights)

        if dtypes.issubdtype(a.dtype, np.inexact):
            out_dtype = dtypes.result_type(a.dtype, weights.dtype)
        else:
            out_dtype = dtypes.result_type(a.dtype, weights.dtype,
                                           dtypes.float_)
        out_dtype = dtypes.canonicalize_dtype(out_dtype)

        a_shape = np.shape(a)
        a_ndim = len(a_shape)
        weights_shape = np.shape(weights)
        axis = None if axis is None else _canonicalize_axis(axis, a_ndim)

        if a_shape != weights_shape:
            # Make sure the dimensions work out
            if axis is None:
                raise ValueError("Axis must be specified when shapes of a and "
                                 "weights differ.")
            if len(weights_shape) != 1:
                raise ValueError("1D weights expected when shapes of a and "
                                 "weights differ.")
            if not core.symbolic_equal_dim(weights_shape[0], a_shape[axis]):
                raise ValueError("Length of weights not "
                                 "compatible with specified axis.")

            weights = _broadcast_to(weights,
                                    (a_ndim - 1) * (1, ) + weights_shape)
            weights = _moveaxis(weights, -1, axis)

        weights_sum = sum(weights, axis=axis, dtype=out_dtype)
        avg = sum(a * weights, axis=axis, dtype=out_dtype) / weights_sum

    if returned:
        if avg.shape != weights_sum.shape:
            weights_sum = _broadcast_to(weights_sum, avg.shape)
        return avg, weights_sum
    return avg
コード例 #7
0
ファイル: shape_poly.py プロジェクト: zhangqiaorjc/jax
 def divide_shape_sizes(self, s1: Shape, s2: Shape) -> DimSize:
   sz1 = np.prod(s1)
   sz2 = np.prod(s2)
   if core.symbolic_equal_dim(sz1, sz2):  # Takes care also of sz1 == sz2 == 0
     return 1
   err_msg = f"Cannot divide evenly the sizes of shapes {tuple(s1)} and {tuple(s2)}"
   try:
     q, r = _ensure_poly(sz1).divmod(sz2)
   except InconclusiveDimensionOperation:
     raise InconclusiveDimensionOperation(err_msg)
   if r != 0:
     raise InconclusiveDimensionOperation(err_msg)
   return q  # type: ignore[return-value]
コード例 #8
0
def _resize_nearest(x, output_shape: core.Shape):
  input_shape = x.shape
  assert len(input_shape) == len(output_shape)
  spatial_dims = tuple(i for i in range(len(input_shape))
                       if not core.symbolic_equal_dim(input_shape[i], output_shape[i]))
  for d in spatial_dims:
    m = input_shape[d]
    n = output_shape[d]
    offsets = (jnp.arange(n) + 0.5) * core.dimension_as_value(m) / core.dimension_as_value(n)
    # TODO(b/206898375): this computation produces the wrong result on
    # CPU and GPU when using float64. Use float32 until the bug is fixed.
    offsets = jnp.floor(offsets.astype(np.float32)).astype(np.int32)
    indices = [slice(None)] * len(input_shape)
    indices[d] = offsets
    x = x[tuple(indices)]
  return x
コード例 #9
0
ファイル: impl_no_xla.py プロジェクト: matthewfeickert/jax
def _pre_gather_for_multidim_indexing(args: GatherArgs):
    """Returns True if this call to gather represents multi-dimensional indexing.

  E.g., jnp.take(op, [[0], [1]], axis=0).
  Note we currently only support multi-dimensional indexing if the last
  dimension is 1.
  """
    # Handle only the case when tf.gather argument batch_dims=0.
    # Find axis to match the tf.gather semantics
    # Let I = len(start_indices_shape)
    # let O = len(op_shape)
    # slice_sizes == op_shape[:axis] + (1,) + op_shape[axis+1:]
    # collapsed_slice_dims == (axis,)
    # start_index_map == (axis,)
    # offset_dims == (0, 1, ..., axis - 1, axis + I, ..., O + I - 1)
    # We added a trailing dimension of size 1
    op_shape = args.op_shape
    start_index_map = args.dnums.start_index_map
    collapsed_slice_dims = args.dnums.collapsed_slice_dims
    offset_dims = args.dnums.offset_dims
    if not (len(op_shape) >= 1 and len(start_index_map) == 1
            and len(collapsed_slice_dims) == 1 and collapsed_slice_dims[0]
            == start_index_map[0] and len(offset_dims) == len(op_shape) - 1):
        raise ValueError("unsupported dimension numbers")
    # We added a trailing dimension of size 1
    if not core.symbolic_equal_dim(args.start_indices_shape[-1], 1):
        raise ValueError("start_indices shape[-1] should be 1")
    # Guess the axis
    axis = collapsed_slice_dims[0]
    index_dims = len(args.start_indices_shape) - 1
    expected_offset_dims = tuple(
        list(range(axis)) +
        list(range(axis + index_dims,
                   len(op_shape) + index_dims - 1)))
    if offset_dims != expected_offset_dims:
        raise ValueError("unsupported offset_dims")
    expected_slice_sizes = op_shape[:axis] + (
        1, ) + op_shape[axis + 1:]  # type: ignore
    if not core.symbolic_equal_shape(args.slice_sizes, expected_slice_sizes):
        raise ValueError("unsupported slice_sizes")
コード例 #10
0
ファイル: convolution.py プロジェクト: frederikwilde/jax
def _conv_general_dilated_shape_rule(lhs: core.ShapedArray,
                                     rhs: core.ShapedArray, *, window_strides,
                                     padding, lhs_dilation, rhs_dilation,
                                     dimension_numbers, feature_group_count,
                                     batch_group_count,
                                     **unused_kwargs) -> Tuple[int, ...]:
    assert type(dimension_numbers) is ConvDimensionNumbers
    if len(lhs.shape) != len(rhs.shape):
        msg = ("conv_general_dilated lhs and rhs must have the same number of "
               "dimensions, but got {} and {}.")
        raise ValueError(msg.format(lhs.shape, rhs.shape))
    if not feature_group_count > 0:
        msg = ("conv_general_dilated feature_group_count "
               "must be a positive integer, got {}.")
        raise ValueError(msg.format(feature_group_count))
    lhs_feature_count = lhs.shape[dimension_numbers.lhs_spec[1]]
    quot, rem = divmod(lhs_feature_count, feature_group_count)
    if rem:
        msg = (
            "conv_general_dilated feature_group_count must divide lhs feature "
            "dimension size, but {} does not divide {}.")
        raise ValueError(msg.format(feature_group_count, lhs_feature_count))
    if not core.symbolic_equal_dim(quot,
                                   rhs.shape[dimension_numbers.rhs_spec[1]]):
        msg = (
            "conv_general_dilated lhs feature dimension size divided by "
            "feature_group_count must equal the rhs input feature dimension "
            "size, but {} // {} != {}.")
        raise ValueError(
            msg.format(lhs_feature_count, feature_group_count,
                       rhs.shape[dimension_numbers.rhs_spec[1]]))
    if rhs.shape[dimension_numbers.rhs_spec[0]] % feature_group_count:
        msg = (
            "conv_general_dilated rhs output feature dimension size must be a "
            "multiple of feature_group_count, but {} is not a multiple of {}.")
        raise ValueError(
            msg.format(rhs.shape[dimension_numbers.rhs_spec[0]],
                       feature_group_count))

    if not batch_group_count > 0:
        msg = ("conv_general_dilated batch_group_count "
               "must be a positive integer, got {}.")
        raise ValueError(msg.format(batch_group_count))
    lhs_batch_count = lhs.shape[dimension_numbers.lhs_spec[0]]
    if batch_group_count > 1 and lhs_batch_count % batch_group_count != 0:
        msg = ("conv_general_dilated batch_group_count must divide lhs batch "
               "dimension size, but {} does not divide {}.")
        raise ValueError(msg.format(batch_group_count, lhs_batch_count))

    if rhs.shape[dimension_numbers.rhs_spec[0]] % batch_group_count:
        msg = (
            "conv_general_dilated rhs output feature dimension size must be a "
            "multiple of batch_group_count, but {} is not a multiple of {}.")
        raise ValueError(
            msg.format(rhs.shape[dimension_numbers.rhs_spec[0]],
                       batch_group_count))

    if batch_group_count > 1 and feature_group_count > 1:
        msg = (
            "At most one of batch_group_count and feature_group_count may be > "
            "1, got batch_group_count={} and feature_group_count={}")
        raise ValueError(msg.format(batch_group_count, feature_group_count))

    if len(_conv_sdims(dimension_numbers.rhs_spec)) != len(window_strides):
        msg = ("conv_general_dilated window and window_strides must have "
               "the same number of dimensions, but got {} and {}")
        raise ValueError(
            msg.format(len(_conv_sdims(dimension_numbers.rhs_spec)),
                       len(window_strides)))

    lhs_perm, rhs_perm, out_perm = dimension_numbers
    lhs_trans = lax._dilate_shape(np.take(lhs.shape, lhs_perm), lhs_dilation)
    rhs_trans = lax._dilate_shape(np.take(rhs.shape, rhs_perm), rhs_dilation)
    out_trans = conv_shape_tuple(lhs_trans, rhs_trans, window_strides, padding,
                                 batch_group_count)
    return tuple(np.take(out_trans,
                         np.argsort(out_perm)))  # type: ignore[arg-type]