def main(argv): del argv # unused arg tf.io.gfile.makedirs(FLAGS.output_dir) logging.info('Saving checkpoints at %s', FLAGS.output_dir) tf.random.set_seed(FLAGS.seed) if FLAGS.use_gpu: logging.info('Use GPU') strategy = tf.distribute.MirroredStrategy() else: logging.info('Use TPU at %s', FLAGS.tpu if FLAGS.tpu is not None else 'local') resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu=FLAGS.tpu) tf.config.experimental_connect_to_cluster(resolver) tf.tpu.experimental.initialize_tpu_system(resolver) strategy = tf.distribute.TPUStrategy(resolver) ds_info = tfds.builder(FLAGS.dataset).info train_batch_size = FLAGS.per_core_batch_size * FLAGS.num_cores // FLAGS.batch_repetitions test_batch_size = FLAGS.per_core_batch_size * FLAGS.num_cores train_dataset_size = ds_info.splits['train'].num_examples steps_per_epoch = train_dataset_size // train_batch_size steps_per_eval = ds_info.splits['test'].num_examples // test_batch_size num_classes = ds_info.features['label'].num_classes train_dataset = utils.load_dataset( split=tfds.Split.TRAIN, name=FLAGS.dataset, batch_size=train_batch_size, use_bfloat16=FLAGS.use_bfloat16) clean_test_dataset = utils.load_dataset( split=tfds.Split.TEST, name=FLAGS.dataset, batch_size=test_batch_size, use_bfloat16=FLAGS.use_bfloat16) train_dataset = strategy.experimental_distribute_dataset(train_dataset) test_datasets = { 'clean': strategy.experimental_distribute_dataset(clean_test_dataset), } if FLAGS.corruptions_interval > 0: if FLAGS.dataset == 'cifar10': load_c_dataset = utils.load_cifar10_c else: load_c_dataset = functools.partial(utils.load_cifar100_c, path=FLAGS.cifar100_c_path) corruption_types, max_intensity = utils.load_corrupted_test_info( FLAGS.dataset) for corruption in corruption_types: for intensity in range(1, max_intensity + 1): dataset = load_c_dataset( corruption_name=corruption, corruption_intensity=intensity, batch_size=test_batch_size, use_bfloat16=FLAGS.use_bfloat16) test_datasets['{0}_{1}'.format(corruption, intensity)] = ( strategy.experimental_distribute_dataset(dataset)) if FLAGS.use_bfloat16: policy = tf.keras.mixed_precision.experimental.Policy('mixed_bfloat16') tf.keras.mixed_precision.experimental.set_policy(policy) summary_writer = tf.summary.create_file_writer( os.path.join(FLAGS.output_dir, 'summaries')) with strategy.scope(): logging.info('Building Keras model') model = cifar_model.wide_resnet( input_shape=[FLAGS.ensemble_size] + list(ds_info.features['image'].shape), depth=28, width_multiplier=FLAGS.width_multiplier, num_classes=num_classes, ensemble_size=FLAGS.ensemble_size) logging.info('Model input shape: %s', model.input_shape) logging.info('Model output shape: %s', model.output_shape) logging.info('Model number of weights: %s', model.count_params()) # Linearly scale learning rate and the decay epochs by vanilla settings. base_lr = FLAGS.base_learning_rate * train_batch_size / 128 lr_decay_epochs = [(int(start_epoch_str) * FLAGS.train_epochs) // 200 for start_epoch_str in FLAGS.lr_decay_epochs] lr_schedule = utils.LearningRateSchedule(steps_per_epoch, base_lr, FLAGS.lr_decay_ratio, lr_decay_epochs, FLAGS.lr_warmup_epochs) optimizer = tf.keras.optimizers.SGD( lr_schedule, momentum=0.9, nesterov=True) metrics = { 'train/negative_log_likelihood': tf.keras.metrics.Mean(), 'train/accuracy': tf.keras.metrics.SparseCategoricalAccuracy(), 'train/loss': tf.keras.metrics.Mean(), 'train/ece': um.ExpectedCalibrationError(num_bins=FLAGS.num_bins), 'test/negative_log_likelihood': tf.keras.metrics.Mean(), 'test/accuracy': tf.keras.metrics.SparseCategoricalAccuracy(), 'test/ece': um.ExpectedCalibrationError(num_bins=FLAGS.num_bins), } if FLAGS.corruptions_interval > 0: corrupt_metrics = {} for intensity in range(1, max_intensity + 1): for corruption in corruption_types: dataset_name = '{0}_{1}'.format(corruption, intensity) corrupt_metrics['test/nll_{}'.format(dataset_name)] = ( tf.keras.metrics.Mean()) corrupt_metrics['test/accuracy_{}'.format(dataset_name)] = ( tf.keras.metrics.SparseCategoricalAccuracy()) corrupt_metrics['test/ece_{}'.format(dataset_name)] = ( um.ExpectedCalibrationError(num_bins=FLAGS.num_bins)) for i in range(FLAGS.ensemble_size): metrics['test/nll_member_{}'.format(i)] = tf.keras.metrics.Mean() metrics['test/accuracy_member_{}'.format(i)] = ( tf.keras.metrics.SparseCategoricalAccuracy()) test_diversity = { 'test/disagreement': tf.keras.metrics.Mean(), 'test/average_kl': tf.keras.metrics.Mean(), 'test/cosine_similarity': tf.keras.metrics.Mean(), } checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer) latest_checkpoint = tf.train.latest_checkpoint(FLAGS.output_dir) initial_epoch = 0 if latest_checkpoint: # checkpoint.restore must be within a strategy.scope() so that optimizer # slot variables are mirrored. checkpoint.restore(latest_checkpoint) logging.info('Loaded checkpoint %s', latest_checkpoint) initial_epoch = optimizer.iterations.numpy() // steps_per_epoch @tf.function def train_step(iterator): """Training StepFn.""" def step_fn(inputs): """Per-Replica StepFn.""" images, labels = inputs batch_size = tf.shape(images)[0] main_shuffle = tf.random.shuffle(tf.tile( tf.range(batch_size), [FLAGS.batch_repetitions])) to_shuffle = tf.cast(tf.cast(tf.shape(main_shuffle)[0], tf.float32) * (1. - FLAGS.input_repetition_probability), tf.int32) shuffle_indices = [ tf.concat([tf.random.shuffle(main_shuffle[:to_shuffle]), main_shuffle[to_shuffle:]], axis=0) for _ in range(FLAGS.ensemble_size)] images = tf.stack([tf.gather(images, indices, axis=0) for indices in shuffle_indices], axis=1) labels = tf.stack([tf.gather(labels, indices, axis=0) for indices in shuffle_indices], axis=1) with tf.GradientTape() as tape: logits = model(images, training=True) if FLAGS.use_bfloat16: logits = tf.cast(logits, tf.float32) negative_log_likelihood = tf.reduce_mean(tf.reduce_sum( tf.keras.losses.sparse_categorical_crossentropy( labels, logits, from_logits=True), axis=1)) filtered_variables = [] for var in model.trainable_variables: # Apply l2 on the BN parameters and bias terms. if ('kernel' in var.name or 'batch_norm' in var.name or 'bias' in var.name): filtered_variables.append(tf.reshape(var, (-1,))) l2_loss = FLAGS.l2 * 2 * tf.nn.l2_loss( tf.concat(filtered_variables, axis=0)) # Scale the loss given the TPUStrategy will reduce sum all gradients. loss = negative_log_likelihood + l2_loss scaled_loss = loss / strategy.num_replicas_in_sync grads = tape.gradient(scaled_loss, model.trainable_variables) optimizer.apply_gradients(zip(grads, model.trainable_variables)) probs = tf.nn.softmax(tf.reshape(logits, [-1, num_classes])) flat_labels = tf.reshape(labels, [-1]) metrics['train/ece'].update_state(flat_labels, probs) metrics['train/loss'].update_state(loss) metrics['train/negative_log_likelihood'].update_state( negative_log_likelihood) metrics['train/accuracy'].update_state(flat_labels, probs) strategy.run(step_fn, args=(next(iterator),)) @tf.function def test_step(iterator, dataset_name): """Evaluation StepFn.""" def step_fn(inputs): """Per-Replica StepFn.""" images, labels = inputs images = tf.tile( tf.expand_dims(images, 1), [1, FLAGS.ensemble_size, 1, 1, 1]) logits = model(images, training=False) if FLAGS.use_bfloat16: logits = tf.cast(logits, tf.float32) probs = tf.nn.softmax(logits) if dataset_name == 'clean': per_probs = tf.transpose(probs, perm=[1, 0, 2]) diversity_results = um.average_pairwise_diversity( per_probs, FLAGS.ensemble_size) for k, v in diversity_results.items(): test_diversity['test/' + k].update_state(v) for i in range(FLAGS.ensemble_size): member_probs = probs[:, i] member_loss = tf.keras.losses.sparse_categorical_crossentropy( labels, member_probs) metrics['test/nll_member_{}'.format(i)].update_state(member_loss) metrics['test/accuracy_member_{}'.format(i)].update_state( labels, member_probs) # Negative log marginal likelihood computed in a numerically-stable way. labels_tiled = tf.tile( tf.expand_dims(labels, 1), [1, FLAGS.ensemble_size]) log_likelihoods = -tf.keras.losses.sparse_categorical_crossentropy( labels_tiled, logits, from_logits=True) negative_log_likelihood = tf.reduce_mean( -tf.reduce_logsumexp(log_likelihoods, axis=[1]) + tf.math.log(float(FLAGS.ensemble_size))) probs = tf.math.reduce_mean(probs, axis=1) # marginalize if dataset_name == 'clean': metrics['test/negative_log_likelihood'].update_state( negative_log_likelihood) metrics['test/accuracy'].update_state(labels, probs) metrics['test/ece'].update_state(labels, probs) else: corrupt_metrics['test/nll_{}'.format(dataset_name)].update_state( negative_log_likelihood) corrupt_metrics['test/accuracy_{}'.format(dataset_name)].update_state( labels, probs) corrupt_metrics['test/ece_{}'.format(dataset_name)].update_state( labels, probs) strategy.run(step_fn, args=(next(iterator),)) metrics.update({'test/ms_per_example': tf.keras.metrics.Mean()}) train_iterator = iter(train_dataset) start_time = time.time() for epoch in range(initial_epoch, FLAGS.train_epochs): logging.info('Starting to run epoch: %s', epoch) for step in range(steps_per_epoch): train_step(train_iterator) current_step = epoch * steps_per_epoch + (step + 1) max_steps = steps_per_epoch * (FLAGS.train_epochs) time_elapsed = time.time() - start_time steps_per_sec = float(current_step) / time_elapsed eta_seconds = (max_steps - current_step) / steps_per_sec message = ('{:.1%} completion: epoch {:d}/{:d}. {:.1f} steps/s. ' 'ETA: {:.0f} min. Time elapsed: {:.0f} min'.format( current_step / max_steps, epoch + 1, FLAGS.train_epochs, steps_per_sec, eta_seconds / 60, time_elapsed / 60)) if step % 20 == 0: logging.info(message) datasets_to_evaluate = {'clean': test_datasets['clean']} if (FLAGS.corruptions_interval > 0 and (epoch + 1) % FLAGS.corruptions_interval == 0): datasets_to_evaluate = test_datasets for dataset_name, test_dataset in datasets_to_evaluate.items(): test_iterator = iter(test_dataset) logging.info('Testing on dataset %s', dataset_name) for step in range(steps_per_eval): if step % 20 == 0: logging.info('Starting to run eval step %s of epoch: %s', step, epoch) test_start_time = time.time() test_step(test_iterator, dataset_name) ms_per_example = (time.time() - test_start_time) * 1e6 / test_batch_size metrics['test/ms_per_example'].update_state(ms_per_example) logging.info('Done with testing on %s', dataset_name) corrupt_results = {} if (FLAGS.corruptions_interval > 0 and (epoch + 1) % FLAGS.corruptions_interval == 0): corrupt_results = utils.aggregate_corrupt_metrics(corrupt_metrics, corruption_types, max_intensity) logging.info('Train Loss: %.4f, Accuracy: %.2f%%', metrics['train/loss'].result(), metrics['train/accuracy'].result() * 100) logging.info('Test NLL: %.4f, Accuracy: %.2f%%', metrics['test/negative_log_likelihood'].result(), metrics['test/accuracy'].result() * 100) for i in range(FLAGS.ensemble_size): logging.info( 'Member %d Test Loss: %.4f, Accuracy: %.2f%%', i, metrics['test/nll_member_{}'.format(i)].result(), metrics['test/accuracy_member_{}'.format(i)].result() * 100) metrics.update(test_diversity) total_results = {name: metric.result() for name, metric in metrics.items()} total_results.update(corrupt_results) with summary_writer.as_default(): for name, result in total_results.items(): tf.summary.scalar(name, result, step=epoch + 1) for metric in metrics.values(): metric.reset_states() if (FLAGS.checkpoint_interval > 0 and (epoch + 1) % FLAGS.checkpoint_interval == 0): checkpoint_name = checkpoint.save( os.path.join(FLAGS.output_dir, 'checkpoint')) logging.info('Saved checkpoint to %s', checkpoint_name) final_checkpoint_name = checkpoint.save( os.path.join(FLAGS.output_dir, 'checkpoint')) logging.info('Saved last checkpoint to %s', final_checkpoint_name)
def main(argv): del argv # Unused arg. tf.random.set_seed(FLAGS.seed) if FLAGS.version2: per_core_bs_train = FLAGS.per_core_batch_size // (FLAGS.ensemble_size * FLAGS.num_train_samples) per_core_bs_eval = FLAGS.per_core_batch_size // (FLAGS.ensemble_size * FLAGS.num_eval_samples) else: per_core_bs_train = FLAGS.per_core_batch_size // FLAGS.num_train_samples per_core_bs_eval = FLAGS.per_core_batch_size // FLAGS.num_eval_samples batch_size_train = per_core_bs_train * FLAGS.num_cores batch_size_eval = per_core_bs_eval * FLAGS.num_cores logging.info('Saving checkpoints at %s', FLAGS.output_dir) if FLAGS.use_gpu: logging.info('Use GPU') strategy = tf.distribute.MirroredStrategy() else: logging.info('Use TPU at %s', FLAGS.tpu if FLAGS.tpu is not None else 'local') resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu=FLAGS.tpu) tf.config.experimental_connect_to_cluster(resolver) tf.tpu.experimental.initialize_tpu_system(resolver) strategy = tf.distribute.TPUStrategy(resolver) train_dataset = utils.load_dataset( split=tfds.Split.TRAIN, name=FLAGS.dataset, batch_size=batch_size_train, use_bfloat16=FLAGS.use_bfloat16, normalize=False) clean_test_dataset = utils.load_dataset( split=tfds.Split.TEST, name=FLAGS.dataset, batch_size=batch_size_eval, use_bfloat16=FLAGS.use_bfloat16, normalize=False) train_dataset = strategy.experimental_distribute_dataset(train_dataset) test_datasets = { 'clean': strategy.experimental_distribute_dataset(clean_test_dataset), } if FLAGS.corruptions_interval > 0: if FLAGS.dataset == 'cifar10': load_c_dataset = utils.load_cifar10_c else: load_c_dataset = functools.partial(utils.load_cifar100_c, path=FLAGS.cifar100_c_path) corruption_types, max_intensity = utils.load_corrupted_test_info( FLAGS.dataset) for corruption in corruption_types: for intensity in range(1, max_intensity + 1): dataset = load_c_dataset( corruption_name=corruption, corruption_intensity=intensity, batch_size=batch_size_eval, use_bfloat16=FLAGS.use_bfloat16, normalize=False) test_datasets['{0}_{1}'.format(corruption, intensity)] = ( strategy.experimental_distribute_dataset(dataset)) ds_info = tfds.builder(FLAGS.dataset).info train_dataset_size = ds_info.splits['train'].num_examples test_dataset_size = ds_info.splits['test'].num_examples num_classes = ds_info.features['label'].num_classes steps_per_epoch = train_dataset_size // batch_size_train steps_per_eval = test_dataset_size // batch_size_eval if FLAGS.use_bfloat16: policy = tf.keras.mixed_precision.experimental.Policy('mixed_bfloat16') tf.keras.mixed_precision.experimental.set_policy(policy) summary_writer = tf.summary.create_file_writer( os.path.join(FLAGS.output_dir, 'summaries')) with strategy.scope(): logging.info('Building Keras ResNet-32 model') model = resnet_cifar_model.rank1_resnet_v1( input_shape=ds_info.features['image'].shape, depth=32, num_classes=num_classes, width_multiplier=4, alpha_initializer=FLAGS.alpha_initializer, gamma_initializer=FLAGS.gamma_initializer, alpha_regularizer=FLAGS.alpha_regularizer, gamma_regularizer=FLAGS.gamma_regularizer, use_additive_perturbation=FLAGS.use_additive_perturbation, ensemble_size=FLAGS.ensemble_size, random_sign_init=FLAGS.random_sign_init, dropout_rate=FLAGS.dropout_rate) logging.info(model.summary()) base_lr = FLAGS.base_learning_rate * batch_size_train / 128 lr_decay_epochs = [(int(start_epoch_str) * FLAGS.train_epochs) // 200 for start_epoch_str in FLAGS.lr_decay_epochs] lr_schedule = utils.LearningRateSchedule( steps_per_epoch, base_lr, decay_ratio=FLAGS.lr_decay_ratio, decay_epochs=lr_decay_epochs, warmup_epochs=FLAGS.lr_warmup_epochs) optimizer = tf.keras.optimizers.SGD( lr_schedule, momentum=0.9, nesterov=True) metrics = { 'train/negative_log_likelihood': tf.keras.metrics.Mean(), 'train/accuracy': tf.keras.metrics.SparseCategoricalAccuracy(), 'train/loss': tf.keras.metrics.Mean(), 'train/ece': um.ExpectedCalibrationError(num_bins=FLAGS.num_bins), 'test/negative_log_likelihood': tf.keras.metrics.Mean(), 'test/accuracy': tf.keras.metrics.SparseCategoricalAccuracy(), 'test/ece': um.ExpectedCalibrationError(num_bins=FLAGS.num_bins), 'test/loss': tf.keras.metrics.Mean(), } if FLAGS.corruptions_interval > 0: corrupt_metrics = {} for intensity in range(1, max_intensity + 1): for corruption in corruption_types: dataset_name = '{0}_{1}'.format(corruption, intensity) corrupt_metrics['test/nll_{}'.format(dataset_name)] = ( tf.keras.metrics.Mean()) corrupt_metrics['test/accuracy_{}'.format(dataset_name)] = ( tf.keras.metrics.SparseCategoricalAccuracy()) corrupt_metrics['test/ece_{}'.format(dataset_name)] = ( um.ExpectedCalibrationError(num_bins=FLAGS.num_bins)) test_diversity = {} training_diversity = {} if FLAGS.ensemble_size > 1: for i in range(FLAGS.ensemble_size): metrics['test/nll_member_{}'.format(i)] = tf.keras.metrics.Mean() metrics['test/accuracy_member_{}'.format(i)] = ( tf.keras.metrics.SparseCategoricalAccuracy()) test_diversity = { 'test/disagreement': tf.keras.metrics.Mean(), 'test/average_kl': tf.keras.metrics.Mean(), 'test/cosine_similarity': tf.keras.metrics.Mean(), } training_diversity = { 'train/disagreement': tf.keras.metrics.Mean(), 'train/average_kl': tf.keras.metrics.Mean(), 'train/cosine_similarity': tf.keras.metrics.Mean(), } checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer) latest_checkpoint = tf.train.latest_checkpoint(FLAGS.output_dir) initial_epoch = 0 if latest_checkpoint: # checkpoint.restore must be within a strategy.scope() so that optimizer # slot variables are mirrored. checkpoint.restore(latest_checkpoint) logging.info('Loaded checkpoint %s', latest_checkpoint) initial_epoch = optimizer.iterations.numpy() // steps_per_epoch @tf.function def train_step(iterator): """Training StepFn.""" def step_fn(inputs): """Per-Replica StepFn.""" images, labels = inputs if FLAGS.version2 and FLAGS.ensemble_size > 1: images = tf.tile(images, [FLAGS.ensemble_size, 1, 1, 1]) if not (FLAGS.member_sampling or FLAGS.expected_probs): labels = tf.tile(labels, [FLAGS.ensemble_size]) if FLAGS.num_train_samples > 1: images = tf.tile(images, [FLAGS.num_train_samples, 1, 1, 1]) with tf.GradientTape() as tape: logits = model(images, training=True) probs = tf.nn.softmax(logits) # Diversity evaluation. if FLAGS.version2 and FLAGS.ensemble_size > 1: per_probs = tf.reshape( probs, tf.concat([[FLAGS.ensemble_size, -1], probs.shape[1:]], 0)) diversity_results = um.average_pairwise_diversity( per_probs, FLAGS.ensemble_size) if FLAGS.num_train_samples > 1: probs = tf.reshape(probs, tf.concat([[FLAGS.num_train_samples, -1], probs.shape[1:]], 0)) probs = tf.reduce_mean(probs, 0) if FLAGS.member_sampling and FLAGS.version2 and FLAGS.ensemble_size > 1: idx = tf.random.uniform([], maxval=FLAGS.ensemble_size, dtype=tf.int64) idx_one_hot = tf.expand_dims(tf.one_hot(idx, FLAGS.ensemble_size, dtype=probs.dtype), 0) probs_shape = probs.shape probs = tf.reshape(probs, [FLAGS.ensemble_size, -1]) probs = tf.matmul(idx_one_hot, probs) probs = tf.reshape(probs, tf.concat([[-1], probs_shape[1:]], 0)) elif FLAGS.expected_probs and FLAGS.version2 and FLAGS.ensemble_size > 1: probs = tf.reshape(probs, tf.concat([[FLAGS.ensemble_size, -1], probs.shape[1:]], 0)) probs = tf.reduce_mean(probs, 0) negative_log_likelihood = tf.reduce_mean( tf.keras.losses.sparse_categorical_crossentropy(labels, probs)) filtered_variables = [] for var in model.trainable_variables: # Apply l2 on the slow weights and bias terms. This excludes BN # parameters and fast weight approximate posterior/prior parameters, # but pay caution to their naming scheme. if 'kernel' in var.name or 'bias' in var.name: filtered_variables.append(tf.reshape(var, (-1,))) l2_loss = FLAGS.l2 * 2 * tf.nn.l2_loss( tf.concat(filtered_variables, axis=0)) kl = sum(model.losses) / train_dataset_size kl_scale = tf.cast(optimizer.iterations + 1, kl.dtype) kl_scale /= FLAGS.kl_annealing_steps kl_scale = tf.minimum(1., kl_scale) kl_loss = kl_scale * kl # Scale the loss given the TPUStrategy will reduce sum all gradients. loss = negative_log_likelihood + l2_loss + kl_loss scaled_loss = loss / strategy.num_replicas_in_sync grads = tape.gradient(scaled_loss, model.trainable_variables) # Separate learning rate implementation. grad_list = [] if FLAGS.fast_weight_lr_multiplier != 1.0: grads_and_vars = list(zip(grads, model.trainable_variables)) for vec, var in grads_and_vars: # Apply different learning rate on the fast weight approximate # posterior/prior parameters. This is excludes BN and slow weights, # but pay caution to the naming scheme. if ('batch_norm' not in var.name and 'kernel' not in var.name): grad_list.append((vec * FLAGS.fast_weight_lr_multiplier, var)) else: grad_list.append((vec, var)) optimizer.apply_gradients(grad_list) else: optimizer.apply_gradients(zip(grads, model.trainable_variables)) metrics['train/ece'].update_state(labels, probs) metrics['train/loss'].update_state(loss) metrics['train/negative_log_likelihood'].update_state( negative_log_likelihood) metrics['train/accuracy'].update_state(labels, probs) if FLAGS.version2 and FLAGS.ensemble_size > 1: for k, v in diversity_results.items(): training_diversity['train/' + k].update_state(v) strategy.run(step_fn, args=(next(iterator),)) @tf.function def test_step(iterator, dataset_name): """Evaluation StepFn.""" def step_fn(inputs): """Per-Replica StepFn.""" images, labels = inputs if FLAGS.ensemble_size > 1: images = tf.tile(images, [FLAGS.ensemble_size, 1, 1, 1]) if FLAGS.num_eval_samples > 1: images = tf.tile(images, [FLAGS.num_eval_samples, 1, 1, 1]) logits = model(images, training=False) probs = tf.nn.softmax(logits) if FLAGS.num_eval_samples > 1: probs = tf.reshape(probs, tf.concat([[FLAGS.num_eval_samples, -1], probs.shape[1:]], 0)) probs = tf.reduce_mean(probs, 0) if FLAGS.ensemble_size > 1: per_probs = tf.split(probs, num_or_size_splits=FLAGS.ensemble_size, axis=0) if dataset_name == 'clean': per_probs_tensor = tf.reshape( probs, tf.concat([[FLAGS.ensemble_size, -1], probs.shape[1:]], 0)) diversity_results = um.average_pairwise_diversity( per_probs_tensor, FLAGS.ensemble_size) for k, v in diversity_results.items(): test_diversity['test/' + k].update_state(v) for i in range(FLAGS.ensemble_size): member_probs = per_probs[i] member_nll = tf.keras.losses.sparse_categorical_crossentropy( labels, member_probs) metrics['test/nll_member_{}'.format(i)].update_state(member_nll) metrics['test/accuracy_member_{}'.format(i)].update_state( labels, member_probs) probs = tf.reduce_mean(per_probs, axis=0) negative_log_likelihood = tf.reduce_mean( tf.keras.losses.sparse_categorical_crossentropy(labels, probs)) filtered_variables = [] for var in model.trainable_variables: if 'kernel' in var.name or 'bias' in var.name: filtered_variables.append(tf.reshape(var, (-1,))) kl = sum(model.losses) / test_dataset_size l2_loss = kl + FLAGS.l2 * 2 * tf.nn.l2_loss( tf.concat(filtered_variables, axis=0)) loss = negative_log_likelihood + l2_loss if dataset_name == 'clean': metrics['test/negative_log_likelihood'].update_state( negative_log_likelihood) metrics['test/accuracy'].update_state(labels, probs) metrics['test/ece'].update_state(labels, probs) metrics['test/loss'].update_state(loss) else: corrupt_metrics['test/nll_{}'.format(dataset_name)].update_state( negative_log_likelihood) corrupt_metrics['test/accuracy_{}'.format(dataset_name)].update_state( labels, probs) corrupt_metrics['test/ece_{}'.format(dataset_name)].update_state( labels, probs) strategy.run(step_fn, args=(next(iterator),)) train_iterator = iter(train_dataset) start_time = time.time() for epoch in range(initial_epoch, FLAGS.train_epochs): logging.info('Starting to run epoch: %s', epoch) for step in range(steps_per_epoch): train_step(train_iterator) current_step = epoch * steps_per_epoch + (step + 1) max_steps = steps_per_epoch * FLAGS.train_epochs time_elapsed = time.time() - start_time steps_per_sec = float(current_step) / time_elapsed eta_seconds = (max_steps - current_step) / steps_per_sec message = ('{:.1%} completion: epoch {:d}/{:d}. {:.1f} steps/s. ' 'ETA: {:.0f} min. Time elapsed: {:.0f} min'.format( current_step / max_steps, epoch + 1, FLAGS.train_epochs, steps_per_sec, eta_seconds / 60, time_elapsed / 60)) work_unit.set_notes(message) if step % 20 == 0: logging.info(message) datasets_to_evaluate = {'clean': test_datasets['clean']} if (FLAGS.corruptions_interval > 0 and (epoch + 1) % FLAGS.corruptions_interval == 0): datasets_to_evaluate = test_datasets for dataset_name, test_dataset in datasets_to_evaluate.items(): test_iterator = iter(test_dataset) logging.info('Testing on dataset %s', dataset_name) for step in range(steps_per_eval): if step % 20 == 0: logging.info('Starting to run eval step %s of epoch: %s', step, epoch) test_step(test_iterator, dataset_name) logging.info('Done with testing on %s', dataset_name) corrupt_results = {} if (FLAGS.corruptions_interval > 0 and (epoch + 1) % FLAGS.corruptions_interval == 0): corrupt_results = utils.aggregate_corrupt_metrics(corrupt_metrics, corruption_types, max_intensity) logging.info('Train Loss: %.4f, Accuracy: %.2f%%', metrics['train/loss'].result(), metrics['train/accuracy'].result() * 100) logging.info('Test NLL: %.4f, Accuracy: %.2f%%', metrics['test/negative_log_likelihood'].result(), metrics['test/accuracy'].result() * 100) for i in range(FLAGS.ensemble_size): logging.info('Member %d Test Loss: %.4f, Accuracy: %.2f%%', i, metrics['test/nll_member_{}'.format(i)].result(), metrics['test/accuracy_member_{}'.format(i)].result() * 100) total_metrics = itertools.chain(metrics.items(), training_diversity.items(), test_diversity.items()) total_results = {name: metric.result() for name, metric in total_metrics} total_results.update(corrupt_results) with summary_writer.as_default(): for name, result in total_results.items(): tf.summary.scalar(name, result, step=epoch + 1) for name, result in total_results.items(): name = name.replace('/', '_') if 'negative_log_likelihood' in name: # Plots sort WIDs from high-to-low so look at maximization objectives. name = name.replace('negative_log_likelihood', 'log_likelihood') result = -result objective = work_unit.get_measurement_series(name) objective.create_measurement(result, epoch + 1) for _, metric in total_metrics: metric.reset_states() summary_writer.flush() if (FLAGS.checkpoint_interval > 0 and (epoch + 1) % FLAGS.checkpoint_interval == 0): checkpoint_name = checkpoint.save( os.path.join(FLAGS.output_dir, 'checkpoint')) logging.info('Saved checkpoint to %s', checkpoint_name)
def main(argv): del argv # unused arg tf.io.gfile.makedirs(FLAGS.output_dir) logging.info('Saving checkpoints at %s', FLAGS.output_dir) tf.random.set_seed(FLAGS.seed) if FLAGS.use_gpu: logging.info('Use GPU') strategy = tf.distribute.MirroredStrategy() else: logging.info('Use TPU at %s', FLAGS.tpu if FLAGS.tpu is not None else 'local') resolver = tf.distribute.cluster_resolver.TPUClusterResolver( tpu=FLAGS.tpu) tf.config.experimental_connect_to_cluster(resolver) tf.tpu.experimental.initialize_tpu_system(resolver) strategy = tf.distribute.TPUStrategy(resolver) ds_info = tfds.builder(FLAGS.dataset).info batch_size = ((FLAGS.per_core_batch_size // FLAGS.ensemble_size) * FLAGS.num_cores) train_dataset_size = ds_info.splits['train'].num_examples steps_per_epoch = train_dataset_size // batch_size test_dataset_size = ds_info.splits['test'].num_examples steps_per_eval = test_dataset_size // batch_size num_classes = ds_info.features['label'].num_classes train_dataset = utils.load_dataset(split=tfds.Split.TRAIN, name=FLAGS.dataset, batch_size=batch_size, use_bfloat16=FLAGS.use_bfloat16) clean_test_dataset = utils.load_dataset(split=tfds.Split.TEST, name=FLAGS.dataset, batch_size=batch_size, use_bfloat16=FLAGS.use_bfloat16) train_dataset = strategy.experimental_distribute_dataset(train_dataset) test_datasets = { 'clean': strategy.experimental_distribute_dataset(clean_test_dataset), } if FLAGS.corruptions_interval > 0: if FLAGS.dataset == 'cifar10': load_c_dataset = utils.load_cifar10_c else: load_c_dataset = functools.partial(utils.load_cifar100_c, path=FLAGS.cifar100_c_path) corruption_types, max_intensity = utils.load_corrupted_test_info( FLAGS.dataset) for corruption in corruption_types: for intensity in range(1, max_intensity + 1): dataset = load_c_dataset(corruption_name=corruption, corruption_intensity=intensity, batch_size=batch_size, use_bfloat16=FLAGS.use_bfloat16) test_datasets['{0}_{1}'.format(corruption, intensity)] = ( strategy.experimental_distribute_dataset(dataset)) if FLAGS.use_bfloat16: policy = tf.keras.mixed_precision.experimental.Policy('mixed_bfloat16') tf.keras.mixed_precision.experimental.set_policy(policy) summary_writer = tf.summary.create_file_writer( os.path.join(FLAGS.output_dir, 'summaries')) with strategy.scope(): logging.info('Building Keras model') model = cifar_model.wide_resnet( input_shape=ds_info.features['image'].shape, depth=28, width_multiplier=10, num_classes=num_classes, alpha_initializer=FLAGS.alpha_initializer, gamma_initializer=FLAGS.gamma_initializer, alpha_regularizer=FLAGS.alpha_regularizer, gamma_regularizer=FLAGS.gamma_regularizer, use_additive_perturbation=FLAGS.use_additive_perturbation, ensemble_size=FLAGS.ensemble_size, random_sign_init=FLAGS.random_sign_init, dropout_rate=FLAGS.dropout_rate, prior_mean=FLAGS.prior_mean, prior_stddev=FLAGS.prior_stddev) logging.info('Model input shape: %s', model.input_shape) logging.info('Model output shape: %s', model.output_shape) logging.info('Model number of weights: %s', model.count_params()) # Linearly scale learning rate and the decay epochs by vanilla settings. base_lr = FLAGS.base_learning_rate * batch_size / 128 lr_decay_epochs = [(int(start_epoch_str) * FLAGS.train_epochs) // 200 for start_epoch_str in FLAGS.lr_decay_epochs] lr_schedule = refining.LearningRateScheduleWithRefining( steps_per_epoch, base_lr, decay_ratio=FLAGS.lr_decay_ratio, decay_epochs=lr_decay_epochs, warmup_epochs=FLAGS.lr_warmup_epochs, train_epochs=FLAGS.train_epochs, refining_learning_rate=FLAGS.refining_learning_rate) optimizer = tf.keras.optimizers.SGD(lr_schedule, momentum=0.9, nesterov=True) metrics = { 'train/negative_log_likelihood': tf.keras.metrics.Mean(), 'train/kl': tf.keras.metrics.Mean(), 'train/kl_scale': tf.keras.metrics.Mean(), 'train/elbo': tf.keras.metrics.Mean(), 'train/loss': tf.keras.metrics.Mean(), 'train/accuracy': tf.keras.metrics.SparseCategoricalAccuracy(), 'train/ece': um.ExpectedCalibrationError(num_bins=FLAGS.num_bins), 'test/negative_log_likelihood': tf.keras.metrics.Mean(), 'test/kl': tf.keras.metrics.Mean(), 'test/elbo': tf.keras.metrics.Mean(), 'test/accuracy': tf.keras.metrics.SparseCategoricalAccuracy(), 'test/ece': um.ExpectedCalibrationError(num_bins=FLAGS.num_bins), } if FLAGS.ensemble_size > 1: for i in range(FLAGS.ensemble_size): metrics['test/nll_member_{}'.format( i)] = tf.keras.metrics.Mean() metrics['test/accuracy_member_{}'.format(i)] = ( tf.keras.metrics.SparseCategoricalAccuracy()) if FLAGS.corruptions_interval > 0: corrupt_metrics = {} for intensity in range(1, max_intensity + 1): for corruption in corruption_types: dataset_name = '{0}_{1}'.format(corruption, intensity) corrupt_metrics['test/nll_{}'.format(dataset_name)] = ( tf.keras.metrics.Mean()) corrupt_metrics['test/kl_{}'.format(dataset_name)] = ( tf.keras.metrics.Mean()) corrupt_metrics['test/elbo_{}'.format(dataset_name)] = ( tf.keras.metrics.Mean()) corrupt_metrics['test/accuracy_{}'.format( dataset_name)] = ( tf.keras.metrics.SparseCategoricalAccuracy()) corrupt_metrics['test/ece_{}'.format(dataset_name)] = ( um.ExpectedCalibrationError(num_bins=FLAGS.num_bins)) checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer) latest_checkpoint = tf.train.latest_checkpoint(FLAGS.output_dir) initial_epoch = 0 if latest_checkpoint: # checkpoint.restore must be within a strategy.scope() so that optimizer # slot variables are mirrored. checkpoint.restore(latest_checkpoint) logging.info('Loaded checkpoint %s', latest_checkpoint) initial_epoch = optimizer.iterations.numpy() // steps_per_epoch def compute_l2_loss(model): filtered_variables = [] for var in model.trainable_variables: # Apply l2 on the BN parameters and bias terms. This # excludes only fast weight approximate posterior/prior parameters, # but pay caution to their naming scheme. if ('kernel' in var.name or 'batch_norm' in var.name or 'bias' in var.name): filtered_variables.append(tf.reshape(var, (-1, ))) l2_loss = FLAGS.l2 * 2 * tf.nn.l2_loss( tf.concat(filtered_variables, axis=0)) return l2_loss @tf.function def train_step(iterator): """Training StepFn.""" def step_fn(inputs): """Per-Replica StepFn.""" images, labels = inputs if FLAGS.ensemble_size > 1: images = tf.tile(images, [FLAGS.ensemble_size, 1, 1, 1]) labels = tf.tile(labels, [FLAGS.ensemble_size]) with tf.GradientTape() as tape: logits = model(images, training=True) if FLAGS.use_bfloat16: logits = tf.cast(logits, tf.float32) negative_log_likelihood = tf.reduce_mean( tf.keras.losses.sparse_categorical_crossentropy( labels, logits, from_logits=True)) l2_loss = compute_l2_loss(model) kl = sum(model.losses) / train_dataset_size kl_scale = tf.cast(optimizer.iterations + 1, kl.dtype) kl_scale /= steps_per_epoch * FLAGS.kl_annealing_epochs kl_scale = tf.minimum(1., kl_scale) kl_loss = kl_scale * kl # Scale the loss given the TPUStrategy will reduce sum all gradients. loss = negative_log_likelihood + l2_loss + kl_loss scaled_loss = loss / strategy.num_replicas_in_sync elbo = -(negative_log_likelihood + l2_loss + kl) grads = tape.gradient(scaled_loss, model.trainable_variables) # Separate learning rate implementation. if FLAGS.fast_weight_lr_multiplier != 1.0: grads_and_vars = [] for grad, var in zip(grads, model.trainable_variables): # Apply different learning rate on the fast weight approximate # posterior/prior parameters. This is excludes BN and slow weights, # but pay caution to the naming scheme. if ('kernel' not in var.name and 'batch_norm' not in var.name and 'bias' not in var.name): grads_and_vars.append( (grad * FLAGS.fast_weight_lr_multiplier, var)) else: grads_and_vars.append((grad, var)) optimizer.apply_gradients(grads_and_vars) else: optimizer.apply_gradients(zip(grads, model.trainable_variables)) probs = tf.nn.softmax(logits) metrics['train/negative_log_likelihood'].update_state( negative_log_likelihood) metrics['train/kl'].update_state(kl) metrics['train/kl_scale'].update_state(kl_scale) metrics['train/elbo'].update_state(elbo) metrics['train/loss'].update_state(loss) metrics['train/accuracy'].update_state(labels, probs) metrics['train/ece'].update_state(labels, probs) strategy.run(step_fn, args=(next(iterator), )) @tf.function def test_step(iterator, dataset_name): """Evaluation StepFn.""" def step_fn(inputs): """Per-Replica StepFn.""" images, labels = inputs if FLAGS.ensemble_size > 1: images = tf.tile(images, [FLAGS.ensemble_size, 1, 1, 1]) logits = tf.reshape([ model(images, training=False) for _ in range(FLAGS.num_eval_samples) ], [FLAGS.num_eval_samples, FLAGS.ensemble_size, -1, num_classes]) if FLAGS.use_bfloat16: logits = tf.cast(logits, tf.float32) probs = tf.nn.softmax(logits) if FLAGS.ensemble_size > 1: per_probs = tf.reduce_mean(probs, axis=0) # marginalize samples for i in range(FLAGS.ensemble_size): member_probs = per_probs[i] member_loss = tf.keras.losses.sparse_categorical_crossentropy( labels, member_probs) metrics['test/nll_member_{}'.format(i)].update_state( member_loss) metrics['test/accuracy_member_{}'.format(i)].update_state( labels, member_probs) # Negative log marginal likelihood computed in a numerically-stable way. labels_broadcasted = tf.broadcast_to( labels, [FLAGS.num_eval_samples, FLAGS.ensemble_size, labels.shape[0]]) log_likelihoods = -tf.keras.losses.sparse_categorical_crossentropy( labels_broadcasted, logits, from_logits=True) negative_log_likelihood = tf.reduce_mean( -tf.reduce_logsumexp(log_likelihoods, axis=[0, 1]) + tf.math.log(float(FLAGS.num_eval_samples * FLAGS.ensemble_size))) probs = tf.math.reduce_mean(probs, axis=[0, 1]) # marginalize l2_loss = compute_l2_loss(model) kl = sum(model.losses) / test_dataset_size elbo = -(negative_log_likelihood + l2_loss + kl) if dataset_name == 'clean': metrics['test/negative_log_likelihood'].update_state( negative_log_likelihood) metrics['test/kl'].update_state(kl) metrics['test/elbo'].update_state(elbo) metrics['test/accuracy'].update_state(labels, probs) metrics['test/ece'].update_state(labels, probs) else: corrupt_metrics['test/nll_{}'.format( dataset_name)].update_state(negative_log_likelihood) corrupt_metrics['test/kl_{}'.format( dataset_name)].update_state(kl) corrupt_metrics['test/elbo_{}'.format( dataset_name)].update_state(elbo) corrupt_metrics['test/accuracy_{}'.format( dataset_name)].update_state(labels, probs) corrupt_metrics['test/ece_{}'.format( dataset_name)].update_state(labels, probs) strategy.run(step_fn, args=(next(iterator), )) train_iterator = iter(train_dataset) start_time = time.time() for epoch in range(initial_epoch, FLAGS.train_epochs + FLAGS.refining_epochs): logging.info('Starting to run epoch: %s', epoch) if epoch in np.linspace(FLAGS.train_epochs, FLAGS.train_epochs + FLAGS.refining_epochs, FLAGS.num_auxiliary_variables, dtype=int): logging.info('Sampling auxiliary variables with ratio %f', FLAGS.auxiliary_variance_ratio) refining.sample_rank1_auxiliaries(model, FLAGS.auxiliary_variance_ratio) if FLAGS.freeze_weights_during_refining: refining.freeze_rank1_weights(model) for step in range(steps_per_epoch): train_step(train_iterator) current_step = epoch * steps_per_epoch + (step + 1) max_steps = steps_per_epoch * (FLAGS.train_epochs + FLAGS.refining_epochs) time_elapsed = time.time() - start_time steps_per_sec = float(current_step) / time_elapsed eta_seconds = (max_steps - current_step) / steps_per_sec message = ('{:.1%} completion: epoch {:d}/{:d}. {:.1f} steps/s. ' 'ETA: {:.0f} min. Time elapsed: {:.0f} min'.format( current_step / max_steps, epoch + 1, FLAGS.train_epochs + FLAGS.refining_epochs, steps_per_sec, eta_seconds / 60, time_elapsed / 60)) if step % 20 == 0: logging.info(message) datasets_to_evaluate = {'clean': test_datasets['clean']} if (FLAGS.corruptions_interval > 0 and (epoch + 1) % FLAGS.corruptions_interval == 0): datasets_to_evaluate = test_datasets for dataset_name, test_dataset in datasets_to_evaluate.items(): test_iterator = iter(test_dataset) logging.info('Testing on dataset %s', dataset_name) for step in range(steps_per_eval): if step % 20 == 0: logging.info('Starting to run eval step %s of epoch: %s', step, epoch) test_step(test_iterator, dataset_name) logging.info('Done with testing on %s', dataset_name) corrupt_results = {} if (FLAGS.corruptions_interval > 0 and (epoch + 1) % FLAGS.corruptions_interval == 0): corrupt_results = utils.aggregate_corrupt_metrics( corrupt_metrics, corruption_types, max_intensity) logging.info('Train Loss: %.4f, Accuracy: %.2f%%', metrics['train/loss'].result(), metrics['train/accuracy'].result() * 100) logging.info('Test NLL: %.4f, Accuracy: %.2f%%', metrics['test/negative_log_likelihood'].result(), metrics['test/accuracy'].result() * 100) if FLAGS.ensemble_size > 1: for i in range(FLAGS.ensemble_size): logging.info( 'Member %d Test Loss: %.4f, Accuracy: %.2f%%', i, metrics['test/nll_member_{}'.format(i)].result(), metrics['test/accuracy_member_{}'.format(i)].result() * 100) total_results = { name: metric.result() for name, metric in metrics.items() } total_results.update(corrupt_results) with summary_writer.as_default(): for name, result in total_results.items(): tf.summary.scalar(name, result, step=epoch + 1) for metric in metrics.values(): metric.reset_states() if (FLAGS.checkpoint_interval > 0 and (epoch + 1) % FLAGS.checkpoint_interval == 0): checkpoint_name = checkpoint.save( os.path.join(FLAGS.output_dir, 'checkpoint')) logging.info('Saved checkpoint to %s', checkpoint_name) final_checkpoint_name = checkpoint.save( os.path.join(FLAGS.output_dir, 'checkpoint')) logging.info('Saved last checkpoint to %s', final_checkpoint_name)
def main(argv): del argv # unused arg tf.io.gfile.makedirs(FLAGS.output_dir) logging.info('Saving checkpoints at %s', FLAGS.output_dir) tf.random.set_seed(FLAGS.seed) if FLAGS.use_gpu: logging.info('Use GPU') strategy = tf.distribute.MirroredStrategy() else: logging.info('Use TPU at %s', FLAGS.tpu if FLAGS.tpu is not None else 'local') resolver = tf.distribute.cluster_resolver.TPUClusterResolver( tpu=FLAGS.tpu) tf.config.experimental_connect_to_cluster(resolver) tf.tpu.experimental.initialize_tpu_system(resolver) strategy = tf.distribute.TPUStrategy(resolver) aug_params = { 'augmix': FLAGS.augmix, 'aug_count': FLAGS.aug_count, 'augmix_depth': FLAGS.augmix_depth, 'augmix_prob_coeff': FLAGS.augmix_prob_coeff, 'augmix_width': FLAGS.augmix_width, 'ensemble_size': 1, 'mixup_alpha': FLAGS.mixup_alpha, 'adaptive_mixup': FLAGS.adaptive_mixup, 'random_augment': FLAGS.random_augment, 'forget_mixup': FLAGS.forget_mixup, 'num_cores': FLAGS.num_cores, 'threshold': FLAGS.forget_threshold, } batch_size = (FLAGS.per_core_batch_size * FLAGS.num_cores // FLAGS.num_dropout_samples_training) train_input_fn = data_utils.load_input_fn( split=tfds.Split.TRAIN, name=FLAGS.dataset, batch_size=batch_size, use_bfloat16=FLAGS.use_bfloat16, proportion=FLAGS.train_proportion, validation_set=FLAGS.validation, aug_params=aug_params) if FLAGS.validation: validation_input_fn = data_utils.load_input_fn( split=tfds.Split.VALIDATION, name=FLAGS.dataset, batch_size=FLAGS.per_core_batch_size, use_bfloat16=FLAGS.use_bfloat16, validation_set=True) val_dataset = strategy.experimental_distribute_datasets_from_function( validation_input_fn) clean_test_dataset = utils.load_dataset( split=tfds.Split.TEST, name=FLAGS.dataset, batch_size=FLAGS.per_core_batch_size * FLAGS.num_cores, use_bfloat16=FLAGS.use_bfloat16) train_dataset = strategy.experimental_distribute_dataset(train_input_fn()) test_datasets = { 'clean': strategy.experimental_distribute_dataset(clean_test_dataset), } if FLAGS.corruptions_interval > 0: if FLAGS.dataset == 'cifar10': load_c_dataset = utils.load_cifar10_c else: load_c_dataset = functools.partial(utils.load_cifar100_c, path=FLAGS.cifar100_c_path) corruption_types, max_intensity = utils.load_corrupted_test_info( FLAGS.dataset) for corruption in corruption_types: for intensity in range(1, max_intensity + 1): dataset = load_c_dataset(corruption_name=corruption, corruption_intensity=intensity, batch_size=FLAGS.per_core_batch_size * FLAGS.num_cores, use_bfloat16=FLAGS.use_bfloat16) test_datasets['{0}_{1}'.format(corruption, intensity)] = ( strategy.experimental_distribute_dataset(dataset)) ds_info = tfds.builder(FLAGS.dataset).info batch_size = (FLAGS.per_core_batch_size * FLAGS.num_cores // FLAGS.num_dropout_samples_training) num_train_examples = ds_info.splits['train'].num_examples # Train_proportion is a float so need to convert steps_per_epoch to int. if FLAGS.validation: # TODO(ywenxu): Remove hard-coding validation images. steps_per_epoch = int( (num_train_examples * FLAGS.train_proportion - 2500) // batch_size) steps_per_val = 2500 // (FLAGS.per_core_batch_size * FLAGS.num_cores) else: steps_per_epoch = int( num_train_examples * FLAGS.train_proportion) // batch_size steps_per_eval = ds_info.splits['test'].num_examples // batch_size num_classes = ds_info.features['label'].num_classes if FLAGS.use_bfloat16: policy = tf.keras.mixed_precision.experimental.Policy('mixed_bfloat16') tf.keras.mixed_precision.experimental.set_policy(policy) summary_writer = tf.summary.create_file_writer( os.path.join(FLAGS.output_dir, 'summaries')) with strategy.scope(): logging.info('Building ResNet model') if FLAGS.use_spec_norm: logging.info('Use Spectral Normalization with norm bound %.2f', FLAGS.spec_norm_bound) if FLAGS.use_gp_layer: logging.info('Use GP layer with hidden units %d', FLAGS.gp_hidden_dim) model = ub.models.wide_resnet_sngp( input_shape=ds_info.features['image'].shape, batch_size=batch_size, depth=28, width_multiplier=10, num_classes=num_classes, l2=FLAGS.l2, use_mc_dropout=FLAGS.use_mc_dropout, dropout_rate=FLAGS.dropout_rate, use_gp_layer=FLAGS.use_gp_layer, gp_input_dim=FLAGS.gp_input_dim, gp_hidden_dim=FLAGS.gp_hidden_dim, gp_scale=FLAGS.gp_scale, gp_bias=FLAGS.gp_bias, gp_input_normalization=FLAGS.gp_input_normalization, gp_cov_discount_factor=FLAGS.gp_cov_discount_factor, gp_cov_ridge_penalty=FLAGS.gp_cov_ridge_penalty, use_spec_norm=FLAGS.use_spec_norm, spec_norm_iteration=FLAGS.spec_norm_iteration, spec_norm_bound=FLAGS.spec_norm_bound) logging.info('Model input shape: %s', model.input_shape) logging.info('Model output shape: %s', model.output_shape) logging.info('Model number of weights: %s', model.count_params()) # Linearly scale learning rate and the decay epochs by vanilla settings. base_lr = FLAGS.base_learning_rate * batch_size / 128 lr_decay_epochs = [(int(start_epoch_str) * FLAGS.train_epochs) // 200 for start_epoch_str in FLAGS.lr_decay_epochs] lr_schedule = utils.LearningRateSchedule( steps_per_epoch, base_lr, decay_ratio=FLAGS.lr_decay_ratio, decay_epochs=lr_decay_epochs, warmup_epochs=FLAGS.lr_warmup_epochs) optimizer = tf.keras.optimizers.SGD(lr_schedule, momentum=0.9, nesterov=True) metrics = { 'train/negative_log_likelihood': tf.keras.metrics.Mean(), 'train/accuracy': tf.keras.metrics.SparseCategoricalAccuracy(), 'train/loss': tf.keras.metrics.Mean(), 'train/ece': um.ExpectedCalibrationError(num_bins=FLAGS.num_bins), 'test/negative_log_likelihood': tf.keras.metrics.Mean(), 'test/accuracy': tf.keras.metrics.SparseCategoricalAccuracy(), 'test/ece': um.ExpectedCalibrationError(num_bins=FLAGS.num_bins), 'test/stddev': tf.keras.metrics.Mean(), } if FLAGS.corruptions_interval > 0: corrupt_metrics = {} for intensity in range(1, max_intensity + 1): for corruption in corruption_types: dataset_name = '{0}_{1}'.format(corruption, intensity) corrupt_metrics['test/nll_{}'.format(dataset_name)] = ( tf.keras.metrics.Mean()) corrupt_metrics['test/accuracy_{}'.format( dataset_name)] = ( tf.keras.metrics.SparseCategoricalAccuracy()) corrupt_metrics['test/ece_{}'.format(dataset_name)] = ( um.ExpectedCalibrationError(num_bins=FLAGS.num_bins)) corrupt_metrics['test/stddev_{}'.format(dataset_name)] = ( tf.keras.metrics.Mean()) checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer) latest_checkpoint = tf.train.latest_checkpoint(FLAGS.output_dir) initial_epoch = 0 if latest_checkpoint: # checkpoint.restore must be within a strategy.scope() so that optimizer # slot variables are mirrored. checkpoint.restore(latest_checkpoint) logging.info('Loaded checkpoint %s', latest_checkpoint) initial_epoch = optimizer.iterations.numpy() // steps_per_epoch @tf.function def train_step(iterator): """Training StepFn.""" def step_fn(inputs): """Per-Replica StepFn.""" if FLAGS.forget_mixup: images, labels, idx = inputs else: images, labels = inputs if FLAGS.augmix and FLAGS.aug_count >= 1: # Index 0 at augmix preprocessing is the unperturbed image. images = images[:, 1, ...] # This is for the case of combining AugMix and Mixup. if FLAGS.mixup_alpha > 0: labels = tf.split(labels, FLAGS.aug_count + 1, axis=0)[1] images = tf.tile(images, [FLAGS.num_dropout_samples_training, 1, 1, 1]) if FLAGS.mixup_alpha > 0: labels = tf.tile(labels, [FLAGS.num_dropout_samples_training, 1]) else: labels = tf.tile(labels, [FLAGS.num_dropout_samples_training]) with tf.GradientTape() as tape: logits, _ = model(images, training=True) if FLAGS.use_bfloat16: logits = tf.cast(logits, tf.float32) if FLAGS.mixup_alpha > 0: negative_log_likelihood = tf.reduce_mean( tf.keras.losses.categorical_crossentropy( labels, logits, from_logits=True)) else: negative_log_likelihood = tf.reduce_mean( tf.keras.losses.sparse_categorical_crossentropy( labels, logits, from_logits=True)) l2_loss = sum(model.losses) loss = negative_log_likelihood + l2_loss # Scale the loss given the TPUStrategy will reduce sum all gradients. scaled_loss = loss / strategy.num_replicas_in_sync grads = tape.gradient(scaled_loss, model.trainable_variables) optimizer.apply_gradients(zip(grads, model.trainable_variables)) probs = tf.nn.softmax(logits) if FLAGS.mixup_alpha > 0: labels = tf.argmax(labels, axis=-1) metrics['train/ece'].update_state(labels, probs) metrics['train/loss'].update_state(loss) metrics['train/negative_log_likelihood'].update_state( negative_log_likelihood) metrics['train/accuracy'].update_state(labels, logits) if FLAGS.forget_mixup: train_predictions = tf.argmax(probs, -1) labels = tf.cast(labels, train_predictions.dtype) # For each ensemble member (1 here), we accumulate the accuracy counts. accuracy_counts = tf.cast( tf.reshape((train_predictions == labels), [1, -1]), tf.float32) return accuracy_counts, idx if FLAGS.forget_mixup: return strategy.run(step_fn, args=(next(iterator), )) else: strategy.run(step_fn, args=(next(iterator), )) @tf.function def test_step(iterator, dataset_name): """Evaluation StepFn.""" def step_fn(inputs): """Per-Replica StepFn.""" images, labels = inputs logits_list = [] stddev_list = [] for _ in range(FLAGS.num_dropout_samples): logits, covmat = model(images, training=False) if FLAGS.use_bfloat16: logits = tf.cast(logits, tf.float32) logits = ed.layers.utils.mean_field_logits( logits, covmat, mean_field_factor=FLAGS.gp_mean_field_factor) stddev = tf.sqrt(tf.linalg.diag_part(covmat)) stddev_list.append(stddev) logits_list.append(logits) # Logits dimension is (num_samples, batch_size, num_classes). logits_list = tf.stack(logits_list, axis=0) stddev_list = tf.stack(stddev_list, axis=0) stddev = tf.reduce_mean(stddev_list, axis=0) probs_list = tf.nn.softmax(logits_list) probs = tf.reduce_mean(probs_list, axis=0) labels_broadcasted = tf.broadcast_to( labels, [FLAGS.num_dropout_samples, labels.shape[0]]) log_likelihoods = -tf.keras.losses.sparse_categorical_crossentropy( labels_broadcasted, logits_list, from_logits=True) negative_log_likelihood = tf.reduce_mean( -tf.reduce_logsumexp(log_likelihoods, axis=[0]) + tf.math.log(float(FLAGS.num_dropout_samples))) if dataset_name == 'clean': metrics['test/negative_log_likelihood'].update_state( negative_log_likelihood) metrics['test/accuracy'].update_state(labels, probs) metrics['test/ece'].update_state(labels, probs) metrics['test/stddev'].update_state(stddev) elif dataset_name != 'validation': corrupt_metrics['test/nll_{}'.format( dataset_name)].update_state(negative_log_likelihood) corrupt_metrics['test/accuracy_{}'.format( dataset_name)].update_state(labels, probs) corrupt_metrics['test/ece_{}'.format( dataset_name)].update_state(labels, probs) corrupt_metrics['test/stddev_{}'.format( dataset_name)].update_state(stddev) if dataset_name == 'validation': return tf.reshape(probs, [1, -1, num_classes]), labels if dataset_name == 'validation': return strategy.run(step_fn, args=(next(iterator), )) else: strategy.run(step_fn, args=(next(iterator), )) metrics.update({'test/ms_per_example': tf.keras.metrics.Mean()}) train_iterator = iter(train_dataset) forget_counts_history = [] start_time = time.time() for epoch in range(initial_epoch, FLAGS.train_epochs): logging.info('Starting to run epoch: %s', epoch) acc_counts_list = [] idx_list = [] for step in range(steps_per_epoch): if FLAGS.forget_mixup: temp_accuracy_counts, temp_idx = train_step(train_iterator) acc_counts_list.append(temp_accuracy_counts) idx_list.append(temp_idx) else: train_step(train_iterator) current_step = epoch * steps_per_epoch + (step + 1) max_steps = steps_per_epoch * FLAGS.train_epochs time_elapsed = time.time() - start_time steps_per_sec = float(current_step) / time_elapsed eta_seconds = (max_steps - current_step) / steps_per_sec message = ('{:.1%} completion: epoch {:d}/{:d}. {:.1f} steps/s. ' 'ETA: {:.0f} min. Time elapsed: {:.0f} min'.format( current_step / max_steps, epoch + 1, FLAGS.train_epochs, steps_per_sec, eta_seconds / 60, time_elapsed / 60)) if step % 20 == 0: logging.info(message) # Only one of the forget_mixup and adaptive_mixup can be true. if FLAGS.forget_mixup: current_acc = [ tf.concat(list(acc_counts_list[i].values), axis=1) for i in range(len(acc_counts_list)) ] total_idx = [ tf.concat(list(idx_list[i].values), axis=0) for i in range(len(idx_list)) ] current_acc = tf.cast(tf.concat(current_acc, axis=1), tf.int32) total_idx = tf.concat(total_idx, axis=0) current_forget_path = os.path.join(FLAGS.output_dir, 'forget_counts.npy') last_acc_path = os.path.join(FLAGS.output_dir, 'last_acc.npy') if epoch == 0: forget_counts = tf.zeros([1, num_train_examples], dtype=tf.int32) last_acc = tf.zeros([1, num_train_examples], dtype=tf.int32) else: if 'last_acc' not in locals(): with tf.io.gfile.GFile(last_acc_path, 'rb') as f: last_acc = np.load(f) last_acc = tf.cast(tf.convert_to_tensor(last_acc), tf.int32) if 'forget_counts' not in locals(): with tf.io.gfile.GFile(current_forget_path, 'rb') as f: forget_counts = np.load(f) forget_counts = tf.cast( tf.convert_to_tensor(forget_counts), tf.int32) selected_last_acc = tf.gather(last_acc, total_idx, axis=1) forget_this_epoch = tf.cast(current_acc < selected_last_acc, tf.int32) forget_this_epoch = tf.transpose(forget_this_epoch) target_shape = tf.constant([num_train_examples, 1]) current_forget_counts = tf.scatter_nd( tf.reshape(total_idx, [-1, 1]), forget_this_epoch, target_shape) current_forget_counts = tf.transpose(current_forget_counts) acc_this_epoch = tf.transpose(current_acc) last_acc = tf.scatter_nd(tf.reshape(total_idx, [-1, 1]), acc_this_epoch, target_shape) # This is lower bound of true acc. last_acc = tf.transpose(last_acc) # TODO(ywenxu): We count the dropped examples as forget. Fix this later. forget_counts += current_forget_counts forget_counts_history.append(forget_counts) logging.info('forgetting counts') logging.info(tf.stack(forget_counts_history, 0)) with tf.io.gfile.GFile( os.path.join(FLAGS.output_dir, 'forget_counts_history.npy'), 'wb') as f: np.save(f, tf.stack(forget_counts_history, 0).numpy()) with tf.io.gfile.GFile(current_forget_path, 'wb') as f: np.save(f, forget_counts.numpy()) with tf.io.gfile.GFile(last_acc_path, 'wb') as f: np.save(f, last_acc.numpy()) aug_params['forget_counts_dir'] = current_forget_path train_input_fn = data_utils.load_input_fn( split=tfds.Split.TRAIN, name=FLAGS.dataset, batch_size=batch_size, use_bfloat16=FLAGS.use_bfloat16, validation_set=FLAGS.validation, aug_params=aug_params) train_dataset = strategy.experimental_distribute_dataset( train_input_fn()) train_iterator = iter(train_dataset) elif FLAGS.adaptive_mixup: val_iterator = iter(val_dataset) logging.info('Testing on validation dataset') predictions_list = [] labels_list = [] for step in range(steps_per_val): temp_predictions, temp_labels = test_step( val_iterator, 'validation') predictions_list.append(temp_predictions) labels_list.append(temp_labels) predictions = [ tf.concat(list(predictions_list[i].values), axis=1) for i in range(len(predictions_list)) ] labels = [ tf.concat(list(labels_list[i].values), axis=0) for i in range(len(labels_list)) ] predictions = tf.concat(predictions, axis=1) labels = tf.cast(tf.concat(labels, axis=0), tf.int64) def compute_acc_conf(preds, label, focus_class): class_preds = tf.boolean_mask(preds, label == focus_class, axis=1) class_pred_labels = tf.argmax(class_preds, axis=-1) confidence = tf.reduce_mean( tf.reduce_max(class_preds, axis=-1), -1) accuracy = tf.reduce_mean(tf.cast( class_pred_labels == focus_class, tf.float32), axis=-1) return accuracy - confidence calibration_per_class = [ compute_acc_conf(predictions, labels, i) for i in range(num_classes) ] calibration_per_class = tf.stack(calibration_per_class, axis=1) logging.info('calibration per class') logging.info(calibration_per_class) mixup_coeff = tf.where(calibration_per_class > 0, 1.0, FLAGS.mixup_alpha) mixup_coeff = tf.clip_by_value(mixup_coeff, 0, 1) logging.info('mixup coeff') logging.info(mixup_coeff) aug_params['mixup_coeff'] = mixup_coeff train_input_fn = data_utils.load_input_fn( split=tfds.Split.TRAIN, name=FLAGS.dataset, batch_size=batch_size, use_bfloat16=FLAGS.use_bfloat16, validation_set=True, aug_params=aug_params) train_dataset = strategy.experimental_distribute_dataset( train_input_fn()) train_iterator = iter(train_dataset) datasets_to_evaluate = {'clean': test_datasets['clean']} if (FLAGS.corruptions_interval > 0 and (epoch + 1) % FLAGS.corruptions_interval == 0): datasets_to_evaluate = test_datasets for dataset_name, test_dataset in datasets_to_evaluate.items(): test_iterator = iter(test_dataset) logging.info('Testing on dataset %s', dataset_name) for step in range(steps_per_eval): if step % 20 == 0: logging.info('Starting to run eval step %s of epoch: %s', step, epoch) test_start_time = time.time() test_step(test_iterator, dataset_name) ms_per_example = (time.time() - test_start_time) * 1e6 / batch_size metrics['test/ms_per_example'].update_state(ms_per_example) logging.info('Done with testing on %s', dataset_name) corrupt_results = {} if (FLAGS.corruptions_interval > 0 and (epoch + 1) % FLAGS.corruptions_interval == 0): corrupt_results = utils.aggregate_corrupt_metrics( corrupt_metrics, corruption_types, max_intensity) logging.info('Train Loss: %.4f, Accuracy: %.2f%%', metrics['train/loss'].result(), metrics['train/accuracy'].result() * 100) logging.info('Test NLL: %.4f, Accuracy: %.2f%%', metrics['test/negative_log_likelihood'].result(), metrics['test/accuracy'].result() * 100) total_results = { name: metric.result() for name, metric in metrics.items() } total_results.update(corrupt_results) with summary_writer.as_default(): for name, result in total_results.items(): tf.summary.scalar(name, result, step=epoch + 1) for metric in metrics.values(): metric.reset_states() if (FLAGS.checkpoint_interval > 0 and (epoch + 1) % FLAGS.checkpoint_interval == 0): checkpoint_name = checkpoint.save( os.path.join(FLAGS.output_dir, 'checkpoint')) logging.info('Saved checkpoint to %s', checkpoint_name) final_checkpoint_name = checkpoint.save( os.path.join(FLAGS.output_dir, 'checkpoint')) logging.info('Saved last checkpoint to %s', final_checkpoint_name)