def main(argv):
  fmt = '[%(filename)s:%(lineno)s] %(message)s'
  formatter = logging.PythonFormatter(fmt)
  logging.get_absl_handler().setFormatter(formatter)
  del argv  # unused arg

  tf.io.gfile.makedirs(FLAGS.output_dir)
  logging.info('Saving checkpoints at %s', FLAGS.output_dir)
  tf.random.set_seed(FLAGS.seed)

  data_dir = FLAGS.data_dir
  if FLAGS.use_gpu:
    logging.info('Use GPU')
    strategy = tf.distribute.MirroredStrategy()
  else:
    logging.info('Use TPU at %s',
                 FLAGS.tpu if FLAGS.tpu is not None else 'local')
    resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu=FLAGS.tpu)
    tf.config.experimental_connect_to_cluster(resolver)
    tf.tpu.experimental.initialize_tpu_system(resolver)
    strategy = tf.distribute.TPUStrategy(resolver)

  ds_info = tfds.builder(FLAGS.dataset).info
  batch_size = FLAGS.per_core_batch_size * FLAGS.num_cores
  train_dataset_size = (
      ds_info.splits['train'].num_examples * FLAGS.train_proportion)
  steps_per_epoch = int(train_dataset_size / batch_size)
  logging.info('Steps per epoch %s', steps_per_epoch)
  logging.info('Size of the dataset %s', ds_info.splits['train'].num_examples)
  logging.info('Train proportion %s', FLAGS.train_proportion)
  steps_per_eval = ds_info.splits['test'].num_examples // batch_size
  num_classes = ds_info.features['label'].num_classes

  aug_params = {
      'augmix': FLAGS.augmix,
      'aug_count': FLAGS.aug_count,
      'augmix_depth': FLAGS.augmix_depth,
      'augmix_prob_coeff': FLAGS.augmix_prob_coeff,
      'augmix_width': FLAGS.augmix_width,
  }

  # Note that stateless_{fold_in,split} may incur a performance cost, but a
  # quick side-by-side test seemed to imply this was minimal.
  seeds = tf.random.experimental.stateless_split(
      [FLAGS.seed, FLAGS.seed + 1], 2)[:, 0]
  train_builder = ub.datasets.get(
      FLAGS.dataset,
      data_dir=data_dir,
      download_data=FLAGS.download_data,
      split=tfds.Split.TRAIN,
      seed=seeds[0],
      aug_params=aug_params,
      shuffle_buffer_size=FLAGS.shuffle_buffer_size,
      validation_percent=1. - FLAGS.train_proportion,
  )
  train_dataset = train_builder.load(batch_size=batch_size)
  validation_dataset = None
  steps_per_validation = 0
  if FLAGS.train_proportion < 1.0:
    validation_builder = ub.datasets.get(
        FLAGS.dataset,
        split=tfds.Split.VALIDATION,
        validation_percent=1. - FLAGS.train_proportion,
        data_dir=data_dir,
        drop_remainder=FLAGS.drop_remainder_for_eval)
    validation_dataset = validation_builder.load(batch_size=batch_size)
    validation_dataset = strategy.experimental_distribute_dataset(
        validation_dataset)
    steps_per_validation = validation_builder.num_examples // batch_size
  clean_test_builder = ub.datasets.get(
      FLAGS.dataset,
      split=tfds.Split.TEST,
      data_dir=data_dir,
      drop_remainder=FLAGS.drop_remainder_for_eval)
  clean_test_dataset = clean_test_builder.load(batch_size=batch_size)
  test_datasets = {
      'clean': strategy.experimental_distribute_dataset(clean_test_dataset),
  }


  train_dataset = strategy.experimental_distribute_dataset(train_dataset)

  steps_per_epoch = train_builder.num_examples // batch_size
  steps_per_eval = clean_test_builder.num_examples // batch_size
  num_classes = 100 if FLAGS.dataset == 'cifar100' else 10

  if FLAGS.eval_on_ood:
    ood_dataset_names = FLAGS.ood_dataset
    ood_ds, steps_per_ood = ood_utils.load_ood_datasets(
        ood_dataset_names,
        clean_test_builder,
        1. - FLAGS.train_proportion,
        batch_size,
        drop_remainder=FLAGS.drop_remainder_for_eval)
    ood_datasets = {
        name: strategy.experimental_distribute_dataset(ds)
        for name, ds in ood_ds.items()
    }

  if FLAGS.corruptions_interval > 0:
    if FLAGS.dataset == 'cifar100':
      data_dir = FLAGS.cifar100_c_path
    corruption_types, _ = utils.load_corrupted_test_info(FLAGS.dataset)
    for corruption_type in corruption_types:
      for severity in range(1, 6):
        dataset = ub.datasets.get(
            f'{FLAGS.dataset}_corrupted',
            corruption_type=corruption_type,
            severity=severity,
            split=tfds.Split.TEST,
            data_dir=data_dir).load(batch_size=batch_size)
        test_datasets[f'{corruption_type}_{severity}'] = (
            strategy.experimental_distribute_dataset(dataset))

  summary_writer = tf.summary.create_file_writer(
      os.path.join(FLAGS.output_dir, 'summaries'))

  with strategy.scope():
    logging.info('Building ResNet model')
    model = ub.models.wide_resnet(
        input_shape=(32, 32, 3),
        depth=28,
        width_multiplier=10,
        num_classes=num_classes,
        l2=FLAGS.l2,
        hps=_extract_hyperparameter_dictionary(),
        seed=seeds[1])
    logging.info('Model input shape: %s', model.input_shape)
    logging.info('Model output shape: %s', model.output_shape)
    logging.info('Model number of weights: %s', model.count_params())
    # Linearly scale learning rate and the decay epochs by vanilla settings.
    base_lr = FLAGS.base_learning_rate * batch_size / 128
    lr_decay_epochs = [(int(start_epoch_str) * FLAGS.train_epochs) // 200
                       for start_epoch_str in FLAGS.lr_decay_epochs]
    lr_schedule = ub.schedules.WarmUpPiecewiseConstantSchedule(
        steps_per_epoch,
        base_lr,
        decay_ratio=FLAGS.lr_decay_ratio,
        decay_epochs=lr_decay_epochs,
        warmup_epochs=FLAGS.lr_warmup_epochs)
    optimizer = tf.keras.optimizers.SGD(lr_schedule,
                                        momentum=1.0 - FLAGS.one_minus_momentum,
                                        nesterov=True)
    metrics = {
        'train/negative_log_likelihood':
            tf.keras.metrics.Mean(),
        'train/accuracy':
            tf.keras.metrics.SparseCategoricalAccuracy(),
        'train/loss':
            tf.keras.metrics.Mean(),
        'train/ece':
            rm.metrics.ExpectedCalibrationError(num_bins=FLAGS.num_bins),
        'test/negative_log_likelihood':
            tf.keras.metrics.Mean(),
        'test/accuracy':
            tf.keras.metrics.SparseCategoricalAccuracy(),
        'test/ece':
            rm.metrics.ExpectedCalibrationError(num_bins=FLAGS.num_bins),
    }
    if validation_dataset:
      metrics.update({
          'validation/negative_log_likelihood': tf.keras.metrics.Mean(),
          'validation/accuracy': tf.keras.metrics.SparseCategoricalAccuracy(),
          'validation/ece': rm.metrics.ExpectedCalibrationError(
              num_bins=FLAGS.num_bins),
      })
    if FLAGS.eval_on_ood:
      ood_metrics = ood_utils.create_ood_metrics(ood_dataset_names)
      metrics.update(ood_metrics)
    if FLAGS.corruptions_interval > 0:
      corrupt_metrics = {}
      for intensity in range(1, 6):
        for corruption in corruption_types:
          dataset_name = '{0}_{1}'.format(corruption, intensity)
          corrupt_metrics['test/nll_{}'.format(dataset_name)] = (
              tf.keras.metrics.Mean())
          corrupt_metrics['test/accuracy_{}'.format(dataset_name)] = (
              tf.keras.metrics.SparseCategoricalAccuracy())
          corrupt_metrics['test/ece_{}'.format(dataset_name)] = (
              rm.metrics.ExpectedCalibrationError(num_bins=FLAGS.num_bins))

    checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer)
    latest_checkpoint = tf.train.latest_checkpoint(FLAGS.output_dir)
    initial_epoch = 0
    if latest_checkpoint:
      # checkpoint.restore must be within a strategy.scope() so that optimizer
      # slot variables are mirrored.
      checkpoint.restore(latest_checkpoint)
      logging.info('Loaded checkpoint %s', latest_checkpoint)
      initial_epoch = optimizer.iterations.numpy() // steps_per_epoch

    if FLAGS.saved_model_dir:
      logging.info('Saved model dir : %s', FLAGS.saved_model_dir)
      latest_checkpoint = tf.train.latest_checkpoint(FLAGS.saved_model_dir)
      checkpoint.restore(latest_checkpoint)
      logging.info('Loaded checkpoint %s', latest_checkpoint)
    if FLAGS.eval_only:
      initial_epoch = FLAGS.train_epochs - 1  # Run just one epoch of eval

  @tf.function
  def train_step(iterator):
    """Training StepFn."""
    def step_fn(inputs):
      """Per-Replica StepFn."""
      images = inputs['features']
      labels = inputs['labels']

      if FLAGS.augmix and FLAGS.aug_count >= 1:
        # Index 0 at augmix processing is the unperturbed image.
        # We take just 1 augmented image from the returned augmented images.
        images = images[:, 1, ...]
      with tf.GradientTape() as tape:
        logits = model(images, training=True)
        if FLAGS.label_smoothing == 0.:
          negative_log_likelihood = tf.reduce_mean(
              tf.keras.losses.sparse_categorical_crossentropy(labels,
                                                              logits,
                                                              from_logits=True))
        else:
          one_hot_labels = tf.one_hot(tf.cast(labels, tf.int32), num_classes)
          negative_log_likelihood = tf.reduce_mean(
              tf.keras.losses.categorical_crossentropy(
                  one_hot_labels,
                  logits,
                  from_logits=True,
                  label_smoothing=FLAGS.label_smoothing))
        l2_loss = sum(model.losses)
        loss = negative_log_likelihood + l2_loss
        # Scale the loss given the TPUStrategy will reduce sum all gradients.
        scaled_loss = loss / strategy.num_replicas_in_sync

      grads = tape.gradient(scaled_loss, model.trainable_variables)
      optimizer.apply_gradients(zip(grads, model.trainable_variables))

      probs = tf.nn.softmax(logits)
      metrics['train/ece'].add_batch(probs, label=labels)
      metrics['train/loss'].update_state(loss)
      metrics['train/negative_log_likelihood'].update_state(
          negative_log_likelihood)
      metrics['train/accuracy'].update_state(labels, logits)

    for _ in tf.range(tf.cast(steps_per_epoch, tf.int32)):
      strategy.run(step_fn, args=(next(iterator),))

  @tf.function
  def test_step(iterator, dataset_split, dataset_name, num_steps):
    """Evaluation StepFn."""
    def step_fn(inputs):
      """Per-Replica StepFn."""
      images = inputs['features']
      labels = inputs['labels']
      logits = model(images, training=False)
      probs = tf.nn.softmax(logits)

      negative_log_likelihood = tf.reduce_mean(
          tf.keras.losses.sparse_categorical_crossentropy(labels, probs))

      if dataset_name == 'clean':
        metrics[f'{dataset_split}/negative_log_likelihood'].update_state(
            negative_log_likelihood)
        metrics[f'{dataset_split}/accuracy'].update_state(labels, probs)
        metrics[f'{dataset_split}/ece'].add_batch(probs, label=labels)
      elif dataset_name.startswith('ood/'):
        ood_labels = 1 - inputs['is_in_distribution']
        if FLAGS.dempster_shafer_ood:
          ood_scores = ood_utils.DempsterShaferUncertainty(logits)
        else:
          ood_scores = 1 - tf.reduce_max(probs, axis=-1)

        # Edgecase for if dataset_name contains underscores
        for name, metric in metrics.items():
          if dataset_name in name:
            metric.update_state(ood_labels, ood_scores)
      else:
        corrupt_metrics['test/nll_{}'.format(dataset_name)].update_state(
            negative_log_likelihood)
        corrupt_metrics['test/accuracy_{}'.format(dataset_name)].update_state(
            labels, probs)
        corrupt_metrics['test/ece_{}'.format(dataset_name)].add_batch(
            probs, label=labels)

    for _ in tf.range(tf.cast(num_steps, tf.int32)):
      strategy.run(step_fn, args=(next(iterator),))

  @tf.function
  def cifar10h_test_step(iterator, num_steps):
    """Evaluation StepFn."""
    def step_fn(inputs):
      """Per-Replica StepFn."""
      images = inputs['features']
      labels = inputs['labels']
      logits = model(images, training=False)

      negative_log_likelihood = tf.keras.losses.CategoricalCrossentropy(
          from_logits=True,
          reduction=tf.keras.losses.Reduction.NONE)(labels, logits)

      negative_log_likelihood = tf.reduce_mean(negative_log_likelihood)
      metrics['cifar10h/nll'].update_state(negative_log_likelihood)

      label_diversity, sample_diversity, ged = _generalized_energy_distance(
          labels, tf.nn.softmax(logits), 10)

      metrics['cifar10h/ged'].update_state(ged)
      metrics['cifar10h/ged_label_diversity'].update_state(
          tf.reduce_mean(label_diversity))
      metrics['cifar10h/ged_sample_diversity'].update_state(
          tf.reduce_mean(sample_diversity))

    for _ in tf.range(tf.cast(num_steps, tf.int32)):
      strategy.run(step_fn, args=(next(iterator),))

  metrics.update({'test/ms_per_example': tf.keras.metrics.Mean()})
  metrics.update({'train/ms_per_example': tf.keras.metrics.Mean()})

  train_iterator = iter(train_dataset)
  start_time = time.time()
  tb_callback = None
  if FLAGS.collect_profile:
    tb_callback = tf.keras.callbacks.TensorBoard(
        profile_batch=(100, 102),
        log_dir=os.path.join(FLAGS.output_dir, 'logs'))
    tb_callback.set_model(model)
  for epoch in range(initial_epoch, FLAGS.train_epochs):
    logging.info('Starting to run epoch: %s', epoch)
    if tb_callback:
      tb_callback.on_epoch_begin(epoch)
    if not FLAGS.eval_only:
      train_start_time = time.time()
      train_step(train_iterator)
      ms_per_example = (time.time() - train_start_time) * 1e6 / batch_size
      metrics['train/ms_per_example'].update_state(ms_per_example)

      current_step = (epoch + 1) * steps_per_epoch
      max_steps = steps_per_epoch * FLAGS.train_epochs
      time_elapsed = time.time() - start_time
      steps_per_sec = float(current_step) / time_elapsed
      eta_seconds = (max_steps - current_step) / steps_per_sec
      message = ('{:.1%} completion: epoch {:d}/{:d}. {:.1f} steps/s. '
                 'ETA: {:.0f} min. Time elapsed: {:.0f} min'.format(
                     current_step / max_steps, epoch + 1, FLAGS.train_epochs,
                     steps_per_sec, eta_seconds / 60, time_elapsed / 60))
      logging.info(message)
    if tb_callback:
      tb_callback.on_epoch_end(epoch)

    if validation_dataset:
      validation_iterator = iter(validation_dataset)
      test_step(
          validation_iterator, 'validation', 'clean', steps_per_validation)
    datasets_to_evaluate = {'clean': test_datasets['clean']}
    if (FLAGS.corruptions_interval > 0 and
        (epoch + 1) % FLAGS.corruptions_interval == 0):
      datasets_to_evaluate = test_datasets
    for dataset_name, test_dataset in datasets_to_evaluate.items():
      test_iterator = iter(test_dataset)
      logging.info('Testing on dataset %s', dataset_name)
      logging.info('Starting to run eval at epoch: %s', epoch)
      test_start_time = time.time()
      test_step(test_iterator, 'test', dataset_name, steps_per_eval)
      ms_per_example = (time.time() - test_start_time) * 1e6 / batch_size
      metrics['test/ms_per_example'].update_state(ms_per_example)

      logging.info('Done with testing on %s', dataset_name)

    if FLAGS.eval_on_ood:
      for ood_dataset_name, ood_dataset in ood_datasets.items():
        ood_iterator = iter(ood_dataset)
        logging.info('Calculating OOD on dataset %s', ood_dataset_name)
        logging.info('Running OOD eval at epoch: %s', epoch)
        test_step(ood_iterator, 'test', ood_dataset_name,
                  steps_per_ood[ood_dataset_name])

        logging.info('Done with OOD eval on %s', ood_dataset_name)

    corrupt_results = {}
    if (FLAGS.corruptions_interval > 0 and
        (epoch + 1) % FLAGS.corruptions_interval == 0):
      corrupt_results = utils.aggregate_corrupt_metrics(corrupt_metrics,
                                                        corruption_types)

    logging.info('Train Loss: %.4f, Accuracy: %.2f%%',
                 metrics['train/loss'].result(),
                 metrics['train/accuracy'].result() * 100)
    logging.info('Test NLL: %.4f, Accuracy: %.2f%%',
                 metrics['test/negative_log_likelihood'].result(),
                 metrics['test/accuracy'].result() * 100)
    total_results = {name: metric.result() for name, metric in metrics.items()}
    total_results.update(corrupt_results)
    # Metrics from Robustness Metrics (like ECE) will return a dict with a
    # single key/value, instead of a scalar.
    total_results = {
        k: (list(v.values())[0] if isinstance(v, dict) else v)
        for k, v in total_results.items()
    }
    with summary_writer.as_default():
      for name, result in total_results.items():
        tf.summary.scalar(name, result, step=epoch + 1)

    for metric in metrics.values():
      metric.reset_states()

    if FLAGS.corruptions_interval > 0:
      for metric in corrupt_metrics.values():
        metric.reset_states()

    if (FLAGS.checkpoint_interval > 0 and
        (epoch + 1) % FLAGS.checkpoint_interval == 0):
      checkpoint_name = checkpoint.save(
          os.path.join(FLAGS.output_dir, 'checkpoint'))
      logging.info('Saved checkpoint to %s', checkpoint_name)

  final_checkpoint_name = checkpoint.save(
      os.path.join(FLAGS.output_dir, 'checkpoint'))
  logging.info('Saved last checkpoint to %s', final_checkpoint_name)
  with summary_writer.as_default():
    hp.hparams({
        'base_learning_rate': FLAGS.base_learning_rate,
        'one_minus_momentum': FLAGS.one_minus_momentum,
        'l2': FLAGS.l2,
    })
def main(config, output_dir):

    seed = config.get('seed', 0)
    rng = jax.random.PRNGKey(seed)
    tf.random.set_seed(seed)

    if config.get('data_dir'):
        logging.info('data_dir=%s', config.data_dir)
    logging.info('Output dir: %s', output_dir)

    save_checkpoint_path = None
    if config.get('checkpoint_steps'):
        gfile.makedirs(output_dir)
        save_checkpoint_path = os.path.join(output_dir, 'checkpoint.npz')

    # Create an asynchronous multi-metric writer.
    writer = metric_writers.create_default_writer(
        output_dir, just_logging=jax.process_index() > 0)

    # The pool is used to perform misc operations such as logging in async way.
    pool = multiprocessing.pool.ThreadPool()

    def write_note(note):
        if jax.process_index() == 0:
            logging.info('NOTE: %s', note)

    write_note('Initializing...')

    # Verify settings to make sure no checkpoints are accidentally missed.
    if config.get('keep_checkpoint_steps'):
        assert config.get('checkpoint_steps'), 'Specify `checkpoint_steps`.'
        assert config.keep_checkpoint_steps % config.checkpoint_steps == 0, (
            f'`keep_checkpoint_steps` ({config.checkpoint_steps}) should be'
            f'divisible by `checkpoint_steps ({config.checkpoint_steps}).`')

    batch_size = config.batch_size
    batch_size_eval = config.get('batch_size_eval', batch_size)
    if (batch_size % jax.device_count() != 0
            or batch_size_eval % jax.device_count() != 0):
        raise ValueError(
            f'Batch sizes ({batch_size} and {batch_size_eval}) must '
            f'be divisible by device number ({jax.device_count()})')

    local_batch_size = batch_size // jax.process_count()
    local_batch_size_eval = batch_size_eval // jax.process_count()
    logging.info(
        'Global batch size %d on %d hosts results in %d local batch size. '
        'With %d devices per host (%d devices total), that\'s a %d per-device '
        'batch size.', batch_size, jax.process_count(), local_batch_size,
        jax.local_device_count(), jax.device_count(),
        local_batch_size // jax.local_device_count())

    write_note('Initializing train dataset...')
    rng, train_ds_rng = jax.random.split(rng)
    train_ds_rng = jax.random.fold_in(train_ds_rng, jax.process_index())
    train_ds = input_utils.get_data(
        dataset=config.dataset,
        split=config.train_split,
        rng=train_ds_rng,
        process_batch_size=local_batch_size,
        preprocess_fn=preprocess_spec.parse(
            spec=config.pp_train, available_ops=preprocess_utils.all_ops()),
        shuffle_buffer_size=config.shuffle_buffer_size,
        prefetch_size=config.get('prefetch_to_host', 2),
        data_dir=config.get('data_dir'))

    # Start prefetching already.
    train_iter = input_utils.start_input_pipeline(
        train_ds, config.get('prefetch_to_device', 1))

    write_note('Initializing val dataset(s)...')

    def _get_val_split(dataset, split, pp_eval, data_dir=None):
        # We do ceil rounding such that we include the last incomplete batch.
        nval_img = input_utils.get_num_examples(
            dataset,
            split=split,
            process_batch_size=local_batch_size_eval,
            drop_remainder=False,
            data_dir=data_dir)
        val_steps = int(np.ceil(nval_img / batch_size_eval))
        logging.info('Running validation for %d steps for %s, %s', val_steps,
                     dataset, split)

        if isinstance(pp_eval, str):
            pp_eval = preprocess_spec.parse(
                spec=pp_eval, available_ops=preprocess_utils.all_ops())

        val_ds = input_utils.get_data(dataset=dataset,
                                      split=split,
                                      rng=None,
                                      process_batch_size=local_batch_size_eval,
                                      preprocess_fn=pp_eval,
                                      cache=config.get('val_cache', 'batched'),
                                      repeat_after_batching=True,
                                      shuffle=False,
                                      prefetch_size=config.get(
                                          'prefetch_to_host', 2),
                                      drop_remainder=False,
                                      data_dir=data_dir)
        val_iter = input_utils.start_input_pipeline(
            val_ds, config.get('prefetch_to_device', 1))

        return (val_iter, val_steps)

    val_iter_splits = {
        'val':
        _get_val_split(config.dataset,
                       split=config.val_split,
                       pp_eval=config.pp_eval,
                       data_dir=config.get('data_dir'))
    }

    if config.get('eval_on_cifar_10h'):
        cifar10_to_cifar10h_fn = data_uncertainty_utils.create_cifar10_to_cifar10h_fn(
            config.get('data_dir', None))
        preprocess_fn = preprocess_spec.parse(
            spec=config.pp_eval_cifar_10h,
            available_ops=preprocess_utils.all_ops())
        pp_eval = lambda ex: preprocess_fn(cifar10_to_cifar10h_fn(ex))
        val_iter_splits['cifar_10h'] = _get_val_split(
            'cifar10',
            split=config.get('cifar_10h_split') or 'test',
            pp_eval=pp_eval,
            data_dir=config.get('data_dir'))
    elif config.get('eval_on_imagenet_real'):
        imagenet_to_real_fn = data_uncertainty_utils.create_imagenet_to_real_fn(
        )
        preprocess_fn = preprocess_spec.parse(
            spec=config.pp_eval_imagenet_real,
            available_ops=preprocess_utils.all_ops())
        pp_eval = lambda ex: preprocess_fn(imagenet_to_real_fn(ex))
        val_iter_imagenet_real, val_steps = _get_val_split(
            'imagenet2012_real',
            split=config.get('imagenet_real_split') or 'validation',
            pp_eval=pp_eval,
            data_dir=config.get('data_dir'))
        val_iter_splits['imagenet_real'] = (val_iter_imagenet_real, val_steps)

    ood_ds = {}
    if config.get('ood_datasets') and config.get('ood_methods'):
        if config.get(
                'ood_methods'):  #  config.ood_methods is not a empty list
            logging.info('loading OOD dataset = %s', config.get('ood_dataset'))
            ood_ds, ood_ds_names = ood_utils.load_ood_datasets(
                config.dataset,
                config.ood_datasets,
                config.ood_split,
                config.pp_eval,
                config.pp_eval_ood,
                config.ood_methods,
                config.train_split,
                config.get('data_dir'),
                _get_val_split,
            )

    ntrain_img = input_utils.get_num_examples(
        config.dataset,
        split=config.train_split,
        process_batch_size=local_batch_size,
        data_dir=config.get('data_dir'))
    steps_per_epoch = int(ntrain_img / batch_size)

    if config.get('num_epochs'):
        total_steps = int(config.num_epochs * steps_per_epoch)
        assert not config.get(
            'total_steps'), 'Set either num_epochs or total_steps'
    else:
        total_steps = config.total_steps

    logging.info('Total train data points: %d', ntrain_img)
    logging.info(
        'Running for %d steps, that means %f epochs and %d steps per epoch',
        total_steps, total_steps * batch_size / ntrain_img, steps_per_epoch)

    write_note('Initializing model...')
    logging.info('config.model = %s', config.get('model'))
    model = ub.models.vision_transformer(num_classes=config.num_classes,
                                         **config.get('model', {}))

    # We want all parameters to be created in host RAM, not on any device, they'll
    # be sent there later as needed, otherwise we already encountered two
    # situations where we allocate them twice.
    @partial(jax.jit, backend='cpu')
    def init(rng):
        image_size = tuple(train_ds.element_spec['image'].shape[2:])
        logging.info('image_size = %s', image_size)
        dummy_input = jnp.zeros((local_batch_size, ) + image_size, jnp.float32)
        params = flax.core.unfreeze(model.init(rng, dummy_input,
                                               train=False))['params']

        # Set bias in the head to a low value, such that loss is small initially.
        params['head']['bias'] = jnp.full_like(params['head']['bias'],
                                               config.get('init_head_bias', 0))

        # init head kernel to all zeros for fine-tuning
        if config.get('model_init'):
            params['head']['kernel'] = jnp.full_like(params['head']['kernel'],
                                                     0)

        return params

    rng, rng_init = jax.random.split(rng)
    params_cpu = init(rng_init)

    if jax.process_index() == 0:
        num_params = sum(p.size for p in jax.tree_flatten(params_cpu)[0])
        parameter_overview.log_parameter_overview(params_cpu)
        writer.write_scalars(step=0, scalars={'num_params': num_params})

    @partial(jax.pmap, axis_name='batch')
    def evaluation_fn(params, images, labels, mask):
        # Ignore the entries with all zero labels for evaluation.
        mask *= labels.max(axis=1)
        logits, out = model.apply({'params': flax.core.freeze(params)},
                                  images,
                                  train=False)
        label_indices = config.get('label_indices')
        logging.info('!!! mask %s, label_indices %s', mask, label_indices)
        if label_indices:
            logits = logits[:, label_indices]

        # Note that logits and labels are usually of the shape [batch,num_classes].
        # But for OOD data, when num_classes_ood > num_classes_ind, we need to
        # adjust labels to labels[:, :config.num_classes] to match the shape of
        # logits. That is just to avoid shape mismatch. The output losses does not
        # have any meaning for OOD data, because OOD not belong to any IND class.
        losses = getattr(train_utils, config.get('loss', 'sigmoid_xent'))(
            logits=logits,
            labels=labels[:, :(
                len(label_indices) if label_indices else config.num_classes)],
            reduction=False)
        loss = jax.lax.psum(losses * mask, axis_name='batch')

        top1_idx = jnp.argmax(logits, axis=1)
        # Extracts the label at the highest logit index for each image.
        top1_correct = jnp.take_along_axis(labels, top1_idx[:, None],
                                           axis=1)[:, 0]
        ncorrect = jax.lax.psum(top1_correct * mask, axis_name='batch')
        n = jax.lax.psum(mask, axis_name='batch')

        metric_args = jax.lax.all_gather(
            [logits, labels, out['pre_logits'], mask], axis_name='batch')
        return ncorrect, loss, n, metric_args

    @partial(jax.pmap, axis_name='batch')
    def cifar_10h_evaluation_fn(params, images, labels, mask):
        logits, out = model.apply({'params': flax.core.freeze(params)},
                                  images,
                                  train=False)
        label_indices = config.get('label_indices')
        if label_indices:
            logits = logits[:, label_indices]

        losses = getattr(train_utils,
                         config.get('loss', 'softmax_xent'))(logits=logits,
                                                             labels=labels,
                                                             reduction=False)
        loss = jax.lax.psum(losses, axis_name='batch')

        top1_idx = jnp.argmax(logits, axis=1)
        # Extracts the label at the highest logit index for each image.
        one_hot_labels = jnp.eye(10)[jnp.argmax(labels, axis=1)]

        top1_correct = jnp.take_along_axis(one_hot_labels,
                                           top1_idx[:, None],
                                           axis=1)[:, 0]
        ncorrect = jax.lax.psum(top1_correct, axis_name='batch')
        n = jax.lax.psum(one_hot_labels, axis_name='batch')

        metric_args = jax.lax.all_gather(
            [logits, labels, out['pre_logits'], mask], axis_name='batch')
        return ncorrect, loss, n, metric_args

    # Setup function for computing representation.
    @partial(jax.pmap, axis_name='batch')
    def representation_fn(params, images, labels, mask):
        _, outputs = model.apply({'params': flax.core.freeze(params)},
                                 images,
                                 train=False)
        representation = outputs[config.fewshot.representation_layer]
        representation = jax.lax.all_gather(representation, 'batch')
        labels = jax.lax.all_gather(labels, 'batch')
        mask = jax.lax.all_gather(mask, 'batch')
        return representation, labels, mask

    # Load the optimizer from flax.
    opt_name = config.get('optim_name')
    write_note(f'Initializing {opt_name} optimizer...')
    opt_def = getattr(flax.optim, opt_name)(**config.get('optim', {}))

    # We jit this, such that the arrays that are created are created on the same
    # device as the input is, in this case the CPU. Else they'd be on device[0].
    opt_cpu = jax.jit(opt_def.create)(params_cpu)

    weight_decay_rules = config.get('weight_decay', []) or []
    rescale_value = config.lr.base if config.get(
        'weight_decay_decouple') else 1.
    weight_decay_fn = train_utils.get_weight_decay_fn(
        weight_decay_rules=weight_decay_rules, rescale_value=rescale_value)

    @partial(jax.pmap, axis_name='batch', donate_argnums=(0, ))
    def update_fn(opt, lr, images, labels, rng):
        """Update step."""

        measurements = {}

        # Get device-specific loss rng.
        rng, rng_model = jax.random.split(rng, 2)
        rng_model_local = jax.random.fold_in(rng_model,
                                             jax.lax.axis_index('batch'))

        def loss_fn(params, images, labels):
            logits, _ = model.apply({'params': flax.core.freeze(params)},
                                    images,
                                    train=True,
                                    rngs={'dropout': rng_model_local})
            label_indices = config.get('label_indices')
            if label_indices:
                logits = logits[:, label_indices]
            return getattr(train_utils,
                           config.get('loss', 'sigmoid_xent'))(logits=logits,
                                                               labels=labels)

        # Implementation considerations compared and summarized at
        # https://docs.google.com/document/d/1g3kMEvqu1DOawaflKNyUsIoQ4yIVEoyE5ZlIPkIl4Lc/edit?hl=en#
        l, g = train_utils.accumulate_gradient(jax.value_and_grad(loss_fn),
                                               opt.target, images, labels,
                                               config.get('grad_accum_steps'))
        l, g = jax.lax.pmean((l, g), axis_name='batch')

        # Log the gradient norm only if we need to compute it anyways (clipping)
        # or if we don't use grad_accum_steps, as they interact badly.
        if config.get('grad_accum_steps',
                      1) == 1 or config.get('grad_clip_norm'):
            grads, _ = jax.tree_flatten(g)
            l2_g = jnp.sqrt(sum([jnp.vdot(p, p) for p in grads]))
            measurements['l2_grads'] = l2_g

        # Optionally resize the global gradient to a maximum norm. We found this
        # useful in some cases across optimizers, hence it's in the main loop.
        if config.get('grad_clip_norm'):
            g_factor = jnp.minimum(1.0, config.grad_clip_norm / l2_g)
            g = jax.tree_util.tree_map(lambda p: g_factor * p, g)
        opt = opt.apply_gradient(g, learning_rate=lr)

        opt = opt.replace(target=weight_decay_fn(opt.target, lr))

        params, _ = jax.tree_flatten(opt.target)
        measurements['l2_params'] = jnp.sqrt(
            sum([jnp.vdot(p, p) for p in params]))

        return opt, l, rng, measurements

    rng, train_loop_rngs = jax.random.split(rng)
    reint_params = ('head/kernel', 'head/bias')
    if config.get('only_eval', False) or not config.get('reint_head', True):
        reint_params = []
    checkpoint_data = checkpoint_utils.maybe_load_checkpoint(
        train_loop_rngs=train_loop_rngs,
        save_checkpoint_path=save_checkpoint_path,
        init_optimizer=opt_cpu,
        init_params=params_cpu,
        init_fixed_model_states=None,
        default_reinit_params=reint_params,
        config=config,
    )
    train_loop_rngs = checkpoint_data.train_loop_rngs
    opt_cpu = checkpoint_data.optimizer
    accumulated_train_time = checkpoint_data.accumulated_train_time

    write_note('Adapting the checkpoint model...')
    adapted_params = checkpoint_utils.adapt_upstream_architecture(
        init_params=params_cpu, loaded_params=opt_cpu.target)
    opt_cpu = opt_cpu.replace(target=adapted_params)

    write_note('Kicking off misc stuff...')
    first_step = int(opt_cpu.state.step)  # Might be a DeviceArray type.
    if first_step == 0 and jax.process_index() == 0:
        writer.write_hparams(dict(config))
    chrono = train_utils.Chrono(first_step, total_steps, batch_size,
                                accumulated_train_time)
    # Note: switch to ProfileAllHosts() if you need to profile all hosts.
    # (Xprof data become much larger and take longer to load for analysis)
    profiler = periodic_actions.Profile(
        # Create profile after every restart to analyze pre-emption related
        # problems and assure we get similar performance in every run.
        logdir=output_dir,
        first_profile=first_step + 10)

    # Prepare the learning-rate and pre-fetch it to device to avoid delays.
    lr_fn = train_utils.create_learning_rate_schedule(total_steps,
                                                      **config.get('lr', {}))
    # TODO(dusenberrymw): According to flax docs, prefetching shouldn't be
    # necessary for TPUs.
    lr_iter = train_utils.prefetch_scalar(map(lr_fn, range(total_steps)),
                                          config.get('prefetch_to_device', 1))

    write_note(f'Replicating...\n{chrono.note}')
    opt_repl = flax_utils.replicate(opt_cpu)

    write_note(f'Initializing few-shotters...\n{chrono.note}')
    fewshotter = None
    if 'fewshot' in config and fewshot is not None:
        fewshotter = fewshot.FewShotEvaluator(
            representation_fn, config.fewshot,
            config.fewshot.get('batch_size') or batch_size_eval)

    checkpoint_writer = None

    # Note: we return the train loss, val loss, and fewshot best l2s for use in
    # reproducibility unit tests.
    train_loss = -jnp.inf
    val_loss = {val_name: -jnp.inf for val_name, _ in val_iter_splits.items()}
    fewshot_results = {'dummy': {(0, 1): -jnp.inf}}

    write_note(f'First step compilations...\n{chrono.note}')
    logging.info('first_step = %s', first_step)
    # Advance the iterators if we are restarting from an earlier checkpoint.
    # TODO(dusenberrymw): Look into checkpointing dataset state instead.
    if first_step > 0:
        write_note('Advancing iterators after resuming from a checkpoint...')
        lr_iter = itertools.islice(lr_iter, first_step, None)
        train_iter = itertools.islice(train_iter, first_step, None)
        # NOTE: Validation eval is only run on certain steps, so determine how many
        # times it was run previously.
        num_val_runs = sum(
            map(
                lambda i: train_utils.itstime(i, config.log_eval_steps,
                                              total_steps),
                range(1, first_step + 1)))
        for val_name, (val_iter, val_steps) in val_iter_splits.items():
            val_iter = itertools.islice(val_iter, num_val_runs * val_steps,
                                        None)
            val_iter_splits[val_name] = (val_iter, val_steps)

    # Using a python integer for step here, because opt.state.step is allocated
    # on TPU during replication.
    for step, train_batch, lr_repl in zip(
            range(first_step + 1, total_steps + 1), train_iter, lr_iter):

        with jax.profiler.TraceAnnotation('train_step', step_num=step, _r=1):
            if not config.get('only_eval', False):
                opt_repl, loss_value, train_loop_rngs, extra_measurements = update_fn(
                    opt_repl,
                    lr_repl,
                    train_batch['image'],
                    train_batch['labels'],
                    rng=train_loop_rngs)

        if jax.process_index() == 0:
            profiler(step)

        # Checkpoint saving
        if not config.get('only_eval', False) and train_utils.itstime(
                step, config.get('checkpoint_steps'), total_steps, process=0):
            write_note('Checkpointing...')
            chrono.pause()
            train_utils.checkpointing_timeout(
                checkpoint_writer, config.get('checkpoint_timeout', 1))
            accumulated_train_time = chrono.accum_train_time
            # We need to transfer the weights over now or else we risk keeping them
            # alive while they'll be updated in a future step, creating hard to debug
            # memory errors (see b/160593526). Also, takes device 0's params only.
            opt_cpu = jax.tree_util.tree_map(lambda x: np.array(x[0]),
                                             opt_repl)

            # Check whether we want to keep a copy of the current checkpoint.
            copy_step = None
            if train_utils.itstime(step, config.get('keep_checkpoint_steps'),
                                   total_steps):
                write_note('Keeping a checkpoint copy...')
                copy_step = step

            # Checkpoint should be a nested dictionary or FLAX datataclasses from
            # `flax.struct`. Both can be present in a checkpoint.
            checkpoint_data = checkpoint_utils.CheckpointData(
                train_loop_rngs=train_loop_rngs,
                optimizer=opt_cpu,
                accumulated_train_time=accumulated_train_time)

            checkpoint_writer = pool.apply_async(
                checkpoint_utils.checkpoint_trained_model,
                (checkpoint_data, save_checkpoint_path, copy_step))
            chrono.resume()

        # Report training progress
        if not config.get('only_eval', False) and train_utils.itstime(
                step, config.log_training_steps, total_steps, process=0):
            write_note('Reporting training progress...')
            train_loss = loss_value[
                0]  # Keep to return for reproducibility tests.
            timing_measurements, note = chrono.tick(step)
            write_note(note)
            train_measurements = {}
            train_measurements.update({
                'learning_rate': lr_repl[0],
                'training_loss': train_loss,
            })
            train_measurements.update(
                flax.jax_utils.unreplicate(extra_measurements))
            train_measurements.update(timing_measurements)
            writer.write_scalars(step, train_measurements)

        # Report validation performance
        if train_utils.itstime(step, config.log_eval_steps, total_steps):
            write_note('Evaluating on the validation set...')
            chrono.pause()
            for val_name, (val_iter, val_steps) in val_iter_splits.items():
                # Sets up evaluation metrics.
                ece_num_bins = config.get('ece_num_bins', 15)
                auc_num_bins = config.get('auc_num_bins', 1000)
                ece = rm.metrics.ExpectedCalibrationError(
                    num_bins=ece_num_bins)
                calib_auc = rm.metrics.CalibrationAUC(
                    correct_pred_as_pos_label=False)
                oc_auc_0_5 = rm.metrics.OracleCollaborativeAUC(
                    oracle_fraction=0.005, num_bins=auc_num_bins)
                oc_auc_1 = rm.metrics.OracleCollaborativeAUC(
                    oracle_fraction=0.01, num_bins=auc_num_bins)
                oc_auc_2 = rm.metrics.OracleCollaborativeAUC(
                    oracle_fraction=0.02, num_bins=auc_num_bins)
                oc_auc_5 = rm.metrics.OracleCollaborativeAUC(
                    oracle_fraction=0.05, num_bins=auc_num_bins)
                label_diversity = tf.keras.metrics.Mean()
                sample_diversity = tf.keras.metrics.Mean()
                ged = tf.keras.metrics.Mean()

                # Runs evaluation loop.
                ncorrect, loss, nseen = 0, 0, 0
                for _, batch in zip(range(val_steps), val_iter):
                    if val_name == 'cifar_10h':
                        batch_ncorrect, batch_losses, batch_n, batch_metric_args = (
                            cifar_10h_evaluation_fn(opt_repl.target,
                                                    batch['image'],
                                                    batch['labels'],
                                                    batch['mask']))
                    else:
                        batch_ncorrect, batch_losses, batch_n, batch_metric_args = (
                            evaluation_fn(opt_repl.target, batch['image'],
                                          batch['labels'], batch['mask']))
                    # All results are a replicated array shaped as follows:
                    # (local_devices, per_device_batch_size, elem_shape...)
                    # with each local device's entry being identical as they got psum'd.
                    # So let's just take the first one to the host as numpy.
                    ncorrect += np.sum(np.array(batch_ncorrect[0]))
                    loss += np.sum(np.array(batch_losses[0]))
                    nseen += np.sum(np.array(batch_n[0]))
                    if config.get('loss', 'sigmoid_xent') != 'sigmoid_xent':
                        # Here we parse batch_metric_args to compute uncertainty metrics.
                        # (e.g., ECE or Calibration AUC).
                        logits, labels, _, masks = batch_metric_args
                        masks = np.array(masks[0], dtype=np.bool)
                        logits = np.array(logits[0])
                        probs = jax.nn.softmax(logits)
                        # From one-hot to integer labels, as required by ECE.
                        int_labels = np.argmax(np.array(labels[0]), axis=-1)
                        int_preds = np.argmax(logits, axis=-1)
                        confidence = np.max(probs, axis=-1)
                        for p, c, l, d, m, label in zip(
                                probs, confidence, int_labels, int_preds,
                                masks, labels[0]):
                            ece.add_batch(p[m, :], label=l[m])
                            calib_auc.add_batch(d[m],
                                                label=l[m],
                                                confidence=c[m])
                            # TODO(jereliu): Extend to support soft multi-class probabilities.
                            oc_auc_0_5.add_batch(d[m],
                                                 label=l[m],
                                                 custom_binning_score=c[m])
                            oc_auc_1.add_batch(d[m],
                                               label=l[m],
                                               custom_binning_score=c[m])
                            oc_auc_2.add_batch(d[m],
                                               label=l[m],
                                               custom_binning_score=c[m])
                            oc_auc_5.add_batch(d[m],
                                               label=l[m],
                                               custom_binning_score=c[m])

                            if val_name == 'cifar_10h' or val_name == 'imagenet_real':
                                batch_label_diversity, batch_sample_diversity, batch_ged = data_uncertainty_utils.generalized_energy_distance(
                                    label[m], p[m, :], config.num_classes)
                                label_diversity.update_state(
                                    batch_label_diversity)
                                sample_diversity.update_state(
                                    batch_sample_diversity)
                                ged.update_state(batch_ged)

                val_loss[
                    val_name] = loss / nseen  # Keep for reproducibility tests.
                val_measurements = {
                    f'{val_name}_prec@1': ncorrect / nseen,
                    f'{val_name}_loss': val_loss[val_name],
                }
                if config.get('loss', 'sigmoid_xent') != 'sigmoid_xent':
                    val_measurements[f'{val_name}_ece'] = ece.result()['ece']
                    val_measurements[
                        f'{val_name}_calib_auc'] = calib_auc.result(
                        )['calibration_auc']
                    val_measurements[
                        f'{val_name}_oc_auc_0.5%'] = oc_auc_0_5.result(
                        )['collaborative_auc']
                    val_measurements[
                        f'{val_name}_oc_auc_1%'] = oc_auc_1.result(
                        )['collaborative_auc']
                    val_measurements[
                        f'{val_name}_oc_auc_2%'] = oc_auc_2.result(
                        )['collaborative_auc']
                    val_measurements[
                        f'{val_name}_oc_auc_5%'] = oc_auc_5.result(
                        )['collaborative_auc']
                writer.write_scalars(step, val_measurements)

                if val_name == 'cifar_10h' or val_name == 'imagenet_real':
                    cifar_10h_measurements = {
                        f'{val_name}_label_diversity':
                        label_diversity.result(),
                        f'{val_name}_sample_diversity':
                        sample_diversity.result(),
                        f'{val_name}_ged': ged.result(),
                    }
                    writer.write_scalars(step, cifar_10h_measurements)

            # OOD eval
            # Entries in the ood_ds dict include:
            # (ind_dataset, ood_dataset1, ood_dataset2, ...).
            # OOD metrics are computed using ind_dataset paired with each of the
            # ood_dataset. When Mahalanobis distance method is applied, train_ind_ds
            # is also included in the ood_ds.
            if ood_ds and config.ood_methods:
                ood_measurements = ood_utils.eval_ood_metrics(
                    ood_ds, ood_ds_names, config.ood_methods, evaluation_fn,
                    opt_repl)
                writer.write_scalars(step, ood_measurements)
            chrono.resume()

        if 'fewshot' in config and fewshotter is not None:
            # Compute few-shot on-the-fly evaluation.
            if train_utils.itstime(step, config.fewshot.log_steps,
                                   total_steps):
                chrono.pause()
                write_note(f'Few-shot evaluation...\n{chrono.note}')
                # Keep `results` to return for reproducibility tests.
                fewshot_results, best_l2 = fewshotter.run_all(
                    opt_repl.target, config.fewshot.datasets)

                # TODO(dusenberrymw): Remove this once fewshot.py is updated.
                def make_writer_measure_fn(step):
                    def writer_measure(name, value):
                        writer.write_scalars(step, {name: value})

                    return writer_measure

                fewshotter.walk_results(make_writer_measure_fn(step),
                                        fewshot_results, best_l2)
                chrono.resume()

        # End of step.
        if config.get('testing_failure_step'):
            # Break early to simulate infra failures in test cases.
            if config.testing_failure_step == step:
                break

    write_note(f'Done!\n{chrono.note}')
    pool.close()
    pool.join()
    writer.close()

    # Return final training loss, validation loss, and fewshot results for
    # reproducibility test cases.
    return train_loss, val_loss, fewshot_results
Esempio n. 3
0
def main(config, output_dir):

    seed = config.get('seed', 0)
    tf.random.set_seed(seed)

    if config.get('data_dir'):
        logging.info('data_dir=%s', config.data_dir)
    logging.info('Output dir: %s', output_dir)
    tf.io.gfile.makedirs(output_dir)

    # Create an asynchronous multi-metric writer.
    writer = metric_writers.create_default_writer(
        output_dir, just_logging=jax.process_index() > 0)

    # The pool is used to perform misc operations such as logging in async way.
    pool = multiprocessing.pool.ThreadPool()

    def write_note(note):
        if jax.process_index() == 0:
            logging.info('NOTE: %s', note)

    write_note('Initializing...')

    batch_size = config.batch_size
    batch_size_eval = config.get('batch_size_eval', batch_size)
    if (batch_size % jax.device_count() != 0
            or batch_size_eval % jax.device_count() != 0):
        raise ValueError(
            f'Batch sizes ({batch_size} and {batch_size_eval}) must '
            f'be divisible by device number ({jax.device_count()})')

    local_batch_size = batch_size // jax.process_count()
    local_batch_size_eval = batch_size_eval // jax.process_count()
    logging.info(
        'Global batch size %d on %d hosts results in %d local batch size. '
        'With %d devices per host (%d devices total), that\'s a %d per-device '
        'batch size.', batch_size, jax.process_count(), local_batch_size,
        jax.local_device_count(), jax.device_count(),
        local_batch_size // jax.local_device_count())

    write_note('Initializing val dataset(s)...')

    def _get_val_split(dataset, split, pp_eval, data_dir=None):
        # We do ceil rounding such that we include the last incomplete batch.
        nval_img = input_utils.get_num_examples(
            dataset,
            split=split,
            process_batch_size=local_batch_size_eval,
            drop_remainder=False,
            data_dir=data_dir)
        val_steps = int(np.ceil(nval_img / batch_size_eval))
        logging.info('Running validation for %d steps for %s, %s', val_steps,
                     dataset, split)

        if isinstance(pp_eval, str):
            pp_eval = preprocess_spec.parse(
                spec=pp_eval, available_ops=preprocess_utils.all_ops())

        val_ds = input_utils.get_data(dataset=dataset,
                                      split=split,
                                      rng=None,
                                      process_batch_size=local_batch_size_eval,
                                      preprocess_fn=pp_eval,
                                      cache=config.get('val_cache', 'batched'),
                                      num_epochs=1,
                                      repeat_after_batching=True,
                                      shuffle=False,
                                      prefetch_size=config.get(
                                          'prefetch_to_host', 2),
                                      drop_remainder=False,
                                      data_dir=data_dir)

        return val_ds

    val_ds_splits = {
        'val':
        _get_val_split(config.dataset,
                       split=config.val_split,
                       pp_eval=config.pp_eval,
                       data_dir=config.get('data_dir'))
    }

    if config.get('test_split'):
        val_ds_splits.update({
            'test':
            _get_val_split(config.dataset,
                           split=config.test_split,
                           pp_eval=config.pp_eval,
                           data_dir=config.get('data_dir'))
        })

    if config.get('eval_on_cifar_10h'):
        cifar10_to_cifar10h_fn = data_uncertainty_utils.create_cifar10_to_cifar10h_fn(
            config.get('data_dir', None))
        preprocess_fn = preprocess_spec.parse(
            spec=config.pp_eval_cifar_10h,
            available_ops=preprocess_utils.all_ops())
        pp_eval = lambda ex: preprocess_fn(cifar10_to_cifar10h_fn(ex))
        val_ds_splits['cifar_10h'] = _get_val_split(
            'cifar10',
            split=config.get('cifar_10h_split') or 'test',
            pp_eval=pp_eval,
            data_dir=config.get('data_dir'))
    elif config.get('eval_on_imagenet_real'):
        imagenet_to_real_fn = data_uncertainty_utils.create_imagenet_to_real_fn(
        )
        preprocess_fn = preprocess_spec.parse(
            spec=config.pp_eval_imagenet_real,
            available_ops=preprocess_utils.all_ops())
        pp_eval = lambda ex: preprocess_fn(imagenet_to_real_fn(ex))
        val_ds_splits['imagenet_real'] = _get_val_split(
            'imagenet2012_real',
            split=config.get('imagenet_real_split') or 'validation',
            pp_eval=pp_eval,
            data_dir=config.get('data_dir'))

    ood_ds = {}
    if config.get('ood_datasets') and config.get('ood_methods'):
        if config.get(
                'ood_methods'):  #  config.ood_methods is not a empty list
            logging.info('loading OOD dataset = %s',
                         config.get('ood_datasets'))
            ood_ds, ood_ds_names = ood_utils.load_ood_datasets(
                config.dataset,
                config.ood_datasets,
                config.ood_split,
                config.pp_eval,
                config.pp_eval_ood,
                config.ood_methods,
                config.train_split,
                config.get('data_dir'),
                _get_val_split,
            )

    write_note('Initializing model...')
    logging.info('config.model = %s', config.model)
    model = ub.models.vision_transformer(num_classes=config.num_classes,
                                         **config.model)

    ensemble_pred_fn = functools.partial(ensemble_prediction_fn, model.apply)

    @functools.partial(jax.pmap, axis_name='batch')
    def evaluation_fn(params, images, labels, mask):
        # params is a dict of the form:
        #   {'model_1': params_model_1, 'model_2': params_model_2, ...}
        # Ignore the entries with all zero labels for evaluation.
        mask *= labels.max(axis=1)
        loss_as_str = config.get('loss', 'sigmoid_xent')
        ens_logits, ens_prelogits = ensemble_pred_fn(params, images,
                                                     loss_as_str)

        label_indices = config.get('label_indices')
        logging.info('!!! mask %s, label_indices %s', mask, label_indices)
        if label_indices:
            ens_logits = ens_logits[:, label_indices]

        # Note that logits and labels are usually of the shape [batch,num_classes].
        # But for OOD data, when num_classes_ood > num_classes_ind, we need to
        # adjust labels to labels[:, :config.num_classes] to match the shape of
        # logits. That is just to avoid shape mismatch. The output losses does not
        # have any meaning for OOD data, because OOD not belong to any IND class.
        losses = getattr(train_utils, loss_as_str)(
            logits=ens_logits,
            labels=labels[:, :(
                len(label_indices) if label_indices else config.num_classes)],
            reduction=False)
        loss = jax.lax.psum(losses * mask, axis_name='batch')

        top1_idx = jnp.argmax(ens_logits, axis=1)
        # Extracts the label at the highest logit index for each image.
        top1_correct = jnp.take_along_axis(labels, top1_idx[:, None],
                                           axis=1)[:, 0]
        ncorrect = jax.lax.psum(top1_correct * mask, axis_name='batch')
        n = jax.lax.psum(mask, axis_name='batch')

        metric_args = jax.lax.all_gather(
            [ens_logits, labels, ens_prelogits, mask], axis_name='batch')
        return ncorrect, loss, n, metric_args

    @functools.partial(jax.pmap, axis_name='batch')
    def cifar_10h_evaluation_fn(params, images, labels, mask):
        loss_as_str = config.get('loss', 'softmax_xent')
        ens_logits, ens_prelogits = ensemble_pred_fn(params, images,
                                                     loss_as_str)
        label_indices = config.get('label_indices')
        if label_indices:
            ens_logits = ens_logits[:, label_indices]

        losses = getattr(train_utils, loss_as_str)(logits=ens_logits,
                                                   labels=labels,
                                                   reduction=False)
        loss = jax.lax.psum(losses, axis_name='batch')

        top1_idx = jnp.argmax(ens_logits, axis=1)
        # Extracts the label at the highest logit index for each image.
        one_hot_labels = jnp.eye(10)[jnp.argmax(labels, axis=1)]

        top1_correct = jnp.take_along_axis(one_hot_labels,
                                           top1_idx[:, None],
                                           axis=1)[:, 0]
        ncorrect = jax.lax.psum(top1_correct, axis_name='batch')
        n = jax.lax.psum(one_hot_labels, axis_name='batch')

        metric_args = jax.lax.all_gather(
            [ens_logits, labels, ens_prelogits, mask], axis_name='batch')
        return ncorrect, loss, n, metric_args

    # Setup function for computing representation.
    @functools.partial(jax.pmap, axis_name='batch')
    def representation_fn(params, images, labels, mask):
        # Return shape [batch_size, representation_size * ensemble_size]. During
        # few-shot eval, a single linear regressor is applied over all dimensions.
        representation = []
        for p in params.values():
            _, outputs = model.apply({'params': flax.core.freeze(p)},
                                     images,
                                     train=False)
            representation += [outputs[config.fewshot.representation_layer]]
        representation = jnp.concatenate(representation, axis=1)
        representation = jax.lax.all_gather(representation, 'batch')
        labels = jax.lax.all_gather(labels, 'batch')
        mask = jax.lax.all_gather(mask, 'batch')
        return representation, labels, mask

    write_note('Load checkpoints...')
    ensemble_params = load_checkpoints(config)

    write_note('Replicating...')
    ensemble_params = flax.jax_utils.replicate(ensemble_params)

    if jax.process_index() == 0:
        writer.write_hparams(dict(config))

    write_note('Initializing few-shotters...')
    fewshotter = None
    if 'fewshot' in config and fewshot is not None:
        fewshotter = fewshot.FewShotEvaluator(
            representation_fn, config.fewshot,
            config.fewshot.get('batch_size') or batch_size_eval)

    # Note: we return the train loss, val loss, and fewshot best l2s for use in
    # reproducibility unit tests.
    val_loss = {val_name: -jnp.inf for val_name, _ in val_ds_splits.items()}
    fewshot_results = {'dummy': {(0, 1): -jnp.inf}}
    step = 1

    # Report validation performance.
    write_note('Evaluating on the validation set...')
    for val_name, val_ds in val_ds_splits.items():
        # Sets up evaluation metrics.
        ece_num_bins = config.get('ece_num_bins', 15)
        auc_num_bins = config.get('auc_num_bins', 1000)
        ece = rm.metrics.ExpectedCalibrationError(num_bins=ece_num_bins)
        calib_auc = rm.metrics.CalibrationAUC(correct_pred_as_pos_label=False)
        oc_auc_0_5 = rm.metrics.OracleCollaborativeAUC(oracle_fraction=0.005,
                                                       num_bins=auc_num_bins)
        oc_auc_1 = rm.metrics.OracleCollaborativeAUC(oracle_fraction=0.01,
                                                     num_bins=auc_num_bins)
        oc_auc_2 = rm.metrics.OracleCollaborativeAUC(oracle_fraction=0.02,
                                                     num_bins=auc_num_bins)
        oc_auc_5 = rm.metrics.OracleCollaborativeAUC(oracle_fraction=0.05,
                                                     num_bins=auc_num_bins)
        label_diversity = tf.keras.metrics.Mean()
        sample_diversity = tf.keras.metrics.Mean()
        ged = tf.keras.metrics.Mean()

        # Runs evaluation loop.
        val_iter = input_utils.start_input_pipeline(
            val_ds, config.get('prefetch_to_device', 1))
        ncorrect, loss, nseen = 0, 0, 0
        for batch in val_iter:
            if val_name == 'cifar_10h':
                batch_ncorrect, batch_losses, batch_n, batch_metric_args = (
                    cifar_10h_evaluation_fn(ensemble_params, batch['image'],
                                            batch['labels'], batch['mask']))
            else:
                batch_ncorrect, batch_losses, batch_n, batch_metric_args = (
                    evaluation_fn(ensemble_params, batch['image'],
                                  batch['labels'], batch['mask']))
            # All results are a replicated array shaped as follows:
            # (local_devices, per_device_batch_size, elem_shape...)
            # with each local device's entry being identical as they got psum'd.
            # So let's just take the first one to the host as numpy.
            ncorrect += np.sum(np.array(batch_ncorrect[0]))
            loss += np.sum(np.array(batch_losses[0]))
            nseen += np.sum(np.array(batch_n[0]))
            if config.get('loss', 'sigmoid_xent') != 'sigmoid_xent':
                # Here we parse batch_metric_args to compute uncertainty metrics.
                # (e.g., ECE or Calibration AUC).
                logits, labels, _, masks = batch_metric_args
                masks = np.array(masks[0], dtype=np.bool)
                logits = np.array(logits[0])
                probs = jax.nn.softmax(logits)
                # From one-hot to integer labels, as required by ECE.
                int_labels = np.argmax(np.array(labels[0]), axis=-1)
                int_preds = np.argmax(logits, axis=-1)
                confidence = np.max(probs, axis=-1)
                for p, c, l, d, m, label in zip(probs, confidence, int_labels,
                                                int_preds, masks, labels[0]):
                    ece.add_batch(p[m, :], label=l[m])
                    calib_auc.add_batch(d[m], label=l[m], confidence=c[m])
                    # TODO(jereliu): Extend to support soft multi-class probabilities.
                    oc_auc_0_5.add_batch(d[m],
                                         label=l[m],
                                         custom_binning_score=c[m])
                    oc_auc_1.add_batch(d[m],
                                       label=l[m],
                                       custom_binning_score=c[m])
                    oc_auc_2.add_batch(d[m],
                                       label=l[m],
                                       custom_binning_score=c[m])
                    oc_auc_5.add_batch(d[m],
                                       label=l[m],
                                       custom_binning_score=c[m])

                    if val_name == 'cifar_10h' or val_name == 'imagenet_real':
                        batch_label_diversity, batch_sample_diversity, batch_ged = data_uncertainty_utils.generalized_energy_distance(
                            label[m], p[m, :], config.num_classes)
                        label_diversity.update_state(batch_label_diversity)
                        sample_diversity.update_state(batch_sample_diversity)
                        ged.update_state(batch_ged)

        val_loss[val_name] = loss / nseen  # Keep for reproducibility tests.
        val_measurements = {
            f'{val_name}_prec@1': ncorrect / nseen,
            f'{val_name}_loss': val_loss[val_name],
        }
        if config.get('loss', 'sigmoid_xent') != 'sigmoid_xent':
            val_measurements[f'{val_name}_ece'] = ece.result()['ece']
            val_measurements[f'{val_name}_calib_auc'] = calib_auc.result(
            )['calibration_auc']
            val_measurements[f'{val_name}_oc_auc_0.5%'] = oc_auc_0_5.result(
            )['collaborative_auc']
            val_measurements[f'{val_name}_oc_auc_1%'] = oc_auc_1.result(
            )['collaborative_auc']
            val_measurements[f'{val_name}_oc_auc_2%'] = oc_auc_2.result(
            )['collaborative_auc']
            val_measurements[f'{val_name}_oc_auc_5%'] = oc_auc_5.result(
            )['collaborative_auc']
        writer.write_scalars(step, val_measurements)

        if val_name == 'cifar_10h' or val_name == 'imagenet_real':
            cifar_10h_measurements = {
                f'{val_name}_label_diversity': label_diversity.result(),
                f'{val_name}_sample_diversity': sample_diversity.result(),
                f'{val_name}_ged': ged.result(),
            }
            writer.write_scalars(step, cifar_10h_measurements)

    # OOD eval
    # Entries in the ood_ds dict include:
    # (ind_dataset, ood_dataset1, ood_dataset2, ...).
    # OOD metrics are computed using ind_dataset paired with each of the
    # ood_dataset. When Mahalanobis distance method is applied, train_ind_ds
    # is also included in the ood_ds.
    if ood_ds and config.ood_methods:
        ood_measurements = ood_utils.eval_ood_metrics(ood_ds,
                                                      ood_ds_names,
                                                      config.ood_methods,
                                                      evaluation_fn,
                                                      ensemble_params,
                                                      n_prefetch=config.get(
                                                          'prefetch_to_device',
                                                          1))
        writer.write_scalars(step, ood_measurements)

    if 'fewshot' in config and fewshotter is not None:
        # Compute few-shot on-the-fly evaluation.
        write_note('Few-shot evaluation...')
        # Keep `results` to return for reproducibility tests.
        fewshot_results, best_l2 = fewshotter.run_all(ensemble_params,
                                                      config.fewshot.datasets)

        # TODO(dusenberrymw): Remove this once fewshot.py is updated.
        def make_writer_measure_fn(step):
            def writer_measure(name, value):
                writer.write_scalars(step, {name: value})

            return writer_measure

        fewshotter.walk_results(make_writer_measure_fn(step), fewshot_results,
                                best_l2)

    write_note('Done!')
    pool.close()
    pool.join()
    writer.close()

    # Return final training loss, validation loss, and fewshot results for
    # reproducibility test cases.
    return val_loss, fewshot_results
Esempio n. 4
0
def main(argv):
  del argv  # unused arg
  tf.io.gfile.makedirs(FLAGS.output_dir)
  logging.info('Saving checkpoints at %s', FLAGS.output_dir)
  tf.random.set_seed(FLAGS.seed)
  # Split the seed into a 2-tuple, for passing into dataset builder.
  dataset_seed = (FLAGS.seed, FLAGS.seed + 1)

  data_dir = FLAGS.data_dir
  if FLAGS.use_gpu:
    logging.info('Use GPU')
    strategy = tf.distribute.MirroredStrategy()
  else:
    logging.info('Use TPU at %s',
                 FLAGS.tpu if FLAGS.tpu is not None else 'local')
    resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu=FLAGS.tpu)
    tf.config.experimental_connect_to_cluster(resolver)
    tf.tpu.experimental.initialize_tpu_system(resolver)
    strategy = tf.distribute.TPUStrategy(resolver)

  batch_size = FLAGS.total_batch_size // FLAGS.num_dropout_samples_training
  test_batch_size = FLAGS.total_batch_size
  num_classes = 10 if FLAGS.dataset == 'cifar10' else 100

  aug_params = {
      'augmix': FLAGS.augmix,
      'aug_count': FLAGS.aug_count,
      'augmix_depth': FLAGS.augmix_depth,
      'augmix_prob_coeff': FLAGS.augmix_prob_coeff,
      'augmix_width': FLAGS.augmix_width,
      'ensemble_size': 1,
      'mixup_alpha': FLAGS.mixup_alpha,
  }
  validation_proportion = 1. - FLAGS.train_proportion
  use_validation_set = validation_proportion > 0.
  if FLAGS.dataset == 'cifar10':
    dataset_builder_class = ub.datasets.Cifar10Dataset
  else:
    dataset_builder_class = ub.datasets.Cifar100Dataset
  train_dataset_builder = dataset_builder_class(
      data_dir=data_dir,
      download_data=FLAGS.download_data,
      split=tfds.Split.TRAIN,
      use_bfloat16=FLAGS.use_bfloat16,
      aug_params=aug_params,
      validation_percent=validation_proportion,
      shuffle_buffer_size=FLAGS.shuffle_buffer_size,
      seed=dataset_seed)
  train_dataset = train_dataset_builder.load(batch_size=batch_size)
  if validation_proportion > 0.:
    validation_dataset_builder = dataset_builder_class(
        data_dir=data_dir,
        download_data=FLAGS.download_data,
        split=tfds.Split.VALIDATION,
        use_bfloat16=FLAGS.use_bfloat16,
        validation_percent=validation_proportion,
        drop_remainder=FLAGS.drop_remainder_for_eval)
    validation_dataset = validation_dataset_builder.load(
        batch_size=test_batch_size)
    validation_dataset = strategy.experimental_distribute_dataset(
        validation_dataset)
    val_sample_size = validation_dataset_builder.num_examples
    steps_per_val = steps_per_epoch = int(val_sample_size / test_batch_size)
  clean_test_dataset_builder = dataset_builder_class(
      data_dir=data_dir,
      download_data=FLAGS.download_data,
      split=tfds.Split.TEST,
      use_bfloat16=FLAGS.use_bfloat16,
      drop_remainder=FLAGS.drop_remainder_for_eval)
  clean_test_dataset = clean_test_dataset_builder.load(
      batch_size=test_batch_size)

  steps_per_epoch = train_dataset_builder.num_examples // batch_size
  steps_per_eval = clean_test_dataset_builder.num_examples // test_batch_size
  train_dataset = strategy.experimental_distribute_dataset(train_dataset)
  test_datasets = {
      'clean': strategy.experimental_distribute_dataset(clean_test_dataset),
  }

  if FLAGS.eval_on_ood:
    ood_dataset_names = FLAGS.ood_dataset
    ood_ds, steps_per_ood = ood_utils.load_ood_datasets(
        ood_dataset_names,
        clean_test_dataset_builder,
        validation_proportion,
        test_batch_size,
        drop_remainder=FLAGS.drop_remainder_for_eval)
    ood_datasets = {
        name: strategy.experimental_distribute_dataset(ds)
        for name, ds in ood_ds.items()
    }

  if FLAGS.corruptions_interval > 0:
    if FLAGS.dataset == 'cifar100':
      data_dir = FLAGS.cifar100_c_path
    corruption_types, _ = utils.load_corrupted_test_info(FLAGS.dataset)
    for corruption_type in corruption_types:
      for severity in range(1, 6):
        dataset = ub.datasets.get(
            f'{FLAGS.dataset}_corrupted',
            corruption_type=corruption_type,
            severity=severity,
            split=tfds.Split.TEST,
            data_dir=data_dir,
            drop_remainder=FLAGS.drop_remainder_for_eval).load(
                batch_size=batch_size)
        test_datasets[f'{corruption_type}_{severity}'] = (
            strategy.experimental_distribute_dataset(dataset))

  if FLAGS.use_bfloat16:
    tf.keras.mixed_precision.set_global_policy('mixed_bfloat16')

  summary_writer = tf.summary.create_file_writer(
      os.path.join(FLAGS.output_dir, 'summaries'))

  with strategy.scope():
    logging.info('Building ResNet model')
    if FLAGS.use_spec_norm:
      logging.info('Use Spectral Normalization with norm bound %.2f',
                   FLAGS.spec_norm_bound)
    if FLAGS.use_gp_layer:
      logging.info('Use GP layer with hidden units %d', FLAGS.gp_hidden_dim)

    model = ub.models.wide_resnet_sngp(
        input_shape=(32, 32, 3),
        batch_size=batch_size,
        depth=28,
        width_multiplier=10,
        num_classes=num_classes,
        l2=FLAGS.l2,
        use_mc_dropout=FLAGS.use_mc_dropout,
        use_filterwise_dropout=FLAGS.use_filterwise_dropout,
        dropout_rate=FLAGS.dropout_rate,
        use_gp_layer=FLAGS.use_gp_layer,
        gp_input_dim=FLAGS.gp_input_dim,
        gp_hidden_dim=FLAGS.gp_hidden_dim,
        gp_scale=FLAGS.gp_scale,
        gp_bias=FLAGS.gp_bias,
        gp_input_normalization=FLAGS.gp_input_normalization,
        gp_random_feature_type=FLAGS.gp_random_feature_type,
        gp_cov_discount_factor=FLAGS.gp_cov_discount_factor,
        gp_cov_ridge_penalty=FLAGS.gp_cov_ridge_penalty,
        use_spec_norm=FLAGS.use_spec_norm,
        spec_norm_iteration=FLAGS.spec_norm_iteration,
        spec_norm_bound=FLAGS.spec_norm_bound)
    logging.info('Model input shape: %s', model.input_shape)
    logging.info('Model output shape: %s', model.output_shape)
    logging.info('Model number of weights: %s', model.count_params())
    # Linearly scale learning rate and the decay epochs by vanilla settings.
    base_lr = FLAGS.base_learning_rate * batch_size / 128
    lr_decay_epochs = [(int(start_epoch_str) * FLAGS.train_epochs) // 200
                       for start_epoch_str in FLAGS.lr_decay_epochs]
    lr_schedule = ub.schedules.WarmUpPiecewiseConstantSchedule(
        steps_per_epoch,
        base_lr,
        decay_ratio=FLAGS.lr_decay_ratio,
        decay_epochs=lr_decay_epochs,
        warmup_epochs=FLAGS.lr_warmup_epochs)
    optimizer = tf.keras.optimizers.SGD(lr_schedule,
                                        momentum=1.0 - FLAGS.one_minus_momentum,
                                        nesterov=True)
    metrics = {
        'train/negative_log_likelihood': tf.keras.metrics.Mean(),
        'train/accuracy': tf.keras.metrics.SparseCategoricalAccuracy(),
        'train/loss': tf.keras.metrics.Mean(),
        'train/ece': rm.metrics.ExpectedCalibrationError(
            num_bins=FLAGS.num_bins),
        'test/negative_log_likelihood': tf.keras.metrics.Mean(),
        'test/accuracy': tf.keras.metrics.SparseCategoricalAccuracy(),
        'test/ece': rm.metrics.ExpectedCalibrationError(
            num_bins=FLAGS.num_bins),
        'test/stddev': tf.keras.metrics.Mean(),
    }
    if use_validation_set:
      metrics.update({
          'val/negative_log_likelihood': tf.keras.metrics.Mean(),
          'val/accuracy': tf.keras.metrics.SparseCategoricalAccuracy(),
          'val/ece': rm.metrics.ExpectedCalibrationError(
              num_bins=FLAGS.num_bins),
          'val/stddev': tf.keras.metrics.Mean(),
      })
    if FLAGS.eval_on_ood:
      ood_metrics = ood_utils.create_ood_metrics(
          ood_dataset_names, tpr_list=FLAGS.ood_tpr_threshold)
      metrics.update(ood_metrics)
    if FLAGS.corruptions_interval > 0:
      corrupt_metrics = {}
      for intensity in range(1, 6):
        for corruption in corruption_types:
          dataset_name = '{0}_{1}'.format(corruption, intensity)
          corrupt_metrics['test/nll_{}'.format(dataset_name)] = (
              tf.keras.metrics.Mean())
          corrupt_metrics['test/accuracy_{}'.format(dataset_name)] = (
              tf.keras.metrics.SparseCategoricalAccuracy())
          corrupt_metrics['test/ece_{}'.format(dataset_name)] = (
              rm.metrics.ExpectedCalibrationError(num_bins=FLAGS.num_bins))
          corrupt_metrics['test/stddev_{}'.format(dataset_name)] = (
              tf.keras.metrics.Mean())

    checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer)
    latest_checkpoint = tf.train.latest_checkpoint(FLAGS.output_dir)
    initial_epoch = 0
    logging.info('Output dir : %s', FLAGS.output_dir)
    if latest_checkpoint:
      # checkpoint.restore must be within a strategy.scope() so that optimizer
      # slot variables are mirrored.
      checkpoint.restore(latest_checkpoint)
      logging.info('Loaded checkpoint %s', latest_checkpoint)
      initial_epoch = optimizer.iterations.numpy() // steps_per_epoch
    if FLAGS.saved_model_dir:
      logging.info('Saved model dir : %s', FLAGS.saved_model_dir)
      latest_checkpoint = tf.train.latest_checkpoint(FLAGS.saved_model_dir)
      checkpoint.restore(latest_checkpoint)
      logging.info('Loaded checkpoint %s', latest_checkpoint)
    if FLAGS.eval_only:
      initial_epoch = FLAGS.train_epochs - 1  # Run just one epoch of eval

  @tf.function
  def train_step(iterator, step):
    """Training StepFn."""

    def step_fn(inputs, step):
      """Per-Replica StepFn."""
      images = inputs['features']
      labels = inputs['labels']

      if tf.equal(step, 0) and FLAGS.gp_cov_discount_factor < 0:
        # Resetting covaraince estimator at the begining of a new epoch.
        if FLAGS.use_gp_layer:
          model.layers[-1].reset_covariance_matrix()

      if FLAGS.augmix and FLAGS.aug_count >= 1:
        # Index 0 at augmix preprocessing is the unperturbed image.
        images = images[:, 1, ...]
        # This is for the case of combining AugMix and Mixup.
        if FLAGS.mixup_alpha > 0:
          labels = tf.split(labels, FLAGS.aug_count + 1, axis=0)[1]
      images = tf.tile(images, [FLAGS.num_dropout_samples_training, 1, 1, 1])
      if FLAGS.mixup_alpha > 0:
        labels = tf.tile(labels, [FLAGS.num_dropout_samples_training, 1])
      else:
        labels = tf.tile(labels, [FLAGS.num_dropout_samples_training])

      with tf.GradientTape() as tape:
        logits = model(images, training=True)
        if isinstance(logits, (list, tuple)):
          # If model returns a tuple of (logits, covmat), extract logits
          logits, _ = logits
        if FLAGS.use_bfloat16:
          logits = tf.cast(logits, tf.float32)
        if FLAGS.mixup_alpha > 0:
          negative_log_likelihood = tf.reduce_mean(
              tf.keras.losses.categorical_crossentropy(labels,
                                                       logits,
                                                       from_logits=True))
        else:
          negative_log_likelihood = tf.reduce_mean(
              tf.keras.losses.sparse_categorical_crossentropy(labels,
                                                              logits,
                                                              from_logits=True))

        l2_loss = sum(model.losses)
        loss = negative_log_likelihood + l2_loss
        # Scale the loss given the TPUStrategy will reduce sum all gradients.
        scaled_loss = loss / strategy.num_replicas_in_sync

      grads = tape.gradient(scaled_loss, model.trainable_variables)
      optimizer.apply_gradients(zip(grads, model.trainable_variables))

      probs = tf.nn.softmax(logits)
      if FLAGS.mixup_alpha > 0:
        labels = tf.argmax(labels, axis=-1)
      metrics['train/ece'].add_batch(probs, label=labels)
      metrics['train/loss'].update_state(loss)
      metrics['train/negative_log_likelihood'].update_state(
          negative_log_likelihood)
      metrics['train/accuracy'].update_state(labels, logits)

    strategy.run(step_fn, args=(next(iterator), step))

  @tf.function
  def test_step(iterator, dataset_name, num_steps):
    """Evaluation StepFn."""
    def step_fn(inputs):
      """Per-Replica StepFn."""
      images = inputs['features']
      labels = inputs['labels']

      logits_list = []
      stddev_list = []
      for _ in range(FLAGS.num_dropout_samples):
        logits = model(images, training=False)
        if isinstance(logits, (list, tuple)):
          # If model returns a tuple of (logits, covmat), extract both
          logits, covmat = logits
          if FLAGS.use_bfloat16:
            logits = tf.cast(logits, tf.float32)
          logits = ed.layers.utils.mean_field_logits(
              logits, covmat, mean_field_factor=FLAGS.gp_mean_field_factor)
        else:
          covmat = tf.eye(logits.shape[0])
          if FLAGS.use_bfloat16:
            logits = tf.cast(logits, tf.float32)
        stddev = tf.sqrt(tf.linalg.diag_part(covmat))

        stddev_list.append(stddev)
        logits_list.append(logits)

      # Logits dimension is (num_samples, batch_size, num_classes).
      logits_list = tf.stack(logits_list, axis=0)
      stddev_list = tf.stack(stddev_list, axis=0)

      stddev = tf.reduce_mean(stddev_list, axis=0)
      probs_list = tf.nn.softmax(logits_list)
      probs = tf.reduce_mean(probs_list, axis=0)
      logits = tf.reduce_mean(logits_list, axis=0)

      labels_broadcasted = tf.broadcast_to(
          labels, [FLAGS.num_dropout_samples,
                   tf.shape(labels)[0]])
      log_likelihoods = -tf.keras.losses.sparse_categorical_crossentropy(
          labels_broadcasted, logits_list, from_logits=True)
      negative_log_likelihood = tf.reduce_mean(
          -tf.reduce_logsumexp(log_likelihoods, axis=[0]) +
          tf.math.log(float(FLAGS.num_dropout_samples)))

      logging.info('Dataset name : %s', dataset_name)
      if dataset_name == 'clean':
        metrics['test/negative_log_likelihood'].update_state(
            negative_log_likelihood)
        metrics['test/accuracy'].update_state(labels, probs)
        metrics['test/ece'].add_batch(probs, label=labels)
        metrics['test/stddev'].update_state(stddev)
      elif dataset_name == 'val':
        metrics['val/negative_log_likelihood'].update_state(
            negative_log_likelihood)
        metrics['val/accuracy'].update_state(labels, probs)
        metrics['val/ece'].add_batch(probs, label=labels)
        metrics['val/stddev'].update_state(stddev)
      elif dataset_name.startswith('ood/'):
        ood_labels = 1 - inputs['is_in_distribution']
        if FLAGS.dempster_shafer_ood:
          ood_scores = ood_utils.DempsterShaferUncertainty(logits)
        else:
          ood_scores = 1 - tf.reduce_max(probs, axis=-1)

        # Edgecase for if dataset_name contains underscores
        for name, metric in metrics.items():
          if dataset_name in name:
            metric.update_state(ood_labels, ood_scores)
      elif FLAGS.corruptions_interval > 0:
        corrupt_metrics['test/nll_{}'.format(dataset_name)].update_state(
            negative_log_likelihood)
        corrupt_metrics['test/accuracy_{}'.format(dataset_name)].update_state(
            labels, probs)
        corrupt_metrics['test/ece_{}'.format(dataset_name)].add_batch(
            probs, label=labels)
        corrupt_metrics['test/stddev_{}'.format(dataset_name)].update_state(
            stddev)

    for _ in tf.range(tf.cast(num_steps, tf.int32)):
      strategy.run(step_fn, args=(next(iterator),))

  metrics.update({'test/ms_per_example': tf.keras.metrics.Mean()})

  step_variable = tf.Variable(0, dtype=tf.int32)
  train_iterator = iter(train_dataset)
  start_time = time.time()

  for epoch in range(initial_epoch, FLAGS.train_epochs):
    logging.info('Starting to run epoch: %s', epoch)
    if not FLAGS.eval_only:
      for step in range(steps_per_epoch):
        step_variable.assign(step)
        # Pass `step` as a tf.Variable to train_step to prevent the tf.function
        # train_step() re-compiling itself at each function call.
        train_step(train_iterator, step_variable)

        current_step = epoch * steps_per_epoch + (step + 1)
        max_steps = steps_per_epoch * FLAGS.train_epochs
        time_elapsed = time.time() - start_time
        steps_per_sec = float(current_step) / time_elapsed
        eta_seconds = (max_steps - current_step) / steps_per_sec
        message = ('{:.1%} completion: epoch {:d}/{:d}. {:.1f} steps/s. '
                   'ETA: {:.0f} min. Time elapsed: {:.0f} min'.format(
                       current_step / max_steps,
                       epoch + 1,
                       FLAGS.train_epochs,
                       steps_per_sec,
                       eta_seconds / 60,
                       time_elapsed / 60))
        if step % 20 == 0:
          logging.info(message)

    datasets_to_evaluate = {'clean': test_datasets['clean']}
    if use_validation_set:
      datasets_to_evaluate['val'] = validation_dataset
    if (FLAGS.corruptions_interval > 0 and
        (epoch + 1) % FLAGS.corruptions_interval == 0):
      datasets_to_evaluate = test_datasets
    for dataset_name, test_dataset in datasets_to_evaluate.items():
      test_iterator = iter(test_dataset)
      logging.info('Testing on dataset %s', dataset_name)
      steps_per_eval = steps_per_val if dataset_name == 'val' else steps_per_eval
      logging.info('Starting to run eval at epoch: %s', epoch)
      test_start_time = time.time()
      test_step(test_iterator, dataset_name, steps_per_eval)
      ms_per_example = (time.time() - test_start_time) * 1e6 / batch_size
      metrics['test/ms_per_example'].update_state(ms_per_example)

      logging.info('Done with testing on %s', dataset_name)

    if FLAGS.eval_on_ood:
      for ood_dataset_name, ood_dataset in ood_datasets.items():
        ood_iterator = iter(ood_dataset)
        logging.info('Calculating OOD on dataset %s', ood_dataset_name)
        logging.info('Running OOD eval at epoch: %s', epoch)
        test_step(ood_iterator, ood_dataset_name,
                  steps_per_ood[ood_dataset_name])

        logging.info('Done with OOD eval on %s', dataset_name)

    corrupt_results = {}
    if (FLAGS.corruptions_interval > 0 and
        (epoch + 1) % FLAGS.corruptions_interval == 0):
      corrupt_results = utils.aggregate_corrupt_metrics(corrupt_metrics,
                                                        corruption_types)

    logging.info('Train Loss: %.4f, Accuracy: %.2f%%',
                 metrics['train/loss'].result(),
                 metrics['train/accuracy'].result() * 100)
    if use_validation_set:
      logging.info('Val NLL: %.4f, Accuracy: %.2f%%',
                   metrics['val/negative_log_likelihood'].result(),
                   metrics['val/accuracy'].result() * 100)
    logging.info('Test NLL: %.4f, Accuracy: %.2f%%',
                 metrics['test/negative_log_likelihood'].result(),
                 metrics['test/accuracy'].result() * 100)
    total_results = {name: metric.result() for name, metric in metrics.items()}
    total_results.update(corrupt_results)
    # Metrics from Robustness Metrics (like ECE) will return a dict with a
    # single key/value, instead of a scalar.
    total_results = {
        k: (list(v.values())[0] if isinstance(v, dict) else v)
        for k, v in total_results.items()
    }
    with summary_writer.as_default():
      for name, result in total_results.items():
        tf.summary.scalar(name, result, step=epoch + 1)

    for metric in metrics.values():
      metric.reset_states()

    if FLAGS.corruptions_interval > 0:
      for metric in corrupt_metrics.values():
        metric.reset_states()

    if (FLAGS.checkpoint_interval > 0 and
        (epoch + 1) % FLAGS.checkpoint_interval == 0):
      checkpoint_name = checkpoint.save(
          os.path.join(FLAGS.output_dir, 'checkpoint'))
      logging.info('Saved checkpoint to %s', checkpoint_name)

  final_checkpoint_name = checkpoint.save(
      os.path.join(FLAGS.output_dir, 'checkpoint'))
  logging.info('Saved last checkpoint to %s', final_checkpoint_name)

  final_save_name = os.path.join(FLAGS.output_dir, 'model')
  model.save(final_save_name)
  logging.info('Saved model to %s', final_save_name)
  with summary_writer.as_default():
    hp.hparams({
        'base_learning_rate': FLAGS.base_learning_rate,
        'one_minus_momentum': FLAGS.one_minus_momentum,
        'l2': FLAGS.l2,
        'gp_mean_field_factor': FLAGS.gp_mean_field_factor,
    })
Esempio n. 5
0
def main(argv):
    del argv  # unused arg
    tf.io.gfile.makedirs(FLAGS.output_dir)
    logging.info('Saving checkpoints at %s', FLAGS.output_dir)
    tf.random.set_seed(FLAGS.seed)

    data_dir = FLAGS.data_dir
    if FLAGS.use_gpu:
        logging.info('Use GPU')
        strategy = tf.distribute.MirroredStrategy()
    else:
        logging.info('Use TPU at %s',
                     FLAGS.tpu if FLAGS.tpu is not None else 'local')
        resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
            tpu=FLAGS.tpu)
        tf.config.experimental_connect_to_cluster(resolver)
        tf.tpu.experimental.initialize_tpu_system(resolver)
        strategy = tf.distribute.TPUStrategy(resolver)

    ds_info = tfds.builder(FLAGS.dataset).info
    batch_size = (FLAGS.per_core_batch_size * FLAGS.num_cores //
                  FLAGS.num_dropout_samples_training)
    test_batch_size = FLAGS.per_core_batch_size * FLAGS.num_cores
    steps_per_epoch = ds_info.splits['train'].num_examples // batch_size
    steps_per_eval = ds_info.splits['test'].num_examples // test_batch_size
    num_classes = ds_info.features['label'].num_classes

    train_builder = ub.datasets.get(FLAGS.dataset,
                                    data_dir=data_dir,
                                    download_data=FLAGS.download_data,
                                    split=tfds.Split.TRAIN,
                                    validation_percent=1. -
                                    FLAGS.train_proportion)
    train_dataset = train_builder.load(batch_size=batch_size)
    validation_dataset = None
    steps_per_validation = 0
    if FLAGS.train_proportion < 1.0:
        validation_builder = ub.datasets.get(
            FLAGS.dataset,
            data_dir=data_dir,
            split=tfds.Split.VALIDATION,
            validation_percent=1. - FLAGS.train_proportion,
            drop_remainder=FLAGS.drop_remainder_for_eval)
        validation_dataset = validation_builder.load(
            batch_size=test_batch_size)
        validation_dataset = strategy.experimental_distribute_dataset(
            validation_dataset)
        steps_per_validation = validation_builder.num_examples // test_batch_size
    clean_test_builder = ub.datasets.get(
        FLAGS.dataset,
        data_dir=data_dir,
        split=tfds.Split.TEST,
        drop_remainder=FLAGS.drop_remainder_for_eval)
    clean_test_dataset = clean_test_builder.load(batch_size=test_batch_size)
    train_dataset = strategy.experimental_distribute_dataset(train_dataset)
    test_datasets = {
        'clean': strategy.experimental_distribute_dataset(clean_test_dataset),
    }
    steps_per_epoch = train_builder.num_examples // batch_size
    steps_per_eval = clean_test_builder.num_examples // batch_size
    num_classes = 100 if FLAGS.dataset == 'cifar100' else 10

    if FLAGS.eval_on_ood:
        ood_dataset_names = FLAGS.ood_dataset
        ood_ds, steps_per_ood = ood_utils.load_ood_datasets(
            ood_dataset_names,
            clean_test_builder,
            1. - FLAGS.train_proportion,
            batch_size,
            drop_remainder=FLAGS.drop_remainder_for_eval)
        ood_datasets = {
            name: strategy.experimental_distribute_dataset(ds)
            for name, ds in ood_ds.items()
        }
    if FLAGS.corruptions_interval > 0:
        if FLAGS.dataset == 'cifar100':
            data_dir = FLAGS.cifar100_c_path
        corruption_types, _ = utils.load_corrupted_test_info(FLAGS.dataset)
        for corruption_type in corruption_types:
            for severity in range(1, 6):
                dataset = ub.datasets.get(
                    f'{FLAGS.dataset}_corrupted',
                    corruption_type=corruption_type,
                    data_dir=data_dir,
                    severity=severity,
                    split=tfds.Split.TEST,
                    drop_remainder=FLAGS.drop_remainder_for_eval).load(
                        batch_size=test_batch_size)
                test_datasets[f'{corruption_type}_{severity}'] = (
                    strategy.experimental_distribute_dataset(dataset))

    summary_writer = tf.summary.create_file_writer(
        os.path.join(FLAGS.output_dir, 'summaries'))

    with strategy.scope():
        logging.info('Building ResNet model')
        model = ub.models.wide_resnet_dropout(
            input_shape=(32, 32, 3),
            depth=28,
            width_multiplier=10,
            num_classes=num_classes,
            l2=FLAGS.l2,
            dropout_rate=FLAGS.dropout_rate,
            residual_dropout=FLAGS.residual_dropout,
            filterwise_dropout=FLAGS.filterwise_dropout)
        logging.info('Model input shape: %s', model.input_shape)
        logging.info('Model output shape: %s', model.output_shape)
        logging.info('Model number of weights: %s', model.count_params())
        # Linearly scale learning rate and the decay epochs by vanilla settings.
        base_lr = FLAGS.base_learning_rate * batch_size / 128
        lr_decay_epochs = [(int(start_epoch_str) * FLAGS.train_epochs) // 200
                           for start_epoch_str in FLAGS.lr_decay_epochs]
        lr_schedule = ub.schedules.WarmUpPiecewiseConstantSchedule(
            steps_per_epoch,
            base_lr,
            decay_ratio=FLAGS.lr_decay_ratio,
            decay_epochs=lr_decay_epochs,
            warmup_epochs=FLAGS.lr_warmup_epochs)
        optimizer = tf.keras.optimizers.SGD(lr_schedule,
                                            momentum=1.0 -
                                            FLAGS.one_minus_momentum,
                                            nesterov=True)
        metrics = {
            'train/negative_log_likelihood':
            tf.keras.metrics.Mean(),
            'train/accuracy':
            tf.keras.metrics.SparseCategoricalAccuracy(),
            'train/loss':
            tf.keras.metrics.Mean(),
            'train/ece':
            rm.metrics.ExpectedCalibrationError(num_bins=FLAGS.num_bins),
            'test/negative_log_likelihood':
            tf.keras.metrics.Mean(),
            'test/accuracy':
            tf.keras.metrics.SparseCategoricalAccuracy(),
            'test/ece':
            rm.metrics.ExpectedCalibrationError(num_bins=FLAGS.num_bins),
        }
        if validation_dataset:
            metrics.update({
                'validation/negative_log_likelihood':
                tf.keras.metrics.Mean(),
                'validation/accuracy':
                tf.keras.metrics.SparseCategoricalAccuracy(),
                'validation/ece':
                rm.metrics.ExpectedCalibrationError(num_bins=FLAGS.num_bins),
            })
        if FLAGS.eval_on_ood:
            ood_metrics = ood_utils.create_ood_metrics(ood_dataset_names)
            metrics.update(ood_metrics)
        if FLAGS.corruptions_interval > 0:
            corrupt_metrics = {}
            for intensity in range(1, 6):
                for corruption in corruption_types:
                    dataset_name = '{0}_{1}'.format(corruption, intensity)
                    corrupt_metrics['test/nll_{}'.format(dataset_name)] = (
                        tf.keras.metrics.Mean())
                    corrupt_metrics['test/accuracy_{}'.format(
                        dataset_name)] = (
                            tf.keras.metrics.SparseCategoricalAccuracy())
                    corrupt_metrics['test/ece_{}'.format(dataset_name)] = (
                        rm.metrics.ExpectedCalibrationError(
                            num_bins=FLAGS.num_bins))

        checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer)
        latest_checkpoint = tf.train.latest_checkpoint(FLAGS.output_dir)
        initial_epoch = 0
        if latest_checkpoint:
            # checkpoint.restore must be within a strategy.scope() so that optimizer
            # slot variables are mirrored.
            checkpoint.restore(latest_checkpoint)
            logging.info('Loaded checkpoint %s', latest_checkpoint)
            initial_epoch = optimizer.iterations.numpy() // steps_per_epoch

    @tf.function
    def train_step(iterator):
        """Training StepFn."""
        def step_fn(inputs):
            """Per-Replica StepFn."""
            images = inputs['features']
            labels = inputs['labels']
            images = tf.tile(images,
                             [FLAGS.num_dropout_samples_training, 1, 1, 1])
            labels = tf.tile(labels, [FLAGS.num_dropout_samples_training])
            with tf.GradientTape() as tape:
                logits = model(images, training=True)
                negative_log_likelihood = tf.reduce_mean(
                    tf.keras.losses.sparse_categorical_crossentropy(
                        labels, logits, from_logits=True))
                l2_loss = sum(model.losses)
                loss = negative_log_likelihood + l2_loss
                # Scale the loss given the TPUStrategy will reduce sum all gradients.
                scaled_loss = loss / strategy.num_replicas_in_sync

            grads = tape.gradient(scaled_loss, model.trainable_variables)
            optimizer.apply_gradients(zip(grads, model.trainable_variables))

            probs = tf.nn.softmax(logits)
            metrics['train/ece'].add_batch(probs, label=labels)
            metrics['train/loss'].update_state(loss)
            metrics['train/negative_log_likelihood'].update_state(
                negative_log_likelihood)
            metrics['train/accuracy'].update_state(labels, logits)

        for _ in tf.range(tf.cast(steps_per_epoch, tf.int32)):
            strategy.run(step_fn, args=(next(iterator), ))

    @tf.function
    def test_step(iterator, dataset_split, dataset_name, num_steps):
        """Evaluation StepFn."""
        def step_fn(inputs):
            """Per-Replica StepFn."""
            images = inputs['features']
            labels = inputs['labels']

            logits_list = []
            for _ in range(FLAGS.num_dropout_samples):
                logits = model(images, training=False)
                logits_list.append(logits)

            # Logits dimension is (num_samples, batch_size, num_classes).
            logits_list = tf.stack(logits_list, axis=0)
            probs_list = tf.nn.softmax(logits_list)
            probs = tf.reduce_mean(probs_list, axis=0)

            labels_broadcasted = tf.broadcast_to(
                labels, [FLAGS.num_dropout_samples,
                         tf.shape(labels)[0]])
            log_likelihoods = -tf.keras.losses.sparse_categorical_crossentropy(
                labels_broadcasted, logits_list, from_logits=True)
            negative_log_likelihood = tf.reduce_mean(
                -tf.reduce_logsumexp(log_likelihoods, axis=[0]) +
                tf.math.log(float(FLAGS.num_dropout_samples)))

            if dataset_name == 'clean':
                metrics[
                    f'{dataset_split}/negative_log_likelihood'].update_state(
                        negative_log_likelihood)
                metrics[f'{dataset_split}/accuracy'].update_state(
                    labels, probs)
                metrics[f'{dataset_split}/ece'].add_batch(probs, label=labels)
            elif dataset_name.startswith('ood/'):
                ood_labels = 1 - inputs['is_in_distribution']
                if FLAGS.dempster_shafer_ood:
                    ood_scores = ood_utils.DempsterShaferUncertainty(logits)
                else:
                    ood_scores = 1 - tf.reduce_max(probs, axis=-1)

                for name, metric in metrics.items():
                    if dataset_name in name:
                        metric.update_state(ood_labels, ood_scores)
            else:
                corrupt_metrics['test/nll_{}'.format(
                    dataset_name)].update_state(negative_log_likelihood)
                corrupt_metrics['test/accuracy_{}'.format(
                    dataset_name)].update_state(labels, probs)
                corrupt_metrics['test/ece_{}'.format(dataset_name)].add_batch(
                    probs, label=labels)

        for _ in tf.range(tf.cast(num_steps, tf.int32)):
            strategy.run(step_fn, args=(next(iterator), ))

    metrics.update({'test/ms_per_example': tf.keras.metrics.Mean()})

    train_iterator = iter(train_dataset)
    start_time = time.time()
    for epoch in range(initial_epoch, FLAGS.train_epochs):
        logging.info('Starting to run epoch: %s', epoch)
        train_step(train_iterator)

        current_step = (epoch + 1) * steps_per_epoch
        max_steps = steps_per_epoch * FLAGS.train_epochs
        time_elapsed = time.time() - start_time
        steps_per_sec = float(current_step) / time_elapsed
        eta_seconds = (max_steps - current_step) / steps_per_sec
        message = ('{:.1%} completion: epoch {:d}/{:d}. {:.1f} steps/s. '
                   'ETA: {:.0f} min. Time elapsed: {:.0f} min'.format(
                       current_step / max_steps, epoch + 1, FLAGS.train_epochs,
                       steps_per_sec, eta_seconds / 60, time_elapsed / 60))
        logging.info(message)

        if validation_dataset:
            validation_iterator = iter(validation_dataset)
            test_step(validation_iterator, 'validation', 'clean',
                      steps_per_validation)
        datasets_to_evaluate = {'clean': test_datasets['clean']}
        if (FLAGS.corruptions_interval > 0
                and (epoch + 1) % FLAGS.corruptions_interval == 0):
            datasets_to_evaluate = test_datasets
        for dataset_name, test_dataset in datasets_to_evaluate.items():
            test_iterator = iter(test_dataset)
            logging.info('Testing on dataset %s', dataset_name)
            logging.info('Starting to run eval at epoch: %s', epoch)
            test_start_time = time.time()
            test_step(test_iterator, 'test', dataset_name, steps_per_eval)
            ms_per_example = (time.time() - test_start_time) * 1e6 / batch_size
            metrics['test/ms_per_example'].update_state(ms_per_example)

            logging.info('Done with testing on %s', dataset_name)

        if FLAGS.eval_on_ood:
            for ood_dataset_name, ood_dataset in ood_datasets.items():
                ood_iterator = iter(ood_dataset)
                logging.info('Calculating OOD on dataset %s', ood_dataset_name)
                logging.info('Running OOD eval at epoch: %s', epoch)
                test_step(ood_iterator, 'test', ood_dataset_name,
                          steps_per_ood[ood_dataset_name])

                logging.info('Done with OOD eval on %s', dataset_name)

        corrupt_results = {}
        if (FLAGS.corruptions_interval > 0
                and (epoch + 1) % FLAGS.corruptions_interval == 0):
            corrupt_results = utils.aggregate_corrupt_metrics(
                corrupt_metrics, corruption_types)

        logging.info('Train Loss: %.4f, Accuracy: %.2f%%',
                     metrics['train/loss'].result(),
                     metrics['train/accuracy'].result() * 100)
        logging.info('Test NLL: %.4f, Accuracy: %.2f%%',
                     metrics['test/negative_log_likelihood'].result(),
                     metrics['test/accuracy'].result() * 100)
        total_results = {
            name: metric.result()
            for name, metric in metrics.items()
        }
        total_results.update(corrupt_results)
        # Metrics from Robustness Metrics (like ECE) will return a dict with a
        # single key/value, instead of a scalar.
        total_results = {
            k: (list(v.values())[0] if isinstance(v, dict) else v)
            for k, v in total_results.items()
        }
        with summary_writer.as_default():
            for name, result in total_results.items():
                tf.summary.scalar(name, result, step=epoch + 1)

        for metric in metrics.values():
            metric.reset_states()

        if FLAGS.corruptions_interval > 0:
            for metric in corrupt_metrics.values():
                metric.reset_states()

        if (FLAGS.checkpoint_interval > 0
                and (epoch + 1) % FLAGS.checkpoint_interval == 0):
            checkpoint_name = checkpoint.save(
                os.path.join(FLAGS.output_dir, 'checkpoint'))
            logging.info('Saved checkpoint to %s', checkpoint_name)
    final_checkpoint_name = checkpoint.save(
        os.path.join(FLAGS.output_dir, 'checkpoint'))
    logging.info('Saved last checkpoint to %s', final_checkpoint_name)
    with summary_writer.as_default():
        hp.hparams({
            'base_learning_rate':
            FLAGS.base_learning_rate,
            'one_minus_momentum':
            FLAGS.one_minus_momentum,
            'l2':
            FLAGS.l2,
            'dropout_rate':
            FLAGS.dropout_rate,
            'num_dropout_samples':
            FLAGS.num_dropout_samples,
            'num_dropout_samples_training':
            FLAGS.num_dropout_samples_training,
        })
Esempio n. 6
0
def main(argv):
    del argv  # unused arg
    if not FLAGS.use_gpu:
        raise ValueError('Only GPU is currently supported.')
    if FLAGS.num_cores > 1:
        raise ValueError('Only a single accelerator is currently supported.')
    tf.random.set_seed(FLAGS.seed)
    tf.io.gfile.makedirs(FLAGS.output_dir)
    logging.info('output_dir=%s', FLAGS.output_dir)

    ds_info = tfds.builder(FLAGS.dataset).info
    batch_size = FLAGS.per_core_batch_size * FLAGS.num_cores
    steps_per_eval = ds_info.splits['test'].num_examples // batch_size
    num_classes = ds_info.features['label'].num_classes

    data_dir = FLAGS.data_dir
    dataset_builder = ub.datasets.get(
        FLAGS.dataset,
        download_data=FLAGS.download_data,
        data_dir=data_dir,
        split=tfds.Split.TEST,
        drop_remainder=FLAGS.drop_remainder_for_eval)
    dataset = dataset_builder.load(batch_size=batch_size)
    test_datasets = {'clean': dataset}
    if FLAGS.eval_on_ood:
        ood_dataset_names = FLAGS.ood_dataset
        ood_datasets, steps_per_ood = ood_utils.load_ood_datasets(
            ood_dataset_names,
            dataset_builder,
            1. - FLAGS.train_proportion,
            batch_size,
            drop_remainder=FLAGS.drop_remainder_for_eval)
        test_datasets.update(ood_datasets)
    if FLAGS.dataset == 'cifar100':
        data_dir = FLAGS.cifar100_c_path
    corruption_types, _ = utils.load_corrupted_test_info(FLAGS.dataset)
    for corruption_type in corruption_types:
        for severity in range(1, 6):
            dataset = ub.datasets.get(
                f'{FLAGS.dataset}_corrupted',
                corruption_type=corruption_type,
                download_data=FLAGS.download_data,
                data_dir=data_dir,
                severity=severity,
                split=tfds.Split.TEST,
                drop_remainder=FLAGS.drop_remainder_for_eval).load(
                    batch_size=batch_size)
            test_datasets[f'{corruption_type}_{severity}'] = dataset

    model = ub.models.wide_resnet(input_shape=ds_info.features['image'].shape,
                                  depth=28,
                                  width_multiplier=10,
                                  num_classes=num_classes,
                                  l2=0.,
                                  version=2)
    logging.info('Model input shape: %s', model.input_shape)
    logging.info('Model output shape: %s', model.output_shape)
    logging.info('Model number of weights: %s', model.count_params())

    # Search for checkpoints from their index file; then remove the index suffix.
    ensemble_filenames = parse_checkpoint_dir(FLAGS.checkpoint_dir)
    ensemble_size = len(ensemble_filenames)
    logging.info('Ensemble size: %s', ensemble_size)
    logging.info('Ensemble number of weights: %s',
                 ensemble_size * model.count_params())
    logging.info('Ensemble filenames: %s', str(ensemble_filenames))
    checkpoint = tf.train.Checkpoint(model=model)

    # Write model predictions to files.
    num_datasets = len(test_datasets)
    for m, ensemble_filename in enumerate(ensemble_filenames):
        checkpoint.restore(ensemble_filename)
        for n, (name, test_dataset) in enumerate(test_datasets.items()):
            filename = '{dataset}_{member}.npy'.format(
                dataset=name.replace('/', '_'),
                member=m)  # ood dataset name has '/'
            filename = os.path.join(FLAGS.output_dir, filename)
            if not tf.io.gfile.exists(filename):
                logits = []
                test_iterator = iter(test_dataset)
                steps = steps_per_eval if 'ood/' not in name else steps_per_ood[
                    name]
                for _ in range(steps):
                    features = next(test_iterator)['features']  # pytype: disable=unsupported-operands
                    logits.append(model(features, training=False))

                logits = tf.concat(logits, axis=0)
                with tf.io.gfile.GFile(filename, 'w') as f:
                    np.save(f, logits.numpy())
            percent = (m * num_datasets +
                       (n + 1)) / (ensemble_size * num_datasets)
            message = (
                '{:.1%} completion for prediction: ensemble member {:d}/{:d}. '
                'Dataset {:d}/{:d}'.format(percent, m + 1, ensemble_size,
                                           n + 1, num_datasets))
            logging.info(message)

    metrics = {
        'test/negative_log_likelihood': tf.keras.metrics.Mean(),
        'test/gibbs_cross_entropy': tf.keras.metrics.Mean(),
        'test/accuracy': tf.keras.metrics.SparseCategoricalAccuracy(),
        'test/ece':
        rm.metrics.ExpectedCalibrationError(num_bins=FLAGS.num_bins),
        'test/diversity': rm.metrics.AveragePairwiseDiversity(),
    }
    if FLAGS.eval_on_ood:
        ood_metrics = ood_utils.create_ood_metrics(ood_dataset_names)
        metrics.update(ood_metrics)
    corrupt_metrics = {}
    for name in test_datasets:
        corrupt_metrics['test/nll_{}'.format(name)] = tf.keras.metrics.Mean()
        corrupt_metrics['test/accuracy_{}'.format(name)] = (
            tf.keras.metrics.SparseCategoricalAccuracy())
        corrupt_metrics['test/ece_{}'.format(name)] = (
            rm.metrics.ExpectedCalibrationError(num_bins=FLAGS.num_bins))
    for i in range(ensemble_size):
        metrics['test/nll_member_{}'.format(i)] = tf.keras.metrics.Mean()
        metrics['test/accuracy_member_{}'.format(i)] = (
            tf.keras.metrics.SparseCategoricalAccuracy())

    # Evaluate model predictions.
    for n, (name, test_dataset) in enumerate(test_datasets.items()):
        logits_dataset = []
        for m in range(ensemble_size):
            filename = '{dataset}_{member}.npy'.format(
                dataset=name.replace('/', '_'),
                member=m)  # ood dataset name has '/'
            filename = os.path.join(FLAGS.output_dir, filename)
            with tf.io.gfile.GFile(filename, 'rb') as f:
                logits_dataset.append(np.load(f))

        logits_dataset = tf.convert_to_tensor(logits_dataset)
        test_iterator = iter(test_dataset)
        steps = steps_per_eval if 'ood/' not in name else steps_per_ood[name]
        for step in range(steps):
            inputs = next(test_iterator)
            labels = inputs['labels']  # pytype: disable=unsupported-operands
            logits = logits_dataset[:, (step * batch_size):((step + 1) *
                                                            batch_size)]
            labels = tf.cast(labels, tf.int32)
            negative_log_likelihood_metric = rm.metrics.EnsembleCrossEntropy()
            negative_log_likelihood_metric.add_batch(logits, labels=labels)
            negative_log_likelihood = list(
                negative_log_likelihood_metric.result().values())[0]
            per_probs = tf.nn.softmax(logits)
            probs = tf.reduce_mean(per_probs, axis=0)
            logits_mean = tf.reduce_mean(logits, axis=0)

            if name == 'clean':
                gibbs_ce_metric = rm.metrics.GibbsCrossEntropy()
                gibbs_ce_metric.add_batch(logits, labels=labels)
                gibbs_ce = list(gibbs_ce_metric.result().values())[0]
                metrics['test/negative_log_likelihood'].update_state(
                    negative_log_likelihood)
                metrics['test/gibbs_cross_entropy'].update_state(gibbs_ce)
                metrics['test/accuracy'].update_state(labels, probs)
                metrics['test/ece'].add_batch(probs, label=labels)

                for i in range(ensemble_size):
                    member_probs = per_probs[i]
                    member_loss = tf.keras.losses.sparse_categorical_crossentropy(
                        labels, member_probs)
                    metrics['test/nll_member_{}'.format(i)].update_state(
                        member_loss)
                    metrics['test/accuracy_member_{}'.format(i)].update_state(
                        labels, member_probs)
                metrics['test/diversity'].add_batch(per_probs)

            elif name.startswith('ood/'):
                ood_labels = 1 - inputs['is_in_distribution']  # pytype: disable=unsupported-operands
                if FLAGS.dempster_shafer_ood:
                    ood_scores = ood_utils.DempsterShaferUncertainty(
                        logits_mean)
                else:
                    ood_scores = 1 - tf.reduce_max(probs, axis=-1)

                for metric_name, metric in metrics.items():
                    if name in metric_name:
                        metric.update_state(ood_labels, ood_scores)
            else:
                corrupt_metrics['test/nll_{}'.format(name)].update_state(
                    negative_log_likelihood)
                corrupt_metrics['test/accuracy_{}'.format(name)].update_state(
                    labels, probs)
                corrupt_metrics['test/ece_{}'.format(name)].add_batch(
                    probs, label=labels)

        message = (
            '{:.1%} completion for evaluation: dataset {:d}/{:d}'.format(
                (n + 1) / num_datasets, n + 1, num_datasets))
        logging.info(message)

    corrupt_results = utils.aggregate_corrupt_metrics(corrupt_metrics,
                                                      corruption_types)
    total_results = {name: metric.result() for name, metric in metrics.items()}
    total_results.update(corrupt_results)
    # Results from Robustness Metrics themselves return a dict, so flatten them.
    total_results = utils.flatten_dictionary(total_results)
    logging.info('Metrics: %s', total_results)
def main(argv):
    del argv  # unused arg
    if not FLAGS.use_gpu:
        raise ValueError('Only GPU is currently supported.')
    if FLAGS.num_cores > 1:
        raise ValueError('Only a single accelerator is currently supported.')
    tf.random.set_seed(FLAGS.seed)
    tf.io.gfile.makedirs(FLAGS.output_dir)

    ds_info = tfds.builder(FLAGS.dataset).info
    batch_size = FLAGS.total_batch_size
    steps_per_eval = ds_info.splits['test'].num_examples // batch_size
    num_classes = ds_info.features['label'].num_classes

    data_dir = FLAGS.data_dir
    dataset_builder = ub.datasets.get(
        FLAGS.dataset,
        data_dir=data_dir,
        download_data=FLAGS.download_data,
        split=tfds.Split.TEST,
        drop_remainder=FLAGS.drop_remainder_for_eval)
    dataset = dataset_builder.load(batch_size=batch_size)
    test_datasets = {'clean': dataset}
    if FLAGS.eval_on_ood:
        ood_dataset_names = FLAGS.ood_dataset
        ood_datasets, steps_per_ood = ood_utils.load_ood_datasets(
            ood_dataset_names,
            dataset_builder,
            1. - FLAGS.train_proportion,
            batch_size,
            drop_remainder=FLAGS.drop_remainder_for_eval)
        test_datasets.update(ood_datasets)
    extra_kwargs = {}
    if FLAGS.dataset == 'cifar100':
        data_dir = FLAGS.cifar100_c_path
    corruption_types, _ = utils.load_corrupted_test_info(FLAGS.dataset)
    for corruption_type in corruption_types:
        for severity in range(1, 6):
            dataset = ub.datasets.get(
                f'{FLAGS.dataset}_corrupted',
                corruption_type=corruption_type,
                data_dir=data_dir,
                severity=severity,
                split=tfds.Split.TEST,
                drop_remainder=FLAGS.drop_remainder_for_eval,
                **extra_kwargs).load(batch_size=batch_size)
            test_datasets[f'{corruption_type}_{severity}'] = dataset

    model = ub.models.wide_resnet_sngp(
        input_shape=ds_info.features['image'].shape,
        batch_size=FLAGS.total_batch_size // FLAGS.num_cores,
        depth=28,
        width_multiplier=10,
        num_classes=num_classes,
        l2=0.,
        use_mc_dropout=FLAGS.use_mc_dropout,
        use_filterwise_dropout=FLAGS.use_filterwise_dropout,
        dropout_rate=FLAGS.dropout_rate,
        use_gp_layer=FLAGS.use_gp_layer,
        gp_input_dim=FLAGS.gp_input_dim,
        gp_hidden_dim=FLAGS.gp_hidden_dim,
        gp_scale=FLAGS.gp_scale,
        gp_bias=FLAGS.gp_bias,
        gp_input_normalization=FLAGS.gp_input_normalization,
        gp_random_feature_type=FLAGS.gp_random_feature_type,
        gp_cov_discount_factor=FLAGS.gp_cov_discount_factor,
        gp_cov_ridge_penalty=FLAGS.gp_cov_ridge_penalty,
        use_spec_norm=FLAGS.use_spec_norm,
        spec_norm_iteration=FLAGS.spec_norm_iteration,
        spec_norm_bound=FLAGS.spec_norm_bound)
    logging.info('Model input shape: %s', model.input_shape)
    logging.info('Model output shape: %s', model.output_shape)
    logging.info('Model number of weights: %s', model.count_params())

    # Search for checkpoints from their index file; then remove the index suffix.
    ensemble_filenames = tf.io.gfile.glob(
        os.path.join(FLAGS.checkpoint_dir, '**/*.index'))
    # Only apply ensemble on the models with the same model architecture
    ensemble_filenames0 = [
        filename for filename in ensemble_filenames
        if f'use_gp_layer:{FLAGS.use_gp_layer}' in filename
        and f'use_spec_norm:{FLAGS.use_spec_norm}' in filename
    ]
    np.random.seed(FLAGS.seed)
    ensemble_filenames = np.random.choice(ensemble_filenames0,
                                          FLAGS.ensemble_size,
                                          replace=True)

    ensemble_filenames = [filename[:-6] for filename in ensemble_filenames]
    ensemble_size = len(ensemble_filenames)
    logging.info('Ensemble size: %s', ensemble_size)
    logging.info('Ensemble filenames: %s', ensemble_filenames)
    logging.info('Ensemble number of weights: %s',
                 ensemble_size * model.count_params())
    logging.info('Ensemble filenames: %s', str(ensemble_filenames))
    checkpoint = tf.train.Checkpoint(model=model)

    # Write model predictions to files.
    num_datasets = len(test_datasets)
    for m, ensemble_filename in enumerate(ensemble_filenames):
        checkpoint.restore(ensemble_filename)
        for n, (name, test_dataset) in enumerate(test_datasets.items()):
            filename = '{dataset}_{member}.npy'.format(
                dataset=name.replace('/', '_'),
                member=m)  # ood dataset name has '/'
            filename = os.path.join(FLAGS.output_dir, filename)
            if not tf.io.gfile.exists(filename):
                logits = []
                test_iterator = iter(test_dataset)
                steps = steps_per_eval if 'ood/' not in name else steps_per_ood[
                    name]
                for _ in range(steps):
                    features = next(test_iterator)['features']  # pytype: disable=unsupported-operands
                    logits_member = model(features, training=False)
                    if isinstance(logits_member, (list, tuple)):
                        # If model returns a tuple of (logits, covmat), extract both
                        logits_member, covmat_member = logits_member
                        logits_member = ed.layers.utils.mean_field_logits(
                            logits_member, covmat_member,
                            FLAGS.gp_mean_field_factor_ensemble)
                    logits.append(logits_member)

                logits = tf.concat(logits, axis=0)
                with tf.io.gfile.GFile(filename, 'w') as f:
                    np.save(f, logits.numpy())
            percent = (m * num_datasets +
                       (n + 1)) / (ensemble_size * num_datasets)
            message = (
                '{:.1%} completion for prediction: ensemble member {:d}/{:d}. '
                'Dataset {:d}/{:d}'.format(percent, m + 1, ensemble_size,
                                           n + 1, num_datasets))
            logging.info(message)

    metrics = {
        'test/negative_log_likelihood': tf.keras.metrics.Mean(),
        'test/gibbs_cross_entropy': tf.keras.metrics.Mean(),
        'test/accuracy': tf.keras.metrics.SparseCategoricalAccuracy(),
        'test/ece':
        rm.metrics.ExpectedCalibrationError(num_bins=FLAGS.num_bins),
    }
    if FLAGS.eval_on_ood:
        ood_metrics = ood_utils.create_ood_metrics(ood_dataset_names)
        metrics.update(ood_metrics)
    corrupt_metrics = {}
    for name in test_datasets:
        corrupt_metrics['test/nll_{}'.format(name)] = tf.keras.metrics.Mean()
        corrupt_metrics['test/accuracy_{}'.format(name)] = (
            tf.keras.metrics.SparseCategoricalAccuracy())
        corrupt_metrics['test/ece_{}'.format(name)] = (
            rm.metrics.ExpectedCalibrationError(num_bins=FLAGS.num_bins))

    # Evaluate model predictions.
    for n, (name, test_dataset) in enumerate(test_datasets.items()):
        logits_dataset = []
        for m in range(ensemble_size):
            filename = '{dataset}_{member}.npy'.format(
                dataset=name.replace('/', '_'),
                member=m)  # ood dataset name has '/'
            filename = os.path.join(FLAGS.output_dir, filename)
            with tf.io.gfile.GFile(filename, 'rb') as f:
                logits_dataset.append(np.load(f))

        logits_dataset = tf.convert_to_tensor(logits_dataset)
        test_iterator = iter(test_dataset)
        steps = steps_per_eval if 'ood/' not in name else steps_per_ood[name]
        for step in range(steps):
            inputs = next(test_iterator)
            labels = inputs['labels']  # pytype: disable=unsupported-operands
            logits = logits_dataset[:, (step * batch_size):((step + 1) *
                                                            batch_size)]
            labels = tf.cast(labels, tf.int32)
            negative_log_likelihood_metric = rm.metrics.EnsembleCrossEntropy()
            negative_log_likelihood_metric.add_batch(logits, labels=labels)
            negative_log_likelihood = list(
                negative_log_likelihood_metric.result().values())[0]
            per_probs = tf.nn.softmax(logits)
            probs = tf.reduce_mean(per_probs, axis=0)
            logits_mean = tf.reduce_mean(logits, axis=0)
            if name == 'clean':
                gibbs_ce_metric = rm.metrics.GibbsCrossEntropy()
                gibbs_ce_metric.add_batch(logits, labels=labels)
                gibbs_ce = list(gibbs_ce_metric.result().values())[0]
                metrics['test/negative_log_likelihood'].update_state(
                    negative_log_likelihood)
                metrics['test/gibbs_cross_entropy'].update_state(gibbs_ce)
                metrics['test/accuracy'].update_state(labels, probs)
                metrics['test/ece'].add_batch(probs, label=labels)
            elif name.startswith('ood/'):
                ood_labels = 1 - inputs['is_in_distribution']  # pytype: disable=unsupported-operands
                if FLAGS.dempster_shafer_ood:
                    ood_scores = ood_utils.DempsterShaferUncertainty(
                        logits_mean)
                else:
                    ood_scores = 1 - tf.reduce_max(probs, axis=-1)

                for metric_name, metric in metrics.items():
                    if name in metric_name:
                        metric.update_state(ood_labels, ood_scores)
            else:
                corrupt_metrics['test/nll_{}'.format(name)].update_state(
                    negative_log_likelihood)
                corrupt_metrics['test/accuracy_{}'.format(name)].update_state(
                    labels, probs)
                corrupt_metrics['test/ece_{}'.format(name)].add_batch(
                    probs, label=labels)

        message = (
            '{:.1%} completion for evaluation: dataset {:d}/{:d}'.format(
                (n + 1) / num_datasets, n + 1, num_datasets))
        logging.info(message)

    corrupt_results = utils.aggregate_corrupt_metrics(corrupt_metrics,
                                                      corruption_types)
    total_results = {name: metric.result() for name, metric in metrics.items()}
    total_results.update(corrupt_results)
    # Metrics from Robustness Metrics (like ECE) will return a dict with a
    # single key/value, instead of a scalar.
    total_results = {
        k: (list(v.values())[0] if isinstance(v, dict) else v)
        for k, v in total_results.items()
    }
    logging.info('Metrics: %s', total_results)