Exemplo n.º 1
0
 def init_radon(self, beam, circle, det_dist):
     if beam == 'parallel':
         angles = np.linspace(0, np.pi, self.n_angles, endpoint=False)
         self.radon = Radon(self.img_size, angles, clip_to_circle=circle)
         self.radon_sparse = Radon(self.img_size,
                                   angles[::self.sample_ratio],
                                   clip_to_circle=circle)
     elif beam == 'fan':
         angles = np.linspace(0, self.n_angles / 180 * np.pi, self.n_angles,
                              False)
         self.radon = RadonFanbeam(self.img_size,
                                   angles,
                                   source_distance=det_dist[0],
                                   det_distance=det_dist[1],
                                   clip_to_circle=circle,
                                   det_count=self.det_size)
         self.radon_sparse = RadonFanbeam(self.img_size,
                                          angles[::self.sample_ratio],
                                          source_distance=det_dist[0],
                                          det_distance=det_dist[1],
                                          clip_to_circle=circle,
                                          det_count=self.det_size)
     else:
         raise Exception('projection beam type undefined!')
     self.n_angles_sparse = len(angles[::self.sample_ratio])
Exemplo n.º 2
0
    def __init__(self, args, dataloader):
        self.device = torch.device(
            'cuda' if torch.cuda.is_available() else 'cpu')
        self.args = args
        self.dataloader = dataloader

        if args.twoends:
            factor = 192 / (args.angles + 2)  # 7.68
        else:
            factor = 180 / args.angles  # 7.826086956521739

        self.net = UNet(input_nc=1, output_nc=1,
                        scale_factor=factor).to(self.device)
        self.net = nn.DataParallel(self.net)
        pathG = os.path.join(args.ckpt)
        self.net.load_state_dict(torch.load(pathG, map_location=self.device))
        self.net.eval()

        self.gen_mask()

        # Radon Operator for different downsampling factors
        angles = np.linspace(0, np.pi, 180, endpoint=False)
        self.radon = Radon(args.height, angles, clip_to_circle=True)
        self.radon23 = Radon(args.height, angles[::8], clip_to_circle=True)
        self.radon45 = Radon(args.height, angles[::4], clip_to_circle=True)
        self.radon90 = Radon(args.height, angles[::2], clip_to_circle=True)
Exemplo n.º 3
0
 def __init__(self, image_size, n_angles, sample_ratio, device, circle=False):
     self.device = device
     self.image_size = image_size
     self.sample_ratio = sample_ratio
     self.n_angles = n_angles
     
     angles = np.linspace(0, np.pi, self.n_angles, endpoint=False)
     self.radon = Radon(self.image_size, angles, clip_to_circle=circle)
     self.radon_sparse = Radon(self.image_size, angles[::sample_ratio], clip_to_circle=circle)
     self.n_angles_sparse = len(angles[::sample_ratio])
     self.landweber = Landweber(self.radon)
     
     self.mask = torch.zeros((1,1,1,180)).to(device)
     self.mask[:,:,:,::sample_ratio].fill_(1)
Exemplo n.º 4
0
def test_half(device, batch_size, image_size, angles, spacing, det_count,
              clip_to_circle):
    # generate random images
    det_count = int(det_count * image_size)
    mask_radius = det_count / 2.0 if clip_to_circle else -1
    x = generate_random_images(batch_size, image_size, mask_radius)

    # our implementation
    radon = Radon(image_size,
                  angles,
                  det_spacing=spacing,
                  det_count=det_count,
                  clip_to_circle=clip_to_circle)
    x = torch.FloatTensor(x).to(device)

    sinogram = radon.forward(x)
    single_precision = radon.backprojection(sinogram)

    h_sino = radon.forward(x.half())
    half_precision = radon.backprojection(h_sino)

    forward_error = relative_error(sinogram.cpu().numpy(),
                                   h_sino.cpu().numpy())
    back_error = relative_error(single_precision.cpu().numpy(),
                                half_precision.cpu().numpy())

    print(
        f"batch: {batch_size}, size: {image_size}, angles: {len(angles)}, spacing: {spacing}, circle: {clip_to_circle}, forward: {forward_error}, back: {back_error}"
    )

    assert_less(forward_error, 1e-3)
    assert_less(back_error, 1e-3)
Exemplo n.º 5
0
def bench_fanbeam_backward(task, dtype, device, *bench_args):
    num_angles = task["num_angles"]
    det_count = task["det_count"]
    source_dist = task["source_distance"]
    det_dist = task["detector_distance"]
    det_spacing = task["det_spacing"]

    x = torch.randn(task["batch_size"],
                    task["size"],
                    task["size"],
                    dtype=dtype,
                    device=device)
    angles = np.linspace(0, np.pi, num_angles, endpoint=False)

    projection = Projection.fanbeam(source_dist, det_dist, det_count,
                                    det_spacing)
    radon = Radon(angles, task["size"], projection)
    # radon = RadonFanbeam(phantom.size(1), angles, source_dist, det_dist, det_count)

    sino = radon.forward(x)

    def f(x):
        return radon.backward(x)

    return benchmark(f, x, *bench_args)
Exemplo n.º 6
0
def test_error(device, batch_size, image_size, angles, spacing, clip_to_circle):
    # generate random images
    x = generate_random_images(batch_size, image_size, masked=clip_to_circle)

    # astra
    astra = AstraWrapper(angles)

    astra_fp_id, astra_fp = astra.forward(x, spacing)
    astra_bp = astra.backproject(astra_fp_id, image_size, batch_size)
    if clip_to_circle:
        astra_bp *= circle_mask(image_size)

    # our implementation
    radon = Radon(image_size, angles, det_spacing=spacing, clip_to_circle=clip_to_circle)
    x = torch.FloatTensor(x).to(device)

    our_fp = radon.forward(x)
    our_bp = radon.backprojection(our_fp)

    forward_error = relative_error(astra_fp, our_fp.cpu().numpy())
    back_error = relative_error(astra_bp, our_bp.cpu().numpy())

    # if forward_error > 10:
    #     plt.imshow(astra_fp[0])
    #     plt.figure()
    #     plt.imshow(our_fp[0].cpu().numpy())
    #     plt.show()

    print(
        f"batch: {batch_size}, size: {image_size}, angles: {len(angles)}, spacing: {spacing}, circle: {clip_to_circle}, forward: {forward_error}, back: {back_error}")
    # TODO better checks
    assert_less(forward_error, 1e-2)
    assert_less(back_error, 5e-3)
Exemplo n.º 7
0
def bench_parallel_forward(phantom, det_count, num_angles, warmup, repeats):
    radon = Radon(phantom.size(1),
                  np.linspace(0, np.pi, num_angles, endpoint=False), det_count)

    f = lambda x: radon.forward(x)

    return benchmark(f, phantom, warmup, repeats)
Exemplo n.º 8
0
def main():
    n_angles = 100
    image_size = 512
    circle_radius = 100
    source_dist = 1.5 * image_size
    batch_size = 1
    n_scales = 5

    angles = (np.linspace(0., 100., n_angles, endpoint=False) -
              50.0) / 180.0 * np.pi

    x = np.zeros((image_size, image_size), dtype=np.float32)
    x[circle_mask(image_size, circle_radius)] = 1.0

    radon = Radon(image_size,
                  angles)  # RadonFanbeam(image_size, angles, source_dist)
    shearlet = ShearletTransform(image_size, image_size, [0.5] * n_scales)

    torch_x = torch.from_numpy(x).cuda()
    torch_x = torch_x.view(1, image_size, image_size).repeat(batch_size, 1, 1)
    sinogram = radon.forward(torch_x)

    bp = radon.backward(sinogram)
    sc = shearlet.forward(bp)

    p_0 = 0.02
    p_1 = 0.1
    w = 3**shearlet.scales / 400
    w = w.view(1, -1, 1, 1).cuda()

    u_2 = torch.zeros_like(bp)
    z_2 = torch.zeros_like(bp)
    u_1 = torch.zeros_like(sc)
    z_1 = torch.zeros_like(sc)
    f = torch.zeros_like(bp)

    relative_error = []
    start_time = time.time()
    for i in range(100):
        cg_y = p_0 * bp + p_1 * shearlet.backward(z_1 - u_1) + (z_2 - u_2)
        f = cg(lambda x: p_0 * radon.backward(radon.forward(x)) +
               (1 + p_1) * x,
               f.clone(),
               cg_y,
               max_iter=50)
        sh_f = shearlet.forward(f)

        z_1 = shrink(sh_f + u_1, p_0 / p_1 * w)
        z_2 = (f + u_2).clamp_min(0)
        u_1 = u_1 + sh_f - z_1
        u_2 = u_2 + f - z_2

        relative_error.append(
            (torch.norm(torch_x[0] - f[0]) / torch.norm(torch_x[0])).item())

    runtime = time.time() - start_time
    print("Running time:", runtime)
    print("Running time per image:", runtime / batch_size)
    print("Relative error: ", 100 * relative_error[-1])
Exemplo n.º 9
0
    def test_differentiation(self):
        device = torch.device('cuda')
        x = torch.FloatTensor(1, 64, 64).to(device)
        x.requires_grad = True
        angles = torch.FloatTensor(
            np.linspace(0, 2 * np.pi, 10).astype(np.float32)).to(device)

        radon = Radon(64, angles)

        # check that backward is implemented for fp and bp
        y = radon.forward(x)
        z = torch.mean(radon.backprojection(y))
        z.backward()
        self.assertIsNotNone(x.grad)
Exemplo n.º 10
0
    def __init__(self, args, dataloader):
        self.device = torch.device(
            'cuda' if torch.cuda.is_available() else 'cpu')
        self.args = args
        self.dataloader = dataloader

        self.net = UNet(input_nc=1, output_nc=1).to(self.device)
        self.net = nn.DataParallel(self.net)

        pathG = os.path.join(args.ckpt)
        self.net.load_state_dict(torch.load(pathG, map_location=self.device))
        self.net.eval()

        self.gen_mask()

        angles = np.linspace(0, np.pi, 180, endpoint=False)
        self.radon = Radon(args.height, angles, clip_to_circle=True)
Exemplo n.º 11
0
def bench_parallel_forward(task, dtype, device, *bench_args):
    num_angles = task["num_angles"]
    det_count = task["det_count"]

    x = torch.randn(task["batch_size"],
                    task["size"],
                    task["size"],
                    dtype=dtype,
                    device=device)
    angles = np.linspace(0, np.pi, num_angles, endpoint=False)
    projection = Projection.parallel_beam(det_count)
    radon = Radon(angles, task["size"], projection)

    def f(x):
        return radon.forward(x)

    return benchmark(f, x, *bench_args)
Exemplo n.º 12
0
def test_noise():
    device = torch.device('cuda')

    x = torch.FloatTensor(3, 5, 64, 64).to(device)
    lookup_table = torch.FloatTensor(128, 64).to(device)
    x.requires_grad = True
    angles = torch.FloatTensor(np.linspace(0, 2 * np.pi, 10).astype(np.float32))

    radon = Radon(64, angles)

    sinogram = radon.forward(x)
    assert_equal(sinogram.size(), (3, 5, 10, 64))

    readings = radon.emulate_readings(sinogram, 5, 10.0)
    assert_equal(readings.size(), (3, 5, 10, 64))
    assert_equal(readings.dtype, torch.int32)

    y = radon.readings_lookup(readings, lookup_table)
    assert_equal(y.size(), (3, 5, 10, 64))
    assert_equal(y.dtype, torch.float32)
Exemplo n.º 13
0
 def __init__(self, net, args, dataloader, device):
     self.netG = net[0]
     self.netDG = net[1]
     self.netDL = net[2]
     if args.mode == 'vgg':
         self.netLoss = net[3]
     self.optimizerG = optim.Adam(self.netG.parameters(), lr=args.lr, betas=(0.5, 0.999))
     self.optimizerDG = optim.Adam(self.netDG.parameters(), lr=args.lr, betas=(0.5, 0.999))
     self.optimizerDL = optim.Adam(self.netDL.parameters(), lr=args.lr, betas=(0.5, 0.999))
     
     self.dataloader = dataloader
     self.device = device
     self.args = args
     self.save_cp = True
     self.start_epoch = args.load+1 if args.load>=0 else 0
     self.mask = self.gen_mask().to(self.device)
     
     self.criterionL1 = torch.nn.L1Loss().to(self.device)
     self.criterionL2 = torch.nn.MSELoss().to(self.device)
     self.criterionGAN = GANLoss('vanilla').to(self.device)
     
     err_list = ["errDG", "errDL", 
                 "errGG_GAN", "errGG_C", "errGG_F", "errGG_P",
                 "errGL_GAN", "errGL_C", "errGL_F", "errGL_P"]
     self.err = dict.fromkeys(err_list, None) 
                 
     if self.save_cp:
         try:
             if not os.path.exists(os.path.join(args.outdir, 'ckpt')):
                 os.makedirs(os.path.join(args.outdir, 'ckpt'))
                 print('Created checkpoint directory')
             if args.load < 0:  # New log file
                 with open(os.path.join(args.outdir, args.log_fn+'.csv'), 'w', newline='') as f:
                     csvwriter = writer(f)
                     csvwriter.writerow(["epoch", "runtime"] + err_list)
         except OSError:
             pass
     
     angles = np.linspace(0, np.pi, 180, endpoint=False)
     self.radon = Radon(args.height, angles, clip_to_circle=True)
Exemplo n.º 14
0
    def test_shapes(self):
        """
        Check using channels is ok
        """
        device = torch.device('cuda')
        angles = torch.FloatTensor(
            np.linspace(0, 2 * np.pi, 10).astype(np.float32)).to(device)
        radon = Radon(64, angles)

        # test with 2 batch dimensions
        x = torch.FloatTensor(2, 3, 64, 64).to(device)
        y = radon.forward(x)
        self.assertEqual(y.size(), (2, 3, 10, 64))
        z = radon.backprojection(y)
        self.assertEqual(z.size(), (2, 3, 64, 64))

        # no batch dimensions
        x = torch.FloatTensor(64, 64).to(device)
        y = radon.forward(x)
        self.assertEqual(y.size(), (10, 64))
        z = radon.backprojection(y)
        self.assertEqual(z.size(), (64, 64))
Exemplo n.º 15
0
    def __init__(self, args, image):
        self.device = torch.device(
            'cuda' if torch.cuda.is_available() else 'cpu')

        if args.twoends:
            factor = 192 / (args.angles + 2)  # 7.68
        else:
            factor = 180 / args.angles  # 7.826086956521739

        self.net = UNet(input_nc=1, output_nc=1,
                        scale_factor=factor).to(self.device)
        self.net = nn.DataParallel(self.net)
        pathG = os.path.join(args.ckpt)
        self.net.load_state_dict(torch.load(pathG, map_location=self.device))
        self.net.eval()

        self.image = image.to(self.device)
        self.twoends = args.twoends
        self.mask = self.gen_mask().to(self.device)

        # Radon Operator
        angles = np.linspace(0, np.pi, 180, endpoint=False)
        self.radon = Radon(args.height, angles, clip_to_circle=True)
Exemplo n.º 16
0
image_size = 128
channels = 4

device = torch.device('cuda')
criterion = nn.L1Loss()

# Instantiate a model for the sinogram and one for the image
sino_model = nn.Conv2d(1, channels, 5, padding=2).to(device)
image_model = nn.Conv2d(channels, 1, 3, padding=1).to(device)

# create empty images
x = torch.FloatTensor(batch_size, 1, image_size, image_size).to(device)

# instantiate Radon transform
angles = np.linspace(0, np.pi, n_angles)
radon = Radon(image_size, angles)

# forward projection
sinogram = radon.forward(x)

# apply sino_model to sinograms
filtered_sinogram = sino_model(sinogram)

# backprojection
backprojected = radon.backprojection(filtered_sinogram)

# apply image_model to backprojected images
y = image_model(backprojected)

# backward works as usual
loss = criterion(y, x)
Exemplo n.º 17
0
import matplotlib.pyplot as plt
import numpy as np
import torch
from utils import show_images

from torch_radon import Radon

device = torch.device('cuda')

img = np.load("phantom.npy")
image_size = img.shape[0]
n_angles = image_size

# Instantiate Radon transform. clip_to_circle should be True when using filtered backprojection.
angles = np.linspace(0, np.pi, n_angles, endpoint=False)
radon = Radon(image_size, angles, clip_to_circle=True)

with torch.no_grad():
    x = torch.FloatTensor(img).to(device)

    sinogram = radon.forward(x)
    filtered_sinogram = radon.filter_sinogram(sinogram)
    fbp = radon.backprojection(filtered_sinogram)

print("FBP Error", torch.norm(x - fbp).item())

# Show results
titles = [
    "Original Image", "Sinogram", "Filtered Sinogram",
    "Filtered Backprojection"
]
Exemplo n.º 18
0
def main():
    parser = argparse.ArgumentParser(
        description='Benchmark and compare with Astra Toolbox')
    parser.add_argument('--task', default="all")
    parser.add_argument('--image-size', default=256, type=int)
    parser.add_argument('--angles', default=-1, type=int)
    parser.add_argument('--batch-size', default=32, type=int)
    parser.add_argument('--samples', default=50, type=int)
    parser.add_argument('--warmup', default=10, type=int)
    parser.add_argument('--output', default="")
    parser.add_argument('--circle', action='store_true')

    args = parser.parse_args()
    if args.angles == -1:
        args.angles = args.image_size

    device = torch.device("cuda")
    angles = np.linspace(0, 2 * np.pi, args.angles,
                         endpoint=False).astype(np.float32)

    radon = Radon(args.image_size, angles, clip_to_circle=args.circle)
    radon_fb = RadonFanbeam(args.image_size,
                            angles,
                            args.image_size,
                            clip_to_circle=args.circle)

    astra_pw = AstraParallelWrapper(angles, args.image_size)
    astra_fw = AstraFanbeamWrapper(angles, args.image_size)
    # astra = AstraWrapper(angles)

    if args.task == "all":
        tasks = ["forward", "backward", "fanbeam forward", "fanbeam backward"]
    elif args.task == "shearlet":
        # tasks = ["shearlet forward", "shearlet backward"]
        benchmark_shearlet(args)
        return
    else:
        tasks = [args.task]

    astra_fps = []
    radon_fps = []
    radon_half_fps = []

    if "forward" in tasks:
        print("Benchmarking forward from device")
        x = generate_random_images(args.batch_size, args.image_size)
        dx = torch.FloatTensor(x).to(device)

        astra_time = benchmark_function(lambda y: astra_pw.forward(y), dx,
                                        args.samples, args.warmup)
        radon_time = benchmark_function(lambda y: radon.forward(y),
                                        dx,
                                        args.samples,
                                        args.warmup,
                                        sync=True)
        radon_half_time = benchmark_function(lambda y: radon.forward(y),
                                             dx.half(),
                                             args.samples,
                                             args.warmup,
                                             sync=True)

        astra_fps.append(args.batch_size / astra_time)
        radon_fps.append(args.batch_size / radon_time)
        radon_half_fps.append(args.batch_size / radon_half_time)

        print("Speedup:", astra_time / radon_time)
        print("Speedup half-precision:", astra_time / radon_half_time)
        print()

    if "backward" in tasks:
        print("Benchmarking backward from device")
        x = generate_random_images(args.batch_size, args.image_size)
        dx = torch.FloatTensor(x).to(device)

        astra_time = benchmark_function(lambda y: astra_pw.backward(y), dx,
                                        args.samples, args.warmup)
        radon_time = benchmark_function(lambda y: radon.backward(y),
                                        dx,
                                        args.samples,
                                        args.warmup,
                                        sync=True)
        radon_half_time = benchmark_function(lambda y: radon.backward(y),
                                             dx.half(),
                                             args.samples,
                                             args.warmup,
                                             sync=True)

        astra_fps.append(args.batch_size / astra_time)
        radon_fps.append(args.batch_size / radon_time)
        radon_half_fps.append(args.batch_size / radon_half_time)

        print("Speedup:", astra_time / radon_time)
        print("Speedup half-precision:", astra_time / radon_half_time)
        print()

    if "fanbeam forward" in tasks:
        print("Benchmarking fanbeam forward")
        x = generate_random_images(args.batch_size, args.image_size)
        dx = torch.FloatTensor(x).to(device)
        #
        astra_time = benchmark_function(lambda y: astra_fw.forward(y), dx,
                                        args.samples, args.warmup)
        radon_time = benchmark_function(lambda y: radon_fb.forward(y),
                                        dx,
                                        args.samples,
                                        args.warmup,
                                        sync=True)
        radon_half_time = benchmark_function(lambda y: radon_fb.forward(y),
                                             dx.half(),
                                             args.samples,
                                             args.warmup,
                                             sync=True)

        astra_fps.append(args.batch_size / astra_time)
        radon_fps.append(args.batch_size / radon_time)
        radon_half_fps.append(args.batch_size / radon_half_time)

        print("Speedup:", astra_time / radon_time)
        print("Speedup half-precision:", astra_time / radon_half_time)
        print()

    if "fanbeam backward" in tasks:
        print("Benchmarking fanbeam backward")
        x = generate_random_images(args.batch_size, args.image_size)
        dx = torch.FloatTensor(x).to(device)
        #
        astra_time = benchmark_function(lambda y: astra_fw.backward(y), dx,
                                        args.samples, args.warmup)
        radon_time = benchmark_function(lambda y: radon_fb.backprojection(y),
                                        dx,
                                        args.samples,
                                        args.warmup,
                                        sync=True)
        radon_half_time = benchmark_function(
            lambda y: radon_fb.backprojection(y),
            dx.half(),
            args.samples,
            args.warmup,
            sync=True)

        astra_fps.append(args.batch_size / astra_time)
        radon_fps.append(args.batch_size / radon_time)
        radon_half_fps.append(args.batch_size / radon_half_time)

        print("Speedup:", astra_time / radon_time)
        print("Speedup half-precision:", astra_time / radon_half_time)
        print()

    title = f"Image size {args.image_size}x{args.image_size}, {args.angles} angles and batch size {args.batch_size} on a {torch.cuda.get_device_name(0)}"

    plot(tasks, astra_fps, radon_fps, radon_half_fps, title)
    if args.output:
        plt.savefig(args.output, dpi=300)
    else:
        plt.show()
Exemplo n.º 19
0
def main():
    parser = argparse.ArgumentParser(description='Benchmark and compare with Astra Toolbox')
    parser.add_argument('--task', default="all")
    parser.add_argument('--image-size', default=256, type=int)
    parser.add_argument('--angles', default=-1, type=int)
    parser.add_argument('--batch-size', default=32, type=int)
    parser.add_argument('--samples', default=50, type=int)
    parser.add_argument('--warmup', default=10, type=int)
    parser.add_argument('--output', default="")
    parser.add_argument('--circle', action='store_true')

    args = parser.parse_args()
    if args.angles == -1:
        args.angles = args.image_size

    device = torch.device("cuda")
    angles = np.linspace(0, 2 * np.pi, args.angles, endpoint=False).astype(np.float32)

    radon = Radon(args.image_size, angles, clip_to_circle=args.circle)
    radon_fb = RadonFanbeam(args.image_size, angles, args.image_size, clip_to_circle=args.circle)
    astra = AstraWrapper(angles)

    if args.task == "all":
        tasks = ["forward", "backward", "fanbeam forward", "fanbeam backward"]
    else:
        tasks = [args.task]

    astra_fps = []
    radon_fps = []
    radon_half_fps = []

    # x = torch.randn((args.batch_size, args.image_size, args.image_size), device=device)

    # if "forward" in tasks:
    #     print("Benchmarking forward")
    #     x = generate_random_images(args.batch_size, args.image_size)
    #     astra_time = benchmark_function(lambda y: astra.forward(y), x, args.samples, args.warmup)
    #     radon_time = benchmark_function(lambda y: radon.forward(torch.FloatTensor(x).to(device)).cpu(), x, args.samples,
    #                                     args.warmup)
    #     radon_half_time = benchmark_function(lambda y: radon.forward(torch.HalfTensor(x).to(device)).cpu(), x,
    #                                          args.samples, args.warmup)
    #
    #     astra_fps.append(args.batch_size / astra_time)
    #     radon_fps.append(args.batch_size / radon_time)
    #     radon_half_fps.append(args.batch_size / radon_half_time)
    #
    #     print(astra_time, radon_time, radon_half_time)
    #     astra.clean()
    #
    # if "backward" in tasks:
    #     print("Benchmarking backward")
    #     x = generate_random_images(args.batch_size, args.image_size)
    #     pid, x = astra.forward(x)
    #
    #     astra_time = benchmark_function(lambda y: astra.backproject(pid, args.image_size, args.batch_size), x,
    #                                     args.samples, args.warmup)
    #     radon_time = benchmark_function(lambda y: radon.backward(torch.FloatTensor(x).to(device)).cpu(), x,
    #                                     args.samples,
    #                                     args.warmup)
    #     radon_half_time = benchmark_function(lambda y: radon.backward(torch.HalfTensor(x).to(device)).cpu(), x,
    #                                          args.samples, args.warmup)
    #
    #     astra_fps.append(args.batch_size / astra_time)
    #     radon_fps.append(args.batch_size / radon_time)
    #     radon_half_fps.append(args.batch_size / radon_half_time)
    #
    #     print(astra_time, radon_time, radon_half_time)
    #     astra.clean()

    #     if "forward+backward" in tasks:
    #         print("Benchmarking forward + backward")
    #         x = generate_random_images(args.batch_size, args.image_size)
    #         astra_time = benchmark_function(lambda y: astra_forward_backward(astra, y, args.image_size, args.batch_size), x,
    #                                         args.samples, args.warmup)
    #         radon_time = benchmark_function(lambda y: radon_forward_backward(radon, y), x, args.samples,
    #                                         args.warmup)
    #         radon_half_time = benchmark_function(lambda y: radon_forward_backward(radon, y, half=True), x,
    #                                              args.samples, args.warmup)

    #         astra_fps.append(args.batch_size / astra_time)
    #         radon_fps.append(args.batch_size / radon_time)
    #         radon_half_fps.append(args.batch_size / radon_half_time)

    #         print(astra_time, radon_time, radon_half_time)
    #         astra.clean()

    if "forward" in tasks:
        print("Benchmarking forward from device")
        x = generate_random_images(args.batch_size, args.image_size)
        dx = torch.FloatTensor(x).to(device)
        astra_time = benchmark_function(lambda y: astra.forward(y), x, args.samples, args.warmup)
        radon_time = benchmark_function(lambda y: radon.forward(y), dx, args.samples,
                                        args.warmup, sync=True)
        radon_half_time = benchmark_function(lambda y: radon.forward(y), dx.half(),
                                             args.samples, args.warmup, sync=True)

        astra_fps.append(args.batch_size / astra_time)
        radon_fps.append(args.batch_size / radon_time)
        radon_half_fps.append(args.batch_size / radon_half_time)

        print(astra_time, radon_time, radon_half_time)
        astra.clean()

    if "backward" in tasks:
        print("Benchmarking backward from device")
        x = generate_random_images(args.batch_size, args.image_size)
        dx = torch.FloatTensor(x).to(device)
        pid, x = astra.forward(x)

        astra_time = benchmark_function(lambda y: astra.backproject(pid, args.image_size, args.batch_size), x,
                                        args.samples, args.warmup)
        radon_time = benchmark_function(lambda y: radon.backward(y), dx, args.samples,
                                        args.warmup, sync=True)
        radon_half_time = benchmark_function(lambda y: radon.backward(y), dx.half(),
                                             args.samples, args.warmup, sync=True)

        astra_fps.append(args.batch_size / astra_time)
        radon_fps.append(args.batch_size / radon_time)
        radon_half_fps.append(args.batch_size / radon_half_time)

        print(astra_time, radon_time, radon_half_time)
        astra.clean()

    if "fanbeam forward" in tasks:
        print("Benchmarking fanbeam forward")
        x = generate_random_images(args.batch_size, args.image_size)
        dx = torch.FloatTensor(x).to(device)
        #
        # astra_time = benchmark_function(lambda y: astra.backproject(pid, args.image_size, args.batch_size), x,
        #                                 args.samples, args.warmup)
        radon_time = benchmark_function(lambda y: radon_fb.forward(y), dx, args.samples,
                                        args.warmup, sync=True)
        radon_half_time = benchmark_function(lambda y: radon_fb.forward(y), dx.half(),
                                             args.samples, args.warmup, sync=True)

        astra_fps.append(0.0)
        radon_fps.append(args.batch_size / radon_time)
        radon_half_fps.append(args.batch_size / radon_half_time)

        #print(astra_time, radon_time, radon_half_time)
        astra.clean()

    if "fanbeam backward" in tasks:
        print("Benchmarking fanbeam backward")
        x = generate_random_images(args.batch_size, args.image_size)
        dx = torch.FloatTensor(x).to(device)
        #
        # astra_time = benchmark_function(lambda y: astra.backproject(pid, args.image_size, args.batch_size), x,
        #                                 args.samples, args.warmup)
        radon_time = benchmark_function(lambda y: radon_fb.backprojection(y), dx, args.samples,
                                        args.warmup, sync=True)
        radon_half_time = benchmark_function(lambda y: radon_fb.backprojection(y), dx.half(),
                                             args.samples, args.warmup, sync=True)

        astra_fps.append(0.0)
        radon_fps.append(args.batch_size / radon_time)
        radon_half_fps.append(args.batch_size / radon_half_time)

        #print(astra_time, radon_time, radon_half_time)
        astra.clean()

    title = f"Image size {args.image_size}x{args.image_size}, {args.angles} angles and batch size {args.batch_size} on a {torch.cuda.get_device_name(0)}"

    plot(tasks, astra_fps, radon_fps, radon_half_fps, title)
    if args.output:
        plt.savefig(args.output, dpi=300)
    else:
        plt.show()