def train(config, model, optimizer, meta_dataset, meta_val_dataset=None, log_results=True, run_eval=True, exp_id=None): lr_scheduler = optim.lr_scheduler.MultiStepLR(optimizer, config.lr_decay_steps, gamma=0.5) if exp_id is None: exp_id = gen_id(config) save_folder = os.path.join(args.results, exp_id) save_config(config, save_folder) if args.super_classes: total_classes = args.nsuperclassestrain else: total_classes = args.nclasses_train #set up logging and printing if log_results: logs_folder = os.path.join("logs", exp_id) exp_logger = ExperimentLogger(logs_folder) it = tqdm(six.moves.xrange(args.accumulation_steps * config.max_train_steps), desc=exp_id, ncols=0) #Initialize for training loop model.train() time1 = time.time() lr = [] clip = 1000 # for clipping loss best_acc = 0 # for saving best model #training loop for niter in it: if niter % args.accumulation_steps == 0: optimizer.zero_grad() lr_scheduler.step() for param_group in optimizer.param_groups: lr += [param_group['lr']] dataset = meta_dataset.next_episode( within_category=args.super_classes) if args.accumulation_steps > 1: classes = np.random.choice(range(0, total_classes), args.nclasses_episode, replace=False) batch = dataset.next_batch_separate(classes, args.nclasses_episode) else: batch = dataset.next_batch() batch = preprocess_batch(batch) loss, output = model(batch, super_classes=args.super_classes) loss.backward() torch.nn.utils.clip_grad_norm(model.parameters(), clip) if (niter + 1) % args.accumulation_steps == 0: optimizer.step() ##LOG and SAVE if (niter + 1) % (args.accumulation_steps * config.steps_per_valid) == 0 and run_eval: if log_results: exp_logger.log_learn_rate(niter, lr[-1]) val_results = evaluate(model, meta_val_dataset, num_episodes=args.num_eval_episode) model.train() if log_results: exp_logger.log_valid_acc(niter, val_results['acc']) exp_logger.log_learn_rate(niter, lr[-1]) val_acc = val_results['acc'] it.set_postfix() meta_val_dataset.reset() if (niter + 1) % (args.accumulation_steps * config.steps_per_log) == 0 and log_results: exp_logger.log_train_ce(niter + 1, output['loss']) it.set_postfix(ce='{:.3e}'.format(output['loss']), val_acc='{:.3f}'.format(val_acc * 100.0), lr='{:.3e}'.format(lr[-1])) print('\n') if (niter + 1) % (args.accumulation_steps * config.steps_per_save) == 0: if val_results['acc'] >= best_acc: best_acc = val_results['acc'] save(model, "best", save_folder) save(model, niter, save_folder) return exp_id
def train(sess, config, model, meta_dataset, mvalid=None, meta_val_dataset=None, log_results=True, run_eval=True, exp_id=None): lr_scheduler = FixedLearnRateScheduler(sess, model, config.learn_rate, config.lr_decay_steps, lr_list=config.lr_list) if exp_id is None: exp_id = gen_id(config) saver = tf.train.Saver() save_folder = os.path.join(FLAGS.results, exp_id) save_config(config, save_folder) if log_results: logs_folder = os.path.join("logs", exp_id) exp_logger = ExperimentLogger(logs_folder) it = tqdm(six.moves.xrange(config.max_train_steps), desc=exp_id, ncols=0) trn_acc = 0.0 val_acc = 0.0 lr = lr_scheduler.lr for niter in it: lr_scheduler.step(niter) dataset = meta_dataset.next() batch = dataset.next_batch() batch = preprocess_batch(batch) feed_dict = { model.x_train: batch.x_train, model.y_train: batch.y_train, model.x_test: batch.x_test, model.y_test: batch.y_test } if hasattr(model, '_x_unlabel'): if batch.x_unlabel is not None: feed_dict[model.x_unlabel] = batch.x_unlabel else: feed_dict[model.x_unlabel] = batch.x_test loss_val, y_pred, _ = sess.run( [model.loss, model.prediction, model.train_op], feed_dict=feed_dict) if (niter + 1) % config.steps_per_valid == 0 and run_eval: train_results = evaluate(sess, mvalid, meta_dataset) tensorboard.add_scalar("Accuracy", train_results['acc'], (niter + 1) // config.steps_per_valid) tensorboard.add_scalar("Loss", loss_val, (niter + 1) // config.steps_per_valid) if log_results: exp_logger.log_train_acc(niter, train_results['acc']) exp_logger.log_learn_rate(niter, lr_scheduler.lr) lr = lr_scheduler.lr trn_acc = train_results['acc'] if mvalid is not None: val_results = evaluate(sess, mvalid, meta_val_dataset) if log_results: exp_logger.log_valid_acc(niter, val_results['acc']) exp_logger.log_learn_rate(niter, lr_scheduler.lr) val_acc = val_results['acc'] it.set_postfix() meta_val_dataset.reset() if (niter + 1) % config.steps_per_log == 0 and log_results: exp_logger.log_train_ce(niter + 1, loss_val) it.set_postfix(ce='{:.3e}'.format(loss_val), trn_acc='{:.3f}'.format(trn_acc * 100.0), val_acc='{:.3f}'.format(val_acc * 100.0), lr='{:.3e}'.format(lr)) if (niter + 1) % config.steps_per_save == 0: save(sess, saver, niter, save_folder) return exp_id
def train(sess, config, model, meta_dataset, mvalid=None, meta_val_dataset=None, label_dataset=None, unlabel_dataset=None, log_results=True, summarize=True, run_eval=True, exp_id=None): lr_scheduler = FixedLearnRateScheduler(sess, model, config.learn_rate, config.lr_decay_steps, lr_list=config.lr_list) if exp_id is None: exp_id = gen_id(config) saver = tf.train.Saver() save_folder = os.path.join(FLAGS.results, exp_id) save_config(config, save_folder) train_writer = tf.summary.FileWriter(os.path.join(save_folder, 'graph'), sess.graph) if log_results: logs_folder = os.path.join("logs", exp_id) exp_logger = ExperimentLogger(logs_folder) it = tqdm(six.moves.xrange(config.max_train_steps), desc=exp_id, ncols=0) trn_acc = 0.0 val_acc = 0.0 lr = lr_scheduler.lr for niter in it: with tf.name_scope('Lr-step'): lr_scheduler.step(niter) dataset = meta_dataset.next() batch = dataset.next_batch() batch = preprocess_batch(batch) feed_dict = { model.x_train: batch.x_train, model.y_train: batch.y_train, model.x_test: batch.x_test, model.y_test: batch.y_test } if hasattr(model, '_x_unlabel'): if batch.x_unlabel is not None: feed_dict[model.x_unlabel] = batch.x_unlabel else: feed_dict[model.x_unlabel] = batch.x_test if hasattr(model, 'training_data'): x, y = label_dataset.__next__() feed_dict[model.training_data] = x feed_dict[model.training_labels] = y if hasattr(model, 'unlabeled_training_data'): feed_dict[ model.unlabeled_training_data] = unlabel_dataset.__next__()[0] if (niter + 1) % FLAGS.steps_per_summary == 0 and summarize: loss_val, y_pred, _, summary, adv_summary = sess.run( [ model.loss, model.prediction, model.train_op, model.merged_summary, model.merged_adv_summary ], feed_dict=feed_dict) train_writer.add_summary(summary, niter) train_writer.add_summary(adv_summary, niter) else: loss_val, y_pred, _ = sess.run( [model.loss, model.prediction, model.train_op], feed_dict=feed_dict) if (niter + 1) % config.steps_per_valid == 0 and run_eval: train_results = evaluate(sess, mvalid, meta_dataset) if log_results: exp_logger.log_train_acc(niter, train_results['acc']) exp_logger.log_learn_rate(niter, lr_scheduler.lr) lr = lr_scheduler.lr trn_acc = train_results['acc'] if mvalid is not None: val_results = evaluate(sess, mvalid, meta_val_dataset) if log_results: exp_logger.log_valid_acc(niter, val_results['acc']) exp_logger.log_learn_rate(niter, lr_scheduler.lr) val_acc = val_results['acc'] it.set_postfix() meta_val_dataset.reset() if (niter + 1) % config.steps_per_log == 0 and log_results: exp_logger.log_train_ce(niter + 1, loss_val) it.set_postfix(ce='{:.3e}'.format(loss_val), trn_acc='{:.3f}'.format(trn_acc * 100.0), val_acc='{:.3f}'.format(val_acc * 100.0), lr='{:.3e}'.format(lr)) if (niter + 1) % config.steps_per_save == 0: save(sess, saver, niter, save_folder) return exp_id