예제 #1
0
def double_phase_amplitude_coding(target_phase,
                                  target_amp,
                                  prop_dist,
                                  wavelength,
                                  feature_size,
                                  prop_model='ASM',
                                  propagator=None,
                                  dtype=torch.float32,
                                  precomputed_H=None):
    """
    Use a single propagation and converts amplitude and phase to double phase coding

    Input
    -----
    :param target_phase: The phase at the target image plane
    :param target_amp: A tensor, (B,C,H,W), the amplitude at the target image plane.
    :param prop_dist: propagation distance, in m.
    :param wavelength: wavelength, in m.
    :param feature_size: The SLM pixel pitch, in meters.
    :param prop_model: The light propagation model to use for prop from target plane to slm plane
    :param propagator: propagation_ASM
    :param dtype: torch datatype for computation at different precision.
    :param precomputed_H: pre-computed kernel - to make it faster over multiple iteration/images - calculate it once

    Output
    ------
    :return: a tensor, the optimized phase pattern at the SLM plane, in the shape of (1,1,H,W)
    """
    real, imag = utils.polar_to_rect(target_amp, target_phase)
    target_field = torch.complex(real, imag)

    slm_field = utils.propagate_field(target_field, propagator, prop_dist,
                                      wavelength, feature_size, prop_model,
                                      dtype, precomputed_H)

    slm_phase = double_phase(slm_field, three_pi=False, mean_adjust=True)

    return slm_phase
예제 #2
0
    recon_amp = []

    # for each channel, propagate wave from the SLM plane to the image plane and get the reconstructed image.
    for c in chs:
        # load and invert phase (our SLM setup)
        phase_filename = os.path.join(opt.root_path, chan_strs[c], f'{target_idx}.png')
        slm_phase = skimage.io.imread(phase_filename) / 255.
        slm_phase = torch.tensor((1 - slm_phase) * 2 * np.pi - np.pi, dtype=dtype).reshape(1, 1, *slm_res).to(device)

        # propagate field
        real, imag = utils.polar_to_rect(torch.ones_like(slm_phase), slm_phase)
        slm_field = torch.complex(real, imag)

        if opt.prop_model.upper() == 'MODEL':
            propagator = propagators[c]  # Select CITL-calibrated models for each channel
        recon_field = utils.propagate_field(slm_field, propagator, prop_dists[c], wavelengths[c], feature_size,
                                            opt.prop_model, dtype)

        # cartesian to polar coordinate
        recon_amp_c = recon_field.abs()

        # crop to ROI
        recon_amp_c = utils.crop_image(recon_amp_c, target_shape=roi_res, stacked_complex=False)

        # append to list
        recon_amp.append(recon_amp_c)

    # list to tensor, scaling
    recon_amp = torch.cat(recon_amp, dim=1)
    recon_amp *= (torch.sum(recon_amp * target_amp, (-2, -1), keepdim=True)
                  / torch.sum(recon_amp * recon_amp, (-2, -1), keepdim=True))
예제 #3
0
def stochastic_gradient_descent(init_phase,
                                target_amp,
                                num_iters,
                                prop_dist,
                                wavelength,
                                feature_size,
                                roi_res=None,
                                phase_path=None,
                                prop_model='ASM',
                                propagator=None,
                                loss=nn.MSELoss(),
                                lr=0.01,
                                lr_s=0.003,
                                s0=1.0,
                                citl=False,
                                camera_prop=None,
                                writer=None,
                                dtype=torch.float32,
                                precomputed_H=None):
    """
    Given the initial guess, run the SGD algorithm to calculate the optimal phase pattern of spatial light modulator.

    Input
    ------
    :param init_phase: a tensor, in the shape of (1,1,H,W), initial guess for the phase.
    :param target_amp: a tensor, in the shape of (1,1,H,W), the amplitude of the target image.
    :param num_iters: the number of iterations to run the SGD.
    :param prop_dist: propagation distance in m.
    :param wavelength: wavelength in m.
    :param feature_size: the SLM pixel pitch, in meters, default 6.4e-6
    :param roi_res: a tuple of integer, region of interest, like (880, 1600)
    :param phase_path: a string, for saving intermediate phases
    :param prop_model: a string, that indicates the propagation model. ('ASM' or 'MODEL')
    :param propagator: predefined function or model instance for the propagation.
    :param loss: loss function, default L2
    :param lr: learning rate for optimization variables
    :param lr_s: learning rate for learnable scale
    :param s0: initial scale
    :param writer: Tensorboard writer instance
    :param dtype: default torch.float32
    :param precomputed_H: A Pytorch complex64 tensor, pre-computed kernel shape of (1,1,2H,2W) for fast computation.

    Output
    ------
    :return: a tensor, the optimized phase pattern at the SLM plane, in the shape of (1,1,H,W)
    """

    device = init_phase.device
    s = torch.tensor(s0, requires_grad=True, device=device)

    # phase at the slm plane
    slm_phase = init_phase.requires_grad_(True)

    # optimization variables and adam optimizer
    optvars = [{'params': slm_phase}]
    if lr_s > 0:
        optvars += [{'params': s, 'lr': lr_s}]
    optimizer = optim.Adam(optvars, lr=lr)

    # crop target roi
    target_amp = utils.crop_image(target_amp, roi_res, stacked_complex=False)

    # run the iterative algorithm
    for k in range(num_iters):
        optimizer.zero_grad()
        # forward propagation from the SLM plane to the target plane
        real, imag = utils.polar_to_rect(torch.ones_like(slm_phase), slm_phase)
        slm_field = torch.complex(real, imag)

        recon_field = utils.propagate_field(slm_field, propagator, prop_dist,
                                            wavelength, feature_size,
                                            prop_model, dtype, precomputed_H)

        # get amplitude
        recon_amp = recon_field.abs()

        # crop roi
        recon_amp = utils.crop_image(recon_amp,
                                     target_shape=roi_res,
                                     stacked_complex=False)

        # camera-in-the-loop technique
        if citl:
            captured_amp = camera_prop(slm_phase)

            # use the gradient of proxy, replacing the amplitudes
            # captured_amp is assumed that its size already matches that of recon_amp
            out_amp = recon_amp + (captured_amp - recon_amp).detach()
        else:
            out_amp = recon_amp

        # calculate loss and backprop
        lossValue = loss(s * out_amp, target_amp)
        lossValue.backward()
        optimizer.step()

        # write to tensorboard / write phase image
        # Note that it takes 0.~ s for writing it to tensorboard
        with torch.no_grad():
            if k % 50 == 0:
                print(k)
                utils.write_sgd_summary(slm_phase,
                                        out_amp,
                                        target_amp,
                                        k,
                                        writer=writer,
                                        path=phase_path,
                                        s=s,
                                        prefix='test')

    return slm_phase
예제 #4
0
def gerchberg_saxton(init_phase,
                     target_amp,
                     num_iters,
                     prop_dist,
                     wavelength,
                     feature_size=6.4e-6,
                     phase_path=None,
                     prop_model='ASM',
                     propagator=None,
                     writer=None,
                     dtype=torch.float32,
                     precomputed_H_f=None,
                     precomputed_H_b=None):
    """
    Given the initial guess, run the SGD algorithm to calculate the optimal phase pattern of spatial light modulator

    :param init_phase: a tensor, in the shape of (1,1,H,W), initial guess for the phase.
    :param target_amp: a tensor, in the shape of (1,1,H,W), the amplitude of the target image.
    :param num_iters: the number of iterations to run the GS.
    :param prop_dist: propagation distance in m.
    :param wavelength: wavelength in m.
    :param feature_size: the SLM pixel pitch, in meters, default 6.4e-6
    :param phase_path: path to save the results.
    :param prop_model: string indicating the light transport model, default 'ASM'. ex) 'ASM', 'fresnel', 'model'
    :param propagator: predefined function or model instance for the propagation.
    :param writer: tensorboard writer
    :param dtype: torch datatype for computation at different precision, default torch.float32.
    :param precomputed_H_f: A Pytorch complex64 tensor, pre-computed kernel for forward prop (SLM to image)
    :param precomputed_H_b: A Pytorch complex64 tensor, pre-computed kernel for backward propagation (image to SLM)

    Output
    ------
    :return: a tensor, the optimized phase pattern at the SLM plane, in the shape of (1,1,H,W)
    """

    # initial guess; random phase
    real, imag = utils.polar_to_rect(torch.ones_like(init_phase), init_phase)
    slm_field = torch.complex(real, imag)

    # run the GS algorithm
    for k in range(num_iters):
        # SLM plane to image plane
        recon_field = utils.propagate_field(slm_field, propagator, prop_dist,
                                            wavelength, feature_size,
                                            prop_model, dtype, precomputed_H_f)

        # write to tensorboard / write phase image
        # Note that it takes 0.~ s for writing it to tensorboard
        if k > 0 and k % 10 == 0:
            print(k)
            utils.write_gs_summary(slm_field,
                                   recon_field,
                                   target_amp,
                                   k,
                                   writer,
                                   prefix='test')

        # replace amplitude at the image plane
        recon_field = utils.replace_amplitude(recon_field, target_amp)

        # image plane to SLM plane
        slm_field = utils.propagate_field(recon_field, propagator, -prop_dist,
                                          wavelength, feature_size, prop_model,
                                          dtype, precomputed_H_b)

        # amplitude constraint at the SLM plane
        slm_field = utils.replace_amplitude(slm_field,
                                            torch.ones_like(target_amp))

    # return phases
    return slm_field.angle()