def train(info: TrainInformation, split, fold): """주어진 split에 대한 학습과 테스트를 진행한다.""" bs = info.BS init_lr = info.INIT_LR lr_decay = info.LR_DECAY momentum = info.MOMENTUM weight_decay = info.WEIGHT_DECAY optimizer_method = info.OPTIMIZER_METHOD epoch = info.EPOCH nchs = info.NCHS filename = info.FILENAME model_name = info.MODEL_NAME exp_name = info.NAME print("Using File {}".format(filename)) train_dataset = Dataset(split=split, fold=fold, phase="train", filename=filename, use_data_dropout=info.USE_DATA_DROPOUT) #val_dataset = Dataset(split=split, fold=fold, phase="val", filename=filename) test_dataset = Dataset(split=split, fold=fold, phase="test", filename=filename, use_data_dropout=False) model = get_classifier_model(model_name, train_dataset.feature_size, nchs, info.ACTIVATION) print(model) # Optimizer 설정 optimizer = set_optimizer( optimizer_method, model, init_lr, weight_decay, momentum=momentum ) data_loader = torch.utils.data.DataLoader( train_dataset, batch_size=bs, shuffle=True, num_workers=0, drop_last=True ) bce_loss = torch.nn.BCEWithLogitsLoss().cuda() train_result = TrainResult() train_result.set_sizes( len(train_dataset.data), 0, len(test_dataset.data) ) for ep in range(epoch): global prev_plot prev_plot = 0 train_step( exp_name, ep, model, train_dataset, test_dataset, optimizer, init_lr, lr_decay, data_loader, bce_loss, train_result, ) savedir = "/content/drive/My Drive/research/frontiers/checkpoints/%s" % exp_name best_test_epoch = train_result.best_test_epoch #25 savepath = "%s/epoch_%04d_fold_%02d.pt" % (savedir, best_test_epoch, train_dataset.split) #model.load_state_dict(torch.load(savepath)) model = torch.load(savepath) model.eval() test_preds = train_utils.get_preds(test_dataset.data[:, 1:], model) test_AUC = train_utils.compute_AUC(test_dataset.data[:, :1], test_preds) test_PRAUC = train_utils.compute_PRAUC(test_dataset.data[:, :1], test_preds) train_utils.plot_AUC(test_dataset, test_preds, test_AUC, savepath=savepath.replace(".pt", "_AUC.tiff")) contributing_variables = compute_contributing_variables(model, test_dataset) with open(os.path.join(savedir, "contributing_variables_epoch_%04d_fold_%02d.txt" % (best_test_epoch, train_dataset.split)), "w") as f: for (v, auc) in contributing_variables: f.write("%s %f\n" % (v, auc)) info.split_index = split info.result_dict = train_result info.save_result() return train_result
def train_step( exp_name, ep, model, train_dataset, test_dataset, optimizer, init_lr, lr_decay, data_loader, bce_loss, train_result: TrainResult, ): global prev_plot model.train(True) for _, (X, y) in enumerate(data_loader): optimizer.zero_grad() pred_out = model(X.cuda()).view(X.shape[0]) loss = bce_loss(pred_out, y.cuda()) loss.backward() avg_loss = train_result.avg_loss * 0.98 + loss.detach().cpu().numpy() * 0.02 optimizer.step() train_result.total_iter += len(y) if train_result.total_iter % 10000 == 0: print( "Loss Iter %05d: %.4f\r" % (train_result.total_iter, avg_loss), end="" ) train_result.loss_list.append( (train_result.total_iter, "{:.4f}".format(avg_loss)) ) print("") lr = init_lr * (lr_decay ** ep) for param_group in optimizer.param_groups: param_group["lr"] = lr print("Learning rate = %f" % lr) train_AUC, test_AUC, test_PRAUC, train_accuracy, test_accuracy, test_preds, test_TP, test_TN, test_FN, test_FP = print_metrics(model, train_dataset, test_dataset, train_result) savedir = "/content/drive/My Drive/research/frontiers/checkpoints/%s" % exp_name os.makedirs(savedir, exist_ok=True) split = train_dataset.split savepath = "%s/epoch_%04d_fold_%02d.pt" % (savedir, ep, split) torch.save(model, savepath) if train_result.best_test_AUC < test_AUC: train_result.best_test_AUC = test_AUC train_result.best_test_epoch = ep if ep - prev_plot > 10: # 너무 자주 찍지 말고 한번 plot 찍고 epoch 10번 이상인 경우에만 찍는다. prev_plot = ep #train_utils.plot_AUC(test_dataset, test_preds, test_AUC) #contributing_variables = compute_contributing_variables(model, test_dataset) print( "Epoch %03d: test_AUC: %.4f (best: %.4f epoch: %d), train_AUC: %.4f" % ( ep, test_AUC, train_result.best_test_AUC, train_result.best_test_epoch, train_AUC, ) ) print( " test_accuracy {:.4f}, train_accuracy {:.4f}".format( test_accuracy, train_accuracy, ) ) print( " test_TP {}, test_TN {}, test_FN {}, test_FP {},".format( test_TP, test_TN, test_FN, test_FP ) )