def test_classification_ce(self):
        cce = stats.ClassificationCrossEntropy()

        logits1 = tf.math.log([[0.3, 0.7], [0.6, 0.4]])
        logits2 = tf.math.log([[0.2, 0.8], [0.5, 0.5]])
        logits3 = tf.math.log([[0.4, 0.6], [0.4, 0.6]])

        labels = tf.convert_to_tensor([1, 0], dtype=tf.int32)

        cce.reset()
        cce.update(logits1, labels)
        cce.update(logits2, labels)
        cce.update(logits3, labels)
        ce = cce.result()

        self.assertAlmostEqual(-math.log(0.7), float(ce[0]), delta=TOL)
        self.assertAlmostEqual(-math.log(0.5), float(ce[1]), delta=TOL)

        ces = []
        gce = stats.ClassificationGibbsCrossEntropy()
        gce.reset()
        for logits in [logits1, logits2, logits3]:
            cce.reset()
            cce.update(logits, labels)
            ces.append(cce.result())

            gce.update(logits, labels)

        self.assertAllClose(tf.reduce_mean(tf.stack(ces, axis=0), axis=0),
                            gce.result(),
                            atol=TOL,
                            msg="Gibbs cross entropy does not match mean CE.")
    def test_fresh_reservoir_ensemble(self):
        model = tf.keras.models.Sequential([
            tf.keras.layers.Dense(10, activation="relu"),
            tf.keras.layers.Dense(10)
        ])

        input_shape = (None, 10)
        model.build(input_shape=input_shape)
        orig_weights = model.get_weights()
        # Keras does not provide a good reinit function, just draw random weights:
        weights1 = [np.random.random(w.shape) for w in orig_weights]
        weights2 = [np.random.random(w.shape) for w in orig_weights]

        ens = ensemble.EmpiricalEnsemble(model, input_shape,
                                         [weights1, weights2])
        self.assertLen(ens, 2, msg="Empirical ensemble len wrong.")

        y_true = np.random.choice(10, 20)
        x = np.random.normal(0, 1, (20, 10))

        ens = ensemble.FreshReservoirEnsemble(model,
                                              input_shape,
                                              capacity=2,
                                              freshness=50)
        ens.append(weights1)
        ens.append(weights2)
        self.assertLen(ens, 1, msg="Fresh reservoir ensemble len wrong.")

        statistics = [stats.ClassificationLogProb()]
        ens_pred = ens.evaluate_ensemble(x, statistics)
        self.assertLen(
            statistics,
            len(ens_pred),
            msg="Number of prediction outputs differ from statistics count.")

        self.assertLen(
            x,
            int(ens_pred[0].shape[0]),
            msg="Ensemble prediction statistics output has wrong shape.")

        statistics = [stats.Accuracy(), stats.ClassificationCrossEntropy()]
        ens_eval = ens.evaluate_ensemble((x, y_true), statistics)
        self.assertLen(
            statistics,
            len(ens_eval),
            msg="Number of evaluation outputs differ from statistics count.")
示例#3
0
def main(argv):
  del argv  # unused arg

  tf.io.gfile.makedirs(FLAGS.output_dir)

  # Load data
  tf.random.set_seed(DATASET_SEED)

  if FLAGS.dataset == 'cifar10':
    dataset_train, ds_info = datasets.load_cifar10(
        tfds.Split.TRAIN, with_info=True,
        data_augmentation=FLAGS.cifar_data_augmentation,
        subsample_n=FLAGS.subsample_train_size)
    dataset_test = datasets.load_cifar10(tfds.Split.TEST)
    logging.info('CIFAR10 dataset loaded.')

  elif FLAGS.dataset == 'imdb':
    dataset, ds_info = datasets.load_imdb(
        with_info=True, subsample_n=FLAGS.subsample_train_size)
    dataset_train, dataset_test = datasets.get_generators_from_ds(dataset)
    logging.info('IMDB dataset loaded.')

  else:
    raise ValueError('Unknown dataset {}.'.format(FLAGS.dataset))

  # Prepare data for SG-MCMC methods
  dataset_size = ds_info['train_num_examples']
  dataset_size_orig = ds_info.get('train_num_examples_orig', dataset_size)
  dataset_train = dataset_train.repeat().shuffle(10 * FLAGS.batch_size).batch(
      FLAGS.batch_size)
  test_batch_size = 100
  validation_steps = ds_info['test_num_examples'] // test_batch_size
  dataset_test_single = dataset_test.batch(FLAGS.batch_size)
  dataset_test = dataset_test.repeat().batch(test_batch_size)

  # If --pretend_batch_size flag is provided any cycle/epoch-length computation
  # is done using this pretend_batch_size.  Real batches are all still
  # FLAGS.batch_size of length.  This feature is used in the batch size ablation
  # study.
  #
  # Also, always determine number of iterations from original data set size
  if FLAGS.pretend_batch_size >= 1:
    steps_per_epoch = dataset_size_orig // FLAGS.pretend_batch_size
  else:
    steps_per_epoch = dataset_size_orig // FLAGS.batch_size

  # Set seed for the experiment
  tf.random.set_seed(FLAGS.seed)

  # Build model using pfac for proper priors
  reg_weight = 1.0 / float(dataset_size)
  if FLAGS.pfac == 'default':
    pfac = priorfactory.DefaultPriorFactory(weight=reg_weight)
  elif FLAGS.pfac == 'gaussian':
    pfac = priorfactory.GaussianPriorFactory(prior_stddev=1.0,
                                             weight=reg_weight)
  else:
    raise ValueError('Choose pfac from: default, gaussian.')

  input_shape = ds_info['input_shape']

  if FLAGS.model == 'cnnlstm':
    assert FLAGS.dataset == 'imdb'
    model = models.build_cnnlstm(ds_info['num_words'],
                                 ds_info['sequence_length'],
                                 pfac)

  elif FLAGS.model == 'resnet':
    assert FLAGS.dataset == 'cifar10'
    model = models.build_resnet_v1(
        input_shape=input_shape,
        depth=20,
        num_classes=ds_info['num_classes'],
        pfac=pfac,
        use_frn=FLAGS.resnet_use_frn,
        use_internal_bias=FLAGS.resnet_bias)
  else:
    raise ValueError('Choose model from: cnnlstm, resnet.')

  model.summary()

  # Setup callbacks executed in keras.compile loop
  callbacks = []

  # Setup preconditioner
  precond_dict = dict()

  if FLAGS.use_preconditioner:
    precond_dict['preconditioner'] = 'fixed'
    logging.info('Use fixed preconditioner.')
  else:
    logging.info('No preconditioner is used.')

  # Always append preconditioner callback to compute ctemp statistics
  precond_estimator_cb = keras_utils.EstimatePreconditionerCallback(
      gradest_train_fn,
      iter(dataset_train),
      every_nth_epoch=1,
      batch_count=32,
      raw_second_moment=True,
      update_precond=FLAGS.use_preconditioner)
  callbacks.append(precond_estimator_cb)

  # Setup MCMC method
  if FLAGS.method == 'sgmcmc':
    # SG-MCMC optimizer, first-order symplectic Euler integrator
    optimizer = sgmcmc.NaiveSymplecticEulerMCMC(
        total_sample_size=dataset_size,
        learning_rate=FLAGS.init_learning_rate,
        momentum_decay=FLAGS.momentum_decay,
        temp=FLAGS.temperature,
        **precond_dict)
    logging.info('Use symplectic Euler integrator.')

  elif FLAGS.method == 'baoab':
     # SG-MCMC optimizer, second-order accurate BAOAB integrator
    optimizer = sgmcmc.BAOABMCMC(
        total_sample_size=dataset_size,
        learning_rate=FLAGS.init_learning_rate,
        momentum_decay=FLAGS.momentum_decay,
        temp=FLAGS.temperature,
        **precond_dict)
    logging.info('Use BAOAB integrator.')
  else:
    raise ValueError('Choose method from: sgmcmc, baoab.')

  # Statistics for evaluation of ensemble performance
  perf_stats = {
      'ens_gce': stats.MeanStatistic(stats.ClassificationGibbsCrossEntropy()),
      'ens_ce': stats.MeanStatistic(stats.ClassificationCrossEntropy()),
      'ens_ce_sem': stats.StandardError(stats.ClassificationCrossEntropy()),
      'ens_brier': stats.MeanStatistic(stats.BrierScore()),
      'ens_brier_unc': stats.BrierUncertainty(),
      'ens_brier_res': stats.BrierResolution(),
      'ens_brier_reliab': stats.BrierReliability(),
      'ens_ece': stats.ECE(10),
      'ens_gacc': stats.MeanStatistic(stats.GibbsAccuracy()),
      'ens_acc': stats.MeanStatistic(stats.Accuracy()),
      'ens_acc_sem': stats.StandardError(stats.Accuracy()),
  }

  perf_stats_l, perf_stats_s = zip(*(perf_stats.items()))

  # Setup ensemble
  ens = ensemble.EmpiricalEnsemble(model, input_shape)
  last_ens_eval = {'size': 0}  # ensemble size from last evaluation

  def cycle_ens_eval_maybe():
    """Ensemble evaluation callback, only evaluate at end of cycle."""

    if len(ens) > last_ens_eval['size']:
      last_ens_eval['size'] = len(ens)
      logging.info('... evaluate ensemble on %d members', len(ens))
      return ens.evaluate_ensemble(
          dataset=dataset_test_single, statistics=perf_stats_s)
    else:
      return None

  ensemble_eval_cb = keras_utils.EvaluateEnsemblePartial(
      cycle_ens_eval_maybe, perf_stats_l)
  callbacks.append(ensemble_eval_cb)

  # Setup cyclical learning rate and temperature schedule for sgmcmc
  if FLAGS.method == 'sgmcmc' or FLAGS.method == 'baoab':
    # setup cyclical learning rate schedule
    cyclic_sampler_cb = keras_utils.CyclicSamplerCallback(
        ens,
        FLAGS.cycle_length * steps_per_epoch,  # number of iterations per cycle
        FLAGS.cycle_start_sampling,  # sampling phase start epoch
        schedule=FLAGS.cycle_schedule,
        min_value=0.0)  # timestep_factor min value
    callbacks.append(cyclic_sampler_cb)

    # Setup temperature ramp-up schedule
    begin_ramp_epoch = FLAGS.cycle_start_sampling - FLAGS.cycle_length
    if begin_ramp_epoch < 0:
      raise ValueError(
          'cycle_start_sampling must be greater equal than cycle_length.')
    ramp_iterations = FLAGS.cycle_length
    tempramp_cb = keras_utils.TemperatureRampScheduler(
        0.0, FLAGS.temperature, begin_ramp_epoch * steps_per_epoch,
        ramp_iterations * steps_per_epoch)
    # T0, Tf, begin_iter, ramp_epochs
    callbacks.append(tempramp_cb)

  # Additional callbacks
  # Plot additional logs
  def plot_logs(epoch, logs):
    del epoch  # unused
    logs['lr'] = optimizer.get_config()['learning_rate']
    if FLAGS.method == 'sgmcmc':
      logs['timestep_factor'] = optimizer.get_config()['timestep_factor']
    logs['ens_size'] = len(ens)
  plot_logs_cb = tf.keras.callbacks.LambdaCallback(on_epoch_end=plot_logs)

  # Write logs to tensorboard
  tensorboard_cb = tf.keras.callbacks.TensorBoard(
      log_dir=FLAGS.output_dir, write_graph=False)

  # Output ktemp
  diag_cb = keras_utils.PrintDiagnosticsCallback(10)

  callbacks.extend([
      diag_cb,
      plot_logs_cb,
      keras_utils.TemperatureMetric(),
      keras_utils.SamplerTemperatureMetric(),
      tensorboard_cb,  # Should be after all callbacks that write logs
      tf.keras.callbacks.CSVLogger(os.path.join(FLAGS.output_dir, 'logs.csv'))
  ])

  # Keras train model
  metrics = [
      tf.keras.metrics.SparseCategoricalCrossentropy(
          name='negative_log_likelihood',
          from_logits=True),
      tf.keras.metrics.SparseCategoricalAccuracy(name='accuracy')]
  model.compile(
      optimizer,
      loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
      metrics=metrics)
  logging.info('Model input shape: %s', model.input_shape)
  logging.info('Model output shape: %s', model.output_shape)
  logging.info('Model number of weights: %s', model.count_params())

  model.fit(
      dataset_train,
      steps_per_epoch=steps_per_epoch,
      epochs=FLAGS.train_epochs,
      validation_data=dataset_test,
      validation_steps=validation_steps,
      callbacks=callbacks)

  # Evaluate final ensemble performance
  logging.info('Ensemble has %d members, computing final performance metrics.',
               len(ens))

  if ens.weights_list:
    ens_perf_stats = ens.evaluate_ensemble(dataset_test_single, perf_stats_s)
    print('Test set metrics:')
    for label, stat_value in zip(perf_stats_l, ens_perf_stats):
      stat_value = float(stat_value)
      logging.info('%s: %.5f', label, stat_value)
      print('%s: %.5f' % (label, stat_value))

  # Add experiment info to experiment metadata csv file in *parent folder*
  if FLAGS.write_experiment_metadata_to_csv:
    csv_path = pathlib.Path.joinpath(
        pathlib.Path(FLAGS.output_dir).parent, 'run_sweeps.csv')
    data = {
        'id': [FLAGS.experiment_id],
        'seed': [FLAGS.seed],
        'temperature': [FLAGS.temperature],
        'dir': ['run_{}'.format(FLAGS.experiment_id)]
    }
    if tf.io.gfile.exists(csv_path):
      sweeps_df = pd.read_csv(csv_path)
      sweeps_df = sweeps_df.append(
          pd.DataFrame.from_dict(data), ignore_index=True).set_index('id')
    else:
      sweeps_df = pd.DataFrame.from_dict(data).set_index('id')

    # save experiment metadata csv file
    sweeps_df.to_csv(csv_path)