def step_fn(inputs): """Per-Replica StepFn.""" # Note that we don't use tf.tile for labels here images, labels = inputs images = tf.tile(images, [FLAGS.ensemble_size, 1, 1, 1]) # get lambdas lambdas = log_uniform_mean(lambda_parameters) rep_lambdas = tf.repeat(lambdas, per_core_batch_size, axis=0) # eval on testsets logits = model([images, rep_lambdas], training=False) if FLAGS.use_bfloat16: logits = tf.cast(logits, tf.float32) probs = tf.nn.softmax(logits) per_probs = tf.split(probs, num_or_size_splits=FLAGS.ensemble_size, axis=0) # per member performance and gibbs performance (average per member perf) if dataset_name == 'clean': for i in range(FLAGS.ensemble_size): member_probs = per_probs[i] member_loss = tf.keras.losses.sparse_categorical_crossentropy( labels, member_probs) metrics['test/nll_member_{}'.format(i)].update_state( member_loss) metrics['test/accuracy_member_{}'.format(i)].update_state( labels, member_probs) labels_tile = tf.tile(labels, [FLAGS.ensemble_size]) metrics['test/gibbs_nll'].update_state( tf.reduce_mean( tf.keras.losses.sparse_categorical_crossentropy( labels_tile, logits, from_logits=True))) metrics['test/gibbs_accuracy'].update_state(labels_tile, probs) # ensemble performance negative_log_likelihood = ensemble_crossentropy( labels, logits, FLAGS.ensemble_size) probs = tf.reduce_mean(per_probs, axis=0) if dataset_name == 'clean': metrics['test/negative_log_likelihood'].update_state( negative_log_likelihood) metrics['test/accuracy'].update_state(labels, probs) metrics['test/ece'].update_state(labels, probs) else: corrupt_metrics['test/nll_{}'.format( dataset_name)].update_state(negative_log_likelihood) corrupt_metrics['test/accuracy_{}'.format( dataset_name)].update_state(labels, probs) corrupt_metrics['test/ece_{}'.format( dataset_name)].update_state(labels, probs) if dataset_name == 'clean': per_probs_stacked = tf.stack(per_probs, axis=0) diversity_results = um.average_pairwise_diversity( per_probs_stacked, FLAGS.ensemble_size) for k, v in diversity_results.items(): metrics['test/' + k].update_state(v)
def step_fn(inputs): """Per-Replica StepFn.""" images, labels = inputs images = tf.tile(images, [FLAGS.ensemble_size, 1, 1, 1]) logits = model(images, training=False) if FLAGS.use_bfloat16: logits = tf.cast(logits, tf.float32) probs = tf.nn.softmax(logits) if dataset_name == 'clean': per_probs_tensor = tf.reshape( probs, tf.concat([[FLAGS.ensemble_size, -1], probs.shape[1:]], 0)) diversity_results = um.average_pairwise_diversity( per_probs_tensor, FLAGS.ensemble_size) for k, v in diversity_results.items(): test_diversity['test/' + k].update_state(v) per_probs = tf.split(probs, num_or_size_splits=FLAGS.ensemble_size, axis=0) probs = tf.reduce_mean(per_probs, axis=0) negative_log_likelihood = tf.reduce_mean( tf.keras.losses.sparse_categorical_crossentropy(labels, probs)) for i in range(FLAGS.ensemble_size): member_probs = per_probs[i] if dataset_name == 'clean': member_loss = tf.keras.losses.sparse_categorical_crossentropy( labels, member_probs) metrics['test/nll_member_{}'.format(i)].update_state( member_loss) metrics['test/accuracy_member_{}'.format(i)].update_state( labels, member_probs) metrics['test/member_accuracy_mean'].update_state( labels, member_probs) metrics['test/member_ece_mean'].update_state( labels, member_probs) elif dataset_name != 'confidence_validation': corrupt_metrics['test/member_acc_mean_{}'.format( dataset_name)].update_state(labels, member_probs) corrupt_metrics['test/member_ece_mean_{}'.format( dataset_name)].update_state(labels, member_probs) if dataset_name == 'clean': metrics['test/negative_log_likelihood'].update_state( negative_log_likelihood) metrics['test/accuracy'].update_state(labels, probs) metrics['test/ece'].update_state(labels, probs) elif dataset_name != 'confidence_validation': corrupt_metrics['test/nll_{}'.format( dataset_name)].update_state(negative_log_likelihood) corrupt_metrics['test/accuracy_{}'.format( dataset_name)].update_state(labels, probs) corrupt_metrics['test/ece_{}'.format( dataset_name)].update_state(labels, probs) if dataset_name == 'confidence_validation': return tf.stack(per_probs, 0), labels
def step_fn(inputs): """Per-Replica StepFn.""" images, labels = inputs if FLAGS.ensemble_size > 1: images = tf.tile(images, [FLAGS.ensemble_size, 1, 1, 1]) logits = tf.reshape( [model(images, training=False) for _ in range(FLAGS.num_eval_samples)], [FLAGS.num_eval_samples, FLAGS.ensemble_size, -1, NUM_CLASSES]) if FLAGS.use_bfloat16: logits = tf.cast(logits, tf.float32) all_probs = tf.nn.softmax(logits) probs = tf.math.reduce_mean(all_probs, axis=[0, 1]) # marginalize # Negative log marginal likelihood computed in a numerically-stable way. labels_broadcasted = tf.broadcast_to( labels, [FLAGS.num_eval_samples, FLAGS.ensemble_size, labels.shape[0]]) log_likelihoods = -tf.keras.losses.sparse_categorical_crossentropy( labels_broadcasted, logits, from_logits=True) negative_log_likelihood = tf.reduce_mean( -tf.reduce_logsumexp(log_likelihoods, axis=[0, 1]) + tf.math.log(float(FLAGS.num_eval_samples * FLAGS.ensemble_size))) l2_loss = compute_l2_loss(model) kl = sum(model.losses) / IMAGENET_VALIDATION_IMAGES elbo = -(negative_log_likelihood + l2_loss + kl) if dataset_name == 'clean': if FLAGS.ensemble_size > 1: per_probs = tf.reduce_mean(all_probs, axis=0) # marginalize samples diversity_results = um.average_pairwise_diversity( per_probs, FLAGS.ensemble_size) for k, v in diversity_results.items(): test_diversity['test/' + k].update_state(v) for i in range(FLAGS.ensemble_size): member_probs = per_probs[i] member_loss = tf.keras.losses.sparse_categorical_crossentropy( labels, member_probs) metrics['test/nll_member_{}'.format(i)].update_state(member_loss) metrics['test/accuracy_member_{}'.format(i)].update_state( labels, member_probs) metrics['test/negative_log_likelihood'].update_state( negative_log_likelihood) metrics['test/kl'].update_state(kl) metrics['test/elbo'].update_state(elbo) metrics['test/accuracy'].update_state(labels, probs) metrics['test/ece'].update_state(labels, probs) else: corrupt_metrics['test/nll_{}'.format(dataset_name)].update_state( negative_log_likelihood) corrupt_metrics['test/kl_{}'.format(dataset_name)].update_state(kl) corrupt_metrics['test/elbo_{}'.format(dataset_name)].update_state(elbo) corrupt_metrics['test/accuracy_{}'.format(dataset_name)].update_state( labels, probs) corrupt_metrics['test/ece_{}'.format(dataset_name)].update_state( labels, probs)
def testAveragePairwiseDiversity(self): num_models = 3 batch_size = 2 num_classes = 5 logits = tf.random.normal([num_models, batch_size, num_classes]) probs = tf.nn.softmax(logits) results = um.average_pairwise_diversity(probs, num_models=num_models) self.assertLen(results, 3) self.assertEqual(results['disagreement'].shape, []) self.assertEqual(results['average_kl'].shape, []) self.assertEqual(results['cosine_similarity'].shape, [])
def step_fn(inputs): """Per-Replica StepFn.""" images, labels = inputs images = tf.tile(tf.expand_dims(images, 1), [1, FLAGS.ensemble_size, 1, 1, 1]) logits = model(images, training=False) if FLAGS.use_bfloat16: logits = tf.cast(logits, tf.float32) probs = tf.nn.softmax(logits) if dataset_name == 'clean': per_probs = tf.transpose(probs, perm=[1, 0, 2]) diversity_results = um.average_pairwise_diversity( per_probs, FLAGS.ensemble_size) for k, v in diversity_results.items(): test_diversity['test/' + k].update_state(v) for i in range(FLAGS.ensemble_size): member_probs = probs[:, i] member_loss = tf.keras.losses.sparse_categorical_crossentropy( labels, member_probs) metrics['test/nll_member_{}'.format(i)].update_state( member_loss) metrics['test/accuracy_member_{}'.format(i)].update_state( labels, member_probs) # Negative log marginal likelihood computed in a numerically-stable way. labels_tiled = tf.tile(tf.expand_dims(labels, 1), [1, FLAGS.ensemble_size]) log_likelihoods = -tf.keras.losses.sparse_categorical_crossentropy( labels_tiled, logits, from_logits=True) negative_log_likelihood = tf.reduce_mean( -tf.reduce_logsumexp(log_likelihoods, axis=[1]) + tf.math.log(float(FLAGS.ensemble_size))) probs = tf.math.reduce_mean(probs, axis=1) # marginalize if dataset_name == 'clean': metrics['test/negative_log_likelihood'].update_state( negative_log_likelihood) metrics['test/accuracy'].update_state(labels, probs) metrics['test/ece'].update_state(labels, probs) else: corrupt_metrics['test/nll_{}'.format( dataset_name)].update_state(negative_log_likelihood) corrupt_metrics['test/accuracy_{}'.format( dataset_name)].update_state(labels, probs) corrupt_metrics['test/ece_{}'.format( dataset_name)].update_state(labels, probs)
def step_fn(inputs): """Per-Replica StepFn.""" images, labels = inputs if FLAGS.adaptive_mixup: images = tf.identity(images) else: images = tf.tile(images, [FLAGS.ensemble_size, 1, 1, 1]) if FLAGS.adaptive_mixup: labels = tf.identity(labels) elif FLAGS.mixup_alpha > 0: labels = tf.tile(labels, [FLAGS.ensemble_size, 1]) else: labels = tf.tile(labels, [FLAGS.ensemble_size]) with tf.GradientTape() as tape: logits = model(images, training=True) if FLAGS.use_bfloat16: logits = tf.cast(logits, tf.float32) probs = tf.nn.softmax(logits) per_probs = tf.reshape( probs, tf.concat([[FLAGS.ensemble_size, -1], probs.shape[1:]], 0)) diversity_results = um.average_pairwise_diversity( per_probs, FLAGS.ensemble_size) if FLAGS.mixup_alpha > 0: negative_log_likelihood = tf.reduce_mean( tf.keras.losses.categorical_crossentropy( labels, logits, from_logits=True)) else: negative_log_likelihood = tf.reduce_mean( tf.keras.losses.sparse_categorical_crossentropy( labels, logits, from_logits=True)) filtered_variables = [] for var in model.trainable_variables: # Apply l2 on the slow weights and bias terms. This excludes BN # parameters and fast weight approximate posterior/prior parameters, # but pay caution to their naming scheme. if 'kernel' in var.name or 'bias' in var.name: filtered_variables.append(tf.reshape(var, (-1, ))) l2_loss = FLAGS.l2 * 2 * tf.nn.l2_loss( tf.concat(filtered_variables, axis=0)) loss = negative_log_likelihood + l2_loss # Scale the loss given the TPUStrategy will reduce sum all gradients. scaled_loss = loss / strategy.num_replicas_in_sync grads = tape.gradient(scaled_loss, model.trainable_variables) # Separate learning rate implementation. if FLAGS.fast_weight_lr_multiplier != 1.0: grads_and_vars = [] for grad, var in zip(grads, model.trainable_variables): # Apply different learning rate on the fast weights. This excludes BN # and slow weights, but pay caution to the naming scheme. if ('batch_norm' not in var.name and 'kernel' not in var.name): grads_and_vars.append( (grad * FLAGS.fast_weight_lr_multiplier, var)) else: grads_and_vars.append((grad, var)) optimizer.apply_gradients(grads_and_vars) else: optimizer.apply_gradients(zip(grads, model.trainable_variables)) if FLAGS.mixup_alpha > 0: labels = tf.argmax(labels, axis=-1) metrics['train/ece'].update_state(labels, probs) metrics['train/loss'].update_state(loss) metrics['train/negative_log_likelihood'].update_state( negative_log_likelihood) metrics['train/accuracy'].update_state(labels, logits) for k, v in diversity_results.items(): training_diversity['train/' + k].update_state(v)
def step_fn(inputs): """Per-Replica StepFn.""" if FLAGS.forget_mixup: images, labels, idx = inputs else: images, labels = inputs if FLAGS.adaptive_mixup or FLAGS.forget_mixup: images = tf.identity(images) elif FLAGS.augmix or FLAGS.random_augment: images_shape = tf.shape(images) images = tf.reshape(tf.transpose( images, [1, 0, 2, 3, 4]), [-1, images_shape[2], images_shape[3], images_shape[4]]) else: images = tf.tile(images, [FLAGS.ensemble_size, 1, 1, 1]) # Augmix, adaptive mixup, forget mixup preprocessing gives tiled labels. if FLAGS.mixup_alpha > 0 or FLAGS.label_smoothing > 0 or FLAGS.cutmix: if FLAGS.augmix or FLAGS.adaptive_mixup or FLAGS.forget_mixup: labels = tf.identity(labels) else: labels = tf.tile(labels, [FLAGS.ensemble_size, 1]) else: labels = tf.tile(labels, [FLAGS.ensemble_size]) def _is_batch_norm(v): """Decide whether a variable belongs to `batch_norm`.""" keywords = ['batchnorm', 'batch_norm', 'bn'] return any([k in v.name.lower() for k in keywords]) def _normalize(x): """Normalize an input with l2 norm.""" l2 = tf.norm(x, ord=2, axis=-1) return x / tf.expand_dims(l2, axis=-1) # Taking the sum of upper triangular of XX^T and divided by ensemble size. def pairwise_cosine_distance(x): """Compute the pairwise distance in a matrix.""" normalized_x = _normalize(x) return (tf.reduce_sum( tf.matmul(normalized_x, normalized_x, transpose_b=True)) - FLAGS.ensemble_size) / (2.0 * FLAGS.ensemble_size) with tf.GradientTape() as tape: logits = model(images, training=True) if FLAGS.use_bfloat16: logits = tf.cast(logits, tf.float32) if FLAGS.mixup_alpha > 0 or FLAGS.label_smoothing > 0 or FLAGS.cutmix: negative_log_likelihood = tf.reduce_mean( tf.keras.losses.categorical_crossentropy(labels, logits, from_logits=True)) else: negative_log_likelihood = tf.reduce_mean( tf.keras.losses.sparse_categorical_crossentropy(labels, logits, from_logits=True)) l2_loss = sum(model.losses) fast_weights = [var for var in model.trainable_variables if not _is_batch_norm(var) and ( 'alpha' in var.name or 'gamma' in var.name)] pairwise_distance_loss = tf.add_n( [pairwise_cosine_distance(var) for var in fast_weights]) diversity_start_iter = steps_per_epoch * FLAGS.diversity_start_epoch diversity_iterations = optimizer.iterations - diversity_start_iter if diversity_iterations > 0: diversity_coeff = diversity_schedule(diversity_iterations) diversity_loss = diversity_coeff * pairwise_distance_loss loss = negative_log_likelihood + l2_loss + diversity_loss else: loss = negative_log_likelihood + l2_loss # Scale the loss given the TPUStrategy will reduce sum all gradients. scaled_loss = loss / strategy.num_replicas_in_sync grads = tape.gradient(scaled_loss, model.trainable_variables) # Separate learning rate implementation. if FLAGS.fast_weight_lr_multiplier != 1.0: grads_and_vars = [] for grad, var in zip(grads, model.trainable_variables): # Apply different learning rate on the fast weight approximate # posterior/prior parameters. This is excludes BN and slow weights, # but pay caution to the naming scheme. if (not _is_batch_norm(var) and 'kernel' not in var.name): grads_and_vars.append((grad * FLAGS.fast_weight_lr_multiplier, var)) else: grads_and_vars.append((grad, var)) optimizer.apply_gradients(grads_and_vars) else: optimizer.apply_gradients(zip(grads, model.trainable_variables)) probs = tf.nn.softmax(logits) if FLAGS.ensemble_size > 1: per_probs = tf.reshape( probs, tf.concat([[FLAGS.ensemble_size, -1], probs.shape[1:]], 0)) diversity_results = um.average_pairwise_diversity( per_probs, FLAGS.ensemble_size) for k, v in diversity_results.items(): training_diversity['train/' + k].update_state(v) if FLAGS.mixup_alpha > 0 or FLAGS.label_smoothing > 0 or FLAGS.cutmix: labels = tf.argmax(labels, axis=-1) metrics['train/ece'].update_state(labels, probs) metrics['train/similarity'].update_state(pairwise_distance_loss) metrics['train/l2'].update_state(l2_loss) metrics['train/loss'].update_state(loss) metrics['train/negative_log_likelihood'].update_state( negative_log_likelihood) metrics['train/accuracy'].update_state(labels, logits) if FLAGS.forget_mixup: train_predictions = tf.argmax(probs, -1) labels = tf.cast(labels, train_predictions.dtype) # For each ensemble member, we accumulate the accuracy counts. accuracy_counts = tf.cast(tf.reshape( (train_predictions == labels), [FLAGS.ensemble_size, -1]), tf.float32) return accuracy_counts, idx
def main(argv): del argv # unused arg if not FLAGS.use_gpu: raise ValueError('Only GPU is currently supported.') if FLAGS.num_cores > 1: raise ValueError('Only a single accelerator is currently supported.') tf.random.set_seed(FLAGS.seed) tf.io.gfile.makedirs(FLAGS.output_dir) ds_info = tfds.builder(FLAGS.dataset).info batch_size = FLAGS.per_core_batch_size * FLAGS.num_cores steps_per_eval = ds_info.splits['test'].num_examples // batch_size num_classes = ds_info.features['label'].num_classes dataset_input_fn = utils.load_input_fn( split=tfds.Split.TEST, name=FLAGS.dataset, batch_size=FLAGS.per_core_batch_size, use_bfloat16=FLAGS.use_bfloat16) test_datasets = {'clean': dataset_input_fn()} corruption_types, max_intensity = utils.load_corrupted_test_info( FLAGS.dataset) for name in corruption_types: for intensity in range(1, max_intensity + 1): dataset_name = '{0}_{1}'.format(name, intensity) if FLAGS.dataset == 'cifar10': load_c_dataset = utils.load_cifar10_c_input_fn else: load_c_dataset = functools.partial( utils.load_cifar100_c_input_fn, path=FLAGS.cifar100_c_path) corrupted_input_fn = load_c_dataset( corruption_name=name, corruption_intensity=intensity, batch_size=FLAGS.per_core_batch_size, use_bfloat16=FLAGS.use_bfloat16) test_datasets[dataset_name] = corrupted_input_fn() model = ub.models.wide_resnet( input_shape=ds_info.features['image'].shape, depth=28, width_multiplier=10, num_classes=num_classes, l2=0., version=2) logging.info('Model input shape: %s', model.input_shape) logging.info('Model output shape: %s', model.output_shape) logging.info('Model number of weights: %s', model.count_params()) # Search for checkpoints from their index file; then remove the index suffix. ensemble_filenames = tf.io.gfile.glob(os.path.join(FLAGS.checkpoint_dir, '**/*.index')) ensemble_filenames = [filename[:-6] for filename in ensemble_filenames] ensemble_size = len(ensemble_filenames) logging.info('Ensemble size: %s', ensemble_size) logging.info('Ensemble number of weights: %s', ensemble_size * model.count_params()) logging.info('Ensemble filenames: %s', str(ensemble_filenames)) checkpoint = tf.train.Checkpoint(model=model) # Write model predictions to files. num_datasets = len(test_datasets) for m, ensemble_filename in enumerate(ensemble_filenames): checkpoint.restore(ensemble_filename) for n, (name, test_dataset) in enumerate(test_datasets.items()): filename = '{dataset}_{member}.npy'.format(dataset=name, member=m) filename = os.path.join(FLAGS.output_dir, filename) if not tf.io.gfile.exists(filename): logits = [] test_iterator = iter(test_dataset) for _ in range(steps_per_eval): features, _ = next(test_iterator) # pytype: disable=attribute-error logits.append(model(features, training=False)) logits = tf.concat(logits, axis=0) with tf.io.gfile.GFile(filename, 'w') as f: np.save(f, logits.numpy()) percent = (m * num_datasets + (n + 1)) / (ensemble_size * num_datasets) message = ('{:.1%} completion for prediction: ensemble member {:d}/{:d}. ' 'Dataset {:d}/{:d}'.format(percent, m + 1, ensemble_size, n + 1, num_datasets)) logging.info(message) metrics = { 'test/negative_log_likelihood': tf.keras.metrics.Mean(), 'test/gibbs_cross_entropy': tf.keras.metrics.Mean(), 'test/accuracy': tf.keras.metrics.SparseCategoricalAccuracy(), 'test/ece': um.ExpectedCalibrationError(num_bins=FLAGS.num_bins), } corrupt_metrics = {} for name in test_datasets: corrupt_metrics['test/nll_{}'.format(name)] = tf.keras.metrics.Mean() corrupt_metrics['test/accuracy_{}'.format(name)] = ( tf.keras.metrics.SparseCategoricalAccuracy()) corrupt_metrics['test/ece_{}'.format(name)] = ( um.ExpectedCalibrationError(num_bins=FLAGS.num_bins)) for i in range(ensemble_size): metrics['test/nll_member_{}'.format(i)] = tf.keras.metrics.Mean() metrics['test/accuracy_member_{}'.format(i)] = ( tf.keras.metrics.SparseCategoricalAccuracy()) test_diversity = { 'test/disagreement': tf.keras.metrics.Mean(), 'test/average_kl': tf.keras.metrics.Mean(), 'test/cosine_similarity': tf.keras.metrics.Mean(), } metrics.update(test_diversity) # Evaluate model predictions. for n, (name, test_dataset) in enumerate(test_datasets.items()): logits_dataset = [] for m in range(ensemble_size): filename = '{dataset}_{member}.npy'.format(dataset=name, member=m) filename = os.path.join(FLAGS.output_dir, filename) with tf.io.gfile.GFile(filename, 'rb') as f: logits_dataset.append(np.load(f)) logits_dataset = tf.convert_to_tensor(logits_dataset) test_iterator = iter(test_dataset) for step in range(steps_per_eval): _, labels = next(test_iterator) # pytype: disable=attribute-error logits = logits_dataset[:, (step*batch_size):((step+1)*batch_size)] labels = tf.cast(labels, tf.int32) negative_log_likelihood = um.ensemble_cross_entropy(labels, logits) per_probs = tf.nn.softmax(logits) probs = tf.reduce_mean(per_probs, axis=0) if name == 'clean': gibbs_ce = um.gibbs_cross_entropy(labels, logits) metrics['test/negative_log_likelihood'].update_state( negative_log_likelihood) metrics['test/gibbs_cross_entropy'].update_state(gibbs_ce) metrics['test/accuracy'].update_state(labels, probs) metrics['test/ece'].update_state(labels, probs) for i in range(ensemble_size): member_probs = per_probs[i] member_loss = tf.keras.losses.sparse_categorical_crossentropy( labels, member_probs) metrics['test/nll_member_{}'.format(i)].update_state(member_loss) metrics['test/accuracy_member_{}'.format(i)].update_state( labels, member_probs) diversity_results = um.average_pairwise_diversity( per_probs, ensemble_size) for k, v in diversity_results.items(): test_diversity['test/' + k].update_state(v) else: corrupt_metrics['test/nll_{}'.format(name)].update_state( negative_log_likelihood) corrupt_metrics['test/accuracy_{}'.format(name)].update_state( labels, probs) corrupt_metrics['test/ece_{}'.format(name)].update_state( labels, probs) message = ('{:.1%} completion for evaluation: dataset {:d}/{:d}'.format( (n + 1) / num_datasets, n + 1, num_datasets)) logging.info(message) corrupt_results = utils.aggregate_corrupt_metrics(corrupt_metrics, corruption_types, max_intensity) total_results = {name: metric.result() for name, metric in metrics.items()} total_results.update(corrupt_results) logging.info('Metrics: %s', total_results)
def step_fn(inputs): """Per-Replica StepFn.""" images, labels = inputs if FLAGS.ensemble_size > 1: images = tf.tile(images, [FLAGS.ensemble_size, 1, 1, 1]) labels = tf.tile(labels, [FLAGS.ensemble_size]) with tf.GradientTape() as tape: logits = model(images, training=True) if FLAGS.use_bfloat16: logits = tf.cast(logits, tf.float32) probs = tf.nn.softmax(logits) if FLAGS.ensemble_size > 1: per_probs = tf.reshape( probs, tf.concat([[FLAGS.ensemble_size, -1], probs.shape[1:]], 0)) diversity_results = um.average_pairwise_diversity( per_probs, FLAGS.ensemble_size) negative_log_likelihood = tf.reduce_mean( tf.keras.losses.sparse_categorical_crossentropy( labels, logits, from_logits=True)) l2_loss = compute_l2_loss(model) kl = sum(model.losses) / APPROX_IMAGENET_TRAIN_IMAGES kl_scale = tf.cast(optimizer.iterations + 1, kl.dtype) kl_scale /= steps_per_epoch * FLAGS.kl_annealing_epochs kl_scale = tf.minimum(1., kl_scale) kl_loss = kl_scale * kl # Scale the loss given the TPUStrategy will reduce sum all gradients. loss = negative_log_likelihood + l2_loss + kl_loss scaled_loss = loss / strategy.num_replicas_in_sync elbo = -(negative_log_likelihood + l2_loss + kl) grads = tape.gradient(scaled_loss, model.trainable_variables) # Separate learning rate implementation. if FLAGS.fast_weight_lr_multiplier != 1.0: grads_and_vars = [] for grad, var in zip(grads, model.trainable_variables): # Apply different learning rate on the fast weights. This excludes BN # and slow weights, but pay caution to the naming scheme. if ('batch_norm' not in var.name and 'kernel' not in var.name): grads_and_vars.append( (grad * FLAGS.fast_weight_lr_multiplier, var)) else: grads_and_vars.append((grad, var)) optimizer.apply_gradients(grads_and_vars) else: optimizer.apply_gradients(zip(grads, model.trainable_variables)) metrics['train/negative_log_likelihood'].update_state( negative_log_likelihood) metrics['train/kl'].update_state(kl) metrics['train/kl_scale'].update_state(kl_scale) metrics['train/elbo'].update_state(elbo) metrics['train/loss'].update_state(loss) metrics['train/accuracy'].update_state(labels, logits) metrics['train/ece'].update_state(labels, probs) if FLAGS.ensemble_size > 1: for k, v in diversity_results.items(): training_diversity['train/' + k].update_state(v)
def step_fn(inputs): """Per-Replica StepFn.""" images, labels = inputs images = tf.tile(images, [FLAGS.ensemble_size, 1, 1, 1]) # generate lambdas lambdas = log_uniform_sample(per_core_batch_size, lambda_parameters) lambdas = tf.reshape(lambdas, (FLAGS.ensemble_size * per_core_batch_size, lambdas_config.dim)) with tf.GradientTape() as tape: logits = model([images, lambdas], training=True) if FLAGS.use_bfloat16: logits = tf.cast(logits, tf.float32) if FLAGS.use_gibbs_ce: # Average of single model CEs # tiling of labels should be only done for Gibbs CE loss labels = tf.tile(labels, [FLAGS.ensemble_size]) negative_log_likelihood = tf.reduce_mean( tf.keras.losses.sparse_categorical_crossentropy( labels, logits, from_logits=True)) else: # Ensemble CE uses no tiling of the labels negative_log_likelihood = ensemble_crossentropy( labels, logits, FLAGS.ensemble_size) # Note: Divide l2_loss by sample_size (this differs from uncertainty_ # baselines implementation.) l2_loss = sum(model.losses) / train_sample_size loss = negative_log_likelihood + l2_loss # Scale the loss given the TPUStrategy will reduce sum all gradients. scaled_loss = loss / strategy.num_replicas_in_sync grads = tape.gradient(scaled_loss, model.trainable_variables) # Separate learning rate for fast weights. grads_and_vars = [] for grad, var in zip(grads, model.trainable_variables): if (('alpha' in var.name or 'gamma' in var.name) and 'batch_norm' not in var.name): grads_and_vars.append( (grad * FLAGS.fast_weight_lr_multiplier, var)) else: grads_and_vars.append((grad, var)) optimizer.apply_gradients(grads_and_vars) probs = tf.nn.softmax(logits) per_probs = tf.split(probs, num_or_size_splits=FLAGS.ensemble_size, axis=0) per_probs_stacked = tf.stack(per_probs, axis=0) metrics['train/ece'].update_state(labels, probs) metrics['train/loss'].update_state(loss) metrics['train/negative_log_likelihood'].update_state( negative_log_likelihood) metrics['train/accuracy'].update_state(labels, logits) diversity_results = um.average_pairwise_diversity( per_probs_stacked, FLAGS.ensemble_size) for k, v in diversity_results.items(): metrics['train/' + k].update_state(v) if grads_and_vars: grads, _ = zip(*grads_and_vars)
def step_fn(inputs): """Per-Replica StepFn.""" images, labels = inputs if FLAGS.ensemble_size > 1: images = tf.tile(images, [FLAGS.ensemble_size, 1, 1, 1]) if FLAGS.num_eval_samples > 1: images = tf.tile(images, [FLAGS.num_eval_samples, 1, 1, 1]) logits = model(images, training=False) probs = tf.nn.softmax(logits) if FLAGS.num_eval_samples > 1: probs = tf.reshape(probs, tf.concat([[FLAGS.num_eval_samples, -1], probs.shape[1:]], 0)) probs = tf.reduce_mean(probs, 0) if FLAGS.ensemble_size > 1: per_probs = tf.split(probs, num_or_size_splits=FLAGS.ensemble_size, axis=0) if dataset_name == 'clean': per_probs_tensor = tf.reshape( probs, tf.concat([[FLAGS.ensemble_size, -1], probs.shape[1:]], 0)) diversity_results = um.average_pairwise_diversity( per_probs_tensor, FLAGS.ensemble_size) for k, v in diversity_results.items(): test_diversity['test/' + k].update_state(v) for i in range(FLAGS.ensemble_size): member_probs = per_probs[i] member_nll = tf.keras.losses.sparse_categorical_crossentropy( labels, member_probs) metrics['test/nll_member_{}'.format(i)].update_state(member_nll) metrics['test/accuracy_member_{}'.format(i)].update_state( labels, member_probs) probs = tf.reduce_mean(per_probs, axis=0) negative_log_likelihood = tf.reduce_mean( tf.keras.losses.sparse_categorical_crossentropy(labels, probs)) filtered_variables = [] for var in model.trainable_variables: if 'kernel' in var.name or 'bias' in var.name: filtered_variables.append(tf.reshape(var, (-1,))) kl = sum(model.losses) / test_dataset_size l2_loss = kl + FLAGS.l2 * 2 * tf.nn.l2_loss( tf.concat(filtered_variables, axis=0)) loss = negative_log_likelihood + l2_loss if dataset_name == 'clean': metrics['test/negative_log_likelihood'].update_state( negative_log_likelihood) metrics['test/accuracy'].update_state(labels, probs) metrics['test/ece'].update_state(labels, probs) metrics['test/loss'].update_state(loss) else: corrupt_metrics['test/nll_{}'.format(dataset_name)].update_state( negative_log_likelihood) corrupt_metrics['test/accuracy_{}'.format(dataset_name)].update_state( labels, probs) corrupt_metrics['test/ece_{}'.format(dataset_name)].update_state( labels, probs)
def step_fn(inputs): """Per-Replica StepFn.""" images, labels = inputs if FLAGS.version2 and FLAGS.ensemble_size > 1: images = tf.tile(images, [FLAGS.ensemble_size, 1, 1, 1]) if not (FLAGS.member_sampling or FLAGS.expected_probs): labels = tf.tile(labels, [FLAGS.ensemble_size]) if FLAGS.num_train_samples > 1: images = tf.tile(images, [FLAGS.num_train_samples, 1, 1, 1]) with tf.GradientTape() as tape: logits = model(images, training=True) probs = tf.nn.softmax(logits) # Diversity evaluation. if FLAGS.version2 and FLAGS.ensemble_size > 1: per_probs = tf.reshape( probs, tf.concat([[FLAGS.ensemble_size, -1], probs.shape[1:]], 0)) diversity_results = um.average_pairwise_diversity( per_probs, FLAGS.ensemble_size) if FLAGS.num_train_samples > 1: probs = tf.reshape(probs, tf.concat([[FLAGS.num_train_samples, -1], probs.shape[1:]], 0)) probs = tf.reduce_mean(probs, 0) if FLAGS.member_sampling and FLAGS.version2 and FLAGS.ensemble_size > 1: idx = tf.random.uniform([], maxval=FLAGS.ensemble_size, dtype=tf.int64) idx_one_hot = tf.expand_dims(tf.one_hot(idx, FLAGS.ensemble_size, dtype=probs.dtype), 0) probs_shape = probs.shape probs = tf.reshape(probs, [FLAGS.ensemble_size, -1]) probs = tf.matmul(idx_one_hot, probs) probs = tf.reshape(probs, tf.concat([[-1], probs_shape[1:]], 0)) elif FLAGS.expected_probs and FLAGS.version2 and FLAGS.ensemble_size > 1: probs = tf.reshape(probs, tf.concat([[FLAGS.ensemble_size, -1], probs.shape[1:]], 0)) probs = tf.reduce_mean(probs, 0) negative_log_likelihood = tf.reduce_mean( tf.keras.losses.sparse_categorical_crossentropy(labels, probs)) filtered_variables = [] for var in model.trainable_variables: # Apply l2 on the slow weights and bias terms. This excludes BN # parameters and fast weight approximate posterior/prior parameters, # but pay caution to their naming scheme. if 'kernel' in var.name or 'bias' in var.name: filtered_variables.append(tf.reshape(var, (-1,))) l2_loss = FLAGS.l2 * 2 * tf.nn.l2_loss( tf.concat(filtered_variables, axis=0)) kl = sum(model.losses) / train_dataset_size kl_scale = tf.cast(optimizer.iterations + 1, kl.dtype) kl_scale /= FLAGS.kl_annealing_steps kl_scale = tf.minimum(1., kl_scale) kl_loss = kl_scale * kl # Scale the loss given the TPUStrategy will reduce sum all gradients. loss = negative_log_likelihood + l2_loss + kl_loss scaled_loss = loss / strategy.num_replicas_in_sync grads = tape.gradient(scaled_loss, model.trainable_variables) # Separate learning rate implementation. grad_list = [] if FLAGS.fast_weight_lr_multiplier != 1.0: grads_and_vars = list(zip(grads, model.trainable_variables)) for vec, var in grads_and_vars: # Apply different learning rate on the fast weight approximate # posterior/prior parameters. This is excludes BN and slow weights, # but pay caution to the naming scheme. if ('batch_norm' not in var.name and 'kernel' not in var.name): grad_list.append((vec * FLAGS.fast_weight_lr_multiplier, var)) else: grad_list.append((vec, var)) optimizer.apply_gradients(grad_list) else: optimizer.apply_gradients(zip(grads, model.trainable_variables)) metrics['train/ece'].update_state(labels, probs) metrics['train/loss'].update_state(loss) metrics['train/negative_log_likelihood'].update_state( negative_log_likelihood) metrics['train/accuracy'].update_state(labels, probs) if FLAGS.version2 and FLAGS.ensemble_size > 1: for k, v in diversity_results.items(): training_diversity['train/' + k].update_state(v)
def step_fn(inputs): """Per-Replica StepFn.""" # Note that we don't use tf.tile for labels here images = inputs['features'] labels = inputs['labels'] images = tf.tile(images, [ensemble_size, 1, 1, 1]) # get lambdas samples = log_uniform_sample(n_samples, lambda_parameters) if num_eval_samples >= 0: lambdas = log_uniform_mean(lambda_parameters) lambdas = tf.expand_dims(lambdas, 1) lambdas = tf.concat((lambdas, samples), 1) else: lambdas = samples # lambdas with shape (ens size, samples, dim of lambdas) rep_lambdas = tf.repeat(lambdas, per_core_batch_size, axis=1) rep_lambdas = tf.reshape(rep_lambdas, (ensemble_size * per_core_batch_size, -1)) # eval on testsets logits = model([images, rep_lambdas], training=False) probs = tf.nn.softmax(logits) per_probs = tf.split(probs, num_or_size_splits=ensemble_size, axis=0) # per member performance and gibbs performance (average per member perf) if dataset_name == 'clean': for i in range(FLAGS.ensemble_size): # we record the first sample of lambdas per batch-ens member first_member_index = i * (ensemble_size // FLAGS.ensemble_size) member_probs = per_probs[first_member_index] member_loss = tf.keras.losses.sparse_categorical_crossentropy( labels, member_probs) metrics['test/nll_member_{}'.format(i)].update_state( member_loss) metrics['test/accuracy_member_{}'.format(i)].update_state( labels, member_probs) labels_tile = tf.tile(labels, [ensemble_size]) metrics['test/gibbs_nll'].update_state( tf.reduce_mean( tf.keras.losses.sparse_categorical_crossentropy( labels_tile, logits, from_logits=True))) metrics['test/gibbs_accuracy'].update_state(labels_tile, probs) # ensemble performance negative_log_likelihood = ensemble_crossentropy( labels, logits, ensemble_size) probs = tf.reduce_mean(per_probs, axis=0) if dataset_name == 'clean': metrics['test/negative_log_likelihood'].update_state( negative_log_likelihood) metrics['test/accuracy'].update_state(labels, probs) metrics['test/ece'].update_state(labels, probs) else: corrupt_metrics['test/nll_{}'.format( dataset_name)].update_state(negative_log_likelihood) corrupt_metrics['test/accuracy_{}'.format( dataset_name)].update_state(labels, probs) corrupt_metrics['test/ece_{}'.format( dataset_name)].update_state(labels, probs) if dataset_name == 'clean': per_probs_stacked = tf.stack(per_probs, axis=0) diversity_results = um.average_pairwise_diversity( per_probs_stacked, ensemble_size) for k, v in diversity_results.items(): metrics['test/' + k].update_state(v)
def main(argv): del argv # unused arg if not FLAGS.use_gpu: raise ValueError('Only GPU is currently supported.') if FLAGS.num_cores > 1: raise ValueError('Only a single accelerator is currently supported.') tf.random.set_seed(FLAGS.seed) tf.io.gfile.makedirs(FLAGS.output_dir) ds_info = tfds.builder(FLAGS.dataset).info batch_size = FLAGS.per_core_batch_size * FLAGS.num_cores steps_per_eval = ds_info.splits['test'].num_examples // batch_size num_classes = ds_info.features['label'].num_classes dataset = ub.datasets.get( FLAGS.dataset, split=tfds.Split.TEST).load(batch_size=batch_size) validation_percent = 1. - FLAGS.train_proportion val_dataset = ub.datasets.get( dataset_name=FLAGS.dataset, split=tfds.Split.VALIDATION, validation_percent=validation_percent, drop_remainder=False).load(batch_size=batch_size) steps_per_val_eval = int(ds_info.splits['train'].num_examples * validation_percent) // batch_size test_datasets = {'clean': dataset} extra_kwargs = {} if FLAGS.dataset == 'cifar100': extra_kwargs['data_dir'] = FLAGS.cifar100_c_path corruption_types, _ = utils.load_corrupted_test_info(FLAGS.dataset) for corruption_type in corruption_types: for severity in range(1, 6): dataset = ub.datasets.get( f'{FLAGS.dataset}_corrupted', corruption_type=corruption_type, severity=severity, split=tfds.Split.TEST, **extra_kwargs).load(batch_size=batch_size) test_datasets[f'{corruption_type}_{severity}'] = dataset model = ub.models.wide_resnet( input_shape=ds_info.features['image'].shape, depth=28, width_multiplier=10, num_classes=num_classes, l2=0., version=2) logging.info('Model input shape: %s', model.input_shape) logging.info('Model output shape: %s', model.output_shape) logging.info('Model number of weights: %s', model.count_params()) # Search for checkpoints ensemble_filenames = parse_checkpoint_dir(FLAGS.checkpoint_dir) model_pool_size = len(ensemble_filenames) logging.info('Model pool size: %s', model_pool_size) logging.info('Ensemble size: %s', FLAGS.ensemble_size) logging.info('Ensemble number of weights: %s', FLAGS.ensemble_size * model.count_params()) logging.info('Ensemble filenames: %s', str(ensemble_filenames)) checkpoint = tf.train.Checkpoint(model=model) # Compute the logits on the validation set val_logits, val_labels = [], [] for m, ensemble_filename in enumerate(ensemble_filenames): # Enforce memory clean-up tf.keras.backend.clear_session() checkpoint.restore(ensemble_filename) val_iterator = iter(val_dataset) val_logits_m = [] for _ in range(steps_per_val_eval): inputs = next(val_iterator) features = inputs['features'] labels = inputs['labels'] val_logits_m.append(model(features, training=False)) if m == 0: val_labels.append(labels) val_logits.append(tf.concat(val_logits_m, axis=0)) if m == 0: val_labels = tf.concat(val_labels, axis=0) percent = (m + 1.) / model_pool_size message = ('{:.1%} completion for prediction on validation set: ' 'model {:d}/{:d}.'.format(percent, m + 1, model_pool_size)) logging.info(message) selected_members, val_acc, val_nll = greedy_selection(val_logits, val_labels, FLAGS.ensemble_size, FLAGS.greedy_objective) unique_selected_members = list(set(selected_members)) message = ('Members selected by greedy procedure: {} (with {} unique ' 'member(s))\n\t{}').format( selected_members, len(unique_selected_members), [ensemble_filenames[i] for i in selected_members]) logging.info(message) val_metrics = { 'val/accuracy': tf.keras.metrics.Mean(), 'val/negative_log_likelihood': tf.keras.metrics.Mean() } val_metrics['val/accuracy'].update_state(val_acc) val_metrics['val/negative_log_likelihood'].update_state(val_nll) # Write model predictions to files. num_datasets = len(test_datasets) for m, member_id in enumerate(unique_selected_members): ensemble_filename = ensemble_filenames[member_id] checkpoint.restore(ensemble_filename) for n, (name, test_dataset) in enumerate(test_datasets.items()): filename = '{dataset}_{member}.npy'.format(dataset=name, member=member_id) filename = os.path.join(FLAGS.output_dir, filename) if not tf.io.gfile.exists(filename): logits = [] test_iterator = iter(test_dataset) for _ in range(steps_per_eval): features = next(test_iterator)['features'] # pytype: disable=unsupported-operands logits.append(model(features, training=False)) logits = tf.concat(logits, axis=0) with tf.io.gfile.GFile(filename, 'w') as f: np.save(f, logits.numpy()) numerator = m * num_datasets + (n + 1) denominator = len(unique_selected_members) * num_datasets percent = numerator / denominator message = ('{:.1%} completion for prediction: ensemble member {:d}/{:d}. ' 'Dataset {:d}/{:d}'.format(percent, m + 1, len(unique_selected_members), n + 1, num_datasets)) logging.info(message) metrics = { 'test/negative_log_likelihood': tf.keras.metrics.Mean(), 'test/gibbs_cross_entropy': tf.keras.metrics.Mean(), 'test/accuracy': tf.keras.metrics.SparseCategoricalAccuracy(), 'test/ece': rm.metrics.ExpectedCalibrationError( num_bins=FLAGS.num_bins), } metrics.update(val_metrics) corrupt_metrics = {} for name in test_datasets: corrupt_metrics['test/nll_{}'.format(name)] = tf.keras.metrics.Mean() corrupt_metrics['test/accuracy_{}'.format(name)] = ( tf.keras.metrics.SparseCategoricalAccuracy()) corrupt_metrics['test/ece_{}'.format(name)] = ( rm.metrics.ExpectedCalibrationError(num_bins=FLAGS.num_bins)) for i in range(len(unique_selected_members)): metrics['test/nll_member_{}'.format(i)] = tf.keras.metrics.Mean() metrics['test/accuracy_member_{}'.format(i)] = ( tf.keras.metrics.SparseCategoricalAccuracy()) test_diversity = { 'test/disagreement': tf.keras.metrics.Mean(), 'test/average_kl': tf.keras.metrics.Mean(), 'test/cosine_similarity': tf.keras.metrics.Mean(), } metrics.update(test_diversity) # Evaluate model predictions. for n, (name, test_dataset) in enumerate(test_datasets.items()): logits_dataset = [] for member_id in selected_members: filename = '{dataset}_{member}.npy'.format(dataset=name, member=member_id) filename = os.path.join(FLAGS.output_dir, filename) with tf.io.gfile.GFile(filename, 'rb') as f: logits_dataset.append(np.load(f)) logits_dataset = tf.convert_to_tensor(logits_dataset) test_iterator = iter(test_dataset) for step in range(steps_per_eval): labels = next(test_iterator)['labels'] # pytype: disable=unsupported-operands logits = logits_dataset[:, (step*batch_size):((step+1)*batch_size)] labels = tf.cast(labels, tf.int32) negative_log_likelihood = um.ensemble_cross_entropy(labels, logits) per_probs = tf.nn.softmax(logits) probs = tf.reduce_mean(per_probs, axis=0) if name == 'clean': gibbs_ce = um.gibbs_cross_entropy(labels, logits) metrics['test/negative_log_likelihood'].update_state( negative_log_likelihood) metrics['test/gibbs_cross_entropy'].update_state(gibbs_ce) metrics['test/accuracy'].update_state(labels, probs) metrics['test/ece'].add_batch(probs, label=labels) # Attention must be paid to deal with duplicated members: # e.g., #. selected_members = [2, 7, 3, 3] # unique_selected_members = [2, 3, 7] # selected_members.index(3) --> 2 for member_id in unique_selected_members: i = selected_members.index(member_id) member_probs = per_probs[i] member_loss = tf.keras.losses.sparse_categorical_crossentropy( labels, member_probs) metrics['test/nll_member_{}'.format(i)].update_state(member_loss) metrics['test/accuracy_member_{}'.format(i)].update_state( labels, member_probs) diversity_results = um.average_pairwise_diversity( per_probs, len(per_probs)) for k, v in diversity_results.items(): test_diversity['test/' + k].update_state(v) else: corrupt_metrics['test/nll_{}'.format(name)].update_state( negative_log_likelihood) corrupt_metrics['test/accuracy_{}'.format(name)].update_state( labels, probs) corrupt_metrics['test/ece_{}'.format(name)].add_batch( probs, label=labels) message = ('{:.1%} completion for evaluation: dataset {:d}/{:d}'.format( (n + 1) / num_datasets, n + 1, num_datasets)) logging.info(message) corrupt_results = utils.aggregate_corrupt_metrics(corrupt_metrics, corruption_types) total_results = {name: metric.result() for name, metric in metrics.items()} total_results.update(corrupt_results) # Metrics from Robustness Metrics (like ECE) will return a dict with a # single key/value, instead of a scalar. total_results = { k: (list(v.values())[0] if isinstance(v, dict) else v) for k, v in total_results.items() } logging.info('Metrics: %s', total_results)
def main(argv): del argv # unused arg if not FLAGS.use_gpu: raise ValueError('Only GPU is currently supported.') if FLAGS.num_cores > 1: raise ValueError('Only a single accelerator is currently supported.') tf.random.set_seed(FLAGS.seed) tf.io.gfile.makedirs(FLAGS.output_dir) ds_info = tfds.builder(FLAGS.dataset).info batch_size = FLAGS.per_core_batch_size * FLAGS.num_cores steps_per_eval = ds_info.splits['test'].num_examples // batch_size num_classes = ds_info.features['label'].num_classes dataset = ub.datasets.get( FLAGS.dataset, split=tfds.Split.TEST).load(batch_size=batch_size) test_datasets = {'clean': dataset} extra_kwargs = {} if FLAGS.dataset == 'cifar100': extra_kwargs['data_dir'] = FLAGS.cifar100_c_path corruption_types, _ = utils.load_corrupted_test_info(FLAGS.dataset) for corruption_type in corruption_types: for severity in range(1, 6): dataset = ub.datasets.get( f'{FLAGS.dataset}_corrupted', corruption_type=corruption_type, severity=severity, split=tfds.Split.TEST, **extra_kwargs).load(batch_size=batch_size) test_datasets[f'{corruption_type}_{severity}'] = dataset model = ub.models.wide_resnet_heteroscedastic( input_shape=ds_info.features['image'].shape, depth=28, width_multiplier=10, num_classes=num_classes, l2=0., version=2, temperature=FLAGS.temperature, num_factors=FLAGS.num_factors, num_mc_samples=FLAGS.num_mc_samples) logging.info('Model input shape: %s', model.input_shape) logging.info('Model output shape: %s', model.output_shape) logging.info('Model number of weights: %s', model.count_params()) # Search for checkpoints from their index file; then remove the index suffix. ensemble_filenames = parse_checkpoint_dir(FLAGS.checkpoint_dir) ensemble_size = len(ensemble_filenames) logging.info('Ensemble size: %s', ensemble_size) logging.info('Ensemble number of weights: %s', ensemble_size * model.count_params()) logging.info('Ensemble filenames: %s', str(ensemble_filenames)) checkpoint = tf.train.Checkpoint(model=model) # Write model predictions to files. num_datasets = len(test_datasets) for m, ensemble_filename in enumerate(ensemble_filenames): checkpoint.restore(ensemble_filename) for n, (name, test_dataset) in enumerate(test_datasets.items()): filename = '{dataset}_{member}.npy'.format(dataset=name, member=m) filename = os.path.join(FLAGS.output_dir, filename) if not tf.io.gfile.exists(filename): logits = [] test_iterator = iter(test_dataset) for _ in range(steps_per_eval): features = next(test_iterator)['features'] # pytype: disable=unsupported-operands logits.append(model(features, training=False)) logits = tf.concat(logits, axis=0) with tf.io.gfile.GFile(filename, 'w') as f: np.save(f, logits.numpy()) percent = (m * num_datasets + (n + 1)) / (ensemble_size * num_datasets) message = ('{:.1%} completion for prediction: ensemble member {:d}/{:d}. ' 'Dataset {:d}/{:d}'.format(percent, m + 1, ensemble_size, n + 1, num_datasets)) logging.info(message) metrics = { 'test/negative_log_likelihood': tf.keras.metrics.Mean(), 'test/gibbs_cross_entropy': tf.keras.metrics.Mean(), 'test/accuracy': tf.keras.metrics.SparseCategoricalAccuracy(), 'test/ece': um.ExpectedCalibrationError(num_bins=FLAGS.num_bins), } corrupt_metrics = {} for name in test_datasets: corrupt_metrics['test/nll_{}'.format(name)] = tf.keras.metrics.Mean() corrupt_metrics['test/accuracy_{}'.format(name)] = ( tf.keras.metrics.SparseCategoricalAccuracy()) corrupt_metrics['test/ece_{}'.format(name)] = ( um.ExpectedCalibrationError(num_bins=FLAGS.num_bins)) for i in range(ensemble_size): metrics['test/nll_member_{}'.format(i)] = tf.keras.metrics.Mean() metrics['test/accuracy_member_{}'.format(i)] = ( tf.keras.metrics.SparseCategoricalAccuracy()) test_diversity = { 'test/disagreement': tf.keras.metrics.Mean(), 'test/average_kl': tf.keras.metrics.Mean(), 'test/cosine_similarity': tf.keras.metrics.Mean(), } metrics.update(test_diversity) # Evaluate model predictions. for n, (name, test_dataset) in enumerate(test_datasets.items()): logits_dataset = [] for m in range(ensemble_size): filename = '{dataset}_{member}.npy'.format(dataset=name, member=m) filename = os.path.join(FLAGS.output_dir, filename) with tf.io.gfile.GFile(filename, 'rb') as f: logits_dataset.append(np.load(f)) logits_dataset = tf.convert_to_tensor(logits_dataset) test_iterator = iter(test_dataset) for step in range(steps_per_eval): labels = next(test_iterator)['labels'] # pytype: disable=unsupported-operands logits = logits_dataset[:, (step*batch_size):((step+1)*batch_size)] labels = tf.cast(labels, tf.int32) negative_log_likelihood = um.ensemble_cross_entropy(labels, logits) per_probs = tf.nn.softmax(logits) probs = tf.reduce_mean(per_probs, axis=0) if name == 'clean': gibbs_ce = um.gibbs_cross_entropy(labels, logits) metrics['test/negative_log_likelihood'].update_state( negative_log_likelihood) metrics['test/gibbs_cross_entropy'].update_state(gibbs_ce) metrics['test/accuracy'].update_state(labels, probs) metrics['test/ece'].add_batch(probs, label=labels) for i in range(ensemble_size): member_probs = per_probs[i] member_loss = tf.keras.losses.sparse_categorical_crossentropy( labels, member_probs) metrics['test/nll_member_{}'.format(i)].update_state(member_loss) metrics['test/accuracy_member_{}'.format(i)].update_state( labels, member_probs) diversity_results = um.average_pairwise_diversity( per_probs, ensemble_size) for k, v in diversity_results.items(): test_diversity['test/' + k].update_state(v) else: corrupt_metrics['test/nll_{}'.format(name)].update_state( negative_log_likelihood) corrupt_metrics['test/accuracy_{}'.format(name)].update_state( labels, probs) corrupt_metrics['test/ece_{}'.format(name)].add_batch( probs, label=labels) message = ('{:.1%} completion for evaluation: dataset {:d}/{:d}'.format( (n + 1) / num_datasets, n + 1, num_datasets)) logging.info(message) corrupt_results = utils.aggregate_corrupt_metrics(corrupt_metrics, corruption_types) total_results = {name: metric.result() for name, metric in metrics.items()} total_results.update(corrupt_results) # Metrics from Robustness Metrics (like ECE) will return a dict with a # single key/value, instead of a scalar. total_results = { k: (list(v.values())[0] if isinstance(v, dict) else v) for k, v in total_results.items() } logging.info('Metrics: %s', total_results)