def runTraining(args): print('-' * 40) print('~~~~~~~~ Starting the training... ~~~~~~') print('-' * 40) batch_size = args.batch_size batch_size_val = 1 batch_size_val_save = 1 lr = args.lr epoch = args.epochs root_dir = './DataSet_Challenge/Val_1' model_dir = 'model' print(' Dataset: {} '.format(root_dir)) transform = transforms.Compose([ transforms.ToTensor() ]) mask_transform = transforms.Compose([ transforms.ToTensor() ]) train_set = medicalDataLoader.MedicalImageDataset('train', root_dir, transform=transform, mask_transform=mask_transform, augment=True, equalize=False) train_loader = DataLoader(train_set, batch_size=batch_size, num_workers=5, shuffle=True) val_set = medicalDataLoader.MedicalImageDataset('val', root_dir, transform=transform, mask_transform=mask_transform, equalize=False) val_loader = DataLoader(val_set, batch_size=batch_size_val, num_workers=5, shuffle=False) val_loader_save_images = DataLoader(val_set, batch_size=batch_size_val_save, num_workers=4, shuffle=False) # Initialize print("~~~~~~~~~~~ Creating the DAF Stacked model ~~~~~~~~~~") net = DAF_stack() print(" Model Name: {}".format(args.modelName)) print(" Model ot create: DAF_Stacked") net.apply(weights_init) softMax = nn.Softmax() CE_loss = nn.CrossEntropyLoss() Dice_loss = computeDiceOneHot() mseLoss = nn.MSELoss() if torch.cuda.is_available(): net.cuda() softMax.cuda() CE_loss.cuda() Dice_loss.cuda() optimizer = Adam(net.parameters(), lr=lr, betas=(0.9, 0.99), amsgrad=False) BestDice, BestEpoch = 0, 0 BestDice3D = [0,0,0,0] d1Val = [] d2Val = [] d3Val = [] d4Val = [] d1Val_3D = [] d2Val_3D = [] d3Val_3D = [] d4Val_3D = [] d1Val_3D_std = [] d2Val_3D_std = [] d3Val_3D_std = [] d4Val_3D_std = [] Losses = [] print("~~~~~~~~~~~ Starting the training ~~~~~~~~~~") for i in range(epoch): net.train() lossVal = [] totalImages = len(train_loader) for j, data in enumerate(train_loader): image, labels, img_names = data # prevent batchnorm error for batch of size 1 if image.size(0) != batch_size: continue optimizer.zero_grad() MRI = to_var(image) Segmentation = to_var(labels) ################### Train ################### net.zero_grad() # Network outputs semVector_1_1, \ semVector_2_1, \ semVector_1_2, \ semVector_2_2, \ semVector_1_3, \ semVector_2_3, \ semVector_1_4, \ semVector_2_4, \ inp_enc0, \ inp_enc1, \ inp_enc2, \ inp_enc3, \ inp_enc4, \ inp_enc5, \ inp_enc6, \ inp_enc7, \ out_enc0, \ out_enc1, \ out_enc2, \ out_enc3, \ out_enc4, \ out_enc5, \ out_enc6, \ out_enc7, \ outputs0, \ outputs1, \ outputs2, \ outputs3, \ outputs0_2, \ outputs1_2, \ outputs2_2, \ outputs3_2 = net(MRI) segmentation_prediction = (outputs0 + outputs1 + outputs2 + outputs3 + outputs0_2 + outputs1_2 + outputs2_2 + outputs3_2) / 8 predClass_y = softMax(segmentation_prediction) Segmentation_planes = getOneHotSegmentation(Segmentation) segmentation_prediction_ones = predToSegmentation(predClass_y) # It needs the logits, not the softmax Segmentation_class = getTargetSegmentation(Segmentation) # Cross-entropy loss loss0 = CE_loss(outputs0, Segmentation_class) loss1 = CE_loss(outputs1, Segmentation_class) loss2 = CE_loss(outputs2, Segmentation_class) loss3 = CE_loss(outputs3, Segmentation_class) loss0_2 = CE_loss(outputs0_2, Segmentation_class) loss1_2 = CE_loss(outputs1_2, Segmentation_class) loss2_2 = CE_loss(outputs2_2, Segmentation_class) loss3_2 = CE_loss(outputs3_2, Segmentation_class) lossSemantic1 = mseLoss(semVector_1_1, semVector_2_1) lossSemantic2 = mseLoss(semVector_1_2, semVector_2_2) lossSemantic3 = mseLoss(semVector_1_3, semVector_2_3) lossSemantic4 = mseLoss(semVector_1_4, semVector_2_4) lossRec0 = mseLoss(inp_enc0, out_enc0) lossRec1 = mseLoss(inp_enc1, out_enc1) lossRec2 = mseLoss(inp_enc2, out_enc2) lossRec3 = mseLoss(inp_enc3, out_enc3) lossRec4 = mseLoss(inp_enc4, out_enc4) lossRec5 = mseLoss(inp_enc5, out_enc5) lossRec6 = mseLoss(inp_enc6, out_enc6) lossRec7 = mseLoss(inp_enc7, out_enc7) lossG = loss0 + loss1 + loss2 + loss3 + loss0_2 + loss1_2 + loss2_2 + loss3_2 + 0.25 * ( lossSemantic1 + lossSemantic2 + lossSemantic3 + lossSemantic4) \ + 0.1 * (lossRec0 + lossRec1 + lossRec2 + lossRec3 + lossRec4 + lossRec5 + lossRec6 + lossRec7) # CE_lossG # Compute the DSC DicesN, DicesB, DicesW, DicesT, DicesZ = Dice_loss(segmentation_prediction_ones, Segmentation_planes) DiceB = DicesToDice(DicesB) DiceW = DicesToDice(DicesW) DiceT = DicesToDice(DicesT) DiceZ = DicesToDice(DicesZ) Dice_score = (DiceB + DiceW + DiceT+ DiceZ) / 4 lossG.backward() optimizer.step() lossVal.append(lossG.cpu().data.numpy()) printProgressBar(j + 1, totalImages, prefix="[Training] Epoch: {} ".format(i), length=15, suffix=" Mean Dice: {:.4f}, Dice1: {:.4f} , Dice2: {:.4f}, , Dice3: {:.4f}, Dice4: {:.4f} ".format( Dice_score.cpu().data.numpy(), DiceB.data.cpu().data.numpy(), DiceW.data.cpu().data.numpy(), DiceT.data.cpu().data.numpy(), DiceZ.data.cpu().data.numpy(),)) printProgressBar(totalImages, totalImages, done="[Training] Epoch: {}, LossG: {:.4f}".format(i,np.mean(lossVal))) # Save statistics modelName = args.modelName directory = 'Results/Statistics/' + modelName Losses.append(np.mean(lossVal)) d1,d2,d3,d4 = inference(net, val_loader) d1Val.append(d1) d2Val.append(d2) d3Val.append(d3) d4Val.append(d4) if not os.path.exists(directory): os.makedirs(directory) np.save(os.path.join(directory, 'Losses.npy'), Losses) np.save(os.path.join(directory, 'd1Val.npy'), d1Val) np.save(os.path.join(directory, 'd2Val.npy'), d2Val) np.save(os.path.join(directory, 'd3Val.npy'), d3Val) currentDice = (d1+d2+d3+d4)/4 print("[val] DSC: (1): {:.4f} (2): {:.4f} (3): {:.4f} (4): {:.4f}".format(d1,d2,d3,d4)) # MRI currentDice = currentDice.data.numpy() # Evaluate on 3D saveImages_for3D(net, val_loader_save_images, batch_size_val_save, 1000, modelName, False, False) reconstruct3D(modelName, 1000, isBest=False) DSC_3D = evaluate3D(modelName) mean_DSC3D = np.mean(DSC_3D, 0) std_DSC3D = np.std(DSC_3D,0) d1Val_3D.append(mean_DSC3D[0]) d2Val_3D.append(mean_DSC3D[1]) d3Val_3D.append(mean_DSC3D[2]) d4Val_3D.append(mean_DSC3D[3]) d1Val_3D_std.append(std_DSC3D[0]) d2Val_3D_std.append(std_DSC3D[1]) d3Val_3D_std.append(std_DSC3D[2]) d4Val_3D_std.append(std_DSC3D[3]) np.save(os.path.join(directory, 'd0Val_3D.npy'), d1Val_3D) np.save(os.path.join(directory, 'd1Val_3D.npy'), d2Val_3D) np.save(os.path.join(directory, 'd2Val_3D.npy'), d3Val_3D) np.save(os.path.join(directory, 'd3Val_3D.npy'), d4Val_3D) np.save(os.path.join(directory, 'd0Val_3D_std.npy'), d1Val_3D_std) np.save(os.path.join(directory, 'd1Val_3D_std.npy'), d2Val_3D_std) np.save(os.path.join(directory, 'd2Val_3D_std.npy'), d3Val_3D_std) np.save(os.path.join(directory, 'd3Val_3D_std.npy'), d4Val_3D_std) if currentDice > BestDice: BestDice = currentDice BestEpoch = i if currentDice > 0.40: if np.mean(mean_DSC3D)>np.mean(BestDice3D): BestDice3D = mean_DSC3D print("### In 3D -----> MEAN: {}, Dice(1): {:.4f} Dice(2): {:.4f} Dice(3): {:.4f} Dice(4): {:.4f} ###".format(np.mean(mean_DSC3D),mean_DSC3D[0], mean_DSC3D[1], mean_DSC3D[2], mean_DSC3D[3])) print("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Saving best model..... ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~") if not os.path.exists(model_dir): os.makedirs(model_dir) torch.save(net.state_dict(), os.path.join(model_dir, "Best_" + modelName + ".pth"),pickle_module=dill) reconstruct3D(modelName, 1000, isBest=True) print("### ###") print("### Best Dice: {:.4f} at epoch {} with Dice(1): {:.4f} Dice(2): {:.4f} Dice(3): {:.4f} Dice(4): {:.4f} ###".format(BestDice, BestEpoch, d1,d2,d3,d4)) print("### Best Dice in 3D: {:.4f} with Dice(1): {:.4f} Dice(2): {:.4f} Dice(3): {:.4f} Dice(4): {:.4f} ###".format(np.mean(BestDice3D),BestDice3D[0], BestDice3D[1], BestDice3D[2], BestDice3D[3] )) print("### ###") if i % (BestEpoch + 50) == 0: for param_group in optimizer.param_groups: lr = lr*0.5 param_group['lr'] = lr print(' ---------- New learning Rate: {}'.format(lr))
def main(): # generate data and translate labels train_features, train_targets = generate_all_datapoints_and_labels() test_features, test_targets = generate_all_datapoints_and_labels() train_labels, test_labels = convert_labels(train_targets), convert_labels(test_targets) print('*************************************************************************') print('*************************************************************************') print('*************************************************************************') print('*************************************************************************') print('*************************************************************************') print('Model: Linear + ReLU + Linear +ReLU + Linear + ReLU + Linear + Tanh') print('Loss: MSE') print('Optimizer: SGD') print('*************************************************************************') print('Training') print('*************************************************************************') # build network, loss and optimizer for Model 1 my_model_design_1=[Linear(2,25), ReLU(), Linear(25,25), Dropout(p=0.5), ReLU(), Linear(25,25), ReLU(),Linear(25,2),Tanh()] my_model_1=Sequential(my_model_design_1) optimizer_1=SGD(my_model_1,lr=1e-3) criterion_1=LossMSE() # train Model 1 batch_size=1 for epoch in range(50): temp_train_loss_sum=0. temp_test_loss_sum=0. num_train_correct=0 num_test_correct=0 # trained in batch-fashion: here batch size = 1 for temp_batch in range(0,len(train_features), batch_size): temp_train_features=train_features.narrow(0, temp_batch, batch_size) temp_train_labels=train_labels.narrow(0, temp_batch, batch_size) for i in range(batch_size): # clean parameter gradient before each batch optimizer_1.zero_grad() temp_train_feature=temp_train_features[i] temp_train_label=temp_train_labels[i] # forward pass to compute loss temp_train_pred=my_model_1.forward(temp_train_feature) temp_train_loss=criterion_1.forward(temp_train_pred,temp_train_label) temp_train_loss_sum+=temp_train_loss _, temp_train_pred_cat=torch.max(temp_train_pred,0) _, temp_train_label_cat=torch.max(temp_train_label,0) if temp_train_pred_cat==temp_train_label_cat: num_train_correct+=1 # calculate gradient according to loss gradient temp_train_loss_grad=criterion_1.backward(temp_train_pred,temp_train_label) # accumulate parameter gradient in each batch my_model_1.backward(temp_train_loss_grad) # update parameters by optimizer optimizer_1.step() # evaluate the current model on testing set # only forward pass is implemented for i_test in range(len(test_features)): temp_test_feature=test_features[i_test] temp_test_label=test_labels[i_test] temp_test_pred=my_model_1.forward(temp_test_feature) temp_test_loss=criterion_1.forward(temp_test_pred,temp_test_label) temp_test_loss_sum+=temp_test_loss _, temp_test_pred_cat=torch.max(temp_test_pred,0) _, temp_test_label_cat=torch.max(temp_test_label,0) if temp_test_pred_cat==temp_test_label_cat: num_test_correct+=1 temp_train_loss_mean=temp_train_loss_sum/len(train_features) temp_test_loss_mean=temp_test_loss_sum/len(test_features) temp_train_accuracy=num_train_correct/len(train_features) temp_test_accuracy=num_test_correct/len(test_features) print("Epoch: {}/{}..".format(epoch+1, 50), "Training Loss: {:.4f}..".format(temp_train_loss_mean), "Training Accuracy: {:.4f}..".format(temp_train_accuracy), "Validation/Test Loss: {:.4f}..".format(temp_test_loss_mean), "Validation/Test Accuracy: {:.4f}..".format(temp_test_accuracy), ) # # visualize the classification performance of Model 1 on testing set test_pred_labels_1=[] for i in range(1000): temp_test_feature=test_features[i] temp_test_label=test_labels[i] temp_test_pred=my_model_1.forward(temp_test_feature) _, temp_train_pred_cat=torch.max(temp_test_pred,0) if test_targets[i].int() == temp_train_pred_cat.int(): test_pred_labels_1.append(int(test_targets[i])) else: test_pred_labels_1.append(2) fig,axes = plt.subplots(1,1,figsize=(6,6)) axes.scatter(test_features[:,0], test_features[:,1], c=test_pred_labels_1) axes.set_title('Classification Performance of Model 1') plt.show() print('*************************************************************************') print('*************************************************************************') print('*************************************************************************') print('*************************************************************************') print('*************************************************************************') print('Model: Linear + ReLU + Linear + Dropout+ SeLU + Linear + Dropout + ReLU + Linear + Sigmoid') print('Loss: Cross Entropy') print('Optimizer: Adam') print('*************************************************************************') print('Training') print('*************************************************************************') # build network, loss function and optimizer for Model 2 my_model_design_2=[Linear(2,25), ReLU(), Linear(25,25), Dropout(p=0.5), SeLU(), Linear(25,25),Dropout(p=0.5), ReLU(),Linear(25,2), Sigmoid()] my_model_2=Sequential(my_model_design_2) optimizer_2=Adam(my_model_2,lr=1e-3) criterion_2=CrossEntropy() # train Model 2 batch_size=1 epoch=0 while(epoch<25): temp_train_loss_sum=0. temp_test_loss_sum=0. num_train_correct=0 num_test_correct=0 # trained in batch-fashion: here batch size = 1 for temp_batch in range(0,len(train_features), batch_size): temp_train_features=train_features.narrow(0, temp_batch, batch_size) temp_train_labels=train_labels.narrow(0, temp_batch, batch_size) for i in range(batch_size): # clean parameter gradient before each batch optimizer_2.zero_grad() temp_train_feature=temp_train_features[i] temp_train_label=temp_train_labels[i] # forward pass to compute loss temp_train_pred=my_model_2.forward(temp_train_feature) temp_train_loss=criterion_2.forward(temp_train_pred,temp_train_label) temp_train_loss_sum+=temp_train_loss _, temp_train_pred_cat=torch.max(temp_train_pred,0) _, temp_train_label_cat=torch.max(temp_train_label,0) if temp_train_pred_cat==temp_train_label_cat: num_train_correct+=1 # calculate gradient according to loss gradient temp_train_loss_grad=criterion_2.backward(temp_train_pred,temp_train_label) ''' if (not temp_train_loss_grad[0]>=0) and (not temp_train_loss_grad[0]<0): continue ''' # accumulate parameter gradient in each batch my_model_2.backward(temp_train_loss_grad) # update parameters by optimizer optimizer_2.step() # evaluate the current model on testing set # only forward pass is implemented for i_test in range(len(test_features)): temp_test_feature=test_features[i_test] temp_test_label=test_labels[i_test] temp_test_pred=my_model_2.forward(temp_test_feature) temp_test_loss=criterion_2.forward(temp_test_pred,temp_test_label) temp_test_loss_sum+=temp_test_loss _, temp_test_pred_cat=torch.max(temp_test_pred,0) _, temp_test_label_cat=torch.max(temp_test_label,0) if temp_test_pred_cat==temp_test_label_cat: num_test_correct+=1 temp_train_loss_mean=temp_train_loss_sum/len(train_features) temp_test_loss_mean=temp_test_loss_sum/len(test_features) temp_train_accuracy=num_train_correct/len(train_features) temp_test_accuracy=num_test_correct/len(test_features) # in case there is gradient explosion problem, initiliza model again and restart training # but the situation seldom happens if (not temp_train_loss_grad[0]>=0) and (not temp_train_loss_grad[0]<0): epoch=0 my_model_design_2=[Linear(2,25), ReLU(), Linear(25,25), Dropout(p=0.5), ReLU(), Linear(25,25),Dropout(p=0.5), ReLU(),Linear(25,2),Sigmoid()] my_model_2=Sequential(my_model_design_2) optimizer_2=Adam(my_model_2,lr=1e-3) criterion_2=CrossEntropy() print('--------------------------------------------------------------------------------') print('--------------------------------------------------------------------------------') print('--------------------------------------------------------------------------------') print('--------------------------------------------------------------------------------') print('--------------------------------------------------------------------------------') print('Restart training because of gradient explosion') continue print("Epoch: {}/{}..".format(epoch+1, 25), "Training Loss: {:.4f}..".format(temp_train_loss_mean), "Training Accuracy: {:.4f}..".format(temp_train_accuracy), "Validation/Test Loss: {:.4f}..".format(temp_test_loss_mean), "Validation/Test Accuracy: {:.4f}..".format(temp_test_accuracy), ) epoch+=1 # visualize the classification performance of Model 2 on testing set test_pred_labels_2=[] for i in range(1000): temp_test_feature=test_features[i] temp_test_label=test_labels[i] temp_test_pred=my_model_2.forward(temp_test_feature) _, temp_train_pred_cat=torch.max(temp_test_pred,0) if test_targets[i].int() == temp_train_pred_cat.int(): test_pred_labels_2.append(int(test_targets[i])) else: test_pred_labels_2.append(2) fig,axes = plt.subplots(1,1,figsize=(6,6)) axes.scatter(test_features[:,0], test_features[:,1], c=test_pred_labels_2) axes.set_title('Classification Performance of Model 2') plt.show()
def runTraining(): print('-' * 40) print('~~~~~~~~ Starting the training... ~~~~~~') print('-' * 40) batch_size = 4 batch_size_val = 1 batch_size_val_save = 1 batch_size_val_savePng = 4 lr = 0.0001 epoch = 1000 root_dir = '../DataSet/Bladder_Aug' modelName = 'UNetG_Dilated_Progressive' model_dir = 'model' transform = transforms.Compose([transforms.ToTensor()]) mask_transform = transforms.Compose([transforms.ToTensor()]) train_set = medicalDataLoader.MedicalImageDataset( 'train', root_dir, transform=transform, mask_transform=mask_transform, augment=False, equalize=False) train_loader = DataLoader(train_set, batch_size=batch_size, num_workers=5, shuffle=True) val_set = medicalDataLoader.MedicalImageDataset( 'val', root_dir, transform=transform, mask_transform=mask_transform, equalize=False) val_loader = DataLoader(val_set, batch_size=batch_size_val, num_workers=5, shuffle=False) val_loader_save_images = DataLoader(val_set, batch_size=batch_size_val_save, num_workers=5, shuffle=False) val_loader_save_imagesPng = DataLoader(val_set, batch_size=batch_size_val_savePng, num_workers=5, shuffle=False) # Initialize print("~~~~~~~~~~~ Creating the model ~~~~~~~~~~") num_classes = 4 initial_kernels = 32 # Load network netG = UNetG_Dilated_Progressive(1, initial_kernels, num_classes) softMax = nn.Softmax() CE_loss = nn.CrossEntropyLoss() Dice_loss = computeDiceOneHot() if torch.cuda.is_available(): netG.cuda() softMax.cuda() CE_loss.cuda() Dice_loss.cuda() '''try: netG = torch.load('./model/Best_UNetG_Dilated_Progressive_Stride_Residual_ChannelsFirst32.pkl') print("--------model restored--------") except: print("--------model not restored--------") pass''' optimizerG = Adam(netG.parameters(), lr=lr, betas=(0.9, 0.99), amsgrad=False) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizerG, mode='max', patience=4, verbose=True, factor=10**-0.5) BestDice, BestEpoch = 0, 0 d1Train = [] d2Train = [] d3Train = [] d1Val = [] d2Val = [] d3Val = [] Losses = [] Losses1 = [] Losses05 = [] Losses025 = [] Losses0125 = [] print("~~~~~~~~~~~ Starting the training ~~~~~~~~~~") for i in range(epoch): netG.train() lossVal = [] lossValD = [] lossVal1 = [] lossVal05 = [] lossVal025 = [] lossVal0125 = [] d1TrainTemp = [] d2TrainTemp = [] d3TrainTemp = [] timesAll = [] success = 0 totalImages = len(train_loader) for j, data in enumerate(train_loader): image, labels, img_names = data # prevent batchnorm error for batch of size 1 if image.size(0) != batch_size: continue optimizerG.zero_grad() MRI = to_var(image) Segmentation = to_var(labels) target_dice = to_var(torch.ones(1)) ################### Train ################### netG.zero_grad() deepSupervision = False multiTask = False start_time = time.time() if deepSupervision == False and multiTask == False: # No deep supervision segmentation_prediction = netG(MRI) else: # Deep supervision if deepSupervision == True: segmentation_prediction, seg_3, seg_2, seg_1 = netG(MRI) else: segmentation_prediction, reg_output = netG(MRI) # Regression feats = getValuesRegression(labels) feats_t = torch.from_numpy(feats).float() featsVar = to_var(feats_t) MSE_loss_val = MSE_loss(reg_output, featsVar) predClass_y = softMax(segmentation_prediction) spentTime = time.time() - start_time timesAll.append(spentTime / batch_size) Segmentation_planes = getOneHotSegmentation(Segmentation) segmentation_prediction_ones = predToSegmentation(predClass_y) # It needs the logits, not the softmax Segmentation_class = getTargetSegmentation(Segmentation) # No deep supervision CE_lossG = CE_loss(segmentation_prediction, Segmentation_class) if deepSupervision == True: imageLabels_05 = resizeTensorMaskInSingleImage( Segmentation_class, 2) imageLabels_025 = resizeTensorMaskInSingleImage( Segmentation_class, 4) imageLabels_0125 = resizeTensorMaskInSingleImage( Segmentation_class, 8) CE_lossG_3 = CE_loss(seg_3, imageLabels_05) CE_lossG_2 = CE_loss(seg_2, imageLabels_025) CE_lossG_1 = CE_loss(seg_1, imageLabels_0125) '''weight = torch.ones(4).cuda() # Num classes weight[0] = 0.2 weight[1] = 0.2 weight[2] = 1 weight[3] = 1 CE_loss.weight = weight''' # Dice loss DicesN, DicesB, DicesW, DicesT = Dice_loss( segmentation_prediction_ones, Segmentation_planes) DiceN = DicesToDice(DicesN) DiceB = DicesToDice(DicesB) DiceW = DicesToDice(DicesW) DiceT = DicesToDice(DicesT) Dice_score = (DiceB + DiceW + DiceT) / 3 if deepSupervision == False and multiTask == False: lossG = CE_lossG else: # Deep supervision if deepSupervision == True: lossG = CE_lossG + 0.25 * CE_lossG_3 + 0.1 * CE_lossG_2 + 0.1 * CE_lossG_1 else: lossG = CE_lossG + 0.000001 * MSE_loss_val lossG.backward() optimizerG.step() lossVal.append(lossG.data[0]) lossVal1.append(CE_lossG.data[0]) if deepSupervision == True: lossVal05.append(CE_lossG_3.data[0]) lossVal025.append(CE_lossG_2.data[0]) lossVal0125.append(CE_lossG_1.data[0]) printProgressBar( j + 1, totalImages, prefix="[Training] Epoch: {} ".format(i), length=15, suffix= " Mean Dice: {:.4f}, Dice1: {:.4f} , Dice2: {:.4f}, , Dice3: {:.4f} " .format(Dice_score.data[0], DiceB.data[0], DiceW.data[0], DiceT.data[0])) if deepSupervision == False: '''printProgressBar(totalImages, totalImages, done="[Training] Epoch: {}, LossG: {:.4f},".format(i,np.mean(lossVal),np.mean(lossVal1)))''' printProgressBar( totalImages, totalImages, done="[Training] Epoch: {}, LossG: {:.4f}, lossMSE: {:.4f}". format(i, np.mean(lossVal), np.mean(lossVal1))) else: printProgressBar( totalImages, totalImages, done= "[Training] Epoch: {}, LossG: {:.4f}, Loss4: {:.4f}, Loss3: {:.4f}, Loss2: {:.4f}, Loss1: {:.4f}" .format(i, np.mean(lossVal), np.mean(lossVal1), np.mean(lossVal05), np.mean(lossVal025), np.mean(lossVal0125))) Losses.append(np.mean(lossVal)) d1, d2, d3 = inference(netG, val_loader, batch_size, i, deepSupervision) d1Val.append(d1) d2Val.append(d2) d3Val.append(d3) d1Train.append(np.mean(d1TrainTemp).data[0]) d2Train.append(np.mean(d2TrainTemp).data[0]) d3Train.append(np.mean(d3TrainTemp).data[0]) mainPath = '../Results/Statistics/' + modelName directory = mainPath if not os.path.exists(directory): os.makedirs(directory) ###### Save statistics ###### np.save(os.path.join(directory, 'Losses.npy'), Losses) np.save(os.path.join(directory, 'd1Val.npy'), d1Val) np.save(os.path.join(directory, 'd2Val.npy'), d2Val) np.save(os.path.join(directory, 'd3Val.npy'), d3Val) np.save(os.path.join(directory, 'd1Train.npy'), d1Train) np.save(os.path.join(directory, 'd2Train.npy'), d2Train) np.save(os.path.join(directory, 'd3Train.npy'), d3Train) currentDice = (d1 + d2 + d3) / 3 # How many slices with/without tumor correctly classified print("[val] DSC: (1): {:.4f} (2): {:.4f} (3): {:.4f} ".format( d1, d2, d3)) if currentDice > BestDice: BestDice = currentDice BestDiceT = d1 BestEpoch = i if currentDice > 0.7: print( "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Saving best model..... ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~" ) if not os.path.exists(model_dir): os.makedirs(model_dir) torch.save( netG, os.path.join(model_dir, "Best_" + modelName + ".pkl")) # Save images saveImages(netG, val_loader_save_images, batch_size_val_save, i, modelName, deepSupervision) saveImagesAsMatlab(netG, val_loader_save_images, batch_size_val_save, i, modelName, deepSupervision) print("### ###") print("### Best Dice: {:.4f} at epoch {} with DiceT: {:.4f} ###". format(BestDice, BestEpoch, BestDiceT)) print("### ###") # This is not as we did it in the MedPhys paper if i % (BestEpoch + 20): for param_group in optimizerG.param_groups: param_group['lr'] = lr / 2
def runTraining(): print('-' * 40) print('~~~~~~~~ Starting the training... ~~~~~~') print('-' * 40) batch_size = 4 batch_size_val = 1 batch_size_val_save = 1 lr = 0.0001 epoch = 200 num_classes = 2 initial_kernels = 32 modelName = 'IVD_Net' img_names_ALL = [] print('.' * 40) print(" ....Model name: {} ........".format(modelName)) print(' - Num. classes: {}'.format(num_classes)) print(' - Num. initial kernels: {}'.format(initial_kernels)) print(' - Batch size: {}'.format(batch_size)) print(' - Learning rate: {}'.format(lr)) print(' - Num. epochs: {}'.format(epoch)) print('.' * 40) root_dir = '../Data/Training_PngITK' model_dir = 'IVD_Net' transform = transforms.Compose([transforms.ToTensor()]) mask_transform = transforms.Compose([transforms.ToTensor()]) train_set = medicalDataLoader.MedicalImageDataset( 'train', root_dir, transform=transform, mask_transform=mask_transform, augment=False, equalize=False) train_loader = DataLoader(train_set, batch_size=batch_size, num_workers=5, shuffle=True) val_set = medicalDataLoader.MedicalImageDataset( 'val', root_dir, transform=transform, mask_transform=mask_transform, equalize=False) val_loader = DataLoader(val_set, batch_size=batch_size_val, num_workers=5, shuffle=False) val_loader_save_images = DataLoader(val_set, batch_size=batch_size_val_save, num_workers=5, shuffle=False) # Initialize print("~~~~~~~~~~~ Creating the model ~~~~~~~~~~") net = IVD_Net_asym(1, num_classes, initial_kernels) # Initialize the weights net.apply(weights_init) softMax = nn.Softmax() CE_loss = nn.CrossEntropyLoss() Dice_ = computeDiceOneHotBinary() if torch.cuda.is_available(): net.cuda() softMax.cuda() CE_loss.cuda() Dice_.cuda() # To load a pre-trained model '''try: net = torch.load('modelName') print("--------model restored--------") except: print("--------model not restored--------") pass''' optimizer = Adam(net.parameters(), lr=lr, betas=(0.9, 0.99), amsgrad=False) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', patience=4, verbose=True, factor=10**-0.5) BestDice, BestEpoch = 0, 0 d1Train = [] d1Val = [] Losses = [] print("~~~~~~~~~~~ Starting the training ~~~~~~~~~~") for i in range(epoch): net.train() lossTrain = [] d1TrainTemp = [] totalImages = len(train_loader) for j, data in enumerate(train_loader): image_f, image_i, image_o, image_w, labels, img_names = data # Be sure your data here is between [0,1] image_f = image_f.type(torch.FloatTensor) image_i = image_i.type(torch.FloatTensor) image_o = image_o.type(torch.FloatTensor) image_w = image_w.type(torch.FloatTensor) labels = labels.numpy() idx = np.where(labels > 0.0) labels[idx] = 1.0 labels = torch.from_numpy(labels) labels = labels.type(torch.FloatTensor) optimizer.zero_grad() MRI = to_var(torch.cat((image_f, image_i, image_o, image_w), dim=1)) Segmentation = to_var(labels) target_dice = to_var(torch.ones(1)) net.zero_grad() segmentation_prediction = net(MRI) predClass_y = softMax(segmentation_prediction) Segmentation_planes = getOneHotSegmentation(Segmentation) segmentation_prediction_ones = predToSegmentation(predClass_y) # It needs the logits, not the softmax Segmentation_class = getTargetSegmentation(Segmentation) CE_loss_ = CE_loss(segmentation_prediction, Segmentation_class) # Compute the Dice (so far in a 2D-basis) DicesB, DicesF = Dice_(segmentation_prediction_ones, Segmentation_planes) DiceB = DicesToDice(DicesB) DiceF = DicesToDice(DicesF) loss = CE_loss_ loss.backward() optimizer.step() lossTrain.append(loss.data[0]) printProgressBar(j + 1, totalImages, prefix="[Training] Epoch: {} ".format(i), length=15, suffix=" Mean Dice: {:.4f},".format( DiceF.data[0])) printProgressBar(totalImages, totalImages, done="[Training] Epoch: {}, LossG: {:.4f}".format( i, np.mean(lossTrain))) # Save statistics Losses.append(np.mean(lossTrain)) d1 = inference(net, val_loader, batch_size, i) d1Val.append(d1) d1Train.append(np.mean(d1TrainTemp).data[0]) mainPath = '../Results/Statistics/' + modelName directory = mainPath if not os.path.exists(directory): os.makedirs(directory) np.save(os.path.join(directory, 'Losses.npy'), Losses) np.save(os.path.join(directory, 'd1Val.npy'), d1Val) np.save(os.path.join(directory, 'd1Train.npy'), d1Train) currentDice = d1[0].numpy() print("[val] DSC: {:.4f} ".format(d1[0])) if currentDice > BestDice: BestDice = currentDice BestEpoch = i if currentDice > 0.75: print( "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Saving best model..... ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~" ) if not os.path.exists(model_dir): os.makedirs(model_dir) torch.save( net, os.path.join(model_dir, "Best_" + modelName + ".pkl")) saveImages(net, val_loader_save_images, batch_size_val_save, i, modelName) # Two ways of decay the learning rate: if i % (BestEpoch + 10): for param_group in optimizer.param_groups: param_group['lr'] = lr