Пример #1
0
def generate_angular_spectrum_kernel(shape, pixel_size, wavelength, \
                                     numerical_aperture=None,  flag_band_limited=True, \
                                     dtype=torch.float32, device=torch.device('cuda')):
    """
    Function that generates angular spectrum propagation kernel WITHOUT the distance
    The angular spectrum has the following form:
    p = exp(distance * kernel)
    kernel = 1j * 2 * pi * sqrt((ri/wavelength)**2-x**2-y**2)
    and this function generates the kernel only!
    """
    assert len(shape) == 2, "pupil should be two dimensional!"
    ky_lin, kx_lin = util.generate_grid_2d(shape,
                                           pixel_size,
                                           flag_fourier=True,
                                           dtype=dtype,
                                           device=device)
    if flag_band_limited:
        assert numerical_aperture is not None, "need to provide numerical aperture of the system!"
        pupil_crop = op.r2c(
            generate_hard_pupil(shape, pixel_size, numerical_aperture,
                                wavelength))
    else:
        pupil_crop = 1.0
    prop_kernel = 2.0 * np.pi * pupil_crop * \
                  op.exponentiate(op.r2c((1./wavelength)**2 - kx_lin**2 - ky_lin**2), 0.5)
    return op.multiply_complex(op._j, prop_kernel)
Пример #2
0
 def _shift_stack_inplace(self, stack, shift_list):
     for img_idx in range(stack.shape[2]):
         y_shift = shift_list[0, img_idx]
         x_shift = shift_list[1, img_idx]
         kernel = op.exp(
             op.multiply_complex(
                 op._j,
                 op.r2c(2 * np.pi *
                        (self.kx_lin * x_shift + self.ky_lin * y_shift))))
         stack[..., img_idx] = op.convolve_kernel(op.r2c(stack[...,
                                                               img_idx]),
                                                  kernel,
                                                  n_dim=2)[..., 0]
     return stack
Пример #3
0
 def __init__(self, shape, pixel_size, wavelength, \
              numerical_aperture = 1.0, pupil = None, \
              dtype=torch.float32, device=torch.device('cuda'), **kwargs):
     super(Pupil, self).__init__()
     if pupil is not None:
         self.pupil = pupil.type(dtype).to(device)
         if len(self.pupil.shape) == 2:
             self.pupil = op.r2c(self.pupil)
     else:
         self.pupil = generate_hard_pupil(shape, pixel_size,
                                          numerical_aperture, wavelength,
                                          dtype, device)
    def backward(self, obj):
        theta = -1 * self.theta
        if theta == 0:
            return obj
        else:
            if not obj.is_cuda:
                obj = obj.cuda()
            theta *= np.pi / 180.0
            alpha = 1.0 * np.tan(theta / 2.0)
            beta = np.sin(-1.0 * theta)

            shear_phase_1 = op.exp(
                op.multiply_complex(op._j, self.coord_phase_1 * alpha))
            shear_phase_2 = op.exp(
                op.multiply_complex(op._j, self.coord_phase_2 * beta))

            self.dim[self.axis] = self.slice_per_tile
            self.obj_rotate = op.r2c(
                torch.zeros([self.dim[0], self.dim[1], self.dim[2]],
                            dtype=self.dtype,
                            device=self.device))

            for idx_start in range(0, obj.shape[self.axis],
                                   self.slice_per_tile):
                idx_end = np.min(
                    [obj.shape[self.axis], idx_start + self.slice_per_tile])
                idx_slice = slice(idx_start, idx_end)
                self.dim[self.axis] = int(idx_end - idx_start)
                if self.axis == 0:
                    self.range_crop_y = slice(0, self.dim[self.axis])
                    obj[idx_slice, :, :] = self._rotate_3d(
                        obj[idx_slice, :, :], alpha, beta, shear_phase_1,
                        shear_phase_2)
                elif self.axis == 1:
                    self.range_crop_x = slice(0, self.dim[self.axis])
                    obj[:, idx_slice, :] = self._rotate_3d(
                        obj[:, idx_slice, :], alpha, beta, shear_phase_1,
                        shear_phase_2)
                elif self.axis == 2:
                    self.range_crop_z = slice(0, self.dim[self.axis])
                    obj[:, :,
                        idx_slice] = self._rotate_3d(obj[:, :,
                                                         idx_slice], alpha,
                                                     beta, shear_phase_1,
                                                     shear_phase_2)
                self.obj_rotate[:] = 0.0
            self.dim[self.axis] = obj.shape[self.axis]
            self.obj_rotate = None
            if not obj.is_cuda:
                obj = obj.cpu()
            return obj
    def __init__(self,
                 shape,
                 axis=0,
                 pad=True,
                 pad_value=0,
                 dtype=torch.float32,
                 device=torch.device('cuda')):
        self.dim = np.array(shape)
        self.axis = axis
        self.pad_value = pad_value
        if pad:
            self.pad_size = np.ceil(self.dim / 2.0).astype('int')
            self.pad_size[self.axis] = 0
            self.dim += 2 * self.pad_size
        else:
            self.pad_size = np.asarray([0, 0, 0])

        self.dim = [int(size) for size in self.dim]

        self.range_crop_y = slice(self.pad_size[0],
                                  self.pad_size[0] + shape[0])
        self.range_crop_x = slice(self.pad_size[1],
                                  self.pad_size[1] + shape[1])
        self.range_crop_z = slice(self.pad_size[2],
                                  self.pad_size[2] + shape[2])

        self.y = generate_grid_1d(self.dim[0],
                                  dtype=dtype).unsqueeze(-1).unsqueeze(-1)
        self.x = generate_grid_1d(self.dim[1],
                                  dtype=dtype).unsqueeze(0).unsqueeze(-1)
        self.z = generate_grid_1d(self.dim[2],
                                  dtype=dtype).unsqueeze(0).unsqueeze(0)

        self.ky = generate_grid_1d(self.dim[0], flag_fourier=True,
                                   dtype=dtype).unsqueeze(-1).unsqueeze(-1)
        self.kx = generate_grid_1d(self.dim[1], flag_fourier=True,
                                   dtype=dtype).unsqueeze(0).unsqueeze(-1)
        self.kz = generate_grid_1d(self.dim[2], flag_fourier=True,
                                   dtype=dtype).unsqueeze(0).unsqueeze(0)

        #Compute FFTs sequentially if object size is too large
        self.slice_per_tile = int(
            np.min([
                np.floor(MAX_DIM * self.dim[self.axis] / np.prod(self.dim)),
                self.dim[self.axis]
            ]))
        self.dtype = dtype
        self.device = device

        if self.axis == 0:
            self.coord_phase_1 = op.r2c(-2.0 * np.pi * self.kz * self.x)
            self.coord_phase_2 = op.r2c(-2.0 * np.pi * self.kx * self.z)
        elif self.axis == 1:
            self.coord_phase_1 = op.r2c(-2.0 * np.pi * self.kz * self.y)
            self.coord_phase_2 = op.r2c(-2.0 * np.pi * self.ky * self.z)
        elif self.axis == 2:
            self.coord_phase_1 = op.r2c(-2.0 * np.pi * self.kx * self.y)
            self.coord_phase_2 = op.r2c(-2.0 * np.pi * self.ky * self.x)
Пример #6
0
	def compute_prox(self, x):
		if self.parameter_list is not None:
			self.set_parameter(self.parameter_list[self.itr_count])
		x_device = x.device
		x = x.to(device=self.device)
		if self.pure_real:
			x[...,0] = self._compute_prox_real(op.real(x), self.realProjector)
			x[...,1] = 0.0
		elif self.pure_imag:
			x[...,0] = 0.0
			x[...,1] = op.multiply_complex(op._j, op.r2c(self._compute_prox_real(op.imag(x), self.imagProjector)))
		elif self.pure_amplitude:
			x[...,0] = self._compute_prox_real(op.abs(x), self.realProjector)
			x[...,1] = 0.0
		elif self.pure_phase:
			x = op.exp(op.multiply_complex(op._j, op.r2c(self._compute_prox_real(op.angle(x), self.realProjector))))
		else:
			x[...,0] = self._compute_prox_real(op.real(x), self.realProjector)
			self.set_parameter(self.parameter / 1.0, self.maxitr)
			x[...,1] = self._compute_prox_real(op.imag(x), self.imagProjector)
			self.set_parameter(self.parameter * 1.0, self.maxitr)
		self.itr_count += 1	
		return x.to(x_device)
Пример #7
0
def generate_hard_pupil(shape, pixel_size, numerical_aperture, wavelength, \
                   dtype=torch.float32, device=torch.device('cuda')):
    """
    This function generates pupil function(circular function) given shape, pixel_size, na, and wavelength
    """
    assert len(shape) == 2, "pupil should be two dimensional!"
    ky_lin, kx_lin = util.generate_grid_2d(shape,
                                           pixel_size,
                                           flag_fourier=True,
                                           dtype=dtype,
                                           device=device)

    pupil_radius = numerical_aperture / wavelength
    pupil = (kx_lin**2 + ky_lin**2 <= pupil_radius**2).type(dtype)
    return op.r2c(pupil)
Пример #8
0
 def forward(self, field, shift=None):
     """
     Input parameters:
         - field: refocused field, before cropping
         - shift: estimated shift [y_shift, x_shift], default None (shift estimation off)
     """
     if shift is None:
         return field
     field_out = field.clone()
     for img_idx in range(field.shape[2]):
         y_shift = shift[0, img_idx]
         x_shift = shift[1, img_idx]
         kernel = complex_exp(
             compelx_mul(
                 op._j,
                 op.r2c(2 * np.pi *
                        (self.kx_lin * x_shift + self.ky_lin * y_shift))))
         field_out[..., img_idx, :] = complex_conv(field[..., img_idx, :],
                                                   kernel, 2, True)
     return field_out
    def run(self, obj_init=None, forward_only=False, callback=None):
        """
		run tomography solver
		Args:
		forward_only: True  -- only runs forward model on estimated object
					  False -- runs reconstruction
		"""
        if forward_only:
            assert obj_init is not None
            self.shuffle = False
            amplitude_list = []

        self.dataloader = DataLoader(self.dataset,
                                     batch_size=1,
                                     shuffle=self.shuffle)

        error = []
        #initialize object
        self.obj = obj_init
        if self.obj is None:
            self.obj = op.r2c(torch.zeros(self.shape).cuda())
        else:
            if not self.obj.is_cuda:
                self.obj = self.obj.cuda()
            if len(self.obj.shape) == 3:
                self.obj = op.r2c(self.obj)

        #initialize shift parameters
        self.yx_shifts = None
        if self.shift_align:
            self.sa_pixel_count = []
            self.yx_shift_all = []
            self.yx_shifts = torch.zeros(
                (2, self.num_defocus, self.num_rotation))

        if self.transform_align:
            self.xy_transform_all = []
            self.xy_transforms = torch.zeros(
                (6, self.num_defocus, self.num_rotation))
#			self.xy_transforms = torch.zeros((3, self.num_defocus, self.num_rotation))
# TEMPP
# defocus_list_grad = torch.zeros((self.num_defocus, self.num_rotation), dtype = torch.float32)
        ref_rot_idx = None
        #begin iteration
        for itr_idx in range(self.optim_max_itr):
            sys.stdout.flush()
            running_cost = 0.0
            #defocus_list_grad[:] = 0.0
            if self.shift_align and itr_idx in self.sa_iterations:
                running_sa_pixel_count = 0.0
            for data_idx, data in enumerate(self.dataloader, 0):
                #parse data
                if not forward_only:
                    amplitudes, rotation_angle, defocus_list, rotation_idx = data
                    if ref_rot_idx is None and abs(rotation_angle -
                                                   0.0) < 1e-2:
                        ref_rot_idx = rotation_idx
                        print("reference index is:", ref_rot_idx)
                    amplitudes = torch.squeeze(amplitudes)
                    if len(amplitudes.shape) < 3:
                        amplitudes = amplitudes.unsqueeze(-1)

                else:
                    rotation_angle, defocus_list, rotation_idx = data[-3:]
                #prepare tilt specific parameters
                defocus_list = torch.flatten(defocus_list).cuda()
                rotation_angle = rotation_angle.item()
                yx_shift = None
                if self.shift_align and self.sa_method == "gradient" and itr_idx in self.sa_iterations:
                    yx_shift = self.yx_shifts[:, :, rotation_idx]
                    yx_shift = yx_shift.cuda()
                    yx_shift.requires_grad_()
                if self.defocus_refine and self.dr_method == "gradient" and itr_idx in self.dr_iterations:
                    defocus_list.requires_grad_()
                #rotate object
                if data_idx == 0:
                    self.obj = self.rotation_obj.forward(
                        self.obj, rotation_angle)
                else:
                    if abs(rotation_angle - previous_angle) > 90:
                        self.obj = self.rotation_obj.forward(
                            self.obj, -1 * previous_angle)
                        self.obj = self.rotation_obj.forward(
                            self.obj, rotation_angle)
                    else:
                        self.obj = self.rotation_obj.forward(
                            self.obj, rotation_angle - previous_angle)
                if not forward_only:
                    #define optimizer
                    optimizer_params = []

                    if itr_idx in self.obj_update_iterations:
                        self.obj.requires_grad_()
                        optimizer_params.append({
                            'params': self.obj,
                            'lr': self.optim_step_size
                        })
                    if self.shift_align and self.sa_method == "gradient" and itr_idx in self.sa_iterations:
                        optimizer_params.append({
                            'params': yx_shift,
                            'lr': self.sa_step_size
                        })
                    if self.defocus_refine and self.dr_method == "gradient" and itr_idx in self.dr_iterations:
                        optimizer_params.append({
                            'params': defocus_list,
                            'lr': self.dr_step_size
                        })
                    optimizer = optim.SGD(optimizer_params)

                #forward scattering
                estimated_amplitudes = self.tomography_obj(
                    self.obj, defocus_list, yx_shift)
                #in-plane rotation estimation
                if not forward_only:
                    if self.transform_align and itr_idx in self.ta_iterations:
                        if rotation_idx != ref_rot_idx:
                            amplitudes, xy_transform = self.transform_obj.estimate(
                                estimated_amplitudes, amplitudes)
                            xy_transform = xy_transform.unsqueeze(-1)
#						self.dataset.update_amplitudes(amplitudes, rotation_idx)
#Correlation based shift estimation
                    if self.shift_align and shift.is_correlation_method(
                            self.sa_method) and itr_idx in self.sa_iterations:
                        if rotation_idx != ref_rot_idx:
                            amplitudes, yx_shift, _ = self.shift_obj.estimate(
                                estimated_amplitudes, amplitudes)
                            yx_shift = yx_shift.unsqueeze(-1)


#						self.dataset.update_amplitudes(amplitudes, rotation_idx)
                    if itr_idx == self.optim_max_itr - 1:
                        print("Last iteration: updated amplitudes")
                        self.dataset.update_amplitudes(amplitudes,
                                                       rotation_idx)

            #compute cost
                    cost = self.cost_function(estimated_amplitudes,
                                              amplitudes.cuda())
                    running_cost += cost.item()

                    #backpropagation
                    cost.backward()
                    #update object
                    # if itr_idx >= self.dr_start_iteration:
                    # 	# print(torch.norm(defocus_list.grad.data))
                    # 	defocus_list_grad[:,data_idx] = defocus_list.grad.data *  self.dr_step_size
                    optimizer.step()
                    optimizer.zero_grad()
                    del cost
                else:
                    #store measurement
                    amplitude_list.append(estimated_amplitudes.cpu().detach())
                del estimated_amplitudes
                self.obj.requires_grad = False
                if not forward_only:
                    #keep track of shift alignment for the tilt
                    if self.shift_align and itr_idx in self.sa_iterations:
                        if yx_shift is not None:
                            yx_shift.requires_grad = False
                            if rotation_idx != ref_rot_idx:
                                self.yx_shifts[:, :,
                                               rotation_idx] = yx_shift[:].cpu(
                                               )
                                running_sa_pixel_count += torch.sum(
                                    torch.abs(yx_shift.cpu().flatten()))

                    #keep track of transform alignment for the tilt
                    if self.transform_align and itr_idx in self.ta_iterations:
                        if rotation_idx != ref_rot_idx:
                            self.xy_transforms[
                                ..., rotation_idx] = xy_transform[:].cpu()

                    #keep track of defocus alignment for the tilt
                    if self.defocus_refine and itr_idx in self.dr_iterations:
                        defocus_list.requires_grad = False
                        self.dataset.update_defocus_list(
                            defocus_list[:].cpu().detach(), rotation_idx)

                previous_angle = rotation_angle

                #rotate object back
                if data_idx == (self.dataset.__len__() - 1):
                    previous_angle = 0.0
                    self.obj = self.rotation_obj.forward(
                        self.obj, -1.0 * rotation_angle)
                print("Rotation {:03d}/{:03d}.".format(data_idx + 1,
                                                       self.dataset.__len__()),
                      end="\r")

            #apply regularization
            amplitudes = None
            torch.cuda.empty_cache()
            if not forward_only:
                if itr_idx in self.obj_update_iterations:
                    self.obj = self.regularizer_obj.apply(self.obj)
            error.append(running_cost)

            #keep track of shift alignment results
            if self.shift_align and itr_idx in self.sa_iterations:
                self.sa_pixel_count.append(running_sa_pixel_count)
                self.yx_shift_all.append(np.array(self.yx_shifts).copy())

            #keep track of transform alignment results
            if self.transform_align and itr_idx in self.ta_iterations:
                self.xy_transform_all.append(
                    np.array(self.xy_transforms).copy())

            if callback is not None:
                callback(self.obj.cpu().detach(), error)
                #TEMPPPPP
                # callback(defocus_list_grad, self.dataset.get_all_defocus_lists(), error)
            if forward_only and itr_idx == 0:
                return torch.cat([
                    torch.unsqueeze(amplitude_list[idx], -1)
                    for idx in range(len(amplitude_list))
                ],
                                 axis=-1)
            print("Iteration {:03d}/{:03d}. Error: {:03f}".format(
                itr_idx + 1, self.optim_max_itr, np.log10(running_cost)))

        self.defocus_list = self.dataset.get_all_defocus_lists()
        return self.obj.cpu().detach(), error
Пример #10
0
	def compute_prox(self, x):	
		x = op.exp(op.multiply_complex(op._j, op.r2c(op.angle(x))))
		return x