def deconvolve_single_gpu(vol_a: Volume, n: int, psf_a: np.ndarray) -> Volume: """ Perform joint Richardson-Lucy deconvolution on two volumes using two specified PSFs on the GPU :param vol_a: The first volume :param n: The number of Richardson-Lucy iterations :param psf_a: The PSF for the first volume :return: The fused RL deconvolution """ from functools import partial from dispim.metrics import DECONV_MSE_DELTA import arrayfire as af print(vol_a.shape) view_a = vol_a.astype(np.float) psf_a = psf_a.astype(np.float) / np.sum(psf_a).astype(np.float) psf_Ai = psf_a[::-1, ::-1, ::-1] view_a = af.cast(af.from_ndarray(np.array(view_a)), af.Dtype.f32) psf_a = af.cast(af.from_ndarray(psf_a), af.Dtype.f32) psf_Ai = af.cast(af.from_ndarray(psf_Ai), af.Dtype.f32) estimate = view_a convolve = partial(af.fft_convolve3) with progressbar.ProgressBar(max_value=n, redirect_stderr=True) as bar: for _ in bar(range(n)): if metrack.is_tracked(DECONV_MSE_DELTA): prev = estimate estimate = estimate * convolve( view_a / (convolve(estimate, psf_a) + 1), psf_Ai) if metrack.is_tracked(DECONV_MSE_DELTA): metrack.append_metric(DECONV_MSE_DELTA, (_, float(np.mean( (prev - estimate)**2)))) CURSOR_UP_ONE = '\x1b[1A' ERASE_LINE = '\x1b[2K' print(CURSOR_UP_ONE + ERASE_LINE + CURSOR_UP_ONE) logger.debug( f'Deconved min: {np.min(estimate)}, max: {np.max(estimate)}, has nan: {np.any(np.isnan(estimate))}' ) result = estimate.to_ndarray() del estimate return Volume(result.astype(np.uint16), inverted=False, spacing=vol_a.spacing, is_skewed=False)
def extract_psf(vol: Volume, min_size: int = 5, max_size: int = 30, psf_half_width: int = 10) -> np.ndarray: """ Attempt to extract the PSF from a volume by looking at small objects that are representative of the volume's PSF :param vol: The volume to extract the PSF from :param min_size: The minimum area of objects to consider :param max_size: The maximum area of objects to consider :param psf_half_width: The half-width of the PSF in all axes :return: The estimated PSF, shape = (psf_half_width*2+1, ) * 3 """ from skimage.measure import label, regionprops from dispim.util import extract_3d, threshold_otsu from dispim.metrics import PSF_SIGMA_XY, PSF_SIGMA_Z data = vol thr = threshold_otsu(data[:, :, ]) data_bin = data > thr points = np.array([ np.array(r.centroid, dtype=np.int) for r in regionprops(label(data_bin)) if min_size <= r.area <= max_size ]) logger.debug(f'Found {len(points)} objects') # points = np.random.choice(points, size=min(len(points), 12000), replace=False) points = points[np.random.choice( len(points), min(len(points), 12000), replace=False), :] blob_images = [] for point in points: blob_images.append(extract_3d(data, point, psf_half_width)) if metrack.is_tracked(PSF_SIGMA_XY) or metrack.is_tracked(PSF_SIGMA_Z): height, center_x, center_y, width_x, width_y, rotation = fitgaussian( blob_images[-1][psf_half_width, :, :]) scale = vol.shape[0] if width_x > width_y: metrack.append_metric(PSF_SIGMA_Z, (None, width_x * scale)) metrack.append_metric(PSF_SIGMA_XY, (None, width_y * scale)) else: metrack.append_metric(PSF_SIGMA_Z, (None, width_y * scale)) metrack.append_metric(PSF_SIGMA_XY, (None, width_x * scale)) median_blob = np.median(blob_images, axis=0) logger.debug( f'PSF mean: {median_blob.mean()}, median: {np.median(median_blob)}, min: {median_blob.min()}, max: {median_blob.max()}' ) return median_blob
def process(self, data: ProcessData, path: str, save_intermediate=False) -> ProcessData: import gc from dispim.metrics import PROCESS_TIME from dispim import metrack import time with metrack.Context('Processor'): for i, step in enumerate(self.steps): # TODO: Check this BEFORE processing... logger.info("Performing step {} on {} data".format(step.__class__.__name__, "dual" if len(data) == 2 else "single")) if ((not step.accepts_dual and len(data) == 2) or (not step.accepts_single and len(data) == 1)): if i > 0: raise ValueError('Step {} is incompatible with the output of step {}' .format(step.__class__.__name__, self.steps[i - 1].__class__.__name__)) else: raise ValueError("Step {} is incompatible with the input data" .format(step.__class__.__name__)) start = time.time() with metrack.Context(f'{step.__class__.__name__} ({i})'): data = step.process(data) end = time.time() metrack.append_metric(PROCESS_TIME, (step.__class__.__name__, end - start)) gc.collect() if save_intermediate: data[0].save_tiff(step.__class__.__name__ + "_A", path=path) if len(data) > 1: data[1].save_tiff(step.__class__.__name__ + "_B", path=path) return data
def callback(value: float, gradient: float): metrack.append_metric(MUTUAL_INFORMATION_METRIC, (None, value)) metrack.append_metric(MUTUAL_INFORMATION_GRADIENT_METRIC, (None, gradient))