def _dct2(x, axes, norm): axis1, axis2 = map(partial(_canonicalize_axis, num_dims=x.ndim), axes) N1, N2 = x.shape[axis1], x.shape[axis2] v = _dct_interleave(_dct_interleave(x, axis1), axis2) V = jnp.fft.fftn(v, axes=axes) k1 = lax.expand_dims(jnp.arange(N1), [a for a in range(x.ndim) if a != axis1]) k2 = lax.expand_dims(jnp.arange(N2), [a for a in range(x.ndim) if a != axis2]) out = _W4(N1, k1) * (_W4(N2, k2) * V + _W4(N2, -k2) * jnp.roll(jnp.flip(V, axis=axis2), shift=1, axis=axis2)) out = 2 * out.real if norm == 'ortho': return _dct_ortho_norm(_dct_ortho_norm(out, axis1), axis2) return out
def _dct_ortho_norm(out, axis): factor = lax.concatenate([ lax.full((1, ), 4, out.dtype), lax.full((out.shape[axis] - 1, ), 2, out.dtype) ], 0) factor = lax.expand_dims(factor, [a for a in range(out.ndim) if a != axis]) return out / lax.sqrt(factor * out.shape[axis])
def logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False): if b is not None: a, b = jnp.broadcast_arrays(a, b) dims = _reduction_dims(a, axis) dimadd = lambda x: lax.expand_dims(x, dims) amax = lax.reduce(a, _constant_like(a, -np.inf), lax.max, dims) amax = lax.stop_gradient( lax.select(lax.is_finite(amax), amax, lax.full_like(amax, 0))) amax_singletons = dimadd(amax) if b is None: out = lax.add( lax.log( lax.reduce(lax.exp(lax.sub(a, amax_singletons)), _constant_like(a, 0), lax.add, dims)), amax) sign = jnp.where(jnp.isnan(out), np.nan, 1.0).astype(out.dtype) sign = jnp.where(out == -np.inf, 0.0, sign) else: sumexp = lax.reduce(lax.mul(lax.exp(lax.sub(a, amax_singletons)), b), _constant_like(a, 0), lax.add, dims) sign = lax.stop_gradient(lax.sign(sumexp)) out = lax.add(lax.log(lax.abs(sumexp)), amax) if return_sign: return (dimadd(out), dimadd(sign)) if keepdims else (out, sign) if b is not None: out = jnp.where(sign < 0, np.nan, out) return dimadd(out) if keepdims else out
def logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False): if b is not None: a, b = _promote_args_inexact("logsumexp", a, b) a = jnp.where(b != 0, a, -jnp.inf) pos_dims, dims = _reduction_dims(a, axis) amax = jnp.max(a, axis=dims, keepdims=keepdims) amax = lax.stop_gradient( lax.select(lax.is_finite(amax), amax, lax.full_like(amax, 0))) amax_with_dims = amax if keepdims else lax.expand_dims(amax, pos_dims) if b is None: out = lax.add( lax.log( jnp.sum(lax.exp(lax.sub(a, amax_with_dims)), axis=dims, keepdims=keepdims)), amax) sign = jnp.where(jnp.isnan(out), np.nan, 1.0).astype(out.dtype) sign = jnp.where(out == -np.inf, 0.0, sign) else: sumexp = jnp.sum(lax.mul(lax.exp(lax.sub(a, amax_with_dims)), b), axis=dims, keepdims=keepdims) sign = lax.stop_gradient(lax.sign(sumexp)) out = lax.add(lax.log(lax.abs(sumexp)), amax) if return_sign: return (out, sign) if b is not None: out = jnp.where(sign < 0, np.nan, out) return out
def _pinv_jvp(rcond, primals, tangents): # The Differentiation of Pseudo-Inverses and Nonlinear Least Squares Problems # Whose Variables Separate. Author(s): G. H. Golub and V. Pereyra. SIAM # Journal on Numerical Analysis, Vol. 10, No. 2 (Apr., 1973), pp. 413-432. # (via https://en.wikipedia.org/wiki/Moore%E2%80%93Penrose_inverse#Derivative) a, = primals a_dot, = tangents p = pinv(a, rcond=rcond) m, n = a.shape[-2:] # TODO(phawkins): on TPU, we would need to opt into high precision here. # TODO(phawkins): consider if this can be simplified in the Hermitian case. p_dot = -p @ a_dot @ p I_n = lax.expand_dims(jnp.eye(m, dtype=a.dtype), range(a.ndim - 2)) p_dot = p_dot + p @ _H(p) @ _H(a_dot) @ (I_n - a @ p) I_m = lax.expand_dims(jnp.eye(n, dtype=a.dtype), range(a.ndim - 2)) p_dot = p_dot + (I_m - p @ a) @ _H(a_dot) @ _H(p) @ p return p, p_dot
def one_hot(x: Array, num_classes: int, *, dtype: Any = jnp.float64, axis: Union[int, AxisName] = -1) -> Array: """One-hot encodes the given indicies. Each index in the input ``x`` is encoded as a vector of zeros of length ``num_classes`` with the element at ``index`` set to one:: >>> jax.nn.one_hot(jnp.array([0, 1, 2]), 3) DeviceArray([[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]], dtype=float32) Indicies outside the range [0, num_classes) will be encoded as zeros:: >>> jax.nn.one_hot(jnp.array([-1, 3]), 3) DeviceArray([[0., 0., 0.], [0., 0., 0.]], dtype=float32) Args: x: A tensor of indices. num_classes: Number of classes in the one-hot dimension. dtype: optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32). axis: the axis or axes along which the function should be computed. """ num_classes = core.concrete_or_error( int, num_classes, "The error arose in jax.nn.one_hot argument `num_classes`.") dtype = dtypes.canonicalize_dtype(dtype) x = jnp.asarray(x) try: output_pos_axis = util.canonicalize_axis(axis, x.ndim + 1) except TypeError: axis_size = lax.psum(1, axis) if num_classes != axis_size: raise ValueError( f"Expected num_classes to match the size of axis {axis}, " f"but {num_classes} != {axis_size}") from None axis_idx = lax.axis_index(axis) return jnp.asarray(x == axis_idx, dtype=dtype) axis = operator.index(axis) lhs = lax.expand_dims(x, (axis, )) rhs_shape = [1] * x.ndim rhs_shape.insert(output_pos_axis, num_classes) rhs = lax.broadcast_in_dim(jnp.arange(num_classes, dtype=x.dtype), rhs_shape, (output_pos_axis, )) return jnp.asarray(lhs == rhs, dtype=dtype)
def pinv(a, rcond=None): # Uses same algorithm as # https://github.com/numpy/numpy/blob/v1.17.0/numpy/linalg/linalg.py#L1890-L1979 a = jnp.conj(a) if rcond is None: max_rows_cols = max(a.shape[-2:]) rcond = 10. * max_rows_cols * jnp.finfo(a.dtype).eps rcond = jnp.asarray(rcond) u, s, vh = svd(a, full_matrices=False) # Singular values less than or equal to ``rcond * largest_singular_value`` # are set to zero. rcond = lax.expand_dims(rcond[..., jnp.newaxis], range(s.ndim - rcond.ndim - 1)) cutoff = rcond * jnp.amax(s, axis=-1, keepdims=True, initial=-jnp.inf) s = jnp.where(s > cutoff, s, jnp.inf) res = jnp.matmul(_T(vh), jnp.divide(_T(u), s[..., jnp.newaxis])) return lax.convert_element_type(res, a.dtype)
def _slogdet_lu(a): dtype = lax.dtype(a) lu, pivot, _ = lax_linalg.lu(a) diag = jnp.diagonal(lu, axis1=-2, axis2=-1) is_zero = jnp.any(diag == jnp.array(0, dtype=dtype), axis=-1) iota = lax.expand_dims(jnp.arange(a.shape[-1]), range(pivot.ndim - 1)) parity = jnp.count_nonzero(pivot != iota, axis=-1) if jnp.iscomplexobj(a): sign = jnp.prod(diag / jnp.abs(diag), axis=-1) else: sign = jnp.array(1, dtype=dtype) parity = parity + jnp.count_nonzero(diag < 0, axis=-1) sign = jnp.where(is_zero, jnp.array(0, dtype=dtype), sign * jnp.array(-2 * (parity % 2) + 1, dtype=dtype)) logdet = jnp.where(is_zero, jnp.array(-jnp.inf, dtype=dtype), jnp.sum(jnp.log(jnp.abs(diag)), axis=-1)) return sign, jnp.real(logdet)
def dct(x, type=2, n=None, axis=-1, norm=None): if type != 2: raise NotImplementedError('Only DCT type 2 is implemented.') axis = _canonicalize_axis(axis, x.ndim) if n is not None: x = lax.pad(x, jnp.array(0, x.dtype), [(0, n - x.shape[axis] if a == axis else 0, 0) for a in range(x.ndim)]) N = x.shape[axis] v = _dct_interleave(x, axis) V = jnp.fft.fft(v, axis=axis) k = lax.expand_dims(jnp.arange(N), [a for a in range(x.ndim) if a != axis]) out = V * _W4(N, k) out = 2 * out.real if norm == 'ortho': out = _dct_ortho_norm(out, axis) return out
def logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False): if b is not None: a, b = _promote_args_inexact("logsumexp", a, b) a = jnp.where(b != 0, a, -jnp.inf) else: a, = _promote_args_inexact("logsumexp", a) pos_dims, dims = _reduction_dims(a, axis) amax = jnp.max(a, axis=dims, keepdims=keepdims) amax = lax.stop_gradient( lax.select(jnp.isfinite(amax), amax, lax.full_like(amax, 0))) amax_with_dims = amax if keepdims else lax.expand_dims(amax, pos_dims) # fast path if the result cannot be negative. if b is None and not np.issubdtype(a.dtype, np.complexfloating): out = lax.add( lax.log( jnp.sum(lax.exp(lax.sub(a, amax_with_dims)), axis=dims, keepdims=keepdims)), amax) sign = jnp.where(jnp.isnan(out), out, 1.0) sign = jnp.where(jnp.isneginf(out), 0.0, sign).astype(out.dtype) else: expsub = lax.exp(lax.sub(a, amax_with_dims)) if b is not None: expsub = lax.mul(expsub, b) sumexp = jnp.sum(expsub, axis=dims, keepdims=keepdims) sign = lax.stop_gradient(jnp.sign(sumexp)) if np.issubdtype(sumexp.dtype, np.complexfloating): if return_sign: sumexp = sign * sumexp out = lax.add(lax.log(sumexp), amax) else: out = lax.add(lax.log(lax.abs(sumexp)), amax) if return_sign: return (out, sign) if b is not None: if not np.issubdtype(out.dtype, np.complexfloating): with jax.debug_nans(False): out = jnp.where(sign < 0, jnp.array(np.nan, dtype=out.dtype), out) return out
def _irfft_transpose(t, fft_lengths): # The transpose of IRFFT is the RFFT of the cotangent times a scaling # factor and a mask. The mask scales the cotangent for the Hermitian # symmetric components of the RFFT by a factor of two, since these components # are de-duplicated in the RFFT. x = fft(t, xla_client.FftType.RFFT, fft_lengths) n = x.shape[-1] is_odd = fft_lengths[-1] % 2 full = partial(lax.full_like, t, dtype=t.dtype) mask = lax.concatenate([ full(1.0, shape=(1, )), full(2.0, shape=(n - 2 + is_odd, )), full(1.0, shape=(1 - is_odd, )) ], dimension=0) scale = 1 / prod(fft_lengths) out = scale * lax.expand_dims(mask, range(x.ndim - 1)) * x assert out.dtype == _complex_dtype(t.dtype), (out.dtype, t.dtype) # Use JAX's convention for complex gradients # https://github.com/google/jax/issues/6223#issuecomment-807740707 return lax.conj(out)
def slogdet(a): a = _promote_arg_dtypes(jnp.asarray(a)) dtype = lax.dtype(a) a_shape = jnp.shape(a) if len(a_shape) < 2 or a_shape[-1] != a_shape[-2]: msg = "Argument to slogdet() must have shape [..., n, n], got {}" raise ValueError(msg.format(a_shape)) lu, pivot, _ = lax_linalg.lu(a) diag = jnp.diagonal(lu, axis1=-2, axis2=-1) is_zero = jnp.any(diag == jnp.array(0, dtype=dtype), axis=-1) iota = lax.expand_dims(jnp.arange(a.shape[-1]), range(pivot.ndim - 1)) parity = jnp.count_nonzero(pivot != iota, axis=-1) if jnp.iscomplexobj(a): sign = jnp.prod(diag / jnp.abs(diag), axis=-1) else: sign = jnp.array(1, dtype=dtype) parity = parity + jnp.count_nonzero(diag < 0, axis=-1) sign = jnp.where(is_zero, jnp.array(0, dtype=dtype), sign * jnp.array(-2 * (parity % 2) + 1, dtype=dtype)) logdet = jnp.where(is_zero, jnp.array(-jnp.inf, dtype=dtype), jnp.sum(jnp.log(jnp.abs(diag)), axis=-1)) return sign, jnp.real(logdet)
def _one_hot(x: Array, num_classes: int, *, dtype: Any, axis: Union[int, AxisName]) -> Array: num_classes = core.concrete_or_error( int, num_classes, "The error arose in jax.nn.one_hot argument `num_classes`.") dtype = dtypes.canonicalize_dtype(dtype) x = jnp.asarray(x) try: output_pos_axis = util.canonicalize_axis(axis, x.ndim + 1) except TypeError: axis_size = lax.psum(1, axis) if num_classes != axis_size: raise ValueError( f"Expected num_classes to match the size of axis {axis}, " f"but {num_classes} != {axis_size}") from None axis_idx = lax.axis_index(axis) return jnp.asarray(x == axis_idx, dtype=dtype) axis = operator.index(axis) # type: ignore[arg-type] lhs = lax.expand_dims(x, (axis, )) rhs_shape = [1] * x.ndim rhs_shape.insert(output_pos_axis, num_classes) rhs = lax.broadcasted_iota(x.dtype, rhs_shape, output_pos_axis) return jnp.asarray(lhs == rhs, dtype=dtype)
def _reduction(a, name, np_fun, op, init_val, has_identity=True, preproc=None, bool_op=None, upcast_f16_for_computation=False, axis=None, dtype=None, out=None, keepdims=False, initial=None, where_=None, parallel_reduce=None): bool_op = bool_op or op # Note: we must accept out=None as an argument, because numpy reductions delegate to # object methods. For example `np.sum(x)` will call `x.sum()` if the `sum()` method # exists, passing along all its arguments. if out is not None: raise NotImplementedError( f"The 'out' argument to jnp.{name} is not supported.") _check_arraylike(name, a) lax_internal._check_user_dtype_supported(dtype, name) axis = core.concrete_or_error(None, axis, f"axis argument to jnp.{name}().") if initial is None and not has_identity and where_ is not None: raise ValueError( f"reduction operation {name} does not have an identity, so to use a " f"where mask one has to specify 'initial'") a = a if isinstance(a, ndarray) else _asarray(a) a = preproc(a) if preproc else a pos_dims, dims = _reduction_dims(a, axis) if initial is None and not has_identity: shape = np.shape(a) if not _all(core.greater_equal_dim(shape[d], 1) for d in pos_dims): raise ValueError( f"zero-size array to reduction operation {name} which has no identity" ) result_dtype = dtypes.canonicalize_dtype( dtype or dtypes.dtype(np_fun(np.ones((), dtype=dtypes.dtype(a))))) if upcast_f16_for_computation and dtypes.issubdtype( result_dtype, np.inexact): computation_dtype = _upcast_f16(result_dtype) else: computation_dtype = result_dtype a = lax.convert_element_type(a, computation_dtype) op = op if computation_dtype != np.bool_ else bool_op # NB: in XLA, init_val must be an identity for the op, so the user-specified # initial value must be applied afterward. init_val = _reduction_init_val(a, init_val) if where_ is not None: a = _where(where_, a, init_val) if pos_dims is not dims: if parallel_reduce is None: raise NotImplementedError( f"Named reductions not implemented for jnp.{name}()") result = parallel_reduce(a, dims) else: result = lax.reduce(a, init_val, op, dims) if initial is not None: result = op(lax.convert_element_type(initial, a.dtype), result) if keepdims: result = lax.expand_dims(result, pos_dims) return lax.convert_element_type(result, dtype or result_dtype)
def _cofactor_solve(a, b): """Equivalent to det(a)*solve(a, b) for nonsingular mat. Intermediate function used for jvp and vjp of det. This function borrows heavily from jax.numpy.linalg.solve and jax.numpy.linalg.slogdet to compute the gradient of the determinant in a way that is well defined even for low rank matrices. This function handles two different cases: * rank(a) == n or n-1 * rank(a) < n-1 For rank n-1 matrices, the gradient of the determinant is a rank 1 matrix. Rather than computing det(a)*solve(a, b), which would return NaN, we work directly with the LU decomposition. If a = p @ l @ u, then det(a)*solve(a, b) = prod(diag(u)) * u^-1 @ l^-1 @ p^-1 b = prod(diag(u)) * triangular_solve(u, solve(p @ l, b)) If a is rank n-1, then the lower right corner of u will be zero and the triangular_solve will fail. Let x = solve(p @ l, b) and y = det(a)*solve(a, b). Then y_{n} x_{n} / u_{nn} * prod_{i=1...n}(u_{ii}) = x_{n} * prod_{i=1...n-1}(u_{ii}) So by replacing the lower-right corner of u with prod_{i=1...n-1}(u_{ii})^-1 we can avoid the triangular_solve failing. To correctly compute the rest of y_{i} for i != n, we simply multiply x_{i} by det(a) for all i != n, which will be zero if rank(a) = n-1. For the second case, a check is done on the matrix to see if `solve` returns NaN or Inf, and gives a matrix of zeros as a result, as the gradient of the determinant of a matrix with rank less than n-1 is 0. This will still return the correct value for rank n-1 matrices, as the check is applied *after* the lower right corner of u has been updated. Args: a: A square matrix or batch of matrices, possibly singular. b: A matrix, or batch of matrices of the same dimension as a. Returns: det(a) and cofactor(a)^T*b, aka adjugate(a)*b """ a = _promote_arg_dtypes(jnp.asarray(a)) b = _promote_arg_dtypes(jnp.asarray(b)) a_shape = jnp.shape(a) b_shape = jnp.shape(b) a_ndims = len(a_shape) if not (a_ndims >= 2 and a_shape[-1] == a_shape[-2] and b_shape[-2:] == a_shape[-2:]): msg = ("The arguments to _cofactor_solve must have shapes " "a=[..., m, m] and b=[..., m, m]; got a={} and b={}") raise ValueError(msg.format(a_shape, b_shape)) if a_shape[-1] == 1: return a[..., 0, 0], b # lu contains u in the upper triangular matrix and l in the strict lower # triangular matrix. # The diagonal of l is set to ones without loss of generality. lu, pivots, permutation = lax_linalg.lu(a) dtype = lax.dtype(a) batch_dims = lax.broadcast_shapes(lu.shape[:-2], b.shape[:-2]) x = jnp.broadcast_to(b, batch_dims + b.shape[-2:]) lu = jnp.broadcast_to(lu, batch_dims + lu.shape[-2:]) # Compute (partial) determinant, ignoring last diagonal of LU diag = jnp.diagonal(lu, axis1=-2, axis2=-1) iota = lax.expand_dims(jnp.arange(a_shape[-1]), range(pivots.ndim - 1)) parity = jnp.count_nonzero(pivots != iota, axis=-1) sign = jnp.asarray(-2 * (parity % 2) + 1, dtype=dtype) # partial_det[:, -1] contains the full determinant and # partial_det[:, -2] contains det(u) / u_{nn}. partial_det = jnp.cumprod(diag, axis=-1) * sign[..., None] lu = lu.at[..., -1, -1].set(1.0 / partial_det[..., -2]) permutation = jnp.broadcast_to(permutation, batch_dims + (a_shape[-1], )) iotas = jnp.ix_(*(lax.iota(jnp.int32, b) for b in batch_dims + (1, ))) # filter out any matrices that are not full rank d = jnp.ones(x.shape[:-1], x.dtype) d = lax_linalg.triangular_solve(lu, d, left_side=True, lower=False) d = jnp.any(jnp.logical_or(jnp.isnan(d), jnp.isinf(d)), axis=-1) d = jnp.tile(d[..., None, None], d.ndim * (1, ) + x.shape[-2:]) x = jnp.where(d, jnp.zeros_like(x), x) # first filter x = x[iotas[:-1] + (permutation, slice(None))] x = lax_linalg.triangular_solve(lu, x, left_side=True, lower=True, unit_diagonal=True) x = jnp.concatenate( (x[..., :-1, :] * partial_det[..., -1, None, None], x[..., -1:, :]), axis=-2) x = lax_linalg.triangular_solve(lu, x, left_side=True, lower=False) x = jnp.where(d, jnp.zeros_like(x), x) # second filter return partial_det[..., -1], x
def expand_dims(self, dimensions: Sequence[int]): # follows lax.expand_dims, not jnp.expand_dims, so dimensions is a sequence ndim_out = self.ndim + len(set(dimensions)) dimensions = [canonicalize_axis(d, ndim_out) for d in dimensions] return PRNGKeyArray(self.impl, lax.expand_dims(self._keys, dimensions))
def _unique(ar, axis, return_index=False, return_inverse=False, return_counts=False, size=None, fill_value=None, return_true_size=False): """ Find the unique elements of an array along a particular axis. """ if ar.shape[axis] == 0 and size and fill_value is None: raise ValueError( "jnp.unique: for zero-sized input with nonzero size argument, fill_value must be specified" ) aux, mask, perm = _unique_sorted_mask(ar, axis) if size is None: ind = core.concrete_or_error( None, mask, "The error arose in jnp.unique(). " + UNIQUE_SIZE_HINT) else: ind = nonzero(mask, size=size)[0] result = aux[ind] if aux.size else aux if fill_value is not None: fill_value = asarray(fill_value, dtype=result.dtype) if size is not None and fill_value is not None: if result.shape[0]: valid = lax.expand_dims( arange(size) < mask.sum(), tuple(range(1, result.ndim))) result = where(valid, result, fill_value) else: result = full_like(result, fill_value, shape=(size, *result.shape[1:])) result = moveaxis(result, 0, axis) ret = (result, ) if return_index: if aux.size: ret += (perm[ind], ) else: ret += (perm, ) if return_inverse: if aux.size: imask = cumsum(mask) - 1 inv_idx = zeros(mask.shape, dtype=dtypes.canonicalize_dtype(dtypes.int_)) inv_idx = inv_idx.at[perm].set(imask) else: inv_idx = zeros(ar.shape[axis], dtype=int) ret += (inv_idx, ) if return_counts: if aux.size: if size is None: idx = append(nonzero(mask)[0], mask.size) else: idx = nonzero(mask, size=size + 1)[0] idx = idx.at[1:].set(where(idx[1:], idx[1:], mask.size)) ret += (diff(idx), ) elif ar.shape[axis]: ret += (array([ar.shape[axis]], dtype=dtypes.canonicalize_dtype(dtypes.int_)), ) else: ret += (empty(0, dtype=int), ) if return_true_size: # Useful for internal uses of unique(). ret += (mask.sum(), ) return ret[0] if len(ret) == 1 else ret