Esempio n. 1
0
    def distribution(self, time_step, policy_state=()):
        """Generates the distribution over next actions given the time_step.

    Args:
      time_step: A `TimeStep` tuple corresponding to `time_step_spec()`.
      policy_state: A Tensor, or a nested dict, list or tuple of Tensors
        representing the previous policy_state.

    Returns:
      A `PolicyStep` named tuple containing:

        `action`: A tf.distribution capturing the distribution of next actions.
        `state`: A policy state tensor for the next call to distribution.
        `info`: Optional side information such as action log probabilities.
    """
        time_step = nest_utils.prune_extra_keys(self._time_step_spec,
                                                time_step)
        policy_state = nest_utils.prune_extra_keys(self._policy_state_spec,
                                                   policy_state)
        tf.nest.assert_same_structure(time_step, self._time_step_spec)
        tf.nest.assert_same_structure(policy_state, self._policy_state_spec)
        if self._automatic_state_reset:
            policy_state = self._maybe_reset_state(time_step, policy_state)
        step = self._distribution(time_step=time_step,
                                  policy_state=policy_state)
        if self.emit_log_probability:
            # This here is set only for compatibility with info_spec in constructor.
            info = policy_step.set_log_probability(
                step.info,
                tf.nest.map_structure(
                    lambda _: tf.constant(0., dtype=tf.float32),
                    policy_step.get_log_probability(self._info_spec)))
            step = step._replace(info=info)
        tf.nest.assert_same_structure(step, self._policy_step_spec)
        return step
Esempio n. 2
0
  def train(self, experience, weights=None, **kwargs):
    """Trains the agent.

    Args:
      experience: A batch of experience data in the form of a `Trajectory`. The
        structure of `experience` must match that of `self.training_data_spec`.
        All tensors in `experience` must be shaped `[batch, time, ...]` where
        `time` must be equal to `self.train_step_length` if that
        property is not `None`.
      weights: (optional).  A `Tensor`, either `0-D` or shaped `[batch]`,
        containing weights to be used when calculating the total train loss.
        Weights are typically multiplied elementwise against the per-batch loss,
        but the implementation is up to the Agent.
      **kwargs: Any additional data as declared by `self.train_argspec`.

    Returns:
        A `LossInfo` loss tuple containing loss and info tensors.
        - In eager mode, the loss values are first calculated, then a train step
          is performed before they are returned.
        - In graph mode, executing any or all of the loss tensors
          will first calculate the loss value(s), then perform a train step,
          and return the pre-train-step `LossInfo`.

    Raises:
      TypeError: If experience is not type `Trajectory`.  Or if experience
        does not match `self.training_data_spec` structure types.
      ValueError: If experience tensors' time axes are not compatible with
        `self.train_sequence_length`.  Or if experience does not match
        `self.training_data_spec` structure.
      ValueError: If the user does not pass `**kwargs` matching
        `self.train_argspec`.
      RuntimeError: If the class was not initialized properly (`super.__init__`
        was not called).
    """
    if self._enable_functions and getattr(self, "_train_fn", None) is None:
      raise RuntimeError(
          "Cannot find _train_fn.  Did %s.__init__ call super?"
          % type(self).__name__)

    self._check_trajectory_dimensions(experience)
    self._check_train_argspec(kwargs)

    # Even though the checks above prune dict keys, we want them to see
    # the non-pruned versions to provide clearer error messages.
    # However, from here on out we want to remove dict entries that aren't
    # requested in the spec.
    experience = nest_utils.prune_extra_keys(
        self.training_data_spec, experience)
    kwargs = nest_utils.prune_extra_keys(self.train_argspec, kwargs)

    if self._enable_functions:
      loss_info = self._train_fn(
          experience=experience, weights=weights, **kwargs)
    else:
      loss_info = self._train(experience=experience, weights=weights, **kwargs)

    if not isinstance(loss_info, LossInfo):
      raise TypeError(
          "loss_info is not a subclass of LossInfo: {}".format(loss_info))
    return loss_info
Esempio n. 3
0
 def testInvalidWide(self):
   self.assertEqual(nest_utils.prune_extra_keys(None, {'a': 1}), {'a': 1})
   self.assertEqual(nest_utils.prune_extra_keys({'a': 1}, {}), {})
   self.assertEqual(nest_utils.prune_extra_keys(
       {'a': 1}, {'c': 'c'}), {'c': 'c'})
   self.assertEqual(nest_utils.prune_extra_keys([], ['a']), ['a'])
   self.assertEqual(
       nest_utils.prune_extra_keys([{}, {}], [{'a': 1}]), [{'a': 1}])
Esempio n. 4
0
    def distribution(
        self, time_step: ts.TimeStep, policy_state: types.NestedTensor = ()
    ) -> policy_step.PolicyStep:
        """Generates the distribution over next actions given the time_step.

    Args:
      time_step: A `TimeStep` tuple corresponding to `time_step_spec()`.
      policy_state: A Tensor, or a nested dict, list or tuple of Tensors
        representing the previous policy_state.

    Returns:
      A `PolicyStep` named tuple containing:

        `action`: A tf.distribution capturing the distribution of next actions.
        `state`: A policy state tensor for the next call to distribution.
        `info`: Optional side information such as action log probabilities.

    Raises:
      ValueError or TypeError: If `validate_args is True` and inputs or
        outputs do not match `time_step_spec`, `policy_state_spec`,
        or `policy_step_spec`.
    """
        if self._validate_args:
            time_step = nest_utils.prune_extra_keys(self._time_step_spec,
                                                    time_step)
            policy_state = nest_utils.prune_extra_keys(self._policy_state_spec,
                                                       policy_state)
            nest_utils.assert_same_structure(
                time_step,
                self._time_step_spec,
                message='time_step and time_step_spec structures do not match')
            nest_utils.assert_same_structure(
                policy_state,
                self._policy_state_spec,
                message=
                'policy_state and policy_state_spec structures do not match')
        if self._automatic_state_reset:
            policy_state = self._maybe_reset_state(time_step, policy_state)
        step = self._distribution(time_step=time_step,
                                  policy_state=policy_state)
        if self.emit_log_probability:
            # This here is set only for compatibility with info_spec in constructor.
            info = policy_step.set_log_probability(
                step.info,
                tf.nest.map_structure(
                    lambda _: tf.constant(0., dtype=tf.float32),
                    policy_step.get_log_probability(self._info_spec)))
            step = step._replace(info=info)
        if self._validate_args:
            nest_utils.assert_same_structure(
                step,
                self._policy_step_spec,
                message=('distribution output and policy_step_spec structures '
                         'do not match'))
        return step
Esempio n. 5
0
  def testPruneExtraKeys(self):
    self.assertEqual(nest_utils.prune_extra_keys({}, {'a': 1}), {})
    self.assertEqual(nest_utils.prune_extra_keys((), {'a': 1}), ())
    self.assertEqual(nest_utils.prune_extra_keys(
        {'a': 1}, {'a': 'a'}), {'a': 'a'})
    self.assertEqual(
        nest_utils.prune_extra_keys({'a': 1}, {'a': 'a', 'b': 2}), {'a': 'a'})
    self.assertEqual(
        nest_utils.prune_extra_keys([{'a': 1}], [{'a': 'a', 'b': 2}]),
        [{'a': 'a'}])
    self.assertEqual(
        nest_utils.prune_extra_keys({'a': (), 'b': None}, {'a': 1, 'b': 2}),
        {'a': (), 'b': 2})
    self.assertEqual(
        nest_utils.prune_extra_keys(
            {'a': {'aa': 1, 'ab': 2}, 'b': {'ba': 1}},
            {'a': {'aa': 'aa', 'ab': 'ab', 'ac': 'ac'},
             'b': {'ba': 'ba', 'bb': 'bb'},
             'c': 'c'}),
        {'a': {'aa': 'aa', 'ab': 'ab'}, 'b': {'ba': 'ba'}})

    self.assertEqual(
        nest_utils.prune_extra_keys(
            {'a': ()},
            DictWrapper({'a': DictWrapper({'b': None})})),
        {'a': ()})

    self.assertEqual(
        nest_utils.prune_extra_keys(
            {'a': 1, 'c': 2},
            DictWrapper({'a': DictWrapper({'b': None})})),
        {'a': {'b': None}})
Esempio n. 6
0
 def testPruneExtraKeys(self):
     self.assertEqual(nest_utils.prune_extra_keys({}, {'a': 1}), {})
     self.assertEqual(nest_utils.prune_extra_keys({'a': 1}, {'a': 'a'}),
                      {'a': 'a'})
     self.assertEqual(
         nest_utils.prune_extra_keys({'a': 1}, {
             'a': 'a',
             'b': 2
         }), {'a': 'a'})
     self.assertEqual(
         nest_utils.prune_extra_keys([{
             'a': 1
         }], [{
             'a': 'a',
             'b': 2
         }]), [{
             'a': 'a'
         }])
     self.assertEqual(
         nest_utils.prune_extra_keys(
             {
                 'a': {
                     'aa': 1,
                     'ab': 2
                 },
                 'b': {
                     'ba': 1
                 }
             }, {
                 'a': {
                     'aa': 'aa',
                     'ab': 'ab',
                     'ac': 'ac'
                 },
                 'b': {
                     'ba': 'ba',
                     'bb': 'bb'
                 },
                 'c': 'c'
             }), {
                 'a': {
                     'aa': 'aa',
                     'ab': 'ab'
                 },
                 'b': {
                     'ba': 'ba'
                 }
             })
  def testSubtypesOfListAndDict(self):

    class A(collections.namedtuple('A', ('a', 'b'))):
      pass

    # pylint: disable=invalid-name
    DictWrapper = data_structures.wrap_or_unwrap
    TupleWrapper = data_structures.wrap_or_unwrap
    # pylint: enable=invalid-name

    self.assertEqual(
        nest_utils.prune_extra_keys(
            [data_structures.ListWrapper([None, DictWrapper({'a': 3, 'b': 4})]),
             None,
             TupleWrapper((DictWrapper({'g': 5}),)),
             TupleWrapper(A(None, DictWrapper({'h': 6}))),
            ],
            [['x', {'a': 'a', 'b': 'b', 'c': 'c'}],
             'd',
             ({'g': 'g', 'gg': 'gg'},),
             A(None, {'h': 'h', 'hh': 'hh'}),
            ]),
        [data_structures.ListWrapper([
            'x', DictWrapper({'a': 'a', 'b': 'b'})]),
         'd',
         TupleWrapper((DictWrapper({'g': 'g'}),)),
         TupleWrapper(A(None, DictWrapper({'h': 'h'}),)),
        ])
Esempio n. 8
0
  def __call__(self, value: typing.Any) -> trajectory.Transition:
    """Convert `value` to an N-step Transition; validate data & prune.

    - If `value` is already a `Transition`, only validation is performed.
    - If `value` is a `Trajectory` with tensors containing a time dimension
      having `T != n + 1`, a `ValueError` is raised.

    Args:
      value: A `Trajectory` or `Transition` object to convert.

    Returns:
      A validated and pruned `Transition`.  If `squeeze_time_dim = True`,
      the resulting `Transition` has tensors with shape `[B, ...]`.  Otherwise,
      the tensors will have shape `[B, T - 1, ...]`.

    Raises:
      TypeError: If `value` is not one of `Trajectory` or `Transition`.
      ValueError: If `value` has structure that doesn't match the converter's
        spec.
      TypeError: If `value` has a structure that doesn't match the converter's
        spec.
      ValueError: If `n != None` and `value` is a `Trajectory`
        with a time dimension having value other than `T=n + 1`.
    """
    if _is_transition_like(value):
      value = _as_tfa_transition(value)
    elif _is_trajectory_like(value):
      required_sequence_length = 1 if self._squeeze_time_dim else None
      _validate_trajectory(
          value,
          self._data_context.trajectory_spec,
          sequence_length=required_sequence_length)
      if self._squeeze_time_dim:
        value = tf.nest.map_structure(lambda e: tf.squeeze(e, axis=1), value)
      policy_steps = policy_step.PolicyStep(
          action=value.action, state=(), info=value.policy_info)
      # TODO(b/130244652): Consider replacing 0 rewards & discounts with ().
      time_steps = ts.TimeStep(
          value.step_type,
          reward=tf.nest.map_structure(tf.zeros_like, value.reward),  # unknown
          discount=tf.zeros_like(value.discount),  # unknown
          observation=value.observation)
      next_time_steps = ts.TimeStep(
          step_type=value.next_step_type,
          reward=value.reward,
          discount=value.discount,
          observation=tf.zeros_like(value.discount))
      value = trajectory.Transition(time_steps, policy_steps, next_time_steps)
    else:
      raise TypeError('Input type not supported: {}'.format(value))

    num_outer_dims = 1 if self._squeeze_time_dim else 2
    _validate_transition(
        value, self._data_context.transition_spec, num_outer_dims)

    value = nest_utils.prune_extra_keys(
        self._data_context.transition_spec, value)
    return value
Esempio n. 9
0
  def testNamedTuple(self):

    class A(collections.namedtuple('A', ('a', 'b'))):
      pass

    self.assertEqual(
        nest_utils.prune_extra_keys(
            [A(a={'aa': 1}, b=3), {'c': 4}],
            [A(a={'aa': 'aa', 'ab': 'ab'}, b='b'), {'c': 'c', 'd': 'd'}]),
        [A(a={'aa': 'aa'}, b='b'), {'c': 'c'}])
Esempio n. 10
0
    def testOrderedDict(self):
        OD = collections.OrderedDict  # pylint: disable=invalid-name

        self.assertEqual(
            nest_utils.prune_extra_keys(
                OD([('a', OD([('aa', 1), ('ab', 2)])),
                    ('b', OD([('ba', 1)]))]),
                OD([('a', OD([('aa', 'aa'), ('ab', 'ab'), ('ac', 'ac')])),
                    ('b', OD([('ba', 'ba'), ('bb', 'bb')])), ('c', 'c')])),
            OD([('a', OD([('aa', 'aa'), ('ab', 'ab')])),
                ('b', OD([('ba', 'ba')]))]))
  def __call__(self, value: typing.Any):
    """Converts `value` to a Transition.  Performs data validation and pruning.

    - If `value` is already a `Transition`, only validation is performed.
    - If `value` is a `Trajectory` and `squeeze_time_dim = True` then
      `value` it must have tensors with shape `[B, T=2]` outer dims.
      This is converted to a `Transition` object without a time
      dimension.
    - If `value` is a `Trajectory` with tensors containing a time dimension
      having `T != 2`, a `ValueError` is raised.

    Args:
      value: A `Trajectory` or `Transition` object to convert.

    Returns:
      A validated and pruned `Transition`.  If `squeeze_time_dim = True`,
      the resulting `Transition` has tensors with shape `[B, ...]`.  Otherwise,
      the tensors will have shape `[B, T - 1, ...]`.

    Raises:
      TypeError: If `value` is not one of `Trajectory` or `Transition`.
      ValueError: If `value` has structure that doesn't match the converter's
        spec.
      TypeError: If `value` has a structure that doesn't match the converter's
        spec.
      ValueError: If `squeeze_time_dim=True` and `value` is a `Trajectory`
        with a time dimension having value other than `T=2`.
    """
    if isinstance(value, trajectory.Transition):
      pass
    elif isinstance(value, trajectory.Trajectory):
      required_sequence_length = 2 if self._squeeze_time_dim else None
      _validate_trajectory(
          value,
          self._data_context.trajectory_spec,
          sequence_length=required_sequence_length)
      value = trajectory.to_transition(value)
      # Remove the now-singleton time dim.
      if self._squeeze_time_dim:
        value = tf.nest.map_structure(
            lambda x: composite.squeeze(x, axis=1), value)
    else:
      raise TypeError('Input type not supported: {}'.format(value))

    self._validate_transition(value)
    value = nest_utils.prune_extra_keys(
        self._data_context.transition_spec, value)
    return value
Esempio n. 12
0
    def __call__(self, value: typing.Any):
        """Convers `value` to a Trajectory.  Performs data validation and pruning.

    - If `value` is already a `Trajectory`, only validation is performed.
    - If `value` is a `Transition` with tensors containing two (`[B, T]`)
      outer dims, then it is simply repackaged to a `Trajectory` and then
      validated.
    - If `value` is a `Transition` with tensors containing one (`[B]`) outer
      dim, a `ValueError` is raised.

    Args:
      value: A `Trajectory` or `Transition` object to convert.

    Returns:
      A validated and pruned `Trajectory`.

    Raises:
      TypeError: If `value` is not one of `Trajectory` or `Transition`.
      ValueError: If `value` has structure that doesn't match the converter's
        spec.
      TypeError: If `value` has a structure that doesn't match the converter's
        spec.
      ValueError: If `value` is a `Transition` without a time dimension, as
        training Trajectories typically have batch and time dimensions.
    """
        if isinstance(value, trajectory.Trajectory):
            pass
        elif isinstance(value, trajectory.Transition):
            value = trajectory.Trajectory(
                step_type=value.time_step.step_type,
                observation=value.time_step.observation,
                action=value.action_step.action,
                policy_info=value.action_step.info,
                next_step_type=value.next_time_step.step_type,
                reward=value.next_time_step.reward,
                discount=value.next_time_step.discount)
        else:
            raise TypeError('Input type not supported: {}'.format(value))
        _validate_trajectory(value,
                             self._data_context.trajectory_spec,
                             sequence_length=self._sequence_length,
                             num_outer_dims=self._num_outer_dims)
        value = nest_utils.prune_extra_keys(self._data_context.trajectory_spec,
                                            value)
        return value
Esempio n. 13
0
  def __call__(self, value: typing.Any) -> trajectory.Transition:
    """Convert `value` to an N-step Transition; validate data & prune.

    - If `value` is already a `Transition`, only validation is performed.
    - If `value` is a `Trajectory` with tensors containing a time dimension
      having `T != n + 1`, a `ValueError` is raised.

    Args:
      value: A `Trajectory` or `Transition` object to convert.

    Returns:
      A validated and pruned `Transition`.  If `squeeze_time_dim = True`,
      the resulting `Transition` has tensors with shape `[B, ...]`.  Otherwise,
      the tensors will have shape `[B, T - 1, ...]`.

    Raises:
      TypeError: If `value` is not one of `Trajectory` or `Transition`.
      ValueError: If `value` has structure that doesn't match the converter's
        spec.
      TypeError: If `value` has a structure that doesn't match the converter's
        spec.
      ValueError: If `n != None` and `value` is a `Trajectory`
        with a time dimension having value other than `T=n + 1`.
    """
    if _is_transition_like(value):
      value = _as_tfa_transition(value)
    elif _is_trajectory_like(value):
      _validate_trajectory(
          value,
          self._data_context.trajectory_spec,
          sequence_length=None if self._n is None else self._n + 1)
      value = trajectory.to_n_step_transition(value, gamma=self._gamma)
    else:
      raise TypeError('Input type not supported: {}'.format(value))

    _validate_transition(
        value, self._data_context.transition_spec, num_outer_dims=1)
    value = nest_utils.prune_extra_keys(
        self._data_context.transition_spec, value)
    return value
Esempio n. 14
0
    def action(self, time_step, policy_state=(), seed=None):
        """Generates next action given the time_step and policy_state.

    Args:
      time_step: A `TimeStep` tuple corresponding to `time_step_spec()`.
      policy_state: A Tensor, or a nested dict, list or tuple of Tensors
        representing the previous policy_state.
      seed: Seed to use if action performs sampling (optional).

    Returns:
      A `PolicyStep` named tuple containing:
        `action`: An action Tensor matching the `action_spec`.
        `state`: A policy state tensor to be fed into the next call to action.
        `info`: Optional side information such as action log probabilities.

    Raises:
      RuntimeError: If subclass __init__ didn't call super().__init__.
      ValueError or TypeError: If `validate_args is True` and inputs or
        outputs do not match `time_step_spec`, `policy_state_spec`,
        or `policy_step_spec`.
    """
        if self._enable_functions and getattr(self, '_action_fn',
                                              None) is None:
            raise RuntimeError(
                'Cannot find _action_fn.  Did %s.__init__ call super?' %
                type(self).__name__)
        if self._enable_functions:
            action_fn = self._action_fn
        else:
            action_fn = self._action

        if self._validate_args:
            time_step = nest_utils.prune_extra_keys(self._time_step_spec,
                                                    time_step)
            policy_state = nest_utils.prune_extra_keys(self._policy_state_spec,
                                                       policy_state)
            nest_utils.assert_same_structure(
                time_step,
                self._time_step_spec,
                message='time_step and time_step_spec structures do not match')
            nest_utils.assert_same_structure(
                policy_state,
                self._policy_state_spec,
                message=
                'policy_state and policy_state_spec structures do not match')

        if self._automatic_state_reset:
            policy_state = self._maybe_reset_state(time_step, policy_state)
        step = action_fn(time_step=time_step,
                         policy_state=policy_state,
                         seed=seed)

        def clip_action(action, action_spec):
            if isinstance(action_spec, tensor_spec.BoundedTensorSpec):
                return common.clip_to_spec(action, action_spec)
            return action

        if self._validate_args:
            nest_utils.assert_same_structure(
                step.action,
                self._action_spec,
                message='action and action_spec structures do not match')

        if self._clip:
            clipped_actions = tf.nest.map_structure(clip_action, step.action,
                                                    self._action_spec)
            step = step._replace(action=clipped_actions)

        if self._validate_args:
            nest_utils.assert_same_structure(
                step,
                self._policy_step_spec,
                message=
                'action output and policy_step_spec structures do not match')

            def compare_to_spec(value, spec):
                return value.dtype.is_compatible_with(spec.dtype)

            compatibility = [
                compare_to_spec(v, s)
                for (v, s) in zip(tf.nest.flatten(step.action),
                                  tf.nest.flatten(self.action_spec))
            ]

            if not all(compatibility):
                get_dtype = lambda x: x.dtype
                action_dtypes = tf.nest.map_structure(get_dtype, step.action)
                spec_dtypes = tf.nest.map_structure(get_dtype,
                                                    self.action_spec)

                raise TypeError(
                    'Policy produced an action with a dtype that doesn\'t '
                    'match its action_spec. Got action:\n  %s\n with '
                    'action_spec:\n  %s' % (action_dtypes, spec_dtypes))

        return step
Esempio n. 15
0
    def loss(self,
             experience: types.NestedTensor,
             weights: Optional[types.Tensor] = None,
             **kwargs) -> LossInfo:
        """Gets loss from the agent.

    If the user calls this from _train, it must be in a `tf.GradientTape` scope
    in order to apply gradients to trainable variables.
    If intermediate gradient steps are needed, _loss and _train will return
    different values since _loss only supports updating all gradients at once
    after all losses have been calculated.

    Args:
      experience: A batch of experience data in the form of a `Trajectory`. The
        structure of `experience` must match that of `self.training_data_spec`.
        All tensors in `experience` must be shaped `[batch, time, ...]` where
        `time` must be equal to `self.train_step_length` if that
        property is not `None`.
      weights: (optional).  A `Tensor`, either `0-D` or shaped `[batch]`,
        containing weights to be used when calculating the total train loss.
        Weights are typically multiplied elementwise against the per-batch loss,
        but the implementation is up to the Agent.
      **kwargs: Any additional data as args to `loss`.

    Returns:
        A `LossInfo` loss tuple containing loss and info tensors.

    Raises:
      TypeError: If `validate_args is True` and: Experience is not type
        `Trajectory`; or if `experience`  does not match
        `self.training_data_spec` structure types.
      ValueError: If `validate_args is True` and: Experience tensors' time axes
        are not compatible with `self.train_sequence_length`; or if experience
        does not match `self.training_data_spec` structure.
      ValueError: If `validate_args is True` and the user does not pass
        `**kwargs` matching `self.train_argspec`.
      RuntimeError: If the class was not initialized properly (`super.__init__`
        was not called).
    """
        if self._enable_functions and getattr(self, "_loss_fn", None) is None:
            raise RuntimeError(
                "Cannot find _loss_fn.  Did %s.__init__ call super?" %
                type(self).__name__)

        if self._validate_args:
            self._check_trajectory_dimensions(experience)
            self._check_train_argspec(kwargs)

            # Even though the checks above prune dict keys, we want them to see
            # the non-pruned versions to provide clearer error messages.
            # However, from here on out we want to remove dict entries that aren't
            # requested in the spec.
            experience = nest_utils.prune_extra_keys(self.training_data_spec,
                                                     experience)
            kwargs = nest_utils.prune_extra_keys(self.train_argspec, kwargs)

        if self._enable_functions:
            loss_info = self._loss_fn(experience=experience,
                                      weights=weights,
                                      **kwargs)
        else:
            loss_info = self._loss(experience=experience,
                                   weights=weights,
                                   **kwargs)

        if not isinstance(loss_info, LossInfo):
            raise TypeError(
                "loss_info is not a subclass of LossInfo: {}".format(
                    loss_info))
        return loss_info
Esempio n. 16
0
  def __call__(self, value: typing.Any) -> trajectory.Transition:
    """Converts `value` to a Transition.  Performs data validation and pruning.

    - If `value` is already a `Transition`, only validation is performed.
    - If `value` is a `Trajectory` and `squeeze_time_dim = True` then
      `value` it must have tensors with shape `[B, T=2]` outer dims.
      This is converted to a `Transition` object without a time
      dimension.
    - If `value` is a `Trajectory` with tensors containing a time dimension
      having `T != 2`, a `ValueError` is raised.

    Args:
      value: A `Trajectory` or `Transition` object to convert.

    Returns:
      A validated and pruned `Transition`.  If `squeeze_time_dim = True`,
      the resulting `Transition` has tensors with shape `[B, ...]`.  Otherwise,
      the tensors will have shape `[B, T - 1, ...]`.

    Raises:
      TypeError: If `value` is not one of `Trajectory` or `Transition`.
      ValueError: If `value` has structure that doesn't match the converter's
        spec.
      TypeError: If `value` has a structure that doesn't match the converter's
        spec.
      ValueError: If `squeeze_time_dim=True` and `value` is a `Trajectory`
        with a time dimension having value other than `T=2`.
    """
    if _is_transition_like(value):
      value = _as_tfa_transition(value)
    elif _is_trajectory_like(value):
      required_sequence_length = 2 if self._squeeze_time_dim else None
      _validate_trajectory(
          value,
          self._data_context.trajectory_spec,
          sequence_length=required_sequence_length)
      value = trajectory.to_transition(value)
      # Remove the now-singleton time dim.
      if self._squeeze_time_dim:
        value = tf.nest.map_structure(
            lambda x: composite.squeeze(x, axis=1), value)
    else:
      raise TypeError('Input type not supported: {}'.format(value))

    num_outer_dims = 1 if self._squeeze_time_dim else 2
    _validate_transition(
        value, self._data_context.transition_spec, num_outer_dims)

    value = nest_utils.prune_extra_keys(
        self._data_context.transition_spec, value)

    if self._prepend_t0_to_next_time_step:
      # This is useful when using sequential model. It allows target_q network
      # to take all the information.
      next_time_step_with_t0 = value.next_time_step._replace(
          observation=tf.nest.map_structure(
              lambda x, y: tf.concat([x[:, :1, ...], y], axis=1),
              value.time_step.observation, value.next_time_step.observation))

      value = value._replace(next_time_step=next_time_step_with_t0)
    return value