Esempio n. 1
0
def train(sess,
          config,
          model,
          meta_dataset,
          mvalid=None,
          meta_val_dataset=None,
          log_results=True,
          run_eval=True,
          exp_id=None):
    lr_scheduler = FixedLearnRateScheduler(sess,
                                           model,
                                           config.learn_rate,
                                           config.lr_decay_steps,
                                           lr_list=config.lr_list)

    if exp_id is None:
        exp_id = gen_id(config)

    saver = tf.train.Saver()
    save_folder = os.path.join(FLAGS.results, exp_id)
    save_config(config, save_folder)
    if log_results:
        logs_folder = os.path.join("logs", exp_id)
        exp_logger = ExperimentLogger(logs_folder)
    it = tqdm(six.moves.xrange(config.max_train_steps), desc=exp_id, ncols=0)

    trn_acc = 0.0
    val_acc = 0.0
    lr = lr_scheduler.lr
    for niter in it:
        lr_scheduler.step(niter)
        dataset = meta_dataset.next()
        batch = dataset.next_batch()
        batch = preprocess_batch(batch)

        feed_dict = {
            model.x_train: batch.x_train,
            model.y_train: batch.y_train,
            model.x_test: batch.x_test,
            model.y_test: batch.y_test
        }
        if hasattr(model, '_x_unlabel'):
            if batch.x_unlabel is not None:
                feed_dict[model.x_unlabel] = batch.x_unlabel
            else:
                feed_dict[model.x_unlabel] = batch.x_test

        loss_val, y_pred, _ = sess.run(
            [model.loss, model.prediction, model.train_op],
            feed_dict=feed_dict)

        if (niter + 1) % config.steps_per_valid == 0 and run_eval:
            train_results = evaluate(sess, mvalid, meta_dataset)
            tensorboard.add_scalar("Accuracy", train_results['acc'],
                                   (niter + 1) // config.steps_per_valid)
            tensorboard.add_scalar("Loss", loss_val,
                                   (niter + 1) // config.steps_per_valid)
            if log_results:
                exp_logger.log_train_acc(niter, train_results['acc'])
                exp_logger.log_learn_rate(niter, lr_scheduler.lr)
                lr = lr_scheduler.lr
                trn_acc = train_results['acc']

            if mvalid is not None:
                val_results = evaluate(sess, mvalid, meta_val_dataset)

                if log_results:
                    exp_logger.log_valid_acc(niter, val_results['acc'])
                    exp_logger.log_learn_rate(niter, lr_scheduler.lr)
                    val_acc = val_results['acc']
                    it.set_postfix()
                    meta_val_dataset.reset()

        if (niter + 1) % config.steps_per_log == 0 and log_results:
            exp_logger.log_train_ce(niter + 1, loss_val)
            it.set_postfix(ce='{:.3e}'.format(loss_val),
                           trn_acc='{:.3f}'.format(trn_acc * 100.0),
                           val_acc='{:.3f}'.format(val_acc * 100.0),
                           lr='{:.3e}'.format(lr))

        if (niter + 1) % config.steps_per_save == 0:
            save(sess, saver, niter, save_folder)
    return exp_id
Esempio n. 2
0
def train(sess,
          config,
          model,
          meta_dataset,
          mvalid=None,
          meta_val_dataset=None,
          label_dataset=None,
          unlabel_dataset=None,
          log_results=True,
          summarize=True,
          run_eval=True,
          exp_id=None):
    lr_scheduler = FixedLearnRateScheduler(sess,
                                           model,
                                           config.learn_rate,
                                           config.lr_decay_steps,
                                           lr_list=config.lr_list)

    if exp_id is None:
        exp_id = gen_id(config)

    saver = tf.train.Saver()
    save_folder = os.path.join(FLAGS.results, exp_id)
    save_config(config, save_folder)
    train_writer = tf.summary.FileWriter(os.path.join(save_folder, 'graph'),
                                         sess.graph)

    if log_results:
        logs_folder = os.path.join("logs", exp_id)
        exp_logger = ExperimentLogger(logs_folder)
    it = tqdm(six.moves.xrange(config.max_train_steps), desc=exp_id, ncols=0)

    trn_acc = 0.0
    val_acc = 0.0
    lr = lr_scheduler.lr
    for niter in it:
        with tf.name_scope('Lr-step'):
            lr_scheduler.step(niter)
        dataset = meta_dataset.next()
        batch = dataset.next_batch()
        batch = preprocess_batch(batch)

        feed_dict = {
            model.x_train: batch.x_train,
            model.y_train: batch.y_train,
            model.x_test: batch.x_test,
            model.y_test: batch.y_test
        }
        if hasattr(model, '_x_unlabel'):
            if batch.x_unlabel is not None:
                feed_dict[model.x_unlabel] = batch.x_unlabel
            else:
                feed_dict[model.x_unlabel] = batch.x_test
        if hasattr(model, 'training_data'):
            x, y = label_dataset.__next__()
            feed_dict[model.training_data] = x
            feed_dict[model.training_labels] = y
        if hasattr(model, 'unlabeled_training_data'):
            feed_dict[
                model.unlabeled_training_data] = unlabel_dataset.__next__()[0]

        if (niter + 1) % FLAGS.steps_per_summary == 0 and summarize:
            loss_val, y_pred, _, summary, adv_summary = sess.run(
                [
                    model.loss, model.prediction, model.train_op,
                    model.merged_summary, model.merged_adv_summary
                ],
                feed_dict=feed_dict)
            train_writer.add_summary(summary, niter)
            train_writer.add_summary(adv_summary, niter)
        else:
            loss_val, y_pred, _ = sess.run(
                [model.loss, model.prediction, model.train_op],
                feed_dict=feed_dict)

        if (niter + 1) % config.steps_per_valid == 0 and run_eval:
            train_results = evaluate(sess, mvalid, meta_dataset)
            if log_results:
                exp_logger.log_train_acc(niter, train_results['acc'])
                exp_logger.log_learn_rate(niter, lr_scheduler.lr)
                lr = lr_scheduler.lr
                trn_acc = train_results['acc']

            if mvalid is not None:
                val_results = evaluate(sess, mvalid, meta_val_dataset)

                if log_results:
                    exp_logger.log_valid_acc(niter, val_results['acc'])
                    exp_logger.log_learn_rate(niter, lr_scheduler.lr)
                    val_acc = val_results['acc']
                    it.set_postfix()
                    meta_val_dataset.reset()

        if (niter + 1) % config.steps_per_log == 0 and log_results:
            exp_logger.log_train_ce(niter + 1, loss_val)
            it.set_postfix(ce='{:.3e}'.format(loss_val),
                           trn_acc='{:.3f}'.format(trn_acc * 100.0),
                           val_acc='{:.3f}'.format(val_acc * 100.0),
                           lr='{:.3e}'.format(lr))

        if (niter + 1) % config.steps_per_save == 0:
            save(sess, saver, niter, save_folder)
    return exp_id