Пример #1
0
    def __init__(self, input_tensor_spec=None, state_spec=(), name=None):
        """Creates an instance of `Network`.

    Args:
      input_tensor_spec: A nest of `tf.TypeSpec` representing the
        input observations.  Optional.  If not provided, `create_variables()`
        will fail unless a spec is provided.
      state_spec: A nest of `tensor_spec.TensorSpec` representing the state
        needed by the network. Default is `()`, which means no state.
      name: (Optional.) A string representing the name of the network.
    """
        # Disable autocast because it may convert bfloats to other types, breaking
        # our spec checks.
        super(Network, self).__init__(name=name, autocast=False)
        common.check_tf1_allowed()

        # Required for summary() to work.
        self._is_graph_network = False

        self._input_tensor_spec = (tensor_spec.from_spec(input_tensor_spec)
                                   if input_tensor_spec is not None else None)
        # NOTE(ebrevdo): Would have preferred to call this output_tensor_spec, but
        # looks like keras.Layer already reserves that one.
        self._network_output_spec = None
        self._state_spec = tensor_spec.from_spec(state_spec)
Пример #2
0
  def __init__(
      self,
      env: tf_environment.TFEnvironment,
      policy: tf_policy.TFPolicy,
      observers: Sequence[Callable[[trajectory.Trajectory], Any]],
      transition_observers: Optional[Sequence[Callable[[trajectory.Transition],
                                                       Any]]] = None,
      max_steps: Optional[types.Int] = None,
      max_episodes: Optional[types.Int] = None,
      disable_tf_function: bool = False):
    """A driver that runs a TF policy in a TF environment.

    **Note** about bias when using batched environments with `max_episodes`:
    When using `max_episodes != None`, a `run` step "finishes" when
    `max_episodes` have been completely collected (hit a boundary).
    When used in conjunction with environments that have variable-length
    episodes, this skews the distribution of collected episodes' lengths:
    short episodes are seen more frequently than long ones.
    As a result, running an `env` of `N > 1` batched environments
    with `max_episodes >= 1` is not the same as running an env with `1`
    environment with `max_episodes >= 1`.

    Args:
      env: A tf_environment.Base environment.
      policy: A tf_policy.TFPolicy policy.
      observers: A list of observers that are notified after every step
        in the environment. Each observer is a callable(trajectory.Trajectory).
      transition_observers: A list of observers that are updated after every
        step in the environment. Each observer is a callable((TimeStep,
        PolicyStep, NextTimeStep)). The transition is shaped just as
        trajectories are for regular observers.
      max_steps: Optional maximum number of steps for each run() call. For
        batched or parallel environments, this is the maximum total number of
        steps summed across all environments. Also see below.  Default: 0.
      max_episodes: Optional maximum number of episodes for each run() call. For
        batched or parallel environments, this is the maximum total number of
        episodes summed across all environments. At least one of max_steps or
        max_episodes must be provided. If both are set, run() terminates when at
        least one of the conditions is
        satisfied.  Default: 0.
      disable_tf_function: If True the use of tf.function for the run method is
        disabled.

    Raises:
      ValueError: If both max_steps and max_episodes are None.
    """
    common.check_tf1_allowed()
    max_steps = max_steps or 0
    max_episodes = max_episodes or 0
    if max_steps < 1 and max_episodes < 1:
      raise ValueError(
          'Either `max_steps` or `max_episodes` should be greater than 0.')

    super(TFDriver, self).__init__(env, policy, observers, transition_observers)

    self._max_steps = max_steps or np.inf
    self._max_episodes = max_episodes or np.inf

    if not disable_tf_function:
      self.run = common.function(self.run, autograph=True)
  def __init__(self,
               env,
               agent,
               adversary_agent,
               adversary_env,
               env_metrics=None,
               collect=True,
               disable_tf_function=False,
               debug=False,
               combined_population=False,
               flexible_protagonist=False):
    """Runs the environment adversary and agents to collect episodes.

    Args:
      env: A tf_environment.Base environment.
      agent: An AgentTrainPackage for the main learner agent.
      adversary_agent: An AgentTrainPackage for the second agent, the
        adversary's ally. This can be None if using an unconstrained adversary
        environment.
      adversary_env: An AgentTrainPackage for the agent that controls the
        environment, learning to set parameters of the environment to decrease
        the agent's score relative to the adversary_agent. Can be None if using
        domain randomization.
      env_metrics: Global environment metrics to track (such as path length).
      collect: True if collecting episodes for training, otherwise eval.
      disable_tf_function: If True the use of tf.function for the run method is
        disabled.
      debug: If True, outputs informative logging statements.
      combined_population: If True, the entire population of protagonists plays
        each generated environment, and regret is the calc'd as the difference
        between the max of the population and the average (there are no explicit
        antagonists).
      flexible_protagonist: Which agent plays the role of protagonist in
        calculating the regret depends on which has the lowest score.
    """
    common.check_tf1_allowed()
    self.debug = debug
    self.total_episodes_collected = 0

    if not disable_tf_function:
      self.run = common.function(self.run, autograph=True)
      self.run_agent = common.function(self.run_agent, autograph=True)

    self.env_metrics = env_metrics
    self.collect = collect
    self.env = env
    self.agent = agent
    self.adversary_agent = adversary_agent
    self.adversary_env = adversary_env
    self.combined_population = combined_population
    self.flexible_protagonist = flexible_protagonist
Пример #4
0
    def __init__(self,
                 env: tf_environment.TFEnvironment,
                 policy: tf_policy.TFPolicy,
                 observers: Sequence[Callable[[trajectory.Trajectory], Any]],
                 transition_observers: Optional[Sequence[Callable[
                     [trajectory.Transition], Any]]] = None,
                 max_steps: Optional[types.Int] = None,
                 max_episodes: Optional[types.Int] = None,
                 disable_tf_function: bool = False):
        """A driver that runs a TF policy in a TF environment.

    Args:
      env: A tf_environment.Base environment.
      policy: A tf_policy.TFPolicy policy.
      observers: A list of observers that are notified after every step
        in the environment. Each observer is a callable(trajectory.Trajectory).
      transition_observers: A list of observers that are updated after every
        step in the environment. Each observer is a callable((TimeStep,
        PolicyStep, NextTimeStep)). The transition is shaped just as
        trajectories are for regular observers.
      max_steps: Optional maximum number of steps for each run() call. For
        batched or parallel environments, this is the maximum total number of
        steps summed across all environments. Also see below.  Default: 0.
      max_episodes: Optional maximum number of episodes for each run() call. For
        batched or parallel environments, this is the maximum total number of
        episodes summed across all environments. At least one of max_steps or
        max_episodes must be provided. If both are set, run() terminates when at
        least one of the conditions is
        satisfied.  Default: 0.
      disable_tf_function: If True the use of tf.function for the run method is
        disabled.

    Raises:
      ValueError: If both max_steps and max_episodes are None.
    """
        common.check_tf1_allowed()
        max_steps = max_steps or 0
        max_episodes = max_episodes or 0
        if max_steps < 1 and max_episodes < 1:
            raise ValueError(
                'Either `max_steps` or `max_episodes` should be greater than 0.'
            )

        super(TFDriver, self).__init__(env, policy, observers,
                                       transition_observers)

        self._max_steps = max_steps or np.inf
        self._max_episodes = max_episodes or np.inf

        if not disable_tf_function:
            self.run = common.function(self.run, autograph=True)
Пример #5
0
    def __init__(self, data_spec, capacity, stateful_dataset=False):
        """Initializes the replay buffer.

    Args:
      data_spec: A spec or a list/tuple/nest of specs describing a single item
        that can be stored in this buffer
      capacity: number of elements that the replay buffer can hold.
      stateful_dataset: whether the dataset contains stateful ops or not.
    """
        super(ReplayBuffer, self).__init__()
        common.check_tf1_allowed()
        self._data_spec = data_spec
        self._capacity = capacity
        self._stateful_dataset = stateful_dataset
Пример #6
0
  def __init__(self, input_tensor_spec, state_spec, name):
    """Creates an instance of `Network`.

    Args:
      input_tensor_spec: A nest of `tensor_spec.TensorSpec` representing the
        input observations.
      state_spec: A nest of `tensor_spec.TensorSpec` representing the state
        needed by the network. Use () if none.
      name: A string representing the name of the network.
    """
    super(Network, self).__init__(name=name)
    common.check_tf1_allowed()
    self._input_tensor_spec = input_tensor_spec
    self._state_spec = state_spec
Пример #7
0
    def __init__(self, input_tensor_spec, state_spec, name=None):
        """Creates an instance of `Network`.

    Args:
      input_tensor_spec: A nest of `tensor_spec.TensorSpec` representing the
        input observations.
      state_spec: A nest of `tensor_spec.TensorSpec` representing the state
        needed by the network. Use () if none.
      name: (Optional.) A string representing the name of the network.
    """
        super(Network, self).__init__(name=name)
        common.check_tf1_allowed()

        # Required for summary() to work.
        self._is_graph_network = False

        self._input_tensor_spec = input_tensor_spec
        self._state_spec = state_spec
Пример #8
0
    def __init__(self, input_tensor_spec=None, state_spec=(), name=None):
        """Creates an instance of `Network`.

    Args:
      input_tensor_spec: A nest of `tensor_spec.TensorSpec` representing the
        input observations.  Optional.  If not provided, `create_variables()`
        will fail unless a spec is provided.
      state_spec: A nest of `tensor_spec.TensorSpec` representing the state
        needed by the network. Default is `()`, which means no state.
      name: (Optional.) A string representing the name of the network.
    """
        super(Network, self).__init__(name=name)
        common.check_tf1_allowed()

        # Required for summary() to work.
        self._is_graph_network = False

        self._input_tensor_spec = input_tensor_spec
        self._state_spec = state_spec
Пример #9
0
    def __init__(self, input_tensor_spec=None, state_spec=(), name=None):
        """Creates an instance of `Network`.

    Args:
      input_tensor_spec: A nest of `tensor_spec.TensorSpec` representing the
        input observations.  Optional.  If not provided, `create_variables()`
        will fail unless a spec is provided.
      state_spec: A nest of `tensor_spec.TensorSpec` representing the state
        needed by the network. Default is `()`, which means no state.
      name: (Optional.) A string representing the name of the network.
    """
        super(Network, self).__init__(name=name)
        common.check_tf1_allowed()

        # Required for summary() to work.
        self._is_graph_network = False

        self._input_tensor_spec = input_tensor_spec
        # NOTE(ebrevdo): Would have preferred to call this output_tensor_spec, but
        # looks like keras.Layer already reserves that one.
        self._network_output_spec = None
        self._state_spec = state_spec
Пример #10
0
    def __init__(self,
                 time_step_spec,
                 action_spec,
                 policy_state_spec=(),
                 info_spec=(),
                 clip=True,
                 emit_log_probability=False,
                 automatic_state_reset=True,
                 observation_and_action_constraint_splitter=None,
                 validate_args=True,
                 name=None):
        """Initialization of TFPolicy class.

    Args:
      time_step_spec: A `TimeStep` spec of the expected time_steps. Usually
        provided by the user to the subclass.
      action_spec: A nest of BoundedTensorSpec representing the actions. Usually
        provided by the user to the subclass.
      policy_state_spec: A nest of TensorSpec representing the policy_state.
        Provided by the subclass, not directly by the user.
      info_spec: A nest of TensorSpec representing the policy info. Provided by
        the subclass, not directly by the user.
      clip: Whether to clip actions to spec before returning them.  Default
        True. Most policy-based algorithms (PCL, PPO, REINFORCE) use unclipped
        continuous actions for training.
      emit_log_probability: Emit log-probabilities of actions, if supported. If
        True, policy_step.info will have CommonFields.LOG_PROBABILITY set.
        Please consult utility methods provided in policy_step for setting and
        retrieving these. When working with custom policies, either provide a
        dictionary info_spec or a namedtuple with the field 'log_probability'.
      automatic_state_reset:  If `True`, then `get_initial_policy_state` is used
        to clear state in `action()` and `distribution()` for for time steps
        where `time_step.is_first()`.
      observation_and_action_constraint_splitter: A function used to process
        observations with action constraints. These constraints can indicate,
        for example, a mask of valid/invalid actions for a given state of the
        environment. The function takes in a full observation and returns a
        tuple consisting of 1) the part of the observation intended as input to
        the network and 2) the constraint. An example
        `observation_and_action_constraint_splitter` could be as simple as: ```
        def observation_and_action_constraint_splitter(observation): return
          observation['network_input'], observation['constraint'] ```
        *Note*: when using `observation_and_action_constraint_splitter`, make
          sure the provided `q_network` is compatible with the network-specific
          half of the output of the
          `observation_and_action_constraint_splitter`. In particular,
          `observation_and_action_constraint_splitter` will be called on the
          observation before passing to the network. If
          `observation_and_action_constraint_splitter` is None, action
          constraints are not applied.
      validate_args: Python bool.  Whether to verify inputs to, and outputs of,
        functions like `action` and `distribution` against spec structures,
        dtypes, and shapes.

        Research code may prefer to set this value to `False` to allow iterating
        on input and output structures without being hamstrung by overly
        rigid checking (at the cost of harder-to-debug errors).

        See also `TFAgent.validate_args`.
      name: A name for this module. Defaults to the class name.
    """
        super(TFPolicy, self).__init__(name=name)
        common.check_tf1_allowed()
        common.tf_agents_gauge.get_cell('TFAPolicy').set(True)
        common.assert_members_are_not_overridden(base_cls=TFPolicy,
                                                 instance=self)
        if not isinstance(time_step_spec, ts.TimeStep):
            raise ValueError(
                'The `time_step_spec` must be an instance of `TimeStep`, but is `{}`.'
                .format(type(time_step_spec)))

        self._time_step_spec = time_step_spec
        self._action_spec = action_spec
        self._policy_state_spec = policy_state_spec
        self._emit_log_probability = emit_log_probability
        self._validate_args = validate_args

        if emit_log_probability:
            log_probability_spec = tensor_spec.BoundedTensorSpec(
                shape=(),
                dtype=tf.float32,
                maximum=0,
                minimum=-float('inf'),
                name='log_probability')
            log_probability_spec = tf.nest.map_structure(
                lambda _: log_probability_spec, action_spec)
            info_spec = policy_step.set_log_probability(
                info_spec, log_probability_spec)

        self._info_spec = info_spec
        self._setup_specs()
        self._clip = clip
        self._action_fn = common.function_in_tf1()(self._action)
        self._automatic_state_reset = automatic_state_reset
        self._observation_and_action_constraint_splitter = (
            observation_and_action_constraint_splitter)
Пример #11
0
    def __init__(self,
                 time_step_spec,
                 action_spec,
                 policy,
                 collect_policy,
                 train_sequence_length,
                 num_outer_dims=2,
                 training_data_spec=None,
                 train_argspec=None,
                 debug_summaries=False,
                 summarize_grads_and_vars=False,
                 enable_summaries=True,
                 train_step_counter=None):
        """Meant to be called by subclass constructors.

    Args:
      time_step_spec: A nest of tf.TypeSpec representing the time_steps.
        Provided by the user.
      action_spec: A nest of BoundedTensorSpec representing the actions.
        Provided by the user.
      policy: An instance of `tf_policy.Base` representing the Agent's current
        policy.
      collect_policy: An instance of `tf_policy.Base` representing the Agent's
        current data collection policy (used to set `self.step_spec`).
      train_sequence_length: A python integer or `None`, signifying the number
        of time steps required from tensors in `experience` as passed to
        `train()`.  All tensors in `experience` will be shaped `[B, T, ...]` but
        for certain agents, `T` should be fixed.  For example, DQN requires
        transitions in the form of 2 time steps, so for a non-RNN DQN Agent, set
        this value to 2.  For agents that don't care, or which can handle `T`
        unknown at graph build time (i.e. most RNN-based agents), set this
        argument to `None`.
      num_outer_dims: The number of outer dimensions for the agent. Must be
        either 1 or 2. If 2, training will require both a batch_size and time
        dimension on every Tensor; if 1, training will require only a batch_size
        outer dimension.
      training_data_spec: A nest of TensorSpec specifying the structure of data
        the train() function expects. If None, defaults to the trajectory_spec
        of the collect_policy.
      train_argspec: (Optional) Describes additional supported arguments
        to the `train` call.  This must be a `dict` mapping strings to nests
        of specs.  Overriding the `experience` arg is also supported.

        Some algorithms require additional arguments to the `train()` call, and
        while TF-Agents encourages most of these to be provided in the
        `policy_info` / `info` field of `experience`, sometimes the extra
        information doesn't fit well, i.e., when it doesn't come from the
        policy.

        **NOTE** kwargs will not have their outer dimensions validated.
        In particular, `train_sequence_length` is ignored for these inputs,
        and they may have any, or inconsistent, batch/time dimensions; only
        their inner shape dimensions are checked against `train_argspec`.

        Below is an example:

        ```python
        class MyAgent(TFAgent):
          def __init__(self, counterfactual_training, ...):
             collect_policy = ...
             train_argspec = None
             if counterfactual_training:
               train_argspec = dict(
                  counterfactual=collect_policy.trajectory_spec)
             super(...).__init__(
               ...
               train_argspec=train_argspec)

        my_agent = MyAgent(...)

        for ...:
          experience, counterfactual = next(experience_and_counterfactual_iter)
          loss_info = my_agent.train(experience, counterfactual=counterfactual)
        ```
      debug_summaries: A bool; if true, subclasses should gather debug
        summaries.
      summarize_grads_and_vars: A bool; if true, subclasses should additionally
        collect gradient and variable summaries.
      enable_summaries: A bool; if false, subclasses should not gather any
        summaries (debug or otherwise); subclasses should gate *all* summaries
        using either `summaries_enabled`, `debug_summaries`, or
        `summarize_grads_and_vars` properties.
      train_step_counter: An optional counter to increment every time the train
        op is run.  Defaults to the global_step.

    Raises:
      TypeError: If `train_argspec` is not a `dict`.
      ValueError: If `train_argspec` has the keys `experience` or `weights`.
      TypeError: If any leaf nodes in `train_argspec` values are not
        subclasses of `tf.TypeSpec`.
      ValueError: If `time_step_spec` is not an instance of `ts.TimeStep`.
      ValueError: If `num_outer_dims` is not in [1, 2].
    """
        def _each_isinstance(spec, spec_types):
            """Checks if each element of `spec` is instance of any of `spec_types`."""
            return all(
                [isinstance(s, spec_types) for s in tf.nest.flatten(spec)])

        if not _each_isinstance(time_step_spec, tf.TypeSpec):
            raise TypeError(
                "time_step_spec has to contain TypeSpec (TensorSpec, "
                "SparseTensorSpec, etc) objects, but received: {}".format(
                    time_step_spec))

        if not _each_isinstance(action_spec, tensor_spec.BoundedTensorSpec):
            raise TypeError(
                "action_spec has to contain BoundedTensorSpec objects, but received: "
                "{}".format(action_spec))

        common.check_tf1_allowed()
        common.tf_agents_gauge.get_cell("TFAgent").set(True)
        common.assert_members_are_not_overridden(base_cls=TFAgent,
                                                 instance=self)
        if not isinstance(time_step_spec, ts.TimeStep):
            raise ValueError(
                "The `time_step_spec` must be an instance of `TimeStep`, but is `{}`."
                .format(type(time_step_spec)))

        if num_outer_dims not in [1, 2]:
            raise ValueError("num_outer_dims must be in [1, 2].")

        self._time_step_spec = time_step_spec
        self._action_spec = action_spec
        self._policy = policy
        self._collect_policy = collect_policy
        self._train_sequence_length = train_sequence_length
        self._num_outer_dims = num_outer_dims
        self._debug_summaries = debug_summaries
        self._summarize_grads_and_vars = summarize_grads_and_vars
        self._enable_summaries = enable_summaries
        self._training_data_spec = training_data_spec
        if train_argspec is None:
            train_argspec = {}
        else:
            if not isinstance(train_argspec, dict):
                raise TypeError(
                    "train_argspec must be a dict, but saw: {}".format(
                        train_argspec))
            train_argspec = dict(train_argspec)  # Create a local copy.
            if "weights" in train_argspec or "experience" in train_argspec:
                raise ValueError(
                    "train_argspec must not override 'weights' or "
                    "'experience' keys, but saw: {}".format(train_argspec))
            if not all(
                    isinstance(x, tf.TypeSpec)
                    for x in tf.nest.flatten(train_argspec)):
                raise TypeError(
                    "train_argspec contains non-TensorSpec objects: {}".format(
                        train_argspec))
        self._train_argspec = train_argspec
        if train_step_counter is None:
            train_step_counter = tf.compat.v1.train.get_or_create_global_step()
        self._train_step_counter = train_step_counter
        self._train_fn = common.function_in_tf1()(self._train)
        self._initialize_fn = common.function_in_tf1()(self._initialize)
Пример #12
0
 def __init__(self, name, prefix='Metrics'):
   super(TFStepMetric, self).__init__(name)
   common.check_tf1_allowed()
   self._prefix = prefix
Пример #13
0
    def __init__(self,
                 time_step_spec: ts.TimeStep,
                 action_spec: types.NestedTensorSpec,
                 policy: tf_policy.TFPolicy,
                 collect_policy: tf_policy.TFPolicy,
                 train_sequence_length: Optional[int],
                 num_outer_dims: int = 2,
                 training_data_spec: Optional[types.NestedTensorSpec] = None,
                 debug_summaries: bool = False,
                 summarize_grads_and_vars: bool = False,
                 enable_summaries: bool = True,
                 train_step_counter: Optional[tf.Variable] = None):
        """Meant to be called by subclass constructors.

    Args:
      time_step_spec: A nest of tf.TypeSpec representing the time_steps.
        Provided by the user.
      action_spec: A nest of BoundedTensorSpec representing the actions.
        Provided by the user.
      policy: An instance of `tf_policy.TFPolicy` representing the
        Agent's current policy.
      collect_policy: An instance of `tf_policy.TFPolicy` representing the
        Agent's current data collection policy (used to set `self.step_spec`).
      train_sequence_length: A python integer or `None`, signifying the number
        of time steps required from tensors in `experience` as passed to
        `train()`.  All tensors in `experience` will be shaped `[B, T, ...]` but
        for certain agents, `T` should be fixed.  For example, DQN requires
        transitions in the form of 2 time steps, so for a non-RNN DQN Agent, set
        this value to 2.  For agents that don't care, or which can handle `T`
        unknown at graph build time (i.e. most RNN-based agents), set this
        argument to `None`.
      num_outer_dims: The number of outer dimensions for the agent. Must be
        either 1 or 2. If 2, training will require both a batch_size and time
        dimension on every Tensor; if 1, training will require only a batch_size
        outer dimension.
      training_data_spec: A nest of TensorSpec specifying the structure of data
        the train() function expects. If None, defaults to the trajectory_spec
        of the collect_policy.
      debug_summaries: A bool; if true, subclasses should gather debug
        summaries.
      summarize_grads_and_vars: A bool; if true, subclasses should additionally
        collect gradient and variable summaries.
      enable_summaries: A bool; if false, subclasses should not gather any
        summaries (debug or otherwise); subclasses should gate *all* summaries
        using either `summaries_enabled`, `debug_summaries`, or
        `summarize_grads_and_vars` properties.
      train_step_counter: An optional counter to increment every time the train
        op is run.  Defaults to the global_step.

    Raises:
      ValueError: If `num_outer_dims` is not in `[1, 2]`.
    """
        common.check_tf1_allowed()
        common.tf_agents_gauge.get_cell("TFAgent").set(True)
        common.tf_agents_gauge.get_cell(str(type(self))).set(True)
        if not isinstance(time_step_spec, ts.TimeStep):
            raise TypeError(
                "The `time_step_spec` must be an instance of `TimeStep`, but is `{}`."
                .format(type(time_step_spec)))

        if num_outer_dims not in [1, 2]:
            raise ValueError("num_outer_dims must be in [1, 2].")

        time_step_spec = tensor_spec.from_spec(time_step_spec)
        action_spec = tensor_spec.from_spec(action_spec)
        self._time_step_spec = time_step_spec
        self._action_spec = action_spec
        self._policy = policy
        self._collect_policy = collect_policy
        self._train_sequence_length = train_sequence_length
        self._num_outer_dims = num_outer_dims
        self._debug_summaries = debug_summaries
        self._summarize_grads_and_vars = summarize_grads_and_vars
        self._enable_summaries = enable_summaries
        self._training_data_spec = training_data_spec
        # Data context for data collected directly from the collect policy.
        self._collect_data_context = data_converter.DataContext(
            time_step_spec=self._time_step_spec,
            action_spec=self._action_spec,
            info_spec=collect_policy.info_spec)
        # Data context for data passed to train().  May be different if
        # training_data_spec is provided.
        if training_data_spec is not None:
            training_data_spec = tensor_spec.from_spec(training_data_spec)
            # training_data_spec can be anything; so build a data_context
            # via best-effort with fall-backs to the collect data spec.
            training_discount_spec = getattr(training_data_spec, "discount",
                                             time_step_spec.discount)
            training_observation_spec = getattr(training_data_spec,
                                                "observation",
                                                time_step_spec.observation)
            training_reward_spec = getattr(training_data_spec, "reward",
                                           time_step_spec.reward)
            training_step_type_spec = getattr(training_data_spec, "step_type",
                                              time_step_spec.step_type)
            training_policy_info_spec = getattr(training_data_spec,
                                                "policy_info",
                                                collect_policy.info_spec)
            training_action_spec = getattr(training_data_spec, "action",
                                           action_spec)
            self._data_context = data_converter.DataContext(
                time_step_spec=ts.TimeStep(
                    discount=training_discount_spec,
                    observation=training_observation_spec,
                    reward=training_reward_spec,
                    step_type=training_step_type_spec),
                action_spec=training_action_spec,
                info_spec=training_policy_info_spec)
        else:
            self._data_context = data_converter.DataContext(
                time_step_spec=time_step_spec,
                action_spec=action_spec,
                info_spec=collect_policy.info_spec)
        if train_step_counter is None:
            train_step_counter = tf.compat.v1.train.get_or_create_global_step()
        self._train_step_counter = train_step_counter
        self._train_fn = common.function_in_tf1()(self._train)
        self._initialize_fn = common.function_in_tf1()(self._initialize)
        self._preprocess_sequence_fn = common.function_in_tf1()(
            self._preprocess_sequence)
        self._loss_fn = common.function_in_tf1()(self._loss)
Пример #14
0
    def __init__(self,
                 time_step_spec: ts.TimeStep,
                 action_spec: types.NestedTensorSpec,
                 policy: tf_policy.TFPolicy,
                 collect_policy: tf_policy.TFPolicy,
                 train_sequence_length: Optional[int],
                 num_outer_dims: int = 2,
                 training_data_spec: Optional[types.NestedTensorSpec] = None,
                 train_argspec: Optional[Dict[Text,
                                              types.NestedTensorSpec]] = None,
                 debug_summaries: bool = False,
                 summarize_grads_and_vars: bool = False,
                 enable_summaries: bool = True,
                 train_step_counter: Optional[tf.Variable] = None,
                 validate_args: bool = True):
        """Meant to be called by subclass constructors.

    Args:
      time_step_spec: A nest of tf.TypeSpec representing the time_steps.
        Provided by the user.
      action_spec: A nest of BoundedTensorSpec representing the actions.
        Provided by the user.
      policy: An instance of `tf_policy.TFPolicy` representing the
        Agent's current policy.
      collect_policy: An instance of `tf_policy.TFPolicy` representing the
        Agent's current data collection policy (used to set `self.step_spec`).
      train_sequence_length: A python integer or `None`, signifying the number
        of time steps required from tensors in `experience` as passed to
        `train()`.  All tensors in `experience` will be shaped `[B, T, ...]` but
        for certain agents, `T` should be fixed.  For example, DQN requires
        transitions in the form of 2 time steps, so for a non-RNN DQN Agent, set
        this value to 2.  For agents that don't care, or which can handle `T`
        unknown at graph build time (i.e. most RNN-based agents), set this
        argument to `None`.
      num_outer_dims: The number of outer dimensions for the agent. Must be
        either 1 or 2. If 2, training will require both a batch_size and time
        dimension on every Tensor; if 1, training will require only a batch_size
        outer dimension.
      training_data_spec: A nest of TensorSpec specifying the structure of data
        the train() function expects. If None, defaults to the trajectory_spec
        of the collect_policy.
      train_argspec: (Optional) Describes additional supported arguments
        to the `train` call.  This must be a `dict` mapping strings to nests
        of specs.  Overriding the `experience` arg is also supported.

        Some algorithms require additional arguments to the `train()` call, and
        while TF-Agents encourages most of these to be provided in the
        `policy_info` / `info` field of `experience`, sometimes the extra
        information doesn't fit well, i.e., when it doesn't come from the
        policy.

        **NOTE** kwargs will not have their outer dimensions validated.
        In particular, `train_sequence_length` is ignored for these inputs,
        and they may have any, or inconsistent, batch/time dimensions; only
        their inner shape dimensions are checked against `train_argspec`.

        Below is an example:

        ```python
        class MyAgent(TFAgent):
          def __init__(self, counterfactual_training, ...):
             collect_policy = ...
             train_argspec = None
             if counterfactual_training:
               train_argspec = dict(
                  counterfactual=collect_policy.trajectory_spec)
             super(...).__init__(
               ...
               train_argspec=train_argspec)

        my_agent = MyAgent(...)

        for ...:
          experience, counterfactual = next(experience_and_counterfactual_iter)
          loss_info = my_agent.train(experience, counterfactual=counterfactual)
        ```
      debug_summaries: A bool; if true, subclasses should gather debug
        summaries.
      summarize_grads_and_vars: A bool; if true, subclasses should additionally
        collect gradient and variable summaries.
      enable_summaries: A bool; if false, subclasses should not gather any
        summaries (debug or otherwise); subclasses should gate *all* summaries
        using either `summaries_enabled`, `debug_summaries`, or
        `summarize_grads_and_vars` properties.
      train_step_counter: An optional counter to increment every time the train
        op is run.  Defaults to the global_step.
      validate_args: Python bool.  Whether to verify inputs to, and outputs of,
        functions like `train` and `preprocess_sequence` against spec
        structures, dtypes, and shapes.

        Research code may prefer to set this value to `False` to allow iterating
        on input and output structures without being hamstrung by overly
        rigid checking (at the cost of harder-to-debug errors).

        See also `TFPolicy.validate_args`.

    Raises:
      TypeError: If `validate_args is True` and `train_argspec` is not a `dict`.
      ValueError: If `validate_args is True` and `train_argspec` has the keys
        `experience` or `weights`.
      TypeError: If `validate_args is True` and any leaf nodes in
        `train_argspec` values are not subclasses of `tf.TypeSpec`.
      ValueError: If `validate_args is True` and `time_step_spec` is not an
        instance of `ts.TimeStep`.
      ValueError: If `num_outer_dims` is not in `[1, 2]`.
    """
        if validate_args:

            def _each_isinstance(spec, spec_types):
                """Checks if each element of `spec` is instance of `spec_types`."""
                return all(
                    [isinstance(s, spec_types) for s in tf.nest.flatten(spec)])

            if not _each_isinstance(time_step_spec, tf.TypeSpec):
                raise TypeError(
                    "time_step_spec has to contain TypeSpec (TensorSpec, "
                    "SparseTensorSpec, etc) objects, but received: {}".format(
                        time_step_spec))

            if not _each_isinstance(action_spec,
                                    tensor_spec.BoundedTensorSpec):
                raise TypeError(
                    "action_spec has to contain BoundedTensorSpec objects, but received: "
                    "{}".format(action_spec))

        common.check_tf1_allowed()
        common.tf_agents_gauge.get_cell("TFAgent").set(True)
        common.tf_agents_gauge.get_cell(str(type(self))).set(True)
        if not isinstance(time_step_spec, ts.TimeStep):
            raise TypeError(
                "The `time_step_spec` must be an instance of `TimeStep`, but is `{}`."
                .format(type(time_step_spec)))

        if num_outer_dims not in [1, 2]:
            raise ValueError("num_outer_dims must be in [1, 2].")

        self._time_step_spec = time_step_spec
        self._action_spec = action_spec
        self._policy = policy
        self._collect_policy = collect_policy
        self._train_sequence_length = train_sequence_length
        self._num_outer_dims = num_outer_dims
        self._debug_summaries = debug_summaries
        self._summarize_grads_and_vars = summarize_grads_and_vars
        self._enable_summaries = enable_summaries
        self._training_data_spec = training_data_spec
        self._validate_args = validate_args
        # Data context for data collected directly from the collect policy.
        self._collect_data_context = data_converter.DataContext(
            time_step_spec=time_step_spec,
            action_spec=action_spec,
            info_spec=collect_policy.info_spec)
        # Data context for data passed to train().  May be different if
        # training_data_spec is provided.
        if training_data_spec is not None:
            # training_data_spec can be anything; so build a data_context
            # via best-effort with fall-backs to the collect data spec.
            training_discount_spec = getattr(training_data_spec, "discount",
                                             time_step_spec.discount)
            training_observation_spec = getattr(training_data_spec,
                                                "observation",
                                                time_step_spec.observation)
            training_reward_spec = getattr(training_data_spec, "reward",
                                           time_step_spec.reward)
            training_step_type_spec = getattr(training_data_spec, "step_type",
                                              time_step_spec.step_type)
            training_policy_info_spec = getattr(training_data_spec,
                                                "policy_info",
                                                collect_policy.info_spec)
            training_action_spec = getattr(training_data_spec, "action",
                                           action_spec)
            self._data_context = data_converter.DataContext(
                time_step_spec=ts.TimeStep(
                    discount=training_discount_spec,
                    observation=training_observation_spec,
                    reward=training_reward_spec,
                    step_type=training_step_type_spec),
                action_spec=training_action_spec,
                info_spec=training_policy_info_spec)
        else:
            self._data_context = data_converter.DataContext(
                time_step_spec=time_step_spec,
                action_spec=action_spec,
                info_spec=collect_policy.info_spec)
        if train_argspec is None:
            train_argspec = {}
        elif validate_args:
            if not isinstance(train_argspec, dict):
                raise TypeError(
                    "train_argspec must be a dict, but saw: {}".format(
                        train_argspec))
            if "weights" in train_argspec or "experience" in train_argspec:
                raise ValueError(
                    "train_argspec must not override 'weights' or "
                    "'experience' keys, but saw: {}".format(train_argspec))
            if not all(
                    isinstance(x, tf.TypeSpec)
                    for x in tf.nest.flatten(train_argspec)):
                raise TypeError(
                    "train_argspec contains non-TensorSpec objects: {}".format(
                        train_argspec))
        train_argspec = dict(train_argspec)  # Create a local copy.
        self._train_argspec = train_argspec
        if train_step_counter is None:
            train_step_counter = tf.compat.v1.train.get_or_create_global_step()
        self._train_step_counter = train_step_counter
        self._train_fn = common.function_in_tf1()(self._train)
        self._initialize_fn = common.function_in_tf1()(self._initialize)
        self._preprocess_sequence_fn = common.function_in_tf1()(
            self._preprocess_sequence)
        self._loss_fn = common.function_in_tf1()(self._loss)
Пример #15
0
  def __init__(self,
               time_step_spec,
               action_spec,
               policy,
               collect_policy,
               train_sequence_length,
               num_outer_dims=2,
               debug_summaries=False,
               summarize_grads_and_vars=False,
               enable_summaries=True,
               train_step_counter=None):
    """Meant to be called by subclass constructors.

    Args:
      time_step_spec: A nest of tf.TypeSpec representing the time_steps.
        Provided by the user.
      action_spec: A nest of BoundedTensorSpec representing the actions.
        Provided by the user.
      policy: An instance of `tf_policy.Base` representing the Agent's current
        policy.
      collect_policy: An instance of `tf_policy.Base` representing the Agent's
        current data collection policy (used to set `self.step_spec`).
      train_sequence_length: A python integer or `None`, signifying the number
        of time steps required from tensors in `experience` as passed to
        `train()`.  All tensors in `experience` will be shaped `[B, T, ...]` but
        for certain agents, `T` should be fixed.  For example, DQN requires
        transitions in the form of 2 time steps, so for a non-RNN DQN Agent, set
        this value to 2.  For agents that don't care, or which can handle `T`
        unknown at graph build time (i.e. most RNN-based agents), set this
        argument to `None`.
      num_outer_dims: The number of outer dimensions for the agent. Must be
        either 1 or 2. If 2, training will require both a batch_size and time
        dimension on every Tensor; if 1, training will require only a batch_size
        outer dimension.
      debug_summaries: A bool; if true, subclasses should gather debug
        summaries.
      summarize_grads_and_vars: A bool; if true, subclasses should additionally
        collect gradient and variable summaries.
      enable_summaries: A bool; if false, subclasses should not gather any
        summaries (debug or otherwise); subclasses should gate *all* summaries
        using either `summaries_enabled`, `debug_summaries`, or
        `summarize_grads_and_vars` properties.
      train_step_counter: An optional counter to increment every time the train
        op is run.  Defaults to the global_step.

    Raises:
      ValueError: If `time_step_spec` is not an instance of `ts.TimeStep`.
      ValueError: If `num_outer_dims` is not in [1, 2].
    """

    def _each_isinstance(spec, spec_types):
      """Checks if each element of `spec` is instance of any of `spec_types`."""
      return all([isinstance(s, spec_types) for s in tf.nest.flatten(spec)])

    if not _each_isinstance(time_step_spec, tf.TypeSpec):
      raise TypeError("time_step_spec has to contain TypeSpec (TensorSpec, "
                      "SparseTensorSpec, etc) objects, but received: {}".format(
                          time_step_spec))

    if not _each_isinstance(action_spec, tensor_spec.BoundedTensorSpec):
      raise TypeError(
          "action_spec has to contain BoundedTensorSpec objects, but received: "
          "{}".format(action_spec))

    common.check_tf1_allowed()
    common.tf_agents_gauge.get_cell("TFAgent").set(True)
    common.assert_members_are_not_overridden(base_cls=TFAgent, instance=self)
    if not isinstance(time_step_spec, ts.TimeStep):
      raise ValueError(
          "The `time_step_spec` must be an instance of `TimeStep`, but is `{}`."
          .format(type(time_step_spec)))

    if num_outer_dims not in [1, 2]:
      raise ValueError("num_outer_dims must be in [1, 2].")

    self._time_step_spec = time_step_spec
    self._action_spec = action_spec
    self._policy = policy
    self._collect_policy = collect_policy
    self._train_sequence_length = train_sequence_length
    self._num_outer_dims = num_outer_dims
    self._debug_summaries = debug_summaries
    self._summarize_grads_and_vars = summarize_grads_and_vars
    self._enable_summaries = enable_summaries
    if train_step_counter is None:
      train_step_counter = tf.compat.v1.train.get_or_create_global_step()
    self._train_step_counter = train_step_counter
    self._train_fn = common.function_in_tf1()(self._train)
    self._initialize_fn = common.function_in_tf1()(self._initialize)