Пример #1
0
def trainNetwork(net,
                 trainData,
                 valData,
                 noiseModel,
                 postfix,
                 device,
                 directory='.',
                 numOfEpochs=200,
                 stepsPerEpoch=50,
                 batchSize=4,
                 patchSize=100,
                 learningRate=0.0001,
                 numMaskedPixels=100 * 100 / 32.0,
                 virtualBatchSize=20,
                 valSize=20,
                 augment=True,
                 supervised=False):
    '''
    Train a network using PN2V
    
    Parameters
    ----------
    net: 
        The network we want to train.
        The number of output channels determines the number of samples that are predicted.
    trainData: numpy array
        Our training image. A 3D array that is interpreted as a stack of 2D images.
    valData: numpy array
        Our validation image. A 3D array that is interpreted as a stack of 2D images.
    noiseModel: NoiseModel
        The noise model we will use during training.
    postfix: string
        This identifier is attached to the names of the files that will be saved during training.
    device: 
        The device we are using, e.g. a GPU or CPU
    directory: string
        The directory all files will be saved to.
    numOfEpochs: int
        Number of training epochs.
    stepsPerEpoch: int
        Number of gradient steps per epoch.
    batchSize: int
        The batch patch_size, i.e. the number of patches processed simultainasly on the GPU.
    patchSize: int
        The width and height of the square training patches.
    learningRate: float
        The learning rate.
    numMaskedPixels: int
        The number of pixels that is to be manipulated/masked N2V style in every training patch.
    virtualBatchSize: int
        The number of batches that are processed before a gradient step is performed.
    valSize: int
        The number of validation patches processed after each epoch.
    augment: bool
        should the patches be randomy flipped and rotated? 
    
        
    Returns
    ----------    
    trainHist: numpy array 
        A numpy array containing the avg. training loss of each epoch.
    valHist: numpy array
        A numpy array containing the avg. validation loss after each epoch.
    '''

    # Calculate mean and std of image.
    # Everything that is processed by the net will be normalized and denormalized using these numbers.
    combined = np.concatenate((trainData, valData))
    net.mean = np.mean(combined)
    net.std = np.std(combined)

    net.to(device)

    optimizer = optim.Adam(net.parameters(), lr=learningRate)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                     'min',
                                                     patience=10,
                                                     factor=0.5,
                                                     verbose=True)

    running_loss = 0.0
    stepCounter = 0
    dataCounter = 0

    trainHist = []
    valHist = []

    pn2v = (noiseModel is not None) and (not supervised)

    while stepCounter / stepsPerEpoch < numOfEpochs:  # loop over the dataset multiple times
        losses = []
        optimizer.zero_grad()
        stepCounter += 1

        # Loop over our virtual batch
        for a in range(virtualBatchSize):
            outputs, labels, masks, dataCounter = trainingPred(
                trainData,
                net,
                dataCounter,
                patchSize,
                batchSize,
                numMaskedPixels,
                device,
                augment=augment,
                supervised=supervised)
            loss = lossFunction(outputs, labels, masks, noiseModel, pn2v,
                                net.std)
            loss.backward()
            running_loss += loss.item()
            losses.append(loss.item())

        optimizer.step()

        # We have reached the end of an epoch
        if stepCounter % stepsPerEpoch == stepsPerEpoch - 1:
            running_loss = (np.mean(losses))
            losses = np.array(losses)
            utils.printNow("Epoch " + str(int(stepCounter / stepsPerEpoch)) +
                           " finished")
            utils.printNow("avg. loss: " + str(np.mean(losses)) + "+-(2SEM)" +
                           str(2.0 * np.std(losses) / np.sqrt(losses.size)))
            trainHist.append(np.mean(losses))
            torch.save(net, os.path.join(directory,
                                         "last_" + postfix + ".net"))

            valCounter = 0
            net.trainable = False
            losses = []
            for i in range(valSize):
                outputs, labels, masks, valCounter = trainingPred(
                    valData,
                    net,
                    valCounter,
                    patchSize,
                    batchSize,
                    numMaskedPixels,
                    device,
                    augment=augment,
                    supervised=supervised)
                loss = lossFunction(outputs, labels, masks, noiseModel, pn2v,
                                    net.std)
                losses.append(loss.item())
            net.trainable = True
            avgValLoss = np.mean(losses)
            if len(valHist) == 0 or avgValLoss < np.min(np.array(valHist)):
                torch.save(net,
                           os.path.join(directory, "best_" + postfix + ".net"))
            valHist.append(avgValLoss)
            epoch = (stepCounter / stepsPerEpoch)
            np.save(os.path.join(directory, "history" + postfix + ".npy"),
                    (np.array([np.arange(epoch), trainHist, valHist])))

    utils.printNow('Finished Training')
    return trainHist, valHist
Пример #2
0
def trainNetwork(
    net,
    trainData,
    valData,
    te_Data_target,
    te_Data_source,
    noiseModel,
    postfix,
    device,
    directory='.',
    numOfEpochs=100,
    stepsPerEpoch=400,
    batchSize=128,
    patchSize=64,
    learningRate=0.0004,
    numMaskedPixels=64,
    virtualBatchSize=20,
    valSize=20,
    augment=True,
    supervised=False,
    save_file_name=None,
):
    '''
    Train a network using PN2V
    
    Parameters
    ----------
    net: 
        The network we want to train.
        The number of output channels determines the number of samples that are predicted.
    trainData: numpy array
        Our training data. A 3D array that is interpreted as a stack of 2D images.
    valData: numpy array
        Our validation data. A 3D array that is interpreted as a stack of 2D images.
    noiseModel: NoiseModel
        The noise model we will use during training.
    postfix: string
        This identifier is attached to the names of the files that will be saved during training.
    device: 
        The device we are using, e.g. a GPU or CPU
    directory: string
        The directory all files will be saved to.
    numOfEpochs: int
        Number of training epochs.
    stepsPerEpoch: int
        Number of gradient steps per epoch.
    batchSize: int
        The batch size, i.e. the number of patches processed simultainasly on the GPU.
    patchSize: int
        The width and height of the square training patches.
    learningRate: float
        The learning rate.
    numMaskedPixels: int
        The number of pixels that is to be manipulated/masked N2V style in every training patch.
    virtualBatchSize: int
        The number of batches that are processed before a gradient step is performed.
    valSize: int
        The number of validation patches processed after each epoch.
    augment: bool
        should the patches be randomy flipped and rotated? 
    
        
    Returns
    ----------    
    trainHist: numpy array 
        A numpy array containing the avg. training loss of each epoch.
    valHist: numpy array
        A numpy array containing the avg. validation loss after each epoch.
    '''

    # Calculate mean and std of data.
    # Everything that is processed by the net will be normalized and denormalized using these numbers.
    combined = np.concatenate((trainData))
    net.mean = np.mean(combined)
    net.std = np.std(combined)

    net.to(device)

    optimizer = optim.Adam(net.parameters(), lr=learningRate)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                     'min',
                                                     patience=10,
                                                     factor=0.5,
                                                     verbose=True)

    running_loss = 0.0
    stepCounter = 0
    dataCounter = 0

    result_psnr_arr = []
    result_ssim_arr = []
    result_time_arr = []
    result_denoised_img_arr = []
    result_te_loss_arr = []
    result_tr_loss_arr = []

    pn2v = (noiseModel is not None) and (not supervised)

    epoch = 0

    while stepCounter / stepsPerEpoch < numOfEpochs:  # loop over the dataset multiple times

        losses = []
        optimizer.zero_grad()
        stepCounter += 1

        # Loop over our virtual batch
        for a in range(virtualBatchSize):
            outputs, labels, masks, dataCounter = trainingPred(
                trainData,
                net,
                dataCounter,
                patchSize,
                batchSize,
                numMaskedPixels,
                device,
                augment=augment,
                supervised=supervised)
            loss = lossFunction(outputs, labels, masks, noiseModel, pn2v,
                                net.std)
            loss.backward()
            running_loss += loss.item()
            losses.append(loss.item())

        optimizer.step()
        avgValLoss = []

        # We have reached the end of an epoch
        if stepCounter % stepsPerEpoch == stepsPerEpoch - 1:
            running_loss = (np.mean(losses))
            losses = np.array(losses)
            utils.printNow("Epoch " + str(int(stepCounter / stepsPerEpoch)) +
                           " finished")
            utils.printNow("avg. loss: " + str(np.mean(losses)) + "+-(2SEM)" +
                           str(2.0 * np.std(losses) / np.sqrt(losses.size)))
            #             trainHist.append(np.mean(losses))
            #             torch.save(net,os.path.join(directory,"last_"+postfix+".net"))

            valCounter = 0
            net.train(False)
            losses = []
            for i in range(valSize):
                outputs, labels, masks, valCounter = trainingPred(
                    valData,
                    net,
                    valCounter,
                    patchSize,
                    batchSize,
                    numMaskedPixels,
                    device,
                    augment=augment,
                    supervised=supervised)
                loss = lossFunction(outputs, labels, masks, noiseModel, pn2v,
                                    net.std)
                losses.append(loss.item())

            PSNR_arr = []
            SSIM_arr = []
            denoised_img_arr = []
            time_arr = []

            for index in range(te_Data_target.shape[0]):

                start = time.time()

                _, w, h = te_Data_target.shape
                remain = w % 2

                im = te_Data_source[index, :w - remain, :h - remain]
                gt = te_Data_target[
                    index, :w - remain, :h -
                    remain]  # The ground truth is the same for all images

                # We are using tiling to fit the image into memory
                # If you get an error try a smaller patch size (ps)
                n2vResult = prediction.tiledPredict(im,
                                                    net,
                                                    ps=256,
                                                    overlap=48,
                                                    device=device,
                                                    noiseModel=None)

                inference_time = time.time() - start
                time_arr.append(inference_time)

                #                 inputImgs.append(im)
                denoised_img_arr.append(n2vResult)

                rangePSNR = np.max(gt) - np.min(gt)
                PSNR_img = PSNR(gt, n2vResult, rangePSNR)
                PSNR_arr.append(PSNR_img)

                SSIM_img = get_SSIM(gt, n2vResult)
                SSIM_arr.append(SSIM_img)

                result_denoised_img_arr = denoised_img_arr.copy()

            mean_loss = np.mean(running_loss)
            mean_psnr = np.mean(PSNR_arr)
            mean_ssim = np.mean(SSIM_arr)
            mean_time = np.mean(time_arr)

            result_psnr_arr.append(mean_psnr)
            result_ssim_arr.append(mean_ssim)
            result_time_arr.append(mean_time)
            #             result_te_loss_arr.append(mean_te_loss)
            result_tr_loss_arr.append(mean_loss)

            net.train(True)
            avgValLoss = np.mean(losses)
            #             if len(valHist)==0 or avgValLoss < np.min(np.array(valHist)):
            #                 torch.save(net,os.path.join(directory,"best_"+postfix+".net"))
            #             valHist.append(avgValLoss)
            scheduler.step(avgValLoss)
            epoch = (stepCounter / stepsPerEpoch)
            #             np.save(os.path.join(directory,"history"+postfix+".npy"), (np.array( [np.arange(epoch),trainHist,valHist ] ) ) )

            print('Tr loss : ', round(running_loss, 4), ' PSNR : ',
                  round(mean_psnr, 2), ' SSIM : ', round(mean_ssim, 4),
                  ' Best Time : ', round(mean_time, 2))

            sio.savemat(
                '../../result_data/' + save_file_name + '_result', {
                    'tr_loss_arr': result_tr_loss_arr,
                    'psnr_arr': result_psnr_arr,
                    'ssim_arr': result_ssim_arr,
                    'time_arr': result_time_arr,
                    'denoised_img': result_denoised_img_arr[:10]
                })
            torch.save(net.state_dict(),
                       '../../weights/' + save_file_name + '.w')

    utils.printNow('Finished Training')
    return