Exemplo n.º 1
0
    def _build_policy_loss(self, i):
        """ Build policy network loss """
        pol_dist = self.policy._dist

        # Entropy terms
        embedding_entropy, inference_ce, policy_entropy = \
            self._build_entropy_terms(i)

        # Augment the path rewards with entropy terms
        with tf.name_scope("augmented_rewards"):
            rewards = i.reward_var \
                      - (self.inference_ce_coeff * inference_ce) \
                      + (self.policy_ent_coeff * policy_entropy)

        with tf.name_scope("policy_loss"):
            with tf.name_scope("advantages"):
                advantages = compute_advantages(self.discount,
                                                self.gae_lambda,
                                                self.max_path_length,
                                                i.baseline_var,
                                                rewards,
                                                name="advantages")

                # Flatten and filter valids
                adv_flat = flatten_batch(advantages, name="adv_flat")
                adv_valid = filter_valids(adv_flat,
                                          i.flat.valid_var,
                                          name="adv_valid")

            policy_dist_info_flat = self.policy.dist_info_sym(
                i.flat.task_var,
                i.flat.obs_var,
                i.flat.policy_state_info_vars,
                name="policy_dist_info_flat")
            policy_dist_info_valid = filter_valids_dict(
                policy_dist_info_flat,
                i.flat.valid_var,
                name="policy_dist_info_valid")

            # Optionally normalize advantages
            eps = tf.constant(1e-8, dtype=tf.float32)
            if self.center_adv:
                with tf.name_scope("center_adv"):
                    mean, var = tf.nn.moments(adv_valid, axes=[0])
                    adv_valid = tf.nn.batch_normalization(
                        adv_valid, mean, var, 0, 1, eps)
            if self.positive_adv:
                with tf.name_scope("positive_adv"):
                    m = tf.reduce_min(adv_valid)
                    adv_valid = (adv_valid - m) + eps

            # Calculate loss function and KL divergence
            with tf.name_scope("kl"):
                kl = pol_dist.kl_sym(
                    i.valid.policy_old_dist_info_vars,
                    policy_dist_info_valid,
                )
                pol_mean_kl = tf.reduce_mean(kl)

            # Calculate surrogate loss
            with tf.name_scope("surr_loss"):
                lr = pol_dist.likelihood_ratio_sym(
                    i.valid.action_var,
                    i.valid.policy_old_dist_info_vars,
                    policy_dist_info_valid,
                    name="lr")

                # Policy gradient surrogate objective
                surr_vanilla = lr * adv_valid

                if self._pg_loss == PGLoss.VANILLA:
                    # VPG, TRPO use the standard surrogate objective
                    surr_obj = tf.identity(surr_vanilla, name="surr_obj")
                elif self._pg_loss == PGLoss.CLIP:
                    # PPO uses a surrogate objective with clipped LR
                    lr_clip = tf.clip_by_value(lr,
                                               1 - self.lr_clip_range,
                                               1 + self.lr_clip_range,
                                               name="lr_clip")
                    surr_clip = lr_clip * adv_valid
                    surr_obj = tf.minimum(surr_vanilla,
                                          surr_clip,
                                          name="surr_obj")
                else:
                    raise NotImplementedError("Unknown PGLoss")

                # Maximize E[surrogate objective] by minimizing
                # -E_t[surrogate objective]
                surr_loss = -tf.reduce_mean(surr_obj)

                # Embedding entropy bonus
                surr_loss -= self.embedding_ent_coeff * embedding_entropy

            embed_mean_kl = self._build_embedding_kl(i)

        # Diagnostic functions
        self.f_policy_kl = tensor_utils.compile_function(
            flatten_inputs(self._policy_opt_inputs),
            pol_mean_kl,
            log_name="f_policy_kl")

        self.f_rewards = tensor_utils.compile_function(flatten_inputs(
            self._policy_opt_inputs),
                                                       rewards,
                                                       log_name="f_rewards")

        # returns = self._build_returns(rewards)
        returns = discounted_returns(self.discount,
                                     self.max_path_length,
                                     rewards,
                                     name="returns")
        self.f_returns = tensor_utils.compile_function(flatten_inputs(
            self._policy_opt_inputs),
                                                       returns,
                                                       log_name="f_returns")

        return surr_loss, pol_mean_kl, embed_mean_kl
Exemplo n.º 2
0
    def _build_policy_loss(self, i):
        pol_dist = self.policy.distribution

        policy_entropy = self._build_entropy_term(i)

        with tf.name_scope("augmented_rewards"):
            rewards = i.reward_var + (self.policy_ent_coeff * policy_entropy)

        with tf.name_scope("policy_loss"):
            advantages = compute_advantages(
                self.discount,
                self.gae_lambda,
                self.max_path_length,
                i.baseline_var,
                rewards,
                name="advantages")

            adv_flat = flatten_batch(advantages, name="adv_flat")
            adv_valid = filter_valids(
                adv_flat, i.flat.valid_var, name="adv_valid")

            if self.policy.recurrent:
                advantages = tf.reshape(advantages, [-1, self.max_path_length])

            # Optionally normalize advantages
            eps = tf.constant(1e-8, dtype=tf.float32)
            if self.center_adv:
                with tf.name_scope("center_adv"):
                    mean, var = tf.nn.moments(adv_valid, axes=[0])
                    adv_valid = tf.nn.batch_normalization(
                        adv_valid, mean, var, 0, 1, eps)
            if self.positive_adv:
                with tf.name_scope("positive_adv"):
                    m = tf.reduce_min(adv_valid)
                    adv_valid = (adv_valid - m) + eps

            if self.policy.recurrent:
                policy_dist_info = self.policy.dist_info_sym(
                    i.obs_var,
                    i.policy_state_info_vars,
                    name="policy_dist_info")
            else:
                policy_dist_info_flat = self.policy.dist_info_sym(
                    i.flat.obs_var,
                    i.flat.policy_state_info_vars,
                    name="policy_dist_info_flat")

                policy_dist_info_valid = filter_valids_dict(
                    policy_dist_info_flat,
                    i.flat.valid_var,
                    name="policy_dist_info_valid")

            # Calculate loss function and KL divergence
            with tf.name_scope("kl"):
                if self.policy.recurrent:
                    kl = pol_dist.kl_sym(
                        i.policy_old_dist_info_vars,
                        policy_dist_info,
                    )
                    pol_mean_kl = tf.reduce_sum(
                        kl * i.valid_var) / tf.reduce_sum(i.valid_var)
                else:
                    kl = pol_dist.kl_sym(
                        i.valid.policy_old_dist_info_vars,
                        policy_dist_info_valid,
                    )
                    pol_mean_kl = tf.reduce_mean(kl)

            # Calculate vanilla loss
            with tf.name_scope("vanilla_loss"):
                if self.policy.recurrent:
                    ll = pol_dist.log_likelihood_sym(
                        i.action_var, policy_dist_info, name="log_likelihood")

                    vanilla = ll * advantages * i.valid_var
                else:
                    ll = pol_dist.log_likelihood_sym(
                        i.valid.action_var,
                        policy_dist_info_valid,
                        name="log_likelihood")

                    vanilla = ll * adv_valid

            # Calculate surrogate loss
            with tf.name_scope("surrogate_loss"):
                if self.policy.recurrent:
                    lr = pol_dist.likelihood_ratio_sym(
                        i.action_var,
                        i.policy_old_dist_info_vars,
                        policy_dist_info,
                        name="lr")

                    surrogate = lr * advantages * i.valid_var
                else:
                    lr = pol_dist.likelihood_ratio_sym(
                        i.valid.action_var,
                        i.valid.policy_old_dist_info_vars,
                        policy_dist_info_valid,
                        name="lr")

                    surrogate = lr * adv_valid

            # Finalize objective function
            with tf.name_scope("loss"):
                if self._pg_loss == PGLoss.VANILLA:
                    # VPG uses the vanilla objective
                    obj = tf.identity(vanilla, name="vanilla_obj")
                elif self._pg_loss == PGLoss.SURROGATE:
                    # TRPO uses the standard surrogate objective
                    obj = tf.identity(surrogate, name="surr_obj")
                elif self._pg_loss == PGLoss.SURROGATE_CLIP:
                    lr_clip = tf.clip_by_value(
                        lr,
                        1 - self.lr_clip_range,
                        1 + self.lr_clip_range,
                        name="lr_clip")
                    if self.policy.recurrent:
                        surr_clip = lr_clip * advantages * i.valid_var
                    else:
                        surr_clip = lr_clip * adv_valid
                    obj = tf.minimum(surrogate, surr_clip, name="surr_obj")
                else:
                    raise NotImplementedError("Unknown PGLoss")

                # Maximize E[surrogate objective] by minimizing
                # -E_t[surrogate objective]
                if self.policy.recurrent:
                    loss = -tf.reduce_sum(obj) / tf.reduce_sum(i.valid_var)
                else:
                    loss = -tf.reduce_mean(obj)

            # Diagnostic functions
            self.f_policy_kl = tensor_utils.compile_function(
                flatten_inputs(self._policy_opt_inputs),
                pol_mean_kl,
                log_name="f_policy_kl")

            self.f_rewards = tensor_utils.compile_function(
                flatten_inputs(self._policy_opt_inputs),
                rewards,
                log_name="f_rewards")

            returns = discounted_returns(self.discount, self.max_path_length,
                                         rewards)
            self.f_returns = tensor_utils.compile_function(
                flatten_inputs(self._policy_opt_inputs),
                returns,
                log_name="f_returns")

            return loss, pol_mean_kl
Exemplo n.º 3
0
    def _build_policy_loss(self, i):
        """Build policy loss and other output tensors.

        Args:
            i (namedtuple): Collection of variables to compute policy loss.

        Returns:
            tf.Tensor: Policy loss.
            tf.Tensor: Mean policy KL divergence.

        """
        policy_entropy = self._build_entropy_term(i)
        rewards = i.reward_var

        if self._maximum_entropy:
            with tf.name_scope('augmented_rewards'):
                rewards = i.reward_var + (self._policy_ent_coeff *
                                          policy_entropy)

        with tf.name_scope('policy_loss'):
            adv = compute_advantages(self._discount,
                                     self._gae_lambda,
                                     self.max_path_length,
                                     i.baseline_var,
                                     rewards,
                                     name='adv')

            adv = tf.reshape(adv, [-1, self.max_path_length])
            # Optionally normalize advantages
            eps = tf.constant(1e-8, dtype=tf.float32)
            if self._center_adv:
                adv = center_advs(adv, axes=[0], eps=eps)

            if self._positive_adv:
                adv = positive_advs(adv, eps)

            with tf.name_scope('kl'):
                kl = self._old_policy.distribution.kl_divergence(
                    self.policy.distribution)
                pol_mean_kl = tf.reduce_mean(kl)

            # Calculate vanilla loss
            with tf.name_scope('vanilla_loss'):
                ll = self.policy.distribution.log_prob(i.action_var,
                                                       name='log_likelihood')
                vanilla = ll * adv

            # Calculate surrogate loss
            with tf.name_scope('surrogate_loss'):
                lr = tf.exp(
                    ll - self._old_policy.distribution.log_prob(i.action_var))
                surrogate = lr * adv

            # Finalize objective function
            with tf.name_scope('loss'):
                if self._pg_loss == 'vanilla':
                    # VPG uses the vanilla objective
                    obj = tf.identity(vanilla, name='vanilla_obj')
                elif self._pg_loss == 'surrogate':
                    # TRPO uses the standard surrogate objective
                    obj = tf.identity(surrogate, name='surr_obj')
                elif self._pg_loss == 'surrogate_clip':
                    lr_clip = tf.clip_by_value(lr,
                                               1 - self._lr_clip_range,
                                               1 + self._lr_clip_range,
                                               name='lr_clip')
                    surr_clip = lr_clip * adv
                    obj = tf.minimum(surrogate, surr_clip, name='surr_obj')

                if self._entropy_regularzied:
                    obj += self._policy_ent_coeff * policy_entropy

                # filter only the valid values
                obj = tf.boolean_mask(obj, i.valid_var)
                # Maximize E[surrogate objective] by minimizing
                # -E_t[surrogate objective]
                loss = -tf.reduce_mean(obj)

            # Diagnostic functions
            self._f_policy_kl = tf.compat.v1.get_default_session(
            ).make_callable(pol_mean_kl,
                            feed_list=flatten_inputs(self._policy_opt_inputs))

            self._f_rewards = tf.compat.v1.get_default_session().make_callable(
                rewards, feed_list=flatten_inputs(self._policy_opt_inputs))

            returns = discounted_returns(self._discount, self.max_path_length,
                                         rewards)
            self._f_returns = tf.compat.v1.get_default_session().make_callable(
                returns, feed_list=flatten_inputs(self._policy_opt_inputs))

            return loss, pol_mean_kl
Exemplo n.º 4
0
    def _build_policy_loss(self, i):
        """Build policy loss and other output tensors.

        Args:
            i (namedtuple): Collection of variables to compute policy loss.

        Returns:
            tf.Tensor: Policy loss.
            tf.Tensor: Mean policy KL divergence.

        """
        # pylint: disable=too-many-statements
        self._policy_network, self._encoder_network = (self.policy.build(
            i.augmented_obs_var, i.task_var, name='loss_policy'))
        self._old_policy_network, self._old_encoder_network = (
            self._old_policy.build(i.augmented_obs_var,
                                   i.task_var,
                                   name='loss_old_policy'))
        self._infer_network = self._inference.build(i.augmented_traj_var,
                                                    name='loss_infer')
        self._old_infer_network = self._old_inference.build(
            i.augmented_traj_var, name='loss_old_infer')

        pol_dist = self._policy_network.dist
        old_pol_dist = self._old_policy_network.dist

        # Entropy terms
        encoder_entropy, inference_ce, policy_entropy = (
            self._build_entropy_terms(i))

        # Augment the path rewards with entropy terms
        with tf.name_scope('augmented_rewards'):
            rewards = (i.reward_var -
                       (self.inference_ce_coeff * inference_ce) +
                       (self._policy_ent_coeff * policy_entropy))

        with tf.name_scope('policy_loss'):
            with tf.name_scope('advantages'):
                adv = compute_advantages(self._discount,
                                         self._gae_lambda,
                                         self.max_path_length,
                                         i.baseline_var,
                                         rewards,
                                         name='advantages')
                adv = tf.reshape(adv, [-1, self.max_path_length])

            # Optionally normalize advantages
            eps = tf.constant(1e-8, dtype=tf.float32)
            if self._center_adv:
                adv = center_advs(adv, axes=[0], eps=eps)

            if self._positive_adv:
                adv = positive_advs(adv, eps)

            # Calculate loss function and KL divergence
            with tf.name_scope('kl'):
                kl = old_pol_dist.kl_divergence(pol_dist)
                pol_mean_kl = tf.reduce_mean(kl)

            ll = pol_dist.log_prob(i.action_var, name='log_likelihood')

            # Calculate surrogate loss
            with tf.name_scope('surr_loss'):
                old_ll = old_pol_dist.log_prob(i.action_var)
                old_ll = tf.stop_gradient(old_ll)
                # Clip early to avoid overflow
                lr = tf.exp(
                    tf.minimum(ll - old_ll, np.log(1 + self._lr_clip_range)))

                surrogate = lr * adv

                surrogate = tf.debugging.check_numerics(surrogate,
                                                        message='surrogate')

            # Finalize objective function
            with tf.name_scope('loss'):
                lr_clip = tf.clip_by_value(lr,
                                           1 - self._lr_clip_range,
                                           1 + self._lr_clip_range,
                                           name='lr_clip')
                surr_clip = lr_clip * adv
                obj = tf.minimum(surrogate, surr_clip, name='surr_obj')
                obj = tf.boolean_mask(obj, i.valid_var)
                # Maximize E[surrogate objective] by minimizing
                # -E_t[surrogate objective]
                loss = -tf.reduce_mean(obj)

                # Encoder entropy bonus
                loss -= self.encoder_ent_coeff * encoder_entropy

            encoder_mean_kl = self._build_encoder_kl()

            # Diagnostic functions
            self._f_policy_kl = tf.compat.v1.get_default_session(
            ).make_callable(pol_mean_kl,
                            feed_list=flatten_inputs(self._policy_opt_inputs))

            self._f_rewards = tf.compat.v1.get_default_session().make_callable(
                rewards, feed_list=flatten_inputs(self._policy_opt_inputs))

            returns = discounted_returns(self._discount,
                                         self.max_path_length,
                                         rewards,
                                         name='returns')
            self._f_returns = tf.compat.v1.get_default_session().make_callable(
                returns, feed_list=flatten_inputs(self._policy_opt_inputs))

        return loss, pol_mean_kl, encoder_mean_kl
Exemplo n.º 5
0
    def _build_policy_loss(self, i):
        pol_dist = self.policy.distribution
        policy_entropy = self._build_entropy_term(i)
        rewards = i.reward_var

        if self._maximum_entropy:
            with tf.name_scope('augmented_rewards'):
                rewards = i.reward_var + self.policy_ent_coeff * policy_entropy

        with tf.name_scope('policy_loss'):
            adv = compute_advantages(self.discount,
                                     self.gae_lambda,
                                     self.max_path_length,
                                     i.baseline_var,
                                     rewards,
                                     name='adv')

            adv_flat = flatten_batch(adv, name='adv_flat')
            adv_valid = filter_valids(adv_flat,
                                      i.flat.valid_var,
                                      name='adv_valid')

            if self.policy.recurrent:
                adv = tf.reshape(adv, [-1, self.max_path_length])

            # Optionally normalize advantages
            eps = tf.constant(1e-8, dtype=tf.float32)
            if self.center_adv:
                if self.policy.recurrent:
                    adv = center_advs(adv, axes=[0], eps=eps)
                else:
                    adv_valid = center_advs(adv_valid, axes=[0], eps=eps)

            if self.positive_adv:
                if self.policy.recurrent:
                    adv = positive_advs(adv, eps)
                else:
                    adv_valid = positive_advs(adv_valid, eps)

            if self.policy.recurrent:
                policy_dist_info = self.policy.dist_info_sym(
                    i.obs_var,
                    i.policy_state_info_vars,
                    name='policy_dist_info')
            else:
                policy_dist_info_flat = self.policy.dist_info_sym(
                    i.flat.obs_var,
                    i.flat.policy_state_info_vars,
                    name='policy_dist_info_flat')

                policy_dist_info_valid = filter_valids_dict(
                    policy_dist_info_flat,
                    i.flat.valid_var,
                    name='policy_dist_info_valid')

                policy_dist_info = policy_dist_info_valid

            # Calculate loss function and KL divergence
            with tf.name_scope('kl'):
                if self.policy.recurrent:
                    kl = pol_dist.kl_sym(
                        i.policy_old_dist_info_vars,
                        policy_dist_info,
                    )
                    pol_mean_kl = tf.reduce_sum(
                        kl * i.valid_var) / tf.reduce_sum(i.valid_var)
                else:
                    kl = pol_dist.kl_sym(
                        i.valid.policy_old_dist_info_vars,
                        policy_dist_info_valid,
                    )
                    pol_mean_kl = tf.reduce_mean(kl)

            # Calculate vanilla loss
            with tf.name_scope('vanilla_loss'):
                if self.policy.recurrent:
                    ll = pol_dist.log_likelihood_sym(i.action_var,
                                                     policy_dist_info,
                                                     name='log_likelihood')

                    vanilla = ll * adv * i.valid_var
                else:
                    ll = pol_dist.log_likelihood_sym(i.valid.action_var,
                                                     policy_dist_info_valid,
                                                     name='log_likelihood')

                    vanilla = ll * adv_valid

            # Calculate surrogate loss
            with tf.name_scope('surrogate_loss'):
                if self.policy.recurrent:
                    lr = pol_dist.likelihood_ratio_sym(
                        i.action_var,
                        i.policy_old_dist_info_vars,
                        policy_dist_info,
                        name='lr')

                    surrogate = lr * adv * i.valid_var
                else:
                    lr = pol_dist.likelihood_ratio_sym(
                        i.valid.action_var,
                        i.valid.policy_old_dist_info_vars,
                        policy_dist_info_valid,
                        name='lr')

                    surrogate = lr * adv_valid

            # Finalize objective function
            with tf.name_scope('loss'):
                if self._pg_loss == 'vanilla':
                    # VPG uses the vanilla objective
                    obj = tf.identity(vanilla, name='vanilla_obj')
                elif self._pg_loss == 'surrogate':
                    # TRPO uses the standard surrogate objective
                    obj = tf.identity(surrogate, name='surr_obj')
                elif self._pg_loss == 'surrogate_clip':
                    lr_clip = tf.clip_by_value(lr,
                                               1 - self.lr_clip_range,
                                               1 + self.lr_clip_range,
                                               name='lr_clip')
                    if self.policy.recurrent:
                        surr_clip = lr_clip * adv * i.valid_var
                    else:
                        surr_clip = lr_clip * adv_valid
                    obj = tf.minimum(surrogate, surr_clip, name='surr_obj')

                if self._entropy_regularzied:
                    obj += self.policy_ent_coeff * policy_entropy

                # Maximize E[surrogate objective] by minimizing
                # -E_t[surrogate objective]
                if self.policy.recurrent:
                    loss = -tf.reduce_sum(obj) / tf.reduce_sum(i.valid_var)
                else:
                    loss = -tf.reduce_mean(obj)

            # Diagnostic functions
            self.f_policy_kl = tensor_utils.compile_function(
                flatten_inputs(self._policy_opt_inputs),
                pol_mean_kl,
                log_name='f_policy_kl')

            self.f_rewards = tensor_utils.compile_function(
                flatten_inputs(self._policy_opt_inputs),
                rewards,
                log_name='f_rewards')

            returns = discounted_returns(self.discount, self.max_path_length,
                                         rewards)
            self.f_returns = tensor_utils.compile_function(
                flatten_inputs(self._policy_opt_inputs),
                returns,
                log_name='f_returns')

            return loss, pol_mean_kl
Exemplo n.º 6
0
    def _build_policy_loss(self, i):
        pol_dist = self.policy.distribution

        policy_entropy = self._build_entropy_term(i)

        with tf.name_scope('augmented_rewards'):
            rewards = i.reward_var + (self.policy_ent_coeff * policy_entropy)

        with tf.name_scope('policy_loss'):
            advantages = compute_advantages(self.discount,
                                            self.gae_lambda,
                                            self.max_path_length,
                                            i.baseline_var,
                                            rewards,
                                            name='advantages')

            adv_flat = flatten_batch(advantages, name='adv_flat')
            adv_valid = filter_valids(adv_flat,
                                      i.flat.valid_var,
                                      name='adv_valid')

            if self.policy.recurrent:
                advantages = tf.reshape(advantages, [-1, self.max_path_length])

            # Optionally normalize advantages
            eps = tf.constant(1e-8, dtype=tf.float32)
            if self.center_adv:
                with tf.name_scope('center_adv'):
                    mean, var = tf.nn.moments(adv_valid, axes=[0])
                    adv_valid = tf.nn.batch_normalization(
                        adv_valid, mean, var, 0, 1, eps)
            if self.positive_adv:
                with tf.name_scope('positive_adv'):
                    m = tf.reduce_min(adv_valid)
                    adv_valid = (adv_valid - m) + eps

            if self.policy.recurrent:
                policy_dist_info = self.policy.dist_info_sym(
                    i.obs_var,
                    i.policy_state_info_vars,
                    name='policy_dist_info')
            else:
                policy_dist_info_flat = self.policy.dist_info_sym(
                    i.flat.obs_var,
                    i.flat.policy_state_info_vars,
                    name='policy_dist_info_flat')

                policy_dist_info_valid = filter_valids_dict(
                    policy_dist_info_flat,
                    i.flat.valid_var,
                    name='policy_dist_info_valid')

            # Calculate loss function and KL divergence
            with tf.name_scope('kl'):
                if self.policy.recurrent:
                    kl = pol_dist.kl_sym(
                        i.policy_old_dist_info_vars,
                        policy_dist_info,
                    )
                    pol_mean_kl = tf.reduce_sum(
                        kl * i.valid_var) / tf.reduce_sum(i.valid_var)
                else:
                    kl = pol_dist.kl_sym(
                        i.valid.policy_old_dist_info_vars,
                        policy_dist_info_valid,
                    )
                    pol_mean_kl = tf.reduce_mean(kl)

            # Calculate surrogate loss
            with tf.name_scope('surr_loss'):
                if self.policy.recurrent:
                    lr = pol_dist.likelihood_ratio_sym(
                        i.action_var,
                        i.policy_old_dist_info_vars,
                        policy_dist_info,
                        name='lr')

                    surr_vanilla = lr * advantages * i.valid_var
                else:
                    lr = pol_dist.likelihood_ratio_sym(
                        i.valid.action_var,
                        i.valid.policy_old_dist_info_vars,
                        policy_dist_info_valid,
                        name='lr')

                    surr_vanilla = lr * adv_valid

                if self._pg_loss == PGLoss.VANILLA:
                    # VPG, TRPO use the standard surrogate objective
                    surr_obj = tf.identity(surr_vanilla, name='surr_obj')
                elif self._pg_loss == PGLoss.CLIP:
                    lr_clip = tf.clip_by_value(lr,
                                               1 - self.clip_range,
                                               1 + self.clip_range,
                                               name='lr_clip')
                    if self.policy.recurrent:
                        surr_clip = lr_clip * advantages * i.valid_var
                    else:
                        surr_clip = lr_clip * adv_valid
                    surr_obj = tf.minimum(surr_vanilla,
                                          surr_clip,
                                          name='surr_obj')
                else:
                    raise NotImplementedError('Unknown PGLoss')

                # Maximize E[surrogate objective] by minimizing
                # -E_t[surrogate objective]
                if self.policy.recurrent:
                    surr_loss = (-tf.reduce_sum(surr_vanilla)) / tf.reduce_sum(
                        i.valid_var)
                else:
                    surr_loss = -tf.reduce_mean(surr_obj)

            # Diagnostic functions
            self.f_policy_kl = compile_function(flatten_inputs(
                self._policy_opt_inputs),
                                                pol_mean_kl,
                                                log_name='f_policy_kl')

            self.f_rewards = compile_function(flatten_inputs(
                self._policy_opt_inputs),
                                              rewards,
                                              log_name='f_rewards')

            returns = discounted_returns(self.discount, self.max_path_length,
                                         rewards)
            self.f_returns = compile_function(flatten_inputs(
                self._policy_opt_inputs),
                                              returns,
                                              log_name='f_returns')

            return surr_loss, pol_mean_kl