Esempio n. 1
0
def run(trial_dir: str, flag_string: Optional[str]):
  """Run the experiment.

  Args:
    trial_dir: String to the dir to write checkpoints to and read them from.
    flag_string: Optional string used to record what flags the job was run with.
  """
  tf.random.set_seed(FLAGS.seed)
  np.random.seed(FLAGS.seed)

  if not FLAGS.eval_frequency:
    FLAGS.eval_frequency = FLAGS.log_frequency

  if FLAGS.eval_frequency % FLAGS.log_frequency != 0:
    raise ValueError(
        'log_frequency ({}) must evenly divide eval_frequency '
        '({}).'.format(FLAGS.log_frequency, FLAGS.eval_frequency))

  strategy = ub.strategy_utils.get_strategy(
      FLAGS.tpu, use_tpu=not FLAGS.use_cpu and not FLAGS.use_gpu)
  with strategy.scope():
    _maybe_setup_trial_dir(strategy, trial_dir, flag_string, FLAGS.mode)

    # TODO(znado): pass all dataset and model kwargs.
    if FLAGS.eval_batch_size is None:
      eval_batch_size = FLAGS.batch_size
    else:
      eval_batch_size = FLAGS.eval_batch_size
    train_dataset_builder = ub.datasets.get(
        dataset_name=FLAGS.dataset_name,
        split='train',
        validation_percent=FLAGS.validation_percent,
        data_dir=FLAGS.data_dir,
        shuffle_buffer_size=FLAGS.shuffle_buffer_size)
    if FLAGS.validation_percent > 0:
      validation_dataset_builder = ub.datasets.get(
          dataset_name=FLAGS.dataset_name,
          split='validation',
          validation_percent=FLAGS.validation_percent,
          data_dir=FLAGS.data_dir)
    else:
      validation_dataset_builder = None
    test_dataset_builder = ub.datasets.get(
        dataset_name=FLAGS.dataset_name,
        split='test',
        data_dir=FLAGS.data_dir)
    model = ub.models.get(
        FLAGS.model_name,
        batch_size=FLAGS.batch_size,
        num_motifs=FLAGS.num_motifs,
        len_motifs=FLAGS.len_motifs,
        num_denses=FLAGS.num_denses,
        l2_weight=FLAGS.l2_regularization,
        depth=FLAGS.wide_resnet_depth,
        width_multiplier=FLAGS.wide_resnet_width_multiplier,
        version=FLAGS.wide_resnet_version)

    metrics = {
        'accuracy': tf.keras.metrics.SparseCategoricalAccuracy(),
        'brier_score': BrierScore(name='brier_score'),
        'loss': tf.keras.metrics.SparseCategoricalCrossentropy(),
    }

    # Record all non-default hparams in tensorboard.
    hparams = _get_hparams()

    if FLAGS.mode == 'eval':
      _check_batch_replica_divisible(eval_batch_size, strategy)
      eval_lib.run_eval_loop(
          validation_dataset_builder=validation_dataset_builder,
          test_dataset_builder=test_dataset_builder,
          batch_size=FLAGS.eval_batch_size,
          model=model,
          trial_dir=trial_dir,
          train_steps=FLAGS.train_steps,
          strategy=strategy,
          metrics=metrics,
          checkpoint_step=FLAGS.checkpoint_step,
          hparams=hparams)
      return

    _check_batch_replica_divisible(FLAGS.batch_size, strategy)
    if FLAGS.mode == 'train_and_eval':
      _check_batch_replica_divisible(eval_batch_size, strategy)

    steps_per_epoch = train_dataset_builder.num_examples // FLAGS.batch_size
    optimizer_kwargs = {
        k[len('optimizer_hparams_'):]: FLAGS[k].value for k in FLAGS
        if k.startswith('optimizer_hparams_')
    }
    optimizer_kwargs.update({
        k[len('schedule_hparams_'):]: FLAGS[k].value for k in FLAGS
        if k.startswith('schedule_hparams_')
    })

    optimizer = ub.optimizers.get(
        optimizer_name=FLAGS.optimizer,
        learning_rate_schedule=FLAGS.learning_rate_schedule,
        learning_rate=FLAGS.learning_rate,
        weight_decay=FLAGS.weight_decay,
        steps_per_epoch=steps_per_epoch,
        model=model,
        **optimizer_kwargs)

    train_lib.run_train_loop(
        train_dataset_builder=train_dataset_builder,
        validation_dataset_builder=validation_dataset_builder,
        test_dataset_builder=test_dataset_builder,
        batch_size=FLAGS.batch_size,
        eval_batch_size=FLAGS.eval_batch_size,
        model=model,
        optimizer=optimizer,
        eval_frequency=FLAGS.eval_frequency,
        log_frequency=FLAGS.log_frequency,
        trial_dir=trial_dir,
        train_steps=FLAGS.train_steps,
        mode=FLAGS.mode,
        strategy=strategy,
        metrics=metrics,
        hparams=hparams)
Esempio n. 2
0
def run(trial_dir: str, flag_string: Optional[str]):
    """Run the experiment.

  Args:
    trial_dir: String to the dir to write checkpoints to and read them from.
    flag_string: Optional string used to record what flags the job was run with.
  """
    tf.random.set_seed(FLAGS.seed)
    np.random.seed(FLAGS.seed)

    if not FLAGS.eval_frequency:
        FLAGS.eval_frequency = FLAGS.log_frequency

    if FLAGS.eval_frequency % FLAGS.log_frequency != 0:
        raise ValueError(
            'log_frequency ({}) must evenly divide eval_frequency '
            '({}).'.format(FLAGS.log_frequency, FLAGS.eval_frequency))

    strategy = ub.strategy_utils.get_strategy(FLAGS.tpu,
                                              use_tpu=not FLAGS.use_cpu
                                              and not FLAGS.use_gpu)
    with strategy.scope():
        _maybe_setup_trial_dir(strategy, trial_dir, flag_string)

        # TODO(znado): pass all dataset and model kwargs.
        train_dataset_builder = ub.datasets.get(
            dataset_name=FLAGS.dataset_name,
            split='train',
            validation_percent=FLAGS.validation_percent,
            shuffle_buffer_size=FLAGS.shuffle_buffer_size)
        if FLAGS.validation_percent > 0:
            validation_dataset_builder = ub.datasets.get(
                dataset_name=FLAGS.dataset_name,
                split='validation',
                validation_percent=FLAGS.validation_percent,
                shuffle_buffer_size=FLAGS.shuffle_buffer_size)
        else:
            validation_dataset_builder = None
        test_dataset_builder = ub.datasets.get(
            dataset_name=FLAGS.dataset_name,
            split='test',
            validation_percent=FLAGS.validation_percent,
            shuffle_buffer_size=FLAGS.shuffle_buffer_size)

        if FLAGS.use_spec_norm:
            logging.info('Use spectral normalization.')
            spec_norm_hparams = {
                'spec_norm_bound': FLAGS.spec_norm_bound,
                'spec_norm_iteration': FLAGS.spec_norm_iteration
            }
        else:
            spec_norm_hparams = None

        if FLAGS.use_gp_layer:
            logging.info('Use GP for output layer.')
            gp_layer_hparams = {
                'gp_input_dim': FLAGS.gp_input_dim,
                'gp_hidden_dim': FLAGS.gp_hidden_dim,
                'gp_scale': FLAGS.gp_scale,
                'gp_bias': FLAGS.gp_bias,
                'gp_input_normalization': FLAGS.gp_input_normalization,
                'gp_cov_discount_factor': FLAGS.gp_cov_discount_factor,
                'gp_cov_ridge_penalty': FLAGS.gp_cov_ridge_penalty
            }
        else:
            gp_layer_hparams = None

        model = ub_smu_models.get(
            FLAGS.model_name,
            num_classes=FLAGS.num_classes,
            batch_size=FLAGS.batch_size,
            len_seqs=FLAGS.len_seqs,
            num_motifs=FLAGS.num_motifs,
            len_motifs=FLAGS.len_motifs,
            num_denses=FLAGS.num_denses,
            depth=FLAGS.wide_resnet_depth,
            width_multiplier=FLAGS.wide_resnet_width_multiplier,
            l2_weight=FLAGS.l2_regularization,
            dropout_rate=FLAGS.dropout_rate,
            before_conv_dropout=FLAGS.before_conv_dropout,
            use_mc_dropout=FLAGS.use_mc_dropout,
            spec_norm_hparams=spec_norm_hparams,
            gp_layer_hparams=gp_layer_hparams)

        metrics = {
            'accuracy': tf.keras.metrics.SparseCategoricalAccuracy(),
            'brier_score': rm.metrics.Brier(),
            'ece':
            rm.metrics.ExpectedCalibrationError(num_bins=FLAGS.num_bins),
            'loss': tf.keras.metrics.SparseCategoricalCrossentropy(),
        }

        # Record all non-default hparams in tensorboard.
        hparams = _get_hparams()

        ood_dataset_builder = None
        ood_metrics = None
        if FLAGS.run_ood:
            if 'cifar' in FLAGS.dataset_name and FLAGS.ood_dataset_name == 'svhn':
                svhn_normalize_by_cifar = True
            else:
                svhn_normalize_by_cifar = False

            ood_dataset_builder_cls = ub.datasets.DATASETS[
                FLAGS.ood_dataset_name]
            ood_dataset_builder_cls = ub.datasets.make_ood_dataset(
                ood_dataset_builder_cls)
            ood_dataset_builder = ood_dataset_builder_cls(
                in_distribution_dataset=test_dataset_builder,
                split='test',
                validation_percent=FLAGS.validation_percent,
                normalize_by_cifar=svhn_normalize_by_cifar,
                data_mode='ood')
            _check_batch_replica_divisible(FLAGS.eval_batch_size, strategy)

            ood_metrics = {
                'auroc':
                tf.keras.metrics.AUC(curve='ROC',
                                     summation_method='interpolation'),
                'auprc':
                tf.keras.metrics.AUC(curve='PR',
                                     summation_method='interpolation')
            }

            aux_metrics = [
                ('spec_at_sen', tf.keras.metrics.SpecificityAtSensitivity,
                 FLAGS.sensitivity_thresholds),
                ('sen_at_spec', tf.keras.metrics.SensitivityAtSpecificity,
                 FLAGS.specificity_thresholds),
                ('prec_at_rec', tf.keras.metrics.PrecisionAtRecall,
                 FLAGS.recall_thresholds),
                ('rec_at_prec', tf.keras.metrics.RecallAtPrecision,
                 FLAGS.precision_thresholds)
            ]

            for metric_name, metric_fn, threshold_vals in aux_metrics:
                vals = [float(x) for x in threshold_vals]
                thresholds = np.linspace(vals[0], vals[1], int(vals[2]))
                for thresh in thresholds:
                    name = f'{metric_name}_{thresh:.2f}'
                    ood_metrics[name] = metric_fn(thresh)

        if FLAGS.mode == 'eval':
            _check_batch_replica_divisible(FLAGS.eval_batch_size, strategy)
            eval_lib.run_eval_loop(
                validation_dataset_builder=validation_dataset_builder,
                test_dataset_builder=test_dataset_builder,
                batch_size=FLAGS.eval_batch_size,
                model=model,
                trial_dir=trial_dir,
                train_steps=FLAGS.train_steps,
                strategy=strategy,
                metrics=metrics,
                checkpoint_step=FLAGS.checkpoint_step,
                hparams=hparams,
                ood_dataset_builder=ood_dataset_builder,
                ood_metrics=ood_metrics,
                mean_field_factor=FLAGS.gp_mean_field_factor)
            return

        if FLAGS.mode == 'train_and_eval':
            _check_batch_replica_divisible(FLAGS.eval_batch_size, strategy)

        steps_per_epoch = train_dataset_builder.num_examples // FLAGS.batch_size
        optimizer_kwargs = {
            k[len('optimizer_hparams_'):]: FLAGS[k].value
            for k in FLAGS if k.startswith('optimizer_hparams_')
        }
        optimizer_kwargs.update({
            k[len('schedule_hparams_'):]: FLAGS[k].value
            for k in FLAGS if k.startswith('schedule_hparams_')
        })

        optimizer = ub.optimizers.get(
            optimizer_name=FLAGS.optimizer,
            learning_rate_schedule=FLAGS.learning_rate_schedule,
            learning_rate=FLAGS.learning_rate,
            weight_decay=FLAGS.weight_decay,
            steps_per_epoch=steps_per_epoch,
            model=model,
            **optimizer_kwargs)

        train_lib.run_train_loop(
            train_dataset_builder=train_dataset_builder,
            validation_dataset_builder=validation_dataset_builder,
            test_dataset_builder=test_dataset_builder,
            batch_size=FLAGS.batch_size,
            eval_batch_size=FLAGS.eval_batch_size,
            model=model,
            optimizer=optimizer,
            eval_frequency=FLAGS.eval_frequency,
            log_frequency=FLAGS.log_frequency,
            trial_dir=trial_dir,
            train_steps=FLAGS.train_steps,
            mode=FLAGS.mode,
            strategy=strategy,
            metrics=metrics,
            hparams=hparams,
            ood_dataset_builder=ood_dataset_builder,
            ood_metrics=ood_metrics,
            focal_loss_gamma=FLAGS.focal_loss_gamma,
            mean_field_factor=FLAGS.gp_mean_field_factor)
Esempio n. 3
0
def run(trial_dir: str, flag_string: Optional[str]):
    """Run the experiment.

  Args:
    trial_dir: String to the dir to write checkpoints to and read them from.
    flag_string: Optional string used to record what flags the job was run with.
  """
    tf.random.set_seed(FLAGS.seed)
    np.random.seed(FLAGS.seed)

    if not FLAGS.eval_frequency:
        FLAGS.eval_frequency = FLAGS.log_frequency

    if FLAGS.eval_frequency % FLAGS.log_frequency != 0:
        raise ValueError(
            'log_frequency ({}) must evenly divide eval_frequency '
            '({}).'.format(FLAGS.log_frequency, FLAGS.eval_frequency))

    strategy = ub.strategy_utils.get_strategy(FLAGS.tpu, FLAGS.use_tpu)
    with strategy.scope():
        _maybe_setup_trial_dir(strategy, trial_dir, flag_string)

        # TODO(znado): pass all dataset and model kwargs.
        train_dataset_builder = ub.datasets.get(
            dataset_name=FLAGS.dataset_name,
            split='train',
            validation_percent=FLAGS.validation_percent,
            shuffle_buffer_size=FLAGS.shuffle_buffer_size)
        if FLAGS.validation_percent > 0:
            validation_dataset_builder = ub.datasets.get(
                dataset_name=FLAGS.dataset_name,
                split='validation',
                validation_percent=FLAGS.validation_percent)
        else:
            validation_dataset_builder = None
        test_dataset_builder = ub.datasets.get(dataset_name=FLAGS.dataset_name,
                                               split='test')
        model = models_lib.create_model(batch_size=FLAGS.batch_size,
                                        num_classes=10,
                                        distance_logits=FLAGS.distance_logits)
        loss_fn = loss_lib.get(FLAGS.loss_name,
                               from_logits=True,
                               dm_alpha=FLAGS.dm_alpha)

        if FLAGS.mode == 'eval':
            _check_batch_replica_divisible(FLAGS.eval_batch_size, strategy)
            eval_lib.run_eval_loop(
                validation_dataset_builder=validation_dataset_builder,
                test_dataset_builder=test_dataset_builder,
                batch_size=FLAGS.eval_batch_size,
                model=model,
                loss_fn=loss_fn,
                trial_dir=trial_dir,
                train_steps=FLAGS.train_steps,
                strategy=strategy,
                metric_names=['accuracy', 'loss'],
                checkpoint_step=FLAGS.checkpoint_step)
            return

        _check_batch_replica_divisible(FLAGS.batch_size, strategy)
        if FLAGS.mode == 'train_and_eval':
            _check_batch_replica_divisible(FLAGS.eval_batch_size, strategy)

        steps_per_epoch = train_dataset_builder.num_examples // FLAGS.batch_size
        optimizer_kwargs = {
            k[len('optimizer_hparams_'):]: FLAGS[k].value
            for k in FLAGS if k.startswith('optimizer_hparams_')
        }
        optimizer = ub.optimizers.get(
            optimizer_name=FLAGS.optimizer,
            learning_rate_schedule=FLAGS.learning_rate_schedule,
            learning_rate=FLAGS.learning_rate,
            weight_decay=FLAGS.weight_decay,
            steps_per_epoch=steps_per_epoch,
            **optimizer_kwargs)

        train_lib.run_train_loop(
            train_dataset_builder=train_dataset_builder,
            validation_dataset_builder=validation_dataset_builder,
            test_dataset_builder=test_dataset_builder,
            batch_size=FLAGS.batch_size,
            eval_batch_size=FLAGS.eval_batch_size,
            model=model,
            optimizer=optimizer,
            loss_fn=loss_fn,
            eval_frequency=FLAGS.eval_frequency,
            log_frequency=FLAGS.log_frequency,
            trial_dir=trial_dir,
            train_steps=FLAGS.train_steps,
            mode=FLAGS.mode,
            strategy=strategy,
            metric_names=['accuracy', 'loss'])