Esempio n. 1
0
    def test_empirical_fisher_should_regularize_changed_regularizer(self):
        """Asserts regularizer correctly changed by `should_regularize` function."""
        labels = tf.constant([[1.0, 0.0]])

        def make_logits():
            a = tf.get_variable("a", initializer=tf.constant(1.))
            b = tf.get_variable("b", initializer=tf.constant(1.))
            l = tf.multiply(a, b)
            return tf.stack([[l, tf.subtract(1., l)]])

        _, regularizer_b = dfr.make_empirical_fisher_regularizer(
            make_logits,
            labels,
            "test_scope_should",
            lambda name: "b" in name,
            # Note that for the "b" in name check to work with the intended effect
            # the scope name cannot contain the letter b
            self.perturbation)

        _, regularizer = dfr.make_empirical_fisher_regularizer(
            make_logits, labels, "test_scope", lambda name: True,
            self.perturbation)

        with self.test_session() as sess:
            sess.run(tf.global_variables_initializer())
            self.assertNotEqual(sess.run(regularizer_b), sess.run(regularizer))
Esempio n. 2
0
    def test_empirical_fisher_batch(self):
        """Asserts sum property of regularizer gradient for sum reduction loss."""
        labels_full_batch = \
            [[0.5, 0.5], [0.5, 0.5], [0.5, 0.5], [0.5, 0.5], [0.5, 0.5]]
        labels_one_batch = [[0.5, 0.5], [0.5, 0.5], [0.5, 0.5]]
        labels_two_batch = [[0.5, 0.5], [0.5, 0.5]]

        x_full_batch = np.array([1., 2., 3., 4., 5.])
        x_part_one_batch = np.array([1., 2., 3.])
        x_part_two_batch = np.array([4., 5.])

        def make_make_logits_part(x_batch):
            def make_logits():
                v = tf.get_variable("v", initializer=tf.constant(2.5))
                y = tf.multiply(v, x_batch)
                return tf.transpose(tf.stack([tf.sin(y), tf.cos(y)]))

            return make_logits

        _, regularizer_full = dfr.make_empirical_fisher_regularizer(
            make_make_logits_part(x_full_batch), labels_full_batch,
            "test_scope", lambda name: True, self.perturbation)

        _, regularizer_part_one = dfr.make_empirical_fisher_regularizer(
            make_make_logits_part(x_part_one_batch), labels_one_batch,
            "test_scope_part_one", lambda name: True, self.perturbation)

        _, regularizer_part_two = dfr.make_empirical_fisher_regularizer(
            make_make_logits_part(x_part_two_batch), labels_two_batch,
            "test_scope_part_two", lambda name: True, self.perturbation)

        with tf.variable_scope("test_scope", reuse=True):
            gradient_full = tf.gradients(regularizer_full,
                                         tf.get_variable("v"))[0]

        with tf.variable_scope("test_scope_part_one", reuse=True):
            gradient_part_one = tf.gradients(regularizer_part_one,
                                             tf.get_variable("v"))[0]

        with tf.variable_scope("test_scope_part_two", reuse=True):
            gradient_part_two = tf.gradients(regularizer_part_two,
                                             tf.get_variable("v"))[0]

        with self.test_session() as sess:
            sess.run(tf.global_variables_initializer())
            self.assertAllClose(5 * sess.run(gradient_full),
                                3 * sess.run(gradient_part_one) +
                                2 * sess.run(gradient_part_two),
                                rtol=self.rtol)
Esempio n. 3
0
    def test_empirical_fisher_should_regularize_unchanged_loss(self):
        """Asserts unregularized loss unchanged by `should_regularize` function."""
        labels = tf.constant([[1.0, 0.0]])

        def make_logits():
            l = tf.get_variable("a", initializer=tf.constant(1.))
            return tf.stack([[l, tf.subtract(1., l)]])

        loss_true, _ = dfr.make_empirical_fisher_regularizer(
            make_logits, labels, "test_scope", lambda name: True,
            self.perturbation)

        loss_false, _ = dfr.make_empirical_fisher_regularizer(
            make_logits, labels, "test_scope_2", lambda name: False,
            self.perturbation)

        with self.test_session() as sess:
            sess.run(tf.global_variables_initializer())
            self.assertAllEqual(sess.run(loss_true), sess.run(loss_false))
Esempio n. 4
0
    def test_empirical_fisher_constant_loss_regularizer(self):
        """Asserts regularizer for loss without variables evaluates to zero."""
        labels = tf.constant([[1.0, 0.0]])
        make_logits = lambda: tf.constant([[0.5, 0.5]])
        _, regularizer = dfr.make_empirical_fisher_regularizer(
            make_logits, labels, "test_scope", lambda name: True,
            self.perturbation)

        with self.test_session() as sess:
            self.assertAllEqual(sess.run(regularizer), 0.)
Esempio n. 5
0
    def make_two_vars_product_loss_and_regularizer(self):
        """Helper that creates Tensors for a * b loss and its regularizer."""
        def make_loss():
            a = tf.get_variable("a", initializer=tf.constant(2.))
            b = tf.get_variable("b", initializer=tf.constant(3.))
            return tf.multiply(a, b)

        return dfr.make_empirical_fisher_regularizer(make_loss, "test_scope",
                                                     lambda name: True,
                                                     self.perturbation)
Esempio n. 6
0
    def make_empirical_fisher_sin_logits_and_regularizer(self):
        """Helper that creates Tensors for sin(x) logits and the regularizer."""
        labels = tf.constant([[1.0, 0.0]])

        def make_logits():
            x = tf.get_variable("x", initializer=tf.constant(2.))
            y = tf.get_variable("y", initializer=tf.constant(3.))
            return tf.stack([[tf.sin(x), tf.sin(y)]])

        return dfr.make_empirical_fisher_regularizer(make_logits, labels,
                                                     "test_scope",
                                                     lambda name: True,
                                                     self.perturbation)
Esempio n. 7
0
def main(unused_argv):
    del unused_argv

    tf.set_random_seed(FLAGS.random_seed)
    sess = tf.Session()

    # Build datasets
    inputs = tf.random_normal(
        (FLAGS.train_size + FLAGS.test_size, FLAGS.input_dimension))
    weights = tf.random_normal((FLAGS.input_dimension, 2))
    logits = tf.matmul(inputs, weights)
    distribution = tf.distributions.Categorical(logits=sess.run(logits))
    labels = distribution.sample()
    labels = tf.one_hot(labels, 2)

    train_data = tf.data.Dataset.from_tensor_slices(
        (sess.run(inputs), sess.run(labels)))
    train_data = train_data.take(FLAGS.train_size)
    train_data = train_data.cache()
    train_data = train_data.repeat()

    train_data_batch = train_data.batch(FLAGS.batch_size)
    train_data_batch_next = train_data_batch.make_one_shot_iterator().get_next(
    )

    train_data = train_data.batch(FLAGS.train_size)
    train_data_next = train_data.make_one_shot_iterator().get_next()

    test_data = tf.data.Dataset.from_tensor_slices(
        (sess.run(inputs), sess.run(labels)))
    test_data = test_data.skip(FLAGS.train_size)
    test_data = test_data.take(FLAGS.test_size)
    test_data = test_data.cache()
    test_data = test_data.repeat()

    test_data = test_data.batch(FLAGS.test_size)
    test_data_next = test_data.make_one_shot_iterator().get_next()

    # Build graph
    input_batch = tf.placeholder(tf.float32,
                                 shape=(None, FLAGS.input_dimension))
    label_batch = tf.placeholder(tf.float32, shape=(None, 2))

    def make_logits():
        """Builds fully connected ReLU neural network model and returns logits."""
        input_layer = tf.layers.dense(inputs=input_batch,
                                      units=FLAGS.layer_width,
                                      activation=tf.nn.relu)
        previous_layer = input_layer

        for _ in range(FLAGS.num_hidden_layers):
            layer = tf.layers.dense(inputs=previous_layer,
                                    units=FLAGS.layer_width,
                                    activation=tf.nn.relu)
            previous_layer = layer

        logits = tf.layers.dense(inputs=previous_layer, units=2)

        return logits

    if FLAGS.regularization_norm == "None":
        with tf.variable_scope("regularizer_scope"):
            logits = make_logits()
        regularizer = tf.constant(0.)

    elif FLAGS.regularization_norm == "fre":
        logits, regularizer = dfr.make_empirical_fisher_regularizer(
            make_logits, label_batch, "regularizer_scope", lambda name: True,
            1e-4)

    elif FLAGS.regularization_norm == "fr":
        logits, regularizer = dfr.make_standard_fisher_regularizer(
            make_logits, "regularizer_scope", lambda name: True, 1e-4,
            FLAGS.differentiate_probability)

    loss = tf.reduce_mean(
        tf.nn.softmax_cross_entropy_with_logits_v2(labels=label_batch,
                                                   logits=logits))
    total_loss = loss + FLAGS.fisher_rao_lambda * regularizer

    if FLAGS.optimizer == "sgd":
        optimizer = tf.train.GradientDescentOptimizer(FLAGS.learning_rate)

    elif FLAGS.optimizer == "adam":
        optimizer = tf.train.AdamOptimizer(learning_rate=FLAGS.learning_rate)

    train = optimizer.minimize(total_loss)

    accuracy = accuracy_reduction(labels=label_batch,
                                  predictions=tf.nn.softmax(logits))

    train_loss_trajectory = []
    train_accuracy_trajectory = []

    test_loss_trajectory = []
    test_accuracy_trajectory = []

    regularizer_trajectory = []

    sess.run(tf.global_variables_initializer())

    # Optimization loop
    for i in range(FLAGS.train_steps):
        if i % FLAGS.eval_steps == 0:
            tf.logging.info("iter " + str(i) + " / " + str(FLAGS.train_steps))

            # Train loss and accuracy
            train_sample = sess.run(train_data_next)

            train_loss = sess.run(tf.reduce_mean(loss),
                                  feed_dict={
                                      input_batch: train_sample[0],
                                      label_batch: train_sample[1]
                                  })
            train_loss_trajectory.append(train_loss)

            train_accuracy = sess.run(accuracy,
                                      feed_dict={
                                          input_batch: train_sample[0],
                                          label_batch: train_sample[1]
                                      })
            train_accuracy_trajectory.append(train_accuracy)

            tf.logging.info("train loss " + str(train_loss))
            tf.logging.info("train accuracy " + str(train_accuracy))

            # Test loss and accuracy
            test_sample = sess.run(test_data_next)

            test_loss = sess.run(tf.reduce_mean(loss),
                                 feed_dict={
                                     input_batch: test_sample[0],
                                     label_batch: test_sample[1]
                                 })
            test_loss_trajectory.append(test_loss)

            test_accuracy = sess.run(accuracy,
                                     feed_dict={
                                         input_batch: test_sample[0],
                                         label_batch: test_sample[1]
                                     })
            test_accuracy_trajectory.append(test_accuracy)

            tf.logging.info("test loss " + str(test_loss))
            tf.logging.info("test accuracy " + str(test_accuracy))

        batch_sample = sess.run(train_data_batch_next)

        regularizer_loss, _ = sess.run((regularizer, train),
                                       feed_dict={
                                           input_batch: batch_sample[0],
                                           label_batch: batch_sample[1]
                                       })
        regularizer_trajectory.append(regularizer_loss)

    output_filename = FLAGS.output_dir + make_output_filename()
    with tf.gfile.Open(output_filename, "w") as output_file:
        output_file.write("{\n")
        output_file.write("\"hparams\" : " + make_output_flags_string() +
                          ",\n")
        output_file.write("\"train_loss\" : " + str(train_loss_trajectory) +
                          ",\n")
        output_file.write(
            ("\"train_accuracy\" : " + str(train_accuracy_trajectory) + ",\n"))
        output_file.write("\"test_loss\" : " + str(test_loss_trajectory) +
                          ",\n")
        output_file.write(
            ("\"test_accuracy\" : " + str(test_accuracy_trajectory) + ",\n"))
        output_file.write(
            ("\"regularizer_loss\" : " + str(regularizer_trajectory) + "\n"))
        output_file.write("}\n")