def main(): args = parser.parse_args() device = f'cuda:{args.gpu}' unet = UNet(in_channels=1, out_channels=1, compact=4, residual=True, circular_padding=True, cat=True).to(device) forw = CT(img_width=128, radon_view=50, circle=False, device=device) dataloader = torch.utils.data.DataLoader(dataset=CTData(mode='test'),batch_size=1, shuffle=False) def test(net, ckp, fbp, adv=False): checkpoint = torch.load(ckp, map_location=device) net.load_state_dict(checkpoint['state_dict_G' if adv else 'state_dict']) net.to(device).eval() return net(fbp) for i, x in enumerate(dataloader): if i in args.sample_to_show: if len(x.shape) == 3: x = x.unsqueeze(1) x = x.type(torch.float).to(device) y = forw.A(x) fbp = forw.A_dagger(y) x_hat = test(unet, args.ckp_net, fbp) plt.subplot(1,4,1) plt.imshow(y[0].detach().permute(1, 2, 0).cpu().numpy()) plt.title('y') plt.subplot(1,4,2) plt.imshow(fbp[0].detach().permute(1, 2, 0).cpu().numpy()) plt.title('FBP ({:.2f})'.format(cal_psnr(x, fbp))) plt.subplot(1,4,3) plt.imshow(x_hat[0].detach().permute(1, 2, 0).cpu().numpy()) plt.title('{} ({:.2f})'.format(args.model_name, cal_psnr(x, x_hat))) plt.subplot(1,4,4) plt.imshow(x[0].detach().permute(1, 2, 0).cpu().numpy()) plt.title('x (GT)') ax = plt.gca() ax.set_xticks([]), ax.set_yticks([]) plt.subplots_adjust(left=0.1, bottom=0.1, top=0.9, right=0.9, hspace=0.02, wspace=0.02) plt.show() else: continue
def closure_sup_ei(net, dataloader, physics, transform, optimizer, criterion_fc, criterion_ei, alpha, dtype, device, reportpsnr=False): loss_x_seq, loss_ei_seq, loss_seq, psnr_seq, mse_seq = [], [], [], [], [] for i, x in enumerate(dataloader): x = x[0] if isinstance(x, list) else x if len(x.shape) == 3: x = x.unsqueeze(1) x = x.type(dtype).to(device) y0 = physics.A(x.type(dtype).to(device)) x0 = physics.A_dagger(y0) #range input (pr) x1 = net(x0) y1 = physics.A(x1) # EI: x2, x3 x2 = transform.apply(x1) x3 = net(physics.A_dagger(physics.A(x2))) loss_x = criterion_fc(x1, x) loss_ei = criterion_ei(x3, x2) loss = loss_x + alpha['ei'] * loss_ei loss_x_seq.append(loss_x.item()) loss_ei_seq.append(loss_ei.item()) loss_seq.append(loss.item()) if reportpsnr: psnr_seq.append(cal_psnr(x1, x)) mse_seq.append(cal_mse(x1, x)) optimizer.zero_grad() loss.backward() optimizer.step() loss_closure = [ np.mean(loss_x_seq), np.mean(loss_ei_seq), np.mean(loss_seq) ] if reportpsnr: loss_closure.append(np.mean(psnr_seq)) loss_closure.append(np.mean(mse_seq)) return loss_closure
def closure_ei(net, dataloader, physics, transform, optimizer, criterion_mc, criterion_ei, alpha, dtype, device, reportpsnr=False): loss_mc_seq, loss_ei_seq, loss_seq, psnr_seq, mse_seq = [], [], [], [], [] for i, x in enumerate(dataloader): x = x[0] if isinstance(x, list) else x if len(x.shape)==3: x = x.unsqueeze(1) x = x.type(dtype).to(device)# ground-truth signal x y0 = physics.A(x.type(dtype).to(device)) # generate measurement input y x0 = physics.A_dagger(y0) # range input (A^+y) x1 = net(x0) y1 = physics.A(x1) # equivariant imaging: x2, x3 x2 = transform.apply(x1) x3 = net(physics.A_dagger(physics.A(x2))) loss_mc = criterion_mc(y1, y0) loss_ei = criterion_ei(x3, x2) loss = loss_mc + alpha['ei'] * loss_ei loss_mc_seq.append(loss_mc.item()) loss_ei_seq.append(loss_ei.item()) loss_seq.append(loss.item()) if reportpsnr: psnr_seq.append(cal_psnr(x1, x)) mse_seq.append(cal_mse(x1, x)) optimizer.zero_grad() loss.backward() optimizer.step() loss_closure = [np.mean(loss_mc_seq), np.mean(loss_ei_seq), np.mean(loss_seq)] if reportpsnr: loss_closure.append(np.mean(psnr_seq)) loss_closure.append(np.mean(mse_seq)) return loss_closure
def closure_dip(net, dataloader, z, physics, optimizer, criterion_mc, dtype, device, reportpsnr=False): loss_dip_seq = [] for i, x in enumerate(dataloader): x = x[0] if isinstance(x, list) else x if len(x.shape) == 3: x = x.unsqueeze(1) x = x.type(dtype).to(device) y0 = physics.A(x.type(dtype).to(device)) # z = torch.rand_like(x) x1 = net(z) y1 = physics.A(x1) if reportpsnr: psnr = cal_psnr(x1, x) mse = torch.nn.MSELoss()(x1, x).item() loss_mc = criterion_mc(y1, y0) loss_dip_seq.append(loss_mc.item()) optimizer.zero_grad() loss_fc.backward() optimizer.step() loss_closure = [np.mean(loss_dip_seq)] if reportpsnr: loss_closure.append(psnr) loss_closure.append(mse) return loss_closure
def closure_mc(net, dataloader, physics, optimizer, criterion_mc, dtype, device, reportpsnr=False): loss_mc_seq, psnr_seq, mse_seq = [], [], [] for i, x in enumerate(dataloader): x = x[0] if isinstance(x, list) else x if len(x.shape) == 3: x = x.unsqueeze(1) x = x.type(dtype).to(device) # ground-truth y0 = physics.A(x.type(dtype).to(device)) # measurement x0 = physics.A_dagger(y0) # range input x1 = net(x0) y1 = physics.A(x1) loss_mc = criterion_mc(y1, y0) loss_mc_seq.append(loss_mc.item()) if reportpsnr: psnr_seq.append(cal_psnr(x1, x)) mse_seq.append(cal_mse(x1, x)) optimizer.zero_grad() loss_mc.backward() optimizer.step() loss_closure = [np.mean(loss_mc_seq)] if reportpsnr: loss_closure.append(np.mean(psnr_seq)) loss_closure.append(np.mean(mse_seq)) return loss_closure
def main(): args = parser.parse_args() device = f'cuda:{args.gpu}' # define the dataloader (i.e. 'urban100', first 90 imgs for training, last 10 for testing) dataloader = CVDB_ICCV(dataset_name=args.dataset_name, mode='test', batch_size=1, shuffle=False) # define the forward oeprator (i.e. physics) forw = Inpainting(img_heigth=256, img_width=256, mask_rate=0.3, device=device) # define the network G (i.e. residual unet in the paper) unet = UNet(in_channels=3, out_channels=3, compact=4, residual=True, circular_padding=True, cat=True).to(device) psnr_fbp, psnr_net = [], [] def test(net, ckp, fbp, adv=False): checkpoint = torch.load(ckp, map_location=device) net.load_state_dict( checkpoint['state_dict_G' if adv else 'state_dict']) net.to(device).eval() return net(fbp) for i, x in enumerate(dataloader): x = x[0] if isinstance(x, list) else x if len(x.shape) == 3: x = x.unsqueeze(1) # groundtruth x = x.type(torch.float).to(device) # compute measurement y = forw.A(x) # compute the A^+y or FBP fbp = forw.A_dagger(y) x_hat = test(unet, args.ckp, fbp) if i in args.sample_to_show: plt.subplot(1, 4, 1) plt.imshow(y.squeeze().detach().permute(1, 2, 0).cpu().numpy()) plt.title('y') plt.subplot(1, 4, 2) plt.imshow(fbp.squeeze().detach().permute(1, 2, 0).cpu().numpy()) plt.title('FBP ({:.2f})'.format(cal_psnr(x, fbp))) plt.subplot(1, 4, 3) plt.imshow(x_hat.squeeze().detach().permute(1, 2, 0).cpu().numpy()) plt.title('{} ({:.2f})'.format(args.model_name, cal_psnr(x, fbp))) plt.subplot(1, 4, 4) plt.imshow(x.squeeze().detach().permute(1, 2, 0).cpu().numpy()) plt.title('x (GT)') ax = plt.gca() ax.set_xticks([]), ax.set_yticks([]) plt.subplots_adjust(left=0.1, bottom=0.1, top=0.9, right=0.9, hspace=0.02, wspace=0.02) plt.show() print('Inpainting (0.3) AVG-PSNR: A^+y={:.2f}\t{}={:.2f}'.format( np.mean(psnr_fbp), args.model_name, np.mean(psnr_ei)))
def closure_ei_adv(generator, discriminator, dataloader, physics, transform, optimizer_G, optimizer_D, criterion_mc, criterion_ei, criterion_gan, alpha, dtype, device, reportpsnr=False): loss_mc_seq, loss_ei_seq, loss_g_seq, loss_G_seq, loss_D_seq, psnr_seq, mse_seq = [], [], [], [], [], [], [] for i, x in enumerate(dataloader): x = x[0] if isinstance(x, list) else x if len(x.shape) == 3: x = x.unsqueeze(1) x = x.type(dtype).to(device) # Measurements y0 = physics.A(x) # Model range inputs x0 = Variable(physics.A_dagger(y0)) # range input (pr) # Adversarial ground truths valid = torch.ones(x.shape[0], *discriminator.output_shape).type(dtype).to(device) valid_ei = torch.ones( x.shape[0] * transform.n_trans, *discriminator.output_shape).type(dtype).to(device) fake_ei = torch.zeros( x.shape[0] * transform.n_trans, *discriminator.output_shape).type(dtype).to(device) valid = Variable(valid, requires_grad=False) valid_ei = Variable(valid_ei, requires_grad=False) fake_ei = Variable(fake_ei, requires_grad=False) # ----------------- # Train Generator # ----------------- optimizer_G.zero_grad() # Generate a batch of images from range input A^+y x1 = generator(x0) y1 = physics.A(x1) # EI: x2, x3 x2 = transform.apply(x1) x3 = generator(physics.A_dagger(physics.A(x2))) # Loss measures generator's ability to measurement consistency and ei loss_fc = criterion_mc(y1, y0) loss_ei = criterion_ei(x3, x2) # Loss measures generator's ability to fool the discriminator loss_g = criterion_gan(discriminator(x2), valid_ei) loss_G = loss_fc + alpha['ei'] * loss_ei + alpha['adv'] * loss_g loss_G.backward() optimizer_G.step() # --------------------- # Train Discriminator # --------------------- optimizer_D.zero_grad() # Measure discriminator's ability to classify real from generated samples real_loss = criterion_gan(discriminator(x1.detach()), valid) fake_loss = criterion_gan(discriminator(x2.detach()), fake_ei) loss_D = 0.5 * alpha['adv'] * (real_loss + fake_loss) loss_D.backward() optimizer_D.step() if reportpsnr: psnr_seq.append(cal_psnr(x1, x)) mse_seq.append(cal_mse(x1, x)) # -------------- # Log Progress # -------------- loss_mc_seq.append(loss_fc.item()) loss_ei_seq.append(loss_ei.item()) loss_g_seq.append(loss_g.item()) loss_G_seq.append(loss_G.item()) # total loss for G loss_D_seq.append(loss_D.item()) # total loss for D #loss: loss_fc, loss_ti, loss_g, loss_G, loss_D loss_closure = [np.mean(loss_mc_seq), np.mean(loss_ei_seq), np.mean(loss_g_seq),\ np.mean(loss_G_seq), np.mean(loss_D_seq)] if reportpsnr: loss_closure.append(np.mean(psnr_seq)) loss_closure.append(np.mean(mse_seq)) return loss_closure