def testResNetV1(self): tf.random.set_seed(83922) dataset_size = 10 batch_size = 5 input_shape = (32, 32, 1) num_classes = 2 features = tf.random.normal((dataset_size, ) + input_shape) coeffs = tf.random.normal([tf.reduce_prod(input_shape), num_classes]) net = tf.reshape(features, [dataset_size, -1]) logits = tf.matmul(net, coeffs) labels = tf.random.categorical(logits, 1) features, labels = self.evaluate([features, labels]) dataset = tf.data.Dataset.from_tensor_slices((features, labels)) dataset = dataset.repeat().shuffle(dataset_size).batch(batch_size) model = deterministic.wide_resnet(input_shape=input_shape, depth=10, width_multiplier=1, num_classes=num_classes, l2=0., version=2) model.compile('adam', loss=tf.keras.losses.SparseCategoricalCrossentropy( from_logits=True)) history = model.fit(dataset, steps_per_epoch=dataset_size // batch_size, epochs=2) loss_history = history.history['loss'] self.assertAllGreaterEqual(loss_history, 0.)
def main(argv): del argv # unused arg if not FLAGS.use_gpu: raise ValueError('Only GPU is currently supported.') if FLAGS.num_cores > 1: raise ValueError('Only a single accelerator is currently supported.') tf.enable_v2_behavior() tf.random.set_seed(FLAGS.seed) tf.io.gfile.makedirs(FLAGS.output_dir) ds_info = tfds.builder(FLAGS.dataset).info batch_size = FLAGS.per_core_batch_size * FLAGS.num_cores steps_per_eval = ds_info.splits['test'].num_examples // batch_size num_classes = ds_info.features['label'].num_classes dataset_input_fn = utils.load_input_fn( split=tfds.Split.TEST, name=FLAGS.dataset, batch_size=FLAGS.per_core_batch_size, use_bfloat16=FLAGS.use_bfloat16) test_datasets = {'clean': dataset_input_fn()} corruption_types, max_intensity = utils.load_corrupted_test_info( FLAGS.dataset) for name in corruption_types: for intensity in range(1, max_intensity + 1): dataset_name = '{0}_{1}'.format(name, intensity) if FLAGS.dataset == 'cifar10': load_c_dataset = utils.load_cifar10_c_input_fn else: load_c_dataset = functools.partial( utils.load_cifar100_c_input_fn, path=FLAGS.cifar100_c_path) corrupted_input_fn = load_c_dataset( corruption_name=name, corruption_intensity=intensity, batch_size=FLAGS.per_core_batch_size, use_bfloat16=FLAGS.use_bfloat16) test_datasets[dataset_name] = corrupted_input_fn() model = deterministic.wide_resnet( input_shape=ds_info.features['image'].shape, depth=28, width_multiplier=10, num_classes=num_classes, l2=0., version=2) logging.info('Model input shape: %s', model.input_shape) logging.info('Model output shape: %s', model.output_shape) logging.info('Model number of weights: %s', model.count_params()) # Search for checkpoints from their index file; then remove the index suffix. ensemble_filenames = tf.io.gfile.glob( os.path.join(FLAGS.checkpoint_dir, '**/*.index')) ensemble_filenames = [filename[:-6] for filename in ensemble_filenames] ensemble_size = len(ensemble_filenames) logging.info('Ensemble size: %s', ensemble_size) logging.info('Ensemble number of weights: %s', ensemble_size * model.count_params()) logging.info('Ensemble filenames: %s', str(ensemble_filenames)) checkpoint = tf.train.Checkpoint(model=model) # Write model predictions to files. num_datasets = len(test_datasets) for m, ensemble_filename in enumerate(ensemble_filenames): checkpoint.restore(ensemble_filename) for n, (name, test_dataset) in enumerate(test_datasets.items()): filename = '{dataset}_{member}.npy'.format(dataset=name, member=m) filename = os.path.join(FLAGS.output_dir, filename) if not tf.io.gfile.exists(filename): logits = [] test_iterator = iter(test_dataset) for _ in range(steps_per_eval): features, _ = next(test_iterator) # pytype: disable=attribute-error logits.append(model(features, training=False)) logits = tf.concat(logits, axis=0) with tf.io.gfile.GFile(filename, 'w') as f: np.save(f, logits.numpy()) percent = (m * num_datasets + (n + 1)) / (ensemble_size * num_datasets) message = ( '{:.1%} completion for prediction: ensemble member {:d}/{:d}. ' 'Dataset {:d}/{:d}'.format(percent, m + 1, ensemble_size, n + 1, num_datasets)) logging.info(message) metrics = { 'test/negative_log_likelihood': tf.keras.metrics.Mean(), 'test/gibbs_cross_entropy': tf.keras.metrics.Mean(), 'test/accuracy': tf.keras.metrics.SparseCategoricalAccuracy(), 'test/ece': ed.metrics.ExpectedCalibrationError(num_bins=FLAGS.num_bins), } corrupt_metrics = {} for name in test_datasets: corrupt_metrics['test/nll_{}'.format(name)] = tf.keras.metrics.Mean() corrupt_metrics['test/accuracy_{}'.format(name)] = ( tf.keras.metrics.SparseCategoricalAccuracy()) corrupt_metrics['test/ece_{}'.format(name)] = ( ed.metrics.ExpectedCalibrationError(num_bins=FLAGS.num_bins)) # Evaluate model predictions. for n, (name, test_dataset) in enumerate(test_datasets.items()): logits_dataset = [] for m in range(ensemble_size): filename = '{dataset}_{member}.npy'.format(dataset=name, member=m) filename = os.path.join(FLAGS.output_dir, filename) with tf.io.gfile.GFile(filename, 'rb') as f: logits_dataset.append(np.load(f)) logits_dataset = tf.convert_to_tensor(logits_dataset) test_iterator = iter(test_dataset) for step in range(steps_per_eval): _, labels = next(test_iterator) # pytype: disable=attribute-error logits = logits_dataset[:, (step * batch_size):((step + 1) * batch_size)] labels = tf.cast(labels, tf.int32) negative_log_likelihood = tf.reduce_mean( ensemble_negative_log_likelihood(labels, logits)) per_probs = tf.nn.softmax(logits) probs = tf.reduce_mean(per_probs, axis=0) if name == 'clean': gibbs_ce = tf.reduce_mean(gibbs_cross_entropy(labels, logits)) metrics['test/negative_log_likelihood'].update_state( negative_log_likelihood) metrics['test/gibbs_cross_entropy'].update_state(gibbs_ce) metrics['test/accuracy'].update_state(labels, probs) metrics['test/ece'].update_state(labels, probs) else: corrupt_metrics['test/nll_{}'.format(name)].update_state( negative_log_likelihood) corrupt_metrics['test/accuracy_{}'.format(name)].update_state( labels, probs) corrupt_metrics['test/ece_{}'.format(name)].update_state( labels, probs) message = ( '{:.1%} completion for evaluation: dataset {:d}/{:d}'.format( (n + 1) / num_datasets, n + 1, num_datasets)) logging.info(message) corrupt_results = utils.aggregate_corrupt_metrics(corrupt_metrics, corruption_types, max_intensity) total_results = {name: metric.result() for name, metric in metrics.items()} total_results.update(corrupt_results) logging.info('Metrics: %s', total_results)
def main(argv): del argv # unused arg if FLAGS.num_cores > 1: raise ValueError('Only a single accelerator is currently supported.') tf.enable_v2_behavior() tf.random.set_seed(FLAGS.seed) dataset_input_fn = utils.load_input_fn(tfds.Split.TEST, FLAGS.per_core_batch_size, name=FLAGS.dataset, use_bfloat16=False, normalize=True, drop_remainder=True, proportion=1.0) test_datasets = {'clean': dataset_input_fn()} ds_info = tfds.builder(FLAGS.dataset).info num_classes = ds_info.features['label'].num_classes model = deterministic.wide_resnet( input_shape=ds_info.features['image'].shape, depth=28, width_multiplier=10, num_classes=num_classes, l2=0., version=2) logging.info('Model input shape: %s', model.input_shape) logging.info('Model output shape: %s', model.output_shape) logging.info('Model number of weights: %s', model.count_params()) # Search for checkpoints from their index file; then remove the index suffix. ensemble_filenames = tf.io.gfile.glob( os.path.join(FLAGS.output_dir, '**/*.index')) ensemble_filenames = [filename[:-6] for filename in ensemble_filenames] ensemble_size = len(ensemble_filenames) logging.info('Ensemble size: %s', ensemble_size) logging.info('Ensemble number of weights: %s', ensemble_size * model.count_params()) logging.info('Ensemble filenames: %s', str(ensemble_filenames)) checkpoint = tf.train.Checkpoint(model=model) # Collect the logits output for each ensemble member and train/test data # point. We also collect the labels. # TODO(trandustin): Refactor data loader so you can get the full dataset in # memory without looping. logits_test = {'clean': []} labels_test = {'clean': []} corruption_types, max_intensity = utils.load_corrupted_test_info( FLAGS.dataset) for name in corruption_types: for intensity in range(1, max_intensity + 1): dataset_name = '{0}_{1}'.format(name, intensity) logits_test[dataset_name] = [] labels_test[dataset_name] = [] if FLAGS.dataset == 'cifar10': load_c_dataset = utils.load_cifar10_c_input_fn else: load_c_dataset = functools.partial( utils.load_cifar100_c_input_fn, path=FLAGS.cifar100_c_path) corrupted_input_fn = load_c_dataset( corruption_name=name, corruption_intensity=intensity, batch_size=FLAGS.per_core_batch_size, use_bfloat16=False) test_datasets[dataset_name] = corrupted_input_fn() for m, ensemble_filename in enumerate(ensemble_filenames): checkpoint.restore(ensemble_filename) logging.info('Working on test data for ensemble member %s', m) for name, test_dataset in test_datasets.items(): logits = [] for features, labels in test_dataset: logits.append(model(features, training=False)) if m == 0: labels_test[name].append(labels) logits = tf.concat(logits, axis=0) logits_test[name].append(logits) if m == 0: labels_test[name] = tf.concat(labels_test[name], axis=0) logging.info('Finished testing on %s', format(name)) metrics = { 'test/ece': ed.metrics.ExpectedCalibrationError(num_classes=num_classes, num_bins=15) } corrupt_metrics = {} for name in test_datasets: corrupt_metrics['test/ece_{}'.format( name)] = ed.metrics.ExpectedCalibrationError( num_classes=num_classes, num_bins=15) corrupt_metrics['test/nll_{}'.format(name)] = tf.keras.metrics.Mean() corrupt_metrics['test/accuracy_{}'.format( name)] = tf.keras.metrics.Mean() for name, test_dataset in test_datasets.items(): labels = labels_test[name] logits = logits_test[name] nll_test = ensemble_negative_log_likelihood(labels, logits) gibbs_ce_test = gibbs_cross_entropy(labels_test[name], logits_test[name]) labels = tf.cast(labels, tf.int32) logits = tf.convert_to_tensor(logits) per_probs = tf.nn.softmax(logits) probs = tf.reduce_mean(per_probs, axis=0) accuracy = tf.keras.metrics.sparse_categorical_accuracy(labels, probs) if name == 'clean': metrics['test/negative_log_likelihood'] = tf.reduce_mean(nll_test) metrics['test/gibbs_cross_entropy'] = tf.reduce_mean(gibbs_ce_test) metrics['test/accuracy'] = tf.reduce_mean(accuracy) metrics['test/ece'].update_state(labels, probs) else: corrupt_metrics['test/nll_{}'.format(name)].update_state( tf.reduce_mean(nll_test)) corrupt_metrics['test/accuracy_{}'.format(name)].update_state( tf.reduce_mean(accuracy)) corrupt_metrics['test/ece_{}'.format(name)].update_state( labels, probs) corrupt_results = {} corrupt_results = utils.aggregate_corrupt_metrics(corrupt_metrics, corruption_types, max_intensity) metrics['test/ece'] = metrics['test/ece'].result() total_results = {name: metric for name, metric in metrics.items()} total_results.update(corrupt_results) logging.info('Metrics: %s', total_results)
def main(argv): del argv # unused arg if FLAGS.num_cores > 1: raise ValueError('Only a single accelerator is currently supported.') tf.enable_v2_behavior() tf.random.set_seed(FLAGS.seed) # TODO(trandustin): Replace with load_distributed_dataset. Currently hangs. dataset_train = utils.load_dataset(tfds.Split.TRAIN, FLAGS.dataset) dataset_test = utils.load_dataset(tfds.Split.TEST, FLAGS.dataset) dataset_train = dataset_train.batch(FLAGS.per_core_batch_size) dataset_test = dataset_test.batch(FLAGS.per_core_batch_size) ds_info = tfds.builder(FLAGS.dataset).info model = deterministic.wide_resnet( input_shape=ds_info.features['image'].shape, depth=28, width_multiplier=10, num_classes=ds_info.features['label'].num_classes, l2=0., version=2) logging.info('Model input shape: %s', model.input_shape) logging.info('Model output shape: %s', model.output_shape) logging.info('Model number of weights: %s', model.count_params()) # Search for checkpoints from their index file; then remove the index suffix. ensemble_filenames = tf.io.gfile.glob( os.path.join(FLAGS.output_dir, '**/*.index')) ensemble_filenames = [filename[:-6] for filename in ensemble_filenames] ensemble_size = len(ensemble_filenames) logging.info('Ensemble size: %s', ensemble_size) logging.info('Ensemble number of weights: %s', ensemble_size * model.count_params()) logging.info('Ensemble filenames: %s', str(ensemble_filenames)) checkpoint = tf.train.Checkpoint(model=model) # Collect the logits output for each ensemble member and train/test data # point. We also collect the labels. # TODO(trandustin): Refactor data loader so you can get the full dataset in # memory without looping. logits_train = [] logits_test = [] labels_train = [] labels_test = [] start_time = time.time() for m, ensemble_filename in enumerate(ensemble_filenames): checkpoint.restore(ensemble_filename) logits = [] logging.info('Working on training data for ensemble member %s', m) for features, labels in dataset_train: logits.append(model(features, training=False)) if m == 0: labels_train.append(labels) logits = tf.concat(logits, axis=0) logits_train.append(logits) if m == 0: labels_train = tf.concat(labels_train, axis=0) logging.info('Working on test data for ensemble member %s', m) logits = [] for features, labels in dataset_test: logits.append(model(features, training=False)) if m == 0: labels_test.append(labels) logits = tf.concat(logits, axis=0) logits_test.append(logits) if m == 0: labels_test = tf.concat(labels_test, axis=0) batch_size = FLAGS.per_core_batch_size steps_per_epoch = ds_info.splits['train'].num_examples // batch_size steps_per_eval = ds_info.splits['test'].num_examples // batch_size current_step = (steps_per_epoch + steps_per_eval) * (m + 1) max_steps = (steps_per_epoch + steps_per_eval) * ensemble_size time_elapsed = time.time() - start_time steps_per_sec = float(current_step) / time_elapsed eta_seconds = (max_steps - current_step) / steps_per_sec message = ( '{:.1%} completion: ensemble member {:d}/{:d}. {:.1f} steps/s. ' 'ETA: {:.0f} min. Time elapsed: {:.0f} min'.format( (m + 1) / ensemble_size, m + 1, ensemble_size, steps_per_sec, eta_seconds / 60, time_elapsed / 60)) metrics = {} # Compute the ensemble's NLL and Gibbs cross entropy for each data point. # Then average over the dataset. nll_train = ensemble_negative_log_likelihood(labels_train, logits_train) nll_test = ensemble_negative_log_likelihood(labels_test, logits_test) gibbs_ce_train = gibbs_cross_entropy(labels_train, logits_train) gibbs_ce_test = gibbs_cross_entropy(labels_test, logits_test) metrics['train_negative_log_likelihood'] = tf.reduce_mean(nll_train) metrics['test_negative_log_likelihood'] = tf.reduce_mean(nll_test) metrics['train_gibbs_cross_entropy'] = tf.reduce_mean(gibbs_ce_train) metrics['test_gibbs_cross_entropy'] = tf.reduce_mean(gibbs_ce_test) # Given the per-element logits tensor of shape [ensemble_size, dataset_size, # num_classes], average over the ensemble members' probabilities. Then # compute accuracy and average over the dataset. probs_train = tf.reduce_mean(tf.nn.softmax(logits_train), axis=0) probs_test = tf.reduce_mean(tf.nn.softmax(logits_test), axis=0) accuracy_train = tf.keras.metrics.sparse_categorical_accuracy( labels_train, probs_train) accuracy_test = tf.keras.metrics.sparse_categorical_accuracy( labels_test, probs_test) metrics['train_accuracy'] = tf.reduce_mean(accuracy_train) metrics['test_accuracy'] = tf.reduce_mean(accuracy_test) logging.info('Metrics: %s', metrics)