Exemple #1
0
def Test():
    from Metric.Dice import Dice
    import matplotlib.pyplot as plt
    dice = Dice()
    test_loader = LoadTestData(is_test=True)
    train_loader, validation_loader = LoadTVData(is_test=True)

    model = MultiSegModel(in_channels=3, out_channels=1).to(device)
    model.load_state_dict(torch.load(model_path))

    name_list = ['Train', 'Validation', 'Test']
    loader_list = [train_loader, validation_loader, test_loader]

    # with torch.no_grad():
    model.eval()
    for name_num, loader in enumerate(loader_list):
        class_list, class_pred_list = [], []
        prostate_list, prostate_pred_list = [], []
        roi_list, roi_pred_list = [], []
        prostate_dice, roi_dice = [], []
        for i, (inputs, outputs) in enumerate(loader):
            t2, dwi, adc = inputs[0], inputs[1], inputs[2],
            roi, prostate = outputs[0].to(device), outputs[1].to(device)

            inputs = torch.cat([t2, dwi, adc], dim=1)
            inputs = inputs.type(torch.FloatTensor).to(device)

            roi_out, prostate_out = model(inputs)

            prostate_out = BinaryPred(prostate_out).cpu().detach()
            prostate = prostate.cpu()
            prostate_pred_list.extend(list(prostate_out))
            prostate_list.extend(list(prostate))

            roi_out = BinaryPred(roi_out).cpu().detach()
            roi = roi.cpu()
            roi_pred_list.extend(list(roi_out))
            roi_list.extend(list(roi))

        for idx in range(len(roi_list)):
            roi_dice.append(dice(roi_list[idx], roi_pred_list[idx]).numpy())
            prostate_dice.append(dice(prostate_list[idx], prostate_pred_list[idx]).numpy())
        print('average dice of roi in', name_list[name_num], ':', sum(roi_dice)/len(roi_dice))
        print('average dice of prostate in', name_list[name_num], ':', sum(prostate_dice) / len(prostate_dice))
        plt.hist(roi_dice)
        plt.title('Dice of Prostate Cancer in ' + name_list[name_num])
        plt.show()

        plt.hist(prostate_dice)
        plt.title('Dice of Prostate in ' + name_list[name_num])
        plt.show()
Exemple #2
0
def ShowPicture():
    from Metric.Dice import Dice

    dice = Dice()
    train_loader, validation_loader = LoadTVData(is_test=True)
    test_loader = LoadTestData()
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    model = MultiTaskModel(in_channels=3, out_channels=1).to(device)
    model.load_state_dict(torch.load(model_path))
    ece_pre_list = []
    ece_list = []
    for i, (inputs, outputs) in enumerate(test_loader):
        t2, dwi, adc = inputs[0], inputs[1], inputs[2],
        ece, roi, prostate = np.squeeze(
            outputs[0], axis=1), outputs[1].to(device), outputs[2].to(device)

        inputs = torch.cat([t2, dwi, adc], axis=1)
        inputs = inputs.type(torch.FloatTensor).to(device)

        ece = ece.type(torch.FloatTensor).to(device)

        roi_out, prostate_out, class_out, _ = model(inputs)
        class_out_softmax = nn.functional.softmax(class_out, dim=1)

        ece_pre_list.append(class_out_softmax.cpu().detach().numpy()[0][0])
        ece_list.append(ece.cpu().numpy()[0][0])

        # prostate_out = np.squeeze(BinaryPred(prostate_out).cpu().detach().numpy())
        # prostate = np.squeeze(prostate.cpu().numpy())
        # roi_out = np.squeeze(BinaryPred(roi_out).cpu().detach().numpy())
        # roi = np.squeeze(roi.cpu().numpy())
        # t2_data = np.squeeze(t2.cpu().numpy())
        # plt.imshow(t2_data, cmap='gray')
        # plt.contour(prostate, colors='r')
        # plt.contour(prostate_out, colors='y')
        # plt.contour(roi, colors='g')
        # plt.contour(roi_out, colors='b')
        # plt.title('ECE: ' + str(ece.cpu().numpy()[0][0])
        #           + ', predict ece: ' + str(class_out.cpu().detach().numpy()[0][0]))
        # plt.axis('off')
        # plt.show()
    plt.hist(ece_list)
    plt.show()
Exemple #3
0
def Test():
    from Metric.Dice import Dice
    import matplotlib.pyplot as plt
    dice = Dice()
    test_loader = LoadTestData()
    train_loader, validation_loader = LoadTVData(is_test=True)

    model = MultiTaskModel(in_channels=3, out_channels=1).to(device)
    model.load_state_dict(torch.load(model_path))

    fpr_list, tpr_list, auc_list = [], [], []

    name_list = ['Train', 'Validation', 'Test']
    loader_list = [train_loader, validation_loader, test_loader]

    # with torch.no_grad():
    model.eval()
    for name_num, loader in enumerate(loader_list):
        class_list, class_pred_list = [], []
        prostate_list, prostate_pred_list = [], []
        roi_list, roi_pred_list = [], []
        prostate_dice, roi_dice = [], []
        for i, (inputs, outputs) in enumerate(loader):
            t2, dwi, adc = inputs[0], inputs[1], inputs[2],
            ece, roi, prostate = np.squeeze(
                outputs[0],
                axis=1), outputs[1].to(device), outputs[2].to(device)

            inputs = torch.cat([t2, dwi, adc], axis=1)
            inputs = inputs.type(torch.FloatTensor).to(device)

            ece = np.argmax(ece, axis=1)
            ece = ece.type(torch.LongTensor).to(device)

            roi_out, prostate_out, class_out, _ = model(inputs)
            class_out_softmax = nn.functional.softmax(class_out, dim=1)

            class_list.extend(list(ece.cpu().numpy()))
            class_pred_list.extend(
                list(class_out_softmax.cpu().detach().numpy()[..., 1]))

            prostate_out = BinaryPred(prostate_out).cpu().detach()
            prostate = prostate.cpu()
            prostate_pred_list.extend(list(prostate_out))
            prostate_list.extend(list(prostate))

            roi_out = BinaryPred(roi_out).cpu().detach()
            roi = roi.cpu()
            roi_pred_list.extend(list(roi_out))
            roi_list.extend(list(roi))

        for idx in range(len(roi_list)):
            roi_dice.append(dice(roi_list[idx], roi_pred_list[idx]).numpy())
            prostate_dice.append(
                dice(prostate_list[idx], prostate_pred_list[idx]).numpy())
        print('average dice of roi:', sum(roi_dice) / len(roi_dice))
        print('average dice of prostate:',
              sum(prostate_dice) / len(prostate_dice))
        plt.hist(roi_dice)
        plt.title('Dice of Prostate Cancer in ' + name_list[name_num])
        plt.show()

        plt.hist(prostate_dice)
        plt.title('Dice of Prostate in ' + name_list[name_num])
        plt.show()

        fpr, tpr, auc = get_auc(class_pred_list, class_list)
        fpr_list.append(fpr)
        tpr_list.append(tpr)
        auc_list.append(auc)
    draw_roc(fpr_list, tpr_list, auc_list, name_list)