Example #1
0
    def _check_trajectory_dimensions(self, experience):
        """Checks the given Trajectory for batch and time outer dimensions."""
        if not nest_utils.is_batched_nested_tensors(
                experience,
                self.training_data_spec,
                num_outer_dims=self._num_outer_dims,
                allow_extra_fields=True,
        ):
            debug_str_1 = tf.nest.map_structure(lambda tp: tp.shape,
                                                experience)
            debug_str_2 = tf.nest.map_structure(lambda spec: spec.shape,
                                                self.training_data_spec)

            if self._num_outer_dims == 2:
                raise ValueError(
                    "All of the Tensors in `experience` must have two outer "
                    "dimensions: batch size and time. Specifically, tensors should be "
                    "shaped as [B x T x ...].\n"
                    "Full shapes of experience tensors:\n{}.\n"
                    "Full expected shapes (minus outer dimensions):\n{}.".
                    format(debug_str_1, debug_str_2))
            else:
                # self._num_outer_dims must be 1.
                raise ValueError(
                    "All of the Tensors in `experience` must have a single outer "
                    "batch_size dimension. If you also want to include an outer time "
                    "dimension, set num_outer_dims=2 when initializing your agent.\n"
                    "Full shapes of experience tensors:\n{}.\n"
                    "Full expected shapes (minus batch_size dimension):\n{}.".
                    format(debug_str_1, debug_str_2))

        # If we have a time dimension and a train_sequence_length, make sure they
        # match.
        if self._num_outer_dims == 2 and self.train_sequence_length is not None:

            def check_shape(path, t):  # pylint: disable=invalid-name
                if t.shape[1] != self.train_sequence_length:
                    debug_str = tf.nest.map_structure(lambda tp: tp.shape,
                                                      experience)
                    raise ValueError(
                        "The agent was configured to expect a `train_sequence_length` "
                        "of '{seq_len}'. Experience is expected to be shaped `[Batch x "
                        "Trajectory_sequence_length x spec.shape]` but at least one the "
                        "Tensors in `experience` has a time axis dim value '{t_dim}' vs "
                        "the expected '{seq_len}'.\nFirst such tensor is:\n\t"
                        "experience.{path}. \nFull shape structure of "
                        "experience:\n\t{debug_str}".format(
                            seq_len=self.train_sequence_length,
                            t_dim=t.shape[1],
                            path=path,
                            debug_str=debug_str))

            nest_utils.map_structure_with_paths(check_shape, experience)
Example #2
0
def _validate_trajectory(
    value: trajectory.Trajectory,
    trajectory_spec: trajectory.Trajectory,
    sequence_length: typing.Optional[int],
    num_outer_dims: te.Literal[1, 2] = 2):  # pylint: disable=bad-whitespace
  """Validate a Trajectory given its spec and a sequence length."""
  if not nest_utils.is_batched_nested_tensors(
      value, trajectory_spec, num_outer_dims=num_outer_dims,
      allow_extra_fields=True):
    debug_str_1 = tf.nest.map_structure(lambda tp: tp.shape, value)
    debug_str_2 = tf.nest.map_structure(
        lambda spec: spec.shape, trajectory_spec)

    shape_str = (
        'two outer dimensions' if num_outer_dims == 2
        else 'one outer dimension')
    shape_prefix_str = '[B, T]' if num_outer_dims == 2 else '[B]'
    raise ValueError(
        'All of the Tensors in `value` must have {shape_str}. Specifically, '
        'tensors must have shape `{shape_prefix_str} + spec.shape`.\n'
        'Full shapes of value tensors:\n  {debug_str_1}.\n'
        'Expected shapes (excluding the {shape_str}):\n  {debug_str_2}.'
        .format(
            shape_str=shape_str,
            debug_str_1=debug_str_1,
            debug_str_2=debug_str_2,
            shape_prefix_str=shape_prefix_str))

  # If we have a time dimension and a train_sequence_length, make sure they
  # match.
  if sequence_length is not None:
    def check_shape(path, t):  # pylint: disable=invalid-name
      if t.shape[1] != sequence_length:
        debug_str = tf.nest.map_structure(lambda tp: tp.shape, value)
        raise ValueError(
            'The agent was configured to expect a `sequence_length` '
            'of \'{seq_len}\'. Value is expected to be shaped `[B, T] + '
            'spec.shape` but at least one of the Tensors in `value` has a '
            'time axis dim value \'{t_dim}\' vs '
            'the expected \'{seq_len}\'.\nFirst such tensor is:\n\t'
            'value.{path}. \nFull shape structure of '
            'value:\n\t{debug_str}'.format(
                seq_len=sequence_length,
                t_dim=t.shape[1],
                path=path,
                debug_str=debug_str))
    nest_utils.map_structure_with_paths(check_shape, value)
Example #3
0
def merge_to_parameters_from_dict(
    value: Params, params_dict: Mapping[Text, Any]) -> Params:
  """Merges dict matching data of `parameters_to_dict(value)` to a new `Params`.

  For more details, see the example below and the documentation of
  `parameters_to_dict`.

  Example:

  ```python
  scale_matrix = tf.Variable([[1.0, 2.0], [-1.0, 0.0]])
  d = tfp.distributions.MultivariateNormalDiag(
      loc=[1.0, 1.0], scale_diag=[2.0, 3.0], validate_args=True)
  b = tfp.bijectors.ScaleMatvecLinearOperator(
      scale=tf.linalg.LinearOperatorFullMatrix(matrix=scale_matrix),
      adjoint=True)
  b_d = b(d)
  p = utils.get_parameters(b_d)

  params_dict = utils.parameters_to_dict(p)
  params_dict["bijector"]["scale"]["matrix"] = new_scale_matrix

  new_params = utils.merge_to_parameters_from_dict(
    p, params_dict)

  # new_d is a `ScaleMatvecLinearOperator()(MultivariateNormalDiag)` with
  # a new scale matrix.
  new_d = utils.make_from_parameters(new_params)
  ```

  Args:
    value: A `Params` from which `params_dict` was derived.
    params_dict: A nested `dict` created by e.g. calling
      `parameters_to_dict(value)` and  modifying it to modify parameters.
      **NOTE** If any keys in the dict are missing, the "default" value in
      `value` is used instead.

  Returns:
    A new `Params` object which can then be turned into e.g. a
    `tfp.Distribution` via `make_from_parameters`.

  Raises:
    ValueError: If `params_dict` has keys missing from `value.params`.
    KeyError: If a subdict entry is missing for a nested value in
      `value.params`.
  """
  new_params = {}
  if params_dict is None:
    params_dict = {}

  processed_params = set()
  for k, v in value.params.items():
    # pylint: disable=cell-var-from-loop
    visited = set()
    converted = set()

    def convert(params_k, p):
      if params_k is not None:
        params_key = "{}:{}".format(k, params_k)
        visited.add(params_key)
        params_dict_value = params_dict.get(params_key, None)
        if params_dict_value is not None:
          converted.add(params_key)
      else:
        params_key = k
        params_dict_value = params_dict.get(k, None)
      processed_params.add(params_key)
      if isinstance(p, Params):
        return merge_to_parameters_from_dict(p, params_dict_value)
      else:
        return params_dict_value if params_dict_value is not None else p
    # pylint: enable=cell-var-from-loop

    if tf.nest.is_nested(v):
      new_params[k] = nest_utils.map_structure_with_paths(convert, v)
      if converted and visited != converted:
        raise KeyError(
            "Only saw partial information from the dictionary for nested "
            "key '{}' in params_dict.  Entries provided: {}.  "
            "Entries required: {}"
            .format(k, sorted(converted), sorted(visited)))
    else:
      new_params[k] = convert(None, v)

  unvisited_params_keys = set(params_dict) - processed_params
  if unvisited_params_keys:
    raise ValueError(
        "params_dict had keys that were not part of value.params.  "
        "params_dict keys: {}, value.params processed keys: {}".format(
            sorted(params_dict.keys()), sorted(processed_params)))

  return Params(type_=value.type_, params=new_params)