Пример #1
0
    def setUp(self):
        super().setUp()
        self._data_context = data_converter.DataContext(
            time_step_spec=ts.TimeStep(step_type=(),
                                       reward=tf.TensorSpec((), tf.float32),
                                       discount=tf.TensorSpec((), tf.float32),
                                       observation=()),
            action_spec={'action1': tf.TensorSpec((), tf.float32)},
            info_spec=(),
            policy_state_spec=(),
            use_half_transition=True)

        self._data_context_with_state = data_converter.DataContext(
            time_step_spec=ts.TimeStep(step_type=(),
                                       reward=tf.TensorSpec((), tf.float32),
                                       discount=tf.TensorSpec((), tf.float32),
                                       observation=tf.TensorSpec((2, ),
                                                                 tf.float32)),
            action_spec={'action1': tf.TensorSpec((), tf.float32)},
            info_spec=(),
            policy_state_spec=[
                tf.TensorSpec((2, ), tf.float32),
                tf.TensorSpec((2, ), tf.float32)
            ],
            use_half_transition=True)
 def setUp(self):
     super(AsTrajectoryTest, self).setUp()
     self._data_context = data_converter.DataContext(
         time_step_spec=ts.TimeStep(step_type=(),
                                    reward=tf.TensorSpec((), tf.float32),
                                    discount=tf.TensorSpec((), tf.float32),
                                    observation=()),
         action_spec={'action1': tf.TensorSpec((), tf.float32)},
         info_spec=())
Пример #3
0
    def _setup_data_converter(self, q_network, gamma, n_step_update):
        if q_network.state_spec:
            if not self._in_graph_bellman_update:
                self._data_context = data_converter.DataContext(
                    time_step_spec=self._time_step_spec,
                    action_spec=self._action_spec,
                    info_spec=self._collect_policy.info_spec,
                    policy_state_spec=self._q_network.state_spec,
                    use_half_transition=True)
                self._as_transition = data_converter.AsHalfTransition(
                    self.data_context, squeeze_time_dim=False)
            else:
                self._data_context = data_converter.DataContext(
                    time_step_spec=self._time_step_spec,
                    action_spec=self._action_spec,
                    info_spec=self._collect_policy.info_spec,
                    policy_state_spec=self._q_network.state_spec,
                    use_half_transition=False)
                self._as_transition = data_converter.AsTransition(
                    self.data_context,
                    squeeze_time_dim=False,
                    prepend_t0_to_next_time_step=True)
        else:
            if not self._in_graph_bellman_update:
                self._data_context = data_converter.DataContext(
                    time_step_spec=self._time_step_spec,
                    action_spec=self._action_spec,
                    info_spec=self._collect_policy.info_spec,
                    policy_state_spec=self._q_network.state_spec,
                    use_half_transition=True)

                self._as_transition = data_converter.AsHalfTransition(
                    self.data_context, squeeze_time_dim=True)
            else:
                # This reduces the n-step return and removes the extra time dimension,
                # allowing the rest of the computations to be independent of the
                # n-step parameter.
                self._as_transition = data_converter.AsNStepTransition(
                    self.data_context, gamma=gamma, n=n_step_update)
Пример #4
0
    def __init__(self,
                 time_step_spec: ts.TimeStep,
                 action_spec: types.NestedTensorSpec,
                 policy: tf_policy.TFPolicy,
                 collect_policy: tf_policy.TFPolicy,
                 train_sequence_length: Optional[int],
                 num_outer_dims: int = 2,
                 training_data_spec: Optional[types.NestedTensorSpec] = None,
                 train_argspec: Optional[Dict[Text,
                                              types.NestedTensorSpec]] = None,
                 debug_summaries: bool = False,
                 summarize_grads_and_vars: bool = False,
                 enable_summaries: bool = True,
                 train_step_counter: Optional[tf.Variable] = None,
                 validate_args: bool = True):
        """Meant to be called by subclass constructors.

    Args:
      time_step_spec: A nest of tf.TypeSpec representing the time_steps.
        Provided by the user.
      action_spec: A nest of BoundedTensorSpec representing the actions.
        Provided by the user.
      policy: An instance of `tf_policy.TFPolicy` representing the
        Agent's current policy.
      collect_policy: An instance of `tf_policy.TFPolicy` representing the
        Agent's current data collection policy (used to set `self.step_spec`).
      train_sequence_length: A python integer or `None`, signifying the number
        of time steps required from tensors in `experience` as passed to
        `train()`.  All tensors in `experience` will be shaped `[B, T, ...]` but
        for certain agents, `T` should be fixed.  For example, DQN requires
        transitions in the form of 2 time steps, so for a non-RNN DQN Agent, set
        this value to 2.  For agents that don't care, or which can handle `T`
        unknown at graph build time (i.e. most RNN-based agents), set this
        argument to `None`.
      num_outer_dims: The number of outer dimensions for the agent. Must be
        either 1 or 2. If 2, training will require both a batch_size and time
        dimension on every Tensor; if 1, training will require only a batch_size
        outer dimension.
      training_data_spec: A nest of TensorSpec specifying the structure of data
        the train() function expects. If None, defaults to the trajectory_spec
        of the collect_policy.
      train_argspec: (Optional) Describes additional supported arguments
        to the `train` call.  This must be a `dict` mapping strings to nests
        of specs.  Overriding the `experience` arg is also supported.

        Some algorithms require additional arguments to the `train()` call, and
        while TF-Agents encourages most of these to be provided in the
        `policy_info` / `info` field of `experience`, sometimes the extra
        information doesn't fit well, i.e., when it doesn't come from the
        policy.

        **NOTE** kwargs will not have their outer dimensions validated.
        In particular, `train_sequence_length` is ignored for these inputs,
        and they may have any, or inconsistent, batch/time dimensions; only
        their inner shape dimensions are checked against `train_argspec`.

        Below is an example:

        ```python
        class MyAgent(TFAgent):
          def __init__(self, counterfactual_training, ...):
             collect_policy = ...
             train_argspec = None
             if counterfactual_training:
               train_argspec = dict(
                  counterfactual=collect_policy.trajectory_spec)
             super(...).__init__(
               ...
               train_argspec=train_argspec)

        my_agent = MyAgent(...)

        for ...:
          experience, counterfactual = next(experience_and_counterfactual_iter)
          loss_info = my_agent.train(experience, counterfactual=counterfactual)
        ```
      debug_summaries: A bool; if true, subclasses should gather debug
        summaries.
      summarize_grads_and_vars: A bool; if true, subclasses should additionally
        collect gradient and variable summaries.
      enable_summaries: A bool; if false, subclasses should not gather any
        summaries (debug or otherwise); subclasses should gate *all* summaries
        using either `summaries_enabled`, `debug_summaries`, or
        `summarize_grads_and_vars` properties.
      train_step_counter: An optional counter to increment every time the train
        op is run.  Defaults to the global_step.
      validate_args: Python bool.  Whether to verify inputs to, and outputs of,
        functions like `train` and `preprocess_sequence` against spec
        structures, dtypes, and shapes.

        Research code may prefer to set this value to `False` to allow iterating
        on input and output structures without being hamstrung by overly
        rigid checking (at the cost of harder-to-debug errors).

        See also `TFPolicy.validate_args`.

    Raises:
      TypeError: If `validate_args is True` and `train_argspec` is not a `dict`.
      ValueError: If `validate_args is True` and `train_argspec` has the keys
        `experience` or `weights`.
      TypeError: If `validate_args is True` and any leaf nodes in
        `train_argspec` values are not subclasses of `tf.TypeSpec`.
      ValueError: If `validate_args is True` and `time_step_spec` is not an
        instance of `ts.TimeStep`.
      ValueError: If `num_outer_dims` is not in `[1, 2]`.
    """
        if validate_args:

            def _each_isinstance(spec, spec_types):
                """Checks if each element of `spec` is instance of `spec_types`."""
                return all(
                    [isinstance(s, spec_types) for s in tf.nest.flatten(spec)])

            if not _each_isinstance(time_step_spec, tf.TypeSpec):
                raise TypeError(
                    "time_step_spec has to contain TypeSpec (TensorSpec, "
                    "SparseTensorSpec, etc) objects, but received: {}".format(
                        time_step_spec))

            if not _each_isinstance(action_spec,
                                    tensor_spec.BoundedTensorSpec):
                raise TypeError(
                    "action_spec has to contain BoundedTensorSpec objects, but received: "
                    "{}".format(action_spec))

        common.check_tf1_allowed()
        common.tf_agents_gauge.get_cell("TFAgent").set(True)
        common.assert_members_are_not_overridden(base_cls=TFAgent,
                                                 instance=self)
        if not isinstance(time_step_spec, ts.TimeStep):
            raise TypeError(
                "The `time_step_spec` must be an instance of `TimeStep`, but is `{}`."
                .format(type(time_step_spec)))

        if num_outer_dims not in [1, 2]:
            raise ValueError("num_outer_dims must be in [1, 2].")

        self._time_step_spec = time_step_spec
        self._action_spec = action_spec
        self._policy = policy
        self._collect_policy = collect_policy
        self._train_sequence_length = train_sequence_length
        self._num_outer_dims = num_outer_dims
        self._debug_summaries = debug_summaries
        self._summarize_grads_and_vars = summarize_grads_and_vars
        self._enable_summaries = enable_summaries
        self._training_data_spec = training_data_spec
        self._validate_args = validate_args
        # Data context for data collected directly from the collect policy.
        self._collect_data_context = data_converter.DataContext(
            time_step_spec=time_step_spec,
            action_spec=action_spec,
            info_spec=collect_policy.info_spec)
        # Data context for data passed to train().  May be different if
        # training_data_spec is provided.
        if training_data_spec is not None:
            data_context_info_spec = getattr(training_data_spec, "policy_info",
                                             ())
        else:
            data_context_info_spec = collect_policy.info_spec
        self._data_context = data_converter.DataContext(
            time_step_spec=time_step_spec,
            action_spec=action_spec,
            info_spec=data_context_info_spec)
        if train_argspec is None:
            train_argspec = {}
        elif validate_args:
            if not isinstance(train_argspec, dict):
                raise TypeError(
                    "train_argspec must be a dict, but saw: {}".format(
                        train_argspec))
            if "weights" in train_argspec or "experience" in train_argspec:
                raise ValueError(
                    "train_argspec must not override 'weights' or "
                    "'experience' keys, but saw: {}".format(train_argspec))
            if not all(
                    isinstance(x, tf.TypeSpec)
                    for x in tf.nest.flatten(train_argspec)):
                raise TypeError(
                    "train_argspec contains non-TensorSpec objects: {}".format(
                        train_argspec))
        train_argspec = dict(train_argspec)  # Create a local copy.
        self._train_argspec = train_argspec
        if train_step_counter is None:
            train_step_counter = tf.compat.v1.train.get_or_create_global_step()
        self._train_step_counter = train_step_counter
        self._train_fn = common.function_in_tf1()(self._train)
        self._initialize_fn = common.function_in_tf1()(self._initialize)
        self._preprocess_sequence_fn = common.function_in_tf1()(
            self._preprocess_sequence)
        self._loss_fn = common.function_in_tf1()(self._loss)
Пример #5
0
    def __init__(self,
                 time_step_spec: ts.TimeStep,
                 action_spec: types.NestedTensorSpec,
                 policy: tf_policy.TFPolicy,
                 collect_policy: tf_policy.TFPolicy,
                 train_sequence_length: Optional[int],
                 num_outer_dims: int = 2,
                 training_data_spec: Optional[types.NestedTensorSpec] = None,
                 debug_summaries: bool = False,
                 summarize_grads_and_vars: bool = False,
                 enable_summaries: bool = True,
                 train_step_counter: Optional[tf.Variable] = None):
        """Meant to be called by subclass constructors.

    Args:
      time_step_spec: A nest of tf.TypeSpec representing the time_steps.
        Provided by the user.
      action_spec: A nest of BoundedTensorSpec representing the actions.
        Provided by the user.
      policy: An instance of `tf_policy.TFPolicy` representing the
        Agent's current policy.
      collect_policy: An instance of `tf_policy.TFPolicy` representing the
        Agent's current data collection policy (used to set `self.step_spec`).
      train_sequence_length: A python integer or `None`, signifying the number
        of time steps required from tensors in `experience` as passed to
        `train()`.  All tensors in `experience` will be shaped `[B, T, ...]` but
        for certain agents, `T` should be fixed.  For example, DQN requires
        transitions in the form of 2 time steps, so for a non-RNN DQN Agent, set
        this value to 2.  For agents that don't care, or which can handle `T`
        unknown at graph build time (i.e. most RNN-based agents), set this
        argument to `None`.
      num_outer_dims: The number of outer dimensions for the agent. Must be
        either 1 or 2. If 2, training will require both a batch_size and time
        dimension on every Tensor; if 1, training will require only a batch_size
        outer dimension.
      training_data_spec: A nest of TensorSpec specifying the structure of data
        the train() function expects. If None, defaults to the trajectory_spec
        of the collect_policy.
      debug_summaries: A bool; if true, subclasses should gather debug
        summaries.
      summarize_grads_and_vars: A bool; if true, subclasses should additionally
        collect gradient and variable summaries.
      enable_summaries: A bool; if false, subclasses should not gather any
        summaries (debug or otherwise); subclasses should gate *all* summaries
        using either `summaries_enabled`, `debug_summaries`, or
        `summarize_grads_and_vars` properties.
      train_step_counter: An optional counter to increment every time the train
        op is run.  Defaults to the global_step.

    Raises:
      ValueError: If `num_outer_dims` is not in `[1, 2]`.
    """
        common.check_tf1_allowed()
        common.tf_agents_gauge.get_cell("TFAgent").set(True)
        common.tf_agents_gauge.get_cell(str(type(self))).set(True)
        if not isinstance(time_step_spec, ts.TimeStep):
            raise TypeError(
                "The `time_step_spec` must be an instance of `TimeStep`, but is `{}`."
                .format(type(time_step_spec)))

        if num_outer_dims not in [1, 2]:
            raise ValueError("num_outer_dims must be in [1, 2].")

        time_step_spec = tensor_spec.from_spec(time_step_spec)
        action_spec = tensor_spec.from_spec(action_spec)
        self._time_step_spec = time_step_spec
        self._action_spec = action_spec
        self._policy = policy
        self._collect_policy = collect_policy
        self._train_sequence_length = train_sequence_length
        self._num_outer_dims = num_outer_dims
        self._debug_summaries = debug_summaries
        self._summarize_grads_and_vars = summarize_grads_and_vars
        self._enable_summaries = enable_summaries
        self._training_data_spec = training_data_spec
        # Data context for data collected directly from the collect policy.
        self._collect_data_context = data_converter.DataContext(
            time_step_spec=self._time_step_spec,
            action_spec=self._action_spec,
            info_spec=collect_policy.info_spec)
        # Data context for data passed to train().  May be different if
        # training_data_spec is provided.
        if training_data_spec is not None:
            training_data_spec = tensor_spec.from_spec(training_data_spec)
            # training_data_spec can be anything; so build a data_context
            # via best-effort with fall-backs to the collect data spec.
            training_discount_spec = getattr(training_data_spec, "discount",
                                             time_step_spec.discount)
            training_observation_spec = getattr(training_data_spec,
                                                "observation",
                                                time_step_spec.observation)
            training_reward_spec = getattr(training_data_spec, "reward",
                                           time_step_spec.reward)
            training_step_type_spec = getattr(training_data_spec, "step_type",
                                              time_step_spec.step_type)
            training_policy_info_spec = getattr(training_data_spec,
                                                "policy_info",
                                                collect_policy.info_spec)
            training_action_spec = getattr(training_data_spec, "action",
                                           action_spec)
            self._data_context = data_converter.DataContext(
                time_step_spec=ts.TimeStep(
                    discount=training_discount_spec,
                    observation=training_observation_spec,
                    reward=training_reward_spec,
                    step_type=training_step_type_spec),
                action_spec=training_action_spec,
                info_spec=training_policy_info_spec)
        else:
            self._data_context = data_converter.DataContext(
                time_step_spec=time_step_spec,
                action_spec=action_spec,
                info_spec=collect_policy.info_spec)
        if train_step_counter is None:
            train_step_counter = tf.compat.v1.train.get_or_create_global_step()
        self._train_step_counter = train_step_counter
        self._train_fn = common.function_in_tf1()(self._train)
        self._initialize_fn = common.function_in_tf1()(self._initialize)
        self._preprocess_sequence_fn = common.function_in_tf1()(
            self._preprocess_sequence)
        self._loss_fn = common.function_in_tf1()(self._loss)