def _configure_hparams(logdir, dicts, metrics=["linear_classification_accuracy", "alignment", "uniformity"]): """ Set up the tensorboard hyperparameter interface :logdir: string; path to log directory :dicts: list of dictionaries containing hyperparameter values :metrics: list of strings; metric names """ metrics = [hp.Metric(m) for m in metrics] params = {} # for each parameter dictionary for d in dicts: # for each parameter: for k in d: # is it a categorical? if k in SPECIAL_HPARAMS: params[hp.HParam(k, hp.Discrete(SPECIAL_HPARAMS[k]))] = d[k] elif isinstance(d[k], bool): params[hp.HParam(k, hp.Discrete([True, False]))] = d[k] elif isinstance(d[k], int): params[hp.HParam(k, hp.IntInterval(1, 1000000))] = d[k] elif isinstance(d[k], float): params[hp.HParam(k, hp.RealInterval(0., 10000000.))] = d[k] # hparams_config = hp.hparams_config( hparams=list(params.keys()), metrics=metrics) # get a name for the run base_dir, run_name = os.path.split(logdir) if len(run_name) == 0: base_dir, run_name = os.path.split(base_dir) # record hyperparamers hp.hparams(params, trial_id=run_name)
def on_result(self, result): if self._file_writer is None: from tensorflow.python.eager import context from tensorboard.plugins.hparams import api as hp self._context = context self._file_writer = tf.summary.create_file_writer(self.logdir) with tf.device("/CPU:0"): with tf.summary.record_if(True), self._file_writer.as_default(): step = result.get( TIMESTEPS_TOTAL) or result[TRAINING_ITERATION] tmp = result.copy() if not self._hp_logged: if self.trial and self.trial.evaluated_params: try: hp.hparams(self.trial.evaluated_params, trial_id=self.trial.trial_id) except Exception as exc: logger.error("HParams failed with %s", exc) self._hp_logged = True for k in [ "config", "pid", "timestamp", TIME_TOTAL_S, TRAINING_ITERATION ]: if k in tmp: del tmp[k] # not useful to log these flat_result = flatten_dict(tmp, delimiter="/") path = ["ray", "tune"] for attr, value in flat_result.items(): if type(value) in VALID_SUMMARY_TYPES: tf.summary.scalar("/".join(path + [attr]), value, step=step) self._file_writer.flush()
def _setup_outputs(root_output_dir, experiment_name, hparam_dict): """Set up directories for experiment loops, write hyperparameters to disk.""" if not experiment_name: raise ValueError('experiment_name must be specified.') create_if_not_exists(root_output_dir) checkpoint_dir = os.path.join(root_output_dir, 'checkpoints', experiment_name) create_if_not_exists(checkpoint_dir) checkpoint_mngr = tff.simulation.FileCheckpointManager(checkpoint_dir) results_dir = os.path.join(root_output_dir, 'results', experiment_name) create_if_not_exists(results_dir) csv_file = os.path.join(results_dir, 'experiment.metrics.csv') metrics_mngr = tff.simulation.CSVMetricsManager(csv_file) summary_logdir = os.path.join(root_output_dir, 'logdir', experiment_name) create_if_not_exists(summary_logdir) tensorboard_mngr = tff.simulation.TensorBoardManager(summary_logdir) if hparam_dict: summary_writer = tf.summary.create_file_writer(summary_logdir) hparam_dict['metrics_file'] = metrics_mngr.metrics_filename hparams_file = os.path.join(results_dir, 'hparams.csv') utils_impl.atomic_write_series_to_csv(hparam_dict, hparams_file) with summary_writer.as_default(): hp.hparams({k: v for k, v in hparam_dict.items() if v is not None}) logging.info('Writing...') logging.info(' checkpoints to: %s', checkpoint_dir) logging.info(' metrics csv to: %s', metrics_mngr.metrics_filename) logging.info(' summaries to: %s', summary_logdir) return checkpoint_mngr, metrics_mngr, tensorboard_mngr
def train_model(run_dir,hparams): hp.hparams(hparams) [X_train, y_train] = create_data(data_unscaled=data, start_train=start_train, end_train=end_train, n_windows=hparams[HP_WINDOW], n_outputs=hparams[HP_OUTPUT]) [X_test, y_test] = create_data(data_unscaled=data, start_train=end_train, end_train=end_test, n_windows=hparams[HP_WINDOW], n_outputs=hparams[HP_OUTPUT]) tf.compat.v1.keras.backend.clear_session() model = Sequential() model.add(TimeDistributed(Masking(mask_value=0., input_shape=(hparams[HP_WINDOW], n_inputs+1)), input_shape=(n_company, hparams[HP_WINDOW], n_inputs+1))) model.add(TimeDistributed(LSTM(hparams[HP_NUM_UNITS], stateful=False, activation='tanh', return_sequences=True, input_shape=(hparams[HP_WINDOW], n_inputs+1), kernel_initializer='TruncatedNormal' ,bias_initializer=initializers.Constant(value=0.1), dropout=hparams[HP_DROPOUT] ,recurrent_dropout=hparams[HP_DROPOUT]))) model.add(TimeDistributed(LSTM(hparams[HP_NUM_UNITS], stateful=False, activation='tanh' ,return_sequences=False ,kernel_initializer='TruncatedNormal' ,bias_initializer=initializers.Constant(value=0.1) ,dropout=hparams[HP_DROPOUT], recurrent_dropout=hparams[HP_DROPOUT]))) #model.add(TimeDistributed(LSTM(hparams[HP_NUM_UNITS], stateful=False, activation='tanh' ,return_sequences=True ,kernel_initializer='TruncatedNormal' ,bias_initializer=initializers.Constant(value=0.1) ,dropout=hparams[HP_DROPOUT], recurrent_dropout=hparams[HP_DROPOUT]))) #model.add(TimeDistributed(LSTM(hparams[HP_NUM_UNITS], stateful=False, activation='tanh' ,return_sequences=False ,kernel_initializer='TruncatedNormal' ,bias_initializer=initializers.Constant(value=0.1) ,dropout=hparams[HP_DROPOUT], recurrent_dropout=hparams[HP_DROPOUT]))) model.add(Dense(units=1, activation='linear')) model.compile(optimizer=hparams[HP_OPTIMIZER], loss='mse') #get_weighted_loss(weights=weights)) # metrics=['mae']) model.build(input_shape=(None, n_company, hparams[HP_WINDOW], n_inputs+1)) model.summary() #model.fit(X_train, y_train[:, :, hparams[HP_WINDOW]-1:hparams[HP_WINDOW], 0], epochs=1, batch_size=batch_size, validation_data=(X_test, y_test[:, :, hparams[HP_WINDOW]-1:hparams[HP_WINDOW], 0])) model.fit(X_train, y_train[:, :, hparams[HP_WINDOW]-1:hparams[HP_WINDOW],0], epochs=200, batch_size=batch_size, validation_data=(X_test, y_test[:, :, hparams[HP_WINDOW]-1:hparams[HP_WINDOW],0]), callbacks=[ TensorBoard(log_dir=run_dir, histogram_freq=10, write_graph=True, write_grads=True, update_freq='epoch'), hp.KerasCallback(writer=run_dir, hparams=hparams)]) model.save('model_' + str(hparams[HP_NUM_UNITS]) + '_' + str(hparams[HP_DROPOUT]) + '_' + str(hparams[HP_OPTIMIZER]) + '_' + str(hparams[HP_WINDOW]) + '_' + str(hparams[HP_OUTPUT]) + '.h5') return 0
def log_summaries(self, step): with self.summary_writer.as_default(): tf.summary.scalar(self.training_episode_length.name, self.training_episode_length.result(), step=step) tf.summary.scalar(self.training_critic_loss.name, self.training_critic_loss.result(), step=step) tf.summary.scalar(self.training_actor_loss.name, self.training_actor_loss.result(), step=step) tf.summary.scalar(self.training_critic_signal_to_noise.name, self.training_critic_signal_to_noise.result(), step=step) tf.summary.scalar(self.testing_episode_length.name, self.testing_episode_length.result(), step=step) tf.summary.scalar(self.testing_total_episode_reward.name, self.testing_total_episode_reward.result(), step=step) tf.summary.scalar(self.testing_mean_total_episode_reward.name, self.testing_mean_total_episode_reward.result(), step=step) hparams = { **self.agent._hparams, **self.replay_buffer._hparams, **self._hparams, "env_id": self.env.spec.id, } hp.hparams(hparams) self.training_episode_length.reset_states() self.training_critic_loss.reset_states() self.training_actor_loss.reset_states() self.training_critic_signal_to_noise.reset_states() self.testing_episode_length.reset_states() self.testing_total_episode_reward.reset_states()
def run_eval_epoch( current_step: int, test_fn: EvalStepFn, test_dataset: tf.data.Dataset, test_summary_writer: tf.summary.SummaryWriter, val_fn: Optional[EvalStepFn] = None, val_dataset: Optional[tf.data.Dataset] = None, val_summary_writer: Optional[tf.summary.SummaryWriter] = None, ood_fn: Optional[EvalStepFn] = None, ood_dataset: Optional[tf.data.Dataset] = None, ood_summary_writer: Optional[tf.summary.SummaryWriter] = None, hparams: Optional[Dict[str, Any]] = None, ): """Run one evaluation epoch on the test and optionally validation splits.""" val_outputs_np = None if val_dataset: val_iterator = iter(val_dataset) val_outputs = val_fn(val_iterator) with val_summary_writer.as_default(): # pytype: disable=attribute-error if hparams: hp.hparams(hparams) for name, metric in val_outputs.items(): tf.summary.scalar(name, metric, step=current_step) val_outputs_np = {k: v.numpy() for k, v in val_outputs.items()} logging.info('Validation metrics for step %d: %s', current_step, val_outputs_np) if ood_dataset: ood_iterator = iter(ood_dataset) ood_outputs = ood_fn(ood_iterator) with ood_summary_writer.as_default(): # pytype: disable=attribute-error if hparams: hp.hparams(hparams) for name, metric in ood_outputs.items(): tf.summary.scalar(name, metric, step=current_step) ood_outputs_np = {k: v.numpy() for k, v in ood_outputs.items()} logging.info('OOD metrics for step %d: %s', current_step, ood_outputs_np) test_iterator = iter(test_dataset) test_outputs = test_fn(test_iterator) with test_summary_writer.as_default(): if hparams: hp.hparams(hparams) for name, metric in test_outputs.items(): tf.summary.scalar(name, metric, step=current_step) test_outputs_np = {k: v.numpy() for k, v in test_outputs.items()} return val_outputs_np, ood_outputs_np, test_outputs_np
def main(argv): del argv # unused arg tf.io.gfile.makedirs(FLAGS.output_dir) logging.info('Saving checkpoints at %s', FLAGS.output_dir) tf.random.set_seed(FLAGS.seed) # Initialize distribution strategy on flag-specified accelerator strategy = utils.init_distribution_strategy(FLAGS.force_use_cpu, FLAGS.use_gpu, FLAGS.tpu) use_tpu = not (FLAGS.force_use_cpu or FLAGS.use_gpu) train_batch_size = FLAGS.train_batch_size * FLAGS.num_cores eval_batch_size = FLAGS.eval_batch_size * FLAGS.num_cores # Reweighting loss for class imbalance class_reweight_mode = FLAGS.class_reweight_mode if class_reweight_mode == 'constant': class_weights = utils.get_diabetic_retinopathy_class_balance_weights() else: class_weights = None # As per the Kaggle challenge, we have split sizes: # train: 35,126 # validation: 10,906 (currently unused) # test: 42,670 ds_info = tfds.builder('diabetic_retinopathy_detection').info steps_per_epoch = ds_info.splits['train'].num_examples // train_batch_size steps_per_validation_eval = ( ds_info.splits['validation'].num_examples // eval_batch_size) steps_per_test_eval = ds_info.splits['test'].num_examples // eval_batch_size data_dir = FLAGS.data_dir dataset_train_builder = ub.datasets.get( 'diabetic_retinopathy_detection', split='train', data_dir=data_dir) dataset_train = dataset_train_builder.load(batch_size=train_batch_size) dataset_validation_builder = ub.datasets.get( 'diabetic_retinopathy_detection', split='validation', data_dir=data_dir, is_training=not FLAGS.use_validation) validation_batch_size = ( eval_batch_size if FLAGS.use_validation else train_batch_size) dataset_validation = dataset_validation_builder.load( batch_size=validation_batch_size) if FLAGS.use_validation: dataset_validation = strategy.experimental_distribute_dataset( dataset_validation) else: # Note that this will not create any mixed batches of train and validation # images. dataset_train = dataset_train.concatenate(dataset_validation) dataset_train = strategy.experimental_distribute_dataset(dataset_train) dataset_test_builder = ub.datasets.get( 'diabetic_retinopathy_detection', split='test', data_dir=data_dir) dataset_test = dataset_test_builder.load(batch_size=eval_batch_size) dataset_test = strategy.experimental_distribute_dataset(dataset_test) if FLAGS.use_bfloat16: policy = tf.keras.mixed_precision.experimental.Policy('mixed_bfloat16') tf.keras.mixed_precision.experimental.set_policy(policy) summary_writer = tf.summary.create_file_writer( os.path.join(FLAGS.output_dir, 'summaries')) with strategy.scope(): logging.info('Building Keras ResNet-50 deterministic model.') model = ub.models.resnet50_deterministic( input_shape=utils.load_input_shape(dataset_train), num_classes=1) # binary classification task logging.info('Model input shape: %s', model.input_shape) logging.info('Model output shape: %s', model.output_shape) logging.info('Model number of weights: %s', model.count_params()) base_lr = FLAGS.base_learning_rate if FLAGS.lr_schedule == 'step': lr_decay_epochs = [ (int(start_epoch_str) * FLAGS.train_epochs) // DEFAULT_NUM_EPOCHS for start_epoch_str in FLAGS.lr_decay_epochs ] lr_schedule = ub.schedules.WarmUpPiecewiseConstantSchedule( steps_per_epoch, base_lr, decay_ratio=FLAGS.lr_decay_ratio, decay_epochs=lr_decay_epochs, warmup_epochs=FLAGS.lr_warmup_epochs) else: lr_schedule = ub.schedules.WarmUpPolynomialSchedule( base_lr, end_learning_rate=FLAGS.final_decay_factor * base_lr, decay_steps=( steps_per_epoch * (FLAGS.train_epochs - FLAGS.lr_warmup_epochs)), warmup_steps=steps_per_epoch * FLAGS.lr_warmup_epochs, decay_power=1.0) optimizer = tf.keras.optimizers.SGD( lr_schedule, momentum=1.0 - FLAGS.one_minus_momentum, nesterov=True) metrics = utils.get_diabetic_retinopathy_base_metrics( use_tpu=use_tpu, num_bins=FLAGS.num_bins, use_validation=FLAGS.use_validation) checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer) latest_checkpoint = tf.train.latest_checkpoint(FLAGS.output_dir) initial_epoch = 0 if latest_checkpoint: # checkpoint.restore must be within a strategy.scope() # so that optimizer slot variables are mirrored. checkpoint.restore(latest_checkpoint) logging.info('Loaded checkpoint %s', latest_checkpoint) initial_epoch = optimizer.iterations.numpy() // steps_per_epoch # Define metrics outside the accelerator scope for CPU eval. # This will cause an error on TPU. if not use_tpu: metrics.update( utils.get_diabetic_retinopathy_cpu_metrics( use_validation=FLAGS.use_validation)) metrics.update({'test/ms_per_example': tf.keras.metrics.Mean()}) # Initialize loss function based on class reweighting setting loss_fn = utils.get_diabetic_retinopathy_loss_fn( class_reweight_mode=class_reweight_mode, class_weights=class_weights) @tf.function def train_step(iterator): """Training step function.""" def step_fn(inputs): """Per-replica step function.""" images = inputs['features'] labels = inputs['labels'] # For minibatch class reweighting, initialize per-batch loss function if class_reweight_mode == 'minibatch': batch_loss_fn = utils.get_minibatch_reweighted_loss_fn(labels=labels) else: batch_loss_fn = loss_fn with tf.GradientTape() as tape: logits = model(images, training=True) if FLAGS.use_bfloat16: logits = tf.cast(logits, tf.float32) negative_log_likelihood = tf.reduce_mean( batch_loss_fn( y_true=tf.expand_dims(labels, axis=-1), y_pred=logits, from_logits=True)) l2_loss = sum(model.losses) loss = negative_log_likelihood + (FLAGS.l2 * l2_loss) # Scale the loss given the TPUStrategy will reduce sum all gradients. scaled_loss = loss / strategy.num_replicas_in_sync grads = tape.gradient(scaled_loss, model.trainable_variables) optimizer.apply_gradients(zip(grads, model.trainable_variables)) probs = tf.squeeze(tf.nn.sigmoid(logits)) metrics['train/loss'].update_state(loss) metrics['train/negative_log_likelihood'].update_state( negative_log_likelihood) metrics['train/accuracy'].update_state(labels, probs) metrics['train/auprc'].update_state(labels, probs) metrics['train/auroc'].update_state(labels, probs) if not use_tpu: metrics['train/ece'].add_batch(probs, label=labels) for _ in tf.range(tf.cast(steps_per_epoch, tf.int32)): strategy.run(step_fn, args=(next(iterator),)) @tf.function def test_step(iterator, dataset_split, num_steps): """Evaluation step function.""" def step_fn(inputs): """Per-replica step function.""" images = inputs['features'] labels = inputs['labels'] logits = model(images, training=False) if FLAGS.use_bfloat16: logits = tf.cast(logits, tf.float32) negative_log_likelihood = tf.reduce_mean( tf.keras.losses.binary_crossentropy( y_true=tf.expand_dims(labels, axis=-1), y_pred=logits, from_logits=True)) probs = tf.squeeze(tf.nn.sigmoid(logits)) metrics[dataset_split + '/negative_log_likelihood'].update_state( negative_log_likelihood) metrics[dataset_split + '/accuracy'].update_state(labels, probs) metrics['test/accuracy'].update_state(labels, probs) metrics[dataset_split + '/auprc'].update_state(labels, probs) metrics[dataset_split + '/auroc'].update_state(labels, probs) if not use_tpu: metrics[dataset_split + '/ece'].add_batch(probs, label=labels) for _ in tf.range(tf.cast(num_steps, tf.int32)): strategy.run(step_fn, args=(next(iterator),)) start_time = time.time() train_iterator = iter(dataset_train) for epoch in range(initial_epoch, FLAGS.train_epochs): logging.info('Starting to run epoch: %s', epoch + 1) train_step(train_iterator) current_step = (epoch + 1) * steps_per_epoch max_steps = steps_per_epoch * FLAGS.train_epochs time_elapsed = time.time() - start_time steps_per_sec = float(current_step) / time_elapsed eta_seconds = (max_steps - current_step) / steps_per_sec message = ('{:.1%} completion: epoch {:d}/{:d}. {:.1f} steps/s. ' 'ETA: {:.0f} min. Time elapsed: {:.0f} min'.format( current_step / max_steps, epoch + 1, FLAGS.train_epochs, steps_per_sec, eta_seconds / 60, time_elapsed / 60)) logging.info(message) if FLAGS.use_validation: validation_iterator = iter(dataset_validation) logging.info('Starting to run validation eval ay epoch: %s', epoch + 1) test_step(validation_iterator, 'validation', steps_per_validation_eval) test_iterator = iter(dataset_test) logging.info('Starting to run test eval at epoch: %s', epoch + 1) test_start_time = time.time() test_step(test_iterator, 'test', steps_per_test_eval) ms_per_example = (time.time() - test_start_time) * 1e6 / eval_batch_size metrics['test/ms_per_example'].update_state(ms_per_example) # Log and write to summary the epoch metrics utils.log_epoch_metrics(metrics=metrics, use_tpu=use_tpu) total_results = {name: metric.result() for name, metric in metrics.items()} # Metrics from Robustness Metrics (like ECE) will return a dict with a # single key/value, instead of a scalar. total_results = { k: (list(v.values())[0] if isinstance(v, dict) else v) for k, v in total_results.items() } with summary_writer.as_default(): for name, result in total_results.items(): tf.summary.scalar(name, result, step=epoch + 1) for metric in metrics.values(): metric.reset_states() if (FLAGS.checkpoint_interval > 0 and (epoch + 1) % FLAGS.checkpoint_interval == 0): checkpoint_name = checkpoint.save( os.path.join(FLAGS.output_dir, 'checkpoint')) logging.info('Saved checkpoint to %s', checkpoint_name) # TODO(nband): debug checkpointing # Also save Keras model, due to checkpoint.save issue keras_model_name = os.path.join(FLAGS.output_dir, f'keras_model_{epoch + 1}') model.save(keras_model_name) logging.info('Saved keras model to %s', keras_model_name) final_checkpoint_name = checkpoint.save( os.path.join(FLAGS.output_dir, 'checkpoint')) logging.info('Saved last checkpoint to %s', final_checkpoint_name) keras_model_name = os.path.join(FLAGS.output_dir, f'keras_model_{FLAGS.train_epochs}') model.save(keras_model_name) logging.info('Saved keras model to %s', keras_model_name) with summary_writer.as_default(): hp.hparams({ 'base_learning_rate': FLAGS.base_learning_rate, 'one_minus_momentum': FLAGS.one_minus_momentum, 'l2': FLAGS.l2, })
def run(run_dir, hparams): with tf.summary.create_file_writer(run_dir).as_default(): hp.hparams(hparams) # record the values used in this trial accuracy = train_test_model(hparams) tf.summary.scalar(METRIC_ACCURACY, accuracy, step=1)
def main(argv): del argv # unused arg tf.random.set_seed(FLAGS.seed) per_core_batch_size = FLAGS.per_core_batch_size // FLAGS.ensemble_size batch_size = per_core_batch_size * FLAGS.num_cores steps_per_epoch = APPROX_IMAGENET_TRAIN_IMAGES // batch_size steps_per_eval = IMAGENET_VALIDATION_IMAGES // batch_size logging.info('Saving checkpoints at %s', FLAGS.output_dir) if FLAGS.use_gpu: logging.info('Use GPU') strategy = tf.distribute.MirroredStrategy() else: logging.info('Use TPU at %s', FLAGS.tpu if FLAGS.tpu is not None else 'local') resolver = tf.distribute.cluster_resolver.TPUClusterResolver( tpu=FLAGS.tpu) tf.config.experimental_connect_to_cluster(resolver) tf.tpu.experimental.initialize_tpu_system(resolver) strategy = tf.distribute.TPUStrategy(resolver) train_builder = ub.datasets.ImageNetDataset( split=tfds.Split.TRAIN, use_bfloat16=FLAGS.use_bfloat16) train_dataset = train_builder.load(batch_size=batch_size, strategy=strategy) test_builder = ub.datasets.ImageNetDataset(split=tfds.Split.TEST, use_bfloat16=FLAGS.use_bfloat16) clean_test_dataset = test_builder.load(batch_size=batch_size, strategy=strategy) test_datasets = {'clean': clean_test_dataset} if FLAGS.corruptions_interval > 0: corruption_types, max_intensity = utils.load_corrupted_test_info() for name in corruption_types: for intensity in range(1, max_intensity + 1): dataset_name = '{0}_{1}'.format(name, intensity) dataset = utils.load_corrupted_test_dataset( batch_size=batch_size, corruption_name=name, corruption_intensity=intensity, use_bfloat16=FLAGS.use_bfloat16) test_datasets[dataset_name] = ( strategy.experimental_distribute_dataset(dataset)) if FLAGS.use_bfloat16: policy = tf.keras.mixed_precision.experimental.Policy('mixed_bfloat16') tf.keras.mixed_precision.experimental.set_policy(policy) summary_writer = tf.summary.create_file_writer( os.path.join(FLAGS.output_dir, 'summaries')) with strategy.scope(): logging.info('Building Keras ResNet-50 model') model = ub.models.resnet50_het_rank1( input_shape=(224, 224, 3), num_classes=NUM_CLASSES, alpha_initializer=FLAGS.alpha_initializer, gamma_initializer=FLAGS.gamma_initializer, alpha_regularizer=FLAGS.alpha_regularizer, gamma_regularizer=FLAGS.gamma_regularizer, use_additive_perturbation=FLAGS.use_additive_perturbation, ensemble_size=FLAGS.ensemble_size, random_sign_init=FLAGS.random_sign_init, dropout_rate=FLAGS.dropout_rate, prior_stddev=FLAGS.prior_stddev, use_tpu=not FLAGS.use_gpu, use_ensemble_bn=FLAGS.use_ensemble_bn, num_factors=FLAGS.num_factors, temperature=FLAGS.temperature, num_mc_samples=FLAGS.num_mc_samples) logging.info('Model input shape: %s', model.input_shape) logging.info('Model output shape: %s', model.output_shape) logging.info('Model number of weights: %s', model.count_params()) # Scale learning rate and decay epochs by vanilla settings. base_lr = FLAGS.base_learning_rate * batch_size / 256 decay_epochs = [ (FLAGS.train_epochs * 30) // 90, (FLAGS.train_epochs * 60) // 90, (FLAGS.train_epochs * 80) // 90, ] learning_rate = ub.schedules.WarmUpPiecewiseConstantSchedule( steps_per_epoch=steps_per_epoch, base_learning_rate=base_lr, decay_ratio=0.1, decay_epochs=decay_epochs, warmup_epochs=5) optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate, momentum=1.0 - FLAGS.one_minus_momentum, nesterov=True) metrics = { 'train/negative_log_likelihood': tf.keras.metrics.Mean(), 'train/kl': tf.keras.metrics.Mean(), 'train/kl_scale': tf.keras.metrics.Mean(), 'train/elbo': tf.keras.metrics.Mean(), 'train/loss': tf.keras.metrics.Mean(), 'train/accuracy': tf.keras.metrics.SparseCategoricalAccuracy(), 'train/ece': rm.metrics.ExpectedCalibrationError(num_bins=FLAGS.num_bins), 'train/diversity': rm.metrics.AveragePairwiseDiversity(), 'test/negative_log_likelihood': tf.keras.metrics.Mean(), 'test/kl': tf.keras.metrics.Mean(), 'test/elbo': tf.keras.metrics.Mean(), 'test/accuracy': tf.keras.metrics.SparseCategoricalAccuracy(), 'test/ece': rm.metrics.ExpectedCalibrationError(num_bins=FLAGS.num_bins), 'test/diversity': rm.metrics.AveragePairwiseDiversity(), 'test/member_accuracy_mean': (tf.keras.metrics.SparseCategoricalAccuracy()), 'test/member_ece_mean': rm.metrics.ExpectedCalibrationError(num_bins=FLAGS.num_bins), } if FLAGS.corruptions_interval > 0: corrupt_metrics = {} for intensity in range(1, max_intensity + 1): for corruption in corruption_types: dataset_name = '{0}_{1}'.format(corruption, intensity) corrupt_metrics['test/nll_{}'.format(dataset_name)] = ( tf.keras.metrics.Mean()) corrupt_metrics['test/kl_{}'.format(dataset_name)] = ( tf.keras.metrics.Mean()) corrupt_metrics['test/elbo_{}'.format(dataset_name)] = ( tf.keras.metrics.Mean()) corrupt_metrics['test/accuracy_{}'.format( dataset_name)] = ( tf.keras.metrics.SparseCategoricalAccuracy()) corrupt_metrics['test/ece_{}'.format(dataset_name)] = ( rm.metrics.ExpectedCalibrationError( num_bins=FLAGS.num_bins)) if FLAGS.ensemble_size > 1: for i in range(FLAGS.ensemble_size): metrics['test/nll_member_{}'.format( i)] = tf.keras.metrics.Mean() metrics['test/accuracy_member_{}'.format(i)] = ( tf.keras.metrics.SparseCategoricalAccuracy()) logging.info('Finished building Keras ResNet-50 model') checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer) latest_checkpoint = tf.train.latest_checkpoint(FLAGS.output_dir) initial_epoch = 0 if latest_checkpoint: # checkpoint.restore must be within a strategy.scope() so that optimizer # slot variables are mirrored. checkpoint.restore(latest_checkpoint) logging.info('Loaded checkpoint %s', latest_checkpoint) initial_epoch = optimizer.iterations.numpy() // steps_per_epoch def compute_l2_loss(model): filtered_variables = [] for var in model.trainable_variables: # Apply l2 on the BN parameters and bias terms. This # excludes only fast weight approximate posterior/prior parameters, # but pay caution to their naming scheme. if ('kernel' in var.name or 'batch_norm' in var.name or 'bias' in var.name): filtered_variables.append(tf.reshape(var, (-1, ))) l2_loss = FLAGS.l2 * 2 * tf.nn.l2_loss( tf.concat(filtered_variables, axis=0)) return l2_loss @tf.function def train_step(iterator): """Training StepFn.""" def step_fn(inputs): """Per-Replica StepFn.""" images = inputs['features'] labels = inputs['labels'] if FLAGS.ensemble_size > 1: images = tf.tile(images, [FLAGS.ensemble_size, 1, 1, 1]) labels = tf.tile(labels, [FLAGS.ensemble_size]) with tf.GradientTape() as tape: logits = model(images, training=True) if FLAGS.use_bfloat16: logits = tf.cast(logits, tf.float32) probs = tf.nn.softmax(logits) if FLAGS.ensemble_size > 1: per_probs = tf.reshape( probs, tf.concat([[FLAGS.ensemble_size, -1], probs.shape[1:]], 0)) metrics['train/diversity'].add_batch(per_probs) negative_log_likelihood = tf.reduce_mean( tf.keras.losses.sparse_categorical_crossentropy( labels, logits, from_logits=True)) l2_loss = compute_l2_loss(model) kl = sum(model.losses) / APPROX_IMAGENET_TRAIN_IMAGES kl_scale = tf.cast(optimizer.iterations + 1, kl.dtype) kl_scale /= steps_per_epoch * FLAGS.kl_annealing_epochs kl_scale = tf.minimum(1., kl_scale) kl_loss = kl_scale * kl # Scale the loss given the TPUStrategy will reduce sum all gradients. loss = negative_log_likelihood + l2_loss + kl_loss scaled_loss = loss / strategy.num_replicas_in_sync elbo = -(negative_log_likelihood + l2_loss + kl) grads = tape.gradient(scaled_loss, model.trainable_variables) # Separate learning rate implementation. if FLAGS.fast_weight_lr_multiplier != 1.0: grads_and_vars = [] for grad, var in zip(grads, model.trainable_variables): # Apply different learning rate on the fast weights. This excludes BN # and slow weights, but pay caution to the naming scheme. if ('batch_norm' not in var.name and 'kernel' not in var.name): grads_and_vars.append( (grad * FLAGS.fast_weight_lr_multiplier, var)) else: grads_and_vars.append((grad, var)) optimizer.apply_gradients(grads_and_vars) else: optimizer.apply_gradients(zip(grads, model.trainable_variables)) metrics['train/negative_log_likelihood'].update_state( negative_log_likelihood) metrics['train/kl'].update_state(kl) metrics['train/kl_scale'].update_state(kl_scale) metrics['train/elbo'].update_state(elbo) metrics['train/loss'].update_state(loss) metrics['train/accuracy'].update_state(labels, logits) metrics['train/ece'].add_batch(probs, label=labels) for _ in tf.range(tf.cast(steps_per_epoch, tf.int32)): strategy.run(step_fn, args=(next(iterator), )) @tf.function def test_step(iterator, dataset_name): """Evaluation StepFn.""" def step_fn(inputs): """Per-Replica StepFn.""" images = inputs['features'] labels = inputs['labels'] if FLAGS.ensemble_size > 1: images = tf.tile(images, [FLAGS.ensemble_size, 1, 1, 1]) logits = tf.reshape([ model(images, training=False) for _ in range(FLAGS.num_eval_samples) ], [FLAGS.num_eval_samples, FLAGS.ensemble_size, -1, NUM_CLASSES]) if FLAGS.use_bfloat16: logits = tf.cast(logits, tf.float32) all_probs = tf.nn.softmax(logits) probs = tf.math.reduce_mean(all_probs, axis=[0, 1]) # marginalize # Negative log marginal likelihood computed in a numerically-stable way. labels_broadcasted = tf.broadcast_to( labels, [FLAGS.num_eval_samples, FLAGS.ensemble_size, labels.shape[0]]) log_likelihoods = -tf.keras.losses.sparse_categorical_crossentropy( labels_broadcasted, logits, from_logits=True) negative_log_likelihood = tf.reduce_mean( -tf.reduce_logsumexp(log_likelihoods, axis=[0, 1]) + tf.math.log(float(FLAGS.num_eval_samples * FLAGS.ensemble_size))) l2_loss = compute_l2_loss(model) kl = sum(model.losses) / IMAGENET_VALIDATION_IMAGES elbo = -(negative_log_likelihood + l2_loss + kl) if dataset_name == 'clean': if FLAGS.ensemble_size > 1: per_probs = tf.reduce_mean(all_probs, axis=0) # marginalize samples metrics['test/diversity'].add_batch(per_probs) for i in range(FLAGS.ensemble_size): member_probs = per_probs[i] member_loss = tf.keras.losses.sparse_categorical_crossentropy( labels, member_probs) metrics['test/nll_member_{}'.format(i)].update_state( member_loss) metrics['test/accuracy_member_{}'.format( i)].update_state(labels, member_probs) metrics['test/member_accuracy_mean'].update_state( labels, member_probs) metrics['test/member_ece_mean'].add_batch(member_probs, label=labels) metrics['test/negative_log_likelihood'].update_state( negative_log_likelihood) metrics['test/kl'].update_state(kl) metrics['test/elbo'].update_state(elbo) metrics['test/accuracy'].update_state(labels, probs) metrics['test/ece'].add_batch(probs, label=labels) else: corrupt_metrics['test/nll_{}'.format( dataset_name)].update_state(negative_log_likelihood) corrupt_metrics['test/kl_{}'.format( dataset_name)].update_state(kl) corrupt_metrics['test/elbo_{}'.format( dataset_name)].update_state(elbo) corrupt_metrics['test/accuracy_{}'.format( dataset_name)].update_state(labels, probs) corrupt_metrics['test/ece_{}'.format(dataset_name)].add_batch( probs, label=labels) for _ in tf.range(tf.cast(steps_per_eval, tf.int32)): strategy.run(step_fn, args=(next(iterator), )) train_iterator = iter(train_dataset) start_time = time.time() for epoch in range(initial_epoch, FLAGS.train_epochs): logging.info('Starting to run epoch: %s', epoch) train_step(train_iterator) current_step = (epoch + 1) * steps_per_epoch max_steps = steps_per_epoch * FLAGS.train_epochs time_elapsed = time.time() - start_time steps_per_sec = float(current_step) / time_elapsed eta_seconds = (max_steps - current_step) / steps_per_sec message = ('{:.1%} completion: epoch {:d}/{:d}. {:.1f} steps/s. ' 'ETA: {:.0f} min. Time elapsed: {:.0f} min'.format( current_step / max_steps, epoch + 1, FLAGS.train_epochs, steps_per_sec, eta_seconds / 60, time_elapsed / 60)) logging.info(message) datasets_to_evaluate = {'clean': test_datasets['clean']} if (FLAGS.corruptions_interval > 0 and (epoch + 1) % FLAGS.corruptions_interval == 0): datasets_to_evaluate = test_datasets for dataset_name, test_dataset in datasets_to_evaluate.items(): logging.info('Testing on dataset %s', dataset_name) test_iterator = iter(test_dataset) logging.info('Starting to run eval at epoch: %s', epoch) test_step(test_iterator, dataset_name) logging.info('Done with testing on %s', dataset_name) corrupt_results = {} if (FLAGS.corruptions_interval > 0 and (epoch + 1) % FLAGS.corruptions_interval == 0): corrupt_results = utils.aggregate_corrupt_metrics( corrupt_metrics, corruption_types, max_intensity, FLAGS.alexnet_errors_path) logging.info('Train Loss: %.4f, Accuracy: %.2f%%', metrics['train/loss'].result(), metrics['train/accuracy'].result() * 100) logging.info('Test NLL: %.4f, Accuracy: %.2f%%', metrics['test/negative_log_likelihood'].result(), metrics['test/accuracy'].result() * 100) for i in range(FLAGS.ensemble_size): logging.info( 'Member %d Test Loss: %.4f, Accuracy: %.2f%%', i, metrics['test/nll_member_{}'.format(i)].result(), metrics['test/accuracy_member_{}'.format(i)].result() * 100) total_results = { name: metric.result() for name, metric in metrics.items() } total_results.update(corrupt_results) # Results from Robustness Metrics themselves return a dict, so flatten them. total_results = utils.flatten_dictionary(total_results) with summary_writer.as_default(): for name, result in total_results.items(): tf.summary.scalar(name, result, step=epoch + 1) for metric in metrics.values(): metric.reset_states() if (FLAGS.checkpoint_interval > 0 and (epoch + 1) % FLAGS.checkpoint_interval == 0): checkpoint_name = checkpoint.save( os.path.join(FLAGS.output_dir, 'checkpoint')) logging.info('Saved checkpoint to %s', checkpoint_name) final_checkpoint_name = checkpoint.save( os.path.join(FLAGS.output_dir, 'checkpoint')) logging.info('Saved last checkpoint to %s', final_checkpoint_name) with summary_writer.as_default(): hp.hparams({ 'base_learning_rate': FLAGS.base_learning_rate, 'one_minus_momentum': FLAGS.one_minus_momentum, 'l2': FLAGS.l2, 'fast_weight_lr_multiplier': FLAGS.fast_weight_lr_multiplier, 'num_eval_samples': FLAGS.num_eval_samples, })
def run(run_dir, hparams): with tf.summary.create_file_writer(run_dir).as_default(): hp.hparams(hparams) accuracy_score = train_test_model(hparams) tf.summary.scalar(METRIC_ACCURACY, accuracy_score, step=1)
def write_hparams(self, hparams: Mapping[str, Any]): with self._summary_writer.as_default(): hparams_api.hparams(dict(utils.flatten_dict(hparams)))
def random_hparam_search(cfg, data, callbacks, log_dir): ''' Conduct a random hyperparameter search over the ranges given for the hyperparameters in config.yml and log results in TensorBoard. Model is trained x times for y random combinations of hyperparameters. :param cfg: Project config :param data: Dict containing the partitioned datasets :param callbacks: List of callbacks for Keras model (excluding TensorBoard) :param log_dir: Base directory in which to store logs :return: (Last model trained, resultant test set metrics, test data generator) ''' # Define HParam objects for each hyperparameter we wish to tune. hp_ranges = cfg['HP_SEARCH']['RANGES'] HPARAMS = [] HPARAMS.append(hp.HParam('KERNEL_SIZE', hp.Discrete(hp_ranges['KERNEL_SIZE']))) HPARAMS.append(hp.HParam('MAXPOOL_SIZE', hp.Discrete(hp_ranges['MAXPOOL_SIZE']))) HPARAMS.append(hp.HParam('INIT_FILTERS', hp.Discrete(hp_ranges['INIT_FILTERS']))) HPARAMS.append(hp.HParam('FILTER_EXP_BASE', hp.IntInterval(hp_ranges['FILTER_EXP_BASE'][0], hp_ranges['FILTER_EXP_BASE'][1]))) HPARAMS.append(hp.HParam('NODES_DENSE0', hp.Discrete(hp_ranges['NODES_DENSE0']))) HPARAMS.append(hp.HParam('CONV_BLOCKS', hp.IntInterval(hp_ranges['CONV_BLOCKS'][0], hp_ranges['CONV_BLOCKS'][1]))) HPARAMS.append(hp.HParam('DROPOUT', hp.Discrete(hp_ranges['DROPOUT']))) HPARAMS.append(hp.HParam('LR', hp.RealInterval(hp_ranges['LR'][0], hp_ranges['LR'][1]))) HPARAMS.append(hp.HParam('OPTIMIZER', hp.Discrete(hp_ranges['OPTIMIZER']))) HPARAMS.append(hp.HParam('L2_LAMBDA', hp.Discrete(hp_ranges['L2_LAMBDA']))) HPARAMS.append(hp.HParam('BATCH_SIZE', hp.Discrete(hp_ranges['BATCH_SIZE']))) HPARAMS.append(hp.HParam('IMB_STRATEGY', hp.Discrete(hp_ranges['IMB_STRATEGY']))) # Define test set metrics that we wish to log to TensorBoard for each training run HP_METRICS = [hp.Metric(metric, display_name='Test ' + metric) for metric in cfg['HP_SEARCH']['METRICS']] # Configure TensorBoard to log the results with tf.summary.create_file_writer(log_dir).as_default(): hp.hparams_config(hparams=HPARAMS, metrics=HP_METRICS) # Complete a number of training runs at different hparam values and log the results. repeats_per_combo = cfg['HP_SEARCH']['REPEATS'] # Number of times to train the model per combination of hparams num_combos = cfg['HP_SEARCH']['COMBINATIONS'] # Number of random combinations of hparams to attempt num_sessions = num_combos * repeats_per_combo # Total number of runs in this experiment model_type = 'DCNN_BINARY' if cfg['TRAIN']['CLASS_MODE'] == 'binary' else 'DCNN_MULTICLASS' trial_id = 0 for group_idx in range(num_combos): rand = random.Random() HPARAMS = {h: h.domain.sample_uniform(rand) for h in HPARAMS} hparams = {h.name: HPARAMS[h] for h in HPARAMS} # To pass to model definition for repeat_idx in range(repeats_per_combo): trial_id += 1 print("Running training session %d/%d" % (trial_id, num_sessions)) print("Hparam values: ", {h.name: HPARAMS[h] for h in HPARAMS}) trial_logdir = os.path.join(log_dir, str(trial_id)) # Need specific logdir for each trial callbacks_hp = callbacks + [TensorBoard(log_dir=trial_logdir, profile_batch=0, write_graph=False)] # Set values of hyperparameters for this run in config file. for h in hparams: if h in ['LR', 'L2_LAMBDA']: val = 10 ** hparams[h] # These hyperparameters are sampled on the log scale. else: val = hparams[h] cfg['NN'][model_type][h] = val # Set some hyperparameters that are not specified in model definition. cfg['TRAIN']['BATCH_SIZE'] = hparams['BATCH_SIZE'] cfg['TRAIN']['IMB_STRATEGY'] = hparams['IMB_STRATEGY'] # Run a training session and log the performance metrics on the test set to HParams dashboard in TensorBoard with tf.summary.create_file_writer(trial_logdir).as_default(): hp.hparams(HPARAMS, trial_id=str(trial_id)) model, test_metrics, test_generator = train_model(cfg, data, callbacks_hp, verbose=0) for metric in HP_METRICS: if metric._tag in test_metrics: tf.summary.scalar(metric._tag, test_metrics[metric._tag], step=1) # Log test metric return
def main(argv): del argv # unused arg tf.io.gfile.makedirs(FLAGS.output_dir) logging.info('Saving checkpoints at %s', FLAGS.output_dir) tf.random.set_seed(FLAGS.seed) per_core_batch_size = FLAGS.per_core_batch_size // FLAGS.ensemble_size batch_size = per_core_batch_size * FLAGS.num_cores data_dir = FLAGS.data_dir if FLAGS.use_gpu: logging.info('Use GPU') strategy = tf.distribute.MirroredStrategy() else: logging.info('Use TPU at %s', FLAGS.tpu if FLAGS.tpu is not None else 'local') resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu=FLAGS.tpu) tf.config.experimental_connect_to_cluster(resolver) tf.tpu.experimental.initialize_tpu_system(resolver) strategy = tf.distribute.TPUStrategy(resolver) train_builder = ub.datasets.get( FLAGS.dataset, data_dir=data_dir, download_data=FLAGS.download_data, split=tfds.Split.TRAIN, validation_percent=1. - FLAGS.train_proportion) train_dataset = train_builder.load(batch_size=batch_size) train_dataset = strategy.experimental_distribute_dataset(train_dataset) validation_dataset = None steps_per_validation = 0 if FLAGS.train_proportion < 1.0: validation_builder = ub.datasets.get( FLAGS.dataset, data_dir=data_dir, split=tfds.Split.VALIDATION, validation_percent=1. - FLAGS.train_proportion) validation_dataset = validation_builder.load(batch_size=batch_size) validation_dataset = strategy.experimental_distribute_dataset( validation_dataset) steps_per_validation = validation_builder.num_examples // batch_size clean_test_builder = ub.datasets.get( FLAGS.dataset, data_dir=data_dir, split=tfds.Split.TEST) clean_test_dataset = clean_test_builder.load(batch_size=batch_size) test_datasets = { 'clean': strategy.experimental_distribute_dataset(clean_test_dataset), } steps_per_epoch = train_builder.num_examples // batch_size steps_per_eval = clean_test_builder.num_examples // batch_size num_classes = 100 if FLAGS.dataset == 'cifar100' else 10 if FLAGS.corruptions_interval > 0: if FLAGS.dataset == 'cifar100': data_dir = FLAGS.cifar100_c_path corruption_types, _ = utils.load_corrupted_test_info(FLAGS.dataset) for corruption_type in corruption_types: for severity in range(1, 6): dataset = ub.datasets.get( f'{FLAGS.dataset}_corrupted', corruption_type=corruption_type, data_dir=data_dir, severity=severity, split=tfds.Split.TEST).load(batch_size=batch_size) test_datasets[f'{corruption_type}_{severity}'] = ( strategy.experimental_distribute_dataset(dataset)) summary_writer = tf.summary.create_file_writer( os.path.join(FLAGS.output_dir, 'summaries')) with strategy.scope(): logging.info('Building Keras model') model = ub.models.wide_resnet_batchensemble( input_shape=(32, 32, 3), depth=28, width_multiplier=10, num_classes=num_classes, ensemble_size=FLAGS.ensemble_size, random_sign_init=FLAGS.random_sign_init, l2=FLAGS.l2) logging.info('Model input shape: %s', model.input_shape) logging.info('Model output shape: %s', model.output_shape) logging.info('Model number of weights: %s', model.count_params()) # Linearly scale learning rate and the decay epochs by vanilla settings. base_lr = FLAGS.base_learning_rate * batch_size / 128 lr_decay_epochs = [(int(start_epoch_str) * FLAGS.train_epochs) // 200 for start_epoch_str in FLAGS.lr_decay_epochs] lr_schedule = ub.schedules.WarmUpPiecewiseConstantSchedule( steps_per_epoch, base_lr, decay_ratio=FLAGS.lr_decay_ratio, decay_epochs=lr_decay_epochs, warmup_epochs=FLAGS.lr_warmup_epochs) optimizer = tf.keras.optimizers.SGD(lr_schedule, momentum=1.0 - FLAGS.one_minus_momentum, nesterov=True) metrics = { 'train/negative_log_likelihood': tf.keras.metrics.Mean(), 'train/accuracy': tf.keras.metrics.SparseCategoricalAccuracy(), 'train/loss': tf.keras.metrics.Mean(), 'train/ece': rm.metrics.ExpectedCalibrationError( num_bins=FLAGS.num_bins), 'test/negative_log_likelihood': tf.keras.metrics.Mean(), 'test/accuracy': tf.keras.metrics.SparseCategoricalAccuracy(), 'test/ece': rm.metrics.ExpectedCalibrationError( num_bins=FLAGS.num_bins), } eval_dataset_splits = ['test'] if validation_dataset: metrics.update({ 'validation/negative_log_likelihood': tf.keras.metrics.Mean(), 'validation/accuracy': tf.keras.metrics.SparseCategoricalAccuracy(), 'validation/ece': rm.metrics.ExpectedCalibrationError( num_bins=FLAGS.num_bins), }) eval_dataset_splits += ['validation'] for i in range(FLAGS.ensemble_size): for dataset_split in eval_dataset_splits: metrics[f'{dataset_split}/nll_member_{i}'] = tf.keras.metrics.Mean() metrics[f'{dataset_split}/accuracy_member_{i}'] = ( tf.keras.metrics.SparseCategoricalAccuracy()) if FLAGS.corruptions_interval > 0: corrupt_metrics = {} for intensity in range(1, 6): for corruption in corruption_types: dataset_name = '{0}_{1}'.format(corruption, intensity) corrupt_metrics['test/nll_{}'.format(dataset_name)] = ( tf.keras.metrics.Mean()) corrupt_metrics['test/accuracy_{}'.format(dataset_name)] = ( tf.keras.metrics.SparseCategoricalAccuracy()) corrupt_metrics['test/ece_{}'.format(dataset_name)] = ( rm.metrics.ExpectedCalibrationError(num_bins=FLAGS.num_bins)) checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer) latest_checkpoint = tf.train.latest_checkpoint(FLAGS.output_dir) initial_epoch = 0 if latest_checkpoint: # checkpoint.restore must be within a strategy.scope() so that optimizer # slot variables are mirrored. checkpoint.restore(latest_checkpoint) logging.info('Loaded checkpoint %s', latest_checkpoint) initial_epoch = optimizer.iterations.numpy() // steps_per_epoch @tf.function def train_step(iterator): """Training StepFn.""" def step_fn(inputs): """Per-Replica StepFn.""" images = inputs['features'] labels = inputs['labels'] images = tf.tile(images, [FLAGS.ensemble_size, 1, 1, 1]) labels = tf.tile(labels, [FLAGS.ensemble_size]) with tf.GradientTape() as tape: logits = model(images, training=True) negative_log_likelihood = tf.reduce_mean( tf.keras.losses.sparse_categorical_crossentropy(labels, logits, from_logits=True)) l2_loss = sum(model.losses) loss = negative_log_likelihood + l2_loss # Scale the loss given the TPUStrategy will reduce sum all gradients. scaled_loss = loss / strategy.num_replicas_in_sync grads = tape.gradient(scaled_loss, model.trainable_variables) # Separate learning rate implementation. if FLAGS.fast_weight_lr_multiplier != 1.0: grads_and_vars = [] for grad, var in zip(grads, model.trainable_variables): # Apply different learning rate on the fast weight approximate # posterior/prior parameters. This is excludes BN and slow weights, # but pay caution to the naming scheme. if ('batch_norm' not in var.name and 'kernel' not in var.name): grads_and_vars.append((grad * FLAGS.fast_weight_lr_multiplier, var)) else: grads_and_vars.append((grad, var)) optimizer.apply_gradients(grads_and_vars) else: optimizer.apply_gradients(zip(grads, model.trainable_variables)) probs = tf.nn.softmax(logits) metrics['train/ece'].add_batch(probs, label=labels) metrics['train/loss'].update_state(loss) metrics['train/negative_log_likelihood'].update_state( negative_log_likelihood) metrics['train/accuracy'].update_state(labels, logits) for _ in tf.range(tf.cast(steps_per_epoch, tf.int32)): strategy.run(step_fn, args=(next(iterator),)) @tf.function def test_step(iterator, dataset_split, dataset_name, num_steps): """Evaluation StepFn.""" def step_fn(inputs): """Per-Replica StepFn.""" images = inputs['features'] labels = inputs['labels'] images = tf.tile(images, [FLAGS.ensemble_size, 1, 1, 1]) logits = model(images, training=False) probs = tf.nn.softmax(logits) per_probs = tf.split(probs, num_or_size_splits=FLAGS.ensemble_size, axis=0) for i in range(FLAGS.ensemble_size): member_probs = per_probs[i] member_loss = tf.keras.losses.sparse_categorical_crossentropy( labels, member_probs) metrics[f'{dataset_split}/nll_member_{i}'].update_state(member_loss) metrics[f'{dataset_split}/accuracy_member_{i}'].update_state( labels, member_probs) probs = tf.reduce_mean(per_probs, axis=0) negative_log_likelihood = tf.reduce_mean( tf.keras.losses.sparse_categorical_crossentropy(labels, probs)) if dataset_name == 'clean': metrics[f'{dataset_split}/negative_log_likelihood'].update_state( negative_log_likelihood) metrics[f'{dataset_split}/accuracy'].update_state(labels, probs) metrics[f'{dataset_split}/ece'].add_batch(probs, label=labels) else: corrupt_metrics['test/nll_{}'.format(dataset_name)].update_state( negative_log_likelihood) corrupt_metrics['test/accuracy_{}'.format(dataset_name)].update_state( labels, probs) corrupt_metrics['test/ece_{}'.format(dataset_name)].add_batch( probs, label=labels) for _ in tf.range(tf.cast(num_steps, tf.int32)): strategy.run(step_fn, args=(next(iterator),)) metrics.update({'test/ms_per_example': tf.keras.metrics.Mean()}) train_iterator = iter(train_dataset) start_time = time.time() for epoch in range(initial_epoch, FLAGS.train_epochs): logging.info('Starting to run epoch: %s', epoch) train_step(train_iterator) current_step = (epoch + 1) * steps_per_epoch max_steps = steps_per_epoch * FLAGS.train_epochs time_elapsed = time.time() - start_time steps_per_sec = float(current_step) / time_elapsed eta_seconds = (max_steps - current_step) / steps_per_sec message = ('{:.1%} completion: epoch {:d}/{:d}. {:.1f} steps/s. ' 'ETA: {:.0f} min. Time elapsed: {:.0f} min'.format( current_step / max_steps, epoch + 1, FLAGS.train_epochs, steps_per_sec, eta_seconds / 60, time_elapsed / 60)) logging.info(message) if validation_dataset: validation_iterator = iter(validation_dataset) test_step( validation_iterator, 'validation', 'clean', steps_per_validation) datasets_to_evaluate = {'clean': test_datasets['clean']} if (FLAGS.corruptions_interval > 0 and (epoch + 1) % FLAGS.corruptions_interval == 0): datasets_to_evaluate = test_datasets for dataset_name, test_dataset in datasets_to_evaluate.items(): test_iterator = iter(test_dataset) logging.info('Testing on dataset %s', dataset_name) logging.info('Starting to run eval at epoch: %s', epoch) test_start_time = time.time() test_step(test_iterator, 'test', dataset_name, steps_per_eval) ms_per_example = (time.time() - test_start_time) * 1e6 / batch_size metrics['test/ms_per_example'].update_state(ms_per_example) logging.info('Done with testing on %s', dataset_name) corrupt_results = {} if (FLAGS.corruptions_interval > 0 and (epoch + 1) % FLAGS.corruptions_interval == 0): corrupt_results = utils.aggregate_corrupt_metrics(corrupt_metrics, corruption_types) logging.info('Train Loss: %.4f, Accuracy: %.2f%%', metrics['train/loss'].result(), metrics['train/accuracy'].result() * 100) logging.info('Test NLL: %.4f, Accuracy: %.2f%%', metrics['test/negative_log_likelihood'].result(), metrics['test/accuracy'].result() * 100) for i in range(FLAGS.ensemble_size): logging.info('Member %d Test Loss: %.4f, Accuracy: %.2f%%', i, metrics['test/nll_member_{}'.format(i)].result(), metrics['test/accuracy_member_{}'.format(i)].result() * 100) total_results = {name: metric.result() for name, metric in metrics.items()} total_results.update(corrupt_results) # Metrics from Robustness Metrics (like ECE) will return a dict with a # single key/value, instead of a scalar. total_results = { k: (list(v.values())[0] if isinstance(v, dict) else v) for k, v in total_results.items() } with summary_writer.as_default(): for name, result in total_results.items(): tf.summary.scalar(name, result, step=epoch + 1) for metric in metrics.values(): metric.reset_states() if (FLAGS.checkpoint_interval > 0 and (epoch + 1) % FLAGS.checkpoint_interval == 0): checkpoint_name = checkpoint.save( os.path.join(FLAGS.output_dir, 'checkpoint')) logging.info('Saved checkpoint to %s', checkpoint_name) final_checkpoint_name = checkpoint.save( os.path.join(FLAGS.output_dir, 'checkpoint')) logging.info('Saved last checkpoint to %s', final_checkpoint_name) with summary_writer.as_default(): hp.hparams({ 'base_learning_rate': FLAGS.base_learning_rate, 'one_minus_momentum': FLAGS.one_minus_momentum, 'l2': FLAGS.l2, 'random_sign_init': FLAGS.random_sign_init, 'fast_weight_lr_multiplier': FLAGS.fast_weight_lr_multiplier, })
def main(hParams, n_run): nsteps = hParams['N_STEPS'] nenv = hParams[HP_N_ENV] n_epochs = hParams['N_EPOCHS'] total_timesteps = int(n_epochs * nsteps * nenv) nbatch = nenv * nsteps update_int = 1 save_int = 5 test_int = 10 gamma = 0.99 lr = hParams[HP_LEARNING_RATE] vf_coef = hParams[HP_VF_COEF] ent_coef = hParams[HP_ENT_COEF] save_dir = 'lr' + str(lr) + 'vc' + str(vf_coef) + 'ec' + str( ent_coef) + 'env' + str(nenv) current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") log_dir = 'logs/cart_hparam_tuning/run-' + str(n_run) summ_writer = tf.summary.create_file_writer(log_dir) envfn = lambda: gym.make('CartPole-v1') env = SubprocVecEnv([envfn] * nenv) state_size = env.observation_space.shape num_actions = env.action_space.n pgnet = SimplePGNet(state_size, num_actions, learning_rate=lr, vf_coef=vf_coef, ent_coef=ent_coef) runner = SonicEnvRunner(env, pgnet, nsteps, gamma) print("Total updates to run: ", total_timesteps // nbatch) for update in range(1, total_timesteps // nbatch + 1): print("\nUpdate #{}:".format(update)) states_mb, actions_mb, values_mb, rewards_mb, next_dones_mb = runner.run( ) tf.summary.trace_on(graph=True) policy_loss, entropy_loss, vf_loss, loss = pgnet.fit_gradient( states_mb, actions_mb, rewards_mb, values_mb) if update == 1: with summ_writer.as_default(): tf.summary.trace_export(name="grad_trace", step=0) WeightWriter(summ_writer, pgnet, (Conv2D, Dense), global_step=update) with summ_writer.as_default(): tf.summary.scalar("PolicyLoss", policy_loss, step=update) tf.summary.scalar("EntropyLoss", entropy_loss, step=update) tf.summary.scalar("ValueFunctionLoss", vf_loss, step=update) tf.summary.scalar("Loss", loss, step=update) if update % update_int == 0: print("PolicyLoss:", policy_loss) print("EntropyLoss: ", entropy_loss) print("ValueFunctionLoss: ", vf_loss) print("Loss: ", loss) if update % save_int == 0: pgnet.model.save_weights('cart_hparams_tuning_models/' + save_dir + '/my_checkpoint') print("Model Saved") if update % test_int == 0: test_rewards = TestRewardWriter(summ_writer, envfn, pgnet, 20, global_step=update) print("Test Rewards: ", test_rewards) with summ_writer.as_default(): hp.hparams(hParams) env.close()
def run(self): run_dir = os.path.join( self.ph['log_dir'], 'run_{}'.format(str(self.ph['curr_run_number']))) # set up tensorboard summary writer scope with tf.summary.create_file_writer(run_dir).as_default(): hp.hparams(self.ph.get_hparams()) # data loading train_client_data, test_dataset = dta.get_client_data( dataset_name=self.ph['dataset'], mask_by=self.ph['mask_by'], mask_ratios={ 'unsupervised': self.ph['unsupervised_mask_ratio'], 'supervised': self.ph['supervised_mask_ratio'] }, sample_client_data=self.ph['sample_client_data']) test_dataset = self.dataloader.preprocess_dataset(test_dataset) sample_batch = self.dataloader.get_sample_batch(train_client_data) model_fn = functools.partial( self.keras_model_fn.create_tff_model_fn, sample_batch) # federated training iterative_process = tff.learning.build_federated_averaging_process( model_fn) state = iterative_process.initialize() for round_num in range(self.num_rounds): sample_clients = np.random.choice( train_client_data.client_ids, size=min(self.num_clients_per_round, len(train_client_data.client_ids)), replace=False) federated_train_data = self.dataloader.make_federated_data( train_client_data, sample_clients) state, metrics = iterative_process.next( state, federated_train_data) if not round_num % self.log_every: print('\nround {:2d}, metrics={}'.format( round_num, metrics)) tf.summary.scalar('train_accuracy', metrics[0], step=round_num) tf.summary.scalar('train_loss', metrics[1], step=round_num) test_loss, test_accuracy = self.evaluate_central( test_dataset, state) tf.summary.scalar('test_accuracy', test_accuracy, step=round_num) tf.summary.scalar('test_loss', test_loss, step=round_num) print('\nround {:2d}, metrics={}'.format(round_num, metrics)) tf.summary.scalar('train_accuracy', metrics[0], step=round_num) tf.summary.scalar('train_loss', metrics[1], step=round_num) test_loss, test_accuracy = self.evaluate_central( test_dataset, state) tf.summary.scalar('test_accuracy', test_accuracy, step=round_num) tf.summary.scalar('test_loss', test_loss, step=round_num) model_fp = os.path.join(run_dir, self.ph['model_fp']) self.keras_model_fn.save_model_weights(model_fp, state) return
def main(argv): del argv # unused arg tf.io.gfile.makedirs(FLAGS.output_dir) logging.info('Saving checkpoints at %s', FLAGS.output_dir) tf.random.set_seed(FLAGS.seed) if FLAGS.use_gpu: logging.info('Use GPU') strategy = tf.distribute.MirroredStrategy() else: logging.info('Use TPU at %s', FLAGS.tpu if FLAGS.tpu is not None else 'local') resolver = tf.distribute.cluster_resolver.TPUClusterResolver( tpu=FLAGS.tpu) tf.config.experimental_connect_to_cluster(resolver) tf.tpu.experimental.initialize_tpu_system(resolver) strategy = tf.distribute.TPUStrategy(resolver) per_core_batch_size = FLAGS.per_core_batch_size // FLAGS.ensemble_size batch_size = per_core_batch_size * FLAGS.num_cores check_bool = FLAGS.train_proportion > 0 and FLAGS.train_proportion <= 1 assert check_bool, 'Proportion of train set has to meet 0 < prop <= 1.' drop_remainder_validation = True if not FLAGS.use_gpu: # This has to be True for TPU traing, otherwise the batchsize of images in # the validation set can't be determined by TPU compile. assert drop_remainder_validation, 'drop_remainder must be True in TPU mode.' validation_percent = 1 - FLAGS.train_proportion train_dataset = ub.datasets.get( FLAGS.dataset, split=tfds.Split.TRAIN, validation_percent=validation_percent).load(batch_size=batch_size) validation_dataset = ub.datasets.get( FLAGS.dataset, split=tfds.Split.VALIDATION, validation_percent=validation_percent, drop_remainder=drop_remainder_validation).load(batch_size=batch_size) validation_dataset = validation_dataset.repeat() clean_test_dataset = ub.datasets.get( FLAGS.dataset, split=tfds.Split.TEST).load(batch_size=batch_size) train_dataset = strategy.experimental_distribute_dataset(train_dataset) validation_dataset = strategy.experimental_distribute_dataset( validation_dataset) test_datasets = { 'clean': strategy.experimental_distribute_dataset(clean_test_dataset), } if FLAGS.corruptions_interval > 0: extra_kwargs = {} if FLAGS.dataset == 'cifar100': extra_kwargs['data_dir'] = FLAGS.cifar100_c_path corruption_types, _ = utils.load_corrupted_test_info(FLAGS.dataset) for corruption_type in corruption_types: for severity in range(1, 6): dataset = ub.datasets.get( f'{FLAGS.dataset}_corrupted', corruption_type=corruption_type, severity=severity, split=tfds.Split.TEST, **extra_kwargs).load(batch_size=batch_size) test_datasets[f'{corruption_type}_{severity}'] = ( strategy.experimental_distribute_dataset(dataset)) ds_info = tfds.builder(FLAGS.dataset).info train_sample_size = ds_info.splits[ 'train'].num_examples * FLAGS.train_proportion steps_per_epoch = int(train_sample_size / batch_size) train_sample_size = int(train_sample_size) steps_per_eval = ds_info.splits['test'].num_examples // batch_size num_classes = ds_info.features['label'].num_classes summary_writer = tf.summary.create_file_writer( os.path.join(FLAGS.output_dir, 'summaries')) logging.info('Building Keras model.') depth = 28 width = 10 dict_ranges = {'min': FLAGS.min_l2_range, 'max': FLAGS.max_l2_range} ranges = [dict_ranges for _ in range(6)] # 6 independent l2 parameters model_config = { 'key_to_index': { 'input_conv_l2_kernel': 0, 'group_l2_kernel': 1, 'group_1_l2_kernel': 2, 'group_2_l2_kernel': 3, 'dense_l2_kernel': 4, 'dense_l2_bias': 5, }, 'ranges': ranges, 'test': None } lambdas_config = LambdaConfig(model_config['ranges'], model_config['key_to_index']) if FLAGS.e_body_hidden_units > 0: e_body_arch = '({},)'.format(FLAGS.e_body_hidden_units) else: e_body_arch = '()' e_shared_arch = '()' e_activation = 'tanh' filters_resnet = [16] for i in range(0, 3): # 3 groups of blocks filters_resnet.extend([16 * width * 2**i] * 9) # 9 layers in each block # e_head dim for conv2d is just the number of filters (only # kernel) and twice num of classes for the last dense layer (kernel + bias) e_head_dims = [x for x in filters_resnet] + [2 * num_classes] with strategy.scope(): e_models = e_factory( lambdas_config.input_shape, e_head_dims=e_head_dims, e_body_arch=eval(e_body_arch), # pylint: disable=eval-used e_shared_arch=eval(e_shared_arch), # pylint: disable=eval-used activation=e_activation, use_bias=FLAGS.e_model_use_bias, e_head_init=FLAGS.init_emodels_stddev) model = wide_resnet_hyperbatchensemble( input_shape=ds_info.features['image'].shape, depth=depth, width_multiplier=width, num_classes=num_classes, ensemble_size=FLAGS.ensemble_size, random_sign_init=FLAGS.random_sign_init, config=lambdas_config, e_models=e_models, l2_batchnorm_layer=FLAGS.l2_batchnorm, regularize_fast_weights=FLAGS.regularize_fast_weights, fast_weights_eq_contraint=FLAGS.fast_weights_eq_contraint, version=2) logging.info('Model input shape: %s', model.input_shape) logging.info('Model output shape: %s', model.output_shape) logging.info('Model number of weights: %s', model.count_params()) # build hyper-batchensemble complete ------------------------- # Initialize Lambda distributions for tuning lambdas_mean = tf.reduce_mean( log_uniform_mean([lambdas_config.log_min, lambdas_config.log_max])) lambdas0 = tf.random.normal((FLAGS.ensemble_size, lambdas_config.dim), lambdas_mean, 0.1 * FLAGS.ens_init_delta_bounds) lower0 = lambdas0 - tf.constant(FLAGS.ens_init_delta_bounds) lower0 = tf.maximum(lower0, 1e-8) upper0 = lambdas0 + tf.constant(FLAGS.ens_init_delta_bounds) log_lower = tf.Variable(tf.math.log(lower0)) log_upper = tf.Variable(tf.math.log(upper0)) lambda_parameters = [log_lower, log_upper] # these variables are tuned clip_lambda_parameters(lambda_parameters, lambdas_config) # Optimizer settings to train model weights # Linearly scale learning rate and the decay epochs by vanilla settings. # Note: Here, we don't divide the epochs by 200 as for the other uncertainty # baselines. base_lr = FLAGS.base_learning_rate * batch_size / 128 lr_decay_epochs = [int(l) for l in FLAGS.lr_decay_epochs] lr_schedule = ub.schedules.WarmUpPiecewiseConstantSchedule( steps_per_epoch, base_lr, decay_ratio=FLAGS.lr_decay_ratio, decay_epochs=lr_decay_epochs, warmup_epochs=FLAGS.lr_warmup_epochs) optimizer = tf.keras.optimizers.SGD(lr_schedule, momentum=1.0 - FLAGS.one_minus_momentum, nesterov=True) # tuner used for optimizing lambda_parameters tuner = tf.keras.optimizers.Adam(FLAGS.lr_tuning) metrics = { 'train/negative_log_likelihood': tf.keras.metrics.Mean(), 'train/accuracy': tf.keras.metrics.SparseCategoricalAccuracy(), 'train/loss': tf.keras.metrics.Mean(), 'train/ece': rm.metrics.ExpectedCalibrationError(num_bins=FLAGS.num_bins), 'train/disagreement': tf.keras.metrics.Mean(), 'train/average_kl': tf.keras.metrics.Mean(), 'train/cosine_similarity': tf.keras.metrics.Mean(), 'test/negative_log_likelihood': tf.keras.metrics.Mean(), 'test/accuracy': tf.keras.metrics.SparseCategoricalAccuracy(), 'test/ece': rm.metrics.ExpectedCalibrationError(num_bins=FLAGS.num_bins), 'test/gibbs_nll': tf.keras.metrics.Mean(), 'test/gibbs_accuracy': tf.keras.metrics.SparseCategoricalAccuracy(), 'test/disagreement': tf.keras.metrics.Mean(), 'test/average_kl': tf.keras.metrics.Mean(), 'test/cosine_similarity': tf.keras.metrics.Mean(), 'validation/loss': tf.keras.metrics.Mean(), 'validation/loss_entropy': tf.keras.metrics.Mean(), 'validation/loss_ce': tf.keras.metrics.Mean() } corrupt_metrics = {} for i in range(FLAGS.ensemble_size): metrics['test/nll_member_{}'.format(i)] = tf.keras.metrics.Mean() metrics['test/accuracy_member_{}'.format(i)] = ( tf.keras.metrics.SparseCategoricalAccuracy()) if FLAGS.corruptions_interval > 0: for intensity in range(1, 6): for corruption in corruption_types: dataset_name = '{0}_{1}'.format(corruption, intensity) corrupt_metrics['test/nll_{}'.format(dataset_name)] = ( tf.keras.metrics.Mean()) corrupt_metrics['test/accuracy_{}'.format( dataset_name)] = ( tf.keras.metrics.SparseCategoricalAccuracy()) corrupt_metrics['test/ece_{}'.format(dataset_name)] = ( rm.metrics.ExpectedCalibrationError( num_bins=FLAGS.num_bins)) checkpoint = tf.train.Checkpoint(model=model, lambda_parameters=lambda_parameters, optimizer=optimizer) latest_checkpoint = tf.train.latest_checkpoint(FLAGS.output_dir) initial_epoch = 0 if latest_checkpoint and FLAGS.restore_checkpoint: # checkpoint.restore must be within a strategy.scope() so that optimizer # slot variables are mirrored. checkpoint.restore(latest_checkpoint) logging.info('Loaded checkpoint %s', latest_checkpoint) initial_epoch = optimizer.iterations.numpy() // steps_per_epoch @tf.function def train_step(iterator): """Training StepFn.""" def step_fn(inputs): """Per-Replica StepFn.""" images = inputs['features'] labels = inputs['labels'] images = tf.tile(images, [FLAGS.ensemble_size, 1, 1, 1]) # generate lambdas lambdas = log_uniform_sample(per_core_batch_size, lambda_parameters) lambdas = tf.reshape(lambdas, (FLAGS.ensemble_size * per_core_batch_size, lambdas_config.dim)) with tf.GradientTape() as tape: logits = model([images, lambdas], training=True) if FLAGS.use_gibbs_ce: # Average of single model CEs # tiling of labels should be only done for Gibbs CE loss labels = tf.tile(labels, [FLAGS.ensemble_size]) negative_log_likelihood = tf.reduce_mean( tf.keras.losses.sparse_categorical_crossentropy( labels, logits, from_logits=True)) else: # Ensemble CE uses no tiling of the labels negative_log_likelihood = ensemble_crossentropy( labels, logits, FLAGS.ensemble_size) # Note: Divide l2_loss by sample_size (this differs from uncertainty_ # baselines implementation.) l2_loss = sum(model.losses) / train_sample_size loss = negative_log_likelihood + l2_loss # Scale the loss given the TPUStrategy will reduce sum all gradients. scaled_loss = loss / strategy.num_replicas_in_sync grads = tape.gradient(scaled_loss, model.trainable_variables) # Separate learning rate for fast weights. grads_and_vars = [] for grad, var in zip(grads, model.trainable_variables): if (('alpha' in var.name or 'gamma' in var.name) and 'batch_norm' not in var.name): grads_and_vars.append( (grad * FLAGS.fast_weight_lr_multiplier, var)) else: grads_and_vars.append((grad, var)) optimizer.apply_gradients(grads_and_vars) probs = tf.nn.softmax(logits) per_probs = tf.split(probs, num_or_size_splits=FLAGS.ensemble_size, axis=0) per_probs_stacked = tf.stack(per_probs, axis=0) metrics['train/ece'].add_batch(probs, label=labels) metrics['train/loss'].update_state(loss) metrics['train/negative_log_likelihood'].update_state( negative_log_likelihood) metrics['train/accuracy'].update_state(labels, logits) diversity = rm.metrics.AveragePairwiseDiversity() diversity.add_batch(per_probs_stacked, num_models=FLAGS.ensemble_size) diversity_results = diversity.result() for k, v in diversity_results.items(): metrics['train/' + k].update_state(v) if grads_and_vars: grads, _ = zip(*grads_and_vars) strategy.run(step_fn, args=(next(iterator), )) @tf.function def tuning_step(iterator): """Tuning StepFn.""" def step_fn(inputs): """Per-Replica StepFn.""" images = inputs['features'] labels = inputs['labels'] images = tf.tile(images, [FLAGS.ensemble_size, 1, 1, 1]) with tf.GradientTape(watch_accessed_variables=False) as tape: tape.watch(lambda_parameters) # sample lambdas if FLAGS.sample_and_tune: lambdas = log_uniform_sample(per_core_batch_size, lambda_parameters) else: lambdas = log_uniform_mean(lambda_parameters) lambdas = tf.repeat(lambdas, per_core_batch_size, axis=0) lambdas = tf.reshape(lambdas, (FLAGS.ensemble_size * per_core_batch_size, lambdas_config.dim)) # ensemble CE logits = model([images, lambdas], training=False) ce = ensemble_crossentropy(labels, logits, FLAGS.ensemble_size) # entropy penalty for lambda distribution entropy = FLAGS.tau * log_uniform_entropy(lambda_parameters) loss = ce - entropy scaled_loss = loss / strategy.num_replicas_in_sync gradients = tape.gradient(loss, lambda_parameters) tuner.apply_gradients(zip(gradients, lambda_parameters)) metrics['validation/loss_ce'].update_state( ce / strategy.num_replicas_in_sync) metrics['validation/loss_entropy'].update_state( entropy / strategy.num_replicas_in_sync) metrics['validation/loss'].update_state(scaled_loss) strategy.run(step_fn, args=(next(iterator), )) @tf.function def test_step(iterator, dataset_name, num_eval_samples=0): """Evaluation StepFn.""" n_samples = num_eval_samples if num_eval_samples >= 0 else -num_eval_samples if num_eval_samples >= 0: # the +1 accounts for the fact that we add the mean of lambdas ensemble_size = FLAGS.ensemble_size * (1 + n_samples) else: ensemble_size = FLAGS.ensemble_size * n_samples def step_fn(inputs): """Per-Replica StepFn.""" # Note that we don't use tf.tile for labels here images = inputs['features'] labels = inputs['labels'] images = tf.tile(images, [ensemble_size, 1, 1, 1]) # get lambdas samples = log_uniform_sample(n_samples, lambda_parameters) if num_eval_samples >= 0: lambdas = log_uniform_mean(lambda_parameters) lambdas = tf.expand_dims(lambdas, 1) lambdas = tf.concat((lambdas, samples), 1) else: lambdas = samples # lambdas with shape (ens size, samples, dim of lambdas) rep_lambdas = tf.repeat(lambdas, per_core_batch_size, axis=1) rep_lambdas = tf.reshape(rep_lambdas, (ensemble_size * per_core_batch_size, -1)) # eval on testsets logits = model([images, rep_lambdas], training=False) probs = tf.nn.softmax(logits) per_probs = tf.split(probs, num_or_size_splits=ensemble_size, axis=0) # per member performance and gibbs performance (average per member perf) if dataset_name == 'clean': for i in range(FLAGS.ensemble_size): # we record the first sample of lambdas per batch-ens member first_member_index = i * (ensemble_size // FLAGS.ensemble_size) member_probs = per_probs[first_member_index] member_loss = tf.keras.losses.sparse_categorical_crossentropy( labels, member_probs) metrics['test/nll_member_{}'.format(i)].update_state( member_loss) metrics['test/accuracy_member_{}'.format(i)].update_state( labels, member_probs) labels_tile = tf.tile(labels, [ensemble_size]) metrics['test/gibbs_nll'].update_state( tf.reduce_mean( tf.keras.losses.sparse_categorical_crossentropy( labels_tile, logits, from_logits=True))) metrics['test/gibbs_accuracy'].update_state(labels_tile, probs) # ensemble performance negative_log_likelihood = ensemble_crossentropy( labels, logits, ensemble_size) probs = tf.reduce_mean(per_probs, axis=0) if dataset_name == 'clean': metrics['test/negative_log_likelihood'].update_state( negative_log_likelihood) metrics['test/accuracy'].update_state(labels, probs) metrics['test/ece'].add_batch(probs, label=labels) else: corrupt_metrics['test/nll_{}'.format( dataset_name)].update_state(negative_log_likelihood) corrupt_metrics['test/accuracy_{}'.format( dataset_name)].update_state(labels, probs) corrupt_metrics['test/ece_{}'.format(dataset_name)].add_batch( probs, label=labels) if dataset_name == 'clean': per_probs_stacked = tf.stack(per_probs, axis=0) diversity = rm.metrics.AveragePairwiseDiversity() diversity.add_batch(per_probs_stacked, num_models=ensemble_size) diversity_results = diversity.result() for k, v in diversity_results.items(): metrics['test/' + k].update_state(v) strategy.run(step_fn, args=(next(iterator), )) logging.info('--- Starting training using %d examples. ---', train_sample_size) train_iterator = iter(train_dataset) validation_iterator = iter(validation_dataset) start_time = time.time() for epoch in range(initial_epoch, FLAGS.train_epochs): logging.info('Starting to run epoch: %s', epoch) for step in range(steps_per_epoch): train_step(train_iterator) do_tuning = (epoch >= FLAGS.tuning_warmup_epochs) if do_tuning and ((step + 1) % FLAGS.tuning_every_x_step == 0): tuning_step(validation_iterator) # clip lambda parameters if outside of range clip_lambda_parameters(lambda_parameters, lambdas_config) current_step = epoch * steps_per_epoch + (step + 1) max_steps = steps_per_epoch * FLAGS.train_epochs time_elapsed = time.time() - start_time steps_per_sec = float(current_step) / time_elapsed eta_seconds = (max_steps - current_step) / steps_per_sec message = ('{:.1%} completion: epoch {:d}/{:d}. {:.1f} steps/s. ' 'ETA: {:.0f} min. Time elapsed: {:.0f} min'.format( current_step / max_steps, epoch + 1, FLAGS.train_epochs, steps_per_sec, eta_seconds / 60, time_elapsed / 60)) if step % 20 == 0: logging.info(message) # evaluate on test data datasets_to_evaluate = {'clean': test_datasets['clean']} if (FLAGS.corruptions_interval > 0 and (epoch + 1) % FLAGS.corruptions_interval == 0): datasets_to_evaluate = test_datasets for dataset_name, test_dataset in datasets_to_evaluate.items(): test_iterator = iter(test_dataset) logging.info('Testing on dataset %s', dataset_name) for step in range(steps_per_eval): if step % 20 == 0: logging.info('Starting to run eval step %s of epoch: %s', step, epoch) test_step(test_iterator, dataset_name, FLAGS.num_eval_samples) logging.info('Done with testing on %s', dataset_name) corrupt_results = {} if (FLAGS.corruptions_interval > 0 and (epoch + 1) % FLAGS.corruptions_interval == 0): corrupt_results = utils.aggregate_corrupt_metrics( corrupt_metrics, corruption_types) logging.info('Train Loss: %.4f, Accuracy: %.2f%%', metrics['train/loss'].result(), metrics['train/accuracy'].result() * 100) logging.info('Validation Loss: %.4f, CE: %.4f, Entropy: %.4f', metrics['validation/loss'].result(), metrics['validation/loss_ce'].result(), metrics['validation/loss_entropy'].result()) logging.info('Test NLL: %.4f, Accuracy: %.2f%%', metrics['test/negative_log_likelihood'].result(), metrics['test/accuracy'].result() * 100) for i in range(FLAGS.ensemble_size): logging.info( 'Member %d Test Loss: %.4f, Accuracy: %.2f%%', i, metrics['test/nll_member_{}'.format(i)].result(), metrics['test/accuracy_member_{}'.format(i)].result() * 100) total_results = { name: metric.result() for name, metric in metrics.items() } total_results.update({ name: metric.result() for name, metric in corrupt_metrics.items() }) total_results.update(corrupt_results) # Metrics from Robustness Metrics (like ECE) will return a dict with a # single key/value, instead of a scalar. total_results = { k: (list(v.values())[0] if isinstance(v, dict) else v) for k, v in total_results.items() } with summary_writer.as_default(): for name, result in total_results.items(): tf.summary.scalar(name, result, step=epoch + 1) for metric in metrics.values(): metric.reset_states() # save checkpoint and lambdas config if (FLAGS.checkpoint_interval > 0 and (epoch + 1) % FLAGS.checkpoint_interval == 0): checkpoint_name = checkpoint.save( os.path.join(FLAGS.output_dir, 'checkpoint')) lambdas_cf = lambdas_config.get_config() filepath = os.path.join(FLAGS.output_dir, 'lambdas_config.p') with tf.io.gfile.GFile(filepath, 'wb') as fp: pickle.dump(lambdas_cf, fp, protocol=pickle.HIGHEST_PROTOCOL) logging.info('Saved checkpoint to %s', checkpoint_name) with summary_writer.as_default(): hp.hparams({ 'base_learning_rate': FLAGS.base_learning_rate, 'one_minus_momentum': FLAGS.one_minus_momentum, 'l2': FLAGS.l2, 'random_sign_init': FLAGS.random_sign_init, 'fast_weight_lr_multiplier': FLAGS.fast_weight_lr_multiplier, })
def main(hParams, n_run, total_timesteps): nsteps = hParams['N_STEPS'] n_epochs = hParams['N_EPOCHS'] n_train = 4 n_minibatch = 8 log_loss_int = 1 save_int = 5 test_int = 10 test_episodes = 5 gamma = 0.95 lr = hParams[HP_LEARNING_RATE] vf_coef = hParams[HP_VF_COEF] ent_coef = hParams[HP_ENT_COEF] save_dir = 'lr' + str(lr) + 'vc' + str(vf_coef) + 'ec' + str(ent_coef) testenvfn = SonicEnv.make_env_3 current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") log_dir = 'logs/sonic_long_test/run-' + str(n_run) summ_writer = tf.summary.create_file_writer(log_dir) env = SubprocVecEnv([SonicEnv.make_env_3]) nenv = env.num_envs state_size = env.observation_space.shape num_actions = env.action_space.n pgnet = PGNetwork(state_size, num_actions, lr=lr, vf_coef=vf_coef, ent_coef=ent_coef) # Runner used to create training data runner = SonicEnvRunner(env, pgnet, nsteps, gamma) # total_timesteps = int(n_epochs * nsteps * nenv) nbatch = nenv * nsteps print("Total updates to run: ", total_timesteps // nbatch) for update in range(1, total_timesteps // nbatch + 1): print("\nUpdate #{}:".format(update)) states_mb, actions_mb, values_mb, rewards_mb, next_dones_mb = runner.run( ) for _ in range(n_train): indices = np.arange(nbatch) np.random.shuffle(indices) for start in range(0, nbatch, nbatch // n_minibatch): end = start + nbatch // n_minibatch bind = indices[start:end] policy_loss, entropy_loss, vf_loss, loss = pgnet.fit_gradient( states_mb[bind], actions_mb[bind], rewards_mb[bind], values_mb[bind]) WeightWriter(summ_writer, pgnet, (Conv2D, Dense), global_step=update) r2 = 1 - (np.var(rewards_mb - values_mb) / np.var(rewards_mb)) with summ_writer.as_default(): tf.summary.scalar("PolicyLoss", policy_loss, step=update) tf.summary.scalar("EntropyLoss", entropy_loss, step=update) tf.summary.scalar("ValueFunctionLoss", vf_loss, step=update) tf.summary.scalar("Loss", loss, step=update) tf.summary.scalar("R-squared", r2, step=update) if update % log_loss_int == 0: print("PolicyLoss:", policy_loss) print("EntropyLoss: ", entropy_loss) print("ValueFunctionLoss: ", vf_loss) print("Loss: ", loss) if update % save_int == 0: pgnet.model.save_weights('sonic_long_test/' + save_dir + '/my_checkpoint') print("Model Saved") if update % test_int == 0: TestRewardWriter(summ_writer, testenvfn, pgnet, test_episodes, global_step=update) with summ_writer.as_default(): hp.hparams(hParams) env.close()
def main(argv): fmt = '[%(filename)s:%(lineno)s] %(message)s' formatter = logging.PythonFormatter(fmt) logging.get_absl_handler().setFormatter(formatter) del argv # unused arg tf.io.gfile.makedirs(FLAGS.output_dir) logging.info('Saving checkpoints at %s', FLAGS.output_dir) tf.random.set_seed(FLAGS.seed) data_dir = None if FLAGS.use_gpu: logging.info('Use GPU') strategy = tf.distribute.MirroredStrategy() else: logging.info('Use TPU at %s', FLAGS.tpu if FLAGS.tpu is not None else 'local') data_dir = FLAGS.data_dir resolver = tf.distribute.cluster_resolver.TPUClusterResolver( tpu=FLAGS.tpu) tf.config.experimental_connect_to_cluster(resolver) tf.tpu.experimental.initialize_tpu_system(resolver) strategy = tf.distribute.TPUStrategy(resolver) ds_info = tfds.builder(FLAGS.dataset).info batch_size = FLAGS.per_core_batch_size * FLAGS.num_cores train_dataset_size = (ds_info.splits['train'].num_examples * FLAGS.train_proportion) steps_per_epoch = int(train_dataset_size / batch_size) logging.info('Steps per epoch %s', steps_per_epoch) logging.info('Size of the dataset %s', ds_info.splits['train'].num_examples) logging.info('Train proportion %s', FLAGS.train_proportion) steps_per_eval = ds_info.splits['test'].num_examples // batch_size num_classes = ds_info.features['label'].num_classes aug_params = { 'augmix': FLAGS.augmix, 'aug_count': FLAGS.aug_count, 'augmix_depth': FLAGS.augmix_depth, 'augmix_prob_coeff': FLAGS.augmix_prob_coeff, 'augmix_width': FLAGS.augmix_width, } # Note that stateless_{fold_in,split} may incur a performance cost, but a # quick side-by-side test seemed to imply this was minimal. seeds = tf.random.experimental.stateless_split( [FLAGS.seed, FLAGS.seed + 1], 2)[:, 0] train_builder = ub.datasets.get(FLAGS.dataset, download_data=FLAGS.download_data, split=tfds.Split.TRAIN, seed=seeds[0], aug_params=aug_params, validation_percent=1. - FLAGS.train_proportion, data_dir=data_dir) train_dataset = train_builder.load(batch_size=batch_size) validation_dataset = None steps_per_validation = 0 if FLAGS.train_proportion < 1.0: validation_builder = ub.datasets.get(FLAGS.dataset, split=tfds.Split.VALIDATION, validation_percent=1. - FLAGS.train_proportion, data_dir=data_dir) validation_dataset = validation_builder.load(batch_size=batch_size) validation_dataset = strategy.experimental_distribute_dataset( validation_dataset) steps_per_validation = validation_builder.num_examples // batch_size clean_test_builder = ub.datasets.get(FLAGS.dataset, split=tfds.Split.TEST, data_dir=data_dir) clean_test_dataset = clean_test_builder.load(batch_size=batch_size) train_dataset = strategy.experimental_distribute_dataset(train_dataset) test_datasets = { 'clean': strategy.experimental_distribute_dataset(clean_test_dataset), } steps_per_epoch = train_builder.num_examples // batch_size steps_per_eval = clean_test_builder.num_examples // batch_size num_classes = 100 if FLAGS.dataset == 'cifar100' else 10 if FLAGS.corruptions_interval > 0: if FLAGS.dataset == 'cifar100': data_dir = FLAGS.cifar100_c_path corruption_types, _ = utils.load_corrupted_test_info(FLAGS.dataset) for corruption_type in corruption_types: for severity in range(1, 6): dataset = ub.datasets.get( f'{FLAGS.dataset}_corrupted', corruption_type=corruption_type, severity=severity, split=tfds.Split.TEST, data_dir=data_dir).load(batch_size=batch_size) test_datasets[f'{corruption_type}_{severity}'] = ( strategy.experimental_distribute_dataset(dataset)) summary_writer = tf.summary.create_file_writer( os.path.join(FLAGS.output_dir, 'summaries')) with strategy.scope(): logging.info('Building ResNet model') model = ub.models.wide_resnet(input_shape=(32, 32, 3), depth=28, width_multiplier=10, num_classes=num_classes, l2=FLAGS.l2, hps=_extract_hyperparameter_dictionary(), seed=seeds[1], version=2) logging.info('Model input shape: %s', model.input_shape) logging.info('Model output shape: %s', model.output_shape) logging.info('Model number of weights: %s', model.count_params()) # Linearly scale learning rate and the decay epochs by vanilla settings. base_lr = FLAGS.base_learning_rate * batch_size / 128 lr_decay_epochs = [(int(start_epoch_str) * FLAGS.train_epochs) // 200 for start_epoch_str in FLAGS.lr_decay_epochs] lr_schedule = ub.schedules.WarmUpPiecewiseConstantSchedule( steps_per_epoch, base_lr, decay_ratio=FLAGS.lr_decay_ratio, decay_epochs=lr_decay_epochs, warmup_epochs=FLAGS.lr_warmup_epochs) optimizer = tf.keras.optimizers.SGD(lr_schedule, momentum=1.0 - FLAGS.one_minus_momentum, nesterov=True) metrics = { 'train/negative_log_likelihood': tf.keras.metrics.Mean(), 'train/accuracy': tf.keras.metrics.SparseCategoricalAccuracy(), 'train/loss': tf.keras.metrics.Mean(), 'train/ece': rm.metrics.ExpectedCalibrationError(num_bins=FLAGS.num_bins), 'test/negative_log_likelihood': tf.keras.metrics.Mean(), 'test/accuracy': tf.keras.metrics.SparseCategoricalAccuracy(), 'test/ece': rm.metrics.ExpectedCalibrationError(num_bins=FLAGS.num_bins), } if validation_dataset: metrics.update({ 'validation/negative_log_likelihood': tf.keras.metrics.Mean(), 'validation/accuracy': tf.keras.metrics.SparseCategoricalAccuracy(), 'validation/ece': rm.metrics.ExpectedCalibrationError(num_bins=FLAGS.num_bins), }) if FLAGS.corruptions_interval > 0: corrupt_metrics = {} for intensity in range(1, 6): for corruption in corruption_types: dataset_name = '{0}_{1}'.format(corruption, intensity) corrupt_metrics['test/nll_{}'.format(dataset_name)] = ( tf.keras.metrics.Mean()) corrupt_metrics['test/accuracy_{}'.format( dataset_name)] = ( tf.keras.metrics.SparseCategoricalAccuracy()) corrupt_metrics['test/ece_{}'.format(dataset_name)] = ( rm.metrics.ExpectedCalibrationError( num_bins=FLAGS.num_bins)) checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer) latest_checkpoint = tf.train.latest_checkpoint(FLAGS.output_dir) initial_epoch = 0 if latest_checkpoint: # checkpoint.restore must be within a strategy.scope() so that optimizer # slot variables are mirrored. checkpoint.restore(latest_checkpoint) logging.info('Loaded checkpoint %s', latest_checkpoint) initial_epoch = optimizer.iterations.numpy() // steps_per_epoch @tf.function def train_step(iterator): """Training StepFn.""" def step_fn(inputs): """Per-Replica StepFn.""" images = inputs['features'] labels = inputs['labels'] if FLAGS.augmix and FLAGS.aug_count >= 1: # Index 0 at augmix processing is the unperturbed image. # We take just 1 augmented image from the returned augmented images. images = images[:, 1, ...] with tf.GradientTape() as tape: logits = model(images, training=True) if FLAGS.label_smoothing == 0.: negative_log_likelihood = tf.reduce_mean( tf.keras.losses.sparse_categorical_crossentropy( labels, logits, from_logits=True)) else: one_hot_labels = tf.one_hot(tf.cast(labels, tf.int32), num_classes) negative_log_likelihood = tf.reduce_mean( tf.keras.losses.categorical_crossentropy( one_hot_labels, logits, from_logits=True, label_smoothing=FLAGS.label_smoothing)) l2_loss = sum(model.losses) loss = negative_log_likelihood + l2_loss # Scale the loss given the TPUStrategy will reduce sum all gradients. scaled_loss = loss / strategy.num_replicas_in_sync grads = tape.gradient(scaled_loss, model.trainable_variables) optimizer.apply_gradients(zip(grads, model.trainable_variables)) probs = tf.nn.softmax(logits) metrics['train/ece'].add_batch(probs, label=labels) metrics['train/loss'].update_state(loss) metrics['train/negative_log_likelihood'].update_state( negative_log_likelihood) metrics['train/accuracy'].update_state(labels, logits) for _ in tf.range(tf.cast(steps_per_epoch, tf.int32)): strategy.run(step_fn, args=(next(iterator), )) @tf.function def test_step(iterator, dataset_split, dataset_name, num_steps): """Evaluation StepFn.""" def step_fn(inputs): """Per-Replica StepFn.""" images = inputs['features'] labels = inputs['labels'] logits = model(images, training=False) probs = tf.nn.softmax(logits) negative_log_likelihood = tf.reduce_mean( tf.keras.losses.sparse_categorical_crossentropy(labels, probs)) if dataset_name == 'clean': metrics[ f'{dataset_split}/negative_log_likelihood'].update_state( negative_log_likelihood) metrics[f'{dataset_split}/accuracy'].update_state( labels, probs) metrics[f'{dataset_split}/ece'].add_batch(probs, label=labels) else: corrupt_metrics['test/nll_{}'.format( dataset_name)].update_state(negative_log_likelihood) corrupt_metrics['test/accuracy_{}'.format( dataset_name)].update_state(labels, probs) corrupt_metrics['test/ece_{}'.format(dataset_name)].add_batch( probs, label=labels) for _ in tf.range(tf.cast(num_steps, tf.int32)): strategy.run(step_fn, args=(next(iterator), )) metrics.update({'test/ms_per_example': tf.keras.metrics.Mean()}) metrics.update({'train/ms_per_example': tf.keras.metrics.Mean()}) train_iterator = iter(train_dataset) start_time = time.time() tb_callback = None if FLAGS.collect_profile: tb_callback = tf.keras.callbacks.TensorBoard(profile_batch=(100, 102), log_dir=os.path.join( FLAGS.output_dir, 'logs')) tb_callback.set_model(model) for epoch in range(initial_epoch, FLAGS.train_epochs): logging.info('Starting to run epoch: %s', epoch) if tb_callback: tb_callback.on_epoch_begin(epoch) train_start_time = time.time() train_step(train_iterator) ms_per_example = (time.time() - train_start_time) * 1e6 / batch_size metrics['train/ms_per_example'].update_state(ms_per_example) current_step = (epoch + 1) * steps_per_epoch max_steps = steps_per_epoch * FLAGS.train_epochs time_elapsed = time.time() - start_time steps_per_sec = float(current_step) / time_elapsed eta_seconds = (max_steps - current_step) / steps_per_sec message = ('{:.1%} completion: epoch {:d}/{:d}. {:.1f} steps/s. ' 'ETA: {:.0f} min. Time elapsed: {:.0f} min'.format( current_step / max_steps, epoch + 1, FLAGS.train_epochs, steps_per_sec, eta_seconds / 60, time_elapsed / 60)) logging.info(message) if tb_callback: tb_callback.on_epoch_end(epoch) if validation_dataset: validation_iterator = iter(validation_dataset) test_step(validation_iterator, 'validation', 'clean', steps_per_validation) datasets_to_evaluate = {'clean': test_datasets['clean']} if (FLAGS.corruptions_interval > 0 and (epoch + 1) % FLAGS.corruptions_interval == 0): datasets_to_evaluate = test_datasets for dataset_name, test_dataset in datasets_to_evaluate.items(): test_iterator = iter(test_dataset) logging.info('Testing on dataset %s', dataset_name) logging.info('Starting to run eval at epoch: %s', epoch) test_start_time = time.time() test_step(test_iterator, 'test', dataset_name, steps_per_eval) ms_per_example = (time.time() - test_start_time) * 1e6 / batch_size metrics['test/ms_per_example'].update_state(ms_per_example) logging.info('Done with testing on %s', dataset_name) corrupt_results = {} if (FLAGS.corruptions_interval > 0 and (epoch + 1) % FLAGS.corruptions_interval == 0): corrupt_results = utils.aggregate_corrupt_metrics( corrupt_metrics, corruption_types) logging.info('Train Loss: %.4f, Accuracy: %.2f%%', metrics['train/loss'].result(), metrics['train/accuracy'].result() * 100) logging.info('Test NLL: %.4f, Accuracy: %.2f%%', metrics['test/negative_log_likelihood'].result(), metrics['test/accuracy'].result() * 100) total_results = { name: metric.result() for name, metric in metrics.items() } total_results.update(corrupt_results) # Metrics from Robustness Metrics (like ECE) will return a dict with a # single key/value, instead of a scalar. total_results = { k: (list(v.values())[0] if isinstance(v, dict) else v) for k, v in total_results.items() } with summary_writer.as_default(): for name, result in total_results.items(): tf.summary.scalar(name, result, step=epoch + 1) for metric in metrics.values(): metric.reset_states() if (FLAGS.checkpoint_interval > 0 and (epoch + 1) % FLAGS.checkpoint_interval == 0): checkpoint_name = checkpoint.save( os.path.join(FLAGS.output_dir, 'checkpoint')) logging.info('Saved checkpoint to %s', checkpoint_name) final_checkpoint_name = checkpoint.save( os.path.join(FLAGS.output_dir, 'checkpoint')) logging.info('Saved last checkpoint to %s', final_checkpoint_name) with summary_writer.as_default(): hp.hparams({ 'base_learning_rate': FLAGS.base_learning_rate, 'one_minus_momentum': FLAGS.one_minus_momentum, 'l2': FLAGS.l2, })
def main(argv): del argv # unused arg tf.io.gfile.makedirs(FLAGS.output_dir) logging.info('Saving checkpoints at %s', FLAGS.output_dir) tf.random.set_seed(FLAGS.seed) # Initialize distribution strategy on flag-specified accelerator strategy = utils.init_distribution_strategy(FLAGS.force_use_cpu, FLAGS.use_gpu, FLAGS.tpu) use_tpu = not (FLAGS.force_use_cpu or FLAGS.use_gpu) train_batch_size = (FLAGS.train_batch_size * FLAGS.num_cores) if use_tpu: logging.info( 'Due to TPU requiring static shapes, we must fix the eval batch size ' 'to the train batch size: %d/', train_batch_size) eval_batch_size = train_batch_size else: eval_batch_size = (FLAGS.eval_batch_size * FLAGS.num_cores) // FLAGS.num_dropout_samples_eval # As per the Kaggle challenge, we have split sizes: # train: 35,126 # validation: 10,906 (currently unused) # test: 42,670 ds_info = tfds.builder('diabetic_retinopathy_detection').info steps_per_epoch = ds_info.splits['train'].num_examples // train_batch_size steps_per_validation_eval = (ds_info.splits['validation'].num_examples // eval_batch_size) steps_per_test_eval = ds_info.splits['test'].num_examples // eval_batch_size data_dir = FLAGS.data_dir dataset_train_builder = ub.datasets.get('diabetic_retinopathy_detection', split='train', data_dir=data_dir) dataset_train = dataset_train_builder.load(batch_size=train_batch_size) dataset_train = strategy.experimental_distribute_dataset(dataset_train) dataset_validation_builder = ub.datasets.get( 'diabetic_retinopathy_detection', split='validation', data_dir=data_dir) dataset_validation = dataset_validation_builder.load( batch_size=eval_batch_size) dataset_validation = strategy.experimental_distribute_dataset( dataset_validation) dataset_test_builder = ub.datasets.get('diabetic_retinopathy_detection', split='test', data_dir=data_dir) dataset_test = dataset_test_builder.load(batch_size=eval_batch_size) dataset_test = strategy.experimental_distribute_dataset(dataset_test) if FLAGS.use_bfloat16: policy = tf.keras.mixed_precision.experimental.Policy('mixed_bfloat16') tf.keras.mixed_precision.experimental.set_policy(policy) summary_writer = tf.summary.create_file_writer( os.path.join(FLAGS.output_dir, 'summaries')) with strategy.scope(): logging.info('Building Keras ResNet-50 MC Dropout model.') model = ub.models.resnet50_dropout( input_shape=utils.load_input_shape(dataset_train), num_classes=1, # binary classification task dropout_rate=FLAGS.dropout_rate, filterwise_dropout=FLAGS.filterwise_dropout) logging.info('Model input shape: %s', model.input_shape) logging.info('Model output shape: %s', model.output_shape) logging.info('Model number of weights: %s', model.count_params()) # Linearly scale learning rate and the decay epochs by vanilla settings. base_lr = FLAGS.base_learning_rate lr_decay_epochs = [ (int(start_epoch_str) * FLAGS.train_epochs) // DEFAULT_NUM_EPOCHS for start_epoch_str in FLAGS.lr_decay_epochs ] lr_schedule = ub.schedules.WarmUpPiecewiseConstantSchedule( steps_per_epoch, base_lr, decay_ratio=FLAGS.lr_decay_ratio, decay_epochs=lr_decay_epochs, warmup_epochs=FLAGS.lr_warmup_epochs) optimizer = tf.keras.optimizers.SGD(lr_schedule, momentum=1.0 - FLAGS.one_minus_momentum, nesterov=True) metrics = utils.get_diabetic_retinopathy_base_metrics( use_tpu=use_tpu, num_bins=FLAGS.num_bins) checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer) latest_checkpoint = tf.train.latest_checkpoint(FLAGS.output_dir) initial_epoch = 0 if latest_checkpoint: # checkpoint.restore must be within a strategy.scope() # so that optimizer slot variables are mirrored. checkpoint.restore(latest_checkpoint) logging.info('Loaded checkpoint %s', latest_checkpoint) initial_epoch = optimizer.iterations.numpy() // steps_per_epoch # Finally, define OOD metrics outside the accelerator scope for CPU eval. # This will cause an error on TPU. if not use_tpu: metrics.update({ 'train/auc': tf.keras.metrics.AUC(), 'validation/auc': tf.keras.metrics.AUC(), 'test/auc': tf.keras.metrics.AUC() }) @tf.function def train_step(iterator): """Training step function.""" def step_fn(inputs): """Per-replica step function.""" images = inputs['features'] labels = inputs['labels'] with tf.GradientTape() as tape: logits = model(images, training=True) if FLAGS.use_bfloat16: logits = tf.cast(logits, tf.float32) negative_log_likelihood = tf.reduce_mean( tf.keras.losses.binary_crossentropy(y_true=tf.expand_dims( labels, axis=-1), y_pred=logits, from_logits=True)) l2_loss = sum(model.losses) loss = negative_log_likelihood + (FLAGS.l2 * l2_loss) # Scale the loss given the TPUStrategy will reduce sum all gradients. scaled_loss = loss / strategy.num_replicas_in_sync grads = tape.gradient(scaled_loss, model.trainable_variables) optimizer.apply_gradients(zip(grads, model.trainable_variables)) probs = tf.nn.sigmoid(logits) metrics['train/loss'].update_state(loss) metrics['train/negative_log_likelihood'].update_state( negative_log_likelihood) metrics['train/accuracy'].update_state(labels, probs) metrics['train/auc'].update_state(labels, probs) if not use_tpu: metrics['train/ece'].update_state(labels, probs) strategy.run(step_fn, args=(next(iterator), )) @tf.function def test_step(iterator, dataset_split): """Evaluation step function.""" def step_fn(inputs): """Per-replica step function.""" images = inputs['features'] labels = tf.convert_to_tensor(inputs['labels']) logits_list = [] for _ in range(FLAGS.num_dropout_samples_eval): logits = model(images, training=False) logits = tf.squeeze(logits, axis=-1) if FLAGS.use_bfloat16: logits = tf.cast(logits, tf.float32) logits_list.append(logits) # Logits dimension is (num_samples, batch_size). logits_list = tf.stack(logits_list, axis=0) probs_list = tf.nn.sigmoid(logits_list) probs = tf.reduce_mean(probs_list, axis=0) labels_broadcasted = tf.broadcast_to( labels, [FLAGS.num_dropout_samples_eval, labels.shape[0]]) log_likelihoods = -tf.keras.losses.binary_crossentropy( labels_broadcasted, logits_list, from_logits=True) negative_log_likelihood = tf.reduce_mean( -tf.reduce_logsumexp(log_likelihoods, axis=[0]) + tf.math.log(float(FLAGS.num_dropout_samples_eval))) metrics[dataset_split + '/negative_log_likelihood'].update_state( negative_log_likelihood) metrics[dataset_split + '/accuracy'].update_state(labels, probs) metrics[dataset_split + '/auc'].update_state(labels, probs) if not use_tpu: metrics[dataset_split + '/ece'].update_state(labels, probs) strategy.run(step_fn, args=(next(iterator), )) metrics.update({'test/ms_per_example': tf.keras.metrics.Mean()}) start_time = time.time() train_iterator = iter(dataset_train) for epoch in range(initial_epoch, FLAGS.train_epochs): logging.info('Starting to run epoch: %s', epoch + 1) for step in range(steps_per_epoch): train_step(train_iterator) current_step = epoch * steps_per_epoch + (step + 1) max_steps = steps_per_epoch * FLAGS.train_epochs time_elapsed = time.time() - start_time steps_per_sec = float(current_step) / time_elapsed eta_seconds = (max_steps - current_step) / steps_per_sec message = ('{:.1%} completion: epoch {:d}/{:d}. {:.1f} steps/s. ' 'ETA: {:.0f} min. Time elapsed: {:.0f} min'.format( current_step / max_steps, epoch + 1, FLAGS.train_epochs, steps_per_sec, eta_seconds / 60, time_elapsed / 60)) if step % 20 == 0: logging.info(message) validation_iterator = iter(dataset_validation) for step in range(steps_per_validation_eval): if step % 20 == 0: logging.info( 'Starting to run validation eval step %s of epoch: %s', step, epoch + 1) test_step(validation_iterator, 'validation') test_iterator = iter(dataset_test) for step in range(steps_per_test_eval): if step % 20 == 0: logging.info('Starting to run test eval step %s of epoch: %s', step, epoch + 1) test_start_time = time.time() test_step(test_iterator, 'test') ms_per_example = (time.time() - test_start_time) * 1e6 / eval_batch_size metrics['test/ms_per_example'].update_state(ms_per_example) # Log and write to summary the epoch metrics utils.log_epoch_metrics(metrics=metrics, use_tpu=use_tpu) total_results = { name: metric.result() for name, metric in metrics.items() } with summary_writer.as_default(): for name, result in total_results.items(): tf.summary.scalar(name, result, step=epoch + 1) for metric in metrics.values(): metric.reset_states() if (FLAGS.checkpoint_interval > 0 and (epoch + 1) % FLAGS.checkpoint_interval == 0): checkpoint_name = checkpoint.save( os.path.join(FLAGS.output_dir, 'checkpoint')) logging.info('Saved checkpoint to %s', checkpoint_name) # TODO(nband): debug checkpointing # Also save Keras model, due to checkpoint.save issue keras_model_name = os.path.join(FLAGS.output_dir, f'keras_model_{epoch + 1}') model.save(keras_model_name) logging.info('Saved keras model to %s', keras_model_name) final_checkpoint_name = checkpoint.save( os.path.join(FLAGS.output_dir, 'checkpoint')) logging.info('Saved last checkpoint to %s', final_checkpoint_name) keras_model_name = os.path.join(FLAGS.output_dir, f'keras_model_{FLAGS.train_epochs}') model.save(keras_model_name) logging.info('Saved keras model to %s', keras_model_name) with summary_writer.as_default(): hp.hparams({ 'base_learning_rate': FLAGS.base_learning_rate, 'one_minus_momentum': FLAGS.one_minus_momentum, 'dropout_rate': FLAGS.dropout_rate, 'l2': FLAGS.l2, })
def main(argv): del argv # unused arg tf.io.gfile.makedirs(FLAGS.output_dir) logging.info('Saving checkpoints at %s', FLAGS.output_dir) tf.random.set_seed(FLAGS.seed) train_batch_size = (FLAGS.per_core_batch_size * FLAGS.num_cores // FLAGS.batch_repetitions) test_batch_size = FLAGS.per_core_batch_size * FLAGS.num_cores steps_per_epoch = APPROX_IMAGENET_TRAIN_IMAGES // train_batch_size steps_per_eval = IMAGENET_VALIDATION_IMAGES // test_batch_size data_dir = FLAGS.data_dir if FLAGS.use_gpu: logging.info('Use GPU') strategy = tf.distribute.MirroredStrategy() else: logging.info('Use TPU at %s', FLAGS.tpu if FLAGS.tpu is not None else 'local') resolver = tf.distribute.cluster_resolver.TPUClusterResolver( tpu=FLAGS.tpu) tf.config.experimental_connect_to_cluster(resolver) tf.tpu.experimental.initialize_tpu_system(resolver) strategy = tf.distribute.TPUStrategy(resolver) train_builder = ub.datasets.ImageNetDataset( split=tfds.Split.TRAIN, use_bfloat16=FLAGS.use_bfloat16, data_dir=data_dir) train_dataset = train_builder.load(batch_size=train_batch_size, strategy=strategy) test_builder = ub.datasets.ImageNetDataset(split=tfds.Split.TEST, use_bfloat16=FLAGS.use_bfloat16, data_dir=data_dir) test_dataset = test_builder.load(batch_size=test_batch_size, strategy=strategy) if FLAGS.use_bfloat16: tf.keras.mixed_precision.set_global_policy('mixed_bfloat16') with strategy.scope(): logging.info('Building Keras ResNet-50 model') model = ub.models.resnet50_het_mimo( input_shape=(FLAGS.ensemble_size, 224, 224, 3), num_classes=NUM_CLASSES, ensemble_size=FLAGS.ensemble_size, num_factors=FLAGS.num_factors, temperature=FLAGS.temperature, num_mc_samples=FLAGS.num_mc_samples, share_het_layer=FLAGS.share_het_layer, width_multiplier=FLAGS.width_multiplier) logging.info('Model input shape: %s', model.input_shape) logging.info('Model output shape: %s', model.output_shape) logging.info('Model number of weights: %s', model.count_params()) # Scale learning rate and decay epochs by vanilla settings. base_lr = FLAGS.base_learning_rate * train_batch_size / 256 decay_epochs = [ (FLAGS.train_epochs * 30) // 90, (FLAGS.train_epochs * 60) // 90, (FLAGS.train_epochs * 80) // 90, ] learning_rate = ub.schedules.WarmUpPiecewiseConstantSchedule( steps_per_epoch=steps_per_epoch, base_learning_rate=base_lr, decay_ratio=0.1, decay_epochs=decay_epochs, warmup_epochs=5) optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate, momentum=1.0 - FLAGS.one_minus_momentum, nesterov=True) metrics = { 'train/negative_log_likelihood': tf.keras.metrics.Mean(), 'train/accuracy': tf.keras.metrics.SparseCategoricalAccuracy(), 'train/loss': tf.keras.metrics.Mean(), 'train/ece': rm.metrics.ExpectedCalibrationError(num_bins=FLAGS.num_bins), 'test/negative_log_likelihood': tf.keras.metrics.Mean(), 'test/accuracy': tf.keras.metrics.SparseCategoricalAccuracy(), 'test/ece': rm.metrics.ExpectedCalibrationError(num_bins=FLAGS.num_bins), 'test/diversity': rm.metrics.AveragePairwiseDiversity(), } for i in range(FLAGS.ensemble_size): metrics['test/nll_member_{}'.format(i)] = tf.keras.metrics.Mean() metrics['test/accuracy_member_{}'.format(i)] = ( tf.keras.metrics.SparseCategoricalAccuracy()) logging.info('Finished building Keras ResNet-50 model') checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer) latest_checkpoint = tf.train.latest_checkpoint(FLAGS.output_dir) initial_epoch = 0 if latest_checkpoint: # checkpoint.restore must be within a strategy.scope() so that optimizer # slot variables are mirrored. checkpoint.restore(latest_checkpoint) logging.info('Loaded checkpoint %s', latest_checkpoint) initial_epoch = optimizer.iterations.numpy() // steps_per_epoch summary_writer = tf.summary.create_file_writer( os.path.join(FLAGS.output_dir, 'summaries')) @tf.function def train_step(iterator): """Training StepFn.""" def step_fn(inputs): """Per-Replica StepFn.""" images = inputs['features'] labels = inputs['labels'] batch_size = tf.shape(images)[0] main_shuffle = tf.random.shuffle( tf.tile(tf.range(batch_size), [FLAGS.batch_repetitions])) to_shuffle = tf.cast( tf.cast(tf.shape(main_shuffle)[0], tf.float32) * (1. - FLAGS.input_repetition_probability), tf.int32) shuffle_indices = [ tf.concat([ tf.random.shuffle(main_shuffle[:to_shuffle]), main_shuffle[to_shuffle:] ], axis=0) for _ in range(FLAGS.ensemble_size) ] images = tf.stack([ tf.gather(images, indices, axis=0) for indices in shuffle_indices ], axis=1) labels = tf.stack([ tf.gather(labels, indices, axis=0) for indices in shuffle_indices ], axis=1) with tf.GradientTape() as tape: logits = model(images, training=True) if FLAGS.use_bfloat16: logits = tf.cast(logits, tf.float32) negative_log_likelihood = tf.reduce_mean( tf.reduce_sum( tf.keras.losses.sparse_categorical_crossentropy( labels, logits, from_logits=True), axis=1)) filtered_variables = [] for var in model.trainable_variables: # Apply l2 on the weights. This excludes BN parameters and biases, but # pay caution to their naming scheme. if 'kernel' in var.name or 'bias' in var.name: filtered_variables.append(tf.reshape(var, (-1, ))) l2_loss = FLAGS.l2 * 2 * tf.nn.l2_loss( tf.concat(filtered_variables, axis=0)) # Scale the loss given the TPUStrategy will reduce sum all gradients. loss = negative_log_likelihood + l2_loss scaled_loss = loss / strategy.num_replicas_in_sync grads = tape.gradient(scaled_loss, model.trainable_variables) optimizer.apply_gradients(zip(grads, model.trainable_variables)) probs = tf.nn.softmax(tf.reshape(logits, [-1, NUM_CLASSES])) flat_labels = tf.reshape(labels, [-1]) metrics['train/ece'].add_batch(probs, label=flat_labels) metrics['train/loss'].update_state(loss) metrics['train/negative_log_likelihood'].update_state( negative_log_likelihood) metrics['train/accuracy'].update_state(flat_labels, probs) for _ in tf.range(tf.cast(steps_per_epoch, tf.int32)): strategy.run(step_fn, args=(next(iterator), )) @tf.function def test_step(iterator): """Evaluation StepFn.""" def step_fn(inputs): """Per-Replica StepFn.""" images = inputs['features'] labels = inputs['labels'] images = tf.tile(tf.expand_dims(images, 1), [1, FLAGS.ensemble_size, 1, 1, 1]) logits = model(images, training=False) if FLAGS.use_bfloat16: logits = tf.cast(logits, tf.float32) probs = tf.nn.softmax(logits) per_probs = tf.transpose(probs, perm=[1, 0, 2]) metrics['test/diversity'].add_batch(per_probs) for i in range(FLAGS.ensemble_size): member_probs = probs[:, i] member_loss = tf.keras.losses.sparse_categorical_crossentropy( labels, member_probs) metrics['test/nll_member_{}'.format(i)].update_state( member_loss) metrics['test/accuracy_member_{}'.format(i)].update_state( labels, member_probs) # Negative log marginal likelihood computed in a numerically-stable way. labels_tiled = tf.tile(tf.expand_dims(labels, 1), [1, FLAGS.ensemble_size]) log_likelihoods = -tf.keras.losses.sparse_categorical_crossentropy( labels_tiled, logits, from_logits=True) negative_log_likelihood = tf.reduce_mean( -tf.reduce_logsumexp(log_likelihoods, axis=[1]) + tf.math.log(float(FLAGS.ensemble_size))) probs = tf.math.reduce_mean(probs, axis=1) # marginalize metrics['test/negative_log_likelihood'].update_state( negative_log_likelihood) metrics['test/accuracy'].update_state(labels, probs) metrics['test/ece'].add_batch(probs, label=labels) for _ in tf.range(tf.cast(steps_per_eval, tf.int32)): strategy.run(step_fn, args=(next(iterator), )) metrics.update({'test/ms_per_example': tf.keras.metrics.Mean()}) train_iterator = iter(train_dataset) start_time = time.time() for epoch in range(initial_epoch, FLAGS.train_epochs): logging.info('Starting to run epoch: %s', epoch) train_step(train_iterator) current_step = (epoch + 1) * steps_per_epoch max_steps = steps_per_epoch * FLAGS.train_epochs time_elapsed = time.time() - start_time steps_per_sec = float(current_step) / time_elapsed eta_seconds = (max_steps - current_step) / steps_per_sec message = ('{:.1%} completion: epoch {:d}/{:d}. {:.1f} steps/s. ' 'ETA: {:.0f} min. Time elapsed: {:.0f} min'.format( current_step / max_steps, epoch + 1, FLAGS.train_epochs, steps_per_sec, eta_seconds / 60, time_elapsed / 60)) logging.info(message) test_iterator = iter(test_dataset) logging.info('Starting to run eval of epoch: %s', epoch) test_start_time = time.time() test_step(test_iterator) ms_per_example = (time.time() - test_start_time) * 1e6 / test_batch_size metrics['test/ms_per_example'].update_state(ms_per_example) logging.info('Train Loss: %.4f, Accuracy: %.2f%%', metrics['train/loss'].result(), metrics['train/accuracy'].result() * 100) logging.info('Test NLL: %.4f, Accuracy: %.2f%%', metrics['test/negative_log_likelihood'].result(), metrics['test/accuracy'].result() * 100) for i in range(FLAGS.ensemble_size): logging.info( 'Member %d Test Loss: %.4f, Accuracy: %.2f%%', i, metrics['test/nll_member_{}'.format(i)].result(), metrics['test/accuracy_member_{}'.format(i)].result() * 100) total_results = { name: metric.result() for name, metric in metrics.items() } # Results from Robustness Metrics themselves return a dict, so flatten them. total_results = utils.flatten_dictionary(total_results) with summary_writer.as_default(): for name, result in total_results.items(): tf.summary.scalar(name, result, step=epoch + 1) for _, metric in metrics.items(): metric.reset_states() if (FLAGS.checkpoint_interval > 0 and (epoch + 1) % FLAGS.checkpoint_interval == 0): checkpoint_name = checkpoint.save( os.path.join(FLAGS.output_dir, 'checkpoint')) logging.info('Saved checkpoint to %s', checkpoint_name) final_save_name = os.path.join(FLAGS.output_dir, 'model') model.save(final_save_name) logging.info('Saved model to %s', final_save_name) with summary_writer.as_default(): hp.hparams({ 'base_learning_rate': FLAGS.base_learning_rate, 'one_minus_momentum': FLAGS.one_minus_momentum, 'l2': FLAGS.l2, 'batch_repetitions': FLAGS.batch_repetitions, })
def write_hparams_v2(writer, hparams: dict): hparams = _copy_and_clean_hparams(hparams) hparams = _set_precision_if_missing(hparams) with writer.as_default(): hp.hparams(hparams)
def main(): print("Tensorflow version {}".format(tf.__version__)) # Configuration MODEL_NAME = "ALAE_CONV_V1" generate_mnist_samples = False generate_samples_tensorboard = True PRINT_IT = 50 RESULT_IT = 100 SAVE_WEIGHT_IT = 5000 # Network configuration & hyper parameters EPOCHS = 100 BATCH_SIZE = 128 Z_DIM = 100 LATENT_DIM = 50 GAMMA_GP = 10 K_RECONST_KL = 0.5 # Latent space quality, Pure reconstruction & Kullback Leibler ratio # Learning Rate for Discriminator, Generator, Latent Space LR_D_G_L = [0.0001,0.0004,0.0002] # Best # # Manage folders # original_mnist_samples = os.path.join("results", "mnist_original") current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") mnist_samples_ = os.path.join("results", MODEL_NAME + "_" + current_time) folder_to_create = [original_mnist_samples,mnist_samples_] # Create folders for folder in folder_to_create: if not os.path.exists(folder): os.makedirs(folder) # Define folders checkpoint_path = os.path.join("checkpoint", MODEL_NAME) original_mnist_samples = os.path.join(original_mnist_samples, "samples_{}.png") mnist_samples = os.path.join(mnist_samples_, "alae_samples_{}.png") static_mnist_samples = os.path.join(mnist_samples_, "static_alae_samples_{}.png") train_log_dir = os.path.join("logs", "tensorboard", MODEL_NAME + "_" + current_time) train_summary_writer = tf.summary.create_file_writer(train_log_dir) # Log hyper parameters HP_BATCH_SIZE = hp.HParam('HP_BATCH_SIZE', hp.Discrete([64,128,256,512,1024])) HP_Z_DIM = hp.HParam('HP_Z_DIM', hp.Discrete([50,100,200])) HP_LATENT_DIM = hp.HParam('HP_LATENT_DIM', hp.Discrete([30,50,70])) HP_GAMMA_GP = hp.HParam('HP_GAMMA_GP', hp.Discrete([2,5,10])) HP_K_RECONST_KL = hp.HParam('HP_K_RECONST_KL', hp.RealInterval (0.,1.)) HP_LR_GENERATOR = hp.HParam('HP_LR_GENERATOR', hp.RealInterval (0.,0.1)) HP_LR_DISCRIMINATOR = hp.HParam('HP_LR_DISCRIMINATOR', hp.RealInterval (0.,0.1)) HP_LR_LATENT = hp.HParam('HP_LR_LATENT', hp.RealInterval (0.,0.1)) hparams = { HP_BATCH_SIZE: BATCH_SIZE, HP_Z_DIM: Z_DIM, HP_LATENT_DIM: LATENT_DIM, HP_GAMMA_GP: GAMMA_GP, HP_K_RECONST_KL: K_RECONST_KL, HP_LR_DISCRIMINATOR: LR_D_G_L[0], HP_LR_GENERATOR: LR_D_G_L[1], HP_LR_LATENT: LR_D_G_L[2], } # Log hyper parameters METRIC_LATENT_LOST = 'latent_loss' with train_summary_writer.as_default(): hp.hparams_config(hparams, metrics=[hp.Metric(METRIC_LATENT_LOST, display_name='Latent Loss')]) hp.hparams(hparams, trial_id = MODEL_NAME + "_" + current_time) tf.summary.scalar(METRIC_LATENT_LOST, 0, step=1) # Do useful stuff seed = 2020 np.random.seed(seed) tf.random.set_seed(seed) # # Load data # (x_train, _), (_, _) = datasets.mnist.load_data() # Prepare the data train_dataset = tf.data.Dataset.from_tensor_slices(x_train) train_dataset = train_dataset.batch(BATCH_SIZE) train_dataset = train_dataset.shuffle(buffer_size=1024) train_dataset = train_dataset.map(utils.tf_process_images) # x_train = utils.process_images(x_train) IMAGE_DIM = 32 * 32 # x_train = tf.dtypes.cast(tf.reshape(x_train, (len(x_train), IMAGE_DIM) ), tf.float32) # A way to plot some real examples of MNIST if generate_mnist_samples: for index in range(16): source_indexes = np.random.permutation(len(x_train.numpy()))[0:64] utils.plot_mnist_grid(x_train.numpy()[source_indexes], target_file = original_mnist_samples.format(index)) # # Prepare the models # conf_dict = {"Z_DIM": Z_DIM, "LATENT_DIM": LATENT_DIM, "IMAGE_DIM": IMAGE_DIM, "GAMMA_GP": GAMMA_GP, "LR_D_G_L": LR_D_G_L, "K_RECONST_KL": K_RECONST_KL, } generator = alae_tf2_models.Generator(conf_dict,train_summary_writer) discriminator = alae_tf2_models.Discriminator(conf_dict,train_summary_writer) E_encoder = alae_tf2_models.E_encoder(conf_dict,train_summary_writer) F_encoder = alae_tf2_models.F_encoder(conf_dict,train_summary_writer) alae_helper = alae.alae_helper({"generator":generator, "discriminator":discriminator, "E_encoder":E_encoder, "F_encoder":F_encoder,}, conf_dict) # Prepate to save the weights checkpoint = tf.train.Checkpoint(generator_optimizer=generator, discriminator_optimizer=discriminator, E_encoder_optimizer=E_encoder, F_encoder_optimizer=F_encoder, generator=generator, discriminator=discriminator, E_encoder=E_encoder, F_encoder=F_encoder, step=tf.Variable(1)) manager = tf.train.CheckpointManager(checkpoint, checkpoint_path, max_to_keep=3) z_samples_static = alae_helper.sample_Z(64, Z_DIM) # ------------------------------------------------------------------------------------ # Start of the training loop # ------------------------------------------------------------------------------------ it = 1 for epoch in range(EPOCHS): for x in train_dataset: if it % RESULT_IT == 0: # 8x8 = 64 as a grid containing figures 0 to 9 z_samples = alae_helper.sample_Z(64, Z_DIM) samples = generator( F_encoder(z_samples, training=False), training = False) img = samples.numpy() utils.plot_mnist_grid(img, target_file=mnist_samples.format(str(it).zfill(3))) samples = generator(F_encoder(z_samples_static, training=False), training=False) img = samples.numpy() utils.plot_mnist_grid(img, target_file=static_mnist_samples.format(str(it).zfill(3))) if generate_samples_tensorboard: # Add results into Tensorboard # (batch_size, height, width, channels) img = np.reshape(img, [64,32,32,1] ) with train_summary_writer.as_default(): tf.summary.image("Generated Image", img, step=it) # The job is done here, x are real samples losses = alae_helper.trainstep(x) if it % PRINT_IT == 0: print('Epoch: {} it: {} L_loss: {:.4f} D_loss: {:.4f} G_loss: {:.4f}'.format(1 + (it * BATCH_SIZE) // x_train.shape[0], it, losses["latent"], losses["disc"], losses["gen"])) with train_summary_writer.as_default(): tf.summary.scalar('discriminator_loss', losses["disc"], step=it) tf.summary.scalar('generator_loss', losses["gen"], step=it) tf.summary.scalar('latent_loss', losses["latent"], step=it) tf.summary.scalar('latent_loss_reconst', losses["latent_reconst"], step=it) tf.summary.scalar('latent_loss_kl', losses["latent_kl"], step=it) tf.summary.histogram('real_samples', x, step=it) if it % SAVE_WEIGHT_IT == 0: # Save the weights checkpoint.step.assign_add(1) save_path = manager.save() print("Saved checkpoint for step {}: {}".format(int(checkpoint.step), save_path)) it+=1 # Next iteration
def main(argv): del argv # unused arg tf.io.gfile.makedirs(FLAGS.output_dir) logging.info('Saving checkpoints at %s', FLAGS.output_dir) tf.random.set_seed(FLAGS.seed) if FLAGS.use_gpu: logging.info('Use GPU') strategy = tf.distribute.MirroredStrategy() else: logging.info('Use TPU at %s', FLAGS.tpu if FLAGS.tpu is not None else 'local') resolver = tf.distribute.cluster_resolver.TPUClusterResolver( tpu=FLAGS.tpu) tf.config.experimental_connect_to_cluster(resolver) tf.tpu.experimental.initialize_tpu_system(resolver) strategy = tf.distribute.TPUStrategy(resolver) batch_size = FLAGS.per_core_batch_size * FLAGS.num_cores train_dataset_builder = ub.datasets.ClincIntentDetectionDataset( split='train', data_dir=FLAGS.data_dir, data_mode='ind') ind_dataset_builder = ub.datasets.ClincIntentDetectionDataset( split='test', data_dir=FLAGS.data_dir, data_mode='ind') ood_dataset_builder = ub.datasets.ClincIntentDetectionDataset( split='test', data_dir=FLAGS.data_dir, data_mode='ood') all_dataset_builder = ub.datasets.ClincIntentDetectionDataset( split='test', data_dir=FLAGS.data_dir, data_mode='all') dataset_builders = { 'clean': ind_dataset_builder, 'ood': ood_dataset_builder, 'all': all_dataset_builder } train_dataset = train_dataset_builder.load( batch_size=FLAGS.per_core_batch_size) ds_info = train_dataset_builder.tfds_info feature_size = ds_info.metadata['feature_size'] # num_classes is number of valid intents plus out-of-scope intent num_classes = ds_info.features['intent_label'].num_classes + 1 steps_per_epoch = train_dataset_builder.num_examples // batch_size test_datasets = {} steps_per_eval = {} for dataset_name, dataset_builder in dataset_builders.items(): test_datasets[dataset_name] = dataset_builder.load( batch_size=FLAGS.eval_batch_size) steps_per_eval[dataset_name] = (dataset_builder.num_examples // FLAGS.eval_batch_size) if FLAGS.use_bfloat16: policy = tf.keras.mixed_precision.experimental.Policy('mixed_bfloat16') tf.keras.mixed_precision.experimental.set_policy(policy) summary_writer = tf.summary.create_file_writer( os.path.join(FLAGS.output_dir, 'summaries')) with strategy.scope(): logging.info('Building BERT model') bert_config_dir, bert_ckpt_dir = resolve_bert_ckpt_and_config_dir( FLAGS.bert_dir, FLAGS.bert_config_dir, FLAGS.bert_ckpt_dir) bert_config = bert_utils.create_config(bert_config_dir) bert_config.hidden_dropout_prob = FLAGS.dropout_rate bert_config.attention_probs_dropout_prob = FLAGS.dropout_rate model, bert_encoder = ub.models.DropoutBertBuilder( num_classes=num_classes, bert_config=bert_config, use_mc_dropout_mha=FLAGS.use_mc_dropout_mha, use_mc_dropout_att=FLAGS.use_mc_dropout_att, use_mc_dropout_ffn=FLAGS.use_mc_dropout_ffn, use_mc_dropout_output=FLAGS.use_mc_dropout_output, channel_wise_dropout_mha=FLAGS.channel_wise_dropout_mha, channel_wise_dropout_att=FLAGS.channel_wise_dropout_att, channel_wise_dropout_ffn=FLAGS.channel_wise_dropout_ffn) # Create an AdamW optimizer with beta_2=0.999, epsilon=1e-6. optimizer = bert_utils.create_optimizer( FLAGS.base_learning_rate, steps_per_epoch=steps_per_epoch, epochs=FLAGS.train_epochs, warmup_proportion=FLAGS.warmup_proportion, beta_1=1.0 - FLAGS.one_minus_momentum) logging.info('Model input shape: %s', model.input_shape) logging.info('Model output shape: %s', model.output_shape) logging.info('Model number of weights: %s', model.count_params()) metrics = { 'train/negative_log_likelihood': tf.keras.metrics.Mean(), 'train/accuracy': tf.keras.metrics.SparseCategoricalAccuracy(), 'train/loss': tf.keras.metrics.Mean(), 'train/ece': rm.metrics.ExpectedCalibrationError(num_bins=FLAGS.num_bins), 'test/negative_log_likelihood': tf.keras.metrics.Mean(), 'test/accuracy': tf.keras.metrics.SparseCategoricalAccuracy(), 'test/ece': rm.metrics.ExpectedCalibrationError(num_bins=FLAGS.num_bins), } for dataset_name, test_dataset in test_datasets.items(): if dataset_name != 'clean': metrics.update({ 'test/nll_{}'.format(dataset_name): tf.keras.metrics.Mean(), 'test/accuracy_{}'.format(dataset_name): tf.keras.metrics.SparseCategoricalAccuracy(), 'test/ece_{}'.format(dataset_name): rm.metrics.ExpectedCalibrationError( num_bins=FLAGS.num_bins) }) checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer) latest_checkpoint = tf.train.latest_checkpoint(FLAGS.output_dir) initial_epoch = 0 if latest_checkpoint: # checkpoint.restore must be within a strategy.scope() so that optimizer # slot variables are mirrored. checkpoint.restore(latest_checkpoint) logging.info('Loaded checkpoint %s', latest_checkpoint) initial_epoch = optimizer.iterations.numpy() // steps_per_epoch else: # load BERT from initial checkpoint bert_checkpoint = tf.train.Checkpoint(model=bert_encoder) bert_checkpoint.restore( bert_ckpt_dir).assert_existing_objects_matched() logging.info('Loaded BERT checkpoint %s', bert_ckpt_dir) # Finally, define OOD metrics outside the accelerator scope for CPU eval. metrics.update({ 'test/auroc_all': tf.keras.metrics.AUC(curve='ROC'), 'test/auprc_all': tf.keras.metrics.AUC(curve='PR') }) @tf.function def train_step(iterator): """Training StepFn.""" def step_fn(inputs): """Per-Replica StepFn.""" features, labels = bert_utils.create_feature_and_label( inputs, feature_size) with tf.GradientTape() as tape: # Set learning phase to enable dropout etc during training. logits = model(features, training=True) if FLAGS.use_bfloat16: logits = tf.cast(logits, tf.float32) negative_log_likelihood = tf.reduce_mean( tf.keras.losses.sparse_categorical_crossentropy( labels, logits, from_logits=True)) l2_loss = sum(model.losses) loss = negative_log_likelihood + l2_loss # Scale the loss given the TPUStrategy will reduce sum all gradients. scaled_loss = loss / strategy.num_replicas_in_sync grads = tape.gradient(scaled_loss, model.trainable_variables) optimizer.apply_gradients(zip(grads, model.trainable_variables)) probs = tf.nn.softmax(logits) metrics['train/ece'].add_batch(probs, label=labels) metrics['train/loss'].update_state(loss) metrics['train/negative_log_likelihood'].update_state( negative_log_likelihood) metrics['train/accuracy'].update_state(labels, logits) for _ in tf.range(tf.cast(steps_per_epoch, tf.int32)): strategy.run(step_fn, args=(next(iterator), )) @tf.function def test_step(iterator, dataset_name, num_steps): """Evaluation StepFn.""" def step_fn(inputs): """Per-Replica StepFn.""" features, labels = bert_utils.create_feature_and_label( inputs, feature_size) # Compute ensemble prediction over Monte Carlo dropout samples. logits_list = [] for _ in range(FLAGS.num_dropout_samples): logits = model(features, training=False) if FLAGS.use_bfloat16: logits = tf.cast(logits, tf.float32) logits_list.append(logits) # Logits dimension is (num_samples, batch_size, num_classes). logits_list = tf.stack(logits_list, axis=0) probs_list = tf.nn.softmax(logits_list) probs = tf.reduce_mean(probs_list, axis=0) labels_broadcasted = tf.broadcast_to( labels, [FLAGS.num_dropout_samples, labels.shape[0]]) log_likelihoods = -tf.keras.losses.sparse_categorical_crossentropy( labels_broadcasted, logits_list, from_logits=True) negative_log_likelihood = tf.reduce_mean( -tf.reduce_logsumexp(log_likelihoods, axis=[0]) + tf.math.log(float(FLAGS.num_dropout_samples))) if dataset_name == 'clean': metrics['test/negative_log_likelihood'].update_state( negative_log_likelihood) metrics['test/accuracy'].update_state(labels, probs) metrics['test/ece'].add_batch(probs, label=labels) else: metrics['test/nll_{}'.format(dataset_name)].update_state( negative_log_likelihood) metrics['test/accuracy_{}'.format(dataset_name)].update_state( labels, probs) metrics['test/ece_{}'.format(dataset_name)].add_batch( probs, label=labels) if dataset_name == 'all': ood_labels = tf.cast(labels == 150, labels.dtype) ood_probs = 1. - tf.reduce_max(probs, axis=-1) metrics['test/auroc_{}'.format(dataset_name)].update_state( ood_labels, ood_probs) metrics['test/auprc_{}'.format(dataset_name)].update_state( ood_labels, ood_probs) for _ in tf.range(tf.cast(num_steps, tf.int32)): step_fn(next(iterator)) train_iterator = iter(train_dataset) start_time = time.time() for epoch in range(initial_epoch, FLAGS.train_epochs): logging.info('Starting to run epoch: %s', epoch) train_step(train_iterator) current_step = (epoch + 1) * steps_per_epoch max_steps = steps_per_epoch * FLAGS.train_epochs time_elapsed = time.time() - start_time steps_per_sec = float(current_step) / time_elapsed eta_seconds = (max_steps - current_step) / steps_per_sec message = ('{:.1%} completion: epoch {:d}/{:d}. {:.1f} steps/s. ' 'ETA: {:.0f} min. Time elapsed: {:.0f} min'.format( current_step / max_steps, epoch + 1, FLAGS.train_epochs, steps_per_sec, eta_seconds / 60, time_elapsed / 60)) logging.info(message) if epoch % FLAGS.evaluation_interval == 0: for dataset_name, test_dataset in test_datasets.items(): test_iterator = iter(test_dataset) logging.info('Testing on dataset %s', dataset_name) logging.info('Starting to run eval at epoch: %s', epoch) test_step(test_iterator, dataset_name, steps_per_eval[dataset_name]) logging.info('Done with testing on %s', dataset_name) logging.info('Train Loss: %.4f, Accuracy: %.2f%%', metrics['train/loss'].result(), metrics['train/accuracy'].result() * 100) logging.info('Test NLL: %.4f, Accuracy: %.2f%%', metrics['test/negative_log_likelihood'].result(), metrics['test/accuracy'].result() * 100) total_results = { name: metric.result() for name, metric in metrics.items() } # Metrics from Robustness Metrics (like ECE) will return a dict with a # single key/value, instead of a scalar. total_results = { k: (list(v.values())[0] if isinstance(v, dict) else v) for k, v in total_results.items() } with summary_writer.as_default(): for name, result in total_results.items(): tf.summary.scalar(name, result, step=epoch + 1) for metric in metrics.values(): metric.reset_states() if (FLAGS.checkpoint_interval > 0 and (epoch + 1) % FLAGS.checkpoint_interval == 0): checkpoint_name = checkpoint.save( os.path.join(FLAGS.output_dir, 'checkpoint')) logging.info('Saved checkpoint to %s', checkpoint_name) with summary_writer.as_default(): hp.hparams({ 'base_learning_rate': FLAGS.base_learning_rate, 'one_minus_momentum': FLAGS.one_minus_momentum, 'dropout_rate': FLAGS.dropout_rate, 'num_dropout_samples': FLAGS.num_dropout_samples, })
def main(_): configure_environment(FLAGS.fp16_run) hparams = { HP_TOKEN_TYPE: HP_TOKEN_TYPE.domain.values[1], # Preprocessing HP_MEL_BINS: HP_MEL_BINS.domain.values[0], HP_FRAME_LENGTH: HP_FRAME_LENGTH.domain.values[0], HP_FRAME_STEP: HP_FRAME_STEP.domain.values[0], HP_HERTZ_LOW: HP_HERTZ_LOW.domain.values[0], HP_HERTZ_HIGH: HP_HERTZ_HIGH.domain.values[0], # Model HP_EMBEDDING_SIZE: HP_EMBEDDING_SIZE.domain.values[0], HP_ENCODER_LAYERS: HP_ENCODER_LAYERS.domain.values[0], HP_ENCODER_SIZE: HP_ENCODER_SIZE.domain.values[0], HP_PROJECTION_SIZE: HP_PROJECTION_SIZE.domain.values[0], HP_TIME_REDUCT_INDEX: HP_TIME_REDUCT_INDEX.domain.values[0], HP_TIME_REDUCT_FACTOR: HP_TIME_REDUCT_FACTOR.domain.values[0], HP_PRED_NET_LAYERS: HP_PRED_NET_LAYERS.domain.values[0], HP_PRED_NET_SIZE: HP_PRED_NET_SIZE.domain.values[0], HP_JOINT_NET_SIZE: HP_JOINT_NET_SIZE.domain.values[0], HP_LEARNING_RATE: HP_LEARNING_RATE.domain.values[0] } with tf.summary.create_file_writer( os.path.join(FLAGS.tb_log_dir, 'hparams_tuning')).as_default(): hp.hparams_config( hparams=[ HP_TOKEN_TYPE, HP_VOCAB_SIZE, HP_EMBEDDING_SIZE, HP_ENCODER_LAYERS, HP_ENCODER_SIZE, HP_PROJECTION_SIZE, HP_TIME_REDUCT_INDEX, HP_TIME_REDUCT_FACTOR, HP_PRED_NET_LAYERS, HP_PRED_NET_SIZE, HP_JOINT_NET_SIZE ], metrics=[ hp.Metric(METRIC_ACCURACY, display_name='Accuracy'), hp.Metric(METRIC_CER, display_name='CER'), hp.Metric(METRIC_WER, display_name='WER'), ], ) _hparams = {k.name: v for k, v in hparams.items()} if len(FLAGS) == 0: gpus = [ x.name.strip('/physical_device:') for x in tf.config.experimental.list_physical_devices('GPU') ] else: gpus = ['GPU:' + str(gpu_id) for gpu_id in FLAGS.gpus] strategy = tf.distribute.MirroredStrategy(devices=gpus) # strategy = None dtype = tf.float32 if FLAGS.fp16_run: dtype = tf.float16 # initializer = tf.keras.initializers.RandomUniform( # minval=-0.1, maxval=0.1) initializer = None if FLAGS.checkpoint is not None: checkpoint_dir = os.path.dirname(os.path.realpath(FLAGS.checkpoint)) _hparams = model_utils.load_hparams(checkpoint_dir) encoder_fn, idx_to_text, vocab_size = encoding.load_encoder( checkpoint_dir, hparams=_hparams) if strategy is not None: with strategy.scope(): model = build_keras_model(_hparams, initializer=initializer, dtype=dtype) model.load_weights(FLAGS.checkpoint) else: model = build_keras_model(_hparams, initializer=initializer, dtype=dtype) model.load_weights(FLAGS.checkpoint) logging.info('Restored weights from {}.'.format(FLAGS.checkpoint)) else: os.makedirs(FLAGS.output_dir, exist_ok=True) shutil.copy(os.path.join(FLAGS.data_dir, 'encoder.subwords'), os.path.join(FLAGS.output_dir, 'encoder.subwords')) encoder_fn, idx_to_text, vocab_size = encoding.load_encoder( FLAGS.output_dir, hparams=_hparams) _hparams[HP_VOCAB_SIZE.name] = vocab_size if strategy is not None: with strategy.scope(): model = build_keras_model(_hparams, initializer=initializer, dtype=dtype) else: model = build_keras_model(_hparams, initializer=initializer, dtype=dtype) model_utils.save_hparams(_hparams, FLAGS.output_dir) logging.info('Using {} encoder with vocab size: {}'.format( _hparams[HP_TOKEN_TYPE.name], vocab_size)) loss_fn = get_loss_fn( reduction_factor=_hparams[HP_TIME_REDUCT_FACTOR.name]) start_token = encoder_fn('')[0] decode_fn = decoding.greedy_decode_fn(model, start_token=start_token) accuracy_fn = metrics.build_accuracy_fn(decode_fn) cer_fn = metrics.build_cer_fn(decode_fn, idx_to_text) wer_fn = metrics.build_wer_fn(decode_fn, idx_to_text) optimizer = tf.keras.optimizers.SGD(_hparams[HP_LEARNING_RATE.name], momentum=0.9) if FLAGS.fp16_run: optimizer = mixed_precision.LossScaleOptimizer(optimizer, loss_scale='dynamic') encoder = model.layers[2] prediction_network = model.layers[3] encoder.summary() prediction_network.summary() model.summary() dev_dataset, _ = get_dataset(FLAGS.data_dir, 'dev', batch_size=FLAGS.batch_size, n_epochs=FLAGS.n_epochs, strategy=strategy, max_size=FLAGS.eval_size) # dev_steps = dev_specs['size'] // FLAGS.batch_size log_dir = os.path.join(FLAGS.tb_log_dir, datetime.now().strftime('%Y%m%d-%H%M%S')) with tf.summary.create_file_writer(log_dir).as_default(): hp.hparams(hparams) if FLAGS.mode == 'train': train_dataset, _ = get_dataset(FLAGS.data_dir, 'train', batch_size=FLAGS.batch_size, n_epochs=FLAGS.n_epochs, strategy=strategy) # train_steps = train_specs['size'] // FLAGS.batch_size os.makedirs(FLAGS.output_dir, exist_ok=True) checkpoint_template = os.path.join( FLAGS.output_dir, 'checkpoint_{step}_{val_loss:.4f}.hdf5') run_training(model=model, optimizer=optimizer, loss_fn=loss_fn, train_dataset=train_dataset, batch_size=FLAGS.batch_size, n_epochs=FLAGS.n_epochs, checkpoint_template=checkpoint_template, strategy=strategy, steps_per_log=FLAGS.steps_per_log, steps_per_checkpoint=FLAGS.steps_per_checkpoint, eval_dataset=dev_dataset, train_metrics=[], eval_metrics=[accuracy_fn, cer_fn, wer_fn], gpus=gpus) elif FLAGS.mode == 'eval' or FLAGS.mode == 'test': if FLAGS.checkpoint is None: raise Exception( 'You must provide a checkpoint to perform eval.') if FLAGS.mode == 'test': dataset, test_specs = get_dataset(FLAGS.data_dir, 'test', batch_size=FLAGS.batch_size, n_epochs=FLAGS.n_epochs) else: dataset = dev_dataset eval_start_time = time.time() eval_loss, eval_metrics_results = run_evaluate( model=model, optimizer=optimizer, loss_fn=loss_fn, eval_dataset=dataset, batch_size=FLAGS.batch_size, strategy=strategy, metrics=[accuracy_fn, cer_fn, wer_fn], gpus=gpus) validation_log_str = 'VALIDATION RESULTS: Time: {:.4f}, Loss: {:.4f}'.format( time.time() - eval_start_time, eval_loss) for metric_name, metric_result in eval_metrics_results.items(): validation_log_str += ', {}: {:.4f}'.format( metric_name, metric_result) print(validation_log_str)
def _train_n_steps(self, num_steps: int): """Runs training for `num_steps` steps. Also prints/logs updates about training progress, and summarizes training output (if output is returned from `self.trainer.train()`, and if `self.summary_dir` is set). Args: num_steps: An integer specifying how many steps of training to run. Raises: RuntimeError: If `global_step` is not properly incremented by `num_steps` after calling `self.trainer.train(num_steps)`. """ if not self.step_timer: self.step_timer = StepTimer(self.global_step) current_step = self.global_step.numpy() with self.summary_manager.summary_writer().as_default(): should_record = False # Allows static optimization in no-summary cases. write_hparams = False if self.summary_interval: # Create a predicate to determine when summaries should be written. should_record = lambda: (self.global_step % self. summary_interval == 0) write_hparams = lambda: (self.global_step == 0) with tf.summary.record_if(write_hparams): hp.hparams(self.hparams) with tf.summary.record_if(should_record): num_steps_tensor = tf.convert_to_tensor(num_steps, dtype=tf.int32) train_output = self.trainer.train(num_steps_tensor) # Verify that global_step was updated properly, then update current_step. expected_step = current_step + num_steps if self.global_step.numpy() != expected_step: message = ( f"`trainer.train({num_steps})` did not update `global_step` by " f"{num_steps}. Old value was {current_step}, expected updated value " f"to be {expected_step}, but it was {self.global_step.numpy()}." ) logging.warning(message) return train_output = train_output or {} for action in self.train_actions: action(train_output) train_output = tf.nest.map_structure(utils.get_value, train_output) current_step = expected_step steps_per_second = self.step_timer.steps_per_second() _log(f"train | step: {current_step: 6d} | " f"steps/sec: {steps_per_second: 6.1f} | " f"output: {_format_output(train_output)}") train_output["steps_per_second"] = steps_per_second #Create logs matching tensorboard log parser format #see tensorboard_for_parser.md train_output["global_step/sec"] = steps_per_second train_output[ "examples/sec"] = steps_per_second * self.hparams["batch_size"] train_output["loss"] = train_output["total_loss"] self.summary_manager.write_summaries(train_output) self.summary_manager.flush()
def main(argv): del argv # unused arg tf.io.gfile.makedirs(FLAGS.output_dir) logging.info('Saving checkpoints at %s', FLAGS.output_dir) # Set seeds tf.random.set_seed(FLAGS.seed) np.random.seed(FLAGS.seed) torch.manual_seed(FLAGS.seed) # Resolve CUDA device(s) if FLAGS.use_gpu and torch.cuda.is_available(): print('Running model with CUDA.') device = 'cuda:0' else: print('Running model on CPU.') device = 'cpu' train_batch_size = FLAGS.train_batch_size eval_batch_size = FLAGS.eval_batch_size // FLAGS.num_dropout_samples_eval # As per the Kaggle challenge, we have split sizes: # train: 35,126 # validation: 10,906 # test: 42,670 ds_info = tfds.builder('diabetic_retinopathy_detection').info steps_per_epoch = ds_info.splits['train'].num_examples // train_batch_size steps_per_validation_eval = (ds_info.splits['validation'].num_examples // eval_batch_size) steps_per_test_eval = ds_info.splits['test'].num_examples // eval_batch_size data_dir = FLAGS.data_dir dataset_train_builder = ub.datasets.get('diabetic_retinopathy_detection', split='train', data_dir=data_dir) dataset_train = dataset_train_builder.load(batch_size=train_batch_size) dataset_validation_builder = ub.datasets.get( 'diabetic_retinopathy_detection', split='validation', data_dir=data_dir, is_training=not FLAGS.use_validation) validation_batch_size = (eval_batch_size if FLAGS.use_validation else train_batch_size) dataset_validation = dataset_validation_builder.load( batch_size=validation_batch_size) if not FLAGS.use_validation: # Note that this will not create any mixed batches of train and validation # images. dataset_train = dataset_train.concatenate(dataset_validation) dataset_test_builder = ub.datasets.get('diabetic_retinopathy_detection', split='test', data_dir=data_dir) dataset_test = dataset_test_builder.load(batch_size=eval_batch_size) summary_writer = tf.summary.create_file_writer( os.path.join(FLAGS.output_dir, 'summaries')) # MC Dropout ResNet50 based on PyTorch Vision implementation logging.info('Building Torch ResNet-50 MC Dropout model.') model = ub.models.resnet50_dropout_torch(num_classes=1, dropout_rate=FLAGS.dropout_rate) logging.info('Model number of weights: %s', torch_utils.count_parameters(model)) # Linearly scale learning rate and the decay epochs by vanilla settings. base_lr = FLAGS.base_learning_rate optimizer = torch.optim.SGD(model.parameters(), lr=base_lr, momentum=1.0 - FLAGS.one_minus_momentum, nesterov=True) steps_to_lr_peak = int(steps_per_epoch * FLAGS.lr_warmup_epochs) scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts( optimizer, steps_to_lr_peak, T_mult=2) model = model.to(device) metrics = utils.get_diabetic_retinopathy_base_metrics( use_tpu=False, num_bins=FLAGS.num_bins, use_validation=FLAGS.use_validation) # Define additional metrics that would fail in a TF TPU implementation. metrics.update( utils.get_diabetic_retinopathy_cpu_metrics( use_validation=FLAGS.use_validation)) # Initialize loss function based on class reweighting setting loss_fn = torch.nn.BCELoss() sigmoid = torch.nn.Sigmoid() max_steps = steps_per_epoch * FLAGS.train_epochs image_h = 512 image_w = 512 def run_train_epoch(iterator): def train_step(inputs): images = inputs['features'] labels = inputs['labels'] images = torch.from_numpy(images._numpy()).view( train_batch_size, 3, # pylint: disable=protected-access image_h, image_w).to(device) labels = torch.from_numpy(labels._numpy()).to(device).float() # pylint: disable=protected-access # Zero the parameter gradients optimizer.zero_grad() # Forward logits = model(images) probs = sigmoid(logits).squeeze(-1) # Add L2 regularization loss to NLL negative_log_likelihood = loss_fn(probs, labels) l2_loss = sum(p.pow(2.0).sum() for p in model.parameters()) loss = negative_log_likelihood + (FLAGS.l2 * l2_loss) # Backward/optimizer loss.backward() optimizer.step() # Convert to NumPy for metrics updates loss = loss.detach() negative_log_likelihood = negative_log_likelihood.detach() labels = labels.detach() probs = probs.detach() if device != 'cpu': loss = loss.cpu() negative_log_likelihood = negative_log_likelihood.cpu() labels = labels.cpu() probs = probs.cpu() loss = loss.numpy() negative_log_likelihood = negative_log_likelihood.numpy() labels = labels.numpy() probs = probs.numpy() metrics['train/loss'].update_state(loss) metrics['train/negative_log_likelihood'].update_state( negative_log_likelihood) metrics['train/accuracy'].update_state(labels, probs) metrics['train/auprc'].update_state(labels, probs) metrics['train/auroc'].update_state(labels, probs) metrics['train/ece'].add_batch(probs, label=labels) for step in range(steps_per_epoch): train_step(next(iterator)) if step % 100 == 0: current_step = (epoch + 1) * step time_elapsed = time.time() - start_time steps_per_sec = float(current_step) / time_elapsed eta_seconds = (max_steps - current_step ) / steps_per_sec if steps_per_sec else 0 message = ( '{:.1%} completion: epoch {:d}/{:d}. {:.1f} steps/s. ' 'ETA: {:.0f} min. Time elapsed: {:.0f} min'.format( current_step / max_steps, epoch + 1, FLAGS.train_epochs, steps_per_sec, eta_seconds / 60, time_elapsed / 60)) logging.info(message) def run_eval_epoch(iterator, dataset_split, num_steps): def eval_step(inputs, model): images = inputs['features'] labels = inputs['labels'] images = torch.from_numpy(images._numpy()).view( eval_batch_size, 3, # pylint: disable=protected-access image_h, image_w).to(device) labels = torch.from_numpy( labels._numpy()).to(device).float().unsqueeze(-1) # pylint: disable=protected-access with torch.no_grad(): logits = torch.stack([ model(images) for _ in range(FLAGS.num_dropout_samples_eval) ], dim=-1) # Logits dimension is (batch_size, 1, num_dropout_samples). logits = logits.squeeze() # It is now (batch_size, num_dropout_samples). probs = sigmoid(logits) # labels_tiled shape is (batch_size, num_dropout_samples). labels_tiled = torch.tile(labels, (1, FLAGS.num_dropout_samples_eval)) log_likelihoods = -loss_fn(probs, labels_tiled) negative_log_likelihood = torch.mean( -torch.logsumexp(log_likelihoods, dim=-1) + torch.log(torch.tensor(float(FLAGS.num_dropout_samples_eval)))) probs = torch.mean(probs, dim=-1) # Convert to NumPy for metrics updates negative_log_likelihood = negative_log_likelihood.detach() labels = labels.detach() probs = probs.detach() if device != 'cpu': negative_log_likelihood = negative_log_likelihood.cpu() labels = labels.cpu() probs = probs.cpu() negative_log_likelihood = negative_log_likelihood.numpy() labels = labels.numpy() probs = probs.numpy() metrics[dataset_split + '/negative_log_likelihood'].update_state( negative_log_likelihood) metrics[dataset_split + '/accuracy'].update_state(labels, probs) metrics[dataset_split + '/auprc'].update_state(labels, probs) metrics[dataset_split + '/auroc'].update_state(labels, probs) metrics[dataset_split + '/ece'].add_batch(probs, label=labels) for _ in range(num_steps): eval_step(next(iterator), model=model) metrics.update({'test/ms_per_example': tf.keras.metrics.Mean()}) start_time = time.time() initial_epoch = 0 train_iterator = iter(dataset_train) model.train() for epoch in range(initial_epoch, FLAGS.train_epochs): logging.info('Starting to run epoch: %s', epoch + 1) run_train_epoch(train_iterator) if FLAGS.use_validation: validation_iterator = iter(dataset_validation) logging.info('Starting to run validation eval at epoch: %s', epoch + 1) run_eval_epoch(validation_iterator, 'validation', steps_per_validation_eval) test_iterator = iter(dataset_test) logging.info('Starting to run test eval at epoch: %s', epoch + 1) test_start_time = time.time() run_eval_epoch(test_iterator, 'test', steps_per_test_eval) ms_per_example = (time.time() - test_start_time) * 1e6 / eval_batch_size metrics['test/ms_per_example'].update_state(ms_per_example) # Step scheduler scheduler.step() # Log and write to summary the epoch metrics utils.log_epoch_metrics(metrics=metrics, use_tpu=False) total_results = { name: metric.result() for name, metric in metrics.items() } # Metrics from Robustness Metrics (like ECE) will return a dict with a # single key/value, instead of a scalar. total_results = { k: (list(v.values())[0] if isinstance(v, dict) else v) for k, v in total_results.items() } with summary_writer.as_default(): for name, result in total_results.items(): tf.summary.scalar(name, result, step=epoch + 1) for metric in metrics.values(): metric.reset_states() if (FLAGS.checkpoint_interval > 0 and (epoch + 1) % FLAGS.checkpoint_interval == 0): checkpoint_path = os.path.join(FLAGS.output_dir, f'model_{epoch + 1}.pt') torch_utils.checkpoint_torch_model(model=model, optimizer=optimizer, epoch=epoch + 1, checkpoint_path=checkpoint_path) logging.info('Saved Torch checkpoint to %s', checkpoint_path) final_checkpoint_path = os.path.join(FLAGS.output_dir, f'model_{FLAGS.train_epochs}.pt') torch_utils.checkpoint_torch_model(model=model, optimizer=optimizer, epoch=FLAGS.train_epochs, checkpoint_path=final_checkpoint_path) logging.info('Saved last checkpoint to %s', final_checkpoint_path) with summary_writer.as_default(): hp.hparams({ 'base_learning_rate': FLAGS.base_learning_rate, 'one_minus_momentum': FLAGS.one_minus_momentum, 'dropout_rate': FLAGS.dropout_rate, 'l2': FLAGS.l2, 'lr_warmup_epochs': FLAGS.lr_warmup_epochs })
def main(argv): del argv # unused arg tf.io.gfile.makedirs(FLAGS.output_dir) logging.info('Saving checkpoints at %s', FLAGS.output_dir) tf.random.set_seed(FLAGS.seed) if FLAGS.use_gpu: logging.info('Use GPU') strategy = tf.distribute.MirroredStrategy() else: logging.info('Use TPU at %s', FLAGS.tpu if FLAGS.tpu is not None else 'local') resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu=FLAGS.tpu) tf.config.experimental_connect_to_cluster(resolver) tf.tpu.experimental.initialize_tpu_system(resolver) strategy = tf.distribute.TPUStrategy(resolver) batch_size = FLAGS.per_core_batch_size * FLAGS.num_cores test_batch_size = batch_size data_buffer_size = batch_size * 10 train_dataset_builder = ds.WikipediaToxicityDataset( split='train', data_dir=FLAGS.in_dataset_dir, shuffle_buffer_size=data_buffer_size) ind_dataset_builder = ds.WikipediaToxicityDataset( split='test', data_dir=FLAGS.in_dataset_dir, shuffle_buffer_size=data_buffer_size) ood_dataset_builder = ds.CivilCommentsDataset( split='test', data_dir=FLAGS.ood_dataset_dir, shuffle_buffer_size=data_buffer_size) ood_identity_dataset_builder = ds.CivilCommentsIdentitiesDataset( split='test', data_dir=FLAGS.identity_dataset_dir, shuffle_buffer_size=data_buffer_size) train_dataset_builders = { 'wikipedia_toxicity_subtypes': train_dataset_builder } test_dataset_builders = { 'ind': ind_dataset_builder, 'ood': ood_dataset_builder, 'ood_identity': ood_identity_dataset_builder, } if FLAGS.prediction_mode and FLAGS.identity_prediction: for dataset_name in utils.IDENTITY_LABELS: if utils.NUM_EXAMPLES[dataset_name]['test'] > 100: test_dataset_builders[dataset_name] = ds.CivilCommentsIdentitiesDataset( split='test', data_dir=os.path.join( FLAGS.identity_specific_dataset_dir, dataset_name), shuffle_buffer_size=data_buffer_size) for dataset_name in utils.IDENTITY_TYPES: if utils.NUM_EXAMPLES[dataset_name]['test'] > 100: test_dataset_builders[dataset_name] = ds.CivilCommentsIdentitiesDataset( split='test', data_dir=os.path.join( FLAGS.identity_type_dataset_dir, dataset_name), shuffle_buffer_size=data_buffer_size) class_weight = utils.create_class_weight( train_dataset_builders, test_dataset_builders) logging.info('class_weight: %s', str(class_weight)) ds_info = train_dataset_builder.tfds_info # Positive and negative classes. num_classes = ds_info.metadata['num_classes'] train_datasets = {} dataset_steps_per_epoch = {} total_steps_per_epoch = 0 # TODO(jereliu): Apply strategy.experimental_distribute_dataset to the # dataset_builders. for dataset_name, dataset_builder in train_dataset_builders.items(): train_datasets[dataset_name] = dataset_builder.load( batch_size=FLAGS.per_core_batch_size) dataset_steps_per_epoch[dataset_name] = ( dataset_builder.num_examples // batch_size) total_steps_per_epoch += dataset_steps_per_epoch[dataset_name] test_datasets = {} steps_per_eval = {} for dataset_name, dataset_builder in test_dataset_builders.items(): test_datasets[dataset_name] = dataset_builder.load( batch_size=test_batch_size) if dataset_name in ['ind', 'ood', 'ood_identity']: steps_per_eval[dataset_name] = ( dataset_builder.num_examples // test_batch_size) else: steps_per_eval[dataset_name] = ( utils.NUM_EXAMPLES[dataset_name]['test'] // test_batch_size) if FLAGS.use_bfloat16: policy = tf.keras.mixed_precision.experimental.Policy('mixed_bfloat16') tf.keras.mixed_precision.experimental.set_policy(policy) summary_writer = tf.summary.create_file_writer( os.path.join(FLAGS.output_dir, 'summaries')) with strategy.scope(): logging.info('Building BERT %s model', FLAGS.bert_model_type) logging.info('use_gp_layer=%s', FLAGS.use_gp_layer) logging.info('use_spec_norm_att=%s', FLAGS.use_spec_norm_att) logging.info('use_spec_norm_ffn=%s', FLAGS.use_spec_norm_ffn) logging.info('use_layer_norm_att=%s', FLAGS.use_layer_norm_att) logging.info('use_layer_norm_ffn=%s', FLAGS.use_layer_norm_ffn) bert_config_dir, bert_ckpt_dir = utils.resolve_bert_ckpt_and_config_dir( FLAGS.bert_model_type, FLAGS.bert_dir, FLAGS.bert_config_dir, FLAGS.bert_ckpt_dir) bert_config = utils.create_config(bert_config_dir) gp_layer_kwargs = dict( num_inducing=FLAGS.gp_hidden_dim, gp_kernel_scale=FLAGS.gp_scale, gp_output_bias=FLAGS.gp_bias, normalize_input=FLAGS.gp_input_normalization, gp_cov_momentum=FLAGS.gp_cov_discount_factor, gp_cov_ridge_penalty=FLAGS.gp_cov_ridge_penalty) spec_norm_kwargs = dict( iteration=FLAGS.spec_norm_iteration, norm_multiplier=FLAGS.spec_norm_bound) model, bert_encoder = ub.models.SngpBertBuilder( num_classes=num_classes, bert_config=bert_config, gp_layer_kwargs=gp_layer_kwargs, spec_norm_kwargs=spec_norm_kwargs, use_gp_layer=FLAGS.use_gp_layer, use_spec_norm_att=FLAGS.use_spec_norm_att, use_spec_norm_ffn=FLAGS.use_spec_norm_ffn, use_layer_norm_att=FLAGS.use_layer_norm_att, use_layer_norm_ffn=FLAGS.use_layer_norm_ffn, use_spec_norm_plr=FLAGS.use_spec_norm_plr) # Create an AdamW optimizer with beta_2=0.999, epsilon=1e-6. optimizer = utils.create_optimizer( FLAGS.base_learning_rate, steps_per_epoch=total_steps_per_epoch, epochs=FLAGS.train_epochs, warmup_proportion=FLAGS.warmup_proportion, beta_1=1.0 - FLAGS.one_minus_momentum) logging.info('Model input shape: %s', model.input_shape) logging.info('Model output shape: %s', model.output_shape) logging.info('Model number of weights: %s', model.count_params()) metrics = { 'train/negative_log_likelihood': tf.keras.metrics.Mean(), 'train/accuracy': tf.keras.metrics.Accuracy(), 'train/accuracy_weighted': tf.keras.metrics.Accuracy(), 'train/auroc': tf.keras.metrics.AUC(), 'train/loss': tf.keras.metrics.Mean(), 'train/ece': rm.metrics.ExpectedCalibrationError( num_bins=FLAGS.num_bins), 'train/precision': tf.keras.metrics.Precision(), 'train/recall': tf.keras.metrics.Recall(), 'train/f1': tfa_metrics.F1Score( num_classes=num_classes, average='micro', threshold=FLAGS.ece_label_threshold), } checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer) if FLAGS.prediction_mode: latest_checkpoint = tf.train.latest_checkpoint(FLAGS.eval_checkpoint_dir) else: latest_checkpoint = tf.train.latest_checkpoint(FLAGS.output_dir) initial_epoch = 0 if latest_checkpoint: # checkpoint.restore must be within a strategy.scope() so that optimizer # slot variables are mirrored. checkpoint.restore(latest_checkpoint) logging.info('Loaded checkpoint %s', latest_checkpoint) initial_epoch = optimizer.iterations.numpy() // total_steps_per_epoch else: # load BERT from initial checkpoint bert_encoder, _, _ = utils.load_bert_weight_from_ckpt( bert_model=bert_encoder, bert_ckpt_dir=bert_ckpt_dir, repl_patterns=ub.models.bert_sngp.CHECKPOINT_REPL_PATTERNS) logging.info('Loaded BERT checkpoint %s', bert_ckpt_dir) metrics.update({ 'test/negative_log_likelihood': tf.keras.metrics.Mean(), 'test/auroc': tf.keras.metrics.AUC(curve='ROC'), 'test/aupr': tf.keras.metrics.AUC(curve='PR'), 'test/brier': tf.keras.metrics.MeanSquaredError(), 'test/brier_weighted': tf.keras.metrics.MeanSquaredError(), 'test/ece': rm.metrics.ExpectedCalibrationError(num_bins=FLAGS.num_bins), 'test/acc': tf.keras.metrics.Accuracy(), 'test/acc_weighted': tf.keras.metrics.Accuracy(), 'test/eval_time': tf.keras.metrics.Mean(), 'test/stddev': tf.keras.metrics.Mean(), 'test/precision': tf.keras.metrics.Precision(), 'test/recall': tf.keras.metrics.Recall(), 'test/f1': tfa_metrics.F1Score( num_classes=num_classes, average='micro', threshold=FLAGS.ece_label_threshold), 'test/calibration_auroc': tc_metrics.CalibrationAUC(curve='ROC'), 'test/calibration_auprc': tc_metrics.CalibrationAUC(curve='PR') }) for fraction in FLAGS.fractions: metrics.update({ 'test_collab_acc/collab_acc_{}'.format(fraction): rm.metrics.OracleCollaborativeAccuracy( fraction=float(fraction), num_bins=FLAGS.num_bins) }) metrics.update({ 'test_abstain_prec/abstain_prec_{}'.format(fraction): tc_metrics.AbstainPrecision(abstain_fraction=float(fraction)) }) metrics.update({ 'test_abstain_recall/abstain_recall_{}'.format(fraction): tc_metrics.AbstainRecall(abstain_fraction=float(fraction)) }) for dataset_name, test_dataset in test_datasets.items(): if dataset_name != 'ind': metrics.update({ 'test/nll_{}'.format(dataset_name): tf.keras.metrics.Mean(), 'test/auroc_{}'.format(dataset_name): tf.keras.metrics.AUC(curve='ROC'), 'test/aupr_{}'.format(dataset_name): tf.keras.metrics.AUC(curve='PR'), 'test/brier_{}'.format(dataset_name): tf.keras.metrics.MeanSquaredError(), 'test/brier_weighted_{}'.format(dataset_name): tf.keras.metrics.MeanSquaredError(), 'test/ece_{}'.format(dataset_name): rm.metrics.ExpectedCalibrationError(num_bins=FLAGS.num_bins), 'test/acc_{}'.format(dataset_name): tf.keras.metrics.Accuracy(), 'test/acc_weighted_{}'.format(dataset_name): tf.keras.metrics.Accuracy(), 'test/eval_time_{}'.format(dataset_name): tf.keras.metrics.Mean(), 'test/stddev_{}'.format(dataset_name): tf.keras.metrics.Mean(), 'test/precision_{}'.format(dataset_name): tf.keras.metrics.Precision(), 'test/recall_{}'.format(dataset_name): tf.keras.metrics.Recall(), 'test/f1_{}'.format(dataset_name): tfa_metrics.F1Score( num_classes=num_classes, average='micro', threshold=FLAGS.ece_label_threshold), 'test/calibration_auroc_{}'.format(dataset_name): tc_metrics.CalibrationAUC(curve='ROC'), 'test/calibration_auprc_{}'.format(dataset_name): tc_metrics.CalibrationAUC(curve='PR'), }) for fraction in FLAGS.fractions: metrics.update({ 'test_collab_acc/collab_acc_{}_{}'.format(fraction, dataset_name): rm.metrics.OracleCollaborativeAccuracy( fraction=float(fraction), num_bins=FLAGS.num_bins) }) metrics.update({ 'test_abstain_prec/abstain_prec_{}_{}'.format( fraction, dataset_name): tc_metrics.AbstainPrecision(abstain_fraction=float(fraction)) }) metrics.update({ 'test_abstain_recall/abstain_recall_{}_{}'.format( fraction, dataset_name): tc_metrics.AbstainRecall(abstain_fraction=float(fraction)) }) @tf.function def generate_sample_weight(labels, class_weight, label_threshold=0.7): """Generate sample weight for weighted accuracy calculation.""" if label_threshold != 0.7: logging.warning('The class weight was based on `label_threshold` = 0.7, ' 'and weighted accuracy/brier will be meaningless if ' '`label_threshold` is not equal to this value, which is ' 'recommended by Jigsaw Conversation AI team.') labels_int = tf.cast(labels > label_threshold, tf.int32) sample_weight = tf.gather(class_weight, labels_int) return sample_weight @tf.function def train_step(iterator, dataset_name, num_steps): """Training StepFn.""" def step_fn(inputs): """Per-Replica StepFn.""" features, labels, _ = utils.create_feature_and_label(inputs) with tf.GradientTape() as tape: logits = model(features, training=True) if isinstance(logits, (list, tuple)): # If model returns a tuple of (logits, covmat), extract logits logits, _ = logits if FLAGS.use_bfloat16: logits = tf.cast(logits, tf.float32) loss_logits = tf.squeeze(logits, axis=1) if FLAGS.loss_type == 'cross_entropy': logging.info('Using cross entropy loss') negative_log_likelihood = tf.nn.sigmoid_cross_entropy_with_logits( labels, loss_logits) elif FLAGS.loss_type == 'focal_cross_entropy': logging.info('Using focal cross entropy loss') negative_log_likelihood = tfa_losses.sigmoid_focal_crossentropy( labels, loss_logits, alpha=FLAGS.focal_loss_alpha, gamma=FLAGS.focal_loss_gamma, from_logits=True) elif FLAGS.loss_type == 'mse': logging.info('Using mean squared error loss') loss_probs = tf.nn.sigmoid(loss_logits) negative_log_likelihood = tf.keras.losses.mean_squared_error( labels, loss_probs) elif FLAGS.loss_type == 'mae': logging.info('Using mean absolute error loss') loss_probs = tf.nn.sigmoid(loss_logits) negative_log_likelihood = tf.keras.losses.mean_absolute_error( labels, loss_probs) negative_log_likelihood = tf.reduce_mean(negative_log_likelihood) l2_loss = sum(model.losses) loss = negative_log_likelihood + l2_loss # Scale the loss given the TPUStrategy will reduce sum all gradients. scaled_loss = loss / strategy.num_replicas_in_sync grads = tape.gradient(scaled_loss, model.trainable_variables) optimizer.apply_gradients(zip(grads, model.trainable_variables)) probs = tf.nn.sigmoid(logits) # Cast labels to discrete for ECE computation. ece_labels = tf.cast(labels > FLAGS.ece_label_threshold, tf.float32) one_hot_labels = tf.one_hot(tf.cast(ece_labels, tf.int32), depth=num_classes) ece_probs = tf.concat([1. - probs, probs], axis=1) auc_probs = tf.squeeze(probs, axis=1) pred_labels = tf.math.argmax(ece_probs, axis=-1) sample_weight = generate_sample_weight( labels, class_weight['train/{}'.format(dataset_name)], FLAGS.ece_label_threshold) metrics['train/negative_log_likelihood'].update_state( negative_log_likelihood) metrics['train/accuracy'].update_state(labels, pred_labels) metrics['train/accuracy_weighted'].update_state( ece_labels, pred_labels, sample_weight=sample_weight) metrics['train/auroc'].update_state(labels, auc_probs) metrics['train/loss'].update_state(loss) metrics['train/ece'].add_batch(ece_probs, label=ece_labels) metrics['train/precision'].update_state(ece_labels, pred_labels) metrics['train/recall'].update_state(ece_labels, pred_labels) metrics['train/f1'].update_state(one_hot_labels, ece_probs) for _ in tf.range(tf.cast(num_steps, tf.int32)): strategy.run(step_fn, args=(next(iterator),)) @tf.function def test_step(iterator, dataset_name): """Evaluation StepFn.""" def step_fn(inputs): """Per-Replica StepFn.""" features, labels, _ = utils.create_feature_and_label(inputs) eval_start_time = time.time() # Compute ensemble prediction over Monte Carlo forward-pass samples. logits_list = [] stddev_list = [] for _ in range(FLAGS.num_mc_samples): logits = model(features, training=False) if isinstance(logits, (list, tuple)): # If model returns a tuple of (logits, covmat), extract both. logits, covmat = logits else: covmat = tf.eye(test_batch_size) if FLAGS.use_bfloat16: logits = tf.cast(logits, tf.float32) covmat = tf.cast(covmat, tf.float32) logits = ed.layers.utils.mean_field_logits( logits, covmat, mean_field_factor=FLAGS.gp_mean_field_factor) stddev = tf.sqrt(tf.linalg.diag_part(covmat)) logits_list.append(logits) stddev_list.append(stddev) eval_time = (time.time() - eval_start_time) / FLAGS.per_core_batch_size # Logits dimension is (num_samples, batch_size, num_classes). logits_list = tf.stack(logits_list, axis=0) stddev_list = tf.stack(stddev_list, axis=0) stddev = tf.reduce_mean(stddev_list, axis=0) probs_list = tf.nn.sigmoid(logits_list) probs = tf.reduce_mean(probs_list, axis=0) # Cast labels to discrete for ECE computation. ece_labels = tf.cast(labels > FLAGS.ece_label_threshold, tf.float32) one_hot_labels = tf.one_hot(tf.cast(ece_labels, tf.int32), depth=num_classes) ece_probs = tf.concat([1. - probs, probs], axis=1) pred_labels = tf.math.argmax(ece_probs, axis=-1) auc_probs = tf.squeeze(probs, axis=1) # Use normalized binary predictive variance as the confidence score. # Since the prediction variance p*(1-p) is within range (0, 0.25), # normalize it by maximum value so the confidence is between (0, 1). calib_confidence = 1. - probs * (1. - probs) / .25 ce = tf.nn.sigmoid_cross_entropy_with_logits( labels=tf.broadcast_to( labels, [FLAGS.num_mc_samples, labels.shape[0]]), logits=tf.squeeze(logits_list, axis=-1) ) negative_log_likelihood = -tf.reduce_logsumexp( -ce, axis=0) + tf.math.log(float(FLAGS.num_mc_samples)) negative_log_likelihood = tf.reduce_mean(negative_log_likelihood) sample_weight = generate_sample_weight( labels, class_weight['test/{}'.format(dataset_name)], FLAGS.ece_label_threshold) if dataset_name == 'ind': metrics['test/negative_log_likelihood'].update_state( negative_log_likelihood) metrics['test/auroc'].update_state(labels, auc_probs) metrics['test/aupr'].update_state(labels, auc_probs) metrics['test/brier'].update_state(labels, auc_probs) metrics['test/brier_weighted'].update_state( tf.expand_dims(labels, -1), probs, sample_weight=sample_weight) metrics['test/ece'].add_batch(ece_probs, label=ece_labels) metrics['test/acc'].update_state(ece_labels, pred_labels) metrics['test/acc_weighted'].update_state( ece_labels, pred_labels, sample_weight=sample_weight) metrics['test/eval_time'].update_state(eval_time) metrics['test/stddev'].update_state(stddev) metrics['test/precision'].update_state(ece_labels, pred_labels) metrics['test/recall'].update_state(ece_labels, pred_labels) metrics['test/f1'].update_state(one_hot_labels, ece_probs) metrics['test/calibration_auroc'].update_state(ece_labels, pred_labels, calib_confidence) metrics['test/calibration_auprc'].update_state(ece_labels, pred_labels, calib_confidence) for fraction in FLAGS.fractions: metrics['test_collab_acc/collab_acc_{}'.format( fraction)].add_batch(ece_probs, label=ece_labels) metrics['test_abstain_prec/abstain_prec_{}'.format( fraction)].update_state(ece_labels, pred_labels, calib_confidence) metrics['test_abstain_recall/abstain_recall_{}'.format( fraction)].update_state(ece_labels, pred_labels, calib_confidence) else: metrics['test/nll_{}'.format(dataset_name)].update_state( negative_log_likelihood) metrics['test/auroc_{}'.format(dataset_name)].update_state( labels, auc_probs) metrics['test/aupr_{}'.format(dataset_name)].update_state( labels, auc_probs) metrics['test/brier_{}'.format(dataset_name)].update_state( labels, auc_probs) metrics['test/brier_weighted_{}'.format(dataset_name)].update_state( tf.expand_dims(labels, -1), probs, sample_weight=sample_weight) metrics['test/ece_{}'.format(dataset_name)].add_batch( ece_probs, label=ece_labels) metrics['test/acc_{}'.format(dataset_name)].update_state( ece_labels, pred_labels) metrics['test/acc_weighted_{}'.format(dataset_name)].update_state( ece_labels, pred_labels, sample_weight=sample_weight) metrics['test/eval_time_{}'.format(dataset_name)].update_state( eval_time) metrics['test/stddev_{}'.format(dataset_name)].update_state(stddev) metrics['test/precision_{}'.format(dataset_name)].update_state( ece_labels, pred_labels) metrics['test/recall_{}'.format(dataset_name)].update_state( ece_labels, pred_labels) metrics['test/f1_{}'.format(dataset_name)].update_state( one_hot_labels, ece_probs) metrics['test/calibration_auroc_{}'.format(dataset_name)].update_state( ece_labels, pred_labels, calib_confidence) metrics['test/calibration_auprc_{}'.format(dataset_name)].update_state( ece_labels, pred_labels, calib_confidence) for fraction in FLAGS.fractions: metrics['test_collab_acc/collab_acc_{}_{}'.format( fraction, dataset_name)].add_batch(ece_probs, label=ece_labels) metrics['test_abstain_prec/abstain_prec_{}_{}'.format( fraction, dataset_name)].update_state(ece_labels, pred_labels, calib_confidence) metrics['test_abstain_recall/abstain_recall_{}_{}'.format( fraction, dataset_name)].update_state(ece_labels, pred_labels, calib_confidence) strategy.run(step_fn, args=(next(iterator),)) @tf.function def final_eval_step(iterator): """Final Evaluation StepFn to save prediction to directory.""" def step_fn(inputs): bert_features, labels, additional_labels = utils.create_feature_and_label( inputs) logits = model(bert_features, training=False) if isinstance(logits, (list, tuple)): # If model returns a tuple of (logits, covmat), extract both. logits, covmat = logits else: covmat = tf.eye(test_batch_size) if FLAGS.use_bfloat16: logits = tf.cast(logits, tf.float32) covmat = tf.cast(covmat, tf.float32) logits = ed.layers.utils.mean_field_logits( logits, covmat, mean_field_factor=FLAGS.gp_mean_field_factor) features = inputs['input_ids'] return features, logits, labels, additional_labels (per_replica_texts, per_replica_logits, per_replica_labels, per_replica_additional_labels) = ( strategy.run(step_fn, args=(next(iterator),))) if strategy.num_replicas_in_sync > 1: texts_list = tf.concat(per_replica_texts.values, axis=0) logits_list = tf.concat(per_replica_logits.values, axis=0) labels_list = tf.concat(per_replica_labels.values, axis=0) additional_labels_dict = {} for additional_label in utils.IDENTITY_LABELS: if additional_label in per_replica_additional_labels: additional_labels_dict[additional_label] = tf.concat( per_replica_additional_labels[additional_label], axis=0) else: texts_list = per_replica_texts logits_list = per_replica_logits labels_list = per_replica_labels additional_labels_dict = {} for additional_label in utils.IDENTITY_LABELS: if additional_label in per_replica_additional_labels: additional_labels_dict[ additional_label] = per_replica_additional_labels[ additional_label] return texts_list, logits_list, labels_list, additional_labels_dict if FLAGS.prediction_mode: # Prediction and exit. for dataset_name, test_dataset in test_datasets.items(): test_iterator = iter(test_dataset) # pytype: disable=wrong-arg-types message = 'Final eval on dataset {}'.format(dataset_name) logging.info(message) texts_all = [] logits_all = [] labels_all = [] additional_labels_all_dict = {} if 'identity' in dataset_name: for identity_label_name in utils.IDENTITY_LABELS: additional_labels_all_dict[identity_label_name] = [] try: with tf.experimental.async_scope(): for step in range(steps_per_eval[dataset_name]): if step % 20 == 0: message = 'Starting to run eval step {}/{} of dataset: {}'.format( step, steps_per_eval[dataset_name], dataset_name) logging.info(message) (text_step, logits_step, labels_step, additional_labels_dict_step) = final_eval_step(test_iterator) texts_all.append(text_step) logits_all.append(logits_step) labels_all.append(labels_step) if 'identity' in dataset_name: for identity_label_name in utils.IDENTITY_LABELS: additional_labels_all_dict[identity_label_name].append( additional_labels_dict_step[identity_label_name]) except (StopIteration, tf.errors.OutOfRangeError): tf.experimental.async_clear_error() logging.info('Done with eval on %s', dataset_name) texts_all = tf.concat(texts_all, axis=0) logits_all = tf.concat(logits_all, axis=0) labels_all = tf.concat(labels_all, axis=0) additional_labels_all = [] if additional_labels_all_dict: for identity_label_name in utils.IDENTITY_LABELS: additional_labels_all.append( tf.concat( additional_labels_all_dict[identity_label_name], axis=0)) additional_labels_all = tf.convert_to_tensor(additional_labels_all) utils.save_prediction( texts_all.numpy(), path=os.path.join(FLAGS.output_dir, 'texts_{}'.format(dataset_name))) utils.save_prediction( labels_all.numpy(), path=os.path.join(FLAGS.output_dir, 'labels_{}'.format(dataset_name))) utils.save_prediction( logits_all.numpy(), path=os.path.join(FLAGS.output_dir, 'logits_{}'.format(dataset_name))) if 'identity' in dataset_name: utils.save_prediction( additional_labels_all.numpy(), path=os.path.join(FLAGS.output_dir, 'additional_labels_{}'.format(dataset_name))) logging.info('Done with testing on %s', dataset_name) else: # Execute train / eval loop. start_time = time.time() train_iterators = {} for dataset_name, train_dataset in train_datasets.items(): train_iterators[dataset_name] = iter(train_dataset) for epoch in range(initial_epoch, FLAGS.train_epochs): logging.info('Starting to run epoch: %s', epoch) for dataset_name, train_iterator in train_iterators.items(): try: with tf.experimental.async_scope(): train_step( train_iterator, dataset_name, dataset_steps_per_epoch[dataset_name]) current_step = ( epoch * total_steps_per_epoch + dataset_steps_per_epoch[dataset_name]) max_steps = total_steps_per_epoch * FLAGS.train_epochs time_elapsed = time.time() - start_time steps_per_sec = float(current_step) / time_elapsed eta_seconds = (max_steps - current_step) / steps_per_sec message = ('{:.1%} completion: epoch {:d}/{:d}. {:.1f} steps/s. ' 'ETA: {:.0f} min. Time elapsed: {:.0f} min'.format( current_step / max_steps, epoch + 1, FLAGS.train_epochs, steps_per_sec, eta_seconds / 60, time_elapsed / 60)) logging.info(message) except (StopIteration, tf.errors.OutOfRangeError): tf.experimental.async_clear_error() logging.info('Done with testing on %s', dataset_name) if epoch % FLAGS.evaluation_interval == 0: for dataset_name, test_dataset in test_datasets.items(): test_iterator = iter(test_dataset) logging.info('Testing on dataset %s', dataset_name) try: with tf.experimental.async_scope(): for step in range(steps_per_eval[dataset_name]): if step % 20 == 0: logging.info('Starting to run eval step %s/%s of epoch: %s', step, steps_per_eval[dataset_name], epoch) test_step(test_iterator, dataset_name) except (StopIteration, tf.errors.OutOfRangeError): tf.experimental.async_clear_error() logging.info('Done with testing on %s', dataset_name) logging.info('Train Loss: %.4f, ECE: %.2f, Accuracy: %.2f', metrics['train/loss'].result(), metrics['train/ece'].result(), metrics['train/accuracy'].result()) total_results = { name: metric.result() for name, metric in metrics.items() } # Metrics from Robustness Metrics (like ECE) will return a dict with a # single key/value, instead of a scalar. total_results = { k: (list(v.values())[0] if isinstance(v, dict) else v) for k, v in total_results.items() } with summary_writer.as_default(): for name, result in total_results.items(): tf.summary.scalar(name, result, step=epoch + 1) for metric in metrics.values(): metric.reset_states() checkpoint_interval = min(FLAGS.checkpoint_interval, FLAGS.train_epochs) if checkpoint_interval > 0 and (epoch + 1) % checkpoint_interval == 0: checkpoint_name = checkpoint.save( os.path.join(FLAGS.output_dir, 'checkpoint')) logging.info('Saved checkpoint to %s', checkpoint_name) # Save model in SavedModel format on exit. final_save_name = os.path.join(FLAGS.output_dir, 'model') model.save(final_save_name) logging.info('Saved model to %s', final_save_name) with summary_writer.as_default(): hp.hparams({ 'base_learning_rate': FLAGS.base_learning_rate, 'one_minus_momentum': FLAGS.one_minus_momentum, 'gp_mean_field_factor': FLAGS.gp_mean_field_factor, })
def main(argv): del argv # unused arg tf.io.gfile.makedirs(FLAGS.output_dir) logging.info('Saving checkpoints at %s', FLAGS.output_dir) tf.random.set_seed(FLAGS.seed) if FLAGS.use_gpu: logging.info('Use GPU') strategy = tf.distribute.MirroredStrategy() else: logging.info('Use TPU at %s', FLAGS.tpu if FLAGS.tpu is not None else 'local') resolver = tf.distribute.cluster_resolver.TPUClusterResolver( tpu=FLAGS.tpu) tf.config.experimental_connect_to_cluster(resolver) tf.tpu.experimental.initialize_tpu_system(resolver) strategy = tf.distribute.TPUStrategy(resolver) ds_info = tfds.builder(FLAGS.dataset).info batch_size = FLAGS.per_core_batch_size * FLAGS.num_cores train_dataset_size = ds_info.splits['train'].num_examples steps_per_epoch = train_dataset_size // batch_size steps_per_eval = ds_info.splits['test'].num_examples // batch_size num_classes = ds_info.features['label'].num_classes train_dataset = ub.datasets.get( FLAGS.dataset, split=tfds.Split.TRAIN).load(batch_size=batch_size) clean_test_dataset = ub.datasets.get( FLAGS.dataset, split=tfds.Split.TEST).load(batch_size=batch_size) train_dataset = strategy.experimental_distribute_dataset(train_dataset) test_datasets = { 'clean': strategy.experimental_distribute_dataset(clean_test_dataset), } if FLAGS.corruptions_interval > 0: extra_kwargs = {} if FLAGS.dataset == 'cifar100': extra_kwargs['data_dir'] = FLAGS.cifar100_c_path corruption_types, _ = utils.load_corrupted_test_info(FLAGS.dataset) for corruption_type in corruption_types: for severity in range(1, 6): dataset = ub.datasets.get( f'{FLAGS.dataset}_corrupted', corruption_type=corruption_type, severity=severity, split=tfds.Split.TEST, **extra_kwargs).load(batch_size=batch_size) test_datasets[f'{corruption_type}_{severity}'] = ( strategy.experimental_distribute_dataset(dataset)) summary_writer = tf.summary.create_file_writer( os.path.join(FLAGS.output_dir, 'summaries')) with strategy.scope(): logging.info('Building ResNet model') model = ub.models.wide_resnet_variational( input_shape=ds_info.features['image'].shape, depth=28, width_multiplier=10, num_classes=num_classes, prior_stddev=FLAGS.prior_stddev, dataset_size=train_dataset_size, stddev_init=FLAGS.stddev_init) logging.info('Model input shape: %s', model.input_shape) logging.info('Model output shape: %s', model.output_shape) logging.info('Model number of weights: %s', model.count_params()) # Linearly scale learning rate and the decay epochs by vanilla settings. base_lr = FLAGS.base_learning_rate * batch_size / 128 lr_decay_epochs = [(int(start_epoch_str) * FLAGS.train_epochs) // 200 for start_epoch_str in FLAGS.lr_decay_epochs] lr_schedule = ub.schedules.WarmUpPiecewiseConstantSchedule( steps_per_epoch, base_lr, decay_ratio=FLAGS.lr_decay_ratio, decay_epochs=lr_decay_epochs, warmup_epochs=FLAGS.lr_warmup_epochs) optimizer = tf.keras.optimizers.SGD(lr_schedule, momentum=1.0 - FLAGS.one_minus_momentum, nesterov=True) metrics = { 'train/negative_log_likelihood': tf.keras.metrics.Mean(), 'train/accuracy': tf.keras.metrics.SparseCategoricalAccuracy(), 'train/loss': tf.keras.metrics.Mean(), 'train/ece': rm.metrics.ExpectedCalibrationError(num_bins=FLAGS.num_bins), 'train/kl': tf.keras.metrics.Mean(), 'train/kl_scale': tf.keras.metrics.Mean(), 'test/negative_log_likelihood': tf.keras.metrics.Mean(), 'test/accuracy': tf.keras.metrics.SparseCategoricalAccuracy(), 'test/ece': rm.metrics.ExpectedCalibrationError(num_bins=FLAGS.num_bins), } if FLAGS.corruptions_interval > 0: corrupt_metrics = {} for intensity in range(1, 6): for corruption in corruption_types: dataset_name = '{0}_{1}'.format(corruption, intensity) corrupt_metrics['test/nll_{}'.format(dataset_name)] = ( tf.keras.metrics.Mean()) corrupt_metrics['test/accuracy_{}'.format( dataset_name)] = ( tf.keras.metrics.SparseCategoricalAccuracy()) corrupt_metrics['test/ece_{}'.format(dataset_name)] = ( rm.metrics.ExpectedCalibrationError( num_bins=FLAGS.num_bins)) checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer) latest_checkpoint = tf.train.latest_checkpoint(FLAGS.output_dir) initial_epoch = 0 if latest_checkpoint: # checkpoint.restore must be within a strategy.scope() so that optimizer # slot variables are mirrored. checkpoint.restore(latest_checkpoint) logging.info('Loaded checkpoint %s', latest_checkpoint) initial_epoch = optimizer.iterations.numpy() // steps_per_epoch @tf.function def train_step(iterator): """Training StepFn.""" def step_fn(inputs): """Per-Replica StepFn.""" images = inputs['features'] labels = inputs['labels'] with tf.GradientTape() as tape: logits = model(images, training=True) negative_log_likelihood = tf.reduce_mean( tf.keras.losses.sparse_categorical_crossentropy( labels, logits, from_logits=True)) filtered_variables = [] for var in model.trainable_variables: # Apply l2 on the BN parameters and bias terms. This # excludes only fast weight approximate posterior/prior parameters, # but pay caution to their naming scheme. if 'batch_norm' in var.name or 'bias' in var.name: filtered_variables.append(tf.reshape(var, (-1, ))) l2_loss = FLAGS.l2 * 2 * tf.nn.l2_loss( tf.concat(filtered_variables, axis=0)) kl = sum(model.losses) kl_scale = tf.cast(optimizer.iterations + 1, kl.dtype) kl_scale /= steps_per_epoch * FLAGS.kl_annealing_epochs kl_scale = tf.minimum(1., kl_scale) kl_loss = kl_scale * kl # Scale the loss given the TPUStrategy will reduce sum all gradients. loss = negative_log_likelihood + l2_loss + kl_loss scaled_loss = loss / strategy.num_replicas_in_sync grads = tape.gradient(scaled_loss, model.trainable_variables) optimizer.apply_gradients(zip(grads, model.trainable_variables)) probs = tf.nn.softmax(logits) metrics['train/ece'].add_batch(probs, label=labels) metrics['train/loss'].update_state(loss) metrics['train/negative_log_likelihood'].update_state( negative_log_likelihood) metrics['train/kl'].update_state(kl) metrics['train/kl_scale'].update_state(kl_scale) metrics['train/accuracy'].update_state(labels, logits) for _ in tf.range(tf.cast(steps_per_epoch, tf.int32)): strategy.run(step_fn, args=(next(iterator), )) @tf.function def test_step(iterator, dataset_name): """Evaluation StepFn.""" def step_fn(inputs): """Per-Replica StepFn.""" images = inputs['features'] labels = inputs['labels'] # TODO(trandustin): Use more eval samples only on corrupted predictions; # it's expensive but a one-time compute if scheduled post-training. if FLAGS.num_eval_samples > 1 and dataset_name != 'clean': logits = tf.stack([ model(images, training=False) for _ in range(FLAGS.num_eval_samples) ], axis=0) else: logits = model(images, training=False) probs = tf.nn.softmax(logits) if FLAGS.num_eval_samples > 1 and dataset_name != 'clean': probs = tf.reduce_mean(probs, axis=0) negative_log_likelihood = tf.reduce_mean( tf.keras.losses.sparse_categorical_crossentropy(labels, probs)) if dataset_name == 'clean': metrics['test/negative_log_likelihood'].update_state( negative_log_likelihood) metrics['test/accuracy'].update_state(labels, probs) metrics['test/ece'].add_batch(probs, label=labels) else: corrupt_metrics['test/nll_{}'.format( dataset_name)].update_state(negative_log_likelihood) corrupt_metrics['test/accuracy_{}'.format( dataset_name)].update_state(labels, probs) corrupt_metrics['test/ece_{}'.format(dataset_name)].add_batch( probs, label=labels) for _ in tf.range(tf.cast(steps_per_eval, tf.int32)): strategy.run(step_fn, args=(next(iterator), )) metrics.update({'test/ms_per_example': tf.keras.metrics.Mean()}) train_iterator = iter(train_dataset) start_time = time.time() for epoch in range(initial_epoch, FLAGS.train_epochs): logging.info('Starting to run epoch: %s', epoch) train_step(train_iterator) current_step = (epoch + 1) * steps_per_epoch max_steps = steps_per_epoch * FLAGS.train_epochs time_elapsed = time.time() - start_time steps_per_sec = float(current_step) / time_elapsed eta_seconds = (max_steps - current_step) / steps_per_sec message = ('{:.1%} completion: epoch {:d}/{:d}. {:.1f} steps/s. ' 'ETA: {:.0f} min. Time elapsed: {:.0f} min'.format( current_step / max_steps, epoch + 1, FLAGS.train_epochs, steps_per_sec, eta_seconds / 60, time_elapsed / 60)) logging.info(message) datasets_to_evaluate = {'clean': test_datasets['clean']} if (FLAGS.corruptions_interval > 0 and (epoch + 1) % FLAGS.corruptions_interval == 0): datasets_to_evaluate = test_datasets for dataset_name, test_dataset in datasets_to_evaluate.items(): test_iterator = iter(test_dataset) logging.info('Testing on dataset %s', dataset_name) logging.info('Starting to run eval at epoch: %s', epoch) test_start_time = time.time() test_step(test_iterator, dataset_name) ms_per_example = (time.time() - test_start_time) * 1e6 / batch_size metrics['test/ms_per_example'].update_state(ms_per_example) logging.info('Done with testing on %s', dataset_name) corrupt_results = {} if (FLAGS.corruptions_interval > 0 and (epoch + 1) % FLAGS.corruptions_interval == 0): corrupt_results = utils.aggregate_corrupt_metrics( corrupt_metrics, corruption_types) logging.info('Train Loss: %.4f, Accuracy: %.2f%%', metrics['train/loss'].result(), metrics['train/accuracy'].result() * 100) logging.info('Test NLL: %.4f, Accuracy: %.2f%%', metrics['test/negative_log_likelihood'].result(), metrics['test/accuracy'].result() * 100) total_results = { name: metric.result() for name, metric in metrics.items() } total_results.update(corrupt_results) # Metrics from Robustness Metrics (like ECE) will return a dict with a # single key/value, instead of a scalar. total_results = { k: (list(v.values())[0] if isinstance(v, dict) else v) for k, v in total_results.items() } with summary_writer.as_default(): for name, result in total_results.items(): tf.summary.scalar(name, result, step=epoch + 1) for metric in metrics.values(): metric.reset_states() if (FLAGS.checkpoint_interval > 0 and (epoch + 1) % FLAGS.checkpoint_interval == 0): checkpoint_name = checkpoint.save( os.path.join(FLAGS.output_dir, 'checkpoint')) logging.info('Saved checkpoint to %s', checkpoint_name) final_checkpoint_name = checkpoint.save( os.path.join(FLAGS.output_dir, 'checkpoint')) logging.info('Saved last checkpoint to %s', final_checkpoint_name) with summary_writer.as_default(): hp.hparams({ 'base_learning_rate': FLAGS.base_learning_rate, 'one_minus_momentum': FLAGS.one_minus_momentum, 'l2': FLAGS.l2, 'prior_stddev': FLAGS.prior_stddev, 'stddev_init': FLAGS.stddev_init, })
def train_and_evaluate( model, num_epochs, steps_per_epoch, train_data, validation_steps, eval_data, output_dir, n_steps_history, FLAGS, decay_type, learning_rate=3e-5, s=1, n_batch_decay=1, metric_accuracy='metric', ): """ Compiles keras model and loads data into it for training. """ logging.info('training the model ...') model_callbacks = [] # create meta data dictionary dict_model = {} dict_data = {} dict_parameter = {} dict_hardware = {} dict_results = {} dict_type_job = {} dict_software = {} # for debugging only activate_tensorboard = True activate_hp_tensorboard = False # True activate_lr = False save_checkpoints = False # True save_history_per_step = False # True save_metadata = False # True activate_timing = False # True # drop official method that is not working activate_tf_summary_hp = True # False # hardcoded way of doing hp activate_hardcoded_hp = True # True # dependencies if activate_tf_summary_hp: save_history_per_step = True if FLAGS.is_hyperparameter_tuning: # get trial ID suffix = mu.get_trial_id() if suffix == '': logging.error('No trial ID for hyper parameter job!') FLAGS.is_hyperparameter_tuning = False else: # callback for hp logging.info('Creating a callback to store the metric!') if activate_tf_summary_hp: hp_metric = mu.HP_metric(metric_accuracy) model_callbacks.append(hp_metric) if output_dir: if activate_tensorboard: # tensorflow callback log_dir = os.path.join(output_dir, 'tensorboard') if FLAGS.is_hyperparameter_tuning: log_dir = os.path.join(log_dir, suffix) tensorboard_callback = tf.keras.callbacks.TensorBoard( log_dir=log_dir, histogram_freq=1, embeddings_freq=0, write_graph=True, update_freq='batch', profile_batch='10, 20') model_callbacks.append(tensorboard_callback) if save_checkpoints: # checkpoints callback checkpoint_dir = os.path.join(output_dir, 'checkpoint_model') if not FLAGS.is_hyperparameter_tuning: # not saving model during hyper parameter tuning # heckpoint_dir = os.path.join(checkpoint_dir, suffix) checkpoint_prefix = os.path.join(checkpoint_dir, 'ckpt_{epoch:02d}') checkpoint_callback = tf.keras.callbacks.ModelCheckpoint( filepath=checkpoint_prefix, verbose=1, save_weights_only=True) model_callbacks.append(checkpoint_callback) if activate_lr: # decay learning rate callback # code snippet to make the switching between different learning rate decays possible if decay_type == 'exponential': decay_fn = mu.exponential_decay(lr0=learning_rate, s=s) elif decay_type == 'stepwise': decay_fn = mu.step_decay(lr0=learning_rate, s=s) elif decay_type == 'timebased': decay_fn = mu.time_decay(lr0=learning_rate, s=s) else: decay_fn = mu.no_decay(lr0=learning_rate) # exponential_decay_fn = mu.exponential_decay(lr0=learning_rate, s=s) # lr_scheduler = tf.keras.callbacks.LearningRateScheduler(exponential_decay_fn, verbose=1) # model_callbacks.append(lr_scheduler) # added these two lines for batch updates lr_decay_batch = mu.LearningRateSchedulerPerBatch(decay_fn, n_batch_decay, verbose=1) # lr_decay_batch = mu.LearningRateSchedulerPerBatch(exponential_decay_fn, n_batch_decay, verbose=0) # lambda step: ((learning_rate - min_learning_rate) * decay_rate ** step + min_learning_rate)) model_callbacks.append(lr_decay_batch) # print_lr = mu.PrintLR() # model_callbacks.append(mu.PrintLR()) # --------------------------------------------------------------------------------------------------------------- # callback to store all the learning rates # all_learning_rates = mu.LearningRateSchedulerPerBatch(model.optimizer, n_steps_history) # all_learning_rates = mu.LR_per_step() # all_learning_rates = mu.LR_per_step(model.optimizer) # model_callbacks.append(all_learning_rates) # disble if save_history_per_step: # callback to create history per step (not per epoch) histories_per_step = mu.History_per_step(eval_data, n_steps_history) model_callbacks.append(histories_per_step) if activate_timing: # callback to time each epoch timing = mu.TimingCallback() model_callbacks.append(timing) # checking model callbacks for logging.info('model\'s callback:\n {}'.format(str(model_callbacks))) # train the model # time the function start_time = time.time() logging.info('starting model.fit') # verbose = 0 (silent) # verbose = 1 (progress bar) # verbose = 2 (one line per epoch) verbose = 1 history = model.fit(train_data, epochs=num_epochs, steps_per_epoch=steps_per_epoch, validation_data=eval_data, validation_steps=validation_steps, verbose=verbose, callbacks=model_callbacks) # print execution time elapsed_time_secs = time.time() - start_time logging.info('\nexecution time: {}'.format( timedelta(seconds=round(elapsed_time_secs)))) # check model logging.info('model summary ={}'.format(model.summary())) logging.info('model input ={}'.format(model.inputs)) logging.info('model outputs ={}'.format(model.outputs)) # to be remove logging.info('\ndebugging .... : ') pp.print_info_data(train_data) if activate_timing: logging.info('timing per epoch:\n{}'.format( list( map(lambda x: str(timedelta(seconds=round(x))), timing.timing_epoch)))) logging.info('timing per validation:\n{}'.format( list( map(lambda x: str(timedelta(seconds=round(x))), timing.timing_valid)))) logging.info('sum timing over all epochs:\n{}'.format( timedelta(seconds=round(sum(timing.timing_epoch))))) # for hp parameter tuning in TensorBoard if FLAGS.is_hyperparameter_tuning: logging.info('setup hyperparameter tuning!') # test #params = json.loads(os.environ.get("CLUSTER_SPEC", "{}")).get("job", {}) #print('debug: CLUSTER_SPEC1:', params) #params = json.loads(os.environ.get("CLUSTER_SPEC", "{}")).get("job", {}).get("job_args", {}) #print('debug: CLUSTER_SPEC2:', params) logging.info('debug: os.environ.items():', os.environ.items()) # if activate_hardcoded_hp: # trick to bypass ai platform bug logging.info('hardcoded hyperparameter tuning!') value_accuracy = histories_per_step.accuracies[-1] hpt = hypertune.HyperTune() hpt.report_hyperparameter_tuning_metric( hyperparameter_metric_tag=metric_accuracy, metric_value=value_accuracy, global_step=0) else: # should be extracted from /var/hypertune/output.metric logging.info('standard hyperparameter tuning!') # is this needed ? # value_accuracy = histories_per_step.accuracies[-1] # look at the content of the file path_metric = '/var/hypertune/output.metric' logging.info('checking if /var/hypertune/output.metric exist!') if os.path.isfile(path_metric): logging.info('file {} exist !'.format(path_metric)) with open(path_metric, 'r') as f: logging.info('content of output.metric: {}'.format(f.read())) if activate_hp_tensorboard: logging.info('setup TensorBoard for hyperparameter tuning!') # CAIP #params = json.loads(os.environ.get("TF_CONFIG", "{}")).get("job", {}).get("hyperparameters", {}).get("params", {}) #uCAIP params = json.loads( os.environ.get("CLUSTER_SPEC", "{}") ) #.get("job", {}).get("hyperparameters", {}).get("params", {}) print('debug: CLUSTER_SPEC:', params) list_hp = [] hparams = {} for el in params: hp_dict = dict(el) if hp_dict.get('type') == 'DOUBLE': key_hp = hp.HParam( hp_dict.get('parameter_name'), hp.RealInterval(hp_dict.get('min_value'), hp_dict.get('max_value'))) list_hp.append(key_hp) try: hparams[key_hp] = FLAGS[hp_dict.get( 'parameter_name')].value except KeyError: logging.error( 'hyperparameter key {} doesn\'t exist'.format( hp_dict.get('parameter_name'))) hparams_dir = os.path.join(output_dir, 'hparams_tuning') with tf.summary.create_file_writer(hparams_dir).as_default(): hp.hparams_config( hparams=list_hp, metrics=[ hp.Metric(metric_accuracy, display_name=metric_accuracy) ], ) hparams_dir = os.path.join(hparams_dir, suffix) with tf.summary.create_file_writer(hparams_dir).as_default(): # record the values used in this trial hp.hparams(hparams) tf.summary.scalar(metric_accuracy, value_accuracy, step=1) if save_history_per_step: # save the history in a file search = re.search('gs://(.*?)/(.*)', output_dir) if search is not None: # temp folder locally and to be ove on gcp later history_dir = os.path.join('./', model.name) os.makedirs(history_dir, exist_ok=True) else: # locally history_dir = os.path.join(output_dir, model.name) os.makedirs(history_dir, exist_ok=True) logging.debug('history_dir: \n {}'.format(history_dir)) with open(history_dir + '/history', 'wb') as file: model_history = mu.History_trained_model(history.history, history.epoch, history.params) pickle.dump(model_history, file, pickle.HIGHEST_PROTOCOL) with open(history_dir + '/history_per_step', 'wb') as file: model_history_per_step = mu.History_per_steps_trained_model( histories_per_step.steps, histories_per_step.losses, histories_per_step.accuracies, histories_per_step.val_steps, histories_per_step.val_losses, histories_per_step.val_accuracies, 0, # all_learning_rates.all_lr, 0, # all_learning_rates.all_lr_alternative, 0) # all_learning_rates.all_lr_logs) pickle.dump(model_history_per_step, file, pickle.HIGHEST_PROTOCOL) if output_dir: # save the model savemodel_path = os.path.join(output_dir, 'saved_model') if not FLAGS.is_hyperparameter_tuning: # not saving model during hyper parameter tuning # savemodel_path = os.path.join(savemodel_path, suffix) model.save(os.path.join(savemodel_path, model.name)) model2 = tf.keras.models.load_model( os.path.join(savemodel_path, model.name)) # check model logging.info('model2 summary ={}'.format(model2.summary())) logging.info('model2 input ={}'.format(model2.inputs)) logging.info('model2 outputs ={}'.format(model2.outputs)) logging.info('model2 signature outputs ={}'.format( model2.signatures['serving_default'].structured_outputs)) logging.info('model2 inputs ={}'.format( model2.signatures['serving_default'].inputs[0])) if save_history_per_step: # save history search = re.search('gs://(.*?)/(.*)', output_dir) if search is not None: bucket_name = search.group(1) blob_name = search.group(2) output_folder = blob_name + '/history' if FLAGS.is_hyperparameter_tuning: output_folder = os.path.join(output_folder, suffix) mu.copy_local_directory_to_gcs(history_dir, bucket_name, output_folder) if save_metadata: # add meta data dict_model['pretrained_transformer_model'] = FLAGS.pretrained_model_dir dict_model['num_classes'] = FLAGS.num_classes dict_data['train'] = FLAGS.input_train_tfrecords dict_data['eval'] = FLAGS.input_eval_tfrecords dict_parameter[ 'use_decay_learning_rate'] = FLAGS.use_decay_learning_rate dict_parameter['epochs'] = FLAGS.epochs dict_parameter['steps_per_epoch_train'] = FLAGS.steps_per_epoch_train dict_parameter['steps_per_epoch_eval'] = FLAGS.steps_per_epoch_eval dict_parameter['n_steps_history'] = FLAGS.n_steps_history dict_parameter['batch_size_train'] = FLAGS.batch_size_train dict_parameter['batch_size_eval'] = FLAGS.batch_size_eval dict_parameter['learning_rate'] = FLAGS.learning_rate dict_parameter['epsilon'] = FLAGS.epsilon dict_hardware['is_tpu'] = FLAGS.use_tpu dict_type_job[ 'is_hyperparameter_tuning'] = FLAGS.is_hyperparameter_tuning dict_type_job['is_tpu'] = FLAGS.use_tpu dict_software['tensorflow'] = tf.__version__ dict_software['transformer'] = __version__ dict_software['python'] = sys.version # aggregate dictionaries dict_all = { 'model': dict_model, 'data': dict_data, 'parameter': dict_parameter, 'hardware': dict_hardware, 'results': dict_results, 'type_job': dict_type_job, 'software': dict_software } # save metadata search = re.search('gs://(.*?)/(.*)', output_dir) if search is not None: bucket_name = search.group(1) blob_name = search.group(2) output_folder = blob_name + '/metadata' storage_client = storage.Client() bucket = storage_client.bucket(bucket_name) blob = bucket.blob(output_folder + '/model_job_metadata.json') blob.upload_from_string(data=json.dumps(dict_all), content_type='application/json')
def train_and_validate(train_x, train_y, test_x, test_y, hparams): unique_items = len(train_y[0]) model = LSTMRec( vocabulary_size=unique_items, emb_output_dim=hparams['emb_dim'], lstm_units=hparams['lstm_units'], lstm_activation=hparams['lstm_activation'], lstm_recurrent_activation=hparams['lstm_recurrent_activation'], lstm_dropout=hparams['lstm_dropout'], lstm_recurrent_dropout=hparams['lstm_recurrent_dropout'], dense_activation=hparams['dense_activation']) model.compile(optimizer=Adam(learning_rate=hparams['learning_rate'], beta_1=hparams['adam_beta_1'], beta_2=hparams['adam_beta_2'], epsilon=hparams['adam_epsilon']), loss='binary_crossentropy', metrics=[ Precision(top_k=1, name='P_at_1'), Precision(top_k=3, name='P_at_3'), Precision(top_k=5, name='P_at_5'), Precision(top_k=10, name='P_at_10'), Recall(top_k=10, name='R_at_10'), Recall(top_k=50, name='R_at_50'), Recall(top_k=100, name='R_at_100') ]) hst = model.fit( x=train_x, y=train_y, batch_size=hparams['batch_size'], epochs=250, callbacks=[ EarlyStopping(monitor='val_R_at_10', patience=10, mode='max', restore_best_weights=True, verbose=True), ModelCheckpoint(filepath=os.path.join(os.pardir, os.pardir, 'models', hparams['run_id'] + '.ckpt'), monitor='val_R_at_10', mode='max', save_best_only=True, save_weights_only=True, verbose=True), TensorBoard(log_dir=os.path.join(os.pardir, os.pardir, 'logs', hparams['run_id']), histogram_freq=1) ], validation_split=0.2) val_best_epoch = np.argmax(hst.history['val_R_at_10']) test_results = model.evaluate(test_x, test_y) with tf.summary.create_file_writer( os.path.join(os.pardir, os.pardir, 'logs', hparams['run_id'], 'hparams')).as_default(): hp.hparams(hparams) tf.summary.scalar('train.final_loss', hst.history["val_loss"][val_best_epoch], step=val_best_epoch) tf.summary.scalar('train.final_P_at_1', hst.history["val_P_at_1"][val_best_epoch], step=val_best_epoch) tf.summary.scalar('train.final_P_at_3', hst.history["val_P_at_3"][val_best_epoch], step=val_best_epoch) tf.summary.scalar('train.final_P_at_5', hst.history["val_P_at_5"][val_best_epoch], step=val_best_epoch) tf.summary.scalar('train.final_P_at_10', hst.history["val_P_at_10"][val_best_epoch], step=val_best_epoch) tf.summary.scalar('train.final_R_at_10', hst.history["val_R_at_10"][val_best_epoch], step=val_best_epoch) tf.summary.scalar('train.final_R_at_50', hst.history["val_R_at_50"][val_best_epoch], step=val_best_epoch) tf.summary.scalar('train.final_R_at_100', hst.history["val_R_at_100"][val_best_epoch], step=val_best_epoch) tf.summary.scalar('test.final_loss', test_results[0], step=val_best_epoch) tf.summary.scalar('test.final_P_at_1', test_results[1], step=val_best_epoch) tf.summary.scalar('test.final_P_at_3', test_results[2], step=val_best_epoch) tf.summary.scalar('test.final_P_at_5', test_results[3], step=val_best_epoch) tf.summary.scalar('test.final_P_at_10', test_results[4], step=val_best_epoch) tf.summary.scalar('test.final_R_at_10', test_results[5], step=val_best_epoch) tf.summary.scalar('test.final_R_at_50', test_results[6], step=val_best_epoch) tf.summary.scalar('test.final_R_at_100', test_results[7], step=val_best_epoch) return val_best_epoch, test_results