示例#1
0
    def update(self, tensor, outer_dims=(0, )):
        """Updates tensor normalizer variables."""
        nest_utils.assert_matching_dtypes_and_inner_shapes(
            tensor,
            self._tensor_spec,
            caller=self,
            tensors_name='tensor',
            specs_name='tensor_spec')
        tensor = tf.nest.map_structure(
            lambda t: tf.cast(t, self.map_dtype(t.dtype)), tensor)

        return tf.group(self._update_ops(tensor, outer_dims))
示例#2
0
    def _update_ops(self, tensors, outer_dims):
        """Returns a list of ops which update normalizer variables for tensor.

    Args:
      tensors: The tensors of values to be normalized.
      outer_dims: Ignored.  The batch dimensions are extracted by comparing
        the associated tensor with the specs.

    Returns:
      A list of ops, which when run will update all necessary normaliztion
      variables.
    """
        del outer_dims

        nest_utils.assert_matching_dtypes_and_inner_shapes(
            tensors,
            self._tensor_spec,
            caller=self,
            tensors_name='tensors',
            specs_name='tensor_spec')
        outer_shape = nest_utils.get_outer_shape(tensors, self._tensor_spec)
        outer_rank_static = tf.compat.dimension_value(outer_shape.shape[0])
        outer_axes = (list(range(outer_rank_static)) if outer_rank_static
                      is not None else tf.range(tf.size(outer_shape)))

        n_a = tf.cast(tf.reduce_prod(outer_shape), self.dtype)

        flat_tensors = tf.nest.flatten(tensors)

        update_ops = []

        for i, t in enumerate(flat_tensors):
            t = tf.cast(t, self.dtype)
            avg_a = tf.math.reduce_mean(t, outer_axes)
            m2_a = tf.math.reduce_sum(tf.math.squared_difference(t, avg_a),
                                      outer_axes)
            n_b = self._count[i]
            avg_b = self._avg[i]
            m2_b = self._m2[i]
            m2_b_c = self._m2_carry[i]

            n_ab, avg_ab, m2_ab, m2_ab_c = parallel_variance_calculation(
                n_a, avg_a, m2_a, n_b, avg_b, m2_b, m2_b_c)
            with tf.control_dependencies([n_ab, avg_ab, m2_ab, m2_ab_c]):
                update_ops.extend([
                    self._count[i].assign(n_ab),
                    self._avg[i].assign(avg_ab),
                    self._m2[i].assign(m2_ab),
                    self._m2_carry[i].assign(m2_ab_c),
                ])

        return update_ops
示例#3
0
def test_loss_and_train_output(test: test_utils.TestCase,
                               expect_equal_loss_values: bool,
                               agent: tf_agent.TFAgent,
                               experience: types.NestedTensor,
                               weights: Optional[types.Tensor] = None,
                               **kwargs):
  """Tests that loss() and train() outputs are equivalent.

  Checks that the outputs have the same structures and shapes, and compares
  loss values based on `expect_equal_loss_values`.

  Args:
    test: An instance of `test_utils.TestCase`.
    expect_equal_loss_values: Whether to expect `LossInfo.loss` to have the same
      values for loss() and train().
    agent: An instance of `TFAgent`.
    experience: A batch of experience data in the form of a `Trajectory`.
    weights: (optional).  A `Tensor` containing weights to be used when
      calculating the total loss.
    **kwargs: Any additional data as args to `train` and `loss`.
  """
  loss_info_from_train = agent.train(
      experience=experience, weights=weights, **kwargs)
  loss_info_from_loss = agent.loss(
      experience=experience, weights=weights, **kwargs)

  test.assertIsInstance(loss_info_from_train, tf_agent.LossInfo)
  test.assertEqual(type(loss_info_from_train), type(loss_info_from_loss))

  # Compare loss values.
  if expect_equal_loss_values:
    test.assertEqual(
        loss_info_from_train.loss,
        loss_info_from_loss.loss,
        msg='Expected equal loss values, but train() has output '
        '{loss_from_train} vs loss() output {loss_from_loss}.'.format(
            loss_from_train=loss_info_from_train.loss,
            loss_from_loss=loss_info_from_loss.loss))
  else:
    test.assertNotEqual(
        loss_info_from_train.loss,
        loss_info_from_loss.loss,
        msg='Expected train() and loss() output to have different loss values, '
        'but both are {loss}.'.format(loss=loss_info_from_train.loss))

  # Check that both `LossInfo` outputs have matching dtypes and inner shapes.
  nest_utils.assert_matching_dtypes_and_inner_shapes(loss_info_from_train,
                                                     loss_info_from_loss, test,
                                                     '`LossInfo` from train()',
                                                     '`LossInfo` from loss()')
示例#4
0
    def _check_train_argspec(self, kwargs):
        """Check that kwargs passed to train match `self.train_argspec`.

    Args:
      kwargs: The `kwargs` passed to `train()`.

    Raises:
      AttributeError: If `kwargs` keyset doesn't match `train_argspec`.
      ValueError: If `kwargs` do not match the specs in `train_argspec`.
    """
        nest_utils.assert_matching_dtypes_and_inner_shapes(
            kwargs,
            self.train_argspec,
            allow_extra_fields=True,
            caller=self,
            tensors_name="`kwargs`",
            specs_name="`train_argspec`")
示例#5
0
 def testNestedNestWithNestedState(self):
     # layer structure: (., {'a': {'b': .}})
     net = nest_map.NestMap((tf.keras.layers.Dense(7), {
         'a':
         nest_map.NestMap({
             'b':
             tf.keras.layers.LSTM(8,
                                  return_state=True,
                                  return_sequences=True)
         })
     }))
     # TODO(b/177337002): remove the forced tuple wrapping the LSTM
     # state once we make a generic KerasWrapper network and clean up
     # Sequential and NestMap to use that instead of singleton Sequential.
     out, state = net((tf.ones((1, 2)), {
         'a': {
             'b': tf.ones((1, 2))
         }
     }),
                      network_state=((), {
                          'a': {
                              'b': ((tf.ones((1, 8)), tf.ones((1, 8))), )
                          }
                      }))
     nest_utils.assert_matching_dtypes_and_inner_shapes(
         out, (tf.TensorSpec(dtype=tf.float32, shape=(7, )), {
             'a': {
                 'b': tf.TensorSpec(dtype=tf.float32, shape=(8, ))
             }
         }),
         caller=self,
         tensors_name='out',
         specs_name='out_expected')
     nest_utils.assert_matching_dtypes_and_inner_shapes(
         state, ((), {
             'a': {
                 'b': ((tf.TensorSpec(dtype=tf.float32, shape=(8, )),
                        tf.TensorSpec(dtype=tf.float32, shape=(8, ))), )
             }
         }),
         caller=self,
         tensors_name='state',
         specs_name='state_expected')
示例#6
0
    def normalize(self,
                  tensor,
                  clip_value=5.0,
                  center_mean=True,
                  variance_epsilon=1e-3):
        """Applies normalization to tensor.

    Args:
      tensor: Tensor to normalize.
      clip_value: Clips normalized observations between +/- this value if
        clip_value > 0, otherwise does not apply clipping.
      center_mean: If true, subtracts off mean from normalized tensor.
      variance_epsilon: Epsilon to avoid division by zero in normalization.

    Returns:
      normalized_tensor: Tensor after applying normalization.
    """
        nest_utils.assert_matching_dtypes_and_inner_shapes(
            tensor,
            self._tensor_spec,
            caller=self,
            tensors_name='tensors',
            specs_name='tensor_spec')
        tensor = [
            tf.cast(t, self.map_dtype(t.dtype))
            for t in tf.nest.flatten(tensor)
        ]

        with tf.name_scope(self._scope + '/normalize'):
            mean_estimate, var_estimate = self._get_mean_var_estimates()
            mean = (mean_estimate if center_mean else tf.nest.map_structure(
                tf.zeros_like, mean_estimate))

            def _normalize_single_tensor(single_tensor, single_mean,
                                         single_var):
                return tf.nn.batch_normalization(
                    single_tensor,
                    single_mean,
                    single_var,
                    offset=None,
                    scale=None,
                    variance_epsilon=variance_epsilon,
                    name='normalized_tensor')

            normalized_tensor = nest_utils.map_structure_up_to(
                self._flat_variable_spec,
                _normalize_single_tensor,
                tensor,
                mean,
                var_estimate,
                check_types=False)

            if clip_value > 0:

                def _clip(t):
                    return tf.clip_by_value(t,
                                            -clip_value,
                                            clip_value,
                                            name='clipped_normalized_tensor')

                normalized_tensor = tf.nest.map_structure(
                    _clip, normalized_tensor)

        normalized_tensor = tf.nest.pack_sequence_as(self._tensor_spec,
                                                     normalized_tensor)
        normalized_tensor = tf.nest.map_structure(
            lambda t, spec: tf.cast(t, spec.dtype), normalized_tensor,
            self._tensor_spec)
        return normalized_tensor
示例#7
0
    def __call__(self, inputs, *args, **kwargs):
        """A wrapper around `Network.call`.

    A typical `call` method in a class subclassing `Network` will have a
    signature that accepts `inputs`, as well as other `*args` and `**kwargs`.
    `call` can optionally also accept `step_type` and `network_state`
    (if `state_spec != ()` is not trivial).  e.g.:

    ```python
    def call(self,
             inputs,
             step_type=None,
             network_state=(),
             training=False):
        ...
        return outputs, new_network_state
    ```

    We will validate the first argument (`inputs`)
    against `self.input_tensor_spec` if one is available.

    If a `network_state` kwarg is given it is also validated against
    `self.state_spec`.  Similarly, the return value of the `call` method is
    expected to be a tuple/list with 2 values:  `(output, new_state)`.
    We validate `new_state` against `self.state_spec`.

    If no `network_state` kwarg is given (or if empty `network_state = ()` is
    given, it is up to `call` to assume a proper "empty" state, and to
    emit an appropriate `output_state`.

    Args:
      inputs: The input to `self.call`, matching `self.input_tensor_spec`.
      *args: Additional arguments to `self.call`.
      **kwargs: Additional keyword arguments to `self.call`.
        These can include `network_state` and `step_type`.  `step_type` is
        required if the network's `call` requires it. `network_state` is
        required if the underlying network's `call` requires it.

    Returns:
      A tuple `(outputs, new_network_state)`.
    """
        if self.input_tensor_spec is not None:
            nest_utils.assert_matching_dtypes_and_inner_shapes(
                inputs,
                self.input_tensor_spec,
                allow_extra_fields=True,
                caller=self,
                tensors_name="`inputs`",
                specs_name="`input_tensor_spec`")

        call_argspec = tf_inspect.getargspec(self.call)

        # Convert *args, **kwargs to a canonical kwarg representation.
        normalized_kwargs = tf_inspect.getcallargs(self.call, inputs, *args,
                                                   **kwargs)
        # TODO(b/156315434): Rename network_state to just state.
        network_state = normalized_kwargs.get("network_state", None)
        normalized_kwargs.pop("self", None)

        if common.safe_has_state(network_state):
            nest_utils.assert_matching_dtypes_and_inner_shapes(
                network_state,
                self.state_spec,
                allow_extra_fields=True,
                caller=self,
                tensors_name="`network_state`",
                specs_name="`state_spec`")

        if "step_type" not in call_argspec.args and not call_argspec.keywords:
            normalized_kwargs.pop("step_type", None)

        if (network_state in (None, ())
                and "network_state" not in call_argspec.args
                and not call_argspec.keywords):
            normalized_kwargs.pop("network_state", None)

        outputs, new_state = super(Network, self).__call__(**normalized_kwargs)

        nest_utils.assert_matching_dtypes_and_inner_shapes(
            new_state,
            self.state_spec,
            allow_extra_fields=True,
            caller=self,
            tensors_name="`new_state`",
            specs_name="`state_spec`")

        return outputs, new_state