예제 #1
0
파일: metrics.py 프로젝트: wbh-lab/trax
 def f(model_output, target_category):  # pylint: disable=invalid-name
   shapes.assert_same_shape(model_output, target_category)
   batch_size = model_output.shape[0]
   j = jnp.dot(jnp.transpose(target_category), jnp.log(model_output))
   j += jnp.dot(jnp.transpose(1 - target_category), jnp.log(1 - model_output))
   j = -1.0/batch_size * jnp.squeeze(j)
   return j
예제 #2
0
 def f(model_output, targets, weights):  # pylint: disable=invalid-name
     shapes.assert_same_shape(model_output, targets)
     shapes.assert_same_shape(model_output, weights)
     l1_dist = jnp.abs(model_output - targets)
     smooth_dist = jnp.where(l1_dist < 1, 0.5 * l1_dist**2, l1_dist - 0.5)
     weighted_smooth_dist = weights * smooth_dist
     return jnp.sum(weighted_smooth_dist) / jnp.sum(weights)
예제 #3
0
    def policy_batches_stream(self):
        """Use the RLTask self._task to create inputs to the policy model."""
        # For now TD-0 estimation of the value. TODO(pkozakowski): Support others?
        for np_trajectory in self._task.trajectory_batch_stream(
                self._policy_batch_size,
                epochs=self._replay_epochs,
                max_slice_length=self._max_slice_length,
                include_final_state=False,
        ):
            (q_values, actions,
             act_log_probs) = self._run_value_model(np_trajectory.observations,
                                                    np_trajectory.dist_inputs)
            shapes.assert_same_shape(q_values, act_log_probs)

            # q_values shape: (batch_size, n_samples, length)
            if len(q_values.shape) != 3:
                raise ValueError(
                    'Q-values are expected to have shape [batch_size, ' +
                    'n_samples, length], got: %s' % str(q_values.shape))
            if q_values.shape[1] != self._q_value_n_samples:
                raise ValueError(
                    'Q-values dimension 1 should = n_samples, %d != %d' %
                    (q_values.shape[1], self._q_value_n_samples))
            if q_values.shape[0] != self._policy_batch_size:
                raise ValueError(
                    'Q-values dimension 0 should = policy batch size, ' +
                    '%d!=%d' % (q_values.shape[1], self._policy_batch_size))

            mask = np_trajectory.mask
            mask = np.reshape(mask, [mask.shape[0], 1] + list(mask.shape[1:]))
            mask = jnp.broadcast_to(mask, q_values.shape)
            shapes.assert_same_shape(mask, q_values)
            yield (np_trajectory.observations, actions, q_values,
                   act_log_probs, mask)
예제 #4
0
 def f(model_output, targets, weights):  # pylint: disable=invalid-name
     predictions = jnp.argmax(model_output, axis=-1)
     shapes.assert_same_shape(predictions, targets)
     position_is_padding = jnp.equal(weights, 0)
     position_is_accurate = jnp.logical_or(jnp.equal(predictions, targets),
                                           position_is_padding)
     sequence_is_accurate = jnp.all(position_is_accurate, axis=-1)
     return jnp.average(sequence_is_accurate)
예제 #5
0
파일: training.py 프로젝트: yliu45/trax
 def f(values, actions, returns, mask):
   ind_0, ind_1 = np.indices(actions.shape)
   # We calculate length using the shape of returns
   # and adequatly remove a superflous slice of values.
   # An analogous operation is done in value_batches_stream.
   length = returns.shape[1]
   values = values[:, :length, :]
   selected_values = values[ind_0, ind_1, actions]
   shapes.assert_same_shape(selected_values, returns)
   shapes.assert_same_shape(selected_values, mask)
   return jnp.sum(selected_values) / jnp.sum(mask)
예제 #6
0
파일: metrics.py 프로젝트: ppvalluri09/trax
    def f(model_output, targets, weights):  # pylint: disable=invalid-name
        """Returns elementwise-weighted L2 norm of `model_output - targets`.

    Args:
      model_output: Output from one batch, treated as an unanalyzed tensor.
      targets: Tensor of same shape as `model_output` containing element-wise
          target values.
      weights: Tensor of same shape as `model_output` and `targets`.
    """
        shapes.assert_same_shape(model_output, targets)
        shapes.assert_same_shape(targets, weights)
        l2 = weights * (model_output - targets)**2
        return jnp.sum(l2) / jnp.sum(weights)
예제 #7
0
파일: training.py 프로젝트: yliu45/trax
 def f(values, actions, returns, mask):
   ind_0, ind_1 = np.indices(actions.shape)
   # We calculate length using the shape of returns
   # and adequatly remove a superflous slice of values.
   # An analogous operation is done in value_batches_stream.
   length = returns.shape[1]
   values = values[:, :length, :]
   selected_values = values[ind_0, ind_1, actions]
   shapes.assert_same_shape(selected_values, returns)
   shapes.assert_same_shape(selected_values, mask)
   if self._smoothl1loss:
     return tl.SmoothL1Loss().forward((selected_values, returns, mask))
   else:
     return tl.L2Loss().forward((selected_values, returns, mask))
예제 #8
0
파일: metrics.py 프로젝트: yaoshuyin/trax
    def f(model_output, targets, weights):  # pylint: disable=invalid-name
        """Returns weighted sum-of-squared-errors for `model_output` vs. `targets`.

    Args:
      model_output: Output from one batch, typically a 2- or 3-d array of
          float-valued elements.
      targets: Tensor of same shape as `model_output` containing element-wise
          target values.
      weights: Tensor of same shape as `model_output` and `targets`, containing
          element-wise weight values.
    """
        shapes.assert_same_shape(model_output, targets)
        shapes.assert_same_shape(targets, weights)
        weighted_sse = weights * (model_output - targets)**2
        return jnp.sum(weighted_sse) / jnp.sum(weights)
예제 #9
0
파일: metrics.py 프로젝트: yaoshuyin/trax
    def smoothl1loss(model_output, targets, weights):  # pylint: disable=invalid-name
        r"""Returns weighted smooth L1 norm of `model_output - targets`.

    The smooth L1 loss, also known as the Huber loss, is defined as:
    .. math::
        z_i =
        \begin{cases}
        0.5 (x_i - y_i)^2, & \text{if } |x_i - y_i| < 1 \\
        |x_i - y_i| - 0.5, & \text{otherwise }
        \end{cases}

    Args:
      model_output: Output from one batch, treated as an unanalyzed tensor.
      targets: Tensor of same shape as `model_output` containing element-wise
          target values.
      weights: Tensor of same shape as `model_output` and `targets`, containing
          element-wise weight values.
    """
        shapes.assert_same_shape(model_output, targets)
        shapes.assert_same_shape(targets, weights)
        l1_dist = jnp.abs(model_output - targets)
        smooth_dist = jnp.where(l1_dist < 1, 0.5 * l1_dist**2, l1_dist - 0.5)
        shapes.assert_same_shape(smooth_dist, weights)
        weighted_smooth_dist = weights * smooth_dist
        return jnp.sum(weighted_smooth_dist) / jnp.sum(weights)
예제 #10
0
  def policy_batches_stream(self):
    """Use the RLTask self._task to create inputs to the policy model."""
    # For now TD-0 estimation of the value. TODO(pkozakowski): Support others?
    for np_trajectory in self._task.trajectory_batch_stream(
        self._policy_batch_size,
        epochs=self._replay_epochs,
        max_slice_length=self._max_slice_length,
        include_final_state=False,
    ):
      (q_values, actions) = self._run_value_model(
          np_trajectory.observations, np_trajectory.dist_inputs
      )
      # TODO(pkozakowski): Try max here.
      values = jnp.mean(q_values, axis=0)

      if len(values.shape) != 2:
        raise ValueError('Values are expected to have shape ' +
                         '[batch_size, length], got: %s' % str(values.shape))
      if values.shape[0] != self._policy_batch_size:
        raise ValueError('Values first dimension should = policy batch size, ' +
                         '%d != %d' %(values.shape[0], self._policy_batch_size))

      # q_values shape: (n_samples, batch_size, length)
      # values shape: (batch_size, length)
      # Computing advantages by broadcasting over n_samples.
      advantages = q_values - values
      mask = jnp.broadcast_to(np_trajectory.mask, advantages.shape)

      shapes.assert_shape_equals(
          advantages, (self._q_value_n_samples,) + values.shape
      )
      shapes.assert_same_shape(mask, advantages)

      # Swapping the n_samples and batch_size axes, so the input is split
      # between accelerators along the batch_size axis.
      advantages = jnp.swapaxes(advantages, 0, 1)
      mask = jnp.swapaxes(mask, 0, 1)

      yield (np_trajectory.observations, actions, advantages, mask, mask)
예제 #11
0
파일: metrics.py 프로젝트: huyunzhi/trax
def L2Loss(inputs):
    y_hat, y, mask = inputs
    shapes.assert_same_shape(y_hat, y)
    shapes.assert_same_shape(y, mask)
    l2 = mask * (y_hat - y)**2
    return np.sum(l2) / np.sum(mask)
예제 #12
0
 def f(model_output, targets):  # pylint: disable=invalid-name
     predictions = jnp.argmax(model_output, axis=-1)
     shapes.assert_same_shape(predictions, targets)
     n_total = predictions.size
     n_correct = jnp.sum(jnp.equal(predictions, targets))
     return n_correct / n_total
예제 #13
0
 def RawL2Loss(inputs, **unused_kwargs):
     y_hat, y = inputs
     shapes.assert_same_shape(y_hat, y)
     return np.mean((y_hat - y)**2)
예제 #14
0
 def f(model_output, targets, weights):  # pylint: disable=invalid-name
     shapes.assert_same_shape(model_output, targets)
     shapes.assert_same_shape(model_output, weights)
     weighted_sse = weights * (model_output - targets)**2
     return jnp.sum(weighted_sse) / jnp.sum(weights)
예제 #15
0
 def MaskedL2Loss(inputs, **unused_kwargs):
     y_hat, y, mask = inputs
     shapes.assert_same_shape(y_hat, y)
     shapes.assert_same_shape(y, mask)
     l2 = mask * (y_hat - y)**2
     return np.sum(l2) / np.sum(mask)
예제 #16
0
 def f(model_output, targets, weights):  # pylint: disable=invalid-name
     predictions = jnp.argmax(model_output, axis=-1)
     shapes.assert_same_shape(predictions, targets)
     ones_and_zeros = jnp.equal(predictions, targets)
     return jnp.sum(ones_and_zeros * weights) / jnp.sum(weights)
예제 #17
0
파일: metrics.py 프로젝트: zhaoqiuye/trax
 def f(y_hat, y, mask):  # pylint: disable=invalid-name
     shapes.assert_same_shape(y_hat, y)
     shapes.assert_same_shape(y, mask)
     l2 = mask * (y_hat - y)**2
     return np.sum(l2) / np.sum(mask)