def generate_angular_spectrum_kernel(shape, pixel_size, wavelength, \ numerical_aperture=None, flag_band_limited=True, \ dtype=torch.float32, device=torch.device('cuda')): """ Function that generates angular spectrum propagation kernel WITHOUT the distance The angular spectrum has the following form: p = exp(distance * kernel) kernel = 1j * 2 * pi * sqrt((ri/wavelength)**2-x**2-y**2) and this function generates the kernel only! """ assert len(shape) == 2, "pupil should be two dimensional!" ky_lin, kx_lin = util.generate_grid_2d(shape, pixel_size, flag_fourier=True, dtype=dtype, device=device) if flag_band_limited: assert numerical_aperture is not None, "need to provide numerical aperture of the system!" pupil_crop = op.r2c( generate_hard_pupil(shape, pixel_size, numerical_aperture, wavelength)) else: pupil_crop = 1.0 prop_kernel = 2.0 * np.pi * pupil_crop * \ op.exponentiate(op.r2c((1./wavelength)**2 - kx_lin**2 - ky_lin**2), 0.5) return op.multiply_complex(op._j, prop_kernel)
def _shift_stack_inplace(self, stack, shift_list): for img_idx in range(stack.shape[2]): y_shift = shift_list[0, img_idx] x_shift = shift_list[1, img_idx] kernel = op.exp( op.multiply_complex( op._j, op.r2c(2 * np.pi * (self.kx_lin * x_shift + self.ky_lin * y_shift)))) stack[..., img_idx] = op.convolve_kernel(op.r2c(stack[..., img_idx]), kernel, n_dim=2)[..., 0] return stack
def __init__(self, shape, pixel_size, wavelength, \ numerical_aperture = 1.0, pupil = None, \ dtype=torch.float32, device=torch.device('cuda'), **kwargs): super(Pupil, self).__init__() if pupil is not None: self.pupil = pupil.type(dtype).to(device) if len(self.pupil.shape) == 2: self.pupil = op.r2c(self.pupil) else: self.pupil = generate_hard_pupil(shape, pixel_size, numerical_aperture, wavelength, dtype, device)
def backward(self, obj): theta = -1 * self.theta if theta == 0: return obj else: if not obj.is_cuda: obj = obj.cuda() theta *= np.pi / 180.0 alpha = 1.0 * np.tan(theta / 2.0) beta = np.sin(-1.0 * theta) shear_phase_1 = op.exp( op.multiply_complex(op._j, self.coord_phase_1 * alpha)) shear_phase_2 = op.exp( op.multiply_complex(op._j, self.coord_phase_2 * beta)) self.dim[self.axis] = self.slice_per_tile self.obj_rotate = op.r2c( torch.zeros([self.dim[0], self.dim[1], self.dim[2]], dtype=self.dtype, device=self.device)) for idx_start in range(0, obj.shape[self.axis], self.slice_per_tile): idx_end = np.min( [obj.shape[self.axis], idx_start + self.slice_per_tile]) idx_slice = slice(idx_start, idx_end) self.dim[self.axis] = int(idx_end - idx_start) if self.axis == 0: self.range_crop_y = slice(0, self.dim[self.axis]) obj[idx_slice, :, :] = self._rotate_3d( obj[idx_slice, :, :], alpha, beta, shear_phase_1, shear_phase_2) elif self.axis == 1: self.range_crop_x = slice(0, self.dim[self.axis]) obj[:, idx_slice, :] = self._rotate_3d( obj[:, idx_slice, :], alpha, beta, shear_phase_1, shear_phase_2) elif self.axis == 2: self.range_crop_z = slice(0, self.dim[self.axis]) obj[:, :, idx_slice] = self._rotate_3d(obj[:, :, idx_slice], alpha, beta, shear_phase_1, shear_phase_2) self.obj_rotate[:] = 0.0 self.dim[self.axis] = obj.shape[self.axis] self.obj_rotate = None if not obj.is_cuda: obj = obj.cpu() return obj
def __init__(self, shape, axis=0, pad=True, pad_value=0, dtype=torch.float32, device=torch.device('cuda')): self.dim = np.array(shape) self.axis = axis self.pad_value = pad_value if pad: self.pad_size = np.ceil(self.dim / 2.0).astype('int') self.pad_size[self.axis] = 0 self.dim += 2 * self.pad_size else: self.pad_size = np.asarray([0, 0, 0]) self.dim = [int(size) for size in self.dim] self.range_crop_y = slice(self.pad_size[0], self.pad_size[0] + shape[0]) self.range_crop_x = slice(self.pad_size[1], self.pad_size[1] + shape[1]) self.range_crop_z = slice(self.pad_size[2], self.pad_size[2] + shape[2]) self.y = generate_grid_1d(self.dim[0], dtype=dtype).unsqueeze(-1).unsqueeze(-1) self.x = generate_grid_1d(self.dim[1], dtype=dtype).unsqueeze(0).unsqueeze(-1) self.z = generate_grid_1d(self.dim[2], dtype=dtype).unsqueeze(0).unsqueeze(0) self.ky = generate_grid_1d(self.dim[0], flag_fourier=True, dtype=dtype).unsqueeze(-1).unsqueeze(-1) self.kx = generate_grid_1d(self.dim[1], flag_fourier=True, dtype=dtype).unsqueeze(0).unsqueeze(-1) self.kz = generate_grid_1d(self.dim[2], flag_fourier=True, dtype=dtype).unsqueeze(0).unsqueeze(0) #Compute FFTs sequentially if object size is too large self.slice_per_tile = int( np.min([ np.floor(MAX_DIM * self.dim[self.axis] / np.prod(self.dim)), self.dim[self.axis] ])) self.dtype = dtype self.device = device if self.axis == 0: self.coord_phase_1 = op.r2c(-2.0 * np.pi * self.kz * self.x) self.coord_phase_2 = op.r2c(-2.0 * np.pi * self.kx * self.z) elif self.axis == 1: self.coord_phase_1 = op.r2c(-2.0 * np.pi * self.kz * self.y) self.coord_phase_2 = op.r2c(-2.0 * np.pi * self.ky * self.z) elif self.axis == 2: self.coord_phase_1 = op.r2c(-2.0 * np.pi * self.kx * self.y) self.coord_phase_2 = op.r2c(-2.0 * np.pi * self.ky * self.x)
def compute_prox(self, x): if self.parameter_list is not None: self.set_parameter(self.parameter_list[self.itr_count]) x_device = x.device x = x.to(device=self.device) if self.pure_real: x[...,0] = self._compute_prox_real(op.real(x), self.realProjector) x[...,1] = 0.0 elif self.pure_imag: x[...,0] = 0.0 x[...,1] = op.multiply_complex(op._j, op.r2c(self._compute_prox_real(op.imag(x), self.imagProjector))) elif self.pure_amplitude: x[...,0] = self._compute_prox_real(op.abs(x), self.realProjector) x[...,1] = 0.0 elif self.pure_phase: x = op.exp(op.multiply_complex(op._j, op.r2c(self._compute_prox_real(op.angle(x), self.realProjector)))) else: x[...,0] = self._compute_prox_real(op.real(x), self.realProjector) self.set_parameter(self.parameter / 1.0, self.maxitr) x[...,1] = self._compute_prox_real(op.imag(x), self.imagProjector) self.set_parameter(self.parameter * 1.0, self.maxitr) self.itr_count += 1 return x.to(x_device)
def generate_hard_pupil(shape, pixel_size, numerical_aperture, wavelength, \ dtype=torch.float32, device=torch.device('cuda')): """ This function generates pupil function(circular function) given shape, pixel_size, na, and wavelength """ assert len(shape) == 2, "pupil should be two dimensional!" ky_lin, kx_lin = util.generate_grid_2d(shape, pixel_size, flag_fourier=True, dtype=dtype, device=device) pupil_radius = numerical_aperture / wavelength pupil = (kx_lin**2 + ky_lin**2 <= pupil_radius**2).type(dtype) return op.r2c(pupil)
def forward(self, field, shift=None): """ Input parameters: - field: refocused field, before cropping - shift: estimated shift [y_shift, x_shift], default None (shift estimation off) """ if shift is None: return field field_out = field.clone() for img_idx in range(field.shape[2]): y_shift = shift[0, img_idx] x_shift = shift[1, img_idx] kernel = complex_exp( compelx_mul( op._j, op.r2c(2 * np.pi * (self.kx_lin * x_shift + self.ky_lin * y_shift)))) field_out[..., img_idx, :] = complex_conv(field[..., img_idx, :], kernel, 2, True) return field_out
def run(self, obj_init=None, forward_only=False, callback=None): """ run tomography solver Args: forward_only: True -- only runs forward model on estimated object False -- runs reconstruction """ if forward_only: assert obj_init is not None self.shuffle = False amplitude_list = [] self.dataloader = DataLoader(self.dataset, batch_size=1, shuffle=self.shuffle) error = [] #initialize object self.obj = obj_init if self.obj is None: self.obj = op.r2c(torch.zeros(self.shape).cuda()) else: if not self.obj.is_cuda: self.obj = self.obj.cuda() if len(self.obj.shape) == 3: self.obj = op.r2c(self.obj) #initialize shift parameters self.yx_shifts = None if self.shift_align: self.sa_pixel_count = [] self.yx_shift_all = [] self.yx_shifts = torch.zeros( (2, self.num_defocus, self.num_rotation)) if self.transform_align: self.xy_transform_all = [] self.xy_transforms = torch.zeros( (6, self.num_defocus, self.num_rotation)) # self.xy_transforms = torch.zeros((3, self.num_defocus, self.num_rotation)) # TEMPP # defocus_list_grad = torch.zeros((self.num_defocus, self.num_rotation), dtype = torch.float32) ref_rot_idx = None #begin iteration for itr_idx in range(self.optim_max_itr): sys.stdout.flush() running_cost = 0.0 #defocus_list_grad[:] = 0.0 if self.shift_align and itr_idx in self.sa_iterations: running_sa_pixel_count = 0.0 for data_idx, data in enumerate(self.dataloader, 0): #parse data if not forward_only: amplitudes, rotation_angle, defocus_list, rotation_idx = data if ref_rot_idx is None and abs(rotation_angle - 0.0) < 1e-2: ref_rot_idx = rotation_idx print("reference index is:", ref_rot_idx) amplitudes = torch.squeeze(amplitudes) if len(amplitudes.shape) < 3: amplitudes = amplitudes.unsqueeze(-1) else: rotation_angle, defocus_list, rotation_idx = data[-3:] #prepare tilt specific parameters defocus_list = torch.flatten(defocus_list).cuda() rotation_angle = rotation_angle.item() yx_shift = None if self.shift_align and self.sa_method == "gradient" and itr_idx in self.sa_iterations: yx_shift = self.yx_shifts[:, :, rotation_idx] yx_shift = yx_shift.cuda() yx_shift.requires_grad_() if self.defocus_refine and self.dr_method == "gradient" and itr_idx in self.dr_iterations: defocus_list.requires_grad_() #rotate object if data_idx == 0: self.obj = self.rotation_obj.forward( self.obj, rotation_angle) else: if abs(rotation_angle - previous_angle) > 90: self.obj = self.rotation_obj.forward( self.obj, -1 * previous_angle) self.obj = self.rotation_obj.forward( self.obj, rotation_angle) else: self.obj = self.rotation_obj.forward( self.obj, rotation_angle - previous_angle) if not forward_only: #define optimizer optimizer_params = [] if itr_idx in self.obj_update_iterations: self.obj.requires_grad_() optimizer_params.append({ 'params': self.obj, 'lr': self.optim_step_size }) if self.shift_align and self.sa_method == "gradient" and itr_idx in self.sa_iterations: optimizer_params.append({ 'params': yx_shift, 'lr': self.sa_step_size }) if self.defocus_refine and self.dr_method == "gradient" and itr_idx in self.dr_iterations: optimizer_params.append({ 'params': defocus_list, 'lr': self.dr_step_size }) optimizer = optim.SGD(optimizer_params) #forward scattering estimated_amplitudes = self.tomography_obj( self.obj, defocus_list, yx_shift) #in-plane rotation estimation if not forward_only: if self.transform_align and itr_idx in self.ta_iterations: if rotation_idx != ref_rot_idx: amplitudes, xy_transform = self.transform_obj.estimate( estimated_amplitudes, amplitudes) xy_transform = xy_transform.unsqueeze(-1) # self.dataset.update_amplitudes(amplitudes, rotation_idx) #Correlation based shift estimation if self.shift_align and shift.is_correlation_method( self.sa_method) and itr_idx in self.sa_iterations: if rotation_idx != ref_rot_idx: amplitudes, yx_shift, _ = self.shift_obj.estimate( estimated_amplitudes, amplitudes) yx_shift = yx_shift.unsqueeze(-1) # self.dataset.update_amplitudes(amplitudes, rotation_idx) if itr_idx == self.optim_max_itr - 1: print("Last iteration: updated amplitudes") self.dataset.update_amplitudes(amplitudes, rotation_idx) #compute cost cost = self.cost_function(estimated_amplitudes, amplitudes.cuda()) running_cost += cost.item() #backpropagation cost.backward() #update object # if itr_idx >= self.dr_start_iteration: # # print(torch.norm(defocus_list.grad.data)) # defocus_list_grad[:,data_idx] = defocus_list.grad.data * self.dr_step_size optimizer.step() optimizer.zero_grad() del cost else: #store measurement amplitude_list.append(estimated_amplitudes.cpu().detach()) del estimated_amplitudes self.obj.requires_grad = False if not forward_only: #keep track of shift alignment for the tilt if self.shift_align and itr_idx in self.sa_iterations: if yx_shift is not None: yx_shift.requires_grad = False if rotation_idx != ref_rot_idx: self.yx_shifts[:, :, rotation_idx] = yx_shift[:].cpu( ) running_sa_pixel_count += torch.sum( torch.abs(yx_shift.cpu().flatten())) #keep track of transform alignment for the tilt if self.transform_align and itr_idx in self.ta_iterations: if rotation_idx != ref_rot_idx: self.xy_transforms[ ..., rotation_idx] = xy_transform[:].cpu() #keep track of defocus alignment for the tilt if self.defocus_refine and itr_idx in self.dr_iterations: defocus_list.requires_grad = False self.dataset.update_defocus_list( defocus_list[:].cpu().detach(), rotation_idx) previous_angle = rotation_angle #rotate object back if data_idx == (self.dataset.__len__() - 1): previous_angle = 0.0 self.obj = self.rotation_obj.forward( self.obj, -1.0 * rotation_angle) print("Rotation {:03d}/{:03d}.".format(data_idx + 1, self.dataset.__len__()), end="\r") #apply regularization amplitudes = None torch.cuda.empty_cache() if not forward_only: if itr_idx in self.obj_update_iterations: self.obj = self.regularizer_obj.apply(self.obj) error.append(running_cost) #keep track of shift alignment results if self.shift_align and itr_idx in self.sa_iterations: self.sa_pixel_count.append(running_sa_pixel_count) self.yx_shift_all.append(np.array(self.yx_shifts).copy()) #keep track of transform alignment results if self.transform_align and itr_idx in self.ta_iterations: self.xy_transform_all.append( np.array(self.xy_transforms).copy()) if callback is not None: callback(self.obj.cpu().detach(), error) #TEMPPPPP # callback(defocus_list_grad, self.dataset.get_all_defocus_lists(), error) if forward_only and itr_idx == 0: return torch.cat([ torch.unsqueeze(amplitude_list[idx], -1) for idx in range(len(amplitude_list)) ], axis=-1) print("Iteration {:03d}/{:03d}. Error: {:03f}".format( itr_idx + 1, self.optim_max_itr, np.log10(running_cost))) self.defocus_list = self.dataset.get_all_defocus_lists() return self.obj.cpu().detach(), error
def compute_prox(self, x): x = op.exp(op.multiply_complex(op._j, op.r2c(op.angle(x)))) return x