Esempio n. 1
0
def compute_inverse_filter_basic_fft(ker, eps, ps):
    """
    ft (nks, ps, ps)
    """
    nks = ker.shape[0]
    device = ker.device
    inv_ker = []
    for n in range(nks):
        K = psf2otf(ker[n], (ps, ps))
        D = utils.conj(K) / (
            utils.prod(utils.conj(K), K).sum(-1, keepdim=True) + eps)
        d = otf2psf(D, (ps, ps))
        inv_ker.append(d)
    inv_ker = torch.cat(inv_ker)
    return inv_ker
Esempio n. 2
0
def wiener_filter(img, psf, K):
    """ Performs Wiener filtering on a single channel
    :param img: pytorch tensor of image (N,C,H,W)
    :param psf: pytorch tensor of psf (H,W)
    :param K: damping factor (can be input through hyps or learned)
    :return: Wiener filtered image in one channel (N,C,H,W)
    """
    img = img.to(DEVICE)
    psf = psf.to(DEVICE)
    imag = torch.zeros(img.shape).to(DEVICE)
    img = utils.stack_complex(img, imag)
    img_fft = torch.fft(utils.ifftshift(img), 2)
    img_fft = img_fft.to(DEVICE)

    otf = psf2otf(psf, output_size=img.shape[2:4])
    otf = torch.stack((otf, otf, otf), 0)
    otf = torch.unsqueeze(otf, 0)

    conj_otf = utils.conj(otf)

    otf_img = utils.mul_complex(conj_otf, img_fft)

    denominator = abs_complex(otf)
    denominator[:, :, :, :, 0] += K
    product = utils.div_complex(otf_img, denominator)
    filtered = utils.ifftshift(torch.ifft(product, 2))
    filtered = torch.clamp(filtered, min=1e-5)

    return filtered[:, :, :, :, 0]
Esempio n. 3
0
    def backward(ctx, grad_output, grad_c=None):

        if ctx.needs_input_grad[2] or ctx.needs_input_grad[3]:
            alpha, B, G, Y, wshape = ctx.intermediate_results
            channels = Y.size(2)
        elif ctx.needs_input_grad[0]:
            alpha, B, G = ctx.intermediate_results

        grad_input = grad_weights = grad_alpha = None

        if ctx.needs_input_grad[0] or ctx.needs_input_grad[
                2] or ctx.needs_input_grad[3]:
            D = cabs(B).pow(2).unsqueeze(1)
            T = cabs(G).pow(2).sum(dim=1).unsqueeze(0)
            T = T.mul(alpha.unsqueeze(0).unsqueeze(-1).unsqueeze(-1))
            D = D + T
            del T
            D = D.unsqueeze(-1)

        if ctx.needs_input_grad[0] or ctx.needs_input_grad[2]:
            Z = torch.rfft(grad_output, 2)

        if ctx.needs_input_grad[0]:
            grad_input = torch.irfft(cmul(B.unsqueeze(1), Z).div(D), 2, \
                                  signal_sizes=grad_output.shape[-2:])
            grad_input = grad_input.sum(dim=1)

        if 'B' in locals(): del B
        if ctx.needs_input_grad[2]:
            ws = tuple(int(i) for i in -(np.asarray(wshape[-2:]) // 2))
            ws = (0, 0, 0, 0) + ws
            U = cmul(conj(Z), Y.div(D.pow(2)))
            U = U[..., 0].unsqueeze(-1).unsqueeze(2)
            U = U.mul(G.unsqueeze(0))
            U = torch.irfft(U, 2, signal_sizes=grad_output.shape[-2:])
            U = utils.shift_transpose(U, ws, bc='circular')
            U = U[..., 0:wshape[3], 0:wshape[4]]
            grad_weights = -2 * U.mul(
                alpha.unsqueeze(0).unsqueeze(2).unsqueeze(-1).unsqueeze(-1))
            del U
            grad_weights = grad_weights.sum(dim=0)
            if wshape[2] == 1:
                grad_weights = grad_weights.sum(dim=2, keepdim=True)
            if wshape[0] == 1 and alpha.size(0) != 1:
                grad_weights = grad_weights.sum(dim=0)

        if 'Z' in locals(): del Z
        if ctx.needs_input_grad[3]:
            Y = Y.mul(cabs(G).pow(2).sum(dim=1).unsqueeze(0).unsqueeze(-1))
            Y = Y.div(D.pow(2))
            Y = torch.irfft(Y, 2, signal_sizes=grad_output.shape[-2:])
            Y = Y.mul(-alpha.unsqueeze(0).unsqueeze(-1).unsqueeze(-1))
            Y = Y.mul(grad_output)
            grad_alpha = Y.sum(dim=4).sum(dim=3).sum(dim=0)
            if channels != 1 and alpha.size(-1) == 1:
                grad_alpha = grad_alpha.sum(dim=-1, keepdim=True)

        return grad_input, None, grad_weights, grad_alpha
Esempio n. 4
0
    def forward(self, x):
        # model point from infinity
        input_field = torch.ones((self.resolution, self.resolution))

        phase_delay = utils.heightmap_to_phase(self.heightmap,
                                               self.wavelength,
                                               self.refractive_idc)

        field = optics.propagate_through_lens(input_field, phase_delay)

        field = optics.circular_aperture(field, self.r_cutoff)

        # kernel_type = 'fresnel_conv' leads to  nans
        element = Propagation(kernel_type='fresnel',
                              propagation_distances=self.focal_length,
                              slm_resolution=[self.resolution, self.resolution],
                              slm_pixel_pitch=[self.pixel_pitch, self.pixel_pitch],
                              wavelength=self.wavelength)

        field = element.forward(field)
        psf = utils.field_to_intensity(field)

        psf /= psf.sum()

        final = optics.convolve_img(x, psf)
        if not self.use_wiener:
            return final.to(DEVICE)
        else:
            # perform Wiener filtering
            final = final.to(DEVICE)
            imag = torch.zeros(final.shape).to(DEVICE)
            img = utils.stack_complex(final, imag)
            img_fft = torch.fft(utils.ifftshift(img), 2)

            otf = optics.psf2otf(psf, output_size=img.shape[2:4])

            otf = torch.stack((otf, otf, otf), 0)
            otf = torch.unsqueeze(otf, 0)
            conj_otf = utils.conj(otf)

            otf_img = utils.mul_complex(img_fft, conj_otf)

            denominator = optics.abs_complex(otf)
            denominator[:, :, :, :, 0] += self.K
            product = utils.div_complex(otf_img, denominator)

            filtered = utils.ifftshift(torch.ifft(product, 2))
            filtered = torch.clamp(filtered, min=1e-5)

            return filtered[:, :, :, :, 0]
    def run(self):
        # run confocal diffuse tomography reconstruction

        with h5py.File('./data/' + self.scene + '.mat', 'r') as f:
            meas = np.array(f['meas']).transpose(2, 1, 0)
        f.close()

        # trim scene to 1 meter along the z-dimension
        # and downsample to ~50 ps time binning from 16 ps
        b = meas[:417, :, :]
        downsampled = np.zeros((self.Nz, 32, 32))
        for i in range(meas.shape[1]):
            for j in range(meas.shape[2]):
                x = np.linspace(0, 1, self.Nz)
                xp = np.linspace(0, 1, 417)
                yp = b[:, i, j].squeeze()
                downsampled[:, i, j] = np.interp(x, xp, yp)
        b = downsampled
        b /= np.max(b)  # normalize to 0 to 1

        # initialize pytorch arrays
        b = torch.from_numpy(b).to(self.device)[None, None, :, :, :].float()
        x = torch.zeros(b.size()[0], 1, 2 * self.Nz, 2 * self.Nx,
                        2 * self.Ny).to(self.device)

        # construct inverse psf for Wiener filtering
        tmp = compl_mul(self.diffusion_fpsf, conj(self.diffusion_fpsf))
        tmp = tmp + 1 / self.snr
        invpsf = compl_mul(conj(self.diffusion_fpsf), 1 / tmp)

        # measure inversion runtime
        if self.device.type == 'cpu':
            start = time.time()
        else:
            start = torch.cuda.Event(enable_timing=True)
            end = torch.cuda.Event(enable_timing=True)
            start.record()

        # pad measurements
        x = self.MT(b)

        # perform f-k migration on measurements
        x_fk = self.AT(x)

        # perform deconvolution
        x_deconv = compl_mul(x.rfft(3, onesided=False),
                             invpsf).ifft(3)[:, :, :, :, :, 0]

        # confocal inverse filter
        x = self.AT(x_deconv)

        # measure elapsed time
        if self.device.type == 'cpu':
            stop = time.time()
            print('Elapsed time: %.02f ms' % (1000 * (stop - start)))
        else:
            end.record()
            torch.cuda.synchronize()
            print('Elapsed time: %.02f ms' % (start.elapsed_time(end)))

        # plot results
        x_npy = x.cpu().data.numpy().squeeze()[:self.Nz, :self.Nx, :self.Ny]
        b_npy = b.cpu().data.numpy().squeeze()
        x_deconv_npy = x_deconv.cpu().data.numpy().squeeze()[:self.Nz, :self.
                                                             Nx, :self.Ny]
        x_fk_npy = x_fk.cpu().data.numpy().squeeze()[:self.Nz, :self.Nx, :self.
                                                     Ny]

        # trim any amplified noise at the very end of the volume
        x_npy[-15:, :, :] = 0

        if self.pause > 0:
            plt.suptitle('Measurements and reconstruction')
            plt.subplot(231)
            plt.imshow(np.max(b_npy, axis=0),
                       cmap='gray',
                       extent=[self.xmin, self.xmax, self.ymin, self.ymax])
            plt.xlabel('x (m)')
            plt.ylabel('y (m)')
            plt.subplot(232)
            plt.imshow(
                np.max(b_npy, axis=1),
                aspect=(self.xmax - self.xmin) / (self.zmax / 3e8 * 1e9),
                cmap='gray',
                extent=[
                    self.xmin, self.xmax, self.zmax / 3e8 * 1e9, self.zmin
                ])
            plt.xlabel('x (m)')
            plt.ylabel('t (ns)')
            plt.subplot(233)
            plt.imshow(
                np.max(b_npy, axis=2),
                aspect=(self.ymax - self.ymin) / (self.zmax / 3e8 * 1e9),
                cmap='gray',
                extent=[
                    self.ymin, self.ymax, self.zmax / 3e8 * 1e9, self.zmin
                ])
            plt.xlabel('y (m)')
            plt.ylabel('t (ns)')

            plt.subplot(234)
            plt.imshow(np.max(x_npy, axis=0),
                       cmap='gray',
                       extent=[self.xmin, self.xmax, self.ymin, self.ymax])
            plt.xlabel('x (m)')
            plt.ylabel('y (m)')
            plt.subplot(235)
            plt.imshow(np.max(x_npy, axis=1),
                       aspect=(self.xmax - self.xmin) / (self.zmax / 2),
                       cmap='gray',
                       extent=[self.xmin, self.xmax, self.zmax / 2, self.zmin])
            plt.xlabel('x (m)')
            plt.ylabel('z (m)')
            plt.subplot(236)
            plt.imshow(np.max(x_npy, axis=2),
                       aspect=(self.ymax - self.ymin) / (self.zmax / 2),
                       cmap='gray',
                       extent=[self.ymin, self.ymax, self.zmax / 2, self.zmin])
            plt.xlabel('y (m)')
            plt.ylabel('z (m)')
            plt.tight_layout()

            plt.pause(self.pause)

        # return measurements, deconvolved meas, reconstruction
        return b_npy, x_fk_npy, x_deconv_npy, x_npy
Esempio n. 6
0
def compute_inverse_filter_fft_penalized(ker, eps, ps, betas):
    """
    fts (len(betas), 3, ps, ps)
    """
    nks, hks, wks = ker.shape

    if wks < 3 or hks < 3:
        hei = max(3, hks)
        wid = max(3, wks)
        ker2 = torch.zeros(nks, hei, wid, device=ker.device)
        ker2[:, hei // 2 - hks // 2:hei // 2 + hks // 2 + 1,
             wid // 2 - wks // 2:wid // 2 + wks // 2 + 1] = ker
        ker = ker2
        _, hks, wks = ker.shape
    centx = wks // 2
    centy = hks // 2
    hps = wks // 2

    grad_y = torch.zeros(1, 3, 3, device=ker.device)
    grad_y[0, 1, 0] = -1
    grad_y[0, 1, 1] = 1
    grad_x = grad_y.transpose(1, 2)

    grad = torch.zeros(2, hks, wks, device=ker.device)
    grad[0, centx - 1:centx + 2, centy - 1:centy + 2] = grad_y
    grad[1, centx - 1:centx + 2, centy - 1:centy + 2] = grad_x

    # compute denom
    otfks = []
    for n in range(nks):
        otfks.append(psf2otf(ker[n], (ps, ps)))
    otfks.append(psf2otf(grad_y[0], (ps, ps)))
    otfks.append(psf2otf(grad_x[0], (ps, ps)))

    mod2_otfks = []
    for n in range(nks + 2):
        K = otfks[n]
        mod2_otfks.append(utils.prod(utils.conj(K), K))
    sum_mod2 = torch.stack(mod2_otfks).sum(0)

    inv_filters = {'beta': [], 'ker': [], 'fts': []}
    inv_filters['beta'] = betas
    inv_filters['ker'] = torch.cat([ker, grad])

    for beta in betas:
        denom = sum_mod2.mul(beta) + eps
        ft = torch.zeros(nks, 3, ps, ps, device=ker.device)
        for n in range(nks):
            K = otfks[n]
            D = utils.conj(K) / denom
            d = otf2psf(D, (ps, ps))
            ft[n, 0] = d
            K = otfks[-2]
            D = utils.conj(K) / denom
            d = otf2psf(D, (ps, ps))
            ft[n, 1] = d
            K = otfks[-1]
            D = utils.conj(K) / denom
            d = otf2psf(D, (ps, ps))
            ft[n, 2] = d
            inv_filters['fts'].append(ft)
    inv_filters['fts'] = torch.cat(inv_filters['fts'])
    return inv_filters
Esempio n. 7
0
    def forward(ctx, input, blurKernel, weights, alpha):
        """
        Wiener Filter for a batch of input images. (Filtering is taking place in the Frequency domain under the
                 assumption of periodic boundary conditions for the input image.)

        input: (torch.(cuda.)Tensor) Input image tensor of size B x C x H x W
        blurKernel: (torch.(cuda.)Tensor) PSFs tensor of size B x C x Hk x Wk
        weights: (torch.(cuda.)Tensor) Regularization kernels of size D x C x Hw x Ww
        alpha: (float) Regularization parameter of shape 1 x 1
        returns: (torch.(cuda.)Tensor) Wiener filter output tensor B x 1 x C x H x H

        output = F^H (B^H*F(input)/(|B|^2+exp(alpha)*|W|^2))
        """

        assert (input.dim() < 5), "The input must be at most a 4D tensor."
        while input.dim() < 4:
            input = input.unsqueeze(0)

        batch = input.size(0)
        channels = input.size(1)

        assert (blurKernel.dim() <
                5), "The blurring kernel must be at most a 4D tensor."
        while blurKernel.dim() < 4:
            blurKernel = blurKernel.unsqueeze(0)

        bshape = tuple(blurKernel.shape)
        assert (bshape[0] in (1, batch) and bshape[1]
                in (1, channels)), "Invalid blurring kernel dimensions."

        N = alpha.size(0)
        assert (alpha.dim() == 2 and alpha.size(-1) in (1, channels)), \
            "Invalid dimensions for the alpha parameter. The expected shape of the " \
            + "tensor is {} x [{}|{}]".format(N, 1, channels)
        alpha = alpha.exp()

        assert (weights.dim() > 3 and weights.dim() < 6), "The regularization " \
                                                          + "kernel must be a 4D or 5D tensor."

        if weights.dim() < 5:
            weights = weights.unsqueeze(0)

        wshape = tuple(weights.shape)
        assert (wshape[0] in (1, N) and wshape[2] in (1, channels)), \
            "Invalid regularization kernel dimensions."

        # Zero-padding of the blur kernel to match the input size
        B = torch.zeros(bshape[0], bshape[1], input.size(2),
                        input.size(3)).type_as(blurKernel)
        B[..., 0:bshape[2], 0:bshape[3]] = blurKernel
        del blurKernel
        # Circular shift of the zero-padded blur kernel
        bs = tuple(int(i) for i in -(np.asarray(bshape[-2:]) // 2))
        bs = (0, 0) + bs
        B = utils.shift(B, bs, bc='circular')
        # FFT of B
        B = torch.rfft(B, 2)

        # Zero-padding of the spatial dimensions of the weights to match the input size
        G = torch.zeros(wshape[0], wshape[1], wshape[2], input.size(2),
                        input.size(3)).type_as(weights)
        G[..., 0:wshape[3], 0:wshape[4]] = weights
        del weights
        # circular shift of the zero-padded weights
        ws = tuple(int(i) for i in -(np.asarray(wshape[-2:]) // 2))
        ws = (0, 0, 0) + ws
        G = utils.shift(G, ws, bc='circular')
        # FFT of G
        G = torch.rfft(G, 2)

        Y = cmul(conj(B), torch.rfft(input, 2)).unsqueeze(1)

        ctx.intermediate_results = tuple()
        if ctx.needs_input_grad[2] or ctx.needs_input_grad[3]:
            ctx.intermediate_results += (alpha, B, G, Y, wshape)
        elif ctx.needs_input_grad[0]:
            ctx.intermediate_results += (alpha, B, G)

        B = cabs(B).unsqueeze(-1)
        G = cabs(G).pow(2).sum(dim=1)

        G = G.mul(alpha.unsqueeze(-1).unsqueeze(-1)).unsqueeze(0).unsqueeze(-1)

        G = G + B.pow(2).unsqueeze(1)
        return torch.irfft(Y.div(G), 2, signal_sizes=input.shape[-2:])