def compute_loss(predicted_positions, predicted_momentums, target_positions, target_momentums, auxiliary_predictions=None, regularizations=None): """Computes the loss for the given predictions.""" assert predicted_positions.shape == target_positions.shape, f'Got predicted_positions: {predicted_positions.shape}, target_positions: {target_positions.shape}' assert predicted_momentums.shape == target_momentums.shape, f'Got predicted_momentums: {predicted_momentums.shape}, target_momentums: {target_momentums.shape}' loss = optax.l2_loss(predictions=predicted_positions, targets=target_positions) loss += optax.l2_loss(predictions=predicted_momentums, targets=target_momentums) loss = jnp.mean(loss) if auxiliary_predictions is not None: angular_velocities = auxiliary_predictions['angular_velocities'] angular_velocities_variances = jnp.var(angular_velocities, axis=0).sum() loss += regularizations[ 'angular_velocities'] * angular_velocities_variances actions = auxiliary_predictions['actions'] actions_variances = jnp.var(actions, axis=0).sum() loss += regularizations['actions'] * actions_variances return loss
def loss_reg(theta, X, Y, W, data_weights, reg_scale): #W = jnp.expand_dims(W, axis=-1)*jnp.eye(W.shape[-1]) # *jnp.expand_dims(jnp.expand_dims(jnp.eye(W.shape[-1]), 0), 0) #W = jnp.expand_dims(jnp.expand_dims(W, axis=0), axis=0) #X = jnp.expand_dims(jnp.expand_dims(X, 0), 0) #print("shapes",theta.shape,X.shape, W.shape, Y.shape) #m = jnp.matmul(W, X) #norm = np.matmul(X.transpose((0,1,2,4,3)), np.matmul(W, X)) l_chiSq = jnp.einsum('jia,jka->jki', X, theta) - Y #print("lch", np.sum(np.isinf(np.array(l_chiSq)))) #print("l shape", l_chiSq.shape) chiSq = jnp.mean(l_chiSq * W * l_chiSq, axis=-1) #x_overlap = jnp.matmul(X.transpose(), X) ##x_overlap *= 1 - jnp.eye(X.shape[-1]) reg_overlap = jnp.sum(theta**2, axis=-1) #x_overlap *= 2*(jnp.eye(X.shape[-1]) - 0.5) #reg_overlap = jnp.einsum('jka,ab,jkb->jk', # theta, x_overlap, theta) reg_scale = jnp.expand_dims( jnp.expand_dims(reg_scale / jnp.var(Y, axis=-1), -1), -1) #print("shapessss", chiSq.shape,reg_overlap.shape, x_overlap.shape) data_weights = jnp.var(Y, axis=-1) loss = np.sum((data_weights*(chiSq + reg_scale*reg_overlap))[0:])\ /(chiSq.shape[1]*np.sum(data_weights[0:])) #loss = chiSq + reg_scale*reg_overlap #loss = reg_scale*reg_overlap return loss
def testVarianceScaling(self, map_in, map_out, fan, distr): shape = (80, 50, 7) fan_in, fan_out = jax._src.nn.initializers._compute_fans( NamedShape(*shape), 0, 1) key = jax.random.PRNGKey(0) base_scaling = partial(jax.nn.initializers.variance_scaling, 100, fan, distr) ref_sampler = lambda: base_scaling(in_axis=0, out_axis=1)(key, shape) if map_in and map_out: out_axes = ['i', 'o', ...] named_shape = NamedShape(shape[2], i=shape[0], o=shape[1]) xmap_sampler = lambda: base_scaling(in_axis='i', out_axis='o')( key, named_shape) elif map_in: out_axes = ['i', ...] named_shape = NamedShape(shape[1], shape[2], i=shape[0]) xmap_sampler = lambda: base_scaling(in_axis='i', out_axis=0)( key, named_shape) elif map_out: out_axes = [None, 'o', ...] named_shape = NamedShape(shape[0], shape[2], o=shape[1]) xmap_sampler = lambda: base_scaling(in_axis=0, out_axis='o')( key, named_shape) mapped_sampler = xmap(xmap_sampler, in_axes=(), out_axes=out_axes, axis_sizes={ 'i': shape[0], 'o': shape[1] }) self.assertAllClose(jnp.var(mapped_sampler()), jnp.var(ref_sampler()), atol=1e-4, rtol=2e-2)
def test_segment_variance(self): result = utils.segment_variance(jnp.arange(8), jnp.array([0, 0, 0, 1, 1, 2, 2, 2]), 3) self.assertAllClose( result, jnp.stack([ jnp.var(jnp.arange(3)), jnp.var(jnp.arange(3, 5)), jnp.var(jnp.arange(5, 8)) ]))
def correlate(data, fitted): """ Designed for images (n_imgs, pxs). """ x, y = check_data(data, fitted) Ex = jnp.mean(x, axis=1, keepdims=True) Ey = jnp.mean(y, axis=1, keepdims=True) cov = jnp.mean((x - Ex) * (y - Ey), axis=1) var_x = jnp.sqrt(jnp.var(x, axis=1)) var_y = jnp.sqrt(jnp.var(y, axis=1)) return cov / (var_x * var_y)
def test_submission_from_samples_linear(metaculus_questions, logistic_mixture_samples): normalized_mixture = metaculus_questions.continuous_linear_open_question.get_submission_from_samples( logistic_mixture_samples ) normalized_mixture_samples = [normalized_mixture.sample() for _ in range(5000)] mixture_samples = metaculus_questions.continuous_linear_open_question.denormalize_samples( normalized_mixture_samples ) assert float(np.mean(logistic_mixture_samples)) == pytest.approx( float(np.mean(mixture_samples)), rel=0.1 ) assert float(np.var(logistic_mixture_samples)) == pytest.approx( float(np.var(mixture_samples)), rel=0.2 )
def test_discrete_gibbs_gmm_1d(modified): def model(probs, locs): c = numpyro.sample("c", dist.Categorical(probs)) numpyro.sample("x", dist.Normal(locs[c], 0.5)) probs = jnp.array([0.15, 0.3, 0.3, 0.25]) locs = jnp.array([-2, 0, 2, 4]) kernel = DiscreteHMCGibbs(NUTS(model), modified=modified) mcmc = MCMC(kernel, 1000, 200000, progress_bar=False) mcmc.run(random.PRNGKey(0), probs, locs) samples = mcmc.get_samples() assert_allclose(jnp.mean(samples["x"]), 1.3, atol=0.1) assert_allclose(jnp.var(samples["x"]), 4.36, atol=0.1) assert_allclose(jnp.mean(samples["c"]), 1.65, atol=0.1) assert_allclose(jnp.var(samples["c"]), 1.03, atol=0.1)
def visualize_normals(depth, acc, scaling=None): """Visualize fake normals of `depth` (optionally scaled to be isotropic).""" if scaling is None: mask = ~jnp.isnan(depth) x, y = jnp.meshgrid(jnp.arange(depth.shape[1]), jnp.arange(depth.shape[0]), indexing='xy') xy_var = (jnp.var(x[mask]) + jnp.var(y[mask])) / 2 z_var = jnp.var(depth[mask]) scaling = jnp.sqrt(xy_var / z_var) scaled_depth = scaling * depth normals = depth_to_normals(scaled_depth) return matte( jnp.isnan(normals) + jnp.nan_to_num((normals + 1) / 2, 0), acc)
def test_mean_and_var_mid(): t0 = 0.0 t1 = 3.0 y0 = np.linspace(0.1, 0.9, D) num_samples = 500 vals = onp.zeros((num_samples, D)) for i in range(num_samples): rng = random.PRNGKey(i) bm = make_brownian_motion(t0, np.zeros(y0.shape), t1, rng) vals[i, :] = bm(t1 / 2.0) print(np.mean(vals), np.var(vals)) assert np.allclose(np.mean(vals), 0.0, atol=1e-1, rtol=1e-1) assert np.allclose(np.var(vals), t1 / 2.0, atol=1e-1, rtol=1e-1)
def update(chain_state, _, rhat_state: GelmanRubinState) -> GelmanRubinState: """Update rhat estimates Parameters ---------- chain_state: HMCState The chain state rhat_state: GelmanRubinState The GelmanRubinState from the previous draw Returns ------- An updated GelmanRubinState object """ within_state, *_ = rhat_state positions = chain_state.position within_state = w_update(within_state, positions) covariance, step, mean = w_covariance(within_state) within_var = jnp.mean(covariance, axis=0) between_var = jnp.var(mean, axis=0, ddof=1) estimator = ((step - 1) / step) * within_var + between_var rhat = jnp.sqrt(estimator / within_var) worst_rhat = rhat[jnp.argmax(jnp.abs(rhat - 1.0))] return GelmanRubinState(within_state, rhat, worst_rhat)
def variance_loss_objective(encoder_params, decoder_params, batch, prng_key, num_samples=1): """ Computes the variance loss objective of a discrete VAE. :param encoder_params: encoder parameters (list) :param decoder_params: decoder parameters (list) :param batch: batch of data (jax.numpy array) :param prng_key: PRNG key :param num_samples: number of samples """ posterior_params = encoder(encoder_params, batch) # BxD posterior_samples = lax.stop_gradient( bernoulli.sample(posterior_params, prng_key, num_samples)) log_prior = np.sum(vmap(bernoulli.logpmf, in_axes=(0, None))( posterior_samples, 0.5 * np.ones( (batch.shape[0], posterior_params.shape[-1]))), axis=(1, 2)) # SxBxD log_posterior = np.sum(vmap(bernoulli.logpmf, in_axes=(0, None))(posterior_samples, posterior_params), axis=(1, 2)) log_likelihood = vmap(bernoulli_log_likelihood, in_axes=(None, 0, None))(decoder_params, posterior_samples, batch) elbo_samples = log_likelihood - log_posterior + log_prior return np.var(elbo_samples, axis=0, ddof=1) / batch.shape[0]
def predict_f(key, x, x_cov, full_covariance=False): # create distribution # print(x.min(), x.max(), x_cov.min(), x_cov.max()) x_dist = z(x, x_cov) # sample x_mc_samples = x_dist.sample((n_samples,), key) # function predictions over mc samples # (N,M,P) = f(N,D,M) y_mu_mc = jax.vmap(gp_pred.predict_mean, in_axes=0, out_axes=1)(x_mc_samples) # mean of mc samples # (N,P,) = (N,M,P) y_mu = jnp.mean(y_mu_mc, axis=1) if full_covariance: # =================== # Covariance # =================== # (N,P,M) - (N,P,1) -> (N,P,M) dfydx = y_mu_mc - y_mu[..., None] # (N,M,P) @ (M,M) @ (N,M,P) -> (N,P,D) cov = wc * jnp.einsum("ijk,lmn->ikl", dfydx, dfydx.T) return y_mu, cov else: # (N,P) = (N,M,P) var = jnp.var(y_mu_mc, axis=1) return y_mu, var
def test_estimate_likelihood(kernel_cls): data_key, tr_key, sub_key, rng_key = random.split(random.PRNGKey(0), 4) ref_params = jnp.array([0.1, 0.5, -0.2]) sigma = 0.1 data = ref_params + dist.Normal(jnp.zeros(3), jnp.ones(3)).sample( data_key, (10_000,) ) n, _ = data.shape num_warmup = 200 num_samples = 200 num_blocks = 20 def model(data): mean = numpyro.sample( "mean", dist.Normal(ref_params, jnp.ones_like(ref_params)) ) with numpyro.plate("N", data.shape[0], subsample_size=100, dim=-2) as idx: numpyro.sample("obs", dist.Normal(mean, sigma), obs=data[idx]) proxy_fn = HMCECS.taylor_proxy({"mean": ref_params}) kernel = HMCECS(kernel_cls(model), proxy=proxy_fn, num_blocks=num_blocks) mcmc = MCMC(kernel, num_warmup=num_warmup, num_samples=num_samples) mcmc.run(random.PRNGKey(0), data, extra_fields=["hmc_state.potential_energy"]) pes = mcmc.get_extra_fields()["hmc_state.potential_energy"] samples = mcmc.get_samples() pes_full = vmap( lambda sample: log_density( model, (data,), {}, {**sample, **{"N": jnp.arange(n)}} )[0] )(samples) assert jnp.var(jnp.exp(-pes - pes_full)) < 1.0
def test_dense_mass(kernel_cls, rho): warmup_steps, num_samples = 20000, 10000 true_cov = jnp.array([[10.0, rho], [rho, 0.1]]) def model(): numpyro.sample( "x", dist.MultivariateNormal(jnp.zeros(2), covariance_matrix=true_cov) ) if kernel_cls is HMC or kernel_cls is NUTS: kernel = kernel_cls(model, trajectory_length=2.0, dense_mass=True) elif kernel_cls is BarkerMH: kernel = BarkerMH(model, dense_mass=True) mcmc = MCMC(kernel, warmup_steps, num_samples, progress_bar=False) mcmc.run(random.PRNGKey(0)) mass_matrix_sqrt = mcmc.last_state.adapt_state.mass_matrix_sqrt if kernel_cls is HMC or kernel_cls is NUTS: mass_matrix_sqrt = mass_matrix_sqrt[("x",)] mass_matrix = jnp.matmul(mass_matrix_sqrt, jnp.transpose(mass_matrix_sqrt)) estimated_cov = jnp.linalg.inv(mass_matrix) assert_allclose(estimated_cov, true_cov, rtol=0.10) samples = mcmc.get_samples()["x"] assert_allclose(jnp.mean(samples[:, 0]), jnp.array(0.0), atol=0.50) assert_allclose(jnp.mean(samples[:, 1]), jnp.array(0.0), atol=0.05) assert_allclose(jnp.mean(samples[:, 0] * samples[:, 1]), jnp.array(rho), atol=0.20) assert_allclose(jnp.var(samples, axis=0), jnp.array([10.0, 0.1]), rtol=0.20)
def test_expected_sin(self): normal_samples = random.normal(random.PRNGKey(0), (10000, )) for mu, var in [(0, 1), (1, 3), (-2, .2), (10, 10)]: sin_mu, sin_var = mip.expected_sin(mu, var) x = jnp.sin(jnp.sqrt(var) * normal_samples + mu) self.assertAllClose(sin_mu, jnp.mean(x), atol=1e-2) self.assertAllClose(sin_var, jnp.var(x), atol=1e-2)
def estimator(likelihoods, params, gibbs_state): subsample_log_liks = defaultdict(float) for (fn, value, name, subsample_dim) in likelihoods.values(): subsample_log_liks[name] += _sum_all_except_at_dim( fn.log_prob(value), subsample_dim ) log_lik_sum = 0.0 proxy_value_all, proxy_value_subsample = proxy_fn( params, subsample_log_liks.keys(), gibbs_state ) for ( name, subsample_log_lik, ) in subsample_log_liks.items(): # loop over all subsample sites n, m = subsample_plate_sizes[name] diff = subsample_log_lik - proxy_value_subsample[name] unbiased_log_lik = proxy_value_all[name] + n * jnp.mean(diff) variance = n ** 2 / m * jnp.var(diff) log_lik_sum += unbiased_log_lik - 0.5 * variance return log_lik_sum
def apply(params, inputs): inputs = params[-2] / np.sqrt(np.var(inputs, axis=0)) * ( inputs - np.mean(inputs, axis=0)) + params[-1] for i in range(depth): #outputs = mlp(params, inputs) + inputs inputs = mlp(params[:-2], inputs) + inputs return inputs
def test_submission_from_samples_smooth(metaculus_questions, smooth_logistic_mixture): samples = np.array([smooth_logistic_mixture.sample() for _ in range(5000)]) fit_mixture = metaculus_questions.continuous_linear_open_question.get_submission_from_samples( samples) normalized_samples_from_fit_mixture = [ fit_mixture.sample() for _ in range(5000) ] mixture_samples = metaculus_questions.continuous_linear_open_question.denormalize_samples( normalized_samples_from_fit_mixture) assert float(np.mean(samples)) == pytest.approx(float( np.mean(mixture_samples)), rel=0.1) assert float(np.var(samples)) == pytest.approx(float( np.var(mixture_samples)), rel=0.2)
def bayesian_regression( X, y, prior_init=None, hyper_prior=((1e-6, 1e-6), (1e-6, 1e-6)), tol=1e-5, max_iter=300, ): n_samples, n_features = X.shape # Prepping matrices XT_y = jnp.dot(X.T, y) _, S, Vh = jnp.linalg.svd(X, full_matrices=False) eigen_vals = S**2 if prior_init is None: prior_init = jnp.ones((2, )) prior_init = jax.ops.index_update(prior_init, 1, 1 / (jnp.var(y) + 1e-7)) # Running prior_params, iterations = fixed_point_solver( update, (X, y, eigen_vals, Vh, XT_y, hyper_prior), prior_init, lambda z_prev, z: jnp.linalg.norm(z_prev[0] - z[0]) > tol, max_iter=max_iter, ) prior = stop_gradient(prior_params) log_LL, mn = evidence(X, y, prior, eigen_vals, Vh, XT_y, hyper_prior) metrics = (iterations, 0.0) return log_LL, mn, prior, metrics
def test_brownian(self, spatial_dimension, dtype): key = random.PRNGKey(0) key, T_split, mass_split = random.split(key, 3) _, shift = space.free() energy_fn = lambda R, **kwargs: f32(0) R = np.zeros((BROWNIAN_PARTICLE_COUNT, 2), dtype=dtype) mass = random.uniform(mass_split, (), minval=0.1, maxval=10.0, dtype=dtype) T = random.uniform(T_split, (), minval=0.3, maxval=1.4, dtype=dtype) dt = f32(1e-2) gamma = f32(0.1) init_fn, apply_fn = simulate.brownian(energy_fn, shift, dt, T, gamma=gamma) apply_fn = jit(apply_fn) state = init_fn(key, R, mass) sim_t = f32(BROWNIAN_DYNAMICS_STEPS * dt) for _ in range(BROWNIAN_DYNAMICS_STEPS): state = apply_fn(state) msd = np.var(state.position) th_msd = dtype(2 * T / (mass * gamma) * sim_t) assert np.abs(msd - th_msd) / msd < 1e-2 assert state.position.dtype == dtype
def __call__( self, inputs: jnp.ndarray, scale: Optional[jnp.ndarray] = None, offset: Optional[jnp.ndarray] = None, ) -> jnp.ndarray: """Connects the layer norm. Args: inputs: An array, where the data format is ``[N, ..., C]``. scale: An array up to n-D. The shape of this tensor must be broadcastable to the shape of ``inputs``. This is the scale applied to the normalized inputs. This cannot be passed in if the module was constructed with ``create_scale=True``. offset: An array up to n-D. The shape of this tensor must be broadcastable to the shape of ``inputs``. This is the offset applied to the normalized inputs. This cannot be passed in if the module was constructed with ``create_offset=True``. Returns: The array, normalized. """ if self.create_scale and scale is not None: raise ValueError( "Cannot pass `scale` at call time if `create_scale=True`.") if self.create_offset and offset is not None: raise ValueError( "Cannot pass `offset` at call time if `create_offset=True`.") axis = self.axis if isinstance(axis, slice): axis = tuple(range(inputs.ndim)[axis]) mean = jnp.mean(inputs, axis=axis, keepdims=True) variance = jnp.var(inputs, axis=axis, keepdims=True) param_shape = inputs.shape[-1:] if self.create_scale: scale = hk.get_parameter("scale", param_shape, inputs.dtype, init=self.scale_init) elif scale is None: scale = np.array(1., dtype=inputs.dtype) if self.create_offset: offset = hk.get_parameter("offset", param_shape, inputs.dtype, init=self.offset_init) elif offset is None: offset = np.array(0., dtype=inputs.dtype) scale = jnp.broadcast_to(scale, inputs.shape) offset = jnp.broadcast_to(offset, inputs.shape) mean = jnp.broadcast_to(mean, inputs.shape) eps = jax.lax.convert_element_type(self.eps, variance.dtype) inv = scale * jax.lax.rsqrt(variance + eps) return inv * (inputs - mean) + offset
def normal_equation_reg(X, Y, W, reg_scale): W = jnp.expand_dims(W, axis=-1)\ *jnp.expand_dims(jnp.expand_dims(jnp.eye(W.shape[-1]), 0), 0) #W = jnp.expand_dims(W, axis=0) X = jnp.expand_dims(X, 1) #print("1",X.shape, W.shape, Y.shape) #m = jnp.matmul(W, X) #norm = np.matmul(X.transpose((0,1,2,4,3)), np.matmul(W, X)) norm = jnp.matmul(X.transpose((0, 1, 3, 2)), jnp.matmul(W, X)) #x_overlap = jnp.abs(jnp.matmul(X.transpose((0,1,3,2)), X)) #x_overlap *= 1 - jnp.eye(X.shape[-1]) reg_overlap = jnp.expand_dims(jnp.expand_dims(jnp.eye(X.shape[-1]), 0), 0) #reg_overlap = jnp.matmul(X.transpose((0,1,3,2)), X) #reg_overlap *= 2*(jnp.eye(X.shape[-1])-0.5) # jnp.expand_dims(theta_sign, -1)*x_overlap*jnp.expand_dims(theta_sign, 0) reg_scale = jnp.expand_dims( jnp.expand_dims(reg_scale / jnp.var(Y, axis=-1), -1), -1) denom = jnp.linalg.inv(norm + reg_scale * reg_overlap) #denom = jnp.linalg.inv(norm) Y = jnp.expand_dims(Y, axis=-1) #print("2", jnp.matmul(W, Y).shape) numerator = jnp.matmul(X.transpose((0, 1, 3, 2)), jnp.matmul(W, Y)) #print("asdf", denom.shape, numerator.shape) fit = jnp.matmul(denom, numerator) return jnp.reshape(fit, list(fit.shape)[:-1])
def parametric(subposteriors, diagonal=False): """ Merges subposteriors following (embarrassingly parallel) parametric Monte Carlo algorithm. **References:** 1. *Asymptotically Exact, Embarrassingly Parallel MCMC*, Willie Neiswanger, Chong Wang, Eric Xing :param list subposteriors: a list in which each element is a collection of samples. :param bool diagonal: whether to compute weights using variance or covariance, defaults to `False` (using covariance). :return: the estimated mean and variance/covariance parameters of the joined posterior """ joined_subposteriors = tree_multimap(lambda *args: np.stack(args), *subposteriors) joined_subposteriors = vmap(vmap(lambda sample: ravel_pytree(sample)[0]))(joined_subposteriors) submeans = np.mean(joined_subposteriors, axis=1) if diagonal: # NB: jax.numpy.var does not support ddof=1, so we do it manually weights = vmap(lambda x: 1 / np.var(x, ddof=1, axis=0))(joined_subposteriors) var = 1 / np.sum(weights, axis=0) normalized_weights = var * weights # comparing to consensus implementation, we compute weighted mean here mean = np.einsum('ij,ij->j', normalized_weights, submeans) return mean, var else: weights = vmap(lambda x: np.linalg.inv(np.cov(x.T)))(joined_subposteriors) cov = np.linalg.inv(np.sum(weights, axis=0)) normalized_weights = np.matmul(cov, weights) # comparing to consensus implementation, we compute weighted mean here mean = np.einsum('ijk,ik->j', normalized_weights, submeans) return mean, cov
def __call__(self, x, is_training): out = self.activation(x) * self.beta if self.stride > 1: # Average-pool downsample. 이 부분이 트랜지션 블록 부분인데, 언제 self.stride > 1이 적용되는가? shortcut = hk.avg_pool(out, window_shape=(1, 2, 2, 1), strides=(1, 2, 2, 1), padding='SAME') if self.use_projection: shortcut = self.conv_shortcut(shortcut) elif self.use_projection: shortcut = self.conv_shortcut(out) else: shortcut = x out = self.conv0(out) # 1x1 out = self.conv1(self.activation(out)) # 3x3 if self.use_two_convs: out = self.conv1b(self.activation(out)) # 3x3 out = self.conv2(self.activation(out)) # 1x1 out = ( self.se(out) * 2 ) * out # Multiply by 2 for rescaling # 이것도 어떤 논문에서 2배로 하면 더 잘된다고 그랬다고 하는 ..? # Get average residual standard deviation for reporting metrics. res_avg_var = jnp.mean(jnp.var(out, axis=[0, 1, 2])) # Apply stochdepth if applicable. if self._has_stochdepth: out = self.stoch_depth(out, is_training) # SkipInit Gain out = out * hk.get_parameter( 'skip_gain', (), out.dtype, init=jnp.zeros) return out * self.alpha + shortcut, res_avg_var
def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray: mean = jnp.mean(inputs, axis=-1, keepdims=True) variance = jnp.var(inputs, axis=-1, keepdims=True) param_shape = inputs.shape[-1:] scale = hk.get_parameter("scale", param_shape, inputs.dtype, init=jnp.ones) scale = jax.lax.all_gather(scale, "shard")[0] offset = hk.get_parameter("offset", param_shape, inputs.dtype, init=jnp.zeros) offset = jax.lax.all_gather(offset, "shard")[0] scale = jnp.broadcast_to(scale, inputs.shape) offset = jnp.broadcast_to(offset, inputs.shape) mean = jnp.broadcast_to(mean, inputs.shape) inv = scale * jax.lax.rsqrt(variance + 1e-5) if self.offset: return inv * (inputs - mean) + offset else: return inv * (inputs - mean)
def __call__(self, x, is_training): bias1a = hk.get_parameter('bias1a', (), x.dtype, init=jnp.zeros) bias1b = hk.get_parameter('bias1b', (), x.dtype, init=jnp.zeros) bias2a = hk.get_parameter('bias2a', (), x.dtype, init=jnp.zeros) bias2b = hk.get_parameter('bias2b', (), x.dtype, init=jnp.zeros) bias3a = hk.get_parameter('bias3a', (), x.dtype, init=jnp.zeros) bias3b = hk.get_parameter('bias3b', (), x.dtype, init=jnp.zeros) scale = hk.get_parameter('scale', (), x.dtype, init=jnp.ones) out = x + bias1a shortcut = out if self.use_projection: # Downsample with conv1x1 shortcut = self.conv_shortcut(shortcut) out = self.conv0(out) out = self.activation(out + bias1b) out = self.conv1(out + bias2a) out = self.activation(out + bias2b) out = self.conv2(out + bias3a) out = out * scale + bias3b # Get average residual variance for reporting metrics. res_avg_var = jnp.mean(jnp.var(out, axis=[0, 1, 2])) # Apply stochdepth if applicable. if self._has_stochdepth: out = self.stoch_depth(out, is_training) # SkipInit Gain out = out + shortcut return self.activation(out), res_avg_var
def apply_fun(params, inputs): beta, gamma = params mu, var = jnp.mean(inputs, axis=(0, 1, 2)), jnp.var(inputs, axis=(0, 1, 2)) lnorm = (jnp.reshape(mu, (1, 1, 1, -1)) - inputs) / jnp.sqrt(jnp.reshape(var, (1, 1, 1, -1)) + 1e-5) return jnp.reshape(gamma, (1, 1, 1, -1)) * lnorm + jnp.reshape( beta, (1, 1, 1, -1))
def __call__(self, x): means = jnp.mean(x, axis=(1, 2)) m = jnp.mean(means, axis=-1, keepdims=True) v = jnp.var(means, axis=-1, keepdims=True) means_plus = (means - m) / jnp.sqrt(v + 1e-5) h = (x - means[:, None, None, :] ) / jnp.sqrt(jnp.var(x, axis=(1, 2), keepdims=True) + 1e-5) h = h + means_plus[:, None, None, :] * self.param( 'alpha', InstanceNorm2dPlus.scale_init, (1, 1, 1, x.shape[-1])) h = h * self.param('gamma', InstanceNorm2dPlus.scale_init, (1, 1, 1, x.shape[-1])) if self.bias: h = h + self.param('beta', init.zeros, (1, 1, 1, x.shape[-1])) return h
def test_discrete_gibbs_gmm_1d(modified, kernel, inner_kernel, kwargs): def model(probs, locs): c = numpyro.sample("c", dist.Categorical(probs)) numpyro.sample("x", dist.Normal(locs[c], 0.5)) probs = jnp.array([0.15, 0.3, 0.3, 0.25]) locs = jnp.array([-2, 0, 2, 4]) sampler = kernel(inner_kernel(model, trajectory_length=1.2), modified=modified, **kwargs) mcmc = MCMC(sampler, 1000, 200000, progress_bar=False) mcmc.run(random.PRNGKey(0), probs, locs) samples = mcmc.get_samples() assert_allclose(jnp.mean(samples["x"]), 1.3, atol=0.1) assert_allclose(jnp.var(samples["x"]), 4.36, atol=0.4) assert_allclose(jnp.mean(samples["c"]), 1.65, atol=0.1) assert_allclose(jnp.var(samples["c"]), 1.03, atol=0.1)
def signal_metrics(x, i): """Things to measure about a NCHW tensor activation.""" metrics = {} # Average channel-wise mean-squared metrics[f'avg_sq_mean_{i}'] = jnp.mean(jnp.mean(x, axis=[0, 1, 2])**2) # Average channel variance metrics[f'avg_var_{i}'] = jnp.mean(jnp.var(x, axis=[0, 1, 2])) return metrics