Example #1
0
    def test_fixed_preconditioner(self):
        pmodel = sgmcmc_testlib.Normal2D(noise_sigma=1.0)
        model = tf.keras.Sequential([pmodel])
        model.build(input_shape=(1, 1))
        var0 = model.trainable_variables[0]

        optimizer = sgmcmc.NaiveSymplecticEulerMCMC(total_sample_size=1,
                                                    learning_rate=0.01,
                                                    momentum_decay=0.7,
                                                    preconditioner='fixed')

        # Initial preconditioner: identity
        precond_dict0 = {var0.name: 1.0}
        optimizer.set_preconditioner_dict(precond_dict0,
                                          model.trainable_variables)
        sgmcmc_testlib.sample_model(model, optimizer, 2000)
        self._check_kinetic_temperature_regions(model, optimizer)

        # Adjust preconditioner
        mom_old = tf.identity(optimizer.get_slot(var0, 'moments'))
        precond_dict1 = {var0.name: 100.0}
        optimizer.set_preconditioner_dict(precond_dict1,
                                          model.trainable_variables)
        mom_new = tf.identity(optimizer.get_slot(var0, 'moments'))

        # Ensure moments are properly scaled and kinetic temperatures are ok
        mom_new_target = tf.sqrt(100.0 / 1.0) * mom_old
        self.assertAllClose(
            mom_new,
            mom_new_target,
            msg='Moments not adjusted on preconditioner update.')
        self._check_kinetic_temperature_regions(model, optimizer)

        # Check kinetic temperature is ok after adjustment
        sgmcmc_testlib.sample_model(model, optimizer, 2000)
        self._check_kinetic_temperature_regions(model, optimizer)
Example #2
0
    def test_timestep_factor(self):
        pmodel = sgmcmc_testlib.Normal2D(correlation=0.99,
                                         noise_sigma=0.25,
                                         uniform_noise=True)
        model = tf.keras.Sequential([pmodel])
        optimizer = sgmcmc.NaiveSymplecticEulerMCMC(total_sample_size=1,
                                                    learning_rate=0.01,
                                                    momentum_decay=0.9,
                                                    timestep_factor=1.0)

        nburnin = 4096
        nsamples = 262144

        # Check that accuracy improves when we half the timestep_factor and
        # efficiency as measured by ESS goes down by half.
        kl_prev = None
        efficiency_prev = None
        for timestep_factor in [1.0, 0.5, 0.25, 0.125]:
            optimizer.timestep_factor.assign(timestep_factor)
            samples = sgmcmc_testlib.sample_model(model, optimizer, nburnin)
            samples = sgmcmc_testlib.sample_model(model, optimizer, nsamples)

            mean = np.mean(samples, axis=1)
            cov = np.cov(samples)
            _, efficiency_all = sgmcmc_testlib.compute_ess_multidimensional(
                samples)
            efficiency = np.min(efficiency_all)
            kl = self._kl2d(pmodel, mean, cov)

            name = optimizer.get_config()['name']
            lr = optimizer.get_config()['learning_rate']
            momentum_decay = optimizer.get_config().get('momentum_decay', -1.0)
            dlr, dmomentum_decay = optimizer.dynamics_parameters(tf.float32)
            dlr = float(dlr)
            dmomentum_decay = float(dmomentum_decay)

            logging.info(
                '%s(lr=%.4f, momentum_decay=%.4f, timestep_factor=%.4f) => '
                '(dlr=%.4f, dmomentum_decay=%.4f) => '
                'mean (%.4f,%.4f)  kl %.4f  eff %.5f', name, lr,
                momentum_decay, timestep_factor, dlr, dmomentum_decay, mean[0],
                mean[1], kl, efficiency)

            # Check values
            if kl_prev is not None:
                self.assertLess(
                    kl,
                    kl_prev + 0.006,
                    msg='Decreasing timestep_factor to %.4f increased KL '
                    'from %.5f to %.5f' % (timestep_factor, kl_prev, kl))
                self.assertAlmostEqual(
                    efficiency / efficiency_prev,
                    0.5,
                    delta=0.05,
                    msg='Decreasing timestep_factor to %.4f '
                    'produced efficiency %.5f, compared to previous '
                    'efficiency of %.5f, but ratio %.5f not close '
                    'to 1/2.' % (timestep_factor, efficiency, efficiency_prev,
                                 efficiency / efficiency_prev))

            kl_prev = kl
            efficiency_prev = efficiency
Example #3
0
def main(argv):
  del argv  # unused arg

  tf.io.gfile.makedirs(FLAGS.output_dir)

  # Load data
  tf.random.set_seed(DATASET_SEED)

  if FLAGS.dataset == 'cifar10':
    dataset_train, ds_info = datasets.load_cifar10(
        tfds.Split.TRAIN, with_info=True,
        data_augmentation=FLAGS.cifar_data_augmentation,
        subsample_n=FLAGS.subsample_train_size)
    dataset_test = datasets.load_cifar10(tfds.Split.TEST)
    logging.info('CIFAR10 dataset loaded.')

  elif FLAGS.dataset == 'imdb':
    dataset, ds_info = datasets.load_imdb(
        with_info=True, subsample_n=FLAGS.subsample_train_size)
    dataset_train, dataset_test = datasets.get_generators_from_ds(dataset)
    logging.info('IMDB dataset loaded.')

  else:
    raise ValueError('Unknown dataset {}.'.format(FLAGS.dataset))

  # Prepare data for SG-MCMC methods
  dataset_size = ds_info['train_num_examples']
  dataset_size_orig = ds_info.get('train_num_examples_orig', dataset_size)
  dataset_train = dataset_train.repeat().shuffle(10 * FLAGS.batch_size).batch(
      FLAGS.batch_size)
  test_batch_size = 100
  validation_steps = ds_info['test_num_examples'] // test_batch_size
  dataset_test_single = dataset_test.batch(FLAGS.batch_size)
  dataset_test = dataset_test.repeat().batch(test_batch_size)

  # If --pretend_batch_size flag is provided any cycle/epoch-length computation
  # is done using this pretend_batch_size.  Real batches are all still
  # FLAGS.batch_size of length.  This feature is used in the batch size ablation
  # study.
  #
  # Also, always determine number of iterations from original data set size
  if FLAGS.pretend_batch_size >= 1:
    steps_per_epoch = dataset_size_orig // FLAGS.pretend_batch_size
  else:
    steps_per_epoch = dataset_size_orig // FLAGS.batch_size

  # Set seed for the experiment
  tf.random.set_seed(FLAGS.seed)

  # Build model using pfac for proper priors
  reg_weight = 1.0 / float(dataset_size)
  if FLAGS.pfac == 'default':
    pfac = priorfactory.DefaultPriorFactory(weight=reg_weight)
  elif FLAGS.pfac == 'gaussian':
    pfac = priorfactory.GaussianPriorFactory(prior_stddev=1.0,
                                             weight=reg_weight)
  else:
    raise ValueError('Choose pfac from: default, gaussian.')

  input_shape = ds_info['input_shape']

  if FLAGS.model == 'cnnlstm':
    assert FLAGS.dataset == 'imdb'
    model = models.build_cnnlstm(ds_info['num_words'],
                                 ds_info['sequence_length'],
                                 pfac)

  elif FLAGS.model == 'resnet':
    assert FLAGS.dataset == 'cifar10'
    model = models.build_resnet_v1(
        input_shape=input_shape,
        depth=20,
        num_classes=ds_info['num_classes'],
        pfac=pfac,
        use_frn=FLAGS.resnet_use_frn,
        use_internal_bias=FLAGS.resnet_bias)
  else:
    raise ValueError('Choose model from: cnnlstm, resnet.')

  model.summary()

  # Setup callbacks executed in keras.compile loop
  callbacks = []

  # Setup preconditioner
  precond_dict = dict()

  if FLAGS.use_preconditioner:
    precond_dict['preconditioner'] = 'fixed'
    logging.info('Use fixed preconditioner.')
  else:
    logging.info('No preconditioner is used.')

  # Always append preconditioner callback to compute ctemp statistics
  precond_estimator_cb = keras_utils.EstimatePreconditionerCallback(
      gradest_train_fn,
      iter(dataset_train),
      every_nth_epoch=1,
      batch_count=32,
      raw_second_moment=True,
      update_precond=FLAGS.use_preconditioner)
  callbacks.append(precond_estimator_cb)

  # Setup MCMC method
  if FLAGS.method == 'sgmcmc':
    # SG-MCMC optimizer, first-order symplectic Euler integrator
    optimizer = sgmcmc.NaiveSymplecticEulerMCMC(
        total_sample_size=dataset_size,
        learning_rate=FLAGS.init_learning_rate,
        momentum_decay=FLAGS.momentum_decay,
        temp=FLAGS.temperature,
        **precond_dict)
    logging.info('Use symplectic Euler integrator.')

  elif FLAGS.method == 'baoab':
     # SG-MCMC optimizer, second-order accurate BAOAB integrator
    optimizer = sgmcmc.BAOABMCMC(
        total_sample_size=dataset_size,
        learning_rate=FLAGS.init_learning_rate,
        momentum_decay=FLAGS.momentum_decay,
        temp=FLAGS.temperature,
        **precond_dict)
    logging.info('Use BAOAB integrator.')
  else:
    raise ValueError('Choose method from: sgmcmc, baoab.')

  # Statistics for evaluation of ensemble performance
  perf_stats = {
      'ens_gce': stats.MeanStatistic(stats.ClassificationGibbsCrossEntropy()),
      'ens_ce': stats.MeanStatistic(stats.ClassificationCrossEntropy()),
      'ens_ce_sem': stats.StandardError(stats.ClassificationCrossEntropy()),
      'ens_brier': stats.MeanStatistic(stats.BrierScore()),
      'ens_brier_unc': stats.BrierUncertainty(),
      'ens_brier_res': stats.BrierResolution(),
      'ens_brier_reliab': stats.BrierReliability(),
      'ens_ece': stats.ECE(10),
      'ens_gacc': stats.MeanStatistic(stats.GibbsAccuracy()),
      'ens_acc': stats.MeanStatistic(stats.Accuracy()),
      'ens_acc_sem': stats.StandardError(stats.Accuracy()),
  }

  perf_stats_l, perf_stats_s = zip(*(perf_stats.items()))

  # Setup ensemble
  ens = ensemble.EmpiricalEnsemble(model, input_shape)
  last_ens_eval = {'size': 0}  # ensemble size from last evaluation

  def cycle_ens_eval_maybe():
    """Ensemble evaluation callback, only evaluate at end of cycle."""

    if len(ens) > last_ens_eval['size']:
      last_ens_eval['size'] = len(ens)
      logging.info('... evaluate ensemble on %d members', len(ens))
      return ens.evaluate_ensemble(
          dataset=dataset_test_single, statistics=perf_stats_s)
    else:
      return None

  ensemble_eval_cb = keras_utils.EvaluateEnsemblePartial(
      cycle_ens_eval_maybe, perf_stats_l)
  callbacks.append(ensemble_eval_cb)

  # Setup cyclical learning rate and temperature schedule for sgmcmc
  if FLAGS.method == 'sgmcmc' or FLAGS.method == 'baoab':
    # setup cyclical learning rate schedule
    cyclic_sampler_cb = keras_utils.CyclicSamplerCallback(
        ens,
        FLAGS.cycle_length * steps_per_epoch,  # number of iterations per cycle
        FLAGS.cycle_start_sampling,  # sampling phase start epoch
        schedule=FLAGS.cycle_schedule,
        min_value=0.0)  # timestep_factor min value
    callbacks.append(cyclic_sampler_cb)

    # Setup temperature ramp-up schedule
    begin_ramp_epoch = FLAGS.cycle_start_sampling - FLAGS.cycle_length
    if begin_ramp_epoch < 0:
      raise ValueError(
          'cycle_start_sampling must be greater equal than cycle_length.')
    ramp_iterations = FLAGS.cycle_length
    tempramp_cb = keras_utils.TemperatureRampScheduler(
        0.0, FLAGS.temperature, begin_ramp_epoch * steps_per_epoch,
        ramp_iterations * steps_per_epoch)
    # T0, Tf, begin_iter, ramp_epochs
    callbacks.append(tempramp_cb)

  # Additional callbacks
  # Plot additional logs
  def plot_logs(epoch, logs):
    del epoch  # unused
    logs['lr'] = optimizer.get_config()['learning_rate']
    if FLAGS.method == 'sgmcmc':
      logs['timestep_factor'] = optimizer.get_config()['timestep_factor']
    logs['ens_size'] = len(ens)
  plot_logs_cb = tf.keras.callbacks.LambdaCallback(on_epoch_end=plot_logs)

  # Write logs to tensorboard
  tensorboard_cb = tf.keras.callbacks.TensorBoard(
      log_dir=FLAGS.output_dir, write_graph=False)

  # Output ktemp
  diag_cb = keras_utils.PrintDiagnosticsCallback(10)

  callbacks.extend([
      diag_cb,
      plot_logs_cb,
      keras_utils.TemperatureMetric(),
      keras_utils.SamplerTemperatureMetric(),
      tensorboard_cb,  # Should be after all callbacks that write logs
      tf.keras.callbacks.CSVLogger(os.path.join(FLAGS.output_dir, 'logs.csv'))
  ])

  # Keras train model
  metrics = [
      tf.keras.metrics.SparseCategoricalCrossentropy(
          name='negative_log_likelihood',
          from_logits=True),
      tf.keras.metrics.SparseCategoricalAccuracy(name='accuracy')]
  model.compile(
      optimizer,
      loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
      metrics=metrics)
  logging.info('Model input shape: %s', model.input_shape)
  logging.info('Model output shape: %s', model.output_shape)
  logging.info('Model number of weights: %s', model.count_params())

  model.fit(
      dataset_train,
      steps_per_epoch=steps_per_epoch,
      epochs=FLAGS.train_epochs,
      validation_data=dataset_test,
      validation_steps=validation_steps,
      callbacks=callbacks)

  # Evaluate final ensemble performance
  logging.info('Ensemble has %d members, computing final performance metrics.',
               len(ens))

  if ens.weights_list:
    ens_perf_stats = ens.evaluate_ensemble(dataset_test_single, perf_stats_s)
    print('Test set metrics:')
    for label, stat_value in zip(perf_stats_l, ens_perf_stats):
      stat_value = float(stat_value)
      logging.info('%s: %.5f', label, stat_value)
      print('%s: %.5f' % (label, stat_value))

  # Add experiment info to experiment metadata csv file in *parent folder*
  if FLAGS.write_experiment_metadata_to_csv:
    csv_path = pathlib.Path.joinpath(
        pathlib.Path(FLAGS.output_dir).parent, 'run_sweeps.csv')
    data = {
        'id': [FLAGS.experiment_id],
        'seed': [FLAGS.seed],
        'temperature': [FLAGS.temperature],
        'dir': ['run_{}'.format(FLAGS.experiment_id)]
    }
    if tf.io.gfile.exists(csv_path):
      sweeps_df = pd.read_csv(csv_path)
      sweeps_df = sweeps_df.append(
          pd.DataFrame.from_dict(data), ignore_index=True).set_index('id')
    else:
      sweeps_df = pd.DataFrame.from_dict(data).set_index('id')

    # save experiment metadata csv file
    sweeps_df.to_csv(csv_path)