def normal_kl(mean1, logvar1, mean2, logvar2): """KL divergence between normal distributions. Distributions parameterized by mean and log variance. Args: mean1: mean of the first distribution logvar1: log variance of the first distribution mean2: mean of the second distribution logvar2: log variance of the second distribution Returns: KL(N(mean1, exp(logvar1)) || N(mean2, exp(logvar2))) """ return 0.5 * (-1.0 + logvar2 - logvar1 + jnp.exp(logvar1 - logvar2) + jnp.square(mean1 - mean2) * jnp.exp(-logvar2))
def clip_grad(grad, config): """Clips the gradient using the method given in the config.""" clip_by = config.opt.clip_by clip_value = config.opt.clip_value if clip_by == 'global_norm': global_norm = jnp.sqrt( sum([ jnp.sum(jnp.square(x)) for x in jax.tree_util.tree_leaves(grad) ])) should_clip = global_norm > clip_value grad = jax.tree_map( lambda g: jnp.where(should_clip, g * clip_value / global_norm, g), grad) else: raise ValueError('Unexpected value for config.opt.clip_by', clip_by) return grad
def diag_gaussian_log_likelihood(z, mean=0.0, logvar=0.0, varmin=1e-16): """Log-likelihood under a Gaussian distribution with diagonal covariance. Returns the log-likelihood for each dimension. Args: z: The value to compute the log-likelihood. mean: The mean of the Gaussian logvar: The log variance of the Gaussian. varmin: Minimum variance allowed (numerically useful). Returns: The log-likelihood under the Gaussian model. """ logvar_wm = np.log(np.exp(logvar) + varmin) return (-0.5 * (logvar + np.log(2 * np.pi) + np.square( (z - mean) / (np.exp(0.5 * (logvar_wm))))))
def torus2sphere(ra: jnp.ndarray, ang: jnp.ndarray) -> jnp.ndarray: """Convert points represented on a two-dimensional torus into points on a two-dimensional sphere. Args: ra: First radial coordinate. ang: Angular coordinate. Returns: out: Conversion of the inputs on a torus into a sphere. """ circ = pm.sphere.ang2euclid(ang) sph = jnp.sqrt(1. - jnp.square(ra[..., jnp.newaxis])) * circ sph = jnp.concatenate((sph, ra[..., jnp.newaxis]), axis=-1) return sph
def loss( params: hk.Params, inputs: np.ndarray, targets: np.ndarray, ) -> jnp.DeviceArray: """Compute the loss of the network, including L2.""" assert targets.dtype == np.int32 batch_size = inputs.shape[0] log_probs = net.apply(params, inputs) l2_loss = 0.5 * sum( jnp.sum(jnp.square(p)) for p in jax.tree_leaves(params)) softmax_xent = -jnp.sum(hk.one_hot(targets, NUM_DIGITS) * log_probs) softmax_xent = softmax_xent / batch_size return softmax_xent + 1e-4 * l2_loss
def loss(params: hk.Params, sample: reverb.ReplaySample) -> jnp.ndarray: """Entropy-regularised actor-critic loss.""" # Extract the data. observations, actions, rewards, discounts, extra = sample.data initial_state = tree.map_structure(lambda s: s[0], extra['core_state']) behaviour_logits = extra['logits'] # Apply reward clipping. rewards = jnp.clip(rewards, -max_abs_reward, max_abs_reward) # Unroll current policy over observations. (logits, values), _ = unroll_fn.apply(params, observations, initial_state) # Compute importance sampling weights: current policy / behavior policy. rhos = rlax.categorical_importance_sampling_ratios( logits[:-1], behaviour_logits[:-1], actions[:-1]) # Critic loss. vtrace_returns = rlax.vtrace_td_error_and_advantage( v_tm1=values[:-1], v_t=values[1:], r_t=rewards[:-1], discount_t=discounts[:-1] * discount, rho_t=rhos) critic_loss = jnp.square(vtrace_returns.errors) # Policy gradient loss. policy_gradient_loss = rlax.policy_gradient_loss( logits_t=logits[:-1], a_t=actions[:-1], adv_t=vtrace_returns.pg_advantage, w_t=jnp.ones_like(rewards[:-1])) # Entropy regulariser. entropy_loss = rlax.entropy_loss(logits[:-1], jnp.ones_like(rewards[:-1])) # Combine weighted sum of actor & critic losses. mean_loss = jnp.mean(policy_gradient_loss + baseline_cost * critic_loss + entropy_cost * entropy_loss) return mean_loss
def evaluate_model(model: flax.nn.Model, X: Array, Y: Array) -> float: """ Evaluates model and returns loss. Args: model: flax.nn.Model. X: physics trajectory of shape [num_samples, num_integration_steps, {state dimensions}]. Y: observations, jax.numpy.ndarray of shape [num_samples, num_integration_steps, {observation dimensions}]. Returns: Mean-squared loss. """ Y_enc = model(Y) loss = jnp.mean(jnp.square(X - Y_enc)) return loss
def kl_gauss_gauss(q_mean, q_logvar, p_mean, p_logvar, varmin=1e-16): """Compute the KL divergence between two diagonal Gaussian distributions. KL(q||p) = E_q[log q(z) - log p(z))] Args: q_mean: mean of q q_logvar: logvar of q p_mean: mean of p p_logvar: logvar of p varmin: minimum variance allowed, useful for numerical stability Returns: np array of KL, computed analytically, same size as q_mean """ q_logvar = np.log(np.exp(q_logvar) + varmin) p_logvar = np.log(np.exp(p_logvar) + varmin) return (0.5 * (p_logvar - q_logvar + np.exp(q_logvar - p_logvar) + np.square((q_mean - p_mean) / np.exp(0.5 * p_logvar)) - 1.0))
def _single_kohn_sham_iteration(carry, inputs): del inputs idx, old_state, alpha, converged, differences = carry state, differences = jax.lax.cond( converged, true_operand=(old_state, differences), true_fun=_converged_kohn_sham_iteration, false_operand=(idx, old_state, alpha, differences), false_fun=_uncoveraged_kohn_sham_iteration) converged = jnp.mean(jnp.square( state.density - old_state.density)) < density_mse_converge_tolerance state = jax.lax.cond( idx <= stop_gradient_step, true_fun=jax.lax.stop_gradient, false_fun=lambda x: x, operand=state) return (idx + 1, state, alpha * alpha_decay, converged, differences), state
def expected_value_delta(params: base.Params, state: CvState) -> float: """"Expected value of second order expansion of `function` at dist mean.""" del state mean_dist = params[0] var_dist = jnp.square(jnp.exp(params[1])) hessians = jax.hessian(function)(mean_dist) assert hessians.ndim == 2 hess_diags = jnp.diag(hessians) assert hess_diags.ndim == 1 # Trace (Hessian * Sigma) and we use that Sigma is diagonal. expected_second_order_term = jnp.sum(var_dist * hess_diags) / 2. expected_control_variate = function(mean_dist) expected_control_variate += expected_second_order_term return expected_control_variate
def data_fidelity( self, input_data: jnp.ndarray, recons: jnp.ndarray, ) -> jnp.ndarray: """Compute Data fidelity (recons loss) for given input and recons. Args: input_data: Input batch of shape (batch_size, ...). recons: Reconstruction of the input data. An array with the same shape as `input_data.data`. Returns: Computed data fidelity term across batch of data. An array of shape `(batch_size,)`. """ error = (input_data - recons).reshape(input_data.shape[0], -1) return -0.5 * jnp.sum(jnp.square(error), axis=1) / self._obs_var
def euclidean_distance(self, x, y, aux=None): """Euclidean distance between simulation summaries and target summary The target summaries are expanded on the zeroth axis so that it can broadcast with the simulation summaries. Parameters ---------- x : (any, n_params) Envisioned as a whole batch of summaries of simulations y : (n_params) Envisioned as a target summary aux : None, default=Note Empty holder so that function works in the same way as `F_distance` """ difference = x - np.expand_dims(y, 0) return np.sqrt(np.sum(np.square(difference), -1))
def _split_grad(self, param, state, grad, decay): """Split the gradient for the direction and scale.""" if param.size > param.shape[-1]: red_dims = tuple(range(param.ndim-1)) direction = param / state.mult norm = jnp.sqrt(jnp.square(param).sum(red_dims, keepdims=True)) scale = norm * jnp.sign(state.mult) scale_grad = jnp.sum( grad * direction, axis=red_dims, keepdims=True) direction_grad = state.mult * (grad - scale_grad * direction) if decay != 0: direction_grad = direction_grad + decay * direction direction_info = direction, state.direction_state, direction_grad scale_info = scale, state.scale_state, scale_grad return direction_info + scale_info else: return (param, state.direction_state, grad, (), (), ())
def loss_fn(rng, params, states, batch): model_fn = mutils.get_model_fn(model, params, states, train=train) data = batch['image'] rng, step_rng = random.split(rng) labels = random.choice(step_rng, vesde.N, shape=(data.shape[0], )) sigmas = smld_sigma_array[labels] rng, step_rng = random.split(rng) noise = batch_mul(random.normal(step_rng, data.shape), sigmas) perturbed_data = noise + data rng, step_rng = random.split(rng) score, new_model_state = model_fn(perturbed_data, labels, rng=step_rng) target = -batch_mul(noise, 1. / (sigmas**2)) losses = jnp.square(score - target) losses = reduce_op(losses.reshape( (losses.shape[0], -1)), axis=-1) * sigmas**2 loss = jnp.mean(losses) return loss, new_model_state
def self_interaction_weight(reshaped_density, dx, width): """Gets self-interaction weight. When the density integral is one, the self-interaction weight is 1. The weight goes to zero when the density integral deviates from one. Args: reshaped_density: Float numpy array with any shape. The total size should be num_grids. dx: Float, grid spacing. width: Float, the width of the Gaussian function. Returns: Float, the self-interaction weight. """ density_integral = jnp.sum(reshaped_density) * dx return jnp.exp(-jnp.square((density_integral - 1) / width))
def loss(self, x, targets, z_loss=1): x = self.norm(x) logits = self.proj(x) logits -= logits.max(-1, keepdims=True) gt_onehot = jax.nn.one_hot(targets, self.dim) predicted_logits = jnp.sum(jnp.multiply(gt_onehot, logits), axis=-1) exp_logits = jnp.exp(logits) sum_exp_logits = exp_logits.sum(axis=-1) loss = jnp.log(sum_exp_logits) - predicted_logits loss += (1e-4 * jnp.square(jnp.log(sum_exp_logits)) * z_loss).mean() correct = (0.0 == predicted_logits) return loss, correct
def apply(self, y, vgp: VariationalGaussianProcess): obs_noise_scale = jax.nn.softplus( self.param('observation_noise_scale', (), jax.nn.initializers.ones)) variational_distribution = vgp.marginal() qu_mean = variational_distribution.mean qu_scale = variational_distribution.scale # Expected value of iid gaussians under q(u) expected_gll_under_qu = -.5 * jnp.squeeze( (jnp.sum(jnp.square(qu_mean - y)) + jnp.trace(qu_scale @ qu_scale.T)) / obs_noise_scale**2 + y.shape[-1] * jnp.log(obs_noise_scale**2) + jnp.log(2 * jnp.pi)) # flip sign to minimize the elbo return -expected_gll_under_qu
def loss_fn(model, minibatch, clip_param, vf_coeff, entropy_coeff): states, actions, old_log_probs, returns, advantages = minibatch log_probs, values = model(states) values = values[:, 0] # Convert shapes: (batch, 1) to (batch, ). probs = jnp.exp(log_probs) entropy = jnp.sum(-probs * log_probs, axis=1).mean() log_probs_act_taken = jax.vmap(lambda lp, a: lp[a])(log_probs, actions) ratios = jnp.exp(log_probs_act_taken - old_log_probs) # Advantage normalization (following the OpenAI baselines). advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) PG_loss = ratios * advantages clipped_loss = advantages * jax.lax.clamp(1. - clip_param, ratios, 1. + clip_param) PPO_loss = -jnp.mean(jnp.minimum(PG_loss, clipped_loss), axis=0) value_loss = jnp.mean(jnp.square(returns - values), axis=0) return PPO_loss + vf_coeff * value_loss - entropy_coeff * entropy
def _linear_regression_gibbs_fn(X, XX, XY, Y, rng_key, gibbs_sites, hmc_sites): N, P = X.shape sigma = jnp.exp(hmc_sites['log_sigma'] ) if 'log_sigma' in hmc_sites else hmc_sites['sigma'] sigma_sq = jnp.square(sigma) covar_inv = XX / sigma_sq + jnp.eye(P) L = cho_factor(covar_inv, lower=True)[0] L_inv = solve_triangular(L, jnp.eye(P), lower=True) loc = cho_solve((L, True), XY) / sigma_sq beta_proposal = dist.MultivariateNormal(loc=loc, scale_tril=L_inv).sample(rng_key) return {'beta': beta_proposal}
def mobius_spline_log_prob(ra: jnp.ndarray, raunif: jnp.ndarray, rb: jnp.ndarray, rbunif: jnp.ndarray, ang: jnp.ndarray, angunif: jnp.ndarray, w: jnp.ndarray, xk: jnp.ndarray, yk: jnp.ndarray, deltak: jnp.ndarray, xl: jnp.ndarray, yl: jnp.ndarray, deltal: jnp.ndarray): """Compute the log-density of the Mobius spline transformation on the sphere. """ lpra = -jnp.log(2.) - jnp.log( grad_rational_quadratic(raunif, xk, yk, deltak)) lprb = -jnp.log(2.) - jnp.log( vmap(grad_rational_quadratic)(rbunif, xl, yl, deltal)) lpang = vmap(mobius_log_prob)(angunif, w) ldj = -(3. / 2 - 1.) * jnp.log(1 - jnp.square(ra)) log_prob = lpra + lprb + lpang + ldj return log_prob
def calc_gp_prior(X, K,n_lats, Fourier = False): ''' Calculates the GP log prior (time domain) x_samples should be nsamples by T K is T by T ''' total_prior = 0 if Fourier: for i in np.arange(n_lats): x_lat = X[i] total_prior = total_prior -(1/2)*(np.sum(np.square(x_lat)/K[i])+ np.sum(np.log(2*np.pi*K[i]))) # else: # Kinv = np.linalg.inv(K) # log_prior = -(1/2)*(np.shape(Kinv)[0]*np.log(2*np.pi) + np.matmul(rate,np.matmul(Kinv,rate))+ np.linalg.slogdet(K)[1]) return total_prior
def loss_fn(x, excite_port_idx=0): wrapped_meep = mpa.MeepJaxWrapper( simulation, [sources[excite_port_idx]], monitors, design_regions, frequencies, ) monitor_values = wrapped_meep([x]) s1p, s1m, s2m, s2p = monitor_values if excite_port_idx == 0: t = s2m / s1p else: t = s1m / s2p # Mean transmission vs wavelength t_mean = jnp.mean(jnp.square(jnp.abs(t))) return t_mean
def loss_fn(rng, params, states, batch): model_fn = mutils.get_model_fn(model, params, states, train=train) data = batch['image'] rng, step_rng = random.split(rng) labels = random.choice(step_rng, vpsde.N, shape=(data.shape[0], )) sqrt_alphas_cumprod = vpsde.sqrt_alphas_cumprod sqrt_1m_alphas_cumprod = vpsde.sqrt_1m_alphas_cumprod rng, step_rng = random.split(rng) noise = random.normal(step_rng, data.shape) perturbed_data = batch_mul(sqrt_alphas_cumprod[labels], data) + \ batch_mul(sqrt_1m_alphas_cumprod[labels], noise) rng, step_rng = random.split(rng) score, new_model_state = model_fn(perturbed_data, labels, rng=step_rng) losses = jnp.square(score - noise) losses = reduce_op(losses.reshape((losses.shape[0], -1)), axis=-1) loss = jnp.mean(losses) return loss, new_model_state
def compute_stats(inputs: JTensor, padding: Optional[JTensor] = None) -> NestedMap: """Computes various stats over the valid data points in inputs.""" # Let's compute stats in fp32 inputs = inputs.astype(jnp.float32) if padding is None: padding = jnp.zeros_like(inputs) assert inputs.ndim == padding.ndim mask = 1.0 - padding sum_v = jnp.sum(inputs * mask) count_v = jnp.sum(jnp.ones_like(inputs) * mask) mean_v = sum_v / jnp.maximum(1.0, count_v) sum_v_squared = jnp.sum(jnp.square((inputs - mean_v) * mask)) std_v = jnp.sqrt(sum_v_squared / jnp.maximum(1.0, count_v)) max_v = jnp.max(jnp.abs(inputs * mask)) return NestedMap(mean_v=mean_v, std_v=std_v, max_v=max_v)
def cluster_dist_metric(point, metric_state: MetricState): if method == 'euclidean': dist = metric_state.cluster_centers - point return jnp.sum(jnp.square(dist), axis=-1) if method == 'mahalanobis': dist = metric_state.cluster_centers - point maha = vmap(lambda dist, C: dist @ C @ dist)(dist, metric_state.C) return maha if method == 'ellipsoid': dist = metric_state.cluster_centers - point weighted_maha = vmap( lambda dist, C, radii, num_k: (dist @ C @ dist) * jnp.exp( log_ellipsoid_volume(radii) - jnp.log(num_k) + jnp.log( num_S) - log_VS))(dist, metric_state.C, metric_state.radii, metric_state.num_k) return weighted_maha
def optimize(state, grad, warmup=config.optim.warmup, grad_clip=config.optim.grad_clip): """Optimizes with warmup and gradient clipping (disabled if negative).""" lr = state.lr if warmup > 0: lr = lr * jnp.minimum(state.step / warmup, 1.0) if grad_clip >= 0: # Compute global gradient norm grad_norm = jnp.sqrt( sum([jnp.sum(jnp.square(x)) for x in jax.tree_leaves(grad)])) # Clip gradient clipped_grad = jax.tree_map( lambda x: x * grad_clip / jnp.maximum(grad_norm, grad_clip), grad) else: # disabling gradient clipping if grad_clip < 0 clipped_grad = grad return state.optimizer.apply_gradient(clipped_grad, learning_rate=lr)
def kl_gauss_gauss(z_mean, z_logvar, prior_params): """Compute the KL divergence between two diagonal Gaussian distributions. KL(q||p) = E_q[log q(z) - log p(z))] Args: z_mean: mean of posterior, z ~ q(z|x) z_logvar: logvar of posterior prior_z_mean: mean of prior, z ~ p(z) prior_z_logvar: logvar of prior Returns: np array of KL, computed analytically, same size as z_mean """ prior_mean = prior_params['mean'] prior_logvar = prior_params['logvar'] return (0.5 * (prior_logvar - z_logvar + np.exp(z_logvar - prior_logvar) + np.square((z_mean - prior_mean) / np.exp(0.5 * prior_logvar)) - 1.0))
def loss_fn(params: flax.core.frozen_dict.FrozenDict, module: models.ActorCritic, minibatch: Tuple, clip_param: float, vf_coeff: float, entropy_coeff: float): """Evaluate the loss function. Compute loss as a sum of three components: the negative of the PPO clipped surrogate objective, the value function loss and the negative of the entropy bonus. Args: params: the parameters of the actor-critic model module: the actor-critic model minibatch: Tuple of five elements forming one experience batch: states: shape (batch_size, 84, 84, 4) actions: shape (batch_size, 84, 84, 4) old_log_probs: shape (batch_size,) returns: shape (batch_size,) advantages: shape (batch_size,) clip_param: the PPO clipping parameter used to clamp ratios in loss function vf_coeff: weighs value function loss in total loss entropy_coeff: weighs entropy bonus in the total loss Returns: loss: the PPO loss, scalar quantity """ states, actions, old_log_probs, returns, advantages = minibatch log_probs, values = agent.policy_action(params, module, states) values = values[:, 0] # Convert shapes: (batch, 1) to (batch, ). probs = jnp.exp(log_probs) value_loss = jnp.mean(jnp.square(returns - values), axis=0) entropy = jnp.sum(-probs * log_probs, axis=1).mean() log_probs_act_taken = jax.vmap(lambda lp, a: lp[a])(log_probs, actions) ratios = jnp.exp(log_probs_act_taken - old_log_probs) # Advantage normalization (following the OpenAI baselines). advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) PG_loss = ratios * advantages clipped_loss = advantages * jax.lax.clamp(1. - clip_param, ratios, 1. + clip_param) PPO_loss = -jnp.mean(jnp.minimum(PG_loss, clipped_loss), axis=0) return PPO_loss + vf_coeff * value_loss - entropy_coeff * entropy
def diffusion_reverse(*, x, z_t, logsnr_s, logsnr_t, x_logvar): """q(z_s | z_t, x) (requires logsnr_s > logsnr_t (i.e. s < t)).""" alpha_st = jnp.sqrt((1. + jnp.exp(-logsnr_t)) / (1. + jnp.exp(-logsnr_s))) alpha_s = jnp.sqrt(nn.sigmoid(logsnr_s)) r = jnp.exp(logsnr_t - logsnr_s) # SNR(t)/SNR(s) one_minus_r = -jnp.expm1(logsnr_t - logsnr_s) # 1-SNR(t)/SNR(s) log_one_minus_r = utils.log1mexp(logsnr_s - logsnr_t) # log(1-SNR(t)/SNR(s)) mean = r * alpha_st * z_t + one_minus_r * alpha_s * x if isinstance(x_logvar, str): if x_logvar == 'small': # same as setting x_logvar to -infinity var = one_minus_r * nn.sigmoid(-logsnr_s) logvar = log_one_minus_r + nn.log_sigmoid(-logsnr_s) elif x_logvar == 'large': # same as setting x_logvar to nn.log_sigmoid(-logsnr_t) var = one_minus_r * nn.sigmoid(-logsnr_t) logvar = log_one_minus_r + nn.log_sigmoid(-logsnr_t) elif x_logvar.startswith('medium:'): _, frac = x_logvar.split(':') frac = float(frac) logging.info('logvar frac=%f', frac) assert 0 <= frac <= 1 min_logvar = log_one_minus_r + nn.log_sigmoid(-logsnr_s) max_logvar = log_one_minus_r + nn.log_sigmoid(-logsnr_t) logvar = frac * max_logvar + (1 - frac) * min_logvar var = jnp.exp(logvar) else: raise NotImplementedError(x_logvar) else: assert isinstance(x_logvar, jnp.ndarray) or isinstance( x_logvar, onp.ndarray) assert x_logvar.shape == x.shape # start with "small" variance var = one_minus_r * nn.sigmoid(-logsnr_s) logvar = log_one_minus_r + nn.log_sigmoid(-logsnr_s) # extra variance weight is (one_minus_r*alpha_s)**2 var += jnp.square(one_minus_r) * nn.sigmoid(logsnr_s) * jnp.exp( x_logvar) logvar = jnp.logaddexp( logvar, 2. * log_one_minus_r + nn.log_sigmoid(logsnr_s) + x_logvar) return {'mean': mean, 'std': jnp.sqrt(var), 'var': var, 'logvar': logvar}
def loss_fn(agent_params): ( logits, v, reconstucted_lang, reconstructed_vision, *_, ) = apply_fast_slow_agent_model( agent_params, num_embeddings, embedding_dim, language_state, vision_state, h_prev, decoder_h_prev, False, ) rhos = categorical_importance_sampling_ratios(logits[:-1], behavior_logits[:-1], actions[:-1]) errors, pg_advantages, q_estimate = vtrace() critic_loss = jnp.square(errors) log_pi_a = jnp.take_along_axis(jax.nn.log_softmax(logits[:-1]), actions[:-1], axis=-1) pg_advantage = jax.lax.stop_gradient(pg_advantage) pg_loss = jnp.mean(-log_pi_a * pg_advantage) entropy_loss = cross_entropy_loss_fn(logits, logits) language_reconstruction_loss = cross_entropy_loss_fn( language_state, reconstructed_lang) vision_reconstruction_loss = cross_entropy_loss_fn( vision_state, reconstructed_lang) regularized_loss = jnp.mean( policy_eps * pg_loss + baseline_eps * critic_loss + entropy_eps * entropy_loss + reconstruction_eps * (language_reconstruction_loss + vision_reconstruction_loss)) return loss