Esempio n. 1
0
def bench_fanbeam_backward(task, dtype, device, *bench_args):
    num_angles = task["num_angles"]
    det_count = task["det_count"]
    source_dist = task["source_distance"]
    det_dist = task["detector_distance"]
    det_spacing = task["det_spacing"]

    x = torch.randn(task["batch_size"],
                    task["size"],
                    task["size"],
                    dtype=dtype,
                    device=device)
    angles = np.linspace(0, np.pi, num_angles, endpoint=False)

    projection = Projection.fanbeam(source_dist, det_dist, det_count,
                                    det_spacing)
    radon = Radon(angles, task["size"], projection)
    # radon = RadonFanbeam(phantom.size(1), angles, source_dist, det_dist, det_count)

    sino = radon.forward(x)

    def f(x):
        return radon.backward(x)

    return benchmark(f, x, *bench_args)
Esempio n. 2
0
 def init_radon(self, beam, circle, det_dist):
     if beam == 'parallel':
         angles = np.linspace(0, np.pi, self.n_angles, endpoint=False)
         self.radon = Radon(self.img_size, angles, clip_to_circle=circle)
         self.radon_sparse = Radon(self.img_size,
                                   angles[::self.sample_ratio],
                                   clip_to_circle=circle)
     elif beam == 'fan':
         angles = np.linspace(0, self.n_angles / 180 * np.pi, self.n_angles,
                              False)
         self.radon = RadonFanbeam(self.img_size,
                                   angles,
                                   source_distance=det_dist[0],
                                   det_distance=det_dist[1],
                                   clip_to_circle=circle,
                                   det_count=self.det_size)
         self.radon_sparse = RadonFanbeam(self.img_size,
                                          angles[::self.sample_ratio],
                                          source_distance=det_dist[0],
                                          det_distance=det_dist[1],
                                          clip_to_circle=circle,
                                          det_count=self.det_size)
     else:
         raise Exception('projection beam type undefined!')
     self.n_angles_sparse = len(angles[::self.sample_ratio])
Esempio n. 3
0
def test_error(device, batch_size, image_size, angles, spacing, clip_to_circle):
    # generate random images
    x = generate_random_images(batch_size, image_size, masked=clip_to_circle)

    # astra
    astra = AstraWrapper(angles)

    astra_fp_id, astra_fp = astra.forward(x, spacing)
    astra_bp = astra.backproject(astra_fp_id, image_size, batch_size)
    if clip_to_circle:
        astra_bp *= circle_mask(image_size)

    # our implementation
    radon = Radon(image_size, angles, det_spacing=spacing, clip_to_circle=clip_to_circle)
    x = torch.FloatTensor(x).to(device)

    our_fp = radon.forward(x)
    our_bp = radon.backprojection(our_fp)

    forward_error = relative_error(astra_fp, our_fp.cpu().numpy())
    back_error = relative_error(astra_bp, our_bp.cpu().numpy())

    # if forward_error > 10:
    #     plt.imshow(astra_fp[0])
    #     plt.figure()
    #     plt.imshow(our_fp[0].cpu().numpy())
    #     plt.show()

    print(
        f"batch: {batch_size}, size: {image_size}, angles: {len(angles)}, spacing: {spacing}, circle: {clip_to_circle}, forward: {forward_error}, back: {back_error}")
    # TODO better checks
    assert_less(forward_error, 1e-2)
    assert_less(back_error, 5e-3)
Esempio n. 4
0
def bench_parallel_forward(phantom, det_count, num_angles, warmup, repeats):
    radon = Radon(phantom.size(1),
                  np.linspace(0, np.pi, num_angles, endpoint=False), det_count)

    f = lambda x: radon.forward(x)

    return benchmark(f, phantom, warmup, repeats)
Esempio n. 5
0
class Predict():
    def __init__(self, args, dataloader):
        self.device = torch.device(
            'cuda' if torch.cuda.is_available() else 'cpu')
        self.args = args
        self.dataloader = dataloader

        self.net = UNet(input_nc=1, output_nc=1).to(self.device)
        self.net = nn.DataParallel(self.net)

        pathG = os.path.join(args.ckpt)
        self.net.load_state_dict(torch.load(pathG, map_location=self.device))
        self.net.eval()

        self.gen_mask()

        angles = np.linspace(0, np.pi, 180, endpoint=False)
        self.radon = Radon(args.height, angles, clip_to_circle=True)

    def gen_mask(self):
        self.mask = torch.zeros(180).to(self.device)
        self.mask[::8].fill_(1)  # 180

    def gen_x(self, y):
        return self.mask * y

    def crop_sinogram(self, x):
        return x[:, :, :, 6:-6]

    def overlay(self, Gx, x):
        result = self.mask * x + (1 - self.mask) * Gx
        return result

    def inpaint(self):
        for i, data in enumerate(self.dataloader):
            y = data[0].to(self.device)  # 320 x 180

            x = self.gen_x(y)  # input, 320 x 23
            Gx = self.net(x)

            Gx = self.overlay(Gx, y)

            # FBP
            Gx = normalize(Gx)  # 0~1
            fbp_Gx = self.radon.backprojection(
                self.radon.filter_sinogram(Gx.permute(0, 1, 3, 2)))

            print(f'Saving images for batch {i}')

            for j in range(y.size()[0]):
                #                 vutils.save_image(Gx[j,0], f'{self.args.outdir}/{class_name}/{fnames[i*self.args.bs+j]}', normalize=True)  # to 0~255
                vutils.save_image(
                    fbp_Gx[j, 0],
                    f'{self.args.outdir}/{class_name}/{fnames[i*self.args.bs+j]}',
                    normalize=True)
Esempio n. 6
0
    def __init__(self, args, dataloader):
        self.device = torch.device(
            'cuda' if torch.cuda.is_available() else 'cpu')
        self.args = args
        self.dataloader = dataloader

        if args.twoends:
            factor = 192 / (args.angles + 2)  # 7.68
        else:
            factor = 180 / args.angles  # 7.826086956521739

        self.net = UNet(input_nc=1, output_nc=1,
                        scale_factor=factor).to(self.device)
        self.net = nn.DataParallel(self.net)
        pathG = os.path.join(args.ckpt)
        self.net.load_state_dict(torch.load(pathG, map_location=self.device))
        self.net.eval()

        self.gen_mask()

        # Radon Operator for different downsampling factors
        angles = np.linspace(0, np.pi, 180, endpoint=False)
        self.radon = Radon(args.height, angles, clip_to_circle=True)
        self.radon23 = Radon(args.height, angles[::8], clip_to_circle=True)
        self.radon45 = Radon(args.height, angles[::4], clip_to_circle=True)
        self.radon90 = Radon(args.height, angles[::2], clip_to_circle=True)
Esempio n. 7
0
    def test_differentiation(self):
        device = torch.device('cuda')
        x = torch.FloatTensor(1, 64, 64).to(device)
        x.requires_grad = True
        angles = torch.FloatTensor(
            np.linspace(0, 2 * np.pi, 10).astype(np.float32)).to(device)

        radon = Radon(64, angles)

        # check that backward is implemented for fp and bp
        y = radon.forward(x)
        z = torch.mean(radon.backprojection(y))
        z.backward()
        self.assertIsNotNone(x.grad)
Esempio n. 8
0
 def __init__(self, image_size, n_angles, sample_ratio, device, circle=False):
     self.device = device
     self.image_size = image_size
     self.sample_ratio = sample_ratio
     self.n_angles = n_angles
     
     angles = np.linspace(0, np.pi, self.n_angles, endpoint=False)
     self.radon = Radon(self.image_size, angles, clip_to_circle=circle)
     self.radon_sparse = Radon(self.image_size, angles[::sample_ratio], clip_to_circle=circle)
     self.n_angles_sparse = len(angles[::sample_ratio])
     self.landweber = Landweber(self.radon)
     
     self.mask = torch.zeros((1,1,1,180)).to(device)
     self.mask[:,:,:,::sample_ratio].fill_(1)
Esempio n. 9
0
    def __init__(self, args, dataloader):
        self.device = torch.device(
            'cuda' if torch.cuda.is_available() else 'cpu')
        self.args = args
        self.dataloader = dataloader

        self.net = UNet(input_nc=1, output_nc=1).to(self.device)
        self.net = nn.DataParallel(self.net)

        pathG = os.path.join(args.ckpt)
        self.net.load_state_dict(torch.load(pathG, map_location=self.device))
        self.net.eval()

        self.gen_mask()

        angles = np.linspace(0, np.pi, 180, endpoint=False)
        self.radon = Radon(args.height, angles, clip_to_circle=True)
Esempio n. 10
0
def test_noise():
    device = torch.device('cuda')

    x = torch.FloatTensor(3, 5, 64, 64).to(device)
    lookup_table = torch.FloatTensor(128, 64).to(device)
    x.requires_grad = True
    angles = torch.FloatTensor(np.linspace(0, 2 * np.pi, 10).astype(np.float32))

    radon = Radon(64, angles)

    sinogram = radon.forward(x)
    assert_equal(sinogram.size(), (3, 5, 10, 64))

    readings = radon.emulate_readings(sinogram, 5, 10.0)
    assert_equal(readings.size(), (3, 5, 10, 64))
    assert_equal(readings.dtype, torch.int32)

    y = radon.readings_lookup(readings, lookup_table)
    assert_equal(y.size(), (3, 5, 10, 64))
    assert_equal(y.dtype, torch.float32)
Esempio n. 11
0
def bench_parallel_backward(task, dtype, device, *bench_args):
    num_angles = task["num_angles"]
    det_count = task["det_count"]

    x = torch.randn(task["batch_size"],
                    task["size"],
                    task["size"],
                    dtype=dtype,
                    device=device)
    angles = np.linspace(0, np.pi, num_angles, endpoint=False)
    projection = Projection.parallel_beam(det_count)
    radon = Radon(angles, task["size"], projection)
    # radon = Radon(phantom.size(1), np.linspace(0, np.pi, num_angles, endpoint=False), det_count)

    sino = radon.forward(x)

    def f(x):
        return radon.backward(x)

    return benchmark(f, sino, *bench_args)
Esempio n. 12
0
 def __init__(self, net, args, dataloader, device):
     self.netG = net[0]
     self.netDG = net[1]
     self.netDL = net[2]
     if args.mode == 'vgg':
         self.netLoss = net[3]
     self.optimizerG = optim.Adam(self.netG.parameters(), lr=args.lr, betas=(0.5, 0.999))
     self.optimizerDG = optim.Adam(self.netDG.parameters(), lr=args.lr, betas=(0.5, 0.999))
     self.optimizerDL = optim.Adam(self.netDL.parameters(), lr=args.lr, betas=(0.5, 0.999))
     
     self.dataloader = dataloader
     self.device = device
     self.args = args
     self.save_cp = True
     self.start_epoch = args.load+1 if args.load>=0 else 0
     self.mask = self.gen_mask().to(self.device)
     
     self.criterionL1 = torch.nn.L1Loss().to(self.device)
     self.criterionL2 = torch.nn.MSELoss().to(self.device)
     self.criterionGAN = GANLoss('vanilla').to(self.device)
     
     err_list = ["errDG", "errDL", 
                 "errGG_GAN", "errGG_C", "errGG_F", "errGG_P",
                 "errGL_GAN", "errGL_C", "errGL_F", "errGL_P"]
     self.err = dict.fromkeys(err_list, None) 
                 
     if self.save_cp:
         try:
             if not os.path.exists(os.path.join(args.outdir, 'ckpt')):
                 os.makedirs(os.path.join(args.outdir, 'ckpt'))
                 print('Created checkpoint directory')
             if args.load < 0:  # New log file
                 with open(os.path.join(args.outdir, args.log_fn+'.csv'), 'w', newline='') as f:
                     csvwriter = writer(f)
                     csvwriter.writerow(["epoch", "runtime"] + err_list)
         except OSError:
             pass
     
     angles = np.linspace(0, np.pi, 180, endpoint=False)
     self.radon = Radon(args.height, angles, clip_to_circle=True)
Esempio n. 13
0
def test_half(device, batch_size, image_size, angles, spacing, det_count,
              clip_to_circle):
    # generate random images
    det_count = int(det_count * image_size)
    mask_radius = det_count / 2.0 if clip_to_circle else -1
    x = generate_random_images(batch_size, image_size, mask_radius)

    # our implementation
    radon = Radon(image_size,
                  angles,
                  det_spacing=spacing,
                  det_count=det_count,
                  clip_to_circle=clip_to_circle)
    x = torch.FloatTensor(x).to(device)

    sinogram = radon.forward(x)
    single_precision = radon.backprojection(sinogram)

    h_sino = radon.forward(x.half())
    half_precision = radon.backprojection(h_sino)

    forward_error = relative_error(sinogram.cpu().numpy(),
                                   h_sino.cpu().numpy())
    back_error = relative_error(single_precision.cpu().numpy(),
                                half_precision.cpu().numpy())

    print(
        f"batch: {batch_size}, size: {image_size}, angles: {len(angles)}, spacing: {spacing}, circle: {clip_to_circle}, forward: {forward_error}, back: {back_error}"
    )

    assert_less(forward_error, 1e-3)
    assert_less(back_error, 1e-3)
Esempio n. 14
0
def main():
    n_angles = 100
    image_size = 512
    circle_radius = 100
    source_dist = 1.5 * image_size
    batch_size = 1
    n_scales = 5

    angles = (np.linspace(0., 100., n_angles, endpoint=False) -
              50.0) / 180.0 * np.pi

    x = np.zeros((image_size, image_size), dtype=np.float32)
    x[circle_mask(image_size, circle_radius)] = 1.0

    radon = Radon(image_size,
                  angles)  # RadonFanbeam(image_size, angles, source_dist)
    shearlet = ShearletTransform(image_size, image_size, [0.5] * n_scales)

    torch_x = torch.from_numpy(x).cuda()
    torch_x = torch_x.view(1, image_size, image_size).repeat(batch_size, 1, 1)
    sinogram = radon.forward(torch_x)

    bp = radon.backward(sinogram)
    sc = shearlet.forward(bp)

    p_0 = 0.02
    p_1 = 0.1
    w = 3**shearlet.scales / 400
    w = w.view(1, -1, 1, 1).cuda()

    u_2 = torch.zeros_like(bp)
    z_2 = torch.zeros_like(bp)
    u_1 = torch.zeros_like(sc)
    z_1 = torch.zeros_like(sc)
    f = torch.zeros_like(bp)

    relative_error = []
    start_time = time.time()
    for i in range(100):
        cg_y = p_0 * bp + p_1 * shearlet.backward(z_1 - u_1) + (z_2 - u_2)
        f = cg(lambda x: p_0 * radon.backward(radon.forward(x)) +
               (1 + p_1) * x,
               f.clone(),
               cg_y,
               max_iter=50)
        sh_f = shearlet.forward(f)

        z_1 = shrink(sh_f + u_1, p_0 / p_1 * w)
        z_2 = (f + u_2).clamp_min(0)
        u_1 = u_1 + sh_f - z_1
        u_2 = u_2 + f - z_2

        relative_error.append(
            (torch.norm(torch_x[0] - f[0]) / torch.norm(torch_x[0])).item())

    runtime = time.time() - start_time
    print("Running time:", runtime)
    print("Running time per image:", runtime / batch_size)
    print("Relative error: ", 100 * relative_error[-1])
    def __init__(self, args, image):
        self.device = torch.device(
            'cuda' if torch.cuda.is_available() else 'cpu')

        if args.twoends:
            factor = 192 / (args.angles + 2)  # 7.68
        else:
            factor = 180 / args.angles  # 7.826086956521739

        self.net = UNet(input_nc=1, output_nc=1,
                        scale_factor=factor).to(self.device)
        self.net = nn.DataParallel(self.net)
        pathG = os.path.join(args.ckpt)
        self.net.load_state_dict(torch.load(pathG, map_location=self.device))
        self.net.eval()

        self.image = image.to(self.device)
        self.twoends = args.twoends
        self.mask = self.gen_mask().to(self.device)

        # Radon Operator
        angles = np.linspace(0, np.pi, 180, endpoint=False)
        self.radon = Radon(args.height, angles, clip_to_circle=True)
Esempio n. 16
0
class Operators():
    def __init__(self, image_size, n_angles, sample_ratio, device, circle=False):
        self.device = device
        self.image_size = image_size
        self.sample_ratio = sample_ratio
        self.n_angles = n_angles
        
        angles = np.linspace(0, np.pi, self.n_angles, endpoint=False)
        self.radon = Radon(self.image_size, angles, clip_to_circle=circle)
        self.radon_sparse = Radon(self.image_size, angles[::sample_ratio], clip_to_circle=circle)
        self.n_angles_sparse = len(angles[::sample_ratio])
        self.landweber = Landweber(self.radon)
        
        self.mask = torch.zeros((1,1,1,180)).to(device)
        self.mask[:,:,:,::sample_ratio].fill_(1)
        
        
    # $X^\T ()$ inverse radon
    def forward_adjoint(self, input):
        # check dimension
        if input.size()[3] == self.n_angles:
            return self.radon.backprojection(input.permute(0,1,3,2))
        elif input.size()[3] == self.n_angles_sparse:
            return self.radon_sparse.backprojection(input.permute(0,1,3,2))/self.n_angles_sparse*self.n_angles  # scale the angles
        else:
            raise Exception(f'forward_adjoint input dimension wrong! received {input.size()}.')
            
        
    # $X^\T X ()$
    def forward_gramian(self, input):
        # check dimension
        if input.size()[2] != self.image_size:
            raise Exception(f'forward_gramian input dimension wrong! received {input.size()}.')
        
        sinogram = self.radon.forward(input)
        return self.radon.backprojection(sinogram)
    

    # Corruption model: undersample sinogram by 8
    def undersample_model(self, input):
        return input[:,:,:,::self.sample_ratio]
    
    
    # Filtered Backprojection. Input siogram range = (0,1)
    def FBP(self, input):
        # check dimension
        if input.size()[2] != self.image_size or input.size()[3] != self.n_angles:
            raise Exception(f'FBP input dimension wrong! received {input.size()}.')
        filtered_sinogram = self.radon.filter_sinogram(input.permute(0,1,3,2))
        return self.radon.backprojection(filtered_sinogram)
    
    
    # estimate step size eta
    def estimate_eta(self):
        eta = self.landweber.estimate_alpha(self.image_size, self.device)
        return torch.tensor(eta, dtype=torch.float32, device=self.device)
Esempio n. 17
0
    def test_shapes(self):
        """
        Check using channels is ok
        """
        device = torch.device('cuda')
        angles = torch.FloatTensor(
            np.linspace(0, 2 * np.pi, 10).astype(np.float32)).to(device)
        radon = Radon(64, angles)

        # test with 2 batch dimensions
        x = torch.FloatTensor(2, 3, 64, 64).to(device)
        y = radon.forward(x)
        self.assertEqual(y.size(), (2, 3, 10, 64))
        z = radon.backprojection(y)
        self.assertEqual(z.size(), (2, 3, 64, 64))

        # no batch dimensions
        x = torch.FloatTensor(64, 64).to(device)
        y = radon.forward(x)
        self.assertEqual(y.size(), (10, 64))
        z = radon.backprojection(y)
        self.assertEqual(z.size(), (64, 64))
Esempio n. 18
0
import torch

from torch_radon import Radon
from torch_radon.solvers import Landweber
from utils import show_images

batch_size = 1
n_angles = 512
image_size = 512

img = np.load("phantom.npy")
device = torch.device('cuda')

# instantiate Radon transform
angles = np.linspace(0, np.pi, n_angles, endpoint=False)
radon = Radon(image_size, angles)

landweber = Landweber(radon)

# estimate step size
alpha = landweber.estimate_alpha(image_size, device)

with torch.no_grad():
    x = torch.FloatTensor(img).reshape(1, 1, image_size, image_size).to(device)
    sinogram = radon.forward(x)

    # use landweber iteration to reconstruct the image
    # values returned by 'callback' are stored inside 'progress'
    reconstruction, progress = landweber.run(
        torch.zeros(x.size(), device=device),
        sinogram,
Esempio n. 19
0
class Inpaint():
    def __init__(self, net, args, dataloader, device):
        self.netG = net[0]
        self.netDG = net[1]
        self.netDL = net[2]
        if args.mode == 'vgg':
            self.netLoss = net[3]
        self.optimizerG = optim.Adam(self.netG.parameters(), lr=args.lr, betas=(0.5, 0.999))
        self.optimizerDG = optim.Adam(self.netDG.parameters(), lr=args.lr, betas=(0.5, 0.999))
        self.optimizerDL = optim.Adam(self.netDL.parameters(), lr=args.lr, betas=(0.5, 0.999))
        
        self.dataloader = dataloader
        self.device = device
        self.args = args
        self.save_cp = True
        self.start_epoch = args.load+1 if args.load>=0 else 0
        self.mask = self.gen_mask().to(self.device)
        
        self.criterionL1 = torch.nn.L1Loss().to(self.device)
        self.criterionL2 = torch.nn.MSELoss().to(self.device)
        self.criterionGAN = GANLoss('vanilla').to(self.device)
        
        err_list = ["errDG", "errDL", 
                    "errGG_GAN", "errGG_C", "errGG_F", "errGG_P",
                    "errGL_GAN", "errGL_C", "errGL_F", "errGL_P"]
        self.err = dict.fromkeys(err_list, None) 
                    
        if self.save_cp:
            try:
                if not os.path.exists(os.path.join(args.outdir, 'ckpt')):
                    os.makedirs(os.path.join(args.outdir, 'ckpt'))
                    print('Created checkpoint directory')
                if args.load < 0:  # New log file
                    with open(os.path.join(args.outdir, args.log_fn+'.csv'), 'w', newline='') as f:
                        csvwriter = writer(f)
                        csvwriter.writerow(["epoch", "runtime"] + err_list)
            except OSError:
                pass
        
        angles = np.linspace(0, np.pi, 180, endpoint=False)
        self.radon = Radon(args.height, angles, clip_to_circle=True)

        
    def gen_mask(self):
        mask = torch.zeros(180)
        mask[::8].fill_(1)  # 180/23
        if self.args.twoends:
            mask = torch.cat((mask[-6:], mask, mask[:6]), 0) # 192/25
        return mask
            
    
    def gen_sparse(self, y):
        return y[:,:,:,self.mask==1]
    
    
    def append_twoends(self, y):
        front = torch.flip(y[:,:,:,:6], [2])
        back = torch.flip(y[:,:,:,-6:], [2])
        return torch.cat((back, y, front), 3)
    
    
    def ramp_module(self, sinogram):
        '''
            Sinogram has dimension: bs x c x height x angle. 
            Ramp is 1D but angle number affects normalization for filter_sinogram. Use with caution.
        '''
        normalized_sinogram = normalize(sinogram, rto=(0,1))
        if sinogram.size()[2] == self.args.height:
            filtered_sinogram = self.radon.filter_sinogram(normalized_sinogram.permute(0,1,3,2)).permute(0,1,3,2)  # 320 x 192
        else:
            print('sinogram dimension wrong for filter!')
        return normalize(filtered_sinogram, rto=(-1,1))
    
    
    def criterionP(self, Gx, y):
        # calculate feature loss
        y_features = self.netLoss(y)
        Gx_features = self.netLoss(Gx)
        
        loss = 0.0
        for j in range(len(y_features)):
            loss += self.criterionL2(Gx_features[j], y_features[j][:y.shape[0]])
        return loss
    
    
    def criterionDP(self, Gx_features, y_features):
        loss = 0.0
        for j in range(len(y_features)):
            loss += self.criterionL2(Gx_features[j], y_features[j])
        return loss
        
        
    def train_D(self, Gx, y, mode):
        '''
            mode is G/L.
        '''
        if mode == 'G':
            netD = self.netDG
            optimizer = self.optimizerDG
        elif mode == 'L':
            netD = self.netDL
            optimizer = self.optimizerDL
        else:
            print('wrong mode!')
            
        netD.zero_grad()
        
        ############################
        # Loss_D: L_D = -(log(D(y) + log(1 - D(G(x))))
        ###########################
        # train with fake
        D_Gx = netD(Gx.detach())[-1]
        errD_fake = self.criterionGAN(D_Gx, False)

        # train with real
        D_y = netD(y)[-1]
        errD_real = self.criterionGAN(D_y, True)

        # backprop
        errD = (errD_real + errD_fake) * 0.5
        errD.backward()
        optimizer.step()
        self.err['errD'+mode] = errD.item()
    
    
    def train_G(self, Gx, y, filtered_Gx, filtered_y, mode):
        '''
            mode is G/L.
        '''
        if mode == 'G':
            netD = self.netDG
        elif mode == 'L':
            netD = self.netDL
        else:
            print('wrong mode!')
            
        self.netG.zero_grad()
        
        ############################
        # Loss_G_GAN: L_G = -log(D(G(x))  # Fake the D
        ###########################
        Gx_features = netD(Gx)
        errG_GAN = self.criterionGAN(Gx_features[-1], True)
        
        ############################
        # Loss_G_C: L_C = ||y - G(x)||_1
        ###########################
        errG_C = self.criterionL1(Gx, y)*50

        ############################
        # Loss_G_DP: Discriminator perceptual feature loss
        ###########################
        if self.args.mode == 'vgg':
            errG_P = self.criterionP(Gx, y)*20
        elif self.args.mode == 'DP':
            y_features = netD(y)
            errG_P = self.criterionDP(Gx_features[:-1], y_features[:-1])*20
#             errG_P = self.criterionDP(Gx_features[-2], y_features[-2])*50
        else:
            errG_P = torch.tensor(0).to(self.device)
            
        ############################
        # Loss_G_F: Ramp filtered sinogram loss
        ###########################
        errG_F = self.criterionL1(filtered_Gx, filtered_y)*50

        # backprop
        errG = errG_GAN + errG_C + errG_F + errG_P
        errG.backward()
        self.optimizerG.step()
        
        self.err['errG'+mode+'_GAN'] = errG_GAN.item()
        self.err['errG'+mode+'_C'] = errG_C.item()
        self.err['errG'+mode+'_F'] = errG_F.item()
        self.err['errG'+mode+'_P'] = errG_P.item()
    
    
    def log(self, epoch, i):
        print(f'[{epoch}/{self.args.epochs}][{i}/{len(self.dataloader)}] ' \
              f'LossDG: {self.err["errDG"]:.4f} ' \
              f'LossGG_GAN: {self.err["errGG_GAN"]:.4f} ' \
              f'LossGG_C: {self.err["errGG_C"]:.4f} ' \
              f'LossGG_F: {self.err["errGG_F"]:.4f} ' \
              f'LossGG_P: {self.err["errGG_P"]:.4f} ' \
              
              f'LossDL: {self.err["errDL"]:.4f} ' \
              f'LossGL_GAN: {self.err["errGL_GAN"]:.4f} ' \
              f'LossGL_C: {self.err["errGL_C"]:.4f} ' \
              f'LossGL_F: {self.err["errGL_F"]:.4f} ' \
              f'LossGL_P: {self.err["errGL_P"]:.4f} ' \
        )
    
    
    def log2file(self, fn, epoch, runtime):
        new_row = [epoch, runtime]+ list[self.err.values()]
        with open(fn, 'a+', newline='') as write_obj:
            csv_writer = writer(write_obj)
            csv_writer.writerow(new_row)
        
        
    def train(self):
        print(f'''Starting training:
            Epochs:          {self.args.epochs}
            Batch size:      {self.args.batchSize}
            Learning rate:   {self.args.lr}
            Checkpoints:     {self.save_cp}
            Device:          {self.device.type}
        ''')
        
        for epoch in range(self.start_epoch, self.args.epochs):
            self.D_epochs = 1 # Adjust if you want
            print('D is trained ', str(self.D_epochs), 'times in this epoch.')
            
            start = time.time()  # log start time
            for i, data in enumerate(self.dataloader):
                y = data[0].to(self.device)  # 320 x 180

                # forward
                if self.args.twoends:
                    y = self.append_twoends(y)  # 320 x 192
                
                filtered_y = self.ramp_module(y)  # 320 x 192, normalized to -1~1
                x = self.gen_sparse(y)  # 320 x 25
                
                # Train Global
                Gx = self.netG(x)
                filtered_Gx = self.ramp_module(Gx)  # 320 x 192, normalized to -1~1

                ###### Train D
                set_requires_grad(self.netDG, True)
                for _ in range(self.D_epochs):  # increase D epoch gradually. FOR DP LOSS training
                    self.train_D(Gx, y, mode='G')
                ###### Train G
                set_requires_grad(self.netDG, False)  # D requires no gradients when optimizing G
                self.train_G(Gx, y, filtered_Gx, filtered_y, mode='G')
                
                # Train Local
                Gx = self.netG(x)
                filtered_Gx = self.ramp_module(Gx)  # 320 x 192, normalized to -1~1
                patch_area = gen_hole_area((y.shape[3]//4, y.shape[2]//4), (y.shape[3], y.shape[2]))
                Gx_patch = crop(Gx, patch_area)
                y_patch = crop(y, patch_area)
                filtered_y_patch = crop(filtered_y, patch_area)
                filtered_Gx_patch = crop(filtered_Gx, patch_area)
                
                ###### Train D
                set_requires_grad(self.netDL, True)
                for _ in range(self.D_epochs):  # increase D epoch gradually. FOR DP LOSS training
                    self.train_D(Gx_patch, y_patch, mode='L')
                ###### Train G
                set_requires_grad(self.netDL, False)  # D requires no gradients when optimizing G
                self.train_G(Gx_patch, y_patch, filtered_Gx_patch, filtered_y_patch, mode='L')
                
                if i % 100 == 0:
                    self.log(epoch, i)
                    
            end = time.time()  # log end time
#             self.log2file(os.path.join(self.args.outdir, self.args.log_fn+'.csv'), epoch , str(end-start))
            
            # Log
            self.log(epoch, i)
            if self.save_cp:
                torch.save(self.netG.state_dict(), f'{self.args.outdir}/ckpt/G_epoch{epoch}.pth')
                torch.save(self.netDG.state_dict(), f'{self.args.outdir}/ckpt/DG_epoch{epoch}.pth')
                torch.save(self.netDL.state_dict(), f'{self.args.outdir}/ckpt/DL_epoch{epoch}.pth')
            vutils.save_image(Gx.detach(), '%s/impainted_samples_epoch_%03d.png' % (self.args.outdir, epoch), normalize=True)
            
class Predict():
    def __init__(self, args, image):
        self.device = torch.device(
            'cuda' if torch.cuda.is_available() else 'cpu')

        if args.twoends:
            factor = 192 / (args.angles + 2)  # 7.68
        else:
            factor = 180 / args.angles  # 7.826086956521739

        self.net = UNet(input_nc=1, output_nc=1,
                        scale_factor=factor).to(self.device)
        self.net = nn.DataParallel(self.net)
        pathG = os.path.join(args.ckpt)
        self.net.load_state_dict(torch.load(pathG, map_location=self.device))
        self.net.eval()

        self.image = image.to(self.device)
        self.twoends = args.twoends
        self.mask = self.gen_mask().to(self.device)

        # Radon Operator
        angles = np.linspace(0, np.pi, 180, endpoint=False)
        self.radon = Radon(args.height, angles, clip_to_circle=True)

    def gen_mask(self):
        mask = torch.zeros(180)
        mask[::8].fill_(1)  # 180
        if self.twoends:
            mask = torch.cat((mask[-6:], mask, mask[:6]), 0)  # 192
        return mask

    def append_twoends(self, y):
        front = torch.flip(y[:, :, :, :6], [2])
        back = torch.flip(y[:, :, :, -6:], [2])
        return torch.cat((back, y, front), 3)

    def gen_sparse(self, y):
        return y[:, :, :, self.mask == 1]

    def crop_sinogram(self, x):
        return x[:, :, :, 6:-6]

    def inpaint(self):
        y = self.image  # 320 x 180

        # Two-Ends Preprocessing
        if self.twoends:
            y = self.append_twoends(y)  # 320 x 192

        # Generate Sparse-view Image, forward model
        x = self.gen_sparse(y)
        Gx = self.net(x)

        # Crop Two-Ends
        if self.twoends:
            Gx = self.crop_sinogram(Gx)

        # FBP Reconstruction
        Gx = normalize(Gx)  # 0~1
        fbp_Gx = self.radon.backprojection(
            self.radon.filter_sinogram(Gx.permute(0, 1, 3, 2)))

        # Save Results
        vutils.save_image(fbp_Gx, 'result_reconstruction.png', normalize=True)
        vutils.save_image(Gx, 'result_sinogram.png', normalize=True)
Esempio n. 21
0

def shrink(a, b):
    return (torch.abs(a) - b).clamp_min(0) * torch.sign(a)


batch_size = 1
n_angles = 512
image_size = 512

img = np.load("phantom.npy")
device = torch.device('cuda')

# instantiate Radon transform
angles = np.linspace(0, np.pi, n_angles, endpoint=False)
radon = Radon(image_size, angles)

x = torch.FloatTensor(img).to(device).view(1, 512, 512)
x = torch.cat([x] * 4, dim=0).view(2, 2, 512, 512)
print(x.size())
y = radon.forward(x)

# CG(radon, 1.0, 0.0, torch.zeros_like(x), radon.backward(y))
# rec = cgne(radon, torch.zeros_like(x), y, tol=1e-2)
s = time.time()
for _ in range(1):
    with torch.no_grad():
        rec, values = cg(lambda z: radon.backward(radon.forward(z)),
                         torch.zeros_like(x),
                         radon.backward(y),
                         callback=lambda x, r: torch.norm(
Esempio n. 22
0

def shrink(a, b):
    return (torch.abs(a) - b).clamp_min(0) * torch.sign(a)


batch_size = 1
n_angles = 512 // 4
image_size = 512

img = np.load("phantom.npy")
device = torch.device('cuda')

# instantiate Radon transform
angles = np.linspace(0, np.pi / 4, n_angles, endpoint=False)
radon = Radon(image_size, angles)
shearlet = Shearlet(512, 512, [0.5] * 5, cache=None)  # ".cache")

with torch.no_grad():
    x = torch.FloatTensor(img).reshape(1, image_size, image_size).to(device)
    sinogram = radon.forward(x)
    bp = radon.backward(sinogram, extend=False)

    # f, values = CG(radon, 1.0 / 512**2, 0.0001, bp.clone(), bp)
    #
    # print(torch.norm(x - f)/torch.norm(x))
    sc = shearlet.forward(bp)
    p_0 = 0.02
    p_1 = 0.1
    w = 3**shearlet.scales / 400
    w = w.view(1, -1, 1, 1).to(device)
Esempio n. 23
0
def main():
    parser = argparse.ArgumentParser(
        description='Benchmark and compare with Astra Toolbox')
    parser.add_argument('--task', default="all")
    parser.add_argument('--image-size', default=256, type=int)
    parser.add_argument('--angles', default=-1, type=int)
    parser.add_argument('--batch-size', default=32, type=int)
    parser.add_argument('--samples', default=50, type=int)
    parser.add_argument('--warmup', default=10, type=int)
    parser.add_argument('--output', default="")
    parser.add_argument('--circle', action='store_true')

    args = parser.parse_args()
    if args.angles == -1:
        args.angles = args.image_size

    device = torch.device("cuda")
    angles = np.linspace(0, 2 * np.pi, args.angles,
                         endpoint=False).astype(np.float32)

    radon = Radon(args.image_size, angles, clip_to_circle=args.circle)
    radon_fb = RadonFanbeam(args.image_size,
                            angles,
                            args.image_size,
                            clip_to_circle=args.circle)

    astra_pw = AstraParallelWrapper(angles, args.image_size)
    astra_fw = AstraFanbeamWrapper(angles, args.image_size)
    # astra = AstraWrapper(angles)

    if args.task == "all":
        tasks = ["forward", "backward", "fanbeam forward", "fanbeam backward"]
    elif args.task == "shearlet":
        # tasks = ["shearlet forward", "shearlet backward"]
        benchmark_shearlet(args)
        return
    else:
        tasks = [args.task]

    astra_fps = []
    radon_fps = []
    radon_half_fps = []

    if "forward" in tasks:
        print("Benchmarking forward from device")
        x = generate_random_images(args.batch_size, args.image_size)
        dx = torch.FloatTensor(x).to(device)

        astra_time = benchmark_function(lambda y: astra_pw.forward(y), dx,
                                        args.samples, args.warmup)
        radon_time = benchmark_function(lambda y: radon.forward(y),
                                        dx,
                                        args.samples,
                                        args.warmup,
                                        sync=True)
        radon_half_time = benchmark_function(lambda y: radon.forward(y),
                                             dx.half(),
                                             args.samples,
                                             args.warmup,
                                             sync=True)

        astra_fps.append(args.batch_size / astra_time)
        radon_fps.append(args.batch_size / radon_time)
        radon_half_fps.append(args.batch_size / radon_half_time)

        print("Speedup:", astra_time / radon_time)
        print("Speedup half-precision:", astra_time / radon_half_time)
        print()

    if "backward" in tasks:
        print("Benchmarking backward from device")
        x = generate_random_images(args.batch_size, args.image_size)
        dx = torch.FloatTensor(x).to(device)

        astra_time = benchmark_function(lambda y: astra_pw.backward(y), dx,
                                        args.samples, args.warmup)
        radon_time = benchmark_function(lambda y: radon.backward(y),
                                        dx,
                                        args.samples,
                                        args.warmup,
                                        sync=True)
        radon_half_time = benchmark_function(lambda y: radon.backward(y),
                                             dx.half(),
                                             args.samples,
                                             args.warmup,
                                             sync=True)

        astra_fps.append(args.batch_size / astra_time)
        radon_fps.append(args.batch_size / radon_time)
        radon_half_fps.append(args.batch_size / radon_half_time)

        print("Speedup:", astra_time / radon_time)
        print("Speedup half-precision:", astra_time / radon_half_time)
        print()

    if "fanbeam forward" in tasks:
        print("Benchmarking fanbeam forward")
        x = generate_random_images(args.batch_size, args.image_size)
        dx = torch.FloatTensor(x).to(device)
        #
        astra_time = benchmark_function(lambda y: astra_fw.forward(y), dx,
                                        args.samples, args.warmup)
        radon_time = benchmark_function(lambda y: radon_fb.forward(y),
                                        dx,
                                        args.samples,
                                        args.warmup,
                                        sync=True)
        radon_half_time = benchmark_function(lambda y: radon_fb.forward(y),
                                             dx.half(),
                                             args.samples,
                                             args.warmup,
                                             sync=True)

        astra_fps.append(args.batch_size / astra_time)
        radon_fps.append(args.batch_size / radon_time)
        radon_half_fps.append(args.batch_size / radon_half_time)

        print("Speedup:", astra_time / radon_time)
        print("Speedup half-precision:", astra_time / radon_half_time)
        print()

    if "fanbeam backward" in tasks:
        print("Benchmarking fanbeam backward")
        x = generate_random_images(args.batch_size, args.image_size)
        dx = torch.FloatTensor(x).to(device)
        #
        astra_time = benchmark_function(lambda y: astra_fw.backward(y), dx,
                                        args.samples, args.warmup)
        radon_time = benchmark_function(lambda y: radon_fb.backprojection(y),
                                        dx,
                                        args.samples,
                                        args.warmup,
                                        sync=True)
        radon_half_time = benchmark_function(
            lambda y: radon_fb.backprojection(y),
            dx.half(),
            args.samples,
            args.warmup,
            sync=True)

        astra_fps.append(args.batch_size / astra_time)
        radon_fps.append(args.batch_size / radon_time)
        radon_half_fps.append(args.batch_size / radon_half_time)

        print("Speedup:", astra_time / radon_time)
        print("Speedup half-precision:", astra_time / radon_half_time)
        print()

    title = f"Image size {args.image_size}x{args.image_size}, {args.angles} angles and batch size {args.batch_size} on a {torch.cuda.get_device_name(0)}"

    plot(tasks, astra_fps, radon_fps, radon_half_fps, title)
    if args.output:
        plt.savefig(args.output, dpi=300)
    else:
        plt.show()
Esempio n. 24
0
import matplotlib.pyplot as plt
import numpy as np
import torch
from utils import show_images

from torch_radon import Radon

device = torch.device('cuda')

img = np.load("phantom.npy")
image_size = img.shape[0]
n_angles = image_size

# Instantiate Radon transform. clip_to_circle should be True when using filtered backprojection.
angles = np.linspace(0, np.pi, n_angles, endpoint=False)
radon = Radon(image_size, angles, clip_to_circle=True)

with torch.no_grad():
    x = torch.FloatTensor(img).to(device)

    sinogram = radon.forward(x)
    filtered_sinogram = radon.filter_sinogram(sinogram)
    fbp = radon.backprojection(filtered_sinogram)

print("FBP Error", torch.norm(x - fbp).item())

# Show results
titles = [
    "Original Image", "Sinogram", "Filtered Sinogram",
    "Filtered Backprojection"
]
Esempio n. 25
0
def main():
    parser = argparse.ArgumentParser(description='Benchmark and compare with Astra Toolbox')
    parser.add_argument('--task', default="all")
    parser.add_argument('--image-size', default=256, type=int)
    parser.add_argument('--angles', default=-1, type=int)
    parser.add_argument('--batch-size', default=32, type=int)
    parser.add_argument('--samples', default=50, type=int)
    parser.add_argument('--warmup', default=10, type=int)
    parser.add_argument('--output', default="")
    parser.add_argument('--circle', action='store_true')

    args = parser.parse_args()
    if args.angles == -1:
        args.angles = args.image_size

    device = torch.device("cuda")
    angles = np.linspace(0, 2 * np.pi, args.angles, endpoint=False).astype(np.float32)

    radon = Radon(args.image_size, angles, clip_to_circle=args.circle)
    radon_fb = RadonFanbeam(args.image_size, angles, args.image_size, clip_to_circle=args.circle)
    astra = AstraWrapper(angles)

    if args.task == "all":
        tasks = ["forward", "backward", "fanbeam forward", "fanbeam backward"]
    else:
        tasks = [args.task]

    astra_fps = []
    radon_fps = []
    radon_half_fps = []

    # x = torch.randn((args.batch_size, args.image_size, args.image_size), device=device)

    # if "forward" in tasks:
    #     print("Benchmarking forward")
    #     x = generate_random_images(args.batch_size, args.image_size)
    #     astra_time = benchmark_function(lambda y: astra.forward(y), x, args.samples, args.warmup)
    #     radon_time = benchmark_function(lambda y: radon.forward(torch.FloatTensor(x).to(device)).cpu(), x, args.samples,
    #                                     args.warmup)
    #     radon_half_time = benchmark_function(lambda y: radon.forward(torch.HalfTensor(x).to(device)).cpu(), x,
    #                                          args.samples, args.warmup)
    #
    #     astra_fps.append(args.batch_size / astra_time)
    #     radon_fps.append(args.batch_size / radon_time)
    #     radon_half_fps.append(args.batch_size / radon_half_time)
    #
    #     print(astra_time, radon_time, radon_half_time)
    #     astra.clean()
    #
    # if "backward" in tasks:
    #     print("Benchmarking backward")
    #     x = generate_random_images(args.batch_size, args.image_size)
    #     pid, x = astra.forward(x)
    #
    #     astra_time = benchmark_function(lambda y: astra.backproject(pid, args.image_size, args.batch_size), x,
    #                                     args.samples, args.warmup)
    #     radon_time = benchmark_function(lambda y: radon.backward(torch.FloatTensor(x).to(device)).cpu(), x,
    #                                     args.samples,
    #                                     args.warmup)
    #     radon_half_time = benchmark_function(lambda y: radon.backward(torch.HalfTensor(x).to(device)).cpu(), x,
    #                                          args.samples, args.warmup)
    #
    #     astra_fps.append(args.batch_size / astra_time)
    #     radon_fps.append(args.batch_size / radon_time)
    #     radon_half_fps.append(args.batch_size / radon_half_time)
    #
    #     print(astra_time, radon_time, radon_half_time)
    #     astra.clean()

    #     if "forward+backward" in tasks:
    #         print("Benchmarking forward + backward")
    #         x = generate_random_images(args.batch_size, args.image_size)
    #         astra_time = benchmark_function(lambda y: astra_forward_backward(astra, y, args.image_size, args.batch_size), x,
    #                                         args.samples, args.warmup)
    #         radon_time = benchmark_function(lambda y: radon_forward_backward(radon, y), x, args.samples,
    #                                         args.warmup)
    #         radon_half_time = benchmark_function(lambda y: radon_forward_backward(radon, y, half=True), x,
    #                                              args.samples, args.warmup)

    #         astra_fps.append(args.batch_size / astra_time)
    #         radon_fps.append(args.batch_size / radon_time)
    #         radon_half_fps.append(args.batch_size / radon_half_time)

    #         print(astra_time, radon_time, radon_half_time)
    #         astra.clean()

    if "forward" in tasks:
        print("Benchmarking forward from device")
        x = generate_random_images(args.batch_size, args.image_size)
        dx = torch.FloatTensor(x).to(device)
        astra_time = benchmark_function(lambda y: astra.forward(y), x, args.samples, args.warmup)
        radon_time = benchmark_function(lambda y: radon.forward(y), dx, args.samples,
                                        args.warmup, sync=True)
        radon_half_time = benchmark_function(lambda y: radon.forward(y), dx.half(),
                                             args.samples, args.warmup, sync=True)

        astra_fps.append(args.batch_size / astra_time)
        radon_fps.append(args.batch_size / radon_time)
        radon_half_fps.append(args.batch_size / radon_half_time)

        print(astra_time, radon_time, radon_half_time)
        astra.clean()

    if "backward" in tasks:
        print("Benchmarking backward from device")
        x = generate_random_images(args.batch_size, args.image_size)
        dx = torch.FloatTensor(x).to(device)
        pid, x = astra.forward(x)

        astra_time = benchmark_function(lambda y: astra.backproject(pid, args.image_size, args.batch_size), x,
                                        args.samples, args.warmup)
        radon_time = benchmark_function(lambda y: radon.backward(y), dx, args.samples,
                                        args.warmup, sync=True)
        radon_half_time = benchmark_function(lambda y: radon.backward(y), dx.half(),
                                             args.samples, args.warmup, sync=True)

        astra_fps.append(args.batch_size / astra_time)
        radon_fps.append(args.batch_size / radon_time)
        radon_half_fps.append(args.batch_size / radon_half_time)

        print(astra_time, radon_time, radon_half_time)
        astra.clean()

    if "fanbeam forward" in tasks:
        print("Benchmarking fanbeam forward")
        x = generate_random_images(args.batch_size, args.image_size)
        dx = torch.FloatTensor(x).to(device)
        #
        # astra_time = benchmark_function(lambda y: astra.backproject(pid, args.image_size, args.batch_size), x,
        #                                 args.samples, args.warmup)
        radon_time = benchmark_function(lambda y: radon_fb.forward(y), dx, args.samples,
                                        args.warmup, sync=True)
        radon_half_time = benchmark_function(lambda y: radon_fb.forward(y), dx.half(),
                                             args.samples, args.warmup, sync=True)

        astra_fps.append(0.0)
        radon_fps.append(args.batch_size / radon_time)
        radon_half_fps.append(args.batch_size / radon_half_time)

        #print(astra_time, radon_time, radon_half_time)
        astra.clean()

    if "fanbeam backward" in tasks:
        print("Benchmarking fanbeam backward")
        x = generate_random_images(args.batch_size, args.image_size)
        dx = torch.FloatTensor(x).to(device)
        #
        # astra_time = benchmark_function(lambda y: astra.backproject(pid, args.image_size, args.batch_size), x,
        #                                 args.samples, args.warmup)
        radon_time = benchmark_function(lambda y: radon_fb.backprojection(y), dx, args.samples,
                                        args.warmup, sync=True)
        radon_half_time = benchmark_function(lambda y: radon_fb.backprojection(y), dx.half(),
                                             args.samples, args.warmup, sync=True)

        astra_fps.append(0.0)
        radon_fps.append(args.batch_size / radon_time)
        radon_half_fps.append(args.batch_size / radon_half_time)

        #print(astra_time, radon_time, radon_half_time)
        astra.clean()

    title = f"Image size {args.image_size}x{args.image_size}, {args.angles} angles and batch size {args.batch_size} on a {torch.cuda.get_device_name(0)}"

    plot(tasks, astra_fps, radon_fps, radon_half_fps, title)
    if args.output:
        plt.savefig(args.output, dpi=300)
    else:
        plt.show()
Esempio n. 26
0
import torch

from torch_radon import Radon
from torch_radon.solvers import Landweber
from utils import show_images

batch_size = 1
n_angles = 512
image_size = 512

img = np.load("phantom.npy")
device = torch.device('cuda')

# instantiate Radon transform
angles = np.linspace(0, np.pi, n_angles, endpoint=False)
radon = Radon(image_size, angles)

with torch.no_grad():
    x = torch.FloatTensor(img).reshape(1, 1, image_size, image_size).to(device)

    sinogram = radon.forward(x)
    filtered_sinogram = radon.filter_sinogram(sinogram)
    fbp = radon.backprojection(filtered_sinogram,
                               extend=False) * np.pi / n_angles

print("FBP Error", torch.norm(x - fbp).item())

titles = [
    "Original Image", "Sinogram", "Filtered Sinogram",
    "Filtered Backprojection"
]
Esempio n. 27
0
# rdx *= (alpha_e - alpha_s)
# rdy *= (alpha_e - alpha_s)
#
# print(rsx, rsy, rsx**2 + rsy**2 - v**2)
# print(rdx, rdy, (rsx+rdx)**2 + (rsy+rdy)**2 - v**2)

device = torch.device('cuda')

angles = np.linspace(0, 2 * np.pi, 180).astype(np.float32)

batch_size = 4
image_size = 256
astraw = AstraWrapper(angles)

x = generate_random_images(batch_size, image_size, masked=True)

astra_fp_id, astra_fp = astraw.forward(x)

# our implementation
radon = Radon(image_size, angles, clip_to_circle=True)
x = torch.FloatTensor(x).to(device)

our_fp = radon.forward(x)

plt.imshow(astra_fp[0])
plt.figure()
plt.imshow(our_fp[0].cpu().numpy())
plt.show()

print(relative_error(astra_fp, our_fp.cpu().numpy()))
Esempio n. 28
0
image_size = 128
channels = 4

device = torch.device('cuda')
criterion = nn.L1Loss()

# Instantiate a model for the sinogram and one for the image
sino_model = nn.Conv2d(1, channels, 5, padding=2).to(device)
image_model = nn.Conv2d(channels, 1, 3, padding=1).to(device)

# create empty images
x = torch.FloatTensor(batch_size, 1, image_size, image_size).to(device)

# instantiate Radon transform
angles = np.linspace(0, np.pi, n_angles)
radon = Radon(image_size, angles)

# forward projection
sinogram = radon.forward(x)

# apply sino_model to sinograms
filtered_sinogram = sino_model(sinogram)

# backprojection
backprojected = radon.backprojection(filtered_sinogram)

# apply image_model to backprojected images
y = image_model(backprojected)

# backward works as usual
loss = criterion(y, x)
Esempio n. 29
0
class Predict():
    def __init__(self, args, dataloader):
        self.device = torch.device(
            'cuda' if torch.cuda.is_available() else 'cpu')
        self.args = args
        self.dataloader = dataloader

        if args.twoends:
            factor = 192 / (args.angles + 2)  # 7.68
        else:
            factor = 180 / args.angles  # 7.826086956521739

        self.net = UNet(input_nc=1, output_nc=1,
                        scale_factor=factor).to(self.device)
        self.net = nn.DataParallel(self.net)
        pathG = os.path.join(args.ckpt)
        self.net.load_state_dict(torch.load(pathG, map_location=self.device))
        self.net.eval()

        self.gen_mask()

        # Radon Operator for different downsampling factors
        angles = np.linspace(0, np.pi, 180, endpoint=False)
        self.radon = Radon(args.height, angles, clip_to_circle=True)
        self.radon23 = Radon(args.height, angles[::8], clip_to_circle=True)
        self.radon45 = Radon(args.height, angles[::4], clip_to_circle=True)
        self.radon90 = Radon(args.height, angles[::2], clip_to_circle=True)

    def gen_mask(self):
        mask = torch.zeros(180)
        mask[::8].fill_(1)  # 180
        if self.args.twoends:
            self.mask = torch.cat((mask[-6:], mask, mask[:6]),
                                  0).to(self.device)  # 192
        self.mask_sparse = mask

    def append_twoends(self, y):
        front = torch.flip(y[:, :, :, :6], [2])
        back = torch.flip(y[:, :, :, -6:], [2])
        return torch.cat((back, y, front), 3)

    def gen_input(self, y, mask):
        return y[:, :, :, mask == 1]

    def crop_sinogram(self, x):
        return x[:, :, :, 6:-6]

    def inpaint(self):
        for i, data in enumerate(self.dataloader):
            y = data[0].to(self.device)  # 320 x 180

            # Two-Ends Preprocessing
            if self.args.twoends:
                y_TE = self.append_twoends(y)  # 320 x 192

            # Forward Model
            x = self.gen_input(y_TE, self.mask)  # input, 320 x 25
            Gx = self.net(x)  # 320 x 192

            # Crop Two-Ends
            if self.args.twoends:
                Gx = self.crop_sinogram(Gx)  # 320 x 180

            # FBP Reconstruction
            Gx = normalize(Gx)  # 0~1
            fbp_Gx = self.radon.backprojection(
                self.radon.filter_sinogram(Gx.permute(0, 1, 3,
                                                      2)))  # 320 x 320

            # FBP for downsampled sinograms
            Gx1 = Gx[:, :, :, ::2]  # 320 x 90
            Gx1 = normalize(Gx1)  # 0~1
            fbp_Gx1 = self.radon90.backprojection(
                self.radon90.filter_sinogram(Gx1.permute(0, 1, 3, 2)))

            Gx2 = Gx[:, :, :, ::4]  # 320 x 45
            Gx2 = normalize(Gx2)  # 0~1
            fbp_Gx2 = self.radon45.backprojection(
                self.radon45.filter_sinogram(Gx2.permute(0, 1, 3, 2)))

            sparse = y[:, :, :, ::8]  # 320 x 23, original sparse-view sinogram
            sparse = normalize(sparse)  # 0~1
            fbp_sparse = self.radon23.backprojection(
                self.radon23.filter_sinogram(sparse.permute(0, 1, 3, 2)))

            print(f'Saving images for batch {i}')

            for j in range(y.size()[0]):
                #                 vutils.save_image(Gx[j,0], f'{self.args.outdir}/{class_name}/{fnames[i*self.args.bs+j]}', normalize=True)
                vutils.save_image(
                    fbp_Gx[j, 0],
                    f'{self.args.outdir}/{class_name}/{fnames[i*self.args.bs+j]}',
                    normalize=True)
                vutils.save_image(
                    fbp_Gx1[j, 0],
                    f'{self.args.outdir}_90/{class_name}/{fnames[i*self.args.bs+j]}',
                    normalize=True)
                vutils.save_image(
                    fbp_Gx2[j, 0],
                    f'{self.args.outdir}_45/{class_name}/{fnames[i*self.args.bs+j]}',
                    normalize=True)
                vutils.save_image(
                    fbp_sparse[j, 0],
                    f'{self.args.outdir}_23/{class_name}/{fnames[i*self.args.bs+j]}',
                    normalize=True)