def update(self, tensor, outer_dims=(0, )): """Updates tensor normalizer variables.""" nest_utils.assert_matching_dtypes_and_inner_shapes( tensor, self._tensor_spec, caller=self, tensors_name='tensor', specs_name='tensor_spec') tensor = tf.nest.map_structure( lambda t: tf.cast(t, self.map_dtype(t.dtype)), tensor) return tf.group(self._update_ops(tensor, outer_dims))
def _update_ops(self, tensors, outer_dims): """Returns a list of ops which update normalizer variables for tensor. Args: tensors: The tensors of values to be normalized. outer_dims: Ignored. The batch dimensions are extracted by comparing the associated tensor with the specs. Returns: A list of ops, which when run will update all necessary normaliztion variables. """ del outer_dims nest_utils.assert_matching_dtypes_and_inner_shapes( tensors, self._tensor_spec, caller=self, tensors_name='tensors', specs_name='tensor_spec') outer_shape = nest_utils.get_outer_shape(tensors, self._tensor_spec) outer_rank_static = tf.compat.dimension_value(outer_shape.shape[0]) outer_axes = (list(range(outer_rank_static)) if outer_rank_static is not None else tf.range(tf.size(outer_shape))) n_a = tf.cast(tf.reduce_prod(outer_shape), self.dtype) flat_tensors = tf.nest.flatten(tensors) update_ops = [] for i, t in enumerate(flat_tensors): t = tf.cast(t, self.dtype) avg_a = tf.math.reduce_mean(t, outer_axes) m2_a = tf.math.reduce_sum(tf.math.squared_difference(t, avg_a), outer_axes) n_b = self._count[i] avg_b = self._avg[i] m2_b = self._m2[i] m2_b_c = self._m2_carry[i] n_ab, avg_ab, m2_ab, m2_ab_c = parallel_variance_calculation( n_a, avg_a, m2_a, n_b, avg_b, m2_b, m2_b_c) with tf.control_dependencies([n_ab, avg_ab, m2_ab, m2_ab_c]): update_ops.extend([ self._count[i].assign(n_ab), self._avg[i].assign(avg_ab), self._m2[i].assign(m2_ab), self._m2_carry[i].assign(m2_ab_c), ]) return update_ops
def test_loss_and_train_output(test: test_utils.TestCase, expect_equal_loss_values: bool, agent: tf_agent.TFAgent, experience: types.NestedTensor, weights: Optional[types.Tensor] = None, **kwargs): """Tests that loss() and train() outputs are equivalent. Checks that the outputs have the same structures and shapes, and compares loss values based on `expect_equal_loss_values`. Args: test: An instance of `test_utils.TestCase`. expect_equal_loss_values: Whether to expect `LossInfo.loss` to have the same values for loss() and train(). agent: An instance of `TFAgent`. experience: A batch of experience data in the form of a `Trajectory`. weights: (optional). A `Tensor` containing weights to be used when calculating the total loss. **kwargs: Any additional data as args to `train` and `loss`. """ loss_info_from_train = agent.train( experience=experience, weights=weights, **kwargs) loss_info_from_loss = agent.loss( experience=experience, weights=weights, **kwargs) test.assertIsInstance(loss_info_from_train, tf_agent.LossInfo) test.assertEqual(type(loss_info_from_train), type(loss_info_from_loss)) # Compare loss values. if expect_equal_loss_values: test.assertEqual( loss_info_from_train.loss, loss_info_from_loss.loss, msg='Expected equal loss values, but train() has output ' '{loss_from_train} vs loss() output {loss_from_loss}.'.format( loss_from_train=loss_info_from_train.loss, loss_from_loss=loss_info_from_loss.loss)) else: test.assertNotEqual( loss_info_from_train.loss, loss_info_from_loss.loss, msg='Expected train() and loss() output to have different loss values, ' 'but both are {loss}.'.format(loss=loss_info_from_train.loss)) # Check that both `LossInfo` outputs have matching dtypes and inner shapes. nest_utils.assert_matching_dtypes_and_inner_shapes(loss_info_from_train, loss_info_from_loss, test, '`LossInfo` from train()', '`LossInfo` from loss()')
def _check_train_argspec(self, kwargs): """Check that kwargs passed to train match `self.train_argspec`. Args: kwargs: The `kwargs` passed to `train()`. Raises: AttributeError: If `kwargs` keyset doesn't match `train_argspec`. ValueError: If `kwargs` do not match the specs in `train_argspec`. """ nest_utils.assert_matching_dtypes_and_inner_shapes( kwargs, self.train_argspec, allow_extra_fields=True, caller=self, tensors_name="`kwargs`", specs_name="`train_argspec`")
def testNestedNestWithNestedState(self): # layer structure: (., {'a': {'b': .}}) net = nest_map.NestMap((tf.keras.layers.Dense(7), { 'a': nest_map.NestMap({ 'b': tf.keras.layers.LSTM(8, return_state=True, return_sequences=True) }) })) # TODO(b/177337002): remove the forced tuple wrapping the LSTM # state once we make a generic KerasWrapper network and clean up # Sequential and NestMap to use that instead of singleton Sequential. out, state = net((tf.ones((1, 2)), { 'a': { 'b': tf.ones((1, 2)) } }), network_state=((), { 'a': { 'b': ((tf.ones((1, 8)), tf.ones((1, 8))), ) } })) nest_utils.assert_matching_dtypes_and_inner_shapes( out, (tf.TensorSpec(dtype=tf.float32, shape=(7, )), { 'a': { 'b': tf.TensorSpec(dtype=tf.float32, shape=(8, )) } }), caller=self, tensors_name='out', specs_name='out_expected') nest_utils.assert_matching_dtypes_and_inner_shapes( state, ((), { 'a': { 'b': ((tf.TensorSpec(dtype=tf.float32, shape=(8, )), tf.TensorSpec(dtype=tf.float32, shape=(8, ))), ) } }), caller=self, tensors_name='state', specs_name='state_expected')
def normalize(self, tensor, clip_value=5.0, center_mean=True, variance_epsilon=1e-3): """Applies normalization to tensor. Args: tensor: Tensor to normalize. clip_value: Clips normalized observations between +/- this value if clip_value > 0, otherwise does not apply clipping. center_mean: If true, subtracts off mean from normalized tensor. variance_epsilon: Epsilon to avoid division by zero in normalization. Returns: normalized_tensor: Tensor after applying normalization. """ nest_utils.assert_matching_dtypes_and_inner_shapes( tensor, self._tensor_spec, caller=self, tensors_name='tensors', specs_name='tensor_spec') tensor = [ tf.cast(t, self.map_dtype(t.dtype)) for t in tf.nest.flatten(tensor) ] with tf.name_scope(self._scope + '/normalize'): mean_estimate, var_estimate = self._get_mean_var_estimates() mean = (mean_estimate if center_mean else tf.nest.map_structure( tf.zeros_like, mean_estimate)) def _normalize_single_tensor(single_tensor, single_mean, single_var): return tf.nn.batch_normalization( single_tensor, single_mean, single_var, offset=None, scale=None, variance_epsilon=variance_epsilon, name='normalized_tensor') normalized_tensor = nest_utils.map_structure_up_to( self._flat_variable_spec, _normalize_single_tensor, tensor, mean, var_estimate, check_types=False) if clip_value > 0: def _clip(t): return tf.clip_by_value(t, -clip_value, clip_value, name='clipped_normalized_tensor') normalized_tensor = tf.nest.map_structure( _clip, normalized_tensor) normalized_tensor = tf.nest.pack_sequence_as(self._tensor_spec, normalized_tensor) normalized_tensor = tf.nest.map_structure( lambda t, spec: tf.cast(t, spec.dtype), normalized_tensor, self._tensor_spec) return normalized_tensor
def __call__(self, inputs, *args, **kwargs): """A wrapper around `Network.call`. A typical `call` method in a class subclassing `Network` will have a signature that accepts `inputs`, as well as other `*args` and `**kwargs`. `call` can optionally also accept `step_type` and `network_state` (if `state_spec != ()` is not trivial). e.g.: ```python def call(self, inputs, step_type=None, network_state=(), training=False): ... return outputs, new_network_state ``` We will validate the first argument (`inputs`) against `self.input_tensor_spec` if one is available. If a `network_state` kwarg is given it is also validated against `self.state_spec`. Similarly, the return value of the `call` method is expected to be a tuple/list with 2 values: `(output, new_state)`. We validate `new_state` against `self.state_spec`. If no `network_state` kwarg is given (or if empty `network_state = ()` is given, it is up to `call` to assume a proper "empty" state, and to emit an appropriate `output_state`. Args: inputs: The input to `self.call`, matching `self.input_tensor_spec`. *args: Additional arguments to `self.call`. **kwargs: Additional keyword arguments to `self.call`. These can include `network_state` and `step_type`. `step_type` is required if the network's `call` requires it. `network_state` is required if the underlying network's `call` requires it. Returns: A tuple `(outputs, new_network_state)`. """ if self.input_tensor_spec is not None: nest_utils.assert_matching_dtypes_and_inner_shapes( inputs, self.input_tensor_spec, allow_extra_fields=True, caller=self, tensors_name="`inputs`", specs_name="`input_tensor_spec`") call_argspec = tf_inspect.getargspec(self.call) # Convert *args, **kwargs to a canonical kwarg representation. normalized_kwargs = tf_inspect.getcallargs(self.call, inputs, *args, **kwargs) # TODO(b/156315434): Rename network_state to just state. network_state = normalized_kwargs.get("network_state", None) normalized_kwargs.pop("self", None) if common.safe_has_state(network_state): nest_utils.assert_matching_dtypes_and_inner_shapes( network_state, self.state_spec, allow_extra_fields=True, caller=self, tensors_name="`network_state`", specs_name="`state_spec`") if "step_type" not in call_argspec.args and not call_argspec.keywords: normalized_kwargs.pop("step_type", None) if (network_state in (None, ()) and "network_state" not in call_argspec.args and not call_argspec.keywords): normalized_kwargs.pop("network_state", None) outputs, new_state = super(Network, self).__call__(**normalized_kwargs) nest_utils.assert_matching_dtypes_and_inner_shapes( new_state, self.state_spec, allow_extra_fields=True, caller=self, tensors_name="`new_state`", specs_name="`state_spec`") return outputs, new_state