def run(trial_dir: str, flag_string: Optional[str]): """Run the experiment. Args: trial_dir: String to the dir to write checkpoints to and read them from. flag_string: Optional string used to record what flags the job was run with. """ tf.random.set_seed(FLAGS.seed) np.random.seed(FLAGS.seed) if not FLAGS.eval_frequency: FLAGS.eval_frequency = FLAGS.log_frequency if FLAGS.eval_frequency % FLAGS.log_frequency != 0: raise ValueError( 'log_frequency ({}) must evenly divide eval_frequency ' '({}).'.format(FLAGS.log_frequency, FLAGS.eval_frequency)) strategy = ub.strategy_utils.get_strategy( FLAGS.tpu, use_tpu=not FLAGS.use_cpu and not FLAGS.use_gpu) with strategy.scope(): _maybe_setup_trial_dir(strategy, trial_dir, flag_string, FLAGS.mode) # TODO(znado): pass all dataset and model kwargs. if FLAGS.eval_batch_size is None: eval_batch_size = FLAGS.batch_size else: eval_batch_size = FLAGS.eval_batch_size train_dataset_builder = ub.datasets.get( dataset_name=FLAGS.dataset_name, split='train', validation_percent=FLAGS.validation_percent, data_dir=FLAGS.data_dir, shuffle_buffer_size=FLAGS.shuffle_buffer_size) if FLAGS.validation_percent > 0: validation_dataset_builder = ub.datasets.get( dataset_name=FLAGS.dataset_name, split='validation', validation_percent=FLAGS.validation_percent, data_dir=FLAGS.data_dir) else: validation_dataset_builder = None test_dataset_builder = ub.datasets.get( dataset_name=FLAGS.dataset_name, split='test', data_dir=FLAGS.data_dir) model = ub.models.get( FLAGS.model_name, batch_size=FLAGS.batch_size, num_motifs=FLAGS.num_motifs, len_motifs=FLAGS.len_motifs, num_denses=FLAGS.num_denses, l2_weight=FLAGS.l2_regularization, depth=FLAGS.wide_resnet_depth, width_multiplier=FLAGS.wide_resnet_width_multiplier, version=FLAGS.wide_resnet_version) metrics = { 'accuracy': tf.keras.metrics.SparseCategoricalAccuracy(), 'brier_score': BrierScore(name='brier_score'), 'loss': tf.keras.metrics.SparseCategoricalCrossentropy(), } # Record all non-default hparams in tensorboard. hparams = _get_hparams() if FLAGS.mode == 'eval': _check_batch_replica_divisible(eval_batch_size, strategy) eval_lib.run_eval_loop( validation_dataset_builder=validation_dataset_builder, test_dataset_builder=test_dataset_builder, batch_size=FLAGS.eval_batch_size, model=model, trial_dir=trial_dir, train_steps=FLAGS.train_steps, strategy=strategy, metrics=metrics, checkpoint_step=FLAGS.checkpoint_step, hparams=hparams) return _check_batch_replica_divisible(FLAGS.batch_size, strategy) if FLAGS.mode == 'train_and_eval': _check_batch_replica_divisible(eval_batch_size, strategy) steps_per_epoch = train_dataset_builder.num_examples // FLAGS.batch_size optimizer_kwargs = { k[len('optimizer_hparams_'):]: FLAGS[k].value for k in FLAGS if k.startswith('optimizer_hparams_') } optimizer_kwargs.update({ k[len('schedule_hparams_'):]: FLAGS[k].value for k in FLAGS if k.startswith('schedule_hparams_') }) optimizer = ub.optimizers.get( optimizer_name=FLAGS.optimizer, learning_rate_schedule=FLAGS.learning_rate_schedule, learning_rate=FLAGS.learning_rate, weight_decay=FLAGS.weight_decay, steps_per_epoch=steps_per_epoch, model=model, **optimizer_kwargs) train_lib.run_train_loop( train_dataset_builder=train_dataset_builder, validation_dataset_builder=validation_dataset_builder, test_dataset_builder=test_dataset_builder, batch_size=FLAGS.batch_size, eval_batch_size=FLAGS.eval_batch_size, model=model, optimizer=optimizer, eval_frequency=FLAGS.eval_frequency, log_frequency=FLAGS.log_frequency, trial_dir=trial_dir, train_steps=FLAGS.train_steps, mode=FLAGS.mode, strategy=strategy, metrics=metrics, hparams=hparams)
def run(trial_dir: str, flag_string: Optional[str]): """Run the experiment. Args: trial_dir: String to the dir to write checkpoints to and read them from. flag_string: Optional string used to record what flags the job was run with. """ tf.random.set_seed(FLAGS.seed) np.random.seed(FLAGS.seed) if not FLAGS.eval_frequency: FLAGS.eval_frequency = FLAGS.log_frequency if FLAGS.eval_frequency % FLAGS.log_frequency != 0: raise ValueError( 'log_frequency ({}) must evenly divide eval_frequency ' '({}).'.format(FLAGS.log_frequency, FLAGS.eval_frequency)) strategy = ub.strategy_utils.get_strategy(FLAGS.tpu, use_tpu=not FLAGS.use_cpu and not FLAGS.use_gpu) with strategy.scope(): _maybe_setup_trial_dir(strategy, trial_dir, flag_string) # TODO(znado): pass all dataset and model kwargs. train_dataset_builder = ub.datasets.get( dataset_name=FLAGS.dataset_name, split='train', validation_percent=FLAGS.validation_percent, shuffle_buffer_size=FLAGS.shuffle_buffer_size) if FLAGS.validation_percent > 0: validation_dataset_builder = ub.datasets.get( dataset_name=FLAGS.dataset_name, split='validation', validation_percent=FLAGS.validation_percent, shuffle_buffer_size=FLAGS.shuffle_buffer_size) else: validation_dataset_builder = None test_dataset_builder = ub.datasets.get( dataset_name=FLAGS.dataset_name, split='test', validation_percent=FLAGS.validation_percent, shuffle_buffer_size=FLAGS.shuffle_buffer_size) if FLAGS.use_spec_norm: logging.info('Use spectral normalization.') spec_norm_hparams = { 'spec_norm_bound': FLAGS.spec_norm_bound, 'spec_norm_iteration': FLAGS.spec_norm_iteration } else: spec_norm_hparams = None if FLAGS.use_gp_layer: logging.info('Use GP for output layer.') gp_layer_hparams = { 'gp_input_dim': FLAGS.gp_input_dim, 'gp_hidden_dim': FLAGS.gp_hidden_dim, 'gp_scale': FLAGS.gp_scale, 'gp_bias': FLAGS.gp_bias, 'gp_input_normalization': FLAGS.gp_input_normalization, 'gp_cov_discount_factor': FLAGS.gp_cov_discount_factor, 'gp_cov_ridge_penalty': FLAGS.gp_cov_ridge_penalty } else: gp_layer_hparams = None model = ub_smu_models.get( FLAGS.model_name, num_classes=FLAGS.num_classes, batch_size=FLAGS.batch_size, len_seqs=FLAGS.len_seqs, num_motifs=FLAGS.num_motifs, len_motifs=FLAGS.len_motifs, num_denses=FLAGS.num_denses, depth=FLAGS.wide_resnet_depth, width_multiplier=FLAGS.wide_resnet_width_multiplier, l2_weight=FLAGS.l2_regularization, dropout_rate=FLAGS.dropout_rate, before_conv_dropout=FLAGS.before_conv_dropout, use_mc_dropout=FLAGS.use_mc_dropout, spec_norm_hparams=spec_norm_hparams, gp_layer_hparams=gp_layer_hparams) metrics = { 'accuracy': tf.keras.metrics.SparseCategoricalAccuracy(), 'brier_score': rm.metrics.Brier(), 'ece': rm.metrics.ExpectedCalibrationError(num_bins=FLAGS.num_bins), 'loss': tf.keras.metrics.SparseCategoricalCrossentropy(), } # Record all non-default hparams in tensorboard. hparams = _get_hparams() ood_dataset_builder = None ood_metrics = None if FLAGS.run_ood: if 'cifar' in FLAGS.dataset_name and FLAGS.ood_dataset_name == 'svhn': svhn_normalize_by_cifar = True else: svhn_normalize_by_cifar = False ood_dataset_builder_cls = ub.datasets.DATASETS[ FLAGS.ood_dataset_name] ood_dataset_builder_cls = ub.datasets.make_ood_dataset( ood_dataset_builder_cls) ood_dataset_builder = ood_dataset_builder_cls( in_distribution_dataset=test_dataset_builder, split='test', validation_percent=FLAGS.validation_percent, normalize_by_cifar=svhn_normalize_by_cifar, data_mode='ood') _check_batch_replica_divisible(FLAGS.eval_batch_size, strategy) ood_metrics = { 'auroc': tf.keras.metrics.AUC(curve='ROC', summation_method='interpolation'), 'auprc': tf.keras.metrics.AUC(curve='PR', summation_method='interpolation') } aux_metrics = [ ('spec_at_sen', tf.keras.metrics.SpecificityAtSensitivity, FLAGS.sensitivity_thresholds), ('sen_at_spec', tf.keras.metrics.SensitivityAtSpecificity, FLAGS.specificity_thresholds), ('prec_at_rec', tf.keras.metrics.PrecisionAtRecall, FLAGS.recall_thresholds), ('rec_at_prec', tf.keras.metrics.RecallAtPrecision, FLAGS.precision_thresholds) ] for metric_name, metric_fn, threshold_vals in aux_metrics: vals = [float(x) for x in threshold_vals] thresholds = np.linspace(vals[0], vals[1], int(vals[2])) for thresh in thresholds: name = f'{metric_name}_{thresh:.2f}' ood_metrics[name] = metric_fn(thresh) if FLAGS.mode == 'eval': _check_batch_replica_divisible(FLAGS.eval_batch_size, strategy) eval_lib.run_eval_loop( validation_dataset_builder=validation_dataset_builder, test_dataset_builder=test_dataset_builder, batch_size=FLAGS.eval_batch_size, model=model, trial_dir=trial_dir, train_steps=FLAGS.train_steps, strategy=strategy, metrics=metrics, checkpoint_step=FLAGS.checkpoint_step, hparams=hparams, ood_dataset_builder=ood_dataset_builder, ood_metrics=ood_metrics, mean_field_factor=FLAGS.gp_mean_field_factor) return if FLAGS.mode == 'train_and_eval': _check_batch_replica_divisible(FLAGS.eval_batch_size, strategy) steps_per_epoch = train_dataset_builder.num_examples // FLAGS.batch_size optimizer_kwargs = { k[len('optimizer_hparams_'):]: FLAGS[k].value for k in FLAGS if k.startswith('optimizer_hparams_') } optimizer_kwargs.update({ k[len('schedule_hparams_'):]: FLAGS[k].value for k in FLAGS if k.startswith('schedule_hparams_') }) optimizer = ub.optimizers.get( optimizer_name=FLAGS.optimizer, learning_rate_schedule=FLAGS.learning_rate_schedule, learning_rate=FLAGS.learning_rate, weight_decay=FLAGS.weight_decay, steps_per_epoch=steps_per_epoch, model=model, **optimizer_kwargs) train_lib.run_train_loop( train_dataset_builder=train_dataset_builder, validation_dataset_builder=validation_dataset_builder, test_dataset_builder=test_dataset_builder, batch_size=FLAGS.batch_size, eval_batch_size=FLAGS.eval_batch_size, model=model, optimizer=optimizer, eval_frequency=FLAGS.eval_frequency, log_frequency=FLAGS.log_frequency, trial_dir=trial_dir, train_steps=FLAGS.train_steps, mode=FLAGS.mode, strategy=strategy, metrics=metrics, hparams=hparams, ood_dataset_builder=ood_dataset_builder, ood_metrics=ood_metrics, focal_loss_gamma=FLAGS.focal_loss_gamma, mean_field_factor=FLAGS.gp_mean_field_factor)
def run(trial_dir: str, flag_string: Optional[str]): """Run the experiment. Args: trial_dir: String to the dir to write checkpoints to and read them from. flag_string: Optional string used to record what flags the job was run with. """ tf.random.set_seed(FLAGS.seed) np.random.seed(FLAGS.seed) if not FLAGS.eval_frequency: FLAGS.eval_frequency = FLAGS.log_frequency if FLAGS.eval_frequency % FLAGS.log_frequency != 0: raise ValueError( 'log_frequency ({}) must evenly divide eval_frequency ' '({}).'.format(FLAGS.log_frequency, FLAGS.eval_frequency)) strategy = ub.strategy_utils.get_strategy(FLAGS.tpu, FLAGS.use_tpu) with strategy.scope(): _maybe_setup_trial_dir(strategy, trial_dir, flag_string) # TODO(znado): pass all dataset and model kwargs. train_dataset_builder = ub.datasets.get( dataset_name=FLAGS.dataset_name, split='train', validation_percent=FLAGS.validation_percent, shuffle_buffer_size=FLAGS.shuffle_buffer_size) if FLAGS.validation_percent > 0: validation_dataset_builder = ub.datasets.get( dataset_name=FLAGS.dataset_name, split='validation', validation_percent=FLAGS.validation_percent) else: validation_dataset_builder = None test_dataset_builder = ub.datasets.get(dataset_name=FLAGS.dataset_name, split='test') model = models_lib.create_model(batch_size=FLAGS.batch_size, num_classes=10, distance_logits=FLAGS.distance_logits) loss_fn = loss_lib.get(FLAGS.loss_name, from_logits=True, dm_alpha=FLAGS.dm_alpha) if FLAGS.mode == 'eval': _check_batch_replica_divisible(FLAGS.eval_batch_size, strategy) eval_lib.run_eval_loop( validation_dataset_builder=validation_dataset_builder, test_dataset_builder=test_dataset_builder, batch_size=FLAGS.eval_batch_size, model=model, loss_fn=loss_fn, trial_dir=trial_dir, train_steps=FLAGS.train_steps, strategy=strategy, metric_names=['accuracy', 'loss'], checkpoint_step=FLAGS.checkpoint_step) return _check_batch_replica_divisible(FLAGS.batch_size, strategy) if FLAGS.mode == 'train_and_eval': _check_batch_replica_divisible(FLAGS.eval_batch_size, strategy) steps_per_epoch = train_dataset_builder.num_examples // FLAGS.batch_size optimizer_kwargs = { k[len('optimizer_hparams_'):]: FLAGS[k].value for k in FLAGS if k.startswith('optimizer_hparams_') } optimizer = ub.optimizers.get( optimizer_name=FLAGS.optimizer, learning_rate_schedule=FLAGS.learning_rate_schedule, learning_rate=FLAGS.learning_rate, weight_decay=FLAGS.weight_decay, steps_per_epoch=steps_per_epoch, **optimizer_kwargs) train_lib.run_train_loop( train_dataset_builder=train_dataset_builder, validation_dataset_builder=validation_dataset_builder, test_dataset_builder=test_dataset_builder, batch_size=FLAGS.batch_size, eval_batch_size=FLAGS.eval_batch_size, model=model, optimizer=optimizer, loss_fn=loss_fn, eval_frequency=FLAGS.eval_frequency, log_frequency=FLAGS.log_frequency, trial_dir=trial_dir, train_steps=FLAGS.train_steps, mode=FLAGS.mode, strategy=strategy, metric_names=['accuracy', 'loss'])