def double_phase_amplitude_coding(target_phase, target_amp, prop_dist, wavelength, feature_size, prop_model='ASM', propagator=None, dtype=torch.float32, precomputed_H=None): """ Use a single propagation and converts amplitude and phase to double phase coding Input ----- :param target_phase: The phase at the target image plane :param target_amp: A tensor, (B,C,H,W), the amplitude at the target image plane. :param prop_dist: propagation distance, in m. :param wavelength: wavelength, in m. :param feature_size: The SLM pixel pitch, in meters. :param prop_model: The light propagation model to use for prop from target plane to slm plane :param propagator: propagation_ASM :param dtype: torch datatype for computation at different precision. :param precomputed_H: pre-computed kernel - to make it faster over multiple iteration/images - calculate it once Output ------ :return: a tensor, the optimized phase pattern at the SLM plane, in the shape of (1,1,H,W) """ real, imag = utils.polar_to_rect(target_amp, target_phase) target_field = torch.complex(real, imag) slm_field = utils.propagate_field(target_field, propagator, prop_dist, wavelength, feature_size, prop_model, dtype, precomputed_H) slm_phase = double_phase(slm_field, three_pi=False, mean_adjust=True) return slm_phase
def combine_zernike_basis(coeffs, basis, return_phase=False): """ Multiplies the Zernike coefficients and basis functions while preserving dimensions :param coeffs: torch tensor with coeffs, see propagation_ASM_zernike :param basis: the output of compute_zernike_basis, must be same length as coeffs :param return_phase: :return: A Complex64 tensor that combines coeffs and basis. """ if len(coeffs.shape) < 3: coeffs = torch.reshape(coeffs, (coeffs.shape[0], 1, 1)) # combine zernike basis and coefficients zernike = (coeffs * basis).sum(0, keepdim=True) # shape to [1, len(coeffs), H, W] zernike = zernike.unsqueeze(0) # convert to Pytorch Complex tensor real, imag = utils.polar_to_rect(torch.ones_like(zernike), zernike) return torch.complex(real, imag)
def propagation_ASM(u_in, feature_size, wavelength, z, linear_conv=True, padtype='zero', return_H=False, precomped_H=None, return_H_exp=False, precomped_H_exp=None, dtype=torch.float32): """Propagates the input field using the angular spectrum method Inputs ------ u_in: complex field of size (num_images, 1, height, width, 2) where the last two channels are real and imaginary values feature_size: (height, width) of individual holographic features in m wavelength: wavelength in m z: propagation distance linear_conv: if True, pad the input to obtain a linear convolution padtype: 'zero' to pad with zeros, 'median' to pad with median of u_in's amplitude return_H[_exp]: used for precomputing H or H_exp, ends the computation early and returns the desired variable precomped_H[_exp]: the precomputed value for H or H_exp dtype: torch dtype for computation at different precision Output ------ tensor of size (num_images, 1, height, width, 2) """ if linear_conv: # preprocess with padding for linear conv. input_resolution = u_in.size()[-3:-1] conv_size = [i * 2 for i in input_resolution] if padtype == 'zero': padval = 0 elif padtype == 'median': padval = torch.median(torch.pow((u_in**2).sum(-1), 0.5)) u_in = utils.pad_image(u_in, conv_size, padval=padval) if precomped_H is None and precomped_H_exp is None: # resolution of input field, should be: (num_images, num_channels, height, width, 2) field_resolution = u_in.size() # number of pixels num_y, num_x = field_resolution[2], field_resolution[3] # sampling inteval size dy, dx = feature_size # size of the field y, x = (dy * float(num_y), dx * float(num_x)) # frequency coordinates sampling fy = np.linspace(-1 / (2 * dy) + 0.5 / (2 * y), 1 / (2 * dy) - 0.5 / (2 * y), num_y) fx = np.linspace(-1 / (2 * dx) + 0.5 / (2 * x), 1 / (2 * dx) - 0.5 / (2 * x), num_x) # momentum/reciprocal space FX, FY = np.meshgrid(fx, fy) # transfer function in numpy (omit distance) HH = 2 * math.pi * np.sqrt(1 / wavelength**2 - (FX**2 + FY**2)) # create tensor & upload to device (GPU) H_exp = torch.tensor(HH, dtype=dtype).to(u_in.device) ### # here one may iterate over multiple distances, once H_exp is uploaded on GPU # reshape tensor and multiply H_exp = torch.reshape(H_exp, (1, 1, *H_exp.size())) # handle loading the precomputed H_exp value, or saving it for later runs elif precomped_H_exp is not None: H_exp = precomped_H_exp if precomped_H is None: # multiply by distance H_exp = torch.mul(H_exp, z) # band-limited ASM - Matsushima et al. (2009) fy_max = 1 / np.sqrt((2 * z * (1 / y))**2 + 1) / wavelength fx_max = 1 / np.sqrt((2 * z * (1 / x))**2 + 1) / wavelength H_filter = torch.tensor( ((np.abs(FX) < fx_max) & (np.abs(FY) < fy_max)).astype(np.uint8), dtype=dtype) # get real/img components H_real, H_imag = utils.polar_to_rect(H_filter.to(u_in.device), H_exp) H = torch.stack((H_real, H_imag), 4) H = utils.ifftshift(H) else: H = precomped_H # return for use later as precomputed inputs if return_H_exp: return H_exp if return_H: return H # angular spectrum U1 = torch.fft(utils.ifftshift(u_in), 2, True) # convolution of the system U2 = utils.mul_complex(H, U1) # Fourier transform of the convolution to the observation plane u_out = utils.fftshift(torch.ifft(U2, 2, True)) if linear_conv: return utils.crop_image(u_out, input_resolution) else: return u_out
def forward(self, phases, skip_lut=False, skip_tm=False): # Section 5.1.3. Modeling Phase Nonlinearity if self.process_phase is not None and not skip_lut: if self.latent_code is not None: # support mini-batch processed_phase = self.process_phase(phases, self.latent_code.repeat(phases.shape[0], 1, 1, 1)) else: processed_phase = self.process_phase(phases) else: processed_phase = phases # Section 5.1.1. Create Source Amplitude (DC + gaussians) if self.source_amp is not None: source_amp = self.source_amp(processed_phase) else: source_amp = torch.ones_like(processed_phase) # convert phase to real and imaginary real, imag = utils.polar_to_rect(source_amp, processed_phase) processed_complex = torch.complex(real, imag) # Section 5.1.2. precompute the zernike basis only once if self.zernike is None and self.coeffs is not None: self.zernike = compute_zernike_basis(self.coeffs.size()[0], phases.size()[-2:], wo_piston=True) self.zernike = self.zernike.to(self.dev).detach() self.zernike.requires_grad = False if self.zernike_fourier is None and self.coeffs_fourier is not None: self.zernike_fourier = compute_zernike_basis(self.coeffs_fourier.size()[0], [i * 2 for i in phases.size()[-2:]], wo_piston=True) self.zernike_fourier = self.zernike_fourier.to(self.dev).detach() self.zernike_fourier.requires_grad = False if not self.training and self.zernike_eval is None and self.coeffs is not None: # sum the phases self.zernike_eval = combine_zernike_basis(self.coeffs, self.zernike) self.zernike_eval = self.zernike_eval.to(self.coeffs.device).detach() self.zernike_eval.requires_grad = False if not self.training and self.zernike_eval_fourier is None and self.coeffs_fourier is not None: # sum the phases self.zernike_eval_fourier = combine_zernike_basis(self.coeffs_fourier, self.zernike_fourier) self.zernike_eval_fourier = utils.ifftshift(self.zernike_eval_fourier) self.zernike_eval_fourier = self.zernike_eval_fourier.to(self.coeffs_fourier.device).detach() self.zernike_eval_fourier.requires_grad = False # precompute the kernel only once if self.learn_dist and self.training: self.precompute_H_exp(processed_complex) else: self.precompute_H(processed_complex) # Section 5.1.2. apply zernike and propagate if self.training: if self.coeffs_fourier is None: output_complex = self.prop_zernike(processed_complex, self.feature_size, self.wavelength, self.distance, coeffs=self.coeffs, zernike=self.zernike, precomped_H=self.precomped_H, precomped_H_exp=self.precomped_H_exp, linear_conv=self.linear_conv) else: output_complex = self.prop_zernike_fourier(processed_complex, self.feature_size, self.wavelength, self.distance, coeffs=self.coeffs_fourier, zernike=self.zernike_fourier, precomped_H=self.precomped_H, precomped_H_exp=self.precomped_H_exp, linear_conv=self.linear_conv) else: if self.coeffs is not None: # in primal domain processed_zernike = self.zernike_eval * processed_complex else: processed_zernike = processed_complex if self.coeffs_fourier is not None: # in fourier domain precomped_H = self.zernike_eval_fourier * self.precomped_H else: precomped_H = self.precomped_H output_complex = self.prop(processed_zernike, self.feature_size, self.wavelength, self.distance, precomped_H=precomped_H, linear_conv=self.linear_conv) # Section 5.1.1. Content-independent field at target plane if self.target_constant_amp is not None: real, imag = utils.polar_to_rect(self.target_constant_amp, self.target_constant_phase) target_field = torch.complex(real, imag) output_complex = output_complex + target_field # Section 5.1.4. Content-dependent Undiffracted light if self.content_dependent_field is not None: if self.latent_coords is not None: cdf = self.content_dependent_field(phases, self.latent_coords.repeat(phases.shape[0], 1, 1, 1)) else: cdf = self.content_dependent_field(phases) real, imag = utils.polar_to_rect(cdf[..., 0], cdf[..., 1]) cdf_rect = torch.complex(real, imag) output_complex = output_complex + cdf_rect amp = output_complex.abs() _, phase = utils.rect_to_polar(output_complex.real, output_complex.imag) if self.blur is not None: amp = self.blur(amp) real, imag = utils.polar_to_rect(amp, phase) return torch.complex(real, imag)
def forward(self, target_amp): # compute some initial phase, convert to real+imag representation if self.initial_phase is not None: init_phase = self.initial_phase(target_amp) real, imag = utils.polar_to_rect(target_amp, init_phase) target_complex = torch.complex(real, imag) else: init_phase = torch.zeros_like(target_amp) # no need to convert, zero phase implies amplitude = real part target_complex = torch.complex(target_amp, init_phase) # subtract the additional target field if self.target_field is not None: target_complex_diff = target_complex - self.target_field else: target_complex_diff = target_complex # precompute the propagation kernel only once if self.precomped_H is None: self.precomped_H = self.prop(target_complex_diff, self.feature_size, self.wavelength, self.distance, return_H=True, linear_conv=self.linear_conv) self.precomped_H = self.precomped_H.to(self.dev).detach() self.precomped_H.requires_grad = False if self.precomped_H_zernike is None: if self.zernike is None and self.zernike_coeffs is not None: self.zernike_basis = compute_zernike_basis(self.zernike_coeffs.size()[0], [i * 2 for i in target_amp.size()[-2:]], wo_piston=True) self.zernike_basis = self.zernike_basis.to(self.dev).detach() self.zernike = combine_zernike_basis(self.zernike_coeffs, self.zernike_basis) self.zernike = utils.ifftshift(self.zernike) self.zernike = self.zernike.to(self.dev).detach() self.zernike.requires_grad = False self.precomped_H_zernike = self.zernike * self.precomped_H self.precomped_H_zernike = self.precomped_H_zernike.to(self.dev).detach() self.precomped_H_zernike.requires_grad = False else: self.precomped_H_zernike = self.precomped_H # precompute the source amplitude, only once if self.source_amp is None and self.source_amplitude is not None: self.source_amp = self.source_amplitude(target_amp) self.source_amp = self.source_amp.to(self.dev).detach() self.source_amp.requires_grad = False # implement the basic propagation to the SLM plane slm_naive = self.prop(target_complex_diff, self.feature_size, self.wavelength, self.distance, precomped_H=self.precomped_H_zernike, linear_conv=self.linear_conv) # switch to amplitude+phase and apply source amplitude adjustment amp, ang = utils.rect_to_polar(slm_naive.real, slm_naive.imag) # amp, ang = slm_naive.abs(), slm_naive.angle() # PyTorch 1.7.0 Complex tensor doesn't support # the gradient of angle() currently. if self.source_amp is not None and self.manual_aberr_corr: amp = amp / self.source_amp if self.final_phase_only is None: return amp, double_phase(amp, ang, three_pi=False) else: # note the change to usual complex number stacking! # We're making this the channel dim via cat instead of stack if (self.zernike is None and self.source_amp is None or self.manual_aberr_corr): if self.latent_codes is not None: slm_amp_phase = torch.cat((amp, ang, self.latent_codes.repeat(amp.shape[0], 1, 1, 1)), -3) else: slm_amp_phase = torch.cat((amp, ang), -3) elif self.zernike is None: slm_amp_phase = torch.cat((amp, ang, self.source_amp), -3) elif self.source_amp is None: slm_amp_phase = torch.cat((amp, ang, self.zernike), -3) else: slm_amp_phase = torch.cat((amp, ang, self.zernike, self.source_amp), -3) return amp, self.final_phase_only(slm_amp_phase)
feature_size=feature_size, wavelength=wavelength, blur=blur).to(device) model_prop.load_state_dict(torch.load(opt.model_path, map_location=device)) # Here, we crop model parameters to match the Holonet resolution, which is slightly different from 1080p. # parameters for CITL model zernike_coeffs = model_prop.coeffs_fourier source_amplitude = model_prop.source_amp latent_codes = model_prop.latent_code latent_codes = utils.crop_image(latent_codes, target_shape=image_res, pytorch=True, stacked_complex=False) # content independent target field (See Section 5.1.1.) u_t_amp = utils.crop_image(model_prop.target_constant_amp, target_shape=image_res, stacked_complex=False) u_t_phase = utils.crop_image(model_prop.target_constant_phase, target_shape=image_res, stacked_complex=False) real, imag = utils.polar_to_rect(u_t_amp, u_t_phase) u_t = torch.complex(real, imag) # match the shape of forward model parameters (1072, 1920) # If you make it nn.Parameter, the Holonet will include these parameters in the torch.save files model_prop.latent_code = nn.Parameter(latent_codes.detach(), requires_grad=False) model_prop.target_constant_amp = nn.Parameter(u_t_amp.detach(), requires_grad=False) model_prop.target_constant_phase = nn.Parameter(u_t_phase.detach(), requires_grad=False) # But as these parameters are already in the CITL-calibrated models, # you can load these from those models avoiding duplicated saves. model_prop.eval() # ensure freezing propagation model # create new phase generator object
print(f' - running for img_{target_idx}...') # crop to ROI target_amp = utils.crop_image(target_amp, target_shape=roi_res, stacked_complex=False).to(device) recon_amp = [] # for each channel, propagate wave from the SLM plane to the image plane and get the reconstructed image. for c in chs: # load and invert phase (our SLM setup) phase_filename = os.path.join(opt.root_path, chan_strs[c], f'{target_idx}.png') slm_phase = skimage.io.imread(phase_filename) / 255. slm_phase = torch.tensor((1 - slm_phase) * 2 * np.pi - np.pi, dtype=dtype).reshape(1, 1, *slm_res).to(device) # propagate field real, imag = utils.polar_to_rect(torch.ones_like(slm_phase), slm_phase) slm_field = torch.complex(real, imag) if opt.prop_model.upper() == 'MODEL': propagator = propagators[c] # Select CITL-calibrated models for each channel recon_field = utils.propagate_field(slm_field, propagator, prop_dists[c], wavelengths[c], feature_size, opt.prop_model, dtype) # cartesian to polar coordinate recon_amp_c = recon_field.abs() # crop to ROI recon_amp_c = utils.crop_image(recon_amp_c, target_shape=roi_res, stacked_complex=False) # append to list recon_amp.append(recon_amp_c)
def stochastic_gradient_descent(init_phase, target_amp, num_iters, prop_dist, wavelength, feature_size, roi_res=None, phase_path=None, prop_model='ASM', propagator=None, loss=nn.MSELoss(), lr=0.01, lr_s=0.003, s0=1.0, citl=False, camera_prop=None, writer=None, dtype=torch.float32, precomputed_H=None): """ Given the initial guess, run the SGD algorithm to calculate the optimal phase pattern of spatial light modulator. Input ------ :param init_phase: a tensor, in the shape of (1,1,H,W), initial guess for the phase. :param target_amp: a tensor, in the shape of (1,1,H,W), the amplitude of the target image. :param num_iters: the number of iterations to run the SGD. :param prop_dist: propagation distance in m. :param wavelength: wavelength in m. :param feature_size: the SLM pixel pitch, in meters, default 6.4e-6 :param roi_res: a tuple of integer, region of interest, like (880, 1600) :param phase_path: a string, for saving intermediate phases :param prop_model: a string, that indicates the propagation model. ('ASM' or 'MODEL') :param propagator: predefined function or model instance for the propagation. :param loss: loss function, default L2 :param lr: learning rate for optimization variables :param lr_s: learning rate for learnable scale :param s0: initial scale :param writer: Tensorboard writer instance :param dtype: default torch.float32 :param precomputed_H: A Pytorch complex64 tensor, pre-computed kernel shape of (1,1,2H,2W) for fast computation. Output ------ :return: a tensor, the optimized phase pattern at the SLM plane, in the shape of (1,1,H,W) """ device = init_phase.device s = torch.tensor(s0, requires_grad=True, device=device) # phase at the slm plane slm_phase = init_phase.requires_grad_(True) # optimization variables and adam optimizer optvars = [{'params': slm_phase}] if lr_s > 0: optvars += [{'params': s, 'lr': lr_s}] optimizer = optim.Adam(optvars, lr=lr) # crop target roi target_amp = utils.crop_image(target_amp, roi_res, stacked_complex=False) # run the iterative algorithm for k in range(num_iters): optimizer.zero_grad() # forward propagation from the SLM plane to the target plane real, imag = utils.polar_to_rect(torch.ones_like(slm_phase), slm_phase) slm_field = torch.complex(real, imag) recon_field = utils.propagate_field(slm_field, propagator, prop_dist, wavelength, feature_size, prop_model, dtype, precomputed_H) # get amplitude recon_amp = recon_field.abs() # crop roi recon_amp = utils.crop_image(recon_amp, target_shape=roi_res, stacked_complex=False) # camera-in-the-loop technique if citl: captured_amp = camera_prop(slm_phase) # use the gradient of proxy, replacing the amplitudes # captured_amp is assumed that its size already matches that of recon_amp out_amp = recon_amp + (captured_amp - recon_amp).detach() else: out_amp = recon_amp # calculate loss and backprop lossValue = loss(s * out_amp, target_amp) lossValue.backward() optimizer.step() # write to tensorboard / write phase image # Note that it takes 0.~ s for writing it to tensorboard with torch.no_grad(): if k % 50 == 0: print(k) utils.write_sgd_summary(slm_phase, out_amp, target_amp, k, writer=writer, path=phase_path, s=s, prefix='test') return slm_phase
def gerchberg_saxton(init_phase, target_amp, num_iters, prop_dist, wavelength, feature_size=6.4e-6, phase_path=None, prop_model='ASM', propagator=None, writer=None, dtype=torch.float32, precomputed_H_f=None, precomputed_H_b=None): """ Given the initial guess, run the SGD algorithm to calculate the optimal phase pattern of spatial light modulator :param init_phase: a tensor, in the shape of (1,1,H,W), initial guess for the phase. :param target_amp: a tensor, in the shape of (1,1,H,W), the amplitude of the target image. :param num_iters: the number of iterations to run the GS. :param prop_dist: propagation distance in m. :param wavelength: wavelength in m. :param feature_size: the SLM pixel pitch, in meters, default 6.4e-6 :param phase_path: path to save the results. :param prop_model: string indicating the light transport model, default 'ASM'. ex) 'ASM', 'fresnel', 'model' :param propagator: predefined function or model instance for the propagation. :param writer: tensorboard writer :param dtype: torch datatype for computation at different precision, default torch.float32. :param precomputed_H_f: A Pytorch complex64 tensor, pre-computed kernel for forward prop (SLM to image) :param precomputed_H_b: A Pytorch complex64 tensor, pre-computed kernel for backward propagation (image to SLM) Output ------ :return: a tensor, the optimized phase pattern at the SLM plane, in the shape of (1,1,H,W) """ # initial guess; random phase real, imag = utils.polar_to_rect(torch.ones_like(init_phase), init_phase) slm_field = torch.complex(real, imag) # run the GS algorithm for k in range(num_iters): # SLM plane to image plane recon_field = utils.propagate_field(slm_field, propagator, prop_dist, wavelength, feature_size, prop_model, dtype, precomputed_H_f) # write to tensorboard / write phase image # Note that it takes 0.~ s for writing it to tensorboard if k > 0 and k % 10 == 0: print(k) utils.write_gs_summary(slm_field, recon_field, target_amp, k, writer, prefix='test') # replace amplitude at the image plane recon_field = utils.replace_amplitude(recon_field, target_amp) # image plane to SLM plane slm_field = utils.propagate_field(recon_field, propagator, -prop_dist, wavelength, feature_size, prop_model, dtype, precomputed_H_b) # amplitude constraint at the SLM plane slm_field = utils.replace_amplitude(slm_field, torch.ones_like(target_amp)) # return phases return slm_field.angle()