def update_opt(self, f, target, inputs, reg_coeff):
        self.target = target
        self.reg_coeff = reg_coeff
        params = target.get_params(trainable=True)

        constraint_grads = tf.gradients(f, xs=params)
        for idx, (grad, param) in enumerate(zip(constraint_grads, params)):
            if grad is None:
                constraint_grads[idx] = tf.zeros_like(param)

        xs = tuple([
            tensor_utils.new_tensor_like(p.name.split(":")[0], p)
            for p in params
        ])

        def Hx_plain():
            Hx_plain_splits = tf.gradients(
                tf.reduce_sum(
                    tf.stack([
                        tf.reduce_sum(g * x)
                        for g, x in zip(constraint_grads, xs)
                    ])), params)
            for idx, (Hx, param) in enumerate(zip(Hx_plain_splits, params)):
                if Hx is None:
                    Hx_plain_splits[idx] = tf.zeros_like(param)
            return tensor_utils.flatten_tensor_variables(Hx_plain_splits)

        self.opt_fun = ext.lazydict(
            f_Hx_plain=lambda: tensor_utils.compile_function(
                inputs=inputs + xs,
                outputs=Hx_plain(),
                log_name="f_Hx_plain",
            ), )
    def update_opt(self, f, target, inputs, reg_coeff):
        self.target = target
        self.reg_coeff = reg_coeff
        params = target.get_params(trainable=True)

        constraint_grads = tf.gradients(f, xs=params)
        for idx, (grad, param) in enumerate(zip(constraint_grads, params)):
            if grad is None:
                constraint_grads[idx] = tf.zeros_like(param)

        xs = tuple([tensor_utils.new_tensor_like(p.name.split(":")[0], p) for p in params])

        def Hx_plain():
            Hx_plain_splits = tf.gradients(
                tf.reduce_sum(
                    tf.stack([tf.reduce_sum(g * x) for g, x in zip(constraint_grads, xs)])
                ),
                params
            )
            for idx, (Hx, param) in enumerate(zip(Hx_plain_splits, params)):
                if Hx is None:
                    Hx_plain_splits[idx] = tf.zeros_like(param)
            return tensor_utils.flatten_tensor_variables(Hx_plain_splits)

        self.opt_fun = ext.lazydict(
            f_Hx_plain=lambda: tensor_utils.compile_function(
                inputs=inputs + xs,
                outputs=Hx_plain(),
                log_name="f_Hx_plain",
            ),
        )