コード例 #1
0
    def forward(ctx, x, angles, tex_cache, rays_cfg):
        sinogram = torch_radon_cuda.forward(x, angles, tex_cache, rays_cfg)
        ctx.tex_cache = tex_cache
        ctx.rays_cfg = rays_cfg
        ctx.save_for_backward(angles)

        return sinogram
コード例 #2
0
    def backward(ctx, grad_x):
        if not grad_x.is_contiguous():
            grad_x = grad_x.contiguous()

        angles, = ctx.saved_variables
        grad = torch_radon_cuda.forward(grad_x, angles, ctx.tex_cache,
                                        ctx.rays_cfg)
        return grad, None, None, None
コード例 #3
0
    def forward(ctx, x, det_count, det_spacing, angles, tex_cache, clip_to_circle):
        sinogram = torch_radon_cuda.forward(x, det_count, det_spacing, angles, tex_cache, clip_to_circle)
        ctx.tex_cache = tex_cache
        ctx.det_count = det_count
        ctx.det_spacing = det_spacing
        ctx.clip_to_circle = clip_to_circle
        ctx.save_for_backward(angles)

        return sinogram
コード例 #4
0
 def backward(ctx, grad_x):
     angles, = ctx.saved_variables
     grad = torch_radon_cuda.forward(grad_x, ctx.det_count, ctx.det_spacing, angles, ctx.tex_cache,
                                     ctx.clip_to_circle)
     return grad, None, None, None, None, None
コード例 #5
0
 def backward(ctx, grad_x):
     angles, = ctx.saved_variables
     grad = torch_radon_cuda.forward(grad_x, angles, ctx.tex_cache,
                                     ctx.rays_cfg)
     return grad, None, None, None