Example #1
0
    def smoothl1loss(model_output, targets, weights):  # pylint: disable=invalid-name
        r"""Returns weighted smooth L1 norm of `model_output - targets`.

    The smooth L1 loss, also known as the Huber loss, is defined as:
    .. math::
        z_i =
        \begin{cases}
        0.5 (x_i - y_i)^2, & \text{if } |x_i - y_i| < 1 \\
        |x_i - y_i| - 0.5, & \text{otherwise }
        \end{cases}

    Args:
      model_output: Output from one batch, treated as an unanalyzed tensor.
      targets: Tensor of same shape as `model_output` containing element-wise
          target values.
      weights: Tensor of same shape as `model_output` and `targets`, containing
          element-wise weight values.
    """
        shapes.assert_same_shape(model_output, targets)
        shapes.assert_same_shape(targets, weights)
        l1_dist = jnp.abs(model_output - targets)
        smooth_dist = jnp.where(l1_dist < 1, 0.5 * l1_dist**2, l1_dist - 0.5)
        shapes.assert_same_shape(smooth_dist, weights)
        weighted_smooth_dist = weights * smooth_dist
        return jnp.sum(weighted_smooth_dist) / jnp.sum(weights)
Example #2
0
 def f(log_probs, advantages, old_log_probs, mask):
     if reweight:  # Use new policy weights for sampled actions instead.
         mask *= jnp.exp(fastmath.stop_gradient(log_probs) - old_log_probs)
     if sampled_all_discrete:  # Actions were sampled uniformly; weight them.
         mask *= jnp.exp(old_log_probs)
     weights = jnp.minimum(awr_weights(advantages, beta), w_max)
     return -jnp.sum(log_probs * weights * mask) / jnp.sum(mask)
Example #3
0
 def f(model_output, targets, weights):  # pylint: disable=invalid-name
     shapes.assert_same_shape(model_output, targets)
     shapes.assert_same_shape(model_output, weights)
     l1_dist = jnp.abs(model_output - targets)
     smooth_dist = jnp.where(l1_dist < 1, 0.5 * l1_dist**2, l1_dist - 0.5)
     weighted_smooth_dist = weights * smooth_dist
     return jnp.sum(weighted_smooth_dist) / jnp.sum(weights)
Example #4
0
 def f(values, actions, returns, mask):
   ind_0, ind_1 = np.indices(actions.shape)
   # We calculate length using the shape of returns
   # and adequatly remove a superflous slice of values.
   # An analogous operation is done in value_batches_stream.
   length = returns.shape[1]
   values = values[:, :length, :]
   selected_values = values[ind_0, ind_1, actions]
   shapes.assert_same_shape(selected_values, returns)
   shapes.assert_same_shape(selected_values, mask)
   return jnp.sum(selected_values) / jnp.sum(mask)
Example #5
0
def A2CObjective(dist_inputs, values, returns, dones, rewards, actions, mask,
                 log_prob_fun, normalize_advantages):
    """Definition of the Advantage Actor Critic (A2C) loss."""
    # dist_inputs of the shape float32[128,1,18]
    # values of the shape float32[128,1,1]
    # returns of the shape float32[128,1,1]
    # dones of the shape int32[128,1,1]
    # actions of the shape int32[128,1]
    # and mask of the shape float32[128,1]
    # We have to squeeze values and returns, because we
    # are planning to compute (return - values) * new_log_probs * mask
    # and all of them should be of the same dimension
    values = values.squeeze(axis=2)
    returns = returns.squeeze(axis=2)
    dones = dones.squeeze(axis=2)
    rewards = rewards.squeeze(axis=2)
    assert rewards.shape == dones.shape, (
        f'rewards.shape was {rewards.shape} and dones.shape was {dones.shape}')
    assert dones.shape == values.shape, (
        f'dones.shape was {dones.shape} and values.shape was {values.shape}')
    assert returns.shape == values.shape, (
        f'returns.shape was {returns.shape} and values.shape was {values.shape}'
    )
    assert values.shape == mask.shape, (
        f'values.shape was {values.shape} and mask.shape was {mask.shape}')
    assert returns.shape[0] == dist_inputs.shape[0], (
        f'returns.shape[0] was {returns.shape[0]} and dist_inputs.shape[0] was '
        f'{dist_inputs.shape[0]}')

    new_log_probs = NewLogProbs(dist_inputs, actions, log_prob_fun)
    assert new_log_probs.shape == mask.shape, (
        f'new_log_probs.shape was {new_log_probs.shape} and mask.shape was '
        f'{mask.shape}')

    # jaxified versions of
    # returns[dones] = rewards[dones]
    # values[dones] = 0
    returns = jnp.where(dones, rewards, returns)
    values = jnp.where(dones, jnp.zeros_like(values), values)
    advantages = returns - values
    if normalize_advantages:
        advantages = advantages - jnp.mean(advantages)
        advantages /= jnp.std(advantages) + 1e-8
    assert new_log_probs.shape == advantages.shape, (
        f'new_log_probs.shape was {new_log_probs.shape} and advantages.shape was '
        f'{advantages.shape}')

    # One of the motivation to the squeezes and assertions is to
    # avoid [128,1] * [128,1,1] * [128] multiplications in the definition
    # of the a2c objective - we insist on the same shapes
    a2c_objective = -jnp.sum(new_log_probs * advantages * mask) / jnp.sum(mask)
    return a2c_objective
Example #6
0
    def f(model_output, targets, weights):  # pylint: disable=invalid-name
        """Returns elementwise-weighted L2 norm of `model_output - targets`.

    Args:
      model_output: Output from one batch, treated as an unanalyzed tensor.
      targets: Tensor of same shape as `model_output` containing element-wise
          target values.
      weights: Tensor of same shape as `model_output` and `targets`.
    """
        shapes.assert_same_shape(model_output, targets)
        shapes.assert_same_shape(targets, weights)
        l2 = weights * (model_output - targets)**2
        return jnp.sum(l2) / jnp.sum(weights)
Example #7
0
    def f(new_log_probs, advantages, old_log_probs, mask):
      # new_log_probs of the shape float32[128,1]
      # advantages of the shape int32[128,1]
      # old_log_probs of the shape int32[128,1]
      # mask of the shape int32[128,1]
      if new_log_probs.shape != advantages.shape:
        raise ValueError('New log-probs and advantages shapes '
                         'should be the same, %s != %s' % (new_log_probs.shape,
                                                           advantages.shape))
      if new_log_probs.shape != old_log_probs.shape:
        raise ValueError('New log-probs and old log-probs shapes '
                         'should be the same, %s != %s' % (new_log_probs.shape,
                                                           old_log_probs.shape))
      if new_log_probs.shape != mask.shape:
        raise ValueError('New log-probs and mask shapes should be the same'
                         ', %s != %s' % (new_log_probs.shape, mask.shape))

      # The ratio between new_probs and old_probs expressed
      # using log_probs and exponentaion
      probs_ratio = jnp.exp(new_log_probs - old_log_probs)
      if advantages.shape != probs_ratio.shape:
        raise ValueError('New log-probs and old log probs shapes '
                         'should be the same, %s != %s' % (advantages.shape,
                                                           probs_ratio.shape))
      unclipped_objective = probs_ratio * advantages
      clipped_objective = jnp.clip(probs_ratio,
                                   1 - self._epsilon,
                                   1 + self._epsilon) * advantages

      if unclipped_objective.shape != probs_ratio.shape:
        raise ValueError('unclipped_objective and clipped_objective shapes '
                         'should be the same, %s != %s' % (
                             unclipped_objective.shape,
                             clipped_objective.shape))

      ppo_objective = jnp.minimum(unclipped_objective, clipped_objective)

      if ppo_objective.shape != mask.shape:
        raise ValueError('ppo_objective and mask shapes '
                         'should be the same, %s != %s' % (
                             ppo_objective.shape,
                             mask.shape))

      ppo_loss = -jnp.sum(ppo_objective * mask) / jnp.sum(mask)
      entropy_vec = self._policy_dist.entropy(
          new_log_probs) * self._entropy_coeff
      entropy_loss = jnp.mean(entropy_vec)
      combined_loss = ppo_loss - entropy_loss

      return combined_loss
Example #8
0
    def f(model_output, targets, weights):  # pylint: disable=invalid-name
        """Returns weighted sum-of-squared-errors for `model_output` vs. `targets`.

    Args:
      model_output: Output from one batch, typically a 2- or 3-d array of
          float-valued elements.
      targets: Tensor of same shape as `model_output` containing element-wise
          target values.
      weights: Tensor of same shape as `model_output` and `targets`, containing
          element-wise weight values.
    """
        shapes.assert_same_shape(model_output, targets)
        shapes.assert_same_shape(targets, weights)
        weighted_sse = weights * (model_output - targets)**2
        return jnp.sum(weighted_sse) / jnp.sum(weights)
Example #9
0
def _n_weights_per_core(weights):  # pylint: disable=invalid-name
    """Calculates the number of weights per core.

  In multi-device settings, gradients and losses are averaged over all devices.
  When loss is weighted and the number of weights can differ by device, e.g.,
  when the weights represent the number of tokens in a batch of sentences (which
  can differ from device to device), we want to make sure each token on each
  device is weighted in the same way. This function ensures that by reporting
  the number of weights per core in multi-core settings (and simply
  np.sum(weights) in a single-core setting).

  Args:
    weights: tensor with arbitrary shape

  Returns:
    a scalar equal to np.sum(weights) in 1-machine settings and to the sum
    of weights over all cores divided by the number of cores otherwise
  """
    weights_sum = jnp.sum(weights)
    if fastmath.device_count() < 2:
        return weights_sum
    else:
        try:
            n_devices_total = fastmath.psum(jnp.array(1.0), 'batch')
            return fastmath.psum(weights_sum, 'batch') / n_devices_total
        except (NameError,
                ValueError):  # running outside of pmap, e.g., on init
            return weights_sum  # fall back to the sum
Example #10
0
 def _l2_norm(self, flat_list):
   """Returns the aggregate L2 norm of a list of tensors."""
   if fastmath.is_backend(fastmath.Backend.JAX):
     norm = jnp.sqrt(sum(jnp.vdot(x, x) for x in flat_list))
   else:  # TODO(lukaszkaiser): add vdot to TF-numpy
     norm = jnp.sqrt(sum(jnp.sum(x*x) for x in flat_list))
   return norm
Example #11
0
 def log_prob(self, inputs, point):
   inputs = tl.LogSoftmax()(self._unflatten_inputs(inputs))
   return jnp.sum(
       # Select the logits specified by point.
       inputs * tl.one_hot(point, self._n_categories),
       # Sum over the parameter dimensions.
       axis=[-a for a in range(1, len(self._shape) + 2)],
   )
Example #12
0
 def log_prob(self, inputs, point):
     point = point.reshape(inputs.shape[:-1] + (-1, ))
     return (
         # L2 term.
         -jnp.sum((point - inputs)**2, axis=-1) / (2 * self._std**2) -
         # Normalizing constant.
         ((jnp.log(self._std) + jnp.log(jnp.sqrt(2 * jnp.pi))) *
          np.prod(self._shape)))
Example #13
0
 def _UpdateRow(x):
   # (L, H), (L1, H) & (L2, H)
   row_ed, row_e, _ = x
   mask_e = row_e != 0
   len_e = jnp.sum(mask_e, dtype=jnp.int32)
   # In `row_ed` start where encoder tokens/vecs end, i.e. are index `len_e`
   # and pick up (L2, H) tensor slice from there.
   return jax.lax.dynamic_slice(row_ed, (len_e, 0), (L2, H))
Example #14
0
 def _UpdateRow(x):
     # (L, H), (L1, H) & (L2, H)
     row_ed, row_e, _ = x
     mask_e = row_e != 0
     len_e = jnp.sum(mask_e, dtype=jnp.int32)
     # In `row_ed` start where encoder tokens/vecs end, i.e. are index `len_e`
     # and pick up (L2, H) tensor slice from there.
     zero = jnp.array(0, dtype=len_e.dtype)  # avoid int32/int64 mismatch
     return fastmath.dynamic_slice(row_ed, (len_e, zero), (L2, H))
Example #15
0
 def _UpdateRow(x):
   # row_e - (L1, H), row_d - (L2, H), row_mask_e - (L1,)
   row_e, row_d, row_mask_e = x
   # final_row - (L1+L2, H)
   final_row = jnp.concatenate([row_e, jnp.zeros_like(row_d)], axis=0)
   # Find the last real token/vector of the encoder.
   e_idx = jnp.sum(row_mask_e, dtype=jnp.int32)
   # Starting after that index, update with the decoder row.
   return jax.lax.dynamic_update_slice(final_row, row_d, (e_idx, 0))
Example #16
0
 def _aggregate_values(self, values, aggregate_max, act_log_probs):
     if self._q_value:
         if aggregate_max:
             values = jnp.max(values, axis=1)
         elif self._sample_all_discrete_actions:
             values = jnp.sum(values * jnp.exp(act_log_probs), axis=1)
         else:
             values = jnp.mean(values, axis=1)
     return np.array(values)  # Move the values to CPU.
Example #17
0
 def _UpdateRow(x):
     # row_e - (L1, H), row_d - (L2, H), row_mask_e - (L1,)
     row_e, row_d, row_mask_e = x
     # final_row - (L1+L2, H)
     final_row = jnp.concatenate([row_e, jnp.zeros_like(row_d)], axis=0)
     # Find the last real token/vector of the encoder.
     e_idx = jnp.sum(row_mask_e, dtype=jnp.int32)
     # Starting after that index, update with the decoder row.
     zero = jnp.array(0, dtype=e_idx.dtype)  # avoid int32/int64 mismatch
     return fastmath.dynamic_update_slice(final_row, row_d, (e_idx, zero))
Example #18
0
 def log_prob(self, inputs, point):
     point = point.reshape(inputs.shape[:-1] + (-1, ))
     (mean, std) = self._params(inputs)
     return -jnp.sum(
         # Scaled distance.
         (point - mean)**2 / (2 * std**2) +
         # Normalizing constant.
         (jnp.log(std) + jnp.log(jnp.sqrt(2 * jnp.pi))),
         axis=-1,
     )
def test_model(preds, target):
    """Function to test the model.

    Args:
        preds (jax.interpreters.xla.DeviceArray): Predictions of a list of batches of tensors corresponding to lines of text.
        target (jax.interpreters.xla.DeviceArray): Actual list of batches of tensors corresponding to lines of text.

    Returns:
        float: log_perplexity of the model.
    """
    ### START CODE HERE (Replace instances of 'None' with your code) ###
    total_log_ppx = np.sum(preds*tl.one_hot(target,preds.shape[-1]) , axis= -1) # HINT: tl.one_hot() should replace one of the Nones

    non_pad = 1.0 - np.equal(target, 0)          # You should check if the target equals 0
    ppx = total_log_ppx * non_pad                             # Get rid of the padding

    log_ppx = np.sum(ppx) / np.sum(non_pad)
    ### END CODE HERE ###
    
    return -log_ppx
Example #20
0
 def f(values, weights):  # pylint: disable=invalid-name
     # This function assumes weights are 0 or 1.
     # Then compute 1: not-correct, 0: correct or masked
     not_correct = (1.0 - values) * weights
     axis_to_sum = list(range(1, len(not_correct.shape)))
     # Summing not-correct on all axes but batch. We're summing 0s and 1s,
     # so the sum is 0 if it's all 0 and >=1 in all other cases.
     not_correct_seq = jnp.sum(not_correct, axis=axis_to_sum)
     # Sequence is correct if not_correct_seq is 0, reverting here.
     correct_seq = 1.0 - jnp.minimum(1.0, not_correct_seq)
     return jnp.mean(correct_seq)  # Mean over batch.
Example #21
0
    def f(log_probs, advantages, old_log_probs, mask):
      del old_log_probs  # Not used in A2C.
      # log_probs of the shape float32[128,1]
      # advantages of the shape int32[128,1]
      # mask of the shape int32[128,1]
      if log_probs.shape != advantages.shape:
        raise ValueError('New log-probs and advantages shapes '
                         'should be the same, %s != %s' % (log_probs.shape,
                                                           advantages.shape))
      if log_probs.shape != mask.shape:
        raise ValueError('New log-probs and mask shapes should be the same'
                         ', %s != %s' % (log_probs.shape, mask.shape))

      a2c_objective = -jnp.sum(log_probs * advantages * mask) / jnp.sum(mask)

      entropy_vec = self._policy_dist.entropy(log_probs) * self._entropy_coeff
      entropy_loss = jnp.mean(entropy_vec)

      combined_loss = a2c_objective - entropy_loss

      return combined_loss
Example #22
0
 def fn(dist_inputs, actions, q_values, act_log_probs, mask):
     del dist_inputs, actions, mask
     q_values = jnp.swapaxes(q_values, 0, 1)
     act_log_probs = jnp.swapaxes(act_log_probs, 0, 1)
     if self._sample_all_discrete_actions:
         values = jnp.sum(q_values * jnp.exp(act_log_probs), axis=0)
     else:
         values = jnp.mean(q_values, axis=0)
     advantages = q_values - values  # Broadcasting values over n_samples
     if preprocess:
         advantages = self._preprocess_advantages(advantages)
     return advantages
Example #23
0
def _category_cross_entropy(  # pylint: disable=invalid-name
    model_output, targets, label_smoothing):
  """Computes category cross entropy with label smoothing."""
  n_categories = model_output.shape[-1]
  target_distributions = core.one_hot(targets, n_categories)
  if label_smoothing:
    if label_smoothing < 0. or label_smoothing > 1.:
      raise ValueError(
          f'Arg label_smoothing ({label_smoothing}) must be between 0 and 1.')
    target_distributions *= (1. - label_smoothing)
    target_distributions += label_smoothing / n_categories
  model_log_distributions = core.log_softmax(model_output)
  return - jnp.sum(target_distributions * model_log_distributions, axis=-1)
Example #24
0
    def _aggregate_values(self, values, aggregate, act_log_probs):
        # Normalize the Q-values before aggragetion, so it can adapt to the scale
        # of the returns. This does not affect mean and max aggregation.
        scale = 1
        epsilon = 1e-5
        if self._q_value_normalization == 'std':
            scale = jnp.std(values) + epsilon
        elif self._q_value_normalization == 'abs':
            scale = jnp.mean(jnp.abs(values - jnp.mean(values))) + epsilon
        values /= scale

        temp = self._q_value_temperature
        if self._q_value:
            assert values.shape[:2] == (self._value_batch_size,
                                        self._q_value_n_samples)
            if aggregate == 'max':
                # max_a Q(s, a)
                values = jnp.max(values, axis=1)
            elif aggregate == 'softmax':
                # sum_a (Q(s, a) * w(s, a))
                # where w(s, .) = softmax (Q(s, .) / T)
                weights = tl.Softmax(axis=1)(values / temp)
                values = jnp.sum(values * weights, axis=1)
            elif aggregate == 'logsumexp':
                # log(mean_a exp(Q(s, a) / T)) * T
                n = values.shape[1]
                values = (fastmath.logsumexp(values / temp, axis=1) -
                          jnp.log(n)) * temp
            else:
                assert aggregate == 'mean'
                # mean_a Q(s, a)
                if self._sample_all_discrete_actions:
                    values = jnp.sum(values * jnp.exp(act_log_probs), axis=1)
                else:
                    values = jnp.mean(values, axis=1)

        # Re-scale the Q-values after aggregation.
        values *= scale
        return np.array(values)  # Move the values to CPU.
Example #25
0
def Sum(axis=-1, keepdims=False):
  """Returns a layer that computes sums using one tensor axis.

  `Sum` uses one tensor axis to form groups of values and replaces each group
  with the sum of that group. The resulting sum values can either remain in
  their own size 1 axis (`keepdims=True`), or that axis can be removed from the
  overall tensor (default `keepdims=False`), lowering the rank of the tensor by
  one.

  Args:
    axis: Axis along which values are grouped for computing a sum.
    keepdims: If `True`, keep the resulting size 1 axis as a separate tensor
        axis; else, remove that axis.
  """
  return Fn('Sum', lambda x: jnp.sum(x, axis=axis, keepdims=keepdims))
Example #26
0
def log_gaussian_diag_pdf(x, mu, diag_sigma):  # pylint: disable=invalid-name
    """Returns `log N(x | mu, eye(diag_sigma))`.

  Args:
    x: <tbd>
    mu: <tbd>
    diag_sigma: <tbd>
  """
    a = mu.shape[-1] * jnp.log(2 * jnp.pi)
    b = jnp.sum(jnp.log(diag_sigma), axis=-1)
    y = x - mu / diag_sigma
    y = jnp.expand_dims(y, axis=-1)
    xm = jnp.expand_dims(x - mu, axis=-2)
    c = jnp.matmul(xm, y)
    c = jnp.squeeze(jnp.squeeze(c, axis=-1), axis=-1)
    return -0.5 * (a + b + c)
Example #27
0
    def _l2_norm(self, flat_list):
        """Returns an L2-like norm of all elements of all tensors in `flat_list`.

    Args:
      flat_list: Collection of tensors as a flat list (rather than, e.g., a
          tree).

    Returns:
      A scalar value computed as if all the tensors in `flat_list` were joined
      and flattened into a single vector, and then the L2 norm of that vector
      was calculated.
    """
        if fastmath.is_backend(fastmath.Backend.JAX):
            norm = jnp.sqrt(sum(jnp.vdot(x, x) for x in flat_list))
        else:  # TODO(lukaszkaiser): add vdot to TF-numpy
            norm = jnp.sqrt(sum(jnp.sum(x * x) for x in flat_list))
        return norm
Example #28
0
def TripletLossFn(v1, v2, margin=0.25):
    """Custom Loss function.

    Args:
        v1 (numpy.ndarray): Array with dimension (batch_size, model_dimension) associated to Q1.
        v2 (numpy.ndarray): Array with dimension (batch_size, model_dimension) associated to Q2.
        margin (float, optional): Desired margin. Defaults to 0.25.

    Returns:
        jax.interpreters.xla.DeviceArray: Triplet Loss.
    """
    ### START CODE HERE (Replace instances of 'None' with your code) ###

    # use fastnp to take the dot product of the two batches (don't forget to transpose the second argument)
    scores = fastnp.dot(v1, fastnp.transpose(v2))  # pairwise cosine sim
    # calculate new batch size
    batch_size = len(scores)
    # use fastnp to grab all postive `diagonal` entries in `scores`
    positive = fastnp.diagonal(scores)  # the positive ones (duplicates)
    # multiply `fastnp.eye(batch_size)` with 2.0 and subtract it out of `scores`
    negative_without_positive = scores - fastnp.eye(batch_size)
    # take the row by row `max` of `negative_without_positive`.
    # Hint: negative_without_positive.max(axis = [?])
    closest_negative = negative_without_positive.max(axis=[1])
    # subtract `fastnp.eye(batch_size)` out of 1.0 and do element-wise multiplication with `scores`
    negative_zero_on_duplicate = (1.0 - fastnp.eye(batch_size)) * scores
    # use `fastnp.sum` on `negative_zero_on_duplicate` for `axis=1` and divide it by `(batch_size - 1)`
    mean_negative = fastnp.sum(negative_zero_on_duplicate,
                               axis=1) / (batch_size - 1)
    # compute `fastnp.maximum` among 0.0 and `A`
    # A = subtract `positive` from `margin` and add `closest_negative`
    triplet_loss1 = fastnp.maximum((margin - positive + closest_negative), 0.0)
    # compute `fastnp.maximum` among 0.0 and `B`
    # B = subtract `positive` from `margin` and add `mean_negative`
    triplet_loss2 = fastnp.maximum((margin - positive + mean_negative), 0.0)
    # add the two losses together and take the `fastnp.mean` of it
    triplet_loss = fastnp.mean(triplet_loss1 + triplet_loss2)

    ### END CODE HERE ###

    return triplet_loss
Example #29
0
        def LossInput(dist_inputs, actions, q_values, act_log_probs, mask):  # pylint: disable=invalid-name
            """Calculates action log probabilities and normalizes advantages."""
            # (batch_size, n_samples, ...) -> (n_samples, batch_size, ...)
            q_values = jnp.swapaxes(q_values, 0, 1)
            mask = jnp.swapaxes(mask, 0, 1)
            actions = jnp.swapaxes(actions, 0, 1)
            act_log_probs = jnp.swapaxes(act_log_probs, 0, 1)

            # TODO(pkozakowski,lukaszkaiser): Try max here, or reweighting?
            if self._sample_all_discrete_actions:
                values = jnp.sum(q_values * jnp.exp(act_log_probs), axis=0)
            else:
                values = jnp.mean(q_values, axis=0)
            advantages = q_values - values  # Broadcasting values over n_samples
            advantages = self._preprocess_advantages(advantages)

            # Broadcast inputs and calculate log-probs
            dist_inputs = jnp.broadcast_to(
                dist_inputs, (self._q_value_n_samples, ) + dist_inputs.shape)
            log_probs = self._policy_dist.log_prob(dist_inputs, actions)
            return (log_probs, advantages, act_log_probs, mask)
Example #30
0
    def policy_batch(self, trajectory_batch, shape_only=False):
        """Computes a policy training batch based on a trajectory batch.

    Args:
      trajectory_batch: trax.rl.task.TimeStepBatch with a batch of trajectory
        slices. Elements should have shape (batch_size, seq_len, ...).
      shape_only: Whether to return dummy zero arrays of correct shape. Useful
        for initializing models.

    Returns:
      Triple (observations, actions, weights), where weights are the
      advantage-based weights for the policy loss. Shapes:
      - observations: (batch_size, seq_len) + observation_shape
      - actions: (batch_size, seq_len) + action_shape
      - weights: (batch_size, seq_len)
    """
        advantages = self.calculate_advantages(trajectory_batch,
                                               shape_only=shape_only)
        (observations, actions,
         mask) = self.trim_batch(trajectory_batch, advantages)
        weights = self.calculate_weights(advantages) * mask / jnp.sum(mask)
        return (observations, actions, weights)