示例#1
0
  def _check_network_output(self, net, label):
    """Check outputs of q_net and target_q_net against expected shape.

    Subclasses that require different q_network outputs should override
    this function.

    Args:
      net: A `Network`.
      label: A label to print in case of a mismatch.
    """
    network_utils.check_single_floating_network_output(
        net.create_variables(),
        expected_output_shape=(self._num_actions,),
        label=label)
示例#2
0
  def __init__(
      self,
      time_step_spec: ts.TimeStep,
      action_spec: types.NestedTensorSpec,
      q_network: network.Network,
      emit_log_probability: bool = False,
      observation_and_action_constraint_splitter: Optional[
          types.Splitter] = None,
      validate_action_spec_and_network: bool = True,
      name: Optional[Text] = None):
    """Builds a Q-Policy given a q_network.

    Args:
      time_step_spec: A `TimeStep` spec of the expected time_steps.
      action_spec: A nest of BoundedTensorSpec representing the actions.
      q_network: An instance of a `tf_agents.network.Network`,
        callable via `network(observation, step_type) -> (output, final_state)`.
      emit_log_probability: Whether to emit log-probs in info of `PolicyStep`.
      observation_and_action_constraint_splitter: A function used to process
        observations with action constraints. These constraints can indicate,
        for example, a mask of valid/invalid actions for a given state of the
        environment.
        The function takes in a full observation and returns a tuple consisting
        of 1) the part of the observation intended as input to the network and
        2) the constraint. An example
        `observation_and_action_constraint_splitter` could be as simple as:
        ```
        def observation_and_action_constraint_splitter(observation):
          return observation['network_input'], observation['constraint']
        ```
        *Note*: when using `observation_and_action_constraint_splitter`, make
        sure the provided `q_network` is compatible with the network-specific
        half of the output of the `observation_and_action_constraint_splitter`.
        In particular, `observation_and_action_constraint_splitter` will be
        called on the observation before passing to the network.
        If `observation_and_action_constraint_splitter` is None, action
        constraints are not applied.
      validate_action_spec_and_network: If `True` (default),
        action_spec is checked to make sure it is a single scalar spec
        with a minimum of zero.  Also validates that the network's output
        matches the spec.
      name: The name of this policy. All variables in this module will fall
        under that name. Defaults to the class name.

    Raises:
      ValueError: If `q_network.action_spec` exists and is not compatible with
        `action_spec`.
      NotImplementedError: If `action_spec` contains more than one
        `BoundedTensorSpec`.
    """
    action_spec = tensor_spec.from_spec(action_spec)
    time_step_spec = tensor_spec.from_spec(time_step_spec)

    network_action_spec = getattr(q_network, 'action_spec', None)

    if network_action_spec is not None:
      action_spec = cast(tf.TypeSpec, action_spec)
      if not action_spec.is_compatible_with(network_action_spec):
        raise ValueError(
            'action_spec must be compatible with q_network.action_spec; '
            'instead got action_spec=%s, q_network.action_spec=%s' % (
                action_spec, network_action_spec))

    flat_action_spec = tf.nest.flatten(action_spec)
    if len(flat_action_spec) > 1:
      raise ValueError(
          'Only scalar actions are supported now, but action spec is: {}'
          .format(action_spec))
    if validate_action_spec_and_network:
      spec = flat_action_spec[0]
      if spec.shape.rank > 0:
        raise ValueError(
            'Only scalar actions are supported now, but action spec is: {}'
            .format(action_spec))

      if spec.minimum != 0:
        raise ValueError(
            'Action specs should have minimum of 0, but saw: {0}'.format(spec))

      num_actions = spec.maximum - spec.minimum + 1
      network_utils.check_single_floating_network_output(
          q_network.create_variables(), (num_actions,), str(q_network))

    # We need to maintain the flat action spec for dtype, shape and range.
    self._flat_action_spec = flat_action_spec[0]

    self._q_network = q_network
    super(QPolicy, self).__init__(
        time_step_spec,
        action_spec,
        policy_state_spec=q_network.state_spec,
        clip=False,
        emit_log_probability=emit_log_probability,
        observation_and_action_constraint_splitter=(
            observation_and_action_constraint_splitter),
        name=name)
    def __init__(self,
                 time_step_spec: ts.TimeStep,
                 action_spec: types.NestedTensorSpec,
                 q_network: network.Network,
                 min_q_value: float,
                 max_q_value: float,
                 observation_and_action_constraint_splitter: Optional[
                     types.Splitter] = None,
                 temperature: types.Float = 1.0):
        """Builds a categorical Q-policy given a categorical Q-network.

    Args:
      time_step_spec: A `TimeStep` spec of the expected time_steps.
      action_spec: A `BoundedTensorSpec` representing the actions.
      q_network: A network.Network to use for our policy.
      min_q_value: A float specifying the minimum Q-value, used for setting up
        the support.
      max_q_value: A float specifying the maximum Q-value, used for setting up
        the support.
      observation_and_action_constraint_splitter: A function used to process
        observations with action constraints. These constraints can indicate,
        for example, a mask of valid/invalid actions for a given state of the
        environment.
        The function takes in a full observation and returns a tuple consisting
        of 1) the part of the observation intended as input to the network and
        2) the constraint. An example
        `observation_and_action_constraint_splitter` could be as simple as:
        ```
        def observation_and_action_constraint_splitter(observation):
          return observation['network_input'], observation['constraint']
        ```
        *Note*: when using `observation_and_action_constraint_splitter`, make
        sure the provided `q_network` is compatible with the network-specific
        half of the output of the `observation_and_action_constraint_splitter`.
        In particular, `observation_and_action_constraint_splitter` will be
        called on the observation before passing to the network.
        If `observation_and_action_constraint_splitter` is None, action
        constraints are not applied.
      temperature: temperature for sampling, when close to 0.0 is arg_max.

    Raises:
      ValueError: if `q_network` does not have property `num_atoms`.
      TypeError: if `action_spec` is not a `BoundedTensorSpec`.
    """
        network_action_spec = getattr(q_network, 'action_spec', None)

        if network_action_spec is not None:
            action_spec = cast(tf.TypeSpec, action_spec)
            if not action_spec.is_compatible_with(network_action_spec):
                raise ValueError(
                    'action_spec must be compatible with q_network.action_spec; '
                    'instead got action_spec=%s, q_network.action_spec=%s' %
                    (action_spec, network_action_spec))

        if not isinstance(action_spec, tensor_spec.BoundedTensorSpec):
            raise TypeError(
                'action_spec must be a BoundedTensorSpec. Got: %s' %
                (action_spec, ))

        action_spec = cast(tensor_spec.BoundedTensorSpec, action_spec)
        if action_spec.minimum != 0:
            raise ValueError(
                'Action specs should have minimum of 0, but saw: {0}.  If collecting '
                'from a python environment, consider using '
                'tf_agents.environments.wrappers.ActionOffsetWrapper.'.format(
                    action_spec))

        num_actions = action_spec.maximum - action_spec.minimum + 1
        try:
            num_atoms = q_network.num_atoms
        except AttributeError:
            raise ValueError(
                'Expected q_network to have property `num_atoms`, but '
                'it doesn\'t. (Note: you likely want to use a '
                'CategoricalQNetwork.) Network is: %s' % q_network)
        self._num_atoms = num_atoms

        network_utils.check_single_floating_network_output(
            q_network.create_variables(), (num_actions, num_atoms),
            str(q_network))

        super(CategoricalQPolicy,
              self).__init__(time_step_spec,
                             action_spec,
                             policy_state_spec=q_network.state_spec,
                             observation_and_action_constraint_splitter=(
                                 observation_and_action_constraint_splitter))

        self._temperature = tf.convert_to_tensor(temperature, dtype=tf.float32)
        self._q_network = q_network

        # Generate support in numpy so that we can assign it to a constant and avoid
        # having a tensor property.
        support = np.linspace(min_q_value,
                              max_q_value,
                              self._num_atoms,
                              dtype=np.float32)
        self._support = tf.constant(support, dtype=tf.float32)
        self._action_dtype = action_spec.dtype
示例#4
0
 def _check_network_output(self, net, label):
     network_utils.check_single_floating_network_output(
         net.create_variables(),
         expected_output_shape=(self._num_actions, self._num_atoms),
         label=label)
示例#5
0
    def __init__(self,
                 time_step_spec: ts.TimeStep,
                 action_spec: types.NestedTensorSpec,
                 q_network: network.Network,
                 sampler: cem_actions_sampler.ActionsSampler,
                 init_mean: types.NestedArray,
                 init_var: types.NestedArray,
                 actor_policy: Optional[tf_policy.TFPolicy] = None,
                 minimal_var: float = 0.0001,
                 info_spec: types.NestedSpecTensorOrArray = (),
                 num_samples: int = 32,
                 num_elites: int = 4,
                 num_iterations: int = 32,
                 emit_log_probability: bool = False,
                 preprocess_state_action: bool = True,
                 training: bool = False,
                 weights: types.NestedTensorOrArray = None,
                 name: Optional[str] = None):
        """Builds a CEM-Policy given a network and a sampler.

    Args:
      time_step_spec: A `TimeStep` spec of the expected time_steps.
      action_spec: A nest of BoundedTensorSpec representing the actions.
      q_network: An instance of a `tf_agents.network.Network`, callable via
        `network(observation, step_type) -> (output, final_state)`.
      sampler: Samples the actions needed for the CEM.
      init_mean: A list or tuple or scalar, reprenting initial mean for actions.
      init_var: A list or tuple or scalar, reprenting initial var for actions.
      actor_policy: Optional actor policy.
      minimal_var: Minimal variance to prevent CEM distributon collapsing.
      info_spec: A policy info spec.
      num_samples: Number of samples to sample each round.
      num_elites: Number of best actions each round to refit the distribution
        with.
      num_iterations: Number of iterations to run the CEM loop.
      emit_log_probability: Whether to emit log-probs in info of `PolicyStep`.
      preprocess_state_action: The shape of state is (B, ...) and the shape of
        action is (B, N, A). When preprocess_state_action is enabled, the state
        will be tile_batched to be (BxN, ...) and the action will be reshaped
        to be (BxN, A). When preprocess_state_action is not enabled, the same
        operation needs to be done inside the network. This is helpful when the
        input have large memory requirements and the replication of state could
        happen after a few layers inside the network.
      training: Whether it is in training mode or inference mode.
      weights: A nested structure of weights w/ the same structure as action.
      name: The name of this policy. All variables in this module will fall
        under that name. Defaults to the class name.

    Raises:
      ValueError: If `q_network.action_spec` exists and is not compatible with
        `action_spec`.
    """
        network_action_spec = getattr(q_network, 'action_spec', None)

        if network_action_spec is not None:
            if not action_spec.is_compatible_with(network_action_spec):
                raise ValueError(
                    'action_spec must be compatible with q_network.action_spec; '
                    'instead got action_spec=%s, q_network.action_spec=%s' %
                    (action_spec, network_action_spec))

        if q_network:
            network_utils.check_single_floating_network_output(
                q_network.create_variables(),
                expected_output_shape=(),
                label=str(q_network))
            policy_state_spec = q_network.state_spec
        else:
            policy_state_spec = ()

        self._actor_policy = actor_policy
        self._q_network = q_network
        self._init_mean = init_mean
        self._init_var = init_var
        self._minimal_var = minimal_var
        self._num_samples = num_samples  # N
        self._num_elites = num_elites  # M
        self._num_iterations = num_iterations
        self._actions_sampler = sampler
        self._observation_spec = time_step_spec.observation
        self._training = training
        self._preprocess_state_action = preprocess_state_action
        self._weights = weights

        super(CEMPolicy,
              self).__init__(time_step_spec,
                             action_spec,
                             info_spec=info_spec,
                             policy_state_spec=policy_state_spec,
                             clip=False,
                             emit_log_probability=emit_log_probability,
                             name=name)