Esempio n. 1
0
 def testDerivativeIsBoundedWhenAlphaIsBelow2(self):
     # Assert that |d_x| < |x|/scale^2 when alpha <= 2.
     _, _, x, alpha, scale, d_x, _, _ = self._precompute_lossfun_inputs()
     mask = jnp.isfinite(alpha) & (alpha <= 2)
     grad = jnp.abs(d_x[mask])
     bound = ((jnp.abs(x[mask]) + (300. * jnp.finfo(jnp.float32).eps)) /
              scale[mask]**2)
     self.assertTrue(jnp.all(grad <= bound))
Esempio n. 2
0
def binary_crossentropy_loss(params, predict, data):
    inputs, targets = data
    probs = predict(params, inputs)
    eps = jnp.finfo(probs.dtype).eps
    probs = jnp.clip(probs, eps, 1 - eps)
    loss = -(jsp.special.xlogy(targets, probs) +
             jsp.special.xlogy(1 - targets, 1 - probs)).mean()
    return loss
Esempio n. 3
0
 def log_abs_det_jacobian(self, x, y, intermediates=None):
     # Ref: https://mc-stan.org/docs/2_19/reference-manual/simplex-transform-section.html
     # |det|(J) = Product(y * (1 - z))
     x = x - jnp.log(x.shape[-1] - jnp.arange(x.shape[-1]))
     z = jnp.clip(expit(x), a_min=jnp.finfo(x.dtype).tiny)
     # XXX we use the identity 1 - z = z * exp(-x) to not worry about
     # the case z ~ 1
     return jnp.sum(jnp.log(y[..., :-1] * z) - x, axis=-1)
Esempio n. 4
0
 def testDerivativeIsMonotonicWrtX(self):
     # Check that the loss increases monotonically with |x|.
     _, _, x, alpha, _, d_x, _, _ = self._precompute_lossfun_inputs()
     # This is just to suppress a warning below.
     d_x = jnp.where(jnp.isfinite(d_x), d_x, jnp.zeros_like(d_x))
     mask = jnp.isfinite(alpha) & (jnp.abs(d_x) >
                                   (300. * jnp.finfo(jnp.float32).eps))
     chex.assert_trees_all_close(jnp.sign(d_x[mask]), jnp.sign(x[mask]))
Esempio n. 5
0
 def testLossIsBoundedWhenAlphaIsNegative(self):
     # Assert that loss < (alpha - 2)/alpha when alpha < 0.
     _, loss, _, alpha, _, _, _, _ = self._precompute_lossfun_inputs()
     mask = alpha < 0.
     min_val = jnp.finfo(jnp.float32).min
     alpha_clipped = jnp.maximum(min_val, alpha[mask])
     self.assertTrue(
         jnp.all(loss[mask] <= ((alpha_clipped - 2.) / alpha_clipped)))
Esempio n. 6
0
def find_reasonable_step_size(potential_fn, kinetic_fn, momentum_generator, inverse_mass_matrix,
                              position, rng, init_step_size):
    """
    Finds a reasonable step size by tuning `init_step_size`. This function is used
    to avoid working with a too large or too small step size in HMC.

    **References:**

    1. *The No-U-Turn Sampler: Adaptively Setting Path Lengths in Hamiltonian Monte Carlo*,
       Matthew D. Hoffman, Andrew Gelman

    :param potential_fn: A callable to compute potential energy.
    :param kinetic_fn: A callable to compute kinetic energy.
    :param momentum_generator: A generator to get a random momentum variable.
    :param inverse_mass_matrix: Inverse of mass matrix.
    :param position: Current position of the particle.
    :param jax.random.PRNGKey rng: Random key to be used as the source of randomness.
    :param float init_step_size: Initial step size to be tuned.
    :return: a reasonable value for step size.
    :rtype: float
    """
    # We are going to find a step_size which make accept_prob (Metropolis correction)
    # near the target_accept_prob. If accept_prob:=exp(-delta_energy) is small,
    # then we have to decrease step_size; otherwise, increase step_size.
    target_accept_prob = np.log(0.8)

    _, vv_update = velocity_verlet(potential_fn, kinetic_fn)
    z = position
    potential_energy, z_grad = value_and_grad(potential_fn)(z)
    tiny = np.finfo(get_dtype(init_step_size)).tiny

    def _body_fn(state):
        step_size, _, direction, rng = state
        rng, rng_momentum = random.split(rng)
        # scale step_size: increase 2x or decrease 2x depends on direction;
        # direction=1 means keep increasing step_size, otherwise decreasing step_size.
        # Note that the direction is -1 if delta_energy is `NaN`, which may be the
        # case for a diverging trajectory (e.g. in the case of evaluating log prob
        # of a value simulated using a large step size for a constrained sample site).
        step_size = (2.0 ** direction) * step_size
        r = momentum_generator(inverse_mass_matrix, rng_momentum)
        _, r_new, potential_energy_new, _ = vv_update(step_size,
                                                      inverse_mass_matrix,
                                                      (z, r, potential_energy, z_grad))
        energy_current = kinetic_fn(inverse_mass_matrix, r) + potential_energy
        energy_new = kinetic_fn(inverse_mass_matrix, r_new) + potential_energy_new
        delta_energy = energy_new - energy_current
        direction_new = np.where(target_accept_prob < -delta_energy, 1, -1)
        return step_size, direction, direction_new, rng

    def _cond_fn(state):
        step_size, last_direction, direction, _ = state
        # condition to run only if step_size is not so small or we are not decreasing step_size
        not_small_step_size_cond = (step_size > tiny) | (direction >= 0)
        return not_small_step_size_cond & ((last_direction == 0) | (direction == last_direction))

    step_size, _, _, _ = while_loop(_cond_fn, _body_fn, (init_step_size, 0, 0, rng))
    return step_size
Esempio n. 7
0
 def sample(self, key, sample_shape=()):
     assert is_prng_key(key)
     dtype = jnp.result_type(float)
     finfo = jnp.finfo(dtype)
     minval = finfo.tiny
     u = random.uniform(key,
                        shape=sample_shape + self.batch_shape,
                        minval=minval)
     return self.base_dist.icdf(u * self._cdf_at_high)
Esempio n. 8
0
File: eigh.py Progetto: wayfeng/jax
def _projector_subspace(P, H, rank, maxiter=2):
    """ Decomposes the `n x n` rank `rank` Hermitian projector `P` into
  an `n x rank` isometry `Vm` such that `P = Vm @ Vm.conj().T` and
  an `n x (n - rank)` isometry `Vm` such that -(I - P) = Vp @ Vp.conj().T`.

  The subspaces are computed using the naiive QR eigendecomposition
  algorithm, which converges very quickly due to the sharp separation
  between the relevant eigenvalues of the projector.

  Args:
    P: A rank-`rank` Hermitian projector into the space of `H`'s
       first `rank` eigenpairs.
    H: The aforementioned Hermitian matrix, which is used to track
       convergence.
    rank: Rank of `P`.
    maxiter: Maximum number of iterations.
  Returns:
    Vm, Vp: Isometries into the eigenspaces described in the docstring.
  """
    # Choose an initial guess: the `rank` largest-norm columns of P.
    column_norms = jnp.linalg.norm(P, axis=1)
    sort_idxs = jnp.argsort(column_norms)
    X = P[:, sort_idxs]
    X = X[:, :rank]

    H_norm = jnp.linalg.norm(H)
    thresh = 10 * jnp.finfo(X.dtype).eps * H_norm

    # First iteration skips the matmul.
    def body_f_after_matmul(X):
        Q, _ = jnp.linalg.qr(X, mode="complete")
        V1 = Q[:, :rank]
        V2 = Q[:, rank:]
        # TODO: might be able to get away with lower precision here
        error_matrix = jnp.dot(V2.conj().T, H, precision=lax.Precision.HIGHEST)
        error_matrix = jnp.dot(error_matrix,
                               V1,
                               precision=lax.Precision.HIGHEST)
        error = jnp.linalg.norm(error_matrix) / H_norm
        return V1, V2, error

    def cond_f(args):
        _, _, j, error = args
        still_counting = j < maxiter
        unconverged = error > thresh
        return jnp.logical_and(still_counting, unconverged)[0]

    def body_f(args):
        V1, _, j, _ = args
        X = jnp.dot(P, V1, precision=lax.Precision.HIGHEST)
        V1, V2, error = body_f_after_matmul(X)
        return V1, V2, j + 1, error

    V1, V2, error = body_f_after_matmul(X)
    one = jnp.ones(1, dtype=jnp.int32)
    V1, V2, _, error = lax.while_loop(cond_f, body_f, (V1, V2, one, error))
    return V1, V2
Esempio n. 9
0
    def testNumpyAndXLAAgreeOnFloatEndianness(self, dtype):
        if not FLAGS.jax_enable_x64 and jnp.issubdtype(dtype, np.float64):
            raise SkipTest("can't test float64 agreement")

        bits_dtype = np.uint32 if jnp.finfo(dtype).bits == 32 else np.uint64
        numpy_bits = np.array(1., dtype).view(bits_dtype)
        xla_bits = api.jit(lambda: lax.bitcast_convert_type(
            np.array(1., dtype), bits_dtype))()
        self.assertEqual(numpy_bits, xla_bits)
Esempio n. 10
0
    def testPolar(self, n_zero_sv, degeneracy, geometric_spectrum, max_sv,
                  shape, method, side, nonzero_condition_number, dtype, seed):
        """ Tests jax.scipy.linalg.polar."""
        if jtu.device_under_test() != "cpu":
            if jnp.dtype(dtype).name in ("bfloat16", "float16"):
                raise unittest.SkipTest("Skip half precision off CPU.")

        m, n = shape
        if (method == "qdwh" and ((side == "left" and m >= n) or
                                  (side == "right" and m < n))):
            raise unittest.SkipTest("method=qdwh does not support these sizes")

        matrix, _ = _initialize_polar_test(self.rng(), shape, n_zero_sv,
                                           degeneracy, geometric_spectrum,
                                           max_sv, nonzero_condition_number,
                                           dtype)
        if jnp.dtype(dtype).name in ("bfloat16", "float16"):
            self.assertRaises(NotImplementedError,
                              jsp.linalg.polar,
                              matrix,
                              method=method,
                              side=side)
            return

        unitary, posdef = jsp.linalg.polar(matrix, method=method, side=side)
        if shape[0] >= shape[1]:
            should_be_eye = np.matmul(unitary.conj().T, unitary)
        else:
            should_be_eye = np.matmul(unitary, unitary.conj().T)
        tol = 500 * float(jnp.finfo(matrix.dtype).eps)
        eye_mat = np.eye(should_be_eye.shape[0], dtype=should_be_eye.dtype)
        with self.subTest('Test unitarity.'):
            self.assertAllClose(eye_mat, should_be_eye, atol=tol * min(shape))

        with self.subTest('Test Hermiticity.'):
            self.assertAllClose(posdef,
                                posdef.conj().T,
                                atol=tol * jnp.linalg.norm(posdef))

        ev, _ = np.linalg.eigh(posdef)
        ev = ev[np.abs(ev) > tol * np.linalg.norm(posdef)]
        negative_ev = jnp.sum(ev < 0.)
        with self.subTest('Test positive definiteness.'):
            self.assertEqual(negative_ev, 0)

        if side == "right":
            recon = jnp.matmul(unitary,
                               posdef,
                               precision=lax.Precision.HIGHEST)
        elif side == "left":
            recon = jnp.matmul(posdef,
                               unitary,
                               precision=lax.Precision.HIGHEST)
        with self.subTest('Test reconstruction.'):
            self.assertAllClose(matrix,
                                recon,
                                atol=tol * jnp.linalg.norm(matrix))
Esempio n. 11
0
def lossfun(x, alpha, scale):
    r"""Implements the general form of the loss.

  This implements the rho(x, \alpha, c) function described in "A General and
  Adaptive Robust Loss Function", Jonathan T. Barron,
  https://arxiv.org/abs/1701.03077.

  Args:
    x: The residual for which the loss is being computed. x can have any shape,
      and alpha and scale will be broadcasted to match x's shape if necessary.
    alpha: The shape parameter of the loss (\alpha in the paper), where more
      negative values produce a loss with more robust behavior (outliers "cost"
      less), and more positive values produce a loss with less robust behavior
      (outliers are penalized more heavily). Alpha can be any value in
      [-infinity, infinity], but the gradient of the loss with respect to alpha
      is 0 at -infinity, infinity, 0, and 2. Varying alpha allows for smooth
      interpolation between several discrete robust losses:
        alpha=-Infinity: Welsch/Leclerc Loss.
        alpha=-2: Geman-McClure loss.
        alpha=0: Cauchy/Lortentzian loss.
        alpha=1: Charbonnier/pseudo-Huber loss.
        alpha=2: L2 loss.
    scale: The scale parameter of the loss. When |x| < scale, the loss is an
      L2-like quadratic bowl, and when |x| > scale the loss function takes on a
      different shape according to alpha.

  Returns:
    The losses for each element of x, in the same shape as x.
  """
    eps = jnp.finfo(jnp.float32).eps

    # `scale` must be > 0.
    scale = jnp.maximum(eps, scale)

    # The loss when alpha == 2. This will get reused repeatedly.
    loss_two = 0.5 * (x / scale)**2

    # "Safe" versions of log1p and expm1 that will not NaN-out.
    log1p_safe = lambda x: jnp.log1p(jnp.minimum(x, 3e37))
    expm1_safe = lambda x: jnp.expm1(jnp.minimum(x, 87.5))

    # The loss when not in one of the special casess.
    # Clamp |alpha| to be >= machine epsilon so that it's safe to divide by.
    a = jnp.where(alpha >= 0, jnp.ones_like(alpha),
                  -jnp.ones_like(alpha)) * jnp.maximum(eps, jnp.abs(alpha))
    # Clamp |2-alpha| to be >= machine epsilon so that it's safe to divide by.
    b = jnp.maximum(eps, jnp.abs(alpha - 2))
    loss_ow = (b / a) * ((loss_two / (0.5 * b) + 1)**(0.5 * alpha) - 1)

    # Select which of the cases of the loss to return as a function of alpha.
    return jnp.where(
        alpha == -jnp.inf, -expm1_safe(-loss_two),
        jnp.where(
            alpha == 0, log1p_safe(loss_two),
            jnp.where(
                alpha == 2, loss_two,
                jnp.where(alpha == jnp.inf, expm1_safe(loss_two), loss_ow))))
Esempio n. 12
0
def categorical_sample(key, probs):
    """Sample from a set of discrete probabilities."""
    probs = probs / probs.sum(axis=-1, keepdims=True)
    cpi = jnp.cumsum(probs, axis=-1)
    eps = jnp.finfo(probs.dtype).eps
    rnds = jax.random.uniform(key=key,
                              shape=probs.shape[:-1] + (1, ),
                              dtype=probs.dtype,
                              minval=eps)
    return jnp.argmin(jnp.logical_or(rnds > cpi, probs < eps), axis=-1)
Esempio n. 13
0
def visualize_depth(x, acc, lo=None, hi=None):
    """Visualizes depth maps."""

    depth_curve_fn = lambda x: -jnp.log(x + jnp.finfo(jnp.float32).eps)
    return visualize_cmap(x,
                          acc,
                          cm.get_cmap('turbo'),
                          curve_fn=depth_curve_fn,
                          lo=lo,
                          hi=hi)
Esempio n. 14
0
def visualize_cmap(value,
                   weight,
                   colormap,
                   lo=None,
                   hi=None,
                   percentile=99.,
                   curve_fn=lambda x: x,
                   modulus=None,
                   matte_background=True):
    """Visualize a 1D image and a 1D weighting according to some colormap.

  Args:
    value: A 1D image.
    weight: A weight map, in [0, 1].
    colormap: A colormap function.
    lo: The lower bound to use when rendering, if None then use a percentile.
    hi: The upper bound to use when rendering, if None then use a percentile.
    percentile: What percentile of the value map to crop to when automatically
      generating `lo` and `hi`. Depends on `weight` as well as `value'.
    curve_fn: A curve function that gets applied to `value`, `lo`, and `hi`
      before the rest of visualization. Good choices: x, 1/(x+eps), log(x+eps).
    modulus: If not None, mod the normalized value by `modulus`. Use (0, 1]. If
      `modulus` is not None, `lo`, `hi` and `percentile` will have no effect.
    matte_background: If True, matte the image over a checkerboard.

  Returns:
    A colormap rendering.
  """
    # Identify the values that bound the middle of `value' according to `weight`.
    lo_auto, hi_auto = math.weighted_percentile(
        value, weight, [50 - percentile / 2, 50 + percentile / 2])

    # If `lo` or `hi` are None, use the automatically-computed bounds above.
    eps = jnp.finfo(jnp.float32).eps
    lo = lo or (lo_auto - eps)
    hi = hi or (hi_auto + eps)

    # Curve all values.
    value, lo, hi = [curve_fn(x) for x in [value, lo, hi]]

    # Wrap the values around if requested.
    if modulus:
        value = jnp.mod(value, modulus) / modulus
    else:
        # Otherwise, just scale to [0, 1].
        value = jnp.nan_to_num(
            jnp.clip((value - jnp.minimum(lo, hi)) / jnp.abs(hi - lo), 0, 1))

    if colormap:
        colorized = colormap(value)[:, :, :3]
    else:
        assert len(value.shape) == 3 and value.shape[-1] == 3
        colorized = value

    return matte(colorized, weight) if matte_background else colorized
Esempio n. 15
0
    def testRngUniform(self, dtype):
        key = random.PRNGKey(0)
        rand = lambda key: random.uniform(key, (10000, ), dtype)
        crand = api.jit(rand)

        uncompiled_samples = rand(key)
        compiled_samples = crand(key)

        for samples in [uncompiled_samples, compiled_samples]:
            self._CheckCollisions(samples, np.finfo(dtype).nmant)
            self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.uniform().cdf)
Esempio n. 16
0
 def debiased_moments(self):
     """Returns debiased moments as in Adam."""
     tiny = jnp.finfo(self.decay_product).tiny
     debias = 1.0 / jnp.maximum(1 - self.decay_product, tiny)
     mean = jax.tree_map(lambda m1: m1 * debias, self.mu)
     # This computation of the variance may lose some numerical precision, if
     # the mean is not approximately zero.
     variance = jax.tree_map(
         lambda m2, m: jnp.maximum(0.0, m2 * debias - jnp.square(m)),
         self.nu, mean)
     return EmaMoments(mean=mean, variance=variance)
Esempio n. 17
0
 def test_bicgstab_on_random_system(self, shape, dtype, preconditioner):
   rng = jtu.rand_default(self.rng())
   A = rng(shape, dtype)
   solution = rng(shape[1:], dtype)
   M = self._fetch_preconditioner(preconditioner, A, rng=rng)
   b = matmul_high_precision(A, solution)
   tol = shape[0] * jnp.finfo(A.dtype).eps
   x, info = jax.scipy.sparse.linalg.bicgstab(A, b, tol=tol, atol=tol, M=M)
   using_x64 = solution.dtype.kind in {np.float64, np.complex128}
   solution_tol = 1e-8 if using_x64 else 1e-4
   self.assertAllClose(x, solution, atol=solution_tol, rtol=solution_tol)
Esempio n. 18
0
def inverse_softplus(y):
    """Inverse of jax.nn.softplus, adapted from TensorFlow Probability."""
    threshold = jnp.log(jnp.finfo(jnp.float32).eps) + 2.
    is_too_small = y < jnp.exp(threshold)
    is_too_large = y > -threshold
    too_small_value = jnp.log(y)
    too_large_value = y
    y = jnp.where(is_too_small | is_too_large, 1., y)
    x = y + jnp.log(-jnp.expm1(-y))
    return jnp.where(is_too_small, too_small_value,
                     jnp.where(is_too_large, too_large_value, x))
Esempio n. 19
0
 def sample(self, key, sample_shape=()):
     assert is_prng_key(key)
     dtype = jnp.result_type(float)
     finfo = jnp.finfo(dtype)
     minval = finfo.tiny
     u = random.uniform(key, shape=sample_shape + self.batch_shape, minval=minval)
     loc = self.base_dist.loc
     sign = jnp.where(loc >= self.low, 1.0, -1.0)
     return (1 - sign) * loc + sign * self.base_dist.icdf(
         (1 - u) * self._tail_prob_at_low + u * self._tail_prob_at_high
     )
 def _max_condition_number_to_be_non_singular(self):
     """Return the maximum condition number that we consider nonsingular."""
     with ops.name_scope("max_nonsingular_condition_number"):
         dtype_eps = np.finfo(self.dtype.as_numpy_dtype).eps
         eps = _ops.cast(
             math_ops.reduce_max([
                 100.,
                 _ops.cast(self.range_dimension_tensor(), self.dtype),
                 _ops.cast(self.domain_dimension_tensor(), self.dtype)
             ]), self.dtype) * dtype_eps
         return 1. / eps
Esempio n. 21
0
 def _sample_n(self, key: PRNGKey, n: int) -> Array:
     """See `Distribution._sample_n`."""
     out_shape = (n, ) + self.batch_shape
     dtype = jnp.result_type(self._loc, self._scale)
     uniform = jax.random.uniform(key,
                                  shape=out_shape,
                                  dtype=dtype,
                                  minval=jnp.finfo(dtype).tiny,
                                  maxval=1.)
     rnd = jnp.log(uniform) - jnp.log1p(-uniform)
     return self._scale * rnd + self._loc
Esempio n. 22
0
def _check_symmetry(x: jnp.ndarray) -> bool:
    """Check if the array is symmetric."""
    m, n = x.shape
    eps = jnp.finfo(x.dtype).eps
    tol = 50.0 * eps
    is_symmetric = False
    if m == n:
        if np.linalg.norm(x - x.T.conj()) / np.linalg.norm(x) < tol:
            is_symmetric = True

    return is_symmetric
Esempio n. 23
0
def _svd_tall_and_square_input(
        a: Any, hermitian: bool, compute_uv: bool,
        max_iterations: int) -> Union[Any, Sequence[Any]]:
    """Singular value decomposition for m x n matrix and m >= n.

  Args:
    a: A matrix of shape `m x n` with `m >= n`.
    hermitian: True if `a` is Hermitian.
    compute_uv: Whether to compute also `u` and `v` in addition to `s`.
    max_iterations: The predefined maximum number of iterations of QDWH.

  Returns:
    A 3-tuple (`u`, `s`, `v`), where `u` is a unitary matrix of shape `m x n`,
    `s` is vector of length `n` containing the singular values in the descending
    order, `v` is a unitary matrix of shape `n x n`, and
    `a = (u * s) @ v.T.conj()`. For `compute_uv=False`, only `s` is returned.
  """

    u, h, _, _ = lax.linalg.qdwh(a,
                                 is_hermitian=hermitian,
                                 max_iterations=max_iterations)

    # TODO: Uses `eigvals_only=True` if `compute_uv=False`.
    v, s = lax.linalg.eigh(h)

    # Flips the singular values in descending order.
    s_out = jnp.flip(s)

    if not compute_uv:
        return s_out

    # Reorders eigenvectors.
    v_out = jnp.fliplr(v)

    u_out = u @ v_out

    # Makes correction if computed `u` from qdwh is not unitary.
    # Section 5.5 of Nakatsukasa, Yuji, and Nicholas J. Higham. "Stable and
    # efficient spectral divide and conquer algorithms for the symmetric
    # eigenvalue decomposition and the SVD." SIAM Journal on Scientific Computing
    # 35, no. 3 (2013): A1325-A1349.
    def correct_rank_deficiency(u_out):
        u_out, r = lax.linalg.qr(u_out, full_matrices=False)
        u_out = u_out @ jnp.diag(lax.sign(jnp.diag(r)))
        return u_out

    eps = float(jnp.finfo(a.dtype).eps)
    u_out = lax.cond(s[0] < a.shape[1] * eps * s_out[0],
                     correct_rank_deficiency,
                     lambda u_out: u_out,
                     operand=(u_out))

    return (u_out, s_out, v_out)
Esempio n. 24
0
  def testDynamicIndexingWithIntegersGrads(self, shape, dtype, rng_factory, indexer):
    rng = rng_factory()
    tol = 1e-2 if jnp.finfo(dtype).bits == 32 else None
    unpacked_indexer, pack_indexer = self._ReplaceSlicesWithTuples(indexer)

    @api.jit
    def fun(unpacked_indexer, x):
      indexer = pack_indexer(unpacked_indexer)
      return x[indexer]

    arr = rng(shape, dtype)
    check_grads(partial(fun, unpacked_indexer), (arr,), 2, tol, tol, tol)
Esempio n. 25
0
    def _inverse(self, y):
        # inverse stick-breaking
        remainder = 1 - jnp.cumsum(jnp.abs(y[..., :-1]), axis=-1)
        pad_width = [(0, 0)] * (y.ndim - 1) + [(1, 0)]
        remainder = jnp.pad(remainder, pad_width, mode="constant", constant_values=1.0)
        finfo = jnp.finfo(y.dtype)
        remainder = jnp.clip(remainder, a_min=finfo.tiny)
        t = y / remainder

        # inverse of tanh
        t = jnp.clip(t, a_min=-1 + finfo.eps, a_max=1 - finfo.eps)
        return jnp.arctanh(t)
Esempio n. 26
0
 def __call__(self, shape: Sequence[int], dtype: Any) -> jnp.ndarray:
     real_dtype = jnp.finfo(dtype).dtype
     m = jax.lax.convert_element_type(self.mean, dtype)
     s = jax.lax.convert_element_type(self.stddev, real_dtype)
     is_complex = jnp.issubdtype(dtype, jnp.complexfloating)
     if is_complex:
         shape = [2, *shape]
     unscaled = jax.random.truncated_normal(hk.next_rng_key(), -2., 2.,
                                            shape, real_dtype)
     if is_complex:
         unscaled = unscaled[0] + 1j * unscaled[1]
     return s * unscaled + m
Esempio n. 27
0
    def test_no_privacy(self):
        """l2_norm_clip=MAX_FLOAT32 and noise_multiplier=0 should recover SGD."""
        dp_agg = privacy.differentially_private_aggregate(
            l2_norm_clip=jnp.finfo(jnp.float32).max,
            noise_multiplier=0.,
            seed=0)
        state = dp_agg.init(self.params)
        update_fn = self.variant(dp_agg.update)
        mean_grads = jax.tree_map(lambda g: g.mean(0), self.per_eg_grads)

        for _ in range(3):
            updates, state = update_fn(self.per_eg_grads, state)
            chex.assert_tree_all_close(updates, mean_grads)
Esempio n. 28
0
  def testRngUniform(self, dtype):
    if jtu.device_under_test() == "tpu" and jnp.dtype(dtype).itemsize < 3:
      raise SkipTest("random.uniform() not supported on TPU for 16-bit types.")
    key = random.PRNGKey(0)
    rand = lambda key: random.uniform(key, (10000,), dtype)
    crand = api.jit(rand)

    uncompiled_samples = rand(key)
    compiled_samples = crand(key)

    for samples in [uncompiled_samples, compiled_samples]:
      self._CheckCollisions(samples, jnp.finfo(dtype).nmant)
      self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.uniform().cdf)
Esempio n. 29
0
def makegauss2D(shape=(3, 3), sigma=0.5):
    """
    2D gaussian mask - should give the same result as MATLAB's
    fspecial('gaussian',[shape],[sigma])
    """
    m, n = [(ss - 1.0) / 2.0 for ss in shape]
    y, x = jnp.meshgrid(jnp.arange(-m, m + 1), jnp.arange(-n, n + 1))
    h = jnp.exp(-(x * x + y * y) / (2.0 * sigma * sigma))
    h = h.at[h < jnp.finfo(h.dtype).eps * h.max()].set(0)
    sumh = h.sum()
    if sumh != 0:
        h /= sumh
    return h
Esempio n. 30
0
    def __call__(self, name, fn, obs):
        assert obs is None, "TransformReparam does not support observe statements"
        shape = fn.shape()
        fn, expand_shape, event_dim = self._unwrap(fn)
        transform = uniform_reparam_transform(fn)
        tiny = jnp.finfo(jnp.result_type(float)).tiny

        x = numpyro.sample(
            "{}_base".format(name),
            dist.Uniform(tiny,
                         1).expand(shape).to_event(event_dim).mask(False),
        )
        # Simulate a numpyro.deterministic() site.
        return None, transform(x)