Пример #1
0
    def test_sample_auxiliary_op(self):
        p_fn, q_fn = sampling.mean_field_fn()
        p = p_fn(tf.float32, (), 'test_prior', True,
                 tf.get_variable).distribution
        q = q_fn(tf.float32, (), 'test_posterior', True,
                 tf.get_variable).distribution

        # Test benign auxiliary variable
        sample_op, _ = sampling.sample_auxiliary_op(p, q, 1e-10)
        sess = tf.Session()
        sess.run(tf.initialize_all_variables())
        p.loc.load(1., session=sess)
        p.untransformed_scale.load(self._softplus_inverse_np(1.), session=sess)
        q.loc.load(1.1, session=sess)
        q.untransformed_scale.load(self._softplus_inverse_np(0.5),
                                   session=sess)
        print(sess.run(q.scale))

        sess.run(sample_op)

        tolerance = 0.0001
        self.assertLess(np.abs(sess.run(p.scale) - 1.), tolerance)
        self.assertLess(np.abs(sess.run(p.loc) - 1.), tolerance)
        self.assertLess(np.abs(sess.run(q.scale) - 0.5), tolerance)
        self.assertLess(np.abs(sess.run(q.loc) - 1.1), tolerance)

        # Test fully determining auxiliary variable
        sample_op, _ = sampling.sample_auxiliary_op(p, q, 1. - 1e-10)
        sess.run(tf.initialize_all_variables())
        p.loc.load(1., session=sess)
        p.untransformed_scale.load(self._softplus_inverse_np(1.), session=sess)
        q.loc.load(1.1, session=sess)
        q.untransformed_scale.load(self._softplus_inverse_np(.5), session=sess)

        sess.run(sample_op)

        self.assertLess(np.abs(sess.run(q.loc) - sess.run(p.loc)), tolerance)
        self.assertLess(sess.run(p.scale), tolerance)
        self.assertLess(sess.run(q.scale), tolerance)

        # Test delta posterior
        sample_op, _ = sampling.sample_auxiliary_op(p, q, 0.5)
        sess.run(tf.initialize_all_variables())
        p.loc.load(1., session=sess)
        p.untransformed_scale.load(self._softplus_inverse_np(1.), session=sess)
        q.loc.load(1.1, session=sess)
        q.untransformed_scale.load(self._softplus_inverse_np(1e-10),
                                   session=sess)

        sess.run(sample_op)

        self.assertLess(np.abs(sess.run(q.loc) - 1.1), tolerance)
        self.assertLess(sess.run(q.scale), tolerance)

        # Test prior is posterior
        sample_op, _ = sampling.sample_auxiliary_op(p, q, 0.5)
        sess.run(tf.initialize_all_variables())
        p.loc.load(1., session=sess)
        p.untransformed_scale.load(self._softplus_inverse_np(1.), session=sess)
        q.loc.load(1., session=sess)
        q.untransformed_scale.load(self._softplus_inverse_np(1.), session=sess)

        sess.run(sample_op)

        self.assertLess(np.abs(sess.run(q.loc - p.loc)), tolerance)
        self.assertLess(np.abs(sess.run(q.scale - p.scale)), tolerance)
Пример #2
0
def main(argv):
    del argv  # unused arg
    np.random.seed(FLAGS.seed)
    tf.random.set_seed(FLAGS.seed)
    tf.io.gfile.makedirs(FLAGS.output_dir)
    tf1.disable_v2_behavior()

    session = tf1.Session()
    with session.as_default():
        x_train, y_train, x_test, y_test = datasets.load(session)
        n_train = x_train.shape[0]

        num_classes = int(np.amax(y_train)) + 1
        if not FLAGS.resnet:
            model = lenet5(n_train, x_train.shape[1:], num_classes)
        else:
            datagen = tf.keras.preprocessing.image.ImageDataGenerator(
                rotation_range=90,
                width_shift_range=0.1,
                height_shift_range=0.1,
                horizontal_flip=True)
            datagen.fit(x_train)
            model = res_net(n_train,
                            x_train.shape[1:],
                            num_classes,
                            batchnorm=FLAGS.batchnorm,
                            variational='hybrid' if FLAGS.hybrid else 'full')

            def schedule_fn(epoch):
                """Learning rate schedule function."""
                rate = FLAGS.learning_rate
                if epoch > 180:
                    rate *= 0.5e-3
                elif epoch > 160:
                    rate *= 1e-3
                elif epoch > 120:
                    rate *= 1e-2
                elif epoch > 80:
                    rate *= 1e-1
                return float(rate)

            lr_callback = tf.keras.callbacks.LearningRateScheduler(schedule_fn)

        for l in model.layers:
            l.kl_cost_weight = l.add_weight(
                name='kl_cost_weight',
                shape=(),
                initializer=tf.constant_initializer(0.),
                trainable=False)
            l.kl_cost_bias = l.add_variable(
                name='kl_cost_bias',
                shape=(),
                initializer=tf.constant_initializer(0.),
                trainable=False)

        [negative_log_likelihood, accuracy, log_likelihood, kl,
         elbo] = get_losses_and_metrics(model, n_train)

        metrics = [elbo, log_likelihood, kl, accuracy]

        tensorboard = tf1.keras.callbacks.TensorBoard(
            log_dir=FLAGS.output_dir,
            update_freq=FLAGS.batch_size * FLAGS.validation_freq)
        if FLAGS.resnet:
            callbacks = [tensorboard, lr_callback]
        else:
            callbacks = [tensorboard]

        if not FLAGS.resnet or not FLAGS.data_augmentation:

            def fit_fn(model,
                       steps,
                       initial_epoch=0,
                       with_lr_schedule=FLAGS.resnet):
                return model.fit(
                    x=x_train,
                    y=y_train,
                    batch_size=FLAGS.batch_size,
                    epochs=initial_epoch +
                    (FLAGS.batch_size * steps) // n_train,
                    initial_epoch=initial_epoch,
                    validation_data=(x_test, y_test),
                    validation_freq=(
                        (FLAGS.validation_freq * FLAGS.batch_size) // n_train),
                    verbose=1,
                    callbacks=callbacks if with_lr_schedule else [tensorboard])
        else:

            def fit_fn(model,
                       steps,
                       initial_epoch=0,
                       with_lr_schedule=FLAGS.resnet):
                return model.fit_generator(
                    datagen.flow(x_train, y_train,
                                 batch_size=FLAGS.batch_size),
                    epochs=initial_epoch +
                    (FLAGS.batch_size * steps) // n_train,
                    initial_epoch=initial_epoch,
                    steps_per_epoch=n_train // FLAGS.batch_size,
                    validation_data=(x_test, y_test),
                    validation_freq=max(
                        (FLAGS.validation_freq * FLAGS.batch_size) // n_train,
                        1),
                    verbose=1,
                    callbacks=callbacks if with_lr_schedule else [tensorboard])

        model.compile(
            optimizer=tf.keras.optimizers.Adam(lr=float(FLAGS.learning_rate)),
            loss=negative_log_likelihood,
            metrics=metrics)
        session.run(tf1.initialize_all_variables())

        train_epochs = (FLAGS.training_steps * FLAGS.batch_size) // n_train
        fit_fn(model, FLAGS.training_steps)

        labels = tf.keras.layers.Input(shape=y_train.shape[1:])
        ll = tf.keras.backend.function([model.input, labels], [
            model.output.distribution.log_prob(tf.squeeze(labels)),
            model.output.distribution.logits
        ])

        base_metrics = [
            ensemble_metrics(x_train, y_train, model, ll),
            ensemble_metrics(x_test, y_test, model, ll)
        ]
        model_dir = os.path.join(FLAGS.output_dir, 'models')
        tf.io.gfile.makedirs(model_dir)
        base_model_filename = os.path.join(model_dir, 'base_model.weights')
        model.save_weights(base_model_filename)

        # Train base model further for comparison.
        fit_fn(model,
               FLAGS.n_auxiliary_variables *
               FLAGS.auxiliary_sampling_frequency * FLAGS.ensemble_size,
               initial_epoch=train_epochs)

        overtrained_metrics = [
            ensemble_metrics(x_train, y_train, model, ll),
            ensemble_metrics(x_test, y_test, model, ll)
        ]

        # Perform refined VI.
        sample_op = []
        for l in model.layers:
            if isinstance(
                    l, tfp.layers.DenseLocalReparameterization) or isinstance(
                        l, tfp.layers.Convolution2DFlipout):
                weight_op, weight_cost = sample_auxiliary_op(
                    l.kernel_prior.distribution,
                    l.kernel_posterior.distribution,
                    FLAGS.auxiliary_variance_ratio)
                sample_op.append(weight_op)
                sample_op.append(l.kl_cost_weight.assign_add(weight_cost))
                # Fix the variance of the prior
                session.run(l.kernel_prior.distribution.istrainable.assign(0.))
                if hasattr(l.bias_prior, 'distribution'):
                    bias_op, bias_cost = sample_auxiliary_op(
                        l.bias_prior.distribution,
                        l.bias_posterior.distribution,
                        FLAGS.auxiliary_variance_ratio)
                    sample_op.append(bias_op)
                    sample_op.append(l.kl_cost_bias.assign_add(bias_cost))
                    # Fix the variance of the prior
                    session.run(
                        l.bias_prior.distribution.istrainable.assign(0.))

        ensemble_filenames = []
        for i in range(FLAGS.ensemble_size):
            model.load_weights(base_model_filename)
            for j in range(FLAGS.n_auxiliary_variables):
                session.run(sample_op)
                model.compile(
                    optimizer=tf.keras.optimizers.Adam(
                        # The learning rate is proportional to the scale of the prior.
                        lr=float(FLAGS.learning_rate_for_sampling *
                                 np.sqrt(1. -
                                         FLAGS.auxiliary_variance_ratio)**j)),
                    loss=negative_log_likelihood,
                    metrics=metrics)
                fit_fn(model,
                       FLAGS.auxiliary_sampling_frequency,
                       initial_epoch=train_epochs,
                       with_lr_schedule=False)
            ensemble_filename = os.path.join(
                model_dir, 'ensemble_component_' + str(i) + '.weights')
            ensemble_filenames.append(ensemble_filename)
            model.save_weights(ensemble_filename)

        auxiliary_metrics = [
            ensemble_metrics(x_train,
                             y_train,
                             model,
                             ll,
                             weight_files=ensemble_filenames),
            ensemble_metrics(x_test,
                             y_test,
                             model,
                             ll,
                             weight_files=ensemble_filenames)
        ]

        for metrics, name in [(base_metrics, 'Base model'),
                              (overtrained_metrics, 'Overtrained model'),
                              (auxiliary_metrics, 'Auxiliary sampling')]:
            logging.info(name)
            for metrics_dict, split in [(metrics[0], 'Training'),
                                        (metrics[1], 'Testing')]:
                logging.info(split)
                for metric_name in metrics_dict:
                    logging.info('%s: %s', metric_name,
                                 metrics_dict[metric_name])