def test_sample_auxiliary_op(self): p_fn, q_fn = sampling.mean_field_fn() p = p_fn(tf.float32, (), 'test_prior', True, tf.get_variable).distribution q = q_fn(tf.float32, (), 'test_posterior', True, tf.get_variable).distribution # Test benign auxiliary variable sample_op, _ = sampling.sample_auxiliary_op(p, q, 1e-10) sess = tf.Session() sess.run(tf.initialize_all_variables()) p.loc.load(1., session=sess) p.untransformed_scale.load(self._softplus_inverse_np(1.), session=sess) q.loc.load(1.1, session=sess) q.untransformed_scale.load(self._softplus_inverse_np(0.5), session=sess) print(sess.run(q.scale)) sess.run(sample_op) tolerance = 0.0001 self.assertLess(np.abs(sess.run(p.scale) - 1.), tolerance) self.assertLess(np.abs(sess.run(p.loc) - 1.), tolerance) self.assertLess(np.abs(sess.run(q.scale) - 0.5), tolerance) self.assertLess(np.abs(sess.run(q.loc) - 1.1), tolerance) # Test fully determining auxiliary variable sample_op, _ = sampling.sample_auxiliary_op(p, q, 1. - 1e-10) sess.run(tf.initialize_all_variables()) p.loc.load(1., session=sess) p.untransformed_scale.load(self._softplus_inverse_np(1.), session=sess) q.loc.load(1.1, session=sess) q.untransformed_scale.load(self._softplus_inverse_np(.5), session=sess) sess.run(sample_op) self.assertLess(np.abs(sess.run(q.loc) - sess.run(p.loc)), tolerance) self.assertLess(sess.run(p.scale), tolerance) self.assertLess(sess.run(q.scale), tolerance) # Test delta posterior sample_op, _ = sampling.sample_auxiliary_op(p, q, 0.5) sess.run(tf.initialize_all_variables()) p.loc.load(1., session=sess) p.untransformed_scale.load(self._softplus_inverse_np(1.), session=sess) q.loc.load(1.1, session=sess) q.untransformed_scale.load(self._softplus_inverse_np(1e-10), session=sess) sess.run(sample_op) self.assertLess(np.abs(sess.run(q.loc) - 1.1), tolerance) self.assertLess(sess.run(q.scale), tolerance) # Test prior is posterior sample_op, _ = sampling.sample_auxiliary_op(p, q, 0.5) sess.run(tf.initialize_all_variables()) p.loc.load(1., session=sess) p.untransformed_scale.load(self._softplus_inverse_np(1.), session=sess) q.loc.load(1., session=sess) q.untransformed_scale.load(self._softplus_inverse_np(1.), session=sess) sess.run(sample_op) self.assertLess(np.abs(sess.run(q.loc - p.loc)), tolerance) self.assertLess(np.abs(sess.run(q.scale - p.scale)), tolerance)
def main(argv): del argv # unused arg np.random.seed(FLAGS.seed) tf.random.set_seed(FLAGS.seed) tf.io.gfile.makedirs(FLAGS.output_dir) tf1.disable_v2_behavior() session = tf1.Session() with session.as_default(): x_train, y_train, x_test, y_test = datasets.load(session) n_train = x_train.shape[0] num_classes = int(np.amax(y_train)) + 1 if not FLAGS.resnet: model = lenet5(n_train, x_train.shape[1:], num_classes) else: datagen = tf.keras.preprocessing.image.ImageDataGenerator( rotation_range=90, width_shift_range=0.1, height_shift_range=0.1, horizontal_flip=True) datagen.fit(x_train) model = res_net(n_train, x_train.shape[1:], num_classes, batchnorm=FLAGS.batchnorm, variational='hybrid' if FLAGS.hybrid else 'full') def schedule_fn(epoch): """Learning rate schedule function.""" rate = FLAGS.learning_rate if epoch > 180: rate *= 0.5e-3 elif epoch > 160: rate *= 1e-3 elif epoch > 120: rate *= 1e-2 elif epoch > 80: rate *= 1e-1 return float(rate) lr_callback = tf.keras.callbacks.LearningRateScheduler(schedule_fn) for l in model.layers: l.kl_cost_weight = l.add_weight( name='kl_cost_weight', shape=(), initializer=tf.constant_initializer(0.), trainable=False) l.kl_cost_bias = l.add_variable( name='kl_cost_bias', shape=(), initializer=tf.constant_initializer(0.), trainable=False) [negative_log_likelihood, accuracy, log_likelihood, kl, elbo] = get_losses_and_metrics(model, n_train) metrics = [elbo, log_likelihood, kl, accuracy] tensorboard = tf1.keras.callbacks.TensorBoard( log_dir=FLAGS.output_dir, update_freq=FLAGS.batch_size * FLAGS.validation_freq) if FLAGS.resnet: callbacks = [tensorboard, lr_callback] else: callbacks = [tensorboard] if not FLAGS.resnet or not FLAGS.data_augmentation: def fit_fn(model, steps, initial_epoch=0, with_lr_schedule=FLAGS.resnet): return model.fit( x=x_train, y=y_train, batch_size=FLAGS.batch_size, epochs=initial_epoch + (FLAGS.batch_size * steps) // n_train, initial_epoch=initial_epoch, validation_data=(x_test, y_test), validation_freq=( (FLAGS.validation_freq * FLAGS.batch_size) // n_train), verbose=1, callbacks=callbacks if with_lr_schedule else [tensorboard]) else: def fit_fn(model, steps, initial_epoch=0, with_lr_schedule=FLAGS.resnet): return model.fit_generator( datagen.flow(x_train, y_train, batch_size=FLAGS.batch_size), epochs=initial_epoch + (FLAGS.batch_size * steps) // n_train, initial_epoch=initial_epoch, steps_per_epoch=n_train // FLAGS.batch_size, validation_data=(x_test, y_test), validation_freq=max( (FLAGS.validation_freq * FLAGS.batch_size) // n_train, 1), verbose=1, callbacks=callbacks if with_lr_schedule else [tensorboard]) model.compile( optimizer=tf.keras.optimizers.Adam(lr=float(FLAGS.learning_rate)), loss=negative_log_likelihood, metrics=metrics) session.run(tf1.initialize_all_variables()) train_epochs = (FLAGS.training_steps * FLAGS.batch_size) // n_train fit_fn(model, FLAGS.training_steps) labels = tf.keras.layers.Input(shape=y_train.shape[1:]) ll = tf.keras.backend.function([model.input, labels], [ model.output.distribution.log_prob(tf.squeeze(labels)), model.output.distribution.logits ]) base_metrics = [ ensemble_metrics(x_train, y_train, model, ll), ensemble_metrics(x_test, y_test, model, ll) ] model_dir = os.path.join(FLAGS.output_dir, 'models') tf.io.gfile.makedirs(model_dir) base_model_filename = os.path.join(model_dir, 'base_model.weights') model.save_weights(base_model_filename) # Train base model further for comparison. fit_fn(model, FLAGS.n_auxiliary_variables * FLAGS.auxiliary_sampling_frequency * FLAGS.ensemble_size, initial_epoch=train_epochs) overtrained_metrics = [ ensemble_metrics(x_train, y_train, model, ll), ensemble_metrics(x_test, y_test, model, ll) ] # Perform refined VI. sample_op = [] for l in model.layers: if isinstance( l, tfp.layers.DenseLocalReparameterization) or isinstance( l, tfp.layers.Convolution2DFlipout): weight_op, weight_cost = sample_auxiliary_op( l.kernel_prior.distribution, l.kernel_posterior.distribution, FLAGS.auxiliary_variance_ratio) sample_op.append(weight_op) sample_op.append(l.kl_cost_weight.assign_add(weight_cost)) # Fix the variance of the prior session.run(l.kernel_prior.distribution.istrainable.assign(0.)) if hasattr(l.bias_prior, 'distribution'): bias_op, bias_cost = sample_auxiliary_op( l.bias_prior.distribution, l.bias_posterior.distribution, FLAGS.auxiliary_variance_ratio) sample_op.append(bias_op) sample_op.append(l.kl_cost_bias.assign_add(bias_cost)) # Fix the variance of the prior session.run( l.bias_prior.distribution.istrainable.assign(0.)) ensemble_filenames = [] for i in range(FLAGS.ensemble_size): model.load_weights(base_model_filename) for j in range(FLAGS.n_auxiliary_variables): session.run(sample_op) model.compile( optimizer=tf.keras.optimizers.Adam( # The learning rate is proportional to the scale of the prior. lr=float(FLAGS.learning_rate_for_sampling * np.sqrt(1. - FLAGS.auxiliary_variance_ratio)**j)), loss=negative_log_likelihood, metrics=metrics) fit_fn(model, FLAGS.auxiliary_sampling_frequency, initial_epoch=train_epochs, with_lr_schedule=False) ensemble_filename = os.path.join( model_dir, 'ensemble_component_' + str(i) + '.weights') ensemble_filenames.append(ensemble_filename) model.save_weights(ensemble_filename) auxiliary_metrics = [ ensemble_metrics(x_train, y_train, model, ll, weight_files=ensemble_filenames), ensemble_metrics(x_test, y_test, model, ll, weight_files=ensemble_filenames) ] for metrics, name in [(base_metrics, 'Base model'), (overtrained_metrics, 'Overtrained model'), (auxiliary_metrics, 'Auxiliary sampling')]: logging.info(name) for metrics_dict, split in [(metrics[0], 'Training'), (metrics[1], 'Testing')]: logging.info(split) for metric_name in metrics_dict: logging.info('%s: %s', metric_name, metrics_dict[metric_name])