Exemple #1
0
    def test_cnn_sparse_init_kaiming(self):
        """Checks kaiming normal sparse initialization for convolutional layer."""
        _, initial_params = MaskedCNN.init_by_shape(self._rng,
                                                    (self._input_shape, ))
        self._unmasked_model = flax.nn.Model(MaskedCNN, initial_params)

        mask = masked.simple_mask(self._unmasked_model, jnp.ones,
                                  masked.WEIGHT_PARAM_NAMES)

        _, initial_params = MaskedCNNSparseInit.init_by_shape(
            jax.random.PRNGKey(42), (self._input_shape, ), mask=mask)
        self._masked_model_sparse_init = flax.nn.Model(MaskedCNNSparseInit,
                                                       initial_params)

        mean_init = jnp.mean(self._unmasked_model.params['MaskedModule_0']
                             ['unmasked']['kernel'])

        stddev_init = jnp.std(self._unmasked_model.params['MaskedModule_0']
                              ['unmasked']['kernel'])

        mean_sparse_init = jnp.mean(
            self._masked_model_sparse_init.params['MaskedModule_0']['unmasked']
            ['kernel'])

        stddev_sparse_init = jnp.std(
            self._masked_model_sparse_init.params['MaskedModule_0']['unmasked']
            ['kernel'])

        with self.subTest(name='test_cnn_sparse_init_mean'):
            self.assertBetween(mean_sparse_init, mean_init - 2 * stddev_init,
                               mean_init + 2 * stddev_init)

        with self.subTest(name='test_cnn_sparse_init_stddev'):
            self.assertBetween(stddev_sparse_init, 0.5 * stddev_init,
                               1.5 * stddev_init)
Exemple #2
0
def test_get_proposal_loc_and_scale(dense_mass):
    N = 10
    dim = 3
    samples = random.normal(random.PRNGKey(0), (N, dim))
    loc = np.mean(samples[:-1], 0)
    if dense_mass:
        scale = np.linalg.cholesky(
            np.cov(samples[:-1], rowvar=False, bias=True))
    else:
        scale = np.std(samples[:-1], 0)
    actual_loc, actual_scale = _get_proposal_loc_and_scale(
        samples[:-1], loc, scale, samples[-1])
    expected_loc, expected_scale = [], []
    for i in range(N - 1):
        samples_i = onp.delete(samples, i, axis=0)
        expected_loc.append(np.mean(samples_i, 0))
        if dense_mass:
            expected_scale.append(
                np.linalg.cholesky(np.cov(samples_i, rowvar=False, bias=True)))
        else:
            expected_scale.append(np.std(samples_i, 0))
    expected_loc = np.stack(expected_loc)
    expected_scale = np.stack(expected_scale)
    assert_allclose(actual_loc, expected_loc, rtol=1e-4)
    assert_allclose(actual_scale, expected_scale, atol=1e-6, rtol=0.05)
Exemple #3
0
 def initialize(self, normalization='log_return'):
     """
     Description: Check if data exists, else download, clean, and setup.
     Args:
         normalization (str/None): if None, no data normalization. if 'log_return', return log(x_t/x_(t-1)).
             if 'return', return (x_t - x_(t-1)) / x_(t-1)
     Returns:
         The first S&P 500 value
     """
     self.initialized = True
     self.has_regressors = False
     self.normalization = normalization
     if normalization != None:
         assert normalization in [
             'return', 'log_return'
         ], "normalization must be either None, return, or log_return"
     self.T = 0
     df = sp500()  # get data
     self.max_T = df.shape[0]
     data = (df['value'].values.tolist())
     if normalization == 'return':
         data = np.array([(data[i + 1] - data[i]) / data[i]
                          for i in range(len(data) - 1)])
         self.std = np.std(data)
         data /= self.std
     elif normalization == 'log_return':
         data = np.array(
             [np.log(data[i + 1] / data[i]) for i in range(len(data) - 1)])
         self.std = np.std(data)
         data /= self.std
     else:
         data = np.array(data)
         self.std = np.std(data)
     self.data = data
     return self.data[self.T]
def test_mean_var(jax_dist, sp_dist, params):
    n = 20000 if jax_dist in [dist.LKJ, dist.LKJCholesky] else 200000
    d_jax = jax_dist(*params)
    k = random.PRNGKey(0)
    samples = d_jax.sample(k, sample_shape=(n,))
    # check with suitable scipy implementation if available
    if sp_dist and not _is_batched_multivariate(d_jax):
        d_sp = sp_dist(*params)
        try:
            sp_mean = d_sp.mean()
        except TypeError:  # mvn does not have .mean() method
            sp_mean = d_sp.mean
        # for multivariate distns try .cov first
        if d_jax.event_shape:
            try:
                sp_var = np.diag(d_sp.cov())
            except TypeError:  # mvn does not have .cov() method
                sp_var = np.diag(d_sp.cov)
            except AttributeError:
                sp_var = d_sp.var()
        else:
            sp_var = d_sp.var()
        assert_allclose(d_jax.mean, sp_mean, rtol=0.01, atol=1e-7)
        assert_allclose(d_jax.variance, sp_var, rtol=0.01, atol=1e-7)
        if np.all(np.isfinite(sp_mean)):
            assert_allclose(np.mean(samples, 0), d_jax.mean, rtol=0.05, atol=1e-2)
        if np.all(np.isfinite(sp_var)):
            assert_allclose(np.std(samples, 0), np.sqrt(d_jax.variance), rtol=0.05, atol=1e-2)
    elif jax_dist in [dist.LKJ, dist.LKJCholesky]:
        if jax_dist is dist.LKJCholesky:
            corr_samples = np.matmul(samples, np.swapaxes(samples, -2, -1))
        else:
            corr_samples = samples
        dimension, concentration, _ = params
        # marginal of off-diagonal entries
        marginal = dist.Beta(concentration + 0.5 * (dimension - 2),
                             concentration + 0.5 * (dimension - 2))
        # scale statistics due to linear mapping
        marginal_mean = 2 * marginal.mean - 1
        marginal_std = 2 * np.sqrt(marginal.variance)
        expected_mean = np.broadcast_to(np.reshape(marginal_mean, np.shape(marginal_mean) + (1, 1)),
                                        np.shape(marginal_mean) + d_jax.event_shape)
        expected_std = np.broadcast_to(np.reshape(marginal_std, np.shape(marginal_std) + (1, 1)),
                                       np.shape(marginal_std) + d_jax.event_shape)
        # diagonal elements of correlation matrices are 1
        expected_mean = expected_mean * (1 - np.identity(dimension)) + np.identity(dimension)
        expected_std = expected_std * (1 - np.identity(dimension))

        assert_allclose(np.mean(corr_samples, axis=0), expected_mean, atol=0.01)
        assert_allclose(np.std(corr_samples, axis=0), expected_std, atol=0.01)
    else:
        if np.all(np.isfinite(d_jax.mean)):
            assert_allclose(np.mean(samples, 0), d_jax.mean, rtol=0.05, atol=1e-2)
        if np.all(np.isfinite(d_jax.variance)):
            assert_allclose(np.std(samples, 0), np.sqrt(d_jax.variance), rtol=0.05, atol=1e-2)
Exemple #5
0
    def standardize(self, data, time_series=False):
        if time_series:
            self.mean = np.mean(data, axis=1)
            self.stddev = np.std(data, axis=1)
            norm_data = (data - np.expand_dims(
                self.mean, axis=1)) / np.expand_dims(self.stddev, axis=1)
        else:
            self.mean = np.mean(data, axis=0)
            self.stddev = np.std(data, axis=0)
            norm_data = (data - self.mean) / self.stddev

        return norm_data
Exemple #6
0
  def test_resnet_imagenet(self):
    rng = random.PRNGKey(10)
    key1, key2 = random.split(rng)
    x = random.normal(key1, (128, 32, 32, 3))
    activation_f = 'bias_scale_SELU_norm'

    model_def = wideresnet.ResNetImageNet50.partial(
        num_classes=1000,
        activation_f=activation_f,
        normalization='none',
        std_penalty_mult=0,
        use_residual=2,
        bias_scale=0.0,
        weight_norm='fixed',
        softplus_scale=1,
        compensate_padding=True,
        no_head=True,
    )
    (y, _, metrics), _ = model_def.create(
        key2,
        x,
        train=True,
    )
    mean = jnp.mean(y, axis=(0, 1, 2))
    std = jnp.std(
        y, axis=(
            0,
            1,
            2,
        ))
    mean_x = jnp.mean(x, axis=(0, 1, 2))
    std_x = jnp.std(x, axis=(0, 1, 2))

    onp.testing.assert_allclose(mean_x, jnp.zeros_like(mean_x), atol=0.1)
    onp.testing.assert_allclose(std_x, jnp.ones_like(std_x), atol=0.1)

    for metric_key, metric_value in metrics.items():
      if 'postnorm' in metric_key or 'postact' in metric_key or 'postres' in metric_key:
        if 'std' in metric_key:
          onp.testing.assert_allclose(
              metric_value,
              jnp.ones_like(metric_value),
              atol=0.1,
              err_msg=metric_key)
        elif 'mean' in metric_key:
          onp.testing.assert_allclose(
              metric_value,
              jnp.zeros_like(metric_value),
              atol=0.1,
              err_msg=metric_key)

    onp.testing.assert_allclose(std, jnp.ones_like(std), atol=0.4)
    onp.testing.assert_allclose(mean, jnp.zeros_like(mean), atol=0.6)
Exemple #7
0
  def test_resnetv1(self):
    rng = random.PRNGKey(10)
    key1, key2 = random.split(rng)
    x = random.normal(key1, (128, 32, 32, 3))

    activation_f = 'bias_scale_SELU_norm'
    model_def = wideresnet.ResnetV1.partial(
        depth=20,
        num_outputs=10,
        activation_f=activation_f,
        normalization='none',
        dropout_rate=0,
        std_penalty_mult=0,
        use_residual=2,  # TODO(basv): test with residual.
        bias_scale=0.0,
        weight_norm='none',
        no_head=True,
        report_metrics=True,
    )
    (y, _, metrics), _ = model_def.create(
        key2,
        x,
    )
    mean = jnp.mean(y, axis=(0, 1, 2))
    std = jnp.std(
        y, axis=(
            0,
            1,
            2,
        ))
    mean_x = jnp.mean(x, axis=(0, 1, 2))
    std_x = jnp.std(x, axis=(0, 1, 2))

    onp.testing.assert_allclose(mean_x, jnp.zeros_like(mean_x), atol=0.1)
    onp.testing.assert_allclose(std_x, jnp.ones_like(std_x), atol=0.1)

    for metric_key, metric_value in metrics.items():
      if 'postnorm' in metric_key or 'postact' in metric_key or 'postres' in metric_key:
        if 'std' in metric_key:
          onp.testing.assert_allclose(
              metric_value,
              jnp.ones_like(metric_value),
              atol=0.1,
              err_msg=metric_key)
        elif 'mean' in metric_key:
          onp.testing.assert_allclose(
              metric_value,
              jnp.zeros_like(metric_value),
              atol=0.3,
              err_msg=metric_key)

    onp.testing.assert_allclose(mean, jnp.zeros_like(mean), atol=0.2)
    onp.testing.assert_allclose(std, jnp.ones_like(std), atol=0.3)
Exemple #8
0
    def test_wrn26_4(self):
        rng = random.PRNGKey(10)
        key1, key2 = random.split(rng)
        x = random.normal(key1, (128, 32, 32, 3))

        for activation_f in ['bias_scale_SELU_norm']:
            model_def = wideresnet.WideResnet.partial(
                blocks_per_group=4,
                channel_multiplier=4,
                num_outputs=10,
                activation_f=activation_f,
                normalization='none',
                dropout_rate=0,
                std_penalty_mult=0,
                use_residual=2,  # TODO(basv): test with residual.
                bias_scale=0.0,
                weight_norm='learned',
                no_head=True,
            )
            (y, _, metrics), _ = model_def.create(
                key2,
                x,
            )
            mean = jnp.mean(jnp.abs(jnp.mean(y, axis=(0, 1, 2))))
            std = jnp.mean(jnp.std(y, axis=(0, 1, 2)))
            mean_x = jnp.mean(x, axis=(0, 1, 2))
            std_x = jnp.std(x, axis=(0, 1, 2))

            onp.testing.assert_allclose(mean_x,
                                        jnp.zeros_like(mean_x),
                                        atol=0.1)
            onp.testing.assert_allclose(std_x, jnp.ones_like(std_x), atol=0.1)

            for metric_key, metric_value in metrics.items():
                if 'postnorm' in metric_key or 'postact' in metric_key or 'postres' in metric_key:
                    if 'std' in metric_key:
                        onp.testing.assert_allclose(
                            metric_value,
                            jnp.ones_like(metric_value),
                            atol=0.2,
                            err_msg=metric_key)
                    elif 'mean' in metric_key:
                        onp.testing.assert_allclose(
                            metric_value,
                            jnp.zeros_like(metric_value),
                            atol=0.2,
                            err_msg=metric_key)

            onp.testing.assert_allclose(mean, jnp.zeros_like(mean), atol=0.1)
            onp.testing.assert_allclose(std, jnp.ones_like(std), atol=0.1)
Exemple #9
0
def eval_log_prob(model,data,energies,F,T):
    #logPM=model(data)
    #logPE=-energies/T-F
    #return jnp.linalg.norm(logPM-logPE)/data.shape[0]
    logPM=model(data)+energies/T
    #logPE=-energies/T-F
    return jnp.std(logPM)
Exemple #10
0
 def test_overall_mean_variance(self):
     noise = OrnsteinUhlenbeckNoise(random_seed=13)
     x = jnp.stack([noise(0.) for _ in range(1000)])
     mu, sigma = jnp.mean(x), jnp.std(x)
     self.assertLess(abs(mu), noise.theta)
     self.assertGreater(sigma, noise.sigma)
     self.assertLess(sigma, noise.sigma * 2)
Exemple #11
0
    def sample_kernel(sa_state, model_args=(), model_kwargs=None):
        pe_fn = potential_fn
        if potential_fn_gen:
            pe_fn = potential_fn_gen(*model_args, **model_kwargs)
        zs, pes, loc, scale = sa_state.adapt_state
        # we recompute loc/scale after each iteration to avoid precision loss
        # XXX: consider to expose a setting to do this job periodically
        # to save some computations
        loc = jnp.mean(zs, 0)
        if scale.ndim == 2:
            cov = jnp.cov(zs, rowvar=False, bias=True)
            if cov.shape == ():  # JAX returns scalar for 1D input
                cov = cov.reshape((1, 1))
            cholesky = jnp.linalg.cholesky(cov)
            scale = jnp.where(jnp.any(jnp.isnan(cholesky)), scale, cholesky)
        else:
            scale = jnp.std(zs, 0)

        rng_key, rng_key_z, rng_key_reject, rng_key_accept = random.split(sa_state.rng_key, 4)
        _, unravel_fn = ravel_pytree(sa_state.z)

        z = loc + _sample_proposal(scale, rng_key_z)
        pe = pe_fn(unravel_fn(z))
        pe = jnp.where(jnp.isnan(pe), jnp.inf, pe)
        diverging = (pe - sa_state.potential_energy) > max_delta_energy

        # NB: all terms having the pattern *s will have shape N x ...
        # and all terms having the pattern *s_ will have shape (N + 1) x ...
        locs, scales = _get_proposal_loc_and_scale(zs, loc, scale, z)
        zs_ = jnp.concatenate([zs, z[None, :]])
        pes_ = jnp.concatenate([pes, pe[None]])
        locs_ = jnp.concatenate([locs, loc[None, :]])
        scales_ = jnp.concatenate([scales, scale[None, ...]])
        if scale.ndim == 2:  # dense_mass
            log_weights_ = dist.MultivariateNormal(locs_, scale_tril=scales_).log_prob(zs_) + pes_
        else:
            log_weights_ = dist.Normal(locs_, scales_).log_prob(zs_).sum(-1) + pes_
        # mask invalid values (nan, +inf) by -inf
        log_weights_ = jnp.where(jnp.isfinite(log_weights_), log_weights_, -jnp.inf)
        # get rejecting index
        j = random.categorical(rng_key_reject, log_weights_)
        zs = _numpy_delete(zs_, j)
        pes = _numpy_delete(pes_, j)
        loc = locs_[j]
        scale = scales_[j]
        adapt_state = SAAdaptState(zs, pes, loc, scale)

        # NB: weights[-1] / sum(weights) is the probability of rejecting the new sample `z`.
        accept_prob = 1 - jnp.exp(log_weights_[-1] - logsumexp(log_weights_))
        itr = sa_state.i + 1
        n = jnp.where(sa_state.i < wa_steps, itr, itr - wa_steps)
        mean_accept_prob = sa_state.mean_accept_prob + (accept_prob - sa_state.mean_accept_prob) / n

        # XXX: we make a modification of SA sampler in [1]
        # in [1], each MCMC state contains N points `zs`
        # here we do resampling to pick randomly a point from those N points
        k = random.categorical(rng_key_accept, jnp.zeros(zs.shape[0]))
        z = unravel_fn(zs[k])
        pe = pes[k]
        return SAState(itr, z, pe, accept_prob, mean_accept_prob, diverging, adapt_state, rng_key)
Exemple #12
0
def NES_profile_jax(params,
                    params_to_xL,
                    score_function,
                    npop=50,
                    sigma_noise=0.1,
                    alpha=0.05):
    """Natural Evolutionary strategy
  
  Args:
  		npop: population size
  		sigma: standard deviation
  		alpha: learning rate
  """
    def single_update(pi, ni):
        p_new = pi + sigma_noise * ni
        xL_new = params_to_xL(p_new)
        reward_new = score_function(xL=xL_new)
        return reward_new

    num_params = params.shape[0]
    xL = params_to_xL(params)
    N = np.array(onp.random.randn(npop, num_params))
    R = vmap(single_update, (None, 0), 0)(params, N)
    A = (R - np.mean(R)) / (np.std(R) + 1e-6)
    params_update = params - alpha / (npop * sigma_noise) * np.dot(N.T, A)
    return params_update
Exemple #13
0
def main(args):
    _, fetch_train = load_dataset(UCBADMIT, split="train", shuffle=False)
    dept, male, applications, admit = fetch_train()
    rng_key, rng_key_predict = random.split(random.PRNGKey(1))
    zs = run_inference(dept, male, applications, admit, rng_key, args)
    pred_probs = Predictive(glmm, zs)(rng_key_predict, dept, male,
                                      applications)["probs"]
    header = "=" * 30 + "glmm - TRAIN" + "=" * 30
    print_results(header, pred_probs, dept, male, admit / applications)

    # make plots
    fig, ax = plt.subplots(figsize=(8, 6), constrained_layout=True)

    ax.plot(range(1, 13), admit / applications, "o", ms=7, label="actual rate")
    ax.errorbar(
        range(1, 13),
        jnp.mean(pred_probs, 0),
        jnp.std(pred_probs, 0),
        fmt="o",
        c="k",
        mfc="none",
        ms=7,
        elinewidth=1,
        label=r"mean $\pm$ std",
    )
    ax.plot(range(1, 13), jnp.percentile(pred_probs, 5, 0), "k+")
    ax.plot(range(1, 13), jnp.percentile(pred_probs, 95, 0), "k+")
    ax.set(
        xlabel="cases",
        ylabel="admit rate",
        title="Posterior Predictive Check with 90% CI",
    )
    ax.legend()

    plt.savefig("ucbadmit_plot.pdf")
def get_norm(init_x):
    mean = jnp.mean(init_x, axis=0)
    std = jnp.std(init_x, axis=0)

    def norm(x):
        return (x - mean) / (std + 1e-5)
    return norm
Exemple #15
0
def test_unnormalized_normal_chain(kernel, kwargs, num_chains):
    from numpyro.contrib.tfp import mcmc

    # TODO: remove when this issue is fixed upstream
    # https://github.com/tensorflow/probability/pull/1087
    if num_chains == 2 and kernel == "ReplicaExchangeMC":
        pytest.xfail(
            "ReplicaExchangeMC is not fully compatible with omnistaging yet.")

    kernel_class = getattr(mcmc, kernel)

    true_mean, true_std = 1., 0.5
    warmup_steps, num_samples = (1000, 8000)

    def potential_fn(z):
        return 0.5 * ((z - true_mean) / true_std)**2

    init_params = jnp.array(0.) if num_chains == 1 else jnp.array([0., 2.])
    tfp_kernel = kernel_class(potential_fn=potential_fn, **kwargs)
    mcmc = MCMC(tfp_kernel,
                warmup_steps,
                num_samples,
                num_chains=num_chains,
                progress_bar=False)
    mcmc.run(random.PRNGKey(0), init_params=init_params)
    mcmc.print_summary()
    hmc_states = mcmc.get_samples()
    assert_allclose(jnp.mean(hmc_states), true_mean, rtol=0.07)
    assert_allclose(jnp.std(hmc_states), true_std, rtol=0.07)
Exemple #16
0
def get_toy_pairs(N=20,
                  S=2,
                  P=10,
                  sigma_obs=0.05,
                  active_pairs=[(0, 1), (1, 2)]):
    assert S < P and P > 1 and S > 0

    onp.random.seed(0)

    X = onp.random.randn(N, P)
    # generate S coefficients with non-negligible magnitude
    W = 0.5 + 2.5 * onp.random.rand(S)
    # generate data using the S coefficients and however many pairwise interactions

    Y = onp.sum(X[:, 0:S] * W, axis=-1)

    # now add in all pairwise interactions
    for pair in active_pairs:
        Y += X[:, pair[0]] * X[:, pair[1]]

    Y += sigma_obs * onp.random.randn(N)
    Y -= np.mean(Y)
    Y_std = np.std(Y)

    assert X.shape == (N, P)
    assert Y.shape == (N, )

    return X, Y / Y_std, W / Y_std, 1.0 / Y_std
Exemple #17
0
def sinkhorn_for_sort(inputs: jnp.ndarray, weights: jnp.ndarray,
                      target_weights: jnp.ndarray, sinkhorn_kw,
                      pointcloud_kw) -> jnp.ndarray:
    """Runs sinkhorn on a fixed increasing target.

  Args:
    inputs: jnp.ndarray[num_points]. Must be one dimensional.
    weights: jnp.ndarray[num_points]. The weights 'a' for the inputs.
    target_weights: jnp.ndarray[num_targets]: the weights of the targets. It may
      be of a different size than the weights.
    sinkhorn_kw: a dictionary holding the sinkhorn keyword arguments. See
      sinkhorn.py for more details.
    pointcloud_kw: a dictionary holding the keyword arguments of the
      PointCloud class. See pointcloud.py for more details.

  Returns:
    A jnp.ndarray<float> representing the transport matrix of the inputs onto
    the underlying sorted target.
  """
    shape = inputs.shape
    if len(shape) > 2 or (len(shape) == 2 and shape[1] != 1):
        raise ValueError(
            "Shape ({shape}) not supported. The input should be one-dimensional."
        )

    x = jnp.expand_dims(jnp.squeeze(inputs), axis=1)
    x = jax.nn.sigmoid((x - jnp.mean(x)) / (jnp.std(x) + 1e-10))
    a = jnp.squeeze(weights)
    b = jnp.squeeze(target_weights)
    num_targets = b.shape[0]
    y = jnp.linspace(0.0, 1.0, num_targets)[:, jnp.newaxis]
    geom = pointcloud.PointCloud(x, y, **pointcloud_kw)
    res = sinkhorn.sinkhorn(geom, a, b, **sinkhorn_kw)
    return geom.transport_from_potentials(res.f, res.g)
Exemple #18
0
def ksd_squared_l(samples, logp, k, return_stddev=False):
    """
    O(n) time estimator for the KSD.
    Arguments:
    * samples: np.array of shape (n, d)
    * logp: callable
    * k: callable, computes scalar-valued kernel k(x, y) given two input arguments of shape (d,).

    Returns:
    * The square of the stein discrepancy KSD(q, p).
    KSD is approximated as $\sum_i g(x_i, y_i)$, where the x and y are iid distributed as q
    * The approximate variance of h(X, Y)
    """
    try:
        xs, ys = samples.split(2)
    except ValueError:  # uneven split
        xs, ys = samples[:-1].split(2)

    def h(x, y):
        """x, y: np.arrays of shape (d,)"""
        def inner(x):
            return stein_operator(lambda y_: k(x, y_), y, logp)

        return stein_operator(inner, x, logp, transposed=True)

    outs = vmap(h)(xs, ys)
    if return_stddev:
        return np.mean(outs), np.std(outs, ddof=1) / xs.shape[0]
    else:
        return np.mean(outs)
Exemple #19
0
            def log_s_shift_init(shape, dtype):
                if x.ndim == len(shape):
                    return jnp.zeros(shape, dtype)

                z = self.f(weight_logits, means, log_scales, x)
                axes = tuple(jnp.arange(len(z.shape) - len(shape)))
                return jnp.log(jnp.std(z, axis=axes) + 1e-5)
Exemple #20
0
    def test_weight_norm_standard(self):

        rng = random.PRNGKey(5)
        key1, key2 = random.split(rng)
        for k in [3, 5]:
            for padding in ['VALID', 'SAME']:
                for layer in [
                        conv_layers.Conv, conv_layers.ConvWS,
                        conv_layers.ConvFixedScale
                ]:
                    x = random.normal(key1, (512, 32, 32, 128))
                    y = x
                    for i in range(5):
                        y, _ = layer.create(
                            key2,
                            y,
                            features=128,
                            kernel_size=(k, k),
                            bias=False,
                            padding=padding,
                            kernel_init=jax.nn.initializers.lecun_normal())

                        mean = jnp.mean(y)
                        std = jnp.std(y)

                        err_msg = 'layer %s, padding %s, kernel_size %d, depth %d' % (
                            layer.__name__, padding, k, i)
                        onp.testing.assert_allclose(mean,
                                                    jnp.zeros_like(mean),
                                                    atol=0.1,
                                                    err_msg=err_msg)
                        onp.testing.assert_allclose(std,
                                                    jnp.ones_like(std),
                                                    atol=0.1,
                                                    err_msg=err_msg)
Exemple #21
0
            def model_update_minibatch(
                carry: Tuple[networks_lib.Params, optax.OptState],
                minibatch: Batch,
            ) -> Tuple[Tuple[networks_lib.Params, optax.OptState], Dict[
                    str, jnp.ndarray]]:
                """Performs model update for a single minibatch."""
                params, opt_state = carry
                # Normalize advantages at the minibatch level before using them.
                advantages = ((minibatch.advantages -
                               jnp.mean(minibatch.advantages, axis=0)) /
                              (jnp.std(minibatch.advantages, axis=0) + 1e-8))
                gradients, metrics = grad_fn(params, minibatch.observations,
                                             minibatch.actions,
                                             minibatch.behavior_log_probs,
                                             minibatch.target_values,
                                             advantages,
                                             minibatch.behavior_values)

                # Apply updates
                updates, opt_state = optimizer.update(gradients, opt_state)
                params = optax.apply_updates(params, updates)

                metrics['norm_grad'] = optax.global_norm(gradients)
                metrics['norm_updates'] = optax.global_norm(updates)
                return (params, opt_state), metrics
Exemple #22
0
def test_unnormalized_normal_x64(kernel_cls, dense_mass):
    true_mean, true_std = 1.0, 0.5
    num_warmup, num_samples = (100000, 100000) if kernel_cls is SA else (1000,
                                                                         8000)

    def potential_fn(z):
        return 0.5 * jnp.sum(((z - true_mean) / true_std)**2)

    init_params = jnp.array(0.0)
    if kernel_cls is SA:
        kernel = SA(potential_fn=potential_fn, dense_mass=dense_mass)
    elif kernel_cls is BarkerMH:
        kernel = SA(potential_fn=potential_fn, dense_mass=dense_mass)
    else:
        kernel = kernel_cls(potential_fn=potential_fn,
                            trajectory_length=8,
                            dense_mass=dense_mass)
    mcmc = MCMC(kernel,
                num_warmup=num_warmup,
                num_samples=num_samples,
                progress_bar=False)
    mcmc.run(random.PRNGKey(0), init_params=init_params)
    mcmc.print_summary()
    hmc_states = mcmc.get_samples()
    assert_allclose(jnp.mean(hmc_states), true_mean, rtol=0.07)
    assert_allclose(jnp.std(hmc_states), true_std, rtol=0.07)

    if "JAX_ENABLE_X64" in os.environ:
        assert hmc_states.dtype == jnp.float64
Exemple #23
0
def test_klee_measure():
    from jax import random, jit, disable_jit
    import pylab as plt
    N, D = 2, 2
    points = random.uniform(random.PRNGKey(1), shape=(N, D))
    # points = jnp.array([[0., 1.],[0., 0.]])
    eps = 0.1
    gamma = 0.90
    for w in jnp.linspace(0., 1., 10):
        true_volume = 2. * w**2 - cubes_intersect_volume(
            points[0, :], points[1, :], w)
        vol = jnp.exp(
            jit(
                vmap(lambda key: log_klee_measure(
                    key, points, w, eps=eps, gamma=gamma)))(random.split(
                        random.PRNGKey(0), 100)))
        print(jnp.mean(vol), jnp.std(vol), true_volume)
        eps_bound = jnp.mean((vol <= true_volume * (1. + eps))
                             & (vol >= true_volume * (1. - eps)))
        l = w / 2.
        plt.scatter(points[:, 0], points[:, 1])
        for i in range(N):
            plt.plot([
                points[i, 0] - l, points[i, 0] + l, points[i, 0] + l,
                points[i, 0] - l, points[i, 0] - l
            ], [
                points[i, 1] - l, points[i, 1] - l, points[i, 1] + l,
                points[i, 1] + l, points[i, 1] - l
            ],
                     c='black')
        plt.title("prob_bound {}".format(eps_bound, true_volume))
        plt.show()
Exemple #24
0
    def get_opd(self, wave):
        """
        Parameters
        ----------
        wave : morphine.Wavefront (or float)
            Incoming Wavefront before this optic to set wavelength and
            scale, or a float giving the wavelength in meters
            for a temporary Wavefront used to compute the OPD.
        """
        y, x = self.get_coordinates(wave)
        rho, theta = _wave_y_x_to_rho_theta(y, x, self.radius)
        psd = np.power(rho, -self.index)  # generate power-law PSD

        np.random.seed(
            self.seed)  # if provided, set a seed for random number generator
        rndm_phase = np.random.normal(
            size=(len(y), len(x)))  # generate random phase screen
        rndm_psd = np.fft.fftshift(np.fft.fft2(np.fft.fftshift(
            rndm_phase)))  # FT of random phase screen to get random PSD
        scaled = np.sqrt(psd) * rndm_psd  # scale random PSD by power-law PSD
        phase_screen = np.fft.ifftshift(np.fft.ifft2(np.fft.ifftshift(
            scaled))).real  # FT of scaled random PSD makes phase screen

        phase_screen -= np.mean(phase_screen)  # force zero-mean
        opd = phase_screen / np.std(
            phase_screen) * self.wfe  # normalize to wanted input rms wfe

        return opd
Exemple #25
0
def transport_for_sort(inputs: jnp.ndarray, weights: jnp.ndarray,
                       target_weights: jnp.ndarray, kwargs) -> jnp.ndarray:
    """Runs sinkhorn on a fixed increasing target.

  Args:
    inputs: jnp.ndarray[num_points]. Must be one dimensional.
    weights: jnp.ndarray[num_points]. The weights 'a' for the inputs.
    target_weights: jnp.ndarray[num_targets]: the weights of the targets. It may
      be of a different size than the weights.
    kwargs: a dictionary holding the sinkhorn keyword arguments and the
      pointcloud argument.

  Returns:
    A jnp.ndarray<float> representing the transport matrix of the inputs onto
    the underlying sorted target.
  """
    shape = inputs.shape
    if len(shape) > 2 or (len(shape) == 2 and shape[1] != 1):
        raise ValueError(
            'Shape ({shape}) not supported. The input should be one-dimensional.'
        )

    x = jnp.expand_dims(jnp.squeeze(inputs), axis=1)
    x = jax.nn.sigmoid((x - jnp.mean(x)) / (jnp.std(x) + 1e-10))
    a = jnp.squeeze(weights)
    b = jnp.squeeze(target_weights)
    num_targets = b.shape[0]
    y = jnp.linspace(0.0, 1.0, num_targets)[:, jnp.newaxis]
    return transport.Transport(x, y, a=a, b=b, **kwargs)
Exemple #26
0
def ppo_loss_given_predictions(log_probab_actions_new,
                               log_probab_actions_old,
                               value_predictions_old,
                               padded_actions,
                               padded_rewards,
                               reward_mask,
                               gamma=0.99,
                               lambda_=0.95,
                               epsilon=0.2):
  """PPO objective, with an eventual minus sign, given predictions."""
  B, T = padded_rewards.shape  # pylint: disable=invalid-name
  _, _, C, A = log_probab_actions_old.shape  # pylint: disable=invalid-name

  assert (B, T) == padded_rewards.shape
  assert (B, T, C) == padded_actions.shape
  assert (B, T) == reward_mask.shape

  assert (B, T + 1, 1) == value_predictions_old.shape
  assert (B, T + 1, C, A) == log_probab_actions_old.shape
  assert (B, T + 1, C, A) == log_probab_actions_new.shape

  # (B, T)
  td_deltas = deltas(
      np.squeeze(value_predictions_old, axis=2),  # (B, T+1)
      padded_rewards,
      reward_mask,
      gamma=gamma)

  # (B, T)
  advantages = gae_advantages(
      td_deltas, reward_mask, lambda_=lambda_, gamma=gamma)

  # Normalize the advantages.
  advantage_mean = np.mean(advantages)
  advantage_std = np.std(advantages)
  advantages = (advantages - advantage_mean) / (advantage_std + 1e-8)

  # (B, T)
  ratios = compute_probab_ratios(log_probab_actions_new, log_probab_actions_old,
                                 padded_actions, reward_mask)
  assert (B, T, C) == ratios.shape

  # (B, T)
  objective = clipped_objective(
      ratios, advantages, reward_mask, epsilon=epsilon)
  assert (B, T, C) == objective.shape

  # ()
  average_objective = np.sum(objective) / np.sum(reward_mask)

  # Loss is negative objective.
  ppo_loss = -average_objective

  summaries = {
      "ppo_loss": ppo_loss,
      "advantage_mean": advantage_mean,
      "advantage_std": advantage_std,
  }

  return (ppo_loss, summaries)
Exemple #27
0
def render_rays_fine(rays,
                     z_vals,
                     weights,
                     num_importance,
                     perturbation=True,
                     rng=None):
    """Render rays for the fine model.
    Args:
        rays: (2, num_rays, 3) origin and direction generated rays
        z_vals: (num_rays, num_samples) depths of the sampled positions
        weights: (num_rays, num_samples) weights assigned to each sampled color for the coarse model
        num_importance: number of samples used in the fine model
        perturbation: whether to apply jitter on each ray or not
        rng: random key
    Returns:
        pts: (num_rays, num_samples + num_importance, 3) points in space to evaluate model at
        z_vals: (num_rays, num_samples + num_importance) depths of the sampled positions
        z_samples: (num_rays) standard deviation of distances along ray for each sample
    """
    rays_o, rays_d = rays

    z_vals_mid = 0.5 * (z_vals[..., 1:] + z_vals[..., :-1])
    z_samples = sample_pdf(z_vals_mid, weights[..., 1:-1], num_importance,
                           perturbation, rng)
    z_samples = lax.stop_gradient(z_samples)

    # obtain all points to evaluate color density at
    z_vals = jnp.sort(jnp.concatenate([z_vals, z_samples], axis=-1), axis=-1)
    z_vals = z_vals.astype(rays_d.dtype)
    pts = rays_o[..., None, :] + rays_d[..., None, :] * z_vals[..., :, None]
    return pts, z_vals, jnp.std(z_samples, axis=-1)
  def test_nested_normalize(self):
    state = running_statistics.init_state({
        'a': specs.Array((5,), jnp.float32),
        'b': specs.Array((2,), jnp.float32)
    })

    x1 = {
        'a': jnp.arange(20, dtype=jnp.float32).reshape(2, 2, 5),
        'b': jnp.arange(8, dtype=jnp.float32).reshape(2, 2, 2)
    }
    x2 = {
        'a': jnp.arange(20, dtype=jnp.float32).reshape(2, 2, 5) + 20,
        'b': jnp.arange(8, dtype=jnp.float32).reshape(2, 2, 2) + 8
    }
    x3 = {
        'a': jnp.arange(40, dtype=jnp.float32).reshape(4, 2, 5),
        'b': jnp.arange(16, dtype=jnp.float32).reshape(4, 2, 2)
    }

    state = update_and_validate(state, x1)
    state = update_and_validate(state, x2)
    state = update_and_validate(state, x3)
    normalized = running_statistics.normalize(x3, state)

    mean = tree.map_structure(lambda x: jnp.mean(x, axis=(0, 1)), normalized)
    std = tree.map_structure(lambda x: jnp.std(x, axis=(0, 1)), normalized)
    tree.map_structure(
        lambda x: self.assert_allclose(x, jnp.zeros_like(x)),
        mean)
    tree.map_structure(
        lambda x: self.assert_allclose(x, jnp.ones_like(x)),
        std)
  def test_pmap_update_nested(self):
    local_device_count = jax.local_device_count()
    state = running_statistics.init_state({
        'a': specs.Array((5,), jnp.float32),
        'b': specs.Array((2,), jnp.float32)
    })

    x = {
        'a': (jnp.arange(15 * local_device_count,
                         dtype=jnp.float32)).reshape(local_device_count, 3, 5),
        'b': (jnp.arange(6 * local_device_count,
                         dtype=jnp.float32)).reshape(local_device_count, 3, 2),
    }

    devices = jax.local_devices()
    state = jax.device_put_replicated(state, devices)
    pmap_axis_name = 'i'
    state = jax.pmap(
        functools.partial(update_and_validate, pmap_axis_name=pmap_axis_name),
        pmap_axis_name)(state, x)
    state = jax.pmap(
        functools.partial(update_and_validate, pmap_axis_name=pmap_axis_name),
        pmap_axis_name)(state, x)
    normalized = jax.pmap(running_statistics.normalize)(x, state)

    mean = tree.map_structure(lambda x: jnp.mean(x, axis=(0, 1)), normalized)
    std = tree.map_structure(lambda x: jnp.std(x, axis=(0, 1)), normalized)
    tree.map_structure(
        lambda x: self.assert_allclose(x, jnp.zeros_like(x)), mean)
    tree.map_structure(
        lambda x: self.assert_allclose(x, jnp.ones_like(x)), std)
Exemple #30
0
    def __call__(self, x, sigmas, train=True):
        # per image standardization
        N = np.prod(x.shape[1:])
        x = (x - jnp.mean(x, axis=(1, 2, 3), keepdims=True)) / jnp.maximum(
            jnp.std(x, axis=(1, 2, 3), keepdims=True), 1. / np.sqrt(N))
        temb = GaussianFourierProjection(embedding_size=128,
                                         scale=16)(jnp.log(sigmas))
        temb = nn.Dense(128 * 4)(temb)
        temb = nn.Dense(128 * 4)(nn.swish(temb))

        x = nn.Conv(16, (3, 3),
                    padding='SAME',
                    name='init_conv',
                    kernel_init=conv_kernel_init_fn,
                    use_bias=False)(x)
        x = WideResnetGroup(self.blocks_per_group,
                            16 * self.channel_multiplier,
                            activate_before_residual=True)(x, temb, train)
        x = WideResnetGroup(self.blocks_per_group,
                            32 * self.channel_multiplier, (2, 2))(x, temb,
                                                                  train)
        x = WideResnetGroup(self.blocks_per_group,
                            64 * self.channel_multiplier, (2, 2))(x, temb,
                                                                  train)
        x = activation(x, train=train, name='pre-pool-bn')
        x = nn.avg_pool(x, x.shape[1:3])
        x = x.reshape((x.shape[0], -1))
        x = nn.Dense(self.num_outputs, kernel_init=dense_layer_init_fn)(x)
        return x