def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    params = FLAGS.flag_values_dict()
    tf.set_random_seed(params['seed'])

    params['results_dir'] = utils.make_subdir(params['training_results_dir'],
                                              params['expname'])
    params['figdir'] = utils.make_subdir(params['results_dir'], 'figs')
    params['ckptdir'] = utils.make_subdir(params['results_dir'], 'ckpts')
    params['logdir'] = utils.make_subdir(params['results_dir'], 'logs')
    params['tensordir'] = utils.make_subdir(params['results_dir'], 'tensors')

    conv_dims = [int(x) for x in params['conv_dims'].split(',')]
    conv_sizes = [int(x) for x in params['conv_sizes'].split(',')]
    dense_sizes = [int(x) for x in params['dense_sizes'].split(',')]
    params['n_layers'] = len(conv_dims)
    clf = classifier.CNN(conv_dims,
                         conv_sizes,
                         dense_sizes,
                         params['n_classes'],
                         onehot=True)

    utils.checkpoint_model(clf, params['ckptdir'], 'initmodel')

    itr_train, itr_valid, itr_test = dataset_utils.load_dset_supervised_onehot(
    )

    train_cnn.train_classifier(clf, itr_train, itr_valid, params)
    train_cnn.test_classifier(clf, itr_test, params, 'test')

    train_data = utils.aggregate_batches(itr_train, 1000,
                                         ['train_x_infl', 'train_y_infl'])

    validation_data = utils.aggregate_batches(itr_valid, 1000,
                                              ['valid_x_infl', 'valid_y_infl'])

    test_data = utils.aggregate_batches(itr_test, 1000,
                                        ['test_x_infl', 'test_y_infl'])

    utils.save_tensors(
        train_data.items() + validation_data.items() + test_data.items(),
        params['tensordir'])
Ejemplo n.º 2
0
def train_vae(vae, itr_train, itr_valid, params):
    """Train a VAE.

  Args:
    vae (VAE): a VAE.
    itr_train (Iterator): an iterator over training data.
    itr_valid (Iterator): an iterator over validation data.
    params (dict): flags for training.

  """
    run_avg_len = params['run_avg_len']
    max_steps = params['max_steps']
    print_freq = params['print_freq']

    # RALoss is an object which tracks the running average of a loss.
    ra_loss = RALoss('elbo', run_avg_len)
    ra_kl = RALoss('kl', run_avg_len)
    ra_recon = RALoss('recon', run_avg_len)
    ra_trainloss = RALoss('train-elbo', run_avg_len)

    min_val_loss = sys.maxsize
    min_val_step = 0
    opt = tf.train.AdamOptimizer(learning_rate=params['lr'])
    finished_training = False
    start_printing = 0
    for i in range(max_steps):
        batch = itr_train.next()
        with tf.GradientTape() as tape:
            train_loss, _, _ = vae.get_loss(batch)
            mean_train_loss = tf.reduce_mean(train_loss)

        val_batch = itr_valid.next()
        valid_loss, kl_loss, recon_loss = vae.get_loss(val_batch)
        loss_list = [ra_loss, ra_kl, ra_recon, ra_trainloss]
        losses = zip(loss_list, [
            tf.reduce_mean(l)
            for l in (valid_loss, kl_loss, recon_loss, train_loss)
        ])
        utils.update_losses(losses)

        grads = tape.gradient(mean_train_loss, vae.weights)
        opt.apply_gradients(zip(grads, vae.weights))

        curr_ra_loss = ra_loss.get_value()
        # Early stopping: stop training when validation loss stops decreasing.
        # The second condition ensures we don't checkpoint every step early on.
        if curr_ra_loss < min_val_loss and \
            i - min_val_step > params['patience'] / 10:
            min_val_loss = curr_ra_loss
            min_val_step = i
            save_path, ckpt = utils.checkpoint_model(vae, params['ckptdir'])
            logging.info('Step {:d}: Checkpointed to {}'.format(i, save_path))
        elif i - min_val_step > params['patience'] or i == max_steps - 1:
            ckpt.restore(save_path)
            logging.info('Best validation loss was {:.3f} at step {:d}'
                         ' - stopping training'.format(min_val_loss,
                                                       min_val_step))
            finished_training = True

        if i % print_freq == 0 or finished_training:
            utils.print_losses(loss_list, i)
            utils.write_losses_to_log(loss_list, range(start_printing, i + 1),
                                      params['logdir'])
            start_printing = i + 1
            utils.plot_losses(params['figdir'], loss_list,
                              params['mpl_format'])
            utils.plot_samples(params['sampledir'], vae, itr_valid,
                               params['mpl_format'])

        if finished_training:
            break
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    params = FLAGS.flag_values_dict()
    tf.set_random_seed(params['seed'])

    params['results_dir'] = utils.make_subdir(params['training_results_dir'],
                                              params['expname'])
    params['figdir'] = utils.make_subdir(params['results_dir'], 'figs')
    params['ckptdir'] = utils.make_subdir(params['results_dir'], 'ckpts')
    params['logdir'] = utils.make_subdir(params['results_dir'], 'logs')
    params['tensordir'] = utils.make_subdir(params['results_dir'], 'tensors')

    # Load the classification model.
    conv_dims = [
        int(x) for x in (
            params['conv_dims'].split(',') if params['conv_dims'] else [])
    ]
    conv_sizes = [
        int(x) for x in (
            params['conv_sizes'].split(',') if params['conv_sizes'] else [])
    ]
    dense_sizes = [
        int(x) for x in (
            params['dense_sizes'].split(',') if params['dense_sizes'] else [])
    ]
    params['n_layers'] = len(conv_dims)
    clf = classifier.CNN(conv_dims,
                         conv_sizes,
                         dense_sizes,
                         params['n_classes'],
                         onehot=True)

    # Checkpoint the initialized model, in case we want to re-run it from there.
    utils.checkpoint_model(clf, params['ckptdir'], 'initmodel')

    # Load the "in-distribution" and "out-of-distribution" classes as
    # separate splits.
    ood_classes = [int(x) for x in params['ood_classes'].split(',')]
    # We assume we train on all non-OOD classes.
    all_classes = range(params['n_classes'])
    ind_classes = [x for x in all_classes if x not in ood_classes]
    (itr_train, itr_valid, itr_test,
     itr_test_ood) = dataset_utils.load_dset_ood_supervised_onehot(
         ind_classes,
         ood_classes,
         label_noise=(params['label_noise']),
         dset_name=params['dataset_name'])
    # Train and test the model in-distribution, and save test outputs.
    train_cnn.train_classifier(clf, itr_train, itr_valid, params)
    train_cnn.test_classifier(clf, itr_test, params, 'test')

    # Save model outputs on the training set.
    params['tensordir'] = utils.make_subdir(params['results_dir'],
                                            'train_tensors')
    train_cnn.test_classifier(clf, itr_train, params, 'train')

    # Save model outputs on the OOD set.
    params['tensordir'] = utils.make_subdir(params['results_dir'],
                                            'ood_tensors')
    train_cnn.test_classifier(clf, itr_test_ood, params, 'ood')

    params['tensordir'] = utils.make_subdir(params['results_dir'], 'tensors')

    # Save to disk samples of size 1000 from the train, valid, test and OOD sets.
    train_data = utils.aggregate_batches(itr_train, 1000,
                                         ['train_x_infl', 'train_y_infl'])

    validation_data = utils.aggregate_batches(itr_valid, 1000,
                                              ['valid_x_infl', 'valid_y_infl'])

    test_data = utils.aggregate_batches(itr_test, 1000,
                                        ['test_x_infl', 'test_y_infl'])

    ood_data = utils.aggregate_batches(itr_test_ood, 1000,
                                       ['ood_x_infl', 'ood_y_infl'])
    utils.save_tensors(
        train_data.items() + validation_data.items() + test_data.items() +
        ood_data.items(), params['tensordir'])
Ejemplo n.º 4
0
def train_classifier(clf, itr_train, itr_valid, params):
  """Train a classifier.

  Args:
    clf (classifier): a classifier we wish to train.
    itr_train (Iterator): an iterator over training data.
    itr_valid (Iterator): an iterator over validation data.
    params (dict): flags for training.

  """
  # Dump the parameters we used to a JSON file for reproducibility.
  params_file = os.path.join(params['results_dir'], 'params.json')
  utils.write_json(params_file, params)

  run_avg_len = params['run_avg_len']
  max_steps = params['max_steps']
  print_freq = params['print_freq']

  # RALoss is an object which tracks the running average of a loss.
  ra_loss = RALoss('loss', run_avg_len)
  ra_error = RALoss('error', run_avg_len)
  ra_trainloss = RALoss('train-loss', run_avg_len)
  ra_trainerr = RALoss('train-err', run_avg_len)

  min_val_loss = sys.maxsize
  min_val_step = 0
  opt = tf.train.AdamOptimizer(learning_rate=params['lr'])
  finished_training = False
  start_printing = 0
  for i in range(max_steps):
    batch_x, batch_y = itr_train.next()
    with tf.GradientTape() as tape:
      train_loss, train_err = clf.get_loss(batch_x, batch_y)
      mean_train_loss = tf.reduce_mean(train_loss)

    val_batch_x, val_batch_y = itr_valid.next()
    valid_loss, valid_err = clf.get_loss(val_batch_x, val_batch_y)
    loss_list = [ra_loss, ra_error, ra_trainloss, ra_trainerr]
    losses = zip(loss_list,
                 [tf.reduce_mean(l) for l in
                  (valid_loss, valid_err, train_loss, train_err)])
    utils.update_losses(losses)

    grads = tape.gradient(mean_train_loss, clf.weights)
    opt.apply_gradients(zip(grads, clf.weights))

    utils.print_losses(loss_list, i)
    if params['early_stopping_metric'] == 'loss':
      curr_ra_loss = ra_loss.get_value()
    elif params['early_stopping_metric'] == 'error':
      curr_ra_loss = ra_error.get_value()
    else:
      raise ValueError('Params["early_stopping_metric"] should be either "loss"'
                       ' or "error", and it is "{}"'.format(
                           params['early_stopping_metric']))
    if curr_ra_loss < min_val_loss and \
        i - min_val_step > params['patience'] / 10:
      # Early stopping: stop training when validation loss stops decreasing.
      # The second condition ensures we don't checkpoint every step early on.
      min_val_loss = curr_ra_loss
      min_val_step = i
      save_path, ckpt = utils.checkpoint_model(clf, params['ckptdir'])
      logging.info('Step {:d}: Checkpointed to {}'.format(i, save_path))
    elif i - min_val_step > params['patience'] or i == max_steps - 1:
      ckpt.restore(save_path)
      finished_training = True
      logging.info('Best validation loss was {:.3f} at step {:d}'
                   ' - stopping training'.format(min_val_loss, min_val_step))

    if i % print_freq == 0 or finished_training:
      utils.write_losses_to_log(loss_list, range(start_printing, i + 1),
                                params['logdir'])
      start_printing = i + 1
      utils.plot_losses(params['figdir'], loss_list, params['mpl_format'])
      logging.info('Step {:d}: Wrote losses and plots'.format(i))

    if finished_training:
      break