예제 #1
0
    def signature(
        cls,
        environment_spec: mava_specs.EnvironmentSpec,
        extras_spec: tf.TypeSpec = {},
    ) -> tf.TypeSpec:

        # This function currently assumes that self._discount is a scalar.
        # If it ever becomes a nested structure and/or a np.ndarray, this method
        # will need to know its structure / shape. This is because the signature
        # discount shape is the environment's discount shape and this adder's
        # discount shape broadcasted together. Also, the reward shape is this
        # signature discount shape broadcasted together with the environment
        # reward shape. As long as self._discount is a scalar, it will not affect
        # either the signature discount shape nor the signature reward shape, so we
        # can ignore it.

        agent_specs = environment_spec.get_agent_specs()
        agents = environment_spec.get_agent_ids()
        env_extras_spec = environment_spec.get_extra_specs()
        extras_spec.update(env_extras_spec)

        obs_specs = {}
        act_specs = {}
        reward_specs = {}
        step_discount_specs = {}
        for agent in agents:

            rewards_spec, step_discounts_spec = tree_utils.broadcast_structures(
                agent_specs[agent].rewards, agent_specs[agent].discounts
            )

            rewards_spec = tree.map_structure(
                _broadcast_specs, rewards_spec, step_discounts_spec
            )
            step_discounts_spec = tree.map_structure(copy.deepcopy, step_discounts_spec)

            obs_specs[agent] = agent_specs[agent].observations
            act_specs[agent] = agent_specs[agent].actions
            reward_specs[agent] = rewards_spec
            step_discount_specs[agent] = step_discounts_spec

        transition_spec = [
            obs_specs,
            act_specs,
            extras_spec,
            reward_specs,
            step_discount_specs,
            obs_specs,  # next_observation
            extras_spec,
        ]

        return tree.map_structure_with_path(
            base.spec_like_to_tensor_spec, tuple(transition_spec)
        )
예제 #2
0
    def signature(
        cls,
        environment_spec: specs.EnvironmentSpec,
        extras_spec: tf.TypeSpec = {},
    ) -> tf.TypeSpec:
        """This is a helper method for generating signatures for Reverb tables.
        Signatures are useful for validating data types and shapes, see Reverb's
        documentation for details on how they are used.
        Args:
          environment_spec: A `specs.EnvironmentSpec` whose fields are nested
            structures with leaf nodes that have `.shape` and `.dtype` attributes.
            This should come from the environment that will be used to generate
            the data inserted into the Reverb table.
          extras_spec: A nested structure with leaf nodes that have `.shape` and
            `.dtype` attributes. The structure (and shapes/dtypes) of this must
            be the same as the `extras` passed into `ReverbAdder.add`.
        Returns:
          A `Step` whose leaf nodes are `tf.TensorSpec` objects.
        """
        agent_specs = environment_spec.get_agent_specs()
        agents = environment_spec.get_agent_ids()
        env_extras_spec = environment_spec.get_extra_specs()
        extras_spec.update(env_extras_spec)

        obs_specs = {}
        act_specs = {}
        reward_specs = {}
        step_discount_specs = {}
        for agent in agents:
            rewards_spec, step_discounts_spec = tree_utils.broadcast_structures(
                agent_specs[agent].rewards, agent_specs[agent].discounts)
            obs_specs[agent] = agent_specs[agent].observations
            act_specs[agent] = agent_specs[agent].actions
            reward_specs[agent] = rewards_spec
            step_discount_specs[agent] = step_discounts_spec

        spec_step = base.Step(
            observations=obs_specs,
            actions=act_specs,
            rewards=reward_specs,
            discounts=step_discount_specs,
            start_of_episode=specs.Array(shape=(), dtype=bool),
            extras=extras_spec,
        )
        return tree.map_structure_with_path(base.spec_like_to_tensor_spec,
                                            spec_step)
예제 #3
0
def remove_singleton_batch_spec_dim(spec: tf.TypeSpec,
                                    outer_ndim: int) -> tf.TypeSpec:
    """Look for `spec`'s shape, check that outer dim is 1, and remove it.

  If `spec.shape[i] != 1` for any `i in range(outer_ndim)`, we stop removing
  singleton batch dimensions at `i` and return what's left.  This is necessary
  to handle the outputs of inconsistent layers like `tf.keras.layers.LSTM()`
  which may take as input `(batch, time, dim) = (1, 1, Nin)` and emits only the
  batch entry if `time == 1`: output shape is `(1, Nout)`.  We log an error
  in these cases.

  Args:
    spec: A `tf.TypeSpec`.
    outer_ndim: The maximum number of outer singleton dims to remove.

  Returns:
    A `tf.TypeSpec`, the spec without its outer batch dimension(s).

  Raises:
    ValueError: If `spec` lacks a `shape` property.
  """
    shape = getattr(spec, 'shape', None)
    if shape is None:
        shape = getattr(spec, '_shape', None)
    if shape is None:
        raise ValueError(
            'Could not remove singleton batch dim from spec; it lacks a shape: {}'
            .format(spec))
    for i in range(outer_ndim):
        if len(shape) <= i:
            logging.error(
                'Could not remove singleton batch dim from spec; len(shape) < %d.  '
                'Shape: %s.  Skipping.', i + 1, shape)
            break
        if tf.compat.dimension_value(shape[i]) != 1:
            logging.error(
                'Could not remove singleton batch dim from spec; shape[%d] != 1: %s '
                '(shape: %s).  Skipping.', i, spec, shape)
            break
        spec = spec._unbatch()  # pylint: disable=protected-access
    return spec