def train_epoch(epoch, NUM_Of_Nets, device, startpic, nets, optimizer, losses, train_data): ''' Trains networks for one epoch Parameters: ----------- epoch: int Current epoch number NUM_Of_Nets: int Number of networks device: torch.device device to train on startpic: float array Start image to feed to first network nets: nn.module array Array of networks to train optimizer: torch.optim array Array of optimizer for training losses: (loss-)function array Array of losses used to train each network train_data: dataloader Dataloader of training data Returns: -------- - ''' tot_loss = 0 for net in nets: net.train() for target in train_data: target = target.to(device) x = startpic.to(device) measurement = magnitude(fft2WoSq(target)).to(device) for step in range(NUM_Of_Nets): x = x.detach() delta_y = magnitude(fft2WoSq(x)) - measurement optimizer[step].zero_grad() out = nets[step](delta_y) out = out * ( (0.5 - (step % 2)) * 2) #alternating addition and substraction x = torch.clamp(x + out, 0, 1) criterion = losses[step] loss = criterion(x, target) loss.backward() optimizer[step].step() tot_loss = tot_loss + loss.item() print('Epoche: {:3.0f} | Loss: {:.6f}'.format(epoch, tot_loss / len(train_data)))
def downsamplingsaveToPDF(nets, dataloader, NUM_Of_Nets, device, im_size): ''' Prints targets and reconstructions from DFPRwoPh Parameters: ----------- nets: nn.module array Networks to test dataloader: dataloader Dataloader of data to use im_size: int array Array of reconstruction sizes NUM_Of_Nets: int Number of networks to test device: torch.device Device to test on Returns: -------- - ''' idx = 4 for net in nets: net.eval() for picbatch in dataloader: target = picbatch.to(device) break torchvision.utils.save_image(target[idx], 'plots/Target.png') ft = torch.from_numpy(np.fft.fftshift(fft2WoSq(target).cpu().numpy())) real = ft[..., 0][idx] torchvision.utils.save_image(real / torch.max(real), 'plots/Real.png') img = ft[..., 1][idx] torchvision.utils.save_image(img / torch.max(img), 'plots/Img.png') phs = phase(ft)[idx] norm_magn = magnitude(ft)[idx] torchvision.utils.save_image(phs / torch.max(phs), 'plots/Phase.png') torchvision.utils.save_image(norm_magn / torch.max(norm_magn), 'plots/Magnitude.png') for step in range(NUM_Of_Nets): stage_target = F.interpolate(target, size=(im_size[step], im_size[step])) path = 'plots/StageTarget{}.png'.format(step) torchvision.utils.save_image(stage_target[idx], path) data = magnitude(fft2WoSq(target)).to(device) if step > 0: out = F.interpolate(out, size=(im_size[-1], im_size[-1])) data = torch.cat([data, out], 1) out = nets[step](data) path = 'plots/Reconstruction{}.png'.format(step) torchvision.utils.save_image(out[idx], path)
def DFPRsaveToPDF(nets, dataloader, startpic, NUM_Of_Nets, device): ''' Prints targets and reconstructions from DFPRwoPh Parameters: ----------- nets: nn.module array Networks to test dataloader: dataloader Dataloader of data to use startpic: float array size of images Start image to feed to first network NUM_Of_Nets: int Number of networks to test device: torch.device Device to test on Returns: -------- - ''' idx = 4 for net in nets: net.eval() for picbatch in dataloader: target = picbatch.to(device) break torchvision.utils.save_image(target[idx], 'plots/Target.png') x = startpic.to(device) measurement = magnitude(fft2WoSq(target)).to(device) torchvision.utils.save_image(measurement[idx], 'plots/Magnitude.png') phs = phase(fft2WoSq(target)) torchvision.utils.save_image(phs[idx], 'plots/Phase.png') torchvision.utils.save_image(x[idx], 'plots/startpic.png') for step in range(NUM_Of_Nets): x = x.detach() fourier_t = fft2WoSq(x) phase_x = phase(fourier_t) delta_y = magnitude(fourier_t) - measurement data = torch.cat([delta_y, phase_x], 1).detach() out = nets[step](data) out = out * ( (0.5 - (step % 2)) * 2) #alternating addition and substraction x = torch.clamp(x + out, 0, 1) path = 'plots/Reconstruction{}.png'.format(step) torchvision.utils.save_image(x[idx], path)
def downsampling_test(nets, test_data, NUM_Of_Nets, device, im_size, pics, save): ''' Tests and benchmarks performance of networks from DFPR Parameters: ----------- nets: nn.module array Networks to test test_data: dataloader Dataloader of test data NUM_Of_Nets: int Number of networks to test device: torch.device Device to test on im_size: int array Array of reconstruction sizes pics: boolean Show grid of pics for visualization save: boolean Save grid of reconstruction to png Returns: -------- - ''' print("========================================") print("Running Tests") pred = [] true = [] for net in nets: net.eval() with torch.no_grad(): for target in test_data: if not len(true): true = target else: true = torch.cat([true, target]) target = target.to(device) for step in range(NUM_Of_Nets): stage_target = F.interpolate(target, size=(im_size[step], im_size[step])) data = magnitude(fft2WoSq(target)) if step > 0: out = F.interpolate(out, size=(im_size[-1], im_size[-1])) data = torch.cat([data, out], 1) out = nets[step](data) out = out.cpu() if not len(pred): pred = out else: pred = torch.cat([pred, out]) true = true.numpy() pred = pred.numpy() if pics: plot_grid(true[:8], figsize = (15,15), grid_size = 8) plot_grid(pred[:8], figsize = (15,15), grid_size = 8, save = save) benchmark(pred, true, check_all = True) print("========================================")
def train_epoch(epoch, NUM_Of_Nets, device, im_size, beta, nets, optimizer, losses, train_data): ''' Trains networks for one epoch Parameters: ----------- epoch: int Current epoch number NUM_Of_Nets: int Number of networks device: torch.device Device to train on im_size: int array Array of reconstruction sizes nets: nn.module array Array of networks to train optimizer: torch.optim array Array of optimizer for training losses: (loss-)function array Array of losses used to train each network train_data: dataloader Dataloader of training data Returns: -------- - ''' for net in nets: net.train() tot_loss = 0 reg_loss = 0 for target in train_data: target = target.to(device) for step in range(NUM_Of_Nets): stage_target = F.interpolate(target, size=(im_size[step], im_size[step])) data = magnitude(fft2WoSq(target)).to(device) if step > 0: out = F.interpolate(out, size=(im_size[-1], im_size[-1])) data = torch.cat([data, out], 1) optimizer[step].zero_grad() data = data.detach() out = nets[step](data) criterion = losses[step] loss = criterion(out, stage_target) if beta > 0: mags = squaredmagnitude(fft2(stage_target)) mags = mags.detach() reg = beta * torch.mean( (squaredmagnitude(fft2(out)) - mags)**2) loss = loss + reg loss.backward() optimizer[step].step() tot_loss = tot_loss + loss.item() print('Epoche: {:3.0f} | Loss: {:.6f}'.format(epoch, tot_loss / len(train_data)))