def _gamma_jax(shape, alpha, beta=None, dtype=np.float32, seed=None, name=None): # pylint: disable=unused-argument """JAX-based reparameterized gamma sampler.""" dtype = utils.common_dtype([alpha, beta], dtype_hint=dtype) alpha = np.array(alpha, dtype=dtype) beta = None if beta is None else np.array(beta, dtype=dtype) shape = _ensure_tuple(shape) import jax.random as jaxrand # pylint: disable=g-import-not-at-top if seed is None: raise ValueError('Must provide PRNGKey to sample in JAX.') # TODO(srvasude): Sample in the given dtype once # https://github.com/google/jax/issues/2130 is fixed. samps = jaxrand.gamma(key=seed, a=alpha, shape=shape, dtype=np.float64).astype(dtype) # Match the 0->tiny behavior of tf.random.gamma. return np.maximum( np.finfo(dtype).tiny, samps if beta is None else samps / beta)
def _one_hot( # pylint: disable=unused-argument indices, depth, on_value=None, off_value=None, axis=None, dtype=None, name=None): """One hot.""" if on_value is None: on_value = 1 if off_value is None: off_value = 0 if dtype is None: dtype = utils.common_dtype([on_value, off_value], np.float32) indices = np.array(indices) depth = np.array(depth) pred = abs(np.arange(depth, dtype=indices.dtype) - indices[..., np.newaxis]) > 0 y_out = np.where(pred, np.array(off_value, dtype), np.array(on_value, dtype)) if axis is not None: y_out = np.moveaxis(y_out, -1, axis) return y_out
def _uniform_jax(shape, minval=0, maxval=None, dtype=tf.float32, seed=None, name=None): # pylint: disable=unused-argument """Jax uniform random sampler.""" import jax.random as jaxrand # pylint: disable=g-import-not-at-top if seed is None: raise ValueError('Must provide PRNGKey to sample in JAX.') dtype = utils.common_dtype([minval, maxval], dtype_hint=dtype) final_rank = max( [len(shape), len(np.shape(minval)), len(np.shape(maxval))]) if np.issubdtype(dtype, np.integer): if maxval is None: raise ValueError( 'Must specify maxval for integer dtype {}.'.format(dtype)) shape = _bcast_shape(shape, [minval, maxval]) # We must match ranks, as lax.max refuses to broadcast different-rank args. minval = minval + np.zeros([1] * final_rank, dtype=dtype) return jaxrand.randint(key=seed, shape=shape, minval=minval, maxval=maxval, dtype=dtype) else: maxval = dtype(1) if maxval is None else maxval shape = _bcast_shape(shape, [minval, maxval]) # We must match ranks, as lax.max refuses to broadcast different-rank args. minval = minval + np.zeros([1] * final_rank, dtype=dtype) maxval = maxval + np.zeros([1] * final_rank, dtype=dtype) return jaxrand.uniform(key=seed, shape=shape, dtype=dtype, minval=minval, maxval=maxval)
def _poisson(shape, lam, dtype=tf.float32, seed=None, name=None): # pylint: disable=unused-argument rng = np.random if seed is None else np.random.RandomState(seed & 0xffffffff) dtype = utils.common_dtype([lam], dtype_hint=dtype) shape = _ensure_tuple(shape) + np.shape(lam) return rng.poisson(lam=lam, size=shape).astype(dtype)
def _normal(shape, mean=0.0, stddev=1.0, dtype=tf.float32, seed=None, name=None): # pylint: disable=unused-argument rng = np.random if seed is None else np.random.RandomState(seed & 0xffffffff) dtype = utils.common_dtype([mean, stddev], dtype_hint=dtype) shape = _bcast_shape(shape, [mean, stddev]) return rng.normal(loc=mean, scale=stddev, size=shape).astype(dtype)
ndtri = utils.copy_docstring('tf.math.ndtri', lambda x, name=None: scipy_special.ndtri(x)) negative = utils.copy_docstring('tf.math.negative', lambda x, name=None: np.negative(x)) nextafter = utils.copy_docstring( 'tf.math.nextafter', lambda x1, x2, name=None: np.nextafter(x1, x2)) not_equal = utils.copy_docstring('tf.math.not_equal', lambda x, y, name=None: np.not_equal(x, y)) polygamma = utils.copy_docstring( 'tf.math.polygamma', lambda a, x, name=None: scipy_special.polygamma(np.int32(a), x).astype( # pylint: disable=unused-argument,g-long-lambda utils.common_dtype([a, x], dtype_hint=np.float32))) polyval = utils.copy_docstring( 'tf.math.polyval', lambda coeffs, x, name=None: np.polyval(np.asarray(coeffs), np.asarray(x))) pow = utils.copy_docstring( # pylint: disable=redefined-builtin 'tf.math.pow', lambda x, y, name=None: np.power(x, y)) real = utils.copy_docstring('tf.math.real', lambda input, name=None: np.real(input)) reciprocal = utils.copy_docstring('tf.math.reciprocal', lambda x, name=None: np.reciprocal(x))
one_hot = utils.copy_docstring(tf.one_hot, _one_hot) ones = utils.copy_docstring( tf.ones, lambda shape, dtype=tf.float32, name=None: np.ones( # pylint: disable=g-long-lambda shape, utils.numpy_dtype(dtype))) ones_like = utils.copy_docstring(tf.ones_like, _ones_like) pad = utils.copy_docstring(tf.pad, _pad) range = utils.copy_docstring( # pylint: disable=redefined-builtin tf.range, lambda start, limit=None, delta=1, dtype=None, name='range': np.arange( # pylint: disable=g-long-lambda start, limit, delta).astype( utils.numpy_dtype(dtype or utils.common_dtype([start], np.int32)))) rank = utils.copy_docstring( tf.rank, lambda input, name=None: np.int32(np.array(input).ndim)) # pylint: disable=redefined-builtin,g-long-lambda reshape = utils.copy_docstring( tf.reshape, lambda tensor, shape, name=None: np.reshape(tensor, shape)) roll = utils.copy_docstring( tf.roll, lambda input, shift, axis: np.roll(input, shift, axis)) # pylint: disable=unnecessary-lambda searchsorted = utils.copy_docstring(tf.searchsorted, _searchsorted) shape = utils.copy_docstring(tf.shape, _shape) size = utils.copy_docstring(tf.size, _size)