def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') params = FLAGS.flag_values_dict() tf.set_random_seed(params['seed']) params['results_dir'] = utils.make_subdir(params['training_results_dir'], params['expname']) params['figdir'] = utils.make_subdir(params['results_dir'], 'figs') params['ckptdir'] = utils.make_subdir(params['results_dir'], 'ckpts') params['logdir'] = utils.make_subdir(params['results_dir'], 'logs') params['tensordir'] = utils.make_subdir(params['results_dir'], 'tensors') conv_dims = [int(x) for x in params['conv_dims'].split(',')] conv_sizes = [int(x) for x in params['conv_sizes'].split(',')] dense_sizes = [int(x) for x in params['dense_sizes'].split(',')] params['n_layers'] = len(conv_dims) clf = classifier.CNN(conv_dims, conv_sizes, dense_sizes, params['n_classes'], onehot=True) utils.checkpoint_model(clf, params['ckptdir'], 'initmodel') itr_train, itr_valid, itr_test = dataset_utils.load_dset_supervised_onehot( ) train_cnn.train_classifier(clf, itr_train, itr_valid, params) train_cnn.test_classifier(clf, itr_test, params, 'test') train_data = utils.aggregate_batches(itr_train, 1000, ['train_x_infl', 'train_y_infl']) validation_data = utils.aggregate_batches(itr_valid, 1000, ['valid_x_infl', 'valid_y_infl']) test_data = utils.aggregate_batches(itr_test, 1000, ['test_x_infl', 'test_y_infl']) utils.save_tensors( train_data.items() + validation_data.items() + test_data.items(), params['tensordir'])
def train_vae(vae, itr_train, itr_valid, params): """Train a VAE. Args: vae (VAE): a VAE. itr_train (Iterator): an iterator over training data. itr_valid (Iterator): an iterator over validation data. params (dict): flags for training. """ run_avg_len = params['run_avg_len'] max_steps = params['max_steps'] print_freq = params['print_freq'] # RALoss is an object which tracks the running average of a loss. ra_loss = RALoss('elbo', run_avg_len) ra_kl = RALoss('kl', run_avg_len) ra_recon = RALoss('recon', run_avg_len) ra_trainloss = RALoss('train-elbo', run_avg_len) min_val_loss = sys.maxsize min_val_step = 0 opt = tf.train.AdamOptimizer(learning_rate=params['lr']) finished_training = False start_printing = 0 for i in range(max_steps): batch = itr_train.next() with tf.GradientTape() as tape: train_loss, _, _ = vae.get_loss(batch) mean_train_loss = tf.reduce_mean(train_loss) val_batch = itr_valid.next() valid_loss, kl_loss, recon_loss = vae.get_loss(val_batch) loss_list = [ra_loss, ra_kl, ra_recon, ra_trainloss] losses = zip(loss_list, [ tf.reduce_mean(l) for l in (valid_loss, kl_loss, recon_loss, train_loss) ]) utils.update_losses(losses) grads = tape.gradient(mean_train_loss, vae.weights) opt.apply_gradients(zip(grads, vae.weights)) curr_ra_loss = ra_loss.get_value() # Early stopping: stop training when validation loss stops decreasing. # The second condition ensures we don't checkpoint every step early on. if curr_ra_loss < min_val_loss and \ i - min_val_step > params['patience'] / 10: min_val_loss = curr_ra_loss min_val_step = i save_path, ckpt = utils.checkpoint_model(vae, params['ckptdir']) logging.info('Step {:d}: Checkpointed to {}'.format(i, save_path)) elif i - min_val_step > params['patience'] or i == max_steps - 1: ckpt.restore(save_path) logging.info('Best validation loss was {:.3f} at step {:d}' ' - stopping training'.format(min_val_loss, min_val_step)) finished_training = True if i % print_freq == 0 or finished_training: utils.print_losses(loss_list, i) utils.write_losses_to_log(loss_list, range(start_printing, i + 1), params['logdir']) start_printing = i + 1 utils.plot_losses(params['figdir'], loss_list, params['mpl_format']) utils.plot_samples(params['sampledir'], vae, itr_valid, params['mpl_format']) if finished_training: break
def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') params = FLAGS.flag_values_dict() tf.set_random_seed(params['seed']) params['results_dir'] = utils.make_subdir(params['training_results_dir'], params['expname']) params['figdir'] = utils.make_subdir(params['results_dir'], 'figs') params['ckptdir'] = utils.make_subdir(params['results_dir'], 'ckpts') params['logdir'] = utils.make_subdir(params['results_dir'], 'logs') params['tensordir'] = utils.make_subdir(params['results_dir'], 'tensors') # Load the classification model. conv_dims = [ int(x) for x in ( params['conv_dims'].split(',') if params['conv_dims'] else []) ] conv_sizes = [ int(x) for x in ( params['conv_sizes'].split(',') if params['conv_sizes'] else []) ] dense_sizes = [ int(x) for x in ( params['dense_sizes'].split(',') if params['dense_sizes'] else []) ] params['n_layers'] = len(conv_dims) clf = classifier.CNN(conv_dims, conv_sizes, dense_sizes, params['n_classes'], onehot=True) # Checkpoint the initialized model, in case we want to re-run it from there. utils.checkpoint_model(clf, params['ckptdir'], 'initmodel') # Load the "in-distribution" and "out-of-distribution" classes as # separate splits. ood_classes = [int(x) for x in params['ood_classes'].split(',')] # We assume we train on all non-OOD classes. all_classes = range(params['n_classes']) ind_classes = [x for x in all_classes if x not in ood_classes] (itr_train, itr_valid, itr_test, itr_test_ood) = dataset_utils.load_dset_ood_supervised_onehot( ind_classes, ood_classes, label_noise=(params['label_noise']), dset_name=params['dataset_name']) # Train and test the model in-distribution, and save test outputs. train_cnn.train_classifier(clf, itr_train, itr_valid, params) train_cnn.test_classifier(clf, itr_test, params, 'test') # Save model outputs on the training set. params['tensordir'] = utils.make_subdir(params['results_dir'], 'train_tensors') train_cnn.test_classifier(clf, itr_train, params, 'train') # Save model outputs on the OOD set. params['tensordir'] = utils.make_subdir(params['results_dir'], 'ood_tensors') train_cnn.test_classifier(clf, itr_test_ood, params, 'ood') params['tensordir'] = utils.make_subdir(params['results_dir'], 'tensors') # Save to disk samples of size 1000 from the train, valid, test and OOD sets. train_data = utils.aggregate_batches(itr_train, 1000, ['train_x_infl', 'train_y_infl']) validation_data = utils.aggregate_batches(itr_valid, 1000, ['valid_x_infl', 'valid_y_infl']) test_data = utils.aggregate_batches(itr_test, 1000, ['test_x_infl', 'test_y_infl']) ood_data = utils.aggregate_batches(itr_test_ood, 1000, ['ood_x_infl', 'ood_y_infl']) utils.save_tensors( train_data.items() + validation_data.items() + test_data.items() + ood_data.items(), params['tensordir'])
def train_classifier(clf, itr_train, itr_valid, params): """Train a classifier. Args: clf (classifier): a classifier we wish to train. itr_train (Iterator): an iterator over training data. itr_valid (Iterator): an iterator over validation data. params (dict): flags for training. """ # Dump the parameters we used to a JSON file for reproducibility. params_file = os.path.join(params['results_dir'], 'params.json') utils.write_json(params_file, params) run_avg_len = params['run_avg_len'] max_steps = params['max_steps'] print_freq = params['print_freq'] # RALoss is an object which tracks the running average of a loss. ra_loss = RALoss('loss', run_avg_len) ra_error = RALoss('error', run_avg_len) ra_trainloss = RALoss('train-loss', run_avg_len) ra_trainerr = RALoss('train-err', run_avg_len) min_val_loss = sys.maxsize min_val_step = 0 opt = tf.train.AdamOptimizer(learning_rate=params['lr']) finished_training = False start_printing = 0 for i in range(max_steps): batch_x, batch_y = itr_train.next() with tf.GradientTape() as tape: train_loss, train_err = clf.get_loss(batch_x, batch_y) mean_train_loss = tf.reduce_mean(train_loss) val_batch_x, val_batch_y = itr_valid.next() valid_loss, valid_err = clf.get_loss(val_batch_x, val_batch_y) loss_list = [ra_loss, ra_error, ra_trainloss, ra_trainerr] losses = zip(loss_list, [tf.reduce_mean(l) for l in (valid_loss, valid_err, train_loss, train_err)]) utils.update_losses(losses) grads = tape.gradient(mean_train_loss, clf.weights) opt.apply_gradients(zip(grads, clf.weights)) utils.print_losses(loss_list, i) if params['early_stopping_metric'] == 'loss': curr_ra_loss = ra_loss.get_value() elif params['early_stopping_metric'] == 'error': curr_ra_loss = ra_error.get_value() else: raise ValueError('Params["early_stopping_metric"] should be either "loss"' ' or "error", and it is "{}"'.format( params['early_stopping_metric'])) if curr_ra_loss < min_val_loss and \ i - min_val_step > params['patience'] / 10: # Early stopping: stop training when validation loss stops decreasing. # The second condition ensures we don't checkpoint every step early on. min_val_loss = curr_ra_loss min_val_step = i save_path, ckpt = utils.checkpoint_model(clf, params['ckptdir']) logging.info('Step {:d}: Checkpointed to {}'.format(i, save_path)) elif i - min_val_step > params['patience'] or i == max_steps - 1: ckpt.restore(save_path) finished_training = True logging.info('Best validation loss was {:.3f} at step {:d}' ' - stopping training'.format(min_val_loss, min_val_step)) if i % print_freq == 0 or finished_training: utils.write_losses_to_log(loss_list, range(start_printing, i + 1), params['logdir']) start_printing = i + 1 utils.plot_losses(params['figdir'], loss_list, params['mpl_format']) logging.info('Step {:d}: Wrote losses and plots'.format(i)) if finished_training: break