Exemplo n.º 1
0
def double_phase_amplitude_coding(target_phase,
                                  target_amp,
                                  prop_dist,
                                  wavelength,
                                  feature_size,
                                  prop_model='ASM',
                                  propagator=None,
                                  dtype=torch.float32,
                                  precomputed_H=None):
    """
    Use a single propagation and converts amplitude and phase to double phase coding

    Input
    -----
    :param target_phase: The phase at the target image plane
    :param target_amp: A tensor, (B,C,H,W), the amplitude at the target image plane.
    :param prop_dist: propagation distance, in m.
    :param wavelength: wavelength, in m.
    :param feature_size: The SLM pixel pitch, in meters.
    :param prop_model: The light propagation model to use for prop from target plane to slm plane
    :param propagator: propagation_ASM
    :param dtype: torch datatype for computation at different precision.
    :param precomputed_H: pre-computed kernel - to make it faster over multiple iteration/images - calculate it once

    Output
    ------
    :return: a tensor, the optimized phase pattern at the SLM plane, in the shape of (1,1,H,W)
    """
    real, imag = utils.polar_to_rect(target_amp, target_phase)
    target_field = torch.complex(real, imag)

    slm_field = utils.propagate_field(target_field, propagator, prop_dist,
                                      wavelength, feature_size, prop_model,
                                      dtype, precomputed_H)

    slm_phase = double_phase(slm_field, three_pi=False, mean_adjust=True)

    return slm_phase
Exemplo n.º 2
0
def combine_zernike_basis(coeffs, basis, return_phase=False):
    """
    Multiplies the Zernike coefficients and basis functions while preserving
    dimensions

    :param coeffs: torch tensor with coeffs, see propagation_ASM_zernike
    :param basis: the output of compute_zernike_basis, must be same length as coeffs
    :param return_phase:
    :return: A Complex64 tensor that combines coeffs and basis.
    """

    if len(coeffs.shape) < 3:
        coeffs = torch.reshape(coeffs, (coeffs.shape[0], 1, 1))

    # combine zernike basis and coefficients
    zernike = (coeffs * basis).sum(0, keepdim=True)

    # shape to [1, len(coeffs), H, W]
    zernike = zernike.unsqueeze(0)

    # convert to Pytorch Complex tensor
    real, imag = utils.polar_to_rect(torch.ones_like(zernike), zernike)
    return torch.complex(real, imag)
def propagation_ASM(u_in,
                    feature_size,
                    wavelength,
                    z,
                    linear_conv=True,
                    padtype='zero',
                    return_H=False,
                    precomped_H=None,
                    return_H_exp=False,
                    precomped_H_exp=None,
                    dtype=torch.float32):
    """Propagates the input field using the angular spectrum method

    Inputs
    ------
    u_in: complex field of size (num_images, 1, height, width, 2)
        where the last two channels are real and imaginary values
    feature_size: (height, width) of individual holographic features in m
    wavelength: wavelength in m
    z: propagation distance
    linear_conv: if True, pad the input to obtain a linear convolution
    padtype: 'zero' to pad with zeros, 'median' to pad with median of u_in's
        amplitude
    return_H[_exp]: used for precomputing H or H_exp, ends the computation early
        and returns the desired variable
    precomped_H[_exp]: the precomputed value for H or H_exp
    dtype: torch dtype for computation at different precision

    Output
    ------
    tensor of size (num_images, 1, height, width, 2)
    """

    if linear_conv:
        # preprocess with padding for linear conv.
        input_resolution = u_in.size()[-3:-1]
        conv_size = [i * 2 for i in input_resolution]
        if padtype == 'zero':
            padval = 0
        elif padtype == 'median':
            padval = torch.median(torch.pow((u_in**2).sum(-1), 0.5))
        u_in = utils.pad_image(u_in, conv_size, padval=padval)

    if precomped_H is None and precomped_H_exp is None:
        # resolution of input field, should be: (num_images, num_channels, height, width, 2)
        field_resolution = u_in.size()

        # number of pixels
        num_y, num_x = field_resolution[2], field_resolution[3]

        # sampling inteval size
        dy, dx = feature_size

        # size of the field
        y, x = (dy * float(num_y), dx * float(num_x))

        # frequency coordinates sampling
        fy = np.linspace(-1 / (2 * dy) + 0.5 / (2 * y),
                         1 / (2 * dy) - 0.5 / (2 * y), num_y)
        fx = np.linspace(-1 / (2 * dx) + 0.5 / (2 * x),
                         1 / (2 * dx) - 0.5 / (2 * x), num_x)

        # momentum/reciprocal space
        FX, FY = np.meshgrid(fx, fy)

        # transfer function in numpy (omit distance)
        HH = 2 * math.pi * np.sqrt(1 / wavelength**2 - (FX**2 + FY**2))

        # create tensor & upload to device (GPU)
        H_exp = torch.tensor(HH, dtype=dtype).to(u_in.device)

        ###
        # here one may iterate over multiple distances, once H_exp is uploaded on GPU

        # reshape tensor and multiply
        H_exp = torch.reshape(H_exp, (1, 1, *H_exp.size()))

    # handle loading the precomputed H_exp value, or saving it for later runs
    elif precomped_H_exp is not None:
        H_exp = precomped_H_exp

    if precomped_H is None:
        # multiply by distance
        H_exp = torch.mul(H_exp, z)

        # band-limited ASM - Matsushima et al. (2009)
        fy_max = 1 / np.sqrt((2 * z * (1 / y))**2 + 1) / wavelength
        fx_max = 1 / np.sqrt((2 * z * (1 / x))**2 + 1) / wavelength
        H_filter = torch.tensor(
            ((np.abs(FX) < fx_max) & (np.abs(FY) < fy_max)).astype(np.uint8),
            dtype=dtype)

        # get real/img components
        H_real, H_imag = utils.polar_to_rect(H_filter.to(u_in.device), H_exp)

        H = torch.stack((H_real, H_imag), 4)
        H = utils.ifftshift(H)
    else:
        H = precomped_H

    # return for use later as precomputed inputs
    if return_H_exp:
        return H_exp
    if return_H:
        return H

    # angular spectrum
    U1 = torch.fft(utils.ifftshift(u_in), 2, True)

    # convolution of the system
    U2 = utils.mul_complex(H, U1)

    # Fourier transform of the convolution to the observation plane
    u_out = utils.fftshift(torch.ifft(U2, 2, True))

    if linear_conv:
        return utils.crop_image(u_out, input_resolution)
    else:
        return u_out
    def forward(self, phases, skip_lut=False, skip_tm=False):

        # Section 5.1.3. Modeling Phase Nonlinearity
        if self.process_phase is not None and not skip_lut:
            if self.latent_code is not None:
                # support mini-batch
                processed_phase = self.process_phase(phases, self.latent_code.repeat(phases.shape[0], 1, 1, 1))
            else:
                processed_phase = self.process_phase(phases)
        else:
            processed_phase = phases

        # Section 5.1.1. Create Source Amplitude (DC + gaussians)
        if self.source_amp is not None:
            source_amp = self.source_amp(processed_phase)
        else:
            source_amp = torch.ones_like(processed_phase)

        # convert phase to real and imaginary
        real, imag = utils.polar_to_rect(source_amp, processed_phase)
        processed_complex = torch.complex(real, imag)

        # Section 5.1.2. precompute the zernike basis only once
        if self.zernike is None and self.coeffs is not None:
            self.zernike = compute_zernike_basis(self.coeffs.size()[0],
                                                 phases.size()[-2:], wo_piston=True)
            self.zernike = self.zernike.to(self.dev).detach()
            self.zernike.requires_grad = False

        if self.zernike_fourier is None and self.coeffs_fourier is not None:
            self.zernike_fourier = compute_zernike_basis(self.coeffs_fourier.size()[0],
                                                         [i * 2 for i in phases.size()[-2:]],
                                                         wo_piston=True)
            self.zernike_fourier = self.zernike_fourier.to(self.dev).detach()
            self.zernike_fourier.requires_grad = False

        if not self.training and self.zernike_eval is None and self.coeffs is not None:
            # sum the phases
            self.zernike_eval = combine_zernike_basis(self.coeffs, self.zernike)
            self.zernike_eval = self.zernike_eval.to(self.coeffs.device).detach()
            self.zernike_eval.requires_grad = False

        if not self.training and self.zernike_eval_fourier is None and self.coeffs_fourier is not None:
            # sum the phases
            self.zernike_eval_fourier = combine_zernike_basis(self.coeffs_fourier, self.zernike_fourier)
            self.zernike_eval_fourier = utils.ifftshift(self.zernike_eval_fourier)
            self.zernike_eval_fourier = self.zernike_eval_fourier.to(self.coeffs_fourier.device).detach()
            self.zernike_eval_fourier.requires_grad = False

        # precompute the kernel only once
        if self.learn_dist and self.training:
            self.precompute_H_exp(processed_complex)
        else:
            self.precompute_H(processed_complex)

        # Section 5.1.2. apply zernike and propagate
        if self.training:
            if self.coeffs_fourier is None:
                output_complex = self.prop_zernike(processed_complex,
                                                   self.feature_size,
                                                   self.wavelength,
                                                   self.distance,
                                                   coeffs=self.coeffs,
                                                   zernike=self.zernike,
                                                   precomped_H=self.precomped_H,
                                                   precomped_H_exp=self.precomped_H_exp,
                                                   linear_conv=self.linear_conv)
            else:
                output_complex = self.prop_zernike_fourier(processed_complex,
                                                           self.feature_size,
                                                           self.wavelength,
                                                           self.distance,
                                                           coeffs=self.coeffs_fourier,
                                                           zernike=self.zernike_fourier,
                                                           precomped_H=self.precomped_H,
                                                           precomped_H_exp=self.precomped_H_exp,
                                                           linear_conv=self.linear_conv)

        else:
            if self.coeffs is not None:
                # in primal domain
                processed_zernike = self.zernike_eval * processed_complex
            else:
                processed_zernike = processed_complex

            if self.coeffs_fourier is not None:
                # in fourier domain
                precomped_H = self.zernike_eval_fourier * self.precomped_H
            else:
                precomped_H = self.precomped_H

            output_complex = self.prop(processed_zernike,
                                       self.feature_size,
                                       self.wavelength,
                                       self.distance,
                                       precomped_H=precomped_H,
                                       linear_conv=self.linear_conv)

        # Section 5.1.1. Content-independent field at target plane
        if self.target_constant_amp is not None:
            real, imag = utils.polar_to_rect(self.target_constant_amp, self.target_constant_phase)
            target_field = torch.complex(real, imag)
            output_complex = output_complex + target_field

        # Section 5.1.4. Content-dependent Undiffracted light
        if self.content_dependent_field is not None:
            if self.latent_coords is not None:
                cdf = self.content_dependent_field(phases, self.latent_coords.repeat(phases.shape[0], 1, 1, 1))
            else:
                cdf = self.content_dependent_field(phases)
            real, imag = utils.polar_to_rect(cdf[..., 0], cdf[..., 1])
            cdf_rect = torch.complex(real, imag)
            output_complex = output_complex + cdf_rect

        amp = output_complex.abs()
        _, phase = utils.rect_to_polar(output_complex.real, output_complex.imag)

        if self.blur is not None:
            amp = self.blur(amp)

        real, imag = utils.polar_to_rect(amp, phase)

        return torch.complex(real, imag)
Exemplo n.º 5
0
    def forward(self, target_amp):
        # compute some initial phase, convert to real+imag representation
        if self.initial_phase is not None:
            init_phase = self.initial_phase(target_amp)
            real, imag = utils.polar_to_rect(target_amp, init_phase)
            target_complex = torch.complex(real, imag)
        else:
            init_phase = torch.zeros_like(target_amp)
            # no need to convert, zero phase implies amplitude = real part
            target_complex = torch.complex(target_amp, init_phase)

        # subtract the additional target field
        if self.target_field is not None:
            target_complex_diff = target_complex - self.target_field
        else:
            target_complex_diff = target_complex

        # precompute the propagation kernel only once
        if self.precomped_H is None:
            self.precomped_H = self.prop(target_complex_diff,
                                         self.feature_size,
                                         self.wavelength,
                                         self.distance,
                                         return_H=True,
                                         linear_conv=self.linear_conv)
            self.precomped_H = self.precomped_H.to(self.dev).detach()
            self.precomped_H.requires_grad = False

        if self.precomped_H_zernike is None:
            if self.zernike is None and self.zernike_coeffs is not None:
                self.zernike_basis = compute_zernike_basis(self.zernike_coeffs.size()[0],
                                                           [i * 2 for i in target_amp.size()[-2:]], wo_piston=True)
                self.zernike_basis = self.zernike_basis.to(self.dev).detach()
                self.zernike = combine_zernike_basis(self.zernike_coeffs, self.zernike_basis)
                self.zernike = utils.ifftshift(self.zernike)
                self.zernike = self.zernike.to(self.dev).detach()
                self.zernike.requires_grad = False
                self.precomped_H_zernike = self.zernike * self.precomped_H
                self.precomped_H_zernike = self.precomped_H_zernike.to(self.dev).detach()
                self.precomped_H_zernike.requires_grad = False
            else:
                self.precomped_H_zernike = self.precomped_H

        # precompute the source amplitude, only once
        if self.source_amp is None and self.source_amplitude is not None:
            self.source_amp = self.source_amplitude(target_amp)
            self.source_amp = self.source_amp.to(self.dev).detach()
            self.source_amp.requires_grad = False

        # implement the basic propagation to the SLM plane
        slm_naive = self.prop(target_complex_diff, self.feature_size,
                              self.wavelength, self.distance,
                              precomped_H=self.precomped_H_zernike,
                              linear_conv=self.linear_conv)

        # switch to amplitude+phase and apply source amplitude adjustment
        amp, ang = utils.rect_to_polar(slm_naive.real, slm_naive.imag)
        # amp, ang = slm_naive.abs(), slm_naive.angle()  # PyTorch 1.7.0 Complex tensor doesn't support
                                                         # the gradient of angle() currently.

        if self.source_amp is not None and self.manual_aberr_corr:
            amp = amp / self.source_amp

        if self.final_phase_only is None:
            return amp, double_phase(amp, ang, three_pi=False)
        else:
            # note the change to usual complex number stacking!
            # We're making this the channel dim via cat instead of stack
            if (self.zernike is None and self.source_amp is None
                    or self.manual_aberr_corr):
                if self.latent_codes is not None:
                    slm_amp_phase = torch.cat((amp, ang, self.latent_codes.repeat(amp.shape[0], 1, 1, 1)), -3)
                else:
                    slm_amp_phase = torch.cat((amp, ang), -3)
            elif self.zernike is None:
                slm_amp_phase = torch.cat((amp, ang, self.source_amp), -3)
            elif self.source_amp is None:
                slm_amp_phase = torch.cat((amp, ang, self.zernike), -3)
            else:
                slm_amp_phase = torch.cat((amp, ang, self.zernike,
                                           self.source_amp), -3)
            return amp, self.final_phase_only(slm_amp_phase)
Exemplo n.º 6
0
                                feature_size=feature_size,
                                wavelength=wavelength,
                                blur=blur).to(device)
    model_prop.load_state_dict(torch.load(opt.model_path, map_location=device))

    # Here, we crop model parameters to match the Holonet resolution, which is slightly different from 1080p.
    # parameters for CITL model
    zernike_coeffs = model_prop.coeffs_fourier
    source_amplitude = model_prop.source_amp
    latent_codes = model_prop.latent_code
    latent_codes = utils.crop_image(latent_codes, target_shape=image_res, pytorch=True, stacked_complex=False)

    # content independent target field (See Section 5.1.1.)
    u_t_amp = utils.crop_image(model_prop.target_constant_amp, target_shape=image_res, stacked_complex=False)
    u_t_phase = utils.crop_image(model_prop.target_constant_phase, target_shape=image_res, stacked_complex=False)
    real, imag = utils.polar_to_rect(u_t_amp, u_t_phase)
    u_t = torch.complex(real, imag)

    # match the shape of forward model parameters (1072, 1920)

    # If you make it nn.Parameter, the Holonet will include these parameters in the torch.save files
    model_prop.latent_code = nn.Parameter(latent_codes.detach(), requires_grad=False)
    model_prop.target_constant_amp = nn.Parameter(u_t_amp.detach(), requires_grad=False)
    model_prop.target_constant_phase = nn.Parameter(u_t_phase.detach(), requires_grad=False)

    # But as these parameters are already in the CITL-calibrated models,
    # you can load these from those models avoiding duplicated saves.

model_prop.eval()  # ensure freezing propagation model

# create new phase generator object
Exemplo n.º 7
0
    print(f'    - running for img_{target_idx}...')

    # crop to ROI
    target_amp = utils.crop_image(target_amp, target_shape=roi_res, stacked_complex=False).to(device)

    recon_amp = []

    # for each channel, propagate wave from the SLM plane to the image plane and get the reconstructed image.
    for c in chs:
        # load and invert phase (our SLM setup)
        phase_filename = os.path.join(opt.root_path, chan_strs[c], f'{target_idx}.png')
        slm_phase = skimage.io.imread(phase_filename) / 255.
        slm_phase = torch.tensor((1 - slm_phase) * 2 * np.pi - np.pi, dtype=dtype).reshape(1, 1, *slm_res).to(device)

        # propagate field
        real, imag = utils.polar_to_rect(torch.ones_like(slm_phase), slm_phase)
        slm_field = torch.complex(real, imag)

        if opt.prop_model.upper() == 'MODEL':
            propagator = propagators[c]  # Select CITL-calibrated models for each channel
        recon_field = utils.propagate_field(slm_field, propagator, prop_dists[c], wavelengths[c], feature_size,
                                            opt.prop_model, dtype)

        # cartesian to polar coordinate
        recon_amp_c = recon_field.abs()

        # crop to ROI
        recon_amp_c = utils.crop_image(recon_amp_c, target_shape=roi_res, stacked_complex=False)

        # append to list
        recon_amp.append(recon_amp_c)
Exemplo n.º 8
0
def stochastic_gradient_descent(init_phase,
                                target_amp,
                                num_iters,
                                prop_dist,
                                wavelength,
                                feature_size,
                                roi_res=None,
                                phase_path=None,
                                prop_model='ASM',
                                propagator=None,
                                loss=nn.MSELoss(),
                                lr=0.01,
                                lr_s=0.003,
                                s0=1.0,
                                citl=False,
                                camera_prop=None,
                                writer=None,
                                dtype=torch.float32,
                                precomputed_H=None):
    """
    Given the initial guess, run the SGD algorithm to calculate the optimal phase pattern of spatial light modulator.

    Input
    ------
    :param init_phase: a tensor, in the shape of (1,1,H,W), initial guess for the phase.
    :param target_amp: a tensor, in the shape of (1,1,H,W), the amplitude of the target image.
    :param num_iters: the number of iterations to run the SGD.
    :param prop_dist: propagation distance in m.
    :param wavelength: wavelength in m.
    :param feature_size: the SLM pixel pitch, in meters, default 6.4e-6
    :param roi_res: a tuple of integer, region of interest, like (880, 1600)
    :param phase_path: a string, for saving intermediate phases
    :param prop_model: a string, that indicates the propagation model. ('ASM' or 'MODEL')
    :param propagator: predefined function or model instance for the propagation.
    :param loss: loss function, default L2
    :param lr: learning rate for optimization variables
    :param lr_s: learning rate for learnable scale
    :param s0: initial scale
    :param writer: Tensorboard writer instance
    :param dtype: default torch.float32
    :param precomputed_H: A Pytorch complex64 tensor, pre-computed kernel shape of (1,1,2H,2W) for fast computation.

    Output
    ------
    :return: a tensor, the optimized phase pattern at the SLM plane, in the shape of (1,1,H,W)
    """

    device = init_phase.device
    s = torch.tensor(s0, requires_grad=True, device=device)

    # phase at the slm plane
    slm_phase = init_phase.requires_grad_(True)

    # optimization variables and adam optimizer
    optvars = [{'params': slm_phase}]
    if lr_s > 0:
        optvars += [{'params': s, 'lr': lr_s}]
    optimizer = optim.Adam(optvars, lr=lr)

    # crop target roi
    target_amp = utils.crop_image(target_amp, roi_res, stacked_complex=False)

    # run the iterative algorithm
    for k in range(num_iters):
        optimizer.zero_grad()
        # forward propagation from the SLM plane to the target plane
        real, imag = utils.polar_to_rect(torch.ones_like(slm_phase), slm_phase)
        slm_field = torch.complex(real, imag)

        recon_field = utils.propagate_field(slm_field, propagator, prop_dist,
                                            wavelength, feature_size,
                                            prop_model, dtype, precomputed_H)

        # get amplitude
        recon_amp = recon_field.abs()

        # crop roi
        recon_amp = utils.crop_image(recon_amp,
                                     target_shape=roi_res,
                                     stacked_complex=False)

        # camera-in-the-loop technique
        if citl:
            captured_amp = camera_prop(slm_phase)

            # use the gradient of proxy, replacing the amplitudes
            # captured_amp is assumed that its size already matches that of recon_amp
            out_amp = recon_amp + (captured_amp - recon_amp).detach()
        else:
            out_amp = recon_amp

        # calculate loss and backprop
        lossValue = loss(s * out_amp, target_amp)
        lossValue.backward()
        optimizer.step()

        # write to tensorboard / write phase image
        # Note that it takes 0.~ s for writing it to tensorboard
        with torch.no_grad():
            if k % 50 == 0:
                print(k)
                utils.write_sgd_summary(slm_phase,
                                        out_amp,
                                        target_amp,
                                        k,
                                        writer=writer,
                                        path=phase_path,
                                        s=s,
                                        prefix='test')

    return slm_phase
Exemplo n.º 9
0
def gerchberg_saxton(init_phase,
                     target_amp,
                     num_iters,
                     prop_dist,
                     wavelength,
                     feature_size=6.4e-6,
                     phase_path=None,
                     prop_model='ASM',
                     propagator=None,
                     writer=None,
                     dtype=torch.float32,
                     precomputed_H_f=None,
                     precomputed_H_b=None):
    """
    Given the initial guess, run the SGD algorithm to calculate the optimal phase pattern of spatial light modulator

    :param init_phase: a tensor, in the shape of (1,1,H,W), initial guess for the phase.
    :param target_amp: a tensor, in the shape of (1,1,H,W), the amplitude of the target image.
    :param num_iters: the number of iterations to run the GS.
    :param prop_dist: propagation distance in m.
    :param wavelength: wavelength in m.
    :param feature_size: the SLM pixel pitch, in meters, default 6.4e-6
    :param phase_path: path to save the results.
    :param prop_model: string indicating the light transport model, default 'ASM'. ex) 'ASM', 'fresnel', 'model'
    :param propagator: predefined function or model instance for the propagation.
    :param writer: tensorboard writer
    :param dtype: torch datatype for computation at different precision, default torch.float32.
    :param precomputed_H_f: A Pytorch complex64 tensor, pre-computed kernel for forward prop (SLM to image)
    :param precomputed_H_b: A Pytorch complex64 tensor, pre-computed kernel for backward propagation (image to SLM)

    Output
    ------
    :return: a tensor, the optimized phase pattern at the SLM plane, in the shape of (1,1,H,W)
    """

    # initial guess; random phase
    real, imag = utils.polar_to_rect(torch.ones_like(init_phase), init_phase)
    slm_field = torch.complex(real, imag)

    # run the GS algorithm
    for k in range(num_iters):
        # SLM plane to image plane
        recon_field = utils.propagate_field(slm_field, propagator, prop_dist,
                                            wavelength, feature_size,
                                            prop_model, dtype, precomputed_H_f)

        # write to tensorboard / write phase image
        # Note that it takes 0.~ s for writing it to tensorboard
        if k > 0 and k % 10 == 0:
            print(k)
            utils.write_gs_summary(slm_field,
                                   recon_field,
                                   target_amp,
                                   k,
                                   writer,
                                   prefix='test')

        # replace amplitude at the image plane
        recon_field = utils.replace_amplitude(recon_field, target_amp)

        # image plane to SLM plane
        slm_field = utils.propagate_field(recon_field, propagator, -prop_dist,
                                          wavelength, feature_size, prop_model,
                                          dtype, precomputed_H_b)

        # amplitude constraint at the SLM plane
        slm_field = utils.replace_amplitude(slm_field,
                                            torch.ones_like(target_amp))

    # return phases
    return slm_field.angle()