def model_fn(self, features, labels, mode, params): """The implementation of PPO algorithm. Args: features: dict from string to tensor with shape 'state_tensor': [BATCH_SIZE, env.state_space] labels: dict from string to tensor with shape 'action_tensor': [BATCH_SIZE, self._env_action_space] 'advantage_tensor': [BATCH_SIZE] 'returns_tensor': [BATCH_SIZE] mode: a tf.estimator.ModeKeys (batchnorm params update for TRAIN only). params: (Ignored; needed for compat with TPUEstimator). Returns: tf.estimator.EstimatorSpec with props. mode: same as mode arg. predictions: dict of tensors 'mean': [BATCH_SIZE, self._env_action_space] 'logstd': [BATCH_SIZE, self._env_action_space] 'value': [BATCH_SIZE] 'action': [BATCH_SIZE, self._env_action_space] 'neg_logprob': [BATCH_SIZE, self._env_action_space] loss: a single value tensor. train_op: train op eval_metric_ops return dict of tensors. """ # Policy network network_out = self.model_inference_fn_ppo(features['mcts_features'], 'new') self.value_new = network_out[0] self.logstd_new = network_out[1] self.mean_new = network_out[2] self.global_step = tf.train.get_or_create_global_step() # Sample an action pd_new = distributions.MultiVariateNormalDiag(mean=self.mean_new, logstd=self.logstd_new) action_sample = pd_new.sample() action_sample_neg_logprob = pd_new.negative_log_prob(action_sample) # Used during TF estimator prediction if mode == tf_estimator.ModeKeys.PREDICT: predictions = { 'mean': self.mean_new, 'logstd': self.logstd_new, 'value': self.value_new, 'action': action_sample, 'neg_logprob': action_sample_neg_logprob } pred_estimator = tf_estimator.tpu.TPUEstimatorSpec( mode, predictions=predictions, export_outputs={ 'ppo_inference': tf_estimator.export.PredictOutput({ 'mean': self.mean_new, 'logstd': self.logstd_new, 'value': self.value_new, 'action': action_sample, 'neg_logprob': action_sample_neg_logprob }) }) return pred_estimator.as_estimator_spec() # Placeholder self.mcts_sampling_enable = tf.reduce_all(labels['mcts_enable_tensor']) self.mean_old = labels['mean_tensor'] self.logstd_old = labels['logstd_tensor'] pd_old = distributions.MultiVariateNormalDiag(mean=self.mean_old, logstd=self.logstd_old) batch_advantage_norm = self.calc_normalized_advantage( return_tensor=labels['policy_return_tensor'], value_tensor=labels['policy_value_tensor']) self.compute_total_loss(pd_new, pd_old, labels['value_tensor'], labels['return_tensor'], batch_advantage_norm, labels['policy_old_neg_logprob_tensor'], labels['policy_action_tensor']) # Update learning rate self.update_learning_rate() # Build training operation self.total_params = tf.trainable_variables(scope='newpolicy') train_ops = self.build_training_op(self.total_loss) host_call = self.create_host_call_fn(params) if mode != tf_estimator.ModeKeys.TRAIN: raise ValueError('Estimator mode should be train at this point.') if mode == tf_estimator.ModeKeys.TRAIN: # Setup fine tune scaffold # The scaffold here is used to restore the weights from _warmstart_file. # If _warmstart_file is None, the training starts from the beginning. if self._warmstart_file: logging.info('Warmstart') def tpu_scaffold(): # restore all the variables tf.init_from_checkpoint(self._warmstart_file, {'newpolicy/': 'newpolicy/'}) return tf.train.Scaffold() scaffold_fn = tpu_scaffold else: scaffold_fn = None tpu_estimator_spec = tf_estimator.tpu.TPUEstimatorSpec( mode=mode, loss=self.total_loss, train_op=train_ops, host_call=host_call, scaffold_fn=scaffold_fn) if self._use_tpu: return tpu_estimator_spec else: return tpu_estimator_spec.as_estimator_spec()
def compute_total_loss(self, pd_new, pd_old, value_tensor, return_tensor, batch_advantage_norm, policy_old_neg_logprob_tensor, policy_action_tensor): """Defines the total loss function. Args: pd_new: The current policy distribution (a multivariate normal distribution). This policy distribution gets updated in the course of training. pd_old: The old policy distribution that we use during sampling the trajectory (a multivariate normal distribution). value_tensor: The values associated to the rollout trajectory. return_tensor: The return values computed for the rollout trajectory. batch_advantage_norm: The normalized advantage tensor computed for a batch of data. For advantage calculation, we use generalized advantage estimation (GAE) formula. policy_old_neg_logprob_tensor: The negative log probabilities from the policy rollouts. policy_action_tensor: The actions from the policy rollouts. """ # Policy loss ppo_policy_loss_out = ppo_loss.ppo_policy_loss( neg_logprobs_old=policy_old_neg_logprob_tensor, actions=policy_action_tensor, advantages=batch_advantage_norm, dist_new=pd_new, mcts_sampling=self.mcts_sampling_enable) (self.policy_loss, self.approxkl, self.clipfrac, self.policy_ratio) = ppo_policy_loss_out # Value Loss if self._ppo2_enable: self.value_loss = ppo_loss.ppo2_value_loss( value_old=value_tensor, pred_value=self.value_new, returns=return_tensor) else: self.value_loss = ppo_loss.ppo1_value_loss( pred_value=self.value_new, returns=return_tensor) # MSE loss between mean and standard deviations self.mean_mse_loss, self.logstd_mse_loss = ppo_loss.l2_norm_policy_loss( policy_mean=self.mean_new, policy_logstd=self.logstd_new, mcts_mean=self.mean_old, mcts_logstd=self.logstd_old) mcts_dist = distributions.MultiVariateNormalDiag( mean=self.mean_old, logstd=self.logstd_old) policy_dist = distributions.MultiVariateNormalDiag( mean=self.mean_new, logstd=self.logstd_new) self.imitation_kl_divergence = tf.reduce_mean( policy_dist.kl_divergence(mcts_dist)) # Calculate KL divergence and entropy of new distribution self.kl_divergence = tf.reduce_mean(pd_new.kl_divergence(pd_old)) self.entropy = pd_new.entropy() # Calculate entropy loss self.entropy_loss = tf.reduce_mean(self.entropy) # Calulate total loss total_loss_ppo = (self._policy_coeff * self.policy_loss) + ( self._value_coeff * self.value_loss) - (self._entropy_coeff * self.entropy_loss) total_loss_mcts = (self._value_coeff * self.value_loss) + ( self._mse_loss_coeff * (self.imitation_kl_divergence + self.entropy_loss)) self.total_loss = tf.cond(tf.equal(self.mcts_sampling_enable, True), lambda: total_loss_mcts, lambda: total_loss_ppo)