Esempio n. 1
0
def evaluate(data, model, device, class_no):

    model.eval()

    with torch.no_grad():
        #
        f1 = 0
        test_iou = 0
        test_h_dist = 0
        recall = 0
        precision = 0
        #
        # for index in evaluate_index:
        for j, (testimg, testlabel, testimgname) in enumerate(data):

            # ===========================================================
            # ===========================================================

            testimg = testimg.to(device=device, dtype=torch.float32)

            testlabel = testlabel.to(device=device, dtype=torch.float32)

            testoutput = model(testimg)

            if class_no == 2:
                #
                testoutput = torch.sigmoid(testoutput)
                testoutput = (testoutput > 0.5).float()
                #
            else:
                #
                _, testoutput = torch.max(testoutput, dim=1)
                #
            # mean_iu_ = segmentation_scores(testlabel.cpu().detach().numpy(), testoutput.cpu().detach().numpy(), class_no)

            mean_iu_ = intersectionAndUnion(testoutput.cpu().detach(), testlabel.cpu().detach(), class_no)

            f1_, recall_, precision_ = f1_score(testlabel.cpu().detach().numpy(), testoutput.cpu().detach().numpy(), class_no)

            f1 += f1_
            test_iou += mean_iu_
            recall += recall_
            precision += precision_

    # return test_iou / len(evaluate_index), f1 / len(evaluate_index), recall / len(evaluate_index), precision / len(evaluate_index)
    return test_iou / (j + 1), f1 / (j + 1), recall / (j + 1), precision / (j + 1)
Esempio n. 2
0
def trainSingleModel(model, model_name, num_epochs, learning_rate, datasettag,
                     train_dataset, train_batchsize, trainloader,
                     validateloader, testdata, reverse_mode, lr_schedule,
                     class_no):

    # change log names
    training_amount = len(train_dataset)

    iteration_amount = training_amount // train_batchsize

    iteration_amount = iteration_amount - 1

    device = torch.device('cuda')

    lr_str = str(learning_rate)

    epoches_str = str(num_epochs)

    save_model_name = model_name + '_' + datasettag + '_e' + epoches_str + '_lr' + lr_str

    saved_information_path = './Results'
    try:
        os.mkdir(saved_information_path)
    except OSError as exc:
        if exc.errno != errno.EEXIST:
            raise
        pass

    saved_information_path = saved_information_path + '/' + save_model_name
    try:
        os.mkdir(saved_information_path)
    except OSError as exc:
        if exc.errno != errno.EEXIST:
            raise
        pass

    saved_model_path = saved_information_path + '/trained_models'
    try:
        os.mkdir(saved_model_path)
    except OSError as exc:
        if exc.errno != errno.EEXIST:
            raise
        pass

    print('The current model is:')
    print(save_model_name)
    print('\n')

    writer = SummaryWriter('./Results/Log_' + datasettag + '/' +
                           save_model_name)

    model.to(device)

    threshold = torch.tensor([0.5],
                             dtype=torch.float32,
                             device=device,
                             requires_grad=False)
    upper = torch.tensor([1.0],
                         dtype=torch.float32,
                         device=device,
                         requires_grad=False)
    lower = torch.tensor([0.0],
                         dtype=torch.float32,
                         device=device,
                         requires_grad=False)

    optimizer = torch.optim.AdamW(model.parameters(),
                                  lr=learning_rate,
                                  betas=(0.9, 0.999),
                                  eps=1e-8,
                                  weight_decay=1e-5)

    if lr_schedule is True:

        # scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.2, patience=10, threshold=0.001)
        scheduler = lr_scheduler.MultiStepLR(
            optimizer,
            milestones=[num_epochs // 2, 3 * num_epochs // 4],
            gamma=0.1)

    start = timeit.default_timer()

    for epoch in range(num_epochs):

        model.train()
        train_iou = []
        train_loss = []

        # j: index of iteration
        for j, (images, labels, imagename) in enumerate(trainloader):

            optimizer.zero_grad()

            images = images.to(device=device, dtype=torch.float32)

            if class_no == 2:
                labels = labels.to(device=device, dtype=torch.float32)
            else:
                labels = labels.to(device=device, dtype=torch.long)

            outputs = model(images)

            if class_no == 2:
                prob_outputs = torch.sigmoid(outputs)
                loss = dice_loss(prob_outputs, labels)
                class_outputs = torch.where(prob_outputs > threshold, upper,
                                            lower)
            else:
                prob_outputs = torch.softmax(outputs, dim=1)
                # loss = nn.CrossEntropyLoss(reduction='mean')(prob_outputs, labels)
                loss = nn.CrossEntropyLoss(reduction='mean')(prob_outputs,
                                                             labels.squeeze(1))
                _, class_outputs = torch.max(outputs, dim=1)

            loss.backward()
            optimizer.step()

            mean_iu_, _, __ = segmentation_scores(labels, class_outputs,
                                                  class_no)
            train_iou.append(mean_iu_)
            train_loss.append(loss.item())

        if lr_schedule is True:
            # scheduler.step(validate_iou)
            scheduler.step()
        else:
            pass

        model.eval()

        with torch.no_grad():

            validate_iou = []
            validate_f1 = []
            validate_h_dist = []

            for i, (val_images, val_label,
                    imagename) in enumerate(validateloader):

                val_img = val_images.to(device=device, dtype=torch.float32)

                if class_no == 2:
                    val_label = val_label.to(device=device,
                                             dtype=torch.float32)
                else:
                    val_label = val_label.to(device=device, dtype=torch.long)

                assert torch.max(val_label) != 100.0

                val_outputs = model(val_img)
                if class_no == 2:
                    val_class_outputs = torch.sigmoid(val_outputs)
                    val_class_outputs = (val_class_outputs > 0.5).float()
                else:
                    val_class_outputs = torch.softmax(val_outputs, dim=1)
                    _, val_class_outputs = torch.max(val_class_outputs, dim=1)

                # b, c, h, w = val_label.size()
                # val_class_outputs = val_class_outputs.reshape(b, c, h, w)

                eval_mean_iu_, _, __ = segmentation_scores(
                    val_label, val_class_outputs, class_no)
                eval_f1_, eval_recall_, eval_precision_, eTP, eTN, eFP, eFN, eP, eN = f1_score(
                    val_label, val_class_outputs, class_no)

                validate_iou.append(eval_mean_iu_)
                validate_f1.append(eval_f1_)

                if (val_class_outputs == 1).sum() > 1 and (
                        val_label == 1).sum() > 1 and class_no == 2:
                    v_dist_ = hd95(val_class_outputs, val_label, class_no)
                    validate_h_dist.append(v_dist_)

        print('Step [{}/{}], '
              'Train loss: {:.4f}, '
              'Train iou: {:.4f}, '
              'val iou:{:.4f}, '.format(epoch + 1, num_epochs,
                                        np.nanmean(train_loss),
                                        np.nanmean(train_iou),
                                        np.nanmean(validate_iou)))

        writer.add_scalars(
            'acc metrics', {
                'train iou': np.nanmean(train_iou),
                'val iou': np.nanmean(validate_iou),
                'val f1': np.nanmean(validate_f1)
            }, epoch + 1)

        if epoch > num_epochs - 10:

            save_model_name_full = saved_model_path + '/epoch' + str(epoch)
            save_model_name_full = save_model_name_full + '.pt'
            path_model = save_model_name_full
            torch.save(model, path_model)

    test(testdata,
         saved_model_path,
         device,
         reverse_mode=reverse_mode,
         class_no=class_no,
         save_path=saved_model_path)

    # save model
    stop = timeit.default_timer()
    print('Time: ', stop - start)
    print('\nTraining finished and model saved\n')

    return model
Esempio n. 3
0
def test(testdata, models_path, device, reverse_mode, class_no, save_path):

    all_models = glob.glob(os.path.join(models_path, '*.pt'))

    # with torch.no_grad():

    test_f1 = []
    test_iou = []
    test_h_dist = []
    test_acc = []
    test_w_acc = []
    test_recall = []
    test_precision = []

    test_bf = []

    test_iou_adv = []
    test_h_dist_adv = []

    for model in all_models:

        model = torch.load(model)
        model.eval()

        for j, (testimg, testlabel, testname) in enumerate(testdata):
            # validate batch size will be set up as 2
            # testimg = torch.from_numpy(testimg).to(device=device, dtype=torch.float32)
            # testlabel = torch.from_numpy(testlabel).to(device=device, dtype=torch.float32)

            testimg = testimg.to(device=device, dtype=torch.float32)

            if class_no == 2:
                testlabel = testlabel.to(device=device, dtype=torch.float32)
            else:
                testlabel = testlabel.to(device=device, dtype=torch.long)

            # b, c, h, w = testimg.size()
            # testimg = testimg[:, 0, :, :].view(b, 1, h, w).contiguous()
            # testlabel = testlabel[:, 0, :, :].view(b, 1, h, w).contiguous()

            if torch.max(testlabel) == 255.:
                testlabel = testlabel / 255.

            testimg.requires_grad = True

            threshold = torch.tensor([0.5],
                                     dtype=torch.float32,
                                     device=device,
                                     requires_grad=False)

            upper = torch.tensor([1.0],
                                 dtype=torch.float32,
                                 device=device,
                                 requires_grad=False)

            lower = torch.tensor([0.0],
                                 dtype=torch.float32,
                                 device=device,
                                 requires_grad=False)

            # c, h, w = testimg.size()
            # testimg = testimg.expand(1, c, h, w)

            testoutput = model(testimg)

            if class_no == 2:
                prob_testoutput = torch.sigmoid(testoutput)
                testoutput = (prob_testoutput > 0.5).float()
            else:
                prob_testoutput = torch.softmax(testoutput, dim=1)
                _, testoutput = torch.max(prob_testoutput, dim=1)

            # attack testing data:
            if class_no == 2:
                loss = dice_loss(prob_testoutput, testlabel)
            else:
                loss = nn.CrossEntropyLoss(reduction='mean')(
                    prob_testoutput, testlabel.squeeze(1))

            model.zero_grad()
            loss.backward()
            data_grad = testimg.grad.data
            perturbed_data = fgsm_attack(testimg, 0.2, data_grad)
            output_attack = model(perturbed_data)

            if class_no == 2:
                output_attack = torch.sigmoid(output_attack)
                output_attack = (output_attack > 0.5).float()
            else:
                output_attack = torch.softmax(output_attack, dim=1)
                _, output_attack = torch.max(output_attack, dim=1)

            mean_iu_, acc_, w_acc_ = segmentation_scores(
                testlabel, testoutput, class_no)

            test_iou.append(mean_iu_)
            test_acc.append(acc_)
            test_w_acc.append(w_acc_)

            mean_iu_adv_, _, __ = segmentation_scores(testlabel, output_attack,
                                                      class_no)
            test_iou_adv.append(mean_iu_adv_)

            if (testoutput == 1).sum() > 1 and (
                    testlabel == 1).sum() > 1 and class_no == 2:
                h_dis95_ = hd95(testoutput, testlabel, class_no)
                test_h_dist.append(h_dis95_)

            if (output_attack == 1).sum() > 1 and (
                    testlabel == 1).sum() > 1 and class_no == 2:
                h_dis95_attack_ = hd95(output_attack, testlabel, class_no)
                test_h_dist_adv.append(h_dis95_attack_)

            f1_, recall_, precision_, TP, TN, FP, FN, P, N = f1_score(
                testlabel, testoutput, class_no)

            bf_ = 2 * precision_ * recall_ / (recall_ + precision_)

            test_f1.append(f1_)
            test_recall.append(recall_)
            test_precision.append(precision_)
            test_bf.append(bf_)

    prediction_map_path = save_path + '/Test'

    try:

        os.mkdir(prediction_map_path)

    except OSError as exc:

        if exc.errno != errno.EEXIST:

            raise

        pass

    result_dictionary = {
        'Test IoU mean': str(np.mean(test_iou)),
        'Test IoU std': str(np.std(test_iou)),
        'Test Acc mean': str(np.mean(test_acc)),
        'Test Acc std': str(np.std(test_acc)),
        'Test W ACC mean': str(np.mean(test_w_acc)),
        'Test W ACC std': str(np.std(test_w_acc)),
        'Test BF mean': str(np.mean(test_bf)),
        'Test BF std': str(np.std(test_bf)),
        'Test f1 mean': str(np.mean(test_f1)),
        'Test f1 std': str(np.std(test_f1)),
        'Test H-dist mean': str(np.mean(test_h_dist)),
        'Test H-dist std': str(np.std(test_h_dist)),
        'Test precision mean': str(np.mean(test_precision)),
        'Test precision std': str(np.std(test_precision)),
        'Test recall mean': str(np.mean(test_recall)),
        'Test recall std': str(np.std(test_recall)),
        'Test IoU attack mean': str(np.mean(test_iou_adv)),
        'Test IoU attack std': str(np.std(test_iou_adv)),
        'Test H-dist attack mean': str(np.mean(test_h_dist_adv)),
        'Test H-dist attack std': str(np.std(test_h_dist_adv)),
    }

    ff_path = prediction_map_path + '/test_result_data.txt'
    ff = open(ff_path, 'w')
    ff.write(str(result_dictionary))
    ff.close()

    print('Test h-dist: {:.4f}, '
          'Test iou: {:.4f}, '.format(np.mean(test_h_dist), np.mean(test_iou)))
Esempio n. 4
0
def evaluate(evaluatedata, model, device, reverse_mode, class_no):

    model.eval()

    f1 = 0
    test_iou = 0
    test_h_dist = 0
    recall = 0
    precision = 0

    FPs_Ns = 0
    FNs_Ps = 0
    FPs_Ps = 0
    FNs_Ns = 0
    TPs = 0
    TNs = 0
    FNs = 0
    FPs = 0
    Ps = 0
    Ns = 0

    test_iou_attack = 0
    test_h_dist_attack = 0

    effective_h = 0
    effective_h_attack = 0

    for j, (testimg, testlabel, testname) in enumerate(evaluatedata):
        # validate batch size will be set up as 2
        # j will be close enough to the

        # testimg = testimg.to(device=device, dtype=torch.float32)

        testimg = testimg.to(device=device, dtype=torch.float32)
        testlabel = testlabel.to(device=device, dtype=torch.float32)

        # b, c, h, w = testimg.size()
        # testimg = testimg[:, 0, :, :].view(b, 1, h, w).contiguous()
        # testlabel = testlabel[:, 0, :, :].view(b, 1, h, w).contiguous()

        # if torch.max(testlabel) == 255.:
        #     testlabel = testlabel / 255.

        testimg.requires_grad = True

        # testlabel = testlabel.to(device=device, dtype=torch.float32)

        threshold = torch.tensor([0.5],
                                 dtype=torch.float32,
                                 device=device,
                                 requires_grad=False)

        upper = torch.tensor([1.0],
                             dtype=torch.float32,
                             device=device,
                             requires_grad=False)

        lower = torch.tensor([0.0],
                             dtype=torch.float32,
                             device=device,
                             requires_grad=False)

        testoutput = model(testimg)

        prob_testoutput = torch.sigmoid(testoutput)

        # attack testing data:
        loss = dice_loss(prob_testoutput, testlabel)
        model.zero_grad()
        loss.backward()
        data_grad = testimg.grad.data
        perturbed_data = fgsm_attack(testimg, 0.2, data_grad)
        output_attack = model(perturbed_data)
        output_attack = torch.sigmoid(output_attack)

        if reverse_mode is True:

            testoutput = torch.where(prob_testoutput < threshold, upper, lower)
            output_attack = torch.where(output_attack < threshold, upper,
                                        lower)

        else:

            testoutput = torch.where(prob_testoutput > threshold, upper, lower)
            output_attack = torch.where(output_attack > threshold, upper,
                                        lower)

        mean_iu_, _, __ = segmentation_scores(testlabel, testoutput, class_no)
        mean_iu_attack_, _, __ = segmentation_scores(testlabel, output_attack,
                                                     class_no)

        if (testoutput == 1).sum() > 1 and (testlabel == 1).sum() > 1:

            h_dis95_ = hd95(testoutput, testlabel, class_no)
            test_h_dist += h_dis95_
            effective_h = effective_h + 1

        if (output_attack == 1).sum() > 1 and (testlabel == 1).sum() > 1:

            h_dis95_attack_ = hd95(output_attack, testlabel, class_no)
            effective_h_attack = effective_h_attack + 1
            test_h_dist_attack += h_dis95_attack_

        f1_, recall_, precision_, TP, TN, FP, FN, P, N = f1_score(
            testlabel, testoutput, class_no)

        f1 += f1_
        test_iou += mean_iu_
        recall += recall_
        precision += precision_
        TPs += TP
        TNs += TN
        FPs += FP
        FNs += FN
        Ps += P
        Ns += N
        FNs_Ps += (FN + 1e-10) / (P + 1e-10)
        FPs_Ns += (FP + 1e-10) / (N + 1e-10)
        FNs_Ns += (FN + 1e-10) / (N + 1e-10)
        FPs_Ps += (FP + 1e-10) / (P + 1e-10)

        test_iou_attack += mean_iu_attack_

    return test_iou / (j + 1), f1 / (j + 1), recall / (j + 1), precision / (
        j +
        1), FPs_Ns / (j + 1), FPs_Ps / (j + 1), FNs_Ns / (j + 1), FNs_Ps / (
            j + 1), FPs / (j + 1), FNs / (j + 1), TPs / (j + 1), TNs / (
                j + 1), Ps / (j + 1), Ns / (j + 1), test_h_dist / (
                    effective_h + 1), test_iou_attack / (
                        j + 1), test_h_dist_attack / (effective_h_attack + 1)
Esempio n. 5
0
def test1(data_1, model, device, class_no, save_location):

    model.eval()

    data_1_testoutputs = []

    with torch.no_grad():

        f1_1 = 0
        test_iou_1 = 0
        # test_h_dist_1 = 0
        recall_1 = 0
        precision_1 = 0
        mse_1 = 0

        # ==============================================
        evaluate_index_all_1 = range(0, len(data_1) - 1)
        #
        # ==============================================
        # evaluate_index_all_2 = range(0, len(data_2) - 1)
        #
        for j, (testimg, testlabel, testimgname) in enumerate(data_1):
            # extract a few random indexs every time in a range of the data
            # ========================================================================
            # ========================================================================

            testimg = torch.from_numpy(testimg).to(device=device, dtype=torch.float32)

            testlabel = torch.from_numpy(testlabel).to(device=device, dtype=torch.float32)

            c, h, w = testimg.size()
            testimg = testimg.expand(1, c, h, w)

            testoutput_original = model(testimg)

            if class_no == 2:

                testoutput = torch.sigmoid(testoutput_original.view(1, h, w))

                testoutput = (testoutput > 0.5).float()

                data_1_testoutputs.append(testoutput)
                #
            else:
                #
                _, testoutput = torch.max(testoutput_original, dim=1)
                #
            mean_iu_ = intersectionAndUnion(testoutput.cpu().detach(), testlabel.cpu().detach(), class_no)

            f1_, recall_, precision_ = f1_score(testlabel.cpu().detach().numpy(), testoutput.cpu().detach().numpy(), class_no)

            mse_ = (np.square(testlabel.cpu().detach().numpy() - testoutput.cpu().detach().numpy())).mean()

            f1_1 += f1_
            test_iou_1 += mean_iu_
            recall_1 += recall_
            precision_1 += precision_
            mse_1 += mse_
            #
            # # Plotting segmentation:
            # testoutput_original = np.asarray(testoutput_original.cpu().detach().numpy(), dtype=np.uint8)
            # testoutput_original = np.squeeze(testoutput_original, axis=0)
            # testoutput_original = np.repeat(testoutput_original[:, :, np.newaxis], 3, axis=2)
            # #
            # if class_no == 2:
            #     segmentation_map = np.zeros((h, w, 3), dtype=np.uint8)
            #     #
            #     segmentation_map[:, :, 0][np.logical_and(testoutput_original[:, :, 0] == 1, testoutput_original[:, :, 1] == 1, testoutput_original[:, :, 2] == 1)] = 255
            #     segmentation_map[:, :, 1][np.logical_and(testoutput_original[:, :, 0] == 1, testoutput_original[:, :, 1] == 1, testoutput_original[:, :, 2] == 1)] = 0
            #     segmentation_map[:, :, 2][np.logical_and(testoutput_original[:, :, 0] == 1, testoutput_original[:, :, 1] == 1, testoutput_original[:, :, 2] == 1)] = 0
            #     #
            # else:
            #     segmentation_map = np.zeros((h, w, 3), dtype=np.uint8)
            #     if class_no == 4:
            #         # multi class for brats 2018
            #         segmentation_map[:, :, 0][np.logical_and(testoutput_original[:, :, 0] == 1, testoutput_original[:, :, 1] == 1, testoutput_original[:, :, 2] == 1)] = 255
            #         segmentation_map[:, :, 1][np.logical_and(testoutput_original[:, :, 0] == 1, testoutput_original[:, :, 1] == 1, testoutput_original[:, :, 2] == 1)] = 0
            #         segmentation_map[:, :, 2][np.logical_and(testoutput_original[:, :, 0] == 1, testoutput_original[:, :, 1] == 1, testoutput_original[:, :, 2] == 1)] = 0
            #         #
            #         segmentation_map[:, :, 0][np.logical_and(testoutput_original[:, :, 0] == 2, testoutput_original[:, :, 1] == 2, testoutput_original[:, :, 2] == 2)] = 0
            #         segmentation_map[:, :, 1][np.logical_and(testoutput_original[:, :, 0] == 2, testoutput_original[:, :, 1] == 2, testoutput_original[:, :, 2] == 2)] = 255
            #         segmentation_map[:, :, 2][np.logical_and(testoutput_original[:, :, 0] == 2, testoutput_original[:, :, 1] == 2, testoutput_original[:, :, 2] == 2)] = 0
            #         #
            #         segmentation_map[:, :, 0][np.logical_and(testoutput_original[:, :, 0] == 3, testoutput_original[:, :, 1] == 3, testoutput_original[:, :, 2] == 3)] = 0
            #         segmentation_map[:, :, 1][np.logical_and(testoutput_original[:, :, 0] == 3, testoutput_original[:, :, 1] == 3, testoutput_original[:, :, 2] == 3)] = 0
            #         segmentation_map[:, :, 2][np.logical_and(testoutput_original[:, :, 0] == 3, testoutput_original[:, :, 1] == 3, testoutput_original[:, :, 2] == 3)] = 255
            #         #
            #     elif class_no == 8:
            #         # multi class for cityscapes
            #         segmentation_map[:, :, 0][np.logical_and(testoutput_original[:, :, 0] == 0, testoutput_original[:, :, 1] == 0, testoutput_original[:, :, 2] == 0)] = 255
            #         segmentation_map[:, :, 1][np.logical_and(testoutput_original[:, :, 0] == 0, testoutput_original[:, :, 1] == 0, testoutput_original[:, :, 2] == 0)] = 0
            #         segmentation_map[:, :, 2][np.logical_and(testoutput_original[:, :, 0] == 0, testoutput_original[:, :, 1] == 0, testoutput_original[:, :, 2] == 0)] = 0
            #         #
            #         segmentation_map[:, :, 0][np.logical_and(testoutput_original[:, :, 0] == 1, testoutput_original[:, :, 1] == 1, testoutput_original[:, :, 2] == 1)] = 0
            #         segmentation_map[:, :, 1][np.logical_and(testoutput_original[:, :, 0] == 1, testoutput_original[:, :, 1] == 1, testoutput_original[:, :, 2] == 1)] = 255
            #         segmentation_map[:, :, 2][np.logical_and(testoutput_original[:, :, 0] == 1, testoutput_original[:, :, 1] == 1, testoutput_original[:, :, 2] == 1)] = 0
            #         #
            #         segmentation_map[:, :, 0][np.logical_and(testoutput_original[:, :, 0] == 2, testoutput_original[:, :, 1] == 2, testoutput_original[:, :, 2] == 2)] = 0
            #         segmentation_map[:, :, 1][np.logical_and(testoutput_original[:, :, 0] == 2, testoutput_original[:, :, 1] == 2, testoutput_original[:, :, 2] == 2)] = 0
            #         segmentation_map[:, :, 2][np.logical_and(testoutput_original[:, :, 0] == 2, testoutput_original[:, :, 1] == 2, testoutput_original[:, :, 2] == 2)] = 255
            #         #
            #         segmentation_map[:, :, 0][np.logical_and(testoutput_original[:, :, 0] == 3, testoutput_original[:, :, 1] == 3, testoutput_original[:, :, 2] == 3)] = 255
            #         segmentation_map[:, :, 1][np.logical_and(testoutput_original[:, :, 0] == 3, testoutput_original[:, :, 1] == 3, testoutput_original[:, :, 2] == 3)] = 255
            #         segmentation_map[:, :, 2][np.logical_and(testoutput_original[:, :, 0] == 3, testoutput_original[:, :, 1] == 3, testoutput_original[:, :, 2] == 3)] = 0
            #         #
            #         segmentation_map[:, :, 0][np.logical_and(testoutput_original[:, :, 0] == 4, testoutput_original[:, :, 1] == 4, testoutput_original[:, :, 2] == 4)] = 153
            #         segmentation_map[:, :, 1][np.logical_and(testoutput_original[:, :, 0] == 4, testoutput_original[:, :, 1] == 4, testoutput_original[:, :, 2] == 4)] = 51
            #         segmentation_map[:, :, 2][np.logical_and(testoutput_original[:, :, 0] == 4, testoutput_original[:, :, 1] == 4, testoutput_original[:, :, 2] == 4)] = 255
            #         #
            #         segmentation_map[:, :, 0][np.logical_and(testoutput_original[:, :, 0] == 5, testoutput_original[:, :, 1] == 5, testoutput_original[:, :, 2] == 5)] = 255
            #         segmentation_map[:, :, 1][np.logical_and(testoutput_original[:, :, 0] == 5, testoutput_original[:, :, 1] == 5, testoutput_original[:, :, 2] == 5)] = 102
            #         segmentation_map[:, :, 2][np.logical_and(testoutput_original[:, :, 0] == 5, testoutput_original[:, :, 1] == 5, testoutput_original[:, :, 2] == 5)] = 178
            #         #
            #         segmentation_map[:, :, 0][np.logical_and(testoutput_original[:, :, 0] == 6, testoutput_original[:, :, 1] == 6, testoutput_original[:, :, 2] == 6)] = 102
            #         segmentation_map[:, :, 1][np.logical_and(testoutput_original[:, :, 0] == 6, testoutput_original[:, :, 1] == 6, testoutput_original[:, :, 2] == 6)] = 255
            #         segmentation_map[:, :, 2][np.logical_and(testoutput_original[:, :, 0] == 6, testoutput_original[:, :, 1] == 6, testoutput_original[:, :, 2] == 6)] = 102
            #         #
            # prediction_name = 'seg_' + test_imagename + '.png'
            # full_error_map_name = os.path.join(prediction_map_path, prediction_name)
            # imageio.imsave(full_error_map_name, segmentation_map)

    #
    prediction_map_path = save_location + '/' + 'Results_map'
    #
    try:
        os.mkdir(prediction_map_path)
    except OSError as exc:
        if exc.errno != errno.EEXIST:
            raise
        pass
    # save numerical results:
    result_dictionary = {'Test IoU data 1': str(test_iou_1 / len(evaluate_index_all_1)),
                         'Test f1 data 1': str(f1_1 / len(evaluate_index_all_1)),
                         'Test recall data 1': str(recall_1 / len(evaluate_index_all_1)),
                         'Test Precision data 1': str(precision_1 / len(evaluate_index_all_1)),
                         'Test MSE data 1': str(mse_1 / len(evaluate_index_all_1))
                         }

    ff_path = prediction_map_path + '/test_result_data.txt'
    ff = open(ff_path, 'w')
    ff.write(str(result_dictionary))
    ff.close()

    return test_iou_1 / len(evaluate_index_all_1), f1_1 / len(evaluate_index_all_1), recall_1 / len(evaluate_index_all_1), precision_1 / len(evaluate_index_all_1), mse_1 / len(evaluate_index_all_1), \
           data_1_testoutputs
Esempio n. 6
0
def trainSingleModel(model,
                     model_name,
                     num_epochs,
                     learning_rate,
                     datasettag,
                     train_dataset,
                     train_batchsize,
                     trainloader,
                     validateloader,
                     testdata,
                     reverse_mode,
                     lr_schedule,
                     class_no):

    # change log names
    training_amount = len(train_dataset)

    iteration_amount = training_amount // train_batchsize

    iteration_amount = iteration_amount - 1

    device = torch.device('cuda')

    lr_str = str(learning_rate)

    epoches_str = str(num_epochs)

    save_model_name = model_name + '_' + datasettag + '_e' + epoches_str + '_lr' + lr_str

    saved_information_path = './Results'
    try:
        os.mkdir(saved_information_path)
    except OSError as exc:
        if exc.errno != errno.EEXIST:
            raise
        pass

    saved_information_path = saved_information_path + '/' + save_model_name
    try:
        os.mkdir(saved_information_path)
    except OSError as exc:
        if exc.errno != errno.EEXIST:
            raise
        pass

    saved_model_path = saved_information_path + '/trained_models'
    try:
        os.mkdir(saved_model_path)
    except OSError as exc:
        if exc.errno != errno.EEXIST:
            raise
        pass

    print('The current model is:')
    print(save_model_name)
    print('\n')

    writer = SummaryWriter('./Results/Log_' + datasettag + '/' + save_model_name)

    model.to(device)

    threshold = torch.tensor([0.5], dtype=torch.float32, device=device, requires_grad=False)
    upper = torch.tensor([1.0], dtype=torch.float32, device=device, requires_grad=False)
    lower = torch.tensor([0.0], dtype=torch.float32, device=device, requires_grad=False)

    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-5)

    if lr_schedule is True:

        # scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.2, patience=10, threshold=0.001)
        scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[num_epochs // 2, 3*num_epochs // 4], gamma=0.1)

    start = timeit.default_timer()

    for epoch in range(num_epochs):

        model.train()

        h_dists = 0
        f1 = 0
        accuracy_iou = 0
        running_loss = 0
        recall = 0
        precision = 0

        t_FPs_Ns = 0
        t_FPs_Ps = 0
        t_FNs_Ns = 0
        t_FNs_Ps = 0
        t_FPs = 0
        t_FNs = 0
        t_TPs = 0
        t_TNs = 0
        t_Ps = 0
        t_Ns = 0

        effective_h = 0

        # j: index of iteration
        for j, (images, labels, imagename) in enumerate(trainloader):

            # check training data:
            # image = images[0, :, :, :].squeeze().detach().cpu().numpy()
            # label = labels[0, :, :, :].squeeze().detach().cpu().numpy()
            # image = np.transpose(image, (1, 2, 0))
            # label = np.expand_dims(label, axis=2)
            # label = np.concatenate((label, label, label), axis=2)
            # plt.imshow(0.5*image + 0.5*label)
            # plt.show()

            optimizer.zero_grad()

            images = images.to(device=device, dtype=torch.float32)
            labels = labels.to(device=device, dtype=torch.float32)

            images.requires_grad = True

            if reverse_mode is True:

                inverse_labels = torch.ones_like(labels)
                inverse_labels = inverse_labels.to(device=device, dtype=torch.float32)
                inverse_labels = inverse_labels - labels
            else:
                pass

            outputs = model(images)
            prob_outputs = torch.sigmoid(outputs)

            if reverse_mode is True:
                loss = dice_loss(prob_outputs, inverse_labels)
            else:
                loss = dice_loss(prob_outputs, labels)

            loss.backward()
            optimizer.step()

            # The taks of binary segmentation is too easy, to compensate the simplicity of the task,
            # we add adversarial noises in the testing images:
            data_grad = images.grad.data
            perturbed_data = fgsm_attack(images, 0.2, data_grad)
            prob_outputs = model(perturbed_data)
            prob_outputs = torch.sigmoid(prob_outputs)

            if reverse_mode is True:
                class_outputs = torch.where(prob_outputs < threshold, upper, lower)
            else:
                class_outputs = torch.where(prob_outputs > threshold, upper, lower)

            if class_no == 2:
                # hasudorff distance is for binary
                if (class_outputs == 1).sum() > 1 and (labels == 1).sum() > 1:
                    dist_ = hd95(class_outputs, labels, class_no)
                    h_dists += dist_
                    effective_h = effective_h + 1
                else:
                    pass
            else:
                pass

            mean_iu_ = segmentation_scores(labels, class_outputs, class_no)
            f1_, recall_, precision_, TPs_, TNs_, FPs_, FNs_, Ps_, Ns_ = f1_score(labels, class_outputs, class_no)

            running_loss += loss
            f1 += f1_
            accuracy_iou += mean_iu_
            recall += recall_
            precision += precision_
            t_TPs += TPs_
            t_TNs += TNs_
            t_FPs += FPs_
            t_FNs += FNs_
            t_Ps += Ps_
            t_Ns += Ns_
            t_FNs_Ps += (FNs_ + 1e-8) / (Ps_ + 1e-8)
            t_FPs_Ns += (FPs_ + 1e-8) / (Ns_ + 1e-8)
            t_FNs_Ns += (FNs_ + 1e-8) / (Ns_ + 1e-8)
            t_FPs_Ps += (FPs_ + 1e-8) / (Ps_ + 1e-8)

            if (j + 1) % iteration_amount == 0:

                validate_iou, validate_f1, validate_recall, validate_precision, v_FPs_Ns, v_FPs_Ps, v_FNs_Ns, v_FNs_Ps, v_FPs, v_FNs, v_TPs, v_TNs, v_Ps, v_Ns, v_h_dist = evaluate(validateloader, model, device, reverse_mode=reverse_mode, class_no=class_no)

                print(
                    'Step [{}/{}], Train loss: {:.4f}, '
                    'Train iou: {:.4f}, '
                    'Train h-dist:{:.4f}, '
                    'Val iou: {:.4f},'
                    'Val h-dist: {:.4f}'.format(epoch + 1, num_epochs,
                                                   running_loss / (j + 1),
                                                   accuracy_iou / (j + 1),
                                                   h_dists / (effective_h + 1),
                                                   validate_iou,
                                                   v_h_dist))

                # # # ================================================================== #
                # # #                        TensorboardX Logging                        #
                # # # # ================================================================ #

                writer.add_scalars('acc metrics', {'train iou': accuracy_iou / (j+1),
                                                   'train hausdorff dist': h_dists / (effective_h+1),
                                                   'val iou': validate_iou,
                                                   'val hasudorff distance': v_h_dist,
                                                   'loss': running_loss / (j+1)}, epoch + 1)

                writer.add_scalars('train confusion matrices analysis', {'train FPs/Ns': t_FPs_Ns / (j+1),
                                                                         'train FNs/Ps': t_FNs_Ps / (j+1),
                                                                         'train FPs/Ps': t_FPs_Ps / (j+1),
                                                                         'train FNs/Ns': t_FNs_Ns / (j+1),
                                                                         'train FNs': t_FNs / (j+1),
                                                                         'train FPs': t_FPs / (j+1),
                                                                         'train TNs': t_TNs / (j+1),
                                                                         'train TPs': t_TPs / (j+1),
                                                                         'train Ns': t_Ns / (j+1),
                                                                         'train Ps': t_Ps / (j+1),
                                                                         'train imbalance': t_Ps / (t_Ps + t_Ns)}, epoch + 1)

                writer.add_scalars('val confusion matrices analysis', {'val FPs/Ns': v_FPs_Ns,
                                                                       'val FNs/Ps': v_FNs_Ps,
                                                                       'val FPs/Ps': v_FPs_Ps,
                                                                       'val FNs/Ns': v_FNs_Ns,
                                                                       'val FNs': v_FNs,
                                                                       'val FPs': v_FPs,
                                                                       'val TNs': v_TNs,
                                                                       'val TPs': v_TPs,
                                                                       'val Ns': v_Ns,
                                                                       'val Ps': v_Ps,
                                                                       'val imbalance': v_Ps / (v_Ps + v_Ns)}, epoch + 1)
            else:
                pass

            # A learning rate schedule plan for fn attention:
            # we ramp-up linearly inside of each iteration
            # without the warm-up, it is hard to train sometimes
            if 'fn' in model_name or 'FN' in model_name:
                if reverse_mode is True:
                    if epoch < 10:
                        for param_group in optimizer.param_groups:
                            param_group['lr'] = learning_rate * (j / len(trainloader))
                    else:
                        pass
                else:
                    pass
            else:
                pass

        if lr_schedule is True:
            scheduler.step()
        else:
            pass

        # save models at last 10 epochs
        if epoch >= (num_epochs - 10):
            save_model_name_full = saved_model_path + '/' + save_model_name + '_epoch' + str(epoch) + '.pt'
            path_model = save_model_name_full
            torch.save(model, path_model)

    # Test on all models and average them:
    test(testdata,
         saved_model_path,
         device,
         reverse_mode=reverse_mode,
         class_no=class_no,
         save_path=saved_information_path)

    # save model
    stop = timeit.default_timer()
    print('Time: ', stop - start)
    print('\n')
    print('\nTraining finished and model saved\n')

    return model
def test(
        testdata,
         models_path,
         device,
         reverse_mode,
         class_no,
         save_path):

    all_models = glob.glob(os.path.join(models_path, '*.pt'))

    test_f1 = []
    test_iou = []
    test_h_dist = []
    test_recall = []
    test_precision = []

    test_iou_adv = []
    test_h_dist_adv = []

    for model in all_models:

        model = torch.load(model)
        model.eval()

        for j, (testimg, testlabel, testname) in enumerate(testdata):
            # validate batch size will be set up as 2
            # testimg = torch.from_numpy(testimg).to(device=device, dtype=torch.float32)
            # testlabel = torch.from_numpy(testlabel).to(device=device, dtype=torch.float32)

            testimg = testimg.to(device=device, dtype=torch.float32)
            testimg.requires_grad = True

            testlabel = testlabel.to(device=device, dtype=torch.float32)

            threshold = torch.tensor([0.5], dtype=torch.float32, device=device, requires_grad=False)
            upper = torch.tensor([1.0], dtype=torch.float32, device=device, requires_grad=False)
            lower = torch.tensor([0.0], dtype=torch.float32, device=device, requires_grad=False)

            # c, h, w = testimg.size()
            # testimg = testimg.expand(1, c, h, w)

            testoutput = model(testimg)
            # (todo) add for multi-class
            prob_testoutput = torch.sigmoid(testoutput)

            if class_no == 2:

                if reverse_mode is True:

                    testoutput = torch.where(prob_testoutput < threshold, upper, lower)

                else:

                    testoutput = torch.where(prob_testoutput > threshold, upper, lower)

            # metrics before attack:
            mean_iu_ = segmentation_scores(testlabel, testoutput, class_no)
            test_iou.append(mean_iu_)

            if (testoutput == 1).sum() > 1 and (testlabel == 1).sum() > 1:
                h_dis95_ = hd95(testoutput, testlabel, class_no)
                test_h_dist.append(h_dis95_)

            f1_, recall_, precision_, TP, TN, FP, FN, P, N = f1_score(testlabel, testoutput, class_no)
            test_f1.append(f1_)
            test_recall.append(recall_)
            test_precision.append(precision_)

            # attack testing data:
            loss = dice_loss(prob_testoutput, testlabel)
            model.zero_grad()
            loss.backward()
            data_grad = testimg.grad.data
            perturbed_data = fgsm_attack(testimg, 0.2, data_grad)
            prob_testoutput_adv = model(perturbed_data)
            prob_testoutput_adv = torch.sigmoid(prob_testoutput_adv)

            if class_no == 2:

                if reverse_mode is True:

                    testoutput_adv = torch.where(prob_testoutput_adv < threshold, upper, lower)

                else:

                    testoutput_adv = torch.where(prob_testoutput_adv > threshold, upper, lower)

            mean_iu_adv_ = segmentation_scores(testlabel, testoutput_adv, class_no)
            test_iou_adv.append(mean_iu_adv_)

            if (testoutput_adv == 1).sum() > 1 and (testlabel == 1).sum() > 1:
                h_dis95_adv_ = hd95(testoutput_adv, testlabel, class_no)
                test_h_dist_adv.append(h_dis95_adv_)

    # store the test metrics

    prediction_map_path = save_path + '/Test_result'

    try:

        os.mkdir(prediction_map_path)

    except OSError as exc:

        if exc.errno != errno.EEXIST:

            raise

        pass

    # save numerical results:
    result_dictionary = {
        'Test IoU mean': str(np.mean(test_iou)),
        'Test IoU std': str(np.std(test_iou)),
        'Test f1 mean': str(np.mean(test_f1)),
        'Test f1 std': str(np.std(test_f1)),
        'Test H-dist mean': str(np.mean(test_h_dist)),
        'Test H-dist std': str(np.std(test_h_dist)),
        'Test precision mean': str(np.mean(test_precision)),
        'Test precision std': str(np.std(test_precision)),
        'Test recall mean': str(np.mean(test_recall)),
        'Test recall std': str(np.std(test_recall)),
        'Test IoU attack mean': str(np.mean(test_iou_adv)),
        'Test IoU attack std': str(np.std(test_iou_adv)),
        'Test H-dist attack mean': str(np.mean(test_h_dist_adv)),
        'Test H-dist attack std': str(np.std(test_h_dist_adv)),
    }

    ff_path = prediction_map_path + '/test_results.txt'
    ff = open(ff_path, 'w')
    ff.write(str(result_dictionary))
    ff.close()

    print(
        'Test h-dist: {:.4f}, '
        'Val iou: {:.4f}, '.format(np.mean(test_h_dist), np.mean(test_iou)))