Exemplo n.º 1
0
    def _resample(self, rng):
        """(Re)Samples posterior/marginals given past test results.

    Args:
      rng: random key

    Produces and examines first the marginal produced by LBP.
    If that marginal is not valid because LBP did not converge,
    or that posterior samples are needed in the next iteration of the
    simulator by a group selector, we compute both marginals
    and posterior  using the more expensive sampler.
    """
        # reset marginals
        self.state.marginals = []
        # compute marginal using a cheap LBP sampler.
        lbp_sampler = self._samplers[0]
        lbp_sampler.produce_sample(rng, self.state)
        # if marginal is valid (i.e. LBP has converged), append it to state.
        if not np.any(np.isnan(lbp_sampler.marginal)):
            self.state.marginals.append(lbp_sampler.marginal)
            self.state.update_particles(lbp_sampler)
        # if marginal has not converged, or expensive sampler is needed, use it.
        if (np.any(np.isnan(lbp_sampler.marginal))
                or (self._policy.next_selector.NEEDS_POSTERIOR
                    and self.state.extra_tests_needed > 0 and
                    (self.state.curr_cycle == 0
                     or self.state.curr_cycle < self._max_test_cycles - 1))):
            sampler = self._samplers[1]
            sampler.produce_sample(rng, self.state)
            self.state.marginals.append(sampler.marginal)
            self.state.update_particles(sampler)
Exemplo n.º 2
0
    def test_gradient_sinkhorn_euclidean(self, lse_mode, momentum_strategy):
        """Test gradient w.r.t. locations x of reg-ot-cost."""
        d = 3
        n = 10
        m = 15
        keys = jax.random.split(self.rng, 4)
        x = jax.random.normal(keys[0], (n, d)) / 10
        y = jax.random.normal(keys[1], (m, d)) / 10

        a = jax.random.uniform(keys[2], (n, ))
        b = jax.random.uniform(keys[3], (m, ))
        # Adding zero weights to test proper handling
        a = jax.ops.index_update(a, 0, 0)
        b = jax.ops.index_update(b, 3, 0)
        a = a / jnp.sum(a)
        b = b / jnp.sum(b)

        def loss_fn(x, y):
            geom = pointcloud.PointCloud(x, y, epsilon=0.01)
            f, g, regularized_transport_cost, _, _ = sinkhorn.sinkhorn(
                geom,
                a,
                b,
                momentum_strategy=momentum_strategy,
                lse_mode=lse_mode)
            return regularized_transport_cost, (geom, f, g)

        delta = jax.random.normal(keys[0], (n, d))
        delta = delta / jnp.sqrt(jnp.vdot(delta, delta))
        eps = 1e-5  # perturbation magnitude

        # first calculation of gradient
        loss_and_grad = jax.value_and_grad(loss_fn, has_aux=True)
        (loss_value, aux), grad_loss = loss_and_grad(x, y)
        custom_grad = jnp.sum(delta * grad_loss)
        self.assertIsNot(loss_value, jnp.nan)
        self.assertEqual(grad_loss.shape, x.shape)
        self.assertFalse(jnp.any(jnp.isnan(grad_loss)))

        # second calculation of gradient
        tm = aux[0].transport_from_potentials(aux[1], aux[2])
        tmp = 2 * tm[:, :, None] * (x[:, None, :] - y[None, :, :])
        grad_x = jnp.sum(tmp, 1)
        other_grad = jnp.sum(delta * grad_x)

        # third calculation of gradient
        loss_delta_plus, _ = loss_fn(x + eps * delta, y)
        loss_delta_minus, _ = loss_fn(x - eps * delta, y)
        finite_diff_grad = (loss_delta_plus - loss_delta_minus) / (2 * eps)

        self.assertAllClose(custom_grad, other_grad, rtol=1e-02, atol=1e-02)
        self.assertAllClose(custom_grad,
                            finite_diff_grad,
                            rtol=1e-02,
                            atol=1e-02)
        self.assertAllClose(other_grad,
                            finite_diff_grad,
                            rtol=1e-02,
                            atol=1e-02)
        self.assertIsNot(jnp.any(jnp.isnan(custom_grad)), True)
Exemplo n.º 3
0
def training_step(H, data_input, target, optimizer, ema, rng):
    def loss_fun(params):
        stats = VAE(H).apply({'params': params}, data_input, target, rng)
        return -stats['elbo'] * np.log(2), stats

    gradval, stats = lax.pmean(
        grad(loss_fun, has_aux=True)(optimizer.target), 'batch')
    gradval, grad_norm = clip_grad_norm(gradval, H.grad_clip)

    ll_nans = jnp.any(jnp.isnan(stats['log_likelihood']))
    kl_nans = jnp.any(jnp.isnan(stats['kl']))
    stats.update(log_likelihood_nans=ll_nans, kl_nans=kl_nans)

    learning_rate = H.lr * linear_warmup(H.warmup_iters)(optimizer.state.step)

    # only update if no rank has a nan and if the grad norm is below a specific
    # threshold
    def skip_update(_):
        # Only increment the step
        return optimizer.replace(state=optimizer.state.replace(
            step=optimizer.state.step + 1)), ema

    def update(_):
        optimizer_ = optimizer.apply_gradient(gradval,
                                              learning_rate=learning_rate)
        e_decay = H.ema_rate
        ema_ = tree_multimap(lambda e, p: e * e_decay + (1 - e_decay) * p, ema,
                             optimizer.target)
        return optimizer_, ema_

    skip = (ll_nans | kl_nans | ((H.skip_threshold != -1)
                                 & ~(grad_norm < H.skip_threshold)))
    optimizer, ema = lax.cond(skip, skip_update, update, None)
    stats.update(skipped_updates=skip, grad_norm=grad_norm)
    return optimizer, ema, stats
Exemplo n.º 4
0
Arquivo: sa.py Projeto: lumip/numpyro
    def sample_kernel(sa_state, model_args=(), model_kwargs=None):
        pe_fn = potential_fn
        if potential_fn_gen:
            pe_fn = potential_fn_gen(*model_args, **model_kwargs)
        zs, pes, loc, scale = sa_state.adapt_state
        # we recompute loc/scale after each iteration to avoid precision loss
        # XXX: consider to expose a setting to do this job periodically
        # to save some computations
        loc = jnp.mean(zs, 0)
        if scale.ndim == 2:
            cov = jnp.cov(zs, rowvar=False, bias=True)
            if cov.shape == ():  # JAX returns scalar for 1D input
                cov = cov.reshape((1, 1))
            cholesky = jnp.linalg.cholesky(cov)
            scale = jnp.where(jnp.any(jnp.isnan(cholesky)), scale, cholesky)
        else:
            scale = jnp.std(zs, 0)

        rng_key, rng_key_z, rng_key_reject, rng_key_accept = random.split(sa_state.rng_key, 4)
        _, unravel_fn = ravel_pytree(sa_state.z)

        z = loc + _sample_proposal(scale, rng_key_z)
        pe = pe_fn(unravel_fn(z))
        pe = jnp.where(jnp.isnan(pe), jnp.inf, pe)
        diverging = (pe - sa_state.potential_energy) > max_delta_energy

        # NB: all terms having the pattern *s will have shape N x ...
        # and all terms having the pattern *s_ will have shape (N + 1) x ...
        locs, scales = _get_proposal_loc_and_scale(zs, loc, scale, z)
        zs_ = jnp.concatenate([zs, z[None, :]])
        pes_ = jnp.concatenate([pes, pe[None]])
        locs_ = jnp.concatenate([locs, loc[None, :]])
        scales_ = jnp.concatenate([scales, scale[None, ...]])
        if scale.ndim == 2:  # dense_mass
            log_weights_ = dist.MultivariateNormal(locs_, scale_tril=scales_).log_prob(zs_) + pes_
        else:
            log_weights_ = dist.Normal(locs_, scales_).log_prob(zs_).sum(-1) + pes_
        # mask invalid values (nan, +inf) by -inf
        log_weights_ = jnp.where(jnp.isfinite(log_weights_), log_weights_, -jnp.inf)
        # get rejecting index
        j = random.categorical(rng_key_reject, log_weights_)
        zs = _numpy_delete(zs_, j)
        pes = _numpy_delete(pes_, j)
        loc = locs_[j]
        scale = scales_[j]
        adapt_state = SAAdaptState(zs, pes, loc, scale)

        # NB: weights[-1] / sum(weights) is the probability of rejecting the new sample `z`.
        accept_prob = 1 - jnp.exp(log_weights_[-1] - logsumexp(log_weights_))
        itr = sa_state.i + 1
        n = jnp.where(sa_state.i < wa_steps, itr, itr - wa_steps)
        mean_accept_prob = sa_state.mean_accept_prob + (accept_prob - sa_state.mean_accept_prob) / n

        # XXX: we make a modification of SA sampler in [1]
        # in [1], each MCMC state contains N points `zs`
        # here we do resampling to pick randomly a point from those N points
        k = random.categorical(rng_key_accept, jnp.zeros(zs.shape[0]))
        z = unravel_fn(zs[k])
        pe = pes[k]
        return SAState(itr, z, pe, accept_prob, mean_accept_prob, diverging, adapt_state, rng_key)
Exemplo n.º 5
0
def rmsprop(gradval,
            params,
            eta=0.01,
            gamma=0.9,
            eps=1e-8,
            R=500,
            per=100,
            disp=None):
    params = {k: np.array(v) for k, v in params.items()}
    grms = {k: np.zeros_like(v) for k, v in params.items()}
    n = len(params)

    # iterate to max
    for j in range(R + 1):
        val, grad = gradval(params)

        vnan = np.isnan(val)
        gnan = tree_map(lambda g: np.isnan(g).any(), grad)

        if vnan or np.any(tree_leaves(gnan)):
            print('Encountered nans!')
            disp(j, val, params)
            return params

        for k in params:
            grms[k] += (1 - gamma) * (grad[k]**2 - grms[k])
            params[k] += eta * grad[k] / np.sqrt(grms[k] + eps)

        if disp is not None and j % per == 0:
            disp(j, val, params)

    return params
Exemplo n.º 6
0
 def update_fn(updates, opt_state, params=None):
     del params
     opt_state = ZeroNansState(
         jax.tree_map(lambda p: jnp.any(jnp.isnan(p)), updates))
     updates = jax.tree_map(
         lambda p: jnp.where(jnp.isnan(p), jnp.zeros_like(p), p), updates)
     return updates, opt_state
Exemplo n.º 7
0
def test_sigmoid(x, version):
    """Check for null gradient issues."""
    result = sigmoid(x, version=version)
    assert not np.isnan(result)

    dsigmoid = grad(sigmoid)
    dresult = dsigmoid(x, version=version)
    assert not np.isnan(dresult)
Exemplo n.º 8
0
    def test_regularized_unbalanced_bures(self):
        """Tests Regularized Unbalanced Bures."""
        x = jnp.concatenate((jnp.array([0.9]), self.x[0, :]))
        y = jnp.concatenate((jnp.array([1.1]), self.y[0, :]))

        rub = costs.UnbalancedBures(self.dim, 1, 0.8)
        self.assertIsNot(jnp.any(jnp.isnan(rub(x, y))), True)
        self.assertIsNot(jnp.any(jnp.isnan(rub(y, x))), True)
        self.assertAllClose(rub(x, y), rub(y, x), rtol=1e-3, atol=1e-3)
Exemplo n.º 9
0
def test_sample_momentum_resolution():
    N = 100
    mom = Momentum.from_ndarray(rjax.uniform(rng, (N, 3)))
    smom, cov = sample_momentum_resolution(rng, mom)

    assert cov.shape == (N, 3, 3)
    assert smom.as_array.shape == (N, 3)
    assert not np.isnan(cov).any()
    assert not np.isnan(smom.as_array).any()
Exemplo n.º 10
0
def test_sample_position_resolution():
    N = 100
    pos = Position.from_ndarray(rjax.uniform(rng, (N, 3)))
    spos, cov = sample_position_resolution(rng, pos)

    assert cov.shape == (N, 3, 3)
    assert spos.as_array.shape == (N, 3)
    assert not np.isnan(cov).any()
    assert not np.isnan(spos.as_array).any()
Exemplo n.º 11
0
def test_sample_helix_resolution():
    N = 100
    hel = Helix.from_ndarray(rjax.uniform(rng, (N, 5)))
    shel, cov = sample_helix_resolution(rng, hel)

    assert cov.shape == (N, 5, 5)
    assert shel.as_array.shape == (N, 5)
    assert not np.isnan(cov).any()
    assert not np.isnan(shel.as_array).any()
Exemplo n.º 12
0
def test_sample_cluster_resolution():
    N = 100
    clu = Cluster.from_ndarray(rjax.uniform(rng, (N, 3)))
    sclu, cov = sample_cluster_resolution(rng, clu)

    assert cov.shape == (N, 3, 3)
    assert sclu.as_array.shape == (N, 3)
    assert not np.isnan(cov).any()
    assert not np.isnan(sclu.as_array).any()
Exemplo n.º 13
0
 def test_sin(self):
     """In [-1e10, 1e10] safe_sin and safe_cos are accurate."""
     for fn in ['sin', 'cos']:
         y_true, y = safe_trig_harness(fn, 10)
         self.assertLess(np.max(np.abs(y - y_true)), 1e-4)
         self.assertFalse(jnp.any(jnp.isnan(y)))
     # Beyond that range it's less accurate but we just don't want it to be NaN.
     for fn in ['sin', 'cos']:
         y_true, y = safe_trig_harness(fn, 60)
         self.assertFalse(jnp.any(jnp.isnan(y)))
Exemplo n.º 14
0
def adam(gradval,
         params,
         eta=0.01,
         beta1=0.9,
         beta2=0.9,
         eps=1e-7,
         c=0.01,
         R=500,
         per=100,
         disp=None,
         log=False):
    params = {k: np.array(v) for k, v in params.items()}
    gavg = {k: np.zeros_like(v) for k, v in params.items()}
    grms = {k: np.zeros_like(v) for k, v in params.items()}
    n = len(params)

    if log:
        hist = {k: [] for k in params}

    # iterate to max
    for j in range(R + 1):
        val, grad = gradval(params)

        # test for nans
        vnan = np.isnan(val) or np.isinf(val)
        gnan = np.array(
            [np.isnan(g).any() or np.isinf(g).any() for g in grad.values()])
        if vnan or gnan.any():
            print('Encountered nan/inf!')
            disp(j, val, params)
            summary_naninf(grad)
            return params

        # clip gradients
        gradc = {k: clip2(grad[k], c) for k in params}

        # early bias correction
        etat = eta * (np.sqrt(1 - beta2**(j + 1))) / (1 - beta1**(j + 1))

        for k in params:
            gavg[k] += (1 - beta1) * (gradc[k] - gavg[k])
            grms[k] += (1 - beta2) * (gradc[k]**2 - grms[k])
            params[k] += etat * gavg[k] / (np.sqrt(grms[k]) + eps)

            if log:
                hist[k].append(params[k])

        if disp is not None and j % per == 0:
            disp(j, val, params)

    if log:
        hist = {k: np.stack(v) for k, v in hist.items()}
        return hist

    return params
Exemplo n.º 15
0
    def sample(self, key: jnp.ndarray) -> jnp.ndarray:
        """Sample from the distribution.

        Args:
            key: JAX random key.
        """
        sample = jax.random.beta(key, self.alpha, self.beta)

        is_nan = jnp.logical_or(jnp.isnan(self.alpha), jnp.isnan(self.beta))
        return jnp.where(is_nan,
                         jnp.full(self.alpha.shape, jnp.nan),
                         sample)
Exemplo n.º 16
0
def test_scce_out_of_bounds():
    ypred = jnp.zeros([4, 10])
    ytrue0 = jnp.array([0, 0, -1, 0])
    ytrue1 = jnp.array([0, 0, 10, 0])

    scce = elegy.losses.SparseCategoricalCrossentropy()

    assert jnp.isnan(scce(ytrue0, ypred)).any()
    assert jnp.isnan(scce(ytrue1, ypred)).any()

    scce = elegy.losses.SparseCategoricalCrossentropy(check_bounds=False)
    assert not jnp.isnan(scce(ytrue0, ypred)).any()
    assert not jnp.isnan(scce(ytrue1, ypred)).any()
Exemplo n.º 17
0
    def sample(self, key: jnp.ndarray) -> jnp.ndarray:
        """Sample from the distribution.

        Args:
            key: JAX random key.
        """
        std_norm = jax.random.normal(key,
                                     shape=self.mu.shape,
                                     dtype=self.mu.dtype)

        is_nan = jnp.logical_or(jnp.isnan(self.mu), jnp.isnan(self.sigma))
        return jnp.where(is_nan, jnp.full(self.mu.shape, jnp.nan),
                         std_norm * self.sigma + self.mu)
Exemplo n.º 18
0
 def test_behler_parrinello_network(self, N_types, dtype):
     key = random.PRNGKey(1)
     R = np.array([[0, 0, 0], [1, 1, 1], [1, 1, 0]], dtype)
     species = np.array([1, 1, N_types]) if N_types > 1 else None
     box_size = f32(1.5)
     displacement, _ = space.periodic(box_size)
     nn_init, nn_apply = energy.behler_parrinello(displacement, species)
     params = nn_init(key, R)
     nn_force_fn = grad(nn_apply, argnums=1)
     nn_force = jit(nn_force_fn)(params, R)
     nn_energy = jit(nn_apply)(params, R)
     self.assertAllClose(np.any(np.isnan(nn_energy)), False)
     self.assertAllClose(np.any(np.isnan(nn_force)), False)
     self.assertAllClose(nn_force.shape, [3, 3])
Exemplo n.º 19
0
def visualize_normals(depth, acc, scaling=None):
    """Visualize fake normals of `depth` (optionally scaled to be isotropic)."""
    if scaling is None:
        mask = ~jnp.isnan(depth)
        x, y = jnp.meshgrid(jnp.arange(depth.shape[1]),
                            jnp.arange(depth.shape[0]),
                            indexing='xy')
        xy_var = (jnp.var(x[mask]) + jnp.var(y[mask])) / 2
        z_var = jnp.var(depth[mask])
        scaling = jnp.sqrt(xy_var / z_var)

    scaled_depth = scaling * depth
    normals = depth_to_normals(scaled_depth)
    return matte(
        jnp.isnan(normals) + jnp.nan_to_num((normals + 1) / 2, 0), acc)
Exemplo n.º 20
0
 def test_prior(self,key, num_samples, log_likelihood=None, **kwargs):
     keys = random.split(key, num_samples)
     for key in keys:
         U = random.uniform(key, shape=(self.U_ndims,))
         Y = self(U, **kwargs)
         for k,v in Y.items():
             if jnp.any(jnp.isnan(v)):
                 raise ValueError('nan in prior transform',Y)
         if log_likelihood is not None:
             loglik = log_likelihood(**Y)
             if jnp.isnan(loglik):
                 raise ValueError("Log likelihood is nan", Y)
             if loglik == 0.:
                 print("Log likelihood is zero", loglik, Y)
             print(loglik)
Exemplo n.º 21
0
    def wrapper(self, x, *args):
        f, grad = self.fun(x, *args)
        if np.isnan(f) or np.any(np.isnan(grad)):
            print("nan detected")
            print(x)
            print(f)
            print(grad)

            if self.fundebug is not None:
                self.fundebug(x)

            assert (0)
        self.f = f
        self.grad = grad
        return f, grad
Exemplo n.º 22
0
def _contains_query(vals, query):
    if isinstance(query, tuple):
        return map(partial(_contains_query, vals), query)

    if np.isnan(query):
        if np.any(np.isnan(vals)):
            raise FoundValue('NaN')
    elif np.isinf(query):
        if np.any(np.isinf(vals)):
            raise FoundValue('Found Inf')
    elif np.isscalar(query):
        if np.any(vals == query):
            raise FoundValue(str(query))
    else:
        raise ValueError('Malformed Query: {}'.format(query))
    def test_gradient_sinkhorn_geometry(self, lse_mode):
        """Test gradient w.r.t. cost matrix."""
        for n, m in ((11, 13), (15, 9)):
            keys = jax.random.split(self.rng, 2)
            cost_matrix = jnp.abs(jax.random.normal(keys[0], (n, m)))
            delta = jax.random.normal(keys[1], (n, m))
            delta = delta / jnp.sqrt(jnp.vdot(delta, delta))
            eps = 1e-3  # perturbation magnitude

            def loss_fn(cm):
                a = jnp.ones(cm.shape[0]) / cm.shape[0]
                b = jnp.ones(cm.shape[1]) / cm.shape[1]
                geom = geometry.Geometry(cm, epsilon=0.5)
                out = sinkhorn.sinkhorn(geom, a, b, lse_mode=lse_mode)
                return out.reg_ot_cost, (geom, out.f, out.g)

            # first calculation of gradient
            loss_and_grad = jax.jit(jax.value_and_grad(loss_fn, has_aux=True))
            (loss_value, aux), grad_loss = loss_and_grad(cost_matrix)
            custom_grad = jnp.sum(delta * grad_loss)

            self.assertIsNot(loss_value, jnp.nan)
            self.assertEqual(grad_loss.shape, cost_matrix.shape)
            self.assertFalse(jnp.any(jnp.isnan(grad_loss)))

            # second calculation of gradient
            transport_matrix = aux[0].transport_from_potentials(aux[1], aux[2])
            grad_x = transport_matrix
            other_grad = jnp.sum(delta * grad_x)

            # third calculation of gradient
            loss_delta_plus, _ = loss_fn(cost_matrix + eps * delta)
            loss_delta_minus, _ = loss_fn(cost_matrix - eps * delta)
            finite_diff_grad = (loss_delta_plus - loss_delta_minus) / (2 * eps)

            self.assertAllClose(custom_grad,
                                other_grad,
                                rtol=1e-02,
                                atol=1e-02)
            self.assertAllClose(custom_grad,
                                finite_diff_grad,
                                rtol=1e-02,
                                atol=1e-02)
            self.assertAllClose(other_grad,
                                finite_diff_grad,
                                rtol=1e-02,
                                atol=1e-02)
            self.assertIsNot(jnp.any(jnp.isnan(custom_grad)), True)
Exemplo n.º 24
0
    def kernel(rng_key: jax.random.PRNGKey,
               state: RWMState) -> Tuple[RWMState, RWMInfo]:
        """Moves the chain by one step using the Random Walk Metropolis algorithm.

        Parameters
        ----------
        rng_key:
           The pseudo-random number generator key used to generate random numbers.
        state:
            The current state of the chain: position, log-probability and gradient
            of the log-probability.

        Returns
        -------
        The next state of the chain and additional information about the current step.
        """
        key_move, key_accept = jax.random.split(rng_key)

        position, log_prob = state

        move_proposal = proposal_generator(key_move)
        new_position = position + move_proposal
        new_log_prob = logpdf(new_position)
        new_state = RWMState(new_position, new_log_prob)

        delta = new_log_prob - log_prob
        delta = np.where(np.isnan(delta), -np.inf, delta)
        p_accept = np.clip(np.exp(delta), a_max=1)

        do_accept = jax.random.bernoulli(key_accept, p_accept)
        accept_state = (new_state, RWMInfo(new_state, p_accept, True))
        reject_state = (state, RWMInfo(new_state, p_accept, False))
        return np.where(do_accept, accept_state, reject_state)
Exemplo n.º 25
0
    def update(self, transition_batch, return_td_error=False):
        r"""

        Update the model parameters (weights) of the underlying function approximator.

        Parameters
        ----------
        transition_batch : TransitionBatch

            A batch of transitions.

        return_td_error : bool, optional

            Whether to return the TD-errors.

        Returns
        -------
        metrics : dict of scalar ndarrays

            The structure of the metrics dict is ``{name: score}``.

        td_error : ndarray, optional

            The non-aggregated TD-errors, :code:`shape == (batch_size,)`. This is only returned if
            we set :code:`return_td_error=True`.

        """
        grads, function_state, metrics, td_error = self.grads_and_metrics(
            transition_batch)
        if any(jnp.any(jnp.isnan(g)) for g in jax.tree_leaves(grads)):
            raise RuntimeError(f"found nan's in grads: {grads}")
        self.update_from_grads(grads, function_state)
        return (metrics, td_error) if return_td_error else metrics
Exemplo n.º 26
0
def nan_error_check(prim, error, enabled_errors, *in_vals, **params):
    out = prim.bind(*in_vals, **params)
    if ErrorCategory.NAN not in enabled_errors:
        return out, error
    no_nans = jnp.logical_not(jnp.any(jnp.isnan(out)))
    msg = f"nan generated by primitive {prim.name} at {summary()}"
    return out, assert_func(error, no_nans, msg, None)
Exemplo n.º 27
0
 def train_step(self, summary: Summary, data: dict, progress: np.ndarray):
     kv = self.train_op(progress, data['image'].numpy(),
                        data['label'].numpy())
     for k, v in kv.items():
         if jn.isnan(v):
             raise ValueError('NaN, try reducing learning rate', k)
         summary.scalar(k, float(v))
Exemplo n.º 28
0
    def _hmc_next(step_size, inverse_mass_matrix, vv_state, model_args,
                  model_kwargs, rng_key):
        if potential_fn_gen:
            nonlocal vv_update
            pe_fn = potential_fn_gen(*model_args, **model_kwargs)
            _, vv_update = velocity_verlet(pe_fn, kinetic_fn)

        num_steps = _get_num_steps(step_size, trajectory_len)
        vv_state_new = fori_loop(
            0, num_steps,
            lambda i, val: vv_update(step_size, inverse_mass_matrix, val),
            vv_state)
        energy_old = vv_state.potential_energy + kinetic_fn(
            inverse_mass_matrix, vv_state.r)
        energy_new = vv_state_new.potential_energy + kinetic_fn(
            inverse_mass_matrix, vv_state_new.r)
        delta_energy = energy_new - energy_old
        delta_energy = jnp.where(jnp.isnan(delta_energy), jnp.inf,
                                 delta_energy)
        accept_prob = jnp.clip(jnp.exp(-delta_energy), a_max=1.0)
        diverging = delta_energy > max_delta_energy
        transition = random.bernoulli(rng_key, accept_prob)
        vv_state, energy = cond(transition, (vv_state_new, energy_new),
                                identity, (vv_state, energy_old), identity)
        return vv_state, energy, num_steps, accept_prob, diverging
Exemplo n.º 29
0
def _build_basetree(vv_update, kinetic_fn, z, r, z_grad, inverse_mass_matrix,
                    step_size, going_right, energy_current, max_delta_energy):
    step_size = np.where(going_right, step_size, -step_size)
    z_new, r_new, potential_energy_new, z_new_grad = vv_update(
        step_size,
        inverse_mass_matrix,
        (z, r, energy_current, z_grad),
    )

    energy_new = potential_energy_new + kinetic_fn(inverse_mass_matrix, r_new)
    delta_energy = energy_new - energy_current
    # Handles the NaN case.
    delta_energy = np.where(np.isnan(delta_energy), np.inf, delta_energy)
    tree_weight = -delta_energy

    diverging = delta_energy > max_delta_energy
    accept_prob = np.clip(np.exp(-delta_energy), a_max=1.0)
    return TreeInfo(z_new,
                    r_new,
                    z_new_grad,
                    z_new,
                    r_new,
                    z_new_grad,
                    z_new,
                    potential_energy_new,
                    z_new_grad,
                    energy_new,
                    depth=0,
                    weight=tree_weight,
                    r_sum=r_new,
                    turning=False,
                    diverging=diverging,
                    sum_accept_probs=accept_prob,
                    num_proposals=1)
Exemplo n.º 30
0
def update_state_arm(state: StateArm, samples: List, grads: List, error_fn: Callable, get_fb_grads=None) -> StateArm:
    last_sample = samples[-1] if samples is not None else state.last_sample

    if get_fb_grads:
        samples, grads = get_fb_grads(samples)

    samples = flatten_param_list(samples)
    grads = flatten_param_list(grads)
    # concatenate
    if state.samples is not None:
        all_samples = jnp.concatenate([state.samples, samples])
    else:
        all_samples = samples
    if state.grads is not None:
        all_grads = jnp.concatenate([state.grads, grads])
    else:
        all_grads = grads
    _metric = error_fn(samples, grads)
    metric = _metric if not jnp.isnan(_metric) else jnp.inf
    state = StateArm(hyperparameters=state.hyperparameters,
                     run_timed_sampler=state.run_timed_sampler,
                     last_sample=last_sample,
                     samples=all_samples,
                     grads=all_grads,
                     metric=metric
                    )
    return state