def trainNetwork(net, trainData, valData, noiseModel, postfix, device, directory='.', numOfEpochs=200, stepsPerEpoch=50, batchSize=4, patchSize=100, learningRate=0.0001, numMaskedPixels=100 * 100 / 32.0, virtualBatchSize=20, valSize=20, augment=True, supervised=False): ''' Train a network using PN2V Parameters ---------- net: The network we want to train. The number of output channels determines the number of samples that are predicted. trainData: numpy array Our training image. A 3D array that is interpreted as a stack of 2D images. valData: numpy array Our validation image. A 3D array that is interpreted as a stack of 2D images. noiseModel: NoiseModel The noise model we will use during training. postfix: string This identifier is attached to the names of the files that will be saved during training. device: The device we are using, e.g. a GPU or CPU directory: string The directory all files will be saved to. numOfEpochs: int Number of training epochs. stepsPerEpoch: int Number of gradient steps per epoch. batchSize: int The batch patch_size, i.e. the number of patches processed simultainasly on the GPU. patchSize: int The width and height of the square training patches. learningRate: float The learning rate. numMaskedPixels: int The number of pixels that is to be manipulated/masked N2V style in every training patch. virtualBatchSize: int The number of batches that are processed before a gradient step is performed. valSize: int The number of validation patches processed after each epoch. augment: bool should the patches be randomy flipped and rotated? Returns ---------- trainHist: numpy array A numpy array containing the avg. training loss of each epoch. valHist: numpy array A numpy array containing the avg. validation loss after each epoch. ''' # Calculate mean and std of image. # Everything that is processed by the net will be normalized and denormalized using these numbers. combined = np.concatenate((trainData, valData)) net.mean = np.mean(combined) net.std = np.std(combined) net.to(device) optimizer = optim.Adam(net.parameters(), lr=learningRate) scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=10, factor=0.5, verbose=True) running_loss = 0.0 stepCounter = 0 dataCounter = 0 trainHist = [] valHist = [] pn2v = (noiseModel is not None) and (not supervised) while stepCounter / stepsPerEpoch < numOfEpochs: # loop over the dataset multiple times losses = [] optimizer.zero_grad() stepCounter += 1 # Loop over our virtual batch for a in range(virtualBatchSize): outputs, labels, masks, dataCounter = trainingPred( trainData, net, dataCounter, patchSize, batchSize, numMaskedPixels, device, augment=augment, supervised=supervised) loss = lossFunction(outputs, labels, masks, noiseModel, pn2v, net.std) loss.backward() running_loss += loss.item() losses.append(loss.item()) optimizer.step() # We have reached the end of an epoch if stepCounter % stepsPerEpoch == stepsPerEpoch - 1: running_loss = (np.mean(losses)) losses = np.array(losses) utils.printNow("Epoch " + str(int(stepCounter / stepsPerEpoch)) + " finished") utils.printNow("avg. loss: " + str(np.mean(losses)) + "+-(2SEM)" + str(2.0 * np.std(losses) / np.sqrt(losses.size))) trainHist.append(np.mean(losses)) torch.save(net, os.path.join(directory, "last_" + postfix + ".net")) valCounter = 0 net.trainable = False losses = [] for i in range(valSize): outputs, labels, masks, valCounter = trainingPred( valData, net, valCounter, patchSize, batchSize, numMaskedPixels, device, augment=augment, supervised=supervised) loss = lossFunction(outputs, labels, masks, noiseModel, pn2v, net.std) losses.append(loss.item()) net.trainable = True avgValLoss = np.mean(losses) if len(valHist) == 0 or avgValLoss < np.min(np.array(valHist)): torch.save(net, os.path.join(directory, "best_" + postfix + ".net")) valHist.append(avgValLoss) epoch = (stepCounter / stepsPerEpoch) np.save(os.path.join(directory, "history" + postfix + ".npy"), (np.array([np.arange(epoch), trainHist, valHist]))) utils.printNow('Finished Training') return trainHist, valHist
def trainNetwork( net, trainData, valData, te_Data_target, te_Data_source, noiseModel, postfix, device, directory='.', numOfEpochs=100, stepsPerEpoch=400, batchSize=128, patchSize=64, learningRate=0.0004, numMaskedPixels=64, virtualBatchSize=20, valSize=20, augment=True, supervised=False, save_file_name=None, ): ''' Train a network using PN2V Parameters ---------- net: The network we want to train. The number of output channels determines the number of samples that are predicted. trainData: numpy array Our training data. A 3D array that is interpreted as a stack of 2D images. valData: numpy array Our validation data. A 3D array that is interpreted as a stack of 2D images. noiseModel: NoiseModel The noise model we will use during training. postfix: string This identifier is attached to the names of the files that will be saved during training. device: The device we are using, e.g. a GPU or CPU directory: string The directory all files will be saved to. numOfEpochs: int Number of training epochs. stepsPerEpoch: int Number of gradient steps per epoch. batchSize: int The batch size, i.e. the number of patches processed simultainasly on the GPU. patchSize: int The width and height of the square training patches. learningRate: float The learning rate. numMaskedPixels: int The number of pixels that is to be manipulated/masked N2V style in every training patch. virtualBatchSize: int The number of batches that are processed before a gradient step is performed. valSize: int The number of validation patches processed after each epoch. augment: bool should the patches be randomy flipped and rotated? Returns ---------- trainHist: numpy array A numpy array containing the avg. training loss of each epoch. valHist: numpy array A numpy array containing the avg. validation loss after each epoch. ''' # Calculate mean and std of data. # Everything that is processed by the net will be normalized and denormalized using these numbers. combined = np.concatenate((trainData)) net.mean = np.mean(combined) net.std = np.std(combined) net.to(device) optimizer = optim.Adam(net.parameters(), lr=learningRate) scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=10, factor=0.5, verbose=True) running_loss = 0.0 stepCounter = 0 dataCounter = 0 result_psnr_arr = [] result_ssim_arr = [] result_time_arr = [] result_denoised_img_arr = [] result_te_loss_arr = [] result_tr_loss_arr = [] pn2v = (noiseModel is not None) and (not supervised) epoch = 0 while stepCounter / stepsPerEpoch < numOfEpochs: # loop over the dataset multiple times losses = [] optimizer.zero_grad() stepCounter += 1 # Loop over our virtual batch for a in range(virtualBatchSize): outputs, labels, masks, dataCounter = trainingPred( trainData, net, dataCounter, patchSize, batchSize, numMaskedPixels, device, augment=augment, supervised=supervised) loss = lossFunction(outputs, labels, masks, noiseModel, pn2v, net.std) loss.backward() running_loss += loss.item() losses.append(loss.item()) optimizer.step() avgValLoss = [] # We have reached the end of an epoch if stepCounter % stepsPerEpoch == stepsPerEpoch - 1: running_loss = (np.mean(losses)) losses = np.array(losses) utils.printNow("Epoch " + str(int(stepCounter / stepsPerEpoch)) + " finished") utils.printNow("avg. loss: " + str(np.mean(losses)) + "+-(2SEM)" + str(2.0 * np.std(losses) / np.sqrt(losses.size))) # trainHist.append(np.mean(losses)) # torch.save(net,os.path.join(directory,"last_"+postfix+".net")) valCounter = 0 net.train(False) losses = [] for i in range(valSize): outputs, labels, masks, valCounter = trainingPred( valData, net, valCounter, patchSize, batchSize, numMaskedPixels, device, augment=augment, supervised=supervised) loss = lossFunction(outputs, labels, masks, noiseModel, pn2v, net.std) losses.append(loss.item()) PSNR_arr = [] SSIM_arr = [] denoised_img_arr = [] time_arr = [] for index in range(te_Data_target.shape[0]): start = time.time() _, w, h = te_Data_target.shape remain = w % 2 im = te_Data_source[index, :w - remain, :h - remain] gt = te_Data_target[ index, :w - remain, :h - remain] # The ground truth is the same for all images # We are using tiling to fit the image into memory # If you get an error try a smaller patch size (ps) n2vResult = prediction.tiledPredict(im, net, ps=256, overlap=48, device=device, noiseModel=None) inference_time = time.time() - start time_arr.append(inference_time) # inputImgs.append(im) denoised_img_arr.append(n2vResult) rangePSNR = np.max(gt) - np.min(gt) PSNR_img = PSNR(gt, n2vResult, rangePSNR) PSNR_arr.append(PSNR_img) SSIM_img = get_SSIM(gt, n2vResult) SSIM_arr.append(SSIM_img) result_denoised_img_arr = denoised_img_arr.copy() mean_loss = np.mean(running_loss) mean_psnr = np.mean(PSNR_arr) mean_ssim = np.mean(SSIM_arr) mean_time = np.mean(time_arr) result_psnr_arr.append(mean_psnr) result_ssim_arr.append(mean_ssim) result_time_arr.append(mean_time) # result_te_loss_arr.append(mean_te_loss) result_tr_loss_arr.append(mean_loss) net.train(True) avgValLoss = np.mean(losses) # if len(valHist)==0 or avgValLoss < np.min(np.array(valHist)): # torch.save(net,os.path.join(directory,"best_"+postfix+".net")) # valHist.append(avgValLoss) scheduler.step(avgValLoss) epoch = (stepCounter / stepsPerEpoch) # np.save(os.path.join(directory,"history"+postfix+".npy"), (np.array( [np.arange(epoch),trainHist,valHist ] ) ) ) print('Tr loss : ', round(running_loss, 4), ' PSNR : ', round(mean_psnr, 2), ' SSIM : ', round(mean_ssim, 4), ' Best Time : ', round(mean_time, 2)) sio.savemat( '../../result_data/' + save_file_name + '_result', { 'tr_loss_arr': result_tr_loss_arr, 'psnr_arr': result_psnr_arr, 'ssim_arr': result_ssim_arr, 'time_arr': result_time_arr, 'denoised_img': result_denoised_img_arr[:10] }) torch.save(net.state_dict(), '../../weights/' + save_file_name + '.w') utils.printNow('Finished Training') return