def test_get_value(self):
    l = RALoss('name', 3)

    dat = [0, 1, 2, 3, 4, 5, 6, 7]
    for x in dat:
      l.update(x)
    self.assertEqual(dat, l.get_history())
    self.assertEqual(6.0, l.get_value())
    self.assertEqual(3.0, l.get_value(i=4))
Example #2
0
def train_vae(vae, itr_train, itr_valid, params):
    """Train a VAE.

  Args:
    vae (VAE): a VAE.
    itr_train (Iterator): an iterator over training data.
    itr_valid (Iterator): an iterator over validation data.
    params (dict): flags for training.

  """
    run_avg_len = params['run_avg_len']
    max_steps = params['max_steps']
    print_freq = params['print_freq']

    # RALoss is an object which tracks the running average of a loss.
    ra_loss = RALoss('elbo', run_avg_len)
    ra_kl = RALoss('kl', run_avg_len)
    ra_recon = RALoss('recon', run_avg_len)
    ra_trainloss = RALoss('train-elbo', run_avg_len)

    min_val_loss = sys.maxsize
    min_val_step = 0
    opt = tf.train.AdamOptimizer(learning_rate=params['lr'])
    finished_training = False
    start_printing = 0
    for i in range(max_steps):
        batch = itr_train.next()
        with tf.GradientTape() as tape:
            train_loss, _, _ = vae.get_loss(batch)
            mean_train_loss = tf.reduce_mean(train_loss)

        val_batch = itr_valid.next()
        valid_loss, kl_loss, recon_loss = vae.get_loss(val_batch)
        loss_list = [ra_loss, ra_kl, ra_recon, ra_trainloss]
        losses = zip(loss_list, [
            tf.reduce_mean(l)
            for l in (valid_loss, kl_loss, recon_loss, train_loss)
        ])
        utils.update_losses(losses)

        grads = tape.gradient(mean_train_loss, vae.weights)
        opt.apply_gradients(zip(grads, vae.weights))

        curr_ra_loss = ra_loss.get_value()
        # Early stopping: stop training when validation loss stops decreasing.
        # The second condition ensures we don't checkpoint every step early on.
        if curr_ra_loss < min_val_loss and \
            i - min_val_step > params['patience'] / 10:
            min_val_loss = curr_ra_loss
            min_val_step = i
            save_path, ckpt = utils.checkpoint_model(vae, params['ckptdir'])
            logging.info('Step {:d}: Checkpointed to {}'.format(i, save_path))
        elif i - min_val_step > params['patience'] or i == max_steps - 1:
            ckpt.restore(save_path)
            logging.info('Best validation loss was {:.3f} at step {:d}'
                         ' - stopping training'.format(min_val_loss,
                                                       min_val_step))
            finished_training = True

        if i % print_freq == 0 or finished_training:
            utils.print_losses(loss_list, i)
            utils.write_losses_to_log(loss_list, range(start_printing, i + 1),
                                      params['logdir'])
            start_printing = i + 1
            utils.plot_losses(params['figdir'], loss_list,
                              params['mpl_format'])
            utils.plot_samples(params['sampledir'], vae, itr_valid,
                               params['mpl_format'])

        if finished_training:
            break
Example #3
0
def train_classifier(clf, itr_train, itr_valid, params):
  """Train a classifier.

  Args:
    clf (classifier): a classifier we wish to train.
    itr_train (Iterator): an iterator over training data.
    itr_valid (Iterator): an iterator over validation data.
    params (dict): flags for training.

  """
  # Dump the parameters we used to a JSON file for reproducibility.
  params_file = os.path.join(params['results_dir'], 'params.json')
  utils.write_json(params_file, params)

  run_avg_len = params['run_avg_len']
  max_steps = params['max_steps']
  print_freq = params['print_freq']

  # RALoss is an object which tracks the running average of a loss.
  ra_loss = RALoss('loss', run_avg_len)
  ra_error = RALoss('error', run_avg_len)
  ra_trainloss = RALoss('train-loss', run_avg_len)
  ra_trainerr = RALoss('train-err', run_avg_len)

  min_val_loss = sys.maxsize
  min_val_step = 0
  opt = tf.train.AdamOptimizer(learning_rate=params['lr'])
  finished_training = False
  start_printing = 0
  for i in range(max_steps):
    batch_x, batch_y = itr_train.next()
    with tf.GradientTape() as tape:
      train_loss, train_err = clf.get_loss(batch_x, batch_y)
      mean_train_loss = tf.reduce_mean(train_loss)

    val_batch_x, val_batch_y = itr_valid.next()
    valid_loss, valid_err = clf.get_loss(val_batch_x, val_batch_y)
    loss_list = [ra_loss, ra_error, ra_trainloss, ra_trainerr]
    losses = zip(loss_list,
                 [tf.reduce_mean(l) for l in
                  (valid_loss, valid_err, train_loss, train_err)])
    utils.update_losses(losses)

    grads = tape.gradient(mean_train_loss, clf.weights)
    opt.apply_gradients(zip(grads, clf.weights))

    utils.print_losses(loss_list, i)
    if params['early_stopping_metric'] == 'loss':
      curr_ra_loss = ra_loss.get_value()
    elif params['early_stopping_metric'] == 'error':
      curr_ra_loss = ra_error.get_value()
    else:
      raise ValueError('Params["early_stopping_metric"] should be either "loss"'
                       ' or "error", and it is "{}"'.format(
                           params['early_stopping_metric']))
    if curr_ra_loss < min_val_loss and \
        i - min_val_step > params['patience'] / 10:
      # Early stopping: stop training when validation loss stops decreasing.
      # The second condition ensures we don't checkpoint every step early on.
      min_val_loss = curr_ra_loss
      min_val_step = i
      save_path, ckpt = utils.checkpoint_model(clf, params['ckptdir'])
      logging.info('Step {:d}: Checkpointed to {}'.format(i, save_path))
    elif i - min_val_step > params['patience'] or i == max_steps - 1:
      ckpt.restore(save_path)
      finished_training = True
      logging.info('Best validation loss was {:.3f} at step {:d}'
                   ' - stopping training'.format(min_val_loss, min_val_step))

    if i % print_freq == 0 or finished_training:
      utils.write_losses_to_log(loss_list, range(start_printing, i + 1),
                                params['logdir'])
      start_printing = i + 1
      utils.plot_losses(params['figdir'], loss_list, params['mpl_format'])
      logging.info('Step {:d}: Wrote losses and plots'.format(i))

    if finished_training:
      break