Example #1
0
def unbatch_nested_tensors(tensors, specs=None):
    """Remove the batch dimension if needed from nested tensors using their specs.

  If specs is None, the first dimension of each tensor will be removed.
  If specs are provided, each tensor is compared to the corresponding spec,
  and the first dimension is removed only if the tensor was batched.

  Args:
    tensors: Nested list/tuple or dict of batched Tensors.
    specs: Nested list/tuple or dict of TensorSpecs, describing the shape of the
      non-batched Tensors.

  Returns:
    A nested non-batched version of each tensor.
  Raises:
    ValueError: if the tensors and specs have incompatible dimensions or shapes.
  """
    if specs is None:
        return tf.nest.map_structure(lambda x: composite.squeeze(x, 0),
                                     tensors)

    unbatched_tensors = []
    flat_tensors, flat_shapes = _flatten_and_check_shape_nested_tensors(
        tensors, specs)
    for tensor, shape in zip(flat_tensors, flat_shapes):
        if tensor.shape.rank == shape.rank + 1:
            tensor = composite.squeeze(tensor, 0)
        unbatched_tensors.append(tensor)
    return tf.nest.pack_sequence_as(tensors, unbatched_tensors)
Example #2
0
def experience_to_transitions(experience, squeeze_time_dim):
  """Break experience to transitions."""
  transitions = to_transition(experience)

  if squeeze_time_dim:
    transitions = tf.nest.map_structure(lambda x: composite.squeeze(x, 1),
                                        transitions)
  time_steps, policy_steps, next_time_steps = transitions
  return time_steps, policy_steps, next_time_steps
Example #3
0
    def _experience_to_transitions(self, experience):
        transitions = trajectory.to_transition(experience)
        if not self._q_network.state_spec:
            transitions = tf.nest.map_structure(
                lambda x: composite.squeeze(x, 1), transitions)

        time_steps, policy_steps, next_time_steps = transitions
        actions = policy_steps.action
        return time_steps, actions, next_time_steps
Example #4
0
def experience_to_transitions(experience: Trajectory,
                              squeeze_time_dim: bool) -> Transition:
    """Break experience to transitions."""
    transitions = to_transition(experience)

    if squeeze_time_dim:
        transitions = tf.nest.map_structure(lambda x: composite.squeeze(x, 1),
                                            transitions)

    return transitions
Example #5
0
def experience_to_transitions(
    experience: Trajectory, squeeze_time_dim: bool
) -> Tuple[ts.TimeStep, policy_step.PolicyStep, ts.TimeStep]:
    """Break experience to transitions."""
    transitions = to_transition(experience)

    if squeeze_time_dim:
        transitions = tf.nest.map_structure(lambda x: composite.squeeze(x, 1),
                                            transitions)
    time_steps, policy_steps, next_time_steps = transitions
    return time_steps, policy_steps, next_time_steps
  def __call__(self, value: typing.Any):
    """Converts `value` to a Transition.  Performs data validation and pruning.

    - If `value` is already a `Transition`, only validation is performed.
    - If `value` is a `Trajectory` and `squeeze_time_dim = True` then
      `value` it must have tensors with shape `[B, T=2]` outer dims.
      This is converted to a `Transition` object without a time
      dimension.
    - If `value` is a `Trajectory` with tensors containing a time dimension
      having `T != 2`, a `ValueError` is raised.

    Args:
      value: A `Trajectory` or `Transition` object to convert.

    Returns:
      A validated and pruned `Transition`.  If `squeeze_time_dim = True`,
      the resulting `Transition` has tensors with shape `[B, ...]`.  Otherwise,
      the tensors will have shape `[B, T - 1, ...]`.

    Raises:
      TypeError: If `value` is not one of `Trajectory` or `Transition`.
      ValueError: If `value` has structure that doesn't match the converter's
        spec.
      TypeError: If `value` has a structure that doesn't match the converter's
        spec.
      ValueError: If `squeeze_time_dim=True` and `value` is a `Trajectory`
        with a time dimension having value other than `T=2`.
    """
    if isinstance(value, trajectory.Transition):
      pass
    elif isinstance(value, trajectory.Trajectory):
      required_sequence_length = 2 if self._squeeze_time_dim else None
      _validate_trajectory(
          value,
          self._data_context.trajectory_spec,
          sequence_length=required_sequence_length)
      value = trajectory.to_transition(value)
      # Remove the now-singleton time dim.
      if self._squeeze_time_dim:
        value = tf.nest.map_structure(
            lambda x: composite.squeeze(x, axis=1), value)
    else:
      raise TypeError('Input type not supported: {}'.format(value))

    self._validate_transition(value)
    value = nest_utils.prune_extra_keys(
        self._data_context.transition_spec, value)
    return value
Example #7
0
def to_n_step_transition(
    trajectory: Trajectory,
    gamma: types.Float
) -> Transition:
  """Create an n-step transition from a trajectory with `T=N + 1` frames.

  **NOTE** Tensors of `trajectory` are sliced along their *second* (`time`)
  dimension, to pull out the appropriate fields for the n-step transitions.

  The output transition's `next_time_step.{reward, discount}` will contain
  N-step discounted reward and discount values calculated as:

  ```
  next_time_step.reward = r_t +
                          g^{1} * d_t * r_{t+1} +
                          g^{2} * d_t * d_{t+1} * r_{t+2} +
                          g^{3} * d_t * d_{t+1} * d_{t+2} * r_{t+3} +
                          ...
                          g^{N-1} * d_t * ... * d_{t+N-2} * r_{t+N-1}
  next_time_step.discount = g^{N-1} * d_t * d_{t+1} * ... * d_{t+N-1}
  ```

  In python notation:

  ```python
  discount = gamma**(N-1) * reduce_prod(trajectory.discount[:, :-1])
  reward = discounted_return(
      rewards=trajectory.reward[:, :-1],
      discounts=gamma * trajectory.discount[:, :-1])
  ```

  When `trajectory.discount[:, :-1]` is an all-ones tensor, this is equivalent
  to:

  ```python
  next_time_step.discount = (
      gamma**(N-1) * tf.ones_like(trajectory.discount[:, 0]))
  next_time_step.reward = (
      sum_{n=0}^{N-1} gamma**n * trajectory.reward[:, n])
  ```

  Args:
    trajectory: An instance of `Trajectory`. The tensors in Trajectory must have
      shape `[B, T, ...]`.  `discount` is assumed to be a scalar float,
      hence the shape of `trajectory.discount` must be `[B, T]`.
    gamma: A floating point scalar; the discount factor.

  Returns:
    An N-step `Transition` where `N = T - 1`.  The reward and discount in
    `time_step.{reward, discount}` are NaN.  The n-step discounted reward
    and final discount are stored in `next_time_step.{reward, discount}`.
    All tensors in the `Transition` have shape `[B, ...]` (no time dimension).

  Raises:
    ValueError: if `discount.shape.rank != 2`.
    ValueError: if `discount.shape[1] < 2`.
  """
  _validate_rank(trajectory.discount, min_rank=2, max_rank=2)

  # Use static values when available, so that we can use XLA when the time
  # dimension is fixed.
  time_dim = (tf.compat.dimension_value(trajectory.discount.shape[1])
              or tf.shape(trajectory.discount)[1])

  static_time_dim = tf.get_static_value(time_dim)
  if static_time_dim in (0, 1):
    raise ValueError(
        'Trajectory frame count must be at least 2, but saw {}.  Shape of '
        'trajectory.discount: {}'.format(static_time_dim,
                                         trajectory.discount.shape))

  n = time_dim - 1

  # Use composite calculations to ensure we properly handle SparseTensor etc in
  # the observations.

  # pylint: disable=g-long-lambda

  # Pull out x[:,0] for x in trajectory
  first_frame = tf.nest.map_structure(
      lambda t: composite.squeeze(
          composite.slice_to(t, axis=1, end=1),
          axis=1),
      trajectory)

  # Pull out x[:,-1] for x in trajectory
  final_frame = tf.nest.map_structure(
      lambda t: composite.squeeze(
          composite.slice_from(t, axis=1, start=-1),
          axis=1),
      trajectory)
  # pylint: enable=g-long-lambda

  # When computing discounted return, we need to throw out the last time
  # index of both reward and discount, which are filled with dummy values
  # to match the dimensions of the observation.
  reward = trajectory.reward[:, :-1]
  discount = trajectory.discount[:, :-1]

  policy_steps = policy_step.PolicyStep(
      action=first_frame.action, state=(), info=first_frame.policy_info)

  discounted_reward = value_ops.discounted_return(
      rewards=reward,
      discounts=gamma * discount,
      time_major=False,
      provide_all_returns=False)

  # NOTE: `final_discount` will have one less discount than `discount`.
  # This is so that when the learner/update uses an additional
  # discount (e.g. gamma) we don't apply it twice.
  final_discount = gamma**(n-1) * tf.math.reduce_prod(discount, axis=1)

  time_steps = ts.TimeStep(
      first_frame.step_type,
      # unknown
      reward=tf.nest.map_structure(
          lambda r: np.nan * tf.ones_like(r), first_frame.reward),
      # unknown
      discount=np.nan * tf.ones_like(first_frame.discount),
      observation=first_frame.observation)
  next_time_steps = ts.TimeStep(
      step_type=final_frame.step_type,
      reward=discounted_reward,
      discount=final_discount,
      observation=final_frame.observation)
  return Transition(time_steps, policy_steps, next_time_steps)
Example #8
0
  def __call__(self, value: typing.Any) -> trajectory.Transition:
    """Converts `value` to a Transition.  Performs data validation and pruning.

    - If `value` is already a `Transition`, only validation is performed.
    - If `value` is a `Trajectory` and `squeeze_time_dim = True` then
      `value` it must have tensors with shape `[B, T=2]` outer dims.
      This is converted to a `Transition` object without a time
      dimension.
    - If `value` is a `Trajectory` with tensors containing a time dimension
      having `T != 2`, a `ValueError` is raised.

    Args:
      value: A `Trajectory` or `Transition` object to convert.

    Returns:
      A validated and pruned `Transition`.  If `squeeze_time_dim = True`,
      the resulting `Transition` has tensors with shape `[B, ...]`.  Otherwise,
      the tensors will have shape `[B, T - 1, ...]`.

    Raises:
      TypeError: If `value` is not one of `Trajectory` or `Transition`.
      ValueError: If `value` has structure that doesn't match the converter's
        spec.
      TypeError: If `value` has a structure that doesn't match the converter's
        spec.
      ValueError: If `squeeze_time_dim=True` and `value` is a `Trajectory`
        with a time dimension having value other than `T=2`.
    """
    if _is_transition_like(value):
      value = _as_tfa_transition(value)
    elif _is_trajectory_like(value):
      required_sequence_length = 2 if self._squeeze_time_dim else None
      _validate_trajectory(
          value,
          self._data_context.trajectory_spec,
          sequence_length=required_sequence_length)
      value = trajectory.to_transition(value)
      # Remove the now-singleton time dim.
      if self._squeeze_time_dim:
        value = tf.nest.map_structure(
            lambda x: composite.squeeze(x, axis=1), value)
    else:
      raise TypeError('Input type not supported: {}'.format(value))

    num_outer_dims = 1 if self._squeeze_time_dim else 2
    _validate_transition(
        value, self._data_context.transition_spec, num_outer_dims)

    value = nest_utils.prune_extra_keys(
        self._data_context.transition_spec, value)

    if self._prepend_t0_to_next_time_step:
      # This is useful when using sequential model. It allows target_q network
      # to take all the information.
      next_time_step_with_t0 = value.next_time_step._replace(
          observation=tf.nest.map_structure(
              lambda x, y: tf.concat([x[:, :1, ...], y], axis=1),
              value.time_step.observation, value.next_time_step.observation))

      value = value._replace(next_time_step=next_time_step_with_t0)
    return value