def main(argv):
  del argv  # unused arg
  tf.io.gfile.makedirs(FLAGS.output_dir)
  logging.info('Saving checkpoints at %s', FLAGS.output_dir)
  tf.random.set_seed(FLAGS.seed)

  # Initialize distribution strategy on flag-specified accelerator
  strategy = utils.init_distribution_strategy(FLAGS.force_use_cpu,
                                              FLAGS.use_gpu, FLAGS.tpu)
  use_tpu = not (FLAGS.force_use_cpu or FLAGS.use_gpu)

  train_batch_size = FLAGS.train_batch_size * FLAGS.num_cores
  eval_batch_size = FLAGS.eval_batch_size * FLAGS.num_cores

  # Reweighting loss for class imbalance
  class_reweight_mode = FLAGS.class_reweight_mode
  if class_reweight_mode == 'constant':
    class_weights = utils.get_diabetic_retinopathy_class_balance_weights()
  else:
    class_weights = None

  # As per the Kaggle challenge, we have split sizes:
  # train: 35,126
  # validation: 10,906 (currently unused)
  # test: 42,670
  ds_info = tfds.builder('diabetic_retinopathy_detection').info
  steps_per_epoch = ds_info.splits['train'].num_examples // train_batch_size
  steps_per_validation_eval = (
      ds_info.splits['validation'].num_examples // eval_batch_size)
  steps_per_test_eval = ds_info.splits['test'].num_examples // eval_batch_size

  data_dir = FLAGS.data_dir

  dataset_train_builder = ub.datasets.get(
      'diabetic_retinopathy_detection', split='train', data_dir=data_dir)
  dataset_train = dataset_train_builder.load(batch_size=train_batch_size)

  dataset_validation_builder = ub.datasets.get(
      'diabetic_retinopathy_detection',
      split='validation',
      data_dir=data_dir,
      is_training=not FLAGS.use_validation)
  validation_batch_size = (
      eval_batch_size if FLAGS.use_validation else train_batch_size)
  dataset_validation = dataset_validation_builder.load(
      batch_size=validation_batch_size)
  if FLAGS.use_validation:
    dataset_validation = strategy.experimental_distribute_dataset(
        dataset_validation)
  else:
    # Note that this will not create any mixed batches of train and validation
    # images.
    dataset_train = dataset_train.concatenate(dataset_validation)

  dataset_train = strategy.experimental_distribute_dataset(dataset_train)

  dataset_test_builder = ub.datasets.get(
      'diabetic_retinopathy_detection', split='test', data_dir=data_dir)
  dataset_test = dataset_test_builder.load(batch_size=eval_batch_size)
  dataset_test = strategy.experimental_distribute_dataset(dataset_test)

  if FLAGS.use_bfloat16:
    policy = tf.keras.mixed_precision.experimental.Policy('mixed_bfloat16')
    tf.keras.mixed_precision.experimental.set_policy(policy)

  summary_writer = tf.summary.create_file_writer(
      os.path.join(FLAGS.output_dir, 'summaries'))

  with strategy.scope():
    logging.info('Building Keras ResNet-50 deterministic model.')
    model = ub.models.resnet50_deterministic(
        input_shape=utils.load_input_shape(dataset_train),
        num_classes=1)  # binary classification task
    logging.info('Model input shape: %s', model.input_shape)
    logging.info('Model output shape: %s', model.output_shape)
    logging.info('Model number of weights: %s', model.count_params())

    base_lr = FLAGS.base_learning_rate
    if FLAGS.lr_schedule == 'step':
      lr_decay_epochs = [
          (int(start_epoch_str) * FLAGS.train_epochs) // DEFAULT_NUM_EPOCHS
          for start_epoch_str in FLAGS.lr_decay_epochs
      ]
      lr_schedule = ub.schedules.WarmUpPiecewiseConstantSchedule(
          steps_per_epoch,
          base_lr,
          decay_ratio=FLAGS.lr_decay_ratio,
          decay_epochs=lr_decay_epochs,
          warmup_epochs=FLAGS.lr_warmup_epochs)
    else:
      lr_schedule = ub.schedules.WarmUpPolynomialSchedule(
          base_lr,
          end_learning_rate=FLAGS.final_decay_factor * base_lr,
          decay_steps=(
              steps_per_epoch * (FLAGS.train_epochs - FLAGS.lr_warmup_epochs)),
          warmup_steps=steps_per_epoch * FLAGS.lr_warmup_epochs,
          decay_power=1.0)
    optimizer = tf.keras.optimizers.SGD(
        lr_schedule, momentum=1.0 - FLAGS.one_minus_momentum, nesterov=True)
    metrics = utils.get_diabetic_retinopathy_base_metrics(
        use_tpu=use_tpu,
        num_bins=FLAGS.num_bins,
        use_validation=FLAGS.use_validation)
    checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer)
    latest_checkpoint = tf.train.latest_checkpoint(FLAGS.output_dir)
    initial_epoch = 0
    if latest_checkpoint:
      # checkpoint.restore must be within a strategy.scope()
      # so that optimizer slot variables are mirrored.
      checkpoint.restore(latest_checkpoint)
      logging.info('Loaded checkpoint %s', latest_checkpoint)
      initial_epoch = optimizer.iterations.numpy() // steps_per_epoch

  # Define metrics outside the accelerator scope for CPU eval.
  # This will cause an error on TPU.
  if not use_tpu:
    metrics.update(
        utils.get_diabetic_retinopathy_cpu_metrics(
            use_validation=FLAGS.use_validation))
  metrics.update({'test/ms_per_example': tf.keras.metrics.Mean()})

  # Initialize loss function based on class reweighting setting
  loss_fn = utils.get_diabetic_retinopathy_loss_fn(
      class_reweight_mode=class_reweight_mode, class_weights=class_weights)

  @tf.function
  def train_step(iterator):
    """Training step function."""

    def step_fn(inputs):
      """Per-replica step function."""
      images = inputs['features']
      labels = inputs['labels']

      # For minibatch class reweighting, initialize per-batch loss function
      if class_reweight_mode == 'minibatch':
        batch_loss_fn = utils.get_minibatch_reweighted_loss_fn(labels=labels)
      else:
        batch_loss_fn = loss_fn

      with tf.GradientTape() as tape:
        logits = model(images, training=True)
        if FLAGS.use_bfloat16:
          logits = tf.cast(logits, tf.float32)

        negative_log_likelihood = tf.reduce_mean(
            batch_loss_fn(
                y_true=tf.expand_dims(labels, axis=-1),
                y_pred=logits,
                from_logits=True))
        l2_loss = sum(model.losses)
        loss = negative_log_likelihood + (FLAGS.l2 * l2_loss)

        # Scale the loss given the TPUStrategy will reduce sum all gradients.
        scaled_loss = loss / strategy.num_replicas_in_sync

      grads = tape.gradient(scaled_loss, model.trainable_variables)
      optimizer.apply_gradients(zip(grads, model.trainable_variables))
      probs = tf.squeeze(tf.nn.sigmoid(logits))

      metrics['train/loss'].update_state(loss)
      metrics['train/negative_log_likelihood'].update_state(
          negative_log_likelihood)
      metrics['train/accuracy'].update_state(labels, probs)
      metrics['train/auprc'].update_state(labels, probs)
      metrics['train/auroc'].update_state(labels, probs)

      if not use_tpu:
        metrics['train/ece'].add_batch(probs, label=labels)

    for _ in tf.range(tf.cast(steps_per_epoch, tf.int32)):
      strategy.run(step_fn, args=(next(iterator),))

  @tf.function
  def test_step(iterator, dataset_split, num_steps):
    """Evaluation step function."""

    def step_fn(inputs):
      """Per-replica step function."""
      images = inputs['features']
      labels = inputs['labels']
      logits = model(images, training=False)
      if FLAGS.use_bfloat16:
        logits = tf.cast(logits, tf.float32)

      negative_log_likelihood = tf.reduce_mean(
          tf.keras.losses.binary_crossentropy(
              y_true=tf.expand_dims(labels, axis=-1),
              y_pred=logits,
              from_logits=True))
      probs = tf.squeeze(tf.nn.sigmoid(logits))

      metrics[dataset_split + '/negative_log_likelihood'].update_state(
          negative_log_likelihood)
      metrics[dataset_split + '/accuracy'].update_state(labels, probs)
      metrics['test/accuracy'].update_state(labels, probs)
      metrics[dataset_split + '/auprc'].update_state(labels, probs)
      metrics[dataset_split + '/auroc'].update_state(labels, probs)

      if not use_tpu:
        metrics[dataset_split + '/ece'].add_batch(probs, label=labels)

    for _ in tf.range(tf.cast(num_steps, tf.int32)):
      strategy.run(step_fn, args=(next(iterator),))

  start_time = time.time()

  train_iterator = iter(dataset_train)
  for epoch in range(initial_epoch, FLAGS.train_epochs):
    logging.info('Starting to run epoch: %s', epoch + 1)
    train_step(train_iterator)

    current_step = (epoch + 1) * steps_per_epoch
    max_steps = steps_per_epoch * FLAGS.train_epochs
    time_elapsed = time.time() - start_time
    steps_per_sec = float(current_step) / time_elapsed
    eta_seconds = (max_steps - current_step) / steps_per_sec
    message = ('{:.1%} completion: epoch {:d}/{:d}. {:.1f} steps/s. '
               'ETA: {:.0f} min. Time elapsed: {:.0f} min'.format(
                   current_step / max_steps, epoch + 1, FLAGS.train_epochs,
                   steps_per_sec, eta_seconds / 60, time_elapsed / 60))
    logging.info(message)

    if FLAGS.use_validation:
      validation_iterator = iter(dataset_validation)
      logging.info('Starting to run validation eval ay epoch: %s', epoch + 1)
      test_step(validation_iterator, 'validation', steps_per_validation_eval)

    test_iterator = iter(dataset_test)
    logging.info('Starting to run test eval at epoch: %s', epoch + 1)
    test_start_time = time.time()
    test_step(test_iterator, 'test', steps_per_test_eval)
    ms_per_example = (time.time() - test_start_time) * 1e6 / eval_batch_size
    metrics['test/ms_per_example'].update_state(ms_per_example)

    # Log and write to summary the epoch metrics
    utils.log_epoch_metrics(metrics=metrics, use_tpu=use_tpu)
    total_results = {name: metric.result() for name, metric in metrics.items()}
    # Metrics from Robustness Metrics (like ECE) will return a dict with a
    # single key/value, instead of a scalar.
    total_results = {
        k: (list(v.values())[0] if isinstance(v, dict) else v)
        for k, v in total_results.items()
    }
    with summary_writer.as_default():
      for name, result in total_results.items():
        tf.summary.scalar(name, result, step=epoch + 1)

    for metric in metrics.values():
      metric.reset_states()

    if (FLAGS.checkpoint_interval > 0 and
        (epoch + 1) % FLAGS.checkpoint_interval == 0):
      checkpoint_name = checkpoint.save(
          os.path.join(FLAGS.output_dir, 'checkpoint'))
      logging.info('Saved checkpoint to %s', checkpoint_name)

      # TODO(nband): debug checkpointing
      # Also save Keras model, due to checkpoint.save issue
      keras_model_name = os.path.join(FLAGS.output_dir,
                                      f'keras_model_{epoch + 1}')
      model.save(keras_model_name)
      logging.info('Saved keras model to %s', keras_model_name)

  final_checkpoint_name = checkpoint.save(
      os.path.join(FLAGS.output_dir, 'checkpoint'))
  logging.info('Saved last checkpoint to %s', final_checkpoint_name)

  keras_model_name = os.path.join(FLAGS.output_dir,
                                  f'keras_model_{FLAGS.train_epochs}')
  model.save(keras_model_name)
  logging.info('Saved keras model to %s', keras_model_name)
  with summary_writer.as_default():
    hp.hparams({
        'base_learning_rate': FLAGS.base_learning_rate,
        'one_minus_momentum': FLAGS.one_minus_momentum,
        'l2': FLAGS.l2,
    })
Esempio n. 2
0
def main(argv):
    del argv  # unused arg
    tf.io.gfile.makedirs(FLAGS.output_dir)
    logging.info('Saving checkpoints at %s', FLAGS.output_dir)
    tf.random.set_seed(FLAGS.seed)

    # Initialize distribution strategy on flag-specified accelerator
    strategy = utils.init_distribution_strategy(FLAGS.force_use_cpu,
                                                FLAGS.use_gpu, FLAGS.tpu)
    use_tpu = not (FLAGS.force_use_cpu or FLAGS.use_gpu)

    train_batch_size = (FLAGS.train_batch_size * FLAGS.num_cores)

    if use_tpu:
        logging.info(
            'Due to TPU requiring static shapes, we must fix the eval batch size '
            'to the train batch size: %d/', train_batch_size)
        eval_batch_size = train_batch_size
    else:
        eval_batch_size = (FLAGS.eval_batch_size *
                           FLAGS.num_cores) // FLAGS.num_dropout_samples_eval

    # As per the Kaggle challenge, we have split sizes:
    # train: 35,126
    # validation: 10,906 (currently unused)
    # test: 42,670
    ds_info = tfds.builder('diabetic_retinopathy_detection').info
    steps_per_epoch = ds_info.splits['train'].num_examples // train_batch_size
    steps_per_validation_eval = (ds_info.splits['validation'].num_examples //
                                 eval_batch_size)
    steps_per_test_eval = ds_info.splits['test'].num_examples // eval_batch_size

    data_dir = FLAGS.data_dir

    dataset_train_builder = ub.datasets.get('diabetic_retinopathy_detection',
                                            split='train',
                                            data_dir=data_dir)
    dataset_train = dataset_train_builder.load(batch_size=train_batch_size)
    dataset_train = strategy.experimental_distribute_dataset(dataset_train)

    dataset_validation_builder = ub.datasets.get(
        'diabetic_retinopathy_detection',
        split='validation',
        data_dir=data_dir)
    dataset_validation = dataset_validation_builder.load(
        batch_size=eval_batch_size)
    dataset_validation = strategy.experimental_distribute_dataset(
        dataset_validation)

    dataset_test_builder = ub.datasets.get('diabetic_retinopathy_detection',
                                           split='test',
                                           data_dir=data_dir)
    dataset_test = dataset_test_builder.load(batch_size=eval_batch_size)
    dataset_test = strategy.experimental_distribute_dataset(dataset_test)

    if FLAGS.use_bfloat16:
        policy = tf.keras.mixed_precision.experimental.Policy('mixed_bfloat16')
        tf.keras.mixed_precision.experimental.set_policy(policy)

    summary_writer = tf.summary.create_file_writer(
        os.path.join(FLAGS.output_dir, 'summaries'))

    with strategy.scope():
        logging.info('Building Keras ResNet-50 MC Dropout model.')
        model = ub.models.resnet50_dropout(
            input_shape=utils.load_input_shape(dataset_train),
            num_classes=1,  # binary classification task
            dropout_rate=FLAGS.dropout_rate,
            filterwise_dropout=FLAGS.filterwise_dropout)
        logging.info('Model input shape: %s', model.input_shape)
        logging.info('Model output shape: %s', model.output_shape)
        logging.info('Model number of weights: %s', model.count_params())

        # Linearly scale learning rate and the decay epochs by vanilla settings.
        base_lr = FLAGS.base_learning_rate
        lr_decay_epochs = [
            (int(start_epoch_str) * FLAGS.train_epochs) // DEFAULT_NUM_EPOCHS
            for start_epoch_str in FLAGS.lr_decay_epochs
        ]

        lr_schedule = ub.schedules.WarmUpPiecewiseConstantSchedule(
            steps_per_epoch,
            base_lr,
            decay_ratio=FLAGS.lr_decay_ratio,
            decay_epochs=lr_decay_epochs,
            warmup_epochs=FLAGS.lr_warmup_epochs)
        optimizer = tf.keras.optimizers.SGD(lr_schedule,
                                            momentum=1.0 -
                                            FLAGS.one_minus_momentum,
                                            nesterov=True)
        metrics = utils.get_diabetic_retinopathy_base_metrics(
            use_tpu=use_tpu, num_bins=FLAGS.num_bins)
        checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer)
        latest_checkpoint = tf.train.latest_checkpoint(FLAGS.output_dir)
        initial_epoch = 0
        if latest_checkpoint:
            # checkpoint.restore must be within a strategy.scope()
            # so that optimizer slot variables are mirrored.
            checkpoint.restore(latest_checkpoint)
            logging.info('Loaded checkpoint %s', latest_checkpoint)
            initial_epoch = optimizer.iterations.numpy() // steps_per_epoch

    # Finally, define OOD metrics outside the accelerator scope for CPU eval.
    # This will cause an error on TPU.
    if not use_tpu:
        metrics.update({
            'train/auc': tf.keras.metrics.AUC(),
            'validation/auc': tf.keras.metrics.AUC(),
            'test/auc': tf.keras.metrics.AUC()
        })

    @tf.function
    def train_step(iterator):
        """Training step function."""
        def step_fn(inputs):
            """Per-replica step function."""
            images = inputs['features']
            labels = inputs['labels']

            with tf.GradientTape() as tape:
                logits = model(images, training=True)
                if FLAGS.use_bfloat16:
                    logits = tf.cast(logits, tf.float32)

                negative_log_likelihood = tf.reduce_mean(
                    tf.keras.losses.binary_crossentropy(y_true=tf.expand_dims(
                        labels, axis=-1),
                                                        y_pred=logits,
                                                        from_logits=True))
                l2_loss = sum(model.losses)
                loss = negative_log_likelihood + (FLAGS.l2 * l2_loss)

                # Scale the loss given the TPUStrategy will reduce sum all gradients.
                scaled_loss = loss / strategy.num_replicas_in_sync

            grads = tape.gradient(scaled_loss, model.trainable_variables)
            optimizer.apply_gradients(zip(grads, model.trainable_variables))
            probs = tf.nn.sigmoid(logits)

            metrics['train/loss'].update_state(loss)
            metrics['train/negative_log_likelihood'].update_state(
                negative_log_likelihood)
            metrics['train/accuracy'].update_state(labels, probs)
            metrics['train/auc'].update_state(labels, probs)

            if not use_tpu:
                metrics['train/ece'].update_state(labels, probs)

        strategy.run(step_fn, args=(next(iterator), ))

    @tf.function
    def test_step(iterator, dataset_split):
        """Evaluation step function."""
        def step_fn(inputs):
            """Per-replica step function."""
            images = inputs['features']
            labels = tf.convert_to_tensor(inputs['labels'])

            logits_list = []
            for _ in range(FLAGS.num_dropout_samples_eval):
                logits = model(images, training=False)
                logits = tf.squeeze(logits, axis=-1)
                if FLAGS.use_bfloat16:
                    logits = tf.cast(logits, tf.float32)

                logits_list.append(logits)

            # Logits dimension is (num_samples, batch_size).
            logits_list = tf.stack(logits_list, axis=0)
            probs_list = tf.nn.sigmoid(logits_list)
            probs = tf.reduce_mean(probs_list, axis=0)
            labels_broadcasted = tf.broadcast_to(
                labels, [FLAGS.num_dropout_samples_eval, labels.shape[0]])
            log_likelihoods = -tf.keras.losses.binary_crossentropy(
                labels_broadcasted, logits_list, from_logits=True)
            negative_log_likelihood = tf.reduce_mean(
                -tf.reduce_logsumexp(log_likelihoods, axis=[0]) +
                tf.math.log(float(FLAGS.num_dropout_samples_eval)))
            metrics[dataset_split + '/negative_log_likelihood'].update_state(
                negative_log_likelihood)
            metrics[dataset_split + '/accuracy'].update_state(labels, probs)
            metrics[dataset_split + '/auc'].update_state(labels, probs)

            if not use_tpu:
                metrics[dataset_split + '/ece'].update_state(labels, probs)

        strategy.run(step_fn, args=(next(iterator), ))

    metrics.update({'test/ms_per_example': tf.keras.metrics.Mean()})
    start_time = time.time()

    train_iterator = iter(dataset_train)
    for epoch in range(initial_epoch, FLAGS.train_epochs):
        logging.info('Starting to run epoch: %s', epoch + 1)
        for step in range(steps_per_epoch):
            train_step(train_iterator)

            current_step = epoch * steps_per_epoch + (step + 1)
            max_steps = steps_per_epoch * FLAGS.train_epochs
            time_elapsed = time.time() - start_time
            steps_per_sec = float(current_step) / time_elapsed
            eta_seconds = (max_steps - current_step) / steps_per_sec
            message = ('{:.1%} completion: epoch {:d}/{:d}. {:.1f} steps/s. '
                       'ETA: {:.0f} min. Time elapsed: {:.0f} min'.format(
                           current_step / max_steps, epoch + 1,
                           FLAGS.train_epochs, steps_per_sec, eta_seconds / 60,
                           time_elapsed / 60))
            if step % 20 == 0:
                logging.info(message)

        validation_iterator = iter(dataset_validation)
        for step in range(steps_per_validation_eval):
            if step % 20 == 0:
                logging.info(
                    'Starting to run validation eval step %s of epoch: %s',
                    step, epoch + 1)
            test_step(validation_iterator, 'validation')

        test_iterator = iter(dataset_test)
        for step in range(steps_per_test_eval):
            if step % 20 == 0:
                logging.info('Starting to run test eval step %s of epoch: %s',
                             step, epoch + 1)
            test_start_time = time.time()
            test_step(test_iterator, 'test')
            ms_per_example = (time.time() -
                              test_start_time) * 1e6 / eval_batch_size
            metrics['test/ms_per_example'].update_state(ms_per_example)

        # Log and write to summary the epoch metrics
        utils.log_epoch_metrics(metrics=metrics, use_tpu=use_tpu)
        total_results = {
            name: metric.result()
            for name, metric in metrics.items()
        }
        with summary_writer.as_default():
            for name, result in total_results.items():
                tf.summary.scalar(name, result, step=epoch + 1)

        for metric in metrics.values():
            metric.reset_states()

        if (FLAGS.checkpoint_interval > 0
                and (epoch + 1) % FLAGS.checkpoint_interval == 0):
            checkpoint_name = checkpoint.save(
                os.path.join(FLAGS.output_dir, 'checkpoint'))
            logging.info('Saved checkpoint to %s', checkpoint_name)

            # TODO(nband): debug checkpointing
            # Also save Keras model, due to checkpoint.save issue
            keras_model_name = os.path.join(FLAGS.output_dir,
                                            f'keras_model_{epoch + 1}')
            model.save(keras_model_name)
            logging.info('Saved keras model to %s', keras_model_name)

    final_checkpoint_name = checkpoint.save(
        os.path.join(FLAGS.output_dir, 'checkpoint'))
    logging.info('Saved last checkpoint to %s', final_checkpoint_name)

    keras_model_name = os.path.join(FLAGS.output_dir,
                                    f'keras_model_{FLAGS.train_epochs}')
    model.save(keras_model_name)
    logging.info('Saved keras model to %s', keras_model_name)
    with summary_writer.as_default():
        hp.hparams({
            'base_learning_rate': FLAGS.base_learning_rate,
            'one_minus_momentum': FLAGS.one_minus_momentum,
            'dropout_rate': FLAGS.dropout_rate,
            'l2': FLAGS.l2,
        })
Esempio n. 3
0
def main(argv):
    del argv  # unused arg
    tf.io.gfile.makedirs(FLAGS.output_dir)
    logging.info('Saving checkpoints at %s', FLAGS.output_dir)

    # Set seeds
    tf.random.set_seed(FLAGS.seed)
    np.random.seed(FLAGS.seed)
    torch.manual_seed(FLAGS.seed)

    # Resolve CUDA device(s)
    if FLAGS.use_gpu and torch.cuda.is_available():
        print('Running model with CUDA.')
        device = 'cuda:0'
    else:
        print('Running model on CPU.')
        device = 'cpu'

    train_batch_size = FLAGS.train_batch_size
    eval_batch_size = FLAGS.eval_batch_size // FLAGS.num_dropout_samples_eval

    # As per the Kaggle challenge, we have split sizes:
    # train: 35,126
    # validation: 10,906
    # test: 42,670
    ds_info = tfds.builder('diabetic_retinopathy_detection').info
    steps_per_epoch = ds_info.splits['train'].num_examples // train_batch_size
    steps_per_validation_eval = (ds_info.splits['validation'].num_examples //
                                 eval_batch_size)
    steps_per_test_eval = ds_info.splits['test'].num_examples // eval_batch_size

    data_dir = FLAGS.data_dir

    dataset_train_builder = ub.datasets.get('diabetic_retinopathy_detection',
                                            split='train',
                                            data_dir=data_dir)
    dataset_train = dataset_train_builder.load(batch_size=train_batch_size)

    dataset_validation_builder = ub.datasets.get(
        'diabetic_retinopathy_detection',
        split='validation',
        data_dir=data_dir,
        is_training=not FLAGS.use_validation)
    validation_batch_size = (eval_batch_size
                             if FLAGS.use_validation else train_batch_size)
    dataset_validation = dataset_validation_builder.load(
        batch_size=validation_batch_size)
    if not FLAGS.use_validation:
        # Note that this will not create any mixed batches of train and validation
        # images.
        dataset_train = dataset_train.concatenate(dataset_validation)

    dataset_test_builder = ub.datasets.get('diabetic_retinopathy_detection',
                                           split='test',
                                           data_dir=data_dir)
    dataset_test = dataset_test_builder.load(batch_size=eval_batch_size)

    summary_writer = tf.summary.create_file_writer(
        os.path.join(FLAGS.output_dir, 'summaries'))

    # MC Dropout ResNet50 based on PyTorch Vision implementation
    logging.info('Building Torch ResNet-50 MC Dropout model.')
    model = ub.models.resnet50_dropout_torch(num_classes=1,
                                             dropout_rate=FLAGS.dropout_rate)
    logging.info('Model number of weights: %s',
                 torch_utils.count_parameters(model))

    # Linearly scale learning rate and the decay epochs by vanilla settings.
    base_lr = FLAGS.base_learning_rate
    optimizer = torch.optim.SGD(model.parameters(),
                                lr=base_lr,
                                momentum=1.0 - FLAGS.one_minus_momentum,
                                nesterov=True)
    steps_to_lr_peak = int(steps_per_epoch * FLAGS.lr_warmup_epochs)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
        optimizer, steps_to_lr_peak, T_mult=2)

    model = model.to(device)

    metrics = utils.get_diabetic_retinopathy_base_metrics(
        use_tpu=False,
        num_bins=FLAGS.num_bins,
        use_validation=FLAGS.use_validation)

    # Define additional metrics that would fail in a TF TPU implementation.
    metrics.update(
        utils.get_diabetic_retinopathy_cpu_metrics(
            use_validation=FLAGS.use_validation))

    # Initialize loss function based on class reweighting setting
    loss_fn = torch.nn.BCELoss()
    sigmoid = torch.nn.Sigmoid()
    max_steps = steps_per_epoch * FLAGS.train_epochs
    image_h = 512
    image_w = 512

    def run_train_epoch(iterator):
        def train_step(inputs):
            images = inputs['features']
            labels = inputs['labels']
            images = torch.from_numpy(images._numpy()).view(
                train_batch_size,
                3,  # pylint: disable=protected-access
                image_h,
                image_w).to(device)
            labels = torch.from_numpy(labels._numpy()).to(device).float()  # pylint: disable=protected-access

            # Zero the parameter gradients
            optimizer.zero_grad()

            # Forward
            logits = model(images)
            probs = sigmoid(logits).squeeze(-1)

            # Add L2 regularization loss to NLL
            negative_log_likelihood = loss_fn(probs, labels)
            l2_loss = sum(p.pow(2.0).sum() for p in model.parameters())
            loss = negative_log_likelihood + (FLAGS.l2 * l2_loss)

            # Backward/optimizer
            loss.backward()
            optimizer.step()

            # Convert to NumPy for metrics updates
            loss = loss.detach()
            negative_log_likelihood = negative_log_likelihood.detach()
            labels = labels.detach()
            probs = probs.detach()

            if device != 'cpu':
                loss = loss.cpu()
                negative_log_likelihood = negative_log_likelihood.cpu()
                labels = labels.cpu()
                probs = probs.cpu()

            loss = loss.numpy()
            negative_log_likelihood = negative_log_likelihood.numpy()
            labels = labels.numpy()
            probs = probs.numpy()

            metrics['train/loss'].update_state(loss)
            metrics['train/negative_log_likelihood'].update_state(
                negative_log_likelihood)
            metrics['train/accuracy'].update_state(labels, probs)
            metrics['train/auprc'].update_state(labels, probs)
            metrics['train/auroc'].update_state(labels, probs)
            metrics['train/ece'].add_batch(probs, label=labels)

        for step in range(steps_per_epoch):
            train_step(next(iterator))

            if step % 100 == 0:
                current_step = (epoch + 1) * step
                time_elapsed = time.time() - start_time
                steps_per_sec = float(current_step) / time_elapsed
                eta_seconds = (max_steps - current_step
                               ) / steps_per_sec if steps_per_sec else 0
                message = (
                    '{:.1%} completion: epoch {:d}/{:d}. {:.1f} steps/s. '
                    'ETA: {:.0f} min. Time elapsed: {:.0f} min'.format(
                        current_step / max_steps, epoch + 1,
                        FLAGS.train_epochs, steps_per_sec, eta_seconds / 60,
                        time_elapsed / 60))
                logging.info(message)

    def run_eval_epoch(iterator, dataset_split, num_steps):
        def eval_step(inputs, model):
            images = inputs['features']
            labels = inputs['labels']
            images = torch.from_numpy(images._numpy()).view(
                eval_batch_size,
                3,  # pylint: disable=protected-access
                image_h,
                image_w).to(device)
            labels = torch.from_numpy(
                labels._numpy()).to(device).float().unsqueeze(-1)  # pylint: disable=protected-access

            with torch.no_grad():
                logits = torch.stack([
                    model(images)
                    for _ in range(FLAGS.num_dropout_samples_eval)
                ],
                                     dim=-1)

            # Logits dimension is (batch_size, 1, num_dropout_samples).
            logits = logits.squeeze()

            # It is now (batch_size, num_dropout_samples).
            probs = sigmoid(logits)

            # labels_tiled shape is (batch_size, num_dropout_samples).
            labels_tiled = torch.tile(labels,
                                      (1, FLAGS.num_dropout_samples_eval))

            log_likelihoods = -loss_fn(probs, labels_tiled)
            negative_log_likelihood = torch.mean(
                -torch.logsumexp(log_likelihoods, dim=-1) +
                torch.log(torch.tensor(float(FLAGS.num_dropout_samples_eval))))

            probs = torch.mean(probs, dim=-1)

            # Convert to NumPy for metrics updates
            negative_log_likelihood = negative_log_likelihood.detach()
            labels = labels.detach()
            probs = probs.detach()

            if device != 'cpu':
                negative_log_likelihood = negative_log_likelihood.cpu()
                labels = labels.cpu()
                probs = probs.cpu()

            negative_log_likelihood = negative_log_likelihood.numpy()
            labels = labels.numpy()
            probs = probs.numpy()

            metrics[dataset_split + '/negative_log_likelihood'].update_state(
                negative_log_likelihood)
            metrics[dataset_split + '/accuracy'].update_state(labels, probs)
            metrics[dataset_split + '/auprc'].update_state(labels, probs)
            metrics[dataset_split + '/auroc'].update_state(labels, probs)
            metrics[dataset_split + '/ece'].add_batch(probs, label=labels)

        for _ in range(num_steps):
            eval_step(next(iterator), model=model)

    metrics.update({'test/ms_per_example': tf.keras.metrics.Mean()})
    start_time = time.time()
    initial_epoch = 0
    train_iterator = iter(dataset_train)
    model.train()
    for epoch in range(initial_epoch, FLAGS.train_epochs):
        logging.info('Starting to run epoch: %s', epoch + 1)

        run_train_epoch(train_iterator)

        if FLAGS.use_validation:
            validation_iterator = iter(dataset_validation)
            logging.info('Starting to run validation eval at epoch: %s',
                         epoch + 1)
            run_eval_epoch(validation_iterator, 'validation',
                           steps_per_validation_eval)

        test_iterator = iter(dataset_test)
        logging.info('Starting to run test eval at epoch: %s', epoch + 1)
        test_start_time = time.time()
        run_eval_epoch(test_iterator, 'test', steps_per_test_eval)
        ms_per_example = (time.time() -
                          test_start_time) * 1e6 / eval_batch_size
        metrics['test/ms_per_example'].update_state(ms_per_example)

        # Step scheduler
        scheduler.step()

        # Log and write to summary the epoch metrics
        utils.log_epoch_metrics(metrics=metrics, use_tpu=False)
        total_results = {
            name: metric.result()
            for name, metric in metrics.items()
        }
        # Metrics from Robustness Metrics (like ECE) will return a dict with a
        # single key/value, instead of a scalar.
        total_results = {
            k: (list(v.values())[0] if isinstance(v, dict) else v)
            for k, v in total_results.items()
        }
        with summary_writer.as_default():
            for name, result in total_results.items():
                tf.summary.scalar(name, result, step=epoch + 1)

        for metric in metrics.values():
            metric.reset_states()

        if (FLAGS.checkpoint_interval > 0
                and (epoch + 1) % FLAGS.checkpoint_interval == 0):

            checkpoint_path = os.path.join(FLAGS.output_dir,
                                           f'model_{epoch + 1}.pt')
            torch_utils.checkpoint_torch_model(model=model,
                                               optimizer=optimizer,
                                               epoch=epoch + 1,
                                               checkpoint_path=checkpoint_path)
            logging.info('Saved Torch checkpoint to %s', checkpoint_path)

    final_checkpoint_path = os.path.join(FLAGS.output_dir,
                                         f'model_{FLAGS.train_epochs}.pt')
    torch_utils.checkpoint_torch_model(model=model,
                                       optimizer=optimizer,
                                       epoch=FLAGS.train_epochs,
                                       checkpoint_path=final_checkpoint_path)
    logging.info('Saved last checkpoint to %s', final_checkpoint_path)

    with summary_writer.as_default():
        hp.hparams({
            'base_learning_rate': FLAGS.base_learning_rate,
            'one_minus_momentum': FLAGS.one_minus_momentum,
            'dropout_rate': FLAGS.dropout_rate,
            'l2': FLAGS.l2,
            'lr_warmup_epochs': FLAGS.lr_warmup_epochs
        })
Esempio n. 4
0
def main(argv):
    del argv  # unused arg
    tf.random.set_seed(FLAGS.seed)

    # Wandb Setup
    if FLAGS.use_wandb:
        pathlib.Path(FLAGS.wandb_dir).mkdir(parents=True, exist_ok=True)
        wandb_args = dict(project=FLAGS.project,
                          entity='uncertainty-baselines',
                          dir=FLAGS.wandb_dir,
                          reinit=True,
                          name=FLAGS.exp_name,
                          group=FLAGS.exp_group)
        wandb_run = wandb.init(**wandb_args)
        wandb.config.update(FLAGS, allow_val_change=True)
        output_dir = str(
            os.path.join(
                FLAGS.output_dir,
                datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S')))
    else:
        wandb_run = None
        output_dir = FLAGS.output_dir

    tf.io.gfile.makedirs(output_dir)
    logging.info('Saving checkpoints at %s', output_dir)

    # Log Run Hypers
    hypers_dict = {
        'per_core_batch_size': FLAGS.per_core_batch_size,
        'base_learning_rate': FLAGS.base_learning_rate,
        'one_minus_momentum': FLAGS.one_minus_momentum,
        'dropout_rate': FLAGS.dropout_rate,
        'l2': FLAGS.l2,
    }
    logging.info('Hypers:')
    logging.info(pprint.pformat(hypers_dict))

    # Initialize distribution strategy on flag-specified accelerator
    strategy = utils.init_distribution_strategy(FLAGS.force_use_cpu,
                                                FLAGS.use_gpu, FLAGS.tpu)
    use_tpu = not (FLAGS.force_use_cpu or FLAGS.use_gpu)

    per_core_batch_size = (FLAGS.per_core_batch_size * FLAGS.num_cores)

    # Reweighting loss for class imbalance
    class_reweight_mode = FLAGS.class_reweight_mode
    if class_reweight_mode == 'constant':
        class_weights = utils.get_diabetic_retinopathy_class_balance_weights()
    else:
        class_weights = None

    # Load in datasets.
    datasets, steps = utils.load_dataset(train_batch_size=per_core_batch_size,
                                         eval_batch_size=per_core_batch_size,
                                         flags=FLAGS,
                                         strategy=strategy)
    available_splits = list(datasets.keys())
    test_splits = [split for split in available_splits if 'test' in split]
    eval_splits = [
        split for split in available_splits
        if 'validation' in split or 'test' in split
    ]

    # Iterate eval datasets
    eval_datasets = {split: iter(datasets[split]) for split in eval_splits}
    dataset_train = datasets['train']
    train_steps_per_epoch = steps['train']

    if FLAGS.use_bfloat16:
        tf.keras.mixed_precision.set_global_policy('mixed_bfloat16')

    summary_writer = tf.summary.create_file_writer(
        os.path.join(output_dir, 'summaries'))

    with strategy.scope():
        logging.info('Building Keras ResNet-50 MC Dropout model.')
        model = None
        if FLAGS.load_from_checkpoint:
            initial_epoch, model = utils.load_keras_checkpoints(
                FLAGS.checkpoint_dir, load_ensemble=False, return_epoch=True)
        else:
            initial_epoch = 0
            model = ub.models.resnet50_dropout(
                input_shape=utils.load_input_shape(dataset_train),
                num_classes=1,  # binary classification task
                dropout_rate=FLAGS.dropout_rate,
                filterwise_dropout=FLAGS.filterwise_dropout)
            utils.log_model_init_info(model=model)

        # Linearly scale learning rate and the decay epochs by vanilla settings.
        base_lr = FLAGS.base_learning_rate
        lr_decay_epochs = [
            (int(start_epoch_str) * FLAGS.train_epochs) // DEFAULT_NUM_EPOCHS
            for start_epoch_str in FLAGS.lr_decay_epochs
        ]
        lr_schedule = ub.schedules.WarmUpPiecewiseConstantSchedule(
            train_steps_per_epoch,
            base_lr,
            decay_ratio=FLAGS.lr_decay_ratio,
            decay_epochs=lr_decay_epochs,
            warmup_epochs=FLAGS.lr_warmup_epochs)
        optimizer = tf.keras.optimizers.SGD(lr_schedule,
                                            momentum=1.0 -
                                            FLAGS.one_minus_momentum,
                                            nesterov=True)
        metrics = utils.get_diabetic_retinopathy_base_metrics(
            use_tpu=use_tpu,
            num_bins=FLAGS.num_bins,
            use_validation=FLAGS.use_validation,
            available_splits=available_splits)

        # TODO(nband): debug or remove
        # checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer)
        # latest_checkpoint = tf.train.latest_checkpoint(output_dir)
        # if latest_checkpoint:
        #   # checkpoint.restore must be within a strategy.scope()
        #   # so that optimizer slot variables are mirrored.
        #   checkpoint.restore(latest_checkpoint)
        #   logging.info('Loaded checkpoint %s', latest_checkpoint)
        #   initial_epoch = optimizer.iterations.numpy() // train_steps_per_epoch

    # Define metrics outside the accelerator scope for CPU eval.
    # This will cause an error on TPU.
    if not use_tpu:
        metrics.update(
            utils.get_diabetic_retinopathy_cpu_metrics(
                available_splits=available_splits,
                use_validation=FLAGS.use_validation))

    for test_split in test_splits:
        metrics.update(
            {f'{test_split}/ms_per_example': tf.keras.metrics.Mean()})

    # Initialize loss function based on class reweighting setting
    loss_fn = utils.get_diabetic_retinopathy_loss_fn(
        class_reweight_mode=class_reweight_mode, class_weights=class_weights)

    # * Prepare for Evaluation *

    # Get the wrapper function which will produce uncertainty estimates for
    # our choice of method and Y/N ensembling.
    uncertainty_estimator_fn = utils.get_uncertainty_estimator(
        'dropout', use_ensemble=False, use_tf=True)

    # Wrap our estimator to predict probabilities (apply sigmoid on logits)
    eval_estimator = utils.wrap_retinopathy_estimator(
        model, use_mixed_precision=FLAGS.use_bfloat16, numpy_outputs=False)

    estimator_args = {'num_samples': FLAGS.num_dropout_samples_eval}

    @tf.function
    def train_step(iterator):
        """Training step function."""
        def step_fn(inputs):
            """Per-replica step function."""
            images = inputs['features']
            labels = inputs['labels']

            # For minibatch class reweighting, initialize per-batch loss function
            if class_reweight_mode == 'minibatch':
                batch_loss_fn = utils.get_minibatch_reweighted_loss_fn(
                    labels=labels)
            else:
                batch_loss_fn = loss_fn

            with tf.GradientTape() as tape:
                logits = model(images, training=True)
                if FLAGS.use_bfloat16:
                    logits = tf.cast(logits, tf.float32)

                negative_log_likelihood = tf.reduce_mean(
                    batch_loss_fn(y_true=tf.expand_dims(labels, axis=-1),
                                  y_pred=logits,
                                  from_logits=True))
                l2_loss = sum(model.losses)
                loss = negative_log_likelihood + (FLAGS.l2 * l2_loss)

                # Scale the loss given the TPUStrategy will reduce sum all gradients.
                scaled_loss = loss / strategy.num_replicas_in_sync

            grads = tape.gradient(scaled_loss, model.trainable_variables)
            optimizer.apply_gradients(zip(grads, model.trainable_variables))
            probs = tf.nn.sigmoid(logits)

            metrics['train/loss'].update_state(loss)
            metrics['train/negative_log_likelihood'].update_state(
                negative_log_likelihood)
            metrics['train/accuracy'].update_state(labels, probs)
            metrics['train/auprc'].update_state(labels, probs)
            metrics['train/auroc'].update_state(labels, probs)

            if not use_tpu:
                metrics['train/ece'].add_batch(probs, label=labels)

        for _ in tf.range(tf.cast(train_steps_per_epoch, tf.int32)):
            strategy.run(step_fn, args=(next(iterator), ))

    start_time = time.time()

    train_iterator = iter(dataset_train)
    for epoch in range(initial_epoch, FLAGS.train_epochs):
        logging.info('Starting to run epoch: %s', epoch + 1)
        train_step(train_iterator)

        current_step = (epoch + 1) * train_steps_per_epoch
        max_steps = train_steps_per_epoch * FLAGS.train_epochs
        time_elapsed = time.time() - start_time
        steps_per_sec = float(current_step) / time_elapsed
        eta_seconds = (max_steps - current_step) / steps_per_sec
        message = ('{:.1%} completion: epoch {:d}/{:d}. {:.1f} steps/s. '
                   'ETA: {:.0f} min. Time elapsed: {:.0f} min'.format(
                       current_step / max_steps, epoch + 1, FLAGS.train_epochs,
                       steps_per_sec, eta_seconds / 60, time_elapsed / 60))
        logging.info(message)

        # Run evaluation on all evaluation datasets, and compute metrics
        per_pred_results, total_results = utils.evaluate_model_and_compute_metrics(
            strategy,
            eval_datasets,
            steps,
            metrics,
            eval_estimator,
            uncertainty_estimator_fn,
            per_core_batch_size,
            available_splits,
            estimator_args=estimator_args,
            call_dataset_iter=False,
            is_deterministic=False,
            num_bins=FLAGS.num_bins,
            use_tpu=use_tpu,
            return_per_pred_results=True)

        # Optionally log to wandb
        if FLAGS.use_wandb:
            wandb.log(total_results, step=epoch)

        with summary_writer.as_default():
            for name, result in total_results.items():
                if result is not None:
                    tf.summary.scalar(name, result, step=epoch + 1)

        for metric in metrics.values():
            metric.reset_states()

        if (FLAGS.checkpoint_interval > 0
                and (epoch + 1) % FLAGS.checkpoint_interval == 0):
            # checkpoint_name = checkpoint.save(
            #     os.path.join(output_dir, 'checkpoint'))
            # logging.info('Saved checkpoint to %s', checkpoint_name)

            # TODO(nband): debug checkpointing
            # Also save Keras model, due to checkpoint.save issue
            keras_model_name = os.path.join(output_dir,
                                            f'keras_model_{epoch + 1}')
            model.save(keras_model_name)
            logging.info('Saved keras model to %s', keras_model_name)

            # Save per-prediction metrics
            utils.save_per_prediction_results(output_dir,
                                              epoch + 1,
                                              per_pred_results,
                                              verbose=False)

    # final_checkpoint_name = checkpoint.save(
    #     os.path.join(output_dir, 'checkpoint'))
    # logging.info('Saved last checkpoint to %s', final_checkpoint_name)

    keras_model_name = os.path.join(output_dir,
                                    f'keras_model_{FLAGS.train_epochs}')
    model.save(keras_model_name)
    logging.info('Saved keras model to %s', keras_model_name)

    # Save per-prediction metrics
    utils.save_per_prediction_results(output_dir,
                                      FLAGS.train_epochs,
                                      per_pred_results,
                                      verbose=False)

    with summary_writer.as_default():
        hp.hparams({
            'per_core_batch_size': FLAGS.per_core_batch_size,
            'base_learning_rate': FLAGS.base_learning_rate,
            'one_minus_momentum': FLAGS.one_minus_momentum,
            'dropout_rate': FLAGS.dropout_rate,
            'l2': FLAGS.l2,
        })

    if wandb_run is not None:
        wandb_run.finish()
Esempio n. 5
0
def main(argv):
  del argv  # unused arg
  tf.io.gfile.makedirs(FLAGS.output_dir)
  logging.info('Saving checkpoints at %s', FLAGS.output_dir)
  tf.random.set_seed(FLAGS.seed)

  # Initialize distribution strategy on flag-specified accelerator
  strategy = utils.init_distribution_strategy(FLAGS.force_use_cpu,
                                              FLAGS.use_gpu, FLAGS.tpu)
  use_tpu = not (FLAGS.force_use_cpu or FLAGS.use_gpu)

  # Only permit use of L2 regularization with a tied mean prior
  if FLAGS.l2 is not None and FLAGS.l2 > 0 and not FLAGS.tied_mean_prior:
    raise NotImplementedError(
        'For a principled objective, L2 regularization should not be used '
        'when the prior mean is untied from the posterior mean.')

  batch_size = FLAGS.batch_size * FLAGS.num_cores

  # As per the Kaggle challenge, we have split sizes:
  # train: 35,126
  # validation: 10,906 (currently unused)
  # test: 42,670
  ds_info = tfds.builder('diabetic_retinopathy_detection').info
  train_dataset_size = ds_info.splits['train'].num_examples
  steps_per_epoch = train_dataset_size // batch_size
  steps_per_validation_eval = (
      ds_info.splits['validation'].num_examples // batch_size)
  steps_per_test_eval = ds_info.splits['test'].num_examples // batch_size

  data_dir = FLAGS.data_dir

  dataset_train_builder = ub.datasets.get(
      'diabetic_retinopathy_detection', split='train', data_dir=data_dir)
  dataset_train = dataset_train_builder.load(batch_size=batch_size)
  dataset_train = strategy.experimental_distribute_dataset(dataset_train)

  dataset_validation_builder = ub.datasets.get(
      'diabetic_retinopathy_detection', split='validation', data_dir=data_dir)
  dataset_validation = dataset_validation_builder.load(
      batch_size=batch_size)
  dataset_validation = strategy.experimental_distribute_dataset(
      dataset_validation)

  dataset_test_builder = ub.datasets.get(
      'diabetic_retinopathy_detection', split='test', data_dir=data_dir)
  dataset_test = dataset_test_builder.load(batch_size=batch_size)
  dataset_test = strategy.experimental_distribute_dataset(dataset_test)

  if FLAGS.use_bfloat16:
    policy = tf.keras.mixed_precision.experimental.Policy('mixed_bfloat16')
    tf.keras.mixed_precision.experimental.set_policy(policy)

  summary_writer = tf.summary.create_file_writer(
      os.path.join(FLAGS.output_dir, 'summaries'))

  with strategy.scope():
    logging.info('Building Keras ResNet-50 Radial model.')

    if FLAGS.prior_stddev is None:
      logging.info(
          'A fixed prior stddev was not supplied. Computing a prior stddev = '
          'sqrt(2 / fan_in) for each layer. This is recommended over providing '
          'a fixed prior stddev.')

    model = ub.models.resnet50_radial(
        input_shape=utils.load_input_shape(dataset_train),
        num_classes=1,  # binary classification task
        prior_stddev=FLAGS.prior_stddev,
        dataset_size=train_dataset_size,
        stddev_mean_init=FLAGS.stddev_mean_init,
        stddev_stddev_init=FLAGS.stddev_stddev_init,
        tied_mean_prior=FLAGS.tied_mean_prior)

    logging.info('Model input shape: %s', model.input_shape)
    logging.info('Model output shape: %s', model.output_shape)
    logging.info('Model number of weights: %s', model.count_params())

    # Linearly scale learning rate and the decay epochs by vanilla settings.
    base_lr = FLAGS.base_learning_rate
    lr_decay_epochs = [
        (int(start_epoch_str) * FLAGS.train_epochs) // DEFAULT_NUM_EPOCHS
        for start_epoch_str in FLAGS.lr_decay_epochs
    ]

    lr_schedule = ub.schedules.WarmUpPiecewiseConstantSchedule(
        steps_per_epoch,
        base_lr,
        decay_ratio=FLAGS.lr_decay_ratio,
        decay_epochs=lr_decay_epochs,
        warmup_epochs=FLAGS.lr_warmup_epochs)
    optimizer = tf.keras.optimizers.SGD(
        lr_schedule, momentum=0.9, nesterov=True)
    metrics = utils.get_diabetic_retinopathy_base_metrics(
        use_tpu=use_tpu, num_bins=FLAGS.num_bins)
    metrics.update({
        'train/kl': tf.keras.metrics.Mean(),
        'train/kl_scale': tf.keras.metrics.Mean()
    })
    checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer)
    latest_checkpoint = tf.train.latest_checkpoint(FLAGS.output_dir)
    initial_epoch = 0
    if latest_checkpoint:
      # checkpoint.restore must be within a strategy.scope()
      # so that optimizer slot variables are mirrored.
      checkpoint.restore(latest_checkpoint)
      logging.info('Loaded checkpoint %s', latest_checkpoint)
      initial_epoch = optimizer.iterations.numpy() // steps_per_epoch

  # Finally, define OOD metrics outside the accelerator scope for CPU eval.
  # This will cause an error on TPU.
  if not use_tpu:
    metrics.update({
        'train/auc': tf.keras.metrics.AUC(),
        'validation/auc': tf.keras.metrics.AUC(),
        'test/auc': tf.keras.metrics.AUC()
    })

  @tf.function
  def train_step(iterator):
    """Training step function."""

    def step_fn(inputs):
      """Per-replica step function."""
      images = inputs['features']
      labels = inputs['labels']
      with tf.GradientTape() as tape:
        logits = model(images, training=True)
        if FLAGS.use_bfloat16:
          logits = tf.cast(logits, tf.float32)

        negative_log_likelihood = tf.reduce_mean(
            tf.keras.losses.binary_crossentropy(
                y_true=tf.expand_dims(labels, axis=-1),
                y_pred=logits,
                from_logits=True))

        filtered_variables = []
        for var in model.trainable_variables:
          # Apply l2 on the BN parameters and bias terms. This
          # excludes only fast weight approximate posterior/prior parameters,
          # but pay caution to their naming scheme.
          if 'bn' in var.name or 'bias' in var.name:
            filtered_variables.append(tf.reshape(var, (-1,)))

        l2_loss = FLAGS.l2 * 2 * tf.nn.l2_loss(
            tf.concat(filtered_variables, axis=0))
        kl = sum(model.losses)
        kl_scale = tf.cast(optimizer.iterations + 1, kl.dtype)
        kl_scale /= steps_per_epoch * FLAGS.kl_annealing_epochs
        kl_scale = tf.minimum(1., kl_scale)
        kl_loss = kl_scale * kl

        loss = negative_log_likelihood + l2_loss + kl_loss

        # Scale the loss given the TPUStrategy will reduce sum all gradients.
        scaled_loss = loss / strategy.num_replicas_in_sync

      grads = tape.gradient(scaled_loss, model.trainable_variables)
      optimizer.apply_gradients(zip(grads, model.trainable_variables))
      probs = tf.squeeze(tf.nn.sigmoid(logits))

      metrics['train/loss'].update_state(loss)
      metrics['train/negative_log_likelihood'].update_state(
          negative_log_likelihood)
      metrics['train/kl'].update_state(kl)
      metrics['train/kl_scale'].update_state(kl_scale)
      metrics['train/accuracy'].update_state(labels, probs)
      metrics['train/auc'].update_state(labels, probs)

      if not use_tpu:
        metrics['train/ece'].update_state(labels, probs)

    strategy.run(step_fn, args=(next(iterator),))

  @tf.function
  def test_step(iterator, dataset_split):
    """Evaluation step function."""

    def step_fn(inputs):
      """Per-replica step function."""
      images = inputs['features']
      labels = inputs['labels']
      logits = model(images, training=False)
      if FLAGS.use_bfloat16:
        logits = tf.cast(logits, tf.float32)

      negative_log_likelihood = tf.reduce_mean(
          tf.keras.losses.binary_crossentropy(
              y_true=tf.expand_dims(labels, axis=-1),
              y_pred=logits,
              from_logits=True))
      probs = tf.squeeze(tf.nn.sigmoid(logits))

      metrics[dataset_split + '/negative_log_likelihood'].update_state(
          negative_log_likelihood)
      metrics[dataset_split + '/accuracy'].update_state(labels, probs)
      metrics[dataset_split + '/auc'].update_state(labels, probs)

      if not use_tpu:
        metrics[dataset_split + '/ece'].update_state(labels, probs)

    strategy.run(step_fn, args=(next(iterator),))

  metrics.update({'test/ms_per_example': tf.keras.metrics.Mean()})
  start_time = time.time()

  train_iterator = iter(dataset_train)
  for epoch in range(initial_epoch, FLAGS.train_epochs):
    logging.info('Starting to run epoch: %s', epoch + 1)
    for step in range(steps_per_epoch):
      train_step(train_iterator)

      current_step = epoch * steps_per_epoch + (step + 1)
      max_steps = steps_per_epoch * FLAGS.train_epochs
      time_elapsed = time.time() - start_time
      steps_per_sec = float(current_step) / time_elapsed
      eta_seconds = (max_steps - current_step) / steps_per_sec
      message = ('{:.1%} completion: epoch {:d}/{:d}. {:.1f} steps/s. '
                 'ETA: {:.0f} min. Time elapsed: {:.0f} min'.format(
                     current_step / max_steps, epoch + 1, FLAGS.train_epochs,
                     steps_per_sec, eta_seconds / 60, time_elapsed / 60))
      if step % 20 == 0:
        logging.info(message)

    validation_iterator = iter(dataset_validation)
    for step in range(steps_per_validation_eval):
      if step % 20 == 0:
        logging.info('Starting to run validation eval step %s of epoch: %s',
                     step, epoch + 1)
      test_step(validation_iterator, 'validation')

    test_iterator = iter(dataset_test)
    for step in range(steps_per_test_eval):
      if step % 20 == 0:
        logging.info('Starting to run test eval step %s of epoch: %s', step,
                     epoch + 1)
      test_start_time = time.time()
      test_step(test_iterator, 'test')
      ms_per_example = (time.time() - test_start_time) * 1e6 / batch_size
      metrics['test/ms_per_example'].update_state(ms_per_example)

    # Log and write to summary the epoch metrics.
    utils.log_epoch_metrics(metrics=metrics, use_tpu=use_tpu)
    total_results = {name: metric.result() for name, metric in metrics.items()}
    with summary_writer.as_default():
      for name, result in total_results.items():
        tf.summary.scalar(name, result, step=epoch + 1)

    for metric in metrics.values():
      metric.reset_states()

    if (FLAGS.checkpoint_interval > 0 and
        (epoch + 1) % FLAGS.checkpoint_interval == 0):
      checkpoint_name = checkpoint.save(
          os.path.join(FLAGS.output_dir, 'checkpoint'))
      logging.info('Saved checkpoint to %s', checkpoint_name)

      # TODO(nband): debug checkpointing
      # Also save Keras model, due to checkpoint.save issue
      keras_model_name = os.path.join(FLAGS.output_dir,
                                      f'keras_model_{epoch + 1}')
      model.save(keras_model_name)
      logging.info('Saved keras model to %s', keras_model_name)

  final_checkpoint_name = checkpoint.save(
      os.path.join(FLAGS.output_dir, 'checkpoint'),)
  logging.info('Saved last checkpoint to %s', final_checkpoint_name)

  keras_model_name = os.path.join(FLAGS.output_dir,
                                  f'keras_model_{FLAGS.train_epochs}')
  model.save(keras_model_name)
  logging.info('Saved keras model to %s', keras_model_name)
  with summary_writer.as_default():
    hp.hparams({
        'base_learning_rate': FLAGS.base_learning_rate,
        'one_minus_momentum': FLAGS.one_minus_momentum,
        'l2': FLAGS.l2,
        'stddev_mean_init': FLAGS.stddev_mean_init,
        'stddev_stddev_init': FLAGS.stddev_stddev_init,
    })
def main(argv):
  del argv  # unused arg
  tf.random.set_seed(FLAGS.seed)

  # Wandb Setup
  if FLAGS.use_wandb:
    pathlib.Path(FLAGS.wandb_dir).mkdir(parents=True, exist_ok=True)
    wandb_args = dict(
        project=FLAGS.project,
        entity='uncertainty-baselines',
        dir=FLAGS.wandb_dir,
        reinit=True,
        name=FLAGS.exp_name,
        group=FLAGS.exp_group)
    wandb_run = wandb.init(**wandb_args)
    wandb.config.update(FLAGS, allow_val_change=True)
    output_dir = str(
        os.path.join(FLAGS.output_dir,
                     datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S')))
  else:
    wandb_run = None
    output_dir = FLAGS.output_dir

  tf.io.gfile.makedirs(output_dir)
  logging.info('Saving checkpoints at %s', output_dir)

  # Log Run Hypers
  hypers_dict = {
      'batch_size': FLAGS.batch_size,
      'base_learning_rate': FLAGS.base_learning_rate,
      'one_minus_momentum': FLAGS.one_minus_momentum,
      'l2': FLAGS.l2,
      'stddev_mean_init': FLAGS.stddev_mean_init,
      'stddev_stddev_init': FLAGS.stddev_stddev_init,
  }
  logging.info('Hypers:')
  logging.info(pprint.pformat(hypers_dict))

  # Initialize distribution strategy on flag-specified accelerator
  strategy = utils.init_distribution_strategy(FLAGS.force_use_cpu,
                                              FLAGS.use_gpu, FLAGS.tpu)
  use_tpu = not (FLAGS.force_use_cpu or FLAGS.use_gpu)

  # Only permit use of L2 regularization with a tied mean prior
  if FLAGS.l2 is not None and FLAGS.l2 > 0 and not FLAGS.tied_mean_prior:
    raise NotImplementedError(
        'For a principled objective, L2 regularization should not be used '
        'when the prior mean is untied from the posterior mean.')

  batch_size = FLAGS.batch_size * FLAGS.num_cores

  # Reweighting loss for class imbalance
  class_reweight_mode = FLAGS.class_reweight_mode
  if class_reweight_mode == 'constant':
    class_weights = utils.get_diabetic_retinopathy_class_balance_weights()
  else:
    class_weights = None

  # Load in datasets.
  datasets, steps = utils.load_dataset(
      train_batch_size=batch_size,
      eval_batch_size=batch_size,
      flags=FLAGS,
      strategy=strategy)
  available_splits = list(datasets.keys())
  test_splits = [split for split in available_splits if 'test' in split]
  eval_splits = [
      split for split in available_splits
      if 'validation' in split or 'test' in split
  ]

  # Iterate eval datasets
  eval_datasets = {split: iter(datasets[split]) for split in eval_splits}
  dataset_train = datasets['train']
  train_steps_per_epoch = steps['train']
  train_dataset_size = train_steps_per_epoch * batch_size

  if FLAGS.use_bfloat16:
    tf.keras.mixed_precision.set_global_policy('mixed_bfloat16')

  summary_writer = tf.summary.create_file_writer(
      os.path.join(output_dir, 'summaries'))

  if FLAGS.prior_stddev is None:
    logging.info(
        'A fixed prior stddev was not supplied. Computing a prior stddev = '
        'sqrt(2 / fan_in) for each layer. This is recommended over providing '
        'a fixed prior stddev.')

  with strategy.scope():
    logging.info('Building Keras ResNet-50 Variational Inference model.')
    model = None
    if FLAGS.load_from_checkpoint:
      initial_epoch, model = utils.load_keras_checkpoints(
          FLAGS.checkpoint_dir, load_ensemble=False, return_epoch=True)
    else:
      initial_epoch = 0
      model = ub.models.resnet50_variational(
          input_shape=utils.load_input_shape(dataset_train),
          num_classes=1,  # binary classification task
          prior_stddev=FLAGS.prior_stddev,
          dataset_size=train_dataset_size,
          stddev_mean_init=FLAGS.stddev_mean_init,
          stddev_stddev_init=FLAGS.stddev_stddev_init,
          tied_mean_prior=FLAGS.tied_mean_prior)
      utils.log_model_init_info(model=model)

    # Linearly scale learning rate and the decay epochs by vanilla settings.
    base_lr = FLAGS.base_learning_rate
    lr_decay_epochs = [
        (int(start_epoch_str) * FLAGS.train_epochs) // DEFAULT_NUM_EPOCHS
        for start_epoch_str in FLAGS.lr_decay_epochs
    ]

    lr_schedule = ub.schedules.WarmUpPiecewiseConstantSchedule(
        train_steps_per_epoch,
        base_lr,
        decay_ratio=FLAGS.lr_decay_ratio,
        decay_epochs=lr_decay_epochs,
        warmup_epochs=FLAGS.lr_warmup_epochs)
    optimizer = tf.keras.optimizers.SGD(
        lr_schedule, momentum=1.0 - FLAGS.one_minus_momentum, nesterov=True)
    metrics = utils.get_diabetic_retinopathy_base_metrics(
        use_tpu=use_tpu,
        num_bins=FLAGS.num_bins,
        use_validation=FLAGS.use_validation,
        available_splits=available_splits)

    # VI specific metrics
    metrics.update({
        'train/kl': tf.keras.metrics.Mean(),
        'train/kl_scale': tf.keras.metrics.Mean()
    })

    # TODO(nband): debug or remove
    # checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer)
    # latest_checkpoint = tf.train.latest_checkpoint(output_dir)
    # if latest_checkpoint:
    #   # checkpoint.restore must be within a strategy.scope()
    #   # so that optimizer slot variables are mirrored.
    #   checkpoint.restore(latest_checkpoint)
    #   logging.info('Loaded checkpoint %s', latest_checkpoint)
    #   initial_epoch = optimizer.iterations.numpy() // steps_per_epoch

  # Define metrics outside the accelerator scope for CPU eval.
  # This will cause an error on TPU.
  if not use_tpu:
    metrics.update(
        utils.get_diabetic_retinopathy_cpu_metrics(
            available_splits=available_splits,
            use_validation=FLAGS.use_validation))

  for test_split in test_splits:
    metrics.update({f'{test_split}/ms_per_example': tf.keras.metrics.Mean()})

  # Initialize loss function based on class reweighting setting
  loss_fn = utils.get_diabetic_retinopathy_loss_fn(
      class_reweight_mode=class_reweight_mode, class_weights=class_weights)

  # * Prepare for Evaluation *

  # Get the wrapper function which will produce uncertainty estimates for
  # our choice of method and Y/N ensembling.
  uncertainty_estimator_fn = utils.get_uncertainty_estimator(
      'variational_inference', use_ensemble=False, use_tf=True)

  # Wrap our estimator to predict probabilities (apply sigmoid on logits)
  eval_estimator = utils.wrap_retinopathy_estimator(
      model, use_mixed_precision=FLAGS.use_bfloat16, numpy_outputs=False)

  estimator_args = {'num_samples': FLAGS.num_mc_samples_eval}

  @tf.function
  def train_step(iterator):
    """Training step function."""
    print('tracing training')
    def step_fn(inputs):
      """Per-replica step function."""
      images = inputs['features']
      labels = inputs['labels']

      # For minibatch class reweighting, initialize per-batch loss function
      if class_reweight_mode == 'minibatch':
        print('Retracing loss fn retrieval')
        batch_loss_fn = utils.get_minibatch_reweighted_loss_fn(labels=labels)
      else:
        batch_loss_fn = loss_fn

      with tf.GradientTape() as tape:
        # TODO(nband): TPU-friendly implem
        if FLAGS.num_mc_samples_train > 1:
          logits_arr = tf.TensorArray(
              tf.float32, size=FLAGS.num_mc_samples_train)

          for i in tf.range(FLAGS.num_mc_samples_train):
            logits = model(images, training=True)
            # logits = tf.squeeze(logits, axis=-1)
            # if FLAGS.use_bfloat16:
            #   logits = tf.cast(logits, tf.float32)

            logits_arr = logits_arr.write(i, logits)

          logits_list = logits_arr.stack()

          # if FLAGS.num_mc_samples_train > 1:
          #   # Pythonic Implem
          #   logits_list = []
          #   for _ in range(FLAGS.num_mc_samples_train):
          #     print('Tracing for loop')
          #     logits = model(images, training=True)
          #     if FLAGS.use_bfloat16:
          #       print('tracing bfloat conditional')
          #       logits = tf.cast(logits, tf.float32)
          #
          #     logits = tf.squeeze(logits, axis=-1)
          #     logits_list.append(logits)
          #
          #   # Logits dimension is (num_samples, batch_size).
          #   logits_list = tf.stack(logits_list, axis=0)

          probs_list = tf.nn.sigmoid(logits_list)
          probs = tf.reduce_mean(probs_list, axis=0)
          negative_log_likelihood = tf.reduce_mean(
              batch_loss_fn(
                  y_true=tf.expand_dims(labels, axis=-1),
                  y_pred=probs,
                  from_logits=False))
        else:
          # Single train step
          logits = model(images, training=True)
          if FLAGS.use_bfloat16:
            logits = tf.cast(logits, tf.float32)
          negative_log_likelihood = tf.reduce_mean(
              batch_loss_fn(
                  y_true=tf.expand_dims(labels, axis=-1),
                  y_pred=logits,
                  from_logits=True))
          probs = tf.squeeze(tf.nn.sigmoid(logits))

        filtered_variables = []
        for var in model.trainable_variables:
          # Apply l2 on the BN parameters and bias terms. This
          # excludes only fast weight approximate posterior/prior parameters,
          # but pay caution to their naming scheme.
          if 'bn' in var.name or 'bias' in var.name:
            filtered_variables.append(tf.reshape(var, (-1,)))

        l2_loss = FLAGS.l2 * 2 * tf.nn.l2_loss(
            tf.concat(filtered_variables, axis=0))
        kl = sum(model.losses)
        kl_scale = tf.cast(optimizer.iterations + 1, kl.dtype)
        kl_scale /= train_steps_per_epoch * FLAGS.kl_annealing_epochs
        kl_scale = tf.minimum(1., kl_scale)
        kl_loss = kl_scale * kl

        loss = negative_log_likelihood + l2_loss + kl_loss

        # Scale the loss given the TPUStrategy will reduce sum all gradients.
        scaled_loss = loss / strategy.num_replicas_in_sync

      grads = tape.gradient(scaled_loss, model.trainable_variables)
      optimizer.apply_gradients(zip(grads, model.trainable_variables))

      metrics['train/loss'].update_state(loss)
      metrics['train/negative_log_likelihood'].update_state(
          negative_log_likelihood)
      metrics['train/kl'].update_state(kl)
      metrics['train/kl_scale'].update_state(kl_scale)
      metrics['train/accuracy'].update_state(labels, probs)
      metrics['train/auprc'].update_state(labels, probs)
      metrics['train/auroc'].update_state(labels, probs)

      if not use_tpu:
        metrics['train/ece'].add_batch(probs, label=labels)

    for _ in tf.range(tf.cast(train_steps_per_epoch, tf.int32)):
      strategy.run(step_fn, args=(next(iterator),))

  start_time = time.time()

  train_iterator = iter(dataset_train)

  for epoch in range(initial_epoch, FLAGS.train_epochs):
    logging.info('Starting to run epoch: %s', epoch + 1)
    train_step(train_iterator)

    current_step = (epoch + 1) * train_steps_per_epoch
    max_steps = train_steps_per_epoch * FLAGS.train_epochs
    time_elapsed = time.time() - start_time
    steps_per_sec = float(current_step) / time_elapsed
    eta_seconds = (max_steps - current_step) / steps_per_sec
    message = ('{:.1%} completion: epoch {:d}/{:d}. {:.1f} steps/s. '
               'ETA: {:.0f} min. Time elapsed: {:.0f} min'.format(
                   current_step / max_steps, epoch + 1, FLAGS.train_epochs,
                   steps_per_sec, eta_seconds / 60, time_elapsed / 60))
    logging.info(message)

    # eval_datasets = {'ood_validation': eval_datasets['ood_validation']}
    # Run evaluation on all evaluation datasets, and compute metrics
    per_pred_results, total_results = utils.evaluate_model_and_compute_metrics(
        strategy,
        eval_datasets,
        steps,
        metrics,
        eval_estimator,
        uncertainty_estimator_fn,
        batch_size,
        available_splits,
        estimator_args=estimator_args,
        call_dataset_iter=False,
        is_deterministic=False,
        num_bins=FLAGS.num_bins,
        use_tpu=use_tpu,
        return_per_pred_results=True)

    # Optionally log to wandb
    if FLAGS.use_wandb:
      wandb.log(total_results, step=epoch)

    with summary_writer.as_default():
      for name, result in total_results.items():
        if result is not None:
          tf.summary.scalar(name, result, step=epoch + 1)

    for metric in metrics.values():
      metric.reset_states()

    if (FLAGS.checkpoint_interval > 0 and
        (epoch + 1) % FLAGS.checkpoint_interval == 0):
      # checkpoint_name = checkpoint.save(
      #     os.path.join(output_dir, 'checkpoint'))
      # logging.info('Saved checkpoint to %s', checkpoint_name)

      # TODO(nband): debug checkpointing
      # Also save Keras model, due to checkpoint.save issue.
      keras_model_name = os.path.join(output_dir, f'keras_model_{epoch + 1}')
      model.save(keras_model_name)
      logging.info('Saved keras model to %s', keras_model_name)

      # Save per-prediction metrics
      utils.save_per_prediction_results(
          output_dir, epoch + 1, per_pred_results, verbose=False)

  # final_checkpoint_name = checkpoint.save(
  #     os.path.join(output_dir, 'checkpoint'),)
  # logging.info('Saved last checkpoint to %s', final_checkpoint_name)

  keras_model_name = os.path.join(output_dir,
                                  f'keras_model_{FLAGS.train_epochs}')
  model.save(keras_model_name)
  logging.info('Saved keras model to %s', keras_model_name)

  # Save per-prediction metrics
  utils.save_per_prediction_results(
      output_dir, FLAGS.train_epochs, per_pred_results, verbose=False)

  with summary_writer.as_default():
    hp.hparams({
        'base_learning_rate': FLAGS.base_learning_rate,
        'one_minus_momentum': FLAGS.one_minus_momentum,
        'l2': FLAGS.l2,
        'stddev_mean_init': FLAGS.stddev_mean_init,
        'stddev_stddev_init': FLAGS.stddev_stddev_init,
    })

  if wandb_run is not None:
    wandb_run.finish()