Beispiel #1
0
def runTraining(args):
    print('-' * 40)
    print('~~~~~~~~  Starting the training... ~~~~~~')
    print('-' * 40)

    batch_size = args.batch_size
    batch_size_val = 1
    batch_size_val_save = 1
    lr = args.lr

    epoch = args.epochs
    root_dir = './DataSet_Challenge/Val_1'
    model_dir = 'model'
    
    print(' Dataset: {} '.format(root_dir))

    transform = transforms.Compose([
        transforms.ToTensor()
    ])

    mask_transform = transforms.Compose([
        transforms.ToTensor()
    ])

    train_set = medicalDataLoader.MedicalImageDataset('train',
                                                      root_dir,
                                                      transform=transform,
                                                      mask_transform=mask_transform,
                                                      augment=True,
                                                      equalize=False)

    train_loader = DataLoader(train_set,
                              batch_size=batch_size,
                              num_workers=5,
                              shuffle=True)

    val_set = medicalDataLoader.MedicalImageDataset('val',
                                                    root_dir,
                                                    transform=transform,
                                                    mask_transform=mask_transform,
                                                    equalize=False)

    val_loader = DataLoader(val_set,
                            batch_size=batch_size_val,
                            num_workers=5,
                            shuffle=False)
                                                   
    val_loader_save_images = DataLoader(val_set,
                                        batch_size=batch_size_val_save,
                                        num_workers=4,
                                        shuffle=False)

                                                                    
    # Initialize
    print("~~~~~~~~~~~ Creating the DAF Stacked model ~~~~~~~~~~")
    net = DAF_stack()
    print(" Model Name: {}".format(args.modelName))
    print(" Model ot create: DAF_Stacked")

    net.apply(weights_init)

    softMax = nn.Softmax()
    CE_loss = nn.CrossEntropyLoss()
    Dice_loss = computeDiceOneHot()
    mseLoss = nn.MSELoss()

    if torch.cuda.is_available():
        net.cuda()
        softMax.cuda()
        CE_loss.cuda()
        Dice_loss.cuda()

    optimizer = Adam(net.parameters(), lr=lr, betas=(0.9, 0.99), amsgrad=False)

    BestDice, BestEpoch = 0, 0
    BestDice3D = [0,0,0,0]

    d1Val = []
    d2Val = []
    d3Val = []
    d4Val = []

    d1Val_3D = []
    d2Val_3D = []
    d3Val_3D = []
    d4Val_3D = []

    d1Val_3D_std = []
    d2Val_3D_std = []
    d3Val_3D_std = []
    d4Val_3D_std = []

    Losses = []

    print("~~~~~~~~~~~ Starting the training ~~~~~~~~~~")
    for i in range(epoch):
        net.train()
        lossVal = []

        totalImages = len(train_loader)
       
        for j, data in enumerate(train_loader):
            image, labels, img_names = data

            # prevent batchnorm error for batch of size 1
            if image.size(0) != batch_size:
                continue

            optimizer.zero_grad()
            MRI = to_var(image)
            Segmentation = to_var(labels)

            ################### Train ###################
            net.zero_grad()

            # Network outputs
            semVector_1_1, \
            semVector_2_1, \
            semVector_1_2, \
            semVector_2_2, \
            semVector_1_3, \
            semVector_2_3, \
            semVector_1_4, \
            semVector_2_4, \
            inp_enc0, \
            inp_enc1, \
            inp_enc2, \
            inp_enc3, \
            inp_enc4, \
            inp_enc5, \
            inp_enc6, \
            inp_enc7, \
            out_enc0, \
            out_enc1, \
            out_enc2, \
            out_enc3, \
            out_enc4, \
            out_enc5, \
            out_enc6, \
            out_enc7, \
            outputs0, \
            outputs1, \
            outputs2, \
            outputs3, \
            outputs0_2, \
            outputs1_2, \
            outputs2_2, \
            outputs3_2 = net(MRI)

            segmentation_prediction = (outputs0 + outputs1 + outputs2 + outputs3 + outputs0_2 + outputs1_2 + outputs2_2 + outputs3_2) / 8
            predClass_y = softMax(segmentation_prediction)

            Segmentation_planes = getOneHotSegmentation(Segmentation)

            segmentation_prediction_ones = predToSegmentation(predClass_y)

            # It needs the logits, not the softmax
            Segmentation_class = getTargetSegmentation(Segmentation)

            # Cross-entropy loss
            loss0 = CE_loss(outputs0, Segmentation_class)
            loss1 = CE_loss(outputs1, Segmentation_class)
            loss2 = CE_loss(outputs2, Segmentation_class)
            loss3 = CE_loss(outputs3, Segmentation_class)
            loss0_2 = CE_loss(outputs0_2, Segmentation_class)
            loss1_2 = CE_loss(outputs1_2, Segmentation_class)
            loss2_2 = CE_loss(outputs2_2, Segmentation_class)
            loss3_2 = CE_loss(outputs3_2, Segmentation_class)

            lossSemantic1 = mseLoss(semVector_1_1, semVector_2_1)
            lossSemantic2 = mseLoss(semVector_1_2, semVector_2_2)
            lossSemantic3 = mseLoss(semVector_1_3, semVector_2_3)
            lossSemantic4 = mseLoss(semVector_1_4, semVector_2_4)

            lossRec0 = mseLoss(inp_enc0, out_enc0)
            lossRec1 = mseLoss(inp_enc1, out_enc1)
            lossRec2 = mseLoss(inp_enc2, out_enc2)
            lossRec3 = mseLoss(inp_enc3, out_enc3)
            lossRec4 = mseLoss(inp_enc4, out_enc4)
            lossRec5 = mseLoss(inp_enc5, out_enc5)
            lossRec6 = mseLoss(inp_enc6, out_enc6)
            lossRec7 = mseLoss(inp_enc7, out_enc7)

            lossG = loss0 + loss1 + loss2 + loss3 + loss0_2 + loss1_2 + loss2_2 + loss3_2 + 0.25 * (
            lossSemantic1 + lossSemantic2 + lossSemantic3 + lossSemantic4) \
            + 0.1 * (lossRec0 + lossRec1 + lossRec2 + lossRec3 + lossRec4 + lossRec5 + lossRec6 + lossRec7)  # CE_lossG

            # Compute the DSC
            DicesN, DicesB, DicesW, DicesT, DicesZ = Dice_loss(segmentation_prediction_ones, Segmentation_planes)

            DiceB = DicesToDice(DicesB)
            DiceW = DicesToDice(DicesW)
            DiceT = DicesToDice(DicesT)
            DiceZ = DicesToDice(DicesZ)

            Dice_score = (DiceB + DiceW + DiceT+ DiceZ) / 4

            lossG.backward()
            optimizer.step()
            
            lossVal.append(lossG.cpu().data.numpy())

            printProgressBar(j + 1, totalImages,
                             prefix="[Training] Epoch: {} ".format(i),
                             length=15,
                             suffix=" Mean Dice: {:.4f}, Dice1: {:.4f} , Dice2: {:.4f}, , Dice3: {:.4f}, Dice4: {:.4f} ".format(
                                 Dice_score.cpu().data.numpy(),
                                 DiceB.data.cpu().data.numpy(),
                                 DiceW.data.cpu().data.numpy(),
                                 DiceT.data.cpu().data.numpy(),
                                 DiceZ.data.cpu().data.numpy(),))

      
        printProgressBar(totalImages, totalImages,
                             done="[Training] Epoch: {}, LossG: {:.4f}".format(i,np.mean(lossVal)))
       
        # Save statistics
        modelName = args.modelName
        directory = 'Results/Statistics/' + modelName
        
        Losses.append(np.mean(lossVal))
        
        d1,d2,d3,d4 = inference(net, val_loader)
        
        d1Val.append(d1)
        d2Val.append(d2)
        d3Val.append(d3)
        d4Val.append(d4)

        if not os.path.exists(directory):
            os.makedirs(directory)

        np.save(os.path.join(directory, 'Losses.npy'), Losses)
        np.save(os.path.join(directory, 'd1Val.npy'), d1Val)
        np.save(os.path.join(directory, 'd2Val.npy'), d2Val)
        np.save(os.path.join(directory, 'd3Val.npy'), d3Val)

        currentDice = (d1+d2+d3+d4)/4

        print("[val] DSC: (1): {:.4f} (2): {:.4f}  (3): {:.4f} (4): {:.4f}".format(d1,d2,d3,d4)) # MRI

        currentDice = currentDice.data.numpy()

        # Evaluate on 3D
        saveImages_for3D(net, val_loader_save_images, batch_size_val_save, 1000, modelName, False, False)
        reconstruct3D(modelName, 1000, isBest=False)
        DSC_3D = evaluate3D(modelName)

        mean_DSC3D = np.mean(DSC_3D, 0)
        std_DSC3D = np.std(DSC_3D,0)

        d1Val_3D.append(mean_DSC3D[0])
        d2Val_3D.append(mean_DSC3D[1])
        d3Val_3D.append(mean_DSC3D[2])
        d4Val_3D.append(mean_DSC3D[3])
        d1Val_3D_std.append(std_DSC3D[0])
        d2Val_3D_std.append(std_DSC3D[1])
        d3Val_3D_std.append(std_DSC3D[2])
        d4Val_3D_std.append(std_DSC3D[3])

        np.save(os.path.join(directory, 'd0Val_3D.npy'), d1Val_3D)
        np.save(os.path.join(directory, 'd1Val_3D.npy'), d2Val_3D)
        np.save(os.path.join(directory, 'd2Val_3D.npy'), d3Val_3D)
        np.save(os.path.join(directory, 'd3Val_3D.npy'), d4Val_3D)
        
        np.save(os.path.join(directory, 'd0Val_3D_std.npy'), d1Val_3D_std)
        np.save(os.path.join(directory, 'd1Val_3D_std.npy'), d2Val_3D_std)
        np.save(os.path.join(directory, 'd2Val_3D_std.npy'), d3Val_3D_std)
        np.save(os.path.join(directory, 'd3Val_3D_std.npy'), d4Val_3D_std)


        if currentDice > BestDice:
            BestDice = currentDice

            BestEpoch = i
            
            if currentDice > 0.40:

                if np.mean(mean_DSC3D)>np.mean(BestDice3D):
                    BestDice3D = mean_DSC3D

                print("###    In 3D -----> MEAN: {}, Dice(1): {:.4f} Dice(2): {:.4f} Dice(3): {:.4f} Dice(4): {:.4f}   ###".format(np.mean(mean_DSC3D),mean_DSC3D[0], mean_DSC3D[1], mean_DSC3D[2], mean_DSC3D[3]))

                print("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Saving best model..... ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~")
                if not os.path.exists(model_dir):
                    os.makedirs(model_dir)
                torch.save(net.state_dict(), os.path.join(model_dir, "Best_" + modelName + ".pth"),pickle_module=dill)
                reconstruct3D(modelName, 1000, isBest=True)

        print("###                                                       ###")
        print("###    Best Dice: {:.4f} at epoch {} with Dice(1): {:.4f} Dice(2): {:.4f} Dice(3): {:.4f} Dice(4): {:.4f}   ###".format(BestDice, BestEpoch, d1,d2,d3,d4))
        print("###    Best Dice in 3D: {:.4f} with Dice(1): {:.4f} Dice(2): {:.4f} Dice(3): {:.4f} Dice(4): {:.4f} ###".format(np.mean(BestDice3D),BestDice3D[0], BestDice3D[1], BestDice3D[2], BestDice3D[3] ))
        print("###                                                       ###")

        if i % (BestEpoch + 50) == 0:
            for param_group in optimizer.param_groups:
                lr = lr*0.5
                param_group['lr'] = lr
                print(' ----------  New learning Rate: {}'.format(lr))
Beispiel #2
0
def main():
    # generate data and translate labels
    train_features, train_targets = generate_all_datapoints_and_labels()
    test_features, test_targets = generate_all_datapoints_and_labels()
    train_labels, test_labels = convert_labels(train_targets), convert_labels(test_targets)


    print('*************************************************************************')
    print('*************************************************************************')
    print('*************************************************************************')
    print('*************************************************************************')
    print('*************************************************************************')
    print('Model: Linear + ReLU + Linear +ReLU + Linear + ReLU + Linear + Tanh')
    print('Loss: MSE')
    print('Optimizer: SGD')
    print('*************************************************************************')
    print('Training')
    print('*************************************************************************')
    # build network, loss and optimizer for Model 1
    my_model_design_1=[Linear(2,25), ReLU(), Linear(25,25), Dropout(p=0.5), ReLU(),
                       Linear(25,25), ReLU(),Linear(25,2),Tanh()]
    my_model_1=Sequential(my_model_design_1)
    optimizer_1=SGD(my_model_1,lr=1e-3)
    criterion_1=LossMSE()

    # train Model 1
    batch_size=1
    for epoch in range(50):
        temp_train_loss_sum=0.
        temp_test_loss_sum=0.
        num_train_correct=0
        num_test_correct=0
        
        # trained in batch-fashion: here batch size = 1
        for temp_batch in range(0,len(train_features), batch_size):
            temp_train_features=train_features.narrow(0, temp_batch, batch_size)  
            temp_train_labels=train_labels.narrow(0, temp_batch, batch_size)  
            
            for i in range(batch_size):
                # clean parameter gradient before each batch
                optimizer_1.zero_grad()  
                temp_train_feature=temp_train_features[i]
                temp_train_label=temp_train_labels[i]
                
                # forward pass to compute loss
                temp_train_pred=my_model_1.forward(temp_train_feature)
                temp_train_loss=criterion_1.forward(temp_train_pred,temp_train_label)
                temp_train_loss_sum+=temp_train_loss
                
                _, temp_train_pred_cat=torch.max(temp_train_pred,0)
                _, temp_train_label_cat=torch.max(temp_train_label,0)

                
                if temp_train_pred_cat==temp_train_label_cat:
                    num_train_correct+=1
  
                # calculate gradient according to loss gradient
                temp_train_loss_grad=criterion_1.backward(temp_train_pred,temp_train_label)
                # accumulate parameter gradient in each batch
                my_model_1.backward(temp_train_loss_grad)                       
            
            # update parameters by optimizer
            optimizer_1.step()
            
            
        # evaluate the current model on testing set
        # only forward pass is implemented
        for i_test in range(len(test_features)):
            temp_test_feature=test_features[i_test]
            temp_test_label=test_labels[i_test]

            temp_test_pred=my_model_1.forward(temp_test_feature)
            temp_test_loss=criterion_1.forward(temp_test_pred,temp_test_label)
            temp_test_loss_sum+=temp_test_loss

            
            _, temp_test_pred_cat=torch.max(temp_test_pred,0)
            _, temp_test_label_cat=torch.max(temp_test_label,0)

            if temp_test_pred_cat==temp_test_label_cat:
                num_test_correct+=1
            
            
        temp_train_loss_mean=temp_train_loss_sum/len(train_features)
        temp_test_loss_mean=temp_test_loss_sum/len(test_features)
        
        temp_train_accuracy=num_train_correct/len(train_features)
        temp_test_accuracy=num_test_correct/len(test_features)
        
        print("Epoch: {}/{}..".format(epoch+1, 50),
                      "Training Loss: {:.4f}..".format(temp_train_loss_mean),
                      "Training Accuracy: {:.4f}..".format(temp_train_accuracy), 
                      "Validation/Test Loss: {:.4f}..".format(temp_test_loss_mean),
                      "Validation/Test Accuracy: {:.4f}..".format(temp_test_accuracy),  )
        
        
        
    # # visualize the classification performance of Model 1 on testing set
    test_pred_labels_1=[]
    for i in range(1000): 
        temp_test_feature=test_features[i]
        temp_test_label=test_labels[i]

        temp_test_pred=my_model_1.forward(temp_test_feature)

        _, temp_train_pred_cat=torch.max(temp_test_pred,0)
        if test_targets[i].int() == temp_train_pred_cat.int():
            test_pred_labels_1.append(int(test_targets[i]))
        else:
            test_pred_labels_1.append(2)
            
    fig,axes = plt.subplots(1,1,figsize=(6,6))
    axes.scatter(test_features[:,0], test_features[:,1], c=test_pred_labels_1)
    axes.set_title('Classification Performance of Model 1')
    plt.show()
                      
      
    print('*************************************************************************')
    print('*************************************************************************')
    print('*************************************************************************')
    print('*************************************************************************')
    print('*************************************************************************')
    print('Model: Linear + ReLU + Linear + Dropout+ SeLU + Linear + Dropout + ReLU + Linear + Sigmoid')
    print('Loss: Cross Entropy')
    print('Optimizer: Adam')
    print('*************************************************************************')
    print('Training')
    print('*************************************************************************')
    
    # build network, loss function and optimizer for Model 2
    my_model_design_2=[Linear(2,25), ReLU(), Linear(25,25), Dropout(p=0.5), SeLU(),
                       Linear(25,25),Dropout(p=0.5), ReLU(),Linear(25,2),
                       Sigmoid()]
    my_model_2=Sequential(my_model_design_2)
    optimizer_2=Adam(my_model_2,lr=1e-3)
    criterion_2=CrossEntropy()

    # train Model 2
    batch_size=1
    epoch=0
    while(epoch<25):
        temp_train_loss_sum=0.
        temp_test_loss_sum=0.
        num_train_correct=0
        num_test_correct=0
        
        # trained in batch-fashion: here batch size = 1
        for temp_batch in range(0,len(train_features), batch_size):
            temp_train_features=train_features.narrow(0, temp_batch, batch_size)  
            temp_train_labels=train_labels.narrow(0, temp_batch, batch_size)  
            
            for i in range(batch_size):
                # clean parameter gradient before each batch
                optimizer_2.zero_grad()  
                temp_train_feature=temp_train_features[i]
                temp_train_label=temp_train_labels[i]
                
                # forward pass to compute loss
                temp_train_pred=my_model_2.forward(temp_train_feature)
                temp_train_loss=criterion_2.forward(temp_train_pred,temp_train_label)
                temp_train_loss_sum+=temp_train_loss
                
                _, temp_train_pred_cat=torch.max(temp_train_pred,0)
                _, temp_train_label_cat=torch.max(temp_train_label,0)

                
                if temp_train_pred_cat==temp_train_label_cat:
                    num_train_correct+=1
       
                
                # calculate gradient according to loss gradient
                temp_train_loss_grad=criterion_2.backward(temp_train_pred,temp_train_label)
                '''
                if (not temp_train_loss_grad[0]>=0) and (not temp_train_loss_grad[0]<0):
                    continue
                '''
                # accumulate parameter gradient in each batch
                my_model_2.backward(temp_train_loss_grad)     
                
            # update parameters by optimizer
            optimizer_2.step()
            
        # evaluate the current model on testing set
        # only forward pass is implemented
        for i_test in range(len(test_features)):
            temp_test_feature=test_features[i_test]
            temp_test_label=test_labels[i_test]

            temp_test_pred=my_model_2.forward(temp_test_feature)
            temp_test_loss=criterion_2.forward(temp_test_pred,temp_test_label)
            temp_test_loss_sum+=temp_test_loss

            
            _, temp_test_pred_cat=torch.max(temp_test_pred,0)
            _, temp_test_label_cat=torch.max(temp_test_label,0)

            if temp_test_pred_cat==temp_test_label_cat:
                num_test_correct+=1
            
            
        temp_train_loss_mean=temp_train_loss_sum/len(train_features)
        temp_test_loss_mean=temp_test_loss_sum/len(test_features)
        
        temp_train_accuracy=num_train_correct/len(train_features)
        temp_test_accuracy=num_test_correct/len(test_features)
        
        # in case there is gradient explosion problem, initiliza model again and restart training
        # but the situation seldom happens
        if (not temp_train_loss_grad[0]>=0) and (not temp_train_loss_grad[0]<0):
            epoch=0
            my_model_design_2=[Linear(2,25), ReLU(), Linear(25,25), Dropout(p=0.5), ReLU(),
                       Linear(25,25),Dropout(p=0.5), ReLU(),Linear(25,2),Sigmoid()]
            my_model_2=Sequential(my_model_design_2)
            optimizer_2=Adam(my_model_2,lr=1e-3)
            criterion_2=CrossEntropy()
            print('--------------------------------------------------------------------------------')
            print('--------------------------------------------------------------------------------')
            print('--------------------------------------------------------------------------------')
            print('--------------------------------------------------------------------------------')
            print('--------------------------------------------------------------------------------')
            print('Restart training because of gradient explosion')
            continue
        
        print("Epoch: {}/{}..".format(epoch+1, 25),
                      "Training Loss: {:.4f}..".format(temp_train_loss_mean),
                      "Training Accuracy: {:.4f}..".format(temp_train_accuracy), 
                      "Validation/Test Loss: {:.4f}..".format(temp_test_loss_mean),
                      "Validation/Test Accuracy: {:.4f}..".format(temp_test_accuracy),  )
        epoch+=1 
        
    # visualize the classification performance of Model 2 on testing set
    test_pred_labels_2=[]
    for i in range(1000): 
        temp_test_feature=test_features[i]
        temp_test_label=test_labels[i]

        temp_test_pred=my_model_2.forward(temp_test_feature)

        _, temp_train_pred_cat=torch.max(temp_test_pred,0)
        if test_targets[i].int() == temp_train_pred_cat.int():
            test_pred_labels_2.append(int(test_targets[i]))
        else:
            test_pred_labels_2.append(2)
            
    fig,axes = plt.subplots(1,1,figsize=(6,6))
    axes.scatter(test_features[:,0], test_features[:,1], c=test_pred_labels_2)
    axes.set_title('Classification Performance of Model 2')
    plt.show()
Beispiel #3
0
def runTraining():
    print('-' * 40)
    print('~~~~~~~~  Starting the training... ~~~~~~')
    print('-' * 40)

    batch_size = 4
    batch_size_val = 1
    batch_size_val_save = 1
    batch_size_val_savePng = 4
    lr = 0.0001
    epoch = 1000
    root_dir = '../DataSet/Bladder_Aug'
    modelName = 'UNetG_Dilated_Progressive'
    model_dir = 'model'

    transform = transforms.Compose([transforms.ToTensor()])

    mask_transform = transforms.Compose([transforms.ToTensor()])

    train_set = medicalDataLoader.MedicalImageDataset(
        'train',
        root_dir,
        transform=transform,
        mask_transform=mask_transform,
        augment=False,
        equalize=False)

    train_loader = DataLoader(train_set,
                              batch_size=batch_size,
                              num_workers=5,
                              shuffle=True)

    val_set = medicalDataLoader.MedicalImageDataset(
        'val',
        root_dir,
        transform=transform,
        mask_transform=mask_transform,
        equalize=False)

    val_loader = DataLoader(val_set,
                            batch_size=batch_size_val,
                            num_workers=5,
                            shuffle=False)

    val_loader_save_images = DataLoader(val_set,
                                        batch_size=batch_size_val_save,
                                        num_workers=5,
                                        shuffle=False)

    val_loader_save_imagesPng = DataLoader(val_set,
                                           batch_size=batch_size_val_savePng,
                                           num_workers=5,
                                           shuffle=False)
    # Initialize
    print("~~~~~~~~~~~ Creating the model ~~~~~~~~~~")
    num_classes = 4

    initial_kernels = 32

    # Load network
    netG = UNetG_Dilated_Progressive(1, initial_kernels, num_classes)
    softMax = nn.Softmax()
    CE_loss = nn.CrossEntropyLoss()
    Dice_loss = computeDiceOneHot()

    if torch.cuda.is_available():
        netG.cuda()
        softMax.cuda()
        CE_loss.cuda()
        Dice_loss.cuda()
    '''try:
        netG = torch.load('./model/Best_UNetG_Dilated_Progressive_Stride_Residual_ChannelsFirst32.pkl')
        print("--------model restored--------")
    except:
        print("--------model not restored--------")
        pass'''

    optimizerG = Adam(netG.parameters(),
                      lr=lr,
                      betas=(0.9, 0.99),
                      amsgrad=False)

    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizerG,
                                                           mode='max',
                                                           patience=4,
                                                           verbose=True,
                                                           factor=10**-0.5)

    BestDice, BestEpoch = 0, 0

    d1Train = []
    d2Train = []
    d3Train = []
    d1Val = []
    d2Val = []
    d3Val = []

    Losses = []
    Losses1 = []
    Losses05 = []
    Losses025 = []
    Losses0125 = []

    print("~~~~~~~~~~~ Starting the training ~~~~~~~~~~")
    for i in range(epoch):
        netG.train()
        lossVal = []
        lossValD = []
        lossVal1 = []
        lossVal05 = []
        lossVal025 = []
        lossVal0125 = []

        d1TrainTemp = []
        d2TrainTemp = []
        d3TrainTemp = []

        timesAll = []
        success = 0
        totalImages = len(train_loader)

        for j, data in enumerate(train_loader):
            image, labels, img_names = data

            # prevent batchnorm error for batch of size 1
            if image.size(0) != batch_size:
                continue

            optimizerG.zero_grad()
            MRI = to_var(image)
            Segmentation = to_var(labels)

            target_dice = to_var(torch.ones(1))

            ################### Train ###################
            netG.zero_grad()

            deepSupervision = False
            multiTask = False

            start_time = time.time()
            if deepSupervision == False and multiTask == False:
                # No deep supervision
                segmentation_prediction = netG(MRI)
            else:
                # Deep supervision
                if deepSupervision == True:
                    segmentation_prediction, seg_3, seg_2, seg_1 = netG(MRI)
                else:
                    segmentation_prediction, reg_output = netG(MRI)
                    # Regression
                    feats = getValuesRegression(labels)

                    feats_t = torch.from_numpy(feats).float()
                    featsVar = to_var(feats_t)

                    MSE_loss_val = MSE_loss(reg_output, featsVar)

            predClass_y = softMax(segmentation_prediction)

            spentTime = time.time() - start_time

            timesAll.append(spentTime / batch_size)

            Segmentation_planes = getOneHotSegmentation(Segmentation)
            segmentation_prediction_ones = predToSegmentation(predClass_y)

            # It needs the logits, not the softmax
            Segmentation_class = getTargetSegmentation(Segmentation)

            # No deep supervision
            CE_lossG = CE_loss(segmentation_prediction, Segmentation_class)
            if deepSupervision == True:

                imageLabels_05 = resizeTensorMaskInSingleImage(
                    Segmentation_class, 2)
                imageLabels_025 = resizeTensorMaskInSingleImage(
                    Segmentation_class, 4)
                imageLabels_0125 = resizeTensorMaskInSingleImage(
                    Segmentation_class, 8)

                CE_lossG_3 = CE_loss(seg_3, imageLabels_05)
                CE_lossG_2 = CE_loss(seg_2, imageLabels_025)
                CE_lossG_1 = CE_loss(seg_1, imageLabels_0125)
            '''weight = torch.ones(4).cuda() # Num classes
            weight[0] = 0.2
            weight[1] = 0.2
            weight[2] = 1
            weight[3] = 1
            
            CE_loss.weight = weight'''

            # Dice loss
            DicesN, DicesB, DicesW, DicesT = Dice_loss(
                segmentation_prediction_ones, Segmentation_planes)
            DiceN = DicesToDice(DicesN)
            DiceB = DicesToDice(DicesB)
            DiceW = DicesToDice(DicesW)
            DiceT = DicesToDice(DicesT)

            Dice_score = (DiceB + DiceW + DiceT) / 3

            if deepSupervision == False and multiTask == False:
                lossG = CE_lossG
            else:
                # Deep supervision
                if deepSupervision == True:
                    lossG = CE_lossG + 0.25 * CE_lossG_3 + 0.1 * CE_lossG_2 + 0.1 * CE_lossG_1
                else:
                    lossG = CE_lossG + 0.000001 * MSE_loss_val

            lossG.backward()
            optimizerG.step()

            lossVal.append(lossG.data[0])
            lossVal1.append(CE_lossG.data[0])

            if deepSupervision == True:
                lossVal05.append(CE_lossG_3.data[0])
                lossVal025.append(CE_lossG_2.data[0])
                lossVal0125.append(CE_lossG_1.data[0])

            printProgressBar(
                j + 1,
                totalImages,
                prefix="[Training] Epoch: {} ".format(i),
                length=15,
                suffix=
                " Mean Dice: {:.4f}, Dice1: {:.4f} , Dice2: {:.4f}, , Dice3: {:.4f} "
                .format(Dice_score.data[0], DiceB.data[0], DiceW.data[0],
                        DiceT.data[0]))

        if deepSupervision == False:
            '''printProgressBar(totalImages, totalImages,
                             done="[Training] Epoch: {}, LossG: {:.4f},".format(i,np.mean(lossVal),np.mean(lossVal1)))'''
            printProgressBar(
                totalImages,
                totalImages,
                done="[Training] Epoch: {}, LossG: {:.4f}, lossMSE: {:.4f}".
                format(i, np.mean(lossVal), np.mean(lossVal1)))
        else:
            printProgressBar(
                totalImages,
                totalImages,
                done=
                "[Training] Epoch: {}, LossG: {:.4f}, Loss4: {:.4f}, Loss3: {:.4f}, Loss2: {:.4f}, Loss1: {:.4f}"
                .format(i, np.mean(lossVal), np.mean(lossVal1),
                        np.mean(lossVal05), np.mean(lossVal025),
                        np.mean(lossVal0125)))

        Losses.append(np.mean(lossVal))

        d1, d2, d3 = inference(netG, val_loader, batch_size, i,
                               deepSupervision)

        d1Val.append(d1)
        d2Val.append(d2)
        d3Val.append(d3)

        d1Train.append(np.mean(d1TrainTemp).data[0])
        d2Train.append(np.mean(d2TrainTemp).data[0])
        d3Train.append(np.mean(d3TrainTemp).data[0])

        mainPath = '../Results/Statistics/' + modelName

        directory = mainPath
        if not os.path.exists(directory):
            os.makedirs(directory)

        ###### Save statistics  ######
        np.save(os.path.join(directory, 'Losses.npy'), Losses)

        np.save(os.path.join(directory, 'd1Val.npy'), d1Val)
        np.save(os.path.join(directory, 'd2Val.npy'), d2Val)
        np.save(os.path.join(directory, 'd3Val.npy'), d3Val)

        np.save(os.path.join(directory, 'd1Train.npy'), d1Train)
        np.save(os.path.join(directory, 'd2Train.npy'), d2Train)
        np.save(os.path.join(directory, 'd3Train.npy'), d3Train)

        currentDice = (d1 + d2 + d3) / 3

        # How many slices with/without tumor correctly classified
        print("[val] DSC: (1): {:.4f} (2): {:.4f}  (3): {:.4f} ".format(
            d1, d2, d3))

        if currentDice > BestDice:
            BestDice = currentDice
            BestDiceT = d1
            BestEpoch = i
            if currentDice > 0.7:
                print(
                    "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Saving best model..... ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~"
                )
                if not os.path.exists(model_dir):
                    os.makedirs(model_dir)
                torch.save(
                    netG, os.path.join(model_dir,
                                       "Best_" + modelName + ".pkl"))

                # Save images
                saveImages(netG, val_loader_save_images, batch_size_val_save,
                           i, modelName, deepSupervision)
                saveImagesAsMatlab(netG, val_loader_save_images,
                                   batch_size_val_save, i, modelName,
                                   deepSupervision)

        print("###                                                       ###")
        print("###    Best Dice: {:.4f} at epoch {} with DiceT: {:.4f}    ###".
              format(BestDice, BestEpoch, BestDiceT))
        print("###                                                       ###")

        # This is not as we did it in the MedPhys paper
        if i % (BestEpoch + 20):
            for param_group in optimizerG.param_groups:
                param_group['lr'] = lr / 2
Beispiel #4
0
def runTraining():
    print('-' * 40)
    print('~~~~~~~~  Starting the training... ~~~~~~')
    print('-' * 40)

    batch_size = 4
    batch_size_val = 1
    batch_size_val_save = 1

    lr = 0.0001
    epoch = 200
    num_classes = 2
    initial_kernels = 32

    modelName = 'IVD_Net'

    img_names_ALL = []
    print('.' * 40)
    print(" ....Model name: {} ........".format(modelName))

    print(' - Num. classes: {}'.format(num_classes))
    print(' - Num. initial kernels: {}'.format(initial_kernels))
    print(' - Batch size: {}'.format(batch_size))
    print(' - Learning rate: {}'.format(lr))
    print(' - Num. epochs: {}'.format(epoch))

    print('.' * 40)
    root_dir = '../Data/Training_PngITK'
    model_dir = 'IVD_Net'

    transform = transforms.Compose([transforms.ToTensor()])

    mask_transform = transforms.Compose([transforms.ToTensor()])

    train_set = medicalDataLoader.MedicalImageDataset(
        'train',
        root_dir,
        transform=transform,
        mask_transform=mask_transform,
        augment=False,
        equalize=False)

    train_loader = DataLoader(train_set,
                              batch_size=batch_size,
                              num_workers=5,
                              shuffle=True)

    val_set = medicalDataLoader.MedicalImageDataset(
        'val',
        root_dir,
        transform=transform,
        mask_transform=mask_transform,
        equalize=False)

    val_loader = DataLoader(val_set,
                            batch_size=batch_size_val,
                            num_workers=5,
                            shuffle=False)

    val_loader_save_images = DataLoader(val_set,
                                        batch_size=batch_size_val_save,
                                        num_workers=5,
                                        shuffle=False)
    # Initialize
    print("~~~~~~~~~~~ Creating the model ~~~~~~~~~~")

    net = IVD_Net_asym(1, num_classes, initial_kernels)

    # Initialize the weights
    net.apply(weights_init)

    softMax = nn.Softmax()
    CE_loss = nn.CrossEntropyLoss()
    Dice_ = computeDiceOneHotBinary()

    if torch.cuda.is_available():
        net.cuda()
        softMax.cuda()
        CE_loss.cuda()
        Dice_.cuda()

    # To load a pre-trained model
    '''try:
        net = torch.load('modelName')
        print("--------model restored--------")
    except:
        print("--------model not restored--------")
        pass'''

    optimizer = Adam(net.parameters(), lr=lr, betas=(0.9, 0.99), amsgrad=False)

    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                           mode='max',
                                                           patience=4,
                                                           verbose=True,
                                                           factor=10**-0.5)

    BestDice, BestEpoch = 0, 0

    d1Train = []
    d1Val = []
    Losses = []

    print("~~~~~~~~~~~ Starting the training ~~~~~~~~~~")
    for i in range(epoch):
        net.train()
        lossTrain = []
        d1TrainTemp = []

        totalImages = len(train_loader)

        for j, data in enumerate(train_loader):

            image_f, image_i, image_o, image_w, labels, img_names = data

            # Be sure your data here is between [0,1]
            image_f = image_f.type(torch.FloatTensor)
            image_i = image_i.type(torch.FloatTensor)
            image_o = image_o.type(torch.FloatTensor)
            image_w = image_w.type(torch.FloatTensor)

            labels = labels.numpy()
            idx = np.where(labels > 0.0)
            labels[idx] = 1.0
            labels = torch.from_numpy(labels)
            labels = labels.type(torch.FloatTensor)

            optimizer.zero_grad()
            MRI = to_var(torch.cat((image_f, image_i, image_o, image_w),
                                   dim=1))

            Segmentation = to_var(labels)

            target_dice = to_var(torch.ones(1))

            net.zero_grad()

            segmentation_prediction = net(MRI)
            predClass_y = softMax(segmentation_prediction)

            Segmentation_planes = getOneHotSegmentation(Segmentation)
            segmentation_prediction_ones = predToSegmentation(predClass_y)

            # It needs the logits, not the softmax
            Segmentation_class = getTargetSegmentation(Segmentation)

            CE_loss_ = CE_loss(segmentation_prediction, Segmentation_class)

            # Compute the Dice (so far in a 2D-basis)
            DicesB, DicesF = Dice_(segmentation_prediction_ones,
                                   Segmentation_planes)
            DiceB = DicesToDice(DicesB)
            DiceF = DicesToDice(DicesF)

            loss = CE_loss_

            loss.backward()
            optimizer.step()

            lossTrain.append(loss.data[0])

            printProgressBar(j + 1,
                             totalImages,
                             prefix="[Training] Epoch: {} ".format(i),
                             length=15,
                             suffix=" Mean Dice: {:.4f},".format(
                                 DiceF.data[0]))

        printProgressBar(totalImages,
                         totalImages,
                         done="[Training] Epoch: {}, LossG: {:.4f}".format(
                             i, np.mean(lossTrain)))
        # Save statistics
        Losses.append(np.mean(lossTrain))
        d1 = inference(net, val_loader, batch_size, i)
        d1Val.append(d1)
        d1Train.append(np.mean(d1TrainTemp).data[0])

        mainPath = '../Results/Statistics/' + modelName

        directory = mainPath
        if not os.path.exists(directory):
            os.makedirs(directory)

        np.save(os.path.join(directory, 'Losses.npy'), Losses)
        np.save(os.path.join(directory, 'd1Val.npy'), d1Val)
        np.save(os.path.join(directory, 'd1Train.npy'), d1Train)

        currentDice = d1[0].numpy()

        print("[val] DSC: {:.4f} ".format(d1[0]))

        if currentDice > BestDice:
            BestDice = currentDice

            BestEpoch = i
            if currentDice > 0.75:
                print(
                    "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Saving best model..... ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~"
                )
                if not os.path.exists(model_dir):
                    os.makedirs(model_dir)
                torch.save(
                    net, os.path.join(model_dir, "Best_" + modelName + ".pkl"))
                saveImages(net, val_loader_save_images, batch_size_val_save, i,
                           modelName)

        # Two ways of decay the learning rate:
        if i % (BestEpoch + 10):
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr