Exemplo n.º 1
0
def _mean(a,
          axis: Optional[Union[int, Tuple[int, ...]]] = None,
          dtype=None,
          out=None,
          keepdims=False,
          *,
          where=None):
    _check_arraylike("mean", a)
    lax_internal._check_user_dtype_supported(dtype, "mean")
    if out is not None:
        raise NotImplementedError(
            "The 'out' argument to jnp.mean is not supported.")

    if where is None:
        if axis is None:
            normalizer = core.dimension_as_value(np.size(a))
        else:
            normalizer = core.dimension_as_value(_axis_size(a, axis))
    else:
        normalizer = sum(_broadcast_to(where, np.shape(a)),
                         axis,
                         dtype=dtype,
                         keepdims=keepdims)

    if dtype is None:
        dtype = dtypes._to_inexact_dtype(dtypes.dtype(a))
    dtype = dtypes.canonicalize_dtype(dtype)

    return lax.div(sum(a, axis, dtype=dtype, keepdims=keepdims, where=where),
                   lax.convert_element_type(normalizer, dtype))
Exemplo n.º 2
0
def _resize(image, shape: core.Shape, method: Union[str, ResizeMethod],
            antialias: bool, precision):
    if len(shape) != image.ndim:
        msg = (
            'shape must have length equal to the number of dimensions of x; '
            f' {shape} vs {image.shape}')
        raise ValueError(msg)
    if isinstance(method, str):
        method = ResizeMethod.from_string(method)
    if method == ResizeMethod.NEAREST:
        return _resize_nearest(image, shape)
    assert isinstance(method, ResizeMethod)
    kernel = _kernels[method]

    if not jnp.issubdtype(image.dtype, jnp.inexact):
        image = lax.convert_element_type(image,
                                         jnp.result_type(image, jnp.float32))
    # Skip dimensions that have scale=1 and translation=0, this is only possible
    # since all of the current resize methods (kernels) are interpolating, so the
    # output = input under an identity warp.
    spatial_dims = tuple(
        i for i in range(len(shape))
        if not core.symbolic_equal_dim(image.shape[i], shape[i]))
    scale = [
        1.0 if core.symbolic_equal_dim(
            shape[d], 0) else core.dimension_as_value(shape[d]) /
        core.dimension_as_value(image.shape[d]) for d in spatial_dims
    ]
    return _scale_and_translate(image, shape, spatial_dims, scale,
                                [0.] * len(spatial_dims), kernel, antialias,
                                precision)
Exemplo n.º 3
0
def _average(a,
             axis: Optional[Union[int, Tuple[int, ...]]] = None,
             weights=None,
             returned=False):
    a = _asarray(a)

    if weights is None:  # Treat all weights as 1
        avg = mean(a, axis=axis)
        if axis is None:
            weights_sum = lax.full((),
                                   core.dimension_as_value(np.size(a)),
                                   dtype=avg.dtype)
        else:
            weights_sum = lax.full_like(avg,
                                        core.dimension_as_value(a.shape[axis]),
                                        dtype=avg.dtype)
    else:
        weights = _asarray(weights)

        if dtypes.issubdtype(a.dtype, np.inexact):
            out_dtype = dtypes.result_type(a.dtype, weights.dtype)
        else:
            out_dtype = dtypes.result_type(a.dtype, weights.dtype,
                                           dtypes.float_)
        out_dtype = dtypes.canonicalize_dtype(out_dtype)

        a_shape = np.shape(a)
        a_ndim = len(a_shape)
        weights_shape = np.shape(weights)
        axis = None if axis is None else _canonicalize_axis(axis, a_ndim)

        if a_shape != weights_shape:
            # Make sure the dimensions work out
            if axis is None:
                raise ValueError("Axis must be specified when shapes of a and "
                                 "weights differ.")
            if len(weights_shape) != 1:
                raise ValueError("1D weights expected when shapes of a and "
                                 "weights differ.")
            if not core.symbolic_equal_dim(weights_shape[0], a_shape[axis]):
                raise ValueError("Length of weights not "
                                 "compatible with specified axis.")

            weights = _broadcast_to(weights,
                                    (a_ndim - 1) * (1, ) + weights_shape)
            weights = _moveaxis(weights, -1, axis)

        weights_sum = sum(weights, axis=axis, dtype=out_dtype)
        avg = sum(a * weights, axis=axis, dtype=out_dtype) / weights_sum

    if returned:
        if avg.shape != weights_sum.shape:
            weights_sum = _broadcast_to(weights_sum, avg.shape)
        return avg, weights_sum
    return avg
Exemplo n.º 4
0
def compute_weight_mat(input_size: core.DimSize, output_size: core.DimSize,
                       scale, translation, kernel: Callable, antialias: bool):
    inv_scale = 1. / scale
    # When downsampling the kernel should be scaled since we want to low pass
    # filter and interpolate, but when upsampling it should not be since we only
    # want to interpolate.
    kernel_scale = jnp.maximum(inv_scale, 1.) if antialias else 1.

    sample_f = ((jnp.arange(output_size) + 0.5) * inv_scale -
                translation * inv_scale - 0.5)
    x = (
        jnp.abs(sample_f[jnp.newaxis, :] -
                jnp.arange(input_size, dtype=sample_f.dtype)[:, jnp.newaxis]) /
        kernel_scale)
    weights = kernel(x)

    total_weight_sum = jnp.sum(weights, axis=0, keepdims=True)
    weights = jnp.where(
        jnp.abs(total_weight_sum) > 1000. * np.finfo(np.float32).eps,
        jnp.divide(weights,
                   jnp.where(total_weight_sum != 0, total_weight_sum, 1)), 0)
    # Zero out weights where the sample location is completely outside the input
    # range.
    # Note sample_f has already had the 0.5 removed, hence the weird range below.
    input_size_minus_0_5 = core.dimension_as_value(input_size) - 0.5
    return jnp.where(
        jnp.logical_and(sample_f >= -0.5,
                        sample_f <= input_size_minus_0_5)[jnp.newaxis, :],
        weights, 0)
Exemplo n.º 5
0
def _resize_nearest(x, output_shape: core.Shape):
  input_shape = x.shape
  assert len(input_shape) == len(output_shape)
  spatial_dims = tuple(i for i in range(len(input_shape))
                       if not core.symbolic_equal_dim(input_shape[i], output_shape[i]))
  for d in spatial_dims:
    m = input_shape[d]
    n = output_shape[d]
    offsets = (jnp.arange(n) + 0.5) * core.dimension_as_value(m) / core.dimension_as_value(n)
    # TODO(b/206898375): this computation produces the wrong result on
    # CPU and GPU when using float64. Use float32 until the bug is fixed.
    offsets = jnp.floor(offsets.astype(np.float32)).astype(np.int32)
    indices = [slice(None)] * len(input_shape)
    indices[d] = offsets
    x = x[tuple(indices)]
  return x
Exemplo n.º 6
0
def _var(a,
         axis: Optional[Union[int, Tuple[int, ...]]] = None,
         dtype=None,
         out=None,
         ddof=0,
         keepdims=False,
         *,
         where=None):
    _check_arraylike("var", a)
    lax_internal._check_user_dtype_supported(dtype, "var")
    if out is not None:
        raise NotImplementedError(
            "The 'out' argument to jnp.var is not supported.")

    computation_dtype, dtype = _var_promote_types(dtypes.dtype(a), dtype)
    a = a.astype(computation_dtype)
    a_mean = mean(a, axis, dtype=computation_dtype, keepdims=True, where=where)
    centered = lax.sub(a, a_mean)
    if dtypes.issubdtype(centered.dtype, np.complexfloating):
        centered = lax.real(lax.mul(centered, lax.conj(centered)))
    else:
        centered = lax.square(centered)

    if where is None:
        if axis is None:
            normalizer = core.dimension_as_value(np.size(a))
        else:
            normalizer = core.dimension_as_value(_axis_size(a, axis))
    else:
        normalizer = sum(_broadcast_to(where, np.shape(a)),
                         axis,
                         dtype=dtype,
                         keepdims=keepdims)
    normalizer = normalizer - ddof

    result = sum(centered, axis, keepdims=keepdims, where=where)
    out = lax.div(result, lax.convert_element_type(normalizer, result.dtype))
    return lax.convert_element_type(out, dtype)