def _gamma_jax(shape,
               alpha,
               beta=None,
               dtype=np.float32,
               seed=None,
               name=None):  # pylint: disable=unused-argument
    """JAX-based reparameterized gamma sampler."""
    dtype = utils.common_dtype([alpha, beta], dtype_hint=dtype)
    alpha = np.array(alpha, dtype=dtype)
    beta = None if beta is None else np.array(beta, dtype=dtype)
    shape = _ensure_tuple(shape)
    import jax.random as jaxrand  # pylint: disable=g-import-not-at-top
    if seed is None:
        raise ValueError('Must provide PRNGKey to sample in JAX.')
    # TODO(srvasude): Sample in the given dtype once
    # https://github.com/google/jax/issues/2130 is fixed.
    samps = jaxrand.gamma(key=seed, a=alpha, shape=shape,
                          dtype=np.float64).astype(dtype)
    # Match the 0->tiny behavior of tf.random.gamma.
    return np.maximum(
        np.finfo(dtype).tiny, samps if beta is None else samps / beta)
Example #2
0
def _one_hot(  # pylint: disable=unused-argument
    indices,
    depth,
    on_value=None,
    off_value=None,
    axis=None,
    dtype=None,
    name=None):
  """One hot."""
  if on_value is None:
    on_value = 1
  if off_value is None:
    off_value = 0
  if dtype is None:
    dtype = utils.common_dtype([on_value, off_value], np.float32)
  indices = np.array(indices)
  depth = np.array(depth)
  pred = abs(np.arange(depth, dtype=indices.dtype) -
             indices[..., np.newaxis]) > 0
  y_out = np.where(pred, np.array(off_value, dtype), np.array(on_value, dtype))
  if axis is not None:
    y_out = np.moveaxis(y_out, -1, axis)
  return y_out
def _uniform_jax(shape,
                 minval=0,
                 maxval=None,
                 dtype=tf.float32,
                 seed=None,
                 name=None):  # pylint: disable=unused-argument
    """Jax uniform random sampler."""
    import jax.random as jaxrand  # pylint: disable=g-import-not-at-top
    if seed is None:
        raise ValueError('Must provide PRNGKey to sample in JAX.')
    dtype = utils.common_dtype([minval, maxval], dtype_hint=dtype)
    final_rank = max(
        [len(shape), len(np.shape(minval)),
         len(np.shape(maxval))])
    if np.issubdtype(dtype, np.integer):
        if maxval is None:
            raise ValueError(
                'Must specify maxval for integer dtype {}.'.format(dtype))
        shape = _bcast_shape(shape, [minval, maxval])
        # We must match ranks, as lax.max refuses to broadcast different-rank args.
        minval = minval + np.zeros([1] * final_rank, dtype=dtype)
        return jaxrand.randint(key=seed,
                               shape=shape,
                               minval=minval,
                               maxval=maxval,
                               dtype=dtype)
    else:
        maxval = dtype(1) if maxval is None else maxval
        shape = _bcast_shape(shape, [minval, maxval])
        # We must match ranks, as lax.max refuses to broadcast different-rank args.
        minval = minval + np.zeros([1] * final_rank, dtype=dtype)
        maxval = maxval + np.zeros([1] * final_rank, dtype=dtype)
        return jaxrand.uniform(key=seed,
                               shape=shape,
                               dtype=dtype,
                               minval=minval,
                               maxval=maxval)
def _poisson(shape, lam, dtype=tf.float32, seed=None, name=None):  # pylint: disable=unused-argument
    rng = np.random if seed is None else np.random.RandomState(seed
                                                               & 0xffffffff)
    dtype = utils.common_dtype([lam], dtype_hint=dtype)
    shape = _ensure_tuple(shape) + np.shape(lam)
    return rng.poisson(lam=lam, size=shape).astype(dtype)
Example #5
0
def _normal(shape, mean=0.0, stddev=1.0, dtype=tf.float32, seed=None,
            name=None):  # pylint: disable=unused-argument
  rng = np.random if seed is None else np.random.RandomState(seed & 0xffffffff)
  dtype = utils.common_dtype([mean, stddev], dtype_hint=dtype)
  shape = _bcast_shape(shape, [mean, stddev])
  return rng.normal(loc=mean, scale=stddev, size=shape).astype(dtype)
Example #6
0
ndtri = utils.copy_docstring('tf.math.ndtri',
                             lambda x, name=None: scipy_special.ndtri(x))

negative = utils.copy_docstring('tf.math.negative',
                                lambda x, name=None: np.negative(x))

nextafter = utils.copy_docstring(
    'tf.math.nextafter', lambda x1, x2, name=None: np.nextafter(x1, x2))

not_equal = utils.copy_docstring('tf.math.not_equal',
                                 lambda x, y, name=None: np.not_equal(x, y))

polygamma = utils.copy_docstring(
    'tf.math.polygamma',
    lambda a, x, name=None: scipy_special.polygamma(np.int32(a), x).astype(  # pylint: disable=unused-argument,g-long-lambda
        utils.common_dtype([a, x], dtype_hint=np.float32)))

polyval = utils.copy_docstring(
    'tf.math.polyval',
    lambda coeffs, x, name=None: np.polyval(np.asarray(coeffs), np.asarray(x)))

pow = utils.copy_docstring(  # pylint: disable=redefined-builtin
    'tf.math.pow',
    lambda x, y, name=None: np.power(x, y))

real = utils.copy_docstring('tf.math.real',
                            lambda input, name=None: np.real(input))

reciprocal = utils.copy_docstring('tf.math.reciprocal',
                                  lambda x, name=None: np.reciprocal(x))
Example #7
0
one_hot = utils.copy_docstring(tf.one_hot, _one_hot)

ones = utils.copy_docstring(
    tf.ones,
    lambda shape, dtype=tf.float32, name=None: np.ones(  # pylint: disable=g-long-lambda
        shape, utils.numpy_dtype(dtype)))

ones_like = utils.copy_docstring(tf.ones_like, _ones_like)

pad = utils.copy_docstring(tf.pad, _pad)

range = utils.copy_docstring(  # pylint: disable=redefined-builtin
    tf.range,
    lambda start, limit=None, delta=1, dtype=None, name='range': np.arange(  # pylint: disable=g-long-lambda
        start, limit, delta).astype(
            utils.numpy_dtype(dtype or utils.common_dtype([start], np.int32))))

rank = utils.copy_docstring(
    tf.rank, lambda input, name=None: np.int32(np.array(input).ndim))  # pylint: disable=redefined-builtin,g-long-lambda

reshape = utils.copy_docstring(
    tf.reshape, lambda tensor, shape, name=None: np.reshape(tensor, shape))

roll = utils.copy_docstring(
    tf.roll, lambda input, shift, axis: np.roll(input, shift, axis))  # pylint: disable=unnecessary-lambda

searchsorted = utils.copy_docstring(tf.searchsorted, _searchsorted)

shape = utils.copy_docstring(tf.shape, _shape)

size = utils.copy_docstring(tf.size, _size)