Beispiel #1
0
    def _kl_divergence(self, time_steps, action_distribution_parameters,
                       current_policy_distribution):
        outer_dims = list(
            range(nest_utils.get_outer_rank(time_steps, self.time_step_spec)))

        old_actions_distribution = (
            distribution_spec.nested_distributions_from_specs(
                self._action_distribution_spec,
                action_distribution_parameters))

        kl_divergence = ppo_utils.nested_kl_divergence(
            old_actions_distribution,
            current_policy_distribution,
            outer_dims=outer_dims)
        return kl_divergence
        def _train_loop_body(counter, policy_state, training_info_ta):
            exp = tf.nest.map_structure(lambda ta: ta.read(counter),
                                        experience_ta)
            collect_action_distribution_param = exp.action_distribution
            collect_action_distribution = nested_distributions_from_specs(
                self._action_distribution_spec,
                collect_action_distribution_param)
            exp = exp._replace(action_distribution=collect_action_distribution)

            policy_state = common.reset_state_if_necessary(
                policy_state, initial_train_state,
                tf.equal(exp.step_type, StepType.FIRST))

            policy_step = common.algorithm_step(self._algorithm,
                                                self._observation_transformer,
                                                exp,
                                                policy_state,
                                                training=True)

            action_dist_param = common.get_distribution_params(
                policy_step.action)

            training_info = make_training_info(
                action=exp.action,
                action_distribution=action_dist_param,
                reward=exp.reward,
                discount=exp.discount,
                step_type=exp.step_type,
                info=policy_step.info,
                collect_info=exp.info,
                collect_action_distribution=collect_action_distribution_param)

            training_info_ta = tf.nest.map_structure(
                lambda ta, x: ta.write(counter, x), training_info_ta,
                training_info)

            counter += 1

            return [counter, policy_step.state, training_info_ta]
Beispiel #3
0
    def _iter(self, time_step, policy_state):
        """One training iteration."""
        counter = tf.zeros((), tf.int32)
        batch_size = self._env.batch_size

        def create_ta(s):
            return tf.TensorArray(dtype=s.dtype,
                                  size=self._train_interval + 1,
                                  element_shape=tf.TensorShape(
                                      [batch_size]).concatenate(s.shape))

        training_info_ta = tf.nest.map_structure(create_ta,
                                                 self._training_info_spec)

        with tf.GradientTape(watch_accessed_variables=False,
                             persistent=True) as tape:
            tape.watch(self._trainable_variables)
            [counter, time_step, policy_state, training_info_ta
             ] = tf.while_loop(cond=lambda *_: True,
                               body=self._train_loop_body,
                               loop_vars=[
                                   counter, time_step, policy_state,
                                   training_info_ta
                               ],
                               back_prop=True,
                               parallel_iterations=1,
                               maximum_iterations=self._train_interval,
                               name='iter_loop')

        if self._final_step_mode == OnPolicyDriver.FINAL_STEP_SKIP:
            next_time_step, policy_step, action = self._step(
                time_step, policy_state)
            next_state = policy_step.state
        else:
            policy_step = common.algorithm_step(self._algorithm.rollout,
                                                self._observation_transformer,
                                                time_step, policy_state)
            action = common.sample_action_distribution(policy_step.action)
            next_time_step = time_step
            next_state = policy_state

        action_distribution_param = common.get_distribution_params(
            policy_step.action)

        final_training_info = make_training_info(
            action_distribution=action_distribution_param,
            action=action,
            reward=time_step.reward,
            discount=time_step.discount,
            step_type=time_step.step_type,
            info=policy_step.info)

        with tape:
            training_info_ta = tf.nest.map_structure(
                lambda ta, x: ta.write(counter, x), training_info_ta,
                final_training_info)
            training_info = tf.nest.map_structure(lambda ta: ta.stack(),
                                                  training_info_ta)

            action_distribution = nested_distributions_from_specs(
                self._algorithm.action_distribution_spec,
                training_info.action_distribution)

            training_info = training_info._replace(
                action_distribution=action_distribution)

        loss_info, grads_and_vars = self._algorithm.train_complete(
            tape, training_info)

        del tape

        self._training_summary(training_info, loss_info, grads_and_vars)

        self._train_step_counter.assign_add(1)

        return next_time_step, next_state
Beispiel #4
0
    def _train(self, experience, weights):
        # Get individual tensors from transitions.
        (time_steps, policy_steps_,
         next_time_steps) = trajectory.to_transition(experience)
        actions = policy_steps_.action

        if self._debug_summaries:
            actions_list = tf.nest.flatten(actions)
            show_action_index = len(actions_list) != 1
            for i, single_action in enumerate(actions_list):
                action_name = ('actions_{}'.format(i)
                               if show_action_index else 'actions')
                tf.compat.v2.summary.histogram(name=action_name,
                                               data=single_action,
                                               step=self.train_step_counter)

        action_distribution_parameters = policy_steps_.info

        # Reconstruct per-timestep policy distribution from stored distribution
        #   parameters.
        old_actions_distribution = (
            distribution_spec.nested_distributions_from_specs(
                self._action_distribution_spec,
                action_distribution_parameters))

        # Compute log probability of actions taken during data collection, using the
        #   collect policy distribution.
        act_log_probs = common.log_probability(old_actions_distribution,
                                               actions, self._action_spec)

        # Compute the value predictions for states using the current value function.
        # To be used for return & advantage computation.
        batch_size = nest_utils.get_outer_shape(time_steps,
                                                self._time_step_spec)[0]
        policy_state = self._collect_policy.get_initial_state(
            batch_size=batch_size)

        value_preds, unused_policy_state = self._collect_policy.apply_value_network(
            experience.observation,
            experience.step_type,
            policy_state=policy_state)
        value_preds = tf.stop_gradient(value_preds)

        valid_mask = ppo_utils.make_timestep_mask(next_time_steps)

        if weights is None:
            weights = valid_mask
        else:
            weights *= valid_mask

        returns, normalized_advantages = self.compute_return_and_advantage(
            next_time_steps, value_preds)

        # Loss tensors across batches will be aggregated for summaries.
        policy_gradient_losses = []
        value_estimation_losses = []
        l2_regularization_losses = []
        entropy_regularization_losses = []
        kl_penalty_losses = []

        loss_info = None  # TODO(b/123627451): Remove.
        # For each epoch, create its own train op that depends on the previous one.
        for i_epoch in range(self._num_epochs):
            with tf.name_scope('epoch_%d' % i_epoch):
                # Only save debug summaries for first and last epochs.
                debug_summaries = (self._debug_summaries
                                   and (i_epoch == 0
                                        or i_epoch == self._num_epochs - 1))

                # Build one epoch train op.
                with tf.GradientTape() as tape:
                    loss_info = self.get_epoch_loss(
                        time_steps, actions, act_log_probs, returns,
                        normalized_advantages, action_distribution_parameters,
                        weights, self.train_step_counter, debug_summaries)

                variables_to_train = (self._actor_net.trainable_weights +
                                      self._value_net.trainable_weights)
                grads = tape.gradient(loss_info.loss, variables_to_train)
                # Tuple is used for py3, where zip is a generator producing values once.
                grads_and_vars = tuple(zip(grads, variables_to_train))
                if self._gradient_clipping > 0:
                    grads_and_vars = eager_utils.clip_gradient_norms(
                        grads_and_vars, self._gradient_clipping)

                # If summarize_gradients, create functions for summarizing both
                # gradients and variables.
                if self._summarize_grads_and_vars and debug_summaries:
                    eager_utils.add_gradients_summaries(
                        grads_and_vars, self.train_step_counter)
                    eager_utils.add_variables_summaries(
                        grads_and_vars, self.train_step_counter)

                self._optimizer.apply_gradients(
                    grads_and_vars, global_step=self.train_step_counter)

                policy_gradient_losses.append(
                    loss_info.extra.policy_gradient_loss)
                value_estimation_losses.append(
                    loss_info.extra.value_estimation_loss)
                l2_regularization_losses.append(
                    loss_info.extra.l2_regularization_loss)
                entropy_regularization_losses.append(
                    loss_info.extra.entropy_regularization_loss)
                kl_penalty_losses.append(loss_info.extra.kl_penalty_loss)

        # After update epochs, update adaptive kl beta, then update observation
        #   normalizer and reward normalizer.
        batch_size = nest_utils.get_outer_shape(time_steps,
                                                self._time_step_spec)[0]
        policy_state = self._collect_policy.get_initial_state(batch_size)
        # Compute the mean kl from previous action distribution.
        kl_divergence = self._kl_divergence(
            time_steps, action_distribution_parameters,
            self._collect_policy.distribution(time_steps, policy_state).action)
        self.update_adaptive_kl_beta(kl_divergence)

        if self._observation_normalizer:
            self._observation_normalizer.update(time_steps.observation,
                                                outer_dims=[0, 1])
        else:
            # TODO(b/127661780): Verify performance of reward_normalizer when obs are
            #                    not normalized
            if self._reward_normalizer:
                self._reward_normalizer.update(next_time_steps.reward,
                                               outer_dims=[0, 1])

        loss_info = tf.nest.map_structure(tf.identity, loss_info)

        # Make summaries for total loss across all epochs.
        # The *_losses lists will have been populated by
        #   calls to self.get_epoch_loss.
        with tf.name_scope('Losses/'):
            total_policy_gradient_loss = tf.add_n(policy_gradient_losses)
            total_value_estimation_loss = tf.add_n(value_estimation_losses)
            total_l2_regularization_loss = tf.add_n(l2_regularization_losses)
            total_entropy_regularization_loss = tf.add_n(
                entropy_regularization_losses)
            total_kl_penalty_loss = tf.add_n(kl_penalty_losses)
            tf.compat.v2.summary.scalar(name='policy_gradient_loss',
                                        data=total_policy_gradient_loss,
                                        step=self.train_step_counter)
            tf.compat.v2.summary.scalar(name='value_estimation_loss',
                                        data=total_value_estimation_loss,
                                        step=self.train_step_counter)
            tf.compat.v2.summary.scalar(name='l2_regularization_loss',
                                        data=total_l2_regularization_loss,
                                        step=self.train_step_counter)
            tf.compat.v2.summary.scalar(name='entropy_regularization_loss',
                                        data=total_entropy_regularization_loss,
                                        step=self.train_step_counter)
            tf.compat.v2.summary.scalar(name='kl_penalty_loss',
                                        data=total_kl_penalty_loss,
                                        step=self.train_step_counter)

            total_abs_loss = (tf.abs(total_policy_gradient_loss) +
                              tf.abs(total_value_estimation_loss) +
                              tf.abs(total_entropy_regularization_loss) +
                              tf.abs(total_l2_regularization_loss) +
                              tf.abs(total_kl_penalty_loss))

            tf.compat.v2.summary.scalar(name='total_abs_loss',
                                        data=total_abs_loss,
                                        step=self.train_step_counter)

        if self._summarize_grads_and_vars:
            with tf.name_scope('Variables/'):
                all_vars = (self._actor_net.trainable_weights +
                            self._value_net.trainable_weights)
                for var in all_vars:
                    tf.compat.v2.summary.histogram(
                        name=var.name.replace(':', '_'),
                        data=var,
                        step=self.train_step_counter)

        return loss_info
    def _update(self, experience, weight):
        batch_size = experience.step_type.shape[1]
        counter = tf.zeros((), tf.int32)
        initial_train_state = common.get_initial_policy_state(
            batch_size, self._algorithm.train_state_spec)
        num_steps = experience.step_type.shape[0]

        def create_ta(s):
            return tf.TensorArray(dtype=s.dtype,
                                  size=num_steps,
                                  element_shape=tf.TensorShape(
                                      [batch_size]).concatenate(s.shape))

        experience_ta = tf.nest.map_structure(create_ta,
                                              self._processed_experience_spec)
        experience_ta = tf.nest.map_structure(
            lambda elem, ta: ta.unstack(elem), experience, experience_ta)
        training_info_ta = tf.nest.map_structure(create_ta,
                                                 self._training_info_spec)

        def _train_loop_body(counter, policy_state, training_info_ta):
            exp = tf.nest.map_structure(lambda ta: ta.read(counter),
                                        experience_ta)
            collect_action_distribution_param = exp.action_distribution
            collect_action_distribution = nested_distributions_from_specs(
                self._action_distribution_spec,
                collect_action_distribution_param)
            exp = exp._replace(action_distribution=collect_action_distribution)

            policy_state = common.reset_state_if_necessary(
                policy_state, initial_train_state,
                tf.equal(exp.step_type, StepType.FIRST))

            policy_step = common.algorithm_step(self._algorithm,
                                                self._observation_transformer,
                                                exp,
                                                policy_state,
                                                training=True)

            action_dist_param = common.get_distribution_params(
                policy_step.action)

            training_info = make_training_info(
                action=exp.action,
                action_distribution=action_dist_param,
                reward=exp.reward,
                discount=exp.discount,
                step_type=exp.step_type,
                info=policy_step.info,
                collect_info=exp.info,
                collect_action_distribution=collect_action_distribution_param)

            training_info_ta = tf.nest.map_structure(
                lambda ta, x: ta.write(counter, x), training_info_ta,
                training_info)

            counter += 1

            return [counter, policy_step.state, training_info_ta]

        with tf.GradientTape(persistent=True,
                             watch_accessed_variables=False) as tape:
            tape.watch(self._trainable_variables)
            [_, _, training_info_ta] = tf.while_loop(
                cond=lambda counter, *_: tf.less(counter, num_steps),
                body=_train_loop_body,
                loop_vars=[counter, initial_train_state, training_info_ta],
                back_prop=True,
                name="train_loop")
            training_info = tf.nest.map_structure(lambda ta: ta.stack(),
                                                  training_info_ta)
            action_distribution = nested_distributions_from_specs(
                self._action_distribution_spec,
                training_info.action_distribution)
            collect_action_distribution = nested_distributions_from_specs(
                self._action_distribution_spec,
                training_info.collect_action_distribution)
            training_info = training_info._replace(
                action_distribution=action_distribution,
                collect_action_distribution=collect_action_distribution)

        loss_info, grads_and_vars = self._algorithm.train_complete(
            tape=tape, training_info=training_info, weight=weight)

        del tape

        return training_info, loss_info, grads_and_vars
    def _prepare_specs(self, algorithm):
        """Prepare various tensor specs."""
        def extract_spec(nest):
            return tf.nest.map_structure(
                lambda t: tf.TensorSpec(t.shape[1:], t.dtype), nest)

        time_step = self.get_initial_time_step()
        self._time_step_spec = extract_spec(time_step)
        self._action_spec = self._env.action_spec()

        policy_step = algorithm.predict(time_step, self._initial_state)
        info_spec = extract_spec(policy_step.info)
        self._pred_policy_step_spec = PolicyStep(
            action=self._action_spec,
            state=algorithm.predict_state_spec,
            info=info_spec)

        def _to_distribution_spec(spec):
            if isinstance(spec, tf.TensorSpec):
                return DistributionSpec(tfp.distributions.Deterministic,
                                        input_params_spec={"loc": spec},
                                        sample_spec=spec)
            return spec

        self._action_distribution_spec = tf.nest.map_structure(
            _to_distribution_spec, algorithm.action_distribution_spec)
        self._action_dist_param_spec = tf.nest.map_structure(
            lambda spec: spec.input_params_spec,
            self._action_distribution_spec)

        self._experience_spec = Experience(
            step_type=self._time_step_spec.step_type,
            reward=self._time_step_spec.reward,
            discount=self._time_step_spec.discount,
            observation=self._time_step_spec.observation,
            prev_action=self._action_spec,
            action=self._action_spec,
            info=info_spec,
            action_distribution=self._action_dist_param_spec)

        action_dist_params = common.zero_tensor_from_nested_spec(
            self._experience_spec.action_distribution, self._env.batch_size)
        action_dist = nested_distributions_from_specs(
            self._action_distribution_spec, action_dist_params)
        exp = Experience(step_type=time_step.step_type,
                         reward=time_step.reward,
                         discount=time_step.discount,
                         observation=time_step.observation,
                         prev_action=time_step.prev_action,
                         action=time_step.prev_action,
                         info=policy_step.info,
                         action_distribution=action_dist)

        processed_exp = algorithm.preprocess_experience(exp)
        self._processed_experience_spec = self._experience_spec._replace(
            info=extract_spec(processed_exp.info))

        policy_step = common.algorithm_step(
            algorithm,
            ob_transformer=self._observation_transformer,
            time_step=exp,
            state=common.get_initial_policy_state(self._env.batch_size,
                                                  algorithm.train_state_spec),
            training=True)
        info_spec = extract_spec(policy_step.info)
        self._training_info_spec = make_training_info(
            action=self._action_spec,
            action_distribution=self._action_dist_param_spec,
            step_type=self._time_step_spec.step_type,
            reward=self._time_step_spec.reward,
            discount=self._time_step_spec.discount,
            info=info_spec,
            collect_info=self._processed_experience_spec.info,
            collect_action_distribution=self._action_dist_param_spec)
Beispiel #7
0
    def _train(self, experience, weights, train_step_counter):
        # Change trajectory to transitions.
        trajectory0 = nest.map_structure(lambda t: t[:, :-1], experience)
        trajectory1 = nest.map_structure(lambda t: t[:, 1:], experience)

        # Get individual tensors from transitions.
        (time_steps, policy_steps_,
         next_time_steps) = trajectory.to_transition(trajectory0, trajectory1)
        actions = policy_steps_.action
        if self._debug_summaries:
            tf.contrib.summary.histogram('actions', actions)

        action_distribution_parameters = policy_steps_.info

        # Reconstruct per-timestep policy distribution from stored distribution
        #   parameters.
        old_actions_distribution = (
            distribution_spec.nested_distributions_from_specs(
                self._action_distribution_spec,
                action_distribution_parameters))

        # Compute log probability of actions taken during data collection, using the
        #   collect policy distribution.
        act_log_probs = common_utils.log_probability(old_actions_distribution,
                                                     actions,
                                                     self._action_spec)

        # Compute the value predictions for states using the current value function.
        # To be used for return & advantage computation.
        batch_size = nest_utils.get_outer_shape(time_steps,
                                                self._time_step_spec)[0]
        policy_state = self._collect_policy.get_initial_state(
            batch_size=batch_size)

        value_preds, unused_policy_state = self._collect_policy.apply_value_network(
            experience.observation,
            experience.step_type,
            policy_state=policy_state)
        value_preds = tf.stop_gradient(value_preds)

        valid_mask = ppo_utils.make_timestep_mask(next_time_steps)

        if weights is None:
            weights = valid_mask
        else:
            weights *= valid_mask

        returns, normalized_advantages = self.compute_return_and_advantage(
            next_time_steps, value_preds)

        # Loss tensors across batches will be aggregated for summaries.
        policy_gradient_losses = []
        value_estimation_losses = []
        l2_regularization_losses = []
        entropy_regularization_losses = []
        kl_penalty_losses = []

        # For each epoch, create its own train op that depends on the previous one.
        loss_info = tf.no_op()
        for i_epoch in range(self._num_epochs):
            with tf.name_scope('epoch_%d' % i_epoch):
                with tf.control_dependencies(nest.flatten(loss_info)):
                    # Only save debug summaries for first and last epochs.
                    debug_summaries = (self._debug_summaries and
                                       (i_epoch == 0
                                        or i_epoch == self._num_epochs - 1))

                    # Build one epoch train op.
                    loss_info = self.build_train_op(
                        time_steps, actions, act_log_probs, returns,
                        normalized_advantages, action_distribution_parameters,
                        weights, train_step_counter,
                        self._summarize_grads_and_vars,
                        self._gradient_clipping, debug_summaries)

                    policy_gradient_losses.append(
                        loss_info.extra.policy_gradient_loss)
                    value_estimation_losses.append(
                        loss_info.extra.value_estimation_loss)
                    l2_regularization_losses.append(
                        loss_info.extra.l2_regularization_loss)
                    entropy_regularization_losses.append(
                        loss_info.extra.entropy_regularization_loss)
                    kl_penalty_losses.append(loss_info.extra.kl_penalty_loss)

        # After update epochs, update adaptive kl beta, then update observation
        #   normalizer and reward normalizer.
        with tf.control_dependencies(nest.flatten(loss_info)):
            # Compute the mean kl from old.
            batch_size = nest_utils.get_outer_shape(time_steps,
                                                    self._time_step_spec)[0]
            policy_state = self._collect_policy.get_initial_state(batch_size)
            kl_divergence = self._kl_divergence(
                time_steps, action_distribution_parameters,
                self._collect_policy.distribution(time_steps,
                                                  policy_state).action)
            update_adaptive_kl_beta_op = self.update_adaptive_kl_beta(
                kl_divergence)

        with tf.control_dependencies([update_adaptive_kl_beta_op]):
            if self._observation_normalizer:
                update_obs_norm = (self._observation_normalizer.update(
                    time_steps.observation, outer_dims=[0, 1]))
            else:
                update_obs_norm = tf.no_op()
            if self._reward_normalizer:
                update_reward_norm = self._reward_normalizer.update(
                    next_time_steps.reward, outer_dims=[0, 1])
            else:
                update_reward_norm = tf.no_op()

        with tf.control_dependencies([update_obs_norm, update_reward_norm]):
            loss_info = nest.map_structure(tf.identity, loss_info)

        # Make summaries for total loss across all epochs.
        # The *_losses lists will have been populated by
        #   calls to self.build_train_op.
        with tf.name_scope('Losses/'):
            total_policy_gradient_loss = tf.add_n(policy_gradient_losses)
            total_value_estimation_loss = tf.add_n(value_estimation_losses)
            total_l2_regularization_loss = tf.add_n(l2_regularization_losses)
            total_entropy_regularization_loss = tf.add_n(
                entropy_regularization_losses)
            total_kl_penalty_loss = tf.add_n(kl_penalty_losses)
            tf.contrib.summary.scalar('policy_gradient_loss',
                                      total_policy_gradient_loss)
            tf.contrib.summary.scalar('value_estimation_loss',
                                      total_value_estimation_loss)
            tf.contrib.summary.scalar('l2_regularization_loss',
                                      total_l2_regularization_loss)
            if self._entropy_regularization:
                tf.contrib.summary.scalar('entropy_regularization_loss',
                                          total_entropy_regularization_loss)
            tf.contrib.summary.scalar('kl_penalty_loss', total_kl_penalty_loss)

            total_abs_loss = (tf.abs(total_policy_gradient_loss) +
                              tf.abs(total_value_estimation_loss) +
                              tf.abs(total_entropy_regularization_loss) +
                              tf.abs(total_l2_regularization_loss) +
                              tf.abs(total_kl_penalty_loss))

            tf.contrib.summary.scalar('total_abs_loss', total_abs_loss)

        if self._summarize_grads_and_vars:
            with tf.name_scope('Variables/'):
                all_vars = (self._actor_net.trainable_weights +
                            self._value_net.trainable_weights)
                for var in all_vars:
                    tf.contrib.summary.histogram(var.name.replace(':', '_'),
                                                 var)

        return loss_info
Beispiel #8
0
    def prepare_off_policy_specs(self, time_step: ActionTimeStep):
        """Prepare various tensor specs for off_policy training.

        prepare_off_policy_specs is called by OffPolicyDriver._prepare_spec().

        """

        self._env_batch_size = time_step.step_type.shape[0]
        self._time_step_spec = common.extract_spec(time_step)
        initial_state = common.get_initial_policy_state(
            self._env_batch_size, self.train_state_spec)
        transformed_timestep = self.transform_timestep(time_step)
        policy_step = self.rollout(transformed_timestep, initial_state)
        info_spec = common.extract_spec(policy_step.info)

        self._action_distribution_spec = tf.nest.map_structure(
            common.to_distribution_spec, self.action_distribution_spec)
        self._action_dist_param_spec = tf.nest.map_structure(
            lambda spec: spec.input_params_spec,
            self._action_distribution_spec)

        self._experience_spec = Experience(
            step_type=self._time_step_spec.step_type,
            reward=self._time_step_spec.reward,
            discount=self._time_step_spec.discount,
            observation=self._time_step_spec.observation,
            prev_action=self._action_spec,
            action=self._action_spec,
            info=info_spec,
            action_distribution=self._action_dist_param_spec,
            state=self.train_state_spec if self._use_rollout_state else ())

        action_dist_params = common.zero_tensor_from_nested_spec(
            self._experience_spec.action_distribution, self._env_batch_size)
        action_dist = nested_distributions_from_specs(
            self._action_distribution_spec, action_dist_params)

        exp = Experience(step_type=time_step.step_type,
                         reward=time_step.reward,
                         discount=time_step.discount,
                         observation=time_step.observation,
                         prev_action=time_step.prev_action,
                         action=time_step.prev_action,
                         info=policy_step.info,
                         action_distribution=action_dist,
                         state=initial_state if self._use_rollout_state else
                         ())

        transformed_exp = self.transform_timestep(exp)
        processed_exp = self.preprocess_experience(transformed_exp)
        self._processed_experience_spec = self._experience_spec._replace(
            observation=common.extract_spec(processed_exp.observation),
            info=common.extract_spec(processed_exp.info))

        policy_step = common.algorithm_step(
            algorithm_step_func=self.train_step,
            time_step=processed_exp,
            state=initial_state)
        info_spec = common.extract_spec(policy_step.info)
        self._training_info_spec = TrainingInfo(
            action_distribution=self._action_dist_param_spec, info=info_spec)