def bench_fanbeam_backward(task, dtype, device, *bench_args): num_angles = task["num_angles"] det_count = task["det_count"] source_dist = task["source_distance"] det_dist = task["detector_distance"] det_spacing = task["det_spacing"] x = torch.randn(task["batch_size"], task["size"], task["size"], dtype=dtype, device=device) angles = np.linspace(0, np.pi, num_angles, endpoint=False) projection = Projection.fanbeam(source_dist, det_dist, det_count, det_spacing) radon = Radon(angles, task["size"], projection) # radon = RadonFanbeam(phantom.size(1), angles, source_dist, det_dist, det_count) sino = radon.forward(x) def f(x): return radon.backward(x) return benchmark(f, x, *bench_args)
def init_radon(self, beam, circle, det_dist): if beam == 'parallel': angles = np.linspace(0, np.pi, self.n_angles, endpoint=False) self.radon = Radon(self.img_size, angles, clip_to_circle=circle) self.radon_sparse = Radon(self.img_size, angles[::self.sample_ratio], clip_to_circle=circle) elif beam == 'fan': angles = np.linspace(0, self.n_angles / 180 * np.pi, self.n_angles, False) self.radon = RadonFanbeam(self.img_size, angles, source_distance=det_dist[0], det_distance=det_dist[1], clip_to_circle=circle, det_count=self.det_size) self.radon_sparse = RadonFanbeam(self.img_size, angles[::self.sample_ratio], source_distance=det_dist[0], det_distance=det_dist[1], clip_to_circle=circle, det_count=self.det_size) else: raise Exception('projection beam type undefined!') self.n_angles_sparse = len(angles[::self.sample_ratio])
def test_error(device, batch_size, image_size, angles, spacing, clip_to_circle): # generate random images x = generate_random_images(batch_size, image_size, masked=clip_to_circle) # astra astra = AstraWrapper(angles) astra_fp_id, astra_fp = astra.forward(x, spacing) astra_bp = astra.backproject(astra_fp_id, image_size, batch_size) if clip_to_circle: astra_bp *= circle_mask(image_size) # our implementation radon = Radon(image_size, angles, det_spacing=spacing, clip_to_circle=clip_to_circle) x = torch.FloatTensor(x).to(device) our_fp = radon.forward(x) our_bp = radon.backprojection(our_fp) forward_error = relative_error(astra_fp, our_fp.cpu().numpy()) back_error = relative_error(astra_bp, our_bp.cpu().numpy()) # if forward_error > 10: # plt.imshow(astra_fp[0]) # plt.figure() # plt.imshow(our_fp[0].cpu().numpy()) # plt.show() print( f"batch: {batch_size}, size: {image_size}, angles: {len(angles)}, spacing: {spacing}, circle: {clip_to_circle}, forward: {forward_error}, back: {back_error}") # TODO better checks assert_less(forward_error, 1e-2) assert_less(back_error, 5e-3)
def bench_parallel_forward(phantom, det_count, num_angles, warmup, repeats): radon = Radon(phantom.size(1), np.linspace(0, np.pi, num_angles, endpoint=False), det_count) f = lambda x: radon.forward(x) return benchmark(f, phantom, warmup, repeats)
class Predict(): def __init__(self, args, dataloader): self.device = torch.device( 'cuda' if torch.cuda.is_available() else 'cpu') self.args = args self.dataloader = dataloader self.net = UNet(input_nc=1, output_nc=1).to(self.device) self.net = nn.DataParallel(self.net) pathG = os.path.join(args.ckpt) self.net.load_state_dict(torch.load(pathG, map_location=self.device)) self.net.eval() self.gen_mask() angles = np.linspace(0, np.pi, 180, endpoint=False) self.radon = Radon(args.height, angles, clip_to_circle=True) def gen_mask(self): self.mask = torch.zeros(180).to(self.device) self.mask[::8].fill_(1) # 180 def gen_x(self, y): return self.mask * y def crop_sinogram(self, x): return x[:, :, :, 6:-6] def overlay(self, Gx, x): result = self.mask * x + (1 - self.mask) * Gx return result def inpaint(self): for i, data in enumerate(self.dataloader): y = data[0].to(self.device) # 320 x 180 x = self.gen_x(y) # input, 320 x 23 Gx = self.net(x) Gx = self.overlay(Gx, y) # FBP Gx = normalize(Gx) # 0~1 fbp_Gx = self.radon.backprojection( self.radon.filter_sinogram(Gx.permute(0, 1, 3, 2))) print(f'Saving images for batch {i}') for j in range(y.size()[0]): # vutils.save_image(Gx[j,0], f'{self.args.outdir}/{class_name}/{fnames[i*self.args.bs+j]}', normalize=True) # to 0~255 vutils.save_image( fbp_Gx[j, 0], f'{self.args.outdir}/{class_name}/{fnames[i*self.args.bs+j]}', normalize=True)
def __init__(self, args, dataloader): self.device = torch.device( 'cuda' if torch.cuda.is_available() else 'cpu') self.args = args self.dataloader = dataloader if args.twoends: factor = 192 / (args.angles + 2) # 7.68 else: factor = 180 / args.angles # 7.826086956521739 self.net = UNet(input_nc=1, output_nc=1, scale_factor=factor).to(self.device) self.net = nn.DataParallel(self.net) pathG = os.path.join(args.ckpt) self.net.load_state_dict(torch.load(pathG, map_location=self.device)) self.net.eval() self.gen_mask() # Radon Operator for different downsampling factors angles = np.linspace(0, np.pi, 180, endpoint=False) self.radon = Radon(args.height, angles, clip_to_circle=True) self.radon23 = Radon(args.height, angles[::8], clip_to_circle=True) self.radon45 = Radon(args.height, angles[::4], clip_to_circle=True) self.radon90 = Radon(args.height, angles[::2], clip_to_circle=True)
def test_differentiation(self): device = torch.device('cuda') x = torch.FloatTensor(1, 64, 64).to(device) x.requires_grad = True angles = torch.FloatTensor( np.linspace(0, 2 * np.pi, 10).astype(np.float32)).to(device) radon = Radon(64, angles) # check that backward is implemented for fp and bp y = radon.forward(x) z = torch.mean(radon.backprojection(y)) z.backward() self.assertIsNotNone(x.grad)
def __init__(self, image_size, n_angles, sample_ratio, device, circle=False): self.device = device self.image_size = image_size self.sample_ratio = sample_ratio self.n_angles = n_angles angles = np.linspace(0, np.pi, self.n_angles, endpoint=False) self.radon = Radon(self.image_size, angles, clip_to_circle=circle) self.radon_sparse = Radon(self.image_size, angles[::sample_ratio], clip_to_circle=circle) self.n_angles_sparse = len(angles[::sample_ratio]) self.landweber = Landweber(self.radon) self.mask = torch.zeros((1,1,1,180)).to(device) self.mask[:,:,:,::sample_ratio].fill_(1)
def __init__(self, args, dataloader): self.device = torch.device( 'cuda' if torch.cuda.is_available() else 'cpu') self.args = args self.dataloader = dataloader self.net = UNet(input_nc=1, output_nc=1).to(self.device) self.net = nn.DataParallel(self.net) pathG = os.path.join(args.ckpt) self.net.load_state_dict(torch.load(pathG, map_location=self.device)) self.net.eval() self.gen_mask() angles = np.linspace(0, np.pi, 180, endpoint=False) self.radon = Radon(args.height, angles, clip_to_circle=True)
def test_noise(): device = torch.device('cuda') x = torch.FloatTensor(3, 5, 64, 64).to(device) lookup_table = torch.FloatTensor(128, 64).to(device) x.requires_grad = True angles = torch.FloatTensor(np.linspace(0, 2 * np.pi, 10).astype(np.float32)) radon = Radon(64, angles) sinogram = radon.forward(x) assert_equal(sinogram.size(), (3, 5, 10, 64)) readings = radon.emulate_readings(sinogram, 5, 10.0) assert_equal(readings.size(), (3, 5, 10, 64)) assert_equal(readings.dtype, torch.int32) y = radon.readings_lookup(readings, lookup_table) assert_equal(y.size(), (3, 5, 10, 64)) assert_equal(y.dtype, torch.float32)
def bench_parallel_backward(task, dtype, device, *bench_args): num_angles = task["num_angles"] det_count = task["det_count"] x = torch.randn(task["batch_size"], task["size"], task["size"], dtype=dtype, device=device) angles = np.linspace(0, np.pi, num_angles, endpoint=False) projection = Projection.parallel_beam(det_count) radon = Radon(angles, task["size"], projection) # radon = Radon(phantom.size(1), np.linspace(0, np.pi, num_angles, endpoint=False), det_count) sino = radon.forward(x) def f(x): return radon.backward(x) return benchmark(f, sino, *bench_args)
def __init__(self, net, args, dataloader, device): self.netG = net[0] self.netDG = net[1] self.netDL = net[2] if args.mode == 'vgg': self.netLoss = net[3] self.optimizerG = optim.Adam(self.netG.parameters(), lr=args.lr, betas=(0.5, 0.999)) self.optimizerDG = optim.Adam(self.netDG.parameters(), lr=args.lr, betas=(0.5, 0.999)) self.optimizerDL = optim.Adam(self.netDL.parameters(), lr=args.lr, betas=(0.5, 0.999)) self.dataloader = dataloader self.device = device self.args = args self.save_cp = True self.start_epoch = args.load+1 if args.load>=0 else 0 self.mask = self.gen_mask().to(self.device) self.criterionL1 = torch.nn.L1Loss().to(self.device) self.criterionL2 = torch.nn.MSELoss().to(self.device) self.criterionGAN = GANLoss('vanilla').to(self.device) err_list = ["errDG", "errDL", "errGG_GAN", "errGG_C", "errGG_F", "errGG_P", "errGL_GAN", "errGL_C", "errGL_F", "errGL_P"] self.err = dict.fromkeys(err_list, None) if self.save_cp: try: if not os.path.exists(os.path.join(args.outdir, 'ckpt')): os.makedirs(os.path.join(args.outdir, 'ckpt')) print('Created checkpoint directory') if args.load < 0: # New log file with open(os.path.join(args.outdir, args.log_fn+'.csv'), 'w', newline='') as f: csvwriter = writer(f) csvwriter.writerow(["epoch", "runtime"] + err_list) except OSError: pass angles = np.linspace(0, np.pi, 180, endpoint=False) self.radon = Radon(args.height, angles, clip_to_circle=True)
def test_half(device, batch_size, image_size, angles, spacing, det_count, clip_to_circle): # generate random images det_count = int(det_count * image_size) mask_radius = det_count / 2.0 if clip_to_circle else -1 x = generate_random_images(batch_size, image_size, mask_radius) # our implementation radon = Radon(image_size, angles, det_spacing=spacing, det_count=det_count, clip_to_circle=clip_to_circle) x = torch.FloatTensor(x).to(device) sinogram = radon.forward(x) single_precision = radon.backprojection(sinogram) h_sino = radon.forward(x.half()) half_precision = radon.backprojection(h_sino) forward_error = relative_error(sinogram.cpu().numpy(), h_sino.cpu().numpy()) back_error = relative_error(single_precision.cpu().numpy(), half_precision.cpu().numpy()) print( f"batch: {batch_size}, size: {image_size}, angles: {len(angles)}, spacing: {spacing}, circle: {clip_to_circle}, forward: {forward_error}, back: {back_error}" ) assert_less(forward_error, 1e-3) assert_less(back_error, 1e-3)
def main(): n_angles = 100 image_size = 512 circle_radius = 100 source_dist = 1.5 * image_size batch_size = 1 n_scales = 5 angles = (np.linspace(0., 100., n_angles, endpoint=False) - 50.0) / 180.0 * np.pi x = np.zeros((image_size, image_size), dtype=np.float32) x[circle_mask(image_size, circle_radius)] = 1.0 radon = Radon(image_size, angles) # RadonFanbeam(image_size, angles, source_dist) shearlet = ShearletTransform(image_size, image_size, [0.5] * n_scales) torch_x = torch.from_numpy(x).cuda() torch_x = torch_x.view(1, image_size, image_size).repeat(batch_size, 1, 1) sinogram = radon.forward(torch_x) bp = radon.backward(sinogram) sc = shearlet.forward(bp) p_0 = 0.02 p_1 = 0.1 w = 3**shearlet.scales / 400 w = w.view(1, -1, 1, 1).cuda() u_2 = torch.zeros_like(bp) z_2 = torch.zeros_like(bp) u_1 = torch.zeros_like(sc) z_1 = torch.zeros_like(sc) f = torch.zeros_like(bp) relative_error = [] start_time = time.time() for i in range(100): cg_y = p_0 * bp + p_1 * shearlet.backward(z_1 - u_1) + (z_2 - u_2) f = cg(lambda x: p_0 * radon.backward(radon.forward(x)) + (1 + p_1) * x, f.clone(), cg_y, max_iter=50) sh_f = shearlet.forward(f) z_1 = shrink(sh_f + u_1, p_0 / p_1 * w) z_2 = (f + u_2).clamp_min(0) u_1 = u_1 + sh_f - z_1 u_2 = u_2 + f - z_2 relative_error.append( (torch.norm(torch_x[0] - f[0]) / torch.norm(torch_x[0])).item()) runtime = time.time() - start_time print("Running time:", runtime) print("Running time per image:", runtime / batch_size) print("Relative error: ", 100 * relative_error[-1])
def __init__(self, args, image): self.device = torch.device( 'cuda' if torch.cuda.is_available() else 'cpu') if args.twoends: factor = 192 / (args.angles + 2) # 7.68 else: factor = 180 / args.angles # 7.826086956521739 self.net = UNet(input_nc=1, output_nc=1, scale_factor=factor).to(self.device) self.net = nn.DataParallel(self.net) pathG = os.path.join(args.ckpt) self.net.load_state_dict(torch.load(pathG, map_location=self.device)) self.net.eval() self.image = image.to(self.device) self.twoends = args.twoends self.mask = self.gen_mask().to(self.device) # Radon Operator angles = np.linspace(0, np.pi, 180, endpoint=False) self.radon = Radon(args.height, angles, clip_to_circle=True)
class Operators(): def __init__(self, image_size, n_angles, sample_ratio, device, circle=False): self.device = device self.image_size = image_size self.sample_ratio = sample_ratio self.n_angles = n_angles angles = np.linspace(0, np.pi, self.n_angles, endpoint=False) self.radon = Radon(self.image_size, angles, clip_to_circle=circle) self.radon_sparse = Radon(self.image_size, angles[::sample_ratio], clip_to_circle=circle) self.n_angles_sparse = len(angles[::sample_ratio]) self.landweber = Landweber(self.radon) self.mask = torch.zeros((1,1,1,180)).to(device) self.mask[:,:,:,::sample_ratio].fill_(1) # $X^\T ()$ inverse radon def forward_adjoint(self, input): # check dimension if input.size()[3] == self.n_angles: return self.radon.backprojection(input.permute(0,1,3,2)) elif input.size()[3] == self.n_angles_sparse: return self.radon_sparse.backprojection(input.permute(0,1,3,2))/self.n_angles_sparse*self.n_angles # scale the angles else: raise Exception(f'forward_adjoint input dimension wrong! received {input.size()}.') # $X^\T X ()$ def forward_gramian(self, input): # check dimension if input.size()[2] != self.image_size: raise Exception(f'forward_gramian input dimension wrong! received {input.size()}.') sinogram = self.radon.forward(input) return self.radon.backprojection(sinogram) # Corruption model: undersample sinogram by 8 def undersample_model(self, input): return input[:,:,:,::self.sample_ratio] # Filtered Backprojection. Input siogram range = (0,1) def FBP(self, input): # check dimension if input.size()[2] != self.image_size or input.size()[3] != self.n_angles: raise Exception(f'FBP input dimension wrong! received {input.size()}.') filtered_sinogram = self.radon.filter_sinogram(input.permute(0,1,3,2)) return self.radon.backprojection(filtered_sinogram) # estimate step size eta def estimate_eta(self): eta = self.landweber.estimate_alpha(self.image_size, self.device) return torch.tensor(eta, dtype=torch.float32, device=self.device)
def test_shapes(self): """ Check using channels is ok """ device = torch.device('cuda') angles = torch.FloatTensor( np.linspace(0, 2 * np.pi, 10).astype(np.float32)).to(device) radon = Radon(64, angles) # test with 2 batch dimensions x = torch.FloatTensor(2, 3, 64, 64).to(device) y = radon.forward(x) self.assertEqual(y.size(), (2, 3, 10, 64)) z = radon.backprojection(y) self.assertEqual(z.size(), (2, 3, 64, 64)) # no batch dimensions x = torch.FloatTensor(64, 64).to(device) y = radon.forward(x) self.assertEqual(y.size(), (10, 64)) z = radon.backprojection(y) self.assertEqual(z.size(), (64, 64))
import torch from torch_radon import Radon from torch_radon.solvers import Landweber from utils import show_images batch_size = 1 n_angles = 512 image_size = 512 img = np.load("phantom.npy") device = torch.device('cuda') # instantiate Radon transform angles = np.linspace(0, np.pi, n_angles, endpoint=False) radon = Radon(image_size, angles) landweber = Landweber(radon) # estimate step size alpha = landweber.estimate_alpha(image_size, device) with torch.no_grad(): x = torch.FloatTensor(img).reshape(1, 1, image_size, image_size).to(device) sinogram = radon.forward(x) # use landweber iteration to reconstruct the image # values returned by 'callback' are stored inside 'progress' reconstruction, progress = landweber.run( torch.zeros(x.size(), device=device), sinogram,
class Inpaint(): def __init__(self, net, args, dataloader, device): self.netG = net[0] self.netDG = net[1] self.netDL = net[2] if args.mode == 'vgg': self.netLoss = net[3] self.optimizerG = optim.Adam(self.netG.parameters(), lr=args.lr, betas=(0.5, 0.999)) self.optimizerDG = optim.Adam(self.netDG.parameters(), lr=args.lr, betas=(0.5, 0.999)) self.optimizerDL = optim.Adam(self.netDL.parameters(), lr=args.lr, betas=(0.5, 0.999)) self.dataloader = dataloader self.device = device self.args = args self.save_cp = True self.start_epoch = args.load+1 if args.load>=0 else 0 self.mask = self.gen_mask().to(self.device) self.criterionL1 = torch.nn.L1Loss().to(self.device) self.criterionL2 = torch.nn.MSELoss().to(self.device) self.criterionGAN = GANLoss('vanilla').to(self.device) err_list = ["errDG", "errDL", "errGG_GAN", "errGG_C", "errGG_F", "errGG_P", "errGL_GAN", "errGL_C", "errGL_F", "errGL_P"] self.err = dict.fromkeys(err_list, None) if self.save_cp: try: if not os.path.exists(os.path.join(args.outdir, 'ckpt')): os.makedirs(os.path.join(args.outdir, 'ckpt')) print('Created checkpoint directory') if args.load < 0: # New log file with open(os.path.join(args.outdir, args.log_fn+'.csv'), 'w', newline='') as f: csvwriter = writer(f) csvwriter.writerow(["epoch", "runtime"] + err_list) except OSError: pass angles = np.linspace(0, np.pi, 180, endpoint=False) self.radon = Radon(args.height, angles, clip_to_circle=True) def gen_mask(self): mask = torch.zeros(180) mask[::8].fill_(1) # 180/23 if self.args.twoends: mask = torch.cat((mask[-6:], mask, mask[:6]), 0) # 192/25 return mask def gen_sparse(self, y): return y[:,:,:,self.mask==1] def append_twoends(self, y): front = torch.flip(y[:,:,:,:6], [2]) back = torch.flip(y[:,:,:,-6:], [2]) return torch.cat((back, y, front), 3) def ramp_module(self, sinogram): ''' Sinogram has dimension: bs x c x height x angle. Ramp is 1D but angle number affects normalization for filter_sinogram. Use with caution. ''' normalized_sinogram = normalize(sinogram, rto=(0,1)) if sinogram.size()[2] == self.args.height: filtered_sinogram = self.radon.filter_sinogram(normalized_sinogram.permute(0,1,3,2)).permute(0,1,3,2) # 320 x 192 else: print('sinogram dimension wrong for filter!') return normalize(filtered_sinogram, rto=(-1,1)) def criterionP(self, Gx, y): # calculate feature loss y_features = self.netLoss(y) Gx_features = self.netLoss(Gx) loss = 0.0 for j in range(len(y_features)): loss += self.criterionL2(Gx_features[j], y_features[j][:y.shape[0]]) return loss def criterionDP(self, Gx_features, y_features): loss = 0.0 for j in range(len(y_features)): loss += self.criterionL2(Gx_features[j], y_features[j]) return loss def train_D(self, Gx, y, mode): ''' mode is G/L. ''' if mode == 'G': netD = self.netDG optimizer = self.optimizerDG elif mode == 'L': netD = self.netDL optimizer = self.optimizerDL else: print('wrong mode!') netD.zero_grad() ############################ # Loss_D: L_D = -(log(D(y) + log(1 - D(G(x)))) ########################### # train with fake D_Gx = netD(Gx.detach())[-1] errD_fake = self.criterionGAN(D_Gx, False) # train with real D_y = netD(y)[-1] errD_real = self.criterionGAN(D_y, True) # backprop errD = (errD_real + errD_fake) * 0.5 errD.backward() optimizer.step() self.err['errD'+mode] = errD.item() def train_G(self, Gx, y, filtered_Gx, filtered_y, mode): ''' mode is G/L. ''' if mode == 'G': netD = self.netDG elif mode == 'L': netD = self.netDL else: print('wrong mode!') self.netG.zero_grad() ############################ # Loss_G_GAN: L_G = -log(D(G(x)) # Fake the D ########################### Gx_features = netD(Gx) errG_GAN = self.criterionGAN(Gx_features[-1], True) ############################ # Loss_G_C: L_C = ||y - G(x)||_1 ########################### errG_C = self.criterionL1(Gx, y)*50 ############################ # Loss_G_DP: Discriminator perceptual feature loss ########################### if self.args.mode == 'vgg': errG_P = self.criterionP(Gx, y)*20 elif self.args.mode == 'DP': y_features = netD(y) errG_P = self.criterionDP(Gx_features[:-1], y_features[:-1])*20 # errG_P = self.criterionDP(Gx_features[-2], y_features[-2])*50 else: errG_P = torch.tensor(0).to(self.device) ############################ # Loss_G_F: Ramp filtered sinogram loss ########################### errG_F = self.criterionL1(filtered_Gx, filtered_y)*50 # backprop errG = errG_GAN + errG_C + errG_F + errG_P errG.backward() self.optimizerG.step() self.err['errG'+mode+'_GAN'] = errG_GAN.item() self.err['errG'+mode+'_C'] = errG_C.item() self.err['errG'+mode+'_F'] = errG_F.item() self.err['errG'+mode+'_P'] = errG_P.item() def log(self, epoch, i): print(f'[{epoch}/{self.args.epochs}][{i}/{len(self.dataloader)}] ' \ f'LossDG: {self.err["errDG"]:.4f} ' \ f'LossGG_GAN: {self.err["errGG_GAN"]:.4f} ' \ f'LossGG_C: {self.err["errGG_C"]:.4f} ' \ f'LossGG_F: {self.err["errGG_F"]:.4f} ' \ f'LossGG_P: {self.err["errGG_P"]:.4f} ' \ f'LossDL: {self.err["errDL"]:.4f} ' \ f'LossGL_GAN: {self.err["errGL_GAN"]:.4f} ' \ f'LossGL_C: {self.err["errGL_C"]:.4f} ' \ f'LossGL_F: {self.err["errGL_F"]:.4f} ' \ f'LossGL_P: {self.err["errGL_P"]:.4f} ' \ ) def log2file(self, fn, epoch, runtime): new_row = [epoch, runtime]+ list[self.err.values()] with open(fn, 'a+', newline='') as write_obj: csv_writer = writer(write_obj) csv_writer.writerow(new_row) def train(self): print(f'''Starting training: Epochs: {self.args.epochs} Batch size: {self.args.batchSize} Learning rate: {self.args.lr} Checkpoints: {self.save_cp} Device: {self.device.type} ''') for epoch in range(self.start_epoch, self.args.epochs): self.D_epochs = 1 # Adjust if you want print('D is trained ', str(self.D_epochs), 'times in this epoch.') start = time.time() # log start time for i, data in enumerate(self.dataloader): y = data[0].to(self.device) # 320 x 180 # forward if self.args.twoends: y = self.append_twoends(y) # 320 x 192 filtered_y = self.ramp_module(y) # 320 x 192, normalized to -1~1 x = self.gen_sparse(y) # 320 x 25 # Train Global Gx = self.netG(x) filtered_Gx = self.ramp_module(Gx) # 320 x 192, normalized to -1~1 ###### Train D set_requires_grad(self.netDG, True) for _ in range(self.D_epochs): # increase D epoch gradually. FOR DP LOSS training self.train_D(Gx, y, mode='G') ###### Train G set_requires_grad(self.netDG, False) # D requires no gradients when optimizing G self.train_G(Gx, y, filtered_Gx, filtered_y, mode='G') # Train Local Gx = self.netG(x) filtered_Gx = self.ramp_module(Gx) # 320 x 192, normalized to -1~1 patch_area = gen_hole_area((y.shape[3]//4, y.shape[2]//4), (y.shape[3], y.shape[2])) Gx_patch = crop(Gx, patch_area) y_patch = crop(y, patch_area) filtered_y_patch = crop(filtered_y, patch_area) filtered_Gx_patch = crop(filtered_Gx, patch_area) ###### Train D set_requires_grad(self.netDL, True) for _ in range(self.D_epochs): # increase D epoch gradually. FOR DP LOSS training self.train_D(Gx_patch, y_patch, mode='L') ###### Train G set_requires_grad(self.netDL, False) # D requires no gradients when optimizing G self.train_G(Gx_patch, y_patch, filtered_Gx_patch, filtered_y_patch, mode='L') if i % 100 == 0: self.log(epoch, i) end = time.time() # log end time # self.log2file(os.path.join(self.args.outdir, self.args.log_fn+'.csv'), epoch , str(end-start)) # Log self.log(epoch, i) if self.save_cp: torch.save(self.netG.state_dict(), f'{self.args.outdir}/ckpt/G_epoch{epoch}.pth') torch.save(self.netDG.state_dict(), f'{self.args.outdir}/ckpt/DG_epoch{epoch}.pth') torch.save(self.netDL.state_dict(), f'{self.args.outdir}/ckpt/DL_epoch{epoch}.pth') vutils.save_image(Gx.detach(), '%s/impainted_samples_epoch_%03d.png' % (self.args.outdir, epoch), normalize=True)
class Predict(): def __init__(self, args, image): self.device = torch.device( 'cuda' if torch.cuda.is_available() else 'cpu') if args.twoends: factor = 192 / (args.angles + 2) # 7.68 else: factor = 180 / args.angles # 7.826086956521739 self.net = UNet(input_nc=1, output_nc=1, scale_factor=factor).to(self.device) self.net = nn.DataParallel(self.net) pathG = os.path.join(args.ckpt) self.net.load_state_dict(torch.load(pathG, map_location=self.device)) self.net.eval() self.image = image.to(self.device) self.twoends = args.twoends self.mask = self.gen_mask().to(self.device) # Radon Operator angles = np.linspace(0, np.pi, 180, endpoint=False) self.radon = Radon(args.height, angles, clip_to_circle=True) def gen_mask(self): mask = torch.zeros(180) mask[::8].fill_(1) # 180 if self.twoends: mask = torch.cat((mask[-6:], mask, mask[:6]), 0) # 192 return mask def append_twoends(self, y): front = torch.flip(y[:, :, :, :6], [2]) back = torch.flip(y[:, :, :, -6:], [2]) return torch.cat((back, y, front), 3) def gen_sparse(self, y): return y[:, :, :, self.mask == 1] def crop_sinogram(self, x): return x[:, :, :, 6:-6] def inpaint(self): y = self.image # 320 x 180 # Two-Ends Preprocessing if self.twoends: y = self.append_twoends(y) # 320 x 192 # Generate Sparse-view Image, forward model x = self.gen_sparse(y) Gx = self.net(x) # Crop Two-Ends if self.twoends: Gx = self.crop_sinogram(Gx) # FBP Reconstruction Gx = normalize(Gx) # 0~1 fbp_Gx = self.radon.backprojection( self.radon.filter_sinogram(Gx.permute(0, 1, 3, 2))) # Save Results vutils.save_image(fbp_Gx, 'result_reconstruction.png', normalize=True) vutils.save_image(Gx, 'result_sinogram.png', normalize=True)
def shrink(a, b): return (torch.abs(a) - b).clamp_min(0) * torch.sign(a) batch_size = 1 n_angles = 512 image_size = 512 img = np.load("phantom.npy") device = torch.device('cuda') # instantiate Radon transform angles = np.linspace(0, np.pi, n_angles, endpoint=False) radon = Radon(image_size, angles) x = torch.FloatTensor(img).to(device).view(1, 512, 512) x = torch.cat([x] * 4, dim=0).view(2, 2, 512, 512) print(x.size()) y = radon.forward(x) # CG(radon, 1.0, 0.0, torch.zeros_like(x), radon.backward(y)) # rec = cgne(radon, torch.zeros_like(x), y, tol=1e-2) s = time.time() for _ in range(1): with torch.no_grad(): rec, values = cg(lambda z: radon.backward(radon.forward(z)), torch.zeros_like(x), radon.backward(y), callback=lambda x, r: torch.norm(
def shrink(a, b): return (torch.abs(a) - b).clamp_min(0) * torch.sign(a) batch_size = 1 n_angles = 512 // 4 image_size = 512 img = np.load("phantom.npy") device = torch.device('cuda') # instantiate Radon transform angles = np.linspace(0, np.pi / 4, n_angles, endpoint=False) radon = Radon(image_size, angles) shearlet = Shearlet(512, 512, [0.5] * 5, cache=None) # ".cache") with torch.no_grad(): x = torch.FloatTensor(img).reshape(1, image_size, image_size).to(device) sinogram = radon.forward(x) bp = radon.backward(sinogram, extend=False) # f, values = CG(radon, 1.0 / 512**2, 0.0001, bp.clone(), bp) # # print(torch.norm(x - f)/torch.norm(x)) sc = shearlet.forward(bp) p_0 = 0.02 p_1 = 0.1 w = 3**shearlet.scales / 400 w = w.view(1, -1, 1, 1).to(device)
def main(): parser = argparse.ArgumentParser( description='Benchmark and compare with Astra Toolbox') parser.add_argument('--task', default="all") parser.add_argument('--image-size', default=256, type=int) parser.add_argument('--angles', default=-1, type=int) parser.add_argument('--batch-size', default=32, type=int) parser.add_argument('--samples', default=50, type=int) parser.add_argument('--warmup', default=10, type=int) parser.add_argument('--output', default="") parser.add_argument('--circle', action='store_true') args = parser.parse_args() if args.angles == -1: args.angles = args.image_size device = torch.device("cuda") angles = np.linspace(0, 2 * np.pi, args.angles, endpoint=False).astype(np.float32) radon = Radon(args.image_size, angles, clip_to_circle=args.circle) radon_fb = RadonFanbeam(args.image_size, angles, args.image_size, clip_to_circle=args.circle) astra_pw = AstraParallelWrapper(angles, args.image_size) astra_fw = AstraFanbeamWrapper(angles, args.image_size) # astra = AstraWrapper(angles) if args.task == "all": tasks = ["forward", "backward", "fanbeam forward", "fanbeam backward"] elif args.task == "shearlet": # tasks = ["shearlet forward", "shearlet backward"] benchmark_shearlet(args) return else: tasks = [args.task] astra_fps = [] radon_fps = [] radon_half_fps = [] if "forward" in tasks: print("Benchmarking forward from device") x = generate_random_images(args.batch_size, args.image_size) dx = torch.FloatTensor(x).to(device) astra_time = benchmark_function(lambda y: astra_pw.forward(y), dx, args.samples, args.warmup) radon_time = benchmark_function(lambda y: radon.forward(y), dx, args.samples, args.warmup, sync=True) radon_half_time = benchmark_function(lambda y: radon.forward(y), dx.half(), args.samples, args.warmup, sync=True) astra_fps.append(args.batch_size / astra_time) radon_fps.append(args.batch_size / radon_time) radon_half_fps.append(args.batch_size / radon_half_time) print("Speedup:", astra_time / radon_time) print("Speedup half-precision:", astra_time / radon_half_time) print() if "backward" in tasks: print("Benchmarking backward from device") x = generate_random_images(args.batch_size, args.image_size) dx = torch.FloatTensor(x).to(device) astra_time = benchmark_function(lambda y: astra_pw.backward(y), dx, args.samples, args.warmup) radon_time = benchmark_function(lambda y: radon.backward(y), dx, args.samples, args.warmup, sync=True) radon_half_time = benchmark_function(lambda y: radon.backward(y), dx.half(), args.samples, args.warmup, sync=True) astra_fps.append(args.batch_size / astra_time) radon_fps.append(args.batch_size / radon_time) radon_half_fps.append(args.batch_size / radon_half_time) print("Speedup:", astra_time / radon_time) print("Speedup half-precision:", astra_time / radon_half_time) print() if "fanbeam forward" in tasks: print("Benchmarking fanbeam forward") x = generate_random_images(args.batch_size, args.image_size) dx = torch.FloatTensor(x).to(device) # astra_time = benchmark_function(lambda y: astra_fw.forward(y), dx, args.samples, args.warmup) radon_time = benchmark_function(lambda y: radon_fb.forward(y), dx, args.samples, args.warmup, sync=True) radon_half_time = benchmark_function(lambda y: radon_fb.forward(y), dx.half(), args.samples, args.warmup, sync=True) astra_fps.append(args.batch_size / astra_time) radon_fps.append(args.batch_size / radon_time) radon_half_fps.append(args.batch_size / radon_half_time) print("Speedup:", astra_time / radon_time) print("Speedup half-precision:", astra_time / radon_half_time) print() if "fanbeam backward" in tasks: print("Benchmarking fanbeam backward") x = generate_random_images(args.batch_size, args.image_size) dx = torch.FloatTensor(x).to(device) # astra_time = benchmark_function(lambda y: astra_fw.backward(y), dx, args.samples, args.warmup) radon_time = benchmark_function(lambda y: radon_fb.backprojection(y), dx, args.samples, args.warmup, sync=True) radon_half_time = benchmark_function( lambda y: radon_fb.backprojection(y), dx.half(), args.samples, args.warmup, sync=True) astra_fps.append(args.batch_size / astra_time) radon_fps.append(args.batch_size / radon_time) radon_half_fps.append(args.batch_size / radon_half_time) print("Speedup:", astra_time / radon_time) print("Speedup half-precision:", astra_time / radon_half_time) print() title = f"Image size {args.image_size}x{args.image_size}, {args.angles} angles and batch size {args.batch_size} on a {torch.cuda.get_device_name(0)}" plot(tasks, astra_fps, radon_fps, radon_half_fps, title) if args.output: plt.savefig(args.output, dpi=300) else: plt.show()
import matplotlib.pyplot as plt import numpy as np import torch from utils import show_images from torch_radon import Radon device = torch.device('cuda') img = np.load("phantom.npy") image_size = img.shape[0] n_angles = image_size # Instantiate Radon transform. clip_to_circle should be True when using filtered backprojection. angles = np.linspace(0, np.pi, n_angles, endpoint=False) radon = Radon(image_size, angles, clip_to_circle=True) with torch.no_grad(): x = torch.FloatTensor(img).to(device) sinogram = radon.forward(x) filtered_sinogram = radon.filter_sinogram(sinogram) fbp = radon.backprojection(filtered_sinogram) print("FBP Error", torch.norm(x - fbp).item()) # Show results titles = [ "Original Image", "Sinogram", "Filtered Sinogram", "Filtered Backprojection" ]
def main(): parser = argparse.ArgumentParser(description='Benchmark and compare with Astra Toolbox') parser.add_argument('--task', default="all") parser.add_argument('--image-size', default=256, type=int) parser.add_argument('--angles', default=-1, type=int) parser.add_argument('--batch-size', default=32, type=int) parser.add_argument('--samples', default=50, type=int) parser.add_argument('--warmup', default=10, type=int) parser.add_argument('--output', default="") parser.add_argument('--circle', action='store_true') args = parser.parse_args() if args.angles == -1: args.angles = args.image_size device = torch.device("cuda") angles = np.linspace(0, 2 * np.pi, args.angles, endpoint=False).astype(np.float32) radon = Radon(args.image_size, angles, clip_to_circle=args.circle) radon_fb = RadonFanbeam(args.image_size, angles, args.image_size, clip_to_circle=args.circle) astra = AstraWrapper(angles) if args.task == "all": tasks = ["forward", "backward", "fanbeam forward", "fanbeam backward"] else: tasks = [args.task] astra_fps = [] radon_fps = [] radon_half_fps = [] # x = torch.randn((args.batch_size, args.image_size, args.image_size), device=device) # if "forward" in tasks: # print("Benchmarking forward") # x = generate_random_images(args.batch_size, args.image_size) # astra_time = benchmark_function(lambda y: astra.forward(y), x, args.samples, args.warmup) # radon_time = benchmark_function(lambda y: radon.forward(torch.FloatTensor(x).to(device)).cpu(), x, args.samples, # args.warmup) # radon_half_time = benchmark_function(lambda y: radon.forward(torch.HalfTensor(x).to(device)).cpu(), x, # args.samples, args.warmup) # # astra_fps.append(args.batch_size / astra_time) # radon_fps.append(args.batch_size / radon_time) # radon_half_fps.append(args.batch_size / radon_half_time) # # print(astra_time, radon_time, radon_half_time) # astra.clean() # # if "backward" in tasks: # print("Benchmarking backward") # x = generate_random_images(args.batch_size, args.image_size) # pid, x = astra.forward(x) # # astra_time = benchmark_function(lambda y: astra.backproject(pid, args.image_size, args.batch_size), x, # args.samples, args.warmup) # radon_time = benchmark_function(lambda y: radon.backward(torch.FloatTensor(x).to(device)).cpu(), x, # args.samples, # args.warmup) # radon_half_time = benchmark_function(lambda y: radon.backward(torch.HalfTensor(x).to(device)).cpu(), x, # args.samples, args.warmup) # # astra_fps.append(args.batch_size / astra_time) # radon_fps.append(args.batch_size / radon_time) # radon_half_fps.append(args.batch_size / radon_half_time) # # print(astra_time, radon_time, radon_half_time) # astra.clean() # if "forward+backward" in tasks: # print("Benchmarking forward + backward") # x = generate_random_images(args.batch_size, args.image_size) # astra_time = benchmark_function(lambda y: astra_forward_backward(astra, y, args.image_size, args.batch_size), x, # args.samples, args.warmup) # radon_time = benchmark_function(lambda y: radon_forward_backward(radon, y), x, args.samples, # args.warmup) # radon_half_time = benchmark_function(lambda y: radon_forward_backward(radon, y, half=True), x, # args.samples, args.warmup) # astra_fps.append(args.batch_size / astra_time) # radon_fps.append(args.batch_size / radon_time) # radon_half_fps.append(args.batch_size / radon_half_time) # print(astra_time, radon_time, radon_half_time) # astra.clean() if "forward" in tasks: print("Benchmarking forward from device") x = generate_random_images(args.batch_size, args.image_size) dx = torch.FloatTensor(x).to(device) astra_time = benchmark_function(lambda y: astra.forward(y), x, args.samples, args.warmup) radon_time = benchmark_function(lambda y: radon.forward(y), dx, args.samples, args.warmup, sync=True) radon_half_time = benchmark_function(lambda y: radon.forward(y), dx.half(), args.samples, args.warmup, sync=True) astra_fps.append(args.batch_size / astra_time) radon_fps.append(args.batch_size / radon_time) radon_half_fps.append(args.batch_size / radon_half_time) print(astra_time, radon_time, radon_half_time) astra.clean() if "backward" in tasks: print("Benchmarking backward from device") x = generate_random_images(args.batch_size, args.image_size) dx = torch.FloatTensor(x).to(device) pid, x = astra.forward(x) astra_time = benchmark_function(lambda y: astra.backproject(pid, args.image_size, args.batch_size), x, args.samples, args.warmup) radon_time = benchmark_function(lambda y: radon.backward(y), dx, args.samples, args.warmup, sync=True) radon_half_time = benchmark_function(lambda y: radon.backward(y), dx.half(), args.samples, args.warmup, sync=True) astra_fps.append(args.batch_size / astra_time) radon_fps.append(args.batch_size / radon_time) radon_half_fps.append(args.batch_size / radon_half_time) print(astra_time, radon_time, radon_half_time) astra.clean() if "fanbeam forward" in tasks: print("Benchmarking fanbeam forward") x = generate_random_images(args.batch_size, args.image_size) dx = torch.FloatTensor(x).to(device) # # astra_time = benchmark_function(lambda y: astra.backproject(pid, args.image_size, args.batch_size), x, # args.samples, args.warmup) radon_time = benchmark_function(lambda y: radon_fb.forward(y), dx, args.samples, args.warmup, sync=True) radon_half_time = benchmark_function(lambda y: radon_fb.forward(y), dx.half(), args.samples, args.warmup, sync=True) astra_fps.append(0.0) radon_fps.append(args.batch_size / radon_time) radon_half_fps.append(args.batch_size / radon_half_time) #print(astra_time, radon_time, radon_half_time) astra.clean() if "fanbeam backward" in tasks: print("Benchmarking fanbeam backward") x = generate_random_images(args.batch_size, args.image_size) dx = torch.FloatTensor(x).to(device) # # astra_time = benchmark_function(lambda y: astra.backproject(pid, args.image_size, args.batch_size), x, # args.samples, args.warmup) radon_time = benchmark_function(lambda y: radon_fb.backprojection(y), dx, args.samples, args.warmup, sync=True) radon_half_time = benchmark_function(lambda y: radon_fb.backprojection(y), dx.half(), args.samples, args.warmup, sync=True) astra_fps.append(0.0) radon_fps.append(args.batch_size / radon_time) radon_half_fps.append(args.batch_size / radon_half_time) #print(astra_time, radon_time, radon_half_time) astra.clean() title = f"Image size {args.image_size}x{args.image_size}, {args.angles} angles and batch size {args.batch_size} on a {torch.cuda.get_device_name(0)}" plot(tasks, astra_fps, radon_fps, radon_half_fps, title) if args.output: plt.savefig(args.output, dpi=300) else: plt.show()
import torch from torch_radon import Radon from torch_radon.solvers import Landweber from utils import show_images batch_size = 1 n_angles = 512 image_size = 512 img = np.load("phantom.npy") device = torch.device('cuda') # instantiate Radon transform angles = np.linspace(0, np.pi, n_angles, endpoint=False) radon = Radon(image_size, angles) with torch.no_grad(): x = torch.FloatTensor(img).reshape(1, 1, image_size, image_size).to(device) sinogram = radon.forward(x) filtered_sinogram = radon.filter_sinogram(sinogram) fbp = radon.backprojection(filtered_sinogram, extend=False) * np.pi / n_angles print("FBP Error", torch.norm(x - fbp).item()) titles = [ "Original Image", "Sinogram", "Filtered Sinogram", "Filtered Backprojection" ]
# rdx *= (alpha_e - alpha_s) # rdy *= (alpha_e - alpha_s) # # print(rsx, rsy, rsx**2 + rsy**2 - v**2) # print(rdx, rdy, (rsx+rdx)**2 + (rsy+rdy)**2 - v**2) device = torch.device('cuda') angles = np.linspace(0, 2 * np.pi, 180).astype(np.float32) batch_size = 4 image_size = 256 astraw = AstraWrapper(angles) x = generate_random_images(batch_size, image_size, masked=True) astra_fp_id, astra_fp = astraw.forward(x) # our implementation radon = Radon(image_size, angles, clip_to_circle=True) x = torch.FloatTensor(x).to(device) our_fp = radon.forward(x) plt.imshow(astra_fp[0]) plt.figure() plt.imshow(our_fp[0].cpu().numpy()) plt.show() print(relative_error(astra_fp, our_fp.cpu().numpy()))
image_size = 128 channels = 4 device = torch.device('cuda') criterion = nn.L1Loss() # Instantiate a model for the sinogram and one for the image sino_model = nn.Conv2d(1, channels, 5, padding=2).to(device) image_model = nn.Conv2d(channels, 1, 3, padding=1).to(device) # create empty images x = torch.FloatTensor(batch_size, 1, image_size, image_size).to(device) # instantiate Radon transform angles = np.linspace(0, np.pi, n_angles) radon = Radon(image_size, angles) # forward projection sinogram = radon.forward(x) # apply sino_model to sinograms filtered_sinogram = sino_model(sinogram) # backprojection backprojected = radon.backprojection(filtered_sinogram) # apply image_model to backprojected images y = image_model(backprojected) # backward works as usual loss = criterion(y, x)
class Predict(): def __init__(self, args, dataloader): self.device = torch.device( 'cuda' if torch.cuda.is_available() else 'cpu') self.args = args self.dataloader = dataloader if args.twoends: factor = 192 / (args.angles + 2) # 7.68 else: factor = 180 / args.angles # 7.826086956521739 self.net = UNet(input_nc=1, output_nc=1, scale_factor=factor).to(self.device) self.net = nn.DataParallel(self.net) pathG = os.path.join(args.ckpt) self.net.load_state_dict(torch.load(pathG, map_location=self.device)) self.net.eval() self.gen_mask() # Radon Operator for different downsampling factors angles = np.linspace(0, np.pi, 180, endpoint=False) self.radon = Radon(args.height, angles, clip_to_circle=True) self.radon23 = Radon(args.height, angles[::8], clip_to_circle=True) self.radon45 = Radon(args.height, angles[::4], clip_to_circle=True) self.radon90 = Radon(args.height, angles[::2], clip_to_circle=True) def gen_mask(self): mask = torch.zeros(180) mask[::8].fill_(1) # 180 if self.args.twoends: self.mask = torch.cat((mask[-6:], mask, mask[:6]), 0).to(self.device) # 192 self.mask_sparse = mask def append_twoends(self, y): front = torch.flip(y[:, :, :, :6], [2]) back = torch.flip(y[:, :, :, -6:], [2]) return torch.cat((back, y, front), 3) def gen_input(self, y, mask): return y[:, :, :, mask == 1] def crop_sinogram(self, x): return x[:, :, :, 6:-6] def inpaint(self): for i, data in enumerate(self.dataloader): y = data[0].to(self.device) # 320 x 180 # Two-Ends Preprocessing if self.args.twoends: y_TE = self.append_twoends(y) # 320 x 192 # Forward Model x = self.gen_input(y_TE, self.mask) # input, 320 x 25 Gx = self.net(x) # 320 x 192 # Crop Two-Ends if self.args.twoends: Gx = self.crop_sinogram(Gx) # 320 x 180 # FBP Reconstruction Gx = normalize(Gx) # 0~1 fbp_Gx = self.radon.backprojection( self.radon.filter_sinogram(Gx.permute(0, 1, 3, 2))) # 320 x 320 # FBP for downsampled sinograms Gx1 = Gx[:, :, :, ::2] # 320 x 90 Gx1 = normalize(Gx1) # 0~1 fbp_Gx1 = self.radon90.backprojection( self.radon90.filter_sinogram(Gx1.permute(0, 1, 3, 2))) Gx2 = Gx[:, :, :, ::4] # 320 x 45 Gx2 = normalize(Gx2) # 0~1 fbp_Gx2 = self.radon45.backprojection( self.radon45.filter_sinogram(Gx2.permute(0, 1, 3, 2))) sparse = y[:, :, :, ::8] # 320 x 23, original sparse-view sinogram sparse = normalize(sparse) # 0~1 fbp_sparse = self.radon23.backprojection( self.radon23.filter_sinogram(sparse.permute(0, 1, 3, 2))) print(f'Saving images for batch {i}') for j in range(y.size()[0]): # vutils.save_image(Gx[j,0], f'{self.args.outdir}/{class_name}/{fnames[i*self.args.bs+j]}', normalize=True) vutils.save_image( fbp_Gx[j, 0], f'{self.args.outdir}/{class_name}/{fnames[i*self.args.bs+j]}', normalize=True) vutils.save_image( fbp_Gx1[j, 0], f'{self.args.outdir}_90/{class_name}/{fnames[i*self.args.bs+j]}', normalize=True) vutils.save_image( fbp_Gx2[j, 0], f'{self.args.outdir}_45/{class_name}/{fnames[i*self.args.bs+j]}', normalize=True) vutils.save_image( fbp_sparse[j, 0], f'{self.args.outdir}_23/{class_name}/{fnames[i*self.args.bs+j]}', normalize=True)