def distribution(self, time_step, policy_state=()): """Generates the distribution over next actions given the time_step. Args: time_step: A `TimeStep` tuple corresponding to `time_step_spec()`. policy_state: A Tensor, or a nested dict, list or tuple of Tensors representing the previous policy_state. Returns: A `PolicyStep` named tuple containing: `action`: A tf.distribution capturing the distribution of next actions. `state`: A policy state tensor for the next call to distribution. `info`: Optional side information such as action log probabilities. """ time_step = nest_utils.prune_extra_keys(self._time_step_spec, time_step) policy_state = nest_utils.prune_extra_keys(self._policy_state_spec, policy_state) tf.nest.assert_same_structure(time_step, self._time_step_spec) tf.nest.assert_same_structure(policy_state, self._policy_state_spec) if self._automatic_state_reset: policy_state = self._maybe_reset_state(time_step, policy_state) step = self._distribution(time_step=time_step, policy_state=policy_state) if self.emit_log_probability: # This here is set only for compatibility with info_spec in constructor. info = policy_step.set_log_probability( step.info, tf.nest.map_structure( lambda _: tf.constant(0., dtype=tf.float32), policy_step.get_log_probability(self._info_spec))) step = step._replace(info=info) tf.nest.assert_same_structure(step, self._policy_step_spec) return step
def train(self, experience, weights=None, **kwargs): """Trains the agent. Args: experience: A batch of experience data in the form of a `Trajectory`. The structure of `experience` must match that of `self.training_data_spec`. All tensors in `experience` must be shaped `[batch, time, ...]` where `time` must be equal to `self.train_step_length` if that property is not `None`. weights: (optional). A `Tensor`, either `0-D` or shaped `[batch]`, containing weights to be used when calculating the total train loss. Weights are typically multiplied elementwise against the per-batch loss, but the implementation is up to the Agent. **kwargs: Any additional data as declared by `self.train_argspec`. Returns: A `LossInfo` loss tuple containing loss and info tensors. - In eager mode, the loss values are first calculated, then a train step is performed before they are returned. - In graph mode, executing any or all of the loss tensors will first calculate the loss value(s), then perform a train step, and return the pre-train-step `LossInfo`. Raises: TypeError: If experience is not type `Trajectory`. Or if experience does not match `self.training_data_spec` structure types. ValueError: If experience tensors' time axes are not compatible with `self.train_sequence_length`. Or if experience does not match `self.training_data_spec` structure. ValueError: If the user does not pass `**kwargs` matching `self.train_argspec`. RuntimeError: If the class was not initialized properly (`super.__init__` was not called). """ if self._enable_functions and getattr(self, "_train_fn", None) is None: raise RuntimeError( "Cannot find _train_fn. Did %s.__init__ call super?" % type(self).__name__) self._check_trajectory_dimensions(experience) self._check_train_argspec(kwargs) # Even though the checks above prune dict keys, we want them to see # the non-pruned versions to provide clearer error messages. # However, from here on out we want to remove dict entries that aren't # requested in the spec. experience = nest_utils.prune_extra_keys( self.training_data_spec, experience) kwargs = nest_utils.prune_extra_keys(self.train_argspec, kwargs) if self._enable_functions: loss_info = self._train_fn( experience=experience, weights=weights, **kwargs) else: loss_info = self._train(experience=experience, weights=weights, **kwargs) if not isinstance(loss_info, LossInfo): raise TypeError( "loss_info is not a subclass of LossInfo: {}".format(loss_info)) return loss_info
def testInvalidWide(self): self.assertEqual(nest_utils.prune_extra_keys(None, {'a': 1}), {'a': 1}) self.assertEqual(nest_utils.prune_extra_keys({'a': 1}, {}), {}) self.assertEqual(nest_utils.prune_extra_keys( {'a': 1}, {'c': 'c'}), {'c': 'c'}) self.assertEqual(nest_utils.prune_extra_keys([], ['a']), ['a']) self.assertEqual( nest_utils.prune_extra_keys([{}, {}], [{'a': 1}]), [{'a': 1}])
def distribution( self, time_step: ts.TimeStep, policy_state: types.NestedTensor = () ) -> policy_step.PolicyStep: """Generates the distribution over next actions given the time_step. Args: time_step: A `TimeStep` tuple corresponding to `time_step_spec()`. policy_state: A Tensor, or a nested dict, list or tuple of Tensors representing the previous policy_state. Returns: A `PolicyStep` named tuple containing: `action`: A tf.distribution capturing the distribution of next actions. `state`: A policy state tensor for the next call to distribution. `info`: Optional side information such as action log probabilities. Raises: ValueError or TypeError: If `validate_args is True` and inputs or outputs do not match `time_step_spec`, `policy_state_spec`, or `policy_step_spec`. """ if self._validate_args: time_step = nest_utils.prune_extra_keys(self._time_step_spec, time_step) policy_state = nest_utils.prune_extra_keys(self._policy_state_spec, policy_state) nest_utils.assert_same_structure( time_step, self._time_step_spec, message='time_step and time_step_spec structures do not match') nest_utils.assert_same_structure( policy_state, self._policy_state_spec, message= 'policy_state and policy_state_spec structures do not match') if self._automatic_state_reset: policy_state = self._maybe_reset_state(time_step, policy_state) step = self._distribution(time_step=time_step, policy_state=policy_state) if self.emit_log_probability: # This here is set only for compatibility with info_spec in constructor. info = policy_step.set_log_probability( step.info, tf.nest.map_structure( lambda _: tf.constant(0., dtype=tf.float32), policy_step.get_log_probability(self._info_spec))) step = step._replace(info=info) if self._validate_args: nest_utils.assert_same_structure( step, self._policy_step_spec, message=('distribution output and policy_step_spec structures ' 'do not match')) return step
def testPruneExtraKeys(self): self.assertEqual(nest_utils.prune_extra_keys({}, {'a': 1}), {}) self.assertEqual(nest_utils.prune_extra_keys((), {'a': 1}), ()) self.assertEqual(nest_utils.prune_extra_keys( {'a': 1}, {'a': 'a'}), {'a': 'a'}) self.assertEqual( nest_utils.prune_extra_keys({'a': 1}, {'a': 'a', 'b': 2}), {'a': 'a'}) self.assertEqual( nest_utils.prune_extra_keys([{'a': 1}], [{'a': 'a', 'b': 2}]), [{'a': 'a'}]) self.assertEqual( nest_utils.prune_extra_keys({'a': (), 'b': None}, {'a': 1, 'b': 2}), {'a': (), 'b': 2}) self.assertEqual( nest_utils.prune_extra_keys( {'a': {'aa': 1, 'ab': 2}, 'b': {'ba': 1}}, {'a': {'aa': 'aa', 'ab': 'ab', 'ac': 'ac'}, 'b': {'ba': 'ba', 'bb': 'bb'}, 'c': 'c'}), {'a': {'aa': 'aa', 'ab': 'ab'}, 'b': {'ba': 'ba'}}) self.assertEqual( nest_utils.prune_extra_keys( {'a': ()}, DictWrapper({'a': DictWrapper({'b': None})})), {'a': ()}) self.assertEqual( nest_utils.prune_extra_keys( {'a': 1, 'c': 2}, DictWrapper({'a': DictWrapper({'b': None})})), {'a': {'b': None}})
def testPruneExtraKeys(self): self.assertEqual(nest_utils.prune_extra_keys({}, {'a': 1}), {}) self.assertEqual(nest_utils.prune_extra_keys({'a': 1}, {'a': 'a'}), {'a': 'a'}) self.assertEqual( nest_utils.prune_extra_keys({'a': 1}, { 'a': 'a', 'b': 2 }), {'a': 'a'}) self.assertEqual( nest_utils.prune_extra_keys([{ 'a': 1 }], [{ 'a': 'a', 'b': 2 }]), [{ 'a': 'a' }]) self.assertEqual( nest_utils.prune_extra_keys( { 'a': { 'aa': 1, 'ab': 2 }, 'b': { 'ba': 1 } }, { 'a': { 'aa': 'aa', 'ab': 'ab', 'ac': 'ac' }, 'b': { 'ba': 'ba', 'bb': 'bb' }, 'c': 'c' }), { 'a': { 'aa': 'aa', 'ab': 'ab' }, 'b': { 'ba': 'ba' } })
def testSubtypesOfListAndDict(self): class A(collections.namedtuple('A', ('a', 'b'))): pass # pylint: disable=invalid-name DictWrapper = data_structures.wrap_or_unwrap TupleWrapper = data_structures.wrap_or_unwrap # pylint: enable=invalid-name self.assertEqual( nest_utils.prune_extra_keys( [data_structures.ListWrapper([None, DictWrapper({'a': 3, 'b': 4})]), None, TupleWrapper((DictWrapper({'g': 5}),)), TupleWrapper(A(None, DictWrapper({'h': 6}))), ], [['x', {'a': 'a', 'b': 'b', 'c': 'c'}], 'd', ({'g': 'g', 'gg': 'gg'},), A(None, {'h': 'h', 'hh': 'hh'}), ]), [data_structures.ListWrapper([ 'x', DictWrapper({'a': 'a', 'b': 'b'})]), 'd', TupleWrapper((DictWrapper({'g': 'g'}),)), TupleWrapper(A(None, DictWrapper({'h': 'h'}),)), ])
def __call__(self, value: typing.Any) -> trajectory.Transition: """Convert `value` to an N-step Transition; validate data & prune. - If `value` is already a `Transition`, only validation is performed. - If `value` is a `Trajectory` with tensors containing a time dimension having `T != n + 1`, a `ValueError` is raised. Args: value: A `Trajectory` or `Transition` object to convert. Returns: A validated and pruned `Transition`. If `squeeze_time_dim = True`, the resulting `Transition` has tensors with shape `[B, ...]`. Otherwise, the tensors will have shape `[B, T - 1, ...]`. Raises: TypeError: If `value` is not one of `Trajectory` or `Transition`. ValueError: If `value` has structure that doesn't match the converter's spec. TypeError: If `value` has a structure that doesn't match the converter's spec. ValueError: If `n != None` and `value` is a `Trajectory` with a time dimension having value other than `T=n + 1`. """ if _is_transition_like(value): value = _as_tfa_transition(value) elif _is_trajectory_like(value): required_sequence_length = 1 if self._squeeze_time_dim else None _validate_trajectory( value, self._data_context.trajectory_spec, sequence_length=required_sequence_length) if self._squeeze_time_dim: value = tf.nest.map_structure(lambda e: tf.squeeze(e, axis=1), value) policy_steps = policy_step.PolicyStep( action=value.action, state=(), info=value.policy_info) # TODO(b/130244652): Consider replacing 0 rewards & discounts with (). time_steps = ts.TimeStep( value.step_type, reward=tf.nest.map_structure(tf.zeros_like, value.reward), # unknown discount=tf.zeros_like(value.discount), # unknown observation=value.observation) next_time_steps = ts.TimeStep( step_type=value.next_step_type, reward=value.reward, discount=value.discount, observation=tf.zeros_like(value.discount)) value = trajectory.Transition(time_steps, policy_steps, next_time_steps) else: raise TypeError('Input type not supported: {}'.format(value)) num_outer_dims = 1 if self._squeeze_time_dim else 2 _validate_transition( value, self._data_context.transition_spec, num_outer_dims) value = nest_utils.prune_extra_keys( self._data_context.transition_spec, value) return value
def testNamedTuple(self): class A(collections.namedtuple('A', ('a', 'b'))): pass self.assertEqual( nest_utils.prune_extra_keys( [A(a={'aa': 1}, b=3), {'c': 4}], [A(a={'aa': 'aa', 'ab': 'ab'}, b='b'), {'c': 'c', 'd': 'd'}]), [A(a={'aa': 'aa'}, b='b'), {'c': 'c'}])
def testOrderedDict(self): OD = collections.OrderedDict # pylint: disable=invalid-name self.assertEqual( nest_utils.prune_extra_keys( OD([('a', OD([('aa', 1), ('ab', 2)])), ('b', OD([('ba', 1)]))]), OD([('a', OD([('aa', 'aa'), ('ab', 'ab'), ('ac', 'ac')])), ('b', OD([('ba', 'ba'), ('bb', 'bb')])), ('c', 'c')])), OD([('a', OD([('aa', 'aa'), ('ab', 'ab')])), ('b', OD([('ba', 'ba')]))]))
def __call__(self, value: typing.Any): """Converts `value` to a Transition. Performs data validation and pruning. - If `value` is already a `Transition`, only validation is performed. - If `value` is a `Trajectory` and `squeeze_time_dim = True` then `value` it must have tensors with shape `[B, T=2]` outer dims. This is converted to a `Transition` object without a time dimension. - If `value` is a `Trajectory` with tensors containing a time dimension having `T != 2`, a `ValueError` is raised. Args: value: A `Trajectory` or `Transition` object to convert. Returns: A validated and pruned `Transition`. If `squeeze_time_dim = True`, the resulting `Transition` has tensors with shape `[B, ...]`. Otherwise, the tensors will have shape `[B, T - 1, ...]`. Raises: TypeError: If `value` is not one of `Trajectory` or `Transition`. ValueError: If `value` has structure that doesn't match the converter's spec. TypeError: If `value` has a structure that doesn't match the converter's spec. ValueError: If `squeeze_time_dim=True` and `value` is a `Trajectory` with a time dimension having value other than `T=2`. """ if isinstance(value, trajectory.Transition): pass elif isinstance(value, trajectory.Trajectory): required_sequence_length = 2 if self._squeeze_time_dim else None _validate_trajectory( value, self._data_context.trajectory_spec, sequence_length=required_sequence_length) value = trajectory.to_transition(value) # Remove the now-singleton time dim. if self._squeeze_time_dim: value = tf.nest.map_structure( lambda x: composite.squeeze(x, axis=1), value) else: raise TypeError('Input type not supported: {}'.format(value)) self._validate_transition(value) value = nest_utils.prune_extra_keys( self._data_context.transition_spec, value) return value
def __call__(self, value: typing.Any): """Convers `value` to a Trajectory. Performs data validation and pruning. - If `value` is already a `Trajectory`, only validation is performed. - If `value` is a `Transition` with tensors containing two (`[B, T]`) outer dims, then it is simply repackaged to a `Trajectory` and then validated. - If `value` is a `Transition` with tensors containing one (`[B]`) outer dim, a `ValueError` is raised. Args: value: A `Trajectory` or `Transition` object to convert. Returns: A validated and pruned `Trajectory`. Raises: TypeError: If `value` is not one of `Trajectory` or `Transition`. ValueError: If `value` has structure that doesn't match the converter's spec. TypeError: If `value` has a structure that doesn't match the converter's spec. ValueError: If `value` is a `Transition` without a time dimension, as training Trajectories typically have batch and time dimensions. """ if isinstance(value, trajectory.Trajectory): pass elif isinstance(value, trajectory.Transition): value = trajectory.Trajectory( step_type=value.time_step.step_type, observation=value.time_step.observation, action=value.action_step.action, policy_info=value.action_step.info, next_step_type=value.next_time_step.step_type, reward=value.next_time_step.reward, discount=value.next_time_step.discount) else: raise TypeError('Input type not supported: {}'.format(value)) _validate_trajectory(value, self._data_context.trajectory_spec, sequence_length=self._sequence_length, num_outer_dims=self._num_outer_dims) value = nest_utils.prune_extra_keys(self._data_context.trajectory_spec, value) return value
def __call__(self, value: typing.Any) -> trajectory.Transition: """Convert `value` to an N-step Transition; validate data & prune. - If `value` is already a `Transition`, only validation is performed. - If `value` is a `Trajectory` with tensors containing a time dimension having `T != n + 1`, a `ValueError` is raised. Args: value: A `Trajectory` or `Transition` object to convert. Returns: A validated and pruned `Transition`. If `squeeze_time_dim = True`, the resulting `Transition` has tensors with shape `[B, ...]`. Otherwise, the tensors will have shape `[B, T - 1, ...]`. Raises: TypeError: If `value` is not one of `Trajectory` or `Transition`. ValueError: If `value` has structure that doesn't match the converter's spec. TypeError: If `value` has a structure that doesn't match the converter's spec. ValueError: If `n != None` and `value` is a `Trajectory` with a time dimension having value other than `T=n + 1`. """ if _is_transition_like(value): value = _as_tfa_transition(value) elif _is_trajectory_like(value): _validate_trajectory( value, self._data_context.trajectory_spec, sequence_length=None if self._n is None else self._n + 1) value = trajectory.to_n_step_transition(value, gamma=self._gamma) else: raise TypeError('Input type not supported: {}'.format(value)) _validate_transition( value, self._data_context.transition_spec, num_outer_dims=1) value = nest_utils.prune_extra_keys( self._data_context.transition_spec, value) return value
def action(self, time_step, policy_state=(), seed=None): """Generates next action given the time_step and policy_state. Args: time_step: A `TimeStep` tuple corresponding to `time_step_spec()`. policy_state: A Tensor, or a nested dict, list or tuple of Tensors representing the previous policy_state. seed: Seed to use if action performs sampling (optional). Returns: A `PolicyStep` named tuple containing: `action`: An action Tensor matching the `action_spec`. `state`: A policy state tensor to be fed into the next call to action. `info`: Optional side information such as action log probabilities. Raises: RuntimeError: If subclass __init__ didn't call super().__init__. ValueError or TypeError: If `validate_args is True` and inputs or outputs do not match `time_step_spec`, `policy_state_spec`, or `policy_step_spec`. """ if self._enable_functions and getattr(self, '_action_fn', None) is None: raise RuntimeError( 'Cannot find _action_fn. Did %s.__init__ call super?' % type(self).__name__) if self._enable_functions: action_fn = self._action_fn else: action_fn = self._action if self._validate_args: time_step = nest_utils.prune_extra_keys(self._time_step_spec, time_step) policy_state = nest_utils.prune_extra_keys(self._policy_state_spec, policy_state) nest_utils.assert_same_structure( time_step, self._time_step_spec, message='time_step and time_step_spec structures do not match') nest_utils.assert_same_structure( policy_state, self._policy_state_spec, message= 'policy_state and policy_state_spec structures do not match') if self._automatic_state_reset: policy_state = self._maybe_reset_state(time_step, policy_state) step = action_fn(time_step=time_step, policy_state=policy_state, seed=seed) def clip_action(action, action_spec): if isinstance(action_spec, tensor_spec.BoundedTensorSpec): return common.clip_to_spec(action, action_spec) return action if self._validate_args: nest_utils.assert_same_structure( step.action, self._action_spec, message='action and action_spec structures do not match') if self._clip: clipped_actions = tf.nest.map_structure(clip_action, step.action, self._action_spec) step = step._replace(action=clipped_actions) if self._validate_args: nest_utils.assert_same_structure( step, self._policy_step_spec, message= 'action output and policy_step_spec structures do not match') def compare_to_spec(value, spec): return value.dtype.is_compatible_with(spec.dtype) compatibility = [ compare_to_spec(v, s) for (v, s) in zip(tf.nest.flatten(step.action), tf.nest.flatten(self.action_spec)) ] if not all(compatibility): get_dtype = lambda x: x.dtype action_dtypes = tf.nest.map_structure(get_dtype, step.action) spec_dtypes = tf.nest.map_structure(get_dtype, self.action_spec) raise TypeError( 'Policy produced an action with a dtype that doesn\'t ' 'match its action_spec. Got action:\n %s\n with ' 'action_spec:\n %s' % (action_dtypes, spec_dtypes)) return step
def loss(self, experience: types.NestedTensor, weights: Optional[types.Tensor] = None, **kwargs) -> LossInfo: """Gets loss from the agent. If the user calls this from _train, it must be in a `tf.GradientTape` scope in order to apply gradients to trainable variables. If intermediate gradient steps are needed, _loss and _train will return different values since _loss only supports updating all gradients at once after all losses have been calculated. Args: experience: A batch of experience data in the form of a `Trajectory`. The structure of `experience` must match that of `self.training_data_spec`. All tensors in `experience` must be shaped `[batch, time, ...]` where `time` must be equal to `self.train_step_length` if that property is not `None`. weights: (optional). A `Tensor`, either `0-D` or shaped `[batch]`, containing weights to be used when calculating the total train loss. Weights are typically multiplied elementwise against the per-batch loss, but the implementation is up to the Agent. **kwargs: Any additional data as args to `loss`. Returns: A `LossInfo` loss tuple containing loss and info tensors. Raises: TypeError: If `validate_args is True` and: Experience is not type `Trajectory`; or if `experience` does not match `self.training_data_spec` structure types. ValueError: If `validate_args is True` and: Experience tensors' time axes are not compatible with `self.train_sequence_length`; or if experience does not match `self.training_data_spec` structure. ValueError: If `validate_args is True` and the user does not pass `**kwargs` matching `self.train_argspec`. RuntimeError: If the class was not initialized properly (`super.__init__` was not called). """ if self._enable_functions and getattr(self, "_loss_fn", None) is None: raise RuntimeError( "Cannot find _loss_fn. Did %s.__init__ call super?" % type(self).__name__) if self._validate_args: self._check_trajectory_dimensions(experience) self._check_train_argspec(kwargs) # Even though the checks above prune dict keys, we want them to see # the non-pruned versions to provide clearer error messages. # However, from here on out we want to remove dict entries that aren't # requested in the spec. experience = nest_utils.prune_extra_keys(self.training_data_spec, experience) kwargs = nest_utils.prune_extra_keys(self.train_argspec, kwargs) if self._enable_functions: loss_info = self._loss_fn(experience=experience, weights=weights, **kwargs) else: loss_info = self._loss(experience=experience, weights=weights, **kwargs) if not isinstance(loss_info, LossInfo): raise TypeError( "loss_info is not a subclass of LossInfo: {}".format( loss_info)) return loss_info
def __call__(self, value: typing.Any) -> trajectory.Transition: """Converts `value` to a Transition. Performs data validation and pruning. - If `value` is already a `Transition`, only validation is performed. - If `value` is a `Trajectory` and `squeeze_time_dim = True` then `value` it must have tensors with shape `[B, T=2]` outer dims. This is converted to a `Transition` object without a time dimension. - If `value` is a `Trajectory` with tensors containing a time dimension having `T != 2`, a `ValueError` is raised. Args: value: A `Trajectory` or `Transition` object to convert. Returns: A validated and pruned `Transition`. If `squeeze_time_dim = True`, the resulting `Transition` has tensors with shape `[B, ...]`. Otherwise, the tensors will have shape `[B, T - 1, ...]`. Raises: TypeError: If `value` is not one of `Trajectory` or `Transition`. ValueError: If `value` has structure that doesn't match the converter's spec. TypeError: If `value` has a structure that doesn't match the converter's spec. ValueError: If `squeeze_time_dim=True` and `value` is a `Trajectory` with a time dimension having value other than `T=2`. """ if _is_transition_like(value): value = _as_tfa_transition(value) elif _is_trajectory_like(value): required_sequence_length = 2 if self._squeeze_time_dim else None _validate_trajectory( value, self._data_context.trajectory_spec, sequence_length=required_sequence_length) value = trajectory.to_transition(value) # Remove the now-singleton time dim. if self._squeeze_time_dim: value = tf.nest.map_structure( lambda x: composite.squeeze(x, axis=1), value) else: raise TypeError('Input type not supported: {}'.format(value)) num_outer_dims = 1 if self._squeeze_time_dim else 2 _validate_transition( value, self._data_context.transition_spec, num_outer_dims) value = nest_utils.prune_extra_keys( self._data_context.transition_spec, value) if self._prepend_t0_to_next_time_step: # This is useful when using sequential model. It allows target_q network # to take all the information. next_time_step_with_t0 = value.next_time_step._replace( observation=tf.nest.map_structure( lambda x, y: tf.concat([x[:, :1, ...], y], axis=1), value.time_step.observation, value.next_time_step.observation)) value = value._replace(next_time_step=next_time_step_with_t0) return value