Example #1
0
def myloader(args):
    if args.dataset == "ISLES":
        training_dir = './Data/ISLES_TRAINING_png'
        transform = transforms.Compose([transforms.ToTensor()])
        mask_transform = transforms.Compose([transforms.ToTensor()])
        train_set = medicalDataLoader.MedicalImageDataset(
            'train',
            training_dir,
            transform=transform,
            mask_transform=mask_transform,
            augment=False,
            equalize=False)
        val_set = medicalDataLoader.MedicalImageDataset(
            'val',
            training_dir,
            transform=transform,
            mask_transform=mask_transform,
            equalize=False)
        train_loader = DataLoader(train_set,
                                  batch_size=args.batch_size,
                                  num_workers=5,
                                  shuffle=True)
        val_loader = DataLoader(val_set,
                                batch_size=args.batch_size,
                                num_workers=5,
                                shuffle=False)
        training_length = len(train_set)
        val_length = len(val_set)
        num_classes = 2
        return train_loader, val_loader, training_length, val_length, num_classes
    else:
        print("The dataset is not supported.")
        raise NotImplementedError
Example #2
0
def runInference(argv):
    print('-' * 40)
    print('~~~~~~~~  Starting the inference... ~~~~~~')
    print('-' * 40)

    batch_size_val = 1
    batch_size_val_save = 1
    batch_size_val_savePng = 1

    transform = transforms.Compose([transforms.ToTensor()])

    mask_transform = transforms.Compose([transforms.ToTensor()])

    root_dir = '../DataSet/Bladder_Aug'
    modelName = 'UNetG_Dilated_Progressive'

    val_set = medicalDataLoader.MedicalImageDataset(
        'val',
        root_dir,
        transform=transform,
        mask_transform=mask_transform,
        equalize=False)

    val_loader = DataLoader(val_set,
                            batch_size=batch_size_val,
                            num_workers=5,
                            shuffle=False)

    val_loader_save_images = DataLoader(val_set,
                                        batch_size=batch_size_val_save,
                                        num_workers=5,
                                        shuffle=False)

    modelName = argv[0]

    print('...Loading model...')
    try:
        netG = torch.load(modelName)
        print("--------model restored--------")
    except:
        print("--------model not restored--------")
        pass

    netG.cuda()

    modelName_dir = argv[1]

    # To save images as png
    saveImages(netG, val_loader_save_images, batch_size_val_save, 0, modelName)

    # To save images as Matlab
    saveImagesAsMatlab(netG, val_loader_save_images, batch_size_val_save, 0,
                       modelName)

    print("###                               ###")
    print("###   Images saved      ###")
    print("###                               ###")
Example #3
0
def main():
    ## Here we have to split the fully annotated dataset and unannotated dataset
    split_ratio = 0.2
    random_index = np.random.permutation(len(train_set))
    labeled_dataset = copy.deepcopy(train_set)
    labeled_dataset.imgs = [train_set.imgs[x] for x in random_index[:int(len(random_index) * split_ratio)]]
    unlabeled_dataset = copy.deepcopy(train_set)
    unlabeled_dataset.imgs = [train_set.imgs[x] for x in random_index[int(len(random_index) * split_ratio):]]
    assert set(unlabeled_dataset.imgs) & set(
        labeled_dataset.imgs) == set(), \
        "there's intersection between labeled and unlabeled training set."

    labeled_dataLoader = DataLoader(labeled_dataset, batch_size=1, num_workers=num_workers, shuffle=True)
    unlabeled_dataLoader = DataLoader(unlabeled_dataset, batch_size=1, num_workers=num_workers, shuffle=True)
    ## Here we terminate the split of labeled and unlabeled data
    ## the validation set is for computing the dice loss.
    val_set = medicalDataLoader.MedicalImageDataset('val', root_dir, transform=transform, mask_transform=mask_transform,
                                                    equalize=False)
    val_loader = DataLoader(val_set, batch_size=batch_size_val, num_workers=num_workers, shuffle=True)

    ##
    ##=====================================================================================================================#
    # np.random.choice(labeled_dataset)

    neural_net = Enet(2)
    ## Uncomment the following line to pretrain the model with few fully labeled data.
    # pretrain(labeled_dataLoader,neural_net,)

    map_location = lambda storage, loc: storage
    neural_net.load_state_dict(torch.load('../checkpoint/pretrained_net.pth', map_location=map_location))
    neural_net.to(device)
    plt.ion()
    for iteration in xrange(300):
        ## choose randomly a batch of image from labeled dataset and unlabeled dataset.
        # Initialize the ADMM dummy variables for one-batch training
        labeled_dataLoader, unlabeled_dataLoader = iter(labeled_dataLoader), iter(unlabeled_dataLoader)
        labeled_img, labeled_mask, labeled_weak_mask = next(labeled_dataLoader)[0:3]
        labeled_img, labeled_mask, labeled_weak_mask = labeled_img.to(device), labeled_mask.to(
            device), labeled_weak_mask.to(device)
        unlabeled_img, unlabeled_mask = next(unlabeled_dataLoader)[0:2]
        unlabeled_img, unlabeled_mask = unlabeled_img.to(device), unlabeled_mask.to(device)
        # skip those with no foreground masks
        if labeled_mask.sum() == 0 or unlabeled_mask.sum() == 0:
            continue

        net = networks(neural_net, lowerbound=10, upperbound=1000)
        for i in xrange(300):
            net.update((labeled_img, labeled_mask), (unlabeled_img, unlabeled_mask))
            # net.show_labeled_pair()
            net.show_ublabel_image()
            net.show_gamma()
            net.show_u()

        net.reset()
Example #4
0
def main():
    ## Here we have to split the fully annotated dataset and unannotated dataset
    split_ratio = 0.2
    random_index = np.random.permutation(len(train_set))
    labeled_dataset = copy.deepcopy(train_set)
    labeled_dataset.imgs = [train_set.imgs[x] for x in random_index[:int(len(random_index) * split_ratio)]]
    unlabeled_dataset = copy.deepcopy(train_set)
    unlabeled_dataset.imgs = [train_set.imgs[x] for x in random_index[int(len(random_index) * split_ratio):]]
    assert set(unlabeled_dataset.imgs) & set(
        labeled_dataset.imgs) == set(), \
        "there's intersection between labeled and unlabeled training set."
    
    labeled_dataLoader = DataLoader(labeled_dataset, batch_size=1, num_workers=num_workers, shuffle=True)
    unlabeled_dataLoader = DataLoader(unlabeled_dataset, batch_size=1, num_workers=num_workers, shuffle=True)
    ## Here we terminate the split of labeled and unlabeled data
    ## the validation set is for computing the dice loss.
    val_set = medicalDataLoader.MedicalImageDataset('val', root_dir, transform=transform, mask_transform=mask_transform,
                                                    equalize=False)
    val_loader = DataLoader(val_set, batch_size=batch_size_val, num_workers=num_workers, shuffle=True)

    ##
    ##=====================================================================================================================#
    # np.random.choice(labeled_dataset)

    global net
    net = Enet(2)
    ## Uncomment the following line to pretrain the model with few fully labeled data.
    # pretrain(labeled_dataLoader,net,)
    map_location = lambda storage, loc: storage
    net.load_state_dict(torch.load('checkpoint/pretrained_net.pth', map_location=map_location))
    net.to(device)
    # optimiser = torch.optim.Adam(net.parameters(),lr = lr, weight_decay=1e-5)

    global labeled_img, labeled_mask, labeled_weak_mask, unlabeled_img, unlabeled_mask
    for iteration in xrange(10000):
        ## choose randomly a batch of image from labeled dataset and unlabeled dataset.
        # Initialize the ADMM dummy variables for one-batch training
        labeled_dataLoader, unlabeled_dataLoader = iter(labeled_dataLoader), iter(unlabeled_dataLoader)
        labeled_img, labeled_mask, labeled_weak_mask = next(labeled_dataLoader)[0:3]
        labeled_img, labeled_mask, labeled_weak_mask = labeled_img.to(device), labeled_mask.to(device), labeled_weak_mask.to(device)
        unlabeled_img, unlabeled_mask = next(unlabeled_dataLoader)[0:2]
        unlabeled_img, unlabeled_mask = unlabeled_img.to(device), unlabeled_mask.to(device)
        if labeled_mask.sum() == 0 or unlabeled_mask.sum() == 0:
            # skip those with no foreground masks
            continue
        f_theta_labeled = net(labeled_img)  # shape b,c,w,h
        f_theta_unlabeled = net(unlabeled_img)  # b,c,w,h
        gamma = pred2segmentation(f_theta_unlabeled).detach()  # b, w, h
        s = gamma  # b w h
        u = np.zeros(list(gamma.shape))  # b w h
        v = np.zeros(u.shape)  # b w h
        global u_r, u_s
        u_r = 1
        u_s = 1

        for i in xrange(200):
            # Finalise the initialization of ADMM dummy variable
            f_theta_labeled, f_theta_unlabeled = update_theta(f_theta_labeled, f_theta_unlabeled, gamma, s, u, v, )
            gamma = update_gamma(f_theta_labeled, f_theta_unlabeled, gamma, s, u, v)
            s = update_s(f_theta_labeled, f_theta_unlabeled, gamma, s, u, v)
            u = update_u(f_theta_labeled, f_theta_unlabeled, gamma, s, u, v)
            v = update_v(f_theta_labeled, f_theta_unlabeled, gamma, s, u, v)

            show_image_mask(labeled_img,labeled_mask,f_theta_labeled[:,1,:,:])
            show_image_mask(unlabeled_img,unlabeled_mask,f_theta_unlabeled[:,1,:,:])
            print()
Example #5
0
root_dir = '../ACDC-2D-All'
model_dir = 'model'
size_min = 5
size_max = 20

cuda_device = "0"
os.environ["CUDA_VISIBLE_DEVICES"] = cuda_device
color_transform = Colorize()
transform = transforms.Compose([
    transforms.ToTensor()
])
mask_transform = transforms.Compose([
    transforms.ToTensor()
])

train_set = medicalDataLoader.MedicalImageDataset('train', root_dir, transform=transform, mask_transform=mask_transform,
                                                  augment=True, equalize=False)


def main():
    ## Here we have to split the fully annotated dataset and unannotated dataset
    split_ratio = 0.2
    random_index = np.random.permutation(len(train_set))
    labeled_dataset = copy.deepcopy(train_set)
    labeled_dataset.imgs = [train_set.imgs[x] for x in random_index[:int(len(random_index) * split_ratio)]]
    unlabeled_dataset = copy.deepcopy(train_set)
    unlabeled_dataset.imgs = [train_set.imgs[x] for x in random_index[int(len(random_index) * split_ratio):]]
    assert set(unlabeled_dataset.imgs) & set(
        labeled_dataset.imgs) == set(), \
        "there's intersection between labeled and unlabeled training set."
    
    labeled_dataLoader = DataLoader(labeled_dataset, batch_size=1, num_workers=num_workers, shuffle=True)
Example #6
0
def runTraining(args):
    print('-' * 40)
    print('~~~~~~~~  Starting the training... ~~~~~~')
    print('-' * 40)

    batch_size = args.batch_size
    batch_size_val = 1
    batch_size_val_save = 1
    lr = args.lr

    epoch = args.epochs
    root_dir = './DataSet_Challenge/Val_1'
    model_dir = 'model'
    
    print(' Dataset: {} '.format(root_dir))

    transform = transforms.Compose([
        transforms.ToTensor()
    ])

    mask_transform = transforms.Compose([
        transforms.ToTensor()
    ])

    train_set = medicalDataLoader.MedicalImageDataset('train',
                                                      root_dir,
                                                      transform=transform,
                                                      mask_transform=mask_transform,
                                                      augment=True,
                                                      equalize=False)

    train_loader = DataLoader(train_set,
                              batch_size=batch_size,
                              num_workers=5,
                              shuffle=True)

    val_set = medicalDataLoader.MedicalImageDataset('val',
                                                    root_dir,
                                                    transform=transform,
                                                    mask_transform=mask_transform,
                                                    equalize=False)

    val_loader = DataLoader(val_set,
                            batch_size=batch_size_val,
                            num_workers=5,
                            shuffle=False)
                                                   
    val_loader_save_images = DataLoader(val_set,
                                        batch_size=batch_size_val_save,
                                        num_workers=4,
                                        shuffle=False)

                                                                    
    # Initialize
    print("~~~~~~~~~~~ Creating the DAF Stacked model ~~~~~~~~~~")
    net = DAF_stack()
    print(" Model Name: {}".format(args.modelName))
    print(" Model ot create: DAF_Stacked")

    net.apply(weights_init)

    softMax = nn.Softmax()
    CE_loss = nn.CrossEntropyLoss()
    Dice_loss = computeDiceOneHot()
    mseLoss = nn.MSELoss()

    if torch.cuda.is_available():
        net.cuda()
        softMax.cuda()
        CE_loss.cuda()
        Dice_loss.cuda()

    optimizer = Adam(net.parameters(), lr=lr, betas=(0.9, 0.99), amsgrad=False)

    BestDice, BestEpoch = 0, 0
    BestDice3D = [0,0,0,0]

    d1Val = []
    d2Val = []
    d3Val = []
    d4Val = []

    d1Val_3D = []
    d2Val_3D = []
    d3Val_3D = []
    d4Val_3D = []

    d1Val_3D_std = []
    d2Val_3D_std = []
    d3Val_3D_std = []
    d4Val_3D_std = []

    Losses = []

    print("~~~~~~~~~~~ Starting the training ~~~~~~~~~~")
    for i in range(epoch):
        net.train()
        lossVal = []

        totalImages = len(train_loader)
       
        for j, data in enumerate(train_loader):
            image, labels, img_names = data

            # prevent batchnorm error for batch of size 1
            if image.size(0) != batch_size:
                continue

            optimizer.zero_grad()
            MRI = to_var(image)
            Segmentation = to_var(labels)

            ################### Train ###################
            net.zero_grad()

            # Network outputs
            semVector_1_1, \
            semVector_2_1, \
            semVector_1_2, \
            semVector_2_2, \
            semVector_1_3, \
            semVector_2_3, \
            semVector_1_4, \
            semVector_2_4, \
            inp_enc0, \
            inp_enc1, \
            inp_enc2, \
            inp_enc3, \
            inp_enc4, \
            inp_enc5, \
            inp_enc6, \
            inp_enc7, \
            out_enc0, \
            out_enc1, \
            out_enc2, \
            out_enc3, \
            out_enc4, \
            out_enc5, \
            out_enc6, \
            out_enc7, \
            outputs0, \
            outputs1, \
            outputs2, \
            outputs3, \
            outputs0_2, \
            outputs1_2, \
            outputs2_2, \
            outputs3_2 = net(MRI)

            segmentation_prediction = (outputs0 + outputs1 + outputs2 + outputs3 + outputs0_2 + outputs1_2 + outputs2_2 + outputs3_2) / 8
            predClass_y = softMax(segmentation_prediction)

            Segmentation_planes = getOneHotSegmentation(Segmentation)

            segmentation_prediction_ones = predToSegmentation(predClass_y)

            # It needs the logits, not the softmax
            Segmentation_class = getTargetSegmentation(Segmentation)

            # Cross-entropy loss
            loss0 = CE_loss(outputs0, Segmentation_class)
            loss1 = CE_loss(outputs1, Segmentation_class)
            loss2 = CE_loss(outputs2, Segmentation_class)
            loss3 = CE_loss(outputs3, Segmentation_class)
            loss0_2 = CE_loss(outputs0_2, Segmentation_class)
            loss1_2 = CE_loss(outputs1_2, Segmentation_class)
            loss2_2 = CE_loss(outputs2_2, Segmentation_class)
            loss3_2 = CE_loss(outputs3_2, Segmentation_class)

            lossSemantic1 = mseLoss(semVector_1_1, semVector_2_1)
            lossSemantic2 = mseLoss(semVector_1_2, semVector_2_2)
            lossSemantic3 = mseLoss(semVector_1_3, semVector_2_3)
            lossSemantic4 = mseLoss(semVector_1_4, semVector_2_4)

            lossRec0 = mseLoss(inp_enc0, out_enc0)
            lossRec1 = mseLoss(inp_enc1, out_enc1)
            lossRec2 = mseLoss(inp_enc2, out_enc2)
            lossRec3 = mseLoss(inp_enc3, out_enc3)
            lossRec4 = mseLoss(inp_enc4, out_enc4)
            lossRec5 = mseLoss(inp_enc5, out_enc5)
            lossRec6 = mseLoss(inp_enc6, out_enc6)
            lossRec7 = mseLoss(inp_enc7, out_enc7)

            lossG = loss0 + loss1 + loss2 + loss3 + loss0_2 + loss1_2 + loss2_2 + loss3_2 + 0.25 * (
            lossSemantic1 + lossSemantic2 + lossSemantic3 + lossSemantic4) \
            + 0.1 * (lossRec0 + lossRec1 + lossRec2 + lossRec3 + lossRec4 + lossRec5 + lossRec6 + lossRec7)  # CE_lossG

            # Compute the DSC
            DicesN, DicesB, DicesW, DicesT, DicesZ = Dice_loss(segmentation_prediction_ones, Segmentation_planes)

            DiceB = DicesToDice(DicesB)
            DiceW = DicesToDice(DicesW)
            DiceT = DicesToDice(DicesT)
            DiceZ = DicesToDice(DicesZ)

            Dice_score = (DiceB + DiceW + DiceT+ DiceZ) / 4

            lossG.backward()
            optimizer.step()
            
            lossVal.append(lossG.cpu().data.numpy())

            printProgressBar(j + 1, totalImages,
                             prefix="[Training] Epoch: {} ".format(i),
                             length=15,
                             suffix=" Mean Dice: {:.4f}, Dice1: {:.4f} , Dice2: {:.4f}, , Dice3: {:.4f}, Dice4: {:.4f} ".format(
                                 Dice_score.cpu().data.numpy(),
                                 DiceB.data.cpu().data.numpy(),
                                 DiceW.data.cpu().data.numpy(),
                                 DiceT.data.cpu().data.numpy(),
                                 DiceZ.data.cpu().data.numpy(),))

      
        printProgressBar(totalImages, totalImages,
                             done="[Training] Epoch: {}, LossG: {:.4f}".format(i,np.mean(lossVal)))
       
        # Save statistics
        modelName = args.modelName
        directory = 'Results/Statistics/' + modelName
        
        Losses.append(np.mean(lossVal))
        
        d1,d2,d3,d4 = inference(net, val_loader)
        
        d1Val.append(d1)
        d2Val.append(d2)
        d3Val.append(d3)
        d4Val.append(d4)

        if not os.path.exists(directory):
            os.makedirs(directory)

        np.save(os.path.join(directory, 'Losses.npy'), Losses)
        np.save(os.path.join(directory, 'd1Val.npy'), d1Val)
        np.save(os.path.join(directory, 'd2Val.npy'), d2Val)
        np.save(os.path.join(directory, 'd3Val.npy'), d3Val)

        currentDice = (d1+d2+d3+d4)/4

        print("[val] DSC: (1): {:.4f} (2): {:.4f}  (3): {:.4f} (4): {:.4f}".format(d1,d2,d3,d4)) # MRI

        currentDice = currentDice.data.numpy()

        # Evaluate on 3D
        saveImages_for3D(net, val_loader_save_images, batch_size_val_save, 1000, modelName, False, False)
        reconstruct3D(modelName, 1000, isBest=False)
        DSC_3D = evaluate3D(modelName)

        mean_DSC3D = np.mean(DSC_3D, 0)
        std_DSC3D = np.std(DSC_3D,0)

        d1Val_3D.append(mean_DSC3D[0])
        d2Val_3D.append(mean_DSC3D[1])
        d3Val_3D.append(mean_DSC3D[2])
        d4Val_3D.append(mean_DSC3D[3])
        d1Val_3D_std.append(std_DSC3D[0])
        d2Val_3D_std.append(std_DSC3D[1])
        d3Val_3D_std.append(std_DSC3D[2])
        d4Val_3D_std.append(std_DSC3D[3])

        np.save(os.path.join(directory, 'd0Val_3D.npy'), d1Val_3D)
        np.save(os.path.join(directory, 'd1Val_3D.npy'), d2Val_3D)
        np.save(os.path.join(directory, 'd2Val_3D.npy'), d3Val_3D)
        np.save(os.path.join(directory, 'd3Val_3D.npy'), d4Val_3D)
        
        np.save(os.path.join(directory, 'd0Val_3D_std.npy'), d1Val_3D_std)
        np.save(os.path.join(directory, 'd1Val_3D_std.npy'), d2Val_3D_std)
        np.save(os.path.join(directory, 'd2Val_3D_std.npy'), d3Val_3D_std)
        np.save(os.path.join(directory, 'd3Val_3D_std.npy'), d4Val_3D_std)


        if currentDice > BestDice:
            BestDice = currentDice

            BestEpoch = i
            
            if currentDice > 0.40:

                if np.mean(mean_DSC3D)>np.mean(BestDice3D):
                    BestDice3D = mean_DSC3D

                print("###    In 3D -----> MEAN: {}, Dice(1): {:.4f} Dice(2): {:.4f} Dice(3): {:.4f} Dice(4): {:.4f}   ###".format(np.mean(mean_DSC3D),mean_DSC3D[0], mean_DSC3D[1], mean_DSC3D[2], mean_DSC3D[3]))

                print("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Saving best model..... ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~")
                if not os.path.exists(model_dir):
                    os.makedirs(model_dir)
                torch.save(net.state_dict(), os.path.join(model_dir, "Best_" + modelName + ".pth"),pickle_module=dill)
                reconstruct3D(modelName, 1000, isBest=True)

        print("###                                                       ###")
        print("###    Best Dice: {:.4f} at epoch {} with Dice(1): {:.4f} Dice(2): {:.4f} Dice(3): {:.4f} Dice(4): {:.4f}   ###".format(BestDice, BestEpoch, d1,d2,d3,d4))
        print("###    Best Dice in 3D: {:.4f} with Dice(1): {:.4f} Dice(2): {:.4f} Dice(3): {:.4f} Dice(4): {:.4f} ###".format(np.mean(BestDice3D),BestDice3D[0], BestDice3D[1], BestDice3D[2], BestDice3D[3] ))
        print("###                                                       ###")

        if i % (BestEpoch + 50) == 0:
            for param_group in optimizer.param_groups:
                lr = lr*0.5
                param_group['lr'] = lr
                print(' ----------  New learning Rate: {}'.format(lr))
def runTraining():
    print('-' * 40)
    print('~~~~~~~~  Starting... ~~~~~~')
    print('-' * 40)

    batch_size = 1

    transform = transforms.Compose([transforms.ToTensor()])

    mask_transform = transforms.Compose([MaskToTensor()])

    #root_dir = '/home/AN82520/Projects/pyTorch/SegmentationFramework/DataSet/MICCAI_Bladder'
    #dest_dir = '/home/AN82520/Projects/pyTorch/SegmentationFramework/DataSet/Bladder_Aug'

    root_dir = '/export/livia/home/vision/jdolz/Projects/pyTorch/Corstem/ACDC-2D'
    dest_dir = '/export/livia/home/vision/jdolz/Projects/pyTorch/Corstem/ACDC-2D_Augmented'

    if not os.path.exists(dest_dir):
        os.makedirs(dest_dir)

        os.makedirs(dest_dir + '/train/Img')
        os.makedirs(dest_dir + '/train/GT')
        os.makedirs(dest_dir + '/val/Img')
        os.makedirs(dest_dir + '/val/GT')

    train_set = medicalDataLoader.MedicalImageDataset('train',
                                                      root_dir,
                                                      transform=transform,
                                                      mask_transform=transform)
    train_loader = DataLoader(train_set,
                              batch_size=batch_size,
                              num_workers=1,
                              shuffle=True)

    val_set = medicalDataLoader.MedicalImageDataset('val',
                                                    root_dir,
                                                    transform=transform,
                                                    mask_transform=transform)
    val_loader = DataLoader(val_set,
                            batch_size=batch_size,
                            num_workers=1,
                            shuffle=False)

    print(" ~~~~~~~~~~~ Augmenting dataset ~~~~~~~~~~")
    for data in train_loader:
        img_Size = 256  # HEART
        #img_Size = 320  # BLADDER
        image, labels, img_path = data
        image *= 255
        labels *= 255
        # Non-modified images
        img = Image.fromarray(image.numpy()[0].reshape((img_Size, img_Size)))
        mask = Image.fromarray(labels.numpy()[0].reshape((img_Size, img_Size)))
        # pdb.set_trace()

        #image, labels = data

        image, labels = augment(
            image.numpy()[0].reshape((img_Size, img_Size)),
            labels.numpy()[0].reshape((img_Size, img_Size)))

        name2save = img_path[0].split('.png')
        mainPath = name2save[0].split('Img')
        nameImage = mainPath[1]
        mainPath = mainPath[0]

        img = img.convert('RGB')
        img.save(dest_dir + '/train/Img' + nameImage + '.png', "PNG")
        mask = mask.convert('RGB')
        mask.save(dest_dir + '/train/GT' + nameImage + '.png', "PNG")

        image = image.convert('RGB')
        image.save(dest_dir + '/train/Img' + nameImage + '_Augm.png', "PNG")
        labels = labels.convert('RGB')
        labels.save(dest_dir + '/train/GT' + nameImage + '_Augm.png', "PNG")
Example #8
0
def runTraining():
    print('-' * 40)
    print('~~~~~~~~  Starting the training... ~~~~~~')
    print('-' * 40)

    batch_size = 4
    batch_size_val = 1
    batch_size_val_save = 1
    batch_size_val_savePng = 4
    lr = 0.0001
    epoch = 1000
    root_dir = '../DataSet/Bladder_Aug'
    modelName = 'UNetG_Dilated_Progressive'
    model_dir = 'model'

    transform = transforms.Compose([transforms.ToTensor()])

    mask_transform = transforms.Compose([transforms.ToTensor()])

    train_set = medicalDataLoader.MedicalImageDataset(
        'train',
        root_dir,
        transform=transform,
        mask_transform=mask_transform,
        augment=False,
        equalize=False)

    train_loader = DataLoader(train_set,
                              batch_size=batch_size,
                              num_workers=5,
                              shuffle=True)

    val_set = medicalDataLoader.MedicalImageDataset(
        'val',
        root_dir,
        transform=transform,
        mask_transform=mask_transform,
        equalize=False)

    val_loader = DataLoader(val_set,
                            batch_size=batch_size_val,
                            num_workers=5,
                            shuffle=False)

    val_loader_save_images = DataLoader(val_set,
                                        batch_size=batch_size_val_save,
                                        num_workers=5,
                                        shuffle=False)

    val_loader_save_imagesPng = DataLoader(val_set,
                                           batch_size=batch_size_val_savePng,
                                           num_workers=5,
                                           shuffle=False)
    # Initialize
    print("~~~~~~~~~~~ Creating the model ~~~~~~~~~~")
    num_classes = 4

    initial_kernels = 32

    # Load network
    netG = UNetG_Dilated_Progressive(1, initial_kernels, num_classes)
    softMax = nn.Softmax()
    CE_loss = nn.CrossEntropyLoss()
    Dice_loss = computeDiceOneHot()

    if torch.cuda.is_available():
        netG.cuda()
        softMax.cuda()
        CE_loss.cuda()
        Dice_loss.cuda()
    '''try:
        netG = torch.load('./model/Best_UNetG_Dilated_Progressive_Stride_Residual_ChannelsFirst32.pkl')
        print("--------model restored--------")
    except:
        print("--------model not restored--------")
        pass'''

    optimizerG = Adam(netG.parameters(),
                      lr=lr,
                      betas=(0.9, 0.99),
                      amsgrad=False)

    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizerG,
                                                           mode='max',
                                                           patience=4,
                                                           verbose=True,
                                                           factor=10**-0.5)

    BestDice, BestEpoch = 0, 0

    d1Train = []
    d2Train = []
    d3Train = []
    d1Val = []
    d2Val = []
    d3Val = []

    Losses = []
    Losses1 = []
    Losses05 = []
    Losses025 = []
    Losses0125 = []

    print("~~~~~~~~~~~ Starting the training ~~~~~~~~~~")
    for i in range(epoch):
        netG.train()
        lossVal = []
        lossValD = []
        lossVal1 = []
        lossVal05 = []
        lossVal025 = []
        lossVal0125 = []

        d1TrainTemp = []
        d2TrainTemp = []
        d3TrainTemp = []

        timesAll = []
        success = 0
        totalImages = len(train_loader)

        for j, data in enumerate(train_loader):
            image, labels, img_names = data

            # prevent batchnorm error for batch of size 1
            if image.size(0) != batch_size:
                continue

            optimizerG.zero_grad()
            MRI = to_var(image)
            Segmentation = to_var(labels)

            target_dice = to_var(torch.ones(1))

            ################### Train ###################
            netG.zero_grad()

            deepSupervision = False
            multiTask = False

            start_time = time.time()
            if deepSupervision == False and multiTask == False:
                # No deep supervision
                segmentation_prediction = netG(MRI)
            else:
                # Deep supervision
                if deepSupervision == True:
                    segmentation_prediction, seg_3, seg_2, seg_1 = netG(MRI)
                else:
                    segmentation_prediction, reg_output = netG(MRI)
                    # Regression
                    feats = getValuesRegression(labels)

                    feats_t = torch.from_numpy(feats).float()
                    featsVar = to_var(feats_t)

                    MSE_loss_val = MSE_loss(reg_output, featsVar)

            predClass_y = softMax(segmentation_prediction)

            spentTime = time.time() - start_time

            timesAll.append(spentTime / batch_size)

            Segmentation_planes = getOneHotSegmentation(Segmentation)
            segmentation_prediction_ones = predToSegmentation(predClass_y)

            # It needs the logits, not the softmax
            Segmentation_class = getTargetSegmentation(Segmentation)

            # No deep supervision
            CE_lossG = CE_loss(segmentation_prediction, Segmentation_class)
            if deepSupervision == True:

                imageLabels_05 = resizeTensorMaskInSingleImage(
                    Segmentation_class, 2)
                imageLabels_025 = resizeTensorMaskInSingleImage(
                    Segmentation_class, 4)
                imageLabels_0125 = resizeTensorMaskInSingleImage(
                    Segmentation_class, 8)

                CE_lossG_3 = CE_loss(seg_3, imageLabels_05)
                CE_lossG_2 = CE_loss(seg_2, imageLabels_025)
                CE_lossG_1 = CE_loss(seg_1, imageLabels_0125)
            '''weight = torch.ones(4).cuda() # Num classes
            weight[0] = 0.2
            weight[1] = 0.2
            weight[2] = 1
            weight[3] = 1
            
            CE_loss.weight = weight'''

            # Dice loss
            DicesN, DicesB, DicesW, DicesT = Dice_loss(
                segmentation_prediction_ones, Segmentation_planes)
            DiceN = DicesToDice(DicesN)
            DiceB = DicesToDice(DicesB)
            DiceW = DicesToDice(DicesW)
            DiceT = DicesToDice(DicesT)

            Dice_score = (DiceB + DiceW + DiceT) / 3

            if deepSupervision == False and multiTask == False:
                lossG = CE_lossG
            else:
                # Deep supervision
                if deepSupervision == True:
                    lossG = CE_lossG + 0.25 * CE_lossG_3 + 0.1 * CE_lossG_2 + 0.1 * CE_lossG_1
                else:
                    lossG = CE_lossG + 0.000001 * MSE_loss_val

            lossG.backward()
            optimizerG.step()

            lossVal.append(lossG.data[0])
            lossVal1.append(CE_lossG.data[0])

            if deepSupervision == True:
                lossVal05.append(CE_lossG_3.data[0])
                lossVal025.append(CE_lossG_2.data[0])
                lossVal0125.append(CE_lossG_1.data[0])

            printProgressBar(
                j + 1,
                totalImages,
                prefix="[Training] Epoch: {} ".format(i),
                length=15,
                suffix=
                " Mean Dice: {:.4f}, Dice1: {:.4f} , Dice2: {:.4f}, , Dice3: {:.4f} "
                .format(Dice_score.data[0], DiceB.data[0], DiceW.data[0],
                        DiceT.data[0]))

        if deepSupervision == False:
            '''printProgressBar(totalImages, totalImages,
                             done="[Training] Epoch: {}, LossG: {:.4f},".format(i,np.mean(lossVal),np.mean(lossVal1)))'''
            printProgressBar(
                totalImages,
                totalImages,
                done="[Training] Epoch: {}, LossG: {:.4f}, lossMSE: {:.4f}".
                format(i, np.mean(lossVal), np.mean(lossVal1)))
        else:
            printProgressBar(
                totalImages,
                totalImages,
                done=
                "[Training] Epoch: {}, LossG: {:.4f}, Loss4: {:.4f}, Loss3: {:.4f}, Loss2: {:.4f}, Loss1: {:.4f}"
                .format(i, np.mean(lossVal), np.mean(lossVal1),
                        np.mean(lossVal05), np.mean(lossVal025),
                        np.mean(lossVal0125)))

        Losses.append(np.mean(lossVal))

        d1, d2, d3 = inference(netG, val_loader, batch_size, i,
                               deepSupervision)

        d1Val.append(d1)
        d2Val.append(d2)
        d3Val.append(d3)

        d1Train.append(np.mean(d1TrainTemp).data[0])
        d2Train.append(np.mean(d2TrainTemp).data[0])
        d3Train.append(np.mean(d3TrainTemp).data[0])

        mainPath = '../Results/Statistics/' + modelName

        directory = mainPath
        if not os.path.exists(directory):
            os.makedirs(directory)

        ###### Save statistics  ######
        np.save(os.path.join(directory, 'Losses.npy'), Losses)

        np.save(os.path.join(directory, 'd1Val.npy'), d1Val)
        np.save(os.path.join(directory, 'd2Val.npy'), d2Val)
        np.save(os.path.join(directory, 'd3Val.npy'), d3Val)

        np.save(os.path.join(directory, 'd1Train.npy'), d1Train)
        np.save(os.path.join(directory, 'd2Train.npy'), d2Train)
        np.save(os.path.join(directory, 'd3Train.npy'), d3Train)

        currentDice = (d1 + d2 + d3) / 3

        # How many slices with/without tumor correctly classified
        print("[val] DSC: (1): {:.4f} (2): {:.4f}  (3): {:.4f} ".format(
            d1, d2, d3))

        if currentDice > BestDice:
            BestDice = currentDice
            BestDiceT = d1
            BestEpoch = i
            if currentDice > 0.7:
                print(
                    "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Saving best model..... ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~"
                )
                if not os.path.exists(model_dir):
                    os.makedirs(model_dir)
                torch.save(
                    netG, os.path.join(model_dir,
                                       "Best_" + modelName + ".pkl"))

                # Save images
                saveImages(netG, val_loader_save_images, batch_size_val_save,
                           i, modelName, deepSupervision)
                saveImagesAsMatlab(netG, val_loader_save_images,
                                   batch_size_val_save, i, modelName,
                                   deepSupervision)

        print("###                                                       ###")
        print("###    Best Dice: {:.4f} at epoch {} with DiceT: {:.4f}    ###".
              format(BestDice, BestEpoch, BestDiceT))
        print("###                                                       ###")

        # This is not as we did it in the MedPhys paper
        if i % (BestEpoch + 20):
            for param_group in optimizerG.param_groups:
                param_group['lr'] = lr / 2
Example #9
0
def runTraining():
    print('-' * 40)
    print('~~~~~~~~  Starting the training... ~~~~~~')
    print('-' * 40)

    # Batch size for training MUST be 1 in weakly/semi supervised learning if we want to impose constraints.
    batch_size = 1
    batch_size_val = 1
    lr = 0.0005
    epoch = 1000

    root_dir = './ACDC-2D-All'
    model_dir = 'model'

    transform = transforms.Compose([transforms.ToTensor()])

    mask_transform = transforms.Compose([transforms.ToTensor()])

    train_set = medicalDataLoader.MedicalImageDataset(
        'train',
        root_dir,
        transform=transform,
        mask_transform=mask_transform,
        augment=False,
        equalize=False)

    train_loader = DataLoader(train_set,
                              batch_size=batch_size,
                              num_workers=5,
                              shuffle=False)

    val_set = medicalDataLoader.MedicalImageDataset(
        'val',
        root_dir,
        transform=transform,
        mask_transform=mask_transform,
        equalize=False)

    val_loader = DataLoader(val_set,
                            batch_size=batch_size_val,
                            num_workers=5,
                            shuffle=False)

    minVal = 97.9
    maxVal = 1722.6
    minSize = torch.FloatTensor(1)
    minSize.fill_(np.int64(minVal).item())
    maxSize = torch.FloatTensor(1)
    maxSize.fill_(np.int64(maxVal).item())

    print("~~~~~~~~~~~ Creating the model ~~~~~~~~~~")
    num_classes = 2

    netG = ENet(1, num_classes)

    netG.apply(weights_init)
    softMax = nn.Softmax()
    Dice_loss = computeDiceOneHotBinary()

    modelName = 'WeaklySupervised_CE-2_b'

    print(' Model name: {}'.format(modelName))
    partial_ce = Partial_CE()
    mil_loss = MIL_Loss()
    size_loss = Size_Loss()

    if torch.cuda.is_available():
        netG.cuda()
        softMax.cuda()
        Dice_loss.cuda()

    optimizerG = torch.optim.Adam(netG.parameters(), lr=lr, betas=(0.5, 0.999))

    BestDice, BestEpoch = 0, 0

    dBAll = []
    Losses = []

    annotatedPixels = 0
    totalPixels = 0

    print(" ~~~~~~~~~~~ Starting the training ~~~~~~~~~~")
    print(' --------- Params: ---------')
    print(' - Lower bound: {}'.format(minVal))
    print(' - Upper bound: {}'.format(maxVal))
    for i in range(epoch):
        netG.train()
        lossVal = []
        lossVal1 = []

        totalImages = len(train_loader)
        for j, data in enumerate(train_loader):
            image, labels, weak_labels, img_names = data

            # prevent batchnorm error for batch of size 1
            if image.size(0) != batch_size:
                continue

            optimizerG.zero_grad()
            netG.zero_grad()

            MRI = to_var(image)
            Segmentation = to_var(labels)
            weakAnnotations = to_var(weak_labels)

            segmentation_prediction = netG(MRI)

            annotatedPixels = annotatedPixels + weak_labels.sum()
            totalPixels = totalPixels + weak_labels.shape[
                2] * weak_labels.shape[3]
            temperature = 0.1
            predClass_y = softMax(segmentation_prediction / temperature)
            Segmentation_planes = getOneHot_Encoded_Segmentation(Segmentation)
            segmentation_prediction_ones = predToSegmentation(predClass_y)

            # lossCE_numpy = partial_ce(segmentation_prediction, Segmentation_planes, weakAnnotations)
            lossCE_numpy = partial_ce(predClass_y, Segmentation_planes,
                                      weakAnnotations)

            # sizeLoss_val = size_loss(segmentation_prediction, Segmentation_planes, Variable(minSize), Variable(maxSize))
            sizeLoss_val = size_loss(predClass_y, Segmentation_planes,
                                     Variable(minSize), Variable(maxSize))

            # MIL_Loss_val = mil_loss(predClass_y, Segmentation_planes)

            # Dice loss (ONLY USED TO COMPUTE THE DICE. This DICE loss version does not work)
            DicesN, DicesB = Dice_loss(segmentation_prediction_ones,
                                       Segmentation_planes)
            DiceN = DicesToDice(DicesN)
            DiceB = DicesToDice(DicesB)

            Dice_score = (DiceB + DiceN) / 2

            # Choose between the different models
            # lossG = lossCE_numpy + MIL_Loss_val
            lossG = lossCE_numpy + sizeLoss_val
            # lossG = lossCE_numpy
            # lossG = sizeLoss_val

            lossG.backward(retain_graph=True)
            optimizerG.step()

            lossVal.append(lossG.data[0])
            lossVal1.append(lossCE_numpy.data[0])

            printProgressBar(
                j + 1,
                totalImages,
                prefix="[Training] Epoch: {} ".format(i),
                length=15,
                suffix=" Mean Dice: {:.4f}, Dice1: {:.4f} ".format(
                    Dice_score.data[0], DiceB.data[0]))

        deepSupervision = False
        printProgressBar(
            totalImages,
            totalImages,
            done=
            f"[Training] Epoch: {i}, LossG: {np.mean(lossVal):.4f}, lossMSE: {np.mean(lossVal1):.4f}"
        )

        Losses.append(np.mean(lossVal))
        d1, sizeGT, sizePred = inference(netG, temperature, val_loader,
                                         batch_size, i, deepSupervision,
                                         modelName, minVal, maxVal)

        dBAll.append(d1)

        directory = 'Results/Statistics/MIDL/' + modelName
        if not os.path.exists(directory):
            os.makedirs(directory)

        np.save(os.path.join(directory, modelName + '_Losses.npy'), Losses)
        np.save(os.path.join(directory, modelName + '_dBAll.npy'), dBAll)

        currentDice = d1

        print(" [VAL] DSC: (1): {:.4f} ".format(d1))
        # saveImagesSegmentation(netG, val_loader_save_imagesPng, batch_size_val_savePng, i, 'test', False)

        if currentDice > BestDice:
            BestDice = currentDice
            if not os.path.exists(model_dir):
                os.makedirs(model_dir)
            torch.save(netG,
                       os.path.join(model_dir, "Best_" + modelName + ".pkl"))

        if i % (BestEpoch + 10):
            for param_group in optimizerG.param_groups:
                param_group['lr'] = lr
def runTraining():
    print('-' * 40)
    print('~~~~~~~~  Starting the training... ~~~~~~')
    print('-' * 40)
    # Batch size for training MUST be 1 in weakly/semi supervised learning if we want to impose constraints.
    batch_size = 1
    batch_size_val = 1
    batch_size_val_save = 1
    batch_size_val_savePng = 1
    lr = 0.0005
    epoch = 1000
    root_dir = './ACDC-2D-All'
    model_dir = 'model'

    transform = transforms.Compose([transforms.ToTensor()])

    mask_transform = transforms.Compose([transforms.ToTensor()])

    train_set = medicalDataLoader.MedicalImageDataset(
        'train',
        root_dir,
        transform=transform,
        mask_transform=mask_transform,
        augment=False,
        equalize=False)

    train_loader = DataLoader(train_set,
                              batch_size=batch_size,
                              num_workers=5,
                              shuffle=False)

    val_set = medicalDataLoader.MedicalImageDataset(
        'val',
        root_dir,
        transform=transform,
        mask_transform=mask_transform,
        equalize=False)

    val_loader = DataLoader(val_set,
                            batch_size=batch_size_val,
                            num_workers=5,
                            shuffle=False)

    val_loader_save_images = DataLoader(val_set,
                                        batch_size=batch_size_val_save,
                                        num_workers=5,
                                        shuffle=False)

    val_loader_save_imagesPng = DataLoader(val_set,
                                           batch_size=batch_size_val_savePng,
                                           num_workers=5,
                                           shuffle=False)

    # Getting label statistics

    #### To Create weak labels ###
    '''for j, data in enumerate(train_loader):
            image, labels, img_names = data
            backgroundVal = 0
            foregroundVal = 1.0
   
            oneHotLabels = (labels == foregroundVal)

            gt_numpy = (oneHotLabels.numpy()).reshape((256,256))
            gt_eroded = gt_numpy
        
            if (gt_numpy.sum()>0):
                # Erode it
                struct2 = ndimage.generate_binary_structure(2, 3)
                gt_eroded = ndimage.binary_erosion(gt_numpy, structure=struct2,iterations=10).astype(gt_numpy.dtype)
        
                # To be sure that we do not delete the Weakly annoated label
                if (gt_eroded.sum() == 0):
                    gt_eroded = ndimage.binary_erosion(gt_numpy, structure=struct2,iterations=7).astype(gt_numpy.dtype)

                if (gt_eroded.sum() == 0):
                    gt_eroded = ndimage.binary_erosion(gt_numpy, structure=struct2,iterations=3).astype(gt_numpy.dtype)
                
            gt_eroded_Torch = torch.from_numpy(gt_eroded.reshape((1,256,256))).float()

            path = 'WeaklyAnnotations'
            if not os.path.exists(path):
                os.makedirs(path)
        
            name = img_names[0].split('../Corstem/ACDC-2D-All/train/Img/' )
            name = name[1]
            torchvision.utils.save_image(gt_eroded_Torch, os.path.join(path, name), nrow=1,   padding=2,  normalize=False, range=None, scale_each=False,  pad_value=0)
            '''

    print("~~~~~~~~~~~ Getting statistics ~~~~~~~~~~")
    LV_Sizes_Sys = []
    LV_Sizes_Dyas = []
    names = []
    '''for j, data in enumerate(train_loader):
            image, labels, weak_labels, img_names = data
            backgroundVal = 0
            foregroundVal = 1.0
            names.append(img_names)
            oneHotLabels = (labels == foregroundVal)

            if (oneHotLabels.sum() > 0):
                str_split = img_names[0].split('_')
                str_split = str_split[1]
                cycle = int(str_split)
                if cycle == 1:
                    LV_Sizes_Sys.append(oneHotLabels.sum())
                else:
                    LV_Sizes_Dyas.append(oneHotLabels.sum())
    
    minVal_Sys = np.min(LV_Sizes_Sys)*0.9
    maxVal_Sys = np.max(LV_Sizes_Sys)*1.1
    
    minVal_Dyas = np.min(LV_Sizes_Dyas)*0.9
    maxVal_Dyas = np.max(LV_Sizes_Dyas)*1.1
    
    minSys = 142 # = 158*0.9
    maxSys = 2339 # = 2127*1.1
    
    minDyas = 80 # 89*0.9
    maxDyas = 1868 # 1698*1.1
    '''
    minVal = 97.9
    #minVal = np.min(LV_Sizes_Sys)
    maxVal = 1722.6
    #maxVal = 10000
    #maxVal = maxVal_Dyas
    #pdb.set_trace()
    # For LogBarrier

    t = 1.0
    mu = 1.001

    currentDice = 0.0
    #    for i in range(200):
    #        t = t*mu
    #        print(' t: {}'.format(t))

    # Initialize
    print("~~~~~~~~~~~ Creating the model ~~~~~~~~~~")
    num_classes = 2

    # ENet
    netG = ENet(1, num_classes)
    #netG = FCN8s(num_classes)
    #netG = UNetG_Dilated(1,16,4)

    netG.apply(weights_init)
    softMax = nn.Softmax()
    Dice_loss = computeDiceOneHotBinary()
    '''BCE_loss = nn.BCELoss()
    CE_loss = nn.CrossEntropyLoss()
    Dice_loss = computeDiceOneHot()
    MSE_loss = torch.nn.MSELoss()  # this is for regression mean squared loss
    '''

    #modelName = 'WeaklySupervised_LogBarrier_ScheduleT_mu1025_DoubleSoftMax'
    #modelName = 'WeaklySupervised_LogBarrier_ScheduleT_tInit_5_mu101_Weighted_numAnnotatedPixels_DerivateCorrected_WeightLogBar_01'
    #modelName = 'WeaklySupervised_LogBarrier_ScheduleT_tInit_5_mu1005_Weighted_numAnnotatedPixels_DerivateCorrected_WeightLogBar_1'
    #modelName = 'WeaklySupervised_LogBarrier_ScheduleT_tInit_5_mu1005_Weighted_numAnnotatedPixels_DerivateCorrected_WeightLogBar_NoWeighted'
    modelName = 'WeaklySupervised_LogBarrier_NIPS3'
    #modelName = 'WeaklySupervised_LogBarrier_ScheduleT_tInit_5_mu101_Weighted_001_DoubleSoftMax'
    #modelName = 'WeaklySupervised_NaiveLoss'

    print(' ModelName: {}'.format(modelName))

    #CE_loss_weakly_numpy_OneHot = myCE_Loss_Weakly_numpy_OneHot()
    CE_loss_weakly_numpy_OneHot = myCE_Loss_Weakly_numpy_OneHot_SoftMaxPyTorch(
    )
    #sizeLoss = mySize_Loss_numpy()
    #sizeLoss = mySize_Loss_numpy_SoftMaxPyTorch()
    #sizeLoss_LOG_BARRIER_OneBound = mySize_Loss_LOG_BARRIER_ONE_BOUND()
    sizeLoss_LOG_BARRIER_Twobounds = mySize_Loss_LOG_BARRIER_TWO_BOUNDS()

    if torch.cuda.is_available():
        netG.cuda()
        softMax.cuda()
        Dice_loss.cuda()
    '''modelName = 'WeaklySupervised_ENet_HardSizeLoss_NewValues'
    
    try:
        netG = torch.load('./model/Best_' + modelName+'.pkl')
        print("--------model restored--------")
    except:
        print("--------model not restored--------")
        pass'''

    #netG.cuda()

    optimizerG = torch.optim.Adam(netG.parameters(), lr=lr, betas=(0.5, 0.999))
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizerG,
                                                           mode='max',
                                                           patience=4,
                                                           verbose=True,
                                                           factor=10**-0.5)
    #optimizerD = torch.optim.Adam(netD.parameters(), lr=lr, betas=(0.5, 0.999))

    BestDice, BestEpoch = 0, 0

    dBAll = []
    percV = []
    dWAll = []
    dTAll = []
    Losses = []
    Losses1 = []
    violConstraintsNeg = []
    violConstraintsPos = []
    violConstraintsTotal = []
    violConstraintsNeg_Distance = []
    violConstraintsPos_Distance = []

    predSizeArr_train = []
    predSizeArr_val = []
    targetSizeArr_train = []
    targetSizeArr_val = []

    annotatedPixels = 0
    totalPixels = 0

    print(" ~~~~~~~~~~~ Starting the training ~~~~~~~~~~")
    print(' --------- Params: ---------')
    print(' - Lower bound: {}'.format(minVal))
    print(' - Upper bound: {}'.format(maxVal))
    print(' - t (logBarrier): {}'.format(t))
    for i in range(epoch):
        netG.train()
        lossVal = []
        lossVal1 = []

        totalImages = len(train_loader)
        #d1, sizeGT, sizePred = inference(netG, 0.1, val_loader, batch_size, i, False, modelName, minVal, maxVal)
        predSizeArrBatches_train = []
        #predSizeArrBatches_val = []
        targetSizeArrBatches_train = []
        #targetSizeArrBatches_val = []

        for j, data in enumerate(train_loader):
            image, labels, weak_labels, img_names = data

            # prevent batchnorm error for batch of size 1
            if image.size(0) != batch_size:
                continue

            optimizerG.zero_grad()
            MRI = to_var(image)
            Segmentation = to_var(labels)
            weakAnnotations = to_var(weak_labels)
            target_dice = to_var(torch.ones(1))

            netG.zero_grad()

            segmentation_prediction = netG(MRI)

            annotatedPixels = annotatedPixels + weak_labels.sum()
            totalPixels = totalPixels + weak_labels.shape[
                2] * weak_labels.shape[3]
            temperature = 0.1
            predClass_y = softMax(segmentation_prediction / temperature)

            # ----- To compute the predicted and target size ------
            predSize = torch.sum(
                (predClass_y[:, 1, :, :] > 0.5).type(torch.FloatTensor))
            predSizeNumpy = predSize.cpu().data.numpy()

            LV_target = (labels == 1).type(torch.FloatTensor)
            targetSize = torch.sum(LV_target)
            targetSizeNumpy = targetSize  # targetSize.cpu().data.numpy()[0]

            predSizeArrBatches_train.append(predSizeNumpy)
            targetSizeArrBatches_train.append(targetSizeNumpy)
            # ---------------------------------------------- #

            Segmentation_planes = getOneHot_Encoded_Segmentation(Segmentation)
            segmentation_prediction_ones = predToSegmentation(predClass_y)
            #segmentation_prediction_ones = predToSegmentation(segmentation_prediction)

            # It needs the logits, not the softmax
            Segmentation_class = getTargetSegmentation(Segmentation)

            #lossCE_numpy = CE_loss_weakly_numpy_OneHot(segmentation_prediction, Segmentation_planes, weakAnnotations)
            lossCE_numpy = CE_loss_weakly_numpy_OneHot(predClass_y,
                                                       Segmentation_planes,
                                                       weakAnnotations)

            minSize = torch.FloatTensor(1)
            minSize.fill_(np.int64(minVal).item())
            maxSize = torch.FloatTensor(1)
            maxSize.fill_(np.int64(maxVal).item())

            t_logB = torch.FloatTensor(1)
            t_logB.fill_(np.int64(t).item())

            #sizeLoss_val = sizeLoss(segmentation_prediction, Segmentation_planes, Variable(minSize), Variable( maxSize))
            #sizeLoss_val = sizeLoss(predClass_y, Segmentation_planes, Variable(minSize), Variable( maxSize))

            #sizeLoss_val = sizeLoss_LOG_BARRIER_OneBound(predClass_y, Segmentation_planes, Variable( maxSize), Variable(t_logB))
            sizeLoss_val = sizeLoss_LOG_BARRIER_Twobounds(
                predClass_y, Segmentation_planes, Variable(minSize),
                Variable(maxSize), Variable(t_logB))
            #CE_lossG = CE_loss(segmentation_prediction, Segmentation_class)

            # Dice loss (ONLY USED TO COMPUTE THE DICE. This DICE loss version does not work)
            DicesN, DicesB = Dice_loss(segmentation_prediction_ones,
                                       Segmentation_planes)
            DiceN = DicesToDice(DicesN)
            DiceB = DicesToDice(DicesB)

            Dice_score = (DiceB + DiceN) / 2

            lossG = lossCE_numpy + sizeLoss_val
            #lossG = lossCE_numpy

            lossG.backward(retain_graph=True)
            #lossG.backward()
            optimizerG.step()

            lossVal.append(lossG.data[0])
            lossVal1.append(lossCE_numpy.data[0])

            printProgressBar(
                j + 1,
                totalImages,
                prefix="[Training] Epoch: {} ".format(i),
                length=15,
                suffix=" Mean Dice: {:.4f}, Dice1: {:.4f} ".format(
                    Dice_score.data[0], DiceB.data[0]))

        predSizeArr_train.append(predSizeArrBatches_train)
        targetSizeArr_train.append(targetSizeArrBatches_train)

        deepSupervision = False
        if deepSupervision == False:
            '''printProgressBar(totalImages, totalImages,
                             done="[Training] Epoch: {}, LossG: {:.4f},".format(i,np.mean(lossVal),np.mean(lossVal1)))'''
            printProgressBar(
                totalImages,
                totalImages,
                done="[Training] Epoch: {}, LossG: {:.4f}, lossMSE: {:.4f}".
                format(i, np.mean(lossVal), np.mean(lossVal1)))
        else:
            printProgressBar(
                totalImages,
                totalImages,
                done=
                "[Training] Epoch: {}, LossG: {:.4f}, Loss4: {:.4f}, Loss3: {:.4f}, Loss2: {:.4f}, Loss1: {:.4f}"
                .format(i, np.mean(lossVal), np.mean(lossVal1),
                        np.mean(lossVal05), np.mean(lossVal025),
                        np.mean(lossVal0125)))

        # Save statistics
        #modelName = 'WeaklySupervised_LOGBARRIER_1_TwoTightBounds'

        #Losses.append(np.mean(lossVal))
        Losses.append(np.mean(lossVal1))

        #d1, percViol, violCases = inference(netG, temperature, val_loader, batch_size, i, deepSupervision)
        d1, targetSizeArrBatches_val, predSizeArrBatches_val = inference(
            netG, temperature, val_loader, batch_size, i, deepSupervision,
            modelName, minVal, maxVal)

        predSizeArr_val.append(predSizeArrBatches_val)
        targetSizeArr_val.append(targetSizeArrBatches_val)

        dBAll.append(d1)
        #percV.append(percViol)

        [
            violPercNeg, violPercPos, violDistanceNeg, violDistanceNeg_min,
            violDistanceNeg_max, violDistancePos, violDistancePos_min,
            violDistancePos_max
        ] = analyzeViolationContraints(targetSizeArrBatches_val,
                                       predSizeArrBatches_val, minVal, maxVal)

        violConstraintsNeg.append(violPercPos)
        violConstraintsPos.append(violPercNeg)
        violConstraintsTotal.append(violPercPos + violPercNeg)
        violConstraintsNeg_Distance.append(violDistanceNeg)
        violConstraintsPos_Distance.append(violDistancePos)

        #dWAll.append(d2)
        #dTAll.append(d3)

        directory = 'Results/Statistics/MIDL/' + modelName
        if not os.path.exists(directory):
            os.makedirs(directory)

        np.save(os.path.join(directory, modelName + '_Losses.npy'), Losses)

        np.save(os.path.join(directory, modelName + '_dBAll.npy'), dBAll)

        np.save(os.path.join(directory, modelName + '_percViolated_Neg.npy'),
                violConstraintsNeg)

        np.save(os.path.join(directory, modelName + '_percViolated_Pos.npy'),
                violConstraintsPos)

        np.save(os.path.join(directory, modelName + '_percViolated_Total.npy'),
                violConstraintsTotal)

        np.save(os.path.join(directory, modelName + '_diffViolated_Neg.npy'),
                violConstraintsNeg_Distance)

        np.save(os.path.join(directory, modelName + '_diffViolated_Pos.npy'),
                violConstraintsPos_Distance)

        np.save(os.path.join(directory, modelName + '_predSizes_train.npy'),
                predSizeArr_train)
        np.save(os.path.join(directory, modelName + '_predSizes_val.npy'),
                predSizeArr_val)
        np.save(os.path.join(directory, modelName + '_targetSizes_train.npy'),
                targetSizeArr_train)
        np.save(os.path.join(directory, modelName + '_targetSizes_val.npy'),
                targetSizeArr_val)

        t = t * mu
        print(' t: {}'.format(t))

        #print("[val] DSC: (1): {:.4f} ".format(d1[0]))
        print(" [VAL] DSC: (1): {:.4f} ".format(d1))
        print(
            ' [VAL] NEGATIVE: Constrained violated in {:.4f} % of images ( Mean diff = {})'
            .format(violPercNeg, violDistanceNeg))
        print(
            ' [VAL] POSITIVE: Constrained violated in {:.4f} % of images ( Mean diff = {})'
            .format(violPercPos, violDistancePos))
        print(
            ' [VAL] TOTAL: Constrained violated in {:.4f} % of images '.format(
                violPercNeg + violPercPos))
        #saveImagesSegmentation(netG, val_loader_save_imagesPng, batch_size_val_savePng, i, 'test', False)

        #if (d1[0]>0.80):
        if (d1 > BestDice):
            BestDice = d1
            if not os.path.exists(model_dir):
                os.makedirs(model_dir)

            torch.save(netG,
                       os.path.join(model_dir, "Best_" + modelName + ".pkl"))
            saveImages(netG, val_loader_save_imagesPng, batch_size_val_savePng,
                       i, modelName, deepSupervision)
        '''if currentDice > BestDice:
            BestDice = currentDice
            BestDiceT = d1
            BestEpoch = i
            if np.mean(currentDice) > 0.88:
                print("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Saving best model..... ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~")
                if not os.path.exists(model_dir):
                    os.makedirs(model_dir)
                #torch.save(netG, os.path.join(model_dir, "Best_" + modelName + ".pkl"))

                # Save images
                #saveImages(netG, val_loader_save_imagesPng, batch_size_val_savePng, i, deepSupervision)
                #saveImagesAsMatlab(netG, val_loader_save_images, batch_size_val_save, i)
                #saveImagesAsMatlab(netG, val_loader_save_images_york, batch_size_val_save, i)
        print("###                                                       ###")
        print("###    Best Dice: {:.4f} at epoch {} with DiceT: {:.4f}    ###".format(BestDice, BestEpoch, BestDiceT))
        print("###                                                       ###")'''

        if i % (BestEpoch + 10):
            for param_group in optimizerG.param_groups:
                param_group['lr'] = lr
Example #11
0
def runInference(argv):
    print('-' * 40)
    print('~~~~~~~~  Starting the training... ~~~~~~')
    print('-' * 40)

    # Batch size for training MUST be 1 in weakly/semi supervised learning if we want to impose constraints.
    batch_size = 1
    batch_size_val = 1
    batch_size_val_save = 1
    batch_size_val_savePng = 1
    lr = 0.0005
    epoch = 1000
 
    root_dir = './ACDC-2D-All'
    model_dir = 'model'


    transform = transforms.Compose([
        transforms.ToTensor()
    ])

    mask_transform = transforms.Compose([
        transforms.ToTensor()
    ])

    val_set = medicalDataLoader.MedicalImageDataset('val',
                                                    root_dir,
                                                    transform=transform,
                                                    mask_transform=mask_transform,
                                                    equalize=False)

    val_loader = DataLoader(val_set,
                            batch_size=batch_size_val,
                            num_workers=5,
                            shuffle=False)
    
                                                    
    val_loader_save_images = DataLoader(val_set,
                                        batch_size=batch_size_val_save,
                                        num_workers=5,
                                        shuffle=False)

    val_loader_save_imagesPng = DataLoader(val_set,
                                        batch_size=batch_size_val_savePng,
                                        num_workers=5,
                                        shuffle=False)
    

    
    # Initialize
    print("~~~~~~~~~~~ Creating the model ~~~~~~~~~~")
    num_classes = 2

    # ENet
    netG = ENet(1,num_classes)
    #netG = UNetG_Dilated(1,16,4)

    netG.apply(weights_init)
    softMax = nn.Softmax()
    Dice_loss = computeDiceOneHotBinary()
    
 
    
    #CE_loss_weakly_numpy_OneHot = myCE_Loss_Weakly_numpy_OneHot()
    CE_loss_weakly_numpy_OneHot = myCE_Loss_Weakly_numpy_OneHot_SoftMaxPyTorch()
    #sizeLoss = mySize_Loss_numpy()
    sizeLoss = mySize_Loss_numpy_SoftMaxPyTorch()
    
    if torch.cuda.is_available():
        netG.cuda()
        softMax.cuda()
        Dice_loss.cuda()
    
    '''modelName = 'WeaklySupervised_ENet_HardSizeLoss_NewValues'''
    modelName = argv[0]
    try:
        netG = torch.load('./model/Best_' + modelName+'.pkl')
        print("--------model restored--------")
    except:
        print("--------model not restored--------")
        pass
        
    #netG.cuda()

    optimizerG = torch.optim.Adam(netG.parameters(), lr=lr, betas=(0.5, 0.999))
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizerG, mode='max', patience=4, verbose=True,
                                                       factor=10 ** -0.5)
    #optimizerD = torch.optim.Adam(netD.parameters(), lr=lr, betas=(0.5, 0.999))

    BestDice, BestEpoch = 0, 0

    dBAll = []
    percV = []
    dWAll = []
    dTAll = []
    Losses = []
    Losses1 = []

    annotatedPixels = 0
    totalPixels = 0
    
    deepSupervision = False
    saveImagesSegmentation(netG, val_loader_save_imagesPng, batch_size_val_savePng, 0, modelName, deepSupervision)
    
    pdb.set_trace()
    print("~~~~~~~~~~~ Starting the training ~~~~~~~~~~")
    for i in range(epoch):
        netG.train()
        lossVal = []
        lossVal1 = []

        totalImages = len(train_loader)
        #d1, percViol = inference(netG, 1.0, val_loader, batch_size, i, False)
        #pdb.set_trace()
        for j, data in enumerate(train_loader):
            image, labels, weak_labels, img_names = data

            # prevent batchnorm error for batch of size 1
            if image.size(0) != batch_size:
                continue

            optimizerG.zero_grad()
            MRI = to_var(image)
            Segmentation = to_var(labels)
            weakAnnotations = to_var(weak_labels)
            target_dice = to_var(torch.ones(1))

            netG.zero_grad()
          
            segmentation_prediction = netG(MRI)
            
            annotatedPixels = annotatedPixels + weak_labels.sum()
            totalPixels = totalPixels + weak_labels.shape[2]*weak_labels.shape[3]
            temperature = 0.1
            predClass_y = softMax(segmentation_prediction/temperature)
            Segmentation_planes = getOneHot_Encoded_Segmentation(Segmentation)
            segmentation_prediction_ones = predToSegmentation(predClass_y)
            #segmentation_prediction_ones = predToSegmentation(segmentation_prediction)
            
            # It needs the logits, not the softmax
            Segmentation_class = getTargetSegmentation(Segmentation)
            
            #lossCE_numpy = CE_loss_weakly_numpy_OneHot(segmentation_prediction, Segmentation_planes, weakAnnotations)
            lossCE_numpy = CE_loss_weakly_numpy_OneHot(predClass_y, Segmentation_planes, weakAnnotations)

            minSize = torch.FloatTensor(1)
            minSize.fill_(np.int64(minVal).item())
            maxSize = torch.FloatTensor(1)
            maxSize.fill_(np.int64(maxVal).item())
            
            #sizeLoss_val = sizeLoss(segmentation_prediction, Segmentation_planes, Variable(minSize), Variable( maxSize))
            #sizeLoss_val = sizeLoss(predClass_y, Segmentation_planes, Variable(minSize), Variable( maxSize))
            
            #CE_lossG = CE_loss(segmentation_prediction, Segmentation_class)
        
            # Dice loss (ONLY USED TO COMPUTE THE DICE. This DICE loss version does not work)
            DicesN, DicesB = Dice_loss(segmentation_prediction_ones, Segmentation_planes)
            DiceN = DicesToDice(DicesN)
            DiceB = DicesToDice(DicesB)
            
            Dice_score = (DiceB + DiceN ) / 2
           
           
            #lossG = lossCE_numpy + sizeLoss_val 
            lossG = lossCE_numpy 

   
            lossG.backward(retain_graph=True)
            #lossG.backward()
            optimizerG.step()
            
            lossVal.append(lossG.data[0])
            lossVal1.append(lossCE_numpy.data[0])


            printProgressBar(j + 1, totalImages,
                             prefix="[Training] Epoch: {} ".format(i),
                             length=15,
                             suffix=" Mean Dice: {:.4f}, Dice1: {:.4f} ".format(
                                 Dice_score.data[0],
                                 DiceB.data[0]))

        deepSupervision = False
        if deepSupervision == False:
            '''printProgressBar(totalImages, totalImages,
                             done="[Training] Epoch: {}, LossG: {:.4f},".format(i,np.mean(lossVal),np.mean(lossVal1)))'''
            printProgressBar(totalImages, totalImages,
                             done="[Training] Epoch: {}, LossG: {:.4f}, lossMSE: {:.4f}".format(i,np.mean(lossVal),np.mean(lossVal1)))
        else:
            printProgressBar(totalImages, totalImages,
                             done="[Training] Epoch: {}, LossG: {:.4f}, Loss4: {:.4f}, Loss3: {:.4f}, Loss2: {:.4f}, Loss1: {:.4f}".format(i,
                                                                                                                                           np.mean(lossVal),
                                                                                                                                           np.mean(lossVal1),
                                                                                                                                           np.mean(lossVal05),
                                                                                                                                           np.mean(lossVal025),
                                                                                                                                           np.mean(lossVal0125)))

        # Save statistics
        modelName = 'WeaklySupervised_ENet_SoftSizeLoss_asTRUST_Temp01_OneBound_50000'
        #modelName = 'Test'
        Losses.append(np.mean(lossVal))

        d1, percViol = inference(netG, temperature, val_loader, batch_size, i, deepSupervision)

             
        '''if currentDice > BestDice:
            BestDice = currentDice
            BestDiceT = d1
            BestEpoch = i
            if np.mean(currentDice) > 0.88:
                print("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Saving best model..... ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~")
                if not os.path.exists(model_dir):
                    os.makedirs(model_dir)
                #torch.save(netG, os.path.join(model_dir, "Best_" + modelName + ".pkl"))

                # Save images
                #saveImages(netG, val_loader_save_imagesPng, batch_size_val_savePng, i, deepSupervision)
                #saveImagesAsMatlab(netG, val_loader_save_images, batch_size_val_save, i)
                #saveImagesAsMatlab(netG, val_loader_save_images_york, batch_size_val_save, i)
        print("###                                                       ###")
        print("###    Best Dice: {:.4f} at epoch {} with DiceT: {:.4f}    ###".format(BestDice, BestEpoch, BestDiceT))
        print("###                                                       ###")'''

        if i % (BestEpoch + 10):
            for param_group in optimizerG.param_groups:
                param_group['lr'] = lr
Example #12
0
def runTraining():
    print('-' * 40)
    print('~~~~~~~~  Starting the training... ~~~~~~')
    print('-' * 40)

    batch_size = 4
    batch_size_val = 1
    batch_size_val_save = 1

    lr = 0.0001
    epoch = 200
    num_classes = 2
    initial_kernels = 32

    modelName = 'IVD_Net'

    img_names_ALL = []
    print('.' * 40)
    print(" ....Model name: {} ........".format(modelName))

    print(' - Num. classes: {}'.format(num_classes))
    print(' - Num. initial kernels: {}'.format(initial_kernels))
    print(' - Batch size: {}'.format(batch_size))
    print(' - Learning rate: {}'.format(lr))
    print(' - Num. epochs: {}'.format(epoch))

    print('.' * 40)
    root_dir = '../Data/Training_PngITK'
    model_dir = 'IVD_Net'

    transform = transforms.Compose([transforms.ToTensor()])

    mask_transform = transforms.Compose([transforms.ToTensor()])

    train_set = medicalDataLoader.MedicalImageDataset(
        'train',
        root_dir,
        transform=transform,
        mask_transform=mask_transform,
        augment=False,
        equalize=False)

    train_loader = DataLoader(train_set,
                              batch_size=batch_size,
                              num_workers=5,
                              shuffle=True)

    val_set = medicalDataLoader.MedicalImageDataset(
        'val',
        root_dir,
        transform=transform,
        mask_transform=mask_transform,
        equalize=False)

    val_loader = DataLoader(val_set,
                            batch_size=batch_size_val,
                            num_workers=5,
                            shuffle=False)

    val_loader_save_images = DataLoader(val_set,
                                        batch_size=batch_size_val_save,
                                        num_workers=5,
                                        shuffle=False)
    # Initialize
    print("~~~~~~~~~~~ Creating the model ~~~~~~~~~~")

    net = IVD_Net_asym(1, num_classes, initial_kernels)

    # Initialize the weights
    net.apply(weights_init)

    softMax = nn.Softmax()
    CE_loss = nn.CrossEntropyLoss()
    Dice_ = computeDiceOneHotBinary()

    if torch.cuda.is_available():
        net.cuda()
        softMax.cuda()
        CE_loss.cuda()
        Dice_.cuda()

    # To load a pre-trained model
    '''try:
        net = torch.load('modelName')
        print("--------model restored--------")
    except:
        print("--------model not restored--------")
        pass'''

    optimizer = Adam(net.parameters(), lr=lr, betas=(0.9, 0.99), amsgrad=False)

    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                           mode='max',
                                                           patience=4,
                                                           verbose=True,
                                                           factor=10**-0.5)

    BestDice, BestEpoch = 0, 0

    d1Train = []
    d1Val = []
    Losses = []

    print("~~~~~~~~~~~ Starting the training ~~~~~~~~~~")
    for i in range(epoch):
        net.train()
        lossTrain = []
        d1TrainTemp = []

        totalImages = len(train_loader)

        for j, data in enumerate(train_loader):

            image_f, image_i, image_o, image_w, labels, img_names = data

            # Be sure your data here is between [0,1]
            image_f = image_f.type(torch.FloatTensor)
            image_i = image_i.type(torch.FloatTensor)
            image_o = image_o.type(torch.FloatTensor)
            image_w = image_w.type(torch.FloatTensor)

            labels = labels.numpy()
            idx = np.where(labels > 0.0)
            labels[idx] = 1.0
            labels = torch.from_numpy(labels)
            labels = labels.type(torch.FloatTensor)

            optimizer.zero_grad()
            MRI = to_var(torch.cat((image_f, image_i, image_o, image_w),
                                   dim=1))

            Segmentation = to_var(labels)

            target_dice = to_var(torch.ones(1))

            net.zero_grad()

            segmentation_prediction = net(MRI)
            predClass_y = softMax(segmentation_prediction)

            Segmentation_planes = getOneHotSegmentation(Segmentation)
            segmentation_prediction_ones = predToSegmentation(predClass_y)

            # It needs the logits, not the softmax
            Segmentation_class = getTargetSegmentation(Segmentation)

            CE_loss_ = CE_loss(segmentation_prediction, Segmentation_class)

            # Compute the Dice (so far in a 2D-basis)
            DicesB, DicesF = Dice_(segmentation_prediction_ones,
                                   Segmentation_planes)
            DiceB = DicesToDice(DicesB)
            DiceF = DicesToDice(DicesF)

            loss = CE_loss_

            loss.backward()
            optimizer.step()

            lossTrain.append(loss.data[0])

            printProgressBar(j + 1,
                             totalImages,
                             prefix="[Training] Epoch: {} ".format(i),
                             length=15,
                             suffix=" Mean Dice: {:.4f},".format(
                                 DiceF.data[0]))

        printProgressBar(totalImages,
                         totalImages,
                         done="[Training] Epoch: {}, LossG: {:.4f}".format(
                             i, np.mean(lossTrain)))
        # Save statistics
        Losses.append(np.mean(lossTrain))
        d1 = inference(net, val_loader, batch_size, i)
        d1Val.append(d1)
        d1Train.append(np.mean(d1TrainTemp).data[0])

        mainPath = '../Results/Statistics/' + modelName

        directory = mainPath
        if not os.path.exists(directory):
            os.makedirs(directory)

        np.save(os.path.join(directory, 'Losses.npy'), Losses)
        np.save(os.path.join(directory, 'd1Val.npy'), d1Val)
        np.save(os.path.join(directory, 'd1Train.npy'), d1Train)

        currentDice = d1[0].numpy()

        print("[val] DSC: {:.4f} ".format(d1[0]))

        if currentDice > BestDice:
            BestDice = currentDice

            BestEpoch = i
            if currentDice > 0.75:
                print(
                    "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Saving best model..... ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~"
                )
                if not os.path.exists(model_dir):
                    os.makedirs(model_dir)
                torch.save(
                    net, os.path.join(model_dir, "Best_" + modelName + ".pkl"))
                saveImages(net, val_loader_save_images, batch_size_val_save, i,
                           modelName)

        # Two ways of decay the learning rate:
        if i % (BestEpoch + 10):
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr