Esempio n. 1
0
    def __init__(self,
                 learning_rate,
                 cg_max_iterations=20,
                 cg_damping=1e-3,
                 cg_unroll_loop=False,
                 scope='natural-gradient',
                 summary_labels=()):
        """
        Creates a new natural gradient optimizer instance.

        Args:
            learning_rate: Learning rate, i.e. KL-divergence of distributions between optimization steps.
            cg_max_iterations: Conjugate gradient solver max iterations.
            cg_damping: Conjugate gradient solver damping factor.
            cg_unroll_loop: Unroll conjugate gradient loop if true.
        """
        assert learning_rate > 0.0
        self.learning_rate = learning_rate

        self.solver = ConjugateGradient(max_iterations=cg_max_iterations,
                                        damping=cg_damping,
                                        unroll_loop=cg_unroll_loop)

        super(NaturalGradient, self).__init__(scope=scope,
                                              summary_labels=summary_labels)
class NaturalGradient(Optimizer):
    """
    Natural gradient optimizer.
    """

    def __init__(self, learning_rate, cg_max_iterations=20, cg_damping=1e-3, cg_unroll_loop=False):
        """
        Creates a new natural gradient optimizer instance.

        Args:
            learning_rate: Learning rate, i.e. KL-divergence of distributions between optimization steps.
            cg_max_iterations: Conjugate gradient solver max iterations.
            cg_damping: Conjugate gradient solver damping factor.
            cg_unroll_loop: Unroll conjugate gradient loop if true.
        """
        super(NaturalGradient, self).__init__()

        assert learning_rate > 0.0
        self.learning_rate = learning_rate

        self.solver = ConjugateGradient(
            max_iterations=cg_max_iterations,
            damping=cg_damping,
            unroll_loop=cg_unroll_loop
        )

    def tf_step(self, time, variables, fn_loss, fn_kl_divergence, return_estimated_improvement=False, **kwargs):
        """
        Creates the TensorFlow operations for performing an optimization step.

        Args:
            time: Time tensor.
            variables: List of variables to optimize.
            fn_loss: A callable returning the loss of the current model.
            fn_kl_divergence: A callable returning the KL-divergence relative to the current model.
            return_estimated_improvement: Returns the estimated improvement resulting from the  
                natural gradient calculation if true.
            **kwargs: Additional arguments, not used.

        Returns:
            List of delta tensors corresponding to the updates for each optimized variable.
        """

        # Optimize: argmin(w) loss(w + delta) such that kldiv(P(w) || P(w + delta)) = learning_rate
        # For more details, see our blogpost: [LINK]

        # grad(kldiv)
        kldiv_gradients = tf.gradients(ys=fn_kl_divergence(), xs=variables)

        # Calculates the product x * F of a given vector x with the fisher matrix F.
        # Incorporating the product prevents having to actually calculate the entire matrix explicitly.
        def fisher_matrix_product(x):
            # Gradient is not propagated through solver.
            x = [tf.stop_gradient(input=delta) for delta in x]

            # delta' * grad(kldiv)
            x_kldiv_gradients = tf.add_n(inputs=[
                tf.reduce_sum(input_tensor=(delta * grad)) for delta, grad in zip(x, kldiv_gradients)
            ])

            # [delta' * F] = grad(delta' * grad(kldiv))
            return tf.gradients(ys=x_kldiv_gradients, xs=variables)

        # grad(loss)
        loss_gradients = tf.gradients(ys=fn_loss(), xs=variables)

        # Solve the following system for delta' via the conjugate gradient solver.
        # [delta' * F] * delta' = -grad(loss)
        # --> delta'  (= lambda * delta)
        deltas = self.solver.solve(fn_x=fisher_matrix_product, x_init=None, b=[-grad for grad in loss_gradients])

        # delta' * F
        delta_fisher_matrix_product = fisher_matrix_product(x=deltas)

        # c' = 0.5 * delta' * F * delta'  (= lambda * c)
        # TODO: Why constant and hence KL-divergence sometimes negative?
        constant = 0.5 * tf.add_n(inputs=[
            tf.reduce_sum(input_tensor=(delta_F * delta))
            for delta_F, delta in zip(delta_fisher_matrix_product, deltas)
        ])

        # Natural gradient step if constant > 0
        def natural_gradient_step():
            # lambda = sqrt(c' / c)
            lagrange_multiplier = tf.sqrt(x=(constant / self.learning_rate))

            # delta = delta' / lambda
            estimated_deltas = [delta / lagrange_multiplier for delta in deltas]

            # improvement = grad(loss) * delta  (= loss_new - loss_old)
            estimated_improvement = tf.add_n(inputs=[
                tf.reduce_sum(input_tensor=(grad * delta))
                for grad, delta in zip(loss_gradients, estimated_deltas)
            ])

            # Apply natural gradient improvement.
            applied = self.apply_step(variables=variables, deltas=estimated_deltas)

            with tf.control_dependencies(control_inputs=(applied,)):
                # Trivial operation to enforce control dependency
                if return_estimated_improvement:
                    return [estimated_delta + 0.0 for estimated_delta in estimated_deltas], estimated_improvement
                else:
                    return [estimated_delta + 0.0 for estimated_delta in estimated_deltas]

        # Zero step if constant <= 0
        def zero_step():
            if return_estimated_improvement:
                return [tf.zeros_like(tensor=delta) for delta in deltas], 0.0
            else:
                return [tf.zeros_like(tensor=delta) for delta in deltas]

        # Natural gradient step only works if constant > 0
        return tf.cond(pred=(constant > 0.0), true_fn=natural_gradient_step, false_fn=zero_step)