示例#1
0
 def ClippedObjectiveMean(
     dist_inputs, values, returns, actions, old_log_probs):
   """Clipped objective from the PPO algorithm."""
   advantages = returns - values
   probs_ratio = rl_layers.ProbsRatio(
       dist_inputs, actions, old_log_probs,
       log_prob_fun=self._policy_dist.log_prob)
   clipped_objective = rl_layers.ClippedObjective(
       probs_ratio, advantages, epsilon=self._epsilon)
   return jnp.mean(clipped_objective)
示例#2
0
 def f(preds, values, returns, actions, mask):
     advantages = jnp.squeeze(returns - stop_gradient(values), axis=-1)
     logps = self._policy_dist.log_prob(preds, actions)
     awr_loss = actor_critic.AWRLoss(beta=self._beta,
                                     w_max=self._w_max)(
                                         (logps, advantages,
                                          jnp.zeros_like(logps), mask))
     l2_value_loss = jnp.mean(
         (returns - values)**2) * self._value_loss_coeff
     return awr_loss + l2_value_loss
示例#3
0
文件: metrics.py 项目: zhaoqiuye/trax
 def f(values, weights):  # pylint: disable=invalid-name
     # This function assumes weights are 0 or 1.
     # Then compute 1: not-correct, 0: correct or masked
     not_correct = (1.0 - values) * weights
     axis_to_sum = list(range(1, len(not_correct.shape)))
     # Summing not-correct on all axes but batch. We're summing 0s and 1s,
     # so the sum is 0 if it's all 0 and >=1 in all other cases.
     not_correct_seq = np.sum(not_correct, axis=axis_to_sum)
     # Sequence is correct if not_correct_seq is 0, reverting here.
     correct_seq = 1.0 - np.minimum(1.0, not_correct_seq)
     return np.mean(correct_seq)  # Mean over batch.
示例#4
0
文件: metrics.py 项目: zsunpku/trax
def _WeightedSequenceMean(inputs, **unused_kwargs):
  """Returns a layer to compute weighted seqeunce accuracy mean."""
  values, weights = inputs  # This function assumes weights are 0 or 1.
  not_correct = (1.0 - values) * weights  # 1: not-correct, 0: correct or masked
  axis_to_sum = list(range(1, len(not_correct.shape)))
  # Summing not-correct on all axes but batch. We're summing 0s and 1s,
  # so the sum is 0 if it's all 0 and >=1 in all other cases.
  not_correct_seq = np.sum(not_correct, axis=axis_to_sum)
  # Sequence is correct if not_correct_seq is 0, reverting here.
  correct_seq = 1.0 - np.minimum(1.0, not_correct_seq)
  return np.mean(correct_seq)  # Mean over batch.
示例#5
0
 def AWRJointLoss(x, **unused_kwargs):  # pylint: disable=invalid-name
     preds, values, returns, actions, mask = x
     advantages = jnp.squeeze(returns - stop_gradient(values), axis=-1)
     logps = self._policy_dist.log_prob(preds, actions)
     awr_loss = actor_critic.AWRLoss(beta=self._beta,
                                     w_max=self._w_max)(
                                         (logps, advantages,
                                          jnp.zeros_like(logps), mask))
     l2_value_loss = jnp.mean(
         (returns - values)**2) * self._value_loss_coeff
     return awr_loss + l2_value_loss
示例#6
0
def A2CObjective(dist_inputs, values, returns, actions, mask, log_prob_fun,
                 normalize_advantages):
    """Definition of the Advantage Actor Critic (A2C) loss."""
    returns = returns.squeeze()
    values = values.squeeze()
    new_log_probs = NewLogProbs(dist_inputs, actions, log_prob_fun)
    advantages = returns - values
    if normalize_advantages:
        advantages = advantages - jnp.mean(advantages)
        advantages /= jnp.std(advantages) + 1e-8
    return -jnp.sum(new_log_probs * advantages * mask) / jnp.sum(mask)
示例#7
0
 def predict(x, weights, state, rng):
   """Predict function JIT-compileds and parallelized as requested."""
   res, state = _combine_devices(model_predict(
       reshape_by_device(x, n_devices),
       weights,
       state,
       jnp.stack(math.random.split(rng, n_devices))))
   if do_mean:
     return math.nested_map(lambda y: jnp.mean(y, axis=0), res), state
   else:
     return res, state
示例#8
0
def actor_loss(actions, advantage_weights, log_probab_actions_new, state=None):
    """Actor loss."""

    # log_probab_actions_new's shape is (AB, 1, #C, #A), AB is actor batch.
    lp = jnp.squeeze(log_probab_actions_new, axis=1)
    AB, NC = actions.shape  # pylint: disable=invalid-name
    log_probs = lp[jnp.arange(AB)[:, None], jnp.arange(NC)[None, :], actions]

    # TODO(afrozm): Clarify this.
    #   log_probs are shaped (AB, #C), however advantage_weights are (AB,)
    return -1.0 * jnp.mean(log_probs * advantage_weights[:, None]), state
示例#9
0
 def f(dist_inputs, values, returns, actions, old_log_probs):
   """Clipped objective from the PPO algorithm."""
   advantages = returns - values
   probs_ratio = rl_layers.ProbsRatio(
       dist_inputs, actions, old_log_probs,
       log_prob_fun=self._policy_dist.log_prob)
   # advantages are of the shape [128,1,1]
   # and probs_ratio are of the shape [128,1]
   advantages = advantages.squeeze(axis=2)
   clipped_objective = rl_layers.ClippedObjective(
       probs_ratio, advantages, epsilon=self._epsilon)
   return jnp.mean(clipped_objective)
示例#10
0
 def f(dist_inputs, values, returns, actions, old_log_probs, mask):
     """A2C objective mean."""
     del old_log_probs
     a2c_objective = rl_layers.A2CObjective(
         dist_inputs,
         values,
         returns,
         actions,
         mask,
         log_prob_fun=self._policy_dist.log_prob,
         normalize_advantages=self._normalize_advantages)
     return jnp.mean(a2c_objective)
示例#11
0
 def f(dist_inputs, values, returns, actions, old_log_probs):
     """Clipped objective from the PPO algorithm."""
     ppo_objective = rl_layers.PPOObjective(
         dist_inputs,
         values,
         returns,
         actions,
         old_log_probs,
         log_prob_fun=self._policy_dist.log_prob,
         epsilon=self._epsilon,
         normalize_advantages=self._normalize_advantages)
     return jnp.mean(ppo_objective)
示例#12
0
def ApproximateKLDivergence(dist_inputs, actions, old_log_probs, log_prob_fun):
    """Probability Ratio from the PPO algorithm."""
    # TODO(henrykm): Clarify the old_log_probs and squeezing
    # Old log probs have an undesirable extra dimension which we remove here
    old_log_probs = jnp.array(old_log_probs.squeeze(axis=-1),
                              dtype=jnp.float32)
    new_log_probs = NewLogProbs(dist_inputs, actions, log_prob_fun)
    # The ratio between new_probs and old_probs expressed
    # using log_probs and exponentaion
    approximate_kl_divergence = 0.5 * \
        jnp.mean(new_log_probs - old_log_probs) ** 2
    return approximate_kl_divergence
示例#13
0
 def fn(dist_inputs, actions, q_values, act_log_probs, mask):
     del dist_inputs, actions, mask
     q_values = jnp.swapaxes(q_values, 0, 1)
     act_log_probs = jnp.swapaxes(act_log_probs, 0, 1)
     if self._sample_all_discrete_actions:
         values = jnp.sum(q_values * jnp.exp(act_log_probs), axis=0)
     else:
         values = jnp.mean(q_values, axis=0)
     advantages = q_values - values  # Broadcasting values over n_samples
     if preprocess:
         advantages = self._preprocess_advantages(advantages)
     return advantages
示例#14
0
def A2CObjective(dist_inputs, values, returns, dones, rewards, actions, mask,
                 log_prob_fun, normalize_advantages):
    """Definition of the Advantage Actor Critic (A2C) loss."""
    # dist_inputs of the shape float32[128,1,18]
    # values of the shape float32[128,1,1]
    # returns of the shape float32[128,1,1]
    # dones of the shape int32[128,1,1]
    # actions of the shape int32[128,1]
    # and mask of the shape float32[128,1]
    # We have to squeeze values and returns, because we
    # are planning to compute (return - values) * new_log_probs * mask
    # and all of them should be of the same dimension
    values = values.squeeze(axis=2)
    returns = returns.squeeze(axis=2)
    dones = dones.squeeze(axis=2)
    rewards = rewards.squeeze(axis=2)
    assert rewards.shape == dones.shape, (
        f'rewards.shape was {rewards.shape} and dones.shape was {dones.shape}')
    assert dones.shape == values.shape, (
        f'dones.shape was {dones.shape} and values.shape was {values.shape}')
    assert returns.shape == values.shape, (
        f'returns.shape was {returns.shape} and values.shape was {values.shape}'
    )
    assert values.shape == mask.shape, (
        f'values.shape was {values.shape} and mask.shape was {mask.shape}')
    assert returns.shape[0] == dist_inputs.shape[0], (
        f'returns.shape[0] was {returns.shape[0]} and dist_inputs.shape[0] was '
        f'{dist_inputs.shape[0]}')

    new_log_probs = NewLogProbs(dist_inputs, actions, log_prob_fun)
    assert new_log_probs.shape == mask.shape, (
        f'new_log_probs.shape was {new_log_probs.shape} and mask.shape was '
        f'{mask.shape}')

    # jaxified versions of
    # returns[dones] = rewards[dones]
    # values[dones] = 0
    returns = jnp.where(dones, rewards, returns)
    values = jnp.where(dones, jnp.zeros_like(values), values)
    advantages = returns - values
    if normalize_advantages:
        advantages = advantages - jnp.mean(advantages)
        advantages /= jnp.std(advantages) + 1e-8
    assert new_log_probs.shape == advantages.shape, (
        f'new_log_probs.shape was {new_log_probs.shape} and advantages.shape was '
        f'{advantages.shape}')

    # One of the motivation to the squeezes and assertions is to
    # avoid [128,1] * [128,1,1] * [128] multiplications in the definition
    # of the a2c objective - we insist on the same shapes
    a2c_objective = -jnp.sum(new_log_probs * advantages * mask) / jnp.sum(mask)
    return a2c_objective
示例#15
0
    def f(new_log_probs, advantages, old_log_probs, mask):
      # new_log_probs of the shape float32[128,1]
      # advantages of the shape int32[128,1]
      # old_log_probs of the shape int32[128,1]
      # mask of the shape int32[128,1]
      if new_log_probs.shape != advantages.shape:
        raise ValueError('New log-probs and advantages shapes '
                         'should be the same, %s != %s' % (new_log_probs.shape,
                                                           advantages.shape))
      if new_log_probs.shape != old_log_probs.shape:
        raise ValueError('New log-probs and old log-probs shapes '
                         'should be the same, %s != %s' % (new_log_probs.shape,
                                                           old_log_probs.shape))
      if new_log_probs.shape != mask.shape:
        raise ValueError('New log-probs and mask shapes should be the same'
                         ', %s != %s' % (new_log_probs.shape, mask.shape))

      # The ratio between new_probs and old_probs expressed
      # using log_probs and exponentaion
      probs_ratio = jnp.exp(new_log_probs - old_log_probs)
      if advantages.shape != probs_ratio.shape:
        raise ValueError('New log-probs and old log probs shapes '
                         'should be the same, %s != %s' % (advantages.shape,
                                                           probs_ratio.shape))
      unclipped_objective = probs_ratio * advantages
      clipped_objective = jnp.clip(probs_ratio,
                                   1 - self._epsilon,
                                   1 + self._epsilon) * advantages

      if unclipped_objective.shape != probs_ratio.shape:
        raise ValueError('unclipped_objective and clipped_objective shapes '
                         'should be the same, %s != %s' % (
                             unclipped_objective.shape,
                             clipped_objective.shape))

      ppo_objective = jnp.minimum(unclipped_objective, clipped_objective)

      if ppo_objective.shape != mask.shape:
        raise ValueError('ppo_objective and mask shapes '
                         'should be the same, %s != %s' % (
                             ppo_objective.shape,
                             mask.shape))

      ppo_loss = -jnp.sum(ppo_objective * mask) / jnp.sum(mask)
      entropy_vec = self._policy_dist.entropy(
          new_log_probs) * self._entropy_coeff
      entropy_loss = jnp.mean(entropy_vec)
      combined_loss = ppo_loss - entropy_loss

      return combined_loss
示例#16
0
        def ProbsRatioMean(x, **unused_kwargs):
            """Probability Ratio Mean from the PPO algorithm."""
            dist_inputs, _, _, actions, old_log_probs = x
            new_log_probs = self._policy_dist.log_prob(dist_inputs, actions)

            # Old log probs have an undesirable extra dimension which we remove here
            old_log_probs = jnp.array(old_log_probs.squeeze(axis=-1),
                                      dtype=jnp.float32)
            new_log_probs = jnp.array(new_log_probs.squeeze(axis=-1))

            # The ratio between new_probs and old_probs expressed
            # using log_probs and exponentaion
            probs_ratio = jnp.exp(new_log_probs - old_log_probs)
            return jnp.mean(probs_ratio)
示例#17
0
文件: core.py 项目: yangcaot/trax
def Mean(axis=-1, keepdims=False):
  """Returns a layer that computes mean values using one tensor axis.

  `Mean` uses one tensor axis to form groups of values and replaces each group
  with the mean value of that group. The resulting values can either remain
  in their own size 1 axis (`keepdims=True`), or that axis can be removed from
  the overall tensor (default `keepdims=False`), lowering the rank of the
  tensor by one.

  Args:
    axis: Axis along which values are grouped for computing a mean.
    keepdims: If `True`, keep the resulting size 1 axis as a separate tensor
        axis; else, remove that axis.
  """
  return Fn('Mean', lambda x: jnp.mean(x, axis=axis, keepdims=keepdims))
示例#18
0
  def forward(self, inputs, weights):
    gamma, beta, epsilon_l = weights

    epsilon = self._init_epsilon
    if epsilon_l is not base.EMPTY_WEIGHTS:
      epsilon += np.abs(epsilon_l[0])

    # Omit B and C
    axis = tuple(range(1, len(np.shape(inputs)) - 1))
    # (B, 1, 1, C)
    nu2 = np.mean(inputs**2, axis=axis, keepdims=True)
    # (B, W, H, C)
    xhat = inputs / np.sqrt(nu2 + epsilon)

    return gamma * xhat + beta
示例#19
0
def PPOObjective(dist_inputs, values, returns, actions, old_log_probs,
                 log_prob_fun, epsilon, normalize_advantages):
    """PPO Objective."""
    # Returns and values are arriving with two extra dimensions
    # TODO(henrykm): remove these dimensions at an earlier stage?
    returns = returns.squeeze()
    values = values.squeeze()
    probs_ratio = ProbsRatio(dist_inputs, actions, old_log_probs, log_prob_fun)
    advantages = returns - values
    if normalize_advantages:
        advantages = advantages - jnp.mean(advantages)
        advantages /= jnp.std(advantages) + 1e-8
    unclipped_objective = UnclippedObjective(probs_ratio, advantages)
    clipped_objective = ClippedObjective(probs_ratio, advantages, epsilon)
    ppo_objective = jnp.minimum(unclipped_objective, clipped_objective)
    return ppo_objective
示例#20
0
 def f(dist_inputs, values, returns, dones, rewards, actions,
       old_log_probs):
     """Unclipped objective Mean from the PPO algorithm."""
     del dones, rewards
     advantages = returns - values
     probs_ratio = rl_layers.ProbsRatio(
         dist_inputs,
         actions,
         old_log_probs,
         log_prob_fun=self._policy_dist.log_prob)
     # advantages are of the shape [128,1,1]
     # and probs_ratio are of the shape [128,1]
     advantages = advantages.squeeze(axis=2)
     unclipped_objective = rl_layers.UnclippedObjective(
         probs_ratio, advantages)
     return jnp.mean(unclipped_objective)
示例#21
0
def PPOObjective(dist_inputs, values, returns, actions, old_log_probs,
                 log_prob_fun, epsilon, normalize_advantages):
    """PPO Objective."""
    # dist_inputs of the shape float32[128,1,18]
    # values of the shape float32[128,1,1]
    # returns of the shape float32[128,1,1]
    # actions of the shape int32[128,1]
    # and old_log_probs of the shape float32[128,1]
    returns = returns.squeeze(axis=2)
    values = values.squeeze(axis=2)
    assert returns.shape == values.shape, (
        f'returns.shape was {returns.shape} and values.shape was {values.shape}'
    )
    assert returns.shape == old_log_probs.shape, (
        f'returns.shape was {returns.shape} and'
        f'old_log_probs.shape was {old_log_probs.shape}')

    probs_ratio = ProbsRatio(dist_inputs, actions, old_log_probs, log_prob_fun)
    assert probs_ratio.shape == old_log_probs.shape, (
        f'probs_ratio.shape was {probs_ratio.shape} and'
        f'old_log_probs.shape was {old_log_probs.shape}')

    advantages = returns - values
    if normalize_advantages:
        advantages = advantages - jnp.mean(advantages)
        advantages /= jnp.std(advantages) + 1e-8
    assert old_log_probs.shape == advantages.shape, (
        f'old_log_probs.shape was {old_log_probs.shape} and advantages.shape was '
        f'{advantages.shape}')

    unclipped_objective = UnclippedObjective(probs_ratio, advantages)
    assert unclipped_objective.shape == advantages.shape, (
        f'old_log_probs.shape was {old_log_probs.shape} and'
        f'unclipped_objective.shape was {unclipped_objective.shape}')

    clipped_objective = ClippedObjective(probs_ratio, advantages, epsilon)
    assert clipped_objective.shape == advantages.shape, (
        f'old_log_probs.shape was {old_log_probs.shape} and'
        f'clipped_objective.shape was {clipped_objective.shape}')

    ppo_objective = jnp.minimum(unclipped_objective, clipped_objective)
    assert ppo_objective.shape == advantages.shape, (
        f'old_log_probs.shape was {old_log_probs.shape} and'
        f'ppo_objective.shape was {ppo_objective.shape}')

    return ppo_objective
示例#22
0
def critic_loss(observations,
                target_values,
                value_predictions_new,
                state=None):
    """Critic loss."""
    # There is no padding involved here, these are all observations.
    (batch, *obs_shape) = observations.shape
    del obs_shape
    if (batch, ) != target_values.shape:
        raise ValueError(f'batch dimension is not the same: obs batch {batch}'
                         f' vs target values batch {target_values.shape[0]}')

    # TODO(afrozm): In the reference implementation, they pass the target through
    #  a trained normalizer before subtracting.

    loss = 0.5 * jnp.mean(jnp.square(target_values - value_predictions_new))
    return loss, state
示例#23
0
 def policy_inputs(self, trajectory, values):
     """Create inputs to policy model from a TrajectoryNp and values."""
     # How much TD to use is determined by the added policy slice length,
     # as the policy batches need to be this much longer to calculate TD.
     advantages = self._advantage_estimator(
         trajectory.rewards,
         trajectory.returns,
         values,
         gamma=self._task.gamma,
         n_extra_steps=self._added_policy_slice_length,
     )
     if self._advantage_normalization:
         advantages = (
             (advantages - jnp.mean(advantages)) /
             (jnp.std(advantages) + self._advantage_normalization_epsilon))
     # Observations should be the same length as advantages - so if we are
     # using n_extra_steps, we need to trim the length to match.
     obs = trajectory.observations[:, :advantages.shape[1]]
     act = trajectory.actions[:, :advantages.shape[1]]
     old_logps = trajectory.log_probs[:, :advantages.shape[1]]
     mask = trajectory.mask[:, :advantages.
                            shape[1]]  # Mask to zero-out padding.
     # Shape checks to help debugging.
     if len(advantages.shape) != 2:
         raise ValueError('Advantages are expected to have shape ' +
                          '[batch_size, length], got: %s' %
                          str(advantages.shape))
     if act.shape[0:2] != advantages.shape:
         raise ValueError(
             'First 2 dimensions of actions should be the same as in '
             'advantages, %s != %s' % (act.shape[0:2], advantages.shape))
     if obs.shape[0:2] != advantages.shape:
         raise ValueError(
             'First 2 dimensions of observations should be the same '
             'as in advantages, %s != %s' %
             (obs.shape[0:2], advantages.shape))
     if old_logps.shape != advantages.shape:
         raise ValueError(
             'Old log-probs and advantages shapes should be the same'
             ', %s != %s' % (old_logps.shape, advantages.shape))
     if mask.shape != advantages.shape:
         raise ValueError('Mask and advantages shapes should be the same'
                          ', %s != %s' % (mask.shape, advantages.shape))
     return (obs, act, advantages, old_logps, mask)
示例#24
0
        def LossInput(dist_inputs, actions, q_values, act_log_probs, mask):  # pylint: disable=invalid-name
            """Calculates action log probabilities and normalizes advantages."""
            # (batch_size, n_samples, ...) -> (n_samples, batch_size, ...)
            q_values = jnp.swapaxes(q_values, 0, 1)
            mask = jnp.swapaxes(mask, 0, 1)
            actions = jnp.swapaxes(actions, 0, 1)
            act_log_probs = jnp.swapaxes(act_log_probs, 0, 1)

            # TODO(pkozakowski,lukaszkaiser): Try max here, or reweighting?
            # Reweight: values = jnp.sum(q_values * jnp.exp(act_log_probs), axis=0)
            values = jnp.mean(q_values, axis=0)
            advantages = q_values - values  # Broadcasting values over n_samples
            advantages = self._preprocess_advantages(advantages)

            # Broadcast inputs and calculate log-probs
            dist_inputs = jnp.broadcast_to(
                dist_inputs, (self._q_value_n_samples, ) + dist_inputs.shape)
            log_probs = self._policy_dist.log_prob(dist_inputs, actions)
            return (log_probs, advantages, act_log_probs, mask)
示例#25
0
 def policy_metrics(self):
     metrics = {
         'policy_loss':
         self.policy_loss,
         'advantage_mean':
         tl.Serial(
             self._policy_inputs_to_advantages(False),
             tl.Fn('Mean', lambda x: jnp.mean(x))  # pylint: disable=unnecessary-lambda
         ),
         'advantage_std':
         tl.Serial(
             self._policy_inputs_to_advantages(False),
             tl.Fn('Std', lambda x: jnp.std(x))  # pylint: disable=unnecessary-lambda
         )
     }
     metrics.update(
         awr_metrics(
             self._beta,
             preprocess_layer=self._policy_inputs_to_advantages(True)))
     return metrics
示例#26
0
文件: trainer_lib.py 项目: srush/trax
  def train_step(self, batch):
    """Run one training step and update self._opt_state."""
    # Calculate the current optimizer parameters.
    # TODO(pkozakowski): Optimizer parameters get polluted with model state,
    # which doesn't break anything but is weird. Filter it out.
    opt_param_updates = self._for_n_devices(
        math.nested_map(np.array, self.nontrainable_params))
    opt_state = self._opt_state
    opt_state.opt_params.update(opt_param_updates)

    # Run the update.
    (weights, slots, stat), self._model_state, self._rngs = self._jit_update_fn(
        self._step, opt_state, batch, self._model_state, self._rngs)
    self._model_state = self._map_to_state_dicts(self._state_dicts_update)
    self._opt_state = opt_state._replace(weights=weights, slots=slots)
    if self._should_log_now():
      for name, value in stat.items():
        scalar_value = np.mean(value)  # On  multiple devices, take the mean.
        self._train_sw.scalar('training/' + name, scalar_value, step=self._step)
    self._step += 1
示例#27
0
    def f(log_probs, advantages, old_log_probs, mask):
      del old_log_probs  # Not used in A2C.
      # log_probs of the shape float32[128,1]
      # advantages of the shape int32[128,1]
      # mask of the shape int32[128,1]
      if log_probs.shape != advantages.shape:
        raise ValueError('New log-probs and advantages shapes '
                         'should be the same, %s != %s' % (log_probs.shape,
                                                           advantages.shape))
      if log_probs.shape != mask.shape:
        raise ValueError('New log-probs and mask shapes should be the same'
                         ', %s != %s' % (log_probs.shape, mask.shape))

      a2c_objective = -jnp.sum(log_probs * advantages * mask) / jnp.sum(mask)

      entropy_vec = self._policy_dist.entropy(log_probs) * self._entropy_coeff
      entropy_loss = jnp.mean(entropy_vec)

      combined_loss = a2c_objective - entropy_loss

      return combined_loss
示例#28
0
  def policy_batches_stream(self):
    """Use the RLTask self._task to create inputs to the policy model."""
    # For now TD-0 estimation of the value. TODO(pkozakowski): Support others?
    for np_trajectory in self._task.trajectory_batch_stream(
        self._policy_batch_size,
        epochs=self._replay_epochs,
        max_slice_length=self._max_slice_length,
        include_final_state=False,
    ):
      (q_values, actions) = self._run_value_model(
          np_trajectory.observations, np_trajectory.dist_inputs
      )
      # TODO(pkozakowski): Try max here.
      values = jnp.mean(q_values, axis=0)

      if len(values.shape) != 2:
        raise ValueError('Values are expected to have shape ' +
                         '[batch_size, length], got: %s' % str(values.shape))
      if values.shape[0] != self._policy_batch_size:
        raise ValueError('Values first dimension should = policy batch size, ' +
                         '%d != %d' %(values.shape[0], self._policy_batch_size))

      # q_values shape: (n_samples, batch_size, length)
      # values shape: (batch_size, length)
      # Computing advantages by broadcasting over n_samples.
      advantages = q_values - values
      mask = jnp.broadcast_to(np_trajectory.mask, advantages.shape)

      shapes.assert_shape_equals(
          advantages, (self._q_value_n_samples,) + values.shape
      )
      shapes.assert_same_shape(mask, advantages)

      # Swapping the n_samples and batch_size axes, so the input is split
      # between accelerators along the batch_size axis.
      advantages = jnp.swapaxes(advantages, 0, 1)
      mask = jnp.swapaxes(mask, 0, 1)

      yield (np_trajectory.observations, actions, advantages, mask, mask)
示例#29
0
  def test_custom_id_grad(self):

    class IdWithIdGrad(base.Layer):

      def forward(self, x, weights):
        return x

      @property
      def has_backward(self):
        return True

      def backward(self, inputs, output, grad, weights, state, new_state, rng):
        return (inputs, ())

    layer = IdWithIdGrad()
    rng = math.random.get_prng(0)
    input_signature = shapes.ShapeDtype((9, 17))
    random_input = math.random.uniform(rng, input_signature.shape,
                                       minval=-1.0, maxval=1.0)
    layer.init(input_signature)
    f = lambda x: jnp.mean(layer(x))
    grad = math.grad(f)(random_input)
    self.assertEqual(grad.shape, (9, 17))  # Gradient for each input.
    self.assertEqual(sum(sum(grad)), sum(sum(random_input)))  # Same as input.
示例#30
0
 def predict(x, weights, state, rng):
     """Predict function jited and parallelized as requested."""
     res, state = _combine_devices(
         model_predict(reshape_by_device(x, n_devices), weights, state,
                       np.stack(math.random.split(rng, n_devices))))
     return math.nested_map(lambda y: np.mean(y, axis=0), res), state