def test_vae(vae, itr_test, params): """Test a trained VAE.""" max_steps_test = params['max_steps_test'] ra_loss = RALoss('elbo', max_steps_test) ra_kl = RALoss('kl', max_steps_test) ra_recon = RALoss('recon', max_steps_test) loss_tensor = [] kl_tensor = [] recon_tensor = [] for i in range(max_steps_test): batch = itr_test.next() loss, kl, recon = vae.get_loss(batch) losses = zip([ra_loss, ra_kl, ra_recon], [tf.reduce_mean(l) for l in (loss, kl, recon)]) utils.update_losses(losses) utils.print_losses([l[0] for l in losses], i) loss_tensor.append(loss) kl_tensor.append(kl) recon_tensor.append(recon) loss_tensor = tf.concat(loss_tensor, 0) kl_tensor = tf.concat(kl_tensor, 0) recon_tensor = tf.concat(recon_tensor, 0) utils.save_tensors( zip([ra_loss, ra_kl, ra_recon], [loss_tensor, kl_tensor, recon_tensor]), params['tensordir'])
def test_classifier(clf, itr_test, params, test_name): """Test a trained classifier.""" max_steps_test = params['max_steps_test'] run_avg_len = params['run_avg_len'] ra_loss = RALoss('loss', run_avg_len) ra_error = RALoss('error', run_avg_len) loss_tensor = [] err_tensor = [] label_tensor = [] preds_tensor = [] reprs_collection = [list() for l in range(params['n_layers'])] for i in range(max_steps_test): batch_x, batch_y = itr_test.next() loss, error, preds, reprs = clf.get_loss(batch_x, batch_y, return_preds=True) losses = zip([ra_loss, ra_error], [tf.reduce_mean(l) for l in (loss, error)]) utils.update_losses(losses) utils.print_losses([l[0] for l in losses], i) loss_tensor.append(loss) err_tensor.append(error) preds_tensor.append(preds) label_tensor.append(batch_y) for l in range(params['n_layers']): reprs_collection[l].append(reprs[l]) loss_tensor = tf.concat(loss_tensor, 0) err_tensor = tf.concat(err_tensor, 0) preds_tensor = tf.concat(preds_tensor, 0) label_tensor = tf.concat(label_tensor, 0) for i in range(params['n_layers']): reprs_collection[i] = tf.concat(reprs_collection[i], 0) utils.save_tensors( zip([ra_loss, ra_error, RALoss('preds', 1), RALoss('labels', 1)], [loss_tensor, err_tensor, preds_tensor, label_tensor]), params['tensordir']) utils.save_tensors( zip([ RALoss('repr_{:d}'.format(l), 1) for l in range(params['n_layers']) ], reprs_collection), params['tensordir']) utils.write_metrics( zip(['loss', 'error'], [np.mean(loss_tensor), np.mean(err_tensor)]), params['logdir'], test_name)
def train_vae(vae, itr_train, itr_valid, params): """Train a VAE. Args: vae (VAE): a VAE. itr_train (Iterator): an iterator over training data. itr_valid (Iterator): an iterator over validation data. params (dict): flags for training. """ run_avg_len = params['run_avg_len'] max_steps = params['max_steps'] print_freq = params['print_freq'] # RALoss is an object which tracks the running average of a loss. ra_loss = RALoss('elbo', run_avg_len) ra_kl = RALoss('kl', run_avg_len) ra_recon = RALoss('recon', run_avg_len) ra_trainloss = RALoss('train-elbo', run_avg_len) min_val_loss = sys.maxsize min_val_step = 0 opt = tf.train.AdamOptimizer(learning_rate=params['lr']) finished_training = False start_printing = 0 for i in range(max_steps): batch = itr_train.next() with tf.GradientTape() as tape: train_loss, _, _ = vae.get_loss(batch) mean_train_loss = tf.reduce_mean(train_loss) val_batch = itr_valid.next() valid_loss, kl_loss, recon_loss = vae.get_loss(val_batch) loss_list = [ra_loss, ra_kl, ra_recon, ra_trainloss] losses = zip(loss_list, [ tf.reduce_mean(l) for l in (valid_loss, kl_loss, recon_loss, train_loss) ]) utils.update_losses(losses) grads = tape.gradient(mean_train_loss, vae.weights) opt.apply_gradients(zip(grads, vae.weights)) curr_ra_loss = ra_loss.get_value() # Early stopping: stop training when validation loss stops decreasing. # The second condition ensures we don't checkpoint every step early on. if curr_ra_loss < min_val_loss and \ i - min_val_step > params['patience'] / 10: min_val_loss = curr_ra_loss min_val_step = i save_path, ckpt = utils.checkpoint_model(vae, params['ckptdir']) logging.info('Step {:d}: Checkpointed to {}'.format(i, save_path)) elif i - min_val_step > params['patience'] or i == max_steps - 1: ckpt.restore(save_path) logging.info('Best validation loss was {:.3f} at step {:d}' ' - stopping training'.format(min_val_loss, min_val_step)) finished_training = True if i % print_freq == 0 or finished_training: utils.print_losses(loss_list, i) utils.write_losses_to_log(loss_list, range(start_printing, i + 1), params['logdir']) start_printing = i + 1 utils.plot_losses(params['figdir'], loss_list, params['mpl_format']) utils.plot_samples(params['sampledir'], vae, itr_valid, params['mpl_format']) if finished_training: break
def train_classifier(clf, itr_train, itr_valid, params): """Train a classifier. Args: clf (classifier): a classifier we wish to train. itr_train (Iterator): an iterator over training data. itr_valid (Iterator): an iterator over validation data. params (dict): flags for training. """ # Dump the parameters we used to a JSON file for reproducibility. params_file = os.path.join(params['results_dir'], 'params.json') utils.write_json(params_file, params) run_avg_len = params['run_avg_len'] max_steps = params['max_steps'] print_freq = params['print_freq'] # RALoss is an object which tracks the running average of a loss. ra_loss = RALoss('loss', run_avg_len) ra_error = RALoss('error', run_avg_len) ra_trainloss = RALoss('train-loss', run_avg_len) ra_trainerr = RALoss('train-err', run_avg_len) min_val_loss = sys.maxsize min_val_step = 0 opt = tf.train.AdamOptimizer(learning_rate=params['lr']) finished_training = False start_printing = 0 for i in range(max_steps): batch_x, batch_y = itr_train.next() with tf.GradientTape() as tape: train_loss, train_err = clf.get_loss(batch_x, batch_y) mean_train_loss = tf.reduce_mean(train_loss) val_batch_x, val_batch_y = itr_valid.next() valid_loss, valid_err = clf.get_loss(val_batch_x, val_batch_y) loss_list = [ra_loss, ra_error, ra_trainloss, ra_trainerr] losses = zip(loss_list, [tf.reduce_mean(l) for l in (valid_loss, valid_err, train_loss, train_err)]) utils.update_losses(losses) grads = tape.gradient(mean_train_loss, clf.weights) opt.apply_gradients(zip(grads, clf.weights)) utils.print_losses(loss_list, i) if params['early_stopping_metric'] == 'loss': curr_ra_loss = ra_loss.get_value() elif params['early_stopping_metric'] == 'error': curr_ra_loss = ra_error.get_value() else: raise ValueError('Params["early_stopping_metric"] should be either "loss"' ' or "error", and it is "{}"'.format( params['early_stopping_metric'])) if curr_ra_loss < min_val_loss and \ i - min_val_step > params['patience'] / 10: # Early stopping: stop training when validation loss stops decreasing. # The second condition ensures we don't checkpoint every step early on. min_val_loss = curr_ra_loss min_val_step = i save_path, ckpt = utils.checkpoint_model(clf, params['ckptdir']) logging.info('Step {:d}: Checkpointed to {}'.format(i, save_path)) elif i - min_val_step > params['patience'] or i == max_steps - 1: ckpt.restore(save_path) finished_training = True logging.info('Best validation loss was {:.3f} at step {:d}' ' - stopping training'.format(min_val_loss, min_val_step)) if i % print_freq == 0 or finished_training: utils.write_losses_to_log(loss_list, range(start_printing, i + 1), params['logdir']) start_printing = i + 1 utils.plot_losses(params['figdir'], loss_list, params['mpl_format']) logging.info('Step {:d}: Wrote losses and plots'.format(i)) if finished_training: break