Exemple #1
0
def train(config,
          model,
          optimizer,
          meta_dataset,
          meta_val_dataset=None,
          log_results=True,
          run_eval=True,
          exp_id=None):

    lr_scheduler = optim.lr_scheduler.MultiStepLR(optimizer,
                                                  config.lr_decay_steps,
                                                  gamma=0.5)
    if exp_id is None:
        exp_id = gen_id(config)

    save_folder = os.path.join(args.results, exp_id)
    save_config(config, save_folder)

    if args.super_classes:
        total_classes = args.nsuperclassestrain
    else:
        total_classes = args.nclasses_train

    #set up logging and printing
    if log_results:
        logs_folder = os.path.join("logs", exp_id)
        exp_logger = ExperimentLogger(logs_folder)
    it = tqdm(six.moves.xrange(args.accumulation_steps *
                               config.max_train_steps),
              desc=exp_id,
              ncols=0)

    #Initialize for training loop
    model.train()
    time1 = time.time()
    lr = []
    clip = 1000  # for clipping loss
    best_acc = 0  # for saving best model

    #training loop
    for niter in it:

        if niter % args.accumulation_steps == 0:
            optimizer.zero_grad()
            lr_scheduler.step()
            for param_group in optimizer.param_groups:
                lr += [param_group['lr']]

            dataset = meta_dataset.next_episode(
                within_category=args.super_classes)

        if args.accumulation_steps > 1:
            classes = np.random.choice(range(0, total_classes),
                                       args.nclasses_episode,
                                       replace=False)
            batch = dataset.next_batch_separate(classes, args.nclasses_episode)
        else:
            batch = dataset.next_batch()

        batch = preprocess_batch(batch)

        loss, output = model(batch, super_classes=args.super_classes)
        loss.backward()

        torch.nn.utils.clip_grad_norm(model.parameters(), clip)

        if (niter + 1) % args.accumulation_steps == 0:
            optimizer.step()

        ##LOG and SAVE
        if (niter + 1) % (args.accumulation_steps *
                          config.steps_per_valid) == 0 and run_eval:
            if log_results:
                exp_logger.log_learn_rate(niter, lr[-1])
            val_results = evaluate(model,
                                   meta_val_dataset,
                                   num_episodes=args.num_eval_episode)
            model.train()
            if log_results:
                exp_logger.log_valid_acc(niter, val_results['acc'])
                exp_logger.log_learn_rate(niter, lr[-1])
                val_acc = val_results['acc']
                it.set_postfix()
                meta_val_dataset.reset()

            if (niter + 1) % (args.accumulation_steps *
                              config.steps_per_log) == 0 and log_results:
                exp_logger.log_train_ce(niter + 1, output['loss'])
                it.set_postfix(ce='{:.3e}'.format(output['loss']),
                               val_acc='{:.3f}'.format(val_acc * 100.0),
                               lr='{:.3e}'.format(lr[-1]))
                print('\n')

        if (niter + 1) % (args.accumulation_steps *
                          config.steps_per_save) == 0:
            if val_results['acc'] >= best_acc:
                best_acc = val_results['acc']
                save(model, "best", save_folder)

            save(model, niter, save_folder)

    return exp_id
Exemple #2
0
def train(sess,
          config,
          model,
          meta_dataset,
          mvalid=None,
          meta_val_dataset=None,
          log_results=True,
          run_eval=True,
          exp_id=None):
    lr_scheduler = FixedLearnRateScheduler(sess,
                                           model,
                                           config.learn_rate,
                                           config.lr_decay_steps,
                                           lr_list=config.lr_list)

    if exp_id is None:
        exp_id = gen_id(config)

    saver = tf.train.Saver()
    save_folder = os.path.join(FLAGS.results, exp_id)
    save_config(config, save_folder)
    if log_results:
        logs_folder = os.path.join("logs", exp_id)
        exp_logger = ExperimentLogger(logs_folder)
    it = tqdm(six.moves.xrange(config.max_train_steps), desc=exp_id, ncols=0)

    trn_acc = 0.0
    val_acc = 0.0
    lr = lr_scheduler.lr
    for niter in it:
        lr_scheduler.step(niter)
        dataset = meta_dataset.next()
        batch = dataset.next_batch()
        batch = preprocess_batch(batch)

        feed_dict = {
            model.x_train: batch.x_train,
            model.y_train: batch.y_train,
            model.x_test: batch.x_test,
            model.y_test: batch.y_test
        }
        if hasattr(model, '_x_unlabel'):
            if batch.x_unlabel is not None:
                feed_dict[model.x_unlabel] = batch.x_unlabel
            else:
                feed_dict[model.x_unlabel] = batch.x_test

        loss_val, y_pred, _ = sess.run(
            [model.loss, model.prediction, model.train_op],
            feed_dict=feed_dict)

        if (niter + 1) % config.steps_per_valid == 0 and run_eval:
            train_results = evaluate(sess, mvalid, meta_dataset)
            tensorboard.add_scalar("Accuracy", train_results['acc'],
                                   (niter + 1) // config.steps_per_valid)
            tensorboard.add_scalar("Loss", loss_val,
                                   (niter + 1) // config.steps_per_valid)
            if log_results:
                exp_logger.log_train_acc(niter, train_results['acc'])
                exp_logger.log_learn_rate(niter, lr_scheduler.lr)
                lr = lr_scheduler.lr
                trn_acc = train_results['acc']

            if mvalid is not None:
                val_results = evaluate(sess, mvalid, meta_val_dataset)

                if log_results:
                    exp_logger.log_valid_acc(niter, val_results['acc'])
                    exp_logger.log_learn_rate(niter, lr_scheduler.lr)
                    val_acc = val_results['acc']
                    it.set_postfix()
                    meta_val_dataset.reset()

        if (niter + 1) % config.steps_per_log == 0 and log_results:
            exp_logger.log_train_ce(niter + 1, loss_val)
            it.set_postfix(ce='{:.3e}'.format(loss_val),
                           trn_acc='{:.3f}'.format(trn_acc * 100.0),
                           val_acc='{:.3f}'.format(val_acc * 100.0),
                           lr='{:.3e}'.format(lr))

        if (niter + 1) % config.steps_per_save == 0:
            save(sess, saver, niter, save_folder)
    return exp_id
Exemple #3
0
def train(sess,
          config,
          model,
          meta_dataset,
          mvalid=None,
          meta_val_dataset=None,
          label_dataset=None,
          unlabel_dataset=None,
          log_results=True,
          summarize=True,
          run_eval=True,
          exp_id=None):
    lr_scheduler = FixedLearnRateScheduler(sess,
                                           model,
                                           config.learn_rate,
                                           config.lr_decay_steps,
                                           lr_list=config.lr_list)

    if exp_id is None:
        exp_id = gen_id(config)

    saver = tf.train.Saver()
    save_folder = os.path.join(FLAGS.results, exp_id)
    save_config(config, save_folder)
    train_writer = tf.summary.FileWriter(os.path.join(save_folder, 'graph'),
                                         sess.graph)

    if log_results:
        logs_folder = os.path.join("logs", exp_id)
        exp_logger = ExperimentLogger(logs_folder)
    it = tqdm(six.moves.xrange(config.max_train_steps), desc=exp_id, ncols=0)

    trn_acc = 0.0
    val_acc = 0.0
    lr = lr_scheduler.lr
    for niter in it:
        with tf.name_scope('Lr-step'):
            lr_scheduler.step(niter)
        dataset = meta_dataset.next()
        batch = dataset.next_batch()
        batch = preprocess_batch(batch)

        feed_dict = {
            model.x_train: batch.x_train,
            model.y_train: batch.y_train,
            model.x_test: batch.x_test,
            model.y_test: batch.y_test
        }
        if hasattr(model, '_x_unlabel'):
            if batch.x_unlabel is not None:
                feed_dict[model.x_unlabel] = batch.x_unlabel
            else:
                feed_dict[model.x_unlabel] = batch.x_test
        if hasattr(model, 'training_data'):
            x, y = label_dataset.__next__()
            feed_dict[model.training_data] = x
            feed_dict[model.training_labels] = y
        if hasattr(model, 'unlabeled_training_data'):
            feed_dict[
                model.unlabeled_training_data] = unlabel_dataset.__next__()[0]

        if (niter + 1) % FLAGS.steps_per_summary == 0 and summarize:
            loss_val, y_pred, _, summary, adv_summary = sess.run(
                [
                    model.loss, model.prediction, model.train_op,
                    model.merged_summary, model.merged_adv_summary
                ],
                feed_dict=feed_dict)
            train_writer.add_summary(summary, niter)
            train_writer.add_summary(adv_summary, niter)
        else:
            loss_val, y_pred, _ = sess.run(
                [model.loss, model.prediction, model.train_op],
                feed_dict=feed_dict)

        if (niter + 1) % config.steps_per_valid == 0 and run_eval:
            train_results = evaluate(sess, mvalid, meta_dataset)
            if log_results:
                exp_logger.log_train_acc(niter, train_results['acc'])
                exp_logger.log_learn_rate(niter, lr_scheduler.lr)
                lr = lr_scheduler.lr
                trn_acc = train_results['acc']

            if mvalid is not None:
                val_results = evaluate(sess, mvalid, meta_val_dataset)

                if log_results:
                    exp_logger.log_valid_acc(niter, val_results['acc'])
                    exp_logger.log_learn_rate(niter, lr_scheduler.lr)
                    val_acc = val_results['acc']
                    it.set_postfix()
                    meta_val_dataset.reset()

        if (niter + 1) % config.steps_per_log == 0 and log_results:
            exp_logger.log_train_ce(niter + 1, loss_val)
            it.set_postfix(ce='{:.3e}'.format(loss_val),
                           trn_acc='{:.3f}'.format(trn_acc * 100.0),
                           val_acc='{:.3f}'.format(val_acc * 100.0),
                           lr='{:.3e}'.format(lr))

        if (niter + 1) % config.steps_per_save == 0:
            save(sess, saver, niter, save_folder)
    return exp_id