예제 #1
0
def trainSingleModel(model_name, depth_limit, epochs, width, depth, repeat, lr,
                     lr_scedule, train_dataset, train_batch, data_name,
                     data_augmentation_train, data_augmentation_test,
                     train_loader, validate_data, test_data_1, test_data_2,
                     shuffle, loss, norm, log, no_class, input_channel):
    # :param model: network module
    # :param epochs: training total epochs
    # :param width: first encoder channel number
    # :param lr: learning rate
    # :param lr_scedule: true or false for learning rate schedule
    # :param repeat: repeat same experiments
    # :param train_dataset: training data set
    # :param train_batch: batch size
    # :param train_loader: training loader
    # :param validate_loader: validation loader
    # :param shuffle: shuffle training data or not
    # :param loss: loss function tag, use 'ce' for cross-entropy
    # :param weights_transfer: 'dynamic', 'static' or 'average'
    # :param alpha: weight for knowledge distillation loss
    # :param norm_1: normalisation for model 1
    # :param norm_2: normalisation for model 2
    # :param log: log tag for recording experiments
    # :param no_class: 2 or multi-class
    # :param input_channel: 4 for BRATS, 3 for CityScapes
    # :param dataset_name: name of the dataset
    # :param temperature_start: 2 or 4
    # :param temperature_end: 4 or 2
    # :return:
    device = torch.device('cuda:0')

    # side_output_use = False

    if model_name == 'unet':

        model = UNet(n_channels=input_channel,
                     n_classes=no_class,
                     bilinear=True).to(device=device)

        # model = UNet2(in_channels=1, n_classes=1, depth=4, wf=32, padding=False, batch_norm=True, up_mode='upconv').to(device=device)

    elif model_name == 'Segnet':

        model = SegNet(in_ch=input_channel,
                       width=width,
                       norm=norm,
                       depth=4,
                       n_classes=no_class,
                       dropout=True,
                       side_output=False).to(device=device)

    elif model_name == 'SOASNet_single':

        model = SOASNet_ss(in_ch=input_channel,
                           width=width,
                           depth=depth,
                           norm=norm,
                           n_classes=no_class,
                           mode='low_rank_attn',
                           side_output=False,
                           downsampling_limit=depth_limit).to(device=device)

    elif model_name == 'SOASNet':

        model = SOASNet(in_ch=input_channel,
                        width=width,
                        depth=depth,
                        norm=norm,
                        n_classes=no_class,
                        mode='low_rank_attn',
                        side_output=False,
                        downsampling_limit=depth_limit).to(device=device)

    elif model_name == 'SOASNet_large_kernel':

        model = SOASNet_ls(in_ch=input_channel,
                           width=width,
                           depth=depth,
                           norm=norm,
                           n_classes=no_class,
                           mode='low_rank_attn',
                           side_output=False,
                           downsampling_limit=depth_limit).to(device=device)

    elif model_name == 'SOASNet_multi_attn':

        model = SOASNet_ma(in_ch=input_channel,
                           width=width,
                           depth=depth,
                           norm=norm,
                           n_classes=no_class,
                           mode='low_rank_attn',
                           side_output=False,
                           downsampling_limit=depth_limit).to(device=device)

    elif model_name == 'SOASNet_very_large_kernel':

        model = SOASNet_vls(in_ch=input_channel,
                            width=width,
                            depth=depth,
                            norm=norm,
                            n_classes=no_class,
                            mode='low_rank_attn',
                            side_output=False,
                            downsampling_limit=depth_limit).to(device=device)

    elif model_name == 'SOASNet_segnet':

        model = SOASNet_segnet(
            in_ch=input_channel,
            width=width,
            depth=depth,
            norm=norm,
            n_classes=no_class,
            mode='low_rank_attn',
            side_output=False,
            downsampling_limit=depth_limit).to(device=device)

    elif model_name == 'SOASNet_segnet_skip':

        model = SOASNet_segnet_skip(
            in_ch=input_channel,
            width=width,
            depth=depth,
            norm=norm,
            n_classes=no_class,
            mode='low_rank_attn',
            side_output=False,
            downsampling_limit=depth_limit).to(device=device)

    elif model_name == 'RelayNet':

        model = SOASNet_segnet_skip(
            in_ch=input_channel,
            width=width,
            depth=depth,
            norm=norm,
            n_classes=no_class,
            mode='relaynet',
            side_output=False,
            downsampling_limit=depth_limit).to(device=device)

    elif model_name == 'attn_unet':

        model = AttentionUNet(in_ch=input_channel,
                              width=width,
                              visulisation=False,
                              class_no=no_class).to(device=device)

    # ==================================
    training_amount = len(train_dataset)
    iteration_amount = training_amount // train_batch
    iteration_amount = iteration_amount - 1

    model_name = model_name + '_Epoch_' + str(epochs) + \
                 '_Dataset_' + data_name + \
                 '_Batch_' + str(train_batch) + \
                 '_Width_' + str(width) + \
                 '_Loss_' + loss + \
                 '_Norm_' + norm + \
                 '_ShuffleTraining_' + str(shuffle) + \
                 '_Data_Augmentation_Train_' + data_augmentation_train + '_' + \
                 '_Data_Augmentation_Test_' + data_augmentation_test + '_' + \
                 '_lr_' + str(lr) + \
                 '_Repeat_' + str(repeat)

    print(model_name)

    writer = SummaryWriter('../../Log_' + log + '/' + model_name)

    optimizer = AdamW(model.parameters(),
                      lr=lr,
                      betas=(0.9, 0.999),
                      eps=1e-8,
                      weight_decay=1e-5)

    # if lr_scedule is True:
    #     learning_rate_steps = lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.1)

    for epoch in range(epochs):

        model.train()

        running_loss = 0

        # i: index of mini batch
        if 'mixup' not in data_augmentation_train:

            for j, (images, labels, imagename) in enumerate(train_loader):

                images = images.to(device=device, dtype=torch.float32)

                if no_class == 2:

                    labels = labels.to(device=device, dtype=torch.float32)

                else:

                    labels = labels.to(device=device, dtype=torch.long)

                outputs_logits = model(images)

                optimizer.zero_grad()

                # calculate main losses for second time
                if no_class == 2:
                    #
                    if loss == 'dice':
                        #
                        main_loss = dice_loss(torch.sigmoid(outputs_logits),
                                              labels)
                        #
                    elif loss == 'ce':
                        #
                        main_loss = nn.BCEWithLogitsLoss(reduction='mean')(
                            outputs_logits, labels)
                        #
                    elif loss == 'hybrid':
                        #
                        main_loss = dice_loss(
                            torch.sigmoid(outputs_logits),
                            labels) + nn.BCEWithLogitsLoss(reduction='mean')(
                                outputs_logits, labels)

                else:

                    # print(outputs_logits.shape)

                    # print(labels.shape)

                    main_loss = nn.CrossEntropyLoss(
                        reduction='mean',
                        ignore_index=8)(torch.softmax(outputs_logits, dim=1),
                                        labels.squeeze(1))

                running_loss += main_loss

                main_loss.backward()

                optimizer.step()

                # ==============================================================================
                # Calculate training and validation metrics at the last iteration of each epoch
                # ==============================================================================

                if (j + 1) % iteration_amount == 0:

                    if no_class == 2:

                        outputs = torch.sigmoid(outputs_logits)

                        # outputs = (outputs > 0.5).float()

                    else:

                        _, outputs = torch.max(outputs_logits, dim=1)

                        # outputs = outputs.unsqueeze(1)

                        labels = labels.squeeze(1)

                    # print(outputs.shape)

                    # print(labels.shape)

                    # mean_iu = segmentation_scores(labels.cpu().detach().numpy(), outputs.cpu().detach().numpy(), no_class)

                    mean_iu = intersectionAndUnion(outputs.cpu().detach(),
                                                   labels.cpu().detach(),
                                                   no_class)

                    validate_iou, validate_f1, validate_recall, validate_precision = evaluate(
                        data=validate_data,
                        model=model,
                        device=device,
                        class_no=no_class)

                    # print(validate_iou.type)

                    print('Step [{}/{}], '
                          'loss: {:.5f}, '
                          'train iou: {:.5f}, '
                          'val iou: {:.5f}'.format(epoch + 1, epochs,
                                                   running_loss / (j + 1),
                                                   mean_iu, validate_iou))

                    writer.add_scalars(
                        'scalars', {
                            'train iou': mean_iu,
                            'val iou': validate_iou,
                            'val f1': validate_f1,
                            'val recall': validate_recall,
                            'val precision': validate_precision
                        }, epoch + 1)

        else:
            # mix-up strategy requires more calculations:

            for j, (images_1, labels_1, imagename_1, images_2, labels_2,
                    mixed_up_image, lam) in enumerate(train_loader):

                mixed_up_image = mixed_up_image.to(device=device,
                                                   dtype=torch.float32)
                lam = lam.to(device=device, dtype=torch.float32)

                if no_class == 2:
                    labels_1 = labels_1.to(device=device, dtype=torch.float32)
                    labels_2 = labels_2.to(device=device, dtype=torch.float32)
                else:
                    labels_1 = labels_1.to(device=device, dtype=torch.long)
                    labels_2 = labels_2.to(device=device, dtype=torch.long)

                outputs_logits = model(mixed_up_image)

                optimizer.zero_grad()

                # calculate main losses for second time
                if no_class == 2:

                    if loss == 'dice':

                        main_loss = lam * dice_loss(
                            torch.sigmoid(outputs_logits),
                            labels_1) + (1 - lam) * dice_loss(
                                torch.sigmoid(outputs_logits), labels_2)

                    elif loss == 'ce':

                        main_loss = lam * nn.BCEWithLogitsLoss(
                            reduction='mean')(outputs_logits, labels_1) + (
                                1 - lam) * nn.BCEWithLogitsLoss(
                                    reduction='mean')(outputs_logits, labels_2)

                    elif loss == 'hybrid':

                        main_loss = lam * dice_loss(torch.sigmoid(outputs_logits), labels_1) \
                                    + (1 - lam) * dice_loss(torch.sigmoid(outputs_logits), labels_2) \
                                    + lam * nn.BCEWithLogitsLoss(reduction='mean')(outputs_logits, labels_1) \
                                    + (1 - lam) * nn.BCEWithLogitsLoss(reduction='mean')(outputs_logits, labels_2)

                elif no_class == 8:

                    main_loss = lam * nn.CrossEntropyLoss(reduction='mean')(
                        outputs_logits, labels_1.squeeze(1)) + (
                            1 - lam) * nn.CrossEntropyLoss(reduction='mean')(
                                outputs_logits, labels_2.squeeze(1))

                else:
                    main_loss = lam * nn.CrossEntropyLoss(reduction='mean')(
                        outputs_logits, labels_1.squeeze(1)) + (
                            1 - lam) * nn.CrossEntropyLoss(reduction='mean')(
                                outputs_logits, labels_2.squeeze(1))

                running_loss += main_loss.mean()

                main_loss.mean().backward()

                optimizer.step()

                # ==============================================================================
                # Calculate training and validation metrics at the last iteration of each epoch
                # ==============================================================================
                if (j + 1) % iteration_amount == 0:

                    if no_class == 2:

                        outputs = torch.sigmoid(outputs_logits)

                    else:

                        _, outputs = torch.max(outputs_logits, dim=1)

                        outputs = outputs.unsqueeze(1)

                    mean_iu_1 = segmentation_scores(
                        labels_1.cpu().detach().numpy(),
                        outputs.cpu().detach().numpy(), no_class)

                    mean_iu_2 = segmentation_scores(
                        labels_2.cpu().detach().numpy(),
                        outputs.cpu().detach().numpy(), no_class)

                    mean_iu = lam.data.sum() * mean_iu_1 + (
                        1 - lam.data.sum()) * mean_iu_2

                    validate_iou, validate_f1, validate_recall, validate_precision = evaluate(
                        data=validate_data,
                        model=model,
                        device=device,
                        class_no=no_class)

                    mean_iu = mean_iu.item()

                    print('Step [{}/{}], '
                          'loss: {:.4f}, '
                          'train iou: {:.4f}, '
                          'val iou: {:.4f}'.format(epoch + 1, epochs,
                                                   running_loss / (j + 1),
                                                   mean_iu, validate_iou))

                    writer.add_scalars(
                        'scalars', {
                            'train iou': mean_iu,
                            'val iou': validate_iou,
                            'val f1': validate_f1,
                            'val recall': validate_recall,
                            'val precision': validate_precision
                        }, epoch + 1)

        if lr_scedule is True:
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr * ((1 - epoch / epochs)**0.999)

    # save model
    save_folder = '../../saved_models_' + log

    try:

        os.makedirs(save_folder)

    except OSError as exc:

        if exc.errno != errno.EEXIST:

            raise
    pass

    save_model_name = model_name + '_Final'

    save_model_name_full = save_folder + '/' + save_model_name + '.pt'

    torch.save(model, save_model_name_full)
    # =======================================================================
    # testing (disabled during training, because it is too slow)
    # =======================================================================
    save_results_folder = save_folder + '/testing_results_' + model_name

    try:
        os.makedirs(save_results_folder)
    except OSError as exc:
        if exc.errno != errno.EEXIST:
            raise
    pass

    test_iou_1, test_f1_1, test_recall_1, test_precision_1, mse_1, test_iou_2, test_f1_2, test_recall_2, test_precision_2, mse_2, outputs_1, outputs_2 = test(
        data_1=test_data_1,
        data_2=test_data_2,
        model=model,
        device=device,
        class_no=no_class,
        save_location=save_results_folder)

    print('test iou data 1: {:.4f}, '
          'test mse data 1: {:.4f}, '
          'test f1 data 1: {:.4f},'
          'test recall data 1: {:.4f}, '
          'test precision data 1: {:.4f}, '.format(test_iou_1, mse_1,
                                                   test_f1_1, test_recall_1,
                                                   test_precision_1))

    print('test iou data 2: {:.4f}, '
          'test mse data 2: {:.4f}, '
          'test f1 data 2: {:.4f},'
          'test recall data 2: {:.4f}, '
          'test precision data 2: {:.4f}, '.format(test_iou_2, mse_2,
                                                   test_f1_2, test_recall_2,
                                                   test_precision_2))

    print('\nTesting finished and results saved.\n')

    return save_model_name_full
예제 #2
0
def trainSingleModel(model, model_name, num_epochs, learning_rate, datasettag,
                     train_dataset, train_batchsize, trainloader,
                     validateloader, testdata, reverse_mode, lr_schedule,
                     class_no):

    # change log names
    training_amount = len(train_dataset)

    iteration_amount = training_amount // train_batchsize

    iteration_amount = iteration_amount - 1

    device = torch.device('cuda')

    lr_str = str(learning_rate)

    epoches_str = str(num_epochs)

    save_model_name = model_name + '_' + datasettag + '_e' + epoches_str + '_lr' + lr_str

    saved_information_path = './Results'
    try:
        os.mkdir(saved_information_path)
    except OSError as exc:
        if exc.errno != errno.EEXIST:
            raise
        pass

    saved_information_path = saved_information_path + '/' + save_model_name
    try:
        os.mkdir(saved_information_path)
    except OSError as exc:
        if exc.errno != errno.EEXIST:
            raise
        pass

    saved_model_path = saved_information_path + '/trained_models'
    try:
        os.mkdir(saved_model_path)
    except OSError as exc:
        if exc.errno != errno.EEXIST:
            raise
        pass

    print('The current model is:')
    print(save_model_name)
    print('\n')

    writer = SummaryWriter('./Results/Log_' + datasettag + '/' +
                           save_model_name)

    model.to(device)

    threshold = torch.tensor([0.5],
                             dtype=torch.float32,
                             device=device,
                             requires_grad=False)
    upper = torch.tensor([1.0],
                         dtype=torch.float32,
                         device=device,
                         requires_grad=False)
    lower = torch.tensor([0.0],
                         dtype=torch.float32,
                         device=device,
                         requires_grad=False)

    optimizer = torch.optim.AdamW(model.parameters(),
                                  lr=learning_rate,
                                  betas=(0.9, 0.999),
                                  eps=1e-8,
                                  weight_decay=1e-5)

    if lr_schedule is True:

        # scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.2, patience=10, threshold=0.001)
        scheduler = lr_scheduler.MultiStepLR(
            optimizer,
            milestones=[num_epochs // 2, 3 * num_epochs // 4],
            gamma=0.1)

    start = timeit.default_timer()

    for epoch in range(num_epochs):

        model.train()
        train_iou = []
        train_loss = []

        # j: index of iteration
        for j, (images, labels, imagename) in enumerate(trainloader):

            optimizer.zero_grad()

            images = images.to(device=device, dtype=torch.float32)

            if class_no == 2:
                labels = labels.to(device=device, dtype=torch.float32)
            else:
                labels = labels.to(device=device, dtype=torch.long)

            outputs = model(images)

            if class_no == 2:
                prob_outputs = torch.sigmoid(outputs)
                loss = dice_loss(prob_outputs, labels)
                class_outputs = torch.where(prob_outputs > threshold, upper,
                                            lower)
            else:
                prob_outputs = torch.softmax(outputs, dim=1)
                # loss = nn.CrossEntropyLoss(reduction='mean')(prob_outputs, labels)
                loss = nn.CrossEntropyLoss(reduction='mean')(prob_outputs,
                                                             labels.squeeze(1))
                _, class_outputs = torch.max(outputs, dim=1)

            loss.backward()
            optimizer.step()

            mean_iu_, _, __ = segmentation_scores(labels, class_outputs,
                                                  class_no)
            train_iou.append(mean_iu_)
            train_loss.append(loss.item())

        if lr_schedule is True:
            # scheduler.step(validate_iou)
            scheduler.step()
        else:
            pass

        model.eval()

        with torch.no_grad():

            validate_iou = []
            validate_f1 = []
            validate_h_dist = []

            for i, (val_images, val_label,
                    imagename) in enumerate(validateloader):

                val_img = val_images.to(device=device, dtype=torch.float32)

                if class_no == 2:
                    val_label = val_label.to(device=device,
                                             dtype=torch.float32)
                else:
                    val_label = val_label.to(device=device, dtype=torch.long)

                assert torch.max(val_label) != 100.0

                val_outputs = model(val_img)
                if class_no == 2:
                    val_class_outputs = torch.sigmoid(val_outputs)
                    val_class_outputs = (val_class_outputs > 0.5).float()
                else:
                    val_class_outputs = torch.softmax(val_outputs, dim=1)
                    _, val_class_outputs = torch.max(val_class_outputs, dim=1)

                # b, c, h, w = val_label.size()
                # val_class_outputs = val_class_outputs.reshape(b, c, h, w)

                eval_mean_iu_, _, __ = segmentation_scores(
                    val_label, val_class_outputs, class_no)
                eval_f1_, eval_recall_, eval_precision_, eTP, eTN, eFP, eFN, eP, eN = f1_score(
                    val_label, val_class_outputs, class_no)

                validate_iou.append(eval_mean_iu_)
                validate_f1.append(eval_f1_)

                if (val_class_outputs == 1).sum() > 1 and (
                        val_label == 1).sum() > 1 and class_no == 2:
                    v_dist_ = hd95(val_class_outputs, val_label, class_no)
                    validate_h_dist.append(v_dist_)

        print('Step [{}/{}], '
              'Train loss: {:.4f}, '
              'Train iou: {:.4f}, '
              'val iou:{:.4f}, '.format(epoch + 1, num_epochs,
                                        np.nanmean(train_loss),
                                        np.nanmean(train_iou),
                                        np.nanmean(validate_iou)))

        writer.add_scalars(
            'acc metrics', {
                'train iou': np.nanmean(train_iou),
                'val iou': np.nanmean(validate_iou),
                'val f1': np.nanmean(validate_f1)
            }, epoch + 1)

        if epoch > num_epochs - 10:

            save_model_name_full = saved_model_path + '/epoch' + str(epoch)
            save_model_name_full = save_model_name_full + '.pt'
            path_model = save_model_name_full
            torch.save(model, path_model)

    test(testdata,
         saved_model_path,
         device,
         reverse_mode=reverse_mode,
         class_no=class_no,
         save_path=saved_model_path)

    # save model
    stop = timeit.default_timer()
    print('Time: ', stop - start)
    print('\nTraining finished and model saved\n')

    return model
예제 #3
0
def test(testdata, models_path, device, reverse_mode, class_no, save_path):

    all_models = glob.glob(os.path.join(models_path, '*.pt'))

    # with torch.no_grad():

    test_f1 = []
    test_iou = []
    test_h_dist = []
    test_acc = []
    test_w_acc = []
    test_recall = []
    test_precision = []

    test_bf = []

    test_iou_adv = []
    test_h_dist_adv = []

    for model in all_models:

        model = torch.load(model)
        model.eval()

        for j, (testimg, testlabel, testname) in enumerate(testdata):
            # validate batch size will be set up as 2
            # testimg = torch.from_numpy(testimg).to(device=device, dtype=torch.float32)
            # testlabel = torch.from_numpy(testlabel).to(device=device, dtype=torch.float32)

            testimg = testimg.to(device=device, dtype=torch.float32)

            if class_no == 2:
                testlabel = testlabel.to(device=device, dtype=torch.float32)
            else:
                testlabel = testlabel.to(device=device, dtype=torch.long)

            # b, c, h, w = testimg.size()
            # testimg = testimg[:, 0, :, :].view(b, 1, h, w).contiguous()
            # testlabel = testlabel[:, 0, :, :].view(b, 1, h, w).contiguous()

            if torch.max(testlabel) == 255.:
                testlabel = testlabel / 255.

            testimg.requires_grad = True

            threshold = torch.tensor([0.5],
                                     dtype=torch.float32,
                                     device=device,
                                     requires_grad=False)

            upper = torch.tensor([1.0],
                                 dtype=torch.float32,
                                 device=device,
                                 requires_grad=False)

            lower = torch.tensor([0.0],
                                 dtype=torch.float32,
                                 device=device,
                                 requires_grad=False)

            # c, h, w = testimg.size()
            # testimg = testimg.expand(1, c, h, w)

            testoutput = model(testimg)

            if class_no == 2:
                prob_testoutput = torch.sigmoid(testoutput)
                testoutput = (prob_testoutput > 0.5).float()
            else:
                prob_testoutput = torch.softmax(testoutput, dim=1)
                _, testoutput = torch.max(prob_testoutput, dim=1)

            # attack testing data:
            if class_no == 2:
                loss = dice_loss(prob_testoutput, testlabel)
            else:
                loss = nn.CrossEntropyLoss(reduction='mean')(
                    prob_testoutput, testlabel.squeeze(1))

            model.zero_grad()
            loss.backward()
            data_grad = testimg.grad.data
            perturbed_data = fgsm_attack(testimg, 0.2, data_grad)
            output_attack = model(perturbed_data)

            if class_no == 2:
                output_attack = torch.sigmoid(output_attack)
                output_attack = (output_attack > 0.5).float()
            else:
                output_attack = torch.softmax(output_attack, dim=1)
                _, output_attack = torch.max(output_attack, dim=1)

            mean_iu_, acc_, w_acc_ = segmentation_scores(
                testlabel, testoutput, class_no)

            test_iou.append(mean_iu_)
            test_acc.append(acc_)
            test_w_acc.append(w_acc_)

            mean_iu_adv_, _, __ = segmentation_scores(testlabel, output_attack,
                                                      class_no)
            test_iou_adv.append(mean_iu_adv_)

            if (testoutput == 1).sum() > 1 and (
                    testlabel == 1).sum() > 1 and class_no == 2:
                h_dis95_ = hd95(testoutput, testlabel, class_no)
                test_h_dist.append(h_dis95_)

            if (output_attack == 1).sum() > 1 and (
                    testlabel == 1).sum() > 1 and class_no == 2:
                h_dis95_attack_ = hd95(output_attack, testlabel, class_no)
                test_h_dist_adv.append(h_dis95_attack_)

            f1_, recall_, precision_, TP, TN, FP, FN, P, N = f1_score(
                testlabel, testoutput, class_no)

            bf_ = 2 * precision_ * recall_ / (recall_ + precision_)

            test_f1.append(f1_)
            test_recall.append(recall_)
            test_precision.append(precision_)
            test_bf.append(bf_)

    prediction_map_path = save_path + '/Test'

    try:

        os.mkdir(prediction_map_path)

    except OSError as exc:

        if exc.errno != errno.EEXIST:

            raise

        pass

    result_dictionary = {
        'Test IoU mean': str(np.mean(test_iou)),
        'Test IoU std': str(np.std(test_iou)),
        'Test Acc mean': str(np.mean(test_acc)),
        'Test Acc std': str(np.std(test_acc)),
        'Test W ACC mean': str(np.mean(test_w_acc)),
        'Test W ACC std': str(np.std(test_w_acc)),
        'Test BF mean': str(np.mean(test_bf)),
        'Test BF std': str(np.std(test_bf)),
        'Test f1 mean': str(np.mean(test_f1)),
        'Test f1 std': str(np.std(test_f1)),
        'Test H-dist mean': str(np.mean(test_h_dist)),
        'Test H-dist std': str(np.std(test_h_dist)),
        'Test precision mean': str(np.mean(test_precision)),
        'Test precision std': str(np.std(test_precision)),
        'Test recall mean': str(np.mean(test_recall)),
        'Test recall std': str(np.std(test_recall)),
        'Test IoU attack mean': str(np.mean(test_iou_adv)),
        'Test IoU attack std': str(np.std(test_iou_adv)),
        'Test H-dist attack mean': str(np.mean(test_h_dist_adv)),
        'Test H-dist attack std': str(np.std(test_h_dist_adv)),
    }

    ff_path = prediction_map_path + '/test_result_data.txt'
    ff = open(ff_path, 'w')
    ff.write(str(result_dictionary))
    ff.close()

    print('Test h-dist: {:.4f}, '
          'Test iou: {:.4f}, '.format(np.mean(test_h_dist), np.mean(test_iou)))
예제 #4
0
def evaluate(evaluatedata, model, device, reverse_mode, class_no):

    model.eval()

    f1 = 0
    test_iou = 0
    test_h_dist = 0
    recall = 0
    precision = 0

    FPs_Ns = 0
    FNs_Ps = 0
    FPs_Ps = 0
    FNs_Ns = 0
    TPs = 0
    TNs = 0
    FNs = 0
    FPs = 0
    Ps = 0
    Ns = 0

    test_iou_attack = 0
    test_h_dist_attack = 0

    effective_h = 0
    effective_h_attack = 0

    for j, (testimg, testlabel, testname) in enumerate(evaluatedata):
        # validate batch size will be set up as 2
        # j will be close enough to the

        # testimg = testimg.to(device=device, dtype=torch.float32)

        testimg = testimg.to(device=device, dtype=torch.float32)
        testlabel = testlabel.to(device=device, dtype=torch.float32)

        # b, c, h, w = testimg.size()
        # testimg = testimg[:, 0, :, :].view(b, 1, h, w).contiguous()
        # testlabel = testlabel[:, 0, :, :].view(b, 1, h, w).contiguous()

        # if torch.max(testlabel) == 255.:
        #     testlabel = testlabel / 255.

        testimg.requires_grad = True

        # testlabel = testlabel.to(device=device, dtype=torch.float32)

        threshold = torch.tensor([0.5],
                                 dtype=torch.float32,
                                 device=device,
                                 requires_grad=False)

        upper = torch.tensor([1.0],
                             dtype=torch.float32,
                             device=device,
                             requires_grad=False)

        lower = torch.tensor([0.0],
                             dtype=torch.float32,
                             device=device,
                             requires_grad=False)

        testoutput = model(testimg)

        prob_testoutput = torch.sigmoid(testoutput)

        # attack testing data:
        loss = dice_loss(prob_testoutput, testlabel)
        model.zero_grad()
        loss.backward()
        data_grad = testimg.grad.data
        perturbed_data = fgsm_attack(testimg, 0.2, data_grad)
        output_attack = model(perturbed_data)
        output_attack = torch.sigmoid(output_attack)

        if reverse_mode is True:

            testoutput = torch.where(prob_testoutput < threshold, upper, lower)
            output_attack = torch.where(output_attack < threshold, upper,
                                        lower)

        else:

            testoutput = torch.where(prob_testoutput > threshold, upper, lower)
            output_attack = torch.where(output_attack > threshold, upper,
                                        lower)

        mean_iu_, _, __ = segmentation_scores(testlabel, testoutput, class_no)
        mean_iu_attack_, _, __ = segmentation_scores(testlabel, output_attack,
                                                     class_no)

        if (testoutput == 1).sum() > 1 and (testlabel == 1).sum() > 1:

            h_dis95_ = hd95(testoutput, testlabel, class_no)
            test_h_dist += h_dis95_
            effective_h = effective_h + 1

        if (output_attack == 1).sum() > 1 and (testlabel == 1).sum() > 1:

            h_dis95_attack_ = hd95(output_attack, testlabel, class_no)
            effective_h_attack = effective_h_attack + 1
            test_h_dist_attack += h_dis95_attack_

        f1_, recall_, precision_, TP, TN, FP, FN, P, N = f1_score(
            testlabel, testoutput, class_no)

        f1 += f1_
        test_iou += mean_iu_
        recall += recall_
        precision += precision_
        TPs += TP
        TNs += TN
        FPs += FP
        FNs += FN
        Ps += P
        Ns += N
        FNs_Ps += (FN + 1e-10) / (P + 1e-10)
        FPs_Ns += (FP + 1e-10) / (N + 1e-10)
        FNs_Ns += (FN + 1e-10) / (N + 1e-10)
        FPs_Ps += (FP + 1e-10) / (P + 1e-10)

        test_iou_attack += mean_iu_attack_

    return test_iou / (j + 1), f1 / (j + 1), recall / (j + 1), precision / (
        j +
        1), FPs_Ns / (j + 1), FPs_Ps / (j + 1), FNs_Ns / (j + 1), FNs_Ps / (
            j + 1), FPs / (j + 1), FNs / (j + 1), TPs / (j + 1), TNs / (
                j + 1), Ps / (j + 1), Ns / (j + 1), test_h_dist / (
                    effective_h + 1), test_iou_attack / (
                        j + 1), test_h_dist_attack / (effective_h_attack + 1)
예제 #5
0
def trainSingleModel(model,
                     model_name,
                     num_epochs,
                     learning_rate,
                     datasettag,
                     train_dataset,
                     train_batchsize,
                     trainloader,
                     validateloader,
                     testdata,
                     reverse_mode,
                     lr_schedule,
                     class_no):

    # change log names
    training_amount = len(train_dataset)

    iteration_amount = training_amount // train_batchsize

    iteration_amount = iteration_amount - 1

    device = torch.device('cuda')

    lr_str = str(learning_rate)

    epoches_str = str(num_epochs)

    save_model_name = model_name + '_' + datasettag + '_e' + epoches_str + '_lr' + lr_str

    saved_information_path = './Results'
    try:
        os.mkdir(saved_information_path)
    except OSError as exc:
        if exc.errno != errno.EEXIST:
            raise
        pass

    saved_information_path = saved_information_path + '/' + save_model_name
    try:
        os.mkdir(saved_information_path)
    except OSError as exc:
        if exc.errno != errno.EEXIST:
            raise
        pass

    saved_model_path = saved_information_path + '/trained_models'
    try:
        os.mkdir(saved_model_path)
    except OSError as exc:
        if exc.errno != errno.EEXIST:
            raise
        pass

    print('The current model is:')
    print(save_model_name)
    print('\n')

    writer = SummaryWriter('./Results/Log_' + datasettag + '/' + save_model_name)

    model.to(device)

    threshold = torch.tensor([0.5], dtype=torch.float32, device=device, requires_grad=False)
    upper = torch.tensor([1.0], dtype=torch.float32, device=device, requires_grad=False)
    lower = torch.tensor([0.0], dtype=torch.float32, device=device, requires_grad=False)

    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-5)

    if lr_schedule is True:

        # scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.2, patience=10, threshold=0.001)
        scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[num_epochs // 2, 3*num_epochs // 4], gamma=0.1)

    start = timeit.default_timer()

    for epoch in range(num_epochs):

        model.train()

        h_dists = 0
        f1 = 0
        accuracy_iou = 0
        running_loss = 0
        recall = 0
        precision = 0

        t_FPs_Ns = 0
        t_FPs_Ps = 0
        t_FNs_Ns = 0
        t_FNs_Ps = 0
        t_FPs = 0
        t_FNs = 0
        t_TPs = 0
        t_TNs = 0
        t_Ps = 0
        t_Ns = 0

        effective_h = 0

        # j: index of iteration
        for j, (images, labels, imagename) in enumerate(trainloader):

            # check training data:
            # image = images[0, :, :, :].squeeze().detach().cpu().numpy()
            # label = labels[0, :, :, :].squeeze().detach().cpu().numpy()
            # image = np.transpose(image, (1, 2, 0))
            # label = np.expand_dims(label, axis=2)
            # label = np.concatenate((label, label, label), axis=2)
            # plt.imshow(0.5*image + 0.5*label)
            # plt.show()

            optimizer.zero_grad()

            images = images.to(device=device, dtype=torch.float32)
            labels = labels.to(device=device, dtype=torch.float32)

            images.requires_grad = True

            if reverse_mode is True:

                inverse_labels = torch.ones_like(labels)
                inverse_labels = inverse_labels.to(device=device, dtype=torch.float32)
                inverse_labels = inverse_labels - labels
            else:
                pass

            outputs = model(images)
            prob_outputs = torch.sigmoid(outputs)

            if reverse_mode is True:
                loss = dice_loss(prob_outputs, inverse_labels)
            else:
                loss = dice_loss(prob_outputs, labels)

            loss.backward()
            optimizer.step()

            # The taks of binary segmentation is too easy, to compensate the simplicity of the task,
            # we add adversarial noises in the testing images:
            data_grad = images.grad.data
            perturbed_data = fgsm_attack(images, 0.2, data_grad)
            prob_outputs = model(perturbed_data)
            prob_outputs = torch.sigmoid(prob_outputs)

            if reverse_mode is True:
                class_outputs = torch.where(prob_outputs < threshold, upper, lower)
            else:
                class_outputs = torch.where(prob_outputs > threshold, upper, lower)

            if class_no == 2:
                # hasudorff distance is for binary
                if (class_outputs == 1).sum() > 1 and (labels == 1).sum() > 1:
                    dist_ = hd95(class_outputs, labels, class_no)
                    h_dists += dist_
                    effective_h = effective_h + 1
                else:
                    pass
            else:
                pass

            mean_iu_ = segmentation_scores(labels, class_outputs, class_no)
            f1_, recall_, precision_, TPs_, TNs_, FPs_, FNs_, Ps_, Ns_ = f1_score(labels, class_outputs, class_no)

            running_loss += loss
            f1 += f1_
            accuracy_iou += mean_iu_
            recall += recall_
            precision += precision_
            t_TPs += TPs_
            t_TNs += TNs_
            t_FPs += FPs_
            t_FNs += FNs_
            t_Ps += Ps_
            t_Ns += Ns_
            t_FNs_Ps += (FNs_ + 1e-8) / (Ps_ + 1e-8)
            t_FPs_Ns += (FPs_ + 1e-8) / (Ns_ + 1e-8)
            t_FNs_Ns += (FNs_ + 1e-8) / (Ns_ + 1e-8)
            t_FPs_Ps += (FPs_ + 1e-8) / (Ps_ + 1e-8)

            if (j + 1) % iteration_amount == 0:

                validate_iou, validate_f1, validate_recall, validate_precision, v_FPs_Ns, v_FPs_Ps, v_FNs_Ns, v_FNs_Ps, v_FPs, v_FNs, v_TPs, v_TNs, v_Ps, v_Ns, v_h_dist = evaluate(validateloader, model, device, reverse_mode=reverse_mode, class_no=class_no)

                print(
                    'Step [{}/{}], Train loss: {:.4f}, '
                    'Train iou: {:.4f}, '
                    'Train h-dist:{:.4f}, '
                    'Val iou: {:.4f},'
                    'Val h-dist: {:.4f}'.format(epoch + 1, num_epochs,
                                                   running_loss / (j + 1),
                                                   accuracy_iou / (j + 1),
                                                   h_dists / (effective_h + 1),
                                                   validate_iou,
                                                   v_h_dist))

                # # # ================================================================== #
                # # #                        TensorboardX Logging                        #
                # # # # ================================================================ #

                writer.add_scalars('acc metrics', {'train iou': accuracy_iou / (j+1),
                                                   'train hausdorff dist': h_dists / (effective_h+1),
                                                   'val iou': validate_iou,
                                                   'val hasudorff distance': v_h_dist,
                                                   'loss': running_loss / (j+1)}, epoch + 1)

                writer.add_scalars('train confusion matrices analysis', {'train FPs/Ns': t_FPs_Ns / (j+1),
                                                                         'train FNs/Ps': t_FNs_Ps / (j+1),
                                                                         'train FPs/Ps': t_FPs_Ps / (j+1),
                                                                         'train FNs/Ns': t_FNs_Ns / (j+1),
                                                                         'train FNs': t_FNs / (j+1),
                                                                         'train FPs': t_FPs / (j+1),
                                                                         'train TNs': t_TNs / (j+1),
                                                                         'train TPs': t_TPs / (j+1),
                                                                         'train Ns': t_Ns / (j+1),
                                                                         'train Ps': t_Ps / (j+1),
                                                                         'train imbalance': t_Ps / (t_Ps + t_Ns)}, epoch + 1)

                writer.add_scalars('val confusion matrices analysis', {'val FPs/Ns': v_FPs_Ns,
                                                                       'val FNs/Ps': v_FNs_Ps,
                                                                       'val FPs/Ps': v_FPs_Ps,
                                                                       'val FNs/Ns': v_FNs_Ns,
                                                                       'val FNs': v_FNs,
                                                                       'val FPs': v_FPs,
                                                                       'val TNs': v_TNs,
                                                                       'val TPs': v_TPs,
                                                                       'val Ns': v_Ns,
                                                                       'val Ps': v_Ps,
                                                                       'val imbalance': v_Ps / (v_Ps + v_Ns)}, epoch + 1)
            else:
                pass

            # A learning rate schedule plan for fn attention:
            # we ramp-up linearly inside of each iteration
            # without the warm-up, it is hard to train sometimes
            if 'fn' in model_name or 'FN' in model_name:
                if reverse_mode is True:
                    if epoch < 10:
                        for param_group in optimizer.param_groups:
                            param_group['lr'] = learning_rate * (j / len(trainloader))
                    else:
                        pass
                else:
                    pass
            else:
                pass

        if lr_schedule is True:
            scheduler.step()
        else:
            pass

        # save models at last 10 epochs
        if epoch >= (num_epochs - 10):
            save_model_name_full = saved_model_path + '/' + save_model_name + '_epoch' + str(epoch) + '.pt'
            path_model = save_model_name_full
            torch.save(model, path_model)

    # Test on all models and average them:
    test(testdata,
         saved_model_path,
         device,
         reverse_mode=reverse_mode,
         class_no=class_no,
         save_path=saved_information_path)

    # save model
    stop = timeit.default_timer()
    print('Time: ', stop - start)
    print('\n')
    print('\nTraining finished and model saved\n')

    return model
def test(
        testdata,
         models_path,
         device,
         reverse_mode,
         class_no,
         save_path):

    all_models = glob.glob(os.path.join(models_path, '*.pt'))

    test_f1 = []
    test_iou = []
    test_h_dist = []
    test_recall = []
    test_precision = []

    test_iou_adv = []
    test_h_dist_adv = []

    for model in all_models:

        model = torch.load(model)
        model.eval()

        for j, (testimg, testlabel, testname) in enumerate(testdata):
            # validate batch size will be set up as 2
            # testimg = torch.from_numpy(testimg).to(device=device, dtype=torch.float32)
            # testlabel = torch.from_numpy(testlabel).to(device=device, dtype=torch.float32)

            testimg = testimg.to(device=device, dtype=torch.float32)
            testimg.requires_grad = True

            testlabel = testlabel.to(device=device, dtype=torch.float32)

            threshold = torch.tensor([0.5], dtype=torch.float32, device=device, requires_grad=False)
            upper = torch.tensor([1.0], dtype=torch.float32, device=device, requires_grad=False)
            lower = torch.tensor([0.0], dtype=torch.float32, device=device, requires_grad=False)

            # c, h, w = testimg.size()
            # testimg = testimg.expand(1, c, h, w)

            testoutput = model(testimg)
            # (todo) add for multi-class
            prob_testoutput = torch.sigmoid(testoutput)

            if class_no == 2:

                if reverse_mode is True:

                    testoutput = torch.where(prob_testoutput < threshold, upper, lower)

                else:

                    testoutput = torch.where(prob_testoutput > threshold, upper, lower)

            # metrics before attack:
            mean_iu_ = segmentation_scores(testlabel, testoutput, class_no)
            test_iou.append(mean_iu_)

            if (testoutput == 1).sum() > 1 and (testlabel == 1).sum() > 1:
                h_dis95_ = hd95(testoutput, testlabel, class_no)
                test_h_dist.append(h_dis95_)

            f1_, recall_, precision_, TP, TN, FP, FN, P, N = f1_score(testlabel, testoutput, class_no)
            test_f1.append(f1_)
            test_recall.append(recall_)
            test_precision.append(precision_)

            # attack testing data:
            loss = dice_loss(prob_testoutput, testlabel)
            model.zero_grad()
            loss.backward()
            data_grad = testimg.grad.data
            perturbed_data = fgsm_attack(testimg, 0.2, data_grad)
            prob_testoutput_adv = model(perturbed_data)
            prob_testoutput_adv = torch.sigmoid(prob_testoutput_adv)

            if class_no == 2:

                if reverse_mode is True:

                    testoutput_adv = torch.where(prob_testoutput_adv < threshold, upper, lower)

                else:

                    testoutput_adv = torch.where(prob_testoutput_adv > threshold, upper, lower)

            mean_iu_adv_ = segmentation_scores(testlabel, testoutput_adv, class_no)
            test_iou_adv.append(mean_iu_adv_)

            if (testoutput_adv == 1).sum() > 1 and (testlabel == 1).sum() > 1:
                h_dis95_adv_ = hd95(testoutput_adv, testlabel, class_no)
                test_h_dist_adv.append(h_dis95_adv_)

    # store the test metrics

    prediction_map_path = save_path + '/Test_result'

    try:

        os.mkdir(prediction_map_path)

    except OSError as exc:

        if exc.errno != errno.EEXIST:

            raise

        pass

    # save numerical results:
    result_dictionary = {
        'Test IoU mean': str(np.mean(test_iou)),
        'Test IoU std': str(np.std(test_iou)),
        'Test f1 mean': str(np.mean(test_f1)),
        'Test f1 std': str(np.std(test_f1)),
        'Test H-dist mean': str(np.mean(test_h_dist)),
        'Test H-dist std': str(np.std(test_h_dist)),
        'Test precision mean': str(np.mean(test_precision)),
        'Test precision std': str(np.std(test_precision)),
        'Test recall mean': str(np.mean(test_recall)),
        'Test recall std': str(np.std(test_recall)),
        'Test IoU attack mean': str(np.mean(test_iou_adv)),
        'Test IoU attack std': str(np.std(test_iou_adv)),
        'Test H-dist attack mean': str(np.mean(test_h_dist_adv)),
        'Test H-dist attack std': str(np.std(test_h_dist_adv)),
    }

    ff_path = prediction_map_path + '/test_results.txt'
    ff = open(ff_path, 'w')
    ff.write(str(result_dictionary))
    ff.close()

    print(
        'Test h-dist: {:.4f}, '
        'Val iou: {:.4f}, '.format(np.mean(test_h_dist), np.mean(test_iou)))