def _meta_update(model, total_grad, opt, task, maml_batchsize, clip_grad): ''' Aggregate the gradients in total_grad Update the initialization in model ''' model['ebd'].train() model['clf'].train() support, query = task XS = model['ebd'](support) pred = model['clf'](XS) loss = torch.sum(pred) # this doesn't matter # aggregate the gradients (skip nan) avg_grad = { 'ebd': {key: sum(g[key] for g in total_grad['ebd'] if not torch.sum(torch.isnan(g[key])) > 0)\ for key in total_grad['ebd'][0].keys()}, 'clf': {key: sum(g[key] for g in total_grad['clf'] if not torch.sum(torch.isnan(g[key])) > 0)\ for key in total_grad['clf'][0].keys()} } # register a hook on each parameter in the model that replaces # the current dummy grad with the meta gradiets hooks = [] for model_name in avg_grad.keys(): for key, value in model[model_name].named_parameters(): if not value.requires_grad: continue def get_closure(): k = key n = model_name def replace_grad(grad): return avg_grad[n][k] / maml_batchsize return replace_grad hooks.append(value.register_hook(get_closure())) opt.zero_grad() loss.backward() ebd_grad = get_norm(model['ebd']) clf_grad = get_norm(model['clf']) if clip_grad is not None: nn.utils.clip_grad_value_(grad_param(model, ['ebd', 'clf']), clip_grad) opt.step() for h in hooks: # remove the hooks before the next training phase h.remove() total_grad['ebd'] = [] total_grad['clf'] = [] return ebd_grad, clf_grad
def test_one(task, fast, args): ''' Evaluate the model on one sampled task. Return the accuracy. ''' support, query = task YS, YQ = fast['clf'].reidx_y(support['label'], query['label']) fast['ebd'].train() fast['clf'].train() opt = torch.optim.SGD(grad_param(fast, ['ebd', 'clf']), lr=args.maml_stepsize) for i in range(args.maml_innersteps): XS = fast['ebd'](support) pred = fast['clf'](XS) loss = F.cross_entropy(pred, YS) opt.zero_grad() loss.backward() opt.step() fast['ebd'].eval() fast['clf'].eval() XQ = fast['ebd'](query) pred = fast['clf'](XQ) acc = torch.mean((torch.argmax(pred, dim=1) == YQ).float()).item() return acc
def train_one_fomaml(task, fast, args, total_grad): ''' Update the fast_model based on the support set. Return the gradient w.r.t. initializations over the query set First order MAML ''' support, query = task # map class label into 0,...,num_classes-1 YS, YQ = fast['clf'].reidx_y(support['label'], query['label']) opt = torch.optim.SGD(grad_param(fast, ['ebd', 'clf']), lr=args.maml_stepsize) fast['ebd'].train() fast['clf'].train() # fast adaptation for i in range(args.maml_innersteps): opt.zero_grad() XS = fast['ebd'](support) acc, loss = fast['clf'](XS, YS) loss.backward() opt.step() # forward on the query, to get meta loss XQ = fast['ebd'](query) acc, loss = fast['clf'](XQ, YQ) loss.backward() grads_ebd = {name: p.grad for (name, p) in named_grad_param(fast, ['ebd'])\ if p.grad is not None} # pooler does not have grad in Bert grads_clf = {name: p.grad for (name, p) in named_grad_param(fast, ['clf'])} total_grad['ebd'].append(grads_ebd) total_grad['clf'].append(grads_clf) return
def train_one(task, model, opt, args, grad): ''' Train the model on one sampled task. ''' model['ebd'].train() if not args.classifier == 'nn': model['clf'].train() opt.zero_grad() support, query = task # Embedding the document XS = model['ebd'](support) YS = support['label'] XQ = model['ebd'](query) YQ = query['label'] # Apply the classifier _, loss = model['clf'](XS, YS, XQ, YQ) print('loss: ', loss) if loss is not None: loss.backward() if torch.isnan(loss): # do not update the parameters if the gradient is nan print("NAN detected") print(model['clf'].lam, model['clf'].alpha, model['clf'].beta) return if args.clip_grad is not None: nn.utils.clip_grad_value_(grad_param(model, ['ebd', 'clf']), args.clip_grad) if args.classifier != 'nn': grad['clf'].append(get_norm(model['clf'])) grad['ebd'].append(get_norm(model['ebd'])) if args.classifier != 'nn': opt.step()
def train(train_data, val_data, model, args): ''' Train the model Use val_data to do early stopping ''' # creating a tmp directory to save the models out_dir = os.path.abspath(os.path.join( os.path.curdir, "tmp-runs", str(int(time.time() * 1e7)))) if not os.path.exists(out_dir): os.makedirs(out_dir) best_acc = 0 sub_cycle = 0 best_path = None optG = torch.optim.Adam(grad_param(model, ['G', 'clf']), lr=args.lr_g) optD = torch.optim.Adam(grad_param(model, ['D']), lr=args.lr_d) if args.lr_scheduler == 'ReduceLROnPlateau': schedulerG = torch.optim.lr_scheduler.ReduceLROnPlateau( optG, 'max', patience=args.patience//2, factor=0.1, verbose=True) schedulerD = torch.optim.lr_scheduler.ReduceLROnPlateau( optD, 'max', patience=args.patience // 2, factor=0.1, verbose=True) elif args.lr_scheduler == 'ExponentialLR': schedulerG = torch.optim.lr_scheduler.ExponentialLR(optG, gamma=args.ExponentialLR_gamma) schedulerD = torch.optim.lr_scheduler.ExponentialLR(optD, gamma=args.ExponentialLR_gamma) print("{}, Start training".format( datetime.datetime.now()), flush=True) # train_gen = ParallelSampler(train_data, args, args.train_episodes) train_gen_val = ParallelSampler_Test(train_data, args, args.val_episodes) val_gen = ParallelSampler_Test(val_data, args, args.val_episodes) # sampled_classes, source_classes = task_sampler(train_data, args) for ep in range(args.train_epochs): sampled_classes, source_classes = task_sampler(train_data, args) train_gen = ParallelSampler(train_data, args, sampled_classes, source_classes, args.train_episodes) sampled_tasks = train_gen.get_epoch() grad = {'clf': [], 'G': [], 'D': []} if not args.notqdm: sampled_tasks = tqdm(sampled_tasks, total=train_gen.num_episodes, ncols=80, leave=False, desc=colored('Training on train', 'yellow')) d_acc = 0 for task in sampled_tasks: if task is None: break d_acc += train_one(task, model, optG, optD, args, grad) d_acc = d_acc / args.train_episodes print("---------------ep:" + str(ep) + " d_acc:" + str(d_acc) + "-----------") if ep % 10 == 0: acc, std, _ = test(train_data, model, args, args.val_episodes, False, train_gen_val.get_epoch()) print("{}, {:s} {:2d}, {:s} {:s}{:>7.4f} ± {:>6.4f} ".format( datetime.datetime.now(), "ep", ep, colored("train", "red"), colored("acc:", "blue"), acc, std, ), flush=True) # Evaluate validation accuracy cur_acc, cur_std, _ = test(val_data, model, args, args.val_episodes, False, val_gen.get_epoch()) print(("{}, {:s} {:2d}, {:s} {:s}{:>7.4f} ± {:>6.4f}, " "{:s} {:s}{:>7.4f}, {:s}{:>7.4f}").format( datetime.datetime.now(), "ep", ep, colored("val ", "cyan"), colored("acc:", "blue"), cur_acc, cur_std, colored("train stats", "cyan"), colored("G_grad:", "blue"), np.mean(np.array(grad['G'])), colored("clf_grad:", "blue"), np.mean(np.array(grad['clf'])), ), flush=True) # Update the current best model if val acc is better if cur_acc > best_acc: best_acc = cur_acc best_path = os.path.join(out_dir, str(ep)) # save current model print("{}, Save cur best model to {}".format( datetime.datetime.now(), best_path)) torch.save(model['G'].state_dict(), best_path + '.G') torch.save(model['D'].state_dict(), best_path + '.D') torch.save(model['clf'].state_dict(), best_path + '.clf') sub_cycle = 0 else: sub_cycle += 1 # Break if the val acc hasn't improved in the past patience epochs if sub_cycle == args.patience: break if args.lr_scheduler == 'ReduceLROnPlateau': schedulerG.step(cur_acc) schedulerD.step(cur_acc) elif args.lr_scheduler == 'ExponentialLR': schedulerG.step() schedulerD.step() print("{}, End of training. Restore the best weights".format( datetime.datetime.now()), flush=True) # restore the best saved model model['G'].load_state_dict(torch.load(best_path + '.G')) model['D'].load_state_dict(torch.load(best_path + '.D')) model['clf'].load_state_dict(torch.load(best_path + '.clf')) if args.save: # save the current model out_dir = os.path.abspath(os.path.join( os.path.curdir, "saved-runs", str(int(time.time() * 1e7)))) if not os.path.exists(out_dir): os.makedirs(out_dir) best_path = os.path.join(out_dir, 'best') print("{}, Save best model to {}".format( datetime.datetime.now(), best_path), flush=True) torch.save(model['G'].state_dict(), best_path + '.G') torch.save(model['D'].state_dict(), best_path + '.D') torch.save(model['clf'].state_dict(), best_path + '.clf') with open(best_path + '_args.txt', 'w') as f: for attr, value in sorted(args.__dict__.items()): f.write("{}={}\n".format(attr, value)) return
def train(train_data, val_data, model, args): ''' Train the model (obviously~) ''' # creating a tmp directory to save the models out_dir = os.path.abspath( os.path.join(os.path.curdir, "tmp-runs", str(int(time.time() * 1e7)))) if not os.path.exists(out_dir): os.makedirs(out_dir) best_acc = 0 sub_cycle = 0 best_path = None opt = torch.optim.Adam(grad_param(model, ['ebd', 'clf']), lr=args.lr) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(opt, 'max', patience=5, factor=0.1, verbose=True) # clone the original model fast_model = { 'ebd': copy.deepcopy(model['ebd']), 'clf': copy.deepcopy(model['clf']), } print("{}, Start training".format( datetime.datetime.now().strftime('%02y/%02m/%02d %H:%M:%S'))) train_gen = ParallelSampler(train_data, args, args.train_episodes * args.maml_batchsize) val_gen = ParallelSampler(val_data, args, args.val_episodes) for ep in range(args.train_epochs): sampled_tasks = train_gen.get_epoch() meta_grad_dict = {'clf': [], 'ebd': []} train_episodes = range(args.train_episodes) if not args.notqdm: train_episodes = tqdm(train_episodes, ncols=80, leave=False, desc=colored('Training on train', 'yellow')) for _ in train_episodes: # update the initialization based on a batch of tasks total_grad = {'ebd': [], 'clf': []} for _ in range(args.maml_batchsize): task = next(sampled_tasks) # clone the current initialization _copy_weights(model['ebd'], fast_model['ebd']) _copy_weights(model['clf'], fast_model['clf']) # get the meta gradient train_one(task, fast_model, args, total_grad) ebd_grad, clf_grad = _meta_update(model, total_grad, opt, task, args.maml_batchsize, args.clip_grad) meta_grad_dict['ebd'].append(ebd_grad) meta_grad_dict['clf'].append(clf_grad) # evaluate training accuracy if ep % 10 == 0: acc, std = test(train_data, model, args, args.train_episodes * args.maml_batchsize, False, train_gen.get_epoch()) print("{}, {:s} {:2d}, {:s} {:s}{:>7.4f} ± {:>6.4f} ".format( datetime.datetime.now().strftime('%02y/%02m/%02d %H:%M:%S'), "ep", ep, colored("train", "red"), colored("acc:", "blue"), acc, std, ), flush=True) # evaluate validation accuracy cur_acc, cur_std = test(val_data, model, args, args.val_episodes, False, val_gen.get_epoch()) print(("{}, {:s} {:2d}, {:s} {:s}{:>7.4f} ± {:>6.4f} " "{:s} {:s}{:>7.4f}, {:s}{:>7.4f}").format( datetime.datetime.now().strftime('%02y/%02m/%02d %H:%M:%S'), "ep", ep, colored("val ", "cyan"), colored("acc:", "blue"), cur_acc, cur_std, colored("train stats", "cyan"), colored("ebd_grad:", "blue"), np.mean(np.array(meta_grad_dict['ebd'])), colored("clf_grad:", "blue"), np.mean(np.array(meta_grad_dict['clf']))), flush=True) # Update the current best model if val acc is better if cur_acc > best_acc: best_acc = cur_acc best_path = os.path.join(out_dir, str(ep)) # save current model print("{}, Save cur best model to {}".format( datetime.datetime.now().strftime('%02y/%02m/%02d %H:%M:%S'), best_path)) torch.save(model['ebd'].state_dict(), best_path + '.ebd') torch.save(model['clf'].state_dict(), best_path + '.clf') sub_cycle = 0 else: sub_cycle += 1 # Break if the val acc hasn't improved in the past patience epochs if sub_cycle == args.patience: break print("{}, End of training. Restore the best weights".format( datetime.datetime.now().strftime('%02y/%02m/%02d %H:%M:%S'))) # restore the best saved model model['ebd'].load_state_dict(torch.load(best_path + '.ebd')) model['clf'].load_state_dict(torch.load(best_path + '.clf')) if args.save: # save the current model out_dir = os.path.abspath( os.path.join(os.path.curdir, "saved-runs", str(int(time.time() * 1e7)))) if not os.path.exists(out_dir): os.makedirs(out_dir) best_path = os.path.join(out_dir, 'best') print("{}, Save best model to {}".format( datetime.datetime.now().strftime('%02y/%02m/%02d %H:%M:%S'), best_path), flush=True) torch.save(model['ebd'].state_dict(), best_path + '.ebd') torch.save(model['clf'].state_dict(), best_path + '.clf') with open(best_path + '_args.txt', 'w') as f: for attr, value in sorted(args.__dict__.items()): f.write("{}={}\n".format(attr, value)) return
def train_one(task, fast, args, total_grad): ''' Update the fast_model based on the support set. Return the gradient w.r.t. initializations over the query set ''' support, query = task # map class label into 0,...,num_classes-1 YS, YQ = fast['clf'].reidx_y(support['label'], query['label']) fast['ebd'].train() fast['clf'].train() # get weights fast_weights = { 'ebd': OrderedDict((name, param) for (name, param) in named_grad_param(fast, ['ebd'])), 'clf': OrderedDict((name, param) for (name, param) in named_grad_param(fast, ['clf'])), } num_ebd_w = len(fast_weights['ebd']) num_clf_w = len(fast_weights['clf']) # fast adaptation for i in range(args.maml_innersteps): if i == 0: XS = fast['ebd'](support) pred = fast['clf'](XS) loss = F.cross_entropy(pred, YS) grads = torch.autograd.grad(loss, grad_param(fast, ['ebd', 'clf']), create_graph=True) else: XS = fast['ebd'](support, fast_weights['ebd']) pred = fast['clf'](XS, weights=fast_weights['clf']) loss = F.cross_entropy(pred, YS) grads = torch.autograd.grad(loss, itertools.chain( fast_weights['ebd'].values(), fast_weights['clf'].values()), create_graph=True) if args.maml_firstorder: grads = tuple([g.detach() for g in list(grads)]) # update fast weight fast_weights['ebd'] = OrderedDict( (name, param - args.maml_stepsize * grad) for ((name, param), grad) in zip(fast_weights['ebd'].items(), grads[:num_ebd_w])) fast_weights['clf'] = OrderedDict( (name, param - args.maml_stepsize * grad) for ((name, param), grad) in zip(fast_weights['clf'].items(), grads[num_ebd_w:])) # forward on the query, to get meta loss XQ = fast['ebd'](query, fast_weights['ebd']) pred = fast['clf'](XQ, weights=fast_weights['clf']) loss = F.cross_entropy(pred, YQ) grads = torch.autograd.grad(loss, grad_param(fast, ['ebd', 'clf'])) grads_ebd = { name: g for ((name, _), g) in zip(named_grad_param(fast, ['ebd']), grads[:num_ebd_w]) } grads_clf = { name: g for ((name, _), g) in zip(named_grad_param(fast, ['clf']), grads[num_ebd_w:]) } total_grad['ebd'].append(grads_ebd) total_grad['clf'].append(grads_clf) return
def train(train_data, val_data, test_data, model, class_names, criterion, args): ''' Train the model Use val_data to do early stopping ''' # creating a tmp directory to save the models out_dir = os.path.abspath( os.path.join(os.path.curdir, "tmp-runs", str(int(time.time() * 1e7)))) if not os.path.exists(out_dir): os.makedirs(out_dir) best_acc = 0 sub_cycle = 0 best_path = None if args.STS == True: classes_sample_p, example_prob_metrix = pre_calculate( train_data, class_names, model['G'], args) else: classes_sample_p, example_prob_metrix = None, None optG = torch.optim.Adam(grad_param(model, ['G']), lr=args.meta_lr, weight_decay=args.weight_decay) # optG2 = torch.optim.Adam(grad_param(model, ['G2']), lr=args.task_lr) # optCLF = torch.optim.Adam(grad_param(model, ['clf']), lr=args.task_lr) if args.lr_scheduler == 'ReduceLROnPlateau': schedulerG = torch.optim.lr_scheduler.ReduceLROnPlateau( optG, 'max', patience=args.patience // 2, factor=0.1, verbose=True) # schedulerCLF = torch.optim.lr_scheduler.ReduceLROnPlateau( # optCLF, 'max', patience=args.patience // 2, factor=0.1, verbose=True) elif args.lr_scheduler == 'ExponentialLR': schedulerG = torch.optim.lr_scheduler.ExponentialLR( optG, gamma=args.ExponentialLR_gamma) # schedulerCLF = torch.optim.lr_scheduler.ExponentialLR(optCLF, gamma=args.ExponentialLR_gamma) print("{}, Start training".format(datetime.datetime.now()), flush=True) # sampled_classes, source_classes = task_sampler(train_data, args) acc = 0 loss = 0 for ep in range(args.train_epochs): ep_loss = 0 for _ in range(args.train_episodes): sampled_classes, source_classes = task_sampler( train_data, args, classes_sample_p) train_gen = SerialSampler(train_data, args, sampled_classes, source_classes, 1, example_prob_metrix) sampled_tasks = train_gen.get_epoch() grad = {'clf': [], 'G': []} if not args.notqdm: sampled_tasks = tqdm(sampled_tasks, total=train_gen.num_episodes, ncols=80, leave=False, desc=colored('Training on train', 'yellow')) for task in sampled_tasks: if task is None: break q_loss, q_acc = train_one(task, class_names, model, optG, criterion, args, grad) acc += q_acc loss = loss + q_loss ep_loss = ep_loss + q_loss ep_loss = ep_loss / args.train_episodes optG.zero_grad() ep_loss.backward() optG.step() if ep % 100 == 0: print("{}:".format(colored('--------[TRAIN] ep', 'blue')) + str(ep) + ", loss:" + str(q_loss.item()) + ", acc:" + str(q_acc.item()) + "-----------") test_count = 100 # if (ep % test_count == 0) and (ep != 0): if (ep % test_count == 0): acc = acc / args.train_episodes / test_count loss = loss / args.train_episodes / test_count print("{}:".format(colored('--------[TRAIN] ep', 'blue')) + str(ep) + ", mean_loss:" + str(loss.item()) + ", mean_acc:" + str(acc.item()) + "-----------") net = copy.deepcopy(model) # acc, std = test(train_data, class_names, optG, net, criterion, args, args.test_epochs, False) # print("[TRAIN] {}, {:s} {:2d}, {:s} {:s}{:>7.4f} ± {:>6.4f} ".format( # datetime.datetime.now(), # "ep", ep, # colored("train", "red"), # colored("acc:", "blue"), acc, std, # ), flush=True) acc = 0 loss = 0 # Evaluate test accuracy cur_acc, cur_std = test(test_data, class_names, optG, net, criterion, args, args.test_epochs, False) print( ("[TEST] {}, {:s} {:2d}, {:s} {:s}{:>7.4f} ± {:>6.4f}, "). format( datetime.datetime.now(), "ep", ep, colored("test ", "cyan"), colored("acc:", "blue"), cur_acc, cur_std, # colored("train stats", "cyan"), # colored("G_grad:", "blue"), np.mean(np.array(grad['G'])), # colored("clf_grad:", "blue"), np.mean(np.array(grad['clf'])), ), flush=True) # Evaluate validation accuracy cur_acc, cur_std = test(val_data, class_names, optG, net, criterion, args, args.test_epochs, False) print( ("[EVAL] {}, {:s} {:2d}, {:s} {:s}{:>7.4f} ± {:>6.4f}, "). format( datetime.datetime.now(), "ep", ep, colored("val ", "cyan"), colored("acc:", "blue"), cur_acc, cur_std, # colored("train stats", "cyan"), # colored("G_grad:", "blue"), np.mean(np.array(grad['G'])), # colored("clf_grad:", "blue"), np.mean(np.array(grad['clf'])), ), flush=True) # Update the current best model if val acc is better if cur_acc > best_acc: best_acc = cur_acc best_path = os.path.join(out_dir, str(ep)) # save current model print("{}, Save cur best model to {}".format( datetime.datetime.now(), best_path)) torch.save(model['G'].state_dict(), best_path + '.G') # torch.save(model['G2'].state_dict(), best_path + '.G2') # torch.save(model['clf'].state_dict(), best_path + '.clf') sub_cycle = 0 else: sub_cycle += 1 # Break if the val acc hasn't improved in the past patience epochs if sub_cycle == args.patience: break if args.lr_scheduler == 'ReduceLROnPlateau': schedulerG.step(cur_acc) # schedulerCLF.step(cur_acc) elif args.lr_scheduler == 'ExponentialLR': schedulerG.step() # schedulerCLF.step() print("{}, End of training. Restore the best weights".format( datetime.datetime.now()), flush=True) # restore the best saved model model['G'].load_state_dict(torch.load(best_path + '.G')) # model['G2'].load_state_dict(torch.load(best_path + '.G2')) # model['clf'].load_state_dict(torch.load(best_path + '.clf')) if args.save: # save the current model out_dir = os.path.abspath( os.path.join(os.path.curdir, "saved-runs", str(int(time.time() * 1e7)))) if not os.path.exists(out_dir): os.makedirs(out_dir) best_path = os.path.join(out_dir, 'best') print("{}, Save best model to {}".format(datetime.datetime.now(), best_path), flush=True) torch.save(model['G'].state_dict(), best_path + '.G') # torch.save(model['clf'].state_dict(), best_path + '.clf') with open(best_path + '_args.txt', 'w') as f: for attr, value in sorted(args.__dict__.items()): f.write("{}={}\n".format(attr, value)) return optG
def train(train_data, val_data, model, class_names, args): ''' Train the model Use val_data to do early stopping ''' # creating a tmp directory to save the models out_dir = os.path.abspath( os.path.join(os.path.curdir, "tmp-runs", str(int(time.time() * 1e7)))) if not os.path.exists(out_dir): os.makedirs(out_dir) best_acc = 0 sub_cycle = 0 best_path = None optG = torch.optim.Adam(grad_param(model, ['G']), lr=args.meta_lr) optG2 = torch.optim.Adam(grad_param(model, ['G2']), lr=args.task_lr) optCLF = torch.optim.Adam(grad_param(model, ['clf']), lr=args.task_lr) if args.lr_scheduler == 'ReduceLROnPlateau': schedulerG = torch.optim.lr_scheduler.ReduceLROnPlateau( optG, 'max', patience=args.patience // 2, factor=0.1, verbose=True) schedulerCLF = torch.optim.lr_scheduler.ReduceLROnPlateau( optCLF, 'max', patience=args.patience // 2, factor=0.1, verbose=True) elif args.lr_scheduler == 'ExponentialLR': schedulerG = torch.optim.lr_scheduler.ExponentialLR( optG, gamma=args.ExponentialLR_gamma) schedulerCLF = torch.optim.lr_scheduler.ExponentialLR( optCLF, gamma=args.ExponentialLR_gamma) print("{}, Start training".format(datetime.datetime.now()), flush=True) # train_gen = ParallelSampler(train_data, args, args.train_episodes) # train_gen_val = ParallelSampler_Test(train_data, args, args.val_episodes) # val_gen = ParallelSampler_Test(val_data, args, args.val_episodes) # sampled_classes, source_classes = task_sampler(train_data, args) acc = 0 loss = 0 for ep in range(args.train_epochs): sampled_classes, source_classes = task_sampler(train_data, args) class_names_dict = {} class_names_dict['label'] = class_names['label'][sampled_classes] class_names_dict['text'] = class_names['text'][sampled_classes] class_names_dict['text_len'] = class_names['text_len'][sampled_classes] class_names_dict['is_support'] = False train_gen = ParallelSampler(train_data, args, sampled_classes, source_classes, args.train_episodes) sampled_tasks = train_gen.get_epoch() class_names_dict = utils.to_tensor(class_names_dict, args.cuda, exclude_keys=['is_support']) grad = {'clf': [], 'G': []} if not args.notqdm: sampled_tasks = tqdm(sampled_tasks, total=train_gen.num_episodes, ncols=80, leave=False, desc=colored('Training on train', 'yellow')) for task in sampled_tasks: if task is None: break q_loss, q_acc = train_one(task, class_names_dict, model, optG, optG2, optCLF, args, grad) acc += q_acc loss += q_loss if ep % 100 == 0: print("--------[TRAIN] ep:" + str(ep) + ", loss:" + str(q_loss.item()) + ", acc:" + str(q_acc.item()) + "-----------") if (ep % 200 == 0) and (ep != 0): acc = acc / args.train_episodes / 200 loss = loss / args.train_episodes / 200 print("--------[TRAIN] ep:" + str(ep) + ", mean_loss:" + str(loss.item()) + ", mean_acc:" + str(acc.item()) + "-----------") net = copy.deepcopy(model) acc, std = test(train_data, class_names, optG, optCLF, net, args, args.test_epochs, False) print( "[TRAIN] {}, {:s} {:2d}, {:s} {:s}{:>7.4f} ± {:>6.4f} ".format( datetime.datetime.now(), "ep", ep, colored("train", "red"), colored("acc:", "blue"), acc, std, ), flush=True) acc = 0 loss = 0 # Evaluate validation accuracy cur_acc, cur_std = test(val_data, class_names, optG, optCLF, net, args, args.test_epochs, False) print(("[EVAL] {}, {:s} {:2d}, {:s} {:s}{:>7.4f} ± {:>6.4f}, " "{:s} {:s}{:>7.4f}, {:s}{:>7.4f}").format( datetime.datetime.now(), "ep", ep, colored("val ", "cyan"), colored("acc:", "blue"), cur_acc, cur_std, colored("train stats", "cyan"), colored("G_grad:", "blue"), np.mean(np.array(grad['G'])), colored("clf_grad:", "blue"), np.mean(np.array(grad['clf'])), ), flush=True) # Update the current best model if val acc is better if cur_acc > best_acc: best_acc = cur_acc best_path = os.path.join(out_dir, str(ep)) # save current model print("{}, Save cur best model to {}".format( datetime.datetime.now(), best_path)) torch.save(model['G'].state_dict(), best_path + '.G') torch.save(model['G2'].state_dict(), best_path + '.G2') torch.save(model['clf'].state_dict(), best_path + '.clf') sub_cycle = 0 else: sub_cycle += 1 # Break if the val acc hasn't improved in the past patience epochs if sub_cycle == args.patience: break if args.lr_scheduler == 'ReduceLROnPlateau': schedulerG.step(cur_acc) schedulerCLF.step(cur_acc) elif args.lr_scheduler == 'ExponentialLR': schedulerG.step() schedulerCLF.step() print("{}, End of training. Restore the best weights".format( datetime.datetime.now()), flush=True) # restore the best saved model model['G'].load_state_dict(torch.load(best_path + '.G')) model['G2'].load_state_dict(torch.load(best_path + '.G2')) model['clf'].load_state_dict(torch.load(best_path + '.clf')) if args.save: # save the current model out_dir = os.path.abspath( os.path.join(os.path.curdir, "saved-runs", str(int(time.time() * 1e7)))) if not os.path.exists(out_dir): os.makedirs(out_dir) best_path = os.path.join(out_dir, 'best') print("{}, Save best model to {}".format(datetime.datetime.now(), best_path), flush=True) torch.save(model['G'].state_dict(), best_path + '.G') torch.save(model['clf'].state_dict(), best_path + '.clf') with open(best_path + '_args.txt', 'w') as f: for attr, value in sorted(args.__dict__.items()): f.write("{}={}\n".format(attr, value)) return optG, optCLF
def train(train_data, val_data, model,args): ''' Train the model Use val_data to do early stopping ''' # creating a tmp directory to save the models out_dir = os.path.abspath(os.path.join( os.path.curdir, "tmp-runs", str(int(time.time() * 1e7)))) if not os.path.exists(out_dir): os.makedirs(out_dir) best_acc = 0 sub_cycle = 0 best_path = None opt = torch.optim.Adam(grad_param(model, ['clf']), args.lr) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( opt, 'max', patience=args.patience//2, factor=0.1, verbose=True) print("{}, Start training".format( datetime.datetime.now().strftime('%02y/%02m/%02d %H:%M:%S')), flush=True) train_gen = ParallelSampler(train_data,args,args.train_episodes) train_gen_val = ParallelSampler(train_data,args, args.val_episodes) val_gen = ParallelSampler(val_data,args,args.val_episodes) for ep in range(args.train_epochs): sampled_tasks = train_gen.get_epoch() grad = {'clf': [], 'ebd': []} sampled_tasks = tqdm(sampled_tasks, total=train_gen.num_episodes, ncols=80, leave=False, desc=colored('Training on train', 'yellow')) for task in sampled_tasks: if task is None: break train_one(task, model,opt,args,grad) acc, std = test(train_data, model,args, args.val_episodes,False, train_gen_val.get_epoch()) print("{}, {:s} {:2d}, {:s} {:s}{:>7.4f} ± {:>6.4f} ".format( datetime.datetime.now().strftime('%02y/%02m/%02d %H:%M:%S'), "ep", ep, colored("train", "red"), colored("acc:", "blue"), acc, std, ), flush=True) # Evaluate validation accuracy cur_acc, cur_std = test(val_data, model,args, args.val_episodes,False, val_gen.get_epoch()) print(("{}, {:s} {:2d}, {:s} {:s}{:>7.4f} ± {:>6.4f}, " "{:s} {:s}{:>7.4f}, {:s}{:>7.4f}").format( datetime.datetime.now().strftime('%02y/%02m/%02d %H:%M:%S'), "ep", ep, colored("val ", "cyan"), colored("acc:", "blue"), cur_acc, cur_std, colored("train stats", "cyan"), colored("ebd_grad:", "blue"), np.mean(np.array(grad['ebd'])), colored("clf_grad:", "blue"), np.mean(np.array(grad['clf'])), ), flush=True) # Update the current best model if val acc is better if cur_acc > best_acc: best_acc = cur_acc best_path = os.path.join(out_dir, str(ep)) # save current model print("{}, Save cur best model to {}".format( datetime.datetime.now().strftime('%02y/%02m/%02d %H:%M:%S'), best_path)) torch.save(model['ebd'].state_dict(), best_path + '.ebd') torch.save(model['clf'].state_dict(), best_path + '.clf') sub_cycle = 0 else: sub_cycle += 1 # Break if the val acc hasn't improved in the past patience epochs if sub_cycle == 20: break print("{}, End of training. Restore the best weights".format( datetime.datetime.now().strftime('%02y/%02m/%02d %H:%M:%S')), flush=True) # restore the best saved model model['ebd'].load_state_dict(torch.load(best_path + '.ebd')) model['clf'].load_state_dict(torch.load(best_path + '.clf')) # save the current model out_dir = os.path.abspath(os.path.join( os.path.curdir, "saved-runs", str(int(time.time() * 1e7)))) if not os.path.exists(out_dir): os.makedirs(out_dir) best_path = os.path.join(out_dir, 'best') print("{}, Save best model to {}".format( datetime.datetime.now().strftime('%02y/%02m/%02d %H:%M:%S'), best_path), flush=True) torch.save(model['ebd'].state_dict(), best_path + '.ebd') torch.save(model['clf'].state_dict(), best_path + '.clf') return
def train(train_data, val_data, model, args): ''' Train the model Use val_data to do early stopping Args: model (dict): {'ebd': embedding, 'clf': classifier} ''' # creating a tmp directory to save the models out_dir = os.path.abspath(os.path.join( os.path.curdir, "tmp-runs", str(int(time.time() * 1e7)))) if not os.path.exists(out_dir): os.makedirs(out_dir) # Write results # write_acc_tr = 'acc_base.csv' # init_csv(write_acc_tr) # write_acc_val = 'val_acc_base.csv' # init_csv(write_acc_val) best_acc = 0 sub_cycle = 0 best_path = None # grad_param generates the learnable parameters from the classifier params_to_opt = grad_param(model, ['ebd', 'clf']) opt = torch.optim.Adam(params_to_opt, lr=args.lr) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( opt, 'max', patience=args.patience//2, factor=0.1, verbose=True) print("{}, Start training".format( datetime.datetime.now().strftime('%02y/%02m/%02d %H:%M:%S')), flush=True) train_gen = ParallelSampler(train_data, args, args.train_episodes) train_gen_val = ParallelSampler(train_data, args, args.val_episodes) val_gen = ParallelSampler(val_data, args, args.val_episodes) for ep in range(args.train_epochs): sampled_tasks = train_gen.get_epoch() grad = {'clf': [], 'ebd': []} if not args.notqdm: sampled_tasks = tqdm(sampled_tasks, total=train_gen.num_episodes, ncols=80, leave=False, desc=colored('Training on train', 'yellow')) for task in sampled_tasks: if task is None: break train_one(task, model, opt, args, grad) if ep % 10 == 0: acc, std = test(train_data, model, args, args.val_episodes, False, train_gen_val.get_epoch()) print("{}, {:s} {:2d}, {:s} {:s}{:>7.4f} ± {:>6.4f} ".format( datetime.datetime.now().strftime('%02y/%02m/%02d %H:%M:%S'), "ep", ep, colored("train", "red"), colored("acc:", "blue"), acc, std, ), flush=True) # write_csv(write_acc_tr, acc, std, ep) # Evaluate validation accuracy cur_acc, cur_std = test(val_data, model, args, args.val_episodes, False, val_gen.get_epoch()) print(("{}, {:s} {:2d}, {:s} {:s}{:>7.4f} ± {:>6.4f}, " "{:s} {:s}{:>7.4f}, {:s}{:>7.4f}").format( datetime.datetime.now().strftime('%02y/%02m/%02d %H:%M:%S'), "ep", ep, colored("val ", "cyan"), colored("acc:", "blue"), cur_acc, cur_std, colored("train stats", "cyan"), colored("ebd_grad:", "blue"), np.mean(np.array(grad['ebd'])), colored("clf_grad:", "blue"), np.mean(np.array(grad['clf'])), ), flush=True) # if ep % 10 == 0: write_csv(write_acc_val, cur_acc, cur_std, ep) # Update the current best model if val acc is better if cur_acc > best_acc: best_acc = cur_acc best_path = os.path.join(out_dir, str(ep)) # save current model print("{}, Save cur best model to {}".format( datetime.datetime.now().strftime('%02y/%02m/%02d %H:%M:%S'), best_path)) torch.save(model['ebd'].state_dict(), best_path + '.ebd') torch.save(model['clf'].state_dict(), best_path + '.clf') sub_cycle = 0 else: sub_cycle += 1 # Break if the val acc hasn't improved in the past patience epochs if sub_cycle == args.patience: break print("{}, End of training. Restore the best weights".format( datetime.datetime.now().strftime('%02y/%02m/%02d %H:%M:%S')), flush=True) # restore the best saved model model['ebd'].load_state_dict(torch.load(best_path + '.ebd')) model['clf'].load_state_dict(torch.load(best_path + '.clf')) if args.save: # save the current model out_dir = os.path.abspath(os.path.join( os.path.curdir, "saved-runs", str(int(time.time() * 1e7)))) if not os.path.exists(out_dir): os.makedirs(out_dir) best_path = os.path.join(out_dir, 'best') print("{}, Save best model to {}".format( datetime.datetime.now().strftime('%02y/%02m/%02d %H:%M:%S'), best_path), flush=True) torch.save(model['ebd'].state_dict(), best_path + '.ebd') torch.save(model['clf'].state_dict(), best_path + '.clf') with open(best_path + '_args.txt', 'w') as f: for attr, value in sorted(args.__dict__.items()): f.write("{}={}\n".format(attr, value)) return