Пример #1
0
def normal_kl(mean1, logvar1, mean2, logvar2):
  """KL divergence between normal distributions.

  Distributions parameterized by mean and log variance.

  Args:
    mean1: mean of the first distribution
    logvar1: log variance of the first distribution
    mean2: mean of the second distribution
    logvar2: log variance of the second distribution

  Returns:
    KL(N(mean1, exp(logvar1)) || N(mean2, exp(logvar2)))
  """
  return 0.5 * (-1.0 + logvar2 - logvar1 + jnp.exp(logvar1 - logvar2)
                + jnp.square(mean1 - mean2) * jnp.exp(-logvar2))
def clip_grad(grad, config):
    """Clips the gradient using the method given in the config."""
    clip_by = config.opt.clip_by
    clip_value = config.opt.clip_value
    if clip_by == 'global_norm':
        global_norm = jnp.sqrt(
            sum([
                jnp.sum(jnp.square(x)) for x in jax.tree_util.tree_leaves(grad)
            ]))
        should_clip = global_norm > clip_value
        grad = jax.tree_map(
            lambda g: jnp.where(should_clip, g * clip_value / global_norm, g),
            grad)
    else:
        raise ValueError('Unexpected value for config.opt.clip_by', clip_by)
    return grad
def diag_gaussian_log_likelihood(z, mean=0.0, logvar=0.0, varmin=1e-16):
    """Log-likelihood under a Gaussian distribution with diagonal covariance.
     Returns the log-likelihood for each dimension.

  Args:
    z: The value to compute the log-likelihood.
    mean: The mean of the Gaussian
    logvar: The log variance of the Gaussian.
    varmin: Minimum variance allowed (numerically useful).

  Returns:
    The log-likelihood under the Gaussian model.
  """
    logvar_wm = np.log(np.exp(logvar) + varmin)
    return (-0.5 * (logvar + np.log(2 * np.pi) + np.square(
        (z - mean) / (np.exp(0.5 * (logvar_wm))))))
Пример #4
0
def torus2sphere(ra: jnp.ndarray, ang: jnp.ndarray) -> jnp.ndarray:
    """Convert points represented on a two-dimensional torus into points on a
    two-dimensional sphere.

    Args:
        ra: First radial coordinate.
        ang: Angular coordinate.

    Returns:
        out: Conversion of the inputs on a torus into a sphere.

    """
    circ = pm.sphere.ang2euclid(ang)
    sph = jnp.sqrt(1. - jnp.square(ra[..., jnp.newaxis])) * circ
    sph = jnp.concatenate((sph, ra[..., jnp.newaxis]), axis=-1)
    return sph
Пример #5
0
    def loss(
        params: hk.Params,
        inputs: np.ndarray,
        targets: np.ndarray,
    ) -> jnp.DeviceArray:
        """Compute the loss of the network, including L2."""
        assert targets.dtype == np.int32
        batch_size = inputs.shape[0]
        log_probs = net.apply(params, inputs)

        l2_loss = 0.5 * sum(
            jnp.sum(jnp.square(p)) for p in jax.tree_leaves(params))
        softmax_xent = -jnp.sum(hk.one_hot(targets, NUM_DIGITS) * log_probs)
        softmax_xent = softmax_xent / batch_size

        return softmax_xent + 1e-4 * l2_loss
Пример #6
0
        def loss(params: hk.Params,
                 sample: reverb.ReplaySample) -> jnp.ndarray:
            """Entropy-regularised actor-critic loss."""

            # Extract the data.
            observations, actions, rewards, discounts, extra = sample.data
            initial_state = tree.map_structure(lambda s: s[0],
                                               extra['core_state'])
            behaviour_logits = extra['logits']

            # Apply reward clipping.
            rewards = jnp.clip(rewards, -max_abs_reward, max_abs_reward)

            # Unroll current policy over observations.
            (logits, values), _ = unroll_fn.apply(params, observations,
                                                  initial_state)

            # Compute importance sampling weights: current policy / behavior policy.
            rhos = rlax.categorical_importance_sampling_ratios(
                logits[:-1], behaviour_logits[:-1], actions[:-1])

            # Critic loss.
            vtrace_returns = rlax.vtrace_td_error_and_advantage(
                v_tm1=values[:-1],
                v_t=values[1:],
                r_t=rewards[:-1],
                discount_t=discounts[:-1] * discount,
                rho_t=rhos)
            critic_loss = jnp.square(vtrace_returns.errors)

            # Policy gradient loss.
            policy_gradient_loss = rlax.policy_gradient_loss(
                logits_t=logits[:-1],
                a_t=actions[:-1],
                adv_t=vtrace_returns.pg_advantage,
                w_t=jnp.ones_like(rewards[:-1]))

            # Entropy regulariser.
            entropy_loss = rlax.entropy_loss(logits[:-1],
                                             jnp.ones_like(rewards[:-1]))

            # Combine weighted sum of actor & critic losses.
            mean_loss = jnp.mean(policy_gradient_loss +
                                 baseline_cost * critic_loss +
                                 entropy_cost * entropy_loss)

            return mean_loss
def evaluate_model(model: flax.nn.Model, X: Array, Y: Array) -> float:
  """
  Evaluates model and returns loss.
  
  Args:
    model: flax.nn.Model.
    X: physics trajectory of shape
      [num_samples, num_integration_steps, {state dimensions}].
    Y: observations, jax.numpy.ndarray of shape 
      [num_samples, num_integration_steps, {observation dimensions}].
      
  Returns:
    Mean-squared loss.
  """
  Y_enc = model(Y)
  loss = jnp.mean(jnp.square(X - Y_enc))
  return loss
def kl_gauss_gauss(q_mean, q_logvar, p_mean, p_logvar, varmin=1e-16):
  """Compute the KL divergence between two diagonal Gaussian distributions.

            KL(q||p) = E_q[log q(z) - log p(z))]
  Args:
    q_mean: mean of q
    q_logvar: logvar of q
    p_mean: mean of p
    p_logvar: logvar of p
    varmin: minimum variance allowed, useful for numerical stability
  Returns:
    np array of KL, computed analytically, same size as q_mean
  """
  q_logvar = np.log(np.exp(q_logvar) + varmin)
  p_logvar = np.log(np.exp(p_logvar) + varmin)
  return (0.5 * (p_logvar - q_logvar + np.exp(q_logvar - p_logvar)
                 + np.square((q_mean - p_mean) / np.exp(0.5 * p_logvar)) - 1.0))
Пример #9
0
 def _single_kohn_sham_iteration(carry, inputs):
   del inputs
   idx, old_state, alpha, converged, differences = carry
   state, differences = jax.lax.cond(
       converged,
       true_operand=(old_state, differences),
       true_fun=_converged_kohn_sham_iteration,
       false_operand=(idx, old_state, alpha, differences),
       false_fun=_uncoveraged_kohn_sham_iteration)
   converged = jnp.mean(jnp.square(
       state.density - old_state.density)) < density_mse_converge_tolerance
   state = jax.lax.cond(
       idx <= stop_gradient_step,
       true_fun=jax.lax.stop_gradient,
       false_fun=lambda x: x,
       operand=state)
   return (idx + 1, state, alpha * alpha_decay, converged, differences), state
Пример #10
0
    def expected_value_delta(params: base.Params, state: CvState) -> float:
        """"Expected value of second order expansion of `function` at dist mean."""
        del state
        mean_dist = params[0]
        var_dist = jnp.square(jnp.exp(params[1]))
        hessians = jax.hessian(function)(mean_dist)

        assert hessians.ndim == 2
        hess_diags = jnp.diag(hessians)
        assert hess_diags.ndim == 1

        # Trace (Hessian * Sigma) and we use that Sigma is diagonal.
        expected_second_order_term = jnp.sum(var_dist * hess_diags) / 2.

        expected_control_variate = function(mean_dist)
        expected_control_variate += expected_second_order_term
        return expected_control_variate
Пример #11
0
  def data_fidelity(
      self,
      input_data: jnp.ndarray,
      recons: jnp.ndarray,
  ) -> jnp.ndarray:
    """Compute Data fidelity (recons loss) for given input and recons.

    Args:
     input_data: Input batch of shape (batch_size, ...).
     recons: Reconstruction of the input data. An array with the same shape as
       `input_data.data`.
    Returns:
     Computed data fidelity term across batch of data. An array of shape
     `(batch_size,)`.
    """
    error = (input_data - recons).reshape(input_data.shape[0], -1)
    return -0.5 * jnp.sum(jnp.square(error), axis=1) / self._obs_var
    def euclidean_distance(self, x, y, aux=None):
        """Euclidean distance between simulation summaries and target summary

        The target summaries are expanded on the zeroth axis so that it can
        broadcast with the simulation summaries.

        Parameters
        ----------
        x : (any, n_params)
            Envisioned as a whole batch of summaries of simulations
        y : (n_params)
            Envisioned as a target summary
        aux : None, default=Note
            Empty holder so that function works in the same way as `F_distance`
        """
        difference = x - np.expand_dims(y, 0)
        return np.sqrt(np.sum(np.square(difference), -1))
Пример #13
0
 def _split_grad(self, param, state, grad, decay):
   """Split the gradient for the direction and scale."""
   if param.size > param.shape[-1]:
     red_dims = tuple(range(param.ndim-1))
     direction = param / state.mult
     norm = jnp.sqrt(jnp.square(param).sum(red_dims, keepdims=True))
     scale = norm * jnp.sign(state.mult)
     scale_grad = jnp.sum(
         grad * direction, axis=red_dims, keepdims=True)
     direction_grad = state.mult * (grad - scale_grad * direction)
     if decay != 0:
       direction_grad = direction_grad + decay * direction
     direction_info = direction, state.direction_state, direction_grad
     scale_info = scale, state.scale_state, scale_grad
     return direction_info + scale_info
   else:
     return (param, state.direction_state, grad, (), (), ())
Пример #14
0
 def loss_fn(rng, params, states, batch):
     model_fn = mutils.get_model_fn(model, params, states, train=train)
     data = batch['image']
     rng, step_rng = random.split(rng)
     labels = random.choice(step_rng, vesde.N, shape=(data.shape[0], ))
     sigmas = smld_sigma_array[labels]
     rng, step_rng = random.split(rng)
     noise = batch_mul(random.normal(step_rng, data.shape), sigmas)
     perturbed_data = noise + data
     rng, step_rng = random.split(rng)
     score, new_model_state = model_fn(perturbed_data, labels, rng=step_rng)
     target = -batch_mul(noise, 1. / (sigmas**2))
     losses = jnp.square(score - target)
     losses = reduce_op(losses.reshape(
         (losses.shape[0], -1)), axis=-1) * sigmas**2
     loss = jnp.mean(losses)
     return loss, new_model_state
Пример #15
0
def self_interaction_weight(reshaped_density, dx, width):
  """Gets self-interaction weight.

  When the density integral is one, the self-interaction weight is 1. The weight
  goes to zero when the density integral deviates from one.

  Args:
    reshaped_density: Float numpy array with any shape. The total size should be
       num_grids.
    dx: Float, grid spacing.
    width: Float, the width of the Gaussian function.

  Returns:
    Float, the self-interaction weight.
  """
  density_integral = jnp.sum(reshaped_density) * dx
  return jnp.exp(-jnp.square((density_integral - 1) / width))
Пример #16
0
    def loss(self, x, targets, z_loss=1):
        x = self.norm(x)
        logits = self.proj(x)

        logits -= logits.max(-1, keepdims=True)

        gt_onehot = jax.nn.one_hot(targets, self.dim)
        predicted_logits = jnp.sum(jnp.multiply(gt_onehot, logits), axis=-1)
        exp_logits = jnp.exp(logits)

        sum_exp_logits = exp_logits.sum(axis=-1)

        loss = jnp.log(sum_exp_logits) - predicted_logits

        loss += (1e-4 * jnp.square(jnp.log(sum_exp_logits)) * z_loss).mean()
        correct = (0.0 == predicted_logits)
        return loss, correct
Пример #17
0
    def apply(self, y, vgp: VariationalGaussianProcess):
        obs_noise_scale = jax.nn.softplus(
            self.param('observation_noise_scale', (),
                       jax.nn.initializers.ones))

        variational_distribution = vgp.marginal()
        qu_mean = variational_distribution.mean
        qu_scale = variational_distribution.scale

        # Expected value of iid gaussians under q(u)
        expected_gll_under_qu = -.5 * jnp.squeeze(
            (jnp.sum(jnp.square(qu_mean - y)) +
             jnp.trace(qu_scale @ qu_scale.T)) / obs_noise_scale**2 +
            y.shape[-1] * jnp.log(obs_noise_scale**2) + jnp.log(2 * jnp.pi))

        # flip sign to minimize the elbo
        return -expected_gll_under_qu
Пример #18
0
 def loss_fn(model, minibatch, clip_param, vf_coeff, entropy_coeff):
     states, actions, old_log_probs, returns, advantages = minibatch
     log_probs, values = model(states)
     values = values[:, 0]  # Convert shapes: (batch, 1) to (batch, ).
     probs = jnp.exp(log_probs)
     entropy = jnp.sum(-probs * log_probs, axis=1).mean()
     log_probs_act_taken = jax.vmap(lambda lp, a: lp[a])(log_probs, actions)
     ratios = jnp.exp(log_probs_act_taken - old_log_probs)
     # Advantage normalization (following the OpenAI baselines).
     advantages = (advantages - advantages.mean()) / (advantages.std() +
                                                      1e-8)
     PG_loss = ratios * advantages
     clipped_loss = advantages * jax.lax.clamp(1. - clip_param, ratios,
                                               1. + clip_param)
     PPO_loss = -jnp.mean(jnp.minimum(PG_loss, clipped_loss), axis=0)
     value_loss = jnp.mean(jnp.square(returns - values), axis=0)
     return PPO_loss + vf_coeff * value_loss - entropy_coeff * entropy
Пример #19
0
def _linear_regression_gibbs_fn(X, XX, XY, Y, rng_key, gibbs_sites, hmc_sites):
    N, P = X.shape

    sigma = jnp.exp(hmc_sites['log_sigma']
                    ) if 'log_sigma' in hmc_sites else hmc_sites['sigma']

    sigma_sq = jnp.square(sigma)
    covar_inv = XX / sigma_sq + jnp.eye(P)

    L = cho_factor(covar_inv, lower=True)[0]
    L_inv = solve_triangular(L, jnp.eye(P), lower=True)
    loc = cho_solve((L, True), XY) / sigma_sq

    beta_proposal = dist.MultivariateNormal(loc=loc,
                                            scale_tril=L_inv).sample(rng_key)

    return {'beta': beta_proposal}
Пример #20
0
def mobius_spline_log_prob(ra: jnp.ndarray, raunif: jnp.ndarray,
                           rb: jnp.ndarray, rbunif: jnp.ndarray,
                           ang: jnp.ndarray, angunif: jnp.ndarray,
                           w: jnp.ndarray, xk: jnp.ndarray, yk: jnp.ndarray,
                           deltak: jnp.ndarray, xl: jnp.ndarray,
                           yl: jnp.ndarray, deltal: jnp.ndarray):
    """Compute the log-density of the Mobius spline transformation on the sphere.

    """
    lpra = -jnp.log(2.) - jnp.log(
        grad_rational_quadratic(raunif, xk, yk, deltak))
    lprb = -jnp.log(2.) - jnp.log(
        vmap(grad_rational_quadratic)(rbunif, xl, yl, deltal))
    lpang = vmap(mobius_log_prob)(angunif, w)
    ldj = -(3. / 2 - 1.) * jnp.log(1 - jnp.square(ra))
    log_prob = lpra + lprb + lpang + ldj
    return log_prob
Пример #21
0
def calc_gp_prior(X, K,n_lats, Fourier = False):
	'''
	Calculates the GP log prior (time domain)
	x_samples should be nsamples by T
	K is T by T
	'''
	total_prior = 0
	if Fourier:
		for i in np.arange(n_lats):
			x_lat = X[i]
			total_prior = total_prior -(1/2)*(np.sum(np.square(x_lat)/K[i])+ np.sum(np.log(2*np.pi*K[i])))  

	# else:
	# 	Kinv = np.linalg.inv(K)
	# 	log_prior = -(1/2)*(np.shape(Kinv)[0]*np.log(2*np.pi) + np.matmul(rate,np.matmul(Kinv,rate))+ np.linalg.slogdet(K)[1]) 

	return total_prior
Пример #22
0
 def loss_fn(x, excite_port_idx=0):
     wrapped_meep = mpa.MeepJaxWrapper(
         simulation,
         [sources[excite_port_idx]],
         monitors,
         design_regions,
         frequencies,
     )
     monitor_values = wrapped_meep([x])
     s1p, s1m, s2m, s2p = monitor_values
     if excite_port_idx == 0:
         t = s2m / s1p
     else:
         t = s1m / s2p
     # Mean transmission vs wavelength
     t_mean = jnp.mean(jnp.square(jnp.abs(t)))
     return t_mean
Пример #23
0
 def loss_fn(rng, params, states, batch):
     model_fn = mutils.get_model_fn(model, params, states, train=train)
     data = batch['image']
     rng, step_rng = random.split(rng)
     labels = random.choice(step_rng, vpsde.N, shape=(data.shape[0], ))
     sqrt_alphas_cumprod = vpsde.sqrt_alphas_cumprod
     sqrt_1m_alphas_cumprod = vpsde.sqrt_1m_alphas_cumprod
     rng, step_rng = random.split(rng)
     noise = random.normal(step_rng, data.shape)
     perturbed_data = batch_mul(sqrt_alphas_cumprod[labels], data) + \
                      batch_mul(sqrt_1m_alphas_cumprod[labels], noise)
     rng, step_rng = random.split(rng)
     score, new_model_state = model_fn(perturbed_data, labels, rng=step_rng)
     losses = jnp.square(score - noise)
     losses = reduce_op(losses.reshape((losses.shape[0], -1)), axis=-1)
     loss = jnp.mean(losses)
     return loss, new_model_state
Пример #24
0
def compute_stats(inputs: JTensor,
                  padding: Optional[JTensor] = None) -> NestedMap:
  """Computes various stats over the valid data points in inputs."""
  # Let's compute stats in fp32
  inputs = inputs.astype(jnp.float32)
  if padding is None:
    padding = jnp.zeros_like(inputs)
  assert inputs.ndim == padding.ndim
  mask = 1.0 - padding

  sum_v = jnp.sum(inputs * mask)
  count_v = jnp.sum(jnp.ones_like(inputs) * mask)
  mean_v = sum_v / jnp.maximum(1.0, count_v)
  sum_v_squared = jnp.sum(jnp.square((inputs - mean_v) * mask))
  std_v = jnp.sqrt(sum_v_squared / jnp.maximum(1.0, count_v))
  max_v = jnp.max(jnp.abs(inputs * mask))

  return NestedMap(mean_v=mean_v, std_v=std_v, max_v=max_v)
Пример #25
0
    def cluster_dist_metric(point, metric_state: MetricState):
        if method == 'euclidean':
            dist = metric_state.cluster_centers - point
            return jnp.sum(jnp.square(dist), axis=-1)
        if method == 'mahalanobis':
            dist = metric_state.cluster_centers - point
            maha = vmap(lambda dist, C: dist @ C @ dist)(dist, metric_state.C)
            return maha
        if method == 'ellipsoid':
            dist = metric_state.cluster_centers - point
            weighted_maha = vmap(
                lambda dist, C, radii, num_k: (dist @ C @ dist) * jnp.exp(
                    log_ellipsoid_volume(radii) - jnp.log(num_k) + jnp.log(
                        num_S) - log_VS))(dist, metric_state.C,
                                          metric_state.radii,
                                          metric_state.num_k)

            return weighted_maha
Пример #26
0
 def optimize(state,
              grad,
              warmup=config.optim.warmup,
              grad_clip=config.optim.grad_clip):
   """Optimizes with warmup and gradient clipping (disabled if negative)."""
   lr = state.lr
   if warmup > 0:
     lr = lr * jnp.minimum(state.step / warmup, 1.0)
   if grad_clip >= 0:
     # Compute global gradient norm
     grad_norm = jnp.sqrt(
         sum([jnp.sum(jnp.square(x)) for x in jax.tree_leaves(grad)]))
     # Clip gradient
     clipped_grad = jax.tree_map(
         lambda x: x * grad_clip / jnp.maximum(grad_norm, grad_clip), grad)
   else:  # disabling gradient clipping if grad_clip < 0
     clipped_grad = grad
   return state.optimizer.apply_gradient(clipped_grad, learning_rate=lr)
def kl_gauss_gauss(z_mean, z_logvar, prior_params):
  """Compute the KL divergence between two diagonal Gaussian distributions.
            KL(q||p) = E_q[log q(z) - log p(z))]
   Args:
    z_mean: mean of posterior, z ~ q(z|x)
    z_logvar: logvar of posterior
    prior_z_mean: mean of prior, z ~ p(z)
    prior_z_logvar: logvar of prior

    Returns:
      np array of KL, computed analytically, same size as z_mean
  """
  prior_mean = prior_params['mean']
  prior_logvar = prior_params['logvar']
  return (0.5 * (prior_logvar - z_logvar
                 + np.exp(z_logvar - prior_logvar)
                 + np.square((z_mean - prior_mean) / np.exp(0.5 * prior_logvar))
                 - 1.0))
Пример #28
0
def loss_fn(params: flax.core.frozen_dict.FrozenDict,
            module: models.ActorCritic, minibatch: Tuple, clip_param: float,
            vf_coeff: float, entropy_coeff: float):
    """Evaluate the loss function.

  Compute loss as a sum of three components: the negative of the PPO clipped
  surrogate objective, the value function loss and the negative of the entropy
  bonus.

  Args:
    params: the parameters of the actor-critic model
    module: the actor-critic model
    minibatch: Tuple of five elements forming one experience batch:
               states: shape (batch_size, 84, 84, 4)
               actions: shape (batch_size, 84, 84, 4)
               old_log_probs: shape (batch_size,)
               returns: shape (batch_size,)
               advantages: shape (batch_size,)
    clip_param: the PPO clipping parameter used to clamp ratios in loss function
    vf_coeff: weighs value function loss in total loss
    entropy_coeff: weighs entropy bonus in the total loss

  Returns:
    loss: the PPO loss, scalar quantity
  """
    states, actions, old_log_probs, returns, advantages = minibatch
    log_probs, values = agent.policy_action(params, module, states)
    values = values[:, 0]  # Convert shapes: (batch, 1) to (batch, ).
    probs = jnp.exp(log_probs)

    value_loss = jnp.mean(jnp.square(returns - values), axis=0)

    entropy = jnp.sum(-probs * log_probs, axis=1).mean()

    log_probs_act_taken = jax.vmap(lambda lp, a: lp[a])(log_probs, actions)
    ratios = jnp.exp(log_probs_act_taken - old_log_probs)
    # Advantage normalization (following the OpenAI baselines).
    advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
    PG_loss = ratios * advantages
    clipped_loss = advantages * jax.lax.clamp(1. - clip_param, ratios,
                                              1. + clip_param)
    PPO_loss = -jnp.mean(jnp.minimum(PG_loss, clipped_loss), axis=0)

    return PPO_loss + vf_coeff * value_loss - entropy_coeff * entropy
Пример #29
0
def diffusion_reverse(*, x, z_t, logsnr_s, logsnr_t, x_logvar):
    """q(z_s | z_t, x) (requires logsnr_s > logsnr_t (i.e. s < t))."""
    alpha_st = jnp.sqrt((1. + jnp.exp(-logsnr_t)) / (1. + jnp.exp(-logsnr_s)))
    alpha_s = jnp.sqrt(nn.sigmoid(logsnr_s))
    r = jnp.exp(logsnr_t - logsnr_s)  # SNR(t)/SNR(s)
    one_minus_r = -jnp.expm1(logsnr_t - logsnr_s)  # 1-SNR(t)/SNR(s)
    log_one_minus_r = utils.log1mexp(logsnr_s -
                                     logsnr_t)  # log(1-SNR(t)/SNR(s))

    mean = r * alpha_st * z_t + one_minus_r * alpha_s * x

    if isinstance(x_logvar, str):
        if x_logvar == 'small':
            # same as setting x_logvar to -infinity
            var = one_minus_r * nn.sigmoid(-logsnr_s)
            logvar = log_one_minus_r + nn.log_sigmoid(-logsnr_s)
        elif x_logvar == 'large':
            # same as setting x_logvar to nn.log_sigmoid(-logsnr_t)
            var = one_minus_r * nn.sigmoid(-logsnr_t)
            logvar = log_one_minus_r + nn.log_sigmoid(-logsnr_t)
        elif x_logvar.startswith('medium:'):
            _, frac = x_logvar.split(':')
            frac = float(frac)
            logging.info('logvar frac=%f', frac)
            assert 0 <= frac <= 1
            min_logvar = log_one_minus_r + nn.log_sigmoid(-logsnr_s)
            max_logvar = log_one_minus_r + nn.log_sigmoid(-logsnr_t)
            logvar = frac * max_logvar + (1 - frac) * min_logvar
            var = jnp.exp(logvar)
        else:
            raise NotImplementedError(x_logvar)
    else:
        assert isinstance(x_logvar, jnp.ndarray) or isinstance(
            x_logvar, onp.ndarray)
        assert x_logvar.shape == x.shape
        # start with "small" variance
        var = one_minus_r * nn.sigmoid(-logsnr_s)
        logvar = log_one_minus_r + nn.log_sigmoid(-logsnr_s)
        # extra variance weight is (one_minus_r*alpha_s)**2
        var += jnp.square(one_minus_r) * nn.sigmoid(logsnr_s) * jnp.exp(
            x_logvar)
        logvar = jnp.logaddexp(
            logvar, 2. * log_one_minus_r + nn.log_sigmoid(logsnr_s) + x_logvar)
    return {'mean': mean, 'std': jnp.sqrt(var), 'var': var, 'logvar': logvar}
Пример #30
0
    def loss_fn(agent_params):
        (
            logits,
            v,
            reconstucted_lang,
            reconstructed_vision,
            *_,
        ) = apply_fast_slow_agent_model(
            agent_params,
            num_embeddings,
            embedding_dim,
            language_state,
            vision_state,
            h_prev,
            decoder_h_prev,
            False,
        )

        rhos = categorical_importance_sampling_ratios(logits[:-1],
                                                      behavior_logits[:-1],
                                                      actions[:-1])

        errors, pg_advantages, q_estimate = vtrace()
        critic_loss = jnp.square(errors)

        log_pi_a = jnp.take_along_axis(jax.nn.log_softmax(logits[:-1]),
                                       actions[:-1],
                                       axis=-1)

        pg_advantage = jax.lax.stop_gradient(pg_advantage)
        pg_loss = jnp.mean(-log_pi_a * pg_advantage)

        entropy_loss = cross_entropy_loss_fn(logits, logits)

        language_reconstruction_loss = cross_entropy_loss_fn(
            language_state, reconstructed_lang)
        vision_reconstruction_loss = cross_entropy_loss_fn(
            vision_state, reconstructed_lang)

        regularized_loss = jnp.mean(
            policy_eps * pg_loss + baseline_eps * critic_loss +
            entropy_eps * entropy_loss + reconstruction_eps *
            (language_reconstruction_loss + vision_reconstruction_loss))
        return loss