示例#1
0
    def update_opt(self, loss, target, inputs, extra_inputs=None, **kwargs):
        """
        :param loss: Symbolic expression for the loss function.
        :param target: A parameterized object to optimize over. It should
         implement methods of the
        :class:`garage.core.paramerized.Parameterized` class.
        :param leq_constraint: A constraint provided as a tuple (f, epsilon),
         of the form f(*inputs) <= epsilon.
        :param inputs: A list of symbolic variables as inputs
        :return: No return value.
        """
        with tf.variable_scope(self._name,
                               values=[
                                   loss,
                                   target.get_params(trainable=True), inputs,
                                   extra_inputs
                               ]):

            self._target = target

            self._train_op = self._tf_optimizer.minimize(
                loss, var_list=target.get_params(trainable=True))

            # updates = OrderedDict(
            #     [(k, v.astype(k.dtype)) for k, v in updates.iteritems()])

            if extra_inputs is None:
                extra_inputs = list()
            self._input_vars = inputs + extra_inputs
            self._opt_fun = ext.LazyDict(
                f_loss=lambda: tensor_utils.compile_function(
                    inputs + extra_inputs, loss), )
示例#2
0
    def update_opt(self,
                   loss,
                   target,
                   leq_constraint,
                   inputs,
                   constraint_name="constraint",
                   name=None,
                   *args,
                   **kwargs):
        """
        :param loss: Symbolic expression for the loss function.
        :param target: A parameterized object to optimize over. It should
         implement methods of the
         :class:`garage.core.paramerized.Parameterized` class.
        :param leq_constraint: A constraint provided as a tuple (f, epsilon),
         of the form f(*inputs) <= epsilon.
        :param inputs: A list of symbolic variables as inputs
        :return: No return value.
        """
        params = target.get_params(trainable=True)
        with tf.name_scope(name, "PenaltyLbfgsOptimizer",
                           [leq_constraint, loss, params]):
            constraint_term, constraint_value = leq_constraint
            penalty_var = tf.placeholder(tf.float32, tuple(), name="penalty")
            penalized_loss = loss + penalty_var * constraint_term

            self._target = target
            self._max_constraint_val = constraint_value
            self._constraint_name = constraint_name

            def get_opt_output():
                with tf.name_scope("get_opt_output",
                                   values=[params, penalized_loss]):
                    grads = tf.gradients(penalized_loss, params)
                    for idx, (grad, param) in enumerate(zip(grads, params)):
                        if grad is None:
                            grads[idx] = tf.zeros_like(param)
                    flat_grad = tensor_utils.flatten_tensor_variables(grads)
                    return [
                        tf.cast(penalized_loss, tf.float64),
                        tf.cast(flat_grad, tf.float64),
                    ]

            self._opt_fun = ext.LazyDict(
                f_loss=lambda: tensor_utils.compile_function(
                    inputs, loss, log_name="f_loss"),
                f_constraint=lambda: tensor_utils.compile_function(
                    inputs, constraint_term, log_name="f_constraint"),
                f_penalized_loss=lambda: tensor_utils.compile_function(
                    inputs=inputs + [penalty_var],
                    outputs=[penalized_loss, loss, constraint_term],
                    log_name="f_penalized_loss",
                ),
                f_opt=lambda: tensor_utils.compile_function(
                    inputs=inputs + [penalty_var],
                    outputs=get_opt_output(),
                ))
    def update_opt(self, f, target, inputs, reg_coeff, name=None):
        self.target = target
        self.reg_coeff = reg_coeff
        params = target.get_params(trainable=True)

        with tf.name_scope(name, "FiniteDifferenceHvp",
                           [f, inputs, params, target]):
            constraint_grads = tf.gradients(f,
                                            xs=params,
                                            name="gradients_constraint")
            for idx, (grad, param) in enumerate(zip(constraint_grads, params)):
                if grad is None:
                    constraint_grads[idx] = tf.zeros_like(param)

            flat_grad = tensor_utils.flatten_tensor_variables(constraint_grads)

            def f_hx_plain(*args):
                with tf.name_scope("f_hx_plain", values=[inputs, self.target]):
                    inputs_ = args[:len(inputs)]
                    xs = args[len(inputs):]
                    flat_xs = np.concatenate(
                        [np.reshape(x, (-1, )) for x in xs])
                    param_val = self.target.get_param_values(trainable=True)
                    eps = np.cast['float32'](
                        self.base_eps / (np.linalg.norm(param_val) + 1e-8))
                    self.target.set_param_values(param_val + eps * flat_xs,
                                                 trainable=True)
                    flat_grad_dvplus = self.opt_fun["f_grad"](*inputs_)
                    self.target.set_param_values(param_val, trainable=True)
                    if self.symmetric:
                        self.target.set_param_values(param_val - eps * flat_xs,
                                                     trainable=True)
                        flat_grad_dvminus = self.opt_fun["f_grad"](*inputs_)
                        hx = (flat_grad_dvplus - flat_grad_dvminus) / (2 * eps)
                        self.target.set_param_values(param_val, trainable=True)
                    else:
                        flat_grad = self.opt_fun["f_grad"](*inputs_)
                        hx = (flat_grad_dvplus - flat_grad) / eps
                    return hx

            self.opt_fun = ext.LazyDict(
                f_grad=lambda: tensor_utils.compile_function(
                    inputs=inputs,
                    outputs=flat_grad,
                    log_name="f_grad",
                ),
                f_hx_plain=lambda: f_hx_plain,
            )
    def update_opt(self, f, target, inputs, reg_coeff, name=None):
        self.target = target
        self.reg_coeff = reg_coeff
        params = target.get_params(trainable=True)
        with tf.name_scope(name, "PerlmutterHvp", [f, inputs, params]):
            constraint_grads = tf.gradients(f,
                                            xs=params,
                                            name="gradients_constraint")
            for idx, (grad, param) in enumerate(zip(constraint_grads, params)):
                if grad is None:
                    constraint_grads[idx] = tf.zeros_like(param)

            xs = tuple([
                tensor_utils.new_tensor_like(p.name.split(":")[0], p)
                for p in params
            ])

            def hx_plain():
                with tf.name_scope("hx_plain",
                                   values=[constraint_grads, params, xs]):
                    with tf.name_scope("hx_function",
                                       values=[constraint_grads, xs]):
                        hx_f = tf.reduce_sum(
                            tf.stack([
                                tf.reduce_sum(g * x)
                                for g, x in zip(constraint_grads, xs)
                            ])),
                    hx_plain_splits = tf.gradients(hx_f,
                                                   params,
                                                   name="gradients_hx_plain")
                    for idx, (hx,
                              param) in enumerate(zip(hx_plain_splits,
                                                      params)):
                        if hx is None:
                            hx_plain_splits[idx] = tf.zeros_like(param)
                    return tensor_utils.flatten_tensor_variables(
                        hx_plain_splits)

            self.opt_fun = ext.LazyDict(
                f_hx_plain=lambda: tensor_utils.compile_function(
                    inputs=inputs + xs,
                    outputs=hx_plain(),
                    log_name="f_hx_plain",
                ), )
示例#5
0
    def update_opt(self,
                   loss,
                   target,
                   inputs,
                   extra_inputs=None,
                   name=None,
                   *args,
                   **kwargs):
        """
        :param loss: Symbolic expression for the loss function.
        :param target: A parameterized object to optimize over. It should
         implement methods of the
        :class:`garage.core.paramerized.Parameterized` class.
        :param leq_constraint: A constraint provided as a tuple (f, epsilon),
         of the form f(*inputs) <= epsilon.
        :param inputs: A list of symbolic variables as inputs
        :return: No return value.
        """
        self._target = target
        params = target.get_params(trainable=True)
        with tf.name_scope(name, "LbfgsOptimizer",
                           [loss, inputs, params, extra_inputs]):

            def get_opt_output():
                with tf.name_scope("get_opt_output", [loss, params]):
                    flat_grad = tensor_utils.flatten_tensor_variables(
                        tf.gradients(loss, params))
                    return [
                        tf.cast(loss, tf.float64),
                        tf.cast(flat_grad, tf.float64)
                    ]

            if extra_inputs is None:
                extra_inputs = list()

            self._opt_fun = ext.LazyDict(
                f_loss=lambda: tensor_utils.compile_function(
                    inputs + extra_inputs, loss),
                f_opt=lambda: tensor_utils.compile_function(
                    inputs=inputs + extra_inputs,
                    outputs=get_opt_output(),
                ))
    def update_opt(self, f, target, inputs, reg_coeff):
        self.target = target
        self.reg_coeff = reg_coeff

        params = target.get_params(trainable=True)

        constraint_grads = theano.grad(f,
                                       wrt=params,
                                       disconnected_inputs='warn')
        flat_grad = tensor_utils.flatten_tensor_variables(constraint_grads)

        def f_hx_plain(*args):
            inputs_ = args[:len(inputs)]
            xs = args[len(inputs):]
            flat_xs = np.concatenate([np.reshape(x, (-1, )) for x in xs])
            param_val = self.target.get_param_values(trainable=True)
            eps = np.cast['float32'](self.base_eps /
                                     (np.linalg.norm(param_val) + 1e-8))
            self.target.set_param_values(param_val + eps * flat_xs,
                                         trainable=True)
            flat_grad_dvplus = self.opt_fun["f_grad"](*inputs_)
            if self.symmetric:
                self.target.set_param_values(param_val - eps * flat_xs,
                                             trainable=True)
                flat_grad_dvminus = self.opt_fun["f_grad"](*inputs_)
                hx = (flat_grad_dvplus - flat_grad_dvminus) / (2 * eps)
                self.target.set_param_values(param_val, trainable=True)
            else:
                self.target.set_param_values(param_val, trainable=True)
                flat_grad = self.opt_fun["f_grad"](*inputs_)
                hx = (flat_grad_dvplus - flat_grad) / eps
            return hx

        self.opt_fun = ext.LazyDict(
            f_grad=lambda: tensor_utils.compile_function(
                inputs=inputs,
                outputs=flat_grad,
                log_name="f_grad",
            ),
            f_hx_plain=lambda: f_hx_plain,
        )
示例#7
0
    def update_opt(self,
                   loss,
                   target,
                   inputs,
                   extra_inputs=None,
                   gradients=None,
                   **kwargs):
        """
        :param loss: Symbolic expression for the loss function.
        :param target: A parameterized object to optimize over. It should
         implement methods of the
         :class:`garage.core.paramerized.Parameterized` class.
        :param leq_constraint: A constraint provided as a tuple (f, epsilon),
         of the form f(*inputs) <= epsilon.
        :param inputs: A list of symbolic variables as inputs
        :return: No return value.
        """

        self._target = target

        if gradients is None:
            gradients = theano.grad(loss,
                                    target.get_params(trainable=True),
                                    disconnected_inputs='ignore')
        updates = self._update_method(gradients,
                                      target.get_params(trainable=True))
        updates = OrderedDict([(k, v.astype(k.dtype))
                               for k, v in updates.items()])

        if extra_inputs is None:
            extra_inputs = list()

        self._opt_fun = ext.LazyDict(
            f_loss=lambda: tensor_utils.compile_function(
                inputs + extra_inputs, loss),
            f_opt=lambda: tensor_utils.compile_function(
                inputs=inputs + extra_inputs,
                outputs=loss,
                updates=updates,
            ))
    def update_opt(self, f, target, inputs, reg_coeff):
        self.target = target
        self.reg_coeff = reg_coeff
        params = target.get_params(trainable=True)

        constraint_grads = theano.grad(f,
                                       wrt=params,
                                       disconnected_inputs='warn')
        xs = tuple([tensor_utils.new_tensor_like("%s x" % p.name, p) \
                    for p in params])

        def hx_plain():
            hx_plain_splits = TT.grad(TT.sum(
                [TT.sum(g * x) for g, x in zip(constraint_grads, xs)]),
                                      wrt=params,
                                      disconnected_inputs='warn')
            return TT.concatenate([TT.flatten(s) for s in hx_plain_splits])

        self.opt_fun = ext.LazyDict(
            f_hx_plain=lambda: tensor_utils.compile_function(
                inputs=inputs + xs,
                outputs=hx_plain(),
                log_name="f_hx_plain",
            ), )
    def update_opt(self,
                   loss,
                   target,
                   leq_constraint,
                   inputs,
                   extra_inputs=None,
                   constraint_name="constraint",
                   *args,
                   **kwargs):
        """
        :param loss: Symbolic expression for the loss function.
        :param target: A parameterized object to optimize over. It should
         implement methods of the
         :class:`garage.core.paramerized.Parameterized` class.
        :param leq_constraint: A constraint provided as a tuple (f, epsilon),
         of the form f(*inputs) <= epsilon.
        :param inputs: A list of symbolic variables as inputs, which could be
         subsampled if needed. It is assumed that the first dimension of these
         inputs should correspond to the number of data points
        :param extra_inputs: A list of symbolic variables as extra inputs which
         should not be subsampled
        :return: No return value.
        """

        inputs = tuple(inputs)
        if extra_inputs is None:
            extra_inputs = tuple()
        else:
            extra_inputs = tuple(extra_inputs)

        constraint_term, constraint_value = leq_constraint

        params = target.get_params(trainable=True)
        grads = theano.grad(loss, wrt=params, disconnected_inputs='warn')
        flat_grad = tensor_utils.flatten_tensor_variables(grads)

        self._hvp_approach.update_opt(f=constraint_term,
                                      target=target,
                                      inputs=inputs + extra_inputs,
                                      reg_coeff=self._reg_coeff)

        self._target = target
        self._max_constraint_val = constraint_value
        self._constraint_name = constraint_name

        self._opt_fun = ext.LazyDict(
            f_loss=lambda: tensor_utils.compile_function(
                inputs=inputs + extra_inputs,
                outputs=loss,
                log_name="f_loss",
            ),
            f_grad=lambda: tensor_utils.compile_function(
                inputs=inputs + extra_inputs,
                outputs=flat_grad,
                log_name="f_grad",
            ),
            f_constraint=lambda: tensor_utils.compile_function(
                inputs=inputs + extra_inputs,
                outputs=constraint_term,
                log_name="constraint",
            ),
            f_loss_constraint=lambda: tensor_utils.compile_function(
                inputs=inputs + extra_inputs,
                outputs=[loss, constraint_term],
                log_name="f_loss_constraint",
            ),
        )
    def update_opt(self,
                   loss,
                   target,
                   leq_constraint,
                   inputs,
                   extra_inputs=None,
                   name=None,
                   constraint_name="constraint",
                   *args,
                   **kwargs):
        """
        :param loss: Symbolic expression for the loss function.
        :param target: A parameterized object to optimize over. It should
         implement methods of the
         the :class:`garage.core.paramerized.Parameterized` class.
        :param leq_constraint: A constraint provided as a tuple (f, epsilon),
         of the form f(*inputs) <= epsilon.
        :param inputs: A list of symbolic variables as inputs, which could be
         subsampled if needed. It is assumed that the first dimension of these
         inputs should correspond to the number of data points
        :param extra_inputs: A list of symbolic variables as extra inputs which
         should not be subsampled
        :return: No return value.
        """
        params = target.get_params(trainable=True)
        with tf.name_scope(
                name, "ConjugateGradientOptimizer",
                [loss, target, leq_constraint, inputs, extra_inputs,
                 params]):  # yapf: disable
            inputs = tuple(inputs)
            if extra_inputs is None:
                extra_inputs = tuple()
            else:
                extra_inputs = tuple(extra_inputs)

            constraint_term, constraint_value = leq_constraint

            with tf.name_scope("loss_gradients", values=[loss, params]):
                grads = tf.gradients(loss, xs=params)
                for idx, (grad, param) in enumerate(zip(grads, params)):
                    if grad is None:
                        grads[idx] = tf.zeros_like(param)
                flat_grad = tensor_utils.flatten_tensor_variables(grads)

            self._hvp_approach.update_opt(f=constraint_term,
                                          target=target,
                                          inputs=inputs + extra_inputs,
                                          reg_coeff=self._reg_coeff,
                                          name="update_opt_" + constraint_name)

            self._target = target
            self._max_constraint_val = constraint_value
            self._constraint_name = constraint_name

            self._opt_fun = ext.LazyDict(
                f_loss=lambda: tensor_utils.compile_function(
                    inputs=inputs + extra_inputs,
                    outputs=loss,
                    log_name="f_loss",
                ),
                f_grad=lambda: tensor_utils.compile_function(
                    inputs=inputs + extra_inputs,
                    outputs=flat_grad,
                    log_name="f_grad",
                ),
                f_constraint=lambda: tensor_utils.compile_function(
                    inputs=inputs + extra_inputs,
                    outputs=constraint_term,
                    log_name="constraint",
                ),
                f_loss_constraint=lambda: tensor_utils.compile_function(
                    inputs=inputs + extra_inputs,
                    outputs=[loss, constraint_term],
                    log_name="f_loss_constraint",
                ),
            )
示例#11
0
    def update_opt(self,
                   loss,
                   loss_tilde,
                   target,
                   target_tilde,
                   leq_constraint,
                   inputs,
                   extra_inputs=None,
                   **kwargs):
        """
        :param loss: Symbolic expression for the loss function.
        :param target: A parameterized object to optimize over. It should
         implement methods of the
        :class:`garage.core.paramerized.Parameterized` class.
        :param leq_constraint: A constraint provided as a tuple (f, epsilon),
         of the form f(*inputs) <= epsilon.
        :param inputs: A list of symbolic variables as inputs
        :return: No return value.
        """
        if extra_inputs is None:
            extra_inputs = list()
        self._input_vars = inputs + extra_inputs

        self._target = target
        self._target_tilde = target_tilde

        constraint_term, constraint_value = leq_constraint
        self._max_constraint_val = constraint_value

        w = target.get_params(trainable=True)
        grads = tf.gradients(loss, xs=w)
        for idx, (g, param) in enumerate(zip(grads, w)):
            if g is None:
                grads[idx] = tf.zeros_like(param)
        flat_grad = tensor_utils.flatten_tensor_variables(grads)

        w_tilde = target_tilde.get_params(trainable=True)
        grads_tilde = tf.gradients(loss_tilde, xs=w_tilde)
        for idx, (g_t, param_t) in enumerate(zip(grads_tilde, w_tilde)):
            if g_t is None:
                grads_tilde[idx] = tf.zeros_like(param_t)
        flat_grad_tilde = tensor_utils.flatten_tensor_variables(grads_tilde)

        self._opt_fun = ext.LazyDict(
            f_loss=lambda: tensor_utils.compile_function(
                inputs=inputs + extra_inputs,
                outputs=loss,
            ),
            f_loss_tilde=lambda: tensor_utils.compile_function(
                inputs=inputs + extra_inputs,
                outputs=loss_tilde,
            ),
            f_grad=lambda: tensor_utils.compile_function(
                inputs=inputs + extra_inputs,
                outputs=flat_grad,
            ),
            f_grad_tilde=lambda: tensor_utils.compile_function(
                inputs=inputs + extra_inputs,
                outputs=flat_grad_tilde,
            ),
            f_loss_constraint=lambda: tensor_utils.compile_function(
                inputs=inputs + extra_inputs,
                outputs=[loss, constraint_term],
            ),
        )
        inputs = tuple(inputs)
        if extra_inputs is None:
            extra_inputs = tuple()
        else:
            extra_inputs = tuple(extra_inputs)
示例#12
0
文件: capg.py 项目: Mee321/HAPG_exp
    def init_opt(self):
        with tf.name_scope("inputs"):
            observations_var = self.env.observation_space.new_tensor_variable(
                'observations', extra_dims=1)
            actions_var = self.env.action_space.new_tensor_variable(
                'actions', extra_dims=1)
            advantages_var = tensor_utils.new_tensor('advantage',
                                                     ndim=1,
                                                     dtype=tf.float32)
            dist = self.policy.distribution
            dist_info_vars = self.policy.dist_info_sym(observations_var)
            old_dist_info_vars = self.backup_policy.dist_info_sym(
                observations_var)

        kl = dist.kl_sym(old_dist_info_vars, dist_info_vars)
        mean_kl = tf.reduce_mean(kl)
        max_kl = tf.reduce_max(kl)

        pos_eps_dist_info_vars = self.pos_eps_policy.dist_info_sym(
            observations_var)
        neg_eps_dist_info_vars = self.neg_eps_policy.dist_info_sym(
            observations_var)
        mix_dist_info_vars = self.mix_policy.dist_info_sym(observations_var)

        surr = tf.reduce_mean(
            dist.log_likelihood_sym(actions_var, dist_info_vars) *
            advantages_var)
        surr_pos_eps = tf.reduce_mean(
            dist.log_likelihood_sym(actions_var, pos_eps_dist_info_vars) *
            advantages_var)
        surr_neg_eps = tf.reduce_mean(
            dist.log_likelihood_sym(actions_var, neg_eps_dist_info_vars) *
            advantages_var)
        surr_mix = tf.reduce_mean(
            dist.log_likelihood_sym(actions_var, mix_dist_info_vars) *
            advantages_var)
        surr_loglikelihood = tf.reduce_sum(
            dist.log_likelihood_sym(actions_var, mix_dist_info_vars))

        params = self.policy.get_params(trainable=True)
        mix_params = self.mix_policy.get_params(trainable=True)
        pos_eps_params = self.pos_eps_policy.get_params(trainable=True)
        neg_eps_params = self.neg_eps_policy.get_params(trainable=True)

        grads = tf.gradients(surr, params)
        grad_pos_eps = tf.gradients(surr_pos_eps, pos_eps_params)
        grad_neg_eps = tf.gradients(surr_neg_eps, neg_eps_params)
        grad_mix = tf.gradients(surr_mix, mix_params)
        grad_mix_lh = tf.gradients(surr_loglikelihood, mix_params)

        self._opt_fun = ext.LazyDict(
            f_loss=lambda: tensor_utils.compile_function(
                inputs=[observations_var, actions_var, advantages_var],
                outputs=surr,
                log_name="f_loss",
            ),
            f_train=lambda: tensor_utils.compile_function(
                inputs=[observations_var, actions_var, advantages_var],
                outputs=grads,
                log_name="f_grad"),
            f_mix_grad=lambda: tensor_utils.compile_function(
                inputs=[observations_var, actions_var, advantages_var],
                outputs=grad_mix,
                log_name="f_mix_grad"),
            f_pos_grad=lambda: tensor_utils.compile_function(
                inputs=[observations_var, actions_var, advantages_var],
                outputs=grad_pos_eps),
            f_neg_grad=lambda: tensor_utils.compile_function(
                inputs=[observations_var, actions_var, advantages_var],
                outputs=grad_neg_eps),
            f_mix_lh=lambda: tensor_utils.compile_function(
                inputs=[observations_var, actions_var], outputs=grad_mix_lh),
            f_kl=lambda: tensor_utils.compile_function(
                inputs=[observations_var],
                outputs=[mean_kl, max_kl],
            ))
示例#13
0
文件: catrpo.py 项目: Mee321/HAPG_exp
    def init_opt(self):
        observations_var = self.env.observation_space.new_tensor_variable(
            'obs',
            extra_dims=1,
        )
        actions_var = self.env.action_space.new_tensor_variable(
            'action',
            extra_dims=1,
        )
        advantages_var = tensor_utils.new_tensor(
            name='advantage',
            ndim=1,
            dtype=tf.float32,
        )
        dist = self.policy.distribution

        old_dist_info_vars = self.backup_policy.dist_info_sym(observations_var)
        dist_info_vars = self.policy.dist_info_sym(observations_var)

        kl = dist.kl_sym(old_dist_info_vars, dist_info_vars)
        mean_kl = tf.reduce_mean(kl)
        max_kl = tf.reduce_max(kl)

        pos_eps_dist_info_vars = self.pos_eps_policy.dist_info_sym(
            observations_var)
        neg_eps_dist_info_vars = self.neg_eps_policy.dist_info_sym(
            observations_var)
        mix_dist_info_vars = self.mix_policy.dist_info_sym(observations_var)

        # formulate as a minimization problem
        # The gradient of the surrogate objective is the policy gradient
        surr = -tf.reduce_mean(
            dist.log_likelihood_sym(actions_var, dist_info_vars) *
            advantages_var)
        surr_pos_eps = -tf.reduce_mean(
            dist.log_likelihood_sym(actions_var, pos_eps_dist_info_vars) *
            advantages_var)
        surr_neg_eps = -tf.reduce_mean(
            dist.log_likelihood_sym(actions_var, neg_eps_dist_info_vars) *
            advantages_var)
        surr_mix = -tf.reduce_mean(
            dist.log_likelihood_sym(actions_var, mix_dist_info_vars) *
            advantages_var)
        surr_loglikelihood = tf.reduce_sum(
            dist.log_likelihood_sym(actions_var, mix_dist_info_vars))

        params = self.policy.get_params(trainable=True)
        mix_params = self.mix_policy.get_params(trainable=True)
        pos_eps_params = self.pos_eps_policy.get_params(trainable=True)
        neg_eps_params = self.neg_eps_policy.get_params(trainable=True)

        grads = tf.gradients(surr, params)
        grad_pos_eps = tf.gradients(surr_pos_eps, pos_eps_params)
        grad_neg_eps = tf.gradients(surr_neg_eps, neg_eps_params)
        grad_mix = tf.gradients(surr_mix, mix_params)
        grad_mix_lh = tf.gradients(surr_loglikelihood, mix_params)

        inputs_list = [observations_var, actions_var, advantages_var]

        self.optimizer.update_opt(loss=surr,
                                  target=self.policy,
                                  leq_constraint=(mean_kl, self.delta),
                                  inputs=inputs_list)

        self._opt_fun = ext.LazyDict(
            f_loss=lambda: tensor_utils.compile_function(
                inputs=inputs_list,
                outputs=surr,
                log_name="f_loss",
            ),
            f_train=lambda: tensor_utils.compile_function(
                inputs=inputs_list, outputs=grads, log_name="f_grad"),
            f_mix_grad=lambda: tensor_utils.compile_function(
                inputs=inputs_list, outputs=grad_mix, log_name="f_mix_grad"),
            f_pos_grad=lambda: tensor_utils.compile_function(
                inputs=inputs_list, outputs=grad_pos_eps),
            f_neg_grad=lambda: tensor_utils.compile_function(
                inputs=inputs_list, outputs=grad_neg_eps),
            f_mix_lh=lambda: tensor_utils.compile_function(
                inputs=inputs_list, outputs=grad_mix_lh),
            f_kl=lambda: tensor_utils.compile_function(
                inputs=inputs_list,
                outputs=[mean_kl, max_kl],
            ))