コード例 #1
0
def compute_loss(predicted_positions,
                 predicted_momentums,
                 target_positions,
                 target_momentums,
                 auxiliary_predictions=None,
                 regularizations=None):
    """Computes the loss for the given predictions."""
    assert predicted_positions.shape == target_positions.shape, f'Got predicted_positions: {predicted_positions.shape}, target_positions: {target_positions.shape}'
    assert predicted_momentums.shape == target_momentums.shape, f'Got predicted_momentums: {predicted_momentums.shape}, target_momentums: {target_momentums.shape}'

    loss = optax.l2_loss(predictions=predicted_positions,
                         targets=target_positions)
    loss += optax.l2_loss(predictions=predicted_momentums,
                          targets=target_momentums)
    loss = jnp.mean(loss)

    if auxiliary_predictions is not None:
        angular_velocities = auxiliary_predictions['angular_velocities']
        angular_velocities_variances = jnp.var(angular_velocities,
                                               axis=0).sum()
        loss += regularizations[
            'angular_velocities'] * angular_velocities_variances

        actions = auxiliary_predictions['actions']
        actions_variances = jnp.var(actions, axis=0).sum()
        loss += regularizations['actions'] * actions_variances
    return loss
コード例 #2
0
def loss_reg(theta, X, Y, W, data_weights, reg_scale):

    #W = jnp.expand_dims(W, axis=-1)*jnp.eye(W.shape[-1])
    #    *jnp.expand_dims(jnp.expand_dims(jnp.eye(W.shape[-1]), 0), 0)
    #W = jnp.expand_dims(jnp.expand_dims(W, axis=0), axis=0)
    #X = jnp.expand_dims(jnp.expand_dims(X, 0), 0)
    #print("shapes",theta.shape,X.shape, W.shape, Y.shape)
    #m = jnp.matmul(W, X)
    #norm = np.matmul(X.transpose((0,1,2,4,3)), np.matmul(W, X))
    l_chiSq = jnp.einsum('jia,jka->jki', X, theta) - Y
    #print("lch", np.sum(np.isinf(np.array(l_chiSq))))
    #print("l shape", l_chiSq.shape)
    chiSq = jnp.mean(l_chiSq * W * l_chiSq, axis=-1)
    #x_overlap = jnp.matmul(X.transpose(), X)
    ##x_overlap *= 1 - jnp.eye(X.shape[-1])
    reg_overlap = jnp.sum(theta**2, axis=-1)
    #x_overlap *= 2*(jnp.eye(X.shape[-1]) - 0.5)
    #reg_overlap = jnp.einsum('jka,ab,jkb->jk',
    #    theta, x_overlap, theta)

    reg_scale = jnp.expand_dims(
        jnp.expand_dims(reg_scale / jnp.var(Y, axis=-1), -1), -1)
    #print("shapessss", chiSq.shape,reg_overlap.shape, x_overlap.shape)
    data_weights = jnp.var(Y, axis=-1)
    loss = np.sum((data_weights*(chiSq + reg_scale*reg_overlap))[0:])\
        /(chiSq.shape[1]*np.sum(data_weights[0:]))
    #loss = chiSq + reg_scale*reg_overlap
    #loss = reg_scale*reg_overlap
    return loss
コード例 #3
0
 def testVarianceScaling(self, map_in, map_out, fan, distr):
     shape = (80, 50, 7)
     fan_in, fan_out = jax._src.nn.initializers._compute_fans(
         NamedShape(*shape), 0, 1)
     key = jax.random.PRNGKey(0)
     base_scaling = partial(jax.nn.initializers.variance_scaling, 100, fan,
                            distr)
     ref_sampler = lambda: base_scaling(in_axis=0, out_axis=1)(key, shape)
     if map_in and map_out:
         out_axes = ['i', 'o', ...]
         named_shape = NamedShape(shape[2], i=shape[0], o=shape[1])
         xmap_sampler = lambda: base_scaling(in_axis='i', out_axis='o')(
             key, named_shape)
     elif map_in:
         out_axes = ['i', ...]
         named_shape = NamedShape(shape[1], shape[2], i=shape[0])
         xmap_sampler = lambda: base_scaling(in_axis='i', out_axis=0)(
             key, named_shape)
     elif map_out:
         out_axes = [None, 'o', ...]
         named_shape = NamedShape(shape[0], shape[2], o=shape[1])
         xmap_sampler = lambda: base_scaling(in_axis=0, out_axis='o')(
             key, named_shape)
     mapped_sampler = xmap(xmap_sampler,
                           in_axes=(),
                           out_axes=out_axes,
                           axis_sizes={
                               'i': shape[0],
                               'o': shape[1]
                           })
     self.assertAllClose(jnp.var(mapped_sampler()),
                         jnp.var(ref_sampler()),
                         atol=1e-4,
                         rtol=2e-2)
コード例 #4
0
ファイル: utils_test.py プロジェクト: BwRy/jraph
 def test_segment_variance(self):
     result = utils.segment_variance(jnp.arange(8),
                                     jnp.array([0, 0, 0, 1, 1, 2, 2, 2]), 3)
     self.assertAllClose(
         result,
         jnp.stack([
             jnp.var(jnp.arange(3)),
             jnp.var(jnp.arange(3, 5)),
             jnp.var(jnp.arange(5, 8))
         ]))
コード例 #5
0
ファイル: utils_jax.py プロジェクト: chaichontat/janelia2020
def correlate(data, fitted):
    """ Designed for images (n_imgs, pxs). """
    x, y = check_data(data, fitted)

    Ex = jnp.mean(x, axis=1, keepdims=True)
    Ey = jnp.mean(y, axis=1, keepdims=True)

    cov = jnp.mean((x - Ex) * (y - Ey), axis=1)

    var_x = jnp.sqrt(jnp.var(x, axis=1))
    var_y = jnp.sqrt(jnp.var(y, axis=1))

    return cov / (var_x * var_y)
コード例 #6
0
def test_submission_from_samples_linear(metaculus_questions, logistic_mixture_samples):
    normalized_mixture = metaculus_questions.continuous_linear_open_question.get_submission_from_samples(
        logistic_mixture_samples
    )
    normalized_mixture_samples = [normalized_mixture.sample() for _ in range(5000)]
    mixture_samples = metaculus_questions.continuous_linear_open_question.denormalize_samples(
        normalized_mixture_samples
    )
    assert float(np.mean(logistic_mixture_samples)) == pytest.approx(
        float(np.mean(mixture_samples)), rel=0.1
    )
    assert float(np.var(logistic_mixture_samples)) == pytest.approx(
        float(np.var(mixture_samples)), rel=0.2
    )
コード例 #7
0
def test_discrete_gibbs_gmm_1d(modified):
    def model(probs, locs):
        c = numpyro.sample("c", dist.Categorical(probs))
        numpyro.sample("x", dist.Normal(locs[c], 0.5))

    probs = jnp.array([0.15, 0.3, 0.3, 0.25])
    locs = jnp.array([-2, 0, 2, 4])
    kernel = DiscreteHMCGibbs(NUTS(model), modified=modified)
    mcmc = MCMC(kernel, 1000, 200000, progress_bar=False)
    mcmc.run(random.PRNGKey(0), probs, locs)
    samples = mcmc.get_samples()
    assert_allclose(jnp.mean(samples["x"]), 1.3, atol=0.1)
    assert_allclose(jnp.var(samples["x"]), 4.36, atol=0.1)
    assert_allclose(jnp.mean(samples["c"]), 1.65, atol=0.1)
    assert_allclose(jnp.var(samples["c"]), 1.03, atol=0.1)
コード例 #8
0
def visualize_normals(depth, acc, scaling=None):
    """Visualize fake normals of `depth` (optionally scaled to be isotropic)."""
    if scaling is None:
        mask = ~jnp.isnan(depth)
        x, y = jnp.meshgrid(jnp.arange(depth.shape[1]),
                            jnp.arange(depth.shape[0]),
                            indexing='xy')
        xy_var = (jnp.var(x[mask]) + jnp.var(y[mask])) / 2
        z_var = jnp.var(depth[mask])
        scaling = jnp.sqrt(xy_var / z_var)

    scaled_depth = scaling * depth
    normals = depth_to_normals(scaled_depth)
    return matte(
        jnp.isnan(normals) + jnp.nan_to_num((normals + 1) / 2, 0), acc)
コード例 #9
0
def test_mean_and_var_mid():
    t0 = 0.0
    t1 = 3.0
    y0 = np.linspace(0.1, 0.9, D)
    num_samples = 500

    vals = onp.zeros((num_samples, D))
    for i in range(num_samples):
        rng = random.PRNGKey(i)
        bm = make_brownian_motion(t0, np.zeros(y0.shape), t1, rng)
        vals[i, :] = bm(t1 / 2.0)

    print(np.mean(vals), np.var(vals))
    assert np.allclose(np.mean(vals), 0.0, atol=1e-1, rtol=1e-1)
    assert np.allclose(np.var(vals), t1 / 2.0, atol=1e-1, rtol=1e-1)
コード例 #10
0
    def update(chain_state, _,
               rhat_state: GelmanRubinState) -> GelmanRubinState:
        """Update rhat estimates

        Parameters
        ----------
        chain_state: HMCState
            The chain state
        rhat_state: GelmanRubinState
            The GelmanRubinState from the previous draw

        Returns
        -------
        An updated GelmanRubinState object
        """
        within_state, *_ = rhat_state

        positions = chain_state.position
        within_state = w_update(within_state, positions)
        covariance, step, mean = w_covariance(within_state)
        within_var = jnp.mean(covariance, axis=0)
        between_var = jnp.var(mean, axis=0, ddof=1)
        estimator = ((step - 1) / step) * within_var + between_var
        rhat = jnp.sqrt(estimator / within_var)
        worst_rhat = rhat[jnp.argmax(jnp.abs(rhat - 1.0))]

        return GelmanRubinState(within_state, rhat, worst_rhat)
コード例 #11
0
    def variance_loss_objective(encoder_params,
                                decoder_params,
                                batch,
                                prng_key,
                                num_samples=1):
        """
        Computes the variance loss objective of a discrete VAE.
        :param encoder_params: encoder parameters (list)
        :param decoder_params: decoder parameters (list)
        :param batch: batch of data (jax.numpy array)
        :param prng_key: PRNG key
        :param num_samples: number of samples
        """
        posterior_params = encoder(encoder_params, batch)  # BxD
        posterior_samples = lax.stop_gradient(
            bernoulli.sample(posterior_params, prng_key, num_samples))

        log_prior = np.sum(vmap(bernoulli.logpmf, in_axes=(0, None))(
            posterior_samples, 0.5 * np.ones(
                (batch.shape[0], posterior_params.shape[-1]))),
                           axis=(1, 2))  # SxBxD

        log_posterior = np.sum(vmap(bernoulli.logpmf,
                                    in_axes=(0, None))(posterior_samples,
                                                       posterior_params),
                               axis=(1, 2))

        log_likelihood = vmap(bernoulli_log_likelihood,
                              in_axes=(None, 0, None))(decoder_params,
                                                       posterior_samples,
                                                       batch)
        elbo_samples = log_likelihood - log_posterior + log_prior
        return np.var(elbo_samples, axis=0, ddof=1) / batch.shape[0]
コード例 #12
0
    def predict_f(key, x, x_cov, full_covariance=False):

        # create distribution
        # print(x.min(), x.max(), x_cov.min(), x_cov.max())
        x_dist = z(x, x_cov)

        # sample
        x_mc_samples = x_dist.sample((n_samples,), key)

        # function predictions over mc samples
        # (N,M,P) = f(N,D,M)
        y_mu_mc = jax.vmap(gp_pred.predict_mean, in_axes=0, out_axes=1)(x_mc_samples)

        # mean of mc samples
        # (N,P,) = (N,M,P)
        y_mu = jnp.mean(y_mu_mc, axis=1)

        if full_covariance:
            # ===================
            # Covariance
            # ===================
            # (N,P,M) - (N,P,1) -> (N,P,M)
            dfydx = y_mu_mc - y_mu[..., None]

            # (N,M,P) @ (M,M) @ (N,M,P) -> (N,P,D)
            cov = wc * jnp.einsum("ijk,lmn->ikl", dfydx, dfydx.T)

            return y_mu, cov
        else:
            # (N,P) = (N,M,P)
            var = jnp.var(y_mu_mc, axis=1)
            return y_mu, var
コード例 #13
0
def test_estimate_likelihood(kernel_cls):
    data_key, tr_key, sub_key, rng_key = random.split(random.PRNGKey(0), 4)
    ref_params = jnp.array([0.1, 0.5, -0.2])
    sigma = 0.1
    data = ref_params + dist.Normal(jnp.zeros(3), jnp.ones(3)).sample(
        data_key, (10_000,)
    )
    n, _ = data.shape
    num_warmup = 200
    num_samples = 200
    num_blocks = 20

    def model(data):
        mean = numpyro.sample(
            "mean", dist.Normal(ref_params, jnp.ones_like(ref_params))
        )
        with numpyro.plate("N", data.shape[0], subsample_size=100, dim=-2) as idx:
            numpyro.sample("obs", dist.Normal(mean, sigma), obs=data[idx])

    proxy_fn = HMCECS.taylor_proxy({"mean": ref_params})
    kernel = HMCECS(kernel_cls(model), proxy=proxy_fn, num_blocks=num_blocks)
    mcmc = MCMC(kernel, num_warmup=num_warmup, num_samples=num_samples)

    mcmc.run(random.PRNGKey(0), data, extra_fields=["hmc_state.potential_energy"])

    pes = mcmc.get_extra_fields()["hmc_state.potential_energy"]
    samples = mcmc.get_samples()
    pes_full = vmap(
        lambda sample: log_density(
            model, (data,), {}, {**sample, **{"N": jnp.arange(n)}}
        )[0]
    )(samples)

    assert jnp.var(jnp.exp(-pes - pes_full)) < 1.0
コード例 #14
0
def test_dense_mass(kernel_cls, rho):
    warmup_steps, num_samples = 20000, 10000

    true_cov = jnp.array([[10.0, rho], [rho, 0.1]])

    def model():
        numpyro.sample(
            "x", dist.MultivariateNormal(jnp.zeros(2), covariance_matrix=true_cov)
        )

    if kernel_cls is HMC or kernel_cls is NUTS:
        kernel = kernel_cls(model, trajectory_length=2.0, dense_mass=True)
    elif kernel_cls is BarkerMH:
        kernel = BarkerMH(model, dense_mass=True)

    mcmc = MCMC(kernel, warmup_steps, num_samples, progress_bar=False)
    mcmc.run(random.PRNGKey(0))

    mass_matrix_sqrt = mcmc.last_state.adapt_state.mass_matrix_sqrt
    if kernel_cls is HMC or kernel_cls is NUTS:
        mass_matrix_sqrt = mass_matrix_sqrt[("x",)]
    mass_matrix = jnp.matmul(mass_matrix_sqrt, jnp.transpose(mass_matrix_sqrt))
    estimated_cov = jnp.linalg.inv(mass_matrix)
    assert_allclose(estimated_cov, true_cov, rtol=0.10)

    samples = mcmc.get_samples()["x"]
    assert_allclose(jnp.mean(samples[:, 0]), jnp.array(0.0), atol=0.50)
    assert_allclose(jnp.mean(samples[:, 1]), jnp.array(0.0), atol=0.05)
    assert_allclose(jnp.mean(samples[:, 0] * samples[:, 1]), jnp.array(rho), atol=0.20)
    assert_allclose(jnp.var(samples, axis=0), jnp.array([10.0, 0.1]), rtol=0.20)
コード例 #15
0
ファイル: mip_test.py プロジェクト: wx-b/mipnerf
 def test_expected_sin(self):
     normal_samples = random.normal(random.PRNGKey(0), (10000, ))
     for mu, var in [(0, 1), (1, 3), (-2, .2), (10, 10)]:
         sin_mu, sin_var = mip.expected_sin(mu, var)
         x = jnp.sin(jnp.sqrt(var) * normal_samples + mu)
         self.assertAllClose(sin_mu, jnp.mean(x), atol=1e-2)
         self.assertAllClose(sin_var, jnp.var(x), atol=1e-2)
コード例 #16
0
    def estimator(likelihoods, params, gibbs_state):
        subsample_log_liks = defaultdict(float)
        for (fn, value, name, subsample_dim) in likelihoods.values():
            subsample_log_liks[name] += _sum_all_except_at_dim(
                fn.log_prob(value), subsample_dim
            )

        log_lik_sum = 0.0

        proxy_value_all, proxy_value_subsample = proxy_fn(
            params, subsample_log_liks.keys(), gibbs_state
        )

        for (
            name,
            subsample_log_lik,
        ) in subsample_log_liks.items():  # loop over all subsample sites
            n, m = subsample_plate_sizes[name]

            diff = subsample_log_lik - proxy_value_subsample[name]

            unbiased_log_lik = proxy_value_all[name] + n * jnp.mean(diff)
            variance = n ** 2 / m * jnp.var(diff)
            log_lik_sum += unbiased_log_lik - 0.5 * variance
        return log_lik_sum
コード例 #17
0
 def apply(params, inputs):
     inputs = params[-2] / np.sqrt(np.var(inputs, axis=0)) * (
         inputs - np.mean(inputs, axis=0)) + params[-1]
     for i in range(depth):
         #outputs = mlp(params, inputs) + inputs
         inputs = mlp(params[:-2], inputs) + inputs
     return inputs
コード例 #18
0
ファイル: test_metaculus.py プロジェクト: peterhurford/ergo
def test_submission_from_samples_smooth(metaculus_questions,
                                        smooth_logistic_mixture):
    samples = np.array([smooth_logistic_mixture.sample() for _ in range(5000)])
    fit_mixture = metaculus_questions.continuous_linear_open_question.get_submission_from_samples(
        samples)
    normalized_samples_from_fit_mixture = [
        fit_mixture.sample() for _ in range(5000)
    ]
    mixture_samples = metaculus_questions.continuous_linear_open_question.denormalize_samples(
        normalized_samples_from_fit_mixture)
    assert float(np.mean(samples)) == pytest.approx(float(
        np.mean(mixture_samples)),
                                                    rel=0.1)
    assert float(np.var(samples)) == pytest.approx(float(
        np.var(mixture_samples)),
                                                   rel=0.2)
コード例 #19
0
ファイル: bayesian_regression.py プロジェクト: GJBoth/modax
def bayesian_regression(
        X,
        y,
        prior_init=None,
        hyper_prior=((1e-6, 1e-6), (1e-6, 1e-6)),
        tol=1e-5,
        max_iter=300,
):

    n_samples, n_features = X.shape
    # Prepping matrices
    XT_y = jnp.dot(X.T, y)
    _, S, Vh = jnp.linalg.svd(X, full_matrices=False)
    eigen_vals = S**2

    if prior_init is None:
        prior_init = jnp.ones((2, ))
        prior_init = jax.ops.index_update(prior_init, 1,
                                          1 / (jnp.var(y) + 1e-7))

    # Running
    prior_params, iterations = fixed_point_solver(
        update,
        (X, y, eigen_vals, Vh, XT_y, hyper_prior),
        prior_init,
        lambda z_prev, z: jnp.linalg.norm(z_prev[0] - z[0]) > tol,
        max_iter=max_iter,
    )

    prior = stop_gradient(prior_params)
    log_LL, mn = evidence(X, y, prior, eigen_vals, Vh, XT_y, hyper_prior)
    metrics = (iterations, 0.0)
    return log_LL, mn, prior, metrics
コード例 #20
0
ファイル: simulate_test.py プロジェクト: scnlong/jax-md
    def test_brownian(self, spatial_dimension, dtype):
        key = random.PRNGKey(0)
        key, T_split, mass_split = random.split(key, 3)

        _, shift = space.free()
        energy_fn = lambda R, **kwargs: f32(0)

        R = np.zeros((BROWNIAN_PARTICLE_COUNT, 2), dtype=dtype)
        mass = random.uniform(mass_split, (),
                              minval=0.1,
                              maxval=10.0,
                              dtype=dtype)
        T = random.uniform(T_split, (), minval=0.3, maxval=1.4, dtype=dtype)

        dt = f32(1e-2)
        gamma = f32(0.1)

        init_fn, apply_fn = simulate.brownian(energy_fn,
                                              shift,
                                              dt,
                                              T,
                                              gamma=gamma)
        apply_fn = jit(apply_fn)

        state = init_fn(key, R, mass)

        sim_t = f32(BROWNIAN_DYNAMICS_STEPS * dt)
        for _ in range(BROWNIAN_DYNAMICS_STEPS):
            state = apply_fn(state)

        msd = np.var(state.position)
        th_msd = dtype(2 * T / (mass * gamma) * sim_t)
        assert np.abs(msd - th_msd) / msd < 1e-2
        assert state.position.dtype == dtype
コード例 #21
0
ファイル: layer_norm.py プロジェクト: stjordanis/dm-haiku
    def __call__(
        self,
        inputs: jnp.ndarray,
        scale: Optional[jnp.ndarray] = None,
        offset: Optional[jnp.ndarray] = None,
    ) -> jnp.ndarray:
        """Connects the layer norm.

    Args:
      inputs: An array, where the data format is ``[N, ..., C]``.
      scale: An array up to n-D. The shape of this tensor must be broadcastable
        to the shape of ``inputs``. This is the scale applied to the normalized
        inputs. This cannot be passed in if the module was constructed with
        ``create_scale=True``.
      offset: An array up to n-D. The shape of this tensor must be broadcastable
        to the shape of ``inputs``. This is the offset applied to the normalized
        inputs. This cannot be passed in if the module was constructed with
        ``create_offset=True``.

    Returns:
      The array, normalized.
    """
        if self.create_scale and scale is not None:
            raise ValueError(
                "Cannot pass `scale` at call time if `create_scale=True`.")
        if self.create_offset and offset is not None:
            raise ValueError(
                "Cannot pass `offset` at call time if `create_offset=True`.")

        axis = self.axis
        if isinstance(axis, slice):
            axis = tuple(range(inputs.ndim)[axis])

        mean = jnp.mean(inputs, axis=axis, keepdims=True)
        variance = jnp.var(inputs, axis=axis, keepdims=True)

        param_shape = inputs.shape[-1:]
        if self.create_scale:
            scale = hk.get_parameter("scale",
                                     param_shape,
                                     inputs.dtype,
                                     init=self.scale_init)
        elif scale is None:
            scale = np.array(1., dtype=inputs.dtype)

        if self.create_offset:
            offset = hk.get_parameter("offset",
                                      param_shape,
                                      inputs.dtype,
                                      init=self.offset_init)
        elif offset is None:
            offset = np.array(0., dtype=inputs.dtype)

        scale = jnp.broadcast_to(scale, inputs.shape)
        offset = jnp.broadcast_to(offset, inputs.shape)
        mean = jnp.broadcast_to(mean, inputs.shape)

        eps = jax.lax.convert_element_type(self.eps, variance.dtype)
        inv = scale * jax.lax.rsqrt(variance + eps)
        return inv * (inputs - mean) + offset
コード例 #22
0
def normal_equation_reg(X, Y, W, reg_scale):

    W = jnp.expand_dims(W, axis=-1)\
        *jnp.expand_dims(jnp.expand_dims(jnp.eye(W.shape[-1]), 0), 0)
    #W = jnp.expand_dims(W, axis=0)
    X = jnp.expand_dims(X, 1)
    #print("1",X.shape, W.shape, Y.shape)
    #m = jnp.matmul(W, X)
    #norm = np.matmul(X.transpose((0,1,2,4,3)), np.matmul(W, X))
    norm = jnp.matmul(X.transpose((0, 1, 3, 2)), jnp.matmul(W, X))
    #x_overlap = jnp.abs(jnp.matmul(X.transpose((0,1,3,2)), X))
    #x_overlap *= 1 - jnp.eye(X.shape[-1])

    reg_overlap = jnp.expand_dims(jnp.expand_dims(jnp.eye(X.shape[-1]), 0), 0)
    #reg_overlap = jnp.matmul(X.transpose((0,1,3,2)), X)
    #reg_overlap *= 2*(jnp.eye(X.shape[-1])-0.5)

    #    jnp.expand_dims(theta_sign, -1)*x_overlap*jnp.expand_dims(theta_sign, 0)
    reg_scale = jnp.expand_dims(
        jnp.expand_dims(reg_scale / jnp.var(Y, axis=-1), -1), -1)

    denom = jnp.linalg.inv(norm + reg_scale * reg_overlap)
    #denom = jnp.linalg.inv(norm)

    Y = jnp.expand_dims(Y, axis=-1)
    #print("2", jnp.matmul(W, Y).shape)
    numerator = jnp.matmul(X.transpose((0, 1, 3, 2)), jnp.matmul(W, Y))
    #print("asdf", denom.shape, numerator.shape)
    fit = jnp.matmul(denom, numerator)

    return jnp.reshape(fit, list(fit.shape)[:-1])
コード例 #23
0
def parametric(subposteriors, diagonal=False):
    """
    Merges subposteriors following (embarrassingly parallel) parametric Monte Carlo algorithm.

    **References:**

    1. *Asymptotically Exact, Embarrassingly Parallel MCMC*,
       Willie Neiswanger, Chong Wang, Eric Xing

    :param list subposteriors: a list in which each element is a collection of samples.
    :param bool diagonal: whether to compute weights using variance or covariance, defaults to
        `False` (using covariance).
    :return: the estimated mean and variance/covariance parameters of the joined posterior
    """
    joined_subposteriors = tree_multimap(lambda *args: np.stack(args), *subposteriors)
    joined_subposteriors = vmap(vmap(lambda sample: ravel_pytree(sample)[0]))(joined_subposteriors)

    submeans = np.mean(joined_subposteriors, axis=1)
    if diagonal:
        # NB: jax.numpy.var does not support ddof=1, so we do it manually
        weights = vmap(lambda x: 1 / np.var(x, ddof=1, axis=0))(joined_subposteriors)
        var = 1 / np.sum(weights, axis=0)
        normalized_weights = var * weights

        # comparing to consensus implementation, we compute weighted mean here
        mean = np.einsum('ij,ij->j', normalized_weights, submeans)
        return mean, var
    else:
        weights = vmap(lambda x: np.linalg.inv(np.cov(x.T)))(joined_subposteriors)
        cov = np.linalg.inv(np.sum(weights, axis=0))
        normalized_weights = np.matmul(cov, weights)

        # comparing to consensus implementation, we compute weighted mean here
        mean = np.einsum('ijk,ik->j', normalized_weights, submeans)
        return mean, cov
コード例 #24
0
ファイル: nfnet.py プロジェクト: isseebx123/deepmind-research
 def __call__(self, x, is_training):
     out = self.activation(x) * self.beta
     if self.stride > 1:  # Average-pool downsample. 이 부분이 트랜지션 블록 부분인데, 언제 self.stride > 1이 적용되는가?
         shortcut = hk.avg_pool(out,
                                window_shape=(1, 2, 2, 1),
                                strides=(1, 2, 2, 1),
                                padding='SAME')
         if self.use_projection:
             shortcut = self.conv_shortcut(shortcut)
     elif self.use_projection:
         shortcut = self.conv_shortcut(out)
     else:
         shortcut = x
     out = self.conv0(out)  # 1x1
     out = self.conv1(self.activation(out))  # 3x3
     if self.use_two_convs:
         out = self.conv1b(self.activation(out))  # 3x3
     out = self.conv2(self.activation(out))  # 1x1
     out = (
         self.se(out) * 2
     ) * out  # Multiply by 2 for rescaling # 이것도 어떤 논문에서 2배로 하면 더 잘된다고 그랬다고 하는 ..?
     # Get average residual standard deviation for reporting metrics.
     res_avg_var = jnp.mean(jnp.var(out, axis=[0, 1, 2]))
     # Apply stochdepth if applicable.
     if self._has_stochdepth:
         out = self.stoch_depth(out, is_training)
     # SkipInit Gain
     out = out * hk.get_parameter(
         'skip_gain', (), out.dtype, init=jnp.zeros)
     return out * self.alpha + shortcut, res_avg_var
コード例 #25
0
    def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray:
        mean = jnp.mean(inputs, axis=-1, keepdims=True)
        variance = jnp.var(inputs, axis=-1, keepdims=True)

        param_shape = inputs.shape[-1:]
        scale = hk.get_parameter("scale",
                                 param_shape,
                                 inputs.dtype,
                                 init=jnp.ones)
        scale = jax.lax.all_gather(scale, "shard")[0]

        offset = hk.get_parameter("offset",
                                  param_shape,
                                  inputs.dtype,
                                  init=jnp.zeros)
        offset = jax.lax.all_gather(offset, "shard")[0]

        scale = jnp.broadcast_to(scale, inputs.shape)
        offset = jnp.broadcast_to(offset, inputs.shape)
        mean = jnp.broadcast_to(mean, inputs.shape)

        inv = scale * jax.lax.rsqrt(variance + 1e-5)
        if self.offset:
            return inv * (inputs - mean) + offset
        else:
            return inv * (inputs - mean)
コード例 #26
0
    def __call__(self, x, is_training):
        bias1a = hk.get_parameter('bias1a', (), x.dtype, init=jnp.zeros)
        bias1b = hk.get_parameter('bias1b', (), x.dtype, init=jnp.zeros)
        bias2a = hk.get_parameter('bias2a', (), x.dtype, init=jnp.zeros)
        bias2b = hk.get_parameter('bias2b', (), x.dtype, init=jnp.zeros)
        bias3a = hk.get_parameter('bias3a', (), x.dtype, init=jnp.zeros)
        bias3b = hk.get_parameter('bias3b', (), x.dtype, init=jnp.zeros)
        scale = hk.get_parameter('scale', (), x.dtype, init=jnp.ones)

        out = x + bias1a
        shortcut = out
        if self.use_projection:  # Downsample with conv1x1
            shortcut = self.conv_shortcut(shortcut)
        out = self.conv0(out)
        out = self.activation(out + bias1b)
        out = self.conv1(out + bias2a)
        out = self.activation(out + bias2b)
        out = self.conv2(out + bias3a)
        out = out * scale + bias3b
        # Get average residual variance for reporting metrics.
        res_avg_var = jnp.mean(jnp.var(out, axis=[0, 1, 2]))
        # Apply stochdepth if applicable.
        if self._has_stochdepth:
            out = self.stoch_depth(out, is_training)
        # SkipInit Gain
        out = out + shortcut
        return self.activation(out), res_avg_var
コード例 #27
0
 def apply_fun(params, inputs):
     beta, gamma = params
     mu, var = jnp.mean(inputs, axis=(0, 1, 2)), jnp.var(inputs,
                                                         axis=(0, 1, 2))
     lnorm = (jnp.reshape(mu, (1, 1, 1, -1)) -
              inputs) / jnp.sqrt(jnp.reshape(var, (1, 1, 1, -1)) + 1e-5)
     return jnp.reshape(gamma, (1, 1, 1, -1)) * lnorm + jnp.reshape(
         beta, (1, 1, 1, -1))
コード例 #28
0
    def __call__(self, x):
        means = jnp.mean(x, axis=(1, 2))
        m = jnp.mean(means, axis=-1, keepdims=True)
        v = jnp.var(means, axis=-1, keepdims=True)
        means_plus = (means - m) / jnp.sqrt(v + 1e-5)

        h = (x - means[:, None, None, :]
             ) / jnp.sqrt(jnp.var(x, axis=(1, 2), keepdims=True) + 1e-5)

        h = h + means_plus[:, None, None, :] * self.param(
            'alpha', InstanceNorm2dPlus.scale_init, (1, 1, 1, x.shape[-1]))
        h = h * self.param('gamma', InstanceNorm2dPlus.scale_init,
                           (1, 1, 1, x.shape[-1]))
        if self.bias:
            h = h + self.param('beta', init.zeros, (1, 1, 1, x.shape[-1]))

        return h
コード例 #29
0
ファイル: test_hmc_gibbs.py プロジェクト: mjbajwa/numpyro
def test_discrete_gibbs_gmm_1d(modified, kernel, inner_kernel, kwargs):
    def model(probs, locs):
        c = numpyro.sample("c", dist.Categorical(probs))
        numpyro.sample("x", dist.Normal(locs[c], 0.5))

    probs = jnp.array([0.15, 0.3, 0.3, 0.25])
    locs = jnp.array([-2, 0, 2, 4])
    sampler = kernel(inner_kernel(model, trajectory_length=1.2),
                     modified=modified,
                     **kwargs)
    mcmc = MCMC(sampler, 1000, 200000, progress_bar=False)
    mcmc.run(random.PRNGKey(0), probs, locs)
    samples = mcmc.get_samples()
    assert_allclose(jnp.mean(samples["x"]), 1.3, atol=0.1)
    assert_allclose(jnp.var(samples["x"]), 4.36, atol=0.4)
    assert_allclose(jnp.mean(samples["c"]), 1.65, atol=0.1)
    assert_allclose(jnp.var(samples["c"]), 1.03, atol=0.1)
コード例 #30
0
def signal_metrics(x, i):
    """Things to measure about a NCHW tensor activation."""
    metrics = {}
    # Average channel-wise mean-squared
    metrics[f'avg_sq_mean_{i}'] = jnp.mean(jnp.mean(x, axis=[0, 1, 2])**2)
    # Average channel variance
    metrics[f'avg_var_{i}'] = jnp.mean(jnp.var(x, axis=[0, 1, 2]))
    return metrics