def test_get_value(self): l = RALoss('name', 3) dat = [0, 1, 2, 3, 4, 5, 6, 7] for x in dat: l.update(x) self.assertEqual(dat, l.get_history()) self.assertEqual(6.0, l.get_value()) self.assertEqual(3.0, l.get_value(i=4))
def train_vae(vae, itr_train, itr_valid, params): """Train a VAE. Args: vae (VAE): a VAE. itr_train (Iterator): an iterator over training data. itr_valid (Iterator): an iterator over validation data. params (dict): flags for training. """ run_avg_len = params['run_avg_len'] max_steps = params['max_steps'] print_freq = params['print_freq'] # RALoss is an object which tracks the running average of a loss. ra_loss = RALoss('elbo', run_avg_len) ra_kl = RALoss('kl', run_avg_len) ra_recon = RALoss('recon', run_avg_len) ra_trainloss = RALoss('train-elbo', run_avg_len) min_val_loss = sys.maxsize min_val_step = 0 opt = tf.train.AdamOptimizer(learning_rate=params['lr']) finished_training = False start_printing = 0 for i in range(max_steps): batch = itr_train.next() with tf.GradientTape() as tape: train_loss, _, _ = vae.get_loss(batch) mean_train_loss = tf.reduce_mean(train_loss) val_batch = itr_valid.next() valid_loss, kl_loss, recon_loss = vae.get_loss(val_batch) loss_list = [ra_loss, ra_kl, ra_recon, ra_trainloss] losses = zip(loss_list, [ tf.reduce_mean(l) for l in (valid_loss, kl_loss, recon_loss, train_loss) ]) utils.update_losses(losses) grads = tape.gradient(mean_train_loss, vae.weights) opt.apply_gradients(zip(grads, vae.weights)) curr_ra_loss = ra_loss.get_value() # Early stopping: stop training when validation loss stops decreasing. # The second condition ensures we don't checkpoint every step early on. if curr_ra_loss < min_val_loss and \ i - min_val_step > params['patience'] / 10: min_val_loss = curr_ra_loss min_val_step = i save_path, ckpt = utils.checkpoint_model(vae, params['ckptdir']) logging.info('Step {:d}: Checkpointed to {}'.format(i, save_path)) elif i - min_val_step > params['patience'] or i == max_steps - 1: ckpt.restore(save_path) logging.info('Best validation loss was {:.3f} at step {:d}' ' - stopping training'.format(min_val_loss, min_val_step)) finished_training = True if i % print_freq == 0 or finished_training: utils.print_losses(loss_list, i) utils.write_losses_to_log(loss_list, range(start_printing, i + 1), params['logdir']) start_printing = i + 1 utils.plot_losses(params['figdir'], loss_list, params['mpl_format']) utils.plot_samples(params['sampledir'], vae, itr_valid, params['mpl_format']) if finished_training: break
def train_classifier(clf, itr_train, itr_valid, params): """Train a classifier. Args: clf (classifier): a classifier we wish to train. itr_train (Iterator): an iterator over training data. itr_valid (Iterator): an iterator over validation data. params (dict): flags for training. """ # Dump the parameters we used to a JSON file for reproducibility. params_file = os.path.join(params['results_dir'], 'params.json') utils.write_json(params_file, params) run_avg_len = params['run_avg_len'] max_steps = params['max_steps'] print_freq = params['print_freq'] # RALoss is an object which tracks the running average of a loss. ra_loss = RALoss('loss', run_avg_len) ra_error = RALoss('error', run_avg_len) ra_trainloss = RALoss('train-loss', run_avg_len) ra_trainerr = RALoss('train-err', run_avg_len) min_val_loss = sys.maxsize min_val_step = 0 opt = tf.train.AdamOptimizer(learning_rate=params['lr']) finished_training = False start_printing = 0 for i in range(max_steps): batch_x, batch_y = itr_train.next() with tf.GradientTape() as tape: train_loss, train_err = clf.get_loss(batch_x, batch_y) mean_train_loss = tf.reduce_mean(train_loss) val_batch_x, val_batch_y = itr_valid.next() valid_loss, valid_err = clf.get_loss(val_batch_x, val_batch_y) loss_list = [ra_loss, ra_error, ra_trainloss, ra_trainerr] losses = zip(loss_list, [tf.reduce_mean(l) for l in (valid_loss, valid_err, train_loss, train_err)]) utils.update_losses(losses) grads = tape.gradient(mean_train_loss, clf.weights) opt.apply_gradients(zip(grads, clf.weights)) utils.print_losses(loss_list, i) if params['early_stopping_metric'] == 'loss': curr_ra_loss = ra_loss.get_value() elif params['early_stopping_metric'] == 'error': curr_ra_loss = ra_error.get_value() else: raise ValueError('Params["early_stopping_metric"] should be either "loss"' ' or "error", and it is "{}"'.format( params['early_stopping_metric'])) if curr_ra_loss < min_val_loss and \ i - min_val_step > params['patience'] / 10: # Early stopping: stop training when validation loss stops decreasing. # The second condition ensures we don't checkpoint every step early on. min_val_loss = curr_ra_loss min_val_step = i save_path, ckpt = utils.checkpoint_model(clf, params['ckptdir']) logging.info('Step {:d}: Checkpointed to {}'.format(i, save_path)) elif i - min_val_step > params['patience'] or i == max_steps - 1: ckpt.restore(save_path) finished_training = True logging.info('Best validation loss was {:.3f} at step {:d}' ' - stopping training'.format(min_val_loss, min_val_step)) if i % print_freq == 0 or finished_training: utils.write_losses_to_log(loss_list, range(start_printing, i + 1), params['logdir']) start_printing = i + 1 utils.plot_losses(params['figdir'], loss_list, params['mpl_format']) logging.info('Step {:d}: Wrote losses and plots'.format(i)) if finished_training: break