Ejemplo n.º 1
0
    def backward(self, input_field):
        # force transfer function device to input's device if this module
        # specifies nothing
        if (self.dev is None
                and self.backward_transfer_fn.device != input_field.device):
            self.backward_transfer_fn = self.backward_transfer_fn.to(
                input_field.device)

        if self.normalize_output:
            input_magnitude_sum = magnitude_sum(input_field)

        # compute Fourier transform of input field
        fourier_input = self.padded_fft(input_field)

        # apply transfer function for backward prop
        fourier_output = utils.mul_complex(fourier_input,
                                           self.backward_transfer_fn)

        # Fourier transform back to get output
        output_cropped = self.cropped_ifft(fourier_output, self.slm_resolution)

        if self.normalize_output:
            output_magnitude_sum = magnitude_sum(output_cropped)
            output_cropped = output_cropped * (input_magnitude_sum /
                                               output_magnitude_sum)

        return output_cropped
Ejemplo n.º 2
0
def correct_img_torch(x_s, scale, r, s, device, for_dag = True, eps = 1e-9, pad='circular'):
    conv_shape = (s.shape[2] + r.shape[2] - 1, s.shape[3] + r.shape[3] - 1)
    S = utils.fft2(s/s.sum(), conv_shape[1], conv_shape[0])
    R = utils.fft2(utils.flip(r)/r.sum(), conv_shape[1], conv_shape[0])
    Q_unscaled = utils.mul_complex(R, S)
    q_unscaled = torch.irfft(Q_unscaled, signal_ndim=2, normalized=False, onesided=False)
    q = q_unscaled[:,:,np.mod(q_unscaled.shape[2], scale)::scale, np.mod(q_unscaled.shape[3], scale)::scale]
    Q = torch.rfft(q, signal_ndim=2, normalized=False, onesided=False)

    # Q_star = utils.conj(Q)
    # abs2_Q = utils.abs2(Q)
    # H = torch.cat( (Q_star[:,:,:,:,0:1]/(abs2_Q[:,:,:,:,0:1] + eps), Q_star[:,:,:,:,1:2]/(abs2_Q[:,:,:,:,0:1] + eps)), dim=4)

    H = utils.inv_complex(Q, eps)

    h_ = torch.irfft(H, signal_ndim=2, normalized=False, onesided=False)
    h = utils.roll_y(utils.roll_x(h_/h_.sum(), -1), -1)

    x_h = utils.filter_2D_torch(x_s, utils.flip(h), device, pad=pad)

    if(for_dag):
        x_h = utils.bicubic_up(x_h, scale, device)
        x_h = utils.downsample_bicubic(x_h, scale, device)

    return x_h
Ejemplo n.º 3
0
def wiener_filter(img, psf, K):
    """ Performs Wiener filtering on a single channel
    :param img: pytorch tensor of image (N,C,H,W)
    :param psf: pytorch tensor of psf (H,W)
    :param K: damping factor (can be input through hyps or learned)
    :return: Wiener filtered image in one channel (N,C,H,W)
    """
    img = img.to(DEVICE)
    psf = psf.to(DEVICE)
    imag = torch.zeros(img.shape).to(DEVICE)
    img = utils.stack_complex(img, imag)
    img_fft = torch.fft(utils.ifftshift(img), 2)
    img_fft = img_fft.to(DEVICE)

    otf = psf2otf(psf, output_size=img.shape[2:4])
    otf = torch.stack((otf, otf, otf), 0)
    otf = torch.unsqueeze(otf, 0)

    conj_otf = utils.conj(otf)

    otf_img = utils.mul_complex(conj_otf, img_fft)

    denominator = abs_complex(otf)
    denominator[:, :, :, :, 0] += K
    product = utils.div_complex(otf_img, denominator)
    filtered = utils.ifftshift(torch.ifft(product, 2))
    filtered = torch.clamp(filtered, min=1e-5)

    return filtered[:, :, :, :, 0]
Ejemplo n.º 4
0
def propagate_through_lens(input_field, phase_delay):
    """
    Provides complex valued wave field upon hitting an optical element
    :param input_field: (H,W) tensor of phase delay of optical element
    :param phase_delay: (H,W) tensor of incoming light field
    :return: (H,W,2) complex valued incident light field
    """
    real, imag = utils.polar_to_rect(1, phase_delay)
    phase_delay = utils.stack_complex(real, imag)

    input_field = utils.stack_complex(input_field,
                                      torch.zeros(input_field.shape))
    return utils.mul_complex(input_field.cpu(), phase_delay.cpu())
Ejemplo n.º 5
0
    def forward(self, x):
        # model point from infinity
        input_field = torch.ones((self.resolution, self.resolution))

        phase_delay = utils.heightmap_to_phase(self.heightmap,
                                               self.wavelength,
                                               self.refractive_idc)

        field = optics.propagate_through_lens(input_field, phase_delay)

        field = optics.circular_aperture(field, self.r_cutoff)

        # kernel_type = 'fresnel_conv' leads to  nans
        element = Propagation(kernel_type='fresnel',
                              propagation_distances=self.focal_length,
                              slm_resolution=[self.resolution, self.resolution],
                              slm_pixel_pitch=[self.pixel_pitch, self.pixel_pitch],
                              wavelength=self.wavelength)

        field = element.forward(field)
        psf = utils.field_to_intensity(field)

        psf /= psf.sum()

        final = optics.convolve_img(x, psf)
        if not self.use_wiener:
            return final.to(DEVICE)
        else:
            # perform Wiener filtering
            final = final.to(DEVICE)
            imag = torch.zeros(final.shape).to(DEVICE)
            img = utils.stack_complex(final, imag)
            img_fft = torch.fft(utils.ifftshift(img), 2)

            otf = optics.psf2otf(psf, output_size=img.shape[2:4])

            otf = torch.stack((otf, otf, otf), 0)
            otf = torch.unsqueeze(otf, 0)
            conj_otf = utils.conj(otf)

            otf_img = utils.mul_complex(img_fft, conj_otf)

            denominator = optics.abs_complex(otf)
            denominator[:, :, :, :, 0] += self.K
            product = utils.div_complex(otf_img, denominator)

            filtered = utils.ifftshift(torch.ifft(product, 2))
            filtered = torch.clamp(filtered, min=1e-5)

            return filtered[:, :, :, :, 0]