def train_epoch(epoch, NUM_Of_Nets, device, startpic, nets, optimizer, losses,
                train_data):
    '''
    Trains networks for one epoch
    Parameters:
    -----------
        epoch: int
            Current epoch number
        NUM_Of_Nets: int
            Number of networks
        device: torch.device
            device to train on
        startpic: float array
            Start image to feed to first network
        nets: nn.module array
            Array of networks to train
        optimizer: torch.optim array
            Array of optimizer for training
        losses: (loss-)function array
            Array of losses used to train each network
        train_data: dataloader
            Dataloader of training data        
    Returns:
    --------
        -
    '''
    tot_loss = 0

    for net in nets:
        net.train()

    for target in train_data:
        target = target.to(device)
        x = startpic.to(device)
        measurement = magnitude(fft2WoSq(target)).to(device)

        for step in range(NUM_Of_Nets):
            x = x.detach()
            delta_y = magnitude(fft2WoSq(x)) - measurement

            optimizer[step].zero_grad()
            out = nets[step](delta_y)

            out = out * (
                (0.5 - (step % 2)) * 2)  #alternating addition and substraction
            x = torch.clamp(x + out, 0, 1)

            criterion = losses[step]
            loss = criterion(x, target)
            loss.backward()
            optimizer[step].step()

        tot_loss = tot_loss + loss.item()

    print('Epoche: {:3.0f} | Loss: {:.6f}'.format(epoch,
                                                  tot_loss / len(train_data)))
コード例 #2
0
def downsamplingsaveToPDF(nets, dataloader, NUM_Of_Nets, device, im_size):
    '''
    Prints targets and reconstructions from DFPRwoPh
    Parameters:
    -----------
        nets: nn.module array
            Networks to test
        dataloader: dataloader
            Dataloader of data to use
        im_size: int array
            Array of reconstruction sizes
        NUM_Of_Nets: int
            Number of networks to test
        device: torch.device
            Device to test on
    Returns:
    --------
        -   
    '''
    idx = 4
    for net in nets:
        net.eval()

    for picbatch in dataloader:
        target = picbatch.to(device)
        break

    torchvision.utils.save_image(target[idx], 'plots/Target.png')
    ft = torch.from_numpy(np.fft.fftshift(fft2WoSq(target).cpu().numpy()))
    real = ft[..., 0][idx]
    torchvision.utils.save_image(real / torch.max(real), 'plots/Real.png')
    img = ft[..., 1][idx]
    torchvision.utils.save_image(img / torch.max(img), 'plots/Img.png')
    phs = phase(ft)[idx]
    norm_magn = magnitude(ft)[idx]
    torchvision.utils.save_image(phs / torch.max(phs), 'plots/Phase.png')
    torchvision.utils.save_image(norm_magn / torch.max(norm_magn),
                                 'plots/Magnitude.png')

    for step in range(NUM_Of_Nets):
        stage_target = F.interpolate(target,
                                     size=(im_size[step], im_size[step]))
        path = 'plots/StageTarget{}.png'.format(step)
        torchvision.utils.save_image(stage_target[idx], path)
        data = magnitude(fft2WoSq(target)).to(device)

        if step > 0:
            out = F.interpolate(out, size=(im_size[-1], im_size[-1]))
            data = torch.cat([data, out], 1)
        out = nets[step](data)
        path = 'plots/Reconstruction{}.png'.format(step)
        torchvision.utils.save_image(out[idx], path)
コード例 #3
0
def DFPRsaveToPDF(nets, dataloader, startpic, NUM_Of_Nets, device):
    '''
    Prints targets and reconstructions from DFPRwoPh
    Parameters:
    -----------
        nets: nn.module array
            Networks to test
        dataloader: dataloader
            Dataloader of data to use
        startpic: float array size of images
            Start image to feed to first network
        NUM_Of_Nets: int
            Number of networks to test
        device: torch.device
            Device to test on
    Returns:
    --------
        -   
    '''
    idx = 4
    for net in nets:
        net.eval()

    for picbatch in dataloader:
        target = picbatch.to(device)
        break

    torchvision.utils.save_image(target[idx], 'plots/Target.png')
    x = startpic.to(device)
    measurement = magnitude(fft2WoSq(target)).to(device)
    torchvision.utils.save_image(measurement[idx], 'plots/Magnitude.png')
    phs = phase(fft2WoSq(target))
    torchvision.utils.save_image(phs[idx], 'plots/Phase.png')
    torchvision.utils.save_image(x[idx], 'plots/startpic.png')

    for step in range(NUM_Of_Nets):
        x = x.detach()
        fourier_t = fft2WoSq(x)
        phase_x = phase(fourier_t)
        delta_y = magnitude(fourier_t) - measurement
        data = torch.cat([delta_y, phase_x], 1).detach()
        out = nets[step](data)

        out = out * (
            (0.5 - (step % 2)) * 2)  #alternating addition and substraction
        x = torch.clamp(x + out, 0, 1)

        path = 'plots/Reconstruction{}.png'.format(step)
        torchvision.utils.save_image(x[idx], path)
コード例 #4
0
def downsampling_test(nets, test_data, NUM_Of_Nets, device, im_size, pics, save):    
    '''
    Tests and benchmarks performance of networks from DFPR
    Parameters:
    -----------
        nets: nn.module array
            Networks to test
        test_data: dataloader
            Dataloader of test data
        NUM_Of_Nets: int
            Number of networks to test
        device: torch.device 
            Device to test on
        im_size: int array
            Array of reconstruction sizes
        pics: boolean
            Show grid of pics for visualization
        save: boolean
            Save grid of reconstruction to png
    Returns:
    --------
        -   
    '''
    print("========================================")
    print("Running Tests")
    pred = []
    true = []
    for net in nets:
        net.eval()
    with torch.no_grad():
        for target in test_data:
            if not len(true):
                true = target
            else:
                true = torch.cat([true, target])
            target = target.to(device)                     
            for step in range(NUM_Of_Nets):
                stage_target = F.interpolate(target, size=(im_size[step], im_size[step]))
                data = magnitude(fft2WoSq(target))            
                if step > 0:
                    out = F.interpolate(out, size=(im_size[-1], im_size[-1]))
                    data = torch.cat([data, out], 1)
                out = nets[step](data)
            out = out.cpu()    
            if not len(pred):
                pred = out
            else:
                pred = torch.cat([pred, out])
    
    true = true.numpy()
    pred = pred.numpy()    
    
    if pics:
        plot_grid(true[:8], figsize = (15,15), grid_size = 8)
        plot_grid(pred[:8], figsize = (15,15), grid_size = 8, save = save)
           
    benchmark(pred, true, check_all = True)
    print("========================================")
def train_epoch(epoch, NUM_Of_Nets, device, im_size, beta, nets, optimizer,
                losses, train_data):
    '''
    Trains networks for one epoch
    Parameters:
    -----------
        epoch: int
            Current epoch number
        NUM_Of_Nets: int
            Number of networks
        device: torch.device
            Device to train on            
        im_size: int array
            Array of reconstruction sizes
        nets: nn.module array
            Array of networks to train
        optimizer: torch.optim array
            Array of optimizer for training
        losses: (loss-)function array
            Array of losses used to train each network
        train_data: dataloader
            Dataloader of training data        
    Returns:
    --------
        -
    '''
    for net in nets:
        net.train()

    tot_loss = 0
    reg_loss = 0

    for target in train_data:
        target = target.to(device)

        for step in range(NUM_Of_Nets):
            stage_target = F.interpolate(target,
                                         size=(im_size[step], im_size[step]))
            data = magnitude(fft2WoSq(target)).to(device)
            if step > 0:
                out = F.interpolate(out, size=(im_size[-1], im_size[-1]))
                data = torch.cat([data, out], 1)

            optimizer[step].zero_grad()
            data = data.detach()
            out = nets[step](data)

            criterion = losses[step]
            loss = criterion(out, stage_target)

            if beta > 0:
                mags = squaredmagnitude(fft2(stage_target))
                mags = mags.detach()
                reg = beta * torch.mean(
                    (squaredmagnitude(fft2(out)) - mags)**2)
                loss = loss + reg

            loss.backward()
            optimizer[step].step()

        tot_loss = tot_loss + loss.item()
    print('Epoche: {:3.0f} | Loss: {:.6f}'.format(epoch,
                                                  tot_loss / len(train_data)))