Beispiel #1
0
    def body_fn(state):
        i, key, _, _ = state
        key, subkey = random.split(key)

        if radius is None or prototype_params is None:
            # Wrap model in a `substitute` handler to initialize from `init_loc_fn`.
            seeded_model = substitute(seed(model, subkey), substitute_fn=init_strategy)
            model_trace = trace(seeded_model).get_trace(*model_args, **model_kwargs)
            constrained_values, inv_transforms = {}, {}
            for k, v in model_trace.items():
                if v['type'] == 'sample' and not v['is_observed'] and not v['fn'].is_discrete:
                    constrained_values[k] = v['value']
                    inv_transforms[k] = biject_to(v['fn'].support)
            params = transform_fn(inv_transforms,
                                  {k: v for k, v in constrained_values.items()},
                                  invert=True)
        else:  # this branch doesn't require tracing the model
            params = {}
            for k, v in prototype_params.items():
                if k in init_values:
                    params[k] = init_values[k]
                else:
                    params[k] = random.uniform(subkey, jnp.shape(v), minval=-radius, maxval=radius)
                    key, subkey = random.split(key)

        potential_fn = partial(potential_energy, model, model_args, model_kwargs, enum=enum)
        pe, z_grad = value_and_grad(potential_fn)(params)
        z_grad_flat = ravel_pytree(z_grad)[0]
        is_valid = jnp.isfinite(pe) & jnp.all(jnp.isfinite(z_grad_flat))
        return i + 1, key, (params, pe, z_grad), is_valid
Beispiel #2
0
    def body_fn(state):
        i, key, _, _ = state
        key, subkey = random.split(key)

        # Wrap model in a `substitute` handler to initialize from `init_loc_fn`.
        # Use `block` to not record sample primitives in `init_loc_fn`.
        seeded_model = substitute(model, substitute_fn=block(seed(init_strategy, subkey)))
        model_trace = trace(seeded_model).get_trace(*model_args, **model_kwargs)
        constrained_values, inv_transforms = {}, {}
        for k, v in model_trace.items():
            if v['type'] == 'sample' and not v['is_observed']:
                if v['intermediates']:
                    constrained_values[k] = v['intermediates'][0][0]
                    inv_transforms[k] = biject_to(v['fn'].base_dist.support)
                else:
                    constrained_values[k] = v['value']
                    inv_transforms[k] = biject_to(v['fn'].support)
            elif v['type'] == 'param' and param_as_improper:
                constraint = v['kwargs'].pop('constraint', real)
                transform = biject_to(constraint)
                if isinstance(transform, ComposeTransform):
                    base_transform = transform.parts[0]
                    inv_transforms[k] = base_transform
                    constrained_values[k] = base_transform(transform.inv(v['value']))
                else:
                    inv_transforms[k] = transform
                    constrained_values[k] = v['value']
        params = transform_fn(inv_transforms,
                              {k: v for k, v in constrained_values.items()},
                              invert=True)
        potential_fn = jax.partial(potential_energy, model, inv_transforms, model_args, model_kwargs)
        pe, param_grads = value_and_grad(potential_fn)(params)
        z_grad = ravel_pytree(param_grads)[0]
        is_valid = np.isfinite(pe) & np.all(np.isfinite(z_grad))
        return i + 1, key, params, is_valid
Beispiel #3
0
  def testScaleAndTranslateGradFinite(self, antialias):
    image_shape = [1, 6, 7, 1]
    target_shape = [1, 3, 3, 1]

    data = [
        51, 38, 32, 89, 41, 21, 97, 51, 33, 87, 89, 34, 21, 97, 43, 25, 25, 92,
        41, 11, 84, 11, 55, 111, 23, 99, 50, 83, 13, 92, 52, 43, 90, 43, 14, 89,
        71, 32, 23, 23, 35, 93
    ]

    x = jnp.array(data, dtype=jnp.float32).reshape(image_shape)
    scale_a = jnp.array([1.0, 0.35, 0.4, 1.0], dtype=jnp.float32)
    translation_a = jnp.array([0.0, 0.2, 0.1, 0.0], dtype=jnp.float32)

    def scale_fn(s):
      return jnp.sum(jax.image.scale_and_translate(
        x, target_shape, (0, 1, 2, 3), s, translation_a, "linear", antialias,
        precision=jax.lax.Precision.HIGHEST))

    scale_out = jax.grad(scale_fn)(scale_a)
    self.assertTrue(jnp.all(jnp.isfinite(scale_out)))

    def translate_fn(t):
      return jnp.sum(jax.image.scale_and_translate(
        x, target_shape, (0, 1, 2, 3), scale_a, t, "linear", antialias,
        precision=jax.lax.Precision.HIGHEST))

    translate_out = jax.grad(translate_fn)(translation_a)
    self.assertTrue(jnp.all(jnp.isfinite(translate_out)))
Beispiel #4
0
        def _param_idx_to_str(idx: int) -> str:
            param = self.x[idx]

            if self.sig_x is None:
                sig = None
            else:
                sig = self.sig_x[idx]

            if self.bounds is None:
                low = None
                upp = None
            else:
                low = self.bounds[idx, 0]
                upp = self.bounds[idx, 1]

            params_str = f"  {param}"
            if sig is not None:
                params_str += f" +/- {sig}"
            params_str += ","
            if low is not None:
                if jnp.isfinite(low):
                    params_str += f"\t [Lower Bound = {low}]"
            if upp is not None:
                if jnp.isfinite(upp):
                    params_str += f"\t [Upper Bound = {upp}]"

            return params_str
Beispiel #5
0
    def body_fn(state):
        i, key, _, _ = state
        key, subkey = random.split(key)

        if radius is None or prototype_params is None:
            # XXX: we don't want to apply enum to draw latent samples
            model_ = model
            if enum:
                from numpyro.contrib.funsor import enum as enum_handler

                if isinstance(model, substitute) and isinstance(model.fn, enum_handler):
                    model_ = substitute(model.fn.fn, data=model.data)
                elif isinstance(model, enum_handler):
                    model_ = model.fn

            # Wrap model in a `substitute` handler to initialize from `init_loc_fn`.
            seeded_model = substitute(seed(model_, subkey), substitute_fn=init_strategy)
            model_trace = trace(seeded_model).get_trace(*model_args, **model_kwargs)
            constrained_values, inv_transforms = {}, {}
            for k, v in model_trace.items():
                if (
                    v["type"] == "sample"
                    and not v["is_observed"]
                    and not v["fn"].is_discrete
                ):
                    constrained_values[k] = v["value"]
                    inv_transforms[k] = biject_to(v["fn"].support)
            params = transform_fn(
                inv_transforms,
                {k: v for k, v in constrained_values.items()},
                invert=True,
            )
        else:  # this branch doesn't require tracing the model
            params = {}
            for k, v in prototype_params.items():
                if k in init_values:
                    params[k] = init_values[k]
                else:
                    params[k] = random.uniform(
                        subkey, jnp.shape(v), minval=-radius, maxval=radius
                    )
                    key, subkey = random.split(key)

        potential_fn = partial(
            potential_energy, model, model_args, model_kwargs, enum=enum
        )
        if validate_grad:
            if forward_mode_differentiation:
                pe = potential_fn(params)
                z_grad = jacfwd(potential_fn)(params)
            else:
                pe, z_grad = value_and_grad(potential_fn)(params)
            z_grad_flat = ravel_pytree(z_grad)[0]
            is_valid = jnp.isfinite(pe) & jnp.all(jnp.isfinite(z_grad_flat))
        else:
            pe = potential_fn(params)
            is_valid = jnp.isfinite(pe)
            z_grad = None

        return i + 1, key, (params, pe, z_grad), is_valid
Beispiel #6
0
 def testDerivativeIsMonotonicWrtX(self):
     # Check that the loss increases monotonically with |x|.
     _, _, x, alpha, _, d_x, _, _ = self._precompute_lossfun_inputs()
     # This is just to suppress a warning below.
     d_x = jnp.where(jnp.isfinite(d_x), d_x, jnp.zeros_like(d_x))
     mask = jnp.isfinite(alpha) & (jnp.abs(d_x) >
                                   (300. * jnp.finfo(jnp.float32).eps))
     chex.assert_tree_all_close(jnp.sign(d_x[mask]), jnp.sign(x[mask]))
Beispiel #7
0
 def keep_step(grad_norm):
   keep_threshold = p.skip_step_gradient_norm_value
   if keep_threshold:
     return jnp.logical_and(
         jnp.all(jnp.isfinite(grad_norm)),
         jnp.all(jnp.less(grad_norm, keep_threshold)))
   else:
     return jnp.all(jnp.isfinite(grad_norm))
def test_mean_var(jax_dist, sp_dist, params):
    n = 20000 if jax_dist in [dist.LKJ, dist.LKJCholesky] else 200000
    d_jax = jax_dist(*params)
    k = random.PRNGKey(0)
    samples = d_jax.sample(k, sample_shape=(n,))
    # check with suitable scipy implementation if available
    if sp_dist and not _is_batched_multivariate(d_jax):
        d_sp = sp_dist(*params)
        try:
            sp_mean = d_sp.mean()
        except TypeError:  # mvn does not have .mean() method
            sp_mean = d_sp.mean
        # for multivariate distns try .cov first
        if d_jax.event_shape:
            try:
                sp_var = np.diag(d_sp.cov())
            except TypeError:  # mvn does not have .cov() method
                sp_var = np.diag(d_sp.cov)
            except AttributeError:
                sp_var = d_sp.var()
        else:
            sp_var = d_sp.var()
        assert_allclose(d_jax.mean, sp_mean, rtol=0.01, atol=1e-7)
        assert_allclose(d_jax.variance, sp_var, rtol=0.01, atol=1e-7)
        if np.all(np.isfinite(sp_mean)):
            assert_allclose(np.mean(samples, 0), d_jax.mean, rtol=0.05, atol=1e-2)
        if np.all(np.isfinite(sp_var)):
            assert_allclose(np.std(samples, 0), np.sqrt(d_jax.variance), rtol=0.05, atol=1e-2)
    elif jax_dist in [dist.LKJ, dist.LKJCholesky]:
        if jax_dist is dist.LKJCholesky:
            corr_samples = np.matmul(samples, np.swapaxes(samples, -2, -1))
        else:
            corr_samples = samples
        dimension, concentration, _ = params
        # marginal of off-diagonal entries
        marginal = dist.Beta(concentration + 0.5 * (dimension - 2),
                             concentration + 0.5 * (dimension - 2))
        # scale statistics due to linear mapping
        marginal_mean = 2 * marginal.mean - 1
        marginal_std = 2 * np.sqrt(marginal.variance)
        expected_mean = np.broadcast_to(np.reshape(marginal_mean, np.shape(marginal_mean) + (1, 1)),
                                        np.shape(marginal_mean) + d_jax.event_shape)
        expected_std = np.broadcast_to(np.reshape(marginal_std, np.shape(marginal_std) + (1, 1)),
                                       np.shape(marginal_std) + d_jax.event_shape)
        # diagonal elements of correlation matrices are 1
        expected_mean = expected_mean * (1 - np.identity(dimension)) + np.identity(dimension)
        expected_std = expected_std * (1 - np.identity(dimension))

        assert_allclose(np.mean(corr_samples, axis=0), expected_mean, atol=0.01)
        assert_allclose(np.std(corr_samples, axis=0), expected_std, atol=0.01)
    else:
        if np.all(np.isfinite(d_jax.mean)):
            assert_allclose(np.mean(samples, 0), d_jax.mean, rtol=0.05, atol=1e-2)
        if np.all(np.isfinite(d_jax.variance)):
            assert_allclose(np.std(samples, 0), np.sqrt(d_jax.variance), rtol=0.05, atol=1e-2)
Beispiel #9
0
  def body_fn(iteration, const, state, compute_error):
    """Carries out sinkhorn iteration.

    Depending on lse_mode, these iterations can be either in:
      - log-space for numerical stability.
      - scaling space, using standard kernel-vector multiply operations.

    Args:
      iteration: iteration number
      const: tuple of constant parameters that do not change throughout the
        loop, here the geometry and the marginals a, b.
      state: potential/scaling variables updated in the loop & error log.
      compute_error: flag to indicate this iteration computes/stores an error

    Returns:
      state variables, i.e. errors and updated f_u, g_v potentials.
    """
    geom, a, b, _ = const
    errors, f_u, g_v = state

    # compute momentum term if needed, using previously seen errors.
    w = jax.lax.stop_gradient(jnp.where(iteration >= (
        inner_iterations * chg_momentum_from + min_iterations),
                                        get_momentum(errors, chg_momentum_from),
                                        momentum_default))

    # Sinkhorn updates using momentum, in either scaling or potential form.
    if parallel_dual_updates:
      old_g_v = g_v
    if lse_mode:
      new_g_v = tau_b * geom.update_potential(f_u, g_v, jnp.log(b),
                                              iteration, axis=0)
      g_v = (1.0 - w) * jnp.where(jnp.isfinite(g_v), g_v, 0.0) + w * new_g_v
      new_f_u = tau_a * geom.update_potential(
          f_u, old_g_v if parallel_dual_updates else g_v,
          jnp.log(a), iteration, axis=1)
      f_u = (1.0 - w) * jnp.where(jnp.isfinite(f_u), f_u, 0.0) + w * new_f_u
    else:
      new_g_v = geom.update_scaling(f_u, b, iteration, axis=0) ** tau_b
      g_v = jnp.where(g_v > 0, g_v, 1) ** (1.0 - w) * new_g_v ** w
      new_f_u = geom.update_scaling(
          old_g_v if parallel_dual_updates else g_v,
          a, iteration, axis=1) ** tau_a
      f_u = jnp.where(f_u > 0, f_u, 1) ** (1.0 - w) * new_f_u ** w

    # re-computes error if compute_error is True, else set it to inf.
    err = jnp.where(
        jnp.logical_and(compute_error, iteration >= min_iterations),
        marginal_error(geom, a, b, tau_a, tau_b, f_u, g_v, norm_error,
                       lse_mode),
        jnp.inf)

    errors = jax.ops.index_update(
        errors, jax.ops.index[iteration // inner_iterations, :], err)
    return errors, f_u, g_v
Beispiel #10
0
def logmarglike_lineargaussianmodel_onetransfer(M_T,
                                                y,
                                                yinvvar,
                                                logyinvvar=None):
    """
    Fit linear model to one Gaussian data set, with no (=uniform) prior on the linear components.

    Parameters
    ----------
    y, yinvvar, logyinvvar : ndarray (n_pix_y)
        data and data inverse variances.
        Zeros will be ignored.
    M_T : ndarray (n_components, n_pix_y)
        design matrix of linear model

    Returns
    -------
    logfml : ndarray scalar
        log likelihood values with parameters marginalised and at best fit
    theta_map : ndarray (n_components)
        Best fit MAP parameters
    theta_cov : ndarray (n_components, n_components)
        Parameter covariance

    """
    # assert y.shape[-2] == yinvvar.shape[-2]
    assert y.shape[-1] == yinvvar.shape[-1]
    # assert y.shape[-1] == 1
    assert M_T.shape[-1] == yinvvar.shape[-1]
    assert np.all(np.isfinite(yinvvar))  # no negative elements
    assert np.all(np.isfinite(y))  # all finite
    assert np.all(np.isfinite(M_T))  # all finite
    assert np.count_nonzero(
        yinvvar) > 2  # at least two valid (non zero) pixels

    log2pi = np.log(2.0 * np.pi)
    nt = np.shape(M_T)[-2]
    ny = np.count_nonzero(yinvvar)
    M = np.transpose(M_T)  # (n_pix_y, n_components)
    Myinv = M * yinvvar[:, None]  # (n_pix_y, n_components)
    Hbar = np.matmul(M_T, Myinv)  #  (n_components, n_components)
    etabar = np.sum(Myinv * y[:, None], axis=0)  # (n_components)
    theta_map = np.linalg.solve(Hbar, etabar)  # (n_components)
    theta_cov = np.linalg.inv(Hbar)  # (n_components, n_components)
    if logyinvvar is None:
        logyinvvar = np.where(yinvvar == 0, 0, np.log(yinvvar))
    logdetH = np.sum(logyinvvar)  # scalar
    xi1 = -0.5 * (ny * log2pi - logdetH + np.sum(y * y * yinvvar))  # scalar
    sign, logdetHbar = np.linalg.slogdet(Hbar)
    xi2 = -0.5 * (nt * log2pi - logdetHbar + np.sum(etabar * theta_map))
    logfml = xi1 - xi2
    return logfml, theta_map, theta_cov
Beispiel #11
0
def assert_tree_all_finite(tree_like: ArrayTree):
  """Assert all tensor leaves in a tree are finite.

  Args:
    tree_like: pytree with array leaves

  Raises:
    AssertionError: if any leaf in the tree is non-finite.
  """
  all_finite = jax.tree_util.tree_all(
      jax.tree_map(lambda x: jnp.all(jnp.isfinite(x)), tree_like))
  if not all_finite:
    is_finite = lambda x: "Finite" if jnp.all(jnp.isfinite(x)) else "Nonfinite"
    error_msg = jax.tree_map(is_finite, tree_like)
    raise AssertionError(f"Tree contains non-finite value: {error_msg}.")
Beispiel #12
0
 def testDerivativeIsBoundedWhenAlphaIsBelow1(self):
     # Assert that |d_x| < 1/scale when alpha <= 1.
     _, _, _, alpha, scale, d_x, _, _ = self._precompute_lossfun_inputs()
     mask = jnp.isfinite(alpha) & (alpha <= 1)
     grad = jnp.abs(d_x[mask])
     bound = ((1. + (300. * jnp.finfo(jnp.float32).eps)) / scale[mask])
     self.assertTrue(jnp.all(grad <= bound))
Beispiel #13
0
    def sample_kernel(sa_state, model_args=(), model_kwargs=None):
        pe_fn = potential_fn
        if potential_fn_gen:
            pe_fn = potential_fn_gen(*model_args, **model_kwargs)
        zs, pes, loc, scale = sa_state.adapt_state
        # we recompute loc/scale after each iteration to avoid precision loss
        # XXX: consider to expose a setting to do this job periodically
        # to save some computations
        loc = jnp.mean(zs, 0)
        if scale.ndim == 2:
            cov = jnp.cov(zs, rowvar=False, bias=True)
            if cov.shape == ():  # JAX returns scalar for 1D input
                cov = cov.reshape((1, 1))
            cholesky = jnp.linalg.cholesky(cov)
            scale = jnp.where(jnp.any(jnp.isnan(cholesky)), scale, cholesky)
        else:
            scale = jnp.std(zs, 0)

        rng_key, rng_key_z, rng_key_reject, rng_key_accept = random.split(sa_state.rng_key, 4)
        _, unravel_fn = ravel_pytree(sa_state.z)

        z = loc + _sample_proposal(scale, rng_key_z)
        pe = pe_fn(unravel_fn(z))
        pe = jnp.where(jnp.isnan(pe), jnp.inf, pe)
        diverging = (pe - sa_state.potential_energy) > max_delta_energy

        # NB: all terms having the pattern *s will have shape N x ...
        # and all terms having the pattern *s_ will have shape (N + 1) x ...
        locs, scales = _get_proposal_loc_and_scale(zs, loc, scale, z)
        zs_ = jnp.concatenate([zs, z[None, :]])
        pes_ = jnp.concatenate([pes, pe[None]])
        locs_ = jnp.concatenate([locs, loc[None, :]])
        scales_ = jnp.concatenate([scales, scale[None, ...]])
        if scale.ndim == 2:  # dense_mass
            log_weights_ = dist.MultivariateNormal(locs_, scale_tril=scales_).log_prob(zs_) + pes_
        else:
            log_weights_ = dist.Normal(locs_, scales_).log_prob(zs_).sum(-1) + pes_
        # mask invalid values (nan, +inf) by -inf
        log_weights_ = jnp.where(jnp.isfinite(log_weights_), log_weights_, -jnp.inf)
        # get rejecting index
        j = random.categorical(rng_key_reject, log_weights_)
        zs = _numpy_delete(zs_, j)
        pes = _numpy_delete(pes_, j)
        loc = locs_[j]
        scale = scales_[j]
        adapt_state = SAAdaptState(zs, pes, loc, scale)

        # NB: weights[-1] / sum(weights) is the probability of rejecting the new sample `z`.
        accept_prob = 1 - jnp.exp(log_weights_[-1] - logsumexp(log_weights_))
        itr = sa_state.i + 1
        n = jnp.where(sa_state.i < wa_steps, itr, itr - wa_steps)
        mean_accept_prob = sa_state.mean_accept_prob + (accept_prob - sa_state.mean_accept_prob) / n

        # XXX: we make a modification of SA sampler in [1]
        # in [1], each MCMC state contains N points `zs`
        # here we do resampling to pick randomly a point from those N points
        k = random.categorical(rng_key_accept, jnp.zeros(zs.shape[0]))
        z = unravel_fn(zs[k])
        pe = pes[k]
        return SAState(itr, z, pe, accept_prob, mean_accept_prob, diverging, adapt_state, rng_key)
Beispiel #14
0
    def sample(self, data=None, name=None, shape=None, obs=None):
        '''Sample responses'''
        name = name or self.name

        if data is None:
            X = self.X  # use same data used to create model
        else:
            info = self.X.design_info  # information from original data
            X = patsy.dmatrix(info, data)  # design matrix for new data

        linpred = np.array(X) @ self.theta

        if shape is not None:
            linpred = linpred.reshape(shape)  # reshape to tensor if requested

        fwd, inv = self.link()

        if self.guess is None:
            mu = inv(linpred)
        else:
            fwd_guess = fwd(self.guess)
            if not np.isfinite(fwd_guess):
                raise ValueError("Bad Guess")
            mu = inv(fwd_guess + linpred)

        y = numpyro.sample(name, self.family(mu), obs=obs)

        return y, mu, linpred
Beispiel #15
0
def test_elbo_dynamic_support():
    x_prior = dist.TransformedDistribution(
        dist.Normal(),
        [AffineTransform(0, 2),
         SigmoidTransform(),
         AffineTransform(0, 3)])
    x_guide = dist.Uniform(0, 3)

    def model():
        numpyro.sample('x', x_prior)

    def guide():
        numpyro.sample('x', x_guide)

    adam = optim.Adam(0.01)
    # set base value of x_guide is 0.9
    x_base = 0.9
    guide = substitute(guide, base_param_map={'x': x_base})
    svi = SVI(model, guide, elbo, adam)
    svi_state = svi.init(random.PRNGKey(0), (), ())
    actual_loss = svi.evaluate(svi_state)
    assert np.isfinite(actual_loss)
    x, _ = x_guide.transform_with_intermediates(x_base)
    expected_loss = x_guide.log_prob(x) - x_prior.log_prob(x)
    assert_allclose(actual_loss, expected_loss)
Beispiel #16
0
    def update(updates, state, params=None):
        inner_state = state.inner_state
        flat_updates = tree_flatten(updates)[0]
        isfinite = jnp.all(
            jnp.array([jnp.all(jnp.isfinite(p)) for p in flat_updates]))
        notfinite_count = jnp.where(isfinite, jnp.zeros([], jnp.int64),
                                    1 + state.notfinite_count)

        def do_update(_):
            return inner.update(updates, inner_state, params)

        def reject_update(_):
            return (tree_map(jnp.zeros_like, updates), inner_state)

        updates, new_inner_state = lax.cond(jnp.logical_or(
            isfinite, notfinite_count > max_consecutive_errors),
                                            do_update,
                                            reject_update,
                                            operand=None)

        return updates, ApplyIfFiniteState(
            notfinite_count=notfinite_count,
            last_finite=isfinite,
            total_notfinite=jnp.logical_not(isfinite) + state.total_notfinite,
            inner_state=new_inner_state)
Beispiel #17
0
def _observe_binom_approx(name, latent, det_rate, det_conc, obs=None):
    '''Make observations of a latent variable using BinomialApprox.'''

    mask = True

    # Regularization: add reg to observed, and (reg/det_rate) to latent
    # The primary purpose is to avoid zeros, which are invalid values for
    # the Beta observation model.
    reg = 0.5
    latent = latent + (reg / det_rate)

    if obs is not None:
        '''
        Workaround for a jax issue: substitute default values
        AND mask out bad observations. 
        
        See https://forum.pyro.ai/t/behavior-of-mask-handler-with-invalid-observation-possible-bug/1719/5
        '''
        mask = np.isfinite(obs)
        obs = np.where(mask, obs, 0.5 * latent)
        obs = obs + reg

    det_rate = np.broadcast_to(det_rate, latent.shape)
    det_conc = np.minimum(
        det_conc,
        latent)  # don't allow it to be *more* concentrated than Binomial

    d = BinomialApprox(latent + (reg / det_rate), det_rate, det_conc)

    with numpyro.handlers.mask(mask_array=mask):
        y = numpyro.sample(name, d, obs=obs)

    return y
Beispiel #18
0
def tmrca_sf(t: np.ndarray, y: np.ndarray, n: int) -> np.ndarray:
    """The survival function of the TMRCA at each time point

    Args:
        t: time grid (including zero and infinity)
        y: effective population size in each epoch
        n: number of sampled haplotypes

    """
    # epoch durations
    s = np.diff(t)
    logu = -s / y
    logu = np.concatenate((np.array([0]), logu))
    # the A_2j are the product of this matrix
    # NOTE: using letter  "l" as a variable name to match text
    l = onp.arange(2, n + 1)[:, onp.newaxis]  # noqa: E741
    with onp.errstate(divide='ignore'):
        A2_terms = l * (l - 1) / (l * (l - 1) - l.T * (l.T - 1))
    onp.fill_diagonal(A2_terms, 1)
    A2 = np.prod(A2_terms, axis=0)

    binom_vec = l * (l - 1) / 2

    result = np.zeros(len(t))
    result = index_update(result, index[:-1],
                          np.squeeze(A2[np.newaxis, :]
                                     @ np.exp(np.cumsum(logu[np.newaxis, :-1],
                                                        axis=1)) ** binom_vec))

    assert np.all(np.isfinite(result))

    return result
Beispiel #19
0
def get_initial_state(system,
                      rng,
                      generate_x_obs_seq_init,
                      dim_q,
                      tol,
                      adam_step_size=2e-1,
                      reg_coeff=5e-2,
                      coarse_tol=1e-1,
                      max_iters=1000,
                      max_num_tries=10):
    """Find an initial constraint satisying state.

    Uses a heuristic combination of gradient-based minimisation of the norm
    of a modified constraint function plus a subsequent projection step using a
    quasi-Newton method, to try to find an initial point `q` such that
    `max(abs(constr(q)) < tol`.
    """

    # Use optimizers to set optimizer initialization and update functions
    opt_init, opt_update, get_params = opt.adam(adam_step_size)

    # Define a compiled update step
    @api.jit
    def step(i, opt_state, x_obs_seq_init):
        q, = get_params(opt_state)
        (obj, constr), grad = system.value_and_grad_init_objective(
            q, x_obs_seq_init, reg_coeff)
        opt_state = opt_update(i, grad, opt_state)
        return opt_state, obj, constr

    for t in range(max_num_tries):
        logging.info(f'Starting try {t+1}')
        q_init = rng.standard_normal(dim_q)
        x_obs_seq_init = generate_x_obs_seq_init(rng)
        opt_state = opt_init((q_init, ))
        for i in range(max_iters):
            opt_state_next, norm, constr = step(i, opt_state, x_obs_seq_init)
            if not np.isfinite(norm):
                logger.info('Adam iteration diverged')
                break
            max_abs_constr = maximum_norm(constr)
            if max_abs_constr < coarse_tol:
                logging.info('Within coarse_tol attempting projection.')
                q_init, = get_params(opt_state)
                state = ConditionedDiffusionHamiltonianState(
                    q_init, x_obs_seq=x_obs_seq_init)
                try:
                    state = jitted_solve_projection_onto_manifold_quasi_newton(
                        state, state, 1., system, tol)
                except ConvergenceError:
                    logger.info('Quasi-Newton iteration diverged.')
                if np.max(np.abs(system.constr(state))) < tol:
                    logging.info('Found constraint satisfying state.')
                    state.mom = system.sample_momentum(state, rng)
                    return state
            if i % 100 == 0:
                logging.info(f'Iteration {i: >6}: mean|constr|^2 = {norm:.3e} '
                             f'max|constr| = {max_abs_constr:.3e}')
            opt_state = opt_state_next
    raise RuntimeError(f'Did not find valid state in {max_num_tries} tries.')
Beispiel #20
0
    def body(state):
        p_k = -(state.H_k @ state.g_k)
        line_search_results = line_search(value_and_grad, state.x_k, p_k, old_fval=state.f_k, gfk=state.g_k,
                                          maxiter=ls_maxiter)
        state = state._replace(nfev=state.nfev + line_search_results.nfev,
                               ngev=state.ngev + line_search_results.ngev,
                               failed=line_search_results.failed,
                               ls_status=line_search_results.status)
        s_k = line_search_results.a_k * p_k
        x_kp1 = state.x_k + s_k
        f_kp1 = line_search_results.f_k
        g_kp1 = line_search_results.g_k
        # print(g_kp1)
        y_k = g_kp1 - state.g_k
        rho_k = jnp.reciprocal(y_k @ s_k)

        sy_k = s_k[:, None] * y_k[None, :]
        w = jnp.eye(d) - rho_k * sy_k
        H_kp1 = jnp.where(jnp.isfinite(rho_k),
                          jnp.linalg.multi_dot([w, state.H_k, w.T]) + rho_k * s_k[:, None] * s_k[None, :], state.H_k)

        converged = jnp.linalg.norm(g_kp1, ord=norm) < gtol

        state = state._replace(converged=converged,
                               k=state.k + 1,
                               x_k=x_kp1,
                               f_k=f_kp1,
                               g_k=g_kp1,
                               H_k=H_kp1
                               )

        return state
Beispiel #21
0
  def cond_fn(iteration, const, state):
    threshold = const[-1]
    errors = state[0]
    err = errors[iteration // inner_iterations-1, 0]

    return jnp.logical_or(iteration == 0,
                          jnp.logical_and(jnp.isfinite(err), err > threshold))
Beispiel #22
0
    def test_discrete_barycenter_grid(self, lse_mode, debiased, epsilon):
        """Tests the discrete barycenters on a 5x5x5 grid.

    Puts two masses on opposing ends of the hypercube with small noise in
    between. Check that their W barycenter sits (mostly) at the middle of the
    hypercube (e.g. index (5x5x5-1)/2)

    Args:
      lse_mode: bool, lse or scaling computations.
      debiased: bool, use (or not) debiasing as proposed in
      https://arxiv.org/abs/2006.02575
      epsilon: float, regularization parameter
    """
        size = jnp.array([5, 5, 5])
        grid_3d = grid.Grid(grid_size=size, epsilon=epsilon)
        a = jnp.ones(size)
        b = jnp.ones(size)
        a = a.ravel()
        b = b.ravel()
        a = jax.ops.index_update(a, 0, 10000)
        b = jax.ops.index_update(b, -1, 10000)
        a = a / jnp.sum(a)
        b = b / jnp.sum(b)
        threshold = 1e-2
        _, _, bar, errors = db.discrete_barycenter(grid_3d,
                                                   a=jnp.stack((a, b)),
                                                   threshold=threshold,
                                                   lse_mode=lse_mode,
                                                   debiased=debiased)
        self.assertGreater(bar[(jnp.prod(size) - 1) // 2], 0.7)
        self.assertGreater(1, bar[(jnp.prod(size) - 1) // 2])
        err = errors[jnp.isfinite(errors)][-1]
        self.assertGreater(threshold, err)
Beispiel #23
0
def _transform(x: DeviceArray, bounds: Optional[DeviceArray]) -> DeviceArray:
    if bounds is None:
        return x

    low = bounds[:, 0]
    upp = bounds[:, 1]

    return jnp.where(
        jnp.isfinite(low) & jnp.isfinite(upp),
        _between(x, low, upp),
        jnp.where(
            jnp.isfinite(low),
            _greater_than(x, low),
            jnp.where(jnp.isfinite(upp), _less_than(x, upp), x),
        ),
    )
def test_elbo_dynamic_support():
    x_prior = dist.Uniform(0, 5)
    x_unconstrained = 2.

    def model():
        numpyro.sample('x', x_prior)

    class _AutoGuide(AutoDiagonalNormal):
        def __call__(self, *args, **kwargs):
            return substitute(
                super(_AutoGuide, self).__call__,
                {'_auto_latent': x_unconstrained})(*args, **kwargs)

    adam = optim.Adam(0.01)
    guide = _AutoGuide(model)
    svi = SVI(model, guide, adam, AutoContinuousELBO())
    svi_state = svi.init(random.PRNGKey(0))
    actual_loss = svi.evaluate(svi_state)
    assert np.isfinite(actual_loss)

    guide_log_prob = dist.Normal(
        guide._init_latent, guide._init_scale).log_prob(x_unconstrained).sum()
    transfrom = transforms.biject_to(constraints.interval(0, 5))
    x = transfrom(x_unconstrained)
    logdet = transfrom.log_abs_det_jacobian(x_unconstrained, x)
    model_log_prob = x_prior.log_prob(x) + logdet
    expected_loss = guide_log_prob - model_log_prob
    assert_allclose(actual_loss, expected_loss, rtol=1e-6)
Beispiel #25
0
    def test_near_singular_inverse(self, jit):
        rng = jtu.rand_default(self.rng())

        @partial(_maybe_jit, jit, static_argnums=1)
        def near_singular_inverse(N=5, eps=1E-40):
            X = rng((N, N), dtype='float64')
            X = jnp.asarray(X)
            X = X.at[-1].mul(eps)
            return jnp.linalg.inv(X)

        with enable_x64():
            result_64 = near_singular_inverse()
            self.assertTrue(jnp.all(jnp.isfinite(result_64)))

        with disable_x64():
            result_32 = near_singular_inverse()
            self.assertTrue(jnp.all(~jnp.isfinite(result_32)))
Beispiel #26
0
def _max_mask_non_finite(x, axis=-1, keepdims=False, mask=0):
    """Returns `max` or `mask` if `max` is not finite."""
    m = np.max(x, axis=_astuple(axis), keepdims=keepdims)
    needs_masking = ~np.isfinite(m)
    if needs_masking.ndim > 0:
        m = np.where(needs_masking, mask, m)
    elif needs_masking:
        m = mask
    return m
Beispiel #27
0
    def eval_and_stable_update(self, fn: Callable,
                               state: _IterOptState) -> _IterOptState:
        """
        Like :meth:`eval_and_update` but when the value of the objective function
        or the gradients are not finite, we will not update the input `state`
        and will set the objective output to `nan`.

        :param fn: objective function.
        :param state: current optimizer state.
        :return: a pair of the output of objective function and the new optimizer state.
        """
        params = self.get_params(state)
        out, grads = value_and_grad(fn)(params)
        out, state = lax.cond(
            jnp.isfinite(out) & jnp.isfinite(ravel_pytree(grads)[0]).all(),
            lambda _: (out, self.update(grads, state)), lambda _:
            (jnp.nan, state), None)
        return out, state
Beispiel #28
0
def categorical_sample(key, probs):
  """Sample from a set of discrete probabilities."""
  probs = probs / probs.sum(axis=-1, keepdims=True)
  is_valid = jnp.logical_and(jnp.all(jnp.isfinite(probs)), jnp.all(probs >= 0))
  cpi = jnp.cumsum(probs, axis=-1)
  eps = jnp.finfo(probs.dtype).eps
  rnds = jax.random.uniform(
      key=key, shape=probs.shape[:-1] + (1,), dtype=probs.dtype, minval=eps)
  argmin = jnp.argmin(jnp.logical_or(rnds > cpi, probs < eps), axis=-1)
  return jnp.where(is_valid, argmin, -1)
def log_posterior(theta):
    log_prior_val = uniform_log_prior(theta)
    if np.isfinite(log_prior_val):
        preds = get_predictions(theta)
        log_lik_val = np.sum(
            log_likelihood_as_fxn_of_prediction(preds, expt_means,
                                                expt_uncertainties))
        return log_prior_val + log_lik_val
    else:
        return -np.inf
Beispiel #30
0
  def test_near_singular_inverse(self, jit):
    if jtu.device_under_test() == "tpu":
      self.skipTest("64-bit inverse not available on TPU")
    @partial(_maybe_jit, jit, static_argnums=1)
    def near_singular_inverse(key, N, eps):
      X = random.uniform(key, (N, N))
      X = X.at[-1].mul(eps)
      return jnp.linalg.inv(X)

    key = random.PRNGKey(1701)
    eps = 1E-40
    N = 5

    with enable_x64():
      result_64 = near_singular_inverse(key, N, eps)
      self.assertTrue(jnp.all(jnp.isfinite(result_64)))

    with disable_x64():
      result_32 = near_singular_inverse(key, N, eps)
      self.assertTrue(jnp.all(~jnp.isfinite(result_32)))