Beispiel #1
0
def main(argv):
    del argv  # unused arg
    if not FLAGS.use_gpu:
        raise ValueError('Only GPU is currently supported.')
    if FLAGS.num_cores > 1:
        raise ValueError('Only a single accelerator is currently supported.')
    tf.enable_v2_behavior()
    tf.random.set_seed(FLAGS.seed)
    tf.io.gfile.makedirs(FLAGS.output_dir)

    batch_size = FLAGS.per_core_batch_size * FLAGS.num_cores
    steps_per_eval = IMAGENET_VALIDATION_IMAGES // batch_size

    dataset_test = utils.ImageNetInput(is_training=False,
                                       data_dir=FLAGS.data_dir,
                                       batch_size=FLAGS.per_core_batch_size,
                                       use_bfloat16=False).input_fn()
    test_datasets = {'clean': dataset_test}
    corruption_types, max_intensity = utils.load_corrupted_test_info()
    for name in corruption_types:
        for intensity in range(1, max_intensity + 1):
            dataset_name = '{0}_{1}'.format(name, intensity)
            test_datasets[dataset_name] = utils.load_corrupted_test_dataset(
                name=name,
                intensity=intensity,
                batch_size=FLAGS.per_core_batch_size,
                drop_remainder=True,
                use_bfloat16=False)

    model = deterministic_model.resnet50(input_shape=(224, 224, 3),
                                         num_classes=NUM_CLASSES)

    logging.info('Model input shape: %s', model.input_shape)
    logging.info('Model output shape: %s', model.output_shape)
    logging.info('Model number of weights: %s', model.count_params())
    # Search for checkpoints from their index file; then remove the index suffix.
    ensemble_filenames = tf.io.gfile.glob(
        os.path.join(FLAGS.checkpoint_dir, '**/*.index'))
    ensemble_filenames = [filename[:-6] for filename in ensemble_filenames]
    ensemble_size = len(ensemble_filenames)
    logging.info('Ensemble size: %s', ensemble_size)
    logging.info('Ensemble number of weights: %s',
                 ensemble_size * model.count_params())
    logging.info('Ensemble filenames: %s', str(ensemble_filenames))
    checkpoint = tf.train.Checkpoint(model=model)

    # Write model predictions to files.
    num_datasets = len(test_datasets)
    for m, ensemble_filename in enumerate(ensemble_filenames):
        checkpoint.restore(ensemble_filename)
        for n, (name, test_dataset) in enumerate(test_datasets.items()):
            filename = '{dataset}_{member}.npy'.format(dataset=name, member=m)
            filename = os.path.join(FLAGS.output_dir, filename)
            if not tf.io.gfile.exists(filename):
                logits = []
                test_iterator = iter(test_dataset)
                for _ in range(steps_per_eval):
                    features, _ = next(test_iterator)  # pytype: disable=attribute-error
                    logits.append(model(features, training=False))

                logits = tf.concat(logits, axis=0)
                with tf.io.gfile.GFile(filename, 'w') as f:
                    np.save(f, logits.numpy())
            percent = (m * num_datasets +
                       (n + 1)) / (ensemble_size * num_datasets)
            message = (
                '{:.1%} completion for prediction: ensemble member {:d}/{:d}. '
                'Dataset {:d}/{:d}'.format(percent, m + 1, ensemble_size,
                                           n + 1, num_datasets))
            logging.info(message)

    metrics = {
        'test/negative_log_likelihood': tf.keras.metrics.Mean(),
        'test/gibbs_cross_entropy': tf.keras.metrics.Mean(),
        'test/accuracy': tf.keras.metrics.SparseCategoricalAccuracy(),
        'test/ece':
        ed.metrics.ExpectedCalibrationError(num_bins=FLAGS.num_bins),
    }
    corrupt_metrics = {}
    for name in test_datasets:
        corrupt_metrics['test/nll_{}'.format(name)] = tf.keras.metrics.Mean()
        corrupt_metrics['test/accuracy_{}'.format(name)] = (
            tf.keras.metrics.SparseCategoricalAccuracy())
        corrupt_metrics['test/ece_{}'.format(
            name)] = ed.metrics.ExpectedCalibrationError(
                num_bins=FLAGS.num_bins)

    # Evaluate model predictions.
    for n, (name, test_dataset) in enumerate(test_datasets.items()):
        logits_dataset = []
        for m in range(ensemble_size):
            filename = '{dataset}_{member}.npy'.format(dataset=name, member=m)
            filename = os.path.join(FLAGS.output_dir, filename)
            with tf.io.gfile.GFile(filename, 'rb') as f:
                logits_dataset.append(np.load(f))

        logits_dataset = tf.convert_to_tensor(logits_dataset)
        test_iterator = iter(test_dataset)
        for step in range(steps_per_eval):
            _, labels = next(test_iterator)  # pytype: disable=attribute-error
            logits = logits_dataset[:, (step * batch_size):((step + 1) *
                                                            batch_size)]
            labels = tf.cast(tf.reshape(labels, [-1]), tf.int32)
            negative_log_likelihood = tf.reduce_mean(
                ensemble_negative_log_likelihood(labels, logits))
            per_probs = tf.nn.softmax(logits)
            probs = tf.reduce_mean(per_probs, axis=0)
            if name == 'clean':
                gibbs_ce = tf.reduce_mean(gibbs_cross_entropy(labels, logits))
                metrics['test/negative_log_likelihood'].update_state(
                    negative_log_likelihood)
                metrics['test/gibbs_cross_entropy'].update_state(gibbs_ce)
                metrics['test/accuracy'].update_state(labels, probs)
                metrics['test/ece'].update_state(labels, probs)
            else:
                corrupt_metrics['test/nll_{}'.format(name)].update_state(
                    negative_log_likelihood)
                corrupt_metrics['test/accuracy_{}'.format(name)].update_state(
                    labels, probs)
                corrupt_metrics['test/ece_{}'.format(name)].update_state(
                    labels, probs)

        message = (
            '{:.1%} completion for evaluation: dataset {:d}/{:d}'.format(
                (n + 1) / num_datasets, n + 1, num_datasets))
        logging.info(message)

    corrupt_results = utils.aggregate_corrupt_metrics(
        corrupt_metrics, corruption_types, max_intensity,
        FLAGS.alexnet_errors_path)
    total_results = {name: metric.result() for name, metric in metrics.items()}
    total_results.update(corrupt_results)
    logging.info('Metrics: %s', total_results)
Beispiel #2
0
def main(argv):
  del argv  # unused arg
  tf.enable_v2_behavior()
  tf.random.set_seed(FLAGS.seed)

  batch_size = FLAGS.per_core_batch_size * FLAGS.num_cores
  steps_per_epoch = APPROX_IMAGENET_TRAIN_IMAGES // batch_size
  steps_per_eval = IMAGENET_VALIDATION_IMAGES // batch_size

  logging.info('Saving checkpoints at %s', FLAGS.output_dir)

  if FLAGS.use_gpu:
    logging.info('Use GPU')
    strategy = tf.distribute.MirroredStrategy()
  else:
    logging.info('Use TPU at %s',
                 FLAGS.tpu if FLAGS.tpu is not None else 'local')
    resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu=FLAGS.tpu)
    tf.config.experimental_connect_to_cluster(resolver)
    tf.tpu.experimental.initialize_tpu_system(resolver)
    strategy = tf.distribute.experimental.TPUStrategy(resolver)

  imagenet_train = utils.ImageNetInput(
      is_training=True,
      data_dir=FLAGS.data_dir,
      batch_size=batch_size,
      use_bfloat16=not FLAGS.use_gpu,
      drop_remainder=True)
  imagenet_eval = utils.ImageNetInput(
      is_training=False,
      data_dir=FLAGS.data_dir,
      batch_size=batch_size,
      use_bfloat16=not FLAGS.use_gpu,
      drop_remainder=True)
  train_dataset = strategy.experimental_distribute_dataset(
      imagenet_train.input_fn())
  test_dataset = strategy.experimental_distribute_dataset(
      imagenet_eval.input_fn())

  if FLAGS.use_bfloat16:
    policy = tf.keras.mixed_precision.experimental.Policy('mixed_bfloat16')
    tf.keras.mixed_precision.experimental.set_policy(policy)

  with strategy.scope():
    logging.info('Building Keras ResNet-50 model')
    model = deterministic_model.resnet50(input_shape=(224, 224, 3),
                                         num_classes=NUM_CLASSES)
    logging.info('Model input shape: %s', model.input_shape)
    logging.info('Model output shape: %s', model.output_shape)
    logging.info('Model number of weights: %s', model.count_params())
    # Scale learning rate and decay epochs by vanilla settings.
    base_lr = FLAGS.base_learning_rate * batch_size / 256
    learning_rate = utils.LearningRateSchedule(steps_per_epoch,
                                               base_lr,
                                               FLAGS.train_epochs,
                                               _LR_SCHEDULE)
    optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate,
                                        momentum=0.9,
                                        nesterov=True)
    metrics = {
        'train/negative_log_likelihood': tf.keras.metrics.Mean(),
        'train/accuracy': tf.keras.metrics.SparseCategoricalAccuracy(),
        'train/loss': tf.keras.metrics.Mean(),
        'test/negative_log_likelihood': tf.keras.metrics.Mean(),
        'test/accuracy': tf.keras.metrics.SparseCategoricalAccuracy(),
    }
    logging.info('Finished building Keras ResNet-50 model')

    checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer)
    latest_checkpoint = tf.train.latest_checkpoint(FLAGS.output_dir)
    initial_epoch = 0
    if latest_checkpoint:
      # checkpoint.restore must be within a strategy.scope() so that optimizer
      # slot variables are mirrored.
      checkpoint.restore(latest_checkpoint)
      logging.info('Loaded checkpoint %s', latest_checkpoint)
      initial_epoch = optimizer.iterations.numpy() // steps_per_epoch

  summary_writer = tf.summary.create_file_writer(
      os.path.join(FLAGS.output_dir, 'summaries'))

  @tf.function
  def train_step(iterator):
    """Training StepFn."""
    def step_fn(inputs):
      """Per-Replica StepFn."""
      images, labels = inputs

      with tf.GradientTape() as tape:
        logits = model(images, training=True)
        if FLAGS.use_bfloat16:
          logits = tf.cast(logits, tf.float32)

        negative_log_likelihood = tf.reduce_mean(
            tf.keras.losses.sparse_categorical_crossentropy(labels,
                                                            logits,
                                                            from_logits=True))
        filtered_variables = []
        for var in model.trainable_variables:
          # Apply l2 on the weights. This excludes BN parameters and biases, but
          # pay caution to their naming scheme.
          if 'kernel' in var.name or 'bias' in var.name:
            filtered_variables.append(tf.reshape(var, (-1,)))

        l2_loss = FLAGS.l2 * 2 * tf.nn.l2_loss(
            tf.concat(filtered_variables, axis=0))
        # Scale the loss given the TPUStrategy will reduce sum all gradients.
        loss = negative_log_likelihood + l2_loss
        scaled_loss = loss / strategy.num_replicas_in_sync

      grads = tape.gradient(scaled_loss, model.trainable_variables)
      optimizer.apply_gradients(zip(grads, model.trainable_variables))

      metrics['train/loss'].update_state(loss)
      metrics['train/negative_log_likelihood'].update_state(
          negative_log_likelihood)
      metrics['train/accuracy'].update_state(labels, logits)

    strategy.experimental_run_v2(step_fn, args=(next(iterator),))

  @tf.function
  def test_step(iterator):
    """Evaluation StepFn."""
    def step_fn(inputs):
      """Per-Replica StepFn."""
      images, labels = inputs
      logits = model(images, training=False)
      if FLAGS.use_bfloat16:
        logits = tf.cast(logits, tf.float32)

      negative_log_likelihood = tf.reduce_mean(
          tf.keras.losses.sparse_categorical_crossentropy(labels,
                                                          logits,
                                                          from_logits=True))
      metrics['test/negative_log_likelihood'].update_state(
          negative_log_likelihood)
      metrics['test/accuracy'].update_state(labels, logits)

    strategy.experimental_run_v2(step_fn, args=(next(iterator),))

  train_iterator = iter(train_dataset)
  start_time = time.time()
  for epoch in range(initial_epoch, FLAGS.train_epochs):
    logging.info('Starting to run epoch: %s', epoch)
    for step in range(steps_per_epoch):
      train_step(train_iterator)

      current_step = epoch * steps_per_epoch + (step + 1)
      max_steps = steps_per_epoch * FLAGS.train_epochs
      time_elapsed = time.time() - start_time
      steps_per_sec = float(current_step) / time_elapsed
      eta_seconds = (max_steps - current_step) / steps_per_sec
      message = ('{:.1%} completion: epoch {:d}/{:d}. {:.1f} steps/s. '
                 'ETA: {:.0f} min. Time elapsed: {:.0f} min'.format(
                     current_step / max_steps,
                     epoch + 1,
                     FLAGS.train_epochs,
                     steps_per_sec,
                     eta_seconds / 60,
                     time_elapsed / 60))
      if step % 20 == 0:
        logging.info(message)

    test_iterator = iter(test_dataset)
    for step in range(steps_per_eval):
      if step % 20 == 0:
        logging.info('Starting to run eval step %s of epoch: %s', step, epoch)
      test_step(test_iterator)

    logging.info('Train Loss: %.4f, Accuracy: %.2f%%',
                 metrics['train/loss'].result(),
                 metrics['train/accuracy'].result() * 100)
    logging.info('Test NLL: %.4f, Accuracy: %.2f%%',
                 metrics['test/negative_log_likelihood'].result(),
                 metrics['test/accuracy'].result() * 100)
    with summary_writer.as_default():
      for name, metric in metrics.items():
        tf.summary.scalar(name, metric.result(), step=epoch + 1)

    for metric in metrics.values():
      metric.reset_states()

    if (epoch + 1) % 20 == 0:
      checkpoint_name = checkpoint.save(os.path.join(
          FLAGS.output_dir, 'checkpoint'))
      logging.info('Saved checkpoint to %s', checkpoint_name)
Beispiel #3
0
def main(argv):
    del argv  # unused arg
    if FLAGS.num_cores > 1:
        raise ValueError('Only a single accelerator is currently supported.')
    tf.enable_v2_behavior()
    tf.random.set_seed(FLAGS.seed)

    dataset_test = utils.ImageNetInput(is_training=False,
                                       data_dir=FLAGS.data_dir,
                                       batch_size=FLAGS.per_core_batch_size,
                                       use_bfloat16=False).input_fn()
    test_datasets = {'clean': dataset_test}

    model = deterministic_model.resnet50(input_shape=(224, 224, 3),
                                         num_classes=NUM_CLASSES)

    logging.info('Model input shape: %s', model.input_shape)
    logging.info('Model output shape: %s', model.output_shape)
    logging.info('Model number of weights: %s', model.count_params())
    # Search for checkpoints from their index file; then remove the index suffix.
    ensemble_filenames = tf.io.gfile.glob(
        os.path.join(FLAGS.output_dir, '**/*.index'))
    ensemble_filenames = [filename[:-6] for filename in ensemble_filenames]
    ensemble_size = len(ensemble_filenames)
    logging.info('Ensemble size: %s', ensemble_size)
    logging.info('Ensemble number of weights: %s',
                 ensemble_size * model.count_params())
    logging.info('Ensemble filenames: %s', str(ensemble_filenames))
    checkpoint = tf.train.Checkpoint(model=model)

    # Collect the logits output for each ensemble member and test data
    # point. We also collect the labels.

    logits_test = {'clean': []}
    labels_test = {'clean': []}
    corruption_types, max_intensity = utils.load_corrupted_test_info()
    for name in corruption_types:
        for intensity in range(1, max_intensity + 1):
            dataset_name = '{0}_{1}'.format(name, intensity)
            logits_test[dataset_name] = []
            labels_test[dataset_name] = []

            test_datasets[dataset_name] = utils.load_corrupted_test_dataset(
                name=name,
                intensity=intensity,
                batch_size=FLAGS.per_core_batch_size,
                drop_remainder=True,
                use_bfloat16=False)

    for m, ensemble_filename in enumerate(ensemble_filenames):
        checkpoint.restore(ensemble_filename)
        logging.info('Working on test data for ensemble member %s', m)
        for name, test_dataset in test_datasets.items():
            logits = []
            for features, labels in test_dataset:
                logits.append(model(features, training=False))
                if m == 0:
                    labels_test[name].append(labels)

            logits = tf.concat(logits, axis=0)
            logits_test[name].append(logits)
            if m == 0:
                labels_test[name] = tf.concat(labels_test[name], axis=0)
            logging.info('Finished testing on %s', format(name))

    metrics = {
        'test/ece':
        ed.metrics.ExpectedCalibrationError(num_classes=NUM_CLASSES,
                                            num_bins=15)
    }
    corrupt_metrics = {}
    for name in test_datasets:
        corrupt_metrics['test/ece_{}'.format(
            name)] = ed.metrics.ExpectedCalibrationError(
                num_classes=NUM_CLASSES, num_bins=15)
        corrupt_metrics['test/nll_{}'.format(name)] = tf.keras.metrics.Mean()
        corrupt_metrics['test/accuracy_{}'.format(
            name)] = tf.keras.metrics.Mean()

    for name, test_dataset in test_datasets.items():
        labels = labels_test[name]
        logits = logits_test[name]
        nll_test = ensemble_negative_log_likelihood(labels, logits)
        gibbs_ce_test = gibbs_cross_entropy(labels_test[name],
                                            logits_test[name])
        labels = tf.cast(labels, tf.int32)
        logits = tf.convert_to_tensor(logits)
        per_probs = tf.nn.softmax(logits)
        probs = tf.reduce_mean(per_probs, axis=0)
        accuracy = tf.keras.metrics.sparse_categorical_accuracy(labels, probs)
        if name == 'clean':
            metrics['test/negative_log_likelihood'] = tf.reduce_mean(nll_test)
            metrics['test/gibbs_cross_entropy'] = tf.reduce_mean(gibbs_ce_test)
            metrics['test/accuracy'] = tf.reduce_mean(accuracy)
            metrics['test/ece'].update_state(labels, probs)
        else:
            corrupt_metrics['test/nll_{}'.format(name)].update_state(
                tf.reduce_mean(nll_test))
            corrupt_metrics['test/accuracy_{}'.format(name)].update_state(
                tf.reduce_mean(accuracy))
            corrupt_metrics['test/ece_{}'.format(name)].update_state(
                labels, probs)

    corrupt_results = {}
    corrupt_results = utils.aggregate_corrupt_metrics(corrupt_metrics,
                                                      corruption_types,
                                                      max_intensity)
    metrics['test/ece'] = metrics['test/ece'].result()
    total_results = {name: metric for name, metric in metrics.items()}
    total_results.update(corrupt_results)
    logging.info('Metrics: %s', total_results)