示例#1
0
def _configure_hparams(logdir, dicts, 
                       metrics=["linear_classification_accuracy",
                                "alignment", "uniformity"]):
    """
    Set up the tensorboard hyperparameter interface
    
    :logdir: string; path to log directory
    :dicts: list of dictionaries containing hyperparameter values
    :metrics: list of strings; metric names
    """
    metrics = [hp.Metric(m) for m in metrics]
    params = {}
    # for each parameter dictionary
    for d in dicts:
        # for each parameter:
        for k in d:
            # is it a categorical?
            if k in SPECIAL_HPARAMS:
                params[hp.HParam(k, hp.Discrete(SPECIAL_HPARAMS[k]))] = d[k]
            elif isinstance(d[k], bool):
                params[hp.HParam(k, hp.Discrete([True, False]))] = d[k]
            elif isinstance(d[k], int):
                params[hp.HParam(k, hp.IntInterval(1, 1000000))] = d[k]
            elif isinstance(d[k], float):
                params[hp.HParam(k, hp.RealInterval(0., 10000000.))] = d[k]
    #
    hparams_config = hp.hparams_config(
                        hparams=list(params.keys()), 
                        metrics=metrics)
                
    # get a name for the run
    base_dir, run_name = os.path.split(logdir)
    if len(run_name) == 0:
        base_dir, run_name = os.path.split(base_dir)
    # record hyperparamers
    hp.hparams(params, trial_id=run_name)
示例#2
0
    def on_result(self, result):
        if self._file_writer is None:
            from tensorflow.python.eager import context
            from tensorboard.plugins.hparams import api as hp
            self._context = context
            self._file_writer = tf.summary.create_file_writer(self.logdir)
        with tf.device("/CPU:0"):
            with tf.summary.record_if(True), self._file_writer.as_default():
                step = result.get(
                    TIMESTEPS_TOTAL) or result[TRAINING_ITERATION]

                tmp = result.copy()
                if not self._hp_logged:
                    if self.trial and self.trial.evaluated_params:
                        try:
                            hp.hparams(self.trial.evaluated_params,
                                       trial_id=self.trial.trial_id)
                        except Exception as exc:
                            logger.error("HParams failed with %s", exc)
                    self._hp_logged = True

                for k in [
                        "config", "pid", "timestamp", TIME_TOTAL_S,
                        TRAINING_ITERATION
                ]:
                    if k in tmp:
                        del tmp[k]  # not useful to log these

                flat_result = flatten_dict(tmp, delimiter="/")
                path = ["ray", "tune"]
                for attr, value in flat_result.items():
                    if type(value) in VALID_SUMMARY_TYPES:
                        tf.summary.scalar("/".join(path + [attr]),
                                          value,
                                          step=step)
        self._file_writer.flush()
示例#3
0
def _setup_outputs(root_output_dir, experiment_name, hparam_dict):
    """Set up directories for experiment loops, write hyperparameters to disk."""

    if not experiment_name:
        raise ValueError('experiment_name must be specified.')

    create_if_not_exists(root_output_dir)

    checkpoint_dir = os.path.join(root_output_dir, 'checkpoints',
                                  experiment_name)
    create_if_not_exists(checkpoint_dir)
    checkpoint_mngr = tff.simulation.FileCheckpointManager(checkpoint_dir)

    results_dir = os.path.join(root_output_dir, 'results', experiment_name)
    create_if_not_exists(results_dir)
    csv_file = os.path.join(results_dir, 'experiment.metrics.csv')
    metrics_mngr = tff.simulation.CSVMetricsManager(csv_file)

    summary_logdir = os.path.join(root_output_dir, 'logdir', experiment_name)
    create_if_not_exists(summary_logdir)
    tensorboard_mngr = tff.simulation.TensorBoardManager(summary_logdir)

    if hparam_dict:
        summary_writer = tf.summary.create_file_writer(summary_logdir)
        hparam_dict['metrics_file'] = metrics_mngr.metrics_filename
        hparams_file = os.path.join(results_dir, 'hparams.csv')
        utils_impl.atomic_write_series_to_csv(hparam_dict, hparams_file)
        with summary_writer.as_default():
            hp.hparams({k: v for k, v in hparam_dict.items() if v is not None})

    logging.info('Writing...')
    logging.info('    checkpoints to: %s', checkpoint_dir)
    logging.info('    metrics csv to: %s', metrics_mngr.metrics_filename)
    logging.info('    summaries to: %s', summary_logdir)

    return checkpoint_mngr, metrics_mngr, tensorboard_mngr
def train_model(run_dir,hparams):

    hp.hparams(hparams)
    [X_train, y_train] = create_data(data_unscaled=data, start_train=start_train, end_train=end_train, n_windows=hparams[HP_WINDOW], n_outputs=hparams[HP_OUTPUT])
    [X_test, y_test] = create_data(data_unscaled=data, start_train=end_train, end_train=end_test, n_windows=hparams[HP_WINDOW], n_outputs=hparams[HP_OUTPUT])

    tf.compat.v1.keras.backend.clear_session()
    model = Sequential()
    model.add(TimeDistributed(Masking(mask_value=0., input_shape=(hparams[HP_WINDOW], n_inputs+1)), input_shape=(n_company, hparams[HP_WINDOW], n_inputs+1)))
    model.add(TimeDistributed(LSTM(hparams[HP_NUM_UNITS], stateful=False, activation='tanh', return_sequences=True, input_shape=(hparams[HP_WINDOW], n_inputs+1), kernel_initializer='TruncatedNormal' ,bias_initializer=initializers.Constant(value=0.1), dropout=hparams[HP_DROPOUT] ,recurrent_dropout=hparams[HP_DROPOUT])))
    model.add(TimeDistributed(LSTM(hparams[HP_NUM_UNITS], stateful=False, activation='tanh' ,return_sequences=False ,kernel_initializer='TruncatedNormal' ,bias_initializer=initializers.Constant(value=0.1) ,dropout=hparams[HP_DROPOUT], recurrent_dropout=hparams[HP_DROPOUT])))
    #model.add(TimeDistributed(LSTM(hparams[HP_NUM_UNITS], stateful=False, activation='tanh' ,return_sequences=True ,kernel_initializer='TruncatedNormal' ,bias_initializer=initializers.Constant(value=0.1) ,dropout=hparams[HP_DROPOUT], recurrent_dropout=hparams[HP_DROPOUT])))
    #model.add(TimeDistributed(LSTM(hparams[HP_NUM_UNITS], stateful=False, activation='tanh' ,return_sequences=False ,kernel_initializer='TruncatedNormal' ,bias_initializer=initializers.Constant(value=0.1) ,dropout=hparams[HP_DROPOUT], recurrent_dropout=hparams[HP_DROPOUT])))
    model.add(Dense(units=1, activation='linear'))
    model.compile(optimizer=hparams[HP_OPTIMIZER], loss='mse') #get_weighted_loss(weights=weights)) # metrics=['mae'])
    model.build(input_shape=(None, n_company, hparams[HP_WINDOW], n_inputs+1))
    model.summary()
    #model.fit(X_train, y_train[:, :, hparams[HP_WINDOW]-1:hparams[HP_WINDOW], 0], epochs=1, batch_size=batch_size, validation_data=(X_test, y_test[:, :, hparams[HP_WINDOW]-1:hparams[HP_WINDOW], 0]))
    model.fit(X_train, y_train[:, :, hparams[HP_WINDOW]-1:hparams[HP_WINDOW],0], epochs=200, batch_size=batch_size, validation_data=(X_test, y_test[:, :, hparams[HP_WINDOW]-1:hparams[HP_WINDOW],0]), callbacks=[
        TensorBoard(log_dir=run_dir, histogram_freq=10, write_graph=True, write_grads=True, update_freq='epoch'),
        hp.KerasCallback(writer=run_dir, hparams=hparams)])
    model.save('model_' + str(hparams[HP_NUM_UNITS]) + '_' + str(hparams[HP_DROPOUT]) + '_' + str(hparams[HP_OPTIMIZER]) + '_' + str(hparams[HP_WINDOW]) + '_' + str(hparams[HP_OUTPUT]) + '.h5')

    return 0
示例#5
0
 def log_summaries(self, step):
     with self.summary_writer.as_default():
         tf.summary.scalar(self.training_episode_length.name,
                           self.training_episode_length.result(),
                           step=step)
         tf.summary.scalar(self.training_critic_loss.name,
                           self.training_critic_loss.result(),
                           step=step)
         tf.summary.scalar(self.training_actor_loss.name,
                           self.training_actor_loss.result(),
                           step=step)
         tf.summary.scalar(self.training_critic_signal_to_noise.name,
                           self.training_critic_signal_to_noise.result(),
                           step=step)
         tf.summary.scalar(self.testing_episode_length.name,
                           self.testing_episode_length.result(),
                           step=step)
         tf.summary.scalar(self.testing_total_episode_reward.name,
                           self.testing_total_episode_reward.result(),
                           step=step)
         tf.summary.scalar(self.testing_mean_total_episode_reward.name,
                           self.testing_mean_total_episode_reward.result(),
                           step=step)
         hparams = {
             **self.agent._hparams,
             **self.replay_buffer._hparams,
             **self._hparams,
             "env_id": self.env.spec.id,
         }
         hp.hparams(hparams)
     self.training_episode_length.reset_states()
     self.training_critic_loss.reset_states()
     self.training_actor_loss.reset_states()
     self.training_critic_signal_to_noise.reset_states()
     self.testing_episode_length.reset_states()
     self.testing_total_episode_reward.reset_states()
示例#6
0
def run_eval_epoch(
    current_step: int,
    test_fn: EvalStepFn,
    test_dataset: tf.data.Dataset,
    test_summary_writer: tf.summary.SummaryWriter,
    val_fn: Optional[EvalStepFn] = None,
    val_dataset: Optional[tf.data.Dataset] = None,
    val_summary_writer: Optional[tf.summary.SummaryWriter] = None,
    ood_fn: Optional[EvalStepFn] = None,
    ood_dataset: Optional[tf.data.Dataset] = None,
    ood_summary_writer: Optional[tf.summary.SummaryWriter] = None,
    hparams: Optional[Dict[str, Any]] = None,
):
    """Run one evaluation epoch on the test and optionally validation splits."""
    val_outputs_np = None
    if val_dataset:
        val_iterator = iter(val_dataset)
        val_outputs = val_fn(val_iterator)
        with val_summary_writer.as_default():  # pytype: disable=attribute-error
            if hparams:
                hp.hparams(hparams)
            for name, metric in val_outputs.items():
                tf.summary.scalar(name, metric, step=current_step)
        val_outputs_np = {k: v.numpy() for k, v in val_outputs.items()}
        logging.info('Validation metrics for step %d: %s', current_step,
                     val_outputs_np)
    if ood_dataset:
        ood_iterator = iter(ood_dataset)
        ood_outputs = ood_fn(ood_iterator)
        with ood_summary_writer.as_default():  # pytype: disable=attribute-error
            if hparams:
                hp.hparams(hparams)
            for name, metric in ood_outputs.items():
                tf.summary.scalar(name, metric, step=current_step)
        ood_outputs_np = {k: v.numpy() for k, v in ood_outputs.items()}
        logging.info('OOD metrics for step %d: %s', current_step,
                     ood_outputs_np)

    test_iterator = iter(test_dataset)
    test_outputs = test_fn(test_iterator)
    with test_summary_writer.as_default():
        if hparams:
            hp.hparams(hparams)
        for name, metric in test_outputs.items():
            tf.summary.scalar(name, metric, step=current_step)
    test_outputs_np = {k: v.numpy() for k, v in test_outputs.items()}
    return val_outputs_np, ood_outputs_np, test_outputs_np
def main(argv):
  del argv  # unused arg
  tf.io.gfile.makedirs(FLAGS.output_dir)
  logging.info('Saving checkpoints at %s', FLAGS.output_dir)
  tf.random.set_seed(FLAGS.seed)

  # Initialize distribution strategy on flag-specified accelerator
  strategy = utils.init_distribution_strategy(FLAGS.force_use_cpu,
                                              FLAGS.use_gpu, FLAGS.tpu)
  use_tpu = not (FLAGS.force_use_cpu or FLAGS.use_gpu)

  train_batch_size = FLAGS.train_batch_size * FLAGS.num_cores
  eval_batch_size = FLAGS.eval_batch_size * FLAGS.num_cores

  # Reweighting loss for class imbalance
  class_reweight_mode = FLAGS.class_reweight_mode
  if class_reweight_mode == 'constant':
    class_weights = utils.get_diabetic_retinopathy_class_balance_weights()
  else:
    class_weights = None

  # As per the Kaggle challenge, we have split sizes:
  # train: 35,126
  # validation: 10,906 (currently unused)
  # test: 42,670
  ds_info = tfds.builder('diabetic_retinopathy_detection').info
  steps_per_epoch = ds_info.splits['train'].num_examples // train_batch_size
  steps_per_validation_eval = (
      ds_info.splits['validation'].num_examples // eval_batch_size)
  steps_per_test_eval = ds_info.splits['test'].num_examples // eval_batch_size

  data_dir = FLAGS.data_dir

  dataset_train_builder = ub.datasets.get(
      'diabetic_retinopathy_detection', split='train', data_dir=data_dir)
  dataset_train = dataset_train_builder.load(batch_size=train_batch_size)

  dataset_validation_builder = ub.datasets.get(
      'diabetic_retinopathy_detection',
      split='validation',
      data_dir=data_dir,
      is_training=not FLAGS.use_validation)
  validation_batch_size = (
      eval_batch_size if FLAGS.use_validation else train_batch_size)
  dataset_validation = dataset_validation_builder.load(
      batch_size=validation_batch_size)
  if FLAGS.use_validation:
    dataset_validation = strategy.experimental_distribute_dataset(
        dataset_validation)
  else:
    # Note that this will not create any mixed batches of train and validation
    # images.
    dataset_train = dataset_train.concatenate(dataset_validation)

  dataset_train = strategy.experimental_distribute_dataset(dataset_train)

  dataset_test_builder = ub.datasets.get(
      'diabetic_retinopathy_detection', split='test', data_dir=data_dir)
  dataset_test = dataset_test_builder.load(batch_size=eval_batch_size)
  dataset_test = strategy.experimental_distribute_dataset(dataset_test)

  if FLAGS.use_bfloat16:
    policy = tf.keras.mixed_precision.experimental.Policy('mixed_bfloat16')
    tf.keras.mixed_precision.experimental.set_policy(policy)

  summary_writer = tf.summary.create_file_writer(
      os.path.join(FLAGS.output_dir, 'summaries'))

  with strategy.scope():
    logging.info('Building Keras ResNet-50 deterministic model.')
    model = ub.models.resnet50_deterministic(
        input_shape=utils.load_input_shape(dataset_train),
        num_classes=1)  # binary classification task
    logging.info('Model input shape: %s', model.input_shape)
    logging.info('Model output shape: %s', model.output_shape)
    logging.info('Model number of weights: %s', model.count_params())

    base_lr = FLAGS.base_learning_rate
    if FLAGS.lr_schedule == 'step':
      lr_decay_epochs = [
          (int(start_epoch_str) * FLAGS.train_epochs) // DEFAULT_NUM_EPOCHS
          for start_epoch_str in FLAGS.lr_decay_epochs
      ]
      lr_schedule = ub.schedules.WarmUpPiecewiseConstantSchedule(
          steps_per_epoch,
          base_lr,
          decay_ratio=FLAGS.lr_decay_ratio,
          decay_epochs=lr_decay_epochs,
          warmup_epochs=FLAGS.lr_warmup_epochs)
    else:
      lr_schedule = ub.schedules.WarmUpPolynomialSchedule(
          base_lr,
          end_learning_rate=FLAGS.final_decay_factor * base_lr,
          decay_steps=(
              steps_per_epoch * (FLAGS.train_epochs - FLAGS.lr_warmup_epochs)),
          warmup_steps=steps_per_epoch * FLAGS.lr_warmup_epochs,
          decay_power=1.0)
    optimizer = tf.keras.optimizers.SGD(
        lr_schedule, momentum=1.0 - FLAGS.one_minus_momentum, nesterov=True)
    metrics = utils.get_diabetic_retinopathy_base_metrics(
        use_tpu=use_tpu,
        num_bins=FLAGS.num_bins,
        use_validation=FLAGS.use_validation)
    checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer)
    latest_checkpoint = tf.train.latest_checkpoint(FLAGS.output_dir)
    initial_epoch = 0
    if latest_checkpoint:
      # checkpoint.restore must be within a strategy.scope()
      # so that optimizer slot variables are mirrored.
      checkpoint.restore(latest_checkpoint)
      logging.info('Loaded checkpoint %s', latest_checkpoint)
      initial_epoch = optimizer.iterations.numpy() // steps_per_epoch

  # Define metrics outside the accelerator scope for CPU eval.
  # This will cause an error on TPU.
  if not use_tpu:
    metrics.update(
        utils.get_diabetic_retinopathy_cpu_metrics(
            use_validation=FLAGS.use_validation))
  metrics.update({'test/ms_per_example': tf.keras.metrics.Mean()})

  # Initialize loss function based on class reweighting setting
  loss_fn = utils.get_diabetic_retinopathy_loss_fn(
      class_reweight_mode=class_reweight_mode, class_weights=class_weights)

  @tf.function
  def train_step(iterator):
    """Training step function."""

    def step_fn(inputs):
      """Per-replica step function."""
      images = inputs['features']
      labels = inputs['labels']

      # For minibatch class reweighting, initialize per-batch loss function
      if class_reweight_mode == 'minibatch':
        batch_loss_fn = utils.get_minibatch_reweighted_loss_fn(labels=labels)
      else:
        batch_loss_fn = loss_fn

      with tf.GradientTape() as tape:
        logits = model(images, training=True)
        if FLAGS.use_bfloat16:
          logits = tf.cast(logits, tf.float32)

        negative_log_likelihood = tf.reduce_mean(
            batch_loss_fn(
                y_true=tf.expand_dims(labels, axis=-1),
                y_pred=logits,
                from_logits=True))
        l2_loss = sum(model.losses)
        loss = negative_log_likelihood + (FLAGS.l2 * l2_loss)

        # Scale the loss given the TPUStrategy will reduce sum all gradients.
        scaled_loss = loss / strategy.num_replicas_in_sync

      grads = tape.gradient(scaled_loss, model.trainable_variables)
      optimizer.apply_gradients(zip(grads, model.trainable_variables))
      probs = tf.squeeze(tf.nn.sigmoid(logits))

      metrics['train/loss'].update_state(loss)
      metrics['train/negative_log_likelihood'].update_state(
          negative_log_likelihood)
      metrics['train/accuracy'].update_state(labels, probs)
      metrics['train/auprc'].update_state(labels, probs)
      metrics['train/auroc'].update_state(labels, probs)

      if not use_tpu:
        metrics['train/ece'].add_batch(probs, label=labels)

    for _ in tf.range(tf.cast(steps_per_epoch, tf.int32)):
      strategy.run(step_fn, args=(next(iterator),))

  @tf.function
  def test_step(iterator, dataset_split, num_steps):
    """Evaluation step function."""

    def step_fn(inputs):
      """Per-replica step function."""
      images = inputs['features']
      labels = inputs['labels']
      logits = model(images, training=False)
      if FLAGS.use_bfloat16:
        logits = tf.cast(logits, tf.float32)

      negative_log_likelihood = tf.reduce_mean(
          tf.keras.losses.binary_crossentropy(
              y_true=tf.expand_dims(labels, axis=-1),
              y_pred=logits,
              from_logits=True))
      probs = tf.squeeze(tf.nn.sigmoid(logits))

      metrics[dataset_split + '/negative_log_likelihood'].update_state(
          negative_log_likelihood)
      metrics[dataset_split + '/accuracy'].update_state(labels, probs)
      metrics['test/accuracy'].update_state(labels, probs)
      metrics[dataset_split + '/auprc'].update_state(labels, probs)
      metrics[dataset_split + '/auroc'].update_state(labels, probs)

      if not use_tpu:
        metrics[dataset_split + '/ece'].add_batch(probs, label=labels)

    for _ in tf.range(tf.cast(num_steps, tf.int32)):
      strategy.run(step_fn, args=(next(iterator),))

  start_time = time.time()

  train_iterator = iter(dataset_train)
  for epoch in range(initial_epoch, FLAGS.train_epochs):
    logging.info('Starting to run epoch: %s', epoch + 1)
    train_step(train_iterator)

    current_step = (epoch + 1) * steps_per_epoch
    max_steps = steps_per_epoch * FLAGS.train_epochs
    time_elapsed = time.time() - start_time
    steps_per_sec = float(current_step) / time_elapsed
    eta_seconds = (max_steps - current_step) / steps_per_sec
    message = ('{:.1%} completion: epoch {:d}/{:d}. {:.1f} steps/s. '
               'ETA: {:.0f} min. Time elapsed: {:.0f} min'.format(
                   current_step / max_steps, epoch + 1, FLAGS.train_epochs,
                   steps_per_sec, eta_seconds / 60, time_elapsed / 60))
    logging.info(message)

    if FLAGS.use_validation:
      validation_iterator = iter(dataset_validation)
      logging.info('Starting to run validation eval ay epoch: %s', epoch + 1)
      test_step(validation_iterator, 'validation', steps_per_validation_eval)

    test_iterator = iter(dataset_test)
    logging.info('Starting to run test eval at epoch: %s', epoch + 1)
    test_start_time = time.time()
    test_step(test_iterator, 'test', steps_per_test_eval)
    ms_per_example = (time.time() - test_start_time) * 1e6 / eval_batch_size
    metrics['test/ms_per_example'].update_state(ms_per_example)

    # Log and write to summary the epoch metrics
    utils.log_epoch_metrics(metrics=metrics, use_tpu=use_tpu)
    total_results = {name: metric.result() for name, metric in metrics.items()}
    # Metrics from Robustness Metrics (like ECE) will return a dict with a
    # single key/value, instead of a scalar.
    total_results = {
        k: (list(v.values())[0] if isinstance(v, dict) else v)
        for k, v in total_results.items()
    }
    with summary_writer.as_default():
      for name, result in total_results.items():
        tf.summary.scalar(name, result, step=epoch + 1)

    for metric in metrics.values():
      metric.reset_states()

    if (FLAGS.checkpoint_interval > 0 and
        (epoch + 1) % FLAGS.checkpoint_interval == 0):
      checkpoint_name = checkpoint.save(
          os.path.join(FLAGS.output_dir, 'checkpoint'))
      logging.info('Saved checkpoint to %s', checkpoint_name)

      # TODO(nband): debug checkpointing
      # Also save Keras model, due to checkpoint.save issue
      keras_model_name = os.path.join(FLAGS.output_dir,
                                      f'keras_model_{epoch + 1}')
      model.save(keras_model_name)
      logging.info('Saved keras model to %s', keras_model_name)

  final_checkpoint_name = checkpoint.save(
      os.path.join(FLAGS.output_dir, 'checkpoint'))
  logging.info('Saved last checkpoint to %s', final_checkpoint_name)

  keras_model_name = os.path.join(FLAGS.output_dir,
                                  f'keras_model_{FLAGS.train_epochs}')
  model.save(keras_model_name)
  logging.info('Saved keras model to %s', keras_model_name)
  with summary_writer.as_default():
    hp.hparams({
        'base_learning_rate': FLAGS.base_learning_rate,
        'one_minus_momentum': FLAGS.one_minus_momentum,
        'l2': FLAGS.l2,
    })
示例#8
0
def run(run_dir, hparams):
  with tf.summary.create_file_writer(run_dir).as_default():
    hp.hparams(hparams)  # record the values used in this trial
    accuracy = train_test_model(hparams)
    tf.summary.scalar(METRIC_ACCURACY, accuracy, step=1)
def main(argv):
    del argv  # unused arg
    tf.random.set_seed(FLAGS.seed)

    per_core_batch_size = FLAGS.per_core_batch_size // FLAGS.ensemble_size
    batch_size = per_core_batch_size * FLAGS.num_cores
    steps_per_epoch = APPROX_IMAGENET_TRAIN_IMAGES // batch_size
    steps_per_eval = IMAGENET_VALIDATION_IMAGES // batch_size

    logging.info('Saving checkpoints at %s', FLAGS.output_dir)

    if FLAGS.use_gpu:
        logging.info('Use GPU')
        strategy = tf.distribute.MirroredStrategy()
    else:
        logging.info('Use TPU at %s',
                     FLAGS.tpu if FLAGS.tpu is not None else 'local')
        resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
            tpu=FLAGS.tpu)
        tf.config.experimental_connect_to_cluster(resolver)
        tf.tpu.experimental.initialize_tpu_system(resolver)
        strategy = tf.distribute.TPUStrategy(resolver)

    train_builder = ub.datasets.ImageNetDataset(
        split=tfds.Split.TRAIN, use_bfloat16=FLAGS.use_bfloat16)
    train_dataset = train_builder.load(batch_size=batch_size,
                                       strategy=strategy)
    test_builder = ub.datasets.ImageNetDataset(split=tfds.Split.TEST,
                                               use_bfloat16=FLAGS.use_bfloat16)
    clean_test_dataset = test_builder.load(batch_size=batch_size,
                                           strategy=strategy)
    test_datasets = {'clean': clean_test_dataset}
    if FLAGS.corruptions_interval > 0:
        corruption_types, max_intensity = utils.load_corrupted_test_info()
        for name in corruption_types:
            for intensity in range(1, max_intensity + 1):
                dataset_name = '{0}_{1}'.format(name, intensity)
                dataset = utils.load_corrupted_test_dataset(
                    batch_size=batch_size,
                    corruption_name=name,
                    corruption_intensity=intensity,
                    use_bfloat16=FLAGS.use_bfloat16)
                test_datasets[dataset_name] = (
                    strategy.experimental_distribute_dataset(dataset))

    if FLAGS.use_bfloat16:
        policy = tf.keras.mixed_precision.experimental.Policy('mixed_bfloat16')
        tf.keras.mixed_precision.experimental.set_policy(policy)

    summary_writer = tf.summary.create_file_writer(
        os.path.join(FLAGS.output_dir, 'summaries'))

    with strategy.scope():
        logging.info('Building Keras ResNet-50 model')
        model = ub.models.resnet50_het_rank1(
            input_shape=(224, 224, 3),
            num_classes=NUM_CLASSES,
            alpha_initializer=FLAGS.alpha_initializer,
            gamma_initializer=FLAGS.gamma_initializer,
            alpha_regularizer=FLAGS.alpha_regularizer,
            gamma_regularizer=FLAGS.gamma_regularizer,
            use_additive_perturbation=FLAGS.use_additive_perturbation,
            ensemble_size=FLAGS.ensemble_size,
            random_sign_init=FLAGS.random_sign_init,
            dropout_rate=FLAGS.dropout_rate,
            prior_stddev=FLAGS.prior_stddev,
            use_tpu=not FLAGS.use_gpu,
            use_ensemble_bn=FLAGS.use_ensemble_bn,
            num_factors=FLAGS.num_factors,
            temperature=FLAGS.temperature,
            num_mc_samples=FLAGS.num_mc_samples)
        logging.info('Model input shape: %s', model.input_shape)
        logging.info('Model output shape: %s', model.output_shape)
        logging.info('Model number of weights: %s', model.count_params())
        # Scale learning rate and decay epochs by vanilla settings.
        base_lr = FLAGS.base_learning_rate * batch_size / 256
        decay_epochs = [
            (FLAGS.train_epochs * 30) // 90,
            (FLAGS.train_epochs * 60) // 90,
            (FLAGS.train_epochs * 80) // 90,
        ]
        learning_rate = ub.schedules.WarmUpPiecewiseConstantSchedule(
            steps_per_epoch=steps_per_epoch,
            base_learning_rate=base_lr,
            decay_ratio=0.1,
            decay_epochs=decay_epochs,
            warmup_epochs=5)
        optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate,
                                            momentum=1.0 -
                                            FLAGS.one_minus_momentum,
                                            nesterov=True)
        metrics = {
            'train/negative_log_likelihood':
            tf.keras.metrics.Mean(),
            'train/kl':
            tf.keras.metrics.Mean(),
            'train/kl_scale':
            tf.keras.metrics.Mean(),
            'train/elbo':
            tf.keras.metrics.Mean(),
            'train/loss':
            tf.keras.metrics.Mean(),
            'train/accuracy':
            tf.keras.metrics.SparseCategoricalAccuracy(),
            'train/ece':
            rm.metrics.ExpectedCalibrationError(num_bins=FLAGS.num_bins),
            'train/diversity':
            rm.metrics.AveragePairwiseDiversity(),
            'test/negative_log_likelihood':
            tf.keras.metrics.Mean(),
            'test/kl':
            tf.keras.metrics.Mean(),
            'test/elbo':
            tf.keras.metrics.Mean(),
            'test/accuracy':
            tf.keras.metrics.SparseCategoricalAccuracy(),
            'test/ece':
            rm.metrics.ExpectedCalibrationError(num_bins=FLAGS.num_bins),
            'test/diversity':
            rm.metrics.AveragePairwiseDiversity(),
            'test/member_accuracy_mean':
            (tf.keras.metrics.SparseCategoricalAccuracy()),
            'test/member_ece_mean':
            rm.metrics.ExpectedCalibrationError(num_bins=FLAGS.num_bins),
        }
        if FLAGS.corruptions_interval > 0:
            corrupt_metrics = {}
            for intensity in range(1, max_intensity + 1):
                for corruption in corruption_types:
                    dataset_name = '{0}_{1}'.format(corruption, intensity)
                    corrupt_metrics['test/nll_{}'.format(dataset_name)] = (
                        tf.keras.metrics.Mean())
                    corrupt_metrics['test/kl_{}'.format(dataset_name)] = (
                        tf.keras.metrics.Mean())
                    corrupt_metrics['test/elbo_{}'.format(dataset_name)] = (
                        tf.keras.metrics.Mean())
                    corrupt_metrics['test/accuracy_{}'.format(
                        dataset_name)] = (
                            tf.keras.metrics.SparseCategoricalAccuracy())
                    corrupt_metrics['test/ece_{}'.format(dataset_name)] = (
                        rm.metrics.ExpectedCalibrationError(
                            num_bins=FLAGS.num_bins))

        if FLAGS.ensemble_size > 1:
            for i in range(FLAGS.ensemble_size):
                metrics['test/nll_member_{}'.format(
                    i)] = tf.keras.metrics.Mean()
                metrics['test/accuracy_member_{}'.format(i)] = (
                    tf.keras.metrics.SparseCategoricalAccuracy())

        logging.info('Finished building Keras ResNet-50 model')

        checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer)
        latest_checkpoint = tf.train.latest_checkpoint(FLAGS.output_dir)
        initial_epoch = 0
        if latest_checkpoint:
            # checkpoint.restore must be within a strategy.scope() so that optimizer
            # slot variables are mirrored.
            checkpoint.restore(latest_checkpoint)
            logging.info('Loaded checkpoint %s', latest_checkpoint)
            initial_epoch = optimizer.iterations.numpy() // steps_per_epoch

    def compute_l2_loss(model):
        filtered_variables = []
        for var in model.trainable_variables:
            # Apply l2 on the BN parameters and bias terms. This
            # excludes only fast weight approximate posterior/prior parameters,
            # but pay caution to their naming scheme.
            if ('kernel' in var.name or 'batch_norm' in var.name
                    or 'bias' in var.name):
                filtered_variables.append(tf.reshape(var, (-1, )))
        l2_loss = FLAGS.l2 * 2 * tf.nn.l2_loss(
            tf.concat(filtered_variables, axis=0))
        return l2_loss

    @tf.function
    def train_step(iterator):
        """Training StepFn."""
        def step_fn(inputs):
            """Per-Replica StepFn."""
            images = inputs['features']
            labels = inputs['labels']
            if FLAGS.ensemble_size > 1:
                images = tf.tile(images, [FLAGS.ensemble_size, 1, 1, 1])
                labels = tf.tile(labels, [FLAGS.ensemble_size])

            with tf.GradientTape() as tape:
                logits = model(images, training=True)
                if FLAGS.use_bfloat16:
                    logits = tf.cast(logits, tf.float32)

                probs = tf.nn.softmax(logits)
                if FLAGS.ensemble_size > 1:
                    per_probs = tf.reshape(
                        probs,
                        tf.concat([[FLAGS.ensemble_size, -1], probs.shape[1:]],
                                  0))
                    metrics['train/diversity'].add_batch(per_probs)

                negative_log_likelihood = tf.reduce_mean(
                    tf.keras.losses.sparse_categorical_crossentropy(
                        labels, logits, from_logits=True))
                l2_loss = compute_l2_loss(model)
                kl = sum(model.losses) / APPROX_IMAGENET_TRAIN_IMAGES
                kl_scale = tf.cast(optimizer.iterations + 1, kl.dtype)
                kl_scale /= steps_per_epoch * FLAGS.kl_annealing_epochs
                kl_scale = tf.minimum(1., kl_scale)
                kl_loss = kl_scale * kl

                # Scale the loss given the TPUStrategy will reduce sum all gradients.
                loss = negative_log_likelihood + l2_loss + kl_loss
                scaled_loss = loss / strategy.num_replicas_in_sync
                elbo = -(negative_log_likelihood + l2_loss + kl)

            grads = tape.gradient(scaled_loss, model.trainable_variables)

            # Separate learning rate implementation.
            if FLAGS.fast_weight_lr_multiplier != 1.0:
                grads_and_vars = []
                for grad, var in zip(grads, model.trainable_variables):
                    # Apply different learning rate on the fast weights. This excludes BN
                    # and slow weights, but pay caution to the naming scheme.
                    if ('batch_norm' not in var.name
                            and 'kernel' not in var.name):
                        grads_and_vars.append(
                            (grad * FLAGS.fast_weight_lr_multiplier, var))
                    else:
                        grads_and_vars.append((grad, var))
                optimizer.apply_gradients(grads_and_vars)
            else:
                optimizer.apply_gradients(zip(grads,
                                              model.trainable_variables))

            metrics['train/negative_log_likelihood'].update_state(
                negative_log_likelihood)
            metrics['train/kl'].update_state(kl)
            metrics['train/kl_scale'].update_state(kl_scale)
            metrics['train/elbo'].update_state(elbo)
            metrics['train/loss'].update_state(loss)
            metrics['train/accuracy'].update_state(labels, logits)
            metrics['train/ece'].add_batch(probs, label=labels)

        for _ in tf.range(tf.cast(steps_per_epoch, tf.int32)):
            strategy.run(step_fn, args=(next(iterator), ))

    @tf.function
    def test_step(iterator, dataset_name):
        """Evaluation StepFn."""
        def step_fn(inputs):
            """Per-Replica StepFn."""
            images = inputs['features']
            labels = inputs['labels']
            if FLAGS.ensemble_size > 1:
                images = tf.tile(images, [FLAGS.ensemble_size, 1, 1, 1])
            logits = tf.reshape([
                model(images, training=False)
                for _ in range(FLAGS.num_eval_samples)
            ], [FLAGS.num_eval_samples, FLAGS.ensemble_size, -1, NUM_CLASSES])
            if FLAGS.use_bfloat16:
                logits = tf.cast(logits, tf.float32)
            all_probs = tf.nn.softmax(logits)
            probs = tf.math.reduce_mean(all_probs, axis=[0, 1])  # marginalize

            # Negative log marginal likelihood computed in a numerically-stable way.
            labels_broadcasted = tf.broadcast_to(
                labels,
                [FLAGS.num_eval_samples, FLAGS.ensemble_size, labels.shape[0]])
            log_likelihoods = -tf.keras.losses.sparse_categorical_crossentropy(
                labels_broadcasted, logits, from_logits=True)
            negative_log_likelihood = tf.reduce_mean(
                -tf.reduce_logsumexp(log_likelihoods, axis=[0, 1]) +
                tf.math.log(float(FLAGS.num_eval_samples *
                                  FLAGS.ensemble_size)))

            l2_loss = compute_l2_loss(model)
            kl = sum(model.losses) / IMAGENET_VALIDATION_IMAGES
            elbo = -(negative_log_likelihood + l2_loss + kl)

            if dataset_name == 'clean':
                if FLAGS.ensemble_size > 1:
                    per_probs = tf.reduce_mean(all_probs,
                                               axis=0)  # marginalize samples
                    metrics['test/diversity'].add_batch(per_probs)
                    for i in range(FLAGS.ensemble_size):
                        member_probs = per_probs[i]
                        member_loss = tf.keras.losses.sparse_categorical_crossentropy(
                            labels, member_probs)
                        metrics['test/nll_member_{}'.format(i)].update_state(
                            member_loss)
                        metrics['test/accuracy_member_{}'.format(
                            i)].update_state(labels, member_probs)
                        metrics['test/member_accuracy_mean'].update_state(
                            labels, member_probs)
                        metrics['test/member_ece_mean'].add_batch(member_probs,
                                                                  label=labels)

                metrics['test/negative_log_likelihood'].update_state(
                    negative_log_likelihood)
                metrics['test/kl'].update_state(kl)
                metrics['test/elbo'].update_state(elbo)
                metrics['test/accuracy'].update_state(labels, probs)
                metrics['test/ece'].add_batch(probs, label=labels)
            else:
                corrupt_metrics['test/nll_{}'.format(
                    dataset_name)].update_state(negative_log_likelihood)
                corrupt_metrics['test/kl_{}'.format(
                    dataset_name)].update_state(kl)
                corrupt_metrics['test/elbo_{}'.format(
                    dataset_name)].update_state(elbo)
                corrupt_metrics['test/accuracy_{}'.format(
                    dataset_name)].update_state(labels, probs)
                corrupt_metrics['test/ece_{}'.format(dataset_name)].add_batch(
                    probs, label=labels)

        for _ in tf.range(tf.cast(steps_per_eval, tf.int32)):
            strategy.run(step_fn, args=(next(iterator), ))

    train_iterator = iter(train_dataset)
    start_time = time.time()

    for epoch in range(initial_epoch, FLAGS.train_epochs):
        logging.info('Starting to run epoch: %s', epoch)
        train_step(train_iterator)

        current_step = (epoch + 1) * steps_per_epoch
        max_steps = steps_per_epoch * FLAGS.train_epochs
        time_elapsed = time.time() - start_time
        steps_per_sec = float(current_step) / time_elapsed
        eta_seconds = (max_steps - current_step) / steps_per_sec
        message = ('{:.1%} completion: epoch {:d}/{:d}. {:.1f} steps/s. '
                   'ETA: {:.0f} min. Time elapsed: {:.0f} min'.format(
                       current_step / max_steps, epoch + 1, FLAGS.train_epochs,
                       steps_per_sec, eta_seconds / 60, time_elapsed / 60))
        logging.info(message)

        datasets_to_evaluate = {'clean': test_datasets['clean']}
        if (FLAGS.corruptions_interval > 0
                and (epoch + 1) % FLAGS.corruptions_interval == 0):
            datasets_to_evaluate = test_datasets
        for dataset_name, test_dataset in datasets_to_evaluate.items():
            logging.info('Testing on dataset %s', dataset_name)
            test_iterator = iter(test_dataset)
            logging.info('Starting to run eval at epoch: %s', epoch)
            test_step(test_iterator, dataset_name)
            logging.info('Done with testing on %s', dataset_name)

        corrupt_results = {}
        if (FLAGS.corruptions_interval > 0
                and (epoch + 1) % FLAGS.corruptions_interval == 0):
            corrupt_results = utils.aggregate_corrupt_metrics(
                corrupt_metrics, corruption_types, max_intensity,
                FLAGS.alexnet_errors_path)

        logging.info('Train Loss: %.4f, Accuracy: %.2f%%',
                     metrics['train/loss'].result(),
                     metrics['train/accuracy'].result() * 100)
        logging.info('Test NLL: %.4f, Accuracy: %.2f%%',
                     metrics['test/negative_log_likelihood'].result(),
                     metrics['test/accuracy'].result() * 100)

        for i in range(FLAGS.ensemble_size):
            logging.info(
                'Member %d Test Loss: %.4f, Accuracy: %.2f%%', i,
                metrics['test/nll_member_{}'.format(i)].result(),
                metrics['test/accuracy_member_{}'.format(i)].result() * 100)

        total_results = {
            name: metric.result()
            for name, metric in metrics.items()
        }
        total_results.update(corrupt_results)
        # Results from Robustness Metrics themselves return a dict, so flatten them.
        total_results = utils.flatten_dictionary(total_results)
        with summary_writer.as_default():
            for name, result in total_results.items():
                tf.summary.scalar(name, result, step=epoch + 1)

        for metric in metrics.values():
            metric.reset_states()

        if (FLAGS.checkpoint_interval > 0
                and (epoch + 1) % FLAGS.checkpoint_interval == 0):
            checkpoint_name = checkpoint.save(
                os.path.join(FLAGS.output_dir, 'checkpoint'))
            logging.info('Saved checkpoint to %s', checkpoint_name)

    final_checkpoint_name = checkpoint.save(
        os.path.join(FLAGS.output_dir, 'checkpoint'))
    logging.info('Saved last checkpoint to %s', final_checkpoint_name)
    with summary_writer.as_default():
        hp.hparams({
            'base_learning_rate': FLAGS.base_learning_rate,
            'one_minus_momentum': FLAGS.one_minus_momentum,
            'l2': FLAGS.l2,
            'fast_weight_lr_multiplier': FLAGS.fast_weight_lr_multiplier,
            'num_eval_samples': FLAGS.num_eval_samples,
        })
示例#10
0
def run(run_dir, hparams):
    with tf.summary.create_file_writer(run_dir).as_default():
        hp.hparams(hparams)
        accuracy_score = train_test_model(hparams)
        tf.summary.scalar(METRIC_ACCURACY, accuracy_score, step=1)
示例#11
0
 def write_hparams(self, hparams: Mapping[str, Any]):
     with self._summary_writer.as_default():
         hparams_api.hparams(dict(utils.flatten_dict(hparams)))
示例#12
0
def random_hparam_search(cfg, data, callbacks, log_dir):
    '''
    Conduct a random hyperparameter search over the ranges given for the hyperparameters in config.yml and log results
    in TensorBoard. Model is trained x times for y random combinations of hyperparameters.
    :param cfg: Project config
    :param data: Dict containing the partitioned datasets
    :param callbacks: List of callbacks for Keras model (excluding TensorBoard)
    :param log_dir: Base directory in which to store logs
    :return: (Last model trained, resultant test set metrics, test data generator)
    '''

    # Define HParam objects for each hyperparameter we wish to tune.
    hp_ranges = cfg['HP_SEARCH']['RANGES']
    HPARAMS = []
    HPARAMS.append(hp.HParam('KERNEL_SIZE', hp.Discrete(hp_ranges['KERNEL_SIZE'])))
    HPARAMS.append(hp.HParam('MAXPOOL_SIZE', hp.Discrete(hp_ranges['MAXPOOL_SIZE'])))
    HPARAMS.append(hp.HParam('INIT_FILTERS', hp.Discrete(hp_ranges['INIT_FILTERS'])))
    HPARAMS.append(hp.HParam('FILTER_EXP_BASE', hp.IntInterval(hp_ranges['FILTER_EXP_BASE'][0], hp_ranges['FILTER_EXP_BASE'][1])))
    HPARAMS.append(hp.HParam('NODES_DENSE0', hp.Discrete(hp_ranges['NODES_DENSE0'])))
    HPARAMS.append(hp.HParam('CONV_BLOCKS', hp.IntInterval(hp_ranges['CONV_BLOCKS'][0], hp_ranges['CONV_BLOCKS'][1])))
    HPARAMS.append(hp.HParam('DROPOUT', hp.Discrete(hp_ranges['DROPOUT'])))
    HPARAMS.append(hp.HParam('LR', hp.RealInterval(hp_ranges['LR'][0], hp_ranges['LR'][1])))
    HPARAMS.append(hp.HParam('OPTIMIZER', hp.Discrete(hp_ranges['OPTIMIZER'])))
    HPARAMS.append(hp.HParam('L2_LAMBDA', hp.Discrete(hp_ranges['L2_LAMBDA'])))
    HPARAMS.append(hp.HParam('BATCH_SIZE', hp.Discrete(hp_ranges['BATCH_SIZE'])))
    HPARAMS.append(hp.HParam('IMB_STRATEGY', hp.Discrete(hp_ranges['IMB_STRATEGY'])))

    # Define test set metrics that we wish to log to TensorBoard for each training run
    HP_METRICS = [hp.Metric(metric, display_name='Test ' + metric) for metric in cfg['HP_SEARCH']['METRICS']]

    # Configure TensorBoard to log the results
    with tf.summary.create_file_writer(log_dir).as_default():
        hp.hparams_config(hparams=HPARAMS, metrics=HP_METRICS)

    # Complete a number of training runs at different hparam values and log the results.
    repeats_per_combo = cfg['HP_SEARCH']['REPEATS']   # Number of times to train the model per combination of hparams
    num_combos = cfg['HP_SEARCH']['COMBINATIONS']     # Number of random combinations of hparams to attempt
    num_sessions = num_combos * repeats_per_combo       # Total number of runs in this experiment
    model_type = 'DCNN_BINARY' if cfg['TRAIN']['CLASS_MODE'] == 'binary' else 'DCNN_MULTICLASS'
    trial_id = 0
    for group_idx in range(num_combos):
        rand = random.Random()
        HPARAMS = {h: h.domain.sample_uniform(rand) for h in HPARAMS}
        hparams = {h.name: HPARAMS[h] for h in HPARAMS}  # To pass to model definition
        for repeat_idx in range(repeats_per_combo):
            trial_id += 1
            print("Running training session %d/%d" % (trial_id, num_sessions))
            print("Hparam values: ", {h.name: HPARAMS[h] for h in HPARAMS})
            trial_logdir = os.path.join(log_dir, str(trial_id))     # Need specific logdir for each trial
            callbacks_hp = callbacks + [TensorBoard(log_dir=trial_logdir, profile_batch=0, write_graph=False)]

            # Set values of hyperparameters for this run in config file.
            for h in hparams:
                if h in ['LR', 'L2_LAMBDA']:
                    val = 10 ** hparams[h]      # These hyperparameters are sampled on the log scale.
                else:
                    val = hparams[h]
                cfg['NN'][model_type][h] = val

            # Set some hyperparameters that are not specified in model definition.
            cfg['TRAIN']['BATCH_SIZE'] = hparams['BATCH_SIZE']
            cfg['TRAIN']['IMB_STRATEGY'] = hparams['IMB_STRATEGY']

            # Run a training session and log the performance metrics on the test set to HParams dashboard in TensorBoard
            with tf.summary.create_file_writer(trial_logdir).as_default():
                hp.hparams(HPARAMS, trial_id=str(trial_id))
                model, test_metrics, test_generator = train_model(cfg, data, callbacks_hp, verbose=0)
                for metric in HP_METRICS:
                    if metric._tag in test_metrics:
                        tf.summary.scalar(metric._tag, test_metrics[metric._tag], step=1)   # Log test metric
    return
def main(argv):
  del argv  # unused arg
  tf.io.gfile.makedirs(FLAGS.output_dir)
  logging.info('Saving checkpoints at %s', FLAGS.output_dir)
  tf.random.set_seed(FLAGS.seed)

  per_core_batch_size = FLAGS.per_core_batch_size // FLAGS.ensemble_size
  batch_size = per_core_batch_size * FLAGS.num_cores

  data_dir = FLAGS.data_dir
  if FLAGS.use_gpu:
    logging.info('Use GPU')
    strategy = tf.distribute.MirroredStrategy()
  else:
    logging.info('Use TPU at %s',
                 FLAGS.tpu if FLAGS.tpu is not None else 'local')
    resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu=FLAGS.tpu)
    tf.config.experimental_connect_to_cluster(resolver)
    tf.tpu.experimental.initialize_tpu_system(resolver)
    strategy = tf.distribute.TPUStrategy(resolver)

  train_builder = ub.datasets.get(
      FLAGS.dataset,
      data_dir=data_dir,
      download_data=FLAGS.download_data,
      split=tfds.Split.TRAIN,
      validation_percent=1. - FLAGS.train_proportion)
  train_dataset = train_builder.load(batch_size=batch_size)
  train_dataset = strategy.experimental_distribute_dataset(train_dataset)

  validation_dataset = None
  steps_per_validation = 0
  if FLAGS.train_proportion < 1.0:
    validation_builder = ub.datasets.get(
        FLAGS.dataset,
        data_dir=data_dir,
        split=tfds.Split.VALIDATION,
        validation_percent=1. - FLAGS.train_proportion)
    validation_dataset = validation_builder.load(batch_size=batch_size)
    validation_dataset = strategy.experimental_distribute_dataset(
        validation_dataset)
    steps_per_validation = validation_builder.num_examples // batch_size

  clean_test_builder = ub.datasets.get(
      FLAGS.dataset,
      data_dir=data_dir,
      split=tfds.Split.TEST)
  clean_test_dataset = clean_test_builder.load(batch_size=batch_size)
  test_datasets = {
      'clean': strategy.experimental_distribute_dataset(clean_test_dataset),
  }
  steps_per_epoch = train_builder.num_examples // batch_size
  steps_per_eval = clean_test_builder.num_examples // batch_size
  num_classes = 100 if FLAGS.dataset == 'cifar100' else 10
  if FLAGS.corruptions_interval > 0:
    if FLAGS.dataset == 'cifar100':
      data_dir = FLAGS.cifar100_c_path
    corruption_types, _ = utils.load_corrupted_test_info(FLAGS.dataset)
    for corruption_type in corruption_types:
      for severity in range(1, 6):
        dataset = ub.datasets.get(
            f'{FLAGS.dataset}_corrupted',
            corruption_type=corruption_type,
            data_dir=data_dir,
            severity=severity,
            split=tfds.Split.TEST).load(batch_size=batch_size)
        test_datasets[f'{corruption_type}_{severity}'] = (
            strategy.experimental_distribute_dataset(dataset))

  summary_writer = tf.summary.create_file_writer(
      os.path.join(FLAGS.output_dir, 'summaries'))

  with strategy.scope():
    logging.info('Building Keras model')
    model = ub.models.wide_resnet_batchensemble(
        input_shape=(32, 32, 3),
        depth=28,
        width_multiplier=10,
        num_classes=num_classes,
        ensemble_size=FLAGS.ensemble_size,
        random_sign_init=FLAGS.random_sign_init,
        l2=FLAGS.l2)
    logging.info('Model input shape: %s', model.input_shape)
    logging.info('Model output shape: %s', model.output_shape)
    logging.info('Model number of weights: %s', model.count_params())
    # Linearly scale learning rate and the decay epochs by vanilla settings.
    base_lr = FLAGS.base_learning_rate * batch_size / 128
    lr_decay_epochs = [(int(start_epoch_str) * FLAGS.train_epochs) // 200
                       for start_epoch_str in FLAGS.lr_decay_epochs]
    lr_schedule = ub.schedules.WarmUpPiecewiseConstantSchedule(
        steps_per_epoch,
        base_lr,
        decay_ratio=FLAGS.lr_decay_ratio,
        decay_epochs=lr_decay_epochs,
        warmup_epochs=FLAGS.lr_warmup_epochs)
    optimizer = tf.keras.optimizers.SGD(lr_schedule,
                                        momentum=1.0 - FLAGS.one_minus_momentum,
                                        nesterov=True)
    metrics = {
        'train/negative_log_likelihood': tf.keras.metrics.Mean(),
        'train/accuracy': tf.keras.metrics.SparseCategoricalAccuracy(),
        'train/loss': tf.keras.metrics.Mean(),
        'train/ece': rm.metrics.ExpectedCalibrationError(
            num_bins=FLAGS.num_bins),
        'test/negative_log_likelihood': tf.keras.metrics.Mean(),
        'test/accuracy': tf.keras.metrics.SparseCategoricalAccuracy(),
        'test/ece': rm.metrics.ExpectedCalibrationError(
            num_bins=FLAGS.num_bins),
    }
    eval_dataset_splits = ['test']
    if validation_dataset:
      metrics.update({
          'validation/negative_log_likelihood': tf.keras.metrics.Mean(),
          'validation/accuracy': tf.keras.metrics.SparseCategoricalAccuracy(),
          'validation/ece': rm.metrics.ExpectedCalibrationError(
              num_bins=FLAGS.num_bins),
      })
      eval_dataset_splits += ['validation']
    for i in range(FLAGS.ensemble_size):
      for dataset_split in eval_dataset_splits:
        metrics[f'{dataset_split}/nll_member_{i}'] = tf.keras.metrics.Mean()
        metrics[f'{dataset_split}/accuracy_member_{i}'] = (
            tf.keras.metrics.SparseCategoricalAccuracy())
    if FLAGS.corruptions_interval > 0:
      corrupt_metrics = {}
      for intensity in range(1, 6):
        for corruption in corruption_types:
          dataset_name = '{0}_{1}'.format(corruption, intensity)
          corrupt_metrics['test/nll_{}'.format(dataset_name)] = (
              tf.keras.metrics.Mean())
          corrupt_metrics['test/accuracy_{}'.format(dataset_name)] = (
              tf.keras.metrics.SparseCategoricalAccuracy())
          corrupt_metrics['test/ece_{}'.format(dataset_name)] = (
              rm.metrics.ExpectedCalibrationError(num_bins=FLAGS.num_bins))

    checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer)
    latest_checkpoint = tf.train.latest_checkpoint(FLAGS.output_dir)
    initial_epoch = 0
    if latest_checkpoint:
      # checkpoint.restore must be within a strategy.scope() so that optimizer
      # slot variables are mirrored.
      checkpoint.restore(latest_checkpoint)
      logging.info('Loaded checkpoint %s', latest_checkpoint)
      initial_epoch = optimizer.iterations.numpy() // steps_per_epoch

  @tf.function
  def train_step(iterator):
    """Training StepFn."""
    def step_fn(inputs):
      """Per-Replica StepFn."""
      images = inputs['features']
      labels = inputs['labels']
      images = tf.tile(images, [FLAGS.ensemble_size, 1, 1, 1])
      labels = tf.tile(labels, [FLAGS.ensemble_size])

      with tf.GradientTape() as tape:
        logits = model(images, training=True)
        negative_log_likelihood = tf.reduce_mean(
            tf.keras.losses.sparse_categorical_crossentropy(labels,
                                                            logits,
                                                            from_logits=True))
        l2_loss = sum(model.losses)
        loss = negative_log_likelihood + l2_loss
        # Scale the loss given the TPUStrategy will reduce sum all gradients.
        scaled_loss = loss / strategy.num_replicas_in_sync

      grads = tape.gradient(scaled_loss, model.trainable_variables)

      # Separate learning rate implementation.
      if FLAGS.fast_weight_lr_multiplier != 1.0:
        grads_and_vars = []
        for grad, var in zip(grads, model.trainable_variables):
          # Apply different learning rate on the fast weight approximate
          # posterior/prior parameters. This is excludes BN and slow weights,
          # but pay caution to the naming scheme.
          if ('batch_norm' not in var.name and 'kernel' not in var.name):
            grads_and_vars.append((grad * FLAGS.fast_weight_lr_multiplier, var))
          else:
            grads_and_vars.append((grad, var))
        optimizer.apply_gradients(grads_and_vars)
      else:
        optimizer.apply_gradients(zip(grads, model.trainable_variables))

      probs = tf.nn.softmax(logits)
      metrics['train/ece'].add_batch(probs, label=labels)
      metrics['train/loss'].update_state(loss)
      metrics['train/negative_log_likelihood'].update_state(
          negative_log_likelihood)
      metrics['train/accuracy'].update_state(labels, logits)

    for _ in tf.range(tf.cast(steps_per_epoch, tf.int32)):
      strategy.run(step_fn, args=(next(iterator),))

  @tf.function
  def test_step(iterator, dataset_split, dataset_name, num_steps):
    """Evaluation StepFn."""
    def step_fn(inputs):
      """Per-Replica StepFn."""
      images = inputs['features']
      labels = inputs['labels']
      images = tf.tile(images, [FLAGS.ensemble_size, 1, 1, 1])
      logits = model(images, training=False)
      probs = tf.nn.softmax(logits)
      per_probs = tf.split(probs,
                           num_or_size_splits=FLAGS.ensemble_size,
                           axis=0)
      for i in range(FLAGS.ensemble_size):
        member_probs = per_probs[i]
        member_loss = tf.keras.losses.sparse_categorical_crossentropy(
            labels, member_probs)
        metrics[f'{dataset_split}/nll_member_{i}'].update_state(member_loss)
        metrics[f'{dataset_split}/accuracy_member_{i}'].update_state(
            labels, member_probs)

      probs = tf.reduce_mean(per_probs, axis=0)
      negative_log_likelihood = tf.reduce_mean(
          tf.keras.losses.sparse_categorical_crossentropy(labels, probs))
      if dataset_name == 'clean':
        metrics[f'{dataset_split}/negative_log_likelihood'].update_state(
            negative_log_likelihood)
        metrics[f'{dataset_split}/accuracy'].update_state(labels, probs)
        metrics[f'{dataset_split}/ece'].add_batch(probs, label=labels)
      else:
        corrupt_metrics['test/nll_{}'.format(dataset_name)].update_state(
            negative_log_likelihood)
        corrupt_metrics['test/accuracy_{}'.format(dataset_name)].update_state(
            labels, probs)
        corrupt_metrics['test/ece_{}'.format(dataset_name)].add_batch(
            probs, label=labels)

    for _ in tf.range(tf.cast(num_steps, tf.int32)):
      strategy.run(step_fn, args=(next(iterator),))

  metrics.update({'test/ms_per_example': tf.keras.metrics.Mean()})

  train_iterator = iter(train_dataset)
  start_time = time.time()
  for epoch in range(initial_epoch, FLAGS.train_epochs):
    logging.info('Starting to run epoch: %s', epoch)
    train_step(train_iterator)

    current_step = (epoch + 1) * steps_per_epoch
    max_steps = steps_per_epoch * FLAGS.train_epochs
    time_elapsed = time.time() - start_time
    steps_per_sec = float(current_step) / time_elapsed
    eta_seconds = (max_steps - current_step) / steps_per_sec
    message = ('{:.1%} completion: epoch {:d}/{:d}. {:.1f} steps/s. '
               'ETA: {:.0f} min. Time elapsed: {:.0f} min'.format(
                   current_step / max_steps,
                   epoch + 1,
                   FLAGS.train_epochs,
                   steps_per_sec,
                   eta_seconds / 60,
                   time_elapsed / 60))
    logging.info(message)

    if validation_dataset:
      validation_iterator = iter(validation_dataset)
      test_step(
          validation_iterator, 'validation', 'clean', steps_per_validation)
    datasets_to_evaluate = {'clean': test_datasets['clean']}
    if (FLAGS.corruptions_interval > 0 and
        (epoch + 1) % FLAGS.corruptions_interval == 0):
      datasets_to_evaluate = test_datasets
    for dataset_name, test_dataset in datasets_to_evaluate.items():
      test_iterator = iter(test_dataset)
      logging.info('Testing on dataset %s', dataset_name)
      logging.info('Starting to run eval at epoch: %s', epoch)
      test_start_time = time.time()
      test_step(test_iterator, 'test', dataset_name, steps_per_eval)
      ms_per_example = (time.time() - test_start_time) * 1e6 / batch_size
      metrics['test/ms_per_example'].update_state(ms_per_example)

      logging.info('Done with testing on %s', dataset_name)

    corrupt_results = {}
    if (FLAGS.corruptions_interval > 0 and
        (epoch + 1) % FLAGS.corruptions_interval == 0):
      corrupt_results = utils.aggregate_corrupt_metrics(corrupt_metrics,
                                                        corruption_types)

    logging.info('Train Loss: %.4f, Accuracy: %.2f%%',
                 metrics['train/loss'].result(),
                 metrics['train/accuracy'].result() * 100)
    logging.info('Test NLL: %.4f, Accuracy: %.2f%%',
                 metrics['test/negative_log_likelihood'].result(),
                 metrics['test/accuracy'].result() * 100)
    for i in range(FLAGS.ensemble_size):
      logging.info('Member %d Test Loss: %.4f, Accuracy: %.2f%%',
                   i, metrics['test/nll_member_{}'.format(i)].result(),
                   metrics['test/accuracy_member_{}'.format(i)].result() * 100)
    total_results = {name: metric.result() for name, metric in metrics.items()}
    total_results.update(corrupt_results)
    # Metrics from Robustness Metrics (like ECE) will return a dict with a
    # single key/value, instead of a scalar.
    total_results = {
        k: (list(v.values())[0] if isinstance(v, dict) else v)
        for k, v in total_results.items()
    }
    with summary_writer.as_default():
      for name, result in total_results.items():
        tf.summary.scalar(name, result, step=epoch + 1)

    for metric in metrics.values():
      metric.reset_states()

    if (FLAGS.checkpoint_interval > 0 and
        (epoch + 1) % FLAGS.checkpoint_interval == 0):
      checkpoint_name = checkpoint.save(
          os.path.join(FLAGS.output_dir, 'checkpoint'))
      logging.info('Saved checkpoint to %s', checkpoint_name)

  final_checkpoint_name = checkpoint.save(
      os.path.join(FLAGS.output_dir, 'checkpoint'))
  logging.info('Saved last checkpoint to %s', final_checkpoint_name)
  with summary_writer.as_default():
    hp.hparams({
        'base_learning_rate': FLAGS.base_learning_rate,
        'one_minus_momentum': FLAGS.one_minus_momentum,
        'l2': FLAGS.l2,
        'random_sign_init': FLAGS.random_sign_init,
        'fast_weight_lr_multiplier': FLAGS.fast_weight_lr_multiplier,
    })
示例#14
0
def main(hParams, n_run):
    nsteps = hParams['N_STEPS']
    nenv = hParams[HP_N_ENV]
    n_epochs = hParams['N_EPOCHS']
    total_timesteps = int(n_epochs * nsteps * nenv)
    nbatch = nenv * nsteps

    update_int = 1
    save_int = 5
    test_int = 10

    gamma = 0.99
    lr = hParams[HP_LEARNING_RATE]
    vf_coef = hParams[HP_VF_COEF]
    ent_coef = hParams[HP_ENT_COEF]
    save_dir = 'lr' + str(lr) + 'vc' + str(vf_coef) + 'ec' + str(
        ent_coef) + 'env' + str(nenv)

    current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
    log_dir = 'logs/cart_hparam_tuning/run-' + str(n_run)
    summ_writer = tf.summary.create_file_writer(log_dir)

    envfn = lambda: gym.make('CartPole-v1')
    env = SubprocVecEnv([envfn] * nenv)
    state_size = env.observation_space.shape
    num_actions = env.action_space.n
    pgnet = SimplePGNet(state_size,
                        num_actions,
                        learning_rate=lr,
                        vf_coef=vf_coef,
                        ent_coef=ent_coef)

    runner = SonicEnvRunner(env, pgnet, nsteps, gamma)

    print("Total updates to run: ", total_timesteps // nbatch)
    for update in range(1, total_timesteps // nbatch + 1):

        print("\nUpdate #{}:".format(update))
        states_mb, actions_mb, values_mb, rewards_mb, next_dones_mb = runner.run(
        )

        tf.summary.trace_on(graph=True)
        policy_loss, entropy_loss, vf_loss, loss = pgnet.fit_gradient(
            states_mb, actions_mb, rewards_mb, values_mb)
        if update == 1:
            with summ_writer.as_default():
                tf.summary.trace_export(name="grad_trace", step=0)

        WeightWriter(summ_writer, pgnet, (Conv2D, Dense), global_step=update)

        with summ_writer.as_default():
            tf.summary.scalar("PolicyLoss", policy_loss, step=update)
            tf.summary.scalar("EntropyLoss", entropy_loss, step=update)
            tf.summary.scalar("ValueFunctionLoss", vf_loss, step=update)
            tf.summary.scalar("Loss", loss, step=update)

        if update % update_int == 0:
            print("PolicyLoss:", policy_loss)
            print("EntropyLoss: ", entropy_loss)
            print("ValueFunctionLoss: ", vf_loss)
            print("Loss: ", loss)

        if update % save_int == 0:
            pgnet.model.save_weights('cart_hparams_tuning_models/' + save_dir +
                                     '/my_checkpoint')
            print("Model Saved")

        if update % test_int == 0:
            test_rewards = TestRewardWriter(summ_writer,
                                            envfn,
                                            pgnet,
                                            20,
                                            global_step=update)
            print("Test Rewards: ", test_rewards)

    with summ_writer.as_default():
        hp.hparams(hParams)

    env.close()
示例#15
0
    def run(self):
        run_dir = os.path.join(
            self.ph['log_dir'],
            'run_{}'.format(str(self.ph['curr_run_number'])))

        # set up tensorboard summary writer scope
        with tf.summary.create_file_writer(run_dir).as_default():
            hp.hparams(self.ph.get_hparams())

            # data loading
            train_client_data, test_dataset = dta.get_client_data(
                dataset_name=self.ph['dataset'],
                mask_by=self.ph['mask_by'],
                mask_ratios={
                    'unsupervised': self.ph['unsupervised_mask_ratio'],
                    'supervised': self.ph['supervised_mask_ratio']
                },
                sample_client_data=self.ph['sample_client_data'])
            test_dataset = self.dataloader.preprocess_dataset(test_dataset)

            sample_batch = self.dataloader.get_sample_batch(train_client_data)
            model_fn = functools.partial(
                self.keras_model_fn.create_tff_model_fn, sample_batch)

            # federated training
            iterative_process = tff.learning.build_federated_averaging_process(
                model_fn)
            state = iterative_process.initialize()

            for round_num in range(self.num_rounds):
                sample_clients = np.random.choice(
                    train_client_data.client_ids,
                    size=min(self.num_clients_per_round,
                             len(train_client_data.client_ids)),
                    replace=False)
                federated_train_data = self.dataloader.make_federated_data(
                    train_client_data, sample_clients)
                state, metrics = iterative_process.next(
                    state, federated_train_data)

                if not round_num % self.log_every:
                    print('\nround {:2d}, metrics={}'.format(
                        round_num, metrics))
                    tf.summary.scalar('train_accuracy',
                                      metrics[0],
                                      step=round_num)
                    tf.summary.scalar('train_loss', metrics[1], step=round_num)

                    test_loss, test_accuracy = self.evaluate_central(
                        test_dataset, state)
                    tf.summary.scalar('test_accuracy',
                                      test_accuracy,
                                      step=round_num)
                    tf.summary.scalar('test_loss', test_loss, step=round_num)

            print('\nround {:2d}, metrics={}'.format(round_num, metrics))
            tf.summary.scalar('train_accuracy', metrics[0], step=round_num)
            tf.summary.scalar('train_loss', metrics[1], step=round_num)

            test_loss, test_accuracy = self.evaluate_central(
                test_dataset, state)
            tf.summary.scalar('test_accuracy', test_accuracy, step=round_num)
            tf.summary.scalar('test_loss', test_loss, step=round_num)

        model_fp = os.path.join(run_dir, self.ph['model_fp'])
        self.keras_model_fn.save_model_weights(model_fp, state)
        return
def main(argv):
    del argv  # unused arg
    tf.io.gfile.makedirs(FLAGS.output_dir)
    logging.info('Saving checkpoints at %s', FLAGS.output_dir)
    tf.random.set_seed(FLAGS.seed)

    if FLAGS.use_gpu:
        logging.info('Use GPU')
        strategy = tf.distribute.MirroredStrategy()
    else:
        logging.info('Use TPU at %s',
                     FLAGS.tpu if FLAGS.tpu is not None else 'local')
        resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
            tpu=FLAGS.tpu)
        tf.config.experimental_connect_to_cluster(resolver)
        tf.tpu.experimental.initialize_tpu_system(resolver)
        strategy = tf.distribute.TPUStrategy(resolver)

    per_core_batch_size = FLAGS.per_core_batch_size // FLAGS.ensemble_size
    batch_size = per_core_batch_size * FLAGS.num_cores
    check_bool = FLAGS.train_proportion > 0 and FLAGS.train_proportion <= 1
    assert check_bool, 'Proportion of train set has to meet 0 < prop <= 1.'

    drop_remainder_validation = True
    if not FLAGS.use_gpu:
        # This has to be True for TPU traing, otherwise the batchsize of images in
        # the validation set can't be determined by TPU compile.
        assert drop_remainder_validation, 'drop_remainder must be True in TPU mode.'

    validation_percent = 1 - FLAGS.train_proportion
    train_dataset = ub.datasets.get(
        FLAGS.dataset,
        split=tfds.Split.TRAIN,
        validation_percent=validation_percent).load(batch_size=batch_size)
    validation_dataset = ub.datasets.get(
        FLAGS.dataset,
        split=tfds.Split.VALIDATION,
        validation_percent=validation_percent,
        drop_remainder=drop_remainder_validation).load(batch_size=batch_size)
    validation_dataset = validation_dataset.repeat()
    clean_test_dataset = ub.datasets.get(
        FLAGS.dataset, split=tfds.Split.TEST).load(batch_size=batch_size)
    train_dataset = strategy.experimental_distribute_dataset(train_dataset)
    validation_dataset = strategy.experimental_distribute_dataset(
        validation_dataset)
    test_datasets = {
        'clean': strategy.experimental_distribute_dataset(clean_test_dataset),
    }
    if FLAGS.corruptions_interval > 0:
        extra_kwargs = {}
        if FLAGS.dataset == 'cifar100':
            extra_kwargs['data_dir'] = FLAGS.cifar100_c_path
        corruption_types, _ = utils.load_corrupted_test_info(FLAGS.dataset)
        for corruption_type in corruption_types:
            for severity in range(1, 6):
                dataset = ub.datasets.get(
                    f'{FLAGS.dataset}_corrupted',
                    corruption_type=corruption_type,
                    severity=severity,
                    split=tfds.Split.TEST,
                    **extra_kwargs).load(batch_size=batch_size)
                test_datasets[f'{corruption_type}_{severity}'] = (
                    strategy.experimental_distribute_dataset(dataset))

    ds_info = tfds.builder(FLAGS.dataset).info
    train_sample_size = ds_info.splits[
        'train'].num_examples * FLAGS.train_proportion
    steps_per_epoch = int(train_sample_size / batch_size)
    train_sample_size = int(train_sample_size)

    steps_per_eval = ds_info.splits['test'].num_examples // batch_size
    num_classes = ds_info.features['label'].num_classes

    summary_writer = tf.summary.create_file_writer(
        os.path.join(FLAGS.output_dir, 'summaries'))

    logging.info('Building Keras model.')
    depth = 28
    width = 10

    dict_ranges = {'min': FLAGS.min_l2_range, 'max': FLAGS.max_l2_range}
    ranges = [dict_ranges for _ in range(6)]  # 6 independent l2 parameters
    model_config = {
        'key_to_index': {
            'input_conv_l2_kernel': 0,
            'group_l2_kernel': 1,
            'group_1_l2_kernel': 2,
            'group_2_l2_kernel': 3,
            'dense_l2_kernel': 4,
            'dense_l2_bias': 5,
        },
        'ranges': ranges,
        'test': None
    }
    lambdas_config = LambdaConfig(model_config['ranges'],
                                  model_config['key_to_index'])

    if FLAGS.e_body_hidden_units > 0:
        e_body_arch = '({},)'.format(FLAGS.e_body_hidden_units)
    else:
        e_body_arch = '()'
    e_shared_arch = '()'
    e_activation = 'tanh'
    filters_resnet = [16]
    for i in range(0, 3):  # 3 groups of blocks
        filters_resnet.extend([16 * width * 2**i] *
                              9)  # 9 layers in each block
    # e_head dim for conv2d is just the number of filters (only
    # kernel) and twice num of classes for the last dense layer (kernel + bias)
    e_head_dims = [x for x in filters_resnet] + [2 * num_classes]

    with strategy.scope():
        e_models = e_factory(
            lambdas_config.input_shape,
            e_head_dims=e_head_dims,
            e_body_arch=eval(e_body_arch),  # pylint: disable=eval-used
            e_shared_arch=eval(e_shared_arch),  # pylint: disable=eval-used
            activation=e_activation,
            use_bias=FLAGS.e_model_use_bias,
            e_head_init=FLAGS.init_emodels_stddev)

        model = wide_resnet_hyperbatchensemble(
            input_shape=ds_info.features['image'].shape,
            depth=depth,
            width_multiplier=width,
            num_classes=num_classes,
            ensemble_size=FLAGS.ensemble_size,
            random_sign_init=FLAGS.random_sign_init,
            config=lambdas_config,
            e_models=e_models,
            l2_batchnorm_layer=FLAGS.l2_batchnorm,
            regularize_fast_weights=FLAGS.regularize_fast_weights,
            fast_weights_eq_contraint=FLAGS.fast_weights_eq_contraint,
            version=2)

        logging.info('Model input shape: %s', model.input_shape)
        logging.info('Model output shape: %s', model.output_shape)
        logging.info('Model number of weights: %s', model.count_params())
        # build hyper-batchensemble complete -------------------------

        # Initialize Lambda distributions for tuning
        lambdas_mean = tf.reduce_mean(
            log_uniform_mean([lambdas_config.log_min, lambdas_config.log_max]))
        lambdas0 = tf.random.normal((FLAGS.ensemble_size, lambdas_config.dim),
                                    lambdas_mean,
                                    0.1 * FLAGS.ens_init_delta_bounds)
        lower0 = lambdas0 - tf.constant(FLAGS.ens_init_delta_bounds)
        lower0 = tf.maximum(lower0, 1e-8)
        upper0 = lambdas0 + tf.constant(FLAGS.ens_init_delta_bounds)

        log_lower = tf.Variable(tf.math.log(lower0))
        log_upper = tf.Variable(tf.math.log(upper0))
        lambda_parameters = [log_lower, log_upper]  # these variables are tuned
        clip_lambda_parameters(lambda_parameters, lambdas_config)

        # Optimizer settings to train model weights
        # Linearly scale learning rate and the decay epochs by vanilla settings.
        # Note: Here, we don't divide the epochs by 200 as for the other uncertainty
        # baselines.
        base_lr = FLAGS.base_learning_rate * batch_size / 128
        lr_decay_epochs = [int(l) for l in FLAGS.lr_decay_epochs]

        lr_schedule = ub.schedules.WarmUpPiecewiseConstantSchedule(
            steps_per_epoch,
            base_lr,
            decay_ratio=FLAGS.lr_decay_ratio,
            decay_epochs=lr_decay_epochs,
            warmup_epochs=FLAGS.lr_warmup_epochs)
        optimizer = tf.keras.optimizers.SGD(lr_schedule,
                                            momentum=1.0 -
                                            FLAGS.one_minus_momentum,
                                            nesterov=True)

        # tuner used for optimizing lambda_parameters
        tuner = tf.keras.optimizers.Adam(FLAGS.lr_tuning)

        metrics = {
            'train/negative_log_likelihood':
            tf.keras.metrics.Mean(),
            'train/accuracy':
            tf.keras.metrics.SparseCategoricalAccuracy(),
            'train/loss':
            tf.keras.metrics.Mean(),
            'train/ece':
            rm.metrics.ExpectedCalibrationError(num_bins=FLAGS.num_bins),
            'train/disagreement':
            tf.keras.metrics.Mean(),
            'train/average_kl':
            tf.keras.metrics.Mean(),
            'train/cosine_similarity':
            tf.keras.metrics.Mean(),
            'test/negative_log_likelihood':
            tf.keras.metrics.Mean(),
            'test/accuracy':
            tf.keras.metrics.SparseCategoricalAccuracy(),
            'test/ece':
            rm.metrics.ExpectedCalibrationError(num_bins=FLAGS.num_bins),
            'test/gibbs_nll':
            tf.keras.metrics.Mean(),
            'test/gibbs_accuracy':
            tf.keras.metrics.SparseCategoricalAccuracy(),
            'test/disagreement':
            tf.keras.metrics.Mean(),
            'test/average_kl':
            tf.keras.metrics.Mean(),
            'test/cosine_similarity':
            tf.keras.metrics.Mean(),
            'validation/loss':
            tf.keras.metrics.Mean(),
            'validation/loss_entropy':
            tf.keras.metrics.Mean(),
            'validation/loss_ce':
            tf.keras.metrics.Mean()
        }
        corrupt_metrics = {}

        for i in range(FLAGS.ensemble_size):
            metrics['test/nll_member_{}'.format(i)] = tf.keras.metrics.Mean()
            metrics['test/accuracy_member_{}'.format(i)] = (
                tf.keras.metrics.SparseCategoricalAccuracy())
        if FLAGS.corruptions_interval > 0:
            for intensity in range(1, 6):
                for corruption in corruption_types:
                    dataset_name = '{0}_{1}'.format(corruption, intensity)
                    corrupt_metrics['test/nll_{}'.format(dataset_name)] = (
                        tf.keras.metrics.Mean())
                    corrupt_metrics['test/accuracy_{}'.format(
                        dataset_name)] = (
                            tf.keras.metrics.SparseCategoricalAccuracy())
                    corrupt_metrics['test/ece_{}'.format(dataset_name)] = (
                        rm.metrics.ExpectedCalibrationError(
                            num_bins=FLAGS.num_bins))

        checkpoint = tf.train.Checkpoint(model=model,
                                         lambda_parameters=lambda_parameters,
                                         optimizer=optimizer)

        latest_checkpoint = tf.train.latest_checkpoint(FLAGS.output_dir)
        initial_epoch = 0
        if latest_checkpoint and FLAGS.restore_checkpoint:
            # checkpoint.restore must be within a strategy.scope() so that optimizer
            # slot variables are mirrored.
            checkpoint.restore(latest_checkpoint)
            logging.info('Loaded checkpoint %s', latest_checkpoint)
            initial_epoch = optimizer.iterations.numpy() // steps_per_epoch

    @tf.function
    def train_step(iterator):
        """Training StepFn."""
        def step_fn(inputs):
            """Per-Replica StepFn."""
            images = inputs['features']
            labels = inputs['labels']
            images = tf.tile(images, [FLAGS.ensemble_size, 1, 1, 1])

            # generate lambdas
            lambdas = log_uniform_sample(per_core_batch_size,
                                         lambda_parameters)
            lambdas = tf.reshape(lambdas,
                                 (FLAGS.ensemble_size * per_core_batch_size,
                                  lambdas_config.dim))

            with tf.GradientTape() as tape:
                logits = model([images, lambdas], training=True)

                if FLAGS.use_gibbs_ce:
                    # Average of single model CEs
                    # tiling of labels should be only done for Gibbs CE loss
                    labels = tf.tile(labels, [FLAGS.ensemble_size])
                    negative_log_likelihood = tf.reduce_mean(
                        tf.keras.losses.sparse_categorical_crossentropy(
                            labels, logits, from_logits=True))
                else:
                    # Ensemble CE uses no tiling of the labels
                    negative_log_likelihood = ensemble_crossentropy(
                        labels, logits, FLAGS.ensemble_size)
                # Note: Divide l2_loss by sample_size (this differs from uncertainty_
                # baselines implementation.)
                l2_loss = sum(model.losses) / train_sample_size
                loss = negative_log_likelihood + l2_loss
                # Scale the loss given the TPUStrategy will reduce sum all gradients.
                scaled_loss = loss / strategy.num_replicas_in_sync

            grads = tape.gradient(scaled_loss, model.trainable_variables)

            # Separate learning rate for fast weights.
            grads_and_vars = []
            for grad, var in zip(grads, model.trainable_variables):
                if (('alpha' in var.name or 'gamma' in var.name)
                        and 'batch_norm' not in var.name):
                    grads_and_vars.append(
                        (grad * FLAGS.fast_weight_lr_multiplier, var))
                else:
                    grads_and_vars.append((grad, var))
            optimizer.apply_gradients(grads_and_vars)

            probs = tf.nn.softmax(logits)
            per_probs = tf.split(probs,
                                 num_or_size_splits=FLAGS.ensemble_size,
                                 axis=0)
            per_probs_stacked = tf.stack(per_probs, axis=0)
            metrics['train/ece'].add_batch(probs, label=labels)
            metrics['train/loss'].update_state(loss)
            metrics['train/negative_log_likelihood'].update_state(
                negative_log_likelihood)
            metrics['train/accuracy'].update_state(labels, logits)
            diversity = rm.metrics.AveragePairwiseDiversity()
            diversity.add_batch(per_probs_stacked,
                                num_models=FLAGS.ensemble_size)
            diversity_results = diversity.result()
            for k, v in diversity_results.items():
                metrics['train/' + k].update_state(v)

            if grads_and_vars:
                grads, _ = zip(*grads_and_vars)

        strategy.run(step_fn, args=(next(iterator), ))

    @tf.function
    def tuning_step(iterator):
        """Tuning StepFn."""
        def step_fn(inputs):
            """Per-Replica StepFn."""
            images = inputs['features']
            labels = inputs['labels']
            images = tf.tile(images, [FLAGS.ensemble_size, 1, 1, 1])

            with tf.GradientTape(watch_accessed_variables=False) as tape:
                tape.watch(lambda_parameters)

                # sample lambdas
                if FLAGS.sample_and_tune:
                    lambdas = log_uniform_sample(per_core_batch_size,
                                                 lambda_parameters)
                else:
                    lambdas = log_uniform_mean(lambda_parameters)
                    lambdas = tf.repeat(lambdas, per_core_batch_size, axis=0)
                lambdas = tf.reshape(lambdas,
                                     (FLAGS.ensemble_size *
                                      per_core_batch_size, lambdas_config.dim))
                # ensemble CE
                logits = model([images, lambdas], training=False)
                ce = ensemble_crossentropy(labels, logits, FLAGS.ensemble_size)
                # entropy penalty for lambda distribution
                entropy = FLAGS.tau * log_uniform_entropy(lambda_parameters)
                loss = ce - entropy
                scaled_loss = loss / strategy.num_replicas_in_sync

            gradients = tape.gradient(loss, lambda_parameters)
            tuner.apply_gradients(zip(gradients, lambda_parameters))

            metrics['validation/loss_ce'].update_state(
                ce / strategy.num_replicas_in_sync)
            metrics['validation/loss_entropy'].update_state(
                entropy / strategy.num_replicas_in_sync)
            metrics['validation/loss'].update_state(scaled_loss)

        strategy.run(step_fn, args=(next(iterator), ))

    @tf.function
    def test_step(iterator, dataset_name, num_eval_samples=0):
        """Evaluation StepFn."""

        n_samples = num_eval_samples if num_eval_samples >= 0 else -num_eval_samples
        if num_eval_samples >= 0:
            # the +1 accounts for the fact that we add the mean of lambdas
            ensemble_size = FLAGS.ensemble_size * (1 + n_samples)
        else:
            ensemble_size = FLAGS.ensemble_size * n_samples

        def step_fn(inputs):
            """Per-Replica StepFn."""
            # Note that we don't use tf.tile for labels here
            images = inputs['features']
            labels = inputs['labels']
            images = tf.tile(images, [ensemble_size, 1, 1, 1])

            # get lambdas
            samples = log_uniform_sample(n_samples, lambda_parameters)
            if num_eval_samples >= 0:
                lambdas = log_uniform_mean(lambda_parameters)
                lambdas = tf.expand_dims(lambdas, 1)
                lambdas = tf.concat((lambdas, samples), 1)
            else:
                lambdas = samples

            # lambdas with shape (ens size, samples, dim of lambdas)
            rep_lambdas = tf.repeat(lambdas, per_core_batch_size, axis=1)
            rep_lambdas = tf.reshape(rep_lambdas,
                                     (ensemble_size * per_core_batch_size, -1))

            # eval on testsets
            logits = model([images, rep_lambdas], training=False)
            probs = tf.nn.softmax(logits)
            per_probs = tf.split(probs,
                                 num_or_size_splits=ensemble_size,
                                 axis=0)

            # per member performance and gibbs performance (average per member perf)
            if dataset_name == 'clean':
                for i in range(FLAGS.ensemble_size):
                    # we record the first sample of lambdas per batch-ens member
                    first_member_index = i * (ensemble_size //
                                              FLAGS.ensemble_size)
                    member_probs = per_probs[first_member_index]
                    member_loss = tf.keras.losses.sparse_categorical_crossentropy(
                        labels, member_probs)
                    metrics['test/nll_member_{}'.format(i)].update_state(
                        member_loss)
                    metrics['test/accuracy_member_{}'.format(i)].update_state(
                        labels, member_probs)

                labels_tile = tf.tile(labels, [ensemble_size])
                metrics['test/gibbs_nll'].update_state(
                    tf.reduce_mean(
                        tf.keras.losses.sparse_categorical_crossentropy(
                            labels_tile, logits, from_logits=True)))
                metrics['test/gibbs_accuracy'].update_state(labels_tile, probs)

            # ensemble performance
            negative_log_likelihood = ensemble_crossentropy(
                labels, logits, ensemble_size)
            probs = tf.reduce_mean(per_probs, axis=0)
            if dataset_name == 'clean':
                metrics['test/negative_log_likelihood'].update_state(
                    negative_log_likelihood)
                metrics['test/accuracy'].update_state(labels, probs)
                metrics['test/ece'].add_batch(probs, label=labels)
            else:
                corrupt_metrics['test/nll_{}'.format(
                    dataset_name)].update_state(negative_log_likelihood)
                corrupt_metrics['test/accuracy_{}'.format(
                    dataset_name)].update_state(labels, probs)
                corrupt_metrics['test/ece_{}'.format(dataset_name)].add_batch(
                    probs, label=labels)

            if dataset_name == 'clean':
                per_probs_stacked = tf.stack(per_probs, axis=0)
                diversity = rm.metrics.AveragePairwiseDiversity()
                diversity.add_batch(per_probs_stacked,
                                    num_models=ensemble_size)
                diversity_results = diversity.result()
                for k, v in diversity_results.items():
                    metrics['test/' + k].update_state(v)

        strategy.run(step_fn, args=(next(iterator), ))

    logging.info('--- Starting training using %d examples. ---',
                 train_sample_size)
    train_iterator = iter(train_dataset)
    validation_iterator = iter(validation_dataset)
    start_time = time.time()
    for epoch in range(initial_epoch, FLAGS.train_epochs):
        logging.info('Starting to run epoch: %s', epoch)
        for step in range(steps_per_epoch):
            train_step(train_iterator)
            do_tuning = (epoch >= FLAGS.tuning_warmup_epochs)
            if do_tuning and ((step + 1) % FLAGS.tuning_every_x_step == 0):
                tuning_step(validation_iterator)
                # clip lambda parameters if outside of range
                clip_lambda_parameters(lambda_parameters, lambdas_config)

            current_step = epoch * steps_per_epoch + (step + 1)
            max_steps = steps_per_epoch * FLAGS.train_epochs
            time_elapsed = time.time() - start_time
            steps_per_sec = float(current_step) / time_elapsed
            eta_seconds = (max_steps - current_step) / steps_per_sec
            message = ('{:.1%} completion: epoch {:d}/{:d}. {:.1f} steps/s. '
                       'ETA: {:.0f} min. Time elapsed: {:.0f} min'.format(
                           current_step / max_steps, epoch + 1,
                           FLAGS.train_epochs, steps_per_sec, eta_seconds / 60,
                           time_elapsed / 60))
            if step % 20 == 0:
                logging.info(message)

        # evaluate on test data
        datasets_to_evaluate = {'clean': test_datasets['clean']}
        if (FLAGS.corruptions_interval > 0
                and (epoch + 1) % FLAGS.corruptions_interval == 0):
            datasets_to_evaluate = test_datasets
        for dataset_name, test_dataset in datasets_to_evaluate.items():
            test_iterator = iter(test_dataset)
            logging.info('Testing on dataset %s', dataset_name)
            for step in range(steps_per_eval):
                if step % 20 == 0:
                    logging.info('Starting to run eval step %s of epoch: %s',
                                 step, epoch)
                test_step(test_iterator, dataset_name, FLAGS.num_eval_samples)
            logging.info('Done with testing on %s', dataset_name)

        corrupt_results = {}
        if (FLAGS.corruptions_interval > 0
                and (epoch + 1) % FLAGS.corruptions_interval == 0):
            corrupt_results = utils.aggregate_corrupt_metrics(
                corrupt_metrics, corruption_types)
        logging.info('Train Loss: %.4f, Accuracy: %.2f%%',
                     metrics['train/loss'].result(),
                     metrics['train/accuracy'].result() * 100)
        logging.info('Validation Loss: %.4f, CE: %.4f, Entropy: %.4f',
                     metrics['validation/loss'].result(),
                     metrics['validation/loss_ce'].result(),
                     metrics['validation/loss_entropy'].result())
        logging.info('Test NLL: %.4f, Accuracy: %.2f%%',
                     metrics['test/negative_log_likelihood'].result(),
                     metrics['test/accuracy'].result() * 100)
        for i in range(FLAGS.ensemble_size):
            logging.info(
                'Member %d Test Loss: %.4f, Accuracy: %.2f%%', i,
                metrics['test/nll_member_{}'.format(i)].result(),
                metrics['test/accuracy_member_{}'.format(i)].result() * 100)

        total_results = {
            name: metric.result()
            for name, metric in metrics.items()
        }
        total_results.update({
            name: metric.result()
            for name, metric in corrupt_metrics.items()
        })
        total_results.update(corrupt_results)
        # Metrics from Robustness Metrics (like ECE) will return a dict with a
        # single key/value, instead of a scalar.
        total_results = {
            k: (list(v.values())[0] if isinstance(v, dict) else v)
            for k, v in total_results.items()
        }
        with summary_writer.as_default():
            for name, result in total_results.items():
                tf.summary.scalar(name, result, step=epoch + 1)

        for metric in metrics.values():
            metric.reset_states()

        # save checkpoint and lambdas config
        if (FLAGS.checkpoint_interval > 0
                and (epoch + 1) % FLAGS.checkpoint_interval == 0):
            checkpoint_name = checkpoint.save(
                os.path.join(FLAGS.output_dir, 'checkpoint'))
            lambdas_cf = lambdas_config.get_config()
            filepath = os.path.join(FLAGS.output_dir, 'lambdas_config.p')
            with tf.io.gfile.GFile(filepath, 'wb') as fp:
                pickle.dump(lambdas_cf, fp, protocol=pickle.HIGHEST_PROTOCOL)
            logging.info('Saved checkpoint to %s', checkpoint_name)
    with summary_writer.as_default():
        hp.hparams({
            'base_learning_rate':
            FLAGS.base_learning_rate,
            'one_minus_momentum':
            FLAGS.one_minus_momentum,
            'l2':
            FLAGS.l2,
            'random_sign_init':
            FLAGS.random_sign_init,
            'fast_weight_lr_multiplier':
            FLAGS.fast_weight_lr_multiplier,
        })
示例#17
0
def main(hParams, n_run, total_timesteps):
    nsteps = hParams['N_STEPS']
    n_epochs = hParams['N_EPOCHS']
    n_train = 4
    n_minibatch = 8

    log_loss_int = 1
    save_int = 5
    test_int = 10
    test_episodes = 5

    gamma = 0.95
    lr = hParams[HP_LEARNING_RATE]
    vf_coef = hParams[HP_VF_COEF]
    ent_coef = hParams[HP_ENT_COEF]
    save_dir = 'lr' + str(lr) + 'vc' + str(vf_coef) + 'ec' + str(ent_coef)
    testenvfn = SonicEnv.make_env_3

    current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
    log_dir = 'logs/sonic_long_test/run-' + str(n_run)
    summ_writer = tf.summary.create_file_writer(log_dir)

    env = SubprocVecEnv([SonicEnv.make_env_3])

    nenv = env.num_envs
    state_size = env.observation_space.shape
    num_actions = env.action_space.n
    pgnet = PGNetwork(state_size,
                      num_actions,
                      lr=lr,
                      vf_coef=vf_coef,
                      ent_coef=ent_coef)

    #  Runner used to create training data
    runner = SonicEnvRunner(env, pgnet, nsteps, gamma)

    # total_timesteps = int(n_epochs * nsteps * nenv)
    nbatch = nenv * nsteps

    print("Total updates to run: ", total_timesteps // nbatch)
    for update in range(1, total_timesteps // nbatch + 1):

        print("\nUpdate #{}:".format(update))
        states_mb, actions_mb, values_mb, rewards_mb, next_dones_mb = runner.run(
        )

        for _ in range(n_train):
            indices = np.arange(nbatch)

            np.random.shuffle(indices)

            for start in range(0, nbatch, nbatch // n_minibatch):
                end = start + nbatch // n_minibatch
                bind = indices[start:end]
                policy_loss, entropy_loss, vf_loss, loss = pgnet.fit_gradient(
                    states_mb[bind], actions_mb[bind], rewards_mb[bind],
                    values_mb[bind])

        WeightWriter(summ_writer, pgnet, (Conv2D, Dense), global_step=update)

        r2 = 1 - (np.var(rewards_mb - values_mb) / np.var(rewards_mb))

        with summ_writer.as_default():
            tf.summary.scalar("PolicyLoss", policy_loss, step=update)
            tf.summary.scalar("EntropyLoss", entropy_loss, step=update)
            tf.summary.scalar("ValueFunctionLoss", vf_loss, step=update)
            tf.summary.scalar("Loss", loss, step=update)
            tf.summary.scalar("R-squared", r2, step=update)

        if update % log_loss_int == 0:
            print("PolicyLoss:", policy_loss)
            print("EntropyLoss: ", entropy_loss)
            print("ValueFunctionLoss: ", vf_loss)
            print("Loss: ", loss)

        if update % save_int == 0:
            pgnet.model.save_weights('sonic_long_test/' + save_dir +
                                     '/my_checkpoint')
            print("Model Saved")

        if update % test_int == 0:
            TestRewardWriter(summ_writer,
                             testenvfn,
                             pgnet,
                             test_episodes,
                             global_step=update)

    with summ_writer.as_default():
        hp.hparams(hParams)

    env.close()
def main(argv):
    fmt = '[%(filename)s:%(lineno)s] %(message)s'
    formatter = logging.PythonFormatter(fmt)
    logging.get_absl_handler().setFormatter(formatter)
    del argv  # unused arg

    tf.io.gfile.makedirs(FLAGS.output_dir)
    logging.info('Saving checkpoints at %s', FLAGS.output_dir)
    tf.random.set_seed(FLAGS.seed)

    data_dir = None
    if FLAGS.use_gpu:
        logging.info('Use GPU')
        strategy = tf.distribute.MirroredStrategy()
    else:
        logging.info('Use TPU at %s',
                     FLAGS.tpu if FLAGS.tpu is not None else 'local')
        data_dir = FLAGS.data_dir
        resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
            tpu=FLAGS.tpu)
        tf.config.experimental_connect_to_cluster(resolver)
        tf.tpu.experimental.initialize_tpu_system(resolver)
        strategy = tf.distribute.TPUStrategy(resolver)

    ds_info = tfds.builder(FLAGS.dataset).info
    batch_size = FLAGS.per_core_batch_size * FLAGS.num_cores
    train_dataset_size = (ds_info.splits['train'].num_examples *
                          FLAGS.train_proportion)
    steps_per_epoch = int(train_dataset_size / batch_size)
    logging.info('Steps per epoch %s', steps_per_epoch)
    logging.info('Size of the dataset %s',
                 ds_info.splits['train'].num_examples)
    logging.info('Train proportion %s', FLAGS.train_proportion)
    steps_per_eval = ds_info.splits['test'].num_examples // batch_size
    num_classes = ds_info.features['label'].num_classes

    aug_params = {
        'augmix': FLAGS.augmix,
        'aug_count': FLAGS.aug_count,
        'augmix_depth': FLAGS.augmix_depth,
        'augmix_prob_coeff': FLAGS.augmix_prob_coeff,
        'augmix_width': FLAGS.augmix_width,
    }

    # Note that stateless_{fold_in,split} may incur a performance cost, but a
    # quick side-by-side test seemed to imply this was minimal.
    seeds = tf.random.experimental.stateless_split(
        [FLAGS.seed, FLAGS.seed + 1], 2)[:, 0]
    train_builder = ub.datasets.get(FLAGS.dataset,
                                    download_data=FLAGS.download_data,
                                    split=tfds.Split.TRAIN,
                                    seed=seeds[0],
                                    aug_params=aug_params,
                                    validation_percent=1. -
                                    FLAGS.train_proportion,
                                    data_dir=data_dir)
    train_dataset = train_builder.load(batch_size=batch_size)
    validation_dataset = None
    steps_per_validation = 0
    if FLAGS.train_proportion < 1.0:
        validation_builder = ub.datasets.get(FLAGS.dataset,
                                             split=tfds.Split.VALIDATION,
                                             validation_percent=1. -
                                             FLAGS.train_proportion,
                                             data_dir=data_dir)
        validation_dataset = validation_builder.load(batch_size=batch_size)
        validation_dataset = strategy.experimental_distribute_dataset(
            validation_dataset)
        steps_per_validation = validation_builder.num_examples // batch_size
    clean_test_builder = ub.datasets.get(FLAGS.dataset,
                                         split=tfds.Split.TEST,
                                         data_dir=data_dir)
    clean_test_dataset = clean_test_builder.load(batch_size=batch_size)
    train_dataset = strategy.experimental_distribute_dataset(train_dataset)
    test_datasets = {
        'clean': strategy.experimental_distribute_dataset(clean_test_dataset),
    }
    steps_per_epoch = train_builder.num_examples // batch_size
    steps_per_eval = clean_test_builder.num_examples // batch_size
    num_classes = 100 if FLAGS.dataset == 'cifar100' else 10
    if FLAGS.corruptions_interval > 0:
        if FLAGS.dataset == 'cifar100':
            data_dir = FLAGS.cifar100_c_path
        corruption_types, _ = utils.load_corrupted_test_info(FLAGS.dataset)
        for corruption_type in corruption_types:
            for severity in range(1, 6):
                dataset = ub.datasets.get(
                    f'{FLAGS.dataset}_corrupted',
                    corruption_type=corruption_type,
                    severity=severity,
                    split=tfds.Split.TEST,
                    data_dir=data_dir).load(batch_size=batch_size)
                test_datasets[f'{corruption_type}_{severity}'] = (
                    strategy.experimental_distribute_dataset(dataset))

    summary_writer = tf.summary.create_file_writer(
        os.path.join(FLAGS.output_dir, 'summaries'))

    with strategy.scope():
        logging.info('Building ResNet model')
        model = ub.models.wide_resnet(input_shape=(32, 32, 3),
                                      depth=28,
                                      width_multiplier=10,
                                      num_classes=num_classes,
                                      l2=FLAGS.l2,
                                      hps=_extract_hyperparameter_dictionary(),
                                      seed=seeds[1],
                                      version=2)
        logging.info('Model input shape: %s', model.input_shape)
        logging.info('Model output shape: %s', model.output_shape)
        logging.info('Model number of weights: %s', model.count_params())
        # Linearly scale learning rate and the decay epochs by vanilla settings.
        base_lr = FLAGS.base_learning_rate * batch_size / 128
        lr_decay_epochs = [(int(start_epoch_str) * FLAGS.train_epochs) // 200
                           for start_epoch_str in FLAGS.lr_decay_epochs]
        lr_schedule = ub.schedules.WarmUpPiecewiseConstantSchedule(
            steps_per_epoch,
            base_lr,
            decay_ratio=FLAGS.lr_decay_ratio,
            decay_epochs=lr_decay_epochs,
            warmup_epochs=FLAGS.lr_warmup_epochs)
        optimizer = tf.keras.optimizers.SGD(lr_schedule,
                                            momentum=1.0 -
                                            FLAGS.one_minus_momentum,
                                            nesterov=True)
        metrics = {
            'train/negative_log_likelihood':
            tf.keras.metrics.Mean(),
            'train/accuracy':
            tf.keras.metrics.SparseCategoricalAccuracy(),
            'train/loss':
            tf.keras.metrics.Mean(),
            'train/ece':
            rm.metrics.ExpectedCalibrationError(num_bins=FLAGS.num_bins),
            'test/negative_log_likelihood':
            tf.keras.metrics.Mean(),
            'test/accuracy':
            tf.keras.metrics.SparseCategoricalAccuracy(),
            'test/ece':
            rm.metrics.ExpectedCalibrationError(num_bins=FLAGS.num_bins),
        }
        if validation_dataset:
            metrics.update({
                'validation/negative_log_likelihood':
                tf.keras.metrics.Mean(),
                'validation/accuracy':
                tf.keras.metrics.SparseCategoricalAccuracy(),
                'validation/ece':
                rm.metrics.ExpectedCalibrationError(num_bins=FLAGS.num_bins),
            })
        if FLAGS.corruptions_interval > 0:
            corrupt_metrics = {}
            for intensity in range(1, 6):
                for corruption in corruption_types:
                    dataset_name = '{0}_{1}'.format(corruption, intensity)
                    corrupt_metrics['test/nll_{}'.format(dataset_name)] = (
                        tf.keras.metrics.Mean())
                    corrupt_metrics['test/accuracy_{}'.format(
                        dataset_name)] = (
                            tf.keras.metrics.SparseCategoricalAccuracy())
                    corrupt_metrics['test/ece_{}'.format(dataset_name)] = (
                        rm.metrics.ExpectedCalibrationError(
                            num_bins=FLAGS.num_bins))

        checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer)
        latest_checkpoint = tf.train.latest_checkpoint(FLAGS.output_dir)
        initial_epoch = 0
        if latest_checkpoint:
            # checkpoint.restore must be within a strategy.scope() so that optimizer
            # slot variables are mirrored.
            checkpoint.restore(latest_checkpoint)
            logging.info('Loaded checkpoint %s', latest_checkpoint)
            initial_epoch = optimizer.iterations.numpy() // steps_per_epoch

    @tf.function
    def train_step(iterator):
        """Training StepFn."""
        def step_fn(inputs):
            """Per-Replica StepFn."""
            images = inputs['features']
            labels = inputs['labels']

            if FLAGS.augmix and FLAGS.aug_count >= 1:
                # Index 0 at augmix processing is the unperturbed image.
                # We take just 1 augmented image from the returned augmented images.
                images = images[:, 1, ...]
            with tf.GradientTape() as tape:
                logits = model(images, training=True)
                if FLAGS.label_smoothing == 0.:
                    negative_log_likelihood = tf.reduce_mean(
                        tf.keras.losses.sparse_categorical_crossentropy(
                            labels, logits, from_logits=True))
                else:
                    one_hot_labels = tf.one_hot(tf.cast(labels, tf.int32),
                                                num_classes)
                    negative_log_likelihood = tf.reduce_mean(
                        tf.keras.losses.categorical_crossentropy(
                            one_hot_labels,
                            logits,
                            from_logits=True,
                            label_smoothing=FLAGS.label_smoothing))
                l2_loss = sum(model.losses)
                loss = negative_log_likelihood + l2_loss
                # Scale the loss given the TPUStrategy will reduce sum all gradients.
                scaled_loss = loss / strategy.num_replicas_in_sync

            grads = tape.gradient(scaled_loss, model.trainable_variables)
            optimizer.apply_gradients(zip(grads, model.trainable_variables))

            probs = tf.nn.softmax(logits)
            metrics['train/ece'].add_batch(probs, label=labels)
            metrics['train/loss'].update_state(loss)
            metrics['train/negative_log_likelihood'].update_state(
                negative_log_likelihood)
            metrics['train/accuracy'].update_state(labels, logits)

        for _ in tf.range(tf.cast(steps_per_epoch, tf.int32)):
            strategy.run(step_fn, args=(next(iterator), ))

    @tf.function
    def test_step(iterator, dataset_split, dataset_name, num_steps):
        """Evaluation StepFn."""
        def step_fn(inputs):
            """Per-Replica StepFn."""
            images = inputs['features']
            labels = inputs['labels']
            logits = model(images, training=False)
            probs = tf.nn.softmax(logits)
            negative_log_likelihood = tf.reduce_mean(
                tf.keras.losses.sparse_categorical_crossentropy(labels, probs))

            if dataset_name == 'clean':
                metrics[
                    f'{dataset_split}/negative_log_likelihood'].update_state(
                        negative_log_likelihood)
                metrics[f'{dataset_split}/accuracy'].update_state(
                    labels, probs)
                metrics[f'{dataset_split}/ece'].add_batch(probs, label=labels)
            else:
                corrupt_metrics['test/nll_{}'.format(
                    dataset_name)].update_state(negative_log_likelihood)
                corrupt_metrics['test/accuracy_{}'.format(
                    dataset_name)].update_state(labels, probs)
                corrupt_metrics['test/ece_{}'.format(dataset_name)].add_batch(
                    probs, label=labels)

        for _ in tf.range(tf.cast(num_steps, tf.int32)):
            strategy.run(step_fn, args=(next(iterator), ))

    metrics.update({'test/ms_per_example': tf.keras.metrics.Mean()})
    metrics.update({'train/ms_per_example': tf.keras.metrics.Mean()})

    train_iterator = iter(train_dataset)
    start_time = time.time()
    tb_callback = None
    if FLAGS.collect_profile:
        tb_callback = tf.keras.callbacks.TensorBoard(profile_batch=(100, 102),
                                                     log_dir=os.path.join(
                                                         FLAGS.output_dir,
                                                         'logs'))
        tb_callback.set_model(model)
    for epoch in range(initial_epoch, FLAGS.train_epochs):
        logging.info('Starting to run epoch: %s', epoch)
        if tb_callback:
            tb_callback.on_epoch_begin(epoch)
        train_start_time = time.time()
        train_step(train_iterator)
        ms_per_example = (time.time() - train_start_time) * 1e6 / batch_size
        metrics['train/ms_per_example'].update_state(ms_per_example)

        current_step = (epoch + 1) * steps_per_epoch
        max_steps = steps_per_epoch * FLAGS.train_epochs
        time_elapsed = time.time() - start_time
        steps_per_sec = float(current_step) / time_elapsed
        eta_seconds = (max_steps - current_step) / steps_per_sec
        message = ('{:.1%} completion: epoch {:d}/{:d}. {:.1f} steps/s. '
                   'ETA: {:.0f} min. Time elapsed: {:.0f} min'.format(
                       current_step / max_steps, epoch + 1, FLAGS.train_epochs,
                       steps_per_sec, eta_seconds / 60, time_elapsed / 60))
        logging.info(message)
        if tb_callback:
            tb_callback.on_epoch_end(epoch)

        if validation_dataset:
            validation_iterator = iter(validation_dataset)
            test_step(validation_iterator, 'validation', 'clean',
                      steps_per_validation)
        datasets_to_evaluate = {'clean': test_datasets['clean']}
        if (FLAGS.corruptions_interval > 0
                and (epoch + 1) % FLAGS.corruptions_interval == 0):
            datasets_to_evaluate = test_datasets
        for dataset_name, test_dataset in datasets_to_evaluate.items():
            test_iterator = iter(test_dataset)
            logging.info('Testing on dataset %s', dataset_name)
            logging.info('Starting to run eval at epoch: %s', epoch)
            test_start_time = time.time()
            test_step(test_iterator, 'test', dataset_name, steps_per_eval)
            ms_per_example = (time.time() - test_start_time) * 1e6 / batch_size
            metrics['test/ms_per_example'].update_state(ms_per_example)

            logging.info('Done with testing on %s', dataset_name)

        corrupt_results = {}
        if (FLAGS.corruptions_interval > 0
                and (epoch + 1) % FLAGS.corruptions_interval == 0):
            corrupt_results = utils.aggregate_corrupt_metrics(
                corrupt_metrics, corruption_types)

        logging.info('Train Loss: %.4f, Accuracy: %.2f%%',
                     metrics['train/loss'].result(),
                     metrics['train/accuracy'].result() * 100)
        logging.info('Test NLL: %.4f, Accuracy: %.2f%%',
                     metrics['test/negative_log_likelihood'].result(),
                     metrics['test/accuracy'].result() * 100)
        total_results = {
            name: metric.result()
            for name, metric in metrics.items()
        }
        total_results.update(corrupt_results)
        # Metrics from Robustness Metrics (like ECE) will return a dict with a
        # single key/value, instead of a scalar.
        total_results = {
            k: (list(v.values())[0] if isinstance(v, dict) else v)
            for k, v in total_results.items()
        }
        with summary_writer.as_default():
            for name, result in total_results.items():
                tf.summary.scalar(name, result, step=epoch + 1)

        for metric in metrics.values():
            metric.reset_states()

        if (FLAGS.checkpoint_interval > 0
                and (epoch + 1) % FLAGS.checkpoint_interval == 0):
            checkpoint_name = checkpoint.save(
                os.path.join(FLAGS.output_dir, 'checkpoint'))
            logging.info('Saved checkpoint to %s', checkpoint_name)

    final_checkpoint_name = checkpoint.save(
        os.path.join(FLAGS.output_dir, 'checkpoint'))
    logging.info('Saved last checkpoint to %s', final_checkpoint_name)
    with summary_writer.as_default():
        hp.hparams({
            'base_learning_rate': FLAGS.base_learning_rate,
            'one_minus_momentum': FLAGS.one_minus_momentum,
            'l2': FLAGS.l2,
        })
示例#19
0
def main(argv):
    del argv  # unused arg
    tf.io.gfile.makedirs(FLAGS.output_dir)
    logging.info('Saving checkpoints at %s', FLAGS.output_dir)
    tf.random.set_seed(FLAGS.seed)

    # Initialize distribution strategy on flag-specified accelerator
    strategy = utils.init_distribution_strategy(FLAGS.force_use_cpu,
                                                FLAGS.use_gpu, FLAGS.tpu)
    use_tpu = not (FLAGS.force_use_cpu or FLAGS.use_gpu)

    train_batch_size = (FLAGS.train_batch_size * FLAGS.num_cores)

    if use_tpu:
        logging.info(
            'Due to TPU requiring static shapes, we must fix the eval batch size '
            'to the train batch size: %d/', train_batch_size)
        eval_batch_size = train_batch_size
    else:
        eval_batch_size = (FLAGS.eval_batch_size *
                           FLAGS.num_cores) // FLAGS.num_dropout_samples_eval

    # As per the Kaggle challenge, we have split sizes:
    # train: 35,126
    # validation: 10,906 (currently unused)
    # test: 42,670
    ds_info = tfds.builder('diabetic_retinopathy_detection').info
    steps_per_epoch = ds_info.splits['train'].num_examples // train_batch_size
    steps_per_validation_eval = (ds_info.splits['validation'].num_examples //
                                 eval_batch_size)
    steps_per_test_eval = ds_info.splits['test'].num_examples // eval_batch_size

    data_dir = FLAGS.data_dir

    dataset_train_builder = ub.datasets.get('diabetic_retinopathy_detection',
                                            split='train',
                                            data_dir=data_dir)
    dataset_train = dataset_train_builder.load(batch_size=train_batch_size)
    dataset_train = strategy.experimental_distribute_dataset(dataset_train)

    dataset_validation_builder = ub.datasets.get(
        'diabetic_retinopathy_detection',
        split='validation',
        data_dir=data_dir)
    dataset_validation = dataset_validation_builder.load(
        batch_size=eval_batch_size)
    dataset_validation = strategy.experimental_distribute_dataset(
        dataset_validation)

    dataset_test_builder = ub.datasets.get('diabetic_retinopathy_detection',
                                           split='test',
                                           data_dir=data_dir)
    dataset_test = dataset_test_builder.load(batch_size=eval_batch_size)
    dataset_test = strategy.experimental_distribute_dataset(dataset_test)

    if FLAGS.use_bfloat16:
        policy = tf.keras.mixed_precision.experimental.Policy('mixed_bfloat16')
        tf.keras.mixed_precision.experimental.set_policy(policy)

    summary_writer = tf.summary.create_file_writer(
        os.path.join(FLAGS.output_dir, 'summaries'))

    with strategy.scope():
        logging.info('Building Keras ResNet-50 MC Dropout model.')
        model = ub.models.resnet50_dropout(
            input_shape=utils.load_input_shape(dataset_train),
            num_classes=1,  # binary classification task
            dropout_rate=FLAGS.dropout_rate,
            filterwise_dropout=FLAGS.filterwise_dropout)
        logging.info('Model input shape: %s', model.input_shape)
        logging.info('Model output shape: %s', model.output_shape)
        logging.info('Model number of weights: %s', model.count_params())

        # Linearly scale learning rate and the decay epochs by vanilla settings.
        base_lr = FLAGS.base_learning_rate
        lr_decay_epochs = [
            (int(start_epoch_str) * FLAGS.train_epochs) // DEFAULT_NUM_EPOCHS
            for start_epoch_str in FLAGS.lr_decay_epochs
        ]

        lr_schedule = ub.schedules.WarmUpPiecewiseConstantSchedule(
            steps_per_epoch,
            base_lr,
            decay_ratio=FLAGS.lr_decay_ratio,
            decay_epochs=lr_decay_epochs,
            warmup_epochs=FLAGS.lr_warmup_epochs)
        optimizer = tf.keras.optimizers.SGD(lr_schedule,
                                            momentum=1.0 -
                                            FLAGS.one_minus_momentum,
                                            nesterov=True)
        metrics = utils.get_diabetic_retinopathy_base_metrics(
            use_tpu=use_tpu, num_bins=FLAGS.num_bins)
        checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer)
        latest_checkpoint = tf.train.latest_checkpoint(FLAGS.output_dir)
        initial_epoch = 0
        if latest_checkpoint:
            # checkpoint.restore must be within a strategy.scope()
            # so that optimizer slot variables are mirrored.
            checkpoint.restore(latest_checkpoint)
            logging.info('Loaded checkpoint %s', latest_checkpoint)
            initial_epoch = optimizer.iterations.numpy() // steps_per_epoch

    # Finally, define OOD metrics outside the accelerator scope for CPU eval.
    # This will cause an error on TPU.
    if not use_tpu:
        metrics.update({
            'train/auc': tf.keras.metrics.AUC(),
            'validation/auc': tf.keras.metrics.AUC(),
            'test/auc': tf.keras.metrics.AUC()
        })

    @tf.function
    def train_step(iterator):
        """Training step function."""
        def step_fn(inputs):
            """Per-replica step function."""
            images = inputs['features']
            labels = inputs['labels']

            with tf.GradientTape() as tape:
                logits = model(images, training=True)
                if FLAGS.use_bfloat16:
                    logits = tf.cast(logits, tf.float32)

                negative_log_likelihood = tf.reduce_mean(
                    tf.keras.losses.binary_crossentropy(y_true=tf.expand_dims(
                        labels, axis=-1),
                                                        y_pred=logits,
                                                        from_logits=True))
                l2_loss = sum(model.losses)
                loss = negative_log_likelihood + (FLAGS.l2 * l2_loss)

                # Scale the loss given the TPUStrategy will reduce sum all gradients.
                scaled_loss = loss / strategy.num_replicas_in_sync

            grads = tape.gradient(scaled_loss, model.trainable_variables)
            optimizer.apply_gradients(zip(grads, model.trainable_variables))
            probs = tf.nn.sigmoid(logits)

            metrics['train/loss'].update_state(loss)
            metrics['train/negative_log_likelihood'].update_state(
                negative_log_likelihood)
            metrics['train/accuracy'].update_state(labels, probs)
            metrics['train/auc'].update_state(labels, probs)

            if not use_tpu:
                metrics['train/ece'].update_state(labels, probs)

        strategy.run(step_fn, args=(next(iterator), ))

    @tf.function
    def test_step(iterator, dataset_split):
        """Evaluation step function."""
        def step_fn(inputs):
            """Per-replica step function."""
            images = inputs['features']
            labels = tf.convert_to_tensor(inputs['labels'])

            logits_list = []
            for _ in range(FLAGS.num_dropout_samples_eval):
                logits = model(images, training=False)
                logits = tf.squeeze(logits, axis=-1)
                if FLAGS.use_bfloat16:
                    logits = tf.cast(logits, tf.float32)

                logits_list.append(logits)

            # Logits dimension is (num_samples, batch_size).
            logits_list = tf.stack(logits_list, axis=0)
            probs_list = tf.nn.sigmoid(logits_list)
            probs = tf.reduce_mean(probs_list, axis=0)
            labels_broadcasted = tf.broadcast_to(
                labels, [FLAGS.num_dropout_samples_eval, labels.shape[0]])
            log_likelihoods = -tf.keras.losses.binary_crossentropy(
                labels_broadcasted, logits_list, from_logits=True)
            negative_log_likelihood = tf.reduce_mean(
                -tf.reduce_logsumexp(log_likelihoods, axis=[0]) +
                tf.math.log(float(FLAGS.num_dropout_samples_eval)))
            metrics[dataset_split + '/negative_log_likelihood'].update_state(
                negative_log_likelihood)
            metrics[dataset_split + '/accuracy'].update_state(labels, probs)
            metrics[dataset_split + '/auc'].update_state(labels, probs)

            if not use_tpu:
                metrics[dataset_split + '/ece'].update_state(labels, probs)

        strategy.run(step_fn, args=(next(iterator), ))

    metrics.update({'test/ms_per_example': tf.keras.metrics.Mean()})
    start_time = time.time()

    train_iterator = iter(dataset_train)
    for epoch in range(initial_epoch, FLAGS.train_epochs):
        logging.info('Starting to run epoch: %s', epoch + 1)
        for step in range(steps_per_epoch):
            train_step(train_iterator)

            current_step = epoch * steps_per_epoch + (step + 1)
            max_steps = steps_per_epoch * FLAGS.train_epochs
            time_elapsed = time.time() - start_time
            steps_per_sec = float(current_step) / time_elapsed
            eta_seconds = (max_steps - current_step) / steps_per_sec
            message = ('{:.1%} completion: epoch {:d}/{:d}. {:.1f} steps/s. '
                       'ETA: {:.0f} min. Time elapsed: {:.0f} min'.format(
                           current_step / max_steps, epoch + 1,
                           FLAGS.train_epochs, steps_per_sec, eta_seconds / 60,
                           time_elapsed / 60))
            if step % 20 == 0:
                logging.info(message)

        validation_iterator = iter(dataset_validation)
        for step in range(steps_per_validation_eval):
            if step % 20 == 0:
                logging.info(
                    'Starting to run validation eval step %s of epoch: %s',
                    step, epoch + 1)
            test_step(validation_iterator, 'validation')

        test_iterator = iter(dataset_test)
        for step in range(steps_per_test_eval):
            if step % 20 == 0:
                logging.info('Starting to run test eval step %s of epoch: %s',
                             step, epoch + 1)
            test_start_time = time.time()
            test_step(test_iterator, 'test')
            ms_per_example = (time.time() -
                              test_start_time) * 1e6 / eval_batch_size
            metrics['test/ms_per_example'].update_state(ms_per_example)

        # Log and write to summary the epoch metrics
        utils.log_epoch_metrics(metrics=metrics, use_tpu=use_tpu)
        total_results = {
            name: metric.result()
            for name, metric in metrics.items()
        }
        with summary_writer.as_default():
            for name, result in total_results.items():
                tf.summary.scalar(name, result, step=epoch + 1)

        for metric in metrics.values():
            metric.reset_states()

        if (FLAGS.checkpoint_interval > 0
                and (epoch + 1) % FLAGS.checkpoint_interval == 0):
            checkpoint_name = checkpoint.save(
                os.path.join(FLAGS.output_dir, 'checkpoint'))
            logging.info('Saved checkpoint to %s', checkpoint_name)

            # TODO(nband): debug checkpointing
            # Also save Keras model, due to checkpoint.save issue
            keras_model_name = os.path.join(FLAGS.output_dir,
                                            f'keras_model_{epoch + 1}')
            model.save(keras_model_name)
            logging.info('Saved keras model to %s', keras_model_name)

    final_checkpoint_name = checkpoint.save(
        os.path.join(FLAGS.output_dir, 'checkpoint'))
    logging.info('Saved last checkpoint to %s', final_checkpoint_name)

    keras_model_name = os.path.join(FLAGS.output_dir,
                                    f'keras_model_{FLAGS.train_epochs}')
    model.save(keras_model_name)
    logging.info('Saved keras model to %s', keras_model_name)
    with summary_writer.as_default():
        hp.hparams({
            'base_learning_rate': FLAGS.base_learning_rate,
            'one_minus_momentum': FLAGS.one_minus_momentum,
            'dropout_rate': FLAGS.dropout_rate,
            'l2': FLAGS.l2,
        })
示例#20
0
def main(argv):
    del argv  # unused arg
    tf.io.gfile.makedirs(FLAGS.output_dir)
    logging.info('Saving checkpoints at %s', FLAGS.output_dir)
    tf.random.set_seed(FLAGS.seed)

    train_batch_size = (FLAGS.per_core_batch_size * FLAGS.num_cores //
                        FLAGS.batch_repetitions)
    test_batch_size = FLAGS.per_core_batch_size * FLAGS.num_cores
    steps_per_epoch = APPROX_IMAGENET_TRAIN_IMAGES // train_batch_size
    steps_per_eval = IMAGENET_VALIDATION_IMAGES // test_batch_size

    data_dir = FLAGS.data_dir
    if FLAGS.use_gpu:
        logging.info('Use GPU')
        strategy = tf.distribute.MirroredStrategy()
    else:
        logging.info('Use TPU at %s',
                     FLAGS.tpu if FLAGS.tpu is not None else 'local')
        resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
            tpu=FLAGS.tpu)
        tf.config.experimental_connect_to_cluster(resolver)
        tf.tpu.experimental.initialize_tpu_system(resolver)
        strategy = tf.distribute.TPUStrategy(resolver)

    train_builder = ub.datasets.ImageNetDataset(
        split=tfds.Split.TRAIN,
        use_bfloat16=FLAGS.use_bfloat16,
        data_dir=data_dir)
    train_dataset = train_builder.load(batch_size=train_batch_size,
                                       strategy=strategy)
    test_builder = ub.datasets.ImageNetDataset(split=tfds.Split.TEST,
                                               use_bfloat16=FLAGS.use_bfloat16,
                                               data_dir=data_dir)
    test_dataset = test_builder.load(batch_size=test_batch_size,
                                     strategy=strategy)

    if FLAGS.use_bfloat16:
        tf.keras.mixed_precision.set_global_policy('mixed_bfloat16')

    with strategy.scope():
        logging.info('Building Keras ResNet-50 model')
        model = ub.models.resnet50_het_mimo(
            input_shape=(FLAGS.ensemble_size, 224, 224, 3),
            num_classes=NUM_CLASSES,
            ensemble_size=FLAGS.ensemble_size,
            num_factors=FLAGS.num_factors,
            temperature=FLAGS.temperature,
            num_mc_samples=FLAGS.num_mc_samples,
            share_het_layer=FLAGS.share_het_layer,
            width_multiplier=FLAGS.width_multiplier)
        logging.info('Model input shape: %s', model.input_shape)
        logging.info('Model output shape: %s', model.output_shape)
        logging.info('Model number of weights: %s', model.count_params())
        # Scale learning rate and decay epochs by vanilla settings.
        base_lr = FLAGS.base_learning_rate * train_batch_size / 256
        decay_epochs = [
            (FLAGS.train_epochs * 30) // 90,
            (FLAGS.train_epochs * 60) // 90,
            (FLAGS.train_epochs * 80) // 90,
        ]
        learning_rate = ub.schedules.WarmUpPiecewiseConstantSchedule(
            steps_per_epoch=steps_per_epoch,
            base_learning_rate=base_lr,
            decay_ratio=0.1,
            decay_epochs=decay_epochs,
            warmup_epochs=5)
        optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate,
                                            momentum=1.0 -
                                            FLAGS.one_minus_momentum,
                                            nesterov=True)
        metrics = {
            'train/negative_log_likelihood':
            tf.keras.metrics.Mean(),
            'train/accuracy':
            tf.keras.metrics.SparseCategoricalAccuracy(),
            'train/loss':
            tf.keras.metrics.Mean(),
            'train/ece':
            rm.metrics.ExpectedCalibrationError(num_bins=FLAGS.num_bins),
            'test/negative_log_likelihood':
            tf.keras.metrics.Mean(),
            'test/accuracy':
            tf.keras.metrics.SparseCategoricalAccuracy(),
            'test/ece':
            rm.metrics.ExpectedCalibrationError(num_bins=FLAGS.num_bins),
            'test/diversity':
            rm.metrics.AveragePairwiseDiversity(),
        }

        for i in range(FLAGS.ensemble_size):
            metrics['test/nll_member_{}'.format(i)] = tf.keras.metrics.Mean()
            metrics['test/accuracy_member_{}'.format(i)] = (
                tf.keras.metrics.SparseCategoricalAccuracy())
        logging.info('Finished building Keras ResNet-50 model')

        checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer)
        latest_checkpoint = tf.train.latest_checkpoint(FLAGS.output_dir)
        initial_epoch = 0
        if latest_checkpoint:
            # checkpoint.restore must be within a strategy.scope() so that optimizer
            # slot variables are mirrored.
            checkpoint.restore(latest_checkpoint)
            logging.info('Loaded checkpoint %s', latest_checkpoint)
            initial_epoch = optimizer.iterations.numpy() // steps_per_epoch

    summary_writer = tf.summary.create_file_writer(
        os.path.join(FLAGS.output_dir, 'summaries'))

    @tf.function
    def train_step(iterator):
        """Training StepFn."""
        def step_fn(inputs):
            """Per-Replica StepFn."""
            images = inputs['features']
            labels = inputs['labels']
            batch_size = tf.shape(images)[0]
            main_shuffle = tf.random.shuffle(
                tf.tile(tf.range(batch_size), [FLAGS.batch_repetitions]))
            to_shuffle = tf.cast(
                tf.cast(tf.shape(main_shuffle)[0], tf.float32) *
                (1. - FLAGS.input_repetition_probability), tf.int32)
            shuffle_indices = [
                tf.concat([
                    tf.random.shuffle(main_shuffle[:to_shuffle]),
                    main_shuffle[to_shuffle:]
                ],
                          axis=0) for _ in range(FLAGS.ensemble_size)
            ]
            images = tf.stack([
                tf.gather(images, indices, axis=0)
                for indices in shuffle_indices
            ],
                              axis=1)
            labels = tf.stack([
                tf.gather(labels, indices, axis=0)
                for indices in shuffle_indices
            ],
                              axis=1)

            with tf.GradientTape() as tape:
                logits = model(images, training=True)
                if FLAGS.use_bfloat16:
                    logits = tf.cast(logits, tf.float32)

                negative_log_likelihood = tf.reduce_mean(
                    tf.reduce_sum(
                        tf.keras.losses.sparse_categorical_crossentropy(
                            labels, logits, from_logits=True),
                        axis=1))
                filtered_variables = []
                for var in model.trainable_variables:
                    # Apply l2 on the weights. This excludes BN parameters and biases, but
                    # pay caution to their naming scheme.
                    if 'kernel' in var.name or 'bias' in var.name:
                        filtered_variables.append(tf.reshape(var, (-1, )))

                l2_loss = FLAGS.l2 * 2 * tf.nn.l2_loss(
                    tf.concat(filtered_variables, axis=0))
                # Scale the loss given the TPUStrategy will reduce sum all gradients.
                loss = negative_log_likelihood + l2_loss
                scaled_loss = loss / strategy.num_replicas_in_sync

            grads = tape.gradient(scaled_loss, model.trainable_variables)
            optimizer.apply_gradients(zip(grads, model.trainable_variables))

            probs = tf.nn.softmax(tf.reshape(logits, [-1, NUM_CLASSES]))
            flat_labels = tf.reshape(labels, [-1])
            metrics['train/ece'].add_batch(probs, label=flat_labels)
            metrics['train/loss'].update_state(loss)
            metrics['train/negative_log_likelihood'].update_state(
                negative_log_likelihood)
            metrics['train/accuracy'].update_state(flat_labels, probs)

        for _ in tf.range(tf.cast(steps_per_epoch, tf.int32)):
            strategy.run(step_fn, args=(next(iterator), ))

    @tf.function
    def test_step(iterator):
        """Evaluation StepFn."""
        def step_fn(inputs):
            """Per-Replica StepFn."""
            images = inputs['features']
            labels = inputs['labels']
            images = tf.tile(tf.expand_dims(images, 1),
                             [1, FLAGS.ensemble_size, 1, 1, 1])
            logits = model(images, training=False)
            if FLAGS.use_bfloat16:
                logits = tf.cast(logits, tf.float32)
            probs = tf.nn.softmax(logits)

            per_probs = tf.transpose(probs, perm=[1, 0, 2])
            metrics['test/diversity'].add_batch(per_probs)

            for i in range(FLAGS.ensemble_size):
                member_probs = probs[:, i]
                member_loss = tf.keras.losses.sparse_categorical_crossentropy(
                    labels, member_probs)
                metrics['test/nll_member_{}'.format(i)].update_state(
                    member_loss)
                metrics['test/accuracy_member_{}'.format(i)].update_state(
                    labels, member_probs)

            # Negative log marginal likelihood computed in a numerically-stable way.
            labels_tiled = tf.tile(tf.expand_dims(labels, 1),
                                   [1, FLAGS.ensemble_size])
            log_likelihoods = -tf.keras.losses.sparse_categorical_crossentropy(
                labels_tiled, logits, from_logits=True)
            negative_log_likelihood = tf.reduce_mean(
                -tf.reduce_logsumexp(log_likelihoods, axis=[1]) +
                tf.math.log(float(FLAGS.ensemble_size)))
            probs = tf.math.reduce_mean(probs, axis=1)  # marginalize

            metrics['test/negative_log_likelihood'].update_state(
                negative_log_likelihood)
            metrics['test/accuracy'].update_state(labels, probs)
            metrics['test/ece'].add_batch(probs, label=labels)

        for _ in tf.range(tf.cast(steps_per_eval, tf.int32)):
            strategy.run(step_fn, args=(next(iterator), ))

    metrics.update({'test/ms_per_example': tf.keras.metrics.Mean()})

    train_iterator = iter(train_dataset)
    start_time = time.time()
    for epoch in range(initial_epoch, FLAGS.train_epochs):
        logging.info('Starting to run epoch: %s', epoch)
        train_step(train_iterator)

        current_step = (epoch + 1) * steps_per_epoch
        max_steps = steps_per_epoch * FLAGS.train_epochs
        time_elapsed = time.time() - start_time
        steps_per_sec = float(current_step) / time_elapsed
        eta_seconds = (max_steps - current_step) / steps_per_sec
        message = ('{:.1%} completion: epoch {:d}/{:d}. {:.1f} steps/s. '
                   'ETA: {:.0f} min. Time elapsed: {:.0f} min'.format(
                       current_step / max_steps, epoch + 1, FLAGS.train_epochs,
                       steps_per_sec, eta_seconds / 60, time_elapsed / 60))
        logging.info(message)

        test_iterator = iter(test_dataset)
        logging.info('Starting to run eval of epoch: %s', epoch)
        test_start_time = time.time()
        test_step(test_iterator)
        ms_per_example = (time.time() -
                          test_start_time) * 1e6 / test_batch_size
        metrics['test/ms_per_example'].update_state(ms_per_example)

        logging.info('Train Loss: %.4f, Accuracy: %.2f%%',
                     metrics['train/loss'].result(),
                     metrics['train/accuracy'].result() * 100)
        logging.info('Test NLL: %.4f, Accuracy: %.2f%%',
                     metrics['test/negative_log_likelihood'].result(),
                     metrics['test/accuracy'].result() * 100)
        for i in range(FLAGS.ensemble_size):
            logging.info(
                'Member %d Test Loss: %.4f, Accuracy: %.2f%%', i,
                metrics['test/nll_member_{}'.format(i)].result(),
                metrics['test/accuracy_member_{}'.format(i)].result() * 100)

        total_results = {
            name: metric.result()
            for name, metric in metrics.items()
        }
        # Results from Robustness Metrics themselves return a dict, so flatten them.
        total_results = utils.flatten_dictionary(total_results)
        with summary_writer.as_default():
            for name, result in total_results.items():
                tf.summary.scalar(name, result, step=epoch + 1)

        for _, metric in metrics.items():
            metric.reset_states()

        if (FLAGS.checkpoint_interval > 0
                and (epoch + 1) % FLAGS.checkpoint_interval == 0):
            checkpoint_name = checkpoint.save(
                os.path.join(FLAGS.output_dir, 'checkpoint'))
            logging.info('Saved checkpoint to %s', checkpoint_name)

    final_save_name = os.path.join(FLAGS.output_dir, 'model')
    model.save(final_save_name)
    logging.info('Saved model to %s', final_save_name)
    with summary_writer.as_default():
        hp.hparams({
            'base_learning_rate': FLAGS.base_learning_rate,
            'one_minus_momentum': FLAGS.one_minus_momentum,
            'l2': FLAGS.l2,
            'batch_repetitions': FLAGS.batch_repetitions,
        })
示例#21
0
def write_hparams_v2(writer, hparams: dict):
    hparams = _copy_and_clean_hparams(hparams)
    hparams = _set_precision_if_missing(hparams)

    with writer.as_default():
        hp.hparams(hparams)
示例#22
0
def main():

    print("Tensorflow version {}".format(tf.__version__))

    # Configuration
    MODEL_NAME    = "ALAE_CONV_V1"
    generate_mnist_samples = False
    generate_samples_tensorboard = True
    PRINT_IT  = 50
    RESULT_IT = 100
    SAVE_WEIGHT_IT = 5000

    # Network configuration & hyper parameters
    EPOCHS       = 100
    BATCH_SIZE   = 128
    Z_DIM        = 100
    LATENT_DIM   = 50
    GAMMA_GP     = 10
    K_RECONST_KL = 0.5 # Latent space quality, Pure reconstruction & Kullback Leibler ratio
    # Learning Rate for Discriminator, Generator, Latent Space
    LR_D_G_L     = [0.0001,0.0004,0.0002] # Best

    #
    # Manage folders
    #
    original_mnist_samples = os.path.join("results", "mnist_original")
    current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
    mnist_samples_ = os.path.join("results", MODEL_NAME + "_" + current_time)
    folder_to_create = [original_mnist_samples,mnist_samples_]
    # Create folders
    for folder in folder_to_create:
        if not os.path.exists(folder):
            os.makedirs(folder)
    # Define folders
    checkpoint_path        = os.path.join("checkpoint", MODEL_NAME)
    original_mnist_samples = os.path.join(original_mnist_samples, "samples_{}.png")
    mnist_samples = os.path.join(mnist_samples_, "alae_samples_{}.png")
    static_mnist_samples = os.path.join(mnist_samples_, "static_alae_samples_{}.png")
    train_log_dir = os.path.join("logs", "tensorboard", MODEL_NAME + "_" + current_time)
    train_summary_writer = tf.summary.create_file_writer(train_log_dir)

    # Log hyper parameters
    HP_BATCH_SIZE = hp.HParam('HP_BATCH_SIZE', hp.Discrete([64,128,256,512,1024]))
    HP_Z_DIM      = hp.HParam('HP_Z_DIM',      hp.Discrete([50,100,200]))
    HP_LATENT_DIM = hp.HParam('HP_LATENT_DIM', hp.Discrete([30,50,70]))
    HP_GAMMA_GP   = hp.HParam('HP_GAMMA_GP',   hp.Discrete([2,5,10]))
    HP_K_RECONST_KL     = hp.HParam('HP_K_RECONST_KL',     hp.RealInterval (0.,1.))
    HP_LR_GENERATOR     = hp.HParam('HP_LR_GENERATOR',     hp.RealInterval (0.,0.1))
    HP_LR_DISCRIMINATOR = hp.HParam('HP_LR_DISCRIMINATOR', hp.RealInterval (0.,0.1))
    HP_LR_LATENT        = hp.HParam('HP_LR_LATENT',        hp.RealInterval (0.,0.1))

    hparams = {
        HP_BATCH_SIZE:       BATCH_SIZE,
        HP_Z_DIM:            Z_DIM,
        HP_LATENT_DIM:       LATENT_DIM,
        HP_GAMMA_GP:         GAMMA_GP,
        HP_K_RECONST_KL:     K_RECONST_KL,
        HP_LR_DISCRIMINATOR: LR_D_G_L[0],
        HP_LR_GENERATOR:     LR_D_G_L[1],
        HP_LR_LATENT:        LR_D_G_L[2],
    }

    # Log hyper parameters
    METRIC_LATENT_LOST = 'latent_loss'
    with train_summary_writer.as_default():
        hp.hparams_config(hparams, metrics=[hp.Metric(METRIC_LATENT_LOST, display_name='Latent Loss')])
        hp.hparams(hparams, trial_id = MODEL_NAME + "_" + current_time)
        tf.summary.scalar(METRIC_LATENT_LOST, 0, step=1)


    # Do useful stuff
    seed = 2020
    np.random.seed(seed)
    tf.random.set_seed(seed)

    #
    # Load data
    #
    (x_train, _), (_, _) = datasets.mnist.load_data()

    # Prepare the data
    train_dataset = tf.data.Dataset.from_tensor_slices(x_train)
    train_dataset = train_dataset.batch(BATCH_SIZE)
    train_dataset = train_dataset.shuffle(buffer_size=1024)
    train_dataset = train_dataset.map(utils.tf_process_images)

    # x_train = utils.process_images(x_train)
    IMAGE_DIM = 32 * 32
    # x_train = tf.dtypes.cast(tf.reshape(x_train, (len(x_train), IMAGE_DIM) ), tf.float32)

    # A way to plot some real examples of MNIST
    if generate_mnist_samples:
        for index in range(16):
            source_indexes = np.random.permutation(len(x_train.numpy()))[0:64]
            utils.plot_mnist_grid(x_train.numpy()[source_indexes], target_file = original_mnist_samples.format(index))

    #
    # Prepare the models
    #
    conf_dict = {"Z_DIM": Z_DIM, "LATENT_DIM": LATENT_DIM,
                 "IMAGE_DIM": IMAGE_DIM,
                 "GAMMA_GP": GAMMA_GP,
                 "LR_D_G_L": LR_D_G_L,
                 "K_RECONST_KL": K_RECONST_KL, }

    generator       = alae_tf2_models.Generator(conf_dict,train_summary_writer)
    discriminator   = alae_tf2_models.Discriminator(conf_dict,train_summary_writer)
    E_encoder       = alae_tf2_models.E_encoder(conf_dict,train_summary_writer)
    F_encoder       = alae_tf2_models.F_encoder(conf_dict,train_summary_writer)

    alae_helper = alae.alae_helper({"generator":generator,
                                    "discriminator":discriminator,
                                    "E_encoder":E_encoder,
                                    "F_encoder":F_encoder,}, conf_dict)

    # Prepate to save the weights
    checkpoint = tf.train.Checkpoint(generator_optimizer=generator,
                                     discriminator_optimizer=discriminator,
                                     E_encoder_optimizer=E_encoder,
                                     F_encoder_optimizer=F_encoder,
                                     generator=generator,
                                     discriminator=discriminator,
                                     E_encoder=E_encoder,
                                     F_encoder=F_encoder,
                                     step=tf.Variable(1))

    manager = tf.train.CheckpointManager(checkpoint, checkpoint_path, max_to_keep=3)

    z_samples_static = alae_helper.sample_Z(64, Z_DIM)

    # ------------------------------------------------------------------------------------
    # Start of the training loop
    # ------------------------------------------------------------------------------------
    it = 1
    for epoch in range(EPOCHS):

        for x in train_dataset:

            if it % RESULT_IT == 0:
                # 8x8 = 64 as a grid containing figures 0 to 9
                z_samples = alae_helper.sample_Z(64, Z_DIM)
                samples = generator( F_encoder(z_samples, training=False), training = False)
                img = samples.numpy()
                utils.plot_mnist_grid(img, target_file=mnist_samples.format(str(it).zfill(3)))

                samples = generator(F_encoder(z_samples_static, training=False), training=False)
                img = samples.numpy()
                utils.plot_mnist_grid(img, target_file=static_mnist_samples.format(str(it).zfill(3)))

                if generate_samples_tensorboard:
                    # Add results into Tensorboard
                    # (batch_size, height, width, channels)
                    img = np.reshape(img, [64,32,32,1] )
                    with train_summary_writer.as_default():
                        tf.summary.image("Generated Image", img, step=it)

            # The job is done here, x are real samples
            losses = alae_helper.trainstep(x)

            if it % PRINT_IT == 0:
                print('Epoch: {}   it: {}     L_loss: {:.4f}     D_loss: {:.4f}     G_loss: {:.4f}'.format(1 + (it * BATCH_SIZE) // x_train.shape[0],
                                                                                                           it, losses["latent"], losses["disc"], losses["gen"]))
                with train_summary_writer.as_default():

                    tf.summary.scalar('discriminator_loss', losses["disc"], step=it)
                    tf.summary.scalar('generator_loss', losses["gen"], step=it)
                    tf.summary.scalar('latent_loss', losses["latent"], step=it)
                    tf.summary.scalar('latent_loss_reconst', losses["latent_reconst"], step=it)
                    tf.summary.scalar('latent_loss_kl', losses["latent_kl"], step=it)
                    tf.summary.histogram('real_samples', x, step=it)

            if it % SAVE_WEIGHT_IT == 0:
                # Save the weights
                checkpoint.step.assign_add(1)
                save_path = manager.save()
                print("Saved checkpoint for step {}: {}".format(int(checkpoint.step), save_path))

            it+=1 # Next iteration
示例#23
0
def main(argv):
    del argv  # unused arg
    tf.io.gfile.makedirs(FLAGS.output_dir)
    logging.info('Saving checkpoints at %s', FLAGS.output_dir)
    tf.random.set_seed(FLAGS.seed)

    if FLAGS.use_gpu:
        logging.info('Use GPU')
        strategy = tf.distribute.MirroredStrategy()
    else:
        logging.info('Use TPU at %s',
                     FLAGS.tpu if FLAGS.tpu is not None else 'local')
        resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
            tpu=FLAGS.tpu)
        tf.config.experimental_connect_to_cluster(resolver)
        tf.tpu.experimental.initialize_tpu_system(resolver)
        strategy = tf.distribute.TPUStrategy(resolver)

    batch_size = FLAGS.per_core_batch_size * FLAGS.num_cores
    train_dataset_builder = ub.datasets.ClincIntentDetectionDataset(
        split='train', data_dir=FLAGS.data_dir, data_mode='ind')
    ind_dataset_builder = ub.datasets.ClincIntentDetectionDataset(
        split='test', data_dir=FLAGS.data_dir, data_mode='ind')
    ood_dataset_builder = ub.datasets.ClincIntentDetectionDataset(
        split='test', data_dir=FLAGS.data_dir, data_mode='ood')
    all_dataset_builder = ub.datasets.ClincIntentDetectionDataset(
        split='test', data_dir=FLAGS.data_dir, data_mode='all')

    dataset_builders = {
        'clean': ind_dataset_builder,
        'ood': ood_dataset_builder,
        'all': all_dataset_builder
    }

    train_dataset = train_dataset_builder.load(
        batch_size=FLAGS.per_core_batch_size)

    ds_info = train_dataset_builder.tfds_info
    feature_size = ds_info.metadata['feature_size']
    # num_classes is number of valid intents plus out-of-scope intent
    num_classes = ds_info.features['intent_label'].num_classes + 1

    steps_per_epoch = train_dataset_builder.num_examples // batch_size

    test_datasets = {}
    steps_per_eval = {}
    for dataset_name, dataset_builder in dataset_builders.items():
        test_datasets[dataset_name] = dataset_builder.load(
            batch_size=FLAGS.eval_batch_size)
        steps_per_eval[dataset_name] = (dataset_builder.num_examples //
                                        FLAGS.eval_batch_size)

    if FLAGS.use_bfloat16:
        policy = tf.keras.mixed_precision.experimental.Policy('mixed_bfloat16')
        tf.keras.mixed_precision.experimental.set_policy(policy)

    summary_writer = tf.summary.create_file_writer(
        os.path.join(FLAGS.output_dir, 'summaries'))

    with strategy.scope():
        logging.info('Building BERT model')

        bert_config_dir, bert_ckpt_dir = resolve_bert_ckpt_and_config_dir(
            FLAGS.bert_dir, FLAGS.bert_config_dir, FLAGS.bert_ckpt_dir)
        bert_config = bert_utils.create_config(bert_config_dir)
        bert_config.hidden_dropout_prob = FLAGS.dropout_rate
        bert_config.attention_probs_dropout_prob = FLAGS.dropout_rate
        model, bert_encoder = ub.models.DropoutBertBuilder(
            num_classes=num_classes,
            bert_config=bert_config,
            use_mc_dropout_mha=FLAGS.use_mc_dropout_mha,
            use_mc_dropout_att=FLAGS.use_mc_dropout_att,
            use_mc_dropout_ffn=FLAGS.use_mc_dropout_ffn,
            use_mc_dropout_output=FLAGS.use_mc_dropout_output,
            channel_wise_dropout_mha=FLAGS.channel_wise_dropout_mha,
            channel_wise_dropout_att=FLAGS.channel_wise_dropout_att,
            channel_wise_dropout_ffn=FLAGS.channel_wise_dropout_ffn)
        # Create an AdamW optimizer with beta_2=0.999, epsilon=1e-6.
        optimizer = bert_utils.create_optimizer(
            FLAGS.base_learning_rate,
            steps_per_epoch=steps_per_epoch,
            epochs=FLAGS.train_epochs,
            warmup_proportion=FLAGS.warmup_proportion,
            beta_1=1.0 - FLAGS.one_minus_momentum)

        logging.info('Model input shape: %s', model.input_shape)
        logging.info('Model output shape: %s', model.output_shape)
        logging.info('Model number of weights: %s', model.count_params())

        metrics = {
            'train/negative_log_likelihood':
            tf.keras.metrics.Mean(),
            'train/accuracy':
            tf.keras.metrics.SparseCategoricalAccuracy(),
            'train/loss':
            tf.keras.metrics.Mean(),
            'train/ece':
            rm.metrics.ExpectedCalibrationError(num_bins=FLAGS.num_bins),
            'test/negative_log_likelihood':
            tf.keras.metrics.Mean(),
            'test/accuracy':
            tf.keras.metrics.SparseCategoricalAccuracy(),
            'test/ece':
            rm.metrics.ExpectedCalibrationError(num_bins=FLAGS.num_bins),
        }

        for dataset_name, test_dataset in test_datasets.items():
            if dataset_name != 'clean':
                metrics.update({
                    'test/nll_{}'.format(dataset_name):
                    tf.keras.metrics.Mean(),
                    'test/accuracy_{}'.format(dataset_name):
                    tf.keras.metrics.SparseCategoricalAccuracy(),
                    'test/ece_{}'.format(dataset_name):
                    rm.metrics.ExpectedCalibrationError(
                        num_bins=FLAGS.num_bins)
                })

        checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer)
        latest_checkpoint = tf.train.latest_checkpoint(FLAGS.output_dir)
        initial_epoch = 0
        if latest_checkpoint:
            # checkpoint.restore must be within a strategy.scope() so that optimizer
            # slot variables are mirrored.
            checkpoint.restore(latest_checkpoint)
            logging.info('Loaded checkpoint %s', latest_checkpoint)
            initial_epoch = optimizer.iterations.numpy() // steps_per_epoch
        else:
            # load BERT from initial checkpoint
            bert_checkpoint = tf.train.Checkpoint(model=bert_encoder)
            bert_checkpoint.restore(
                bert_ckpt_dir).assert_existing_objects_matched()
            logging.info('Loaded BERT checkpoint %s', bert_ckpt_dir)

    # Finally, define OOD metrics outside the accelerator scope for CPU eval.
    metrics.update({
        'test/auroc_all': tf.keras.metrics.AUC(curve='ROC'),
        'test/auprc_all': tf.keras.metrics.AUC(curve='PR')
    })

    @tf.function
    def train_step(iterator):
        """Training StepFn."""
        def step_fn(inputs):
            """Per-Replica StepFn."""
            features, labels = bert_utils.create_feature_and_label(
                inputs, feature_size)

            with tf.GradientTape() as tape:
                # Set learning phase to enable dropout etc during training.
                logits = model(features, training=True)
                if FLAGS.use_bfloat16:
                    logits = tf.cast(logits, tf.float32)
                negative_log_likelihood = tf.reduce_mean(
                    tf.keras.losses.sparse_categorical_crossentropy(
                        labels, logits, from_logits=True))
                l2_loss = sum(model.losses)
                loss = negative_log_likelihood + l2_loss
                # Scale the loss given the TPUStrategy will reduce sum all gradients.
                scaled_loss = loss / strategy.num_replicas_in_sync

            grads = tape.gradient(scaled_loss, model.trainable_variables)
            optimizer.apply_gradients(zip(grads, model.trainable_variables))

            probs = tf.nn.softmax(logits)
            metrics['train/ece'].add_batch(probs, label=labels)
            metrics['train/loss'].update_state(loss)
            metrics['train/negative_log_likelihood'].update_state(
                negative_log_likelihood)
            metrics['train/accuracy'].update_state(labels, logits)

        for _ in tf.range(tf.cast(steps_per_epoch, tf.int32)):
            strategy.run(step_fn, args=(next(iterator), ))

    @tf.function
    def test_step(iterator, dataset_name, num_steps):
        """Evaluation StepFn."""
        def step_fn(inputs):
            """Per-Replica StepFn."""
            features, labels = bert_utils.create_feature_and_label(
                inputs, feature_size)

            # Compute ensemble prediction over Monte Carlo dropout samples.
            logits_list = []
            for _ in range(FLAGS.num_dropout_samples):
                logits = model(features, training=False)
                if FLAGS.use_bfloat16:
                    logits = tf.cast(logits, tf.float32)
                logits_list.append(logits)

            # Logits dimension is (num_samples, batch_size, num_classes).
            logits_list = tf.stack(logits_list, axis=0)
            probs_list = tf.nn.softmax(logits_list)
            probs = tf.reduce_mean(probs_list, axis=0)

            labels_broadcasted = tf.broadcast_to(
                labels, [FLAGS.num_dropout_samples, labels.shape[0]])
            log_likelihoods = -tf.keras.losses.sparse_categorical_crossentropy(
                labels_broadcasted, logits_list, from_logits=True)
            negative_log_likelihood = tf.reduce_mean(
                -tf.reduce_logsumexp(log_likelihoods, axis=[0]) +
                tf.math.log(float(FLAGS.num_dropout_samples)))

            if dataset_name == 'clean':
                metrics['test/negative_log_likelihood'].update_state(
                    negative_log_likelihood)
                metrics['test/accuracy'].update_state(labels, probs)
                metrics['test/ece'].add_batch(probs, label=labels)
            else:
                metrics['test/nll_{}'.format(dataset_name)].update_state(
                    negative_log_likelihood)
                metrics['test/accuracy_{}'.format(dataset_name)].update_state(
                    labels, probs)
                metrics['test/ece_{}'.format(dataset_name)].add_batch(
                    probs, label=labels)

            if dataset_name == 'all':
                ood_labels = tf.cast(labels == 150, labels.dtype)
                ood_probs = 1. - tf.reduce_max(probs, axis=-1)
                metrics['test/auroc_{}'.format(dataset_name)].update_state(
                    ood_labels, ood_probs)
                metrics['test/auprc_{}'.format(dataset_name)].update_state(
                    ood_labels, ood_probs)

        for _ in tf.range(tf.cast(num_steps, tf.int32)):
            step_fn(next(iterator))

    train_iterator = iter(train_dataset)
    start_time = time.time()
    for epoch in range(initial_epoch, FLAGS.train_epochs):
        logging.info('Starting to run epoch: %s', epoch)
        train_step(train_iterator)

        current_step = (epoch + 1) * steps_per_epoch
        max_steps = steps_per_epoch * FLAGS.train_epochs
        time_elapsed = time.time() - start_time
        steps_per_sec = float(current_step) / time_elapsed
        eta_seconds = (max_steps - current_step) / steps_per_sec
        message = ('{:.1%} completion: epoch {:d}/{:d}. {:.1f} steps/s. '
                   'ETA: {:.0f} min. Time elapsed: {:.0f} min'.format(
                       current_step / max_steps, epoch + 1, FLAGS.train_epochs,
                       steps_per_sec, eta_seconds / 60, time_elapsed / 60))
        logging.info(message)

        if epoch % FLAGS.evaluation_interval == 0:
            for dataset_name, test_dataset in test_datasets.items():
                test_iterator = iter(test_dataset)
                logging.info('Testing on dataset %s', dataset_name)
                logging.info('Starting to run eval at epoch: %s', epoch)
                test_step(test_iterator, dataset_name,
                          steps_per_eval[dataset_name])
                logging.info('Done with testing on %s', dataset_name)

            logging.info('Train Loss: %.4f, Accuracy: %.2f%%',
                         metrics['train/loss'].result(),
                         metrics['train/accuracy'].result() * 100)
            logging.info('Test NLL: %.4f, Accuracy: %.2f%%',
                         metrics['test/negative_log_likelihood'].result(),
                         metrics['test/accuracy'].result() * 100)
            total_results = {
                name: metric.result()
                for name, metric in metrics.items()
            }
            # Metrics from Robustness Metrics (like ECE) will return a dict with a
            # single key/value, instead of a scalar.
            total_results = {
                k: (list(v.values())[0] if isinstance(v, dict) else v)
                for k, v in total_results.items()
            }
            with summary_writer.as_default():
                for name, result in total_results.items():
                    tf.summary.scalar(name, result, step=epoch + 1)

        for metric in metrics.values():
            metric.reset_states()

        if (FLAGS.checkpoint_interval > 0
                and (epoch + 1) % FLAGS.checkpoint_interval == 0):
            checkpoint_name = checkpoint.save(
                os.path.join(FLAGS.output_dir, 'checkpoint'))
            logging.info('Saved checkpoint to %s', checkpoint_name)
    with summary_writer.as_default():
        hp.hparams({
            'base_learning_rate': FLAGS.base_learning_rate,
            'one_minus_momentum': FLAGS.one_minus_momentum,
            'dropout_rate': FLAGS.dropout_rate,
            'num_dropout_samples': FLAGS.num_dropout_samples,
        })
示例#24
0
def main(_):

    configure_environment(FLAGS.fp16_run)

    hparams = {
        HP_TOKEN_TYPE: HP_TOKEN_TYPE.domain.values[1],

        # Preprocessing
        HP_MEL_BINS: HP_MEL_BINS.domain.values[0],
        HP_FRAME_LENGTH: HP_FRAME_LENGTH.domain.values[0],
        HP_FRAME_STEP: HP_FRAME_STEP.domain.values[0],
        HP_HERTZ_LOW: HP_HERTZ_LOW.domain.values[0],
        HP_HERTZ_HIGH: HP_HERTZ_HIGH.domain.values[0],

        # Model
        HP_EMBEDDING_SIZE: HP_EMBEDDING_SIZE.domain.values[0],
        HP_ENCODER_LAYERS: HP_ENCODER_LAYERS.domain.values[0],
        HP_ENCODER_SIZE: HP_ENCODER_SIZE.domain.values[0],
        HP_PROJECTION_SIZE: HP_PROJECTION_SIZE.domain.values[0],
        HP_TIME_REDUCT_INDEX: HP_TIME_REDUCT_INDEX.domain.values[0],
        HP_TIME_REDUCT_FACTOR: HP_TIME_REDUCT_FACTOR.domain.values[0],
        HP_PRED_NET_LAYERS: HP_PRED_NET_LAYERS.domain.values[0],
        HP_PRED_NET_SIZE: HP_PRED_NET_SIZE.domain.values[0],
        HP_JOINT_NET_SIZE: HP_JOINT_NET_SIZE.domain.values[0],
        HP_LEARNING_RATE: HP_LEARNING_RATE.domain.values[0]
    }

    with tf.summary.create_file_writer(
            os.path.join(FLAGS.tb_log_dir, 'hparams_tuning')).as_default():
        hp.hparams_config(
            hparams=[
                HP_TOKEN_TYPE, HP_VOCAB_SIZE, HP_EMBEDDING_SIZE,
                HP_ENCODER_LAYERS, HP_ENCODER_SIZE, HP_PROJECTION_SIZE,
                HP_TIME_REDUCT_INDEX, HP_TIME_REDUCT_FACTOR,
                HP_PRED_NET_LAYERS, HP_PRED_NET_SIZE, HP_JOINT_NET_SIZE
            ],
            metrics=[
                hp.Metric(METRIC_ACCURACY, display_name='Accuracy'),
                hp.Metric(METRIC_CER, display_name='CER'),
                hp.Metric(METRIC_WER, display_name='WER'),
            ],
        )

    _hparams = {k.name: v for k, v in hparams.items()}

    if len(FLAGS) == 0:
        gpus = [
            x.name.strip('/physical_device:')
            for x in tf.config.experimental.list_physical_devices('GPU')
        ]
    else:
        gpus = ['GPU:' + str(gpu_id) for gpu_id in FLAGS.gpus]

    strategy = tf.distribute.MirroredStrategy(devices=gpus)
    # strategy = None

    dtype = tf.float32
    if FLAGS.fp16_run:
        dtype = tf.float16

    # initializer = tf.keras.initializers.RandomUniform(
    #     minval=-0.1, maxval=0.1)
    initializer = None

    if FLAGS.checkpoint is not None:

        checkpoint_dir = os.path.dirname(os.path.realpath(FLAGS.checkpoint))

        _hparams = model_utils.load_hparams(checkpoint_dir)
        encoder_fn, idx_to_text, vocab_size = encoding.load_encoder(
            checkpoint_dir, hparams=_hparams)

        if strategy is not None:
            with strategy.scope():
                model = build_keras_model(_hparams,
                                          initializer=initializer,
                                          dtype=dtype)
                model.load_weights(FLAGS.checkpoint)
        else:
            model = build_keras_model(_hparams,
                                      initializer=initializer,
                                      dtype=dtype)
            model.load_weights(FLAGS.checkpoint)

        logging.info('Restored weights from {}.'.format(FLAGS.checkpoint))

    else:

        os.makedirs(FLAGS.output_dir, exist_ok=True)

        shutil.copy(os.path.join(FLAGS.data_dir, 'encoder.subwords'),
                    os.path.join(FLAGS.output_dir, 'encoder.subwords'))

        encoder_fn, idx_to_text, vocab_size = encoding.load_encoder(
            FLAGS.output_dir, hparams=_hparams)
        _hparams[HP_VOCAB_SIZE.name] = vocab_size

        if strategy is not None:
            with strategy.scope():
                model = build_keras_model(_hparams,
                                          initializer=initializer,
                                          dtype=dtype)
        else:
            model = build_keras_model(_hparams,
                                      initializer=initializer,
                                      dtype=dtype)
        model_utils.save_hparams(_hparams, FLAGS.output_dir)

    logging.info('Using {} encoder with vocab size: {}'.format(
        _hparams[HP_TOKEN_TYPE.name], vocab_size))

    loss_fn = get_loss_fn(
        reduction_factor=_hparams[HP_TIME_REDUCT_FACTOR.name])

    start_token = encoder_fn('')[0]
    decode_fn = decoding.greedy_decode_fn(model, start_token=start_token)

    accuracy_fn = metrics.build_accuracy_fn(decode_fn)
    cer_fn = metrics.build_cer_fn(decode_fn, idx_to_text)
    wer_fn = metrics.build_wer_fn(decode_fn, idx_to_text)

    optimizer = tf.keras.optimizers.SGD(_hparams[HP_LEARNING_RATE.name],
                                        momentum=0.9)

    if FLAGS.fp16_run:
        optimizer = mixed_precision.LossScaleOptimizer(optimizer,
                                                       loss_scale='dynamic')

    encoder = model.layers[2]
    prediction_network = model.layers[3]

    encoder.summary()
    prediction_network.summary()

    model.summary()

    dev_dataset, _ = get_dataset(FLAGS.data_dir,
                                 'dev',
                                 batch_size=FLAGS.batch_size,
                                 n_epochs=FLAGS.n_epochs,
                                 strategy=strategy,
                                 max_size=FLAGS.eval_size)
    # dev_steps = dev_specs['size'] // FLAGS.batch_size

    log_dir = os.path.join(FLAGS.tb_log_dir,
                           datetime.now().strftime('%Y%m%d-%H%M%S'))

    with tf.summary.create_file_writer(log_dir).as_default():

        hp.hparams(hparams)

        if FLAGS.mode == 'train':

            train_dataset, _ = get_dataset(FLAGS.data_dir,
                                           'train',
                                           batch_size=FLAGS.batch_size,
                                           n_epochs=FLAGS.n_epochs,
                                           strategy=strategy)
            # train_steps = train_specs['size'] // FLAGS.batch_size

            os.makedirs(FLAGS.output_dir, exist_ok=True)
            checkpoint_template = os.path.join(
                FLAGS.output_dir, 'checkpoint_{step}_{val_loss:.4f}.hdf5')

            run_training(model=model,
                         optimizer=optimizer,
                         loss_fn=loss_fn,
                         train_dataset=train_dataset,
                         batch_size=FLAGS.batch_size,
                         n_epochs=FLAGS.n_epochs,
                         checkpoint_template=checkpoint_template,
                         strategy=strategy,
                         steps_per_log=FLAGS.steps_per_log,
                         steps_per_checkpoint=FLAGS.steps_per_checkpoint,
                         eval_dataset=dev_dataset,
                         train_metrics=[],
                         eval_metrics=[accuracy_fn, cer_fn, wer_fn],
                         gpus=gpus)

        elif FLAGS.mode == 'eval' or FLAGS.mode == 'test':

            if FLAGS.checkpoint is None:
                raise Exception(
                    'You must provide a checkpoint to perform eval.')

            if FLAGS.mode == 'test':
                dataset, test_specs = get_dataset(FLAGS.data_dir,
                                                  'test',
                                                  batch_size=FLAGS.batch_size,
                                                  n_epochs=FLAGS.n_epochs)
            else:
                dataset = dev_dataset

            eval_start_time = time.time()

            eval_loss, eval_metrics_results = run_evaluate(
                model=model,
                optimizer=optimizer,
                loss_fn=loss_fn,
                eval_dataset=dataset,
                batch_size=FLAGS.batch_size,
                strategy=strategy,
                metrics=[accuracy_fn, cer_fn, wer_fn],
                gpus=gpus)

            validation_log_str = 'VALIDATION RESULTS: Time: {:.4f}, Loss: {:.4f}'.format(
                time.time() - eval_start_time, eval_loss)
            for metric_name, metric_result in eval_metrics_results.items():
                validation_log_str += ', {}: {:.4f}'.format(
                    metric_name, metric_result)

            print(validation_log_str)
示例#25
0
    def _train_n_steps(self, num_steps: int):
        """Runs training for `num_steps` steps.

    Also prints/logs updates about training progress, and summarizes training
    output (if output is returned from `self.trainer.train()`, and if
    `self.summary_dir` is set).

    Args:
      num_steps: An integer specifying how many steps of training to run.

    Raises:
      RuntimeError: If `global_step` is not properly incremented by `num_steps`
        after calling `self.trainer.train(num_steps)`.
    """
        if not self.step_timer:
            self.step_timer = StepTimer(self.global_step)
        current_step = self.global_step.numpy()

        with self.summary_manager.summary_writer().as_default():
            should_record = False  # Allows static optimization in no-summary cases.
            write_hparams = False
            if self.summary_interval:
                # Create a predicate to determine when summaries should be written.
                should_record = lambda: (self.global_step % self.
                                         summary_interval == 0)
                write_hparams = lambda: (self.global_step == 0)
            with tf.summary.record_if(write_hparams):
                hp.hparams(self.hparams)
            with tf.summary.record_if(should_record):
                num_steps_tensor = tf.convert_to_tensor(num_steps,
                                                        dtype=tf.int32)
                train_output = self.trainer.train(num_steps_tensor)

        # Verify that global_step was updated properly, then update current_step.
        expected_step = current_step + num_steps
        if self.global_step.numpy() != expected_step:
            message = (
                f"`trainer.train({num_steps})` did not update `global_step` by "
                f"{num_steps}. Old value was {current_step}, expected updated value "
                f"to be {expected_step}, but it was {self.global_step.numpy()}."
            )
            logging.warning(message)
            return

        train_output = train_output or {}
        for action in self.train_actions:
            action(train_output)
        train_output = tf.nest.map_structure(utils.get_value, train_output)

        current_step = expected_step
        steps_per_second = self.step_timer.steps_per_second()
        _log(f"train | step: {current_step: 6d} | "
             f"steps/sec: {steps_per_second: 6.1f} | "
             f"output: {_format_output(train_output)}")

        train_output["steps_per_second"] = steps_per_second

        #Create logs matching tensorboard log parser format
        #see tensorboard_for_parser.md
        train_output["global_step/sec"] = steps_per_second
        train_output[
            "examples/sec"] = steps_per_second * self.hparams["batch_size"]
        train_output["loss"] = train_output["total_loss"]

        self.summary_manager.write_summaries(train_output)
        self.summary_manager.flush()
示例#26
0
def main(argv):
    del argv  # unused arg
    tf.io.gfile.makedirs(FLAGS.output_dir)
    logging.info('Saving checkpoints at %s', FLAGS.output_dir)

    # Set seeds
    tf.random.set_seed(FLAGS.seed)
    np.random.seed(FLAGS.seed)
    torch.manual_seed(FLAGS.seed)

    # Resolve CUDA device(s)
    if FLAGS.use_gpu and torch.cuda.is_available():
        print('Running model with CUDA.')
        device = 'cuda:0'
    else:
        print('Running model on CPU.')
        device = 'cpu'

    train_batch_size = FLAGS.train_batch_size
    eval_batch_size = FLAGS.eval_batch_size // FLAGS.num_dropout_samples_eval

    # As per the Kaggle challenge, we have split sizes:
    # train: 35,126
    # validation: 10,906
    # test: 42,670
    ds_info = tfds.builder('diabetic_retinopathy_detection').info
    steps_per_epoch = ds_info.splits['train'].num_examples // train_batch_size
    steps_per_validation_eval = (ds_info.splits['validation'].num_examples //
                                 eval_batch_size)
    steps_per_test_eval = ds_info.splits['test'].num_examples // eval_batch_size

    data_dir = FLAGS.data_dir

    dataset_train_builder = ub.datasets.get('diabetic_retinopathy_detection',
                                            split='train',
                                            data_dir=data_dir)
    dataset_train = dataset_train_builder.load(batch_size=train_batch_size)

    dataset_validation_builder = ub.datasets.get(
        'diabetic_retinopathy_detection',
        split='validation',
        data_dir=data_dir,
        is_training=not FLAGS.use_validation)
    validation_batch_size = (eval_batch_size
                             if FLAGS.use_validation else train_batch_size)
    dataset_validation = dataset_validation_builder.load(
        batch_size=validation_batch_size)
    if not FLAGS.use_validation:
        # Note that this will not create any mixed batches of train and validation
        # images.
        dataset_train = dataset_train.concatenate(dataset_validation)

    dataset_test_builder = ub.datasets.get('diabetic_retinopathy_detection',
                                           split='test',
                                           data_dir=data_dir)
    dataset_test = dataset_test_builder.load(batch_size=eval_batch_size)

    summary_writer = tf.summary.create_file_writer(
        os.path.join(FLAGS.output_dir, 'summaries'))

    # MC Dropout ResNet50 based on PyTorch Vision implementation
    logging.info('Building Torch ResNet-50 MC Dropout model.')
    model = ub.models.resnet50_dropout_torch(num_classes=1,
                                             dropout_rate=FLAGS.dropout_rate)
    logging.info('Model number of weights: %s',
                 torch_utils.count_parameters(model))

    # Linearly scale learning rate and the decay epochs by vanilla settings.
    base_lr = FLAGS.base_learning_rate
    optimizer = torch.optim.SGD(model.parameters(),
                                lr=base_lr,
                                momentum=1.0 - FLAGS.one_minus_momentum,
                                nesterov=True)
    steps_to_lr_peak = int(steps_per_epoch * FLAGS.lr_warmup_epochs)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
        optimizer, steps_to_lr_peak, T_mult=2)

    model = model.to(device)

    metrics = utils.get_diabetic_retinopathy_base_metrics(
        use_tpu=False,
        num_bins=FLAGS.num_bins,
        use_validation=FLAGS.use_validation)

    # Define additional metrics that would fail in a TF TPU implementation.
    metrics.update(
        utils.get_diabetic_retinopathy_cpu_metrics(
            use_validation=FLAGS.use_validation))

    # Initialize loss function based on class reweighting setting
    loss_fn = torch.nn.BCELoss()
    sigmoid = torch.nn.Sigmoid()
    max_steps = steps_per_epoch * FLAGS.train_epochs
    image_h = 512
    image_w = 512

    def run_train_epoch(iterator):
        def train_step(inputs):
            images = inputs['features']
            labels = inputs['labels']
            images = torch.from_numpy(images._numpy()).view(
                train_batch_size,
                3,  # pylint: disable=protected-access
                image_h,
                image_w).to(device)
            labels = torch.from_numpy(labels._numpy()).to(device).float()  # pylint: disable=protected-access

            # Zero the parameter gradients
            optimizer.zero_grad()

            # Forward
            logits = model(images)
            probs = sigmoid(logits).squeeze(-1)

            # Add L2 regularization loss to NLL
            negative_log_likelihood = loss_fn(probs, labels)
            l2_loss = sum(p.pow(2.0).sum() for p in model.parameters())
            loss = negative_log_likelihood + (FLAGS.l2 * l2_loss)

            # Backward/optimizer
            loss.backward()
            optimizer.step()

            # Convert to NumPy for metrics updates
            loss = loss.detach()
            negative_log_likelihood = negative_log_likelihood.detach()
            labels = labels.detach()
            probs = probs.detach()

            if device != 'cpu':
                loss = loss.cpu()
                negative_log_likelihood = negative_log_likelihood.cpu()
                labels = labels.cpu()
                probs = probs.cpu()

            loss = loss.numpy()
            negative_log_likelihood = negative_log_likelihood.numpy()
            labels = labels.numpy()
            probs = probs.numpy()

            metrics['train/loss'].update_state(loss)
            metrics['train/negative_log_likelihood'].update_state(
                negative_log_likelihood)
            metrics['train/accuracy'].update_state(labels, probs)
            metrics['train/auprc'].update_state(labels, probs)
            metrics['train/auroc'].update_state(labels, probs)
            metrics['train/ece'].add_batch(probs, label=labels)

        for step in range(steps_per_epoch):
            train_step(next(iterator))

            if step % 100 == 0:
                current_step = (epoch + 1) * step
                time_elapsed = time.time() - start_time
                steps_per_sec = float(current_step) / time_elapsed
                eta_seconds = (max_steps - current_step
                               ) / steps_per_sec if steps_per_sec else 0
                message = (
                    '{:.1%} completion: epoch {:d}/{:d}. {:.1f} steps/s. '
                    'ETA: {:.0f} min. Time elapsed: {:.0f} min'.format(
                        current_step / max_steps, epoch + 1,
                        FLAGS.train_epochs, steps_per_sec, eta_seconds / 60,
                        time_elapsed / 60))
                logging.info(message)

    def run_eval_epoch(iterator, dataset_split, num_steps):
        def eval_step(inputs, model):
            images = inputs['features']
            labels = inputs['labels']
            images = torch.from_numpy(images._numpy()).view(
                eval_batch_size,
                3,  # pylint: disable=protected-access
                image_h,
                image_w).to(device)
            labels = torch.from_numpy(
                labels._numpy()).to(device).float().unsqueeze(-1)  # pylint: disable=protected-access

            with torch.no_grad():
                logits = torch.stack([
                    model(images)
                    for _ in range(FLAGS.num_dropout_samples_eval)
                ],
                                     dim=-1)

            # Logits dimension is (batch_size, 1, num_dropout_samples).
            logits = logits.squeeze()

            # It is now (batch_size, num_dropout_samples).
            probs = sigmoid(logits)

            # labels_tiled shape is (batch_size, num_dropout_samples).
            labels_tiled = torch.tile(labels,
                                      (1, FLAGS.num_dropout_samples_eval))

            log_likelihoods = -loss_fn(probs, labels_tiled)
            negative_log_likelihood = torch.mean(
                -torch.logsumexp(log_likelihoods, dim=-1) +
                torch.log(torch.tensor(float(FLAGS.num_dropout_samples_eval))))

            probs = torch.mean(probs, dim=-1)

            # Convert to NumPy for metrics updates
            negative_log_likelihood = negative_log_likelihood.detach()
            labels = labels.detach()
            probs = probs.detach()

            if device != 'cpu':
                negative_log_likelihood = negative_log_likelihood.cpu()
                labels = labels.cpu()
                probs = probs.cpu()

            negative_log_likelihood = negative_log_likelihood.numpy()
            labels = labels.numpy()
            probs = probs.numpy()

            metrics[dataset_split + '/negative_log_likelihood'].update_state(
                negative_log_likelihood)
            metrics[dataset_split + '/accuracy'].update_state(labels, probs)
            metrics[dataset_split + '/auprc'].update_state(labels, probs)
            metrics[dataset_split + '/auroc'].update_state(labels, probs)
            metrics[dataset_split + '/ece'].add_batch(probs, label=labels)

        for _ in range(num_steps):
            eval_step(next(iterator), model=model)

    metrics.update({'test/ms_per_example': tf.keras.metrics.Mean()})
    start_time = time.time()
    initial_epoch = 0
    train_iterator = iter(dataset_train)
    model.train()
    for epoch in range(initial_epoch, FLAGS.train_epochs):
        logging.info('Starting to run epoch: %s', epoch + 1)

        run_train_epoch(train_iterator)

        if FLAGS.use_validation:
            validation_iterator = iter(dataset_validation)
            logging.info('Starting to run validation eval at epoch: %s',
                         epoch + 1)
            run_eval_epoch(validation_iterator, 'validation',
                           steps_per_validation_eval)

        test_iterator = iter(dataset_test)
        logging.info('Starting to run test eval at epoch: %s', epoch + 1)
        test_start_time = time.time()
        run_eval_epoch(test_iterator, 'test', steps_per_test_eval)
        ms_per_example = (time.time() -
                          test_start_time) * 1e6 / eval_batch_size
        metrics['test/ms_per_example'].update_state(ms_per_example)

        # Step scheduler
        scheduler.step()

        # Log and write to summary the epoch metrics
        utils.log_epoch_metrics(metrics=metrics, use_tpu=False)
        total_results = {
            name: metric.result()
            for name, metric in metrics.items()
        }
        # Metrics from Robustness Metrics (like ECE) will return a dict with a
        # single key/value, instead of a scalar.
        total_results = {
            k: (list(v.values())[0] if isinstance(v, dict) else v)
            for k, v in total_results.items()
        }
        with summary_writer.as_default():
            for name, result in total_results.items():
                tf.summary.scalar(name, result, step=epoch + 1)

        for metric in metrics.values():
            metric.reset_states()

        if (FLAGS.checkpoint_interval > 0
                and (epoch + 1) % FLAGS.checkpoint_interval == 0):

            checkpoint_path = os.path.join(FLAGS.output_dir,
                                           f'model_{epoch + 1}.pt')
            torch_utils.checkpoint_torch_model(model=model,
                                               optimizer=optimizer,
                                               epoch=epoch + 1,
                                               checkpoint_path=checkpoint_path)
            logging.info('Saved Torch checkpoint to %s', checkpoint_path)

    final_checkpoint_path = os.path.join(FLAGS.output_dir,
                                         f'model_{FLAGS.train_epochs}.pt')
    torch_utils.checkpoint_torch_model(model=model,
                                       optimizer=optimizer,
                                       epoch=FLAGS.train_epochs,
                                       checkpoint_path=final_checkpoint_path)
    logging.info('Saved last checkpoint to %s', final_checkpoint_path)

    with summary_writer.as_default():
        hp.hparams({
            'base_learning_rate': FLAGS.base_learning_rate,
            'one_minus_momentum': FLAGS.one_minus_momentum,
            'dropout_rate': FLAGS.dropout_rate,
            'l2': FLAGS.l2,
            'lr_warmup_epochs': FLAGS.lr_warmup_epochs
        })
示例#27
0
def main(argv):
  del argv  # unused arg
  tf.io.gfile.makedirs(FLAGS.output_dir)
  logging.info('Saving checkpoints at %s', FLAGS.output_dir)
  tf.random.set_seed(FLAGS.seed)

  if FLAGS.use_gpu:
    logging.info('Use GPU')
    strategy = tf.distribute.MirroredStrategy()
  else:
    logging.info('Use TPU at %s',
                 FLAGS.tpu if FLAGS.tpu is not None else 'local')
    resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu=FLAGS.tpu)
    tf.config.experimental_connect_to_cluster(resolver)
    tf.tpu.experimental.initialize_tpu_system(resolver)
    strategy = tf.distribute.TPUStrategy(resolver)

  batch_size = FLAGS.per_core_batch_size * FLAGS.num_cores
  test_batch_size = batch_size
  data_buffer_size = batch_size * 10

  train_dataset_builder = ds.WikipediaToxicityDataset(
      split='train',
      data_dir=FLAGS.in_dataset_dir,
      shuffle_buffer_size=data_buffer_size)
  ind_dataset_builder = ds.WikipediaToxicityDataset(
      split='test',
      data_dir=FLAGS.in_dataset_dir,
      shuffle_buffer_size=data_buffer_size)
  ood_dataset_builder = ds.CivilCommentsDataset(
      split='test',
      data_dir=FLAGS.ood_dataset_dir,
      shuffle_buffer_size=data_buffer_size)
  ood_identity_dataset_builder = ds.CivilCommentsIdentitiesDataset(
      split='test',
      data_dir=FLAGS.identity_dataset_dir,
      shuffle_buffer_size=data_buffer_size)

  train_dataset_builders = {
      'wikipedia_toxicity_subtypes': train_dataset_builder
  }
  test_dataset_builders = {
      'ind': ind_dataset_builder,
      'ood': ood_dataset_builder,
      'ood_identity': ood_identity_dataset_builder,
  }
  if FLAGS.prediction_mode and FLAGS.identity_prediction:
    for dataset_name in utils.IDENTITY_LABELS:
      if utils.NUM_EXAMPLES[dataset_name]['test'] > 100:
        test_dataset_builders[dataset_name] = ds.CivilCommentsIdentitiesDataset(
            split='test',
            data_dir=os.path.join(
                FLAGS.identity_specific_dataset_dir, dataset_name),
            shuffle_buffer_size=data_buffer_size)
    for dataset_name in utils.IDENTITY_TYPES:
      if utils.NUM_EXAMPLES[dataset_name]['test'] > 100:
        test_dataset_builders[dataset_name] = ds.CivilCommentsIdentitiesDataset(
            split='test',
            data_dir=os.path.join(
                FLAGS.identity_type_dataset_dir, dataset_name),
            shuffle_buffer_size=data_buffer_size)

  class_weight = utils.create_class_weight(
      train_dataset_builders, test_dataset_builders)
  logging.info('class_weight: %s', str(class_weight))

  ds_info = train_dataset_builder.tfds_info
  # Positive and negative classes.
  num_classes = ds_info.metadata['num_classes']

  train_datasets = {}
  dataset_steps_per_epoch = {}
  total_steps_per_epoch = 0

  # TODO(jereliu): Apply strategy.experimental_distribute_dataset to the
  # dataset_builders.
  for dataset_name, dataset_builder in train_dataset_builders.items():
    train_datasets[dataset_name] = dataset_builder.load(
        batch_size=FLAGS.per_core_batch_size)
    dataset_steps_per_epoch[dataset_name] = (
        dataset_builder.num_examples // batch_size)
    total_steps_per_epoch += dataset_steps_per_epoch[dataset_name]

  test_datasets = {}
  steps_per_eval = {}
  for dataset_name, dataset_builder in test_dataset_builders.items():
    test_datasets[dataset_name] = dataset_builder.load(
        batch_size=test_batch_size)
    if dataset_name in ['ind', 'ood', 'ood_identity']:
      steps_per_eval[dataset_name] = (
          dataset_builder.num_examples // test_batch_size)
    else:
      steps_per_eval[dataset_name] = (
          utils.NUM_EXAMPLES[dataset_name]['test'] // test_batch_size)

  if FLAGS.use_bfloat16:
    policy = tf.keras.mixed_precision.experimental.Policy('mixed_bfloat16')
    tf.keras.mixed_precision.experimental.set_policy(policy)

  summary_writer = tf.summary.create_file_writer(
      os.path.join(FLAGS.output_dir, 'summaries'))

  with strategy.scope():
    logging.info('Building BERT %s model', FLAGS.bert_model_type)
    logging.info('use_gp_layer=%s', FLAGS.use_gp_layer)
    logging.info('use_spec_norm_att=%s', FLAGS.use_spec_norm_att)
    logging.info('use_spec_norm_ffn=%s', FLAGS.use_spec_norm_ffn)
    logging.info('use_layer_norm_att=%s', FLAGS.use_layer_norm_att)
    logging.info('use_layer_norm_ffn=%s', FLAGS.use_layer_norm_ffn)

    bert_config_dir, bert_ckpt_dir = utils.resolve_bert_ckpt_and_config_dir(
        FLAGS.bert_model_type, FLAGS.bert_dir, FLAGS.bert_config_dir,
        FLAGS.bert_ckpt_dir)
    bert_config = utils.create_config(bert_config_dir)

    gp_layer_kwargs = dict(
        num_inducing=FLAGS.gp_hidden_dim,
        gp_kernel_scale=FLAGS.gp_scale,
        gp_output_bias=FLAGS.gp_bias,
        normalize_input=FLAGS.gp_input_normalization,
        gp_cov_momentum=FLAGS.gp_cov_discount_factor,
        gp_cov_ridge_penalty=FLAGS.gp_cov_ridge_penalty)
    spec_norm_kwargs = dict(
        iteration=FLAGS.spec_norm_iteration,
        norm_multiplier=FLAGS.spec_norm_bound)

    model, bert_encoder = ub.models.SngpBertBuilder(
        num_classes=num_classes,
        bert_config=bert_config,
        gp_layer_kwargs=gp_layer_kwargs,
        spec_norm_kwargs=spec_norm_kwargs,
        use_gp_layer=FLAGS.use_gp_layer,
        use_spec_norm_att=FLAGS.use_spec_norm_att,
        use_spec_norm_ffn=FLAGS.use_spec_norm_ffn,
        use_layer_norm_att=FLAGS.use_layer_norm_att,
        use_layer_norm_ffn=FLAGS.use_layer_norm_ffn,
        use_spec_norm_plr=FLAGS.use_spec_norm_plr)
    # Create an AdamW optimizer with beta_2=0.999, epsilon=1e-6.
    optimizer = utils.create_optimizer(
        FLAGS.base_learning_rate,
        steps_per_epoch=total_steps_per_epoch,
        epochs=FLAGS.train_epochs,
        warmup_proportion=FLAGS.warmup_proportion,
        beta_1=1.0 - FLAGS.one_minus_momentum)

    logging.info('Model input shape: %s', model.input_shape)
    logging.info('Model output shape: %s', model.output_shape)
    logging.info('Model number of weights: %s', model.count_params())

    metrics = {
        'train/negative_log_likelihood': tf.keras.metrics.Mean(),
        'train/accuracy': tf.keras.metrics.Accuracy(),
        'train/accuracy_weighted': tf.keras.metrics.Accuracy(),
        'train/auroc': tf.keras.metrics.AUC(),
        'train/loss': tf.keras.metrics.Mean(),
        'train/ece': rm.metrics.ExpectedCalibrationError(
            num_bins=FLAGS.num_bins),
        'train/precision': tf.keras.metrics.Precision(),
        'train/recall': tf.keras.metrics.Recall(),
        'train/f1': tfa_metrics.F1Score(
            num_classes=num_classes, average='micro',
            threshold=FLAGS.ece_label_threshold),
    }

    checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer)
    if FLAGS.prediction_mode:
      latest_checkpoint = tf.train.latest_checkpoint(FLAGS.eval_checkpoint_dir)
    else:
      latest_checkpoint = tf.train.latest_checkpoint(FLAGS.output_dir)
    initial_epoch = 0
    if latest_checkpoint:
      # checkpoint.restore must be within a strategy.scope() so that optimizer
      # slot variables are mirrored.
      checkpoint.restore(latest_checkpoint)
      logging.info('Loaded checkpoint %s', latest_checkpoint)
      initial_epoch = optimizer.iterations.numpy() // total_steps_per_epoch
    else:
      # load BERT from initial checkpoint
      bert_encoder, _, _ = utils.load_bert_weight_from_ckpt(
          bert_model=bert_encoder,
          bert_ckpt_dir=bert_ckpt_dir,
          repl_patterns=ub.models.bert_sngp.CHECKPOINT_REPL_PATTERNS)
      logging.info('Loaded BERT checkpoint %s', bert_ckpt_dir)

    metrics.update({
        'test/negative_log_likelihood':
            tf.keras.metrics.Mean(),
        'test/auroc':
            tf.keras.metrics.AUC(curve='ROC'),
        'test/aupr':
            tf.keras.metrics.AUC(curve='PR'),
        'test/brier':
            tf.keras.metrics.MeanSquaredError(),
        'test/brier_weighted':
            tf.keras.metrics.MeanSquaredError(),
        'test/ece':
            rm.metrics.ExpectedCalibrationError(num_bins=FLAGS.num_bins),
        'test/acc':
            tf.keras.metrics.Accuracy(),
        'test/acc_weighted':
            tf.keras.metrics.Accuracy(),
        'test/eval_time':
            tf.keras.metrics.Mean(),
        'test/stddev':
            tf.keras.metrics.Mean(),
        'test/precision':
            tf.keras.metrics.Precision(),
        'test/recall':
            tf.keras.metrics.Recall(),
        'test/f1':
            tfa_metrics.F1Score(
                num_classes=num_classes,
                average='micro',
                threshold=FLAGS.ece_label_threshold),
        'test/calibration_auroc':
            tc_metrics.CalibrationAUC(curve='ROC'),
        'test/calibration_auprc':
            tc_metrics.CalibrationAUC(curve='PR')
    })
    for fraction in FLAGS.fractions:
      metrics.update({
          'test_collab_acc/collab_acc_{}'.format(fraction):
              rm.metrics.OracleCollaborativeAccuracy(
                  fraction=float(fraction), num_bins=FLAGS.num_bins)
      })
      metrics.update({
          'test_abstain_prec/abstain_prec_{}'.format(fraction):
              tc_metrics.AbstainPrecision(abstain_fraction=float(fraction))
      })
      metrics.update({
          'test_abstain_recall/abstain_recall_{}'.format(fraction):
              tc_metrics.AbstainRecall(abstain_fraction=float(fraction))
      })

    for dataset_name, test_dataset in test_datasets.items():
      if dataset_name != 'ind':
        metrics.update({
            'test/nll_{}'.format(dataset_name):
                tf.keras.metrics.Mean(),
            'test/auroc_{}'.format(dataset_name):
                tf.keras.metrics.AUC(curve='ROC'),
            'test/aupr_{}'.format(dataset_name):
                tf.keras.metrics.AUC(curve='PR'),
            'test/brier_{}'.format(dataset_name):
                tf.keras.metrics.MeanSquaredError(),
            'test/brier_weighted_{}'.format(dataset_name):
                tf.keras.metrics.MeanSquaredError(),
            'test/ece_{}'.format(dataset_name):
                rm.metrics.ExpectedCalibrationError(num_bins=FLAGS.num_bins),
            'test/acc_{}'.format(dataset_name):
                tf.keras.metrics.Accuracy(),
            'test/acc_weighted_{}'.format(dataset_name):
                tf.keras.metrics.Accuracy(),
            'test/eval_time_{}'.format(dataset_name):
                tf.keras.metrics.Mean(),
            'test/stddev_{}'.format(dataset_name):
                tf.keras.metrics.Mean(),
            'test/precision_{}'.format(dataset_name):
                tf.keras.metrics.Precision(),
            'test/recall_{}'.format(dataset_name):
                tf.keras.metrics.Recall(),
            'test/f1_{}'.format(dataset_name):
                tfa_metrics.F1Score(
                    num_classes=num_classes,
                    average='micro',
                    threshold=FLAGS.ece_label_threshold),
            'test/calibration_auroc_{}'.format(dataset_name):
                tc_metrics.CalibrationAUC(curve='ROC'),
            'test/calibration_auprc_{}'.format(dataset_name):
                tc_metrics.CalibrationAUC(curve='PR'),
        })
        for fraction in FLAGS.fractions:
          metrics.update({
              'test_collab_acc/collab_acc_{}_{}'.format(fraction, dataset_name):
                  rm.metrics.OracleCollaborativeAccuracy(
                      fraction=float(fraction), num_bins=FLAGS.num_bins)
          })
          metrics.update({
              'test_abstain_prec/abstain_prec_{}_{}'.format(
                  fraction, dataset_name):
                  tc_metrics.AbstainPrecision(abstain_fraction=float(fraction))
          })
          metrics.update({
              'test_abstain_recall/abstain_recall_{}_{}'.format(
                  fraction, dataset_name):
                  tc_metrics.AbstainRecall(abstain_fraction=float(fraction))
          })

  @tf.function
  def generate_sample_weight(labels, class_weight, label_threshold=0.7):
    """Generate sample weight for weighted accuracy calculation."""
    if label_threshold != 0.7:
      logging.warning('The class weight was based on `label_threshold` = 0.7, '
                      'and weighted accuracy/brier will be meaningless if '
                      '`label_threshold` is not equal to this value, which is '
                      'recommended by Jigsaw Conversation AI team.')
    labels_int = tf.cast(labels > label_threshold, tf.int32)
    sample_weight = tf.gather(class_weight, labels_int)
    return sample_weight

  @tf.function
  def train_step(iterator, dataset_name, num_steps):
    """Training StepFn."""

    def step_fn(inputs):
      """Per-Replica StepFn."""
      features, labels, _ = utils.create_feature_and_label(inputs)

      with tf.GradientTape() as tape:
        logits = model(features, training=True)

        if isinstance(logits, (list, tuple)):
          # If model returns a tuple of (logits, covmat), extract logits
          logits, _ = logits
        if FLAGS.use_bfloat16:
          logits = tf.cast(logits, tf.float32)

        loss_logits = tf.squeeze(logits, axis=1)
        if FLAGS.loss_type == 'cross_entropy':
          logging.info('Using cross entropy loss')
          negative_log_likelihood = tf.nn.sigmoid_cross_entropy_with_logits(
              labels, loss_logits)
        elif FLAGS.loss_type == 'focal_cross_entropy':
          logging.info('Using focal cross entropy loss')
          negative_log_likelihood = tfa_losses.sigmoid_focal_crossentropy(
              labels,
              loss_logits,
              alpha=FLAGS.focal_loss_alpha,
              gamma=FLAGS.focal_loss_gamma,
              from_logits=True)
        elif FLAGS.loss_type == 'mse':
          logging.info('Using mean squared error loss')
          loss_probs = tf.nn.sigmoid(loss_logits)
          negative_log_likelihood = tf.keras.losses.mean_squared_error(
              labels, loss_probs)
        elif FLAGS.loss_type == 'mae':
          logging.info('Using mean absolute error loss')
          loss_probs = tf.nn.sigmoid(loss_logits)
          negative_log_likelihood = tf.keras.losses.mean_absolute_error(
              labels, loss_probs)

        negative_log_likelihood = tf.reduce_mean(negative_log_likelihood)

        l2_loss = sum(model.losses)
        loss = negative_log_likelihood + l2_loss
        # Scale the loss given the TPUStrategy will reduce sum all gradients.
        scaled_loss = loss / strategy.num_replicas_in_sync

      grads = tape.gradient(scaled_loss, model.trainable_variables)
      optimizer.apply_gradients(zip(grads, model.trainable_variables))

      probs = tf.nn.sigmoid(logits)
      # Cast labels to discrete for ECE computation.
      ece_labels = tf.cast(labels > FLAGS.ece_label_threshold, tf.float32)
      one_hot_labels = tf.one_hot(tf.cast(ece_labels, tf.int32),
                                  depth=num_classes)
      ece_probs = tf.concat([1. - probs, probs], axis=1)
      auc_probs = tf.squeeze(probs, axis=1)
      pred_labels = tf.math.argmax(ece_probs, axis=-1)

      sample_weight = generate_sample_weight(
          labels, class_weight['train/{}'.format(dataset_name)],
          FLAGS.ece_label_threshold)
      metrics['train/negative_log_likelihood'].update_state(
          negative_log_likelihood)
      metrics['train/accuracy'].update_state(labels, pred_labels)
      metrics['train/accuracy_weighted'].update_state(
          ece_labels, pred_labels, sample_weight=sample_weight)
      metrics['train/auroc'].update_state(labels, auc_probs)
      metrics['train/loss'].update_state(loss)
      metrics['train/ece'].add_batch(ece_probs, label=ece_labels)
      metrics['train/precision'].update_state(ece_labels, pred_labels)
      metrics['train/recall'].update_state(ece_labels, pred_labels)
      metrics['train/f1'].update_state(one_hot_labels, ece_probs)

    for _ in tf.range(tf.cast(num_steps, tf.int32)):
      strategy.run(step_fn, args=(next(iterator),))

  @tf.function
  def test_step(iterator, dataset_name):
    """Evaluation StepFn."""

    def step_fn(inputs):
      """Per-Replica StepFn."""
      features, labels, _ = utils.create_feature_and_label(inputs)

      eval_start_time = time.time()
      # Compute ensemble prediction over Monte Carlo forward-pass samples.
      logits_list = []
      stddev_list = []
      for _ in range(FLAGS.num_mc_samples):
        logits = model(features, training=False)

        if isinstance(logits, (list, tuple)):
          # If model returns a tuple of (logits, covmat), extract both.
          logits, covmat = logits
        else:
          covmat = tf.eye(test_batch_size)

        if FLAGS.use_bfloat16:
          logits = tf.cast(logits, tf.float32)
          covmat = tf.cast(covmat, tf.float32)

        logits = ed.layers.utils.mean_field_logits(
            logits, covmat, mean_field_factor=FLAGS.gp_mean_field_factor)
        stddev = tf.sqrt(tf.linalg.diag_part(covmat))

        logits_list.append(logits)
        stddev_list.append(stddev)

      eval_time = (time.time() - eval_start_time) / FLAGS.per_core_batch_size
      # Logits dimension is (num_samples, batch_size, num_classes).
      logits_list = tf.stack(logits_list, axis=0)
      stddev_list = tf.stack(stddev_list, axis=0)

      stddev = tf.reduce_mean(stddev_list, axis=0)
      probs_list = tf.nn.sigmoid(logits_list)
      probs = tf.reduce_mean(probs_list, axis=0)
      # Cast labels to discrete for ECE computation.
      ece_labels = tf.cast(labels > FLAGS.ece_label_threshold, tf.float32)
      one_hot_labels = tf.one_hot(tf.cast(ece_labels, tf.int32),
                                  depth=num_classes)
      ece_probs = tf.concat([1. - probs, probs], axis=1)
      pred_labels = tf.math.argmax(ece_probs, axis=-1)
      auc_probs = tf.squeeze(probs, axis=1)

      # Use normalized binary predictive variance as the confidence score.
      # Since the prediction variance p*(1-p) is within range (0, 0.25),
      # normalize it by maximum value so the confidence is between (0, 1).
      calib_confidence = 1. - probs * (1. - probs) / .25

      ce = tf.nn.sigmoid_cross_entropy_with_logits(
          labels=tf.broadcast_to(
              labels, [FLAGS.num_mc_samples, labels.shape[0]]),
          logits=tf.squeeze(logits_list, axis=-1)
      )
      negative_log_likelihood = -tf.reduce_logsumexp(
          -ce, axis=0) + tf.math.log(float(FLAGS.num_mc_samples))
      negative_log_likelihood = tf.reduce_mean(negative_log_likelihood)

      sample_weight = generate_sample_weight(
          labels, class_weight['test/{}'.format(dataset_name)],
          FLAGS.ece_label_threshold)
      if dataset_name == 'ind':
        metrics['test/negative_log_likelihood'].update_state(
            negative_log_likelihood)
        metrics['test/auroc'].update_state(labels, auc_probs)
        metrics['test/aupr'].update_state(labels, auc_probs)
        metrics['test/brier'].update_state(labels, auc_probs)
        metrics['test/brier_weighted'].update_state(
            tf.expand_dims(labels, -1), probs, sample_weight=sample_weight)
        metrics['test/ece'].add_batch(ece_probs, label=ece_labels)
        metrics['test/acc'].update_state(ece_labels, pred_labels)
        metrics['test/acc_weighted'].update_state(
            ece_labels, pred_labels, sample_weight=sample_weight)
        metrics['test/eval_time'].update_state(eval_time)
        metrics['test/stddev'].update_state(stddev)
        metrics['test/precision'].update_state(ece_labels, pred_labels)
        metrics['test/recall'].update_state(ece_labels, pred_labels)
        metrics['test/f1'].update_state(one_hot_labels, ece_probs)
        metrics['test/calibration_auroc'].update_state(ece_labels, pred_labels,
                                                       calib_confidence)
        metrics['test/calibration_auprc'].update_state(ece_labels, pred_labels,
                                                       calib_confidence)
        for fraction in FLAGS.fractions:
          metrics['test_collab_acc/collab_acc_{}'.format(
              fraction)].add_batch(ece_probs, label=ece_labels)
          metrics['test_abstain_prec/abstain_prec_{}'.format(
              fraction)].update_state(ece_labels, pred_labels, calib_confidence)
          metrics['test_abstain_recall/abstain_recall_{}'.format(
              fraction)].update_state(ece_labels, pred_labels, calib_confidence)

      else:
        metrics['test/nll_{}'.format(dataset_name)].update_state(
            negative_log_likelihood)
        metrics['test/auroc_{}'.format(dataset_name)].update_state(
            labels, auc_probs)
        metrics['test/aupr_{}'.format(dataset_name)].update_state(
            labels, auc_probs)
        metrics['test/brier_{}'.format(dataset_name)].update_state(
            labels, auc_probs)
        metrics['test/brier_weighted_{}'.format(dataset_name)].update_state(
            tf.expand_dims(labels, -1), probs, sample_weight=sample_weight)
        metrics['test/ece_{}'.format(dataset_name)].add_batch(
            ece_probs, label=ece_labels)
        metrics['test/acc_{}'.format(dataset_name)].update_state(
            ece_labels, pred_labels)
        metrics['test/acc_weighted_{}'.format(dataset_name)].update_state(
            ece_labels, pred_labels, sample_weight=sample_weight)
        metrics['test/eval_time_{}'.format(dataset_name)].update_state(
            eval_time)
        metrics['test/stddev_{}'.format(dataset_name)].update_state(stddev)
        metrics['test/precision_{}'.format(dataset_name)].update_state(
            ece_labels, pred_labels)
        metrics['test/recall_{}'.format(dataset_name)].update_state(
            ece_labels, pred_labels)
        metrics['test/f1_{}'.format(dataset_name)].update_state(
            one_hot_labels, ece_probs)
        metrics['test/calibration_auroc_{}'.format(dataset_name)].update_state(
            ece_labels, pred_labels, calib_confidence)
        metrics['test/calibration_auprc_{}'.format(dataset_name)].update_state(
            ece_labels, pred_labels, calib_confidence)
        for fraction in FLAGS.fractions:
          metrics['test_collab_acc/collab_acc_{}_{}'.format(
              fraction, dataset_name)].add_batch(ece_probs, label=ece_labels)
          metrics['test_abstain_prec/abstain_prec_{}_{}'.format(
              fraction, dataset_name)].update_state(ece_labels, pred_labels,
                                                    calib_confidence)
          metrics['test_abstain_recall/abstain_recall_{}_{}'.format(
              fraction, dataset_name)].update_state(ece_labels, pred_labels,
                                                    calib_confidence)

    strategy.run(step_fn, args=(next(iterator),))

  @tf.function
  def final_eval_step(iterator):
    """Final Evaluation StepFn to save prediction to directory."""

    def step_fn(inputs):
      bert_features, labels, additional_labels = utils.create_feature_and_label(
          inputs)
      logits = model(bert_features, training=False)
      if isinstance(logits, (list, tuple)):
        # If model returns a tuple of (logits, covmat), extract both.
        logits, covmat = logits
      else:
        covmat = tf.eye(test_batch_size)

      if FLAGS.use_bfloat16:
        logits = tf.cast(logits, tf.float32)
        covmat = tf.cast(covmat, tf.float32)

      logits = ed.layers.utils.mean_field_logits(
          logits, covmat, mean_field_factor=FLAGS.gp_mean_field_factor)
      features = inputs['input_ids']
      return features, logits, labels, additional_labels

    (per_replica_texts, per_replica_logits, per_replica_labels,
     per_replica_additional_labels) = (
         strategy.run(step_fn, args=(next(iterator),)))

    if strategy.num_replicas_in_sync > 1:
      texts_list = tf.concat(per_replica_texts.values, axis=0)
      logits_list = tf.concat(per_replica_logits.values, axis=0)
      labels_list = tf.concat(per_replica_labels.values, axis=0)
      additional_labels_dict = {}
      for additional_label in utils.IDENTITY_LABELS:
        if additional_label in per_replica_additional_labels:
          additional_labels_dict[additional_label] = tf.concat(
              per_replica_additional_labels[additional_label], axis=0)
    else:
      texts_list = per_replica_texts
      logits_list = per_replica_logits
      labels_list = per_replica_labels
      additional_labels_dict = {}
      for additional_label in utils.IDENTITY_LABELS:
        if additional_label in per_replica_additional_labels:
          additional_labels_dict[
              additional_label] = per_replica_additional_labels[
                  additional_label]

    return texts_list, logits_list, labels_list, additional_labels_dict

  if FLAGS.prediction_mode:
    # Prediction and exit.
    for dataset_name, test_dataset in test_datasets.items():
      test_iterator = iter(test_dataset)  # pytype: disable=wrong-arg-types
      message = 'Final eval on dataset {}'.format(dataset_name)
      logging.info(message)

      texts_all = []
      logits_all = []
      labels_all = []
      additional_labels_all_dict = {}
      if 'identity' in dataset_name:
        for identity_label_name in utils.IDENTITY_LABELS:
          additional_labels_all_dict[identity_label_name] = []

      try:
        with tf.experimental.async_scope():
          for step in range(steps_per_eval[dataset_name]):
            if step % 20 == 0:
              message = 'Starting to run eval step {}/{} of dataset: {}'.format(
                  step, steps_per_eval[dataset_name], dataset_name)
              logging.info(message)

            (text_step, logits_step, labels_step,
             additional_labels_dict_step) = final_eval_step(test_iterator)

            texts_all.append(text_step)
            logits_all.append(logits_step)
            labels_all.append(labels_step)
            if 'identity' in dataset_name:
              for identity_label_name in utils.IDENTITY_LABELS:
                additional_labels_all_dict[identity_label_name].append(
                    additional_labels_dict_step[identity_label_name])

      except (StopIteration, tf.errors.OutOfRangeError):
        tf.experimental.async_clear_error()
        logging.info('Done with eval on %s', dataset_name)

      texts_all = tf.concat(texts_all, axis=0)
      logits_all = tf.concat(logits_all, axis=0)
      labels_all = tf.concat(labels_all, axis=0)
      additional_labels_all = []
      if additional_labels_all_dict:
        for identity_label_name in utils.IDENTITY_LABELS:
          additional_labels_all.append(
              tf.concat(
                  additional_labels_all_dict[identity_label_name], axis=0))
      additional_labels_all = tf.convert_to_tensor(additional_labels_all)

      utils.save_prediction(
          texts_all.numpy(),
          path=os.path.join(FLAGS.output_dir, 'texts_{}'.format(dataset_name)))
      utils.save_prediction(
          labels_all.numpy(),
          path=os.path.join(FLAGS.output_dir, 'labels_{}'.format(dataset_name)))
      utils.save_prediction(
          logits_all.numpy(),
          path=os.path.join(FLAGS.output_dir, 'logits_{}'.format(dataset_name)))
      if 'identity' in dataset_name:
        utils.save_prediction(
            additional_labels_all.numpy(),
            path=os.path.join(FLAGS.output_dir,
                              'additional_labels_{}'.format(dataset_name)))
      logging.info('Done with testing on %s', dataset_name)

  else:
    # Execute train / eval loop.
    start_time = time.time()
    train_iterators = {}
    for dataset_name, train_dataset in train_datasets.items():
      train_iterators[dataset_name] = iter(train_dataset)
    for epoch in range(initial_epoch, FLAGS.train_epochs):
      logging.info('Starting to run epoch: %s', epoch)
      for dataset_name, train_iterator in train_iterators.items():
        try:
          with tf.experimental.async_scope():
            train_step(
                train_iterator,
                dataset_name,
                dataset_steps_per_epoch[dataset_name])

            current_step = (
                epoch * total_steps_per_epoch +
                dataset_steps_per_epoch[dataset_name])
            max_steps = total_steps_per_epoch * FLAGS.train_epochs
            time_elapsed = time.time() - start_time
            steps_per_sec = float(current_step) / time_elapsed
            eta_seconds = (max_steps - current_step) / steps_per_sec
            message = ('{:.1%} completion: epoch {:d}/{:d}. {:.1f} steps/s. '
                       'ETA: {:.0f} min. Time elapsed: {:.0f} min'.format(
                           current_step / max_steps, epoch + 1,
                           FLAGS.train_epochs, steps_per_sec,
                           eta_seconds / 60, time_elapsed / 60))
            logging.info(message)

        except (StopIteration, tf.errors.OutOfRangeError):
          tf.experimental.async_clear_error()
          logging.info('Done with testing on %s', dataset_name)

      if epoch % FLAGS.evaluation_interval == 0:
        for dataset_name, test_dataset in test_datasets.items():
          test_iterator = iter(test_dataset)
          logging.info('Testing on dataset %s', dataset_name)

          try:
            with tf.experimental.async_scope():
              for step in range(steps_per_eval[dataset_name]):
                if step % 20 == 0:
                  logging.info('Starting to run eval step %s/%s of epoch: %s',
                               step, steps_per_eval[dataset_name], epoch)
                test_step(test_iterator, dataset_name)
          except (StopIteration, tf.errors.OutOfRangeError):
            tf.experimental.async_clear_error()
            logging.info('Done with testing on %s', dataset_name)

        logging.info('Train Loss: %.4f, ECE: %.2f, Accuracy: %.2f',
                     metrics['train/loss'].result(),
                     metrics['train/ece'].result(),
                     metrics['train/accuracy'].result())

        total_results = {
            name: metric.result() for name, metric in metrics.items()
        }
        # Metrics from Robustness Metrics (like ECE) will return a dict with a
        # single key/value, instead of a scalar.
        total_results = {
            k: (list(v.values())[0] if isinstance(v, dict) else v)
            for k, v in total_results.items()
        }
        with summary_writer.as_default():
          for name, result in total_results.items():
            tf.summary.scalar(name, result, step=epoch + 1)

      for metric in metrics.values():
        metric.reset_states()

      checkpoint_interval = min(FLAGS.checkpoint_interval, FLAGS.train_epochs)
      if checkpoint_interval > 0 and (epoch + 1) % checkpoint_interval == 0:
        checkpoint_name = checkpoint.save(
            os.path.join(FLAGS.output_dir, 'checkpoint'))
        logging.info('Saved checkpoint to %s', checkpoint_name)

    # Save model in SavedModel format on exit.
    final_save_name = os.path.join(FLAGS.output_dir, 'model')
    model.save(final_save_name)
    logging.info('Saved model to %s', final_save_name)
  with summary_writer.as_default():
    hp.hparams({
        'base_learning_rate': FLAGS.base_learning_rate,
        'one_minus_momentum': FLAGS.one_minus_momentum,
        'gp_mean_field_factor': FLAGS.gp_mean_field_factor,
    })
示例#28
0
def main(argv):
    del argv  # unused arg
    tf.io.gfile.makedirs(FLAGS.output_dir)
    logging.info('Saving checkpoints at %s', FLAGS.output_dir)
    tf.random.set_seed(FLAGS.seed)

    if FLAGS.use_gpu:
        logging.info('Use GPU')
        strategy = tf.distribute.MirroredStrategy()
    else:
        logging.info('Use TPU at %s',
                     FLAGS.tpu if FLAGS.tpu is not None else 'local')
        resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
            tpu=FLAGS.tpu)
        tf.config.experimental_connect_to_cluster(resolver)
        tf.tpu.experimental.initialize_tpu_system(resolver)
        strategy = tf.distribute.TPUStrategy(resolver)

    ds_info = tfds.builder(FLAGS.dataset).info
    batch_size = FLAGS.per_core_batch_size * FLAGS.num_cores
    train_dataset_size = ds_info.splits['train'].num_examples
    steps_per_epoch = train_dataset_size // batch_size
    steps_per_eval = ds_info.splits['test'].num_examples // batch_size
    num_classes = ds_info.features['label'].num_classes

    train_dataset = ub.datasets.get(
        FLAGS.dataset, split=tfds.Split.TRAIN).load(batch_size=batch_size)
    clean_test_dataset = ub.datasets.get(
        FLAGS.dataset, split=tfds.Split.TEST).load(batch_size=batch_size)
    train_dataset = strategy.experimental_distribute_dataset(train_dataset)
    test_datasets = {
        'clean': strategy.experimental_distribute_dataset(clean_test_dataset),
    }
    if FLAGS.corruptions_interval > 0:
        extra_kwargs = {}
        if FLAGS.dataset == 'cifar100':
            extra_kwargs['data_dir'] = FLAGS.cifar100_c_path
        corruption_types, _ = utils.load_corrupted_test_info(FLAGS.dataset)
        for corruption_type in corruption_types:
            for severity in range(1, 6):
                dataset = ub.datasets.get(
                    f'{FLAGS.dataset}_corrupted',
                    corruption_type=corruption_type,
                    severity=severity,
                    split=tfds.Split.TEST,
                    **extra_kwargs).load(batch_size=batch_size)
                test_datasets[f'{corruption_type}_{severity}'] = (
                    strategy.experimental_distribute_dataset(dataset))

    summary_writer = tf.summary.create_file_writer(
        os.path.join(FLAGS.output_dir, 'summaries'))

    with strategy.scope():
        logging.info('Building ResNet model')
        model = ub.models.wide_resnet_variational(
            input_shape=ds_info.features['image'].shape,
            depth=28,
            width_multiplier=10,
            num_classes=num_classes,
            prior_stddev=FLAGS.prior_stddev,
            dataset_size=train_dataset_size,
            stddev_init=FLAGS.stddev_init)
        logging.info('Model input shape: %s', model.input_shape)
        logging.info('Model output shape: %s', model.output_shape)
        logging.info('Model number of weights: %s', model.count_params())
        # Linearly scale learning rate and the decay epochs by vanilla settings.
        base_lr = FLAGS.base_learning_rate * batch_size / 128
        lr_decay_epochs = [(int(start_epoch_str) * FLAGS.train_epochs) // 200
                           for start_epoch_str in FLAGS.lr_decay_epochs]
        lr_schedule = ub.schedules.WarmUpPiecewiseConstantSchedule(
            steps_per_epoch,
            base_lr,
            decay_ratio=FLAGS.lr_decay_ratio,
            decay_epochs=lr_decay_epochs,
            warmup_epochs=FLAGS.lr_warmup_epochs)
        optimizer = tf.keras.optimizers.SGD(lr_schedule,
                                            momentum=1.0 -
                                            FLAGS.one_minus_momentum,
                                            nesterov=True)
        metrics = {
            'train/negative_log_likelihood':
            tf.keras.metrics.Mean(),
            'train/accuracy':
            tf.keras.metrics.SparseCategoricalAccuracy(),
            'train/loss':
            tf.keras.metrics.Mean(),
            'train/ece':
            rm.metrics.ExpectedCalibrationError(num_bins=FLAGS.num_bins),
            'train/kl':
            tf.keras.metrics.Mean(),
            'train/kl_scale':
            tf.keras.metrics.Mean(),
            'test/negative_log_likelihood':
            tf.keras.metrics.Mean(),
            'test/accuracy':
            tf.keras.metrics.SparseCategoricalAccuracy(),
            'test/ece':
            rm.metrics.ExpectedCalibrationError(num_bins=FLAGS.num_bins),
        }
        if FLAGS.corruptions_interval > 0:
            corrupt_metrics = {}
            for intensity in range(1, 6):
                for corruption in corruption_types:
                    dataset_name = '{0}_{1}'.format(corruption, intensity)
                    corrupt_metrics['test/nll_{}'.format(dataset_name)] = (
                        tf.keras.metrics.Mean())
                    corrupt_metrics['test/accuracy_{}'.format(
                        dataset_name)] = (
                            tf.keras.metrics.SparseCategoricalAccuracy())
                    corrupt_metrics['test/ece_{}'.format(dataset_name)] = (
                        rm.metrics.ExpectedCalibrationError(
                            num_bins=FLAGS.num_bins))

        checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer)
        latest_checkpoint = tf.train.latest_checkpoint(FLAGS.output_dir)
        initial_epoch = 0
        if latest_checkpoint:
            # checkpoint.restore must be within a strategy.scope() so that optimizer
            # slot variables are mirrored.
            checkpoint.restore(latest_checkpoint)
            logging.info('Loaded checkpoint %s', latest_checkpoint)
            initial_epoch = optimizer.iterations.numpy() // steps_per_epoch

    @tf.function
    def train_step(iterator):
        """Training StepFn."""
        def step_fn(inputs):
            """Per-Replica StepFn."""
            images = inputs['features']
            labels = inputs['labels']
            with tf.GradientTape() as tape:
                logits = model(images, training=True)
                negative_log_likelihood = tf.reduce_mean(
                    tf.keras.losses.sparse_categorical_crossentropy(
                        labels, logits, from_logits=True))

                filtered_variables = []
                for var in model.trainable_variables:
                    # Apply l2 on the BN parameters and bias terms. This
                    # excludes only fast weight approximate posterior/prior parameters,
                    # but pay caution to their naming scheme.
                    if 'batch_norm' in var.name or 'bias' in var.name:
                        filtered_variables.append(tf.reshape(var, (-1, )))

                l2_loss = FLAGS.l2 * 2 * tf.nn.l2_loss(
                    tf.concat(filtered_variables, axis=0))
                kl = sum(model.losses)
                kl_scale = tf.cast(optimizer.iterations + 1, kl.dtype)
                kl_scale /= steps_per_epoch * FLAGS.kl_annealing_epochs
                kl_scale = tf.minimum(1., kl_scale)
                kl_loss = kl_scale * kl

                # Scale the loss given the TPUStrategy will reduce sum all gradients.
                loss = negative_log_likelihood + l2_loss + kl_loss
                scaled_loss = loss / strategy.num_replicas_in_sync

            grads = tape.gradient(scaled_loss, model.trainable_variables)
            optimizer.apply_gradients(zip(grads, model.trainable_variables))

            probs = tf.nn.softmax(logits)
            metrics['train/ece'].add_batch(probs, label=labels)
            metrics['train/loss'].update_state(loss)
            metrics['train/negative_log_likelihood'].update_state(
                negative_log_likelihood)
            metrics['train/kl'].update_state(kl)
            metrics['train/kl_scale'].update_state(kl_scale)
            metrics['train/accuracy'].update_state(labels, logits)

        for _ in tf.range(tf.cast(steps_per_epoch, tf.int32)):
            strategy.run(step_fn, args=(next(iterator), ))

    @tf.function
    def test_step(iterator, dataset_name):
        """Evaluation StepFn."""
        def step_fn(inputs):
            """Per-Replica StepFn."""
            images = inputs['features']
            labels = inputs['labels']
            # TODO(trandustin): Use more eval samples only on corrupted predictions;
            # it's expensive but a one-time compute if scheduled post-training.
            if FLAGS.num_eval_samples > 1 and dataset_name != 'clean':
                logits = tf.stack([
                    model(images, training=False)
                    for _ in range(FLAGS.num_eval_samples)
                ],
                                  axis=0)
            else:
                logits = model(images, training=False)
            probs = tf.nn.softmax(logits)
            if FLAGS.num_eval_samples > 1 and dataset_name != 'clean':
                probs = tf.reduce_mean(probs, axis=0)
            negative_log_likelihood = tf.reduce_mean(
                tf.keras.losses.sparse_categorical_crossentropy(labels, probs))

            if dataset_name == 'clean':
                metrics['test/negative_log_likelihood'].update_state(
                    negative_log_likelihood)
                metrics['test/accuracy'].update_state(labels, probs)
                metrics['test/ece'].add_batch(probs, label=labels)
            else:
                corrupt_metrics['test/nll_{}'.format(
                    dataset_name)].update_state(negative_log_likelihood)
                corrupt_metrics['test/accuracy_{}'.format(
                    dataset_name)].update_state(labels, probs)
                corrupt_metrics['test/ece_{}'.format(dataset_name)].add_batch(
                    probs, label=labels)

        for _ in tf.range(tf.cast(steps_per_eval, tf.int32)):
            strategy.run(step_fn, args=(next(iterator), ))

    metrics.update({'test/ms_per_example': tf.keras.metrics.Mean()})

    train_iterator = iter(train_dataset)
    start_time = time.time()
    for epoch in range(initial_epoch, FLAGS.train_epochs):
        logging.info('Starting to run epoch: %s', epoch)
        train_step(train_iterator)

        current_step = (epoch + 1) * steps_per_epoch
        max_steps = steps_per_epoch * FLAGS.train_epochs
        time_elapsed = time.time() - start_time
        steps_per_sec = float(current_step) / time_elapsed
        eta_seconds = (max_steps - current_step) / steps_per_sec
        message = ('{:.1%} completion: epoch {:d}/{:d}. {:.1f} steps/s. '
                   'ETA: {:.0f} min. Time elapsed: {:.0f} min'.format(
                       current_step / max_steps, epoch + 1, FLAGS.train_epochs,
                       steps_per_sec, eta_seconds / 60, time_elapsed / 60))
        logging.info(message)

        datasets_to_evaluate = {'clean': test_datasets['clean']}
        if (FLAGS.corruptions_interval > 0
                and (epoch + 1) % FLAGS.corruptions_interval == 0):
            datasets_to_evaluate = test_datasets
        for dataset_name, test_dataset in datasets_to_evaluate.items():
            test_iterator = iter(test_dataset)
            logging.info('Testing on dataset %s', dataset_name)
            logging.info('Starting to run eval at epoch: %s', epoch)
            test_start_time = time.time()
            test_step(test_iterator, dataset_name)
            ms_per_example = (time.time() - test_start_time) * 1e6 / batch_size
            metrics['test/ms_per_example'].update_state(ms_per_example)

            logging.info('Done with testing on %s', dataset_name)

        corrupt_results = {}
        if (FLAGS.corruptions_interval > 0
                and (epoch + 1) % FLAGS.corruptions_interval == 0):
            corrupt_results = utils.aggregate_corrupt_metrics(
                corrupt_metrics, corruption_types)

        logging.info('Train Loss: %.4f, Accuracy: %.2f%%',
                     metrics['train/loss'].result(),
                     metrics['train/accuracy'].result() * 100)
        logging.info('Test NLL: %.4f, Accuracy: %.2f%%',
                     metrics['test/negative_log_likelihood'].result(),
                     metrics['test/accuracy'].result() * 100)
        total_results = {
            name: metric.result()
            for name, metric in metrics.items()
        }
        total_results.update(corrupt_results)
        # Metrics from Robustness Metrics (like ECE) will return a dict with a
        # single key/value, instead of a scalar.
        total_results = {
            k: (list(v.values())[0] if isinstance(v, dict) else v)
            for k, v in total_results.items()
        }
        with summary_writer.as_default():
            for name, result in total_results.items():
                tf.summary.scalar(name, result, step=epoch + 1)

        for metric in metrics.values():
            metric.reset_states()

        if (FLAGS.checkpoint_interval > 0
                and (epoch + 1) % FLAGS.checkpoint_interval == 0):
            checkpoint_name = checkpoint.save(
                os.path.join(FLAGS.output_dir, 'checkpoint'))
            logging.info('Saved checkpoint to %s', checkpoint_name)

    final_checkpoint_name = checkpoint.save(
        os.path.join(FLAGS.output_dir, 'checkpoint'))
    logging.info('Saved last checkpoint to %s', final_checkpoint_name)
    with summary_writer.as_default():
        hp.hparams({
            'base_learning_rate': FLAGS.base_learning_rate,
            'one_minus_momentum': FLAGS.one_minus_momentum,
            'l2': FLAGS.l2,
            'prior_stddev': FLAGS.prior_stddev,
            'stddev_init': FLAGS.stddev_init,
        })
示例#29
0
def train_and_evaluate(
    model,
    num_epochs,
    steps_per_epoch,
    train_data,
    validation_steps,
    eval_data,
    output_dir,
    n_steps_history,
    FLAGS,
    decay_type,
    learning_rate=3e-5,
    s=1,
    n_batch_decay=1,
    metric_accuracy='metric',
):
    """
    Compiles keras model and loads data into it for training.
    """
    logging.info('training the model ...')
    model_callbacks = []

    # create meta data dictionary
    dict_model = {}
    dict_data = {}
    dict_parameter = {}
    dict_hardware = {}
    dict_results = {}
    dict_type_job = {}
    dict_software = {}

    # for debugging only
    activate_tensorboard = True
    activate_hp_tensorboard = False  # True
    activate_lr = False
    save_checkpoints = False  # True
    save_history_per_step = False  # True
    save_metadata = False  # True
    activate_timing = False  # True
    # drop official method that is not working
    activate_tf_summary_hp = True  # False
    # hardcoded way of doing hp
    activate_hardcoded_hp = True  # True

    # dependencies
    if activate_tf_summary_hp:
        save_history_per_step = True

    if FLAGS.is_hyperparameter_tuning:
        # get trial ID
        suffix = mu.get_trial_id()

        if suffix == '':
            logging.error('No trial ID for hyper parameter job!')
            FLAGS.is_hyperparameter_tuning = False
        else:
            # callback for hp
            logging.info('Creating a callback to store the metric!')
            if activate_tf_summary_hp:
                hp_metric = mu.HP_metric(metric_accuracy)
                model_callbacks.append(hp_metric)

    if output_dir:
        if activate_tensorboard:
            # tensorflow callback
            log_dir = os.path.join(output_dir, 'tensorboard')
            if FLAGS.is_hyperparameter_tuning:
                log_dir = os.path.join(log_dir, suffix)
            tensorboard_callback = tf.keras.callbacks.TensorBoard(
                log_dir=log_dir,
                histogram_freq=1,
                embeddings_freq=0,
                write_graph=True,
                update_freq='batch',
                profile_batch='10, 20')
            model_callbacks.append(tensorboard_callback)

        if save_checkpoints:
            # checkpoints callback
            checkpoint_dir = os.path.join(output_dir, 'checkpoint_model')
            if not FLAGS.is_hyperparameter_tuning:
                # not saving model during hyper parameter tuning
                # heckpoint_dir = os.path.join(checkpoint_dir, suffix)
                checkpoint_prefix = os.path.join(checkpoint_dir,
                                                 'ckpt_{epoch:02d}')
                checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
                    filepath=checkpoint_prefix,
                    verbose=1,
                    save_weights_only=True)
                model_callbacks.append(checkpoint_callback)

    if activate_lr:
        # decay learning rate callback

        # code snippet to make the switching between different learning rate decays possible
        if decay_type == 'exponential':
            decay_fn = mu.exponential_decay(lr0=learning_rate, s=s)
        elif decay_type == 'stepwise':
            decay_fn = mu.step_decay(lr0=learning_rate, s=s)
        elif decay_type == 'timebased':
            decay_fn = mu.time_decay(lr0=learning_rate, s=s)
        else:
            decay_fn = mu.no_decay(lr0=learning_rate)

        # exponential_decay_fn = mu.exponential_decay(lr0=learning_rate, s=s)
        # lr_scheduler = tf.keras.callbacks.LearningRateScheduler(exponential_decay_fn, verbose=1)
        # model_callbacks.append(lr_scheduler)

        # added these two lines for batch updates
        lr_decay_batch = mu.LearningRateSchedulerPerBatch(decay_fn,
                                                          n_batch_decay,
                                                          verbose=1)
        # lr_decay_batch = mu.LearningRateSchedulerPerBatch(exponential_decay_fn, n_batch_decay, verbose=0)
        # lambda step: ((learning_rate - min_learning_rate) * decay_rate ** step + min_learning_rate))
        model_callbacks.append(lr_decay_batch)

        # print_lr = mu.PrintLR()
        # model_callbacks.append(mu.PrintLR())
        # ---------------------------------------------------------------------------------------------------------------

        # callback to store all the learning rates
        # all_learning_rates = mu.LearningRateSchedulerPerBatch(model.optimizer, n_steps_history)
        # all_learning_rates = mu.LR_per_step()
        # all_learning_rates = mu.LR_per_step(model.optimizer)
        # model_callbacks.append(all_learning_rates)  # disble

    if save_history_per_step:
        # callback to create  history per step (not per epoch)
        histories_per_step = mu.History_per_step(eval_data, n_steps_history)
        model_callbacks.append(histories_per_step)

    if activate_timing:
        # callback to time each epoch
        timing = mu.TimingCallback()
        model_callbacks.append(timing)

    # checking model callbacks for
    logging.info('model\'s callback:\n {}'.format(str(model_callbacks)))

    # train the model
    # time the function
    start_time = time.time()

    logging.info('starting model.fit')
    # verbose = 0 (silent)
    # verbose = 1 (progress bar)
    # verbose = 2 (one line per epoch)
    verbose = 1
    history = model.fit(train_data,
                        epochs=num_epochs,
                        steps_per_epoch=steps_per_epoch,
                        validation_data=eval_data,
                        validation_steps=validation_steps,
                        verbose=verbose,
                        callbacks=model_callbacks)

    # print execution time
    elapsed_time_secs = time.time() - start_time
    logging.info('\nexecution time: {}'.format(
        timedelta(seconds=round(elapsed_time_secs))))

    # check model
    logging.info('model summary ={}'.format(model.summary()))
    logging.info('model input ={}'.format(model.inputs))
    logging.info('model outputs ={}'.format(model.outputs))

    # to be remove
    logging.info('\ndebugging .... : ')
    pp.print_info_data(train_data)

    if activate_timing:
        logging.info('timing per epoch:\n{}'.format(
            list(
                map(lambda x: str(timedelta(seconds=round(x))),
                    timing.timing_epoch))))
        logging.info('timing per validation:\n{}'.format(
            list(
                map(lambda x: str(timedelta(seconds=round(x))),
                    timing.timing_valid))))
        logging.info('sum timing over all epochs:\n{}'.format(
            timedelta(seconds=round(sum(timing.timing_epoch)))))

    # for hp parameter tuning in TensorBoard
    if FLAGS.is_hyperparameter_tuning:
        logging.info('setup hyperparameter tuning!')
        # test
        #params = json.loads(os.environ.get("CLUSTER_SPEC", "{}")).get("job", {})
        #print('debug: CLUSTER_SPEC1:', params)
        #params = json.loads(os.environ.get("CLUSTER_SPEC", "{}")).get("job", {}).get("job_args", {})
        #print('debug: CLUSTER_SPEC2:', params)
        logging.info('debug: os.environ.items():', os.environ.items())
        #
        if activate_hardcoded_hp:
            # trick to bypass ai platform bug
            logging.info('hardcoded hyperparameter tuning!')
            value_accuracy = histories_per_step.accuracies[-1]
            hpt = hypertune.HyperTune()
            hpt.report_hyperparameter_tuning_metric(
                hyperparameter_metric_tag=metric_accuracy,
                metric_value=value_accuracy,
                global_step=0)
        else:
            # should be extracted from /var/hypertune/output.metric
            logging.info('standard hyperparameter tuning!')
            # is this needed ?
            # value_accuracy = histories_per_step.accuracies[-1]

        # look at the content of the file
        path_metric = '/var/hypertune/output.metric'
        logging.info('checking if /var/hypertune/output.metric exist!')
        if os.path.isfile(path_metric):
            logging.info('file {} exist !'.format(path_metric))
            with open(path_metric, 'r') as f:
                logging.info('content of output.metric: {}'.format(f.read()))

        if activate_hp_tensorboard:
            logging.info('setup TensorBoard for hyperparameter tuning!')
            # CAIP
            #params = json.loads(os.environ.get("TF_CONFIG", "{}")).get("job", {}).get("hyperparameters", {}).get("params", {})
            #uCAIP
            params = json.loads(
                os.environ.get("CLUSTER_SPEC", "{}")
            )  #.get("job", {}).get("hyperparameters", {}).get("params", {})
            print('debug: CLUSTER_SPEC:', params)
            list_hp = []
            hparams = {}
            for el in params:
                hp_dict = dict(el)
                if hp_dict.get('type') == 'DOUBLE':
                    key_hp = hp.HParam(
                        hp_dict.get('parameter_name'),
                        hp.RealInterval(hp_dict.get('min_value'),
                                        hp_dict.get('max_value')))
                    list_hp.append(key_hp)
                    try:
                        hparams[key_hp] = FLAGS[hp_dict.get(
                            'parameter_name')].value
                    except KeyError:
                        logging.error(
                            'hyperparameter key {} doesn\'t exist'.format(
                                hp_dict.get('parameter_name')))

            hparams_dir = os.path.join(output_dir, 'hparams_tuning')
            with tf.summary.create_file_writer(hparams_dir).as_default():
                hp.hparams_config(
                    hparams=list_hp,
                    metrics=[
                        hp.Metric(metric_accuracy,
                                  display_name=metric_accuracy)
                    ],
                )

            hparams_dir = os.path.join(hparams_dir, suffix)
            with tf.summary.create_file_writer(hparams_dir).as_default():
                # record the values used in this trial
                hp.hparams(hparams)
                tf.summary.scalar(metric_accuracy, value_accuracy, step=1)

    if save_history_per_step:
        # save the history in a file
        search = re.search('gs://(.*?)/(.*)', output_dir)
        if search is not None:
            # temp folder locally and to be  ove on gcp later
            history_dir = os.path.join('./', model.name)
            os.makedirs(history_dir, exist_ok=True)
        else:
            # locally
            history_dir = os.path.join(output_dir, model.name)
            os.makedirs(history_dir, exist_ok=True)
        logging.debug('history_dir: \n {}'.format(history_dir))
        with open(history_dir + '/history', 'wb') as file:
            model_history = mu.History_trained_model(history.history,
                                                     history.epoch,
                                                     history.params)
            pickle.dump(model_history, file, pickle.HIGHEST_PROTOCOL)
        with open(history_dir + '/history_per_step', 'wb') as file:
            model_history_per_step = mu.History_per_steps_trained_model(
                histories_per_step.steps,
                histories_per_step.losses,
                histories_per_step.accuracies,
                histories_per_step.val_steps,
                histories_per_step.val_losses,
                histories_per_step.val_accuracies,
                0,  # all_learning_rates.all_lr,
                0,  # all_learning_rates.all_lr_alternative,
                0)  # all_learning_rates.all_lr_logs)
            pickle.dump(model_history_per_step, file, pickle.HIGHEST_PROTOCOL)

    if output_dir:
        # save the model
        savemodel_path = os.path.join(output_dir, 'saved_model')

        if not FLAGS.is_hyperparameter_tuning:
            # not saving model during hyper parameter tuning
            # savemodel_path = os.path.join(savemodel_path, suffix)
            model.save(os.path.join(savemodel_path, model.name))

            model2 = tf.keras.models.load_model(
                os.path.join(savemodel_path, model.name))
            # check model
            logging.info('model2 summary ={}'.format(model2.summary()))
            logging.info('model2 input ={}'.format(model2.inputs))
            logging.info('model2 outputs ={}'.format(model2.outputs))

            logging.info('model2 signature outputs ={}'.format(
                model2.signatures['serving_default'].structured_outputs))
            logging.info('model2 inputs ={}'.format(
                model2.signatures['serving_default'].inputs[0]))

        if save_history_per_step:
            # save history
            search = re.search('gs://(.*?)/(.*)', output_dir)
            if search is not None:
                bucket_name = search.group(1)
                blob_name = search.group(2)
                output_folder = blob_name + '/history'
                if FLAGS.is_hyperparameter_tuning:
                    output_folder = os.path.join(output_folder, suffix)
                mu.copy_local_directory_to_gcs(history_dir, bucket_name,
                                               output_folder)

    if save_metadata:
        # add meta data
        dict_model['pretrained_transformer_model'] = FLAGS.pretrained_model_dir
        dict_model['num_classes'] = FLAGS.num_classes

        dict_data['train'] = FLAGS.input_train_tfrecords
        dict_data['eval'] = FLAGS.input_eval_tfrecords

        dict_parameter[
            'use_decay_learning_rate'] = FLAGS.use_decay_learning_rate
        dict_parameter['epochs'] = FLAGS.epochs
        dict_parameter['steps_per_epoch_train'] = FLAGS.steps_per_epoch_train
        dict_parameter['steps_per_epoch_eval'] = FLAGS.steps_per_epoch_eval
        dict_parameter['n_steps_history'] = FLAGS.n_steps_history
        dict_parameter['batch_size_train'] = FLAGS.batch_size_train
        dict_parameter['batch_size_eval'] = FLAGS.batch_size_eval
        dict_parameter['learning_rate'] = FLAGS.learning_rate
        dict_parameter['epsilon'] = FLAGS.epsilon

        dict_hardware['is_tpu'] = FLAGS.use_tpu

        dict_type_job[
            'is_hyperparameter_tuning'] = FLAGS.is_hyperparameter_tuning
        dict_type_job['is_tpu'] = FLAGS.use_tpu

        dict_software['tensorflow'] = tf.__version__
        dict_software['transformer'] = __version__
        dict_software['python'] = sys.version

        # aggregate dictionaries
        dict_all = {
            'model': dict_model,
            'data': dict_data,
            'parameter': dict_parameter,
            'hardware': dict_hardware,
            'results': dict_results,
            'type_job': dict_type_job,
            'software': dict_software
        }

        # save metadata
        search = re.search('gs://(.*?)/(.*)', output_dir)
        if search is not None:
            bucket_name = search.group(1)
            blob_name = search.group(2)
            output_folder = blob_name + '/metadata'

            storage_client = storage.Client()
            bucket = storage_client.bucket(bucket_name)
            blob = bucket.blob(output_folder + '/model_job_metadata.json')
            blob.upload_from_string(data=json.dumps(dict_all),
                                    content_type='application/json')
def train_and_validate(train_x, train_y, test_x, test_y, hparams):
    unique_items = len(train_y[0])
    model = LSTMRec(
        vocabulary_size=unique_items,
        emb_output_dim=hparams['emb_dim'],
        lstm_units=hparams['lstm_units'],
        lstm_activation=hparams['lstm_activation'],
        lstm_recurrent_activation=hparams['lstm_recurrent_activation'],
        lstm_dropout=hparams['lstm_dropout'],
        lstm_recurrent_dropout=hparams['lstm_recurrent_dropout'],
        dense_activation=hparams['dense_activation'])
    model.compile(optimizer=Adam(learning_rate=hparams['learning_rate'],
                                 beta_1=hparams['adam_beta_1'],
                                 beta_2=hparams['adam_beta_2'],
                                 epsilon=hparams['adam_epsilon']),
                  loss='binary_crossentropy',
                  metrics=[
                      Precision(top_k=1, name='P_at_1'),
                      Precision(top_k=3, name='P_at_3'),
                      Precision(top_k=5, name='P_at_5'),
                      Precision(top_k=10, name='P_at_10'),
                      Recall(top_k=10, name='R_at_10'),
                      Recall(top_k=50, name='R_at_50'),
                      Recall(top_k=100, name='R_at_100')
                  ])
    hst = model.fit(
        x=train_x,
        y=train_y,
        batch_size=hparams['batch_size'],
        epochs=250,
        callbacks=[
            EarlyStopping(monitor='val_R_at_10',
                          patience=10,
                          mode='max',
                          restore_best_weights=True,
                          verbose=True),
            ModelCheckpoint(filepath=os.path.join(os.pardir, os.pardir,
                                                  'models',
                                                  hparams['run_id'] + '.ckpt'),
                            monitor='val_R_at_10',
                            mode='max',
                            save_best_only=True,
                            save_weights_only=True,
                            verbose=True),
            TensorBoard(log_dir=os.path.join(os.pardir, os.pardir, 'logs',
                                             hparams['run_id']),
                        histogram_freq=1)
        ],
        validation_split=0.2)
    val_best_epoch = np.argmax(hst.history['val_R_at_10'])
    test_results = model.evaluate(test_x, test_y)
    with tf.summary.create_file_writer(
            os.path.join(os.pardir, os.pardir, 'logs', hparams['run_id'],
                         'hparams')).as_default():
        hp.hparams(hparams)
        tf.summary.scalar('train.final_loss',
                          hst.history["val_loss"][val_best_epoch],
                          step=val_best_epoch)
        tf.summary.scalar('train.final_P_at_1',
                          hst.history["val_P_at_1"][val_best_epoch],
                          step=val_best_epoch)
        tf.summary.scalar('train.final_P_at_3',
                          hst.history["val_P_at_3"][val_best_epoch],
                          step=val_best_epoch)
        tf.summary.scalar('train.final_P_at_5',
                          hst.history["val_P_at_5"][val_best_epoch],
                          step=val_best_epoch)
        tf.summary.scalar('train.final_P_at_10',
                          hst.history["val_P_at_10"][val_best_epoch],
                          step=val_best_epoch)
        tf.summary.scalar('train.final_R_at_10',
                          hst.history["val_R_at_10"][val_best_epoch],
                          step=val_best_epoch)
        tf.summary.scalar('train.final_R_at_50',
                          hst.history["val_R_at_50"][val_best_epoch],
                          step=val_best_epoch)
        tf.summary.scalar('train.final_R_at_100',
                          hst.history["val_R_at_100"][val_best_epoch],
                          step=val_best_epoch)

        tf.summary.scalar('test.final_loss',
                          test_results[0],
                          step=val_best_epoch)
        tf.summary.scalar('test.final_P_at_1',
                          test_results[1],
                          step=val_best_epoch)
        tf.summary.scalar('test.final_P_at_3',
                          test_results[2],
                          step=val_best_epoch)
        tf.summary.scalar('test.final_P_at_5',
                          test_results[3],
                          step=val_best_epoch)
        tf.summary.scalar('test.final_P_at_10',
                          test_results[4],
                          step=val_best_epoch)
        tf.summary.scalar('test.final_R_at_10',
                          test_results[5],
                          step=val_best_epoch)
        tf.summary.scalar('test.final_R_at_50',
                          test_results[6],
                          step=val_best_epoch)
        tf.summary.scalar('test.final_R_at_100',
                          test_results[7],
                          step=val_best_epoch)

    return val_best_epoch, test_results