def model_fn(self, features, labels, mode, params):
        """The implementation of PPO algorithm.

    Args:
      features: dict from string to tensor with shape
            'state_tensor': [BATCH_SIZE, env.state_space]
      labels: dict from string to tensor with shape
              'action_tensor': [BATCH_SIZE, self._env_action_space]
              'advantage_tensor': [BATCH_SIZE]
              'returns_tensor': [BATCH_SIZE]
      mode: a tf.estimator.ModeKeys (batchnorm params update for TRAIN only).
      params: (Ignored; needed for compat with TPUEstimator).

    Returns:
      tf.estimator.EstimatorSpec with props.
      mode: same as mode arg.
      predictions: dict of tensors
              'mean': [BATCH_SIZE, self._env_action_space]
              'logstd': [BATCH_SIZE, self._env_action_space]
              'value': [BATCH_SIZE]
              'action': [BATCH_SIZE, self._env_action_space]
              'neg_logprob': [BATCH_SIZE, self._env_action_space]
      loss: a single value tensor.
      train_op: train op eval_metric_ops return dict of tensors.
    """

        # Policy network
        network_out = self.model_inference_fn_ppo(features['mcts_features'],
                                                  'new')
        self.value_new = network_out[0]
        self.logstd_new = network_out[1]
        self.mean_new = network_out[2]

        self.global_step = tf.train.get_or_create_global_step()
        # Sample an action
        pd_new = distributions.MultiVariateNormalDiag(mean=self.mean_new,
                                                      logstd=self.logstd_new)
        action_sample = pd_new.sample()
        action_sample_neg_logprob = pd_new.negative_log_prob(action_sample)

        # Used during TF estimator prediction
        if mode == tf_estimator.ModeKeys.PREDICT:
            predictions = {
                'mean': self.mean_new,
                'logstd': self.logstd_new,
                'value': self.value_new,
                'action': action_sample,
                'neg_logprob': action_sample_neg_logprob
            }
            pred_estimator = tf_estimator.tpu.TPUEstimatorSpec(
                mode,
                predictions=predictions,
                export_outputs={
                    'ppo_inference':
                    tf_estimator.export.PredictOutput({
                        'mean':
                        self.mean_new,
                        'logstd':
                        self.logstd_new,
                        'value':
                        self.value_new,
                        'action':
                        action_sample,
                        'neg_logprob':
                        action_sample_neg_logprob
                    })
                })
            return pred_estimator.as_estimator_spec()

        # Placeholder
        self.mcts_sampling_enable = tf.reduce_all(labels['mcts_enable_tensor'])

        self.mean_old = labels['mean_tensor']
        self.logstd_old = labels['logstd_tensor']
        pd_old = distributions.MultiVariateNormalDiag(mean=self.mean_old,
                                                      logstd=self.logstd_old)

        batch_advantage_norm = self.calc_normalized_advantage(
            return_tensor=labels['policy_return_tensor'],
            value_tensor=labels['policy_value_tensor'])

        self.compute_total_loss(pd_new, pd_old, labels['value_tensor'],
                                labels['return_tensor'], batch_advantage_norm,
                                labels['policy_old_neg_logprob_tensor'],
                                labels['policy_action_tensor'])
        # Update learning rate
        self.update_learning_rate()

        # Build training operation
        self.total_params = tf.trainable_variables(scope='newpolicy')

        train_ops = self.build_training_op(self.total_loss)

        host_call = self.create_host_call_fn(params)

        if mode != tf_estimator.ModeKeys.TRAIN:
            raise ValueError('Estimator mode should be train at this point.')

        if mode == tf_estimator.ModeKeys.TRAIN:
            # Setup fine tune scaffold
            # The scaffold here is used to restore the weights from _warmstart_file.
            # If _warmstart_file is None, the training starts from the beginning.
            if self._warmstart_file:
                logging.info('Warmstart')

                def tpu_scaffold():
                    # restore all the variables
                    tf.init_from_checkpoint(self._warmstart_file,
                                            {'newpolicy/': 'newpolicy/'})
                    return tf.train.Scaffold()

                scaffold_fn = tpu_scaffold
            else:
                scaffold_fn = None

            tpu_estimator_spec = tf_estimator.tpu.TPUEstimatorSpec(
                mode=mode,
                loss=self.total_loss,
                train_op=train_ops,
                host_call=host_call,
                scaffold_fn=scaffold_fn)
            if self._use_tpu:
                return tpu_estimator_spec
            else:
                return tpu_estimator_spec.as_estimator_spec()
    def compute_total_loss(self, pd_new, pd_old, value_tensor, return_tensor,
                           batch_advantage_norm, policy_old_neg_logprob_tensor,
                           policy_action_tensor):
        """Defines the total loss function.

    Args:
      pd_new: The current policy distribution
        (a multivariate normal distribution). This policy distribution gets
        updated in the course of training.
      pd_old: The old policy distribution that we use during sampling the
        trajectory (a multivariate normal distribution).
      value_tensor: The values associated to the rollout trajectory.
      return_tensor: The return values computed for the rollout trajectory.
      batch_advantage_norm: The normalized advantage tensor computed for a
        batch of data. For advantage calculation, we use generalized
        advantage estimation (GAE) formula.
      policy_old_neg_logprob_tensor: The negative log probabilities from the
        policy rollouts.
      policy_action_tensor: The actions from the policy rollouts.
    """
        # Policy loss
        ppo_policy_loss_out = ppo_loss.ppo_policy_loss(
            neg_logprobs_old=policy_old_neg_logprob_tensor,
            actions=policy_action_tensor,
            advantages=batch_advantage_norm,
            dist_new=pd_new,
            mcts_sampling=self.mcts_sampling_enable)

        (self.policy_loss, self.approxkl, self.clipfrac,
         self.policy_ratio) = ppo_policy_loss_out

        # Value Loss
        if self._ppo2_enable:
            self.value_loss = ppo_loss.ppo2_value_loss(
                value_old=value_tensor,
                pred_value=self.value_new,
                returns=return_tensor)
        else:
            self.value_loss = ppo_loss.ppo1_value_loss(
                pred_value=self.value_new, returns=return_tensor)

        # MSE loss between mean and standard deviations
        self.mean_mse_loss, self.logstd_mse_loss = ppo_loss.l2_norm_policy_loss(
            policy_mean=self.mean_new,
            policy_logstd=self.logstd_new,
            mcts_mean=self.mean_old,
            mcts_logstd=self.logstd_old)

        mcts_dist = distributions.MultiVariateNormalDiag(
            mean=self.mean_old, logstd=self.logstd_old)
        policy_dist = distributions.MultiVariateNormalDiag(
            mean=self.mean_new, logstd=self.logstd_new)
        self.imitation_kl_divergence = tf.reduce_mean(
            policy_dist.kl_divergence(mcts_dist))
        # Calculate KL divergence and entropy of new distribution
        self.kl_divergence = tf.reduce_mean(pd_new.kl_divergence(pd_old))
        self.entropy = pd_new.entropy()

        # Calculate entropy loss
        self.entropy_loss = tf.reduce_mean(self.entropy)

        # Calulate total loss
        total_loss_ppo = (self._policy_coeff * self.policy_loss) + (
            self._value_coeff * self.value_loss) - (self._entropy_coeff *
                                                    self.entropy_loss)

        total_loss_mcts = (self._value_coeff * self.value_loss) + (
            self._mse_loss_coeff *
            (self.imitation_kl_divergence + self.entropy_loss))

        self.total_loss = tf.cond(tf.equal(self.mcts_sampling_enable,
                                           True), lambda: total_loss_mcts,
                                  lambda: total_loss_ppo)