def main(argv): del argv # unused arg tf.io.gfile.makedirs(FLAGS.output_dir) logging.info('Saving checkpoints at %s', FLAGS.output_dir) tf.random.set_seed(FLAGS.seed) # Initialize distribution strategy on flag-specified accelerator strategy = utils.init_distribution_strategy(FLAGS.force_use_cpu, FLAGS.use_gpu, FLAGS.tpu) use_tpu = not (FLAGS.force_use_cpu or FLAGS.use_gpu) train_batch_size = FLAGS.train_batch_size * FLAGS.num_cores eval_batch_size = FLAGS.eval_batch_size * FLAGS.num_cores # Reweighting loss for class imbalance class_reweight_mode = FLAGS.class_reweight_mode if class_reweight_mode == 'constant': class_weights = utils.get_diabetic_retinopathy_class_balance_weights() else: class_weights = None # As per the Kaggle challenge, we have split sizes: # train: 35,126 # validation: 10,906 (currently unused) # test: 42,670 ds_info = tfds.builder('diabetic_retinopathy_detection').info steps_per_epoch = ds_info.splits['train'].num_examples // train_batch_size steps_per_validation_eval = ( ds_info.splits['validation'].num_examples // eval_batch_size) steps_per_test_eval = ds_info.splits['test'].num_examples // eval_batch_size data_dir = FLAGS.data_dir dataset_train_builder = ub.datasets.get( 'diabetic_retinopathy_detection', split='train', data_dir=data_dir) dataset_train = dataset_train_builder.load(batch_size=train_batch_size) dataset_validation_builder = ub.datasets.get( 'diabetic_retinopathy_detection', split='validation', data_dir=data_dir, is_training=not FLAGS.use_validation) validation_batch_size = ( eval_batch_size if FLAGS.use_validation else train_batch_size) dataset_validation = dataset_validation_builder.load( batch_size=validation_batch_size) if FLAGS.use_validation: dataset_validation = strategy.experimental_distribute_dataset( dataset_validation) else: # Note that this will not create any mixed batches of train and validation # images. dataset_train = dataset_train.concatenate(dataset_validation) dataset_train = strategy.experimental_distribute_dataset(dataset_train) dataset_test_builder = ub.datasets.get( 'diabetic_retinopathy_detection', split='test', data_dir=data_dir) dataset_test = dataset_test_builder.load(batch_size=eval_batch_size) dataset_test = strategy.experimental_distribute_dataset(dataset_test) if FLAGS.use_bfloat16: policy = tf.keras.mixed_precision.experimental.Policy('mixed_bfloat16') tf.keras.mixed_precision.experimental.set_policy(policy) summary_writer = tf.summary.create_file_writer( os.path.join(FLAGS.output_dir, 'summaries')) with strategy.scope(): logging.info('Building Keras ResNet-50 deterministic model.') model = ub.models.resnet50_deterministic( input_shape=utils.load_input_shape(dataset_train), num_classes=1) # binary classification task logging.info('Model input shape: %s', model.input_shape) logging.info('Model output shape: %s', model.output_shape) logging.info('Model number of weights: %s', model.count_params()) base_lr = FLAGS.base_learning_rate if FLAGS.lr_schedule == 'step': lr_decay_epochs = [ (int(start_epoch_str) * FLAGS.train_epochs) // DEFAULT_NUM_EPOCHS for start_epoch_str in FLAGS.lr_decay_epochs ] lr_schedule = ub.schedules.WarmUpPiecewiseConstantSchedule( steps_per_epoch, base_lr, decay_ratio=FLAGS.lr_decay_ratio, decay_epochs=lr_decay_epochs, warmup_epochs=FLAGS.lr_warmup_epochs) else: lr_schedule = ub.schedules.WarmUpPolynomialSchedule( base_lr, end_learning_rate=FLAGS.final_decay_factor * base_lr, decay_steps=( steps_per_epoch * (FLAGS.train_epochs - FLAGS.lr_warmup_epochs)), warmup_steps=steps_per_epoch * FLAGS.lr_warmup_epochs, decay_power=1.0) optimizer = tf.keras.optimizers.SGD( lr_schedule, momentum=1.0 - FLAGS.one_minus_momentum, nesterov=True) metrics = utils.get_diabetic_retinopathy_base_metrics( use_tpu=use_tpu, num_bins=FLAGS.num_bins, use_validation=FLAGS.use_validation) checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer) latest_checkpoint = tf.train.latest_checkpoint(FLAGS.output_dir) initial_epoch = 0 if latest_checkpoint: # checkpoint.restore must be within a strategy.scope() # so that optimizer slot variables are mirrored. checkpoint.restore(latest_checkpoint) logging.info('Loaded checkpoint %s', latest_checkpoint) initial_epoch = optimizer.iterations.numpy() // steps_per_epoch # Define metrics outside the accelerator scope for CPU eval. # This will cause an error on TPU. if not use_tpu: metrics.update( utils.get_diabetic_retinopathy_cpu_metrics( use_validation=FLAGS.use_validation)) metrics.update({'test/ms_per_example': tf.keras.metrics.Mean()}) # Initialize loss function based on class reweighting setting loss_fn = utils.get_diabetic_retinopathy_loss_fn( class_reweight_mode=class_reweight_mode, class_weights=class_weights) @tf.function def train_step(iterator): """Training step function.""" def step_fn(inputs): """Per-replica step function.""" images = inputs['features'] labels = inputs['labels'] # For minibatch class reweighting, initialize per-batch loss function if class_reweight_mode == 'minibatch': batch_loss_fn = utils.get_minibatch_reweighted_loss_fn(labels=labels) else: batch_loss_fn = loss_fn with tf.GradientTape() as tape: logits = model(images, training=True) if FLAGS.use_bfloat16: logits = tf.cast(logits, tf.float32) negative_log_likelihood = tf.reduce_mean( batch_loss_fn( y_true=tf.expand_dims(labels, axis=-1), y_pred=logits, from_logits=True)) l2_loss = sum(model.losses) loss = negative_log_likelihood + (FLAGS.l2 * l2_loss) # Scale the loss given the TPUStrategy will reduce sum all gradients. scaled_loss = loss / strategy.num_replicas_in_sync grads = tape.gradient(scaled_loss, model.trainable_variables) optimizer.apply_gradients(zip(grads, model.trainable_variables)) probs = tf.squeeze(tf.nn.sigmoid(logits)) metrics['train/loss'].update_state(loss) metrics['train/negative_log_likelihood'].update_state( negative_log_likelihood) metrics['train/accuracy'].update_state(labels, probs) metrics['train/auprc'].update_state(labels, probs) metrics['train/auroc'].update_state(labels, probs) if not use_tpu: metrics['train/ece'].add_batch(probs, label=labels) for _ in tf.range(tf.cast(steps_per_epoch, tf.int32)): strategy.run(step_fn, args=(next(iterator),)) @tf.function def test_step(iterator, dataset_split, num_steps): """Evaluation step function.""" def step_fn(inputs): """Per-replica step function.""" images = inputs['features'] labels = inputs['labels'] logits = model(images, training=False) if FLAGS.use_bfloat16: logits = tf.cast(logits, tf.float32) negative_log_likelihood = tf.reduce_mean( tf.keras.losses.binary_crossentropy( y_true=tf.expand_dims(labels, axis=-1), y_pred=logits, from_logits=True)) probs = tf.squeeze(tf.nn.sigmoid(logits)) metrics[dataset_split + '/negative_log_likelihood'].update_state( negative_log_likelihood) metrics[dataset_split + '/accuracy'].update_state(labels, probs) metrics['test/accuracy'].update_state(labels, probs) metrics[dataset_split + '/auprc'].update_state(labels, probs) metrics[dataset_split + '/auroc'].update_state(labels, probs) if not use_tpu: metrics[dataset_split + '/ece'].add_batch(probs, label=labels) for _ in tf.range(tf.cast(num_steps, tf.int32)): strategy.run(step_fn, args=(next(iterator),)) start_time = time.time() train_iterator = iter(dataset_train) for epoch in range(initial_epoch, FLAGS.train_epochs): logging.info('Starting to run epoch: %s', epoch + 1) train_step(train_iterator) current_step = (epoch + 1) * steps_per_epoch max_steps = steps_per_epoch * FLAGS.train_epochs time_elapsed = time.time() - start_time steps_per_sec = float(current_step) / time_elapsed eta_seconds = (max_steps - current_step) / steps_per_sec message = ('{:.1%} completion: epoch {:d}/{:d}. {:.1f} steps/s. ' 'ETA: {:.0f} min. Time elapsed: {:.0f} min'.format( current_step / max_steps, epoch + 1, FLAGS.train_epochs, steps_per_sec, eta_seconds / 60, time_elapsed / 60)) logging.info(message) if FLAGS.use_validation: validation_iterator = iter(dataset_validation) logging.info('Starting to run validation eval ay epoch: %s', epoch + 1) test_step(validation_iterator, 'validation', steps_per_validation_eval) test_iterator = iter(dataset_test) logging.info('Starting to run test eval at epoch: %s', epoch + 1) test_start_time = time.time() test_step(test_iterator, 'test', steps_per_test_eval) ms_per_example = (time.time() - test_start_time) * 1e6 / eval_batch_size metrics['test/ms_per_example'].update_state(ms_per_example) # Log and write to summary the epoch metrics utils.log_epoch_metrics(metrics=metrics, use_tpu=use_tpu) total_results = {name: metric.result() for name, metric in metrics.items()} # Metrics from Robustness Metrics (like ECE) will return a dict with a # single key/value, instead of a scalar. total_results = { k: (list(v.values())[0] if isinstance(v, dict) else v) for k, v in total_results.items() } with summary_writer.as_default(): for name, result in total_results.items(): tf.summary.scalar(name, result, step=epoch + 1) for metric in metrics.values(): metric.reset_states() if (FLAGS.checkpoint_interval > 0 and (epoch + 1) % FLAGS.checkpoint_interval == 0): checkpoint_name = checkpoint.save( os.path.join(FLAGS.output_dir, 'checkpoint')) logging.info('Saved checkpoint to %s', checkpoint_name) # TODO(nband): debug checkpointing # Also save Keras model, due to checkpoint.save issue keras_model_name = os.path.join(FLAGS.output_dir, f'keras_model_{epoch + 1}') model.save(keras_model_name) logging.info('Saved keras model to %s', keras_model_name) final_checkpoint_name = checkpoint.save( os.path.join(FLAGS.output_dir, 'checkpoint')) logging.info('Saved last checkpoint to %s', final_checkpoint_name) keras_model_name = os.path.join(FLAGS.output_dir, f'keras_model_{FLAGS.train_epochs}') model.save(keras_model_name) logging.info('Saved keras model to %s', keras_model_name) with summary_writer.as_default(): hp.hparams({ 'base_learning_rate': FLAGS.base_learning_rate, 'one_minus_momentum': FLAGS.one_minus_momentum, 'l2': FLAGS.l2, })
def main(argv): del argv # unused arg tf.io.gfile.makedirs(FLAGS.output_dir) logging.info('Saving checkpoints at %s', FLAGS.output_dir) tf.random.set_seed(FLAGS.seed) # Initialize distribution strategy on flag-specified accelerator strategy = utils.init_distribution_strategy(FLAGS.force_use_cpu, FLAGS.use_gpu, FLAGS.tpu) use_tpu = not (FLAGS.force_use_cpu or FLAGS.use_gpu) train_batch_size = (FLAGS.train_batch_size * FLAGS.num_cores) if use_tpu: logging.info( 'Due to TPU requiring static shapes, we must fix the eval batch size ' 'to the train batch size: %d/', train_batch_size) eval_batch_size = train_batch_size else: eval_batch_size = (FLAGS.eval_batch_size * FLAGS.num_cores) // FLAGS.num_dropout_samples_eval # As per the Kaggle challenge, we have split sizes: # train: 35,126 # validation: 10,906 (currently unused) # test: 42,670 ds_info = tfds.builder('diabetic_retinopathy_detection').info steps_per_epoch = ds_info.splits['train'].num_examples // train_batch_size steps_per_validation_eval = (ds_info.splits['validation'].num_examples // eval_batch_size) steps_per_test_eval = ds_info.splits['test'].num_examples // eval_batch_size data_dir = FLAGS.data_dir dataset_train_builder = ub.datasets.get('diabetic_retinopathy_detection', split='train', data_dir=data_dir) dataset_train = dataset_train_builder.load(batch_size=train_batch_size) dataset_train = strategy.experimental_distribute_dataset(dataset_train) dataset_validation_builder = ub.datasets.get( 'diabetic_retinopathy_detection', split='validation', data_dir=data_dir) dataset_validation = dataset_validation_builder.load( batch_size=eval_batch_size) dataset_validation = strategy.experimental_distribute_dataset( dataset_validation) dataset_test_builder = ub.datasets.get('diabetic_retinopathy_detection', split='test', data_dir=data_dir) dataset_test = dataset_test_builder.load(batch_size=eval_batch_size) dataset_test = strategy.experimental_distribute_dataset(dataset_test) if FLAGS.use_bfloat16: policy = tf.keras.mixed_precision.experimental.Policy('mixed_bfloat16') tf.keras.mixed_precision.experimental.set_policy(policy) summary_writer = tf.summary.create_file_writer( os.path.join(FLAGS.output_dir, 'summaries')) with strategy.scope(): logging.info('Building Keras ResNet-50 MC Dropout model.') model = ub.models.resnet50_dropout( input_shape=utils.load_input_shape(dataset_train), num_classes=1, # binary classification task dropout_rate=FLAGS.dropout_rate, filterwise_dropout=FLAGS.filterwise_dropout) logging.info('Model input shape: %s', model.input_shape) logging.info('Model output shape: %s', model.output_shape) logging.info('Model number of weights: %s', model.count_params()) # Linearly scale learning rate and the decay epochs by vanilla settings. base_lr = FLAGS.base_learning_rate lr_decay_epochs = [ (int(start_epoch_str) * FLAGS.train_epochs) // DEFAULT_NUM_EPOCHS for start_epoch_str in FLAGS.lr_decay_epochs ] lr_schedule = ub.schedules.WarmUpPiecewiseConstantSchedule( steps_per_epoch, base_lr, decay_ratio=FLAGS.lr_decay_ratio, decay_epochs=lr_decay_epochs, warmup_epochs=FLAGS.lr_warmup_epochs) optimizer = tf.keras.optimizers.SGD(lr_schedule, momentum=1.0 - FLAGS.one_minus_momentum, nesterov=True) metrics = utils.get_diabetic_retinopathy_base_metrics( use_tpu=use_tpu, num_bins=FLAGS.num_bins) checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer) latest_checkpoint = tf.train.latest_checkpoint(FLAGS.output_dir) initial_epoch = 0 if latest_checkpoint: # checkpoint.restore must be within a strategy.scope() # so that optimizer slot variables are mirrored. checkpoint.restore(latest_checkpoint) logging.info('Loaded checkpoint %s', latest_checkpoint) initial_epoch = optimizer.iterations.numpy() // steps_per_epoch # Finally, define OOD metrics outside the accelerator scope for CPU eval. # This will cause an error on TPU. if not use_tpu: metrics.update({ 'train/auc': tf.keras.metrics.AUC(), 'validation/auc': tf.keras.metrics.AUC(), 'test/auc': tf.keras.metrics.AUC() }) @tf.function def train_step(iterator): """Training step function.""" def step_fn(inputs): """Per-replica step function.""" images = inputs['features'] labels = inputs['labels'] with tf.GradientTape() as tape: logits = model(images, training=True) if FLAGS.use_bfloat16: logits = tf.cast(logits, tf.float32) negative_log_likelihood = tf.reduce_mean( tf.keras.losses.binary_crossentropy(y_true=tf.expand_dims( labels, axis=-1), y_pred=logits, from_logits=True)) l2_loss = sum(model.losses) loss = negative_log_likelihood + (FLAGS.l2 * l2_loss) # Scale the loss given the TPUStrategy will reduce sum all gradients. scaled_loss = loss / strategy.num_replicas_in_sync grads = tape.gradient(scaled_loss, model.trainable_variables) optimizer.apply_gradients(zip(grads, model.trainable_variables)) probs = tf.nn.sigmoid(logits) metrics['train/loss'].update_state(loss) metrics['train/negative_log_likelihood'].update_state( negative_log_likelihood) metrics['train/accuracy'].update_state(labels, probs) metrics['train/auc'].update_state(labels, probs) if not use_tpu: metrics['train/ece'].update_state(labels, probs) strategy.run(step_fn, args=(next(iterator), )) @tf.function def test_step(iterator, dataset_split): """Evaluation step function.""" def step_fn(inputs): """Per-replica step function.""" images = inputs['features'] labels = tf.convert_to_tensor(inputs['labels']) logits_list = [] for _ in range(FLAGS.num_dropout_samples_eval): logits = model(images, training=False) logits = tf.squeeze(logits, axis=-1) if FLAGS.use_bfloat16: logits = tf.cast(logits, tf.float32) logits_list.append(logits) # Logits dimension is (num_samples, batch_size). logits_list = tf.stack(logits_list, axis=0) probs_list = tf.nn.sigmoid(logits_list) probs = tf.reduce_mean(probs_list, axis=0) labels_broadcasted = tf.broadcast_to( labels, [FLAGS.num_dropout_samples_eval, labels.shape[0]]) log_likelihoods = -tf.keras.losses.binary_crossentropy( labels_broadcasted, logits_list, from_logits=True) negative_log_likelihood = tf.reduce_mean( -tf.reduce_logsumexp(log_likelihoods, axis=[0]) + tf.math.log(float(FLAGS.num_dropout_samples_eval))) metrics[dataset_split + '/negative_log_likelihood'].update_state( negative_log_likelihood) metrics[dataset_split + '/accuracy'].update_state(labels, probs) metrics[dataset_split + '/auc'].update_state(labels, probs) if not use_tpu: metrics[dataset_split + '/ece'].update_state(labels, probs) strategy.run(step_fn, args=(next(iterator), )) metrics.update({'test/ms_per_example': tf.keras.metrics.Mean()}) start_time = time.time() train_iterator = iter(dataset_train) for epoch in range(initial_epoch, FLAGS.train_epochs): logging.info('Starting to run epoch: %s', epoch + 1) for step in range(steps_per_epoch): train_step(train_iterator) current_step = epoch * steps_per_epoch + (step + 1) max_steps = steps_per_epoch * FLAGS.train_epochs time_elapsed = time.time() - start_time steps_per_sec = float(current_step) / time_elapsed eta_seconds = (max_steps - current_step) / steps_per_sec message = ('{:.1%} completion: epoch {:d}/{:d}. {:.1f} steps/s. ' 'ETA: {:.0f} min. Time elapsed: {:.0f} min'.format( current_step / max_steps, epoch + 1, FLAGS.train_epochs, steps_per_sec, eta_seconds / 60, time_elapsed / 60)) if step % 20 == 0: logging.info(message) validation_iterator = iter(dataset_validation) for step in range(steps_per_validation_eval): if step % 20 == 0: logging.info( 'Starting to run validation eval step %s of epoch: %s', step, epoch + 1) test_step(validation_iterator, 'validation') test_iterator = iter(dataset_test) for step in range(steps_per_test_eval): if step % 20 == 0: logging.info('Starting to run test eval step %s of epoch: %s', step, epoch + 1) test_start_time = time.time() test_step(test_iterator, 'test') ms_per_example = (time.time() - test_start_time) * 1e6 / eval_batch_size metrics['test/ms_per_example'].update_state(ms_per_example) # Log and write to summary the epoch metrics utils.log_epoch_metrics(metrics=metrics, use_tpu=use_tpu) total_results = { name: metric.result() for name, metric in metrics.items() } with summary_writer.as_default(): for name, result in total_results.items(): tf.summary.scalar(name, result, step=epoch + 1) for metric in metrics.values(): metric.reset_states() if (FLAGS.checkpoint_interval > 0 and (epoch + 1) % FLAGS.checkpoint_interval == 0): checkpoint_name = checkpoint.save( os.path.join(FLAGS.output_dir, 'checkpoint')) logging.info('Saved checkpoint to %s', checkpoint_name) # TODO(nband): debug checkpointing # Also save Keras model, due to checkpoint.save issue keras_model_name = os.path.join(FLAGS.output_dir, f'keras_model_{epoch + 1}') model.save(keras_model_name) logging.info('Saved keras model to %s', keras_model_name) final_checkpoint_name = checkpoint.save( os.path.join(FLAGS.output_dir, 'checkpoint')) logging.info('Saved last checkpoint to %s', final_checkpoint_name) keras_model_name = os.path.join(FLAGS.output_dir, f'keras_model_{FLAGS.train_epochs}') model.save(keras_model_name) logging.info('Saved keras model to %s', keras_model_name) with summary_writer.as_default(): hp.hparams({ 'base_learning_rate': FLAGS.base_learning_rate, 'one_minus_momentum': FLAGS.one_minus_momentum, 'dropout_rate': FLAGS.dropout_rate, 'l2': FLAGS.l2, })
def main(argv): del argv # unused arg tf.io.gfile.makedirs(FLAGS.output_dir) logging.info('Saving checkpoints at %s', FLAGS.output_dir) tf.random.set_seed(FLAGS.seed) # Initialize distribution strategy on flag-specified accelerator strategy = utils.init_distribution_strategy(FLAGS.force_use_cpu, FLAGS.use_gpu, FLAGS.tpu) train_batch_size = FLAGS.train_batch_size * FLAGS.num_cores eval_batch_size = FLAGS.eval_batch_size * FLAGS.num_cores # As per the Kaggle challenge, we have split sizes: # train: 35,126 # validation: 10,906 (currently unused) # test: 42,670 ds_info = tfds.builder('diabetic_retinopathy_detection').info steps_per_epoch = ds_info.splits['train'].num_examples // train_batch_size steps_per_eval = ds_info.splits['test'].num_examples // eval_batch_size dataset_train_builder = ub.datasets.get('diabetic_retinopathy_detection', split='train', data_dir=FLAGS.data_dir) dataset_train = dataset_train_builder.load(batch_size=train_batch_size) dataset_train = strategy.experimental_distribute_dataset(dataset_train) dataset_test_builder = ub.datasets.get('diabetic_retinopathy_detection', split='test', data_dir=FLAGS.data_dir) dataset_test = dataset_test_builder.load(batch_size=eval_batch_size) dataset_test = strategy.experimental_distribute_dataset(dataset_test) if FLAGS.use_bfloat16: policy = tf.keras.mixed_precision.experimental.Policy('mixed_bfloat16') tf.keras.mixed_precision.experimental.set_policy(policy) summary_writer = tf.summary.create_file_writer( os.path.join(FLAGS.output_dir, 'summaries')) with strategy.scope(): logging.info('Building Keras ResNet-50 deterministic model.') model = ub.models.resnet50_deterministic( input_shape=utils.load_input_shape(dataset_train), num_classes=1) # binary classification task logging.info('Model input shape: %s', model.input_shape) logging.info('Model output shape: %s', model.output_shape) logging.info('Model number of weights: %s', model.count_params()) # Linearly scale learning rate and the decay epochs by vanilla settings. base_lr = (FLAGS.base_learning_rate * train_batch_size) / DEFAULT_TRAIN_BATCH_SIZE lr_decay_epochs = [ (int(start_epoch_str) * FLAGS.train_epochs) // DEFAULT_NUM_EPOCHS for start_epoch_str in FLAGS.lr_decay_epochs ] lr_schedule = utils.LearningRateSchedule( steps_per_epoch, base_lr, decay_ratio=FLAGS.lr_decay_ratio, decay_epochs=lr_decay_epochs, warmup_epochs=FLAGS.lr_warmup_epochs) optimizer = tf.keras.optimizers.SGD(lr_schedule, momentum=0.9, nesterov=True) metrics = { 'train/negative_log_likelihood': tf.keras.metrics.Mean(), 'train/accuracy': tf.keras.metrics.BinaryAccuracy(), 'train/loss': tf.keras.metrics.Mean(), # NLL + L2 'train/ece': um.ExpectedCalibrationError(num_bins=FLAGS.num_bins), 'test/negative_log_likelihood': tf.keras.metrics.Mean(), 'test/accuracy': tf.keras.metrics.BinaryAccuracy(), 'test/ece': um.ExpectedCalibrationError(num_bins=FLAGS.num_bins) } checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer) latest_checkpoint = tf.train.latest_checkpoint(FLAGS.output_dir) initial_epoch = 0 if latest_checkpoint: # checkpoint.restore must be within a strategy.scope() # so that optimizer slot variables are mirrored. checkpoint.restore(latest_checkpoint) logging.info('Loaded checkpoint %s', latest_checkpoint) initial_epoch = optimizer.iterations.numpy() // steps_per_epoch # Finally, define OOD metrics outside the accelerator scope for CPU eval. metrics.update({ 'train/auc': tf.keras.metrics.AUC(), 'test/auc': tf.keras.metrics.AUC() }) @tf.function def train_step(iterator): """Training step function.""" def step_fn(inputs): """Per-replica step function.""" images = inputs['features'] labels = inputs['labels'] with tf.GradientTape() as tape: logits = model(images, training=True) if FLAGS.use_bfloat16: logits = tf.cast(logits, tf.float32) negative_log_likelihood = tf.reduce_mean( tf.keras.losses.binary_crossentropy(y_true=tf.expand_dims( labels, axis=-1), y_pred=logits, from_logits=True)) l2_loss = sum(model.losses) loss = negative_log_likelihood + (FLAGS.l2 * l2_loss) # Scale the loss given the TPUStrategy will reduce sum all gradients. scaled_loss = loss / strategy.num_replicas_in_sync grads = tape.gradient(scaled_loss, model.trainable_variables) optimizer.apply_gradients(zip(grads, model.trainable_variables)) probs = tf.squeeze(tf.nn.sigmoid(logits)) metrics['train/ece'].update_state(labels, probs) metrics['train/loss'].update_state(loss) metrics['train/negative_log_likelihood'].update_state( negative_log_likelihood) metrics['train/accuracy'].update_state(labels, probs) metrics['train/auc'].update_state(labels, probs) strategy.run(step_fn, args=(next(iterator), )) @tf.function def test_step(iterator): """Evaluation step function.""" def step_fn(inputs): """Per-replica step function.""" images = inputs['features'] labels = inputs['labels'] logits = model(images, training=False) if FLAGS.use_bfloat16: logits = tf.cast(logits, tf.float32) negative_log_likelihood = tf.reduce_mean( tf.keras.losses.binary_crossentropy(y_true=tf.expand_dims( labels, axis=-1), y_pred=logits, from_logits=True)) probs = tf.squeeze(tf.nn.sigmoid(logits)) metrics['test/negative_log_likelihood'].update_state( negative_log_likelihood) metrics['test/accuracy'].update_state(labels, probs) metrics['test/auc'].update_state(labels, probs) metrics['test/ece'].update_state(labels, probs) strategy.run(step_fn, args=(next(iterator), )) metrics.update({'test/ms_per_example': tf.keras.metrics.Mean()}) start_time = time.time() for epoch in range(initial_epoch, FLAGS.train_epochs): train_iterator = iter(dataset_train) test_iterator = iter(dataset_test) logging.info('Starting to run epoch: %s', epoch + 1) for step in range(steps_per_epoch): train_step(train_iterator) current_step = epoch * steps_per_epoch + (step + 1) max_steps = steps_per_epoch * FLAGS.train_epochs time_elapsed = time.time() - start_time steps_per_sec = float(current_step) / time_elapsed eta_seconds = (max_steps - current_step) / steps_per_sec message = ('{:.1%} completion: epoch {:d}/{:d}. {:.1f} steps/s. ' 'ETA: {:.0f} min. Time elapsed: {:.0f} min'.format( current_step / max_steps, epoch + 1, FLAGS.train_epochs, steps_per_sec, eta_seconds / 60, time_elapsed / 60)) if step % 20 == 0: logging.info(message) for step in range(steps_per_eval): if step % 20 == 0: logging.info('Starting to run eval step %s of epoch: %s', step, epoch + 1) test_start_time = time.time() test_step(test_iterator) ms_per_example = (time.time() - test_start_time) * 1e6 / eval_batch_size metrics['test/ms_per_example'].update_state(ms_per_example) logging.info( 'Train Loss (NLL+L2): %.4f, Accuracy: %.2f%%, AUC: %.2f%%, ECE: %.2f%%', metrics['train/loss'].result(), metrics['train/accuracy'].result() * 100, metrics['train/auc'].result() * 100, metrics['train/ece'].result() * 100) logging.info( 'Test NLL: %.4f, Accuracy: %.2f%%, AUC: %.2f%%, ECE: %.2f%%', metrics['test/negative_log_likelihood'].result(), metrics['test/accuracy'].result() * 100, metrics['test/auc'].result() * 100, metrics['test/ece'].result() * 100) total_results = { name: metric.result() for name, metric in metrics.items() } with summary_writer.as_default(): for name, result in total_results.items(): tf.summary.scalar(name, result, step=epoch + 1) for metric in metrics.values(): metric.reset_states() if (FLAGS.checkpoint_interval > 0 and (epoch + 1) % FLAGS.checkpoint_interval == 0): checkpoint_name = checkpoint.save( os.path.join(FLAGS.output_dir, 'checkpoint')) logging.info('Saved checkpoint to %s', checkpoint_name) # TODO(nband): debug checkpointing # Also save Keras model, due to checkpoint.save issue keras_model_name = os.path.join(FLAGS.output_dir, f'keras_model_{epoch + 1}') model.save(keras_model_name) logging.info('Saved keras model to %s', keras_model_name) final_checkpoint_name = checkpoint.save( os.path.join(FLAGS.output_dir, 'checkpoint')) logging.info('Saved last checkpoint to %s', final_checkpoint_name) keras_model_name = os.path.join(FLAGS.output_dir, f'keras_model_{FLAGS.train_epochs}') model.save(keras_model_name) logging.info('Saved keras model to %s', keras_model_name)
def main(argv): del argv # unused arg tf.random.set_seed(FLAGS.seed) # Wandb Setup if FLAGS.use_wandb: pathlib.Path(FLAGS.wandb_dir).mkdir(parents=True, exist_ok=True) wandb_args = dict(project=FLAGS.project, entity='uncertainty-baselines', dir=FLAGS.wandb_dir, reinit=True, name=FLAGS.exp_name, group=FLAGS.exp_group) wandb_run = wandb.init(**wandb_args) wandb.config.update(FLAGS, allow_val_change=True) output_dir = str( os.path.join( FLAGS.output_dir, datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S'))) else: wandb_run = None output_dir = FLAGS.output_dir tf.io.gfile.makedirs(output_dir) logging.info('Saving checkpoints at %s', output_dir) # Log Run Hypers hypers_dict = { 'per_core_batch_size': FLAGS.per_core_batch_size, 'base_learning_rate': FLAGS.base_learning_rate, 'one_minus_momentum': FLAGS.one_minus_momentum, 'dropout_rate': FLAGS.dropout_rate, 'l2': FLAGS.l2, } logging.info('Hypers:') logging.info(pprint.pformat(hypers_dict)) # Initialize distribution strategy on flag-specified accelerator strategy = utils.init_distribution_strategy(FLAGS.force_use_cpu, FLAGS.use_gpu, FLAGS.tpu) use_tpu = not (FLAGS.force_use_cpu or FLAGS.use_gpu) per_core_batch_size = (FLAGS.per_core_batch_size * FLAGS.num_cores) # Reweighting loss for class imbalance class_reweight_mode = FLAGS.class_reweight_mode if class_reweight_mode == 'constant': class_weights = utils.get_diabetic_retinopathy_class_balance_weights() else: class_weights = None # Load in datasets. datasets, steps = utils.load_dataset(train_batch_size=per_core_batch_size, eval_batch_size=per_core_batch_size, flags=FLAGS, strategy=strategy) available_splits = list(datasets.keys()) test_splits = [split for split in available_splits if 'test' in split] eval_splits = [ split for split in available_splits if 'validation' in split or 'test' in split ] # Iterate eval datasets eval_datasets = {split: iter(datasets[split]) for split in eval_splits} dataset_train = datasets['train'] train_steps_per_epoch = steps['train'] if FLAGS.use_bfloat16: tf.keras.mixed_precision.set_global_policy('mixed_bfloat16') summary_writer = tf.summary.create_file_writer( os.path.join(output_dir, 'summaries')) with strategy.scope(): logging.info('Building Keras ResNet-50 MC Dropout model.') model = None if FLAGS.load_from_checkpoint: initial_epoch, model = utils.load_keras_checkpoints( FLAGS.checkpoint_dir, load_ensemble=False, return_epoch=True) else: initial_epoch = 0 model = ub.models.resnet50_dropout( input_shape=utils.load_input_shape(dataset_train), num_classes=1, # binary classification task dropout_rate=FLAGS.dropout_rate, filterwise_dropout=FLAGS.filterwise_dropout) utils.log_model_init_info(model=model) # Linearly scale learning rate and the decay epochs by vanilla settings. base_lr = FLAGS.base_learning_rate lr_decay_epochs = [ (int(start_epoch_str) * FLAGS.train_epochs) // DEFAULT_NUM_EPOCHS for start_epoch_str in FLAGS.lr_decay_epochs ] lr_schedule = ub.schedules.WarmUpPiecewiseConstantSchedule( train_steps_per_epoch, base_lr, decay_ratio=FLAGS.lr_decay_ratio, decay_epochs=lr_decay_epochs, warmup_epochs=FLAGS.lr_warmup_epochs) optimizer = tf.keras.optimizers.SGD(lr_schedule, momentum=1.0 - FLAGS.one_minus_momentum, nesterov=True) metrics = utils.get_diabetic_retinopathy_base_metrics( use_tpu=use_tpu, num_bins=FLAGS.num_bins, use_validation=FLAGS.use_validation, available_splits=available_splits) # TODO(nband): debug or remove # checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer) # latest_checkpoint = tf.train.latest_checkpoint(output_dir) # if latest_checkpoint: # # checkpoint.restore must be within a strategy.scope() # # so that optimizer slot variables are mirrored. # checkpoint.restore(latest_checkpoint) # logging.info('Loaded checkpoint %s', latest_checkpoint) # initial_epoch = optimizer.iterations.numpy() // train_steps_per_epoch # Define metrics outside the accelerator scope for CPU eval. # This will cause an error on TPU. if not use_tpu: metrics.update( utils.get_diabetic_retinopathy_cpu_metrics( available_splits=available_splits, use_validation=FLAGS.use_validation)) for test_split in test_splits: metrics.update( {f'{test_split}/ms_per_example': tf.keras.metrics.Mean()}) # Initialize loss function based on class reweighting setting loss_fn = utils.get_diabetic_retinopathy_loss_fn( class_reweight_mode=class_reweight_mode, class_weights=class_weights) # * Prepare for Evaluation * # Get the wrapper function which will produce uncertainty estimates for # our choice of method and Y/N ensembling. uncertainty_estimator_fn = utils.get_uncertainty_estimator( 'dropout', use_ensemble=False, use_tf=True) # Wrap our estimator to predict probabilities (apply sigmoid on logits) eval_estimator = utils.wrap_retinopathy_estimator( model, use_mixed_precision=FLAGS.use_bfloat16, numpy_outputs=False) estimator_args = {'num_samples': FLAGS.num_dropout_samples_eval} @tf.function def train_step(iterator): """Training step function.""" def step_fn(inputs): """Per-replica step function.""" images = inputs['features'] labels = inputs['labels'] # For minibatch class reweighting, initialize per-batch loss function if class_reweight_mode == 'minibatch': batch_loss_fn = utils.get_minibatch_reweighted_loss_fn( labels=labels) else: batch_loss_fn = loss_fn with tf.GradientTape() as tape: logits = model(images, training=True) if FLAGS.use_bfloat16: logits = tf.cast(logits, tf.float32) negative_log_likelihood = tf.reduce_mean( batch_loss_fn(y_true=tf.expand_dims(labels, axis=-1), y_pred=logits, from_logits=True)) l2_loss = sum(model.losses) loss = negative_log_likelihood + (FLAGS.l2 * l2_loss) # Scale the loss given the TPUStrategy will reduce sum all gradients. scaled_loss = loss / strategy.num_replicas_in_sync grads = tape.gradient(scaled_loss, model.trainable_variables) optimizer.apply_gradients(zip(grads, model.trainable_variables)) probs = tf.nn.sigmoid(logits) metrics['train/loss'].update_state(loss) metrics['train/negative_log_likelihood'].update_state( negative_log_likelihood) metrics['train/accuracy'].update_state(labels, probs) metrics['train/auprc'].update_state(labels, probs) metrics['train/auroc'].update_state(labels, probs) if not use_tpu: metrics['train/ece'].add_batch(probs, label=labels) for _ in tf.range(tf.cast(train_steps_per_epoch, tf.int32)): strategy.run(step_fn, args=(next(iterator), )) start_time = time.time() train_iterator = iter(dataset_train) for epoch in range(initial_epoch, FLAGS.train_epochs): logging.info('Starting to run epoch: %s', epoch + 1) train_step(train_iterator) current_step = (epoch + 1) * train_steps_per_epoch max_steps = train_steps_per_epoch * FLAGS.train_epochs time_elapsed = time.time() - start_time steps_per_sec = float(current_step) / time_elapsed eta_seconds = (max_steps - current_step) / steps_per_sec message = ('{:.1%} completion: epoch {:d}/{:d}. {:.1f} steps/s. ' 'ETA: {:.0f} min. Time elapsed: {:.0f} min'.format( current_step / max_steps, epoch + 1, FLAGS.train_epochs, steps_per_sec, eta_seconds / 60, time_elapsed / 60)) logging.info(message) # Run evaluation on all evaluation datasets, and compute metrics per_pred_results, total_results = utils.evaluate_model_and_compute_metrics( strategy, eval_datasets, steps, metrics, eval_estimator, uncertainty_estimator_fn, per_core_batch_size, available_splits, estimator_args=estimator_args, call_dataset_iter=False, is_deterministic=False, num_bins=FLAGS.num_bins, use_tpu=use_tpu, return_per_pred_results=True) # Optionally log to wandb if FLAGS.use_wandb: wandb.log(total_results, step=epoch) with summary_writer.as_default(): for name, result in total_results.items(): if result is not None: tf.summary.scalar(name, result, step=epoch + 1) for metric in metrics.values(): metric.reset_states() if (FLAGS.checkpoint_interval > 0 and (epoch + 1) % FLAGS.checkpoint_interval == 0): # checkpoint_name = checkpoint.save( # os.path.join(output_dir, 'checkpoint')) # logging.info('Saved checkpoint to %s', checkpoint_name) # TODO(nband): debug checkpointing # Also save Keras model, due to checkpoint.save issue keras_model_name = os.path.join(output_dir, f'keras_model_{epoch + 1}') model.save(keras_model_name) logging.info('Saved keras model to %s', keras_model_name) # Save per-prediction metrics utils.save_per_prediction_results(output_dir, epoch + 1, per_pred_results, verbose=False) # final_checkpoint_name = checkpoint.save( # os.path.join(output_dir, 'checkpoint')) # logging.info('Saved last checkpoint to %s', final_checkpoint_name) keras_model_name = os.path.join(output_dir, f'keras_model_{FLAGS.train_epochs}') model.save(keras_model_name) logging.info('Saved keras model to %s', keras_model_name) # Save per-prediction metrics utils.save_per_prediction_results(output_dir, FLAGS.train_epochs, per_pred_results, verbose=False) with summary_writer.as_default(): hp.hparams({ 'per_core_batch_size': FLAGS.per_core_batch_size, 'base_learning_rate': FLAGS.base_learning_rate, 'one_minus_momentum': FLAGS.one_minus_momentum, 'dropout_rate': FLAGS.dropout_rate, 'l2': FLAGS.l2, }) if wandb_run is not None: wandb_run.finish()
def main(argv): del argv # unused arg tf.io.gfile.makedirs(FLAGS.output_dir) logging.info('Saving checkpoints at %s', FLAGS.output_dir) tf.random.set_seed(FLAGS.seed) # Initialize distribution strategy on flag-specified accelerator strategy = utils.init_distribution_strategy(FLAGS.force_use_cpu, FLAGS.use_gpu, FLAGS.tpu) use_tpu = not (FLAGS.force_use_cpu or FLAGS.use_gpu) # Only permit use of L2 regularization with a tied mean prior if FLAGS.l2 is not None and FLAGS.l2 > 0 and not FLAGS.tied_mean_prior: raise NotImplementedError( 'For a principled objective, L2 regularization should not be used ' 'when the prior mean is untied from the posterior mean.') batch_size = FLAGS.batch_size * FLAGS.num_cores # As per the Kaggle challenge, we have split sizes: # train: 35,126 # validation: 10,906 (currently unused) # test: 42,670 ds_info = tfds.builder('diabetic_retinopathy_detection').info train_dataset_size = ds_info.splits['train'].num_examples steps_per_epoch = train_dataset_size // batch_size steps_per_validation_eval = ( ds_info.splits['validation'].num_examples // batch_size) steps_per_test_eval = ds_info.splits['test'].num_examples // batch_size data_dir = FLAGS.data_dir dataset_train_builder = ub.datasets.get( 'diabetic_retinopathy_detection', split='train', data_dir=data_dir) dataset_train = dataset_train_builder.load(batch_size=batch_size) dataset_train = strategy.experimental_distribute_dataset(dataset_train) dataset_validation_builder = ub.datasets.get( 'diabetic_retinopathy_detection', split='validation', data_dir=data_dir) dataset_validation = dataset_validation_builder.load( batch_size=batch_size) dataset_validation = strategy.experimental_distribute_dataset( dataset_validation) dataset_test_builder = ub.datasets.get( 'diabetic_retinopathy_detection', split='test', data_dir=data_dir) dataset_test = dataset_test_builder.load(batch_size=batch_size) dataset_test = strategy.experimental_distribute_dataset(dataset_test) if FLAGS.use_bfloat16: policy = tf.keras.mixed_precision.experimental.Policy('mixed_bfloat16') tf.keras.mixed_precision.experimental.set_policy(policy) summary_writer = tf.summary.create_file_writer( os.path.join(FLAGS.output_dir, 'summaries')) with strategy.scope(): logging.info('Building Keras ResNet-50 Radial model.') if FLAGS.prior_stddev is None: logging.info( 'A fixed prior stddev was not supplied. Computing a prior stddev = ' 'sqrt(2 / fan_in) for each layer. This is recommended over providing ' 'a fixed prior stddev.') model = ub.models.resnet50_radial( input_shape=utils.load_input_shape(dataset_train), num_classes=1, # binary classification task prior_stddev=FLAGS.prior_stddev, dataset_size=train_dataset_size, stddev_mean_init=FLAGS.stddev_mean_init, stddev_stddev_init=FLAGS.stddev_stddev_init, tied_mean_prior=FLAGS.tied_mean_prior) logging.info('Model input shape: %s', model.input_shape) logging.info('Model output shape: %s', model.output_shape) logging.info('Model number of weights: %s', model.count_params()) # Linearly scale learning rate and the decay epochs by vanilla settings. base_lr = FLAGS.base_learning_rate lr_decay_epochs = [ (int(start_epoch_str) * FLAGS.train_epochs) // DEFAULT_NUM_EPOCHS for start_epoch_str in FLAGS.lr_decay_epochs ] lr_schedule = ub.schedules.WarmUpPiecewiseConstantSchedule( steps_per_epoch, base_lr, decay_ratio=FLAGS.lr_decay_ratio, decay_epochs=lr_decay_epochs, warmup_epochs=FLAGS.lr_warmup_epochs) optimizer = tf.keras.optimizers.SGD( lr_schedule, momentum=0.9, nesterov=True) metrics = utils.get_diabetic_retinopathy_base_metrics( use_tpu=use_tpu, num_bins=FLAGS.num_bins) metrics.update({ 'train/kl': tf.keras.metrics.Mean(), 'train/kl_scale': tf.keras.metrics.Mean() }) checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer) latest_checkpoint = tf.train.latest_checkpoint(FLAGS.output_dir) initial_epoch = 0 if latest_checkpoint: # checkpoint.restore must be within a strategy.scope() # so that optimizer slot variables are mirrored. checkpoint.restore(latest_checkpoint) logging.info('Loaded checkpoint %s', latest_checkpoint) initial_epoch = optimizer.iterations.numpy() // steps_per_epoch # Finally, define OOD metrics outside the accelerator scope for CPU eval. # This will cause an error on TPU. if not use_tpu: metrics.update({ 'train/auc': tf.keras.metrics.AUC(), 'validation/auc': tf.keras.metrics.AUC(), 'test/auc': tf.keras.metrics.AUC() }) @tf.function def train_step(iterator): """Training step function.""" def step_fn(inputs): """Per-replica step function.""" images = inputs['features'] labels = inputs['labels'] with tf.GradientTape() as tape: logits = model(images, training=True) if FLAGS.use_bfloat16: logits = tf.cast(logits, tf.float32) negative_log_likelihood = tf.reduce_mean( tf.keras.losses.binary_crossentropy( y_true=tf.expand_dims(labels, axis=-1), y_pred=logits, from_logits=True)) filtered_variables = [] for var in model.trainable_variables: # Apply l2 on the BN parameters and bias terms. This # excludes only fast weight approximate posterior/prior parameters, # but pay caution to their naming scheme. if 'bn' in var.name or 'bias' in var.name: filtered_variables.append(tf.reshape(var, (-1,))) l2_loss = FLAGS.l2 * 2 * tf.nn.l2_loss( tf.concat(filtered_variables, axis=0)) kl = sum(model.losses) kl_scale = tf.cast(optimizer.iterations + 1, kl.dtype) kl_scale /= steps_per_epoch * FLAGS.kl_annealing_epochs kl_scale = tf.minimum(1., kl_scale) kl_loss = kl_scale * kl loss = negative_log_likelihood + l2_loss + kl_loss # Scale the loss given the TPUStrategy will reduce sum all gradients. scaled_loss = loss / strategy.num_replicas_in_sync grads = tape.gradient(scaled_loss, model.trainable_variables) optimizer.apply_gradients(zip(grads, model.trainable_variables)) probs = tf.squeeze(tf.nn.sigmoid(logits)) metrics['train/loss'].update_state(loss) metrics['train/negative_log_likelihood'].update_state( negative_log_likelihood) metrics['train/kl'].update_state(kl) metrics['train/kl_scale'].update_state(kl_scale) metrics['train/accuracy'].update_state(labels, probs) metrics['train/auc'].update_state(labels, probs) if not use_tpu: metrics['train/ece'].update_state(labels, probs) strategy.run(step_fn, args=(next(iterator),)) @tf.function def test_step(iterator, dataset_split): """Evaluation step function.""" def step_fn(inputs): """Per-replica step function.""" images = inputs['features'] labels = inputs['labels'] logits = model(images, training=False) if FLAGS.use_bfloat16: logits = tf.cast(logits, tf.float32) negative_log_likelihood = tf.reduce_mean( tf.keras.losses.binary_crossentropy( y_true=tf.expand_dims(labels, axis=-1), y_pred=logits, from_logits=True)) probs = tf.squeeze(tf.nn.sigmoid(logits)) metrics[dataset_split + '/negative_log_likelihood'].update_state( negative_log_likelihood) metrics[dataset_split + '/accuracy'].update_state(labels, probs) metrics[dataset_split + '/auc'].update_state(labels, probs) if not use_tpu: metrics[dataset_split + '/ece'].update_state(labels, probs) strategy.run(step_fn, args=(next(iterator),)) metrics.update({'test/ms_per_example': tf.keras.metrics.Mean()}) start_time = time.time() train_iterator = iter(dataset_train) for epoch in range(initial_epoch, FLAGS.train_epochs): logging.info('Starting to run epoch: %s', epoch + 1) for step in range(steps_per_epoch): train_step(train_iterator) current_step = epoch * steps_per_epoch + (step + 1) max_steps = steps_per_epoch * FLAGS.train_epochs time_elapsed = time.time() - start_time steps_per_sec = float(current_step) / time_elapsed eta_seconds = (max_steps - current_step) / steps_per_sec message = ('{:.1%} completion: epoch {:d}/{:d}. {:.1f} steps/s. ' 'ETA: {:.0f} min. Time elapsed: {:.0f} min'.format( current_step / max_steps, epoch + 1, FLAGS.train_epochs, steps_per_sec, eta_seconds / 60, time_elapsed / 60)) if step % 20 == 0: logging.info(message) validation_iterator = iter(dataset_validation) for step in range(steps_per_validation_eval): if step % 20 == 0: logging.info('Starting to run validation eval step %s of epoch: %s', step, epoch + 1) test_step(validation_iterator, 'validation') test_iterator = iter(dataset_test) for step in range(steps_per_test_eval): if step % 20 == 0: logging.info('Starting to run test eval step %s of epoch: %s', step, epoch + 1) test_start_time = time.time() test_step(test_iterator, 'test') ms_per_example = (time.time() - test_start_time) * 1e6 / batch_size metrics['test/ms_per_example'].update_state(ms_per_example) # Log and write to summary the epoch metrics. utils.log_epoch_metrics(metrics=metrics, use_tpu=use_tpu) total_results = {name: metric.result() for name, metric in metrics.items()} with summary_writer.as_default(): for name, result in total_results.items(): tf.summary.scalar(name, result, step=epoch + 1) for metric in metrics.values(): metric.reset_states() if (FLAGS.checkpoint_interval > 0 and (epoch + 1) % FLAGS.checkpoint_interval == 0): checkpoint_name = checkpoint.save( os.path.join(FLAGS.output_dir, 'checkpoint')) logging.info('Saved checkpoint to %s', checkpoint_name) # TODO(nband): debug checkpointing # Also save Keras model, due to checkpoint.save issue keras_model_name = os.path.join(FLAGS.output_dir, f'keras_model_{epoch + 1}') model.save(keras_model_name) logging.info('Saved keras model to %s', keras_model_name) final_checkpoint_name = checkpoint.save( os.path.join(FLAGS.output_dir, 'checkpoint'),) logging.info('Saved last checkpoint to %s', final_checkpoint_name) keras_model_name = os.path.join(FLAGS.output_dir, f'keras_model_{FLAGS.train_epochs}') model.save(keras_model_name) logging.info('Saved keras model to %s', keras_model_name) with summary_writer.as_default(): hp.hparams({ 'base_learning_rate': FLAGS.base_learning_rate, 'one_minus_momentum': FLAGS.one_minus_momentum, 'l2': FLAGS.l2, 'stddev_mean_init': FLAGS.stddev_mean_init, 'stddev_stddev_init': FLAGS.stddev_stddev_init, })
def main(argv): del argv # unused arg tf.random.set_seed(FLAGS.seed) # Wandb Setup if FLAGS.use_wandb: pathlib.Path(FLAGS.wandb_dir).mkdir(parents=True, exist_ok=True) wandb_args = dict( project=FLAGS.project, entity='uncertainty-baselines', dir=FLAGS.wandb_dir, reinit=True, name=FLAGS.exp_name, group=FLAGS.exp_group) wandb_run = wandb.init(**wandb_args) wandb.config.update(FLAGS, allow_val_change=True) output_dir = str( os.path.join(FLAGS.output_dir, datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S'))) else: wandb_run = None output_dir = FLAGS.output_dir tf.io.gfile.makedirs(output_dir) logging.info('Saving checkpoints at %s', output_dir) # Log Run Hypers hypers_dict = { 'batch_size': FLAGS.batch_size, 'base_learning_rate': FLAGS.base_learning_rate, 'one_minus_momentum': FLAGS.one_minus_momentum, 'l2': FLAGS.l2, 'stddev_mean_init': FLAGS.stddev_mean_init, 'stddev_stddev_init': FLAGS.stddev_stddev_init, } logging.info('Hypers:') logging.info(pprint.pformat(hypers_dict)) # Initialize distribution strategy on flag-specified accelerator strategy = utils.init_distribution_strategy(FLAGS.force_use_cpu, FLAGS.use_gpu, FLAGS.tpu) use_tpu = not (FLAGS.force_use_cpu or FLAGS.use_gpu) # Only permit use of L2 regularization with a tied mean prior if FLAGS.l2 is not None and FLAGS.l2 > 0 and not FLAGS.tied_mean_prior: raise NotImplementedError( 'For a principled objective, L2 regularization should not be used ' 'when the prior mean is untied from the posterior mean.') batch_size = FLAGS.batch_size * FLAGS.num_cores # Reweighting loss for class imbalance class_reweight_mode = FLAGS.class_reweight_mode if class_reweight_mode == 'constant': class_weights = utils.get_diabetic_retinopathy_class_balance_weights() else: class_weights = None # Load in datasets. datasets, steps = utils.load_dataset( train_batch_size=batch_size, eval_batch_size=batch_size, flags=FLAGS, strategy=strategy) available_splits = list(datasets.keys()) test_splits = [split for split in available_splits if 'test' in split] eval_splits = [ split for split in available_splits if 'validation' in split or 'test' in split ] # Iterate eval datasets eval_datasets = {split: iter(datasets[split]) for split in eval_splits} dataset_train = datasets['train'] train_steps_per_epoch = steps['train'] train_dataset_size = train_steps_per_epoch * batch_size if FLAGS.use_bfloat16: tf.keras.mixed_precision.set_global_policy('mixed_bfloat16') summary_writer = tf.summary.create_file_writer( os.path.join(output_dir, 'summaries')) if FLAGS.prior_stddev is None: logging.info( 'A fixed prior stddev was not supplied. Computing a prior stddev = ' 'sqrt(2 / fan_in) for each layer. This is recommended over providing ' 'a fixed prior stddev.') with strategy.scope(): logging.info('Building Keras ResNet-50 Variational Inference model.') model = None if FLAGS.load_from_checkpoint: initial_epoch, model = utils.load_keras_checkpoints( FLAGS.checkpoint_dir, load_ensemble=False, return_epoch=True) else: initial_epoch = 0 model = ub.models.resnet50_variational( input_shape=utils.load_input_shape(dataset_train), num_classes=1, # binary classification task prior_stddev=FLAGS.prior_stddev, dataset_size=train_dataset_size, stddev_mean_init=FLAGS.stddev_mean_init, stddev_stddev_init=FLAGS.stddev_stddev_init, tied_mean_prior=FLAGS.tied_mean_prior) utils.log_model_init_info(model=model) # Linearly scale learning rate and the decay epochs by vanilla settings. base_lr = FLAGS.base_learning_rate lr_decay_epochs = [ (int(start_epoch_str) * FLAGS.train_epochs) // DEFAULT_NUM_EPOCHS for start_epoch_str in FLAGS.lr_decay_epochs ] lr_schedule = ub.schedules.WarmUpPiecewiseConstantSchedule( train_steps_per_epoch, base_lr, decay_ratio=FLAGS.lr_decay_ratio, decay_epochs=lr_decay_epochs, warmup_epochs=FLAGS.lr_warmup_epochs) optimizer = tf.keras.optimizers.SGD( lr_schedule, momentum=1.0 - FLAGS.one_minus_momentum, nesterov=True) metrics = utils.get_diabetic_retinopathy_base_metrics( use_tpu=use_tpu, num_bins=FLAGS.num_bins, use_validation=FLAGS.use_validation, available_splits=available_splits) # VI specific metrics metrics.update({ 'train/kl': tf.keras.metrics.Mean(), 'train/kl_scale': tf.keras.metrics.Mean() }) # TODO(nband): debug or remove # checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer) # latest_checkpoint = tf.train.latest_checkpoint(output_dir) # if latest_checkpoint: # # checkpoint.restore must be within a strategy.scope() # # so that optimizer slot variables are mirrored. # checkpoint.restore(latest_checkpoint) # logging.info('Loaded checkpoint %s', latest_checkpoint) # initial_epoch = optimizer.iterations.numpy() // steps_per_epoch # Define metrics outside the accelerator scope for CPU eval. # This will cause an error on TPU. if not use_tpu: metrics.update( utils.get_diabetic_retinopathy_cpu_metrics( available_splits=available_splits, use_validation=FLAGS.use_validation)) for test_split in test_splits: metrics.update({f'{test_split}/ms_per_example': tf.keras.metrics.Mean()}) # Initialize loss function based on class reweighting setting loss_fn = utils.get_diabetic_retinopathy_loss_fn( class_reweight_mode=class_reweight_mode, class_weights=class_weights) # * Prepare for Evaluation * # Get the wrapper function which will produce uncertainty estimates for # our choice of method and Y/N ensembling. uncertainty_estimator_fn = utils.get_uncertainty_estimator( 'variational_inference', use_ensemble=False, use_tf=True) # Wrap our estimator to predict probabilities (apply sigmoid on logits) eval_estimator = utils.wrap_retinopathy_estimator( model, use_mixed_precision=FLAGS.use_bfloat16, numpy_outputs=False) estimator_args = {'num_samples': FLAGS.num_mc_samples_eval} @tf.function def train_step(iterator): """Training step function.""" print('tracing training') def step_fn(inputs): """Per-replica step function.""" images = inputs['features'] labels = inputs['labels'] # For minibatch class reweighting, initialize per-batch loss function if class_reweight_mode == 'minibatch': print('Retracing loss fn retrieval') batch_loss_fn = utils.get_minibatch_reweighted_loss_fn(labels=labels) else: batch_loss_fn = loss_fn with tf.GradientTape() as tape: # TODO(nband): TPU-friendly implem if FLAGS.num_mc_samples_train > 1: logits_arr = tf.TensorArray( tf.float32, size=FLAGS.num_mc_samples_train) for i in tf.range(FLAGS.num_mc_samples_train): logits = model(images, training=True) # logits = tf.squeeze(logits, axis=-1) # if FLAGS.use_bfloat16: # logits = tf.cast(logits, tf.float32) logits_arr = logits_arr.write(i, logits) logits_list = logits_arr.stack() # if FLAGS.num_mc_samples_train > 1: # # Pythonic Implem # logits_list = [] # for _ in range(FLAGS.num_mc_samples_train): # print('Tracing for loop') # logits = model(images, training=True) # if FLAGS.use_bfloat16: # print('tracing bfloat conditional') # logits = tf.cast(logits, tf.float32) # # logits = tf.squeeze(logits, axis=-1) # logits_list.append(logits) # # # Logits dimension is (num_samples, batch_size). # logits_list = tf.stack(logits_list, axis=0) probs_list = tf.nn.sigmoid(logits_list) probs = tf.reduce_mean(probs_list, axis=0) negative_log_likelihood = tf.reduce_mean( batch_loss_fn( y_true=tf.expand_dims(labels, axis=-1), y_pred=probs, from_logits=False)) else: # Single train step logits = model(images, training=True) if FLAGS.use_bfloat16: logits = tf.cast(logits, tf.float32) negative_log_likelihood = tf.reduce_mean( batch_loss_fn( y_true=tf.expand_dims(labels, axis=-1), y_pred=logits, from_logits=True)) probs = tf.squeeze(tf.nn.sigmoid(logits)) filtered_variables = [] for var in model.trainable_variables: # Apply l2 on the BN parameters and bias terms. This # excludes only fast weight approximate posterior/prior parameters, # but pay caution to their naming scheme. if 'bn' in var.name or 'bias' in var.name: filtered_variables.append(tf.reshape(var, (-1,))) l2_loss = FLAGS.l2 * 2 * tf.nn.l2_loss( tf.concat(filtered_variables, axis=0)) kl = sum(model.losses) kl_scale = tf.cast(optimizer.iterations + 1, kl.dtype) kl_scale /= train_steps_per_epoch * FLAGS.kl_annealing_epochs kl_scale = tf.minimum(1., kl_scale) kl_loss = kl_scale * kl loss = negative_log_likelihood + l2_loss + kl_loss # Scale the loss given the TPUStrategy will reduce sum all gradients. scaled_loss = loss / strategy.num_replicas_in_sync grads = tape.gradient(scaled_loss, model.trainable_variables) optimizer.apply_gradients(zip(grads, model.trainable_variables)) metrics['train/loss'].update_state(loss) metrics['train/negative_log_likelihood'].update_state( negative_log_likelihood) metrics['train/kl'].update_state(kl) metrics['train/kl_scale'].update_state(kl_scale) metrics['train/accuracy'].update_state(labels, probs) metrics['train/auprc'].update_state(labels, probs) metrics['train/auroc'].update_state(labels, probs) if not use_tpu: metrics['train/ece'].add_batch(probs, label=labels) for _ in tf.range(tf.cast(train_steps_per_epoch, tf.int32)): strategy.run(step_fn, args=(next(iterator),)) start_time = time.time() train_iterator = iter(dataset_train) for epoch in range(initial_epoch, FLAGS.train_epochs): logging.info('Starting to run epoch: %s', epoch + 1) train_step(train_iterator) current_step = (epoch + 1) * train_steps_per_epoch max_steps = train_steps_per_epoch * FLAGS.train_epochs time_elapsed = time.time() - start_time steps_per_sec = float(current_step) / time_elapsed eta_seconds = (max_steps - current_step) / steps_per_sec message = ('{:.1%} completion: epoch {:d}/{:d}. {:.1f} steps/s. ' 'ETA: {:.0f} min. Time elapsed: {:.0f} min'.format( current_step / max_steps, epoch + 1, FLAGS.train_epochs, steps_per_sec, eta_seconds / 60, time_elapsed / 60)) logging.info(message) # eval_datasets = {'ood_validation': eval_datasets['ood_validation']} # Run evaluation on all evaluation datasets, and compute metrics per_pred_results, total_results = utils.evaluate_model_and_compute_metrics( strategy, eval_datasets, steps, metrics, eval_estimator, uncertainty_estimator_fn, batch_size, available_splits, estimator_args=estimator_args, call_dataset_iter=False, is_deterministic=False, num_bins=FLAGS.num_bins, use_tpu=use_tpu, return_per_pred_results=True) # Optionally log to wandb if FLAGS.use_wandb: wandb.log(total_results, step=epoch) with summary_writer.as_default(): for name, result in total_results.items(): if result is not None: tf.summary.scalar(name, result, step=epoch + 1) for metric in metrics.values(): metric.reset_states() if (FLAGS.checkpoint_interval > 0 and (epoch + 1) % FLAGS.checkpoint_interval == 0): # checkpoint_name = checkpoint.save( # os.path.join(output_dir, 'checkpoint')) # logging.info('Saved checkpoint to %s', checkpoint_name) # TODO(nband): debug checkpointing # Also save Keras model, due to checkpoint.save issue. keras_model_name = os.path.join(output_dir, f'keras_model_{epoch + 1}') model.save(keras_model_name) logging.info('Saved keras model to %s', keras_model_name) # Save per-prediction metrics utils.save_per_prediction_results( output_dir, epoch + 1, per_pred_results, verbose=False) # final_checkpoint_name = checkpoint.save( # os.path.join(output_dir, 'checkpoint'),) # logging.info('Saved last checkpoint to %s', final_checkpoint_name) keras_model_name = os.path.join(output_dir, f'keras_model_{FLAGS.train_epochs}') model.save(keras_model_name) logging.info('Saved keras model to %s', keras_model_name) # Save per-prediction metrics utils.save_per_prediction_results( output_dir, FLAGS.train_epochs, per_pred_results, verbose=False) with summary_writer.as_default(): hp.hparams({ 'base_learning_rate': FLAGS.base_learning_rate, 'one_minus_momentum': FLAGS.one_minus_momentum, 'l2': FLAGS.l2, 'stddev_mean_init': FLAGS.stddev_mean_init, 'stddev_stddev_init': FLAGS.stddev_stddev_init, }) if wandb_run is not None: wandb_run.finish()