def main(args): """ The main training function. Only works for single node (be it single or multi-GPU) Parameters ---------- args : Parsed arguments """ # setup ngpus = torch.cuda.device_count() if ngpus == 0: raise RuntimeWarning("This will not be able to run on CPU only") print(f"Working with {ngpus} GPUs") if args.optim.lower() == "ranger": # No warm up if ranger optimizer args.warm = 0 current_experiment_time = datetime.now().strftime('%Y%m%d_%T').replace(":", "") args.exp_name = f"{'debug_' if args.debug else ''}{current_experiment_time}_" \ f"_fold{args.fold if not args.full else 'FULL'}" \ f"_{args.arch}_{args.width}" \ f"_batch{args.batch_size}" \ f"_optim{args.optim}" \ f"_{args.optim}" \ f"_lr{args.lr}-wd{args.weight_decay}_epochs{args.epochs}_deepsup{args.deep_sup}" \ f"_{'fp16' if not args.no_fp16 else 'fp32'}" \ f"_warm{args.warm}_" \ f"_norm{args.norm_layer}{'_swa' + str(args.swa_repeat) if args.swa else ''}" \ f"_dropout{args.dropout}" \ f"_warm_restart{args.warm_restart}" \ f"{'_' + args.com.replace(' ', '_') if args.com else ''}" args.save_folder = pathlib.Path(f"./runs/{args.exp_name}") args.save_folder.mkdir(parents=True, exist_ok=True) args.seg_folder = args.save_folder / "segs" args.seg_folder.mkdir(parents=True, exist_ok=True) args.save_folder = args.save_folder.resolve() save_args(args) t_writer = SummaryWriter(str(args.save_folder)) # Create model print(f"Creating {args.arch}") model_maker = getattr(models, args.arch) model = model_maker( 4, 3, width=args.width, deep_supervision=args.deep_sup, norm_layer=get_norm_layer(args.norm_layer), dropout=args.dropout) print(f"total number of trainable parameters {count_parameters(model)}") if args.swa: # Create the average model swa_model = model_maker( 4, 3, width=args.width, deep_supervision=args.deep_sup, norm_layer=get_norm_layer(args.norm_layer)) for param in swa_model.parameters(): param.detach_() swa_model = swa_model.cuda() swa_model_optim = WeightSWA(swa_model) if ngpus > 1: model = torch.nn.DataParallel(model).cuda() else: model = model.cuda() print(model) model_file = args.save_folder / "model.txt" with model_file.open("w") as f: print(model, file=f) criterion = EDiceLoss().cuda() metric = criterion.metric print(metric) rangered = False # needed because LR scheduling scheme is different for this optimizer if args.optim == "adam": optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay, eps=1e-4) elif args.optim == "sgd": optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, weight_decay=args.weight_decay, momentum=0.9, nesterov=True) elif args.optim == "adamw": print(f"weight decay argument will not be used. Default is 11e-2") optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr) elif args.optim == "ranger": optimizer = Ranger(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) rangered = True # optionally resume from a checkpoint if args.resume: reload_ckpt(args, model, optimizer) if args.debug: args.epochs = 2 args.warm = 0 args.val = 1 if args.full: train_dataset, bench_dataset = get_datasets(args.seed, args.debug, full=True) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=False, drop_last=True) bench_loader = torch.utils.data.DataLoader( bench_dataset, batch_size=1, num_workers=args.workers) else: train_dataset, val_dataset, bench_dataset = get_datasets(args.seed, args.debug, fold_number=args.fold) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=False, drop_last=True) val_loader = torch.utils.data.DataLoader( val_dataset, batch_size=max(1, args.batch_size // 2), shuffle=False, pin_memory=False, num_workers=args.workers, collate_fn=determinist_collate) bench_loader = torch.utils.data.DataLoader( bench_dataset, batch_size=1, num_workers=args.workers) print("Val dataset number of batch:", len(val_loader)) print("Train dataset number of batch:", len(train_loader)) # create grad scaler scaler = GradScaler() # Actual Train loop best = np.inf print("start warm-up now!") if args.warm != 0: tot_iter_train = len(train_loader) scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda cur_iter: (1 + cur_iter) / (tot_iter_train * args.warm)) patients_perf = [] if not args.resume: for epoch in range(args.warm): ts = time.perf_counter() model.train() training_loss = step(train_loader, model, criterion, metric, args.deep_sup, optimizer, epoch, t_writer, scaler, scheduler, save_folder=args.save_folder, no_fp16=args.no_fp16, patients_perf=patients_perf) te = time.perf_counter() print(f"Train Epoch done in {te - ts} s") # Validate at the end of epoch every val step if (epoch + 1) % args.val == 0 and not args.full: model.eval() with torch.no_grad(): validation_loss = step(val_loader, model, criterion, metric, args.deep_sup, optimizer, epoch, t_writer, save_folder=args.save_folder, no_fp16=args.no_fp16) t_writer.add_scalar(f"SummaryLoss/overfit", validation_loss - training_loss, epoch) if args.warm_restart: print('Total number of epochs should be divisible by 30, else it will do odd things') scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, 30, eta_min=1e-7) else: scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs + 30 if not rangered else round( args.epochs * 0.5)) print("start training now!") if args.swa: # c = 15, k=3, repeat = 5 c, k, repeat = 30, 3, args.swa_repeat epochs_done = args.epochs reboot_lr = 0 if args.debug: c, k, repeat = 2, 1, 2 for epoch in range(args.start_epoch + args.warm, args.epochs + args.warm): try: # do_epoch for one epoch ts = time.perf_counter() model.train() training_loss = step(train_loader, model, criterion, metric, args.deep_sup, optimizer, epoch, t_writer, scaler, save_folder=args.save_folder, no_fp16=args.no_fp16, patients_perf=patients_perf) te = time.perf_counter() print(f"Train Epoch done in {te - ts} s") # Validate at the end of epoch every val step if (epoch + 1) % args.val == 0 and not args.full: model.eval() with torch.no_grad(): validation_loss = step(val_loader, model, criterion, metric, args.deep_sup, optimizer, epoch, t_writer, save_folder=args.save_folder, no_fp16=args.no_fp16, patients_perf=patients_perf) t_writer.add_scalar(f"SummaryLoss/overfit", validation_loss - training_loss, epoch) if validation_loss < best: best = validation_loss model_dict = model.state_dict() save_checkpoint( dict( epoch=epoch, arch=args.arch, state_dict=model_dict, optimizer=optimizer.state_dict(), scheduler=scheduler.state_dict(), ), save_folder=args.save_folder, ) ts = time.perf_counter() print(f"Val epoch done in {ts - te} s") if args.swa: if (args.epochs - epoch - c) == 0: reboot_lr = optimizer.param_groups[0]['lr'] if not rangered: scheduler.step() print("scheduler stepped!") else: if epoch / args.epochs > 0.5: scheduler.step() print("scheduler stepped!") except KeyboardInterrupt: print("Stopping training loop, doing benchmark") break if args.swa: swa_model_optim.update(model) print("SWA Model initialised!") for i in range(repeat): optimizer = torch.optim.Adam(model.parameters(), args.lr / 2, weight_decay=args.weight_decay) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, c + 10) for swa_epoch in range(c): # do_epoch for one epoch ts = time.perf_counter() model.train() swa_model.train() current_epoch = epochs_done + i * c + swa_epoch training_loss = step(train_loader, model, criterion, metric, args.deep_sup, optimizer, current_epoch, t_writer, scaler, no_fp16=args.no_fp16, patients_perf=patients_perf) te = time.perf_counter() print(f"Train Epoch done in {te - ts} s") t_writer.add_scalar(f"SummaryLoss/train", training_loss, current_epoch) # update every k epochs and val: print(f"cycle number: {i}, swa_epoch: {swa_epoch}, total_cycle_to_do {repeat}") if (swa_epoch + 1) % k == 0: swa_model_optim.update(model) if not args.full: model.eval() swa_model.eval() with torch.no_grad(): validation_loss = step(val_loader, model, criterion, metric, args.deep_sup, optimizer, current_epoch, t_writer, save_folder=args.save_folder, no_fp16=args.no_fp16) swa_model_loss = step(val_loader, swa_model, criterion, metric, args.deep_sup, optimizer, current_epoch, t_writer, swa=True, save_folder=args.save_folder, no_fp16=args.no_fp16) t_writer.add_scalar(f"SummaryLoss/val", validation_loss, current_epoch) t_writer.add_scalar(f"SummaryLoss/swa", swa_model_loss, current_epoch) t_writer.add_scalar(f"SummaryLoss/overfit", validation_loss - training_loss, current_epoch) t_writer.add_scalar(f"SummaryLoss/overfit_swa", swa_model_loss - training_loss, current_epoch) scheduler.step() epochs_added = c * repeat save_checkpoint( dict( epoch=args.epochs + epochs_added, arch=args.arch, state_dict=swa_model.state_dict(), optimizer=optimizer.state_dict() ), save_folder=args.save_folder, ) else: save_checkpoint( dict( epoch=args.epochs, arch=args.arch, state_dict=model.state_dict(), optimizer=optimizer.state_dict() ), save_folder=args.save_folder, ) try: df_individual_perf = pd.DataFrame.from_records(patients_perf) print(df_individual_perf) df_individual_perf.to_csv(f'{str(args.save_folder)}/patients_indiv_perf.csv') reload_ckpt_bis(f'{str(args.save_folder)}/model_best.pth.tar', model) generate_segmentations(bench_loader, model, t_writer, args) except KeyboardInterrupt: print("Stopping right now!")
def train(args): """ :param args: :return: """ grammar = semQL.Grammar() sql_data, table_data, val_sql_data, val_table_data = utils.load_dataset( args.dataset, use_small=args.toy) model = IRNet(args, grammar) if args.cuda: model.cuda() # now get the optimizer optimizer_cls = eval('torch.optim.%s' % args.optimizer) optimizer = optimizer_cls(model.parameters(), lr=args.lr) print('Enable Learning Rate Scheduler: ', args.lr_scheduler) if args.lr_scheduler: scheduler = optim.lr_scheduler.MultiStepLR( optimizer, milestones=[21, 41], gamma=args.lr_scheduler_gammar) else: scheduler = None print('Loss epoch threshold: %d' % args.loss_epoch_threshold) print('Sketch loss coefficient: %f' % args.sketch_loss_coefficient) if args.load_model: print('load pretrained model from %s' % (args.load_model)) pretrained_model = torch.load( args.load_model, map_location=lambda storage, loc: storage) pretrained_modeled = copy.deepcopy(pretrained_model) for k in pretrained_model.keys(): if k not in model.state_dict().keys(): del pretrained_modeled[k] model.load_state_dict(pretrained_modeled) model.word_emb = utils.load_word_emb(args.glove_embed_path) # begin train model_save_path = utils.init_log_checkpoint_path(args) utils.save_args(args, os.path.join(model_save_path, 'config.json')) best_dev_acc = .0 try: with open(os.path.join(model_save_path, 'epoch.log'), 'w') as epoch_fd: for epoch in tqdm.tqdm(range(args.epoch)): if args.lr_scheduler: scheduler.step() epoch_begin = time.time() loss = utils.epoch_train( model, optimizer, args.batch_size, sql_data, table_data, args, loss_epoch_threshold=args.loss_epoch_threshold, sketch_loss_coefficient=args.sketch_loss_coefficient) epoch_end = time.time() json_datas, sketch_acc, acc, counts, corrects = utils.epoch_acc( model, args.batch_size, val_sql_data, val_table_data, beam_size=args.beam_size) # acc = utils.eval_acc(json_datas, val_sql_data) if acc > best_dev_acc: utils.save_checkpoint( model, os.path.join(model_save_path, 'best_model.model')) best_dev_acc = acc utils.save_checkpoint( model, os.path.join(model_save_path, '{%s}_{%s}.model') % (epoch, acc)) log_str = 'Epoch: %d, Loss: %f, Sketch Acc: %f, Acc: %f, time: %f\n' % ( epoch + 1, loss, sketch_acc, acc, epoch_end - epoch_begin) tqdm.tqdm.write(log_str) epoch_fd.write(log_str) epoch_fd.flush() except Exception as e: # Save model utils.save_checkpoint(model, os.path.join(model_save_path, 'end_model.model')) print(e) tb = traceback.format_exc() print(tb) else: utils.save_checkpoint(model, os.path.join(model_save_path, 'end_model.model')) json_datas, sketch_acc, acc, counts, corrects = utils.epoch_acc( model, args.batch_size, val_sql_data, val_table_data, beam_size=args.beam_size) # acc = utils.eval_acc(json_datas, val_sql_data) print("Sketch Acc: %f, Acc: %f, Beam Acc: %f" % ( sketch_acc, acc, acc, ))
def train(args, model, optimizer, bert_optimizer, data): ''' :param args: :param model: :param data: :param grammar: :return: ''' print('Loss epoch threshold: %d' % args.loss_epoch_threshold) print('Sketch loss coefficient: %f' % args.sketch_loss_coefficient) if args.load_model and not args.resume: print('load pretrained model from %s' % (args.load_model)) pretrained_model = torch.load( args.load_model, map_location=lambda storage, loc: storage) pretrained_modeled = copy.deepcopy(pretrained_model) for k in pretrained_model.keys(): if k not in model.state_dict().keys(): del pretrained_modeled[k] model.load_state_dict(pretrained_modeled) # ==============data============== if args.interaction_level: batch_size = 1 train_batchs, train_sample_batchs = data.get_interaction_batches( batch_size, sample_num=args.train_evaluation_size, use_small=args.toy) valid_batchs = data.get_all_interactions(data.val_sql_data, data.val_table_data, _type='test', use_small=args.toy) else: batch_size = args.batch_size train_batchs = data.get_utterance_batches(batch_size) valid_batchs = data.get_all_utterances(data.val_sql_data) print(len(train_batchs), len(train_sample_batchs), len(valid_batchs)) start_epoch = 1 best_question_match = .0 lr = args.initial_lr stage = 1 if args.resume: model_save_path = utils.init_log_checkpoint_path(args) current_w = torch.load( os.path.join(model_save_path, args.current_model_name)) best_w = torch.load(os.path.join(model_save_path, args.best_model_name)) best_question_match = best_w['question_match'] start_epoch = current_w['epoch'] + 1 lr = current_w['lr'] utils.adjust_learning_rate(optimizer, lr) stage = current_w['stage'] model.load_state_dict(current_w['state_dict']) # 如果中断点恰好为转换stage的点 if start_epoch - 1 in args.stage_epoch: stage += 1 lr /= args.lr_decay utils.adjust_learning_rate(optimizer, lr) model.load_state_dict(best_w['state_dict']) print("=> Loading resume model from epoch {} ...".format(start_epoch - 1)) # model.word_emb = utils.load_word_emb(args.glove_embed_path,use_small=args.use_small) # begin train model_save_path = utils.init_log_checkpoint_path(args) utils.save_args(args, os.path.join(model_save_path, 'config.json')) file_mode = 'a' if args.resume else 'w' log = Logger(os.path.join(args.save, args.logfile), file_mode) # log_pred_gt = Logger(os.path.join(args.save, args.log_pred_gt), file_mode) with open(os.path.join(model_save_path, 'epoch.log'), 'w') as epoch_fd: for epoch in range(start_epoch, args.epoch + 1): epoch_begin = time.time() # model.set_dropout(args.dropout_amount) model.dropout_ratio = args.dropout_amount if args.interaction_level: loss = utils.epoch_train_with_interaction( epoch, log, model, optimizer, bert_optimizer, train_batchs, args) #loss = 2. else: pass model.dropout_ratio = 0. # model.set_dropout(0.) epoch_end = time.time() s = time.time() sample_sketch_acc, sample_lf_acc, sample_interaction_lf_acc = utils.epoch_acc_with_interaction( epoch, model, train_sample_batchs, args, beam_size=args.beam_size, use_predicted_queries=True) log_str = '[Epoch: %d(sample predicted),Sample ratio[%d]: %f], Sketch Acc: %f, Acc: %f, Interaction Acc: %f, Train time: %f, Sample predict time: %f\n' % ( epoch, len(train_sample_batchs), len(train_sample_batchs) / len(train_batchs), sample_sketch_acc, sample_lf_acc, sample_interaction_lf_acc, epoch_end - epoch_begin, time.time() - s) print(log_str) log.put(log_str) epoch_fd.write(log_str) epoch_fd.flush() # s = time.time() # # gold_sketch_acc,gold_lf_acc ,gold_interaction_lf_acc = utils.epoch_acc_with_interaction(epoch,model, valid_batchs,args,beam_size=args.beam_size) # # log_str = '[Epoch: %d(gold)], Loss: %f, Sketch Acc: %f, Acc: %f, Interaction Acc: %f, Gold predict time: %f\n' % ( # epoch, loss, gold_sketch_acc, gold_lf_acc, gold_interaction_lf_acc,time.time()-s) # print(log_str) # log.put(log_str) # epoch_fd.write(log_str) # epoch_fd.flush() s = time.time() valid_jsonf, pred_sketch_acc, pred_lf_acc, pred_interaction_lf_acc = utils.epoch_acc_with_interaction_save_json( epoch, model, valid_batchs, args, beam_size=args.beam_size, use_predicted_queries=True) question_match, interaction_match = utils.semQL2SQL_question_and_interaction_match( valid_jsonf, args) log_str = '[Epoch: %d(predicted)], Loss: %f, lr: %.3ef, Sketch Acc: %f, Acc: %f, Interaction Acc: %f, Question Match : %f, Interaction Macth : %f, Predicted predict time: %f\n\n' % ( epoch, loss, optimizer.param_groups[0]["lr"], pred_sketch_acc, pred_lf_acc, pred_interaction_lf_acc, question_match, interaction_match, time.time() - s) print(log_str) log.put(log_str) epoch_fd.write(log_str) epoch_fd.flush() state = { "state_dict": model.state_dict(), "epoch": epoch, "question_match": question_match, "interaction_match": interaction_match, "lr": lr, 'stage': stage } current_w_name = os.path.join( model_save_path, '{}_{:.3f}.pth'.format(epoch, question_match)) best_w_name = os.path.join(model_save_path, args.best_model_name) current_w_name2 = os.path.join(model_save_path, args.current_model_name) utils.save_ckpt(state, best_question_match < question_match, current_w_name, best_w_name, current_w_name2) best_question_match = max(best_question_match, question_match) if epoch in args.stage_epoch: stage += 1 lr /= args.lr_decay best_w_name = os.path.join(model_save_path, args.best_model_name) model.load_state_dict(torch.load(best_w_name)['state_dict']) print("*" * 10, "step into stage%02d lr %.3ef" % (stage, lr)) utils.adjust_learning_rate(optimizer, lr) log.put("Finished training!") log.close()