Exemple #1
0
def log_gaussian_pdf(x, mu, sigma):  # pylint: disable=invalid-name
    """Compute log N(x | mu, sigma)."""
    a = mu.shape[-1] * jnp.log(2 * jnp.pi)
    _, b = jnp.linalg.slogdet(sigma)
    y = jnp.linalg.solve(sigma, x - mu)
    y = jnp.expand_dims(y, axis=-1)
    xm = jnp.expand_dims(x - mu, axis=-2)
    c = jnp.matmul(xm, y)
    c = jnp.squeeze(jnp.squeeze(c, axis=-1), axis=-1)
    return -0.5 * (a + b + c)
Exemple #2
0
def log_gaussian_diag_pdf(x, mu, diag_sigma):  # pylint: disable=invalid-name
    """Compute log N(x | mu, eye(diag_sigma))."""
    a = mu.shape[-1] * jnp.log(2 * jnp.pi)
    b = jnp.sum(jnp.log(diag_sigma), axis=-1)
    y = x - mu / diag_sigma
    y = jnp.expand_dims(y, axis=-1)
    xm = jnp.expand_dims(x - mu, axis=-2)
    c = jnp.matmul(xm, y)
    c = jnp.squeeze(jnp.squeeze(c, axis=-1), axis=-1)
    return -0.5 * (a + b + c)
Exemple #3
0
    def update(self, step, grads, weights, slots, opt_params):
        updates = []
        learning_rate = opt_params['learning_rate']
        beta1 = opt_params['beta1']
        decay_rate = opt_params['decay_rate']
        clipping_threshold = opt_params['clipping_threshold']
        weight_decay_rate = opt_params['weight_decay_rate']
        epsilon1 = opt_params['epsilon1']
        epsilon2 = opt_params['epsilon2']
        decay_rate = self._decay_rate_pow(step, exponent=decay_rate)
        update_scale = learning_rate
        if self._multiply_by_parameter_scale:
            update_scale *= np.maximum(np.sqrt(np.mean(weights * weights)),
                                       epsilon2)
        mixing_rate = 1.0 - decay_rate

        grads_sqr = grads * grads + epsilon1
        if self._factored and len(weights.shape) >= 2:
            v_row = slots.pop(0)
            v_col = slots.pop(0)
            new_v_row = decay_rate * v_row + mixing_rate * np.mean(grads_sqr,
                                                                   axis=-1)
            new_v_col = decay_rate * v_col + mixing_rate * np.mean(grads_sqr,
                                                                   axis=-2)
            updates.extend([new_v_row, new_v_col])
            row_col_mean = np.mean(new_v_row, axis=-1, keepdims=True)
            row_factor = (new_v_row / row_col_mean)**-0.5
            col_factor = (new_v_col)**-0.5
            y = (grads * np.expand_dims(row_factor, axis=-1) *
                 np.expand_dims(col_factor, axis=-2))
        else:
            v = slots.pop(0)
            new_v = decay_rate * v + mixing_rate * grads_sqr
            updates.append(new_v)
            y = grads * (new_v)**-0.5

        if self._do_clipping:
            clipping_denom = (np.maximum(
                1.0,
                np.sqrt(np.mean(y * y)) / clipping_threshold))
            y /= clipping_denom

        subtrahend = update_scale * y
        if self._do_momentum:
            m = slots.pop(0)
            new_m = beta1 * m + (1.0 - beta1) * subtrahend
            subtrahend = new_m
            updates.append(new_m)

        new_weights = (1 - weight_decay_rate) * weights - subtrahend
        # TODO(lukaszkaiser): why is the astype needed here? Check and correct.
        return new_weights.astype(weights.dtype), updates
Exemple #4
0
 def forward_with_state(self,
                        inputs,
                        weights=base.EMPTY_WEIGHTS,
                        state=base.EMPTY_STATE,
                        rng=None,
                        **kwargs):
     if self._mode in ('train', 'eval'):
         x = inputs
         symbol_size = np.shape(x)[1]
         px = weights[:, :symbol_size, :]
         if self._dropout == 0:
             return (x + px, state)
         else:
             noise_shape = list(px.shape)
             for dim in self._dropout_broadcast_dims:
                 noise_shape[dim] = 1
             keep_prob = 1.0 - self._dropout
             if math.backend_name() == 'jax':
                 keep_prob = jax.lax.tie_in(
                     x, np.full((), keep_prob, dtype=x.dtype))
             keep = math.random.bernoulli(rng, keep_prob,
                                          tuple(noise_shape))
             multiplier = keep.astype(x.dtype) / keep_prob
             return (x + px * multiplier, state)
     else:
         assert self._mode == 'predict'
         assert self._dropout == 0
         # State in this class is only used for fast inference. In that case,
         # the model is called with consecutive elements position-by-position.
         # This positional encoding layer needs to store the index of the current
         # position then and increment it on each call -- that's how state is used
         # and updated below.
         return (inputs + np.expand_dims(weights[0, state, :], 1),
                 state + 1)
Exemple #5
0
def combined_loss(new_weights,
                  observations,
                  actions,
                  target_values,
                  advantage_weights,
                  policy_and_value_net_apply,
                  state=None,
                  rng=None):
  """Returns the loss components."""
  # reshape as (batch, 1, *obs_shape) - this is because that is the signature
  # demanded by `policy_and_value_net_apply`
  observations = np.expand_dims(observations, axis=1)

  (log_probab_actions_new, value_predictions_new) = (
      policy_and_value_net_apply(
          observations, weights=new_weights, state=state, rng=rng))

  critic_loss_val, intermediate_state = critic_loss(
      observations,
      target_values,
      value_predictions_new,
      state=state)
  actor_loss_val, final_state = actor_loss(
      actions,
      advantage_weights,
      log_probab_actions_new,
      state=intermediate_state)
  entropy_val = entropy(log_probab_actions_new)
  return critic_loss_val, actor_loss_val, entropy_val, final_state
Exemple #6
0
  def forward_with_state(self, inputs, weights, state, rng):
    if self._mode != 'predict':
      x = inputs
      symbol_size = jnp.shape(x)[1]
      px = weights[:, :symbol_size, :]
      if self._dropout == 0:
        return (x + px, state)
      else:
        noise_shape = list(px.shape)
        for dim in self._dropout_broadcast_dims:
          noise_shape[dim] = 1
        keep_prob = 1.0 - self._dropout
        if math.backend_name() == 'jax':
          keep_prob = jax.lax.tie_in(x, jnp.full((), keep_prob, dtype=x.dtype))
        keep = math.random.bernoulli(rng, keep_prob, tuple(noise_shape))
        multiplier = keep.astype(x.dtype) / keep_prob
        return (x + px * multiplier, state)
    else:
      if self._dropout != 0:
        raise ValueError(f'In predict mode, but dropout rate '
                         f'({self._dropout}) is not zero.')

      # State in this class is only used for fast inference. In that case,
      # the model is called with consecutive elements position-by-position.
      # This positional encoding layer needs to store the index of the current
      # position then and increment it on each call -- that's how state is used
      # and updated below.
      if inputs.shape[1] == 1:
        return (inputs + jnp.expand_dims(weights[0, state, :], 1), state + 1)
      else:
        emb = []
        for i in range(inputs.shape[0]):
          emb.append(jax.lax.dynamic_slice_in_dim(
              weights[0], state[i], inputs.shape[1], axis=0))
        return inputs + jnp.stack(emb, 0), state + inputs.shape[1]
Exemple #7
0
def combined_loss(new_weights,
                  observations,
                  actions,
                  target_values,
                  advantage_weights,
                  policy_and_value_net_apply,
                  action_space,
                  state=None,
                  rng=None):
    """Returns the loss components."""

    # TODO(afrozm): This is where we need to eventually feed the earlier
    #  observations than this observation, currently the replay buffer just gives
    #  the observation as is, for transformer like policies, we should also get
    #  all the earlier observations as well, and then the extra dimension will
    #  just be time. For now we reshape as (batch, 1, *obs_shape).
    observations = jnp.expand_dims(observations, axis=1)

    (log_probab_actions_new, value_predictions_new, state_new,
     unused_rng_new) = (policy_based_utils.run_policy_all_timesteps(
         policy_and_value_net_apply,
         observations,
         new_weights,
         state,
         rng,
         action_space,
     ))

    critic_loss_val, intermediate_state = critic_loss(observations,
                                                      target_values,
                                                      value_predictions_new,
                                                      state=state_new)
    actor_loss_val, final_state = actor_loss(actions,
                                             advantage_weights,
                                             log_probab_actions_new,
                                             state=intermediate_state)
    entropy_val = entropy(log_probab_actions_new)

    return critic_loss_val, actor_loss_val, entropy_val, final_state