def runTraining(args): print('-' * 40) print('~~~~~~~~ Starting the training... ~~~~~~') print('-' * 40) batch_size = args.batch_size batch_size_val = 1 batch_size_val_save = 1 lr = args.lr epoch = args.epochs root_dir = './DataSet_Challenge/Val_1' model_dir = 'model' print(' Dataset: {} '.format(root_dir)) transform = transforms.Compose([ transforms.ToTensor() ]) mask_transform = transforms.Compose([ transforms.ToTensor() ]) train_set = medicalDataLoader.MedicalImageDataset('train', root_dir, transform=transform, mask_transform=mask_transform, augment=True, equalize=False) train_loader = DataLoader(train_set, batch_size=batch_size, num_workers=5, shuffle=True) val_set = medicalDataLoader.MedicalImageDataset('val', root_dir, transform=transform, mask_transform=mask_transform, equalize=False) val_loader = DataLoader(val_set, batch_size=batch_size_val, num_workers=5, shuffle=False) val_loader_save_images = DataLoader(val_set, batch_size=batch_size_val_save, num_workers=4, shuffle=False) # Initialize print("~~~~~~~~~~~ Creating the DAF Stacked model ~~~~~~~~~~") net = DAF_stack() print(" Model Name: {}".format(args.modelName)) print(" Model ot create: DAF_Stacked") net.apply(weights_init) softMax = nn.Softmax() CE_loss = nn.CrossEntropyLoss() Dice_loss = computeDiceOneHot() mseLoss = nn.MSELoss() if torch.cuda.is_available(): net.cuda() softMax.cuda() CE_loss.cuda() Dice_loss.cuda() optimizer = Adam(net.parameters(), lr=lr, betas=(0.9, 0.99), amsgrad=False) BestDice, BestEpoch = 0, 0 BestDice3D = [0,0,0,0] d1Val = [] d2Val = [] d3Val = [] d4Val = [] d1Val_3D = [] d2Val_3D = [] d3Val_3D = [] d4Val_3D = [] d1Val_3D_std = [] d2Val_3D_std = [] d3Val_3D_std = [] d4Val_3D_std = [] Losses = [] print("~~~~~~~~~~~ Starting the training ~~~~~~~~~~") for i in range(epoch): net.train() lossVal = [] totalImages = len(train_loader) for j, data in enumerate(train_loader): image, labels, img_names = data # prevent batchnorm error for batch of size 1 if image.size(0) != batch_size: continue optimizer.zero_grad() MRI = to_var(image) Segmentation = to_var(labels) ################### Train ################### net.zero_grad() # Network outputs semVector_1_1, \ semVector_2_1, \ semVector_1_2, \ semVector_2_2, \ semVector_1_3, \ semVector_2_3, \ semVector_1_4, \ semVector_2_4, \ inp_enc0, \ inp_enc1, \ inp_enc2, \ inp_enc3, \ inp_enc4, \ inp_enc5, \ inp_enc6, \ inp_enc7, \ out_enc0, \ out_enc1, \ out_enc2, \ out_enc3, \ out_enc4, \ out_enc5, \ out_enc6, \ out_enc7, \ outputs0, \ outputs1, \ outputs2, \ outputs3, \ outputs0_2, \ outputs1_2, \ outputs2_2, \ outputs3_2 = net(MRI) segmentation_prediction = (outputs0 + outputs1 + outputs2 + outputs3 + outputs0_2 + outputs1_2 + outputs2_2 + outputs3_2) / 8 predClass_y = softMax(segmentation_prediction) Segmentation_planes = getOneHotSegmentation(Segmentation) segmentation_prediction_ones = predToSegmentation(predClass_y) # It needs the logits, not the softmax Segmentation_class = getTargetSegmentation(Segmentation) # Cross-entropy loss loss0 = CE_loss(outputs0, Segmentation_class) loss1 = CE_loss(outputs1, Segmentation_class) loss2 = CE_loss(outputs2, Segmentation_class) loss3 = CE_loss(outputs3, Segmentation_class) loss0_2 = CE_loss(outputs0_2, Segmentation_class) loss1_2 = CE_loss(outputs1_2, Segmentation_class) loss2_2 = CE_loss(outputs2_2, Segmentation_class) loss3_2 = CE_loss(outputs3_2, Segmentation_class) lossSemantic1 = mseLoss(semVector_1_1, semVector_2_1) lossSemantic2 = mseLoss(semVector_1_2, semVector_2_2) lossSemantic3 = mseLoss(semVector_1_3, semVector_2_3) lossSemantic4 = mseLoss(semVector_1_4, semVector_2_4) lossRec0 = mseLoss(inp_enc0, out_enc0) lossRec1 = mseLoss(inp_enc1, out_enc1) lossRec2 = mseLoss(inp_enc2, out_enc2) lossRec3 = mseLoss(inp_enc3, out_enc3) lossRec4 = mseLoss(inp_enc4, out_enc4) lossRec5 = mseLoss(inp_enc5, out_enc5) lossRec6 = mseLoss(inp_enc6, out_enc6) lossRec7 = mseLoss(inp_enc7, out_enc7) lossG = loss0 + loss1 + loss2 + loss3 + loss0_2 + loss1_2 + loss2_2 + loss3_2 + 0.25 * ( lossSemantic1 + lossSemantic2 + lossSemantic3 + lossSemantic4) \ + 0.1 * (lossRec0 + lossRec1 + lossRec2 + lossRec3 + lossRec4 + lossRec5 + lossRec6 + lossRec7) # CE_lossG # Compute the DSC DicesN, DicesB, DicesW, DicesT, DicesZ = Dice_loss(segmentation_prediction_ones, Segmentation_planes) DiceB = DicesToDice(DicesB) DiceW = DicesToDice(DicesW) DiceT = DicesToDice(DicesT) DiceZ = DicesToDice(DicesZ) Dice_score = (DiceB + DiceW + DiceT+ DiceZ) / 4 lossG.backward() optimizer.step() lossVal.append(lossG.cpu().data.numpy()) printProgressBar(j + 1, totalImages, prefix="[Training] Epoch: {} ".format(i), length=15, suffix=" Mean Dice: {:.4f}, Dice1: {:.4f} , Dice2: {:.4f}, , Dice3: {:.4f}, Dice4: {:.4f} ".format( Dice_score.cpu().data.numpy(), DiceB.data.cpu().data.numpy(), DiceW.data.cpu().data.numpy(), DiceT.data.cpu().data.numpy(), DiceZ.data.cpu().data.numpy(),)) printProgressBar(totalImages, totalImages, done="[Training] Epoch: {}, LossG: {:.4f}".format(i,np.mean(lossVal))) # Save statistics modelName = args.modelName directory = 'Results/Statistics/' + modelName Losses.append(np.mean(lossVal)) d1,d2,d3,d4 = inference(net, val_loader) d1Val.append(d1) d2Val.append(d2) d3Val.append(d3) d4Val.append(d4) if not os.path.exists(directory): os.makedirs(directory) np.save(os.path.join(directory, 'Losses.npy'), Losses) np.save(os.path.join(directory, 'd1Val.npy'), d1Val) np.save(os.path.join(directory, 'd2Val.npy'), d2Val) np.save(os.path.join(directory, 'd3Val.npy'), d3Val) currentDice = (d1+d2+d3+d4)/4 print("[val] DSC: (1): {:.4f} (2): {:.4f} (3): {:.4f} (4): {:.4f}".format(d1,d2,d3,d4)) # MRI currentDice = currentDice.data.numpy() # Evaluate on 3D saveImages_for3D(net, val_loader_save_images, batch_size_val_save, 1000, modelName, False, False) reconstruct3D(modelName, 1000, isBest=False) DSC_3D = evaluate3D(modelName) mean_DSC3D = np.mean(DSC_3D, 0) std_DSC3D = np.std(DSC_3D,0) d1Val_3D.append(mean_DSC3D[0]) d2Val_3D.append(mean_DSC3D[1]) d3Val_3D.append(mean_DSC3D[2]) d4Val_3D.append(mean_DSC3D[3]) d1Val_3D_std.append(std_DSC3D[0]) d2Val_3D_std.append(std_DSC3D[1]) d3Val_3D_std.append(std_DSC3D[2]) d4Val_3D_std.append(std_DSC3D[3]) np.save(os.path.join(directory, 'd0Val_3D.npy'), d1Val_3D) np.save(os.path.join(directory, 'd1Val_3D.npy'), d2Val_3D) np.save(os.path.join(directory, 'd2Val_3D.npy'), d3Val_3D) np.save(os.path.join(directory, 'd3Val_3D.npy'), d4Val_3D) np.save(os.path.join(directory, 'd0Val_3D_std.npy'), d1Val_3D_std) np.save(os.path.join(directory, 'd1Val_3D_std.npy'), d2Val_3D_std) np.save(os.path.join(directory, 'd2Val_3D_std.npy'), d3Val_3D_std) np.save(os.path.join(directory, 'd3Val_3D_std.npy'), d4Val_3D_std) if currentDice > BestDice: BestDice = currentDice BestEpoch = i if currentDice > 0.40: if np.mean(mean_DSC3D)>np.mean(BestDice3D): BestDice3D = mean_DSC3D print("### In 3D -----> MEAN: {}, Dice(1): {:.4f} Dice(2): {:.4f} Dice(3): {:.4f} Dice(4): {:.4f} ###".format(np.mean(mean_DSC3D),mean_DSC3D[0], mean_DSC3D[1], mean_DSC3D[2], mean_DSC3D[3])) print("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Saving best model..... ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~") if not os.path.exists(model_dir): os.makedirs(model_dir) torch.save(net.state_dict(), os.path.join(model_dir, "Best_" + modelName + ".pth"),pickle_module=dill) reconstruct3D(modelName, 1000, isBest=True) print("### ###") print("### Best Dice: {:.4f} at epoch {} with Dice(1): {:.4f} Dice(2): {:.4f} Dice(3): {:.4f} Dice(4): {:.4f} ###".format(BestDice, BestEpoch, d1,d2,d3,d4)) print("### Best Dice in 3D: {:.4f} with Dice(1): {:.4f} Dice(2): {:.4f} Dice(3): {:.4f} Dice(4): {:.4f} ###".format(np.mean(BestDice3D),BestDice3D[0], BestDice3D[1], BestDice3D[2], BestDice3D[3] )) print("### ###") if i % (BestEpoch + 50) == 0: for param_group in optimizer.param_groups: lr = lr*0.5 param_group['lr'] = lr print(' ---------- New learning Rate: {}'.format(lr))
def runTraining(): print('-' * 40) print('~~~~~~~~ Starting the training... ~~~~~~') print('-' * 40) batch_size = 4 batch_size_val = 1 batch_size_val_save = 1 batch_size_val_savePng = 4 lr = 0.0001 epoch = 1000 root_dir = '../DataSet/Bladder_Aug' modelName = 'UNetG_Dilated_Progressive' model_dir = 'model' transform = transforms.Compose([transforms.ToTensor()]) mask_transform = transforms.Compose([transforms.ToTensor()]) train_set = medicalDataLoader.MedicalImageDataset( 'train', root_dir, transform=transform, mask_transform=mask_transform, augment=False, equalize=False) train_loader = DataLoader(train_set, batch_size=batch_size, num_workers=5, shuffle=True) val_set = medicalDataLoader.MedicalImageDataset( 'val', root_dir, transform=transform, mask_transform=mask_transform, equalize=False) val_loader = DataLoader(val_set, batch_size=batch_size_val, num_workers=5, shuffle=False) val_loader_save_images = DataLoader(val_set, batch_size=batch_size_val_save, num_workers=5, shuffle=False) val_loader_save_imagesPng = DataLoader(val_set, batch_size=batch_size_val_savePng, num_workers=5, shuffle=False) # Initialize print("~~~~~~~~~~~ Creating the model ~~~~~~~~~~") num_classes = 4 initial_kernels = 32 # Load network netG = UNetG_Dilated_Progressive(1, initial_kernels, num_classes) softMax = nn.Softmax() CE_loss = nn.CrossEntropyLoss() Dice_loss = computeDiceOneHot() if torch.cuda.is_available(): netG.cuda() softMax.cuda() CE_loss.cuda() Dice_loss.cuda() '''try: netG = torch.load('./model/Best_UNetG_Dilated_Progressive_Stride_Residual_ChannelsFirst32.pkl') print("--------model restored--------") except: print("--------model not restored--------") pass''' optimizerG = Adam(netG.parameters(), lr=lr, betas=(0.9, 0.99), amsgrad=False) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizerG, mode='max', patience=4, verbose=True, factor=10**-0.5) BestDice, BestEpoch = 0, 0 d1Train = [] d2Train = [] d3Train = [] d1Val = [] d2Val = [] d3Val = [] Losses = [] Losses1 = [] Losses05 = [] Losses025 = [] Losses0125 = [] print("~~~~~~~~~~~ Starting the training ~~~~~~~~~~") for i in range(epoch): netG.train() lossVal = [] lossValD = [] lossVal1 = [] lossVal05 = [] lossVal025 = [] lossVal0125 = [] d1TrainTemp = [] d2TrainTemp = [] d3TrainTemp = [] timesAll = [] success = 0 totalImages = len(train_loader) for j, data in enumerate(train_loader): image, labels, img_names = data # prevent batchnorm error for batch of size 1 if image.size(0) != batch_size: continue optimizerG.zero_grad() MRI = to_var(image) Segmentation = to_var(labels) target_dice = to_var(torch.ones(1)) ################### Train ################### netG.zero_grad() deepSupervision = False multiTask = False start_time = time.time() if deepSupervision == False and multiTask == False: # No deep supervision segmentation_prediction = netG(MRI) else: # Deep supervision if deepSupervision == True: segmentation_prediction, seg_3, seg_2, seg_1 = netG(MRI) else: segmentation_prediction, reg_output = netG(MRI) # Regression feats = getValuesRegression(labels) feats_t = torch.from_numpy(feats).float() featsVar = to_var(feats_t) MSE_loss_val = MSE_loss(reg_output, featsVar) predClass_y = softMax(segmentation_prediction) spentTime = time.time() - start_time timesAll.append(spentTime / batch_size) Segmentation_planes = getOneHotSegmentation(Segmentation) segmentation_prediction_ones = predToSegmentation(predClass_y) # It needs the logits, not the softmax Segmentation_class = getTargetSegmentation(Segmentation) # No deep supervision CE_lossG = CE_loss(segmentation_prediction, Segmentation_class) if deepSupervision == True: imageLabels_05 = resizeTensorMaskInSingleImage( Segmentation_class, 2) imageLabels_025 = resizeTensorMaskInSingleImage( Segmentation_class, 4) imageLabels_0125 = resizeTensorMaskInSingleImage( Segmentation_class, 8) CE_lossG_3 = CE_loss(seg_3, imageLabels_05) CE_lossG_2 = CE_loss(seg_2, imageLabels_025) CE_lossG_1 = CE_loss(seg_1, imageLabels_0125) '''weight = torch.ones(4).cuda() # Num classes weight[0] = 0.2 weight[1] = 0.2 weight[2] = 1 weight[3] = 1 CE_loss.weight = weight''' # Dice loss DicesN, DicesB, DicesW, DicesT = Dice_loss( segmentation_prediction_ones, Segmentation_planes) DiceN = DicesToDice(DicesN) DiceB = DicesToDice(DicesB) DiceW = DicesToDice(DicesW) DiceT = DicesToDice(DicesT) Dice_score = (DiceB + DiceW + DiceT) / 3 if deepSupervision == False and multiTask == False: lossG = CE_lossG else: # Deep supervision if deepSupervision == True: lossG = CE_lossG + 0.25 * CE_lossG_3 + 0.1 * CE_lossG_2 + 0.1 * CE_lossG_1 else: lossG = CE_lossG + 0.000001 * MSE_loss_val lossG.backward() optimizerG.step() lossVal.append(lossG.data[0]) lossVal1.append(CE_lossG.data[0]) if deepSupervision == True: lossVal05.append(CE_lossG_3.data[0]) lossVal025.append(CE_lossG_2.data[0]) lossVal0125.append(CE_lossG_1.data[0]) printProgressBar( j + 1, totalImages, prefix="[Training] Epoch: {} ".format(i), length=15, suffix= " Mean Dice: {:.4f}, Dice1: {:.4f} , Dice2: {:.4f}, , Dice3: {:.4f} " .format(Dice_score.data[0], DiceB.data[0], DiceW.data[0], DiceT.data[0])) if deepSupervision == False: '''printProgressBar(totalImages, totalImages, done="[Training] Epoch: {}, LossG: {:.4f},".format(i,np.mean(lossVal),np.mean(lossVal1)))''' printProgressBar( totalImages, totalImages, done="[Training] Epoch: {}, LossG: {:.4f}, lossMSE: {:.4f}". format(i, np.mean(lossVal), np.mean(lossVal1))) else: printProgressBar( totalImages, totalImages, done= "[Training] Epoch: {}, LossG: {:.4f}, Loss4: {:.4f}, Loss3: {:.4f}, Loss2: {:.4f}, Loss1: {:.4f}" .format(i, np.mean(lossVal), np.mean(lossVal1), np.mean(lossVal05), np.mean(lossVal025), np.mean(lossVal0125))) Losses.append(np.mean(lossVal)) d1, d2, d3 = inference(netG, val_loader, batch_size, i, deepSupervision) d1Val.append(d1) d2Val.append(d2) d3Val.append(d3) d1Train.append(np.mean(d1TrainTemp).data[0]) d2Train.append(np.mean(d2TrainTemp).data[0]) d3Train.append(np.mean(d3TrainTemp).data[0]) mainPath = '../Results/Statistics/' + modelName directory = mainPath if not os.path.exists(directory): os.makedirs(directory) ###### Save statistics ###### np.save(os.path.join(directory, 'Losses.npy'), Losses) np.save(os.path.join(directory, 'd1Val.npy'), d1Val) np.save(os.path.join(directory, 'd2Val.npy'), d2Val) np.save(os.path.join(directory, 'd3Val.npy'), d3Val) np.save(os.path.join(directory, 'd1Train.npy'), d1Train) np.save(os.path.join(directory, 'd2Train.npy'), d2Train) np.save(os.path.join(directory, 'd3Train.npy'), d3Train) currentDice = (d1 + d2 + d3) / 3 # How many slices with/without tumor correctly classified print("[val] DSC: (1): {:.4f} (2): {:.4f} (3): {:.4f} ".format( d1, d2, d3)) if currentDice > BestDice: BestDice = currentDice BestDiceT = d1 BestEpoch = i if currentDice > 0.7: print( "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Saving best model..... ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~" ) if not os.path.exists(model_dir): os.makedirs(model_dir) torch.save( netG, os.path.join(model_dir, "Best_" + modelName + ".pkl")) # Save images saveImages(netG, val_loader_save_images, batch_size_val_save, i, modelName, deepSupervision) saveImagesAsMatlab(netG, val_loader_save_images, batch_size_val_save, i, modelName, deepSupervision) print("### ###") print("### Best Dice: {:.4f} at epoch {} with DiceT: {:.4f} ###". format(BestDice, BestEpoch, BestDiceT)) print("### ###") # This is not as we did it in the MedPhys paper if i % (BestEpoch + 20): for param_group in optimizerG.param_groups: param_group['lr'] = lr / 2
def train (self, targ, inp, mode = 'online', epochs = 500, par = None, out = None, Jrout = None, track = False): ''' This is the main function of the model: is used to trained the system given a target and and input. Two training mode can be selected: (offline, online). The first uses the exact likelihood gradient (which is non local in time, thus non biologically-plausible), the second is the online approx. of the gradient as descrived in the article. Args: targ: numpy.array of shape (N, T), where N is the number of neurons in the network and T is the time length of the sequence. inp : numpy.array of shape (N, T) that collects the input signal to neurons. Keywords: mode : (default: online) The training mode to use, either 'offline' or 'online'. epochs: (default: 500) The number of epochs of training. par : (default: None) Optional different dictionary collecting training parameters: {dv, alpha, alpha_rout, beta_ro, offT}. If not provided defaults to the parameter dictionary of the model. out : (default: None) Output target trajectories, numpy.array of shape (K, T), where K is the dimension of the output trajectories. This parameter should be specified if either Jrout != None or track is True. Jrout : (default: None) Pre-trained readout connection matrix. If not provided, a novel matrix is built and trained simultaneously with the recurrent connections training. If Jrout is provided, the out parameter should be specified as it is needed to compute output error. track : (default: None) Flag to signal whether to track the evolution of output MSE over training epochs. If track is True then the out parameters should be specified as it is needed to compute output error. ''' assert (targ.shape == inp.shape); par = self.par if par is None else par; dv = par['dv']; itau_m = self.itau_m; itau_s = self.itau_s; sigm = self._sigm; alpha = par['alpha']; alpha_rout = par['alpha_rout']; beta_ro = par['beta_ro']; offT = par['offT']; self.S [:, 0] = targ [:, 0].copy (); self.S_hat [:, 0] = self.S [:, 0] * itau_s; Tmax = np.shape (targ) [-1]; dH = np.zeros ((self.N, self.T)); track = np.zeros (epochs) if track else None; opt_rec = Adam (alpha = alpha, drop = 0.9, drop_time = 100 * Tmax if mode == 'online' else 100); if Jrout is None: S_rout = ut.sfilter (targ, itau = beta_ro); J_rout = np.random.normal (0., 0.1, size = (out.shape[0], self.N)); opt = Adam (alpha = alpha_rout, drop = 0.9, drop_time = 20 * Tmax if mode == 'online' else 20); else: J_rout = Jrout; for epoch in trange (epochs, leave = False, desc = 'Training {}'.format (mode)): if Jrout is None: # Here we train the readout dJrout = (out - J_rout @ S_rout) @ S_rout.T; J_rout = opt.step (J_rout, dJrout); # Here we train the network for t in range (Tmax - 1): self.S_hat [:, t] = self.S_hat [:, t - 1] * itau_s + targ [:, t] * (1. - itau_s) if t > 0 else self.S_hat [:, 0]; self.H [:, t + 1] = self.H [:, t] * (1. - itau_m) + itau_m * (self.J @ self.S_hat [:, t] + inp [:, t] + self.h [:, t])\ + self.Jreset @ targ [:, t]; dH [:, t + 1] = dH [:, t] * (1. - itau_m) + itau_m * self.S_hat [:, t]; if mode == 'online': dJ = np.outer (targ [:, t + 1] - sigm (self.H [:, t + 1], dv = dv), dH [:, t + 1]); self.J = opt_rec.step (self.J, dJ); np.fill_diagonal (self.J, 0.); if mode == 'offline': dJ = (targ - sigm (self.H, dv = dv)) @ dH.T; self.J = opt_rec.step (self.J, dJ); np.fill_diagonal (self.J, 0.); # Here we track MSE if track is not None: S_gen = self.compute (inp, init = np.zeros (self.N)); track [epoch] = np.mean ((out - J_rout @ ut.sfilter (S_gen, itau = beta_ro))[:, offT:]**2.); return (J_rout, track) if Jrout is None else track;
def main(): # generate data and translate labels train_features, train_targets = generate_all_datapoints_and_labels() test_features, test_targets = generate_all_datapoints_and_labels() train_labels, test_labels = convert_labels(train_targets), convert_labels(test_targets) print('*************************************************************************') print('*************************************************************************') print('*************************************************************************') print('*************************************************************************') print('*************************************************************************') print('Model: Linear + ReLU + Linear +ReLU + Linear + ReLU + Linear + Tanh') print('Loss: MSE') print('Optimizer: SGD') print('*************************************************************************') print('Training') print('*************************************************************************') # build network, loss and optimizer for Model 1 my_model_design_1=[Linear(2,25), ReLU(), Linear(25,25), Dropout(p=0.5), ReLU(), Linear(25,25), ReLU(),Linear(25,2),Tanh()] my_model_1=Sequential(my_model_design_1) optimizer_1=SGD(my_model_1,lr=1e-3) criterion_1=LossMSE() # train Model 1 batch_size=1 for epoch in range(50): temp_train_loss_sum=0. temp_test_loss_sum=0. num_train_correct=0 num_test_correct=0 # trained in batch-fashion: here batch size = 1 for temp_batch in range(0,len(train_features), batch_size): temp_train_features=train_features.narrow(0, temp_batch, batch_size) temp_train_labels=train_labels.narrow(0, temp_batch, batch_size) for i in range(batch_size): # clean parameter gradient before each batch optimizer_1.zero_grad() temp_train_feature=temp_train_features[i] temp_train_label=temp_train_labels[i] # forward pass to compute loss temp_train_pred=my_model_1.forward(temp_train_feature) temp_train_loss=criterion_1.forward(temp_train_pred,temp_train_label) temp_train_loss_sum+=temp_train_loss _, temp_train_pred_cat=torch.max(temp_train_pred,0) _, temp_train_label_cat=torch.max(temp_train_label,0) if temp_train_pred_cat==temp_train_label_cat: num_train_correct+=1 # calculate gradient according to loss gradient temp_train_loss_grad=criterion_1.backward(temp_train_pred,temp_train_label) # accumulate parameter gradient in each batch my_model_1.backward(temp_train_loss_grad) # update parameters by optimizer optimizer_1.step() # evaluate the current model on testing set # only forward pass is implemented for i_test in range(len(test_features)): temp_test_feature=test_features[i_test] temp_test_label=test_labels[i_test] temp_test_pred=my_model_1.forward(temp_test_feature) temp_test_loss=criterion_1.forward(temp_test_pred,temp_test_label) temp_test_loss_sum+=temp_test_loss _, temp_test_pred_cat=torch.max(temp_test_pred,0) _, temp_test_label_cat=torch.max(temp_test_label,0) if temp_test_pred_cat==temp_test_label_cat: num_test_correct+=1 temp_train_loss_mean=temp_train_loss_sum/len(train_features) temp_test_loss_mean=temp_test_loss_sum/len(test_features) temp_train_accuracy=num_train_correct/len(train_features) temp_test_accuracy=num_test_correct/len(test_features) print("Epoch: {}/{}..".format(epoch+1, 50), "Training Loss: {:.4f}..".format(temp_train_loss_mean), "Training Accuracy: {:.4f}..".format(temp_train_accuracy), "Validation/Test Loss: {:.4f}..".format(temp_test_loss_mean), "Validation/Test Accuracy: {:.4f}..".format(temp_test_accuracy), ) # # visualize the classification performance of Model 1 on testing set test_pred_labels_1=[] for i in range(1000): temp_test_feature=test_features[i] temp_test_label=test_labels[i] temp_test_pred=my_model_1.forward(temp_test_feature) _, temp_train_pred_cat=torch.max(temp_test_pred,0) if test_targets[i].int() == temp_train_pred_cat.int(): test_pred_labels_1.append(int(test_targets[i])) else: test_pred_labels_1.append(2) fig,axes = plt.subplots(1,1,figsize=(6,6)) axes.scatter(test_features[:,0], test_features[:,1], c=test_pred_labels_1) axes.set_title('Classification Performance of Model 1') plt.show() print('*************************************************************************') print('*************************************************************************') print('*************************************************************************') print('*************************************************************************') print('*************************************************************************') print('Model: Linear + ReLU + Linear + Dropout+ SeLU + Linear + Dropout + ReLU + Linear + Sigmoid') print('Loss: Cross Entropy') print('Optimizer: Adam') print('*************************************************************************') print('Training') print('*************************************************************************') # build network, loss function and optimizer for Model 2 my_model_design_2=[Linear(2,25), ReLU(), Linear(25,25), Dropout(p=0.5), SeLU(), Linear(25,25),Dropout(p=0.5), ReLU(),Linear(25,2), Sigmoid()] my_model_2=Sequential(my_model_design_2) optimizer_2=Adam(my_model_2,lr=1e-3) criterion_2=CrossEntropy() # train Model 2 batch_size=1 epoch=0 while(epoch<25): temp_train_loss_sum=0. temp_test_loss_sum=0. num_train_correct=0 num_test_correct=0 # trained in batch-fashion: here batch size = 1 for temp_batch in range(0,len(train_features), batch_size): temp_train_features=train_features.narrow(0, temp_batch, batch_size) temp_train_labels=train_labels.narrow(0, temp_batch, batch_size) for i in range(batch_size): # clean parameter gradient before each batch optimizer_2.zero_grad() temp_train_feature=temp_train_features[i] temp_train_label=temp_train_labels[i] # forward pass to compute loss temp_train_pred=my_model_2.forward(temp_train_feature) temp_train_loss=criterion_2.forward(temp_train_pred,temp_train_label) temp_train_loss_sum+=temp_train_loss _, temp_train_pred_cat=torch.max(temp_train_pred,0) _, temp_train_label_cat=torch.max(temp_train_label,0) if temp_train_pred_cat==temp_train_label_cat: num_train_correct+=1 # calculate gradient according to loss gradient temp_train_loss_grad=criterion_2.backward(temp_train_pred,temp_train_label) ''' if (not temp_train_loss_grad[0]>=0) and (not temp_train_loss_grad[0]<0): continue ''' # accumulate parameter gradient in each batch my_model_2.backward(temp_train_loss_grad) # update parameters by optimizer optimizer_2.step() # evaluate the current model on testing set # only forward pass is implemented for i_test in range(len(test_features)): temp_test_feature=test_features[i_test] temp_test_label=test_labels[i_test] temp_test_pred=my_model_2.forward(temp_test_feature) temp_test_loss=criterion_2.forward(temp_test_pred,temp_test_label) temp_test_loss_sum+=temp_test_loss _, temp_test_pred_cat=torch.max(temp_test_pred,0) _, temp_test_label_cat=torch.max(temp_test_label,0) if temp_test_pred_cat==temp_test_label_cat: num_test_correct+=1 temp_train_loss_mean=temp_train_loss_sum/len(train_features) temp_test_loss_mean=temp_test_loss_sum/len(test_features) temp_train_accuracy=num_train_correct/len(train_features) temp_test_accuracy=num_test_correct/len(test_features) # in case there is gradient explosion problem, initiliza model again and restart training # but the situation seldom happens if (not temp_train_loss_grad[0]>=0) and (not temp_train_loss_grad[0]<0): epoch=0 my_model_design_2=[Linear(2,25), ReLU(), Linear(25,25), Dropout(p=0.5), ReLU(), Linear(25,25),Dropout(p=0.5), ReLU(),Linear(25,2),Sigmoid()] my_model_2=Sequential(my_model_design_2) optimizer_2=Adam(my_model_2,lr=1e-3) criterion_2=CrossEntropy() print('--------------------------------------------------------------------------------') print('--------------------------------------------------------------------------------') print('--------------------------------------------------------------------------------') print('--------------------------------------------------------------------------------') print('--------------------------------------------------------------------------------') print('Restart training because of gradient explosion') continue print("Epoch: {}/{}..".format(epoch+1, 25), "Training Loss: {:.4f}..".format(temp_train_loss_mean), "Training Accuracy: {:.4f}..".format(temp_train_accuracy), "Validation/Test Loss: {:.4f}..".format(temp_test_loss_mean), "Validation/Test Accuracy: {:.4f}..".format(temp_test_accuracy), ) epoch+=1 # visualize the classification performance of Model 2 on testing set test_pred_labels_2=[] for i in range(1000): temp_test_feature=test_features[i] temp_test_label=test_labels[i] temp_test_pred=my_model_2.forward(temp_test_feature) _, temp_train_pred_cat=torch.max(temp_test_pred,0) if test_targets[i].int() == temp_train_pred_cat.int(): test_pred_labels_2.append(int(test_targets[i])) else: test_pred_labels_2.append(2) fig,axes = plt.subplots(1,1,figsize=(6,6)) axes.scatter(test_features[:,0], test_features[:,1], c=test_pred_labels_2) axes.set_title('Classification Performance of Model 2') plt.show()
def runTraining(): print('-' * 40) print('~~~~~~~~ Starting the training... ~~~~~~') print('-' * 40) batch_size = 4 batch_size_val = 1 batch_size_val_save = 1 lr = 0.0001 epoch = 200 num_classes = 2 initial_kernels = 32 modelName = 'IVD_Net' img_names_ALL = [] print('.' * 40) print(" ....Model name: {} ........".format(modelName)) print(' - Num. classes: {}'.format(num_classes)) print(' - Num. initial kernels: {}'.format(initial_kernels)) print(' - Batch size: {}'.format(batch_size)) print(' - Learning rate: {}'.format(lr)) print(' - Num. epochs: {}'.format(epoch)) print('.' * 40) root_dir = '../Data/Training_PngITK' model_dir = 'IVD_Net' transform = transforms.Compose([transforms.ToTensor()]) mask_transform = transforms.Compose([transforms.ToTensor()]) train_set = medicalDataLoader.MedicalImageDataset( 'train', root_dir, transform=transform, mask_transform=mask_transform, augment=False, equalize=False) train_loader = DataLoader(train_set, batch_size=batch_size, num_workers=5, shuffle=True) val_set = medicalDataLoader.MedicalImageDataset( 'val', root_dir, transform=transform, mask_transform=mask_transform, equalize=False) val_loader = DataLoader(val_set, batch_size=batch_size_val, num_workers=5, shuffle=False) val_loader_save_images = DataLoader(val_set, batch_size=batch_size_val_save, num_workers=5, shuffle=False) # Initialize print("~~~~~~~~~~~ Creating the model ~~~~~~~~~~") net = IVD_Net_asym(1, num_classes, initial_kernels) # Initialize the weights net.apply(weights_init) softMax = nn.Softmax() CE_loss = nn.CrossEntropyLoss() Dice_ = computeDiceOneHotBinary() if torch.cuda.is_available(): net.cuda() softMax.cuda() CE_loss.cuda() Dice_.cuda() # To load a pre-trained model '''try: net = torch.load('modelName') print("--------model restored--------") except: print("--------model not restored--------") pass''' optimizer = Adam(net.parameters(), lr=lr, betas=(0.9, 0.99), amsgrad=False) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', patience=4, verbose=True, factor=10**-0.5) BestDice, BestEpoch = 0, 0 d1Train = [] d1Val = [] Losses = [] print("~~~~~~~~~~~ Starting the training ~~~~~~~~~~") for i in range(epoch): net.train() lossTrain = [] d1TrainTemp = [] totalImages = len(train_loader) for j, data in enumerate(train_loader): image_f, image_i, image_o, image_w, labels, img_names = data # Be sure your data here is between [0,1] image_f = image_f.type(torch.FloatTensor) image_i = image_i.type(torch.FloatTensor) image_o = image_o.type(torch.FloatTensor) image_w = image_w.type(torch.FloatTensor) labels = labels.numpy() idx = np.where(labels > 0.0) labels[idx] = 1.0 labels = torch.from_numpy(labels) labels = labels.type(torch.FloatTensor) optimizer.zero_grad() MRI = to_var(torch.cat((image_f, image_i, image_o, image_w), dim=1)) Segmentation = to_var(labels) target_dice = to_var(torch.ones(1)) net.zero_grad() segmentation_prediction = net(MRI) predClass_y = softMax(segmentation_prediction) Segmentation_planes = getOneHotSegmentation(Segmentation) segmentation_prediction_ones = predToSegmentation(predClass_y) # It needs the logits, not the softmax Segmentation_class = getTargetSegmentation(Segmentation) CE_loss_ = CE_loss(segmentation_prediction, Segmentation_class) # Compute the Dice (so far in a 2D-basis) DicesB, DicesF = Dice_(segmentation_prediction_ones, Segmentation_planes) DiceB = DicesToDice(DicesB) DiceF = DicesToDice(DicesF) loss = CE_loss_ loss.backward() optimizer.step() lossTrain.append(loss.data[0]) printProgressBar(j + 1, totalImages, prefix="[Training] Epoch: {} ".format(i), length=15, suffix=" Mean Dice: {:.4f},".format( DiceF.data[0])) printProgressBar(totalImages, totalImages, done="[Training] Epoch: {}, LossG: {:.4f}".format( i, np.mean(lossTrain))) # Save statistics Losses.append(np.mean(lossTrain)) d1 = inference(net, val_loader, batch_size, i) d1Val.append(d1) d1Train.append(np.mean(d1TrainTemp).data[0]) mainPath = '../Results/Statistics/' + modelName directory = mainPath if not os.path.exists(directory): os.makedirs(directory) np.save(os.path.join(directory, 'Losses.npy'), Losses) np.save(os.path.join(directory, 'd1Val.npy'), d1Val) np.save(os.path.join(directory, 'd1Train.npy'), d1Train) currentDice = d1[0].numpy() print("[val] DSC: {:.4f} ".format(d1[0])) if currentDice > BestDice: BestDice = currentDice BestEpoch = i if currentDice > 0.75: print( "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Saving best model..... ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~" ) if not os.path.exists(model_dir): os.makedirs(model_dir) torch.save( net, os.path.join(model_dir, "Best_" + modelName + ".pkl")) saveImages(net, val_loader_save_images, batch_size_val_save, i, modelName) # Two ways of decay the learning rate: if i % (BestEpoch + 10): for param_group in optimizer.param_groups: param_group['lr'] = lr
def train(self, targs, inps, mode='online', epochs=500, par=None, outs=None, Jrout=None, track=False): ''' This is the main function of the model: is used to trained the system given a target and and input. Two training mode can be selected: (offline, online). The first uses the exact likelihood gradient (which is non local in time, thus non biologically-plausible), the second is the online approx. of the gradient as descrived in the article. Args: targ: numpy.array of shape (N, T), where N is the number of neurons in the network and T is the time length of the sequence. inp : numpy.array of shape (N, T) that collects the input signal to neurons. Keywords: mode : (default: online) The training mode to use, either 'offline' or 'online'. epochs: (default: 500) The number of epochs of training. par : (default: None) Optional different dictionary collecting training parameters: {dv, alpha, alpha_rout, beta_ro, offT}. If not provided defaults to the parameter dictionary of the model. out : (default: None) Output target trajectories, numpy.array of shape (K, T), where K is the dimension of the output trajectories. This parameter should be specified if either Jrout != None or track is True. Jrout : (default: None) Pre-trained readout connection matrix. If not provided, a novel matrix is built and trained simultaneously with the recurrent connections training. If Jrout is provided, the out parameter should be specified as it is needed to compute output error. track : (default: None) Flag to signal whether to track the evolution of output MSE over training epochs. If track is True then the out parameters should be specified as it is needed to compute output error. ''' par = self.par if par is None else par dv = par['dv'] itau_m = self.itau_m itau_s = self.itau_s sigm = self._sigm alpha = par['alpha'] alpha_rout = par['alpha_rout'] beta_ro = par['beta_ro'] offT = par['offT'] self.J = np.random.normal(0., .1 / np.sqrt(self.N), size=(self.N, self.N)) self.J[:, :self.N // 2] = np.abs(self.J[:, :self.N // 2]) self.J[:, self.N // 2:] = -np.abs(self.J[:, self.N // 2:]) np.fill_diagonal(self.J, 0.) ndxe = self.J > 0 ndxi = self.J < 0 track = np.zeros((epochs, 2)) if track else None Tmax = np.shape(targs[0])[-1] opt_recE = SimpleGradient(alpha=alpha) opt_recI = SimpleGradient(alpha=alpha) if Jrout is None: S_rout = [ut.sfilter(targ, itau=beta_ro) for targ in targs] J_rout = np.random.normal(0., 0.01, size=(outs[0].shape[0], self.N)) opt = Adam(alpha=alpha_rout, drop=0.9, drop_time=epochs // 5 * Tmax if mode == 'online' else 1e10) else: J_rout = Jrout for epoch in trange(epochs, leave=False, desc='Training {}'.format(mode)): if Jrout is None: ut.shuffle((inps, outs, targs, S_rout)) else: ut.shuffle((inps, targs)) if Jrout is None: for out, s_rout in zip(outs, S_rout): # Here we train the readout dJrout = (out - J_rout @ s_rout) @ s_rout.T J_rout = opt.step(J_rout, dJrout) # Here we train the network for inp, targ in zip(inps, targs): self.S[:, 0] = targ[:, 0].copy() self.S_hat[:, 0] = self.S[:, 0] * itau_s dHe = np.zeros((self.N, self.N)) dHi = np.zeros((self.N, self.N)) dJe = dJi = 0. for t in range(Tmax - 1): Jp = ndxe * self.J Jm = -(ndxi * self.J) self.S_hat[:, t] = self.S_hat[:, t - 1] * itau_s + targ[:, t] * ( 1. - itau_s ) if t > 0 else self.S_hat[:, 0] self.H [:, t + 1] = self.H [:, t] * (1. - itau_m) +\ itau_m * ((self.Ve - self.H[:, t]).reshape (self.N, 1) / abs(self.Ve) * Jp @ self.S_hat [:, t] + (self.Vi - self.H[:, t]).reshape (self.N, 1) / abs(self.Vi) * Jm @ self.S_hat [:, t] + inp [:, t] + self.h [:, t]) +\ self.Jreset @ targ [:, t] dHe [...] = dHe * (1. - itau_m * (1 + ((Jp + Jm) @ self.S_hat [:, t]).reshape (self.N, 1) / abs (self.Ve))) +\ itau_m * ((self.Ve - self.H[:, t + 1]).reshape (self.N, 1) / abs (self.Ve) * ndxe * self.S_hat[:, t].reshape (1, self.N)) dHi [...] = dHi * (1. - itau_m * (1 + ((Jp + Jm) @ self.S_hat [:, t]).reshape (self.N, 1) / abs (self.Vi))) +\ itau_m * ((self.Vi - self.H[:, t + 1]).reshape(self.N, 1) / abs (self.Vi) * ndxi * self.S_hat[:, t].reshape(1, self.N)) if mode == 'online': dJe = (targ[:, t + 1] - sigm( self.H[:, t + 1], dv=dv)).reshape(self.N, 1) * dHe dJi = (targ[:, t + 1] - sigm( self.H[:, t + 1], dv=dv)).reshape(self.N, 1) * dHi Jp = opt_recE.step(Jp, dJe * ndxe) Jm = opt_recI.step(Jm, dJi * ndxi) Jp = np.maximum(Jp, 0.) Jm = np.maximum(Jm, 0.) self.J = Jp - Jm np.fill_diagonal(self.J, 0.) elif mode == 'offline': dJe += (targ[:, t + 1] - sigm( self.H[:, t + 1], dv=dv)).reshape(self.N, 1) * dHe dJi += (targ[:, t + 1] - sigm( self.H[:, t + 1], dv=dv)).reshape(self.N, 1) * dHi if mode == 'offline': Jp = opt_recE.step(Jp, dJe * ndxe) Jm = opt_recI.step(Jm, dJi * ndxi) Jp = np.maximum(Jp, 0.) Jm = np.maximum(Jm, 0.) self.J = Jp - Jm np.fill_diagonal(self.J, 0.) # Here we track MSE if track is not None: S_gen = self.compute(inp, init=np.zeros(self.N)) track[epoch, 0] = np.mean( (out - J_rout @ ut.sfilter(S_gen, itau=beta_ro))[:, offT:]**2.) track[epoch, 1] = np.sum(np.abs(targ - S_gen)) / (self.N * self.T) return (J_rout, track) if Jrout is None else track