Exemple #1
0
    def testDistributedLinearAgentUpdate(self,
                                         batch_size,
                                         context_dim,
                                         exploration_policy,
                                         dtype,
                                         use_eigendecomp=False):
        """Same as above, but uses the distributed train function of the agent."""

        # Construct a `Trajectory` for the given action, observation, reward.
        num_actions = 5
        initial_step, final_step = _get_initial_and_final_steps(
            batch_size, context_dim)
        action = np.random.randint(num_actions,
                                   size=batch_size,
                                   dtype=np.int32)
        action_step = _get_action_step(action)
        experience = _get_experience(initial_step, action_step, final_step)

        # Construct an agent and perform the update.
        observation_spec = tensor_spec.TensorSpec([context_dim], tf.float32)
        time_step_spec = time_step.time_step_spec(observation_spec)
        action_spec = tensor_spec.BoundedTensorSpec(dtype=tf.int32,
                                                    shape=(),
                                                    minimum=0,
                                                    maximum=num_actions - 1)

        agent = linear_agent.LinearBanditAgent(
            exploration_policy=exploration_policy,
            time_step_spec=time_step_spec,
            action_spec=action_spec,
            dtype=dtype)
        self.evaluate(agent.initialize())
        train_fn = common.function_in_tf1()(agent._distributed_train_step)
        loss_info = train_fn(experience=experience)
        self.evaluate(loss_info)

        final_a = self.evaluate(agent.cov_matrix)
        final_b = self.evaluate(agent.data_vector)

        # Compute the expected updated estimates.
        observations_list = tf.dynamic_partition(
            data=tf.reshape(experience.observation, [batch_size, context_dim]),
            partitions=tf.convert_to_tensor(action),
            num_partitions=num_actions)
        rewards_list = tf.dynamic_partition(
            data=tf.reshape(experience.reward, [batch_size]),
            partitions=tf.convert_to_tensor(action),
            num_partitions=num_actions)
        expected_a_updated_list = []
        expected_b_updated_list = []
        expected_theta_updated_list = []
        for _, (observations_for_arm, rewards_for_arm) in enumerate(
                zip(observations_list, rewards_list)):
            num_samples_for_arm_current = tf.cast(
                tf.shape(rewards_for_arm)[0], tf.float32)
            num_samples_for_arm_total = num_samples_for_arm_current

            # pylint: disable=cell-var-from-loop
            def true_fn():
                a_new = tf.matmul(observations_for_arm,
                                  observations_for_arm,
                                  transpose_a=True)
                b_new = bandit_utils.sum_reward_weighted_observations(
                    rewards_for_arm, observations_for_arm)
                return a_new, b_new

            def false_fn():
                return tf.zeros([context_dim,
                                 context_dim]), tf.zeros([context_dim])

            a_new, b_new = tf.cond(
                tf.squeeze(num_samples_for_arm_total) > 0, true_fn, false_fn)
            theta_new = tf.squeeze(tf.linalg.solve(
                a_new + tf.eye(context_dim), tf.expand_dims(b_new, axis=-1)),
                                   axis=-1)

            expected_a_updated_list.append(self.evaluate(a_new))
            expected_b_updated_list.append(self.evaluate(b_new))
            expected_theta_updated_list.append(self.evaluate(theta_new))

        # Check that the actual updated estimates match the expectations.
        self.assertAllClose(expected_a_updated_list, final_a)
        self.assertAllClose(expected_b_updated_list, final_b)
Exemple #2
0
    def __init__(self,
                 time_step_spec,
                 action_spec,
                 policy,
                 collect_policy,
                 train_sequence_length,
                 num_outer_dims=2,
                 train_argspec=None,
                 debug_summaries=False,
                 summarize_grads_and_vars=False,
                 enable_summaries=True,
                 train_step_counter=None):
        """Meant to be called by subclass constructors.
    Args:
      time_step_spec: A nest of tf.TypeSpec representing the time_steps.
        Provided by the user.
      action_spec: A nest of BoundedTensorSpec representing the actions.
        Provided by the user.
      policy: An instance of `tf_policy.Base` representing the Agent's current
        policy.
      collect_policy: An instance of `tf_policy.Base` representing the Agent's
        current data collection policy (used to set `self.step_spec`).
      train_sequence_length: A python integer or `None`, signifying the number
        of time steps required from tensors in `experience` as passed to
        `train()`.  All tensors in `experience` will be shaped `[B, T, ...]` but
        for certain agents, `T` should be fixed.  For example, DQN requires
        transitions in the form of 2 time steps, so for a non-RNN DQN Agent, set
        this value to 2.  For agents that don't care, or which can handle `T`
        unknown at graph build time (i.e. most RNN-based agents), set this
        argument to `None`.
      num_outer_dims: The number of outer dimensions for the agent. Must be
        either 1 or 2. If 2, training will require both a batch_size and time
        dimension on every Tensor; if 1, training will require only a batch_size
        outer dimension.
      train_argspec: (Optional) Describes additional supported arguments
        to the `train` call.  This must be a `dict` mapping strings to nests
        of specs.  Overriding the `experience` arg is also supported.
        Some algorithms require additional arguments to the `train()` call, and
        while TF-Agents encourages most of these to be provided in the
        `policy_info` / `info` field of `experience`, sometimes the extra
        information doesn't fit well, i.e., when it doesn't come from the
        policy.
        **NOTE** kwargs will not have their outer dimensions validated.
        In particular, `train_sequence_length` is ignored for these inputs,
        and they may have any, or inconsistent, batch/time dimensions; only
        their inner shape dimensions are checked against `train_argspec`.
        Below is an example:
        ```python
        class MyAgent(TFAgent):
          def __init__(self, counterfactual_training, ...):
             collect_policy = ...
             train_argspec = None
             if counterfactual_training:
               train_argspec = dict(
                  counterfactual=collect_policy.trajectory_spec)
             super(...).__init__(
               ...
               train_argspec=train_argspec)
        my_agent = MyAgent(...)
        for ...:
          experience, counterfactual = next(experience_and_counterfactual_iter)
          loss_info = my_agent.train(experience, counterfactual=counterfactual)
        ```
      debug_summaries: A bool; if true, subclasses should gather debug
        summaries.
      summarize_grads_and_vars: A bool; if true, subclasses should additionally
        collect gradient and variable summaries.
      enable_summaries: A bool; if false, subclasses should not gather any
        summaries (debug or otherwise); subclasses should gate *all* summaries
        using either `summaries_enabled`, `debug_summaries`, or
        `summarize_grads_and_vars` properties.
      train_step_counter: An optional counter to increment every time the train
        op is run.  Defaults to the global_step.
    Raises:
      TypeError: If `train_argspec` is not a `dict`.
      ValueError: If `train_argspec` has the keys `experience` or `weights`.
      TypeError: If any leaf nodes in `train_argspec` values are not
        subclasses of `tf.TypeSpec`.
      ValueError: If `time_step_spec` is not an instance of `ts.TimeStep`.
      ValueError: If `num_outer_dims` is not in [1, 2].
    """
        def _each_isinstance(spec, spec_types):
            """Checks if each element of `spec` is instance of any of `spec_types`."""
            return all(
                [isinstance(s, spec_types) for s in tf.nest.flatten(spec)])

        if not _each_isinstance(time_step_spec, tf.TypeSpec):
            raise TypeError(
                "time_step_spec has to contain TypeSpec (TensorSpec, "
                "SparseTensorSpec, etc) objects, but received: {}".format(
                    time_step_spec))

        if not _each_isinstance(action_spec, tensor_spec.BoundedTensorSpec):
            raise TypeError(
                "action_spec has to contain BoundedTensorSpec objects, but received: "
                "{}".format(action_spec))

        common.check_tf1_allowed()
        common.tf_agents_gauge.get_cell("TFAgent").set(True)
        common.assert_members_are_not_overridden(base_cls=TFAgent,
                                                 instance=self)
        if not isinstance(time_step_spec, ts.TimeStep):
            raise ValueError(
                "The `time_step_spec` must be an instance of `TimeStep`, but is `{}`."
                .format(type(time_step_spec)))

        if num_outer_dims not in [1, 2]:
            raise ValueError("num_outer_dims must be in [1, 2].")

        self._time_step_spec = time_step_spec
        self._action_spec = action_spec
        self._policy = policy
        self._collect_policy = collect_policy
        self._train_sequence_length = train_sequence_length
        self._num_outer_dims = num_outer_dims
        self._debug_summaries = debug_summaries
        self._summarize_grads_and_vars = summarize_grads_and_vars
        self._enable_summaries = enable_summaries
        if train_argspec is None:
            train_argspec = {}
        else:
            if not isinstance(train_argspec, dict):
                raise TypeError(
                    "train_argspec must be a dict, but saw: {}".format(
                        train_argspec))
            train_argspec = dict(train_argspec)  # Create a local copy.
            if "weights" in train_argspec or "experience" in train_argspec:
                raise ValueError(
                    "train_argspec must not override 'weights' or "
                    "'experience' keys, but saw: {}".format(train_argspec))
            if not all(
                    isinstance(x, tf.TypeSpec)
                    for x in tf.nest.flatten(train_argspec)):
                raise TypeError(
                    "train_argspec contains non-TensorSpec objects: {}".format(
                        train_argspec))
        self._train_argspec = train_argspec
        if train_step_counter is None:
            train_step_counter = tf.compat.v1.train.get_or_create_global_step()
        self._train_step_counter = train_step_counter
        self._train_fn = common.function_in_tf1()(self._train)
        self._initialize_fn = common.function_in_tf1()(self._initialize)
Exemple #3
0
    def testActionBatchWithVariablesAndPolicyUpdate(self, batch_size,
                                                    actions_from_reward_layer):

        a_list = []
        a_new_list = []
        b_list = []
        b_new_list = []
        num_samples_list = []
        num_samples_new_list = []
        for k in range(1, self._num_actions + 1):
            a_initial_value = k + 1 + 2 * k * tf.eye(self._encoding_dim,
                                                     dtype=tf.float32)
            a_for_one_arm = tf.compat.v2.Variable(a_initial_value)
            a_list.append(a_for_one_arm)
            b_initial_value = tf.constant(k * np.ones(self._encoding_dim),
                                          dtype=tf.float32)
            b_for_one_arm = tf.compat.v2.Variable(b_initial_value)
            b_list.append(b_for_one_arm)
            num_samples_initial_value = tf.constant([1], dtype=tf.float32)
            num_samples_for_one_arm = tf.compat.v2.Variable(
                num_samples_initial_value)
            num_samples_list.append(num_samples_for_one_arm)

            # Variables for the new policy (they differ by an offset).
            a_new_for_one_arm = tf.compat.v2.Variable(a_initial_value +
                                                      _POLICY_VARIABLES_OFFSET)
            a_new_list.append(a_new_for_one_arm)
            b_new_for_one_arm = tf.compat.v2.Variable(b_initial_value +
                                                      _POLICY_VARIABLES_OFFSET)
            b_new_list.append(b_new_for_one_arm)
            num_samples_for_one_arm_new = tf.compat.v2.Variable(
                num_samples_initial_value + _POLICY_VARIABLES_OFFSET)
            num_samples_new_list.append(num_samples_for_one_arm_new)

        policy = neural_linucb_policy.NeuralLinUCBPolicy(
            encoding_network=DummyNet(),
            encoding_dim=self._encoding_dim,
            reward_layer=get_reward_layer(),
            actions_from_reward_layer=tf.constant(actions_from_reward_layer,
                                                  dtype=tf.bool),
            cov_matrix=a_list,
            data_vector=b_list,
            num_samples=num_samples_list,
            epsilon_greedy=0.0,
            time_step_spec=self._time_step_spec)

        new_policy = neural_linucb_policy.NeuralLinUCBPolicy(
            encoding_network=DummyNet(),
            encoding_dim=self._encoding_dim,
            reward_layer=get_reward_layer(),
            actions_from_reward_layer=tf.constant(actions_from_reward_layer,
                                                  dtype=tf.bool),
            cov_matrix=a_new_list,
            data_vector=b_new_list,
            num_samples=num_samples_new_list,
            epsilon_greedy=0.0,
            time_step_spec=self._time_step_spec)

        action_step = policy.action(
            self._time_step_batch(batch_size=batch_size))
        new_action_step = new_policy.action(
            self._time_step_batch(batch_size=batch_size))
        self.assertEqual(action_step.action.shape,
                         new_action_step.action.shape)
        self.assertEqual(action_step.action.dtype,
                         new_action_step.action.dtype)

        self.evaluate(tf.compat.v1.global_variables_initializer())
        self.evaluate(new_policy.update(policy))

        action_fn = common.function_in_tf1()(policy.action)
        action_step = action_fn(self._time_step_batch(batch_size=batch_size))
        new_action_fn = common.function_in_tf1()(new_policy.action)
        new_action_step = new_action_fn(
            self._time_step_batch(batch_size=batch_size))

        actions_, new_actions_ = self.evaluate(
            [action_step.action, new_action_step.action])
        self.assertAllEqual(actions_, new_actions_)
Exemple #4
0
    def __init__(self,
                 time_step_spec,
                 action_spec,
                 critic_network,
                 actor_network,
                 model_network,
                 compressor_network,
                 actor_optimizer,
                 critic_optimizer,
                 alpha_optimizer,
                 model_optimizer,
                 sequence_length,
                 target_update_tau=1.0,
                 target_update_period=1,
                 td_errors_loss_fn=tf.math.squared_difference,
                 gamma=1.0,
                 reward_scale_factor=1.0,
                 initial_log_alpha=0.0,
                 target_entropy=None,
                 gradient_clipping=None,
                 trainable_model=True,
                 critic_input='state',
                 actor_input='state',
                 critic_input_stop_gradient=True,
                 actor_input_stop_gradient=False,
                 model_batch_size=None,
                 control_timestep=None,
                 num_images_per_summary=1,
                 debug_summaries=False,
                 summarize_grads_and_vars=False,
                 train_step_counter=None,
                 name=None):
        tf.Module.__init__(self, name=name)

        self._critic_network1 = critic_network
        self._critic_network2 = critic_network.copy(name='CriticNetwork2')
        self._target_critic_network1 = critic_network.copy(
            name='TargetCriticNetwork1')
        self._target_critic_network2 = critic_network.copy(
            name='TargetCriticNetwork2')
        self._actor_network = actor_network
        self._model_network = model_network
        self._compressor_network = compressor_network

        policy = ActorSequencePolicy(
            time_step_spec=time_step_spec,
            action_spec=action_spec,
            actor_network=self._actor_network,
            model_network=self._model_network,
            compressor_network=self._compressor_network,
            sequence_length=sequence_length,
            actor_input=actor_input,
            control_timestep=control_timestep,
            num_images_per_summary=num_images_per_summary,
            debug_summaries=debug_summaries)

        self._log_alpha = common.create_variable(
            'initial_log_alpha',
            initial_value=initial_log_alpha,
            dtype=tf.float32,
            trainable=True)

        # If target_entropy was not passed, set it to negative of the total number
        # of action dimensions.
        if target_entropy is None:
            flat_action_spec = tf.nest.flatten(action_spec)
            target_entropy = -np.sum([
                np.product(single_spec.shape.as_list())
                for single_spec in flat_action_spec
            ])

        self._target_update_tau = target_update_tau
        self._target_update_period = target_update_period
        self._actor_optimizer = actor_optimizer
        self._critic_optimizer = critic_optimizer
        self._alpha_optimizer = alpha_optimizer
        self._model_optimizer = model_optimizer
        self._sequence_length = sequence_length
        self._td_errors_loss_fn = td_errors_loss_fn
        self._gamma = gamma
        self._reward_scale_factor = reward_scale_factor
        self._target_entropy = target_entropy
        self._gradient_clipping = gradient_clipping
        self._trainable_model = trainable_model
        self._critic_input = critic_input
        self._actor_input = actor_input
        self._critic_input_stop_gradient = critic_input_stop_gradient
        self._actor_input_stop_gradient = actor_input_stop_gradient
        self._model_batch_size = model_batch_size
        self._control_timestep = control_timestep
        self._num_images_per_summary = num_images_per_summary
        self._debug_summaries = debug_summaries
        self._summarize_grads_and_vars = summarize_grads_and_vars
        self._update_target = self._get_target_updater(
            tau=self._target_update_tau, period=self._target_update_period)

        self._actor_time_step_spec = time_step_spec._replace(
            observation=actor_network.input_tensor_spec)
        super(SlacAgent,
              self).__init__(time_step_spec,
                             action_spec,
                             policy=policy,
                             collect_policy=policy,
                             train_sequence_length=sequence_length + 1,
                             debug_summaries=debug_summaries,
                             summarize_grads_and_vars=summarize_grads_and_vars,
                             train_step_counter=train_step_counter)
        self._train_model_fn = common.function_in_tf1()(self._train_model)
Exemple #5
0
    def __init__(self,
                 time_step_spec,
                 action_spec,
                 policy_state_spec=(),
                 info_spec=(),
                 clip=True,
                 emit_log_probability=False,
                 automatic_state_reset=True,
                 name=None):
        """Initialization of Base class.

    Args:
      time_step_spec: A `TimeStep` spec of the expected time_steps. Usually
        provided by the user to the subclass.
      action_spec: A nest of BoundedTensorSpec representing the actions. Usually
        provided by the user to the subclass.
      policy_state_spec: A nest of TensorSpec representing the policy_state.
        Provided by the subclass, not directly by the user.
      info_spec: A nest of TensorSpec representing the policy info. Provided by
        the subclass, not directly by the user.
      clip: Whether to clip actions to spec before returning them.  Default
        True. Most policy-based algorithms (PCL, PPO, REINFORCE) use unclipped
        continuous actions for training.
      emit_log_probability: Emit log-probabilities of actions, if supported. If
        True, policy_step.info will have CommonFields.LOG_PROBABILITY set.
        Please consult utility methods provided in policy_step for setting and
        retrieving these. When working with custom policies, either provide a
        dictionary info_spec or a namedtuple with the field 'log_probability'.
      automatic_state_reset:  If `True`, then `get_initial_policy_state` is used
        to clear state in `action()` and `distribution()` for for time steps
        where `time_step.is_first()`.
      name: A name for this module. Defaults to the class name.
    """
        super(Base, self).__init__(name=name)
        common.assert_members_are_not_overridden(base_cls=Base, instance=self)
        if not isinstance(time_step_spec, ts.TimeStep):
            raise ValueError(
                'The `time_step_spec` must be an instance of `TimeStep`, but is `{}`.'
                .format(type(time_step_spec)))

        self._time_step_spec = time_step_spec
        self._action_spec = action_spec
        self._policy_state_spec = policy_state_spec
        self._emit_log_probability = emit_log_probability
        if emit_log_probability:
            log_probability_spec = tensor_spec.BoundedTensorSpec(
                shape=(),
                dtype=tf.float32,
                maximum=0,
                minimum=-float('inf'),
                name='log_probability')
            log_probability_spec = tf.nest.map_structure(
                lambda _: log_probability_spec, action_spec)
            info_spec = policy_step.set_log_probability(
                info_spec, log_probability_spec)

        self._info_spec = info_spec
        self._setup_specs()
        self._clip = clip
        self._action_fn = common.function_in_tf1()(self._action)
        self._automatic_state_reset = automatic_state_reset
    def testNeuralLinUCBUpdateDistributed(self, batch_size=1, context_dim=10):
        """Same as above but with distributed LinUCB updates."""

        # Construct a `Trajectory` for the given action, observation, reward.
        num_actions = 5
        initial_step, final_step = _get_initial_and_final_steps(
            batch_size, context_dim)
        action = np.random.randint(num_actions,
                                   size=batch_size,
                                   dtype=np.int32)
        action_step = _get_action_step(action)
        experience = _get_experience(initial_step, action_step, final_step)

        # Construct an agent and perform the update.
        observation_spec = tensor_spec.TensorSpec([context_dim], tf.float32)
        time_step_spec = time_step.time_step_spec(observation_spec)
        action_spec = tensor_spec.BoundedTensorSpec(dtype=tf.int32,
                                                    shape=(),
                                                    minimum=0,
                                                    maximum=num_actions - 1)
        encoder = DummyNet(observation_spec)
        encoding_dim = 10
        agent = neural_linucb_agent.NeuralLinUCBAgent(
            time_step_spec=time_step_spec,
            action_spec=action_spec,
            encoding_network=encoder,
            encoding_network_num_train_steps=0,
            encoding_dim=encoding_dim,
            optimizer=tf.compat.v1.train.AdamOptimizer(learning_rate=1e-2))

        self.evaluate(agent.initialize())
        self.evaluate(tf.compat.v1.global_variables_initializer())
        # Call the distributed LinUCB training instead of agent.train().
        train_fn = common.function_in_tf1()(
            agent.compute_loss_using_linucb_distributed)
        reward = tf.cast(experience.reward, agent._dtype)
        loss_info = train_fn(experience.observation,
                             action,
                             reward,
                             weights=None)
        self.evaluate(loss_info)
        final_a = self.evaluate(agent.cov_matrix)
        final_b = self.evaluate(agent.data_vector)

        # Compute the expected updated estimates.
        observations_list = tf.dynamic_partition(
            data=tf.reshape(tf.cast(experience.observation, tf.float64),
                            [batch_size, context_dim]),
            partitions=tf.convert_to_tensor(action),
            num_partitions=num_actions)
        rewards_list = tf.dynamic_partition(
            data=tf.reshape(tf.cast(experience.reward, tf.float64),
                            [batch_size]),
            partitions=tf.convert_to_tensor(action),
            num_partitions=num_actions)
        expected_a_updated_list = []
        expected_b_updated_list = []
        for _, (observations_for_arm, rewards_for_arm) in enumerate(
                zip(observations_list, rewards_list)):

            encoded_observations_for_arm, _ = encoder(observations_for_arm)
            encoded_observations_for_arm = tf.cast(
                encoded_observations_for_arm, dtype=tf.float64)

            num_samples_for_arm_current = tf.cast(
                tf.shape(rewards_for_arm)[0], tf.float64)
            num_samples_for_arm_total = num_samples_for_arm_current

            # pylint: disable=cell-var-from-loop
            def true_fn():
                a_new = tf.matmul(encoded_observations_for_arm,
                                  encoded_observations_for_arm,
                                  transpose_a=True)
                b_new = bandit_utils.sum_reward_weighted_observations(
                    rewards_for_arm, encoded_observations_for_arm)
                return a_new, b_new

            def false_fn():
                return (tf.zeros([encoding_dim, encoding_dim],
                                 dtype=tf.float64),
                        tf.zeros([encoding_dim], dtype=tf.float64))

            a_new, b_new = tf.cond(
                tf.squeeze(num_samples_for_arm_total) > 0, true_fn, false_fn)

            expected_a_updated_list.append(self.evaluate(a_new))
            expected_b_updated_list.append(self.evaluate(b_new))

        # Check that the actual updated estimates match the expectations.
        self.assertAllClose(expected_a_updated_list, final_a)
        self.assertAllClose(expected_b_updated_list, final_b)
Exemple #7
0
  def __init__(
      self,
      time_step_spec: ts.TimeStep,
      action_spec: types.NestedTensorSpec,
      policy: tf_policy.TFPolicy,
      collect_policy: tf_policy.TFPolicy,
      train_sequence_length: Optional[int],
      num_outer_dims: int = 2,
      training_data_spec: Optional[types.NestedTensorSpec] = None,
      debug_summaries: bool = False,
      summarize_grads_and_vars: bool = False,
      enable_summaries: bool = True,
      train_step_counter: Optional[tf.Variable] = None):
    """Meant to be called by subclass constructors.

    Args:
      time_step_spec: A nest of tf.TypeSpec representing the time_steps.
        Provided by the user.
      action_spec: A nest of BoundedTensorSpec representing the actions.
        Provided by the user.
      policy: An instance of `tf_policy.TFPolicy` representing the
        Agent's current policy.
      collect_policy: An instance of `tf_policy.TFPolicy` representing the
        Agent's current data collection policy (used to set `self.step_spec`).
      train_sequence_length: A python integer or `None`, signifying the number
        of time steps required from tensors in `experience` as passed to
        `train()`.  All tensors in `experience` will be shaped `[B, T, ...]` but
        for certain agents, `T` should be fixed.  For example, DQN requires
        transitions in the form of 2 time steps, so for a non-RNN DQN Agent, set
        this value to 2.  For agents that don't care, or which can handle `T`
        unknown at graph build time (i.e. most RNN-based agents), set this
        argument to `None`.
      num_outer_dims: The number of outer dimensions for the agent. Must be
        either 1 or 2. If 2, training will require both a batch_size and time
        dimension on every Tensor; if 1, training will require only a batch_size
        outer dimension.
      training_data_spec: A nest of TensorSpec specifying the structure of data
        the train() function expects. If None, defaults to the trajectory_spec
        of the collect_policy.
      debug_summaries: A bool; if true, subclasses should gather debug
        summaries.
      summarize_grads_and_vars: A bool; if true, subclasses should additionally
        collect gradient and variable summaries.
      enable_summaries: A bool; if false, subclasses should not gather any
        summaries (debug or otherwise); subclasses should gate *all* summaries
        using either `summaries_enabled`, `debug_summaries`, or
        `summarize_grads_and_vars` properties.
      train_step_counter: An optional counter to increment every time the train
        op is run.  Defaults to the global_step.

    Raises:
      ValueError: If `num_outer_dims` is not in `[1, 2]`.
    """
    common.check_tf1_allowed()
    common.tf_agents_gauge.get_cell("TFAgent").set(True)
    common.tf_agents_gauge.get_cell(str(type(self))).set(True)
    if not isinstance(time_step_spec, ts.TimeStep):
      raise TypeError(
          "The `time_step_spec` must be an instance of `TimeStep`, but is `{}`."
          .format(type(time_step_spec)))

    if num_outer_dims not in [1, 2]:
      raise ValueError("num_outer_dims must be in [1, 2].")

    self._time_step_spec = time_step_spec
    self._action_spec = action_spec
    self._policy = policy
    self._collect_policy = collect_policy
    self._train_sequence_length = train_sequence_length
    self._num_outer_dims = num_outer_dims
    self._debug_summaries = debug_summaries
    self._summarize_grads_and_vars = summarize_grads_and_vars
    self._enable_summaries = enable_summaries
    self._training_data_spec = training_data_spec
    # Data context for data collected directly from the collect policy.
    self._collect_data_context = data_converter.DataContext(
        time_step_spec=time_step_spec,
        action_spec=action_spec,
        info_spec=collect_policy.info_spec)
    # Data context for data passed to train().  May be different if
    # training_data_spec is provided.
    if training_data_spec is not None:
      # training_data_spec can be anything; so build a data_context
      # via best-effort with fall-backs to the collect data spec.
      training_discount_spec = getattr(
          training_data_spec, "discount", time_step_spec.discount)
      training_observation_spec = getattr(
          training_data_spec, "observation", time_step_spec.observation)
      training_reward_spec = getattr(
          training_data_spec, "reward", time_step_spec.reward)
      training_step_type_spec = getattr(
          training_data_spec, "step_type", time_step_spec.step_type)
      training_policy_info_spec = getattr(
          training_data_spec, "policy_info", collect_policy.info_spec)
      training_action_spec = getattr(
          training_data_spec, "action", action_spec)
      self._data_context = data_converter.DataContext(
          time_step_spec=ts.TimeStep(
              discount=training_discount_spec,
              observation=training_observation_spec,
              reward=training_reward_spec,
              step_type=training_step_type_spec),
          action_spec=training_action_spec,
          info_spec=training_policy_info_spec)
    else:
      self._data_context = data_converter.DataContext(
          time_step_spec=time_step_spec,
          action_spec=action_spec,
          info_spec=collect_policy.info_spec)
    if train_step_counter is None:
      train_step_counter = tf.compat.v1.train.get_or_create_global_step()
    self._train_step_counter = train_step_counter
    self._train_fn = common.function_in_tf1()(self._train)
    self._initialize_fn = common.function_in_tf1()(self._initialize)
    self._preprocess_sequence_fn = common.function_in_tf1()(
        self._preprocess_sequence)
    self._loss_fn = common.function_in_tf1()(self._loss)
Exemple #8
0
  def __init__(self,
               time_step_spec,
               action_spec,
               policy,
               collect_policy,
               train_sequence_length,
               num_outer_dims=2,
               debug_summaries=False,
               summarize_grads_and_vars=False,
               disable_summaries=False,
               train_step_counter=None):
    """Meant to be called by subclass constructors.

    Args:
      time_step_spec: A `TimeStep` spec of the expected time_steps. Provided by
        the user.
      action_spec: A nest of BoundedTensorSpec representing the actions.
        Provided by the user.
      policy: An instance of `tf_policy.Base` representing the Agent's current
        policy.
      collect_policy: An instance of `tf_policy.Base` representing the Agent's
        current data collection policy (used to set `self.step_spec`).
      train_sequence_length: A python integer or `None`, signifying the number
        of time steps required from tensors in `experience` as passed to
        `train()`.  All tensors in `experience` will be shaped `[B, T, ...]` but
        for certain agents, `T` should be fixed.  For example, DQN requires
        transitions in the form of 2 time steps, so for a non-RNN DQN Agent, set
        this value to 2.  For agents that don't care, or which can handle `T`
        unknown at graph build time (i.e. most RNN-based agents), set this
        argument to `None`.
      num_outer_dims: The number of outer dimensions for the agent. Must be
        either 1 or 2. If 2, training will require both a batch_size and time
        dimension on every Tensor; if 1, training will require only a batch_size
        outer dimension.
      debug_summaries: A bool; if true, subclasses should gather debug
        summaries.
      summarize_grads_and_vars: A bool; if true, subclasses should additionally
        collect gradient and variable summaries.
      disable_summaries: A bool; if true, subclasses should not gather any
        summaries (debug or otherwise); subclasses should gate all summaries
        using either `summaries_enabled`, `debug_summaries`, or
        `summarize_grads_and_vars` properties.
      train_step_counter: An optional counter to increment every time the train
        op is run.  Defaults to the global_step.

    Raises:
      ValueError: If `time_step_spec` is not an instance of `ts.TimeStep`.
      ValueError: If `num_outer_dims` is not in [1, 2].
    """
    common.assert_members_are_not_overridden(base_cls=TFAgent, instance=self)
    if not isinstance(time_step_spec, ts.TimeStep):
      raise ValueError(
          "The `time_step_spec` must be an instance of `TimeStep`, but is `{}`."
          .format(type(time_step_spec)))

    if num_outer_dims not in [1, 2]:
      raise ValueError("num_outer_dims must be in [1, 2].")

    self._time_step_spec = time_step_spec
    self._action_spec = action_spec
    self._policy = policy
    self._collect_policy = collect_policy
    self._train_sequence_length = train_sequence_length
    self._num_outer_dims = num_outer_dims
    self._debug_summaries = debug_summaries
    self._summarize_grads_and_vars = summarize_grads_and_vars
    self._disable_summaries = disable_summaries
    if train_step_counter is None:
      train_step_counter = tf.compat.v1.train.get_or_create_global_step()
    self._train_step_counter = train_step_counter
    self._train_fn = common.function_in_tf1()(self._train)
    self._initialize_fn = common.function_in_tf1()(self._initialize)
Exemple #9
0
    def testSparseObs(self, batch_size, actions_from_reward_layer):
        obs_spec = {
            'global': {
                'sport': tensor_spec.TensorSpec((), tf.string)
            },
            'per_arm': {
                'name': tensor_spec.TensorSpec((3, ), tf.string),
                'fruit': tensor_spec.TensorSpec((3, ), tf.string)
            }
        }
        columns_a = tf.feature_column.indicator_column(
            tf.feature_column.categorical_column_with_vocabulary_list(
                'name', ['bob', 'george', 'wanda']))
        columns_b = tf.feature_column.indicator_column(
            tf.feature_column.categorical_column_with_vocabulary_list(
                'fruit', ['banana', 'kiwi', 'pear']))
        columns_c = tf.feature_column.indicator_column(
            tf.feature_column.categorical_column_with_vocabulary_list(
                'sport', ['bridge', 'chess', 'snooker']))

        dummy_net = arm_network.create_feed_forward_common_tower_network(
            obs_spec,
            global_layers=(3, 4, 5),
            arm_layers=(3, 2),
            common_layers=(4, 3),
            output_dim=self._encoding_dim,
            global_preprocessing_combiner=(
                tf.compat.v2.keras.layers.DenseFeatures([columns_c])),
            arm_preprocessing_combiner=tf.compat.v2.keras.layers.DenseFeatures(
                [columns_a, columns_b]))
        time_step_spec = ts.time_step_spec(obs_spec)
        reward_layer = get_per_arm_reward_layer(
            encoding_dim=self._encoding_dim)
        policy = neural_linucb_policy.NeuralLinUCBPolicy(
            dummy_net,
            self._encoding_dim,
            reward_layer,
            actions_from_reward_layer=tf.constant(actions_from_reward_layer,
                                                  dtype=tf.bool),
            cov_matrix=self._a[0:1],
            data_vector=self._b[0:1],
            num_samples=self._num_samples_per_arm[0:1],
            epsilon_greedy=0.0,
            time_step_spec=time_step_spec,
            accepts_per_arm_features=True,
            emit_policy_info=('predicted_rewards_mean', ))
        observations = {
            'global': {
                'sport': tf.constant(['snooker', 'chess'])
            },
            'per_arm': {
                'name':
                tf.constant([['george', 'george', 'george'],
                             ['bob', 'bob', 'bob']]),
                'fruit':
                tf.constant([['banana', 'banana', 'banana'],
                             ['kiwi', 'kiwi', 'kiwi']])
            }
        }

        time_step = ts.restart(observations, batch_size=2)
        action_fn = common.function_in_tf1()(policy.action)
        action_step = action_fn(time_step, seed=1)
        self.assertEqual(action_step.action.shape.as_list(), [2])
        self.assertEqual(action_step.action.dtype, tf.int32)
        # Initialize all variables
        self.evaluate([
            tf.compat.v1.global_variables_initializer(),
            tf.compat.v1.tables_initializer()
        ])
        action = self.evaluate(action_step.action)
        self.assertAllEqual(action.shape, [2])
        p_info = self.evaluate(action_step.info)
        self.assertAllEqual(p_info.predicted_rewards_mean.shape, [2, 3])
        self.assertAllEqual(p_info.chosen_arm_features['name'].shape, [2])
        self.assertAllEqual(p_info.chosen_arm_features['fruit'].shape, [2])
        first_action = action[0]
        first_arm_name_feature = observations[
            bandit_spec_utils.PER_ARM_FEATURE_KEY]['name'][0]
        self.assertAllEqual(p_info.chosen_arm_features['name'][0],
                            first_arm_name_feature[first_action])
Exemple #10
0
    def __init__(self,
                 time_step_spec,
                 action_spec,
                 policy_state_spec=(),
                 info_spec=(),
                 clip=True,
                 emit_log_probability=False,
                 automatic_state_reset=True,
                 observation_and_action_constraint_splitter=None,
                 name=None):
        """Initialization of Base class.

    Args:
      time_step_spec: A `TimeStep` spec of the expected time_steps. Usually
        provided by the user to the subclass.
      action_spec: A nest of BoundedTensorSpec representing the actions. Usually
        provided by the user to the subclass.
      policy_state_spec: A nest of TensorSpec representing the policy_state.
        Provided by the subclass, not directly by the user.
      info_spec: A nest of TensorSpec representing the policy info. Provided by
        the subclass, not directly by the user.
      clip: Whether to clip actions to spec before returning them.  Default
        True. Most policy-based algorithms (PCL, PPO, REINFORCE) use unclipped
        continuous actions for training.
      emit_log_probability: Emit log-probabilities of actions, if supported. If
        True, policy_step.info will have CommonFields.LOG_PROBABILITY set.
        Please consult utility methods provided in policy_step for setting and
        retrieving these. When working with custom policies, either provide a
        dictionary info_spec or a namedtuple with the field 'log_probability'.
      automatic_state_reset:  If `True`, then `get_initial_policy_state` is used
        to clear state in `action()` and `distribution()` for for time steps
        where `time_step.is_first()`.
      observation_and_action_constraint_splitter: A function used to process
        observations with action constraints. These constraints can indicate,
        for example, a mask of valid/invalid actions for a given state of the
        environment. The function takes in a full observation and returns a
        tuple consisting of 1) the part of the observation intended as input to
        the network and 2) the constraint. An example
        `observation_and_action_constraint_splitter` could be as simple as: ```
        def observation_and_action_constraint_splitter(observation): return
          observation['network_input'], observation['constraint'] ```
        *Note*: when using `observation_and_action_constraint_splitter`, make
          sure the provided `q_network` is compatible with the network-specific
          half of the output of the
          `observation_and_action_constraint_splitter`. In particular,
          `observation_and_action_constraint_splitter` will be called on the
          observation before passing to the network. If
          `observation_and_action_constraint_splitter` is None, action
          constraints are not applied.
      name: A name for this module. Defaults to the class name.
    """
        super(Base, self).__init__(name=name)
        common.tf_agents_gauge.get_cell('TFAPolicy').set(True)
        common.assert_members_are_not_overridden(base_cls=Base, instance=self)
        if not isinstance(time_step_spec, ts.TimeStep):
            raise ValueError(
                'The `time_step_spec` must be an instance of `TimeStep`, but is `{}`.'
                .format(type(time_step_spec)))

        self._time_step_spec = time_step_spec
        self._action_spec = action_spec
        self._policy_state_spec = policy_state_spec
        self._emit_log_probability = emit_log_probability
        if emit_log_probability:
            log_probability_spec = tensor_spec.BoundedTensorSpec(
                shape=(),
                dtype=tf.float32,
                maximum=0,
                minimum=-float('inf'),
                name='log_probability')
            log_probability_spec = tf.nest.map_structure(
                lambda _: log_probability_spec, action_spec)
            info_spec = policy_step.set_log_probability(
                info_spec, log_probability_spec)

        self._info_spec = info_spec
        self._setup_specs()
        self._clip = clip
        self._action_fn = common.function_in_tf1()(self._action)
        self._automatic_state_reset = automatic_state_reset
        self._observation_and_action_constraint_splitter = (
            observation_and_action_constraint_splitter)
Exemple #11
0
    def testPerArmObservation(self, batch_size, actions_from_reward_layer):
        global_obs_dim = 7
        arm_obs_dim = 3
        obs_spec = bandit_spec_utils.create_per_arm_observation_spec(
            global_obs_dim,
            arm_obs_dim,
            self._num_actions,
            add_num_actions_feature=True)
        time_step_spec = ts.time_step_spec(obs_spec)
        dummy_net = arm_network.create_feed_forward_common_tower_network(
            obs_spec,
            global_layers=(3, 4, 5),
            arm_layers=(3, 2),
            common_layers=(4, 3),
            output_dim=self._encoding_dim)
        reward_layer = get_per_arm_reward_layer(
            encoding_dim=self._encoding_dim)

        policy = neural_linucb_policy.NeuralLinUCBPolicy(
            dummy_net,
            self._encoding_dim,
            reward_layer,
            actions_from_reward_layer=tf.constant(actions_from_reward_layer,
                                                  dtype=tf.bool),
            cov_matrix=self._a[0:1],
            data_vector=self._b[0:1],
            num_samples=self._num_samples_per_arm[0:1],
            epsilon_greedy=0.0,
            time_step_spec=time_step_spec,
            accepts_per_arm_features=True,
            emit_policy_info=('predicted_rewards_mean',
                              'predicted_rewards_optimistic'))

        current_time_step = self._per_arm_time_step_batch(
            batch_size=batch_size,
            global_obs_dim=global_obs_dim,
            arm_obs_dim=arm_obs_dim)
        action_step = policy.action(current_time_step)
        self.assertEqual(action_step.action.dtype, tf.int32)
        self.evaluate(tf.compat.v1.global_variables_initializer())
        action_fn = common.function_in_tf1()(policy.action)
        action_step = action_fn(current_time_step)

        input_observation = current_time_step.observation
        encoded_observation, _ = dummy_net(input_observation)

        if actions_from_reward_layer:
            predicted_rewards_from_reward_layer = reward_layer(
                encoded_observation)
            predicted_rewards_expected = self.evaluate(
                predicted_rewards_from_reward_layer).reshape(
                    (-1, self._num_actions))
        else:
            observation_numpy = self.evaluate(encoded_observation)
            predicted_rewards_expected = (
                self._get_predicted_rewards_from_per_arm_linucb(
                    observation_numpy, batch_size))

        p_info = self.evaluate(action_step.info)
        self.assertEqual(p_info.predicted_rewards_mean.dtype, np.float32)
        self.assertAllClose(p_info.predicted_rewards_mean,
                            predicted_rewards_expected)
        self.assertAllGreaterEqual(
            p_info.predicted_rewards_optimistic - predicted_rewards_expected,
            0)
Exemple #12
0
  def __init__(self,

               # counter
               train_step_counter,

               # specs
               time_step_spec,
               action_spec,

               # networks
               critic_network,
               actor_network,
               model_network,
               compressor_network,

               # optimizers
               actor_optimizer,
               critic_optimizer,
               alpha_optimizer,
               model_optimizer,

               # target update
               target_update_tau=1.0,
               target_update_period=1,

               # inputs and stop gradients
               critic_input='state',
               actor_input='state',
               critic_input_stop_gradient=True,
               actor_input_stop_gradient=False,

               # model stuff
               model_batch_size=256, # will round to nearest full trajectory
               ac_batch_size=128,

               # other
               episodes_per_trial = 1,
               num_tasks_per_train=1,
               num_batches_per_sampled_trials=1,
               td_errors_loss_fn=tf.math.squared_difference,
               gamma=1.0,
               reward_scale_factor=1.0,
               task_reward_dim=None,
               initial_log_alpha=0.0,
               target_entropy=None,
               gradient_clipping=None,
               control_timestep=None,
               num_images_per_summary=1,

               offline_ratio=None,
               override_reward_func=None,
               ):

    tf.Module.__init__(self)
    self.override_reward_func = override_reward_func
    self.offline_ratio = offline_ratio

    ################
    # critic
    ################
    # networks
    self._critic_network1 = critic_network
    self._critic_network2 = critic_network.copy(name='CriticNetwork2')
    self._target_critic_network1 = critic_network.copy(name='TargetCriticNetwork1')
    self._target_critic_network2 = critic_network.copy(name='TargetCriticNetwork2')
    # update the target networks
    self._target_update_tau = target_update_tau
    self._target_update_period = target_update_period
    self._update_target = self._get_target_updater(tau=self._target_update_tau, period=self._target_update_period)

    ################
    # model
    ################
    self._model_network = model_network
    self.model_input = self._model_network.model_input

    ################
    # compressor
    ################
    self._compressor_network = compressor_network

    ################
    # actor
    ################
    self._actor_network = actor_network

    ################
    # policies
    ################

    self.condition_on_full_latent_dist = (actor_input=="latentDistribution" and critic_input=="latentDistribution")
    
    # both policies below share the same actor network
    # but they process latents (to give to actor network) in potentially different ways

    # used for eval
    which_posterior='first'
    if self._model_network.sparse_reward_inputs:
      which_rew_input='sparse'
    else:
      which_rew_input='dense'

    policy = MeldPolicy(
        time_step_spec=time_step_spec,
        action_spec=action_spec,
        actor_network=self._actor_network,
        model_network=self._model_network,
        actor_input=actor_input,
        which_posterior=which_posterior,
        which_rew_input=which_rew_input,
        )

    # used for collecting data during training

    # overwrite if specified (eg for double agent)
    which_posterior='first'
    if self._model_network.sparse_reward_inputs:
      which_rew_input='sparse'
    else:
      which_rew_input='dense'

    collect_policy = MeldPolicy(
      time_step_spec=time_step_spec,
      action_spec=action_spec,
      actor_network=self._actor_network,
      model_network=self._model_network,
      actor_input=actor_input,
      which_posterior=which_posterior,
      which_rew_input=which_rew_input,
      )


    ################
    # more vars
    ################
    self.num_batches_per_sampled_trials = num_batches_per_sampled_trials
    self.episodes_per_trial = episodes_per_trial
    self._task_reward_dim = task_reward_dim
    self._log_alpha = common.create_variable(
        'initial_log_alpha',
        initial_value=initial_log_alpha,
        dtype=tf.float32,
        trainable=True)

    # If target_entropy was not passed, set it to negative of the total number
    # of action dimensions.
    if target_entropy is None:
      flat_action_spec = tf.nest.flatten(action_spec)
      target_entropy = -np.sum([
        np.product(single_spec.shape.as_list())
        for single_spec in flat_action_spec
      ])

    self._actor_optimizer = actor_optimizer
    self._critic_optimizer = critic_optimizer
    self._alpha_optimizer = alpha_optimizer
    self._model_optimizer = model_optimizer
    self._td_errors_loss_fn = td_errors_loss_fn
    self._gamma = gamma
    self._reward_scale_factor = reward_scale_factor
    self._target_entropy = target_entropy
    self._gradient_clipping = gradient_clipping

    self._critic_input = critic_input
    self._actor_input = actor_input
    self._critic_input_stop_gradient = critic_input_stop_gradient
    self._actor_input_stop_gradient = actor_input_stop_gradient
    self._model_batch_size = model_batch_size
    self._ac_batch_size = ac_batch_size
    self._control_timestep = control_timestep
    self._num_images_per_summary = num_images_per_summary
    self._actor_time_step_spec = time_step_spec._replace(observation=actor_network.input_tensor_spec)
    self._num_tasks_per_train = num_tasks_per_train

    ################
    # init tf agent
    ################

    super(MeldAgent, self).__init__(
        time_step_spec,
        action_spec,
        policy=policy,
        collect_policy=collect_policy, #used to set self.step_spec
        train_sequence_length=None, #train function can accept experience of any length T (i.e., [B,T,...])
        train_step_counter=train_step_counter)

    self._train_model_fn = common.function_in_tf1()(self._train_model)
    self._train_ac_fn = common.function_in_tf1()(self._train_ac)
Exemple #13
0
    def __init__(self,
                 time_step_spec: ts.TimeStep,
                 action_spec: types.NestedTensorSpec,
                 policy: tf_policy.TFPolicy,
                 collect_policy: tf_policy.TFPolicy,
                 train_sequence_length: Optional[int],
                 num_outer_dims: int = 2,
                 training_data_spec: Optional[types.NestedTensorSpec] = None,
                 train_argspec: Optional[Dict[Text,
                                              types.NestedTensorSpec]] = None,
                 debug_summaries: bool = False,
                 summarize_grads_and_vars: bool = False,
                 enable_summaries: bool = True,
                 train_step_counter: Optional[tf.Variable] = None,
                 validate_args: bool = True):
        """Meant to be called by subclass constructors.

    Args:
      time_step_spec: A nest of tf.TypeSpec representing the time_steps.
        Provided by the user.
      action_spec: A nest of BoundedTensorSpec representing the actions.
        Provided by the user.
      policy: An instance of `tf_policy.TFPolicy` representing the
        Agent's current policy.
      collect_policy: An instance of `tf_policy.TFPolicy` representing the
        Agent's current data collection policy (used to set `self.step_spec`).
      train_sequence_length: A python integer or `None`, signifying the number
        of time steps required from tensors in `experience` as passed to
        `train()`.  All tensors in `experience` will be shaped `[B, T, ...]` but
        for certain agents, `T` should be fixed.  For example, DQN requires
        transitions in the form of 2 time steps, so for a non-RNN DQN Agent, set
        this value to 2.  For agents that don't care, or which can handle `T`
        unknown at graph build time (i.e. most RNN-based agents), set this
        argument to `None`.
      num_outer_dims: The number of outer dimensions for the agent. Must be
        either 1 or 2. If 2, training will require both a batch_size and time
        dimension on every Tensor; if 1, training will require only a batch_size
        outer dimension.
      training_data_spec: A nest of TensorSpec specifying the structure of data
        the train() function expects. If None, defaults to the trajectory_spec
        of the collect_policy.
      train_argspec: (Optional) Describes additional supported arguments
        to the `train` call.  This must be a `dict` mapping strings to nests
        of specs.  Overriding the `experience` arg is also supported.

        Some algorithms require additional arguments to the `train()` call, and
        while TF-Agents encourages most of these to be provided in the
        `policy_info` / `info` field of `experience`, sometimes the extra
        information doesn't fit well, i.e., when it doesn't come from the
        policy.

        **NOTE** kwargs will not have their outer dimensions validated.
        In particular, `train_sequence_length` is ignored for these inputs,
        and they may have any, or inconsistent, batch/time dimensions; only
        their inner shape dimensions are checked against `train_argspec`.

        Below is an example:

        ```python
        class MyAgent(TFAgent):
          def __init__(self, counterfactual_training, ...):
             collect_policy = ...
             train_argspec = None
             if counterfactual_training:
               train_argspec = dict(
                  counterfactual=collect_policy.trajectory_spec)
             super(...).__init__(
               ...
               train_argspec=train_argspec)

        my_agent = MyAgent(...)

        for ...:
          experience, counterfactual = next(experience_and_counterfactual_iter)
          loss_info = my_agent.train(experience, counterfactual=counterfactual)
        ```
      debug_summaries: A bool; if true, subclasses should gather debug
        summaries.
      summarize_grads_and_vars: A bool; if true, subclasses should additionally
        collect gradient and variable summaries.
      enable_summaries: A bool; if false, subclasses should not gather any
        summaries (debug or otherwise); subclasses should gate *all* summaries
        using either `summaries_enabled`, `debug_summaries`, or
        `summarize_grads_and_vars` properties.
      train_step_counter: An optional counter to increment every time the train
        op is run.  Defaults to the global_step.
      validate_args: Python bool.  Whether to verify inputs to, and outputs of,
        functions like `train` and `preprocess_sequence` against spec
        structures, dtypes, and shapes.

        Research code may prefer to set this value to `False` to allow iterating
        on input and output structures without being hamstrung by overly
        rigid checking (at the cost of harder-to-debug errors).

        See also `TFPolicy.validate_args`.

    Raises:
      TypeError: If `validate_args is True` and `train_argspec` is not a `dict`.
      ValueError: If `validate_args is True` and `train_argspec` has the keys
        `experience` or `weights`.
      TypeError: If `validate_args is True` and any leaf nodes in
        `train_argspec` values are not subclasses of `tf.TypeSpec`.
      ValueError: If `validate_args is True` and `time_step_spec` is not an
        instance of `ts.TimeStep`.
      ValueError: If `num_outer_dims` is not in `[1, 2]`.
    """
        if validate_args:

            def _each_isinstance(spec, spec_types):
                """Checks if each element of `spec` is instance of `spec_types`."""
                return all(
                    [isinstance(s, spec_types) for s in tf.nest.flatten(spec)])

            if not _each_isinstance(time_step_spec, tf.TypeSpec):
                raise TypeError(
                    "time_step_spec has to contain TypeSpec (TensorSpec, "
                    "SparseTensorSpec, etc) objects, but received: {}".format(
                        time_step_spec))

            if not _each_isinstance(action_spec,
                                    tensor_spec.BoundedTensorSpec):
                raise TypeError(
                    "action_spec has to contain BoundedTensorSpec objects, but received: "
                    "{}".format(action_spec))

        common.check_tf1_allowed()
        common.tf_agents_gauge.get_cell("TFAgent").set(True)
        common.tf_agents_gauge.get_cell(str(type(self))).set(True)
        if not isinstance(time_step_spec, ts.TimeStep):
            raise TypeError(
                "The `time_step_spec` must be an instance of `TimeStep`, but is `{}`."
                .format(type(time_step_spec)))

        if num_outer_dims not in [1, 2]:
            raise ValueError("num_outer_dims must be in [1, 2].")

        self._time_step_spec = time_step_spec
        self._action_spec = action_spec
        self._policy = policy
        self._collect_policy = collect_policy
        self._train_sequence_length = train_sequence_length
        self._num_outer_dims = num_outer_dims
        self._debug_summaries = debug_summaries
        self._summarize_grads_and_vars = summarize_grads_and_vars
        self._enable_summaries = enable_summaries
        self._training_data_spec = training_data_spec
        self._validate_args = validate_args
        # Data context for data collected directly from the collect policy.
        self._collect_data_context = data_converter.DataContext(
            time_step_spec=time_step_spec,
            action_spec=action_spec,
            info_spec=collect_policy.info_spec)
        # Data context for data passed to train().  May be different if
        # training_data_spec is provided.
        if training_data_spec is not None:
            # training_data_spec can be anything; so build a data_context
            # via best-effort with fall-backs to the collect data spec.
            training_discount_spec = getattr(training_data_spec, "discount",
                                             time_step_spec.discount)
            training_observation_spec = getattr(training_data_spec,
                                                "observation",
                                                time_step_spec.observation)
            training_reward_spec = getattr(training_data_spec, "reward",
                                           time_step_spec.reward)
            training_step_type_spec = getattr(training_data_spec, "step_type",
                                              time_step_spec.step_type)
            training_policy_info_spec = getattr(training_data_spec,
                                                "policy_info",
                                                collect_policy.info_spec)
            training_action_spec = getattr(training_data_spec, "action",
                                           action_spec)
            self._data_context = data_converter.DataContext(
                time_step_spec=ts.TimeStep(
                    discount=training_discount_spec,
                    observation=training_observation_spec,
                    reward=training_reward_spec,
                    step_type=training_step_type_spec),
                action_spec=training_action_spec,
                info_spec=training_policy_info_spec)
        else:
            self._data_context = data_converter.DataContext(
                time_step_spec=time_step_spec,
                action_spec=action_spec,
                info_spec=collect_policy.info_spec)
        if train_argspec is None:
            train_argspec = {}
        elif validate_args:
            if not isinstance(train_argspec, dict):
                raise TypeError(
                    "train_argspec must be a dict, but saw: {}".format(
                        train_argspec))
            if "weights" in train_argspec or "experience" in train_argspec:
                raise ValueError(
                    "train_argspec must not override 'weights' or "
                    "'experience' keys, but saw: {}".format(train_argspec))
            if not all(
                    isinstance(x, tf.TypeSpec)
                    for x in tf.nest.flatten(train_argspec)):
                raise TypeError(
                    "train_argspec contains non-TensorSpec objects: {}".format(
                        train_argspec))
        train_argspec = dict(train_argspec)  # Create a local copy.
        self._train_argspec = train_argspec
        if train_step_counter is None:
            train_step_counter = tf.compat.v1.train.get_or_create_global_step()
        self._train_step_counter = train_step_counter
        self._train_fn = common.function_in_tf1()(self._train)
        self._initialize_fn = common.function_in_tf1()(self._initialize)
        self._preprocess_sequence_fn = common.function_in_tf1()(
            self._preprocess_sequence)
        self._loss_fn = common.function_in_tf1()(self._loss)
Exemple #14
0
    def __init__(self,
                 time_step_spec,
                 action_spec,
                 policy,
                 collect_policy,
                 train_sequence_length,
                 update_period=None,
                 debug_summaries=False,
                 enable_functions=True,
                 summarize_grads_and_vars=False,
                 train_step_counter=None):
        """Meant to be called by subclass constructors.

    Args:
      time_step_spec: A `TimeStep` spec of the expected time_steps. Provided by
        the user.
      action_spec: A nest of BoundedTensorSpec representing the actions.
        Provided by the user.
      policy: An instance of `tf_policy.Base` representing the Agent's current
        policy.
      collect_policy: An instance of `tf_policy.Base` representing the Agent's
        current data collection policy (used to set `self.step_spec`).
      train_sequence_length: A python integer or `None`, signifying the number
        of time steps required from tensors in `experience` as passed to
        `train()`.  All tensors in `experience` will be shaped `[B, T, ...]` but
        for certain agents, `T` should be fixed.  For example, DQN requires
        transitions in the form of 2 time steps, so for a non-RNN DQN Agent, set
        this value to 2.  For agents that don't care, or which can handle `T`
        unknown at graph build time (i.e. most RNN-based agents), set this
        argument to `None`.
      update_period: Update period.
      debug_summaries: A bool; if true, subclasses should gather debug
        summaries.
      enable_functions: A bool; if true, enable functions.
      summarize_grads_and_vars: A bool; if true, subclasses should additionally
        collect gradient and variable summaries.
      train_step_counter: An optional counter to increment every time the train
        op is run.  Defaults to the global_step.
    """
        common.assert_members_are_not_overridden(base_cls=TFAgent,
                                                 instance=self)
        if not isinstance(time_step_spec, ts.TimeStep):
            raise ValueError(
                "The `time_step_spec` must be an instance of `TimeStep`, but is `{}`."
                .format(type(time_step_spec)))

        self._time_step_spec = time_step_spec
        self._action_spec = action_spec
        self._policy = policy
        # self._py_policy = py_tf_policy.PyTFPolicy(policy)
        self._collect_policy = collect_policy
        # self._collect_py_policy = py_tf_policy.PyTFPolicy(collect_policy)
        self._train_sequence_length = train_sequence_length
        self.update_period = update_period
        self._debug_summaries = debug_summaries
        self._summarize_grads_and_vars = summarize_grads_and_vars
        if train_step_counter is None:
            train_step_counter = tf.compat.v1.train.get_or_create_global_step()
        self._train_step_counter = train_step_counter
        self._train_fn = common.function_in_tf1()(self._train)
        self._initialize_fn = common.function_in_tf1()(self._initialize)

        self._enable_functions = enable_functions
Exemple #15
0
def train_eval(
    root_dir,
    load_root_dir=None,
    env_load_fn=None,
    gym_env_wrappers=[],
    monitor=False,
    env_name=None,
    agent_class=None,
    initial_collect_driver_class=None,
    collect_driver_class=None,
    online_driver_class=dynamic_episode_driver.DynamicEpisodeDriver,
    num_global_steps=1000000,
    rb_size=None,
    train_steps_per_iteration=1,
    train_metrics=None,
    eval_metrics=None,
    train_metrics_callback=None,
    # SacAgent args
    actor_fc_layers=(256, 256),
    critic_joint_fc_layers=(256, 256),
    # Safety Critic training args
    sc_rb_size=None,
    target_safety=None,
    train_sc_steps=10,
    train_sc_interval=1000,
    online_critic=False,
    n_envs=None,
    finetune_sc=False,
    pretraining=True,
    lambda_schedule_nsteps=0,
    lambda_initial=0.,
    lambda_final=1.,
    kstep_fail=0,
    # Ensemble Critic training args
    num_critics=None,
    critic_learning_rate=3e-4,
    # Wcpg Critic args
    critic_preprocessing_layer_size=256,
    # Params for train
    batch_size=256,
    # Params for eval
    run_eval=False,
    num_eval_episodes=10,
    eval_interval=1000,
    # Params for summaries and logging
    train_checkpoint_interval=10000,
    policy_checkpoint_interval=5000,
    rb_checkpoint_interval=50000,
    keep_rb_checkpoint=False,
    log_interval=1000,
    summary_interval=1000,
    monitor_interval=5000,
    summaries_flush_secs=10,
    early_termination_fn=None,
    debug_summaries=False,
    seed=None,
    eager_debug=False,
    env_metric_factories=None,
    wandb=False):  # pylint: disable=unused-argument

  """train and eval script for SQRL."""
  if isinstance(agent_class, str):
    assert agent_class in ALGOS, 'trainer.train_eval: agent_class {} invalid'.format(agent_class)
    agent_class = ALGOS.get(agent_class)
  n_envs = n_envs or num_eval_episodes
  root_dir = os.path.expanduser(root_dir)
  train_dir = os.path.join(root_dir, 'train')

  # =====================================================================#
  #  Setup summary metrics, file writers, and create env                 #
  # =====================================================================#
  train_summary_writer = tf.compat.v2.summary.create_file_writer(
    train_dir, flush_millis=summaries_flush_secs * 1000)
  train_summary_writer.set_as_default()

  train_metrics = train_metrics or []
  eval_metrics = eval_metrics or []

  updating_sc = online_critic and (not load_root_dir or finetune_sc)
  logging.debug('updating safety critic: %s', updating_sc)

  if seed:
    tf.compat.v1.set_random_seed(seed)

  if agent_class in SAFETY_AGENTS:
    if online_critic:
      sc_tf_env = tf_py_environment.TFPyEnvironment(
        parallel_py_environment.ParallelPyEnvironment(
          [lambda: env_load_fn(env_name)] * n_envs
        ))
      if seed:
        seeds = [seed * n_envs + i for i in range(n_envs)]
        try:
          sc_tf_env.pyenv.seed(seeds)
        except:
          pass

  if run_eval:
    eval_dir = os.path.join(root_dir, 'eval')
    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
      eval_dir, flush_millis=summaries_flush_secs * 1000)
    eval_metrics = [
                     tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes, batch_size=n_envs),
                     tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes, batch_size=n_envs),
                   ] + [tf_py_metric.TFPyMetric(m) for m in eval_metrics]
    eval_tf_env = tf_py_environment.TFPyEnvironment(
      parallel_py_environment.ParallelPyEnvironment(
        [lambda: env_load_fn(env_name)] * n_envs
      ))
    if seed:
      try:
        for i, pyenv in enumerate(eval_tf_env.pyenv.envs):
          pyenv.seed(seed * n_envs + i)
      except:
        pass
  elif 'Drunk' in env_name:
    # Just visualizes trajectories in drunk spider environment
    eval_tf_env = tf_py_environment.TFPyEnvironment(
      env_load_fn(env_name))
  else:
    eval_tf_env = None

  if monitor:
    vid_path = os.path.join(root_dir, 'rollouts')
    monitor_env_wrapper = misc.monitor_freq(1, vid_path)
    monitor_env = gym.make(env_name)
    for wrapper in gym_env_wrappers:
      monitor_env = wrapper(monitor_env)
    monitor_env = monitor_env_wrapper(monitor_env)
    # auto_reset must be False to ensure Monitor works correctly
    monitor_py_env = gym_wrapper.GymWrapper(monitor_env, auto_reset=False)

  global_step = tf.compat.v1.train.get_or_create_global_step()

  with tf.summary.record_if(
          lambda: tf.math.equal(global_step % summary_interval, 0)):
    py_env = env_load_fn(env_name)
    tf_env = tf_py_environment.TFPyEnvironment(py_env)
    if seed:
      try:
        for i, pyenv in enumerate(tf_env.pyenv.envs):
          pyenv.seed(seed * n_envs + i)
      except:
        pass
    time_step_spec = tf_env.time_step_spec()
    observation_spec = time_step_spec.observation
    action_spec = tf_env.action_spec()

    logging.debug('obs spec: %s', observation_spec)
    logging.debug('action spec: %s', action_spec)

    # =====================================================================#
    #  Setup agent class                                                   #
    # =====================================================================#

    if agent_class == wcpg_agent.WcpgAgent:
      alpha_spec = tensor_spec.BoundedTensorSpec(shape=(1,), dtype=tf.float32, minimum=0., maximum=1.,
                                                 name='alpha')
      input_tensor_spec = (observation_spec, action_spec, alpha_spec)
      critic_net = agents.DistributionalCriticNetwork(
        input_tensor_spec, preprocessing_layer_size=critic_preprocessing_layer_size,
        joint_fc_layer_params=critic_joint_fc_layers)
      actor_net = agents.WcpgActorNetwork((observation_spec, alpha_spec), action_spec)
    else:
      actor_net = actor_distribution_network.ActorDistributionNetwork(
        observation_spec,
        action_spec,
        fc_layer_params=actor_fc_layers,
        continuous_projection_net=agents.normal_projection_net)
      critic_net = agents.CriticNetwork(
        (observation_spec, action_spec),
        joint_fc_layer_params=critic_joint_fc_layers)

    if agent_class in SAFETY_AGENTS:
      logging.debug('Making SQRL agent')
      if lambda_schedule_nsteps > 0:
        lambda_update_every_nsteps = num_global_steps // lambda_schedule_nsteps
        step_size = (lambda_final - lambda_initial) / lambda_update_every_nsteps
        lambda_scheduler = lambda lam: common.periodically(
          body=lambda: tf.group(lam.assign(lam + step_size)),
          period=lambda_update_every_nsteps)
      else:
        lambda_scheduler = None
      safety_critic_net = agents.CriticNetwork(
        (observation_spec, action_spec),
        joint_fc_layer_params=critic_joint_fc_layers)
      ts = target_safety
      thresholds = [ts, 0.5]
      sc_metrics = [tf.keras.metrics.AUC(name='safety_critic_auc'),
                    tf.keras.metrics.TruePositives(name='safety_critic_tp',
                                                   thresholds=thresholds),
                    tf.keras.metrics.FalsePositives(name='safety_critic_fp',
                                                    thresholds=thresholds),
                    tf.keras.metrics.TrueNegatives(name='safety_critic_tn',
                                                   thresholds=thresholds),
                    tf.keras.metrics.FalseNegatives(name='safety_critic_fn',
                                                    thresholds=thresholds),
                    tf.keras.metrics.BinaryAccuracy(name='safety_critic_acc',
                                                    threshold=0.5)]
      tf_agent = agent_class(
        time_step_spec,
        action_spec,
        actor_network=actor_net,
        critic_network=critic_net,
        safety_critic_network=safety_critic_net,
        train_step_counter=global_step,
        debug_summaries=debug_summaries,
        safety_pretraining=pretraining,
        train_critic_online=online_critic,
        initial_log_lambda=lambda_initial,
        log_lambda=(lambda_scheduler is None),
        lambda_scheduler=lambda_scheduler)
    elif agent_class is ensemble_sac_agent.EnsembleSacAgent:
      critic_nets, critic_optimizers = [critic_net], [tf.keras.optimizers.Adam(critic_learning_rate)]
      for _ in range(num_critics - 1):
        critic_nets.append(agents.CriticNetwork((observation_spec, action_spec),
                                                joint_fc_layer_params=critic_joint_fc_layers))
        critic_optimizers.append(tf.keras.optimizers.Adam(critic_learning_rate))
      tf_agent = agent_class(
        time_step_spec,
        action_spec,
        actor_network=actor_net,
        critic_networks=critic_nets,
        critic_optimizers=critic_optimizers,
        debug_summaries=debug_summaries
      )
    else:  # agent is either SacAgent or WcpgAgent
      logging.debug('critic input_tensor_spec: %s', critic_net.input_tensor_spec)
      tf_agent = agent_class(
        time_step_spec,
        action_spec,
        actor_network=actor_net,
        critic_network=critic_net,
        train_step_counter=global_step,
        debug_summaries=debug_summaries)

    tf_agent.initialize()

    # =====================================================================#
    #  Setup replay buffer                                                 #
    # =====================================================================#
    collect_data_spec = tf_agent.collect_data_spec

    logging.debug('Allocating replay buffer ...')
    # Add to replay buffer and other agent specific observers.
    rb_size = rb_size or 1000000
    replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
      collect_data_spec,
      batch_size=1,
      max_length=rb_size)

    logging.debug('RB capacity: %i', replay_buffer.capacity)
    logging.debug('ReplayBuffer Collect data spec: %s', collect_data_spec)

    if agent_class in SAFETY_AGENTS:
      sc_rb_size = sc_rb_size or num_eval_episodes * 500
      sc_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
        collect_data_spec, batch_size=1, max_length=sc_rb_size,
        dataset_window_shift=1)

    num_episodes = tf_metrics.NumberOfEpisodes()
    num_env_steps = tf_metrics.EnvironmentSteps()
    return_metric = tf_metrics.AverageReturnMetric(
      buffer_size=num_eval_episodes, batch_size=tf_env.batch_size)
    train_metrics = [
                      num_episodes, num_env_steps,
                      return_metric,
                      tf_metrics.AverageEpisodeLengthMetric(
                        buffer_size=num_eval_episodes, batch_size=tf_env.batch_size),
                    ] + [tf_py_metric.TFPyMetric(m) for m in train_metrics]

    if 'Minitaur' in env_name and not pretraining:
      goal_vel = gin.query_parameter("%GOAL_VELOCITY")
      early_termination_fn = train_utils.MinitaurTerminationFn(
        speed_metric=train_metrics[-2], total_falls_metric=train_metrics[-3],
        env_steps_metric=num_env_steps, goal_speed=goal_vel)

    if env_metric_factories:
      for env_metric in env_metric_factories:
        train_metrics.append(tf_py_metric.TFPyMetric(env_metric(tf_env.pyenv.envs)))
        if run_eval:
          eval_metrics.append(env_metric([env for env in
                                          eval_tf_env.pyenv._envs]))

    # =====================================================================#
    #  Setup collect policies                                              #
    # =====================================================================#
    if not online_critic:
      eval_policy = tf_agent.policy
      collect_policy = tf_agent.collect_policy
      if not pretraining and agent_class in SAFETY_AGENTS:
        collect_policy = tf_agent.safe_policy
    else:
      eval_policy = tf_agent.collect_policy if pretraining else tf_agent.safe_policy
      collect_policy = tf_agent.collect_policy if pretraining else tf_agent.safe_policy
      online_collect_policy = tf_agent.safe_policy  # if pretraining else tf_agent.collect_policy
      if pretraining:
        online_collect_policy._training = False

    if not load_root_dir:
      initial_collect_policy = random_tf_policy.RandomTFPolicy(time_step_spec, action_spec)
    else:
      initial_collect_policy = collect_policy
    if agent_class == wcpg_agent.WcpgAgent:
      initial_collect_policy = agents.WcpgPolicyWrapper(initial_collect_policy)

    # =====================================================================#
    #  Setup Checkpointing                                                 #
    # =====================================================================#
    train_checkpointer = common.Checkpointer(
      ckpt_dir=train_dir,
      agent=tf_agent,
      global_step=global_step,
      metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))
    policy_checkpointer = common.Checkpointer(
      ckpt_dir=os.path.join(train_dir, 'policy'),
      policy=eval_policy,
      global_step=global_step)

    rb_ckpt_dir = os.path.join(train_dir, 'replay_buffer')
    rb_checkpointer = common.Checkpointer(
      ckpt_dir=rb_ckpt_dir, max_to_keep=1, replay_buffer=replay_buffer)

    if online_critic:
      online_rb_ckpt_dir = os.path.join(train_dir, 'online_replay_buffer')
      online_rb_checkpointer = common.Checkpointer(
        ckpt_dir=online_rb_ckpt_dir,
        max_to_keep=1,
        replay_buffer=sc_buffer)

    # loads agent, replay buffer, and online sc/buffer if online_critic
    if load_root_dir:
      load_root_dir = os.path.expanduser(load_root_dir)
      load_train_dir = os.path.join(load_root_dir, 'train')
      misc.load_agent_ckpt(load_train_dir, tf_agent)
      if len(os.listdir(os.path.join(load_train_dir, 'replay_buffer'))) > 1:
        load_rb_ckpt_dir = os.path.join(load_train_dir, 'replay_buffer')
        misc.load_rb_ckpt(load_rb_ckpt_dir, replay_buffer)
      if online_critic:
        load_online_sc_ckpt_dir = os.path.join(load_root_dir, 'sc')
        load_online_rb_ckpt_dir = os.path.join(load_train_dir,
                                               'online_replay_buffer')
        if osp.exists(load_online_rb_ckpt_dir):
          misc.load_rb_ckpt(load_online_rb_ckpt_dir, sc_buffer)
        if osp.exists(load_online_sc_ckpt_dir):
          misc.load_safety_critic_ckpt(load_online_sc_ckpt_dir,
                                       safety_critic_net)
      elif agent_class in SAFETY_AGENTS:
        offline_run = sorted(os.listdir(os.path.join(load_train_dir, 'offline')))[-1]
        load_sc_ckpt_dir = os.path.join(load_train_dir, 'offline',
                                        offline_run, 'safety_critic')
        if osp.exists(load_sc_ckpt_dir):
          sc_net_off = agents.CriticNetwork(
            (observation_spec, action_spec),
            joint_fc_layer_params=(512, 512),
            name='SafetyCriticOffline')
          sc_net_off.create_variables()
          target_sc_net_off = common.maybe_copy_target_network_with_checks(
            sc_net_off, None, 'TargetSafetyCriticNetwork')
          sc_optimizer = tf.keras.optimizers.Adam(critic_learning_rate)
          _ = misc.load_safety_critic_ckpt(
            load_sc_ckpt_dir, safety_critic_net=sc_net_off,
            target_safety_critic=target_sc_net_off,
            optimizer=sc_optimizer)
          tf_agent._safety_critic_network = sc_net_off
          tf_agent._target_safety_critic_network = target_sc_net_off
          tf_agent._safety_critic_optimizer = sc_optimizer
    else:
      train_checkpointer.initialize_or_restore()
      rb_checkpointer.initialize_or_restore()
      if online_critic:
        online_rb_checkpointer.initialize_or_restore()

    if agent_class in SAFETY_AGENTS:
      sc_dir = os.path.join(root_dir, 'sc')
      safety_critic_checkpointer = common.Checkpointer(
        ckpt_dir=sc_dir,
        safety_critic=tf_agent._safety_critic_network,
        # pylint: disable=protected-access
        target_safety_critic=tf_agent._target_safety_critic_network,
        optimizer=tf_agent._safety_critic_optimizer,
        global_step=global_step)

      if not (load_root_dir and not online_critic):
        safety_critic_checkpointer.initialize_or_restore()

    agent_observers = [replay_buffer.add_batch] + train_metrics
    collect_driver = collect_driver_class(
      tf_env, collect_policy, observers=agent_observers)
    collect_driver.run = common.function_in_tf1()(collect_driver.run)

    if online_critic:
      logging.debug('online driver class: %s', online_driver_class)
      online_agent_observers = [num_episodes, num_env_steps,
                                sc_buffer.add_batch]
      online_driver = online_driver_class(
        sc_tf_env, online_collect_policy, observers=online_agent_observers,
        num_episodes=num_eval_episodes)
      online_driver.run = common.function_in_tf1()(online_driver.run)

    if eager_debug:
      tf.config.experimental_run_functions_eagerly(True)
    else:
      config_saver = gin.tf.GinConfigSaverHook(train_dir, summarize_config=True)
      tf.function(config_saver.after_create_session)()

    if global_step == 0:
      logging.info('Performing initial collection ...')
      init_collect_observers = agent_observers
      if agent_class in SAFETY_AGENTS:
        init_collect_observers += [sc_buffer.add_batch]
      initial_collect_driver_class(
        tf_env,
        initial_collect_policy,
        observers=init_collect_observers).run()
      last_id = replay_buffer._get_last_id()  # pylint: disable=protected-access
      logging.info('Data saved after initial collection: %d steps', last_id)
      if agent_class in SAFETY_AGENTS:
        last_id = sc_buffer._get_last_id()  # pylint: disable=protected-access
        logging.debug('Data saved in sc_buffer after initial collection: %d steps', last_id)

    if run_eval:
      results = metric_utils.eager_compute(
        eval_metrics,
        eval_tf_env,
        eval_policy,
        num_episodes=num_eval_episodes,
        train_step=global_step,
        summary_writer=eval_summary_writer,
        summary_prefix='EvalMetrics',
      )
      if train_metrics_callback is not None:
        train_metrics_callback(results, global_step.numpy())
      metric_utils.log_metrics(eval_metrics)

    time_step = None
    policy_state = collect_policy.get_initial_state(tf_env.batch_size)

    timed_at_step = global_step.numpy()
    time_acc = 0

    train_step = train_utils.get_train_step(tf_agent, replay_buffer, batch_size)

    if agent_class in SAFETY_AGENTS:
      critic_train_step = train_utils.get_critic_train_step(
        tf_agent, replay_buffer, sc_buffer, batch_size=batch_size,
        updating_sc=updating_sc, metrics=sc_metrics)

    if early_termination_fn is None:
      early_termination_fn = lambda: False

    loss_diverged = False
    # How many consecutive steps was loss diverged for.
    loss_divergence_counter = 0
    mean_train_loss = tf.keras.metrics.Mean(name='mean_train_loss')

    if agent_class in SAFETY_AGENTS:
      resample_counter = collect_policy._resample_counter
      mean_resample_ac = tf.keras.metrics.Mean(name='mean_unsafe_ac_freq')
      sc_metrics.append(mean_resample_ac)

      if online_critic:
        logging.debug('starting safety critic pretraining')
        # don't fine-tune safety critic
        if global_step.numpy() == 0:
          for _ in range(train_sc_steps):
            sc_loss, lambda_loss = critic_train_step()
          critic_results = [('sc_loss', sc_loss.numpy()), ('lambda_loss', lambda_loss.numpy())]
          for critic_metric in sc_metrics:
            res = critic_metric.result().numpy()
            if not res.shape:
              critic_results.append((critic_metric.name, res))
            else:
              for r, thresh in zip(res, thresholds):
                name = '_'.join([critic_metric.name, str(thresh)])
                critic_results.append((name, r))
            critic_metric.reset_states()
          if train_metrics_callback:
            train_metrics_callback(collections.OrderedDict(critic_results),
                                   step=global_step.numpy())

    logging.debug('Starting main train loop...')
    curr_ep = []
    global_step_val = global_step.numpy()
    while global_step_val <= num_global_steps and not early_termination_fn():
      start_time = time.time()

      # MEASURE ACTION RESAMPLING FREQUENCY
      if agent_class in SAFETY_AGENTS:
        if pretraining and global_step_val == num_global_steps // 2:
          if online_critic:
            online_collect_policy._training = True
          collect_policy._training = True
        if online_critic or collect_policy._training:
          mean_resample_ac(resample_counter.result())
          resample_counter.reset()
          if time_step is None or time_step.is_last():
            resample_ac_freq = mean_resample_ac.result()
            mean_resample_ac.reset_states()
            tf.compat.v2.summary.scalar(
              name='resample_ac_freq', data=resample_ac_freq, step=global_step)

      # RUN COLLECTION
      time_step, policy_state = collect_driver.run(
        time_step=time_step,
        policy_state=policy_state,
      )

      # get last step taken by step_driver
      traj = replay_buffer._data_table.read(replay_buffer._get_last_id() %
                                            replay_buffer._capacity)
      curr_ep.append(traj)

      if time_step.is_last():
        if agent_class in SAFETY_AGENTS:
          if time_step.observation['task_agn_rew']:
            if kstep_fail:
              # applies task agn rew. over last k steps
              for i, traj in enumerate(curr_ep[-kstep_fail:]):
                traj.observation['task_agn_rew'] = 1.
                sc_buffer.add_batch(traj)
            else:
              [sc_buffer.add_batch(traj) for traj in curr_ep]
        curr_ep = []
        if agent_class == wcpg_agent.WcpgAgent:
          collect_policy._alpha = None  # reset WCPG alpha

      if (global_step_val + 1) % log_interval == 0:
        logging.debug('policy eval: %4.2f sec', time.time() - start_time)

      # PERFORMS TRAIN STEP ON ALGORITHM (OFF-POLICY)
      for _ in range(train_steps_per_iteration):
        train_loss = train_step()
        mean_train_loss(train_loss.loss)

      current_step = global_step.numpy()
      total_loss = mean_train_loss.result()
      mean_train_loss.reset_states()

      if train_metrics_callback and current_step % summary_interval == 0:
        train_metrics_callback(
          collections.OrderedDict([(k, v.numpy()) for k, v in
                                   train_loss.extra._asdict().items()]),
          step=current_step)
        train_metrics_callback(
          {'train_loss': total_loss.numpy()}, step=current_step)

      # TRAIN AND/OR EVAL SAFETY CRITIC
      if agent_class in SAFETY_AGENTS and current_step % train_sc_interval == 0:
        if online_critic:
          batch_time_step = sc_tf_env.reset()

          # run online critic training collect & update
          batch_policy_state = online_collect_policy.get_initial_state(
            sc_tf_env.batch_size)
          online_driver.run(time_step=batch_time_step,
                            policy_state=batch_policy_state)
        for _ in range(train_sc_steps):
          sc_loss, lambda_loss = critic_train_step()
        # log safety_critic loss results
        critic_results = [('sc_loss', sc_loss.numpy()),
                          ('lambda_loss', lambda_loss.numpy())]
        metric_utils.log_metrics(sc_metrics)
        for critic_metric in sc_metrics:
          res = critic_metric.result().numpy()
          if not res.shape:
            critic_results.append((critic_metric.name, res))
          else:
            for r, thresh in zip(res, thresholds):
              name = '_'.join([critic_metric.name, str(thresh)])
              critic_results.append((name, r))
          critic_metric.reset_states()
        if train_metrics_callback and current_step % summary_interval == 0:
          train_metrics_callback(collections.OrderedDict(critic_results),
                                 step=current_step)

      # Check for exploding losses.
      if (math.isnan(total_loss) or math.isinf(total_loss) or
              total_loss > MAX_LOSS):
        loss_divergence_counter += 1
        if loss_divergence_counter > TERMINATE_AFTER_DIVERGED_LOSS_STEPS:
          loss_diverged = True
          logging.info('Loss diverged, critic_loss: %s, actor_loss: %s',
                       train_loss.extra.critic_loss,
                       train_loss.extra.actor_loss)
          break
      else:
        loss_divergence_counter = 0

      time_acc += time.time() - start_time

      # LOGGING AND METRICS
      if current_step % log_interval == 0:
        metric_utils.log_metrics(train_metrics)
        logging.info('step = %d, loss = %f', current_step, total_loss)
        steps_per_sec = (current_step - timed_at_step) / time_acc
        logging.info('%4.2f steps/sec', steps_per_sec)
        tf.compat.v2.summary.scalar(
          name='global_steps_per_sec', data=steps_per_sec, step=global_step)
        timed_at_step = current_step
        time_acc = 0

      train_results = []

      for metric in train_metrics[2:]:
        if isinstance(metric, (metrics.AverageEarlyFailureMetric,
                               metrics.AverageFallenMetric,
                               metrics.AverageSuccessMetric)):
          # Plot failure as a fn of return
          metric.tf_summaries(
            train_step=global_step, step_metrics=[num_env_steps, num_episodes,
                                                  return_metric])
        else:
          metric.tf_summaries(
            train_step=global_step, step_metrics=[num_env_steps, num_env_steps])
        train_results.append((metric.name, metric.result().numpy()))

      if train_metrics_callback and current_step % summary_interval == 0:
        train_metrics_callback(collections.OrderedDict(train_results),
                               step=global_step.numpy())

      if current_step % train_checkpoint_interval == 0:
        train_checkpointer.save(global_step=current_step)

      if current_step % policy_checkpoint_interval == 0:
        policy_checkpointer.save(global_step=current_step)
        if agent_class in SAFETY_AGENTS:
          safety_critic_checkpointer.save(global_step=current_step)
          if online_critic:
            online_rb_checkpointer.save(global_step=current_step)

      if rb_checkpoint_interval and current_step % rb_checkpoint_interval == 0:
        rb_checkpointer.save(global_step=current_step)

      if wandb and current_step % eval_interval == 0 and "Drunk" in env_name:
        misc.record_point_mass_episode(eval_tf_env, eval_policy, current_step)
        if online_critic:
          misc.record_point_mass_episode(eval_tf_env, tf_agent.safe_policy,
                                         current_step, 'safe-trajectory')

      if run_eval and current_step % eval_interval == 0:
        eval_results = metric_utils.eager_compute(
          eval_metrics,
          eval_tf_env,
          eval_policy,
          num_episodes=num_eval_episodes,
          train_step=global_step,
          summary_writer=eval_summary_writer,
          summary_prefix='EvalMetrics',
        )
        if train_metrics_callback is not None:
          train_metrics_callback(eval_results, current_step)
        metric_utils.log_metrics(eval_metrics)

        with eval_summary_writer.as_default():
          for eval_metric in eval_metrics[2:]:
            eval_metric.tf_summaries(train_step=global_step,
                                     step_metrics=eval_metrics[:2])

      if monitor and current_step % monitor_interval == 0:
        monitor_time_step = monitor_py_env.reset()
        monitor_policy_state = eval_policy.get_initial_state(1)
        ep_len = 0
        monitor_start = time.time()
        while not monitor_time_step.is_last():
          monitor_action = eval_policy.action(monitor_time_step, monitor_policy_state)
          action, monitor_policy_state = monitor_action.action, monitor_action.state
          monitor_time_step = monitor_py_env.step(action)
          ep_len += 1
        logging.debug('saved rollout at timestep %d, rollout length: %d, %4.2f sec',
                      current_step, ep_len, time.time() - monitor_start)

      global_step_val = current_step

  if early_termination_fn():
    #  Early stopped, save all checkpoints if not saved
    if global_step_val % train_checkpoint_interval != 0:
      train_checkpointer.save(global_step=global_step_val)

    if global_step_val % policy_checkpoint_interval != 0:
      policy_checkpointer.save(global_step=global_step_val)
      if agent_class in SAFETY_AGENTS:
        safety_critic_checkpointer.save(global_step=global_step_val)
        if online_critic:
          online_rb_checkpointer.save(global_step=global_step_val)

    if rb_checkpoint_interval and global_step_val % rb_checkpoint_interval == 0:
      rb_checkpointer.save(global_step=global_step_val)

  if not keep_rb_checkpoint:
    misc.cleanup_checkpoints(rb_ckpt_dir)

  if loss_diverged:
    # Raise an error at the very end after the cleanup.
    raise ValueError('Loss diverged to {} at step {}, terminating.'.format(
      total_loss, global_step.numpy()))

  return total_loss