def testEnsembleCrossEntropy(self): """Checks that ensemble cross entropy lower-bounds Gibbs cross entropy.""" # For multi-class classifications batch_size = 2 num_classes = 3 ensemble_size = 5 labels = tf.random.uniform( [batch_size], minval=0, maxval=num_classes, dtype=tf.int32) logits = tf.random.normal([ensemble_size, batch_size, num_classes]) ensemble_error = um.ensemble_cross_entropy(labels, logits) gibbs_error = um.gibbs_cross_entropy(labels, logits) self.assertEqual(ensemble_error.shape, ()) self.assertEqual(gibbs_error.shape, ()) self.assertLessEqual(ensemble_error, gibbs_error) # For binary classifications num_classes = 1 labels = tf.random.uniform( [batch_size], minval=0, maxval=num_classes, dtype=tf.float32) logits = tf.random.normal([ensemble_size, batch_size, num_classes]) loss_logits = tf.squeeze(logits, axis=-1) ensemble_error = um.ensemble_cross_entropy(labels, loss_logits, binary=True) gibbs_error = um.gibbs_cross_entropy(labels, loss_logits, binary=True) self.assertEqual(ensemble_error.shape, ()) self.assertEqual(gibbs_error.shape, ()) self.assertLessEqual(ensemble_error, gibbs_error)
def greedy_selection(val_logits, val_labels, max_ens_size, objective='nll'): """Greedy procedure from Caruana et al. 2004, with replacement.""" assert_msg = 'Unknown objective type (received {}).'.format(objective) assert objective in ('nll', 'acc', 'nll-acc'), assert_msg if objective == 'nll': get_objective = lambda acc, nll: nll elif objective == 'acc': get_objective = lambda acc, nll: acc else: get_objective = lambda acc, nll: nll - acc best_acc = 0. best_nll = np.inf best_objective = np.inf ens = [] def get_ens_size(): return len(set(ens)) while get_ens_size() < max_ens_size: current_val_logits = [val_logits[model_id] for model_id in ens] best_model_id = None for model_id, logits in enumerate(val_logits): acc = _ensemble_accuracy(val_labels, current_val_logits + [logits]) nll = um.ensemble_cross_entropy(val_labels, current_val_logits + [logits]) obj = get_objective(acc, nll) if obj < best_objective: best_acc = acc best_nll = nll best_objective = obj best_model_id = model_id if best_model_id is None: logging.info( 'Ensemble could not be improved: Greedy selection stops.') break ens.append(best_model_id) return ens, best_acc, best_nll
def main(argv): del argv # unused arg if not FLAGS.use_gpu: raise ValueError('Only GPU is currently supported.') if FLAGS.num_cores > 1: raise ValueError('Only a single accelerator is currently supported.') tf.random.set_seed(FLAGS.seed) tf.io.gfile.makedirs(FLAGS.output_dir) ds_info = tfds.builder(FLAGS.dataset).info batch_size = FLAGS.per_core_batch_size * FLAGS.num_cores steps_per_eval = ds_info.splits['test'].num_examples // batch_size num_classes = ds_info.features['label'].num_classes dataset = ub.datasets.get( FLAGS.dataset, split=tfds.Split.TEST).load(batch_size=batch_size) test_datasets = {'clean': dataset} extra_kwargs = {} if FLAGS.dataset == 'cifar100': extra_kwargs['data_dir'] = FLAGS.cifar100_c_path corruption_types, _ = utils.load_corrupted_test_info(FLAGS.dataset) for corruption_type in corruption_types: for severity in range(1, 6): dataset = ub.datasets.get( f'{FLAGS.dataset}_corrupted', corruption_type=corruption_type, severity=severity, split=tfds.Split.TEST, **extra_kwargs).load(batch_size=batch_size) test_datasets[f'{corruption_type}_{severity}'] = dataset model = ub.models.wide_resnet(input_shape=ds_info.features['image'].shape, depth=28, width_multiplier=10, num_classes=num_classes, l2=0., version=2) logging.info('Model input shape: %s', model.input_shape) logging.info('Model output shape: %s', model.output_shape) logging.info('Model number of weights: %s', model.count_params()) # Search for checkpoints from their index file; then remove the index suffix. ensemble_filenames = parse_checkpoint_dir(FLAGS.checkpoint_dir) ensemble_size = len(ensemble_filenames) logging.info('Ensemble size: %s', ensemble_size) logging.info('Ensemble number of weights: %s', ensemble_size * model.count_params()) logging.info('Ensemble filenames: %s', str(ensemble_filenames)) checkpoint = tf.train.Checkpoint(model=model) # Write model predictions to files. num_datasets = len(test_datasets) for m, ensemble_filename in enumerate(ensemble_filenames): checkpoint.restore(ensemble_filename) for n, (name, test_dataset) in enumerate(test_datasets.items()): filename = '{dataset}_{member}.npy'.format(dataset=name, member=m) filename = os.path.join(FLAGS.output_dir, filename) if not tf.io.gfile.exists(filename): logits = [] test_iterator = iter(test_dataset) for _ in range(steps_per_eval): features = next(test_iterator)['features'] # pytype: disable=unsupported-operands logits.append(model(features, training=False)) logits = tf.concat(logits, axis=0) with tf.io.gfile.GFile(filename, 'w') as f: np.save(f, logits.numpy()) percent = (m * num_datasets + (n + 1)) / (ensemble_size * num_datasets) message = ( '{:.1%} completion for prediction: ensemble member {:d}/{:d}. ' 'Dataset {:d}/{:d}'.format(percent, m + 1, ensemble_size, n + 1, num_datasets)) logging.info(message) metrics = { 'test/negative_log_likelihood': tf.keras.metrics.Mean(), 'test/gibbs_cross_entropy': tf.keras.metrics.Mean(), 'test/accuracy': tf.keras.metrics.SparseCategoricalAccuracy(), 'test/ece': rm.metrics.ExpectedCalibrationError(num_bins=FLAGS.num_bins), } corrupt_metrics = {} for name in test_datasets: corrupt_metrics['test/nll_{}'.format(name)] = tf.keras.metrics.Mean() corrupt_metrics['test/accuracy_{}'.format(name)] = ( tf.keras.metrics.SparseCategoricalAccuracy()) corrupt_metrics['test/ece_{}'.format(name)] = ( rm.metrics.ExpectedCalibrationError(num_bins=FLAGS.num_bins)) for i in range(ensemble_size): metrics['test/nll_member_{}'.format(i)] = tf.keras.metrics.Mean() metrics['test/accuracy_member_{}'.format(i)] = ( tf.keras.metrics.SparseCategoricalAccuracy()) test_diversity = { 'test/disagreement': tf.keras.metrics.Mean(), 'test/average_kl': tf.keras.metrics.Mean(), 'test/cosine_similarity': tf.keras.metrics.Mean(), } metrics.update(test_diversity) # Evaluate model predictions. for n, (name, test_dataset) in enumerate(test_datasets.items()): logits_dataset = [] for m in range(ensemble_size): filename = '{dataset}_{member}.npy'.format(dataset=name, member=m) filename = os.path.join(FLAGS.output_dir, filename) with tf.io.gfile.GFile(filename, 'rb') as f: logits_dataset.append(np.load(f)) logits_dataset = tf.convert_to_tensor(logits_dataset) test_iterator = iter(test_dataset) for step in range(steps_per_eval): labels = next(test_iterator)['labels'] # pytype: disable=unsupported-operands logits = logits_dataset[:, (step * batch_size):((step + 1) * batch_size)] labels = tf.cast(labels, tf.int32) negative_log_likelihood = um.ensemble_cross_entropy(labels, logits) per_probs = tf.nn.softmax(logits) probs = tf.reduce_mean(per_probs, axis=0) if name == 'clean': gibbs_ce = um.gibbs_cross_entropy(labels, logits) metrics['test/negative_log_likelihood'].update_state( negative_log_likelihood) metrics['test/gibbs_cross_entropy'].update_state(gibbs_ce) metrics['test/accuracy'].update_state(labels, probs) metrics['test/ece'].add_batch(probs, label=labels) for i in range(ensemble_size): member_probs = per_probs[i] member_loss = tf.keras.losses.sparse_categorical_crossentropy( labels, member_probs) metrics['test/nll_member_{}'.format(i)].update_state( member_loss) metrics['test/accuracy_member_{}'.format(i)].update_state( labels, member_probs) diversity = rm.metrics.AveragePairwiseDiversity() diversity.add_batch(per_probs, num_models=ensemble_size) diversity_results = diversity.result() for k, v in diversity_results.items(): test_diversity['test/' + k].update_state(v) else: corrupt_metrics['test/nll_{}'.format(name)].update_state( negative_log_likelihood) corrupt_metrics['test/accuracy_{}'.format(name)].update_state( labels, probs) corrupt_metrics['test/ece_{}'.format(name)].add_batch( probs, label=labels) message = ( '{:.1%} completion for evaluation: dataset {:d}/{:d}'.format( (n + 1) / num_datasets, n + 1, num_datasets)) logging.info(message) corrupt_results = utils.aggregate_corrupt_metrics(corrupt_metrics, corruption_types) total_results = {name: metric.result() for name, metric in metrics.items()} total_results.update(corrupt_results) # Metrics from Robustness Metrics (like ECE) will return a dict with a # single key/value, instead of a scalar. total_results = { k: (list(v.values())[0] if isinstance(v, dict) else v) for k, v in total_results.items() } logging.info('Metrics: %s', total_results)
def main(argv): del argv # unused arg if not FLAGS.use_gpu: raise ValueError('Only GPU is currently supported.') if FLAGS.num_cores > 1: raise ValueError('Only a single accelerator is currently supported.') tf.random.set_seed(FLAGS.seed) tf.io.gfile.makedirs(FLAGS.output_dir) batch_size = FLAGS.per_core_batch_size * FLAGS.num_cores steps_per_eval = IMAGENET_VALIDATION_IMAGES // batch_size builder = utils.ImageNetInput(data_dir=FLAGS.data_dir, use_bfloat16=False) clean_test_dataset = builder.as_dataset(split=tfds.Split.TEST, batch_size=batch_size) test_datasets = {'clean': clean_test_dataset} corruption_types, max_intensity = utils.load_corrupted_test_info() for name in corruption_types: for intensity in range(1, max_intensity + 1): dataset_name = '{0}_{1}'.format(name, intensity) test_datasets[dataset_name] = utils.load_corrupted_test_dataset( corruption_name=name, corruption_intensity=intensity, batch_size=batch_size, drop_remainder=True, use_bfloat16=False) model = ub.models.resnet50_deterministic(input_shape=(224, 224, 3), num_classes=NUM_CLASSES) logging.info('Model input shape: %s', model.input_shape) logging.info('Model output shape: %s', model.output_shape) logging.info('Model number of weights: %s', model.count_params()) # Search for checkpoints from their index file; then remove the index suffix. ensemble_filenames = tf.io.gfile.glob(os.path.join(FLAGS.checkpoint_dir, '**/*.index')) ensemble_filenames = [filename[:-6] for filename in ensemble_filenames] ensemble_size = len(ensemble_filenames) logging.info('Ensemble size: %s', ensemble_size) logging.info('Ensemble number of weights: %s', ensemble_size * model.count_params()) logging.info('Ensemble filenames: %s', str(ensemble_filenames)) checkpoint = tf.train.Checkpoint(model=model) # Write model predictions to files. num_datasets = len(test_datasets) for m, ensemble_filename in enumerate(ensemble_filenames): checkpoint.restore(ensemble_filename) for n, (name, test_dataset) in enumerate(test_datasets.items()): filename = '{dataset}_{member}.npy'.format(dataset=name, member=m) filename = os.path.join(FLAGS.output_dir, filename) if not tf.io.gfile.exists(filename): logits = [] test_iterator = iter(test_dataset) for _ in range(steps_per_eval): features, _ = next(test_iterator) # pytype: disable=attribute-error logits.append(model(features, training=False)) logits = tf.concat(logits, axis=0) with tf.io.gfile.GFile(filename, 'w') as f: np.save(f, logits.numpy()) percent = (m * num_datasets + (n + 1)) / (ensemble_size * num_datasets) message = ('{:.1%} completion for prediction: ensemble member {:d}/{:d}. ' 'Dataset {:d}/{:d}'.format(percent, m + 1, ensemble_size, n + 1, num_datasets)) logging.info(message) metrics = { 'test/negative_log_likelihood': tf.keras.metrics.Mean(), 'test/gibbs_cross_entropy': tf.keras.metrics.Mean(), 'test/accuracy': tf.keras.metrics.SparseCategoricalAccuracy(), 'test/ece': um.ExpectedCalibrationError(num_bins=FLAGS.num_bins), } corrupt_metrics = {} for name in test_datasets: corrupt_metrics['test/nll_{}'.format(name)] = tf.keras.metrics.Mean() corrupt_metrics['test/accuracy_{}'.format(name)] = ( tf.keras.metrics.SparseCategoricalAccuracy()) corrupt_metrics['test/ece_{}'.format( name)] = um.ExpectedCalibrationError(num_bins=FLAGS.num_bins) # Evaluate model predictions. for n, (name, test_dataset) in enumerate(test_datasets.items()): logits_dataset = [] for m in range(ensemble_size): filename = '{dataset}_{member}.npy'.format(dataset=name, member=m) filename = os.path.join(FLAGS.output_dir, filename) with tf.io.gfile.GFile(filename, 'rb') as f: logits_dataset.append(np.load(f)) logits_dataset = tf.convert_to_tensor(logits_dataset) test_iterator = iter(test_dataset) for step in range(steps_per_eval): _, labels = next(test_iterator) # pytype: disable=attribute-error logits = logits_dataset[:, (step*batch_size):((step+1)*batch_size)] labels = tf.cast(tf.reshape(labels, [-1]), tf.int32) negative_log_likelihood = um.ensemble_cross_entropy(labels, logits) per_probs = tf.nn.softmax(logits) probs = tf.reduce_mean(per_probs, axis=0) if name == 'clean': gibbs_ce = um.gibbs_cross_entropy(labels, logits) metrics['test/negative_log_likelihood'].update_state( negative_log_likelihood) metrics['test/gibbs_cross_entropy'].update_state(gibbs_ce) metrics['test/accuracy'].update_state(labels, probs) metrics['test/ece'].update_state(labels, probs) else: corrupt_metrics['test/nll_{}'.format(name)].update_state( negative_log_likelihood) corrupt_metrics['test/accuracy_{}'.format(name)].update_state( labels, probs) corrupt_metrics['test/ece_{}'.format(name)].update_state( labels, probs) message = ('{:.1%} completion for evaluation: dataset {:d}/{:d}'.format( (n + 1) / num_datasets, n + 1, num_datasets)) logging.info(message) corrupt_results = utils.aggregate_corrupt_metrics(corrupt_metrics, corruption_types, max_intensity, FLAGS.alexnet_errors_path) total_results = {name: metric.result() for name, metric in metrics.items()} total_results.update(corrupt_results) logging.info('Metrics: %s', total_results)
def main(argv): del argv # unused arg if not FLAGS.use_gpu: raise ValueError('Only GPU is currently supported.') if FLAGS.num_cores > 1: raise ValueError('Only a single accelerator is currently supported.') tf.random.set_seed(FLAGS.seed) logging.info('Model checkpoint will be saved at %s', FLAGS.output_dir) tf.io.gfile.makedirs(FLAGS.output_dir) batch_size = FLAGS.per_core_batch_size * FLAGS.num_cores test_batch_size = batch_size data_buffer_size = batch_size * 10 ind_dataset_builder = ds.WikipediaToxicityDataset( split='test', data_dir=FLAGS.in_dataset_dir, shuffle_buffer_size=data_buffer_size) ood_dataset_builder = ds.CivilCommentsDataset( split='test', data_dir=FLAGS.ood_dataset_dir, shuffle_buffer_size=data_buffer_size) ood_identity_dataset_builder = ds.CivilCommentsIdentitiesDataset( split='test', data_dir=FLAGS.identity_dataset_dir, shuffle_buffer_size=data_buffer_size) test_dataset_builders = { 'ind': ind_dataset_builder, 'ood': ood_dataset_builder, 'ood_identity': ood_identity_dataset_builder, } class_weight = utils.create_class_weight( test_dataset_builders=test_dataset_builders) logging.info('class_weight: %s', str(class_weight)) ds_info = ind_dataset_builder.tfds_info # Positive and negative classes. num_classes = ds_info.metadata['num_classes'] test_datasets = {} steps_per_eval = {} for dataset_name, dataset_builder in test_dataset_builders.items(): test_datasets[dataset_name] = dataset_builder.load( batch_size=test_batch_size) steps_per_eval[dataset_name] = ( dataset_builder.num_examples // test_batch_size) logging.info('Building %s model', FLAGS.model_family) bert_config_dir, _ = utils.resolve_bert_ckpt_and_config_dir( FLAGS.bert_model_type, FLAGS.bert_dir, FLAGS.bert_config_dir, FLAGS.bert_ckpt_dir) bert_config = utils.create_config(bert_config_dir) gp_layer_kwargs = dict( num_inducing=FLAGS.gp_hidden_dim, gp_kernel_scale=FLAGS.gp_scale, gp_output_bias=FLAGS.gp_bias, normalize_input=FLAGS.gp_input_normalization, gp_cov_momentum=FLAGS.gp_cov_discount_factor, gp_cov_ridge_penalty=FLAGS.gp_cov_ridge_penalty) spec_norm_kwargs = dict( iteration=FLAGS.spec_norm_iteration, norm_multiplier=FLAGS.spec_norm_bound) model, _ = ub.models.SngpBertBuilder( num_classes=num_classes, bert_config=bert_config, gp_layer_kwargs=gp_layer_kwargs, spec_norm_kwargs=spec_norm_kwargs, use_gp_layer=FLAGS.use_gp_layer, use_spec_norm_att=FLAGS.use_spec_norm_att, use_spec_norm_ffn=FLAGS.use_spec_norm_ffn, use_layer_norm_att=FLAGS.use_layer_norm_att, use_layer_norm_ffn=FLAGS.use_layer_norm_ffn, use_spec_norm_plr=FLAGS.use_spec_norm_plr) logging.info('Model input shape: %s', model.input_shape) logging.info('Model output shape: %s', model.output_shape) logging.info('Model number of weights: %s', model.count_params()) # Search for checkpoints from their index file; then remove the index suffix. ensemble_filenames = tf.io.gfile.glob( os.path.join(FLAGS.checkpoint_dir, '**/*.index')) ensemble_filenames = [filename[:-6] for filename in ensemble_filenames] if FLAGS.num_models > len(ensemble_filenames): raise ValueError('Number of models to be included in the ensemble ' 'should be less than total number of models in ' 'the checkpoint_dir.') ensemble_filenames = ensemble_filenames[:FLAGS.num_models] ensemble_size = len(ensemble_filenames) logging.info('Ensemble size: %s', ensemble_size) logging.info('Ensemble number of weights: %s', ensemble_size * model.count_params()) logging.info('Ensemble filenames: %s', str(ensemble_filenames)) checkpoint = tf.train.Checkpoint(model=model) # Write model predictions to files. num_datasets = len(test_datasets) for m, ensemble_filename in enumerate(ensemble_filenames): checkpoint.restore(ensemble_filename).assert_existing_objects_matched() for n, (dataset_name, test_dataset) in enumerate(test_datasets.items()): filename = '{dataset}_{member}.npy'.format(dataset=dataset_name, member=m) filename = os.path.join(FLAGS.output_dir, filename) if not tf.io.gfile.exists(filename): logits_list = [] test_iterator = iter(test_dataset) for step in range(steps_per_eval[dataset_name]): try: inputs = next(test_iterator) except StopIteration: continue features, labels, _ = utils.create_feature_and_label(inputs) logits = model(features, training=False) if isinstance(logits, (list, tuple)): # If model returns a tuple of (logits, covmat), extract both. logits, covmat = logits else: covmat = tf.eye(test_batch_size) if FLAGS.use_bfloat16: logits = tf.cast(logits, tf.float32) covmat = tf.cast(covmat, tf.float32) logits = ed.layers.utils.mean_field_logits( logits, covmat, mean_field_factor=FLAGS.gp_mean_field_factor_ensemble) logits_list.append(logits) logits_all = tf.concat(logits_list, axis=0) with tf.io.gfile.GFile(filename, 'w') as f: np.save(f, logits_all.numpy()) percent = (m * num_datasets + (n + 1)) / (ensemble_size * num_datasets) message = ('{:.1%} completion for prediction: ensemble member {:d}/{:d}. ' 'Dataset {:d}/{:d}'.format(percent, m + 1, ensemble_size, n + 1, num_datasets)) logging.info(message) metrics = { 'test/negative_log_likelihood': tf.keras.metrics.Mean(), 'test/auroc': tf.keras.metrics.AUC(curve='ROC'), 'test/aupr': tf.keras.metrics.AUC(curve='PR'), 'test/brier': tf.keras.metrics.MeanSquaredError(), 'test/brier_weighted': tf.keras.metrics.MeanSquaredError(), 'test/ece': um.ExpectedCalibrationError(num_bins=FLAGS.num_bins), 'test/acc': tf.keras.metrics.Accuracy(), 'test/acc_weighted': tf.keras.metrics.Accuracy(), 'test/precision': tf.keras.metrics.Precision(), 'test/recall': tf.keras.metrics.Recall(), 'test/f1': tfa_metrics.F1Score( num_classes=num_classes, average='micro', threshold=FLAGS.ece_label_threshold) } for fraction in FLAGS.fractions: metrics.update({ 'test_collab_acc/collab_acc_{}'.format(fraction): um.OracleCollaborativeAccuracy( fraction=float(fraction), num_bins=FLAGS.num_bins) }) for dataset_name, test_dataset in test_datasets.items(): if dataset_name != 'ind': metrics.update({ 'test/nll_{}'.format(dataset_name): tf.keras.metrics.Mean(), 'test/auroc_{}'.format(dataset_name): tf.keras.metrics.AUC(curve='ROC'), 'test/aupr_{}'.format(dataset_name): tf.keras.metrics.AUC(curve='PR'), 'test/brier_{}'.format(dataset_name): tf.keras.metrics.MeanSquaredError(), 'test/brier_weighted_{}'.format(dataset_name): tf.keras.metrics.MeanSquaredError(), 'test/ece_{}'.format(dataset_name): um.ExpectedCalibrationError(num_bins=FLAGS.num_bins), 'test/acc_weighted_{}'.format(dataset_name): tf.keras.metrics.Accuracy(), 'test/acc_{}'.format(dataset_name): tf.keras.metrics.Accuracy(), 'test/precision_{}'.format(dataset_name): tf.keras.metrics.Precision(), 'test/recall_{}'.format(dataset_name): tf.keras.metrics.Recall(), 'test/f1_{}'.format(dataset_name): tfa_metrics.F1Score( num_classes=num_classes, average='micro', threshold=FLAGS.ece_label_threshold) }) for fraction in FLAGS.fractions: metrics.update({ 'test_collab_acc/collab_acc_{}_{}'.format(fraction, dataset_name): um.OracleCollaborativeAccuracy( fraction=float(fraction), num_bins=FLAGS.num_bins) }) @tf.function def generate_sample_weight(labels, class_weight, label_threshold=0.7): """Generate sample weight for weighted accuracy calculation.""" if label_threshold != 0.7: logging.warning('The class weight was based on `label_threshold` = 0.7, ' 'and weighted accuracy/brier will be meaningless if ' '`label_threshold` is not equal to this value, which is ' 'recommended by Jigsaw Conversation AI team.') labels_int = tf.cast(labels > label_threshold, tf.int32) sample_weight = tf.gather(class_weight, labels_int) return sample_weight # Evaluate model predictions. for n, (dataset_name, test_dataset) in enumerate(test_datasets.items()): logits_dataset = [] for m in range(ensemble_size): filename = '{dataset}_{member}.npy'.format(dataset=dataset_name, member=m) filename = os.path.join(FLAGS.output_dir, filename) with tf.io.gfile.GFile(filename, 'rb') as f: logits_dataset.append(np.load(f)) logits_dataset = tf.convert_to_tensor(logits_dataset) test_iterator = iter(test_dataset) texts_list = [] logits_list = [] labels_list = [] # Use dict to collect additional labels specified by additional label names. # Here we use `OrderedDict` to get consistent ordering for this dict so # we can retrieve the predictions for each identity labels in Colab. additional_labels_dict = collections.OrderedDict() for step in range(steps_per_eval[dataset_name]): try: inputs = next(test_iterator) # type: Mapping[Text, tf.Tensor] # pytype: disable=annotation-type-mismatch except StopIteration: continue features, labels, additional_labels = ( utils.create_feature_and_label(inputs)) logits = logits_dataset[:, (step * batch_size):((step + 1) * batch_size)] loss_logits = tf.squeeze(logits, axis=-1) negative_log_likelihood = um.ensemble_cross_entropy( labels, loss_logits, binary=True) per_probs = tf.nn.sigmoid(logits) probs = tf.reduce_mean(per_probs, axis=0) # Cast labels to discrete for ECE computation ece_labels = tf.cast(labels > FLAGS.ece_label_threshold, tf.float32) one_hot_labels = tf.one_hot(tf.cast(ece_labels, tf.int32), depth=num_classes) ece_probs = tf.concat([1. - probs, probs], axis=1) pred_labels = tf.math.argmax(ece_probs, axis=-1) auc_probs = tf.squeeze(probs, axis=1) texts_list.append(inputs['input_ids']) logits_list.append(logits) labels_list.append(labels) if 'identity' in dataset_name: for identity_label_name in utils.IDENTITY_LABELS: if identity_label_name not in additional_labels_dict: additional_labels_dict[identity_label_name] = [] additional_labels_dict[identity_label_name].append( additional_labels[identity_label_name].numpy()) sample_weight = generate_sample_weight( labels, class_weight['test/{}'.format(dataset_name)], FLAGS.ece_label_threshold) if dataset_name == 'ind': metrics['test/negative_log_likelihood'].update_state( negative_log_likelihood) metrics['test/auroc'].update_state(labels, auc_probs) metrics['test/aupr'].update_state(labels, auc_probs) metrics['test/brier'].update_state(labels, auc_probs) metrics['test/brier_weighted'].update_state( tf.expand_dims(labels, -1), probs, sample_weight=sample_weight) metrics['test/ece'].add_batch(ece_probs, label=ece_labels) metrics['test/acc'].update_state(ece_labels, pred_labels) metrics['test/acc_weighted'].update_state( ece_labels, pred_labels, sample_weight=sample_weight) metrics['test/precision'].update_state(ece_labels, pred_labels) metrics['test/recall'].update_state(ece_labels, pred_labels) metrics['test/f1'].update_state(one_hot_labels, ece_probs) for fraction in FLAGS.fractions: metrics['test_collab_acc/collab_acc_{}'.format( fraction)].update_state(ece_labels, ece_probs) else: metrics['test/nll_{}'.format(dataset_name)].update_state( negative_log_likelihood) metrics['test/auroc_{}'.format(dataset_name)].update_state( labels, auc_probs) metrics['test/aupr_{}'.format(dataset_name)].update_state( labels, auc_probs) metrics['test/brier_{}'.format(dataset_name)].update_state( labels, auc_probs) metrics['test/brier_weighted_{}'.format(dataset_name)].update_state( tf.expand_dims(labels, -1), probs, sample_weight=sample_weight) metrics['test/ece_{}'.format(dataset_name)].add_batch( ece_probs, label=ece_labels) metrics['test/acc_{}'.format(dataset_name)].update_state( ece_labels, pred_labels) metrics['test/acc_weighted_{}'.format(dataset_name)].update_state( ece_labels, pred_labels, sample_weight=sample_weight) metrics['test/precision_{}'.format(dataset_name)].update_state( ece_labels, pred_labels) metrics['test/recall_{}'.format(dataset_name)].update_state( ece_labels, pred_labels) metrics['test/f1_{}'.format(dataset_name)].update_state( one_hot_labels, ece_probs) for fraction in FLAGS.fractions: metrics['test_collab_acc/collab_acc_{}_{}'.format( fraction, dataset_name)].update_state(ece_labels, ece_probs) texts_all = tf.concat(texts_list, axis=0) logits_all = tf.concat(logits_list, axis=1) labels_all = tf.concat(labels_list, axis=0) additional_labels_all = [] if additional_labels_dict: additional_labels_all = list(additional_labels_dict.values()) utils.save_prediction( texts_all.numpy(), path=os.path.join(FLAGS.output_dir, 'texts_{}'.format(dataset_name))) utils.save_prediction( labels_all.numpy(), path=os.path.join(FLAGS.output_dir, 'labels_{}'.format(dataset_name))) utils.save_prediction( logits_all.numpy(), path=os.path.join(FLAGS.output_dir, 'logits_{}'.format(dataset_name))) if 'identity' in dataset_name: utils.save_prediction( np.array(additional_labels_all), path=os.path.join(FLAGS.output_dir, 'additional_labels_{}'.format(dataset_name))) message = ('{:.1%} completion for evaluation: dataset {:d}/{:d}'.format( (n + 1) / num_datasets, n + 1, num_datasets)) logging.info(message) total_results = {name: metric.result() for name, metric in metrics.items()} # Metrics from Robustness Metrics (like ECE) will return a dict with a # single key/value, instead of a scalar. total_results = { k: (list(v.values())[0] if isinstance(v, dict) else v) for k, v in total_results.items() } logging.info('Metrics: %s', total_results)
def main(argv): del argv # unused arg if not FLAGS.use_gpu: raise ValueError('Only GPU is currently supported.') if FLAGS.num_cores > 1: raise ValueError('Only a single accelerator is currently supported.') tf.random.set_seed(FLAGS.seed) tf.io.gfile.makedirs(FLAGS.output_dir) ds_info = tfds.builder(FLAGS.dataset).info batch_size = FLAGS.per_core_batch_size * FLAGS.num_cores steps_per_eval = ds_info.splits['test'].num_examples // batch_size num_classes = ds_info.features['label'].num_classes dataset_input_fn = utils.load_input_fn( split=tfds.Split.TEST, name=FLAGS.dataset, batch_size=FLAGS.per_core_batch_size, use_bfloat16=FLAGS.use_bfloat16) test_datasets = {'clean': dataset_input_fn()} corruption_types, max_intensity = utils.load_corrupted_test_info( FLAGS.dataset) for name in corruption_types: for intensity in range(1, max_intensity + 1): dataset_name = '{0}_{1}'.format(name, intensity) if FLAGS.dataset == 'cifar10': load_c_dataset = utils.load_cifar10_c_input_fn else: load_c_dataset = functools.partial( utils.load_cifar100_c_input_fn, path=FLAGS.cifar100_c_path) corrupted_input_fn = load_c_dataset( corruption_name=name, corruption_intensity=intensity, batch_size=FLAGS.per_core_batch_size, use_bfloat16=FLAGS.use_bfloat16) test_datasets[dataset_name] = corrupted_input_fn() model = ub.models.wide_resnet( input_shape=ds_info.features['image'].shape, depth=28, width_multiplier=10, num_classes=num_classes, l2=0., version=2) logging.info('Model input shape: %s', model.input_shape) logging.info('Model output shape: %s', model.output_shape) logging.info('Model number of weights: %s', model.count_params()) # Search for checkpoints from their index file; then remove the index suffix. ensemble_filenames = tf.io.gfile.glob(os.path.join(FLAGS.checkpoint_dir, '**/*.index')) ensemble_filenames = [filename[:-6] for filename in ensemble_filenames] ensemble_size = len(ensemble_filenames) logging.info('Ensemble size: %s', ensemble_size) logging.info('Ensemble number of weights: %s', ensemble_size * model.count_params()) logging.info('Ensemble filenames: %s', str(ensemble_filenames)) checkpoint = tf.train.Checkpoint(model=model) # Write model predictions to files. num_datasets = len(test_datasets) for m, ensemble_filename in enumerate(ensemble_filenames): checkpoint.restore(ensemble_filename) for n, (name, test_dataset) in enumerate(test_datasets.items()): filename = '{dataset}_{member}.npy'.format(dataset=name, member=m) filename = os.path.join(FLAGS.output_dir, filename) if not tf.io.gfile.exists(filename): logits = [] test_iterator = iter(test_dataset) for _ in range(steps_per_eval): features, _ = next(test_iterator) # pytype: disable=attribute-error logits.append(model(features, training=False)) logits = tf.concat(logits, axis=0) with tf.io.gfile.GFile(filename, 'w') as f: np.save(f, logits.numpy()) percent = (m * num_datasets + (n + 1)) / (ensemble_size * num_datasets) message = ('{:.1%} completion for prediction: ensemble member {:d}/{:d}. ' 'Dataset {:d}/{:d}'.format(percent, m + 1, ensemble_size, n + 1, num_datasets)) logging.info(message) metrics = { 'test/negative_log_likelihood': tf.keras.metrics.Mean(), 'test/gibbs_cross_entropy': tf.keras.metrics.Mean(), 'test/accuracy': tf.keras.metrics.SparseCategoricalAccuracy(), 'test/ece': um.ExpectedCalibrationError(num_bins=FLAGS.num_bins), } corrupt_metrics = {} for name in test_datasets: corrupt_metrics['test/nll_{}'.format(name)] = tf.keras.metrics.Mean() corrupt_metrics['test/accuracy_{}'.format(name)] = ( tf.keras.metrics.SparseCategoricalAccuracy()) corrupt_metrics['test/ece_{}'.format(name)] = ( um.ExpectedCalibrationError(num_bins=FLAGS.num_bins)) for i in range(ensemble_size): metrics['test/nll_member_{}'.format(i)] = tf.keras.metrics.Mean() metrics['test/accuracy_member_{}'.format(i)] = ( tf.keras.metrics.SparseCategoricalAccuracy()) test_diversity = { 'test/disagreement': tf.keras.metrics.Mean(), 'test/average_kl': tf.keras.metrics.Mean(), 'test/cosine_similarity': tf.keras.metrics.Mean(), } metrics.update(test_diversity) # Evaluate model predictions. for n, (name, test_dataset) in enumerate(test_datasets.items()): logits_dataset = [] for m in range(ensemble_size): filename = '{dataset}_{member}.npy'.format(dataset=name, member=m) filename = os.path.join(FLAGS.output_dir, filename) with tf.io.gfile.GFile(filename, 'rb') as f: logits_dataset.append(np.load(f)) logits_dataset = tf.convert_to_tensor(logits_dataset) test_iterator = iter(test_dataset) for step in range(steps_per_eval): _, labels = next(test_iterator) # pytype: disable=attribute-error logits = logits_dataset[:, (step*batch_size):((step+1)*batch_size)] labels = tf.cast(labels, tf.int32) negative_log_likelihood = um.ensemble_cross_entropy(labels, logits) per_probs = tf.nn.softmax(logits) probs = tf.reduce_mean(per_probs, axis=0) if name == 'clean': gibbs_ce = um.gibbs_cross_entropy(labels, logits) metrics['test/negative_log_likelihood'].update_state( negative_log_likelihood) metrics['test/gibbs_cross_entropy'].update_state(gibbs_ce) metrics['test/accuracy'].update_state(labels, probs) metrics['test/ece'].update_state(labels, probs) for i in range(ensemble_size): member_probs = per_probs[i] member_loss = tf.keras.losses.sparse_categorical_crossentropy( labels, member_probs) metrics['test/nll_member_{}'.format(i)].update_state(member_loss) metrics['test/accuracy_member_{}'.format(i)].update_state( labels, member_probs) diversity_results = um.average_pairwise_diversity( per_probs, ensemble_size) for k, v in diversity_results.items(): test_diversity['test/' + k].update_state(v) else: corrupt_metrics['test/nll_{}'.format(name)].update_state( negative_log_likelihood) corrupt_metrics['test/accuracy_{}'.format(name)].update_state( labels, probs) corrupt_metrics['test/ece_{}'.format(name)].update_state( labels, probs) message = ('{:.1%} completion for evaluation: dataset {:d}/{:d}'.format( (n + 1) / num_datasets, n + 1, num_datasets)) logging.info(message) corrupt_results = utils.aggregate_corrupt_metrics(corrupt_metrics, corruption_types, max_intensity) total_results = {name: metric.result() for name, metric in metrics.items()} total_results.update(corrupt_results) logging.info('Metrics: %s', total_results)
def main(argv): del argv # unused arg if not FLAGS.use_gpu: raise ValueError('Only GPU is currently supported.') if FLAGS.num_cores > 1: raise ValueError('Only a single accelerator is currently supported.') tf.random.set_seed(FLAGS.seed) tf.io.gfile.makedirs(FLAGS.output_dir) ind_dataset_builder = ub.datasets.ClincIntentDetectionDataset( split='test', data_dir=FLAGS.data_dir, data_mode='ind') ood_dataset_builder = ub.datasets.ClincIntentDetectionDataset( split='test', data_dir=FLAGS.data_dir, data_mode='ood') all_dataset_builder = ub.datasets.ClincIntentDetectionDataset( split='test', data_dir=FLAGS.data_dir, data_mode='all') dataset_builders = { 'clean': ind_dataset_builder, 'ood': ood_dataset_builder, 'all': all_dataset_builder } ds_info = ind_dataset_builder.tfds_info feature_size = ds_info.metadata['feature_size'] # num_classes is number of valid intents plus out-of-scope intent num_classes = ds_info.features['intent_label'].num_classes + 1 # vocab_size is total number of valid tokens plus the out-of-vocabulary token. vocab_size = ind_dataset_builder.tokenizer.num_words + 1 batch_size = FLAGS.per_core_batch_size * FLAGS.num_cores test_datasets = {} steps_per_eval = {} for dataset_name, dataset_builder in dataset_builders.items(): test_datasets[dataset_name] = dataset_builder.load( batch_size=batch_size) steps_per_eval[ dataset_name] = dataset_builder.num_examples // batch_size bert_config_dir, _ = sngp.resolve_bert_ckpt_and_config_dir( FLAGS.bert_dir, FLAGS.bert_config_dir, FLAGS.bert_ckpt_dir) bert_config = bert_utils.create_config(bert_config_dir) gp_layer_kwargs = dict(num_inducing=FLAGS.gp_hidden_dim, gp_kernel_scale=FLAGS.gp_scale, gp_output_bias=FLAGS.gp_bias, normalize_input=FLAGS.gp_input_normalization, gp_cov_momentum=FLAGS.gp_cov_discount_factor, gp_cov_ridge_penalty=FLAGS.gp_cov_ridge_penalty) spec_norm_kwargs = dict(iteration=FLAGS.spec_norm_iteration, norm_multiplier=FLAGS.spec_norm_bound) model, bert_encoder = ub.models.SngpBertBuilder( num_classes=num_classes, bert_config=bert_config, gp_layer_kwargs=gp_layer_kwargs, spec_norm_kwargs=spec_norm_kwargs, use_gp_layer=FLAGS.use_gp_layer, use_spec_norm_att=FLAGS.use_spec_norm_att, use_spec_norm_ffn=FLAGS.use_spec_norm_ffn, use_layer_norm_att=FLAGS.use_layer_norm_att, use_layer_norm_ffn=FLAGS.use_layer_norm_ffn, use_spec_norm_plr=FLAGS.use_spec_norm_plr) logging.info('Model input shape: %s', model.input_shape) logging.info('Model output shape: %s', model.output_shape) logging.info('Model number of weights: %s', model.count_params()) # Search for checkpoints from their index file; then remove the index suffix. ensemble_filenames = tf.io.gfile.glob( os.path.join(FLAGS.checkpoint_dir, '**/*.index')) ensemble_filenames = [filename[:-6] for filename in ensemble_filenames] ensemble_size = len(ensemble_filenames) logging.info('Ensemble size: %s', ensemble_size) logging.info('Ensemble number of weights: %s', ensemble_size * model.count_params()) logging.info('Ensemble filenames: %s', str(ensemble_filenames)) checkpoint = tf.train.Checkpoint(model=model) # Write model predictions to files. num_datasets = len(test_datasets) for m, ensemble_filename in enumerate(ensemble_filenames): checkpoint.restore(ensemble_filename) for n, (name, test_dataset) in enumerate(test_datasets.items()): filename = '{dataset}_{member}.npy'.format(dataset=name, member=m) filename = os.path.join(FLAGS.output_dir, filename) if not tf.io.gfile.exists(filename): logits_list = [] test_iterator = iter(test_dataset) for _ in range(steps_per_eval[name]): inputs = next(test_iterator) features, _ = bert_utils.create_feature_and_label( inputs, feature_size) logits, covmat = model(features, training=False) logits = ed.layers.utils.mean_field_logits( logits, covmat, mean_field_factor=FLAGS.gp_mean_field_factor) logits_list.append(logits) logits_list = tf.concat(logits_list, axis=0) with tf.io.gfile.GFile(filename, 'w') as f: np.save(f, logits_list.numpy()) percent = (m * num_datasets + (n + 1)) / (ensemble_size * num_datasets) message = ( '{:.1%} completion for prediction: ensemble member {:d}/{:d}. ' 'Dataset {:d}/{:d}'.format(percent, m + 1, ensemble_size, n + 1, num_datasets)) logging.info(message) metrics = { 'test/negative_log_likelihood': tf.keras.metrics.Mean(), 'test/gibbs_cross_entropy': tf.keras.metrics.Mean(), 'test/accuracy': tf.keras.metrics.SparseCategoricalAccuracy(), 'test/ece': um.ExpectedCalibrationError(num_bins=FLAGS.num_bins), } for dataset_name, test_dataset in test_datasets.items(): if dataset_name != 'clean': metrics.update({ 'test/nll_{}'.format(dataset_name): tf.keras.metrics.Mean(), 'test/accuracy_{}'.format(dataset_name): tf.keras.metrics.SparseCategoricalAccuracy(), 'test/ece_{}'.format(dataset_name): um.ExpectedCalibrationError(num_bins=FLAGS.num_bins) }) # Finally, define OOD metrics for the combined IND and OOD dataset. metrics.update({ 'test/auroc_all': tf.keras.metrics.AUC(curve='ROC'), 'test/auprc_all': tf.keras.metrics.AUC(curve='PR') }) # Evaluate model predictions. for n, (name, test_dataset) in enumerate(test_datasets.items()): logits_dataset = [] for m in range(ensemble_size): filename = '{dataset}_{member}.npy'.format(dataset=name, member=m) filename = os.path.join(FLAGS.output_dir, filename) with tf.io.gfile.GFile(filename, 'rb') as f: logits_dataset.append(np.load(f)) logits_dataset = tf.convert_to_tensor(logits_dataset) test_iterator = iter(test_dataset) for step in range(steps_per_eval[name]): inputs = next(test_iterator) _, labels = bert_utils.create_feature_and_label( inputs, feature_size) logits = logits_dataset[:, (step * batch_size):((step + 1) * batch_size)] labels = tf.cast(labels, tf.int32) negative_log_likelihood = um.ensemble_cross_entropy(labels, logits) per_probs = tf.nn.softmax(logits) probs = tf.reduce_mean(per_probs, axis=0) if name == 'clean': gibbs_ce = um.gibbs_cross_entropy(labels, logits) metrics['test/negative_log_likelihood'].update_state( negative_log_likelihood) metrics['test/gibbs_cross_entropy'].update_state(gibbs_ce) metrics['test/accuracy'].update_state(labels, probs) metrics['test/ece'].update_state(labels, probs) else: metrics['test/nll_{}'.format(name)].update_state( negative_log_likelihood) metrics['test/accuracy_{}'.format(name)].update_state( labels, probs) metrics['test/ece_{}'.format(name)].update_state(labels, probs) if dataset_name == 'all': ood_labels = tf.cast(labels == 150, labels.dtype) ood_probs = 1. - tf.reduce_max(probs, axis=-1) metrics['test/auroc_{}'.format(dataset_name)].update_state( ood_labels, ood_probs) metrics['test/auprc_{}'.format(dataset_name)].update_state( ood_labels, ood_probs) message = ( '{:.1%} completion for evaluation: dataset {:d}/{:d}'.format( (n + 1) / num_datasets, n + 1, num_datasets)) logging.info(message) total_results = {name: metric.result() for name, metric in metrics.items()} logging.info('Metrics: %s', total_results)
def main(argv): del argv # unused arg if not FLAGS.use_gpu: raise ValueError('Only GPU is currently supported.') if FLAGS.num_cores > 1: raise ValueError('Only a single accelerator is currently supported.') tf.random.set_seed(FLAGS.seed) logging.info('Model checkpoint will be saved at %s', FLAGS.output_dir) tf.io.gfile.makedirs(FLAGS.output_dir) batch_size = FLAGS.per_core_batch_size * FLAGS.num_cores test_batch_size = batch_size data_buffer_size = batch_size * 10 ind_dataset_builder = ds.WikipediaToxicityDataset( batch_size=batch_size, eval_batch_size=test_batch_size, data_dir=FLAGS.in_dataset_dir, shuffle_buffer_size=data_buffer_size) ood_dataset_builder = ds.CivilCommentsDataset( batch_size=batch_size, eval_batch_size=test_batch_size, data_dir=FLAGS.ood_dataset_dir, shuffle_buffer_size=data_buffer_size) ood_identity_dataset_builder = ds.CivilCommentsIdentitiesDataset( batch_size=batch_size, eval_batch_size=test_batch_size, data_dir=FLAGS.identity_dataset_dir, shuffle_buffer_size=data_buffer_size) dataset_builders = { 'ind': ind_dataset_builder, 'ood': ood_dataset_builder, 'ood_identity': ood_identity_dataset_builder, } ds_info = ind_dataset_builder.info feature_size = _MAX_SEQ_LENGTH num_classes = ds_info['num_classes'] # Positive and negative classes. test_datasets = {} steps_per_eval = {} for dataset_name, dataset_builder in dataset_builders.items(): test_datasets[dataset_name] = dataset_builder.build( split=base.Split.TEST) steps_per_eval[dataset_name] = ( dataset_builder.info['num_test_examples'] // test_batch_size) logging.info('Building %s model', FLAGS.model_family) bert_config_dir, bert_ckpt_dir = deterministic.resolve_bert_ckpt_and_config_dir( FLAGS.bert_dir, FLAGS.bert_config_dir, FLAGS.bert_ckpt_dir) bert_config = bert_utils.create_config(bert_config_dir) model, bert_encoder = ub.models.BertBuilder( num_classes=num_classes, max_seq_length=feature_size, bert_config=bert_config) logging.info('Model input shape: %s', model.input_shape) logging.info('Model output shape: %s', model.output_shape) logging.info('Model number of weights: %s', model.count_params()) # Search for checkpoints from their index file; then remove the index suffix. ensemble_filenames = tf.io.gfile.glob( os.path.join(FLAGS.checkpoint_dir, '**/*.index')) ensemble_filenames = [filename[:-6] for filename in ensemble_filenames] if FLAGS.num_models > len(ensemble_filenames): raise ValueError('Number of models to be included in the ensemble ' 'should be less than total number of models in ' 'the checkpoint_dir.') ensemble_filenames = ensemble_filenames[:FLAGS.num_models] ensemble_size = len(ensemble_filenames) logging.info('Ensemble size: %s', ensemble_size) logging.info('Ensemble number of weights: %s', ensemble_size * model.count_params()) logging.info('Ensemble filenames: %s', str(ensemble_filenames)) checkpoint = tf.train.Checkpoint(model=model) # Write model predictions to files. num_datasets = len(test_datasets) for m, ensemble_filename in enumerate(ensemble_filenames): checkpoint.restore(ensemble_filename).assert_existing_objects_matched() for n, (dataset_name, test_dataset) in enumerate(test_datasets.items()): filename = '{dataset}_{member}.npy'.format(dataset=dataset_name, member=m) filename = os.path.join(FLAGS.output_dir, filename) if not tf.io.gfile.exists(filename): logits = [] test_iterator = iter(test_dataset) for step in range(steps_per_eval[dataset_name]): try: inputs = next(test_iterator) except StopIteration: continue features, labels, _ = deterministic.create_feature_and_label(inputs) logits.append(model(features, training=False)) logits = tf.concat(logits, axis=0) with tf.io.gfile.GFile(filename, 'w') as f: np.save(f, logits.numpy()) percent = (m * num_datasets + (n + 1)) / (ensemble_size * num_datasets) message = ('{:.1%} completion for prediction: ensemble member {:d}/{:d}. ' 'Dataset {:d}/{:d}'.format(percent, m + 1, ensemble_size, n + 1, num_datasets)) logging.info(message) metrics = { 'test/negative_log_likelihood': tf.keras.metrics.Mean(), 'test/auroc': tf.keras.metrics.AUC(curve='ROC'), 'test/aupr': tf.keras.metrics.AUC(curve='PR'), 'test/brier': tf.keras.metrics.MeanSquaredError(), 'test/ece': um.ExpectedCalibrationError(num_bins=FLAGS.num_bins), } for fraction in FLAGS.fractions: metrics.update({ 'test_collab_acc/collab_acc_{}'.format(fraction): um.OracleCollaborativeAccuracy( fraction=float(fraction), num_bins=FLAGS.num_bins) }) for dataset_name, test_dataset in test_datasets.items(): if dataset_name != 'ind': metrics.update({ 'test/nll_{}'.format(dataset_name): tf.keras.metrics.Mean(), 'test/auroc_{}'.format(dataset_name): tf.keras.metrics.AUC(curve='ROC'), 'test/aupr_{}'.format(dataset_name): tf.keras.metrics.AUC(curve='PR'), 'test/brier_{}'.format(dataset_name): tf.keras.metrics.MeanSquaredError(), 'test/ece_{}'.format(dataset_name): um.ExpectedCalibrationError(num_bins=FLAGS.num_bins), }) for fraction in FLAGS.fractions: metrics.update({ 'test_collab_acc/collab_acc_{}_{}'.format(fraction, dataset_name): um.OracleCollaborativeAccuracy( fraction=float(fraction), num_bins=FLAGS.num_bins) }) # Evaluate model predictions. for n, (dataset_name, test_dataset) in enumerate(test_datasets.items()): logits_dataset = [] for m in range(ensemble_size): filename = '{dataset}_{member}.npy'.format(dataset=dataset_name, member=m) filename = os.path.join(FLAGS.output_dir, filename) with tf.io.gfile.GFile(filename, 'rb') as f: logits_dataset.append(np.load(f)) logits_dataset = tf.convert_to_tensor(logits_dataset) test_iterator = iter(test_dataset) texts_list = [] logits_list = [] labels_list = [] # Use dict to collect additional labels specified by additional label names. # Here we use `OrderedDict` to get consistent ordering for this dict so # we can retrieve the predictions for each identity labels in Colab. additional_labels_dict = collections.OrderedDict() for step in range(steps_per_eval[dataset_name]): try: inputs = next(test_iterator) # type: Mapping[Text, tf.Tensor] except StopIteration: continue features, labels, additional_labels = ( deterministic.create_feature_and_label(inputs)) logits = logits_dataset[:, (step * batch_size):((step + 1) * batch_size)] loss_logits = tf.squeeze(logits, axis=-1) negative_log_likelihood = um.ensemble_cross_entropy( labels, loss_logits, binary=True) per_probs = tf.nn.sigmoid(logits) probs = tf.reduce_mean(per_probs, axis=0) # Cast labels to discrete for ECE computation ece_labels = tf.cast(labels > FLAGS.ece_label_threshold, tf.float32) ece_probs = tf.concat([1. - probs, probs], axis=1) auc_probs = tf.squeeze(probs, axis=1) texts_list.append(inputs['input_ids']) logits_list.append(logits) labels_list.append(labels) if 'identity' in dataset_name: for identity_label_name in deterministic._IDENTITY_LABELS: # pylint: disable=protected-access if identity_label_name not in additional_labels_dict: additional_labels_dict[identity_label_name] = [] additional_labels_dict[identity_label_name].append( additional_labels[identity_label_name].numpy()) if dataset_name == 'ind': metrics['test/negative_log_likelihood'].update_state( negative_log_likelihood) metrics['test/auroc'].update_state(labels, auc_probs) metrics['test/aupr'].update_state(labels, auc_probs) metrics['test/brier'].update_state(labels, auc_probs) metrics['test/ece'].update_state(ece_labels, ece_probs) for fraction in FLAGS.fractions: metrics['test_collab_acc/collab_acc_{}'.format( fraction)].update_state(ece_labels, ece_probs) else: metrics['test/nll_{}'.format(dataset_name)].update_state( negative_log_likelihood) metrics['test/auroc_{}'.format(dataset_name)].update_state( labels, auc_probs) metrics['test/aupr_{}'.format(dataset_name)].update_state( labels, auc_probs) metrics['test/brier_{}'.format(dataset_name)].update_state( labels, auc_probs) metrics['test/ece_{}'.format(dataset_name)].update_state( ece_labels, ece_probs) for fraction in FLAGS.fractions: metrics['test_collab_acc/collab_acc_{}_{}'.format( fraction, dataset_name)].update_state(ece_labels, ece_probs) texts_all = tf.concat(texts_list, axis=0) logits_all = tf.concat(logits_list, axis=1) labels_all = tf.concat(labels_list, axis=0) additional_labels_all = [] if additional_labels_dict: additional_labels_all = list(additional_labels_dict.values()) deterministic.save_prediction( texts_all.numpy(), path=os.path.join(FLAGS.output_dir, 'texts_{}'.format(dataset_name))) deterministic.save_prediction( labels_all.numpy(), path=os.path.join(FLAGS.output_dir, 'labels_{}'.format(dataset_name))) deterministic.save_prediction( logits_all.numpy(), path=os.path.join(FLAGS.output_dir, 'logits_{}'.format(dataset_name))) if 'identity' in dataset_name: deterministic.save_prediction( np.array(additional_labels_all), path=os.path.join(FLAGS.output_dir, 'additional_labels_{}'.format(dataset_name))) message = ('{:.1%} completion for evaluation: dataset {:d}/{:d}'.format( (n + 1) / num_datasets, n + 1, num_datasets)) logging.info(message) total_results = {name: metric.result() for name, metric in metrics.items()} logging.info('Metrics: %s', total_results)
def main(argv): del argv # unused arg if not FLAGS.use_gpu: raise ValueError('Only GPU is currently supported.') if FLAGS.num_cores > 1: raise ValueError('Only a single accelerator is currently supported.') tf.random.set_seed(FLAGS.seed) tf.io.gfile.makedirs(FLAGS.output_dir) ds_info = tfds.builder(FLAGS.dataset).info batch_size = FLAGS.per_core_batch_size * FLAGS.num_cores steps_per_eval = ds_info.splits['test'].num_examples // batch_size num_classes = ds_info.features['label'].num_classes dataset = ub.datasets.get( FLAGS.dataset, split=tfds.Split.TEST).load(batch_size=batch_size) validation_percent = 1. - FLAGS.train_proportion val_dataset = ub.datasets.get( dataset_name=FLAGS.dataset, split=tfds.Split.VALIDATION, validation_percent=validation_percent, drop_remainder=False).load(batch_size=batch_size) steps_per_val_eval = int(ds_info.splits['train'].num_examples * validation_percent) // batch_size test_datasets = {'clean': dataset} extra_kwargs = {} if FLAGS.dataset == 'cifar100': extra_kwargs['data_dir'] = FLAGS.cifar100_c_path corruption_types, _ = utils.load_corrupted_test_info(FLAGS.dataset) for corruption_type in corruption_types: for severity in range(1, 6): dataset = ub.datasets.get( f'{FLAGS.dataset}_corrupted', corruption_type=corruption_type, severity=severity, split=tfds.Split.TEST, **extra_kwargs).load(batch_size=batch_size) test_datasets[f'{corruption_type}_{severity}'] = dataset model = ub.models.wide_resnet( input_shape=ds_info.features['image'].shape, depth=28, width_multiplier=10, num_classes=num_classes, l2=0., version=2) logging.info('Model input shape: %s', model.input_shape) logging.info('Model output shape: %s', model.output_shape) logging.info('Model number of weights: %s', model.count_params()) # Search for checkpoints ensemble_filenames = parse_checkpoint_dir(FLAGS.checkpoint_dir) model_pool_size = len(ensemble_filenames) logging.info('Model pool size: %s', model_pool_size) logging.info('Ensemble size: %s', FLAGS.ensemble_size) logging.info('Ensemble number of weights: %s', FLAGS.ensemble_size * model.count_params()) logging.info('Ensemble filenames: %s', str(ensemble_filenames)) checkpoint = tf.train.Checkpoint(model=model) # Compute the logits on the validation set val_logits, val_labels = [], [] for m, ensemble_filename in enumerate(ensemble_filenames): # Enforce memory clean-up tf.keras.backend.clear_session() checkpoint.restore(ensemble_filename) val_iterator = iter(val_dataset) val_logits_m = [] for _ in range(steps_per_val_eval): inputs = next(val_iterator) features = inputs['features'] labels = inputs['labels'] val_logits_m.append(model(features, training=False)) if m == 0: val_labels.append(labels) val_logits.append(tf.concat(val_logits_m, axis=0)) if m == 0: val_labels = tf.concat(val_labels, axis=0) percent = (m + 1.) / model_pool_size message = ('{:.1%} completion for prediction on validation set: ' 'model {:d}/{:d}.'.format(percent, m + 1, model_pool_size)) logging.info(message) selected_members, val_acc, val_nll = greedy_selection(val_logits, val_labels, FLAGS.ensemble_size, FLAGS.greedy_objective) unique_selected_members = list(set(selected_members)) message = ('Members selected by greedy procedure: {} (with {} unique ' 'member(s))\n\t{}').format( selected_members, len(unique_selected_members), [ensemble_filenames[i] for i in selected_members]) logging.info(message) val_metrics = { 'val/accuracy': tf.keras.metrics.Mean(), 'val/negative_log_likelihood': tf.keras.metrics.Mean() } val_metrics['val/accuracy'].update_state(val_acc) val_metrics['val/negative_log_likelihood'].update_state(val_nll) # Write model predictions to files. num_datasets = len(test_datasets) for m, member_id in enumerate(unique_selected_members): ensemble_filename = ensemble_filenames[member_id] checkpoint.restore(ensemble_filename) for n, (name, test_dataset) in enumerate(test_datasets.items()): filename = '{dataset}_{member}.npy'.format(dataset=name, member=member_id) filename = os.path.join(FLAGS.output_dir, filename) if not tf.io.gfile.exists(filename): logits = [] test_iterator = iter(test_dataset) for _ in range(steps_per_eval): features = next(test_iterator)['features'] # pytype: disable=unsupported-operands logits.append(model(features, training=False)) logits = tf.concat(logits, axis=0) with tf.io.gfile.GFile(filename, 'w') as f: np.save(f, logits.numpy()) numerator = m * num_datasets + (n + 1) denominator = len(unique_selected_members) * num_datasets percent = numerator / denominator message = ('{:.1%} completion for prediction: ensemble member {:d}/{:d}. ' 'Dataset {:d}/{:d}'.format(percent, m + 1, len(unique_selected_members), n + 1, num_datasets)) logging.info(message) metrics = { 'test/negative_log_likelihood': tf.keras.metrics.Mean(), 'test/gibbs_cross_entropy': tf.keras.metrics.Mean(), 'test/accuracy': tf.keras.metrics.SparseCategoricalAccuracy(), 'test/ece': rm.metrics.ExpectedCalibrationError( num_bins=FLAGS.num_bins), } metrics.update(val_metrics) corrupt_metrics = {} for name in test_datasets: corrupt_metrics['test/nll_{}'.format(name)] = tf.keras.metrics.Mean() corrupt_metrics['test/accuracy_{}'.format(name)] = ( tf.keras.metrics.SparseCategoricalAccuracy()) corrupt_metrics['test/ece_{}'.format(name)] = ( rm.metrics.ExpectedCalibrationError(num_bins=FLAGS.num_bins)) for i in range(len(unique_selected_members)): metrics['test/nll_member_{}'.format(i)] = tf.keras.metrics.Mean() metrics['test/accuracy_member_{}'.format(i)] = ( tf.keras.metrics.SparseCategoricalAccuracy()) test_diversity = { 'test/disagreement': tf.keras.metrics.Mean(), 'test/average_kl': tf.keras.metrics.Mean(), 'test/cosine_similarity': tf.keras.metrics.Mean(), } metrics.update(test_diversity) # Evaluate model predictions. for n, (name, test_dataset) in enumerate(test_datasets.items()): logits_dataset = [] for member_id in selected_members: filename = '{dataset}_{member}.npy'.format(dataset=name, member=member_id) filename = os.path.join(FLAGS.output_dir, filename) with tf.io.gfile.GFile(filename, 'rb') as f: logits_dataset.append(np.load(f)) logits_dataset = tf.convert_to_tensor(logits_dataset) test_iterator = iter(test_dataset) for step in range(steps_per_eval): labels = next(test_iterator)['labels'] # pytype: disable=unsupported-operands logits = logits_dataset[:, (step*batch_size):((step+1)*batch_size)] labels = tf.cast(labels, tf.int32) negative_log_likelihood = um.ensemble_cross_entropy(labels, logits) per_probs = tf.nn.softmax(logits) probs = tf.reduce_mean(per_probs, axis=0) if name == 'clean': gibbs_ce = um.gibbs_cross_entropy(labels, logits) metrics['test/negative_log_likelihood'].update_state( negative_log_likelihood) metrics['test/gibbs_cross_entropy'].update_state(gibbs_ce) metrics['test/accuracy'].update_state(labels, probs) metrics['test/ece'].add_batch(probs, label=labels) # Attention must be paid to deal with duplicated members: # e.g., #. selected_members = [2, 7, 3, 3] # unique_selected_members = [2, 3, 7] # selected_members.index(3) --> 2 for member_id in unique_selected_members: i = selected_members.index(member_id) member_probs = per_probs[i] member_loss = tf.keras.losses.sparse_categorical_crossentropy( labels, member_probs) metrics['test/nll_member_{}'.format(i)].update_state(member_loss) metrics['test/accuracy_member_{}'.format(i)].update_state( labels, member_probs) diversity_results = um.average_pairwise_diversity( per_probs, len(per_probs)) for k, v in diversity_results.items(): test_diversity['test/' + k].update_state(v) else: corrupt_metrics['test/nll_{}'.format(name)].update_state( negative_log_likelihood) corrupt_metrics['test/accuracy_{}'.format(name)].update_state( labels, probs) corrupt_metrics['test/ece_{}'.format(name)].add_batch( probs, label=labels) message = ('{:.1%} completion for evaluation: dataset {:d}/{:d}'.format( (n + 1) / num_datasets, n + 1, num_datasets)) logging.info(message) corrupt_results = utils.aggregate_corrupt_metrics(corrupt_metrics, corruption_types) total_results = {name: metric.result() for name, metric in metrics.items()} total_results.update(corrupt_results) # Metrics from Robustness Metrics (like ECE) will return a dict with a # single key/value, instead of a scalar. total_results = { k: (list(v.values())[0] if isinstance(v, dict) else v) for k, v in total_results.items() } logging.info('Metrics: %s', total_results)
def main(argv): del argv # unused arg if not FLAGS.use_gpu: raise ValueError('Only GPU is currently supported.') if FLAGS.num_cores > 1: raise ValueError('Only a single accelerator is currently supported.') tf.random.set_seed(FLAGS.seed) tf.io.gfile.makedirs(FLAGS.output_dir) ds_info = tfds.builder(FLAGS.dataset).info batch_size = FLAGS.per_core_batch_size * FLAGS.num_cores steps_per_eval = ds_info.splits['test'].num_examples // batch_size num_classes = ds_info.features['label'].num_classes dataset = ub.datasets.get( FLAGS.dataset, split=tfds.Split.TEST).load(batch_size=batch_size) test_datasets = {'clean': dataset} extra_kwargs = {} if FLAGS.dataset == 'cifar100': extra_kwargs['data_dir'] = FLAGS.cifar100_c_path corruption_types, _ = utils.load_corrupted_test_info(FLAGS.dataset) for corruption_type in corruption_types: for severity in range(1, 6): dataset = ub.datasets.get( f'{FLAGS.dataset}_corrupted', corruption_type=corruption_type, severity=severity, split=tfds.Split.TEST, **extra_kwargs).load(batch_size=batch_size) test_datasets[f'{corruption_type}_{severity}'] = dataset model = ub.models.wide_resnet_sngp( input_shape=ds_info.features['image'].shape, batch_size=FLAGS.per_core_batch_size, depth=28, width_multiplier=10, num_classes=num_classes, l2=0., use_mc_dropout=FLAGS.use_mc_dropout, use_filterwise_dropout=FLAGS.use_filterwise_dropout, dropout_rate=FLAGS.dropout_rate, use_gp_layer=FLAGS.use_gp_layer, gp_input_dim=FLAGS.gp_input_dim, gp_hidden_dim=FLAGS.gp_hidden_dim, gp_scale=FLAGS.gp_scale, gp_bias=FLAGS.gp_bias, gp_input_normalization=FLAGS.gp_input_normalization, gp_random_feature_type=FLAGS.gp_random_feature_type, gp_cov_discount_factor=FLAGS.gp_cov_discount_factor, gp_cov_ridge_penalty=FLAGS.gp_cov_ridge_penalty, use_spec_norm=FLAGS.use_spec_norm, spec_norm_iteration=FLAGS.spec_norm_iteration, spec_norm_bound=FLAGS.spec_norm_bound) logging.info('Model input shape: %s', model.input_shape) logging.info('Model output shape: %s', model.output_shape) logging.info('Model number of weights: %s', model.count_params()) # Search for checkpoints from their index file; then remove the index suffix. ensemble_filenames = tf.io.gfile.glob( os.path.join(FLAGS.checkpoint_dir, '**/*.index')) ensemble_filenames = [filename[:-6] for filename in ensemble_filenames] ensemble_size = len(ensemble_filenames) logging.info('Ensemble size: %s', ensemble_size) logging.info('Ensemble number of weights: %s', ensemble_size * model.count_params()) logging.info('Ensemble filenames: %s', str(ensemble_filenames)) checkpoint = tf.train.Checkpoint(model=model) # Write model predictions to files. num_datasets = len(test_datasets) for m, ensemble_filename in enumerate(ensemble_filenames): checkpoint.restore(ensemble_filename) for n, (name, test_dataset) in enumerate(test_datasets.items()): filename = '{dataset}_{member}.npy'.format(dataset=name, member=m) filename = os.path.join(FLAGS.output_dir, filename) if not tf.io.gfile.exists(filename): logits = [] test_iterator = iter(test_dataset) for _ in range(steps_per_eval): features = next(test_iterator)['features'] # pytype: disable=unsupported-operands logits_member = model(features, training=False) if isinstance(logits_member, tuple): # If model returns a tuple of (logits, covmat), extract both logits_member, covmat_member = logits_member else: covmat_member = tf.eye(FLAGS.per_core_batch_size) logits_member = ed.layers.utils.mean_field_logits( logits_member, covmat_member, FLAGS.gp_mean_field_factor_ensemble) logits.append(logits_member) logits = tf.concat(logits, axis=0) with tf.io.gfile.GFile(filename, 'w') as f: np.save(f, logits.numpy()) percent = (m * num_datasets + (n + 1)) / (ensemble_size * num_datasets) message = ( '{:.1%} completion for prediction: ensemble member {:d}/{:d}. ' 'Dataset {:d}/{:d}'.format(percent, m + 1, ensemble_size, n + 1, num_datasets)) logging.info(message) metrics = { 'test/negative_log_likelihood': tf.keras.metrics.Mean(), 'test/gibbs_cross_entropy': tf.keras.metrics.Mean(), 'test/accuracy': tf.keras.metrics.SparseCategoricalAccuracy(), 'test/ece': um.ExpectedCalibrationError(num_bins=FLAGS.num_bins), } corrupt_metrics = {} for name in test_datasets: corrupt_metrics['test/nll_{}'.format(name)] = tf.keras.metrics.Mean() corrupt_metrics['test/accuracy_{}'.format(name)] = ( tf.keras.metrics.SparseCategoricalAccuracy()) corrupt_metrics['test/ece_{}'.format(name)] = ( um.ExpectedCalibrationError(num_bins=FLAGS.num_bins)) # Evaluate model predictions. for n, (name, test_dataset) in enumerate(test_datasets.items()): logits_dataset = [] for m in range(ensemble_size): filename = '{dataset}_{member}.npy'.format(dataset=name, member=m) filename = os.path.join(FLAGS.output_dir, filename) with tf.io.gfile.GFile(filename, 'rb') as f: logits_dataset.append(np.load(f)) logits_dataset = tf.convert_to_tensor(logits_dataset) test_iterator = iter(test_dataset) for step in range(steps_per_eval): labels = next(test_iterator)['labels'] # pytype: disable=unsupported-operands logits = logits_dataset[:, (step * batch_size):((step + 1) * batch_size)] labels = tf.cast(labels, tf.int32) negative_log_likelihood = um.ensemble_cross_entropy(labels, logits) per_probs = tf.nn.softmax(logits) probs = tf.reduce_mean(per_probs, axis=0) if name == 'clean': gibbs_ce = um.gibbs_cross_entropy(labels, logits) metrics['test/negative_log_likelihood'].update_state( negative_log_likelihood) metrics['test/gibbs_cross_entropy'].update_state(gibbs_ce) metrics['test/accuracy'].update_state(labels, probs) metrics['test/ece'].update_state(labels, probs) else: corrupt_metrics['test/nll_{}'.format(name)].update_state( negative_log_likelihood) corrupt_metrics['test/accuracy_{}'.format(name)].update_state( labels, probs) corrupt_metrics['test/ece_{}'.format(name)].update_state( labels, probs) message = ( '{:.1%} completion for evaluation: dataset {:d}/{:d}'.format( (n + 1) / num_datasets, n + 1, num_datasets)) logging.info(message) corrupt_results = utils.aggregate_corrupt_metrics(corrupt_metrics, corruption_types) total_results = {name: metric.result() for name, metric in metrics.items()} total_results.update(corrupt_results) logging.info('Metrics: %s', total_results)
def main(argv): del argv # unused arg if not FLAGS.use_gpu: raise ValueError('Only GPU is currently supported.') if FLAGS.num_cores > 1: raise ValueError('Only a single accelerator is currently supported.') tf.random.set_seed(FLAGS.seed) tf.io.gfile.makedirs(FLAGS.output_dir) ind_dataset_builder = ub.datasets.ClincIntentDetectionDataset( batch_size=FLAGS.per_core_batch_size, eval_batch_size=FLAGS.per_core_batch_size, data_dir=FLAGS.data_dir, data_mode='ind') ood_dataset_builder = ub.datasets.ClincIntentDetectionDataset( batch_size=FLAGS.per_core_batch_size, eval_batch_size=FLAGS.per_core_batch_size, data_dir=FLAGS.data_dir, data_mode='ood') all_dataset_builder = ub.datasets.ClincIntentDetectionDataset( batch_size=FLAGS.per_core_batch_size, eval_batch_size=FLAGS.per_core_batch_size, data_dir=FLAGS.data_dir, data_mode='all') dataset_builders = { 'clean': ind_dataset_builder, 'ood': ood_dataset_builder, 'all': all_dataset_builder } ds_info = ind_dataset_builder.info feature_size = ds_info['feature_size'] # num_classes is number of valid intents plus out-of-scope intent num_classes = ds_info['num_classes'] + 1 # vocab_size is total number of valid tokens plus the out-of-vocabulary token. vocab_size = ind_dataset_builder.tokenizer.num_words + 1 batch_size = FLAGS.per_core_batch_size * FLAGS.num_cores test_datasets = {} steps_per_eval = {} for dataset_name, dataset_builder in dataset_builders.items(): test_datasets[dataset_name] = dataset_builder.build( split=ub.datasets.base.Split.TEST) steps_per_eval[dataset_name] = ( dataset_builder.info['num_test_examples'] // batch_size) if FLAGS.model_family.lower() == 'textcnn': model = cnn_model.textcnn(filter_sizes=FLAGS.filter_sizes, num_filters=FLAGS.num_filters, num_classes=num_classes, feature_size=feature_size, vocab_size=vocab_size, embed_size=FLAGS.embedding_size, dropout_rate=FLAGS.dropout_rate, l2=FLAGS.l2) elif FLAGS.model_family.lower() == 'bert': bert_config_dir, _ = deterministic.resolve_bert_ckpt_and_config_dir( FLAGS.bert_dir, FLAGS.bert_config_dir, FLAGS.bert_ckpt_dir) bert_config = bert_utils.create_config(bert_config_dir) model, _ = ub.models.BertBuilder(num_classes=num_classes, max_seq_length=feature_size, bert_config=bert_config) else: raise ValueError( 'model_family ({}) can only be TextCNN or BERT.'.format( FLAGS.model_family)) logging.info('Model input shape: %s', model.input_shape) logging.info('Model output shape: %s', model.output_shape) logging.info('Model number of weights: %s', model.count_params()) # Search for checkpoints from their index file; then remove the index suffix. ensemble_filenames = tf.io.gfile.glob( os.path.join(FLAGS.checkpoint_dir, '**/*.index')) ensemble_filenames = [filename[:-6] for filename in ensemble_filenames] ensemble_size = len(ensemble_filenames) logging.info('Ensemble size: %s', ensemble_size) logging.info('Ensemble number of weights: %s', ensemble_size * model.count_params()) logging.info('Ensemble filenames: %s', str(ensemble_filenames)) checkpoint = tf.train.Checkpoint(model=model) # Write model predictions to files. num_datasets = len(test_datasets) for m, ensemble_filename in enumerate(ensemble_filenames): checkpoint.restore(ensemble_filename) for n, (name, test_dataset) in enumerate(test_datasets.items()): filename = '{dataset}_{member}.npy'.format(dataset=name, member=m) filename = os.path.join(FLAGS.output_dir, filename) if not tf.io.gfile.exists(filename): logits = [] test_iterator = iter(test_dataset) for _ in range(steps_per_eval[name]): inputs = next(test_iterator) features, _ = deterministic.create_feature_and_label( inputs, feature_size, model_family=FLAGS.model_family) logits.append(model(features, training=False)) logits = tf.concat(logits, axis=0) with tf.io.gfile.GFile(filename, 'w') as f: np.save(f, logits.numpy()) percent = (m * num_datasets + (n + 1)) / (ensemble_size * num_datasets) message = ( '{:.1%} completion for prediction: ensemble member {:d}/{:d}. ' 'Dataset {:d}/{:d}'.format(percent, m + 1, ensemble_size, n + 1, num_datasets)) logging.info(message) metrics = { 'test/negative_log_likelihood': tf.keras.metrics.Mean(), 'test/gibbs_cross_entropy': tf.keras.metrics.Mean(), 'test/accuracy': tf.keras.metrics.SparseCategoricalAccuracy(), 'test/ece': um.ExpectedCalibrationError(num_bins=FLAGS.num_bins), } for dataset_name, test_dataset in test_datasets.items(): if dataset_name != 'clean': metrics.update({ 'test/nll_{}'.format(dataset_name): tf.keras.metrics.Mean(), 'test/accuracy_{}'.format(dataset_name): tf.keras.metrics.SparseCategoricalAccuracy(), 'test/ece_{}'.format(dataset_name): um.ExpectedCalibrationError(num_bins=FLAGS.num_bins) }) # Finally, define OOD metrics for the combined IND and OOD dataset. metrics.update({ 'test/auroc_all': tf.keras.metrics.AUC(curve='ROC'), 'test/auprc_all': tf.keras.metrics.AUC(curve='PR') }) # Evaluate model predictions. for n, (name, test_dataset) in enumerate(test_datasets.items()): logits_dataset = [] for m in range(ensemble_size): filename = '{dataset}_{member}.npy'.format(dataset=name, member=m) filename = os.path.join(FLAGS.output_dir, filename) with tf.io.gfile.GFile(filename, 'rb') as f: logits_dataset.append(np.load(f)) logits_dataset = tf.convert_to_tensor(logits_dataset) test_iterator = iter(test_dataset) for step in range(steps_per_eval[name]): inputs = next(test_iterator) _, labels = deterministic.create_feature_and_label( inputs, feature_size, model_family=FLAGS.model_family) logits = logits_dataset[:, (step * batch_size):((step + 1) * batch_size)] labels = tf.cast(labels, tf.int32) negative_log_likelihood = um.ensemble_cross_entropy(labels, logits) per_probs = tf.nn.softmax(logits) probs = tf.reduce_mean(per_probs, axis=0) if name == 'clean': gibbs_ce = um.gibbs_cross_entropy(labels, logits) metrics['test/negative_log_likelihood'].update_state( negative_log_likelihood) metrics['test/gibbs_cross_entropy'].update_state(gibbs_ce) metrics['test/accuracy'].update_state(labels, probs) metrics['test/ece'].update_state(labels, probs) else: metrics['test/nll_{}'.format(name)].update_state( negative_log_likelihood) metrics['test/accuracy_{}'.format(name)].update_state( labels, probs) metrics['test/ece_{}'.format(name)].update_state(labels, probs) if dataset_name == 'all': ood_labels = tf.cast(labels == 150, labels.dtype) ood_probs = 1. - tf.reduce_max(probs, axis=-1) metrics['test/auroc_{}'.format(dataset_name)].update_state( ood_labels, ood_probs) metrics['test/auprc_{}'.format(dataset_name)].update_state( ood_labels, ood_probs) message = ( '{:.1%} completion for evaluation: dataset {:d}/{:d}'.format( (n + 1) / num_datasets, n + 1, num_datasets)) logging.info(message) total_results = {name: metric.result() for name, metric in metrics.items()} logging.info('Metrics: %s', total_results)