コード例 #1
0
    def _init_model_and_dist_class(self):
        if is_overridden(self.make_model) and is_overridden(
                self.make_model_and_action_dist):
            raise ValueError(
                "Only one of make_model or make_model_and_action_dist "
                "can be overridden.")

        if is_overridden(self.make_model):
            model = self.make_model()
            dist_class, _ = ModelCatalog.get_action_dist(
                self.action_space,
                self.config["model"],
                framework=self.framework)
        elif is_overridden(self.make_model_and_action_dist):
            model, dist_class = self.make_model_and_action_dist()
        else:
            dist_class, logit_dim = ModelCatalog.get_action_dist(
                self.action_space,
                self.config["model"],
                framework=self.framework)
            model = ModelCatalog.get_model_v2(
                obs_space=self.observation_space,
                action_space=self.action_space,
                num_outputs=logit_dim,
                model_config=self.config["model"],
                framework=self.framework,
            )
        return model, dist_class
コード例 #2
0
    def _compute_gradients_helper(self, samples):
        """Computes and returns grads as eager tensors."""

        # Increase the tracing counter to make sure we don't re-trace too
        # often. If eager_tracing=True, this counter should only get
        # incremented during the @tf.function trace operations, never when
        # calling the already traced function after that.
        self._re_trace_counter += 1

        # Gather all variables for which to calculate losses.
        if isinstance(self.model, tf.keras.Model):
            variables = self.model.trainable_variables
        else:
            variables = self.model.trainable_variables()

        # Calculate the loss(es) inside a tf GradientTape.
        with tf.GradientTape(
            persistent=is_overridden(self.compute_gradients_fn)
        ) as tape:
            losses = self.loss(self.model, self.dist_class, samples)
        losses = force_list(losses)

        # User provided a custom compute_gradients_fn.
        if is_overridden(self.compute_gradients_fn):
            # Wrap our tape inside a wrapper, such that the resulting
            # object looks like a "classic" tf.optimizer. This way, custom
            # compute_gradients_fn will work on both tf static graph
            # and tf-eager.
            optimizer = _OptimizerWrapper(tape)
            # More than one loss terms/optimizers.
            if self.config["_tf_policy_handles_more_than_one_loss"]:
                grads_and_vars = self.compute_gradients_fn(
                    [optimizer] * len(losses), losses
                )
            # Only one loss and one optimizer.
            else:
                grads_and_vars = [self.compute_gradients_fn(optimizer, losses[0])]
        # Default: Compute gradients using the above tape.
        else:
            grads_and_vars = [
                list(zip(tape.gradient(loss, variables), variables)) for loss in losses
            ]

        if log_once("grad_vars"):
            for g_and_v in grads_and_vars:
                for g, v in g_and_v:
                    if g is not None:
                        logger.info(f"Optimizing variable {v.name}")

        # `grads_and_vars` is returned a list (len=num optimizers/losses)
        # of lists of (grad, var) tuples.
        if self.config["_tf_policy_handles_more_than_one_loss"]:
            grads = [[g for g, _ in g_and_v] for g_and_v in grads_and_vars]
        # `grads_and_vars` is returned as a list of (grad, var) tuples.
        else:
            grads_and_vars = grads_and_vars[0]
            grads = [g for g, _ in grads_and_vars]

        stats = self._stats(samples, grads)
        return grads_and_vars, grads, stats
コード例 #3
0
ファイル: eager_tf_policy_v2.py プロジェクト: smorad/ray
 def _init_dist_class(self):
     if is_overridden(self.action_sampler_fn) or is_overridden(
             self.action_distribution_fn):
         if not is_overridden(self.make_model):
             raise ValueError(
                 "`make_model` is required if `action_sampler_fn` OR "
                 "`action_distribution_fn` is given")
     else:
         dist_class, _ = ModelCatalog.get_action_dist(
             self.action_space, self.config["model"])
     return dist_class
コード例 #4
0
    def compute_log_likelihoods(
        self,
        actions,
        obs_batch,
        state_batches=None,
        prev_action_batch=None,
        prev_reward_batch=None,
        actions_normalized=True,
    ):
        if is_overridden(self.action_sampler_fn) and not is_overridden(
            self.action_distribution_fn
        ):
            raise ValueError(
                "Cannot compute log-prob/likelihood w/o an "
                "`action_distribution_fn` and a provided "
                "`action_sampler_fn`!"
            )

        seq_lens = tf.ones(len(obs_batch), dtype=tf.int32)
        input_batch = SampleBatch(
            {SampleBatch.CUR_OBS: tf.convert_to_tensor(obs_batch)},
            _is_training=False,
        )
        if prev_action_batch is not None:
            input_batch[SampleBatch.PREV_ACTIONS] = tf.convert_to_tensor(
                prev_action_batch
            )
        if prev_reward_batch is not None:
            input_batch[SampleBatch.PREV_REWARDS] = tf.convert_to_tensor(
                prev_reward_batch
            )

        # Exploration hook before each forward pass.
        self.exploration.before_compute_actions(explore=False)

        # Action dist class and inputs are generated via custom function.
        if is_overridden(self.action_distribution_fn):
            dist_inputs, self.dist_class, _ = self.action_distribution_fn(
                self, self.model, input_batch, explore=False, is_training=False
            )
        # Default log-likelihood calculation.
        else:
            dist_inputs, _ = self.model(input_batch, state_batches, seq_lens)

        action_dist = self.dist_class(dist_inputs, self.model)

        # Normalize actions if necessary.
        if not actions_normalized and self.config["normalize_actions"]:
            actions = normalize_action(actions, self.action_space_struct)

        log_likelihoods = action_dist.logp(actions)

        return log_likelihoods
コード例 #5
0
    def _stats(self, samples, grads):
        fetches = {}
        if is_overridden(self.stats_fn):
            fetches[LEARNER_STATS_KEY] = {
                k: v for k, v in self.stats_fn(samples).items()
            }
        else:
            fetches[LEARNER_STATS_KEY] = {}

        fetches.update({k: v for k, v in self.extra_learn_fetches_fn().items()})
        fetches.update({k: v for k, v in self.grad_stats_fn(samples, grads).items()})
        return fetches
コード例 #6
0
ファイル: callbacks.py プロジェクト: ray-project/ray
 def __init__(self, legacy_callbacks_dict: Dict[str, callable] = None):
     if legacy_callbacks_dict:
         deprecation_warning(
             "callbacks dict interface",
             "a class extending rllib.algorithms.callbacks.DefaultCallbacks",
         )
     self.legacy_callbacks = legacy_callbacks_dict or {}
     if is_overridden(self.on_trainer_init):
         deprecation_warning(
             old="on_trainer_init(trainer, **kwargs)",
             new="on_algorithm_init(algorithm, **kwargs)",
             error=False,
         )
コード例 #7
0
    def gradients(self, optimizer, loss):
        optimizers = force_list(optimizer)
        losses = force_list(loss)

        if is_overridden(self.compute_gradients_fn):
            # New API: Allow more than one optimizer -> Return a list of
            # lists of gradients.
            if self.config["_tf_policy_handles_more_than_one_loss"]:
                return self.compute_gradients_fn(optimizers, losses)
            # Old API: Return a single List of gradients.
            else:
                return self.compute_gradients_fn(optimizers[0], losses[0])
        else:
            return super().gradients(optimizers, losses)
コード例 #8
0
    def _compute_action_probs(self, obs: TensorType) -> TensorType:
        """Compute action distribution over the action space.

        Args:
            obs: A tensor of observations of shape (batch_size * obs_dim)

        Returns:
            action_probs: A tensor of action probabilities
            of shape (batch_size * action_dim)
        """
        input_dict = {SampleBatch.OBS: obs}
        seq_lens = torch.ones(len(obs), device=self.device, dtype=int)
        state_batches = []
        if is_overridden(self.policy.action_distribution_fn):
            try:
                # TorchPolicyV2 function signature
                dist_inputs, dist_class, _ = self.policy.action_distribution_fn(
                    self.policy.model,
                    obs_batch=input_dict,
                    state_batches=state_batches,
                    seq_lens=seq_lens,
                    explore=False,
                    is_training=False,
                )
            except TypeError:
                # TorchPolicyV1 function signature for compatibility with DQN
                # TODO: Remove this once DQNTorchPolicy is migrated to PolicyV2
                dist_inputs, dist_class, _ = self.policy.action_distribution_fn(
                    self.policy,
                    self.policy.model,
                    input_dict=input_dict,
                    state_batches=state_batches,
                    seq_lens=seq_lens,
                    explore=False,
                    is_training=False,
                )
        else:
            dist_class = self.policy.dist_class
            dist_inputs, _ = self.policy.model(input_dict, state_batches, seq_lens)
        action_dist = dist_class(dist_inputs, self.policy.model)
        assert isinstance(
            action_dist.dist, torch.distributions.categorical.Categorical
        ), "FQE only supports Categorical or MultiCategorical distributions!"
        action_probs = action_dist.dist.probs
        return action_probs
コード例 #9
0
ファイル: eager_tf_policy_v2.py プロジェクト: smorad/ray
    def _apply_gradients_helper(self, grads_and_vars):
        # Increase the tracing counter to make sure we don't re-trace too
        # often. If eager_tracing=True, this counter should only get
        # incremented during the @tf.function trace operations, never when
        # calling the already traced function after that.
        self._re_trace_counter += 1

        if is_overridden(self.apply_gradients_fn):
            if self.config["_tf_policy_handles_more_than_one_loss"]:
                self.apply_gradients_fn(self._optimizers, grads_and_vars)
            else:
                self.apply_gradients_fn(self._optimizer, grads_and_vars)
        else:
            if self.config["_tf_policy_handles_more_than_one_loss"]:
                for i, o in enumerate(self._optimizers):
                    o.apply_gradients([(g, v) for g, v in grads_and_vars[i]
                                       if g is not None])
            else:
                self._optimizer.apply_gradients([(g, v)
                                                 for g, v in grads_and_vars
                                                 if g is not None])
コード例 #10
0
    def _compute_action_helper(self, input_dict, state_batches, seq_lens,
                               explore, timestep):
        """Shared forward pass logic (w/ and w/o trajectory view API).

        Returns:
            A tuple consisting of a) actions, b) state_out, c) extra_fetches.
        """
        explore = explore if explore is not None else self.config["explore"]
        timestep = timestep if timestep is not None else self.global_timestep
        self._is_recurrent = state_batches is not None and state_batches != []

        # Switch to eval mode.
        if self.model:
            self.model.eval()

        if is_overridden(self.action_sampler_fn):
            action_dist = dist_inputs = None
            actions, logp, state_out = self.action_sampler_fn(
                self.model,
                obs_batch=input_dict,
                state_batches=state_batches,
                explore=explore,
                timestep=timestep,
            )
        else:
            # Call the exploration before_compute_actions hook.
            self.exploration.before_compute_actions(explore=explore,
                                                    timestep=timestep)
            if is_overridden(self.action_distribution_fn):
                dist_inputs, dist_class, state_out = self.action_distribution_fn(
                    self.model,
                    obs_batch=input_dict,
                    state_batches=state_batches,
                    seq_lens=seq_lens,
                    explore=explore,
                    timestep=timestep,
                    is_training=False,
                )
            else:
                dist_class = self.dist_class
                dist_inputs, state_out = self.model(input_dict, state_batches,
                                                    seq_lens)

            if not (isinstance(dist_class, functools.partial)
                    or issubclass(dist_class, TorchDistributionWrapper)):
                raise ValueError(
                    "`dist_class` ({}) not a TorchDistributionWrapper "
                    "subclass! Make sure your `action_distribution_fn` or "
                    "`make_model_and_action_dist` return a correct "
                    "distribution class.".format(dist_class.__name__))
            action_dist = dist_class(dist_inputs, self.model)

            # Get the exploration action from the forward results.
            actions, logp = self.exploration.get_exploration_action(
                action_distribution=action_dist,
                timestep=timestep,
                explore=explore)

        input_dict[SampleBatch.ACTIONS] = actions

        # Add default and custom fetches.
        extra_fetches = self.extra_action_out(input_dict, state_batches,
                                              self.model, action_dist)

        # Action-dist inputs.
        if dist_inputs is not None:
            extra_fetches[SampleBatch.ACTION_DIST_INPUTS] = dist_inputs

        # Action-logp and action-prob.
        if logp is not None:
            extra_fetches[SampleBatch.ACTION_PROB] = torch.exp(logp.float())
            extra_fetches[SampleBatch.ACTION_LOGP] = logp

        # Update our global timestep by the batch size.
        self.global_timestep += len(input_dict[SampleBatch.CUR_OBS])

        return convert_to_numpy((actions, state_out, extra_fetches))
コード例 #11
0
    def compute_log_likelihoods(
        self,
        actions: Union[List[TensorStructType], TensorStructType],
        obs_batch: Union[List[TensorStructType], TensorStructType],
        state_batches: Optional[List[TensorType]] = None,
        prev_action_batch: Optional[Union[List[TensorStructType],
                                          TensorStructType]] = None,
        prev_reward_batch: Optional[Union[List[TensorStructType],
                                          TensorStructType]] = None,
        actions_normalized: bool = True,
    ) -> TensorType:

        if is_overridden(self.action_sampler_fn) and not is_overridden(
                self.action_distribution_fn):
            raise ValueError("Cannot compute log-prob/likelihood w/o an "
                             "`action_distribution_fn` and a provided "
                             "`action_sampler_fn`!")

        with torch.no_grad():
            input_dict = self._lazy_tensor_dict({
                SampleBatch.CUR_OBS: obs_batch,
                SampleBatch.ACTIONS: actions
            })
            if prev_action_batch is not None:
                input_dict[SampleBatch.PREV_ACTIONS] = prev_action_batch
            if prev_reward_batch is not None:
                input_dict[SampleBatch.PREV_REWARDS] = prev_reward_batch
            seq_lens = torch.ones(len(obs_batch), dtype=torch.int32)
            state_batches = [
                convert_to_torch_tensor(s, self.device)
                for s in (state_batches or [])
            ]

            # Exploration hook before each forward pass.
            self.exploration.before_compute_actions(explore=False)

            # Action dist class and inputs are generated via custom function.
            if is_overridden(self.action_distribution_fn):
                dist_inputs, dist_class, state_out = self.action_distribution_fn(
                    self.model,
                    input_dict=input_dict,
                    state_batches=state_batches,
                    seq_lens=seq_lens,
                    explore=False,
                    is_training=False,
                )
            # Default action-dist inputs calculation.
            else:
                dist_class = self.dist_class
                dist_inputs, _ = self.model(input_dict, state_batches,
                                            seq_lens)

            action_dist = dist_class(dist_inputs, self.model)

            # Normalize actions if necessary.
            actions = input_dict[SampleBatch.ACTIONS]
            if not actions_normalized and self.config["normalize_actions"]:
                actions = normalize_action(actions, self.action_space_struct)

            log_likelihoods = action_dist.logp(actions)

            return log_likelihoods
コード例 #12
0
ファイル: eager_tf_policy_v2.py プロジェクト: smorad/ray
    def _compute_actions_helper(
        self,
        input_dict,
        state_batches,
        episodes,
        explore,
        timestep,
        _ray_trace_ctx=None,
    ):
        # Increase the tracing counter to make sure we don't re-trace too
        # often. If eager_tracing=True, this counter should only get
        # incremented during the @tf.function trace operations, never when
        # calling the already traced function after that.
        self._re_trace_counter += 1

        # Calculate RNN sequence lengths.
        batch_size = tree.flatten(input_dict[SampleBatch.OBS])[0].shape[0]
        seq_lens = tf.ones(batch_size,
                           dtype=tf.int32) if state_batches else None

        # Add default and custom fetches.
        extra_fetches = {}

        # Use Exploration object.
        with tf.variable_creator_scope(_disallow_var_creation):
            if is_overridden(self.action_sampler_fn):
                dist_inputs = None
                state_out = []
                actions, logp, dist_inputs, state_out = self.action_sampler_fn(
                    self.model,
                    input_dict[SampleBatch.CUR_OBS],
                    explore=explore,
                    timestep=timestep,
                    episodes=episodes,
                )
            else:
                if is_overridden(self.action_distribution_fn):

                    # Try new action_distribution_fn signature, supporting
                    # state_batches and seq_lens.
                    (
                        dist_inputs,
                        self.dist_class,
                        state_out,
                    ) = self.action_distribution_fn(
                        self.model,
                        input_dict=input_dict,
                        state_batches=state_batches,
                        seq_lens=seq_lens,
                        explore=explore,
                        timestep=timestep,
                        is_training=False,
                    )
                elif isinstance(self.model, tf.keras.Model):
                    input_dict = SampleBatch(input_dict, seq_lens=seq_lens)
                    if state_batches and "state_in_0" not in input_dict:
                        for i, s in enumerate(state_batches):
                            input_dict[f"state_in_{i}"] = s
                    self._lazy_tensor_dict(input_dict)
                    dist_inputs, state_out, extra_fetches = self.model(
                        input_dict)
                else:
                    dist_inputs, state_out = self.model(
                        input_dict, state_batches, seq_lens)

                action_dist = self.dist_class(dist_inputs, self.model)

                # Get the exploration action from the forward results.
                actions, logp = self.exploration.get_exploration_action(
                    action_distribution=action_dist,
                    timestep=timestep,
                    explore=explore,
                )

        # Action-logp and action-prob.
        if logp is not None:
            extra_fetches[SampleBatch.ACTION_PROB] = tf.exp(logp)
            extra_fetches[SampleBatch.ACTION_LOGP] = logp
        # Action-dist inputs.
        if dist_inputs is not None:
            extra_fetches[SampleBatch.ACTION_DIST_INPUTS] = dist_inputs
        # Custom extra fetches.
        extra_fetches.update(self.extra_action_out_fn())

        return actions, state_out, extra_fetches
コード例 #13
0
    def _init_action_fetches(
        self, timestep: Union[int, TensorType], explore: Union[bool,
                                                               TensorType]
    ) -> Tuple[TensorType, TensorType, TensorType, type, Dict[str,
                                                              TensorType]]:
        """Create action related fields for base Policy and loss initialization."""
        # Multi-GPU towers do not need any action computing/exploration
        # graphs.
        sampled_action = None
        sampled_action_logp = None
        dist_inputs = None
        extra_action_fetches = {}
        self._state_out = None
        if not self._is_tower:
            # Create the Exploration object to use for this Policy.
            self.exploration = self._create_exploration()

            # Fully customized action generation (e.g., custom policy).
            if is_overridden(self.action_sampler_fn):
                (
                    sampled_action,
                    sampled_action_logp,
                    dist_inputs,
                    self._state_out,
                ) = self.action_sampler_fn(
                    self.model,
                    obs_batch=self._input_dict[SampleBatch.CUR_OBS],
                    state_batches=self._state_inputs,
                    seq_lens=self._seq_lens,
                    prev_action_batch=self._input_dict.get(
                        SampleBatch.PREV_ACTIONS),
                    prev_reward_batch=self._input_dict.get(
                        SampleBatch.PREV_REWARDS),
                    explore=explore,
                    is_training=self._input_dict.is_training,
                )
            # Distribution generation is customized, e.g., DQN, DDPG.
            else:
                if is_overridden(self.action_distribution_fn):
                    # Try new action_distribution_fn signature, supporting
                    # state_batches and seq_lens.
                    in_dict = self._input_dict
                    (
                        dist_inputs,
                        self.dist_class,
                        self._state_out,
                    ) = self.action_distribution_fn(
                        self.model,
                        input_dict=in_dict,
                        state_batches=self._state_inputs,
                        seq_lens=self._seq_lens,
                        explore=explore,
                        timestep=timestep,
                        is_training=in_dict.is_training,
                    )
                # Default distribution generation behavior:
                # Pass through model. E.g., PG, PPO.
                else:
                    if isinstance(self.model, tf.keras.Model):
                        dist_inputs, self._state_out, extra_action_fetches = self.model(
                            self._input_dict)
                    else:
                        dist_inputs, self._state_out = self.model(
                            self._input_dict)

                action_dist = self.dist_class(dist_inputs, self.model)

                # Using exploration to get final action (e.g. via sampling).
                (
                    sampled_action,
                    sampled_action_logp,
                ) = self.exploration.get_exploration_action(
                    action_distribution=action_dist,
                    timestep=timestep,
                    explore=explore)

        if dist_inputs is not None:
            extra_action_fetches[SampleBatch.ACTION_DIST_INPUTS] = dist_inputs

        if sampled_action_logp is not None:
            extra_action_fetches[SampleBatch.ACTION_LOGP] = sampled_action_logp
            extra_action_fetches[SampleBatch.ACTION_PROB] = tf.exp(
                tf.cast(sampled_action_logp, tf.float32))

        return (
            sampled_action,
            sampled_action_logp,
            dist_inputs,
            extra_action_fetches,
        )
コード例 #14
0
ファイル: policy.py プロジェクト: smorad/ray
    def _initialize_loss_from_dummy_batch(
        self,
        auto_remove_unneeded_view_reqs: bool = True,
        stats_fn=None,
    ) -> None:
        """Performs test calls through policy's model and loss.

        NOTE: This base method should work for define-by-run Policies such as
        torch and tf-eager policies.

        If required, will thereby detect automatically, which data views are
        required by a) the forward pass, b) the postprocessing, and c) the loss
        functions, and remove those from self.view_requirements that are not
        necessary for these computations (to save data storage and transfer).

        Args:
            auto_remove_unneeded_view_reqs (bool): Whether to automatically
                remove those ViewRequirements records from
                self.view_requirements that are not needed.
            stats_fn (Optional[Callable[[Policy, SampleBatch], Dict[str,
                TensorType]]]): An optional stats function to be called after
                the loss.
        """
        # Signal Policy that currently we do not like to eager/jit trace
        # any function calls. This is to be able to track, which columns
        # in the dummy batch are accessed by the different function (e.g.
        # loss) such that we can then adjust our view requirements.
        self._no_tracing = True

        sample_batch_size = max(self.batch_divisibility_req * 4, 32)
        self._dummy_batch = self._get_dummy_batch_from_view_requirements(
            sample_batch_size)
        self._lazy_tensor_dict(self._dummy_batch)
        actions, state_outs, extra_outs = self.compute_actions_from_input_dict(
            self._dummy_batch, explore=False)
        for key, view_req in self.view_requirements.items():
            if key not in self._dummy_batch.accessed_keys:
                view_req.used_for_compute_actions = False
        # Add all extra action outputs to view reqirements (these may be
        # filtered out later again, if not needed for postprocessing or loss).
        for key, value in extra_outs.items():
            self._dummy_batch[key] = value
            if key not in self.view_requirements:
                self.view_requirements[key] = ViewRequirement(
                    space=gym.spaces.Box(-1.0,
                                         1.0,
                                         shape=value.shape[1:],
                                         dtype=value.dtype),
                    used_for_compute_actions=False,
                )
        for key in self._dummy_batch.accessed_keys:
            if key not in self.view_requirements:
                self.view_requirements[key] = ViewRequirement()
            self.view_requirements[key].used_for_compute_actions = True
        self._dummy_batch = self._get_dummy_batch_from_view_requirements(
            sample_batch_size)
        self._dummy_batch.set_get_interceptor(None)
        self.exploration.postprocess_trajectory(self, self._dummy_batch)
        postprocessed_batch = self.postprocess_trajectory(self._dummy_batch)
        seq_lens = None
        if state_outs:
            B = 4  # For RNNs, have B=4, T=[depends on sample_batch_size]
            i = 0
            while "state_in_{}".format(i) in postprocessed_batch:
                postprocessed_batch["state_in_{}".format(
                    i)] = postprocessed_batch["state_in_{}".format(i)][:B]
                if "state_out_{}".format(i) in postprocessed_batch:
                    postprocessed_batch["state_out_{}".format(
                        i)] = postprocessed_batch["state_out_{}".format(i)][:B]
                i += 1
            seq_len = sample_batch_size // B
            seq_lens = np.array([seq_len for _ in range(B)], dtype=np.int32)
            postprocessed_batch[SampleBatch.SEQ_LENS] = seq_lens
        # Switch on lazy to-tensor conversion on `postprocessed_batch`.
        train_batch = self._lazy_tensor_dict(postprocessed_batch)
        # Calling loss, so set `is_training` to True.
        train_batch.set_training(True)
        if seq_lens is not None:
            train_batch[SampleBatch.SEQ_LENS] = seq_lens
        train_batch.count = self._dummy_batch.count
        # Call the loss function, if it exists.
        # TODO(jungong) : clean up after all agents get migrated.
        # We should simply do self.loss(...) here.
        if self._loss is not None:
            self._loss(self, self.model, self.dist_class, train_batch)
        elif is_overridden(self.loss):
            self.loss(self.model, self.dist_class, train_batch)
        # Call the stats fn, if given.
        # TODO(jungong) : clean up after all agents get migrated.
        # We should simply do self.stats_fn(train_batch) here.
        if stats_fn is not None:
            stats_fn(self, train_batch)
        if hasattr(self, "stats_fn"):
            self.stats_fn(train_batch)

        # Re-enable tracing.
        self._no_tracing = False

        # Add new columns automatically to view-reqs.
        if auto_remove_unneeded_view_reqs:
            # Add those needed for postprocessing and training.
            all_accessed_keys = (train_batch.accessed_keys
                                 | self._dummy_batch.accessed_keys
                                 | self._dummy_batch.added_keys)
            for key in all_accessed_keys:
                if key not in self.view_requirements and key != SampleBatch.SEQ_LENS:
                    self.view_requirements[key] = ViewRequirement(
                        used_for_compute_actions=False)
            if self._loss or is_overridden(self.loss):
                # Tag those only needed for post-processing (with some
                # exceptions).
                for key in self._dummy_batch.accessed_keys:
                    if (key not in train_batch.accessed_keys
                            and key in self.view_requirements
                            and key not in self.model.view_requirements
                            and key not in [
                                SampleBatch.EPS_ID,
                                SampleBatch.AGENT_INDEX,
                                SampleBatch.UNROLL_ID,
                                SampleBatch.DONES,
                                SampleBatch.REWARDS,
                                SampleBatch.INFOS,
                            ]):
                        self.view_requirements[key].used_for_training = False
                # Remove those not needed at all (leave those that are needed
                # by Sampler to properly execute sample collection).
                # Also always leave DONES, REWARDS, INFOS, no matter what.
                for key in list(self.view_requirements.keys()):
                    if (key not in all_accessed_keys and key not in [
                            SampleBatch.EPS_ID,
                            SampleBatch.AGENT_INDEX,
                            SampleBatch.UNROLL_ID,
                            SampleBatch.DONES,
                            SampleBatch.REWARDS,
                            SampleBatch.INFOS,
                    ] and key not in self.model.view_requirements):
                        # If user deleted this key manually in postprocessing
                        # fn, warn about it and do not remove from
                        # view-requirements.
                        if key in self._dummy_batch.deleted_keys:
                            logger.warning(
                                "SampleBatch key '{}' was deleted manually in "
                                "postprocessing function! RLlib will "
                                "automatically remove non-used items from the "
                                "data stream. Remove the `del` from your "
                                "postprocessing function.".format(key))
                        # If we are not writing output to disk, save to erase
                        # this key to save space in the sample batch.
                        elif self.config["output"] is None:
                            del self.view_requirements[key]