Example #1
0
def print_metrics(model, train_dataset, test_dataset, train_result):
    model.train(False)

    test_preds = train_utils.get_preds(test_dataset.data[:, 1:], model)
    test_AUC = train_utils.compute_AUC(test_dataset.data[:, :1], test_preds)
    test_PRAUC = train_utils.compute_PRAUC(test_dataset.data[:, :1], test_preds)

    test_accuracy = train_utils.compute_accuracy(test_dataset.data[:, :1], test_preds)

    test_TP, test_TN, test_FN, test_FP = train_utils.compute_confusion(test_dataset.data[:, :1], test_preds)


    train_preds = train_utils.get_preds(train_dataset.data[:1000, 1:], model)
    train_AUC = train_utils.compute_AUC(train_dataset.data[:1000, :1], train_preds)
    train_PRAUC = train_utils.compute_PRAUC(train_dataset.data[:1000, :1], train_preds)

    train_accuracy = train_utils.compute_accuracy(train_dataset.data[:1000, :1], train_preds)

    train_result.test_AUC_list.append("%.04f" % test_AUC)
    train_result.test_PRAUC_list.append("%.04f" % test_PRAUC)
    train_result.test_accuracy_list.append("%.04f" % test_accuracy)

    train_result.test_TP_list.append("%.04f" %test_TP)
    train_result.test_TP_list.append("%.04f" %test_TN)
    train_result.test_TP_list.append("%.04f" %test_FN)
    train_result.test_TP_list.append("%.04f" %test_FP)


    return train_AUC, test_AUC, test_PRAUC, train_accuracy, test_accuracy, test_preds, test_TP, test_TN, test_FN, test_FP
Example #2
0
def train(info: TrainInformation, split, fold):
    """주어진 split에 대한 학습과 테스트를 진행한다."""
    bs = info.BS
    init_lr = info.INIT_LR
    lr_decay = info.LR_DECAY
    momentum = info.MOMENTUM
    weight_decay = info.WEIGHT_DECAY
    optimizer_method = info.OPTIMIZER_METHOD
    epoch = info.EPOCH
    nchs = info.NCHS
    filename = info.FILENAME
    model_name = info.MODEL_NAME
    exp_name = info.NAME

    print("Using File {}".format(filename))

    train_dataset = Dataset(split=split, fold=fold, phase="train", filename=filename, use_data_dropout=info.USE_DATA_DROPOUT)
    #val_dataset = Dataset(split=split, fold=fold, phase="val", filename=filename)
    test_dataset = Dataset(split=split, fold=fold, phase="test", filename=filename, use_data_dropout=False)

    model = get_classifier_model(model_name, train_dataset.feature_size, nchs, info.ACTIVATION)
    

    print(model)

    # Optimizer 설정
    optimizer = set_optimizer(
        optimizer_method, model, init_lr, weight_decay, momentum=momentum
    )

    data_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=bs, shuffle=True, num_workers=0, drop_last=True
    )

    bce_loss = torch.nn.BCEWithLogitsLoss().cuda()
    train_result = TrainResult()
    train_result.set_sizes(
        len(train_dataset.data), 0, len(test_dataset.data)
    )

    for ep in range(epoch):
        global prev_plot
        prev_plot = 0
        train_step(
            exp_name,
            ep,
            model,
            train_dataset,
            test_dataset,
            optimizer,
            init_lr,
            lr_decay,
            data_loader,
            bce_loss,
            train_result,
        )

    savedir = "/content/drive/My Drive/research/frontiers/checkpoints/%s" % exp_name
    best_test_epoch = train_result.best_test_epoch #25
    savepath = "%s/epoch_%04d_fold_%02d.pt" % (savedir, best_test_epoch, train_dataset.split)
    #model.load_state_dict(torch.load(savepath))
    model = torch.load(savepath)
    model.eval()

    test_preds = train_utils.get_preds(test_dataset.data[:, 1:], model)
    test_AUC = train_utils.compute_AUC(test_dataset.data[:, :1], test_preds)
    test_PRAUC = train_utils.compute_PRAUC(test_dataset.data[:, :1], test_preds)

    train_utils.plot_AUC(test_dataset, test_preds, test_AUC, savepath=savepath.replace(".pt", "_AUC.tiff"))

    contributing_variables = compute_contributing_variables(model, test_dataset)
    with open(os.path.join(savedir, "contributing_variables_epoch_%04d_fold_%02d.txt" % (best_test_epoch, train_dataset.split)), "w") as f:
        for (v, auc) in contributing_variables:
            f.write("%s %f\n" % (v, auc))

    
    info.split_index = split
    info.result_dict = train_result
    info.save_result()
    return train_result