def test_cnn_sparse_init_kaiming(self): """Checks kaiming normal sparse initialization for convolutional layer.""" _, initial_params = MaskedCNN.init_by_shape(self._rng, (self._input_shape, )) self._unmasked_model = flax.nn.Model(MaskedCNN, initial_params) mask = masked.simple_mask(self._unmasked_model, jnp.ones, masked.WEIGHT_PARAM_NAMES) _, initial_params = MaskedCNNSparseInit.init_by_shape( jax.random.PRNGKey(42), (self._input_shape, ), mask=mask) self._masked_model_sparse_init = flax.nn.Model(MaskedCNNSparseInit, initial_params) mean_init = jnp.mean(self._unmasked_model.params['MaskedModule_0'] ['unmasked']['kernel']) stddev_init = jnp.std(self._unmasked_model.params['MaskedModule_0'] ['unmasked']['kernel']) mean_sparse_init = jnp.mean( self._masked_model_sparse_init.params['MaskedModule_0']['unmasked'] ['kernel']) stddev_sparse_init = jnp.std( self._masked_model_sparse_init.params['MaskedModule_0']['unmasked'] ['kernel']) with self.subTest(name='test_cnn_sparse_init_mean'): self.assertBetween(mean_sparse_init, mean_init - 2 * stddev_init, mean_init + 2 * stddev_init) with self.subTest(name='test_cnn_sparse_init_stddev'): self.assertBetween(stddev_sparse_init, 0.5 * stddev_init, 1.5 * stddev_init)
def test_get_proposal_loc_and_scale(dense_mass): N = 10 dim = 3 samples = random.normal(random.PRNGKey(0), (N, dim)) loc = np.mean(samples[:-1], 0) if dense_mass: scale = np.linalg.cholesky( np.cov(samples[:-1], rowvar=False, bias=True)) else: scale = np.std(samples[:-1], 0) actual_loc, actual_scale = _get_proposal_loc_and_scale( samples[:-1], loc, scale, samples[-1]) expected_loc, expected_scale = [], [] for i in range(N - 1): samples_i = onp.delete(samples, i, axis=0) expected_loc.append(np.mean(samples_i, 0)) if dense_mass: expected_scale.append( np.linalg.cholesky(np.cov(samples_i, rowvar=False, bias=True))) else: expected_scale.append(np.std(samples_i, 0)) expected_loc = np.stack(expected_loc) expected_scale = np.stack(expected_scale) assert_allclose(actual_loc, expected_loc, rtol=1e-4) assert_allclose(actual_scale, expected_scale, atol=1e-6, rtol=0.05)
def initialize(self, normalization='log_return'): """ Description: Check if data exists, else download, clean, and setup. Args: normalization (str/None): if None, no data normalization. if 'log_return', return log(x_t/x_(t-1)). if 'return', return (x_t - x_(t-1)) / x_(t-1) Returns: The first S&P 500 value """ self.initialized = True self.has_regressors = False self.normalization = normalization if normalization != None: assert normalization in [ 'return', 'log_return' ], "normalization must be either None, return, or log_return" self.T = 0 df = sp500() # get data self.max_T = df.shape[0] data = (df['value'].values.tolist()) if normalization == 'return': data = np.array([(data[i + 1] - data[i]) / data[i] for i in range(len(data) - 1)]) self.std = np.std(data) data /= self.std elif normalization == 'log_return': data = np.array( [np.log(data[i + 1] / data[i]) for i in range(len(data) - 1)]) self.std = np.std(data) data /= self.std else: data = np.array(data) self.std = np.std(data) self.data = data return self.data[self.T]
def test_mean_var(jax_dist, sp_dist, params): n = 20000 if jax_dist in [dist.LKJ, dist.LKJCholesky] else 200000 d_jax = jax_dist(*params) k = random.PRNGKey(0) samples = d_jax.sample(k, sample_shape=(n,)) # check with suitable scipy implementation if available if sp_dist and not _is_batched_multivariate(d_jax): d_sp = sp_dist(*params) try: sp_mean = d_sp.mean() except TypeError: # mvn does not have .mean() method sp_mean = d_sp.mean # for multivariate distns try .cov first if d_jax.event_shape: try: sp_var = np.diag(d_sp.cov()) except TypeError: # mvn does not have .cov() method sp_var = np.diag(d_sp.cov) except AttributeError: sp_var = d_sp.var() else: sp_var = d_sp.var() assert_allclose(d_jax.mean, sp_mean, rtol=0.01, atol=1e-7) assert_allclose(d_jax.variance, sp_var, rtol=0.01, atol=1e-7) if np.all(np.isfinite(sp_mean)): assert_allclose(np.mean(samples, 0), d_jax.mean, rtol=0.05, atol=1e-2) if np.all(np.isfinite(sp_var)): assert_allclose(np.std(samples, 0), np.sqrt(d_jax.variance), rtol=0.05, atol=1e-2) elif jax_dist in [dist.LKJ, dist.LKJCholesky]: if jax_dist is dist.LKJCholesky: corr_samples = np.matmul(samples, np.swapaxes(samples, -2, -1)) else: corr_samples = samples dimension, concentration, _ = params # marginal of off-diagonal entries marginal = dist.Beta(concentration + 0.5 * (dimension - 2), concentration + 0.5 * (dimension - 2)) # scale statistics due to linear mapping marginal_mean = 2 * marginal.mean - 1 marginal_std = 2 * np.sqrt(marginal.variance) expected_mean = np.broadcast_to(np.reshape(marginal_mean, np.shape(marginal_mean) + (1, 1)), np.shape(marginal_mean) + d_jax.event_shape) expected_std = np.broadcast_to(np.reshape(marginal_std, np.shape(marginal_std) + (1, 1)), np.shape(marginal_std) + d_jax.event_shape) # diagonal elements of correlation matrices are 1 expected_mean = expected_mean * (1 - np.identity(dimension)) + np.identity(dimension) expected_std = expected_std * (1 - np.identity(dimension)) assert_allclose(np.mean(corr_samples, axis=0), expected_mean, atol=0.01) assert_allclose(np.std(corr_samples, axis=0), expected_std, atol=0.01) else: if np.all(np.isfinite(d_jax.mean)): assert_allclose(np.mean(samples, 0), d_jax.mean, rtol=0.05, atol=1e-2) if np.all(np.isfinite(d_jax.variance)): assert_allclose(np.std(samples, 0), np.sqrt(d_jax.variance), rtol=0.05, atol=1e-2)
def standardize(self, data, time_series=False): if time_series: self.mean = np.mean(data, axis=1) self.stddev = np.std(data, axis=1) norm_data = (data - np.expand_dims( self.mean, axis=1)) / np.expand_dims(self.stddev, axis=1) else: self.mean = np.mean(data, axis=0) self.stddev = np.std(data, axis=0) norm_data = (data - self.mean) / self.stddev return norm_data
def test_resnet_imagenet(self): rng = random.PRNGKey(10) key1, key2 = random.split(rng) x = random.normal(key1, (128, 32, 32, 3)) activation_f = 'bias_scale_SELU_norm' model_def = wideresnet.ResNetImageNet50.partial( num_classes=1000, activation_f=activation_f, normalization='none', std_penalty_mult=0, use_residual=2, bias_scale=0.0, weight_norm='fixed', softplus_scale=1, compensate_padding=True, no_head=True, ) (y, _, metrics), _ = model_def.create( key2, x, train=True, ) mean = jnp.mean(y, axis=(0, 1, 2)) std = jnp.std( y, axis=( 0, 1, 2, )) mean_x = jnp.mean(x, axis=(0, 1, 2)) std_x = jnp.std(x, axis=(0, 1, 2)) onp.testing.assert_allclose(mean_x, jnp.zeros_like(mean_x), atol=0.1) onp.testing.assert_allclose(std_x, jnp.ones_like(std_x), atol=0.1) for metric_key, metric_value in metrics.items(): if 'postnorm' in metric_key or 'postact' in metric_key or 'postres' in metric_key: if 'std' in metric_key: onp.testing.assert_allclose( metric_value, jnp.ones_like(metric_value), atol=0.1, err_msg=metric_key) elif 'mean' in metric_key: onp.testing.assert_allclose( metric_value, jnp.zeros_like(metric_value), atol=0.1, err_msg=metric_key) onp.testing.assert_allclose(std, jnp.ones_like(std), atol=0.4) onp.testing.assert_allclose(mean, jnp.zeros_like(mean), atol=0.6)
def test_resnetv1(self): rng = random.PRNGKey(10) key1, key2 = random.split(rng) x = random.normal(key1, (128, 32, 32, 3)) activation_f = 'bias_scale_SELU_norm' model_def = wideresnet.ResnetV1.partial( depth=20, num_outputs=10, activation_f=activation_f, normalization='none', dropout_rate=0, std_penalty_mult=0, use_residual=2, # TODO(basv): test with residual. bias_scale=0.0, weight_norm='none', no_head=True, report_metrics=True, ) (y, _, metrics), _ = model_def.create( key2, x, ) mean = jnp.mean(y, axis=(0, 1, 2)) std = jnp.std( y, axis=( 0, 1, 2, )) mean_x = jnp.mean(x, axis=(0, 1, 2)) std_x = jnp.std(x, axis=(0, 1, 2)) onp.testing.assert_allclose(mean_x, jnp.zeros_like(mean_x), atol=0.1) onp.testing.assert_allclose(std_x, jnp.ones_like(std_x), atol=0.1) for metric_key, metric_value in metrics.items(): if 'postnorm' in metric_key or 'postact' in metric_key or 'postres' in metric_key: if 'std' in metric_key: onp.testing.assert_allclose( metric_value, jnp.ones_like(metric_value), atol=0.1, err_msg=metric_key) elif 'mean' in metric_key: onp.testing.assert_allclose( metric_value, jnp.zeros_like(metric_value), atol=0.3, err_msg=metric_key) onp.testing.assert_allclose(mean, jnp.zeros_like(mean), atol=0.2) onp.testing.assert_allclose(std, jnp.ones_like(std), atol=0.3)
def test_wrn26_4(self): rng = random.PRNGKey(10) key1, key2 = random.split(rng) x = random.normal(key1, (128, 32, 32, 3)) for activation_f in ['bias_scale_SELU_norm']: model_def = wideresnet.WideResnet.partial( blocks_per_group=4, channel_multiplier=4, num_outputs=10, activation_f=activation_f, normalization='none', dropout_rate=0, std_penalty_mult=0, use_residual=2, # TODO(basv): test with residual. bias_scale=0.0, weight_norm='learned', no_head=True, ) (y, _, metrics), _ = model_def.create( key2, x, ) mean = jnp.mean(jnp.abs(jnp.mean(y, axis=(0, 1, 2)))) std = jnp.mean(jnp.std(y, axis=(0, 1, 2))) mean_x = jnp.mean(x, axis=(0, 1, 2)) std_x = jnp.std(x, axis=(0, 1, 2)) onp.testing.assert_allclose(mean_x, jnp.zeros_like(mean_x), atol=0.1) onp.testing.assert_allclose(std_x, jnp.ones_like(std_x), atol=0.1) for metric_key, metric_value in metrics.items(): if 'postnorm' in metric_key or 'postact' in metric_key or 'postres' in metric_key: if 'std' in metric_key: onp.testing.assert_allclose( metric_value, jnp.ones_like(metric_value), atol=0.2, err_msg=metric_key) elif 'mean' in metric_key: onp.testing.assert_allclose( metric_value, jnp.zeros_like(metric_value), atol=0.2, err_msg=metric_key) onp.testing.assert_allclose(mean, jnp.zeros_like(mean), atol=0.1) onp.testing.assert_allclose(std, jnp.ones_like(std), atol=0.1)
def eval_log_prob(model,data,energies,F,T): #logPM=model(data) #logPE=-energies/T-F #return jnp.linalg.norm(logPM-logPE)/data.shape[0] logPM=model(data)+energies/T #logPE=-energies/T-F return jnp.std(logPM)
def test_overall_mean_variance(self): noise = OrnsteinUhlenbeckNoise(random_seed=13) x = jnp.stack([noise(0.) for _ in range(1000)]) mu, sigma = jnp.mean(x), jnp.std(x) self.assertLess(abs(mu), noise.theta) self.assertGreater(sigma, noise.sigma) self.assertLess(sigma, noise.sigma * 2)
def sample_kernel(sa_state, model_args=(), model_kwargs=None): pe_fn = potential_fn if potential_fn_gen: pe_fn = potential_fn_gen(*model_args, **model_kwargs) zs, pes, loc, scale = sa_state.adapt_state # we recompute loc/scale after each iteration to avoid precision loss # XXX: consider to expose a setting to do this job periodically # to save some computations loc = jnp.mean(zs, 0) if scale.ndim == 2: cov = jnp.cov(zs, rowvar=False, bias=True) if cov.shape == (): # JAX returns scalar for 1D input cov = cov.reshape((1, 1)) cholesky = jnp.linalg.cholesky(cov) scale = jnp.where(jnp.any(jnp.isnan(cholesky)), scale, cholesky) else: scale = jnp.std(zs, 0) rng_key, rng_key_z, rng_key_reject, rng_key_accept = random.split(sa_state.rng_key, 4) _, unravel_fn = ravel_pytree(sa_state.z) z = loc + _sample_proposal(scale, rng_key_z) pe = pe_fn(unravel_fn(z)) pe = jnp.where(jnp.isnan(pe), jnp.inf, pe) diverging = (pe - sa_state.potential_energy) > max_delta_energy # NB: all terms having the pattern *s will have shape N x ... # and all terms having the pattern *s_ will have shape (N + 1) x ... locs, scales = _get_proposal_loc_and_scale(zs, loc, scale, z) zs_ = jnp.concatenate([zs, z[None, :]]) pes_ = jnp.concatenate([pes, pe[None]]) locs_ = jnp.concatenate([locs, loc[None, :]]) scales_ = jnp.concatenate([scales, scale[None, ...]]) if scale.ndim == 2: # dense_mass log_weights_ = dist.MultivariateNormal(locs_, scale_tril=scales_).log_prob(zs_) + pes_ else: log_weights_ = dist.Normal(locs_, scales_).log_prob(zs_).sum(-1) + pes_ # mask invalid values (nan, +inf) by -inf log_weights_ = jnp.where(jnp.isfinite(log_weights_), log_weights_, -jnp.inf) # get rejecting index j = random.categorical(rng_key_reject, log_weights_) zs = _numpy_delete(zs_, j) pes = _numpy_delete(pes_, j) loc = locs_[j] scale = scales_[j] adapt_state = SAAdaptState(zs, pes, loc, scale) # NB: weights[-1] / sum(weights) is the probability of rejecting the new sample `z`. accept_prob = 1 - jnp.exp(log_weights_[-1] - logsumexp(log_weights_)) itr = sa_state.i + 1 n = jnp.where(sa_state.i < wa_steps, itr, itr - wa_steps) mean_accept_prob = sa_state.mean_accept_prob + (accept_prob - sa_state.mean_accept_prob) / n # XXX: we make a modification of SA sampler in [1] # in [1], each MCMC state contains N points `zs` # here we do resampling to pick randomly a point from those N points k = random.categorical(rng_key_accept, jnp.zeros(zs.shape[0])) z = unravel_fn(zs[k]) pe = pes[k] return SAState(itr, z, pe, accept_prob, mean_accept_prob, diverging, adapt_state, rng_key)
def NES_profile_jax(params, params_to_xL, score_function, npop=50, sigma_noise=0.1, alpha=0.05): """Natural Evolutionary strategy Args: npop: population size sigma: standard deviation alpha: learning rate """ def single_update(pi, ni): p_new = pi + sigma_noise * ni xL_new = params_to_xL(p_new) reward_new = score_function(xL=xL_new) return reward_new num_params = params.shape[0] xL = params_to_xL(params) N = np.array(onp.random.randn(npop, num_params)) R = vmap(single_update, (None, 0), 0)(params, N) A = (R - np.mean(R)) / (np.std(R) + 1e-6) params_update = params - alpha / (npop * sigma_noise) * np.dot(N.T, A) return params_update
def main(args): _, fetch_train = load_dataset(UCBADMIT, split="train", shuffle=False) dept, male, applications, admit = fetch_train() rng_key, rng_key_predict = random.split(random.PRNGKey(1)) zs = run_inference(dept, male, applications, admit, rng_key, args) pred_probs = Predictive(glmm, zs)(rng_key_predict, dept, male, applications)["probs"] header = "=" * 30 + "glmm - TRAIN" + "=" * 30 print_results(header, pred_probs, dept, male, admit / applications) # make plots fig, ax = plt.subplots(figsize=(8, 6), constrained_layout=True) ax.plot(range(1, 13), admit / applications, "o", ms=7, label="actual rate") ax.errorbar( range(1, 13), jnp.mean(pred_probs, 0), jnp.std(pred_probs, 0), fmt="o", c="k", mfc="none", ms=7, elinewidth=1, label=r"mean $\pm$ std", ) ax.plot(range(1, 13), jnp.percentile(pred_probs, 5, 0), "k+") ax.plot(range(1, 13), jnp.percentile(pred_probs, 95, 0), "k+") ax.set( xlabel="cases", ylabel="admit rate", title="Posterior Predictive Check with 90% CI", ) ax.legend() plt.savefig("ucbadmit_plot.pdf")
def get_norm(init_x): mean = jnp.mean(init_x, axis=0) std = jnp.std(init_x, axis=0) def norm(x): return (x - mean) / (std + 1e-5) return norm
def test_unnormalized_normal_chain(kernel, kwargs, num_chains): from numpyro.contrib.tfp import mcmc # TODO: remove when this issue is fixed upstream # https://github.com/tensorflow/probability/pull/1087 if num_chains == 2 and kernel == "ReplicaExchangeMC": pytest.xfail( "ReplicaExchangeMC is not fully compatible with omnistaging yet.") kernel_class = getattr(mcmc, kernel) true_mean, true_std = 1., 0.5 warmup_steps, num_samples = (1000, 8000) def potential_fn(z): return 0.5 * ((z - true_mean) / true_std)**2 init_params = jnp.array(0.) if num_chains == 1 else jnp.array([0., 2.]) tfp_kernel = kernel_class(potential_fn=potential_fn, **kwargs) mcmc = MCMC(tfp_kernel, warmup_steps, num_samples, num_chains=num_chains, progress_bar=False) mcmc.run(random.PRNGKey(0), init_params=init_params) mcmc.print_summary() hmc_states = mcmc.get_samples() assert_allclose(jnp.mean(hmc_states), true_mean, rtol=0.07) assert_allclose(jnp.std(hmc_states), true_std, rtol=0.07)
def get_toy_pairs(N=20, S=2, P=10, sigma_obs=0.05, active_pairs=[(0, 1), (1, 2)]): assert S < P and P > 1 and S > 0 onp.random.seed(0) X = onp.random.randn(N, P) # generate S coefficients with non-negligible magnitude W = 0.5 + 2.5 * onp.random.rand(S) # generate data using the S coefficients and however many pairwise interactions Y = onp.sum(X[:, 0:S] * W, axis=-1) # now add in all pairwise interactions for pair in active_pairs: Y += X[:, pair[0]] * X[:, pair[1]] Y += sigma_obs * onp.random.randn(N) Y -= np.mean(Y) Y_std = np.std(Y) assert X.shape == (N, P) assert Y.shape == (N, ) return X, Y / Y_std, W / Y_std, 1.0 / Y_std
def sinkhorn_for_sort(inputs: jnp.ndarray, weights: jnp.ndarray, target_weights: jnp.ndarray, sinkhorn_kw, pointcloud_kw) -> jnp.ndarray: """Runs sinkhorn on a fixed increasing target. Args: inputs: jnp.ndarray[num_points]. Must be one dimensional. weights: jnp.ndarray[num_points]. The weights 'a' for the inputs. target_weights: jnp.ndarray[num_targets]: the weights of the targets. It may be of a different size than the weights. sinkhorn_kw: a dictionary holding the sinkhorn keyword arguments. See sinkhorn.py for more details. pointcloud_kw: a dictionary holding the keyword arguments of the PointCloud class. See pointcloud.py for more details. Returns: A jnp.ndarray<float> representing the transport matrix of the inputs onto the underlying sorted target. """ shape = inputs.shape if len(shape) > 2 or (len(shape) == 2 and shape[1] != 1): raise ValueError( "Shape ({shape}) not supported. The input should be one-dimensional." ) x = jnp.expand_dims(jnp.squeeze(inputs), axis=1) x = jax.nn.sigmoid((x - jnp.mean(x)) / (jnp.std(x) + 1e-10)) a = jnp.squeeze(weights) b = jnp.squeeze(target_weights) num_targets = b.shape[0] y = jnp.linspace(0.0, 1.0, num_targets)[:, jnp.newaxis] geom = pointcloud.PointCloud(x, y, **pointcloud_kw) res = sinkhorn.sinkhorn(geom, a, b, **sinkhorn_kw) return geom.transport_from_potentials(res.f, res.g)
def ksd_squared_l(samples, logp, k, return_stddev=False): """ O(n) time estimator for the KSD. Arguments: * samples: np.array of shape (n, d) * logp: callable * k: callable, computes scalar-valued kernel k(x, y) given two input arguments of shape (d,). Returns: * The square of the stein discrepancy KSD(q, p). KSD is approximated as $\sum_i g(x_i, y_i)$, where the x and y are iid distributed as q * The approximate variance of h(X, Y) """ try: xs, ys = samples.split(2) except ValueError: # uneven split xs, ys = samples[:-1].split(2) def h(x, y): """x, y: np.arrays of shape (d,)""" def inner(x): return stein_operator(lambda y_: k(x, y_), y, logp) return stein_operator(inner, x, logp, transposed=True) outs = vmap(h)(xs, ys) if return_stddev: return np.mean(outs), np.std(outs, ddof=1) / xs.shape[0] else: return np.mean(outs)
def log_s_shift_init(shape, dtype): if x.ndim == len(shape): return jnp.zeros(shape, dtype) z = self.f(weight_logits, means, log_scales, x) axes = tuple(jnp.arange(len(z.shape) - len(shape))) return jnp.log(jnp.std(z, axis=axes) + 1e-5)
def test_weight_norm_standard(self): rng = random.PRNGKey(5) key1, key2 = random.split(rng) for k in [3, 5]: for padding in ['VALID', 'SAME']: for layer in [ conv_layers.Conv, conv_layers.ConvWS, conv_layers.ConvFixedScale ]: x = random.normal(key1, (512, 32, 32, 128)) y = x for i in range(5): y, _ = layer.create( key2, y, features=128, kernel_size=(k, k), bias=False, padding=padding, kernel_init=jax.nn.initializers.lecun_normal()) mean = jnp.mean(y) std = jnp.std(y) err_msg = 'layer %s, padding %s, kernel_size %d, depth %d' % ( layer.__name__, padding, k, i) onp.testing.assert_allclose(mean, jnp.zeros_like(mean), atol=0.1, err_msg=err_msg) onp.testing.assert_allclose(std, jnp.ones_like(std), atol=0.1, err_msg=err_msg)
def model_update_minibatch( carry: Tuple[networks_lib.Params, optax.OptState], minibatch: Batch, ) -> Tuple[Tuple[networks_lib.Params, optax.OptState], Dict[ str, jnp.ndarray]]: """Performs model update for a single minibatch.""" params, opt_state = carry # Normalize advantages at the minibatch level before using them. advantages = ((minibatch.advantages - jnp.mean(minibatch.advantages, axis=0)) / (jnp.std(minibatch.advantages, axis=0) + 1e-8)) gradients, metrics = grad_fn(params, minibatch.observations, minibatch.actions, minibatch.behavior_log_probs, minibatch.target_values, advantages, minibatch.behavior_values) # Apply updates updates, opt_state = optimizer.update(gradients, opt_state) params = optax.apply_updates(params, updates) metrics['norm_grad'] = optax.global_norm(gradients) metrics['norm_updates'] = optax.global_norm(updates) return (params, opt_state), metrics
def test_unnormalized_normal_x64(kernel_cls, dense_mass): true_mean, true_std = 1.0, 0.5 num_warmup, num_samples = (100000, 100000) if kernel_cls is SA else (1000, 8000) def potential_fn(z): return 0.5 * jnp.sum(((z - true_mean) / true_std)**2) init_params = jnp.array(0.0) if kernel_cls is SA: kernel = SA(potential_fn=potential_fn, dense_mass=dense_mass) elif kernel_cls is BarkerMH: kernel = SA(potential_fn=potential_fn, dense_mass=dense_mass) else: kernel = kernel_cls(potential_fn=potential_fn, trajectory_length=8, dense_mass=dense_mass) mcmc = MCMC(kernel, num_warmup=num_warmup, num_samples=num_samples, progress_bar=False) mcmc.run(random.PRNGKey(0), init_params=init_params) mcmc.print_summary() hmc_states = mcmc.get_samples() assert_allclose(jnp.mean(hmc_states), true_mean, rtol=0.07) assert_allclose(jnp.std(hmc_states), true_std, rtol=0.07) if "JAX_ENABLE_X64" in os.environ: assert hmc_states.dtype == jnp.float64
def test_klee_measure(): from jax import random, jit, disable_jit import pylab as plt N, D = 2, 2 points = random.uniform(random.PRNGKey(1), shape=(N, D)) # points = jnp.array([[0., 1.],[0., 0.]]) eps = 0.1 gamma = 0.90 for w in jnp.linspace(0., 1., 10): true_volume = 2. * w**2 - cubes_intersect_volume( points[0, :], points[1, :], w) vol = jnp.exp( jit( vmap(lambda key: log_klee_measure( key, points, w, eps=eps, gamma=gamma)))(random.split( random.PRNGKey(0), 100))) print(jnp.mean(vol), jnp.std(vol), true_volume) eps_bound = jnp.mean((vol <= true_volume * (1. + eps)) & (vol >= true_volume * (1. - eps))) l = w / 2. plt.scatter(points[:, 0], points[:, 1]) for i in range(N): plt.plot([ points[i, 0] - l, points[i, 0] + l, points[i, 0] + l, points[i, 0] - l, points[i, 0] - l ], [ points[i, 1] - l, points[i, 1] - l, points[i, 1] + l, points[i, 1] + l, points[i, 1] - l ], c='black') plt.title("prob_bound {}".format(eps_bound, true_volume)) plt.show()
def get_opd(self, wave): """ Parameters ---------- wave : morphine.Wavefront (or float) Incoming Wavefront before this optic to set wavelength and scale, or a float giving the wavelength in meters for a temporary Wavefront used to compute the OPD. """ y, x = self.get_coordinates(wave) rho, theta = _wave_y_x_to_rho_theta(y, x, self.radius) psd = np.power(rho, -self.index) # generate power-law PSD np.random.seed( self.seed) # if provided, set a seed for random number generator rndm_phase = np.random.normal( size=(len(y), len(x))) # generate random phase screen rndm_psd = np.fft.fftshift(np.fft.fft2(np.fft.fftshift( rndm_phase))) # FT of random phase screen to get random PSD scaled = np.sqrt(psd) * rndm_psd # scale random PSD by power-law PSD phase_screen = np.fft.ifftshift(np.fft.ifft2(np.fft.ifftshift( scaled))).real # FT of scaled random PSD makes phase screen phase_screen -= np.mean(phase_screen) # force zero-mean opd = phase_screen / np.std( phase_screen) * self.wfe # normalize to wanted input rms wfe return opd
def transport_for_sort(inputs: jnp.ndarray, weights: jnp.ndarray, target_weights: jnp.ndarray, kwargs) -> jnp.ndarray: """Runs sinkhorn on a fixed increasing target. Args: inputs: jnp.ndarray[num_points]. Must be one dimensional. weights: jnp.ndarray[num_points]. The weights 'a' for the inputs. target_weights: jnp.ndarray[num_targets]: the weights of the targets. It may be of a different size than the weights. kwargs: a dictionary holding the sinkhorn keyword arguments and the pointcloud argument. Returns: A jnp.ndarray<float> representing the transport matrix of the inputs onto the underlying sorted target. """ shape = inputs.shape if len(shape) > 2 or (len(shape) == 2 and shape[1] != 1): raise ValueError( 'Shape ({shape}) not supported. The input should be one-dimensional.' ) x = jnp.expand_dims(jnp.squeeze(inputs), axis=1) x = jax.nn.sigmoid((x - jnp.mean(x)) / (jnp.std(x) + 1e-10)) a = jnp.squeeze(weights) b = jnp.squeeze(target_weights) num_targets = b.shape[0] y = jnp.linspace(0.0, 1.0, num_targets)[:, jnp.newaxis] return transport.Transport(x, y, a=a, b=b, **kwargs)
def ppo_loss_given_predictions(log_probab_actions_new, log_probab_actions_old, value_predictions_old, padded_actions, padded_rewards, reward_mask, gamma=0.99, lambda_=0.95, epsilon=0.2): """PPO objective, with an eventual minus sign, given predictions.""" B, T = padded_rewards.shape # pylint: disable=invalid-name _, _, C, A = log_probab_actions_old.shape # pylint: disable=invalid-name assert (B, T) == padded_rewards.shape assert (B, T, C) == padded_actions.shape assert (B, T) == reward_mask.shape assert (B, T + 1, 1) == value_predictions_old.shape assert (B, T + 1, C, A) == log_probab_actions_old.shape assert (B, T + 1, C, A) == log_probab_actions_new.shape # (B, T) td_deltas = deltas( np.squeeze(value_predictions_old, axis=2), # (B, T+1) padded_rewards, reward_mask, gamma=gamma) # (B, T) advantages = gae_advantages( td_deltas, reward_mask, lambda_=lambda_, gamma=gamma) # Normalize the advantages. advantage_mean = np.mean(advantages) advantage_std = np.std(advantages) advantages = (advantages - advantage_mean) / (advantage_std + 1e-8) # (B, T) ratios = compute_probab_ratios(log_probab_actions_new, log_probab_actions_old, padded_actions, reward_mask) assert (B, T, C) == ratios.shape # (B, T) objective = clipped_objective( ratios, advantages, reward_mask, epsilon=epsilon) assert (B, T, C) == objective.shape # () average_objective = np.sum(objective) / np.sum(reward_mask) # Loss is negative objective. ppo_loss = -average_objective summaries = { "ppo_loss": ppo_loss, "advantage_mean": advantage_mean, "advantage_std": advantage_std, } return (ppo_loss, summaries)
def render_rays_fine(rays, z_vals, weights, num_importance, perturbation=True, rng=None): """Render rays for the fine model. Args: rays: (2, num_rays, 3) origin and direction generated rays z_vals: (num_rays, num_samples) depths of the sampled positions weights: (num_rays, num_samples) weights assigned to each sampled color for the coarse model num_importance: number of samples used in the fine model perturbation: whether to apply jitter on each ray or not rng: random key Returns: pts: (num_rays, num_samples + num_importance, 3) points in space to evaluate model at z_vals: (num_rays, num_samples + num_importance) depths of the sampled positions z_samples: (num_rays) standard deviation of distances along ray for each sample """ rays_o, rays_d = rays z_vals_mid = 0.5 * (z_vals[..., 1:] + z_vals[..., :-1]) z_samples = sample_pdf(z_vals_mid, weights[..., 1:-1], num_importance, perturbation, rng) z_samples = lax.stop_gradient(z_samples) # obtain all points to evaluate color density at z_vals = jnp.sort(jnp.concatenate([z_vals, z_samples], axis=-1), axis=-1) z_vals = z_vals.astype(rays_d.dtype) pts = rays_o[..., None, :] + rays_d[..., None, :] * z_vals[..., :, None] return pts, z_vals, jnp.std(z_samples, axis=-1)
def test_nested_normalize(self): state = running_statistics.init_state({ 'a': specs.Array((5,), jnp.float32), 'b': specs.Array((2,), jnp.float32) }) x1 = { 'a': jnp.arange(20, dtype=jnp.float32).reshape(2, 2, 5), 'b': jnp.arange(8, dtype=jnp.float32).reshape(2, 2, 2) } x2 = { 'a': jnp.arange(20, dtype=jnp.float32).reshape(2, 2, 5) + 20, 'b': jnp.arange(8, dtype=jnp.float32).reshape(2, 2, 2) + 8 } x3 = { 'a': jnp.arange(40, dtype=jnp.float32).reshape(4, 2, 5), 'b': jnp.arange(16, dtype=jnp.float32).reshape(4, 2, 2) } state = update_and_validate(state, x1) state = update_and_validate(state, x2) state = update_and_validate(state, x3) normalized = running_statistics.normalize(x3, state) mean = tree.map_structure(lambda x: jnp.mean(x, axis=(0, 1)), normalized) std = tree.map_structure(lambda x: jnp.std(x, axis=(0, 1)), normalized) tree.map_structure( lambda x: self.assert_allclose(x, jnp.zeros_like(x)), mean) tree.map_structure( lambda x: self.assert_allclose(x, jnp.ones_like(x)), std)
def test_pmap_update_nested(self): local_device_count = jax.local_device_count() state = running_statistics.init_state({ 'a': specs.Array((5,), jnp.float32), 'b': specs.Array((2,), jnp.float32) }) x = { 'a': (jnp.arange(15 * local_device_count, dtype=jnp.float32)).reshape(local_device_count, 3, 5), 'b': (jnp.arange(6 * local_device_count, dtype=jnp.float32)).reshape(local_device_count, 3, 2), } devices = jax.local_devices() state = jax.device_put_replicated(state, devices) pmap_axis_name = 'i' state = jax.pmap( functools.partial(update_and_validate, pmap_axis_name=pmap_axis_name), pmap_axis_name)(state, x) state = jax.pmap( functools.partial(update_and_validate, pmap_axis_name=pmap_axis_name), pmap_axis_name)(state, x) normalized = jax.pmap(running_statistics.normalize)(x, state) mean = tree.map_structure(lambda x: jnp.mean(x, axis=(0, 1)), normalized) std = tree.map_structure(lambda x: jnp.std(x, axis=(0, 1)), normalized) tree.map_structure( lambda x: self.assert_allclose(x, jnp.zeros_like(x)), mean) tree.map_structure( lambda x: self.assert_allclose(x, jnp.ones_like(x)), std)
def __call__(self, x, sigmas, train=True): # per image standardization N = np.prod(x.shape[1:]) x = (x - jnp.mean(x, axis=(1, 2, 3), keepdims=True)) / jnp.maximum( jnp.std(x, axis=(1, 2, 3), keepdims=True), 1. / np.sqrt(N)) temb = GaussianFourierProjection(embedding_size=128, scale=16)(jnp.log(sigmas)) temb = nn.Dense(128 * 4)(temb) temb = nn.Dense(128 * 4)(nn.swish(temb)) x = nn.Conv(16, (3, 3), padding='SAME', name='init_conv', kernel_init=conv_kernel_init_fn, use_bias=False)(x) x = WideResnetGroup(self.blocks_per_group, 16 * self.channel_multiplier, activate_before_residual=True)(x, temb, train) x = WideResnetGroup(self.blocks_per_group, 32 * self.channel_multiplier, (2, 2))(x, temb, train) x = WideResnetGroup(self.blocks_per_group, 64 * self.channel_multiplier, (2, 2))(x, temb, train) x = activation(x, train=train, name='pre-pool-bn') x = nn.avg_pool(x, x.shape[1:3]) x = x.reshape((x.shape[0], -1)) x = nn.Dense(self.num_outputs, kernel_init=dense_layer_init_fn)(x) return x