def correct_dot_masks(masks, gain_map, excluded_pixels=None, allow_empty=False): mask_shape = masks.shape sig_shape = gain_map.shape masks = masks.reshape((-1, prod(sig_shape))) if excluded_pixels is not None: if is_sparse(masks): result = sparse.DOK(masks) else: result = masks.copy() desc = RepairDescriptor(sig_shape, excluded_pixels=excluded_pixels, allow_empty=allow_empty) for e, r, c in zip(desc.exclude_flat, desc.repair_flat, desc.repair_counts): result[:, e] = 0 rep = masks[:, e] / c # We have to loop because of sparse.pydata limitations for m in range(result.shape[0]): for rr in r[:c]: result[m, rr] = result[m, rr] + rep[m] if is_sparse(result): result = sparse.COO(result) else: result = masks result = result * gain_map.flatten() return result.reshape(mask_shape)
def _compute_masks(self): """ Call mask factories and combine to mask stack Returns ------- a list of masks with contents as they were created by the factories and converted uniformly to dense or sparse matrices depending on ``self.use_sparse``. """ # Make sure all the masks are either sparse or dense # If the use_sparse property is set to Ture or False, # it takes precedence. # If it is None, use sparse only if all masks are sparse # and set the use_sparse property accordingly default_sparse = 'scipy.sparse' if callable(self.mask_factories): raw_masks = self.mask_factories() if not is_sparse(raw_masks): default_sparse = False mask_slices = [raw_masks] else: mask_slices = [] for f in self.mask_factories: m = f() # Scipy.sparse is always 2D, so we have to convert here # before reshaping if scipy.sparse.issparse(m): m = sparse.COO.from_scipy_sparse(m) # We reshape to be a stack of 1 so that we can unify code below m = m.reshape((1, ) + m.shape) if not is_sparse(m): default_sparse = False mask_slices.append(m) if self._use_sparse is None: self._use_sparse = default_sparse if self.use_sparse: # Conversion to correct back-end will happen later # Use sparse.pydata because it implements the array interface # which makes mask handling easier masks = sparse.concatenate([to_sparse(m) for m in mask_slices]) else: masks = np.concatenate([to_dense(m) for m in mask_slices]) return masks
def test_uses_sparse_mixed_default(lt_ctx): data = _mk_random(size=(16, 16, 16, 16), dtype="<u2") mask0 = sp.csr_matrix(_mk_random(size=(16, 16))) mask1 = _mk_random(size=(16, 16)) dataset = MemoryDataSet(data=data, tileshape=(4 * 4, 4, 4), num_partitions=2) analysis = lt_ctx.create_mask_analysis( dataset=dataset, factories=[lambda: mask0, lambda: mask1]) assert not is_sparse(_mask_from_analysis(dataset, analysis))
def test_uses_sparse_sparse_false(lt_ctx): data = _mk_random(size=(16, 16, 16, 16), dtype="<u2") mask0 = sparse.COO.from_numpy(_mk_random(size=(16, 16))) mask1 = sparse.COO.from_numpy(_mk_random(size=(16, 16))) dataset = MemoryDataSet(data=data, tileshape=(4 * 4, 4, 4), num_partitions=2) analysis = lt_ctx.create_mask_analysis( dataset=dataset, factories=[lambda: mask0, lambda: mask1], use_sparse=False) assert not is_sparse(_mask_from_analysis(dataset, analysis))
def test_mask_patch_overlapping(): for i in range(REPEATS): print(f"Loop number {i}") num_nav_dims = np.random.choice([1, 2, 3]) num_sig_dims = np.random.choice([2, 3]) nav_dims = tuple(np.random.randint(low=8, high=16, size=num_nav_dims)) sig_dims = tuple(np.random.randint(low=8, high=16, size=num_sig_dims)) # The mask-based correction is performed as float64 since it creates # numerical instabilities otherwise # Constant data to reconstruct neighboring damaged pixels faithfully data = np.ones(nav_dims + sig_dims, dtype=np.float64) gain_map = (np.random.random(sig_dims) + 1).astype(np.float64) dark_image = np.random.random(sig_dims).astype(np.float64) # Neighboring excluded pixels exclude = np.ones((num_sig_dims, 3), dtype=np.int32) exclude[0, 1] += 1 exclude[1, 2] += 1 damaged_data = data.copy() damaged_data /= gain_map damaged_data += dark_image damaged_data[(Ellipsis, *exclude)] = 1e24 print("Nav dims: ", nav_dims) print("Sig dims:", sig_dims) print("Exclude: ", exclude) masks = (np.random.random((2, ) + sig_dims) - 0.5).astype(np.float64) data_flat = data.reshape((np.prod(nav_dims), np.prod(sig_dims))) damaged_flat = damaged_data.reshape((np.prod(nav_dims), np.prod(sig_dims))) correct_dot = np.dot(data_flat, masks.reshape((-1, np.prod(sig_dims))).T) corrected_masks = detector.correct_dot_masks(masks, gain_map, exclude) assert not is_sparse(corrected_masks) reconstructed_dot =\ np.dot(damaged_flat, corrected_masks.reshape((-1, np.prod(sig_dims))).T)\ - np.dot(dark_image.flatten(), corrected_masks.reshape((-1, np.prod(sig_dims))).T) _check_result( data=correct_dot, corrected=reconstructed_dot, atol=1e-8, rtol=1e-5 )
def test_mask_patch_sparse(): for i in range(REPEATS): print(f"Loop number {i}") num_nav_dims = np.random.choice([1, 2, 3]) num_sig_dims = np.random.choice([2, 3]) nav_dims = tuple(np.random.randint(low=8, high=16, size=num_nav_dims)) sig_dims = tuple(np.random.randint(low=8, high=16, size=num_sig_dims)) # The mask-based correction is performed as float64 since it creates # numerical instabilities otherwise data = gradient_data(nav_dims, sig_dims).astype(np.float64) gain_map = (np.random.random(sig_dims) + 1).astype(np.float64) dark_image = np.random.random(sig_dims).astype(np.float64) exclude = exclude_pixels(sig_dims=sig_dims, num_excluded=3) damaged_data = data.copy() damaged_data /= gain_map damaged_data += dark_image damaged_data[(Ellipsis, *exclude)] = 1e24 print("Nav dims: ", nav_dims) print("Sig dims:", sig_dims) print("Exclude: ", exclude) masks = sparse.DOK(sparse.zeros((20, ) + sig_dims, dtype=np.float64)) indices = [np.random.randint(low=0, high=s, size=s//2) for s in (20, ) + sig_dims] for tup in zip(*indices): masks[tup] = 1 masks = masks.to_coo() data_flat = data.reshape((np.prod(nav_dims), np.prod(sig_dims))) damaged_flat = damaged_data.reshape((np.prod(nav_dims), np.prod(sig_dims))) correct_dot = sparse.dot(data_flat, masks.reshape((-1, np.prod(sig_dims))).T) corrected_masks = detector.correct_dot_masks(masks, gain_map, exclude) assert is_sparse(corrected_masks) reconstructed_dot =\ sparse.dot(damaged_flat, corrected_masks.reshape((-1, np.prod(sig_dims))).T)\ - sparse.dot(dark_image.flatten(), corrected_masks.reshape((-1, np.prod(sig_dims))).T) _check_result( data=correct_dot, corrected=reconstructed_dot, atol=1e-8, rtol=1e-5 )
def _get_masks_for_slice(slice_): stack_height = computed_masks.shape[0] m = slice_.get(computed_masks, sig_only=True) # We need the mask's signal dimension flattened m = m.reshape((stack_height, -1)) if transpose: # We need the stack transposed in the next step m = m.T if is_sparse(m): return _build_sparse(m, dtype, sparse_backend, backend) else: if backend == 'numpy': return m.astype(dtype) elif backend == 'cupy': # Avoid importing if possible import cupy return cupy.array(m.astype(dtype))
def test_mask_correction(): for i in range(REPEATS): print(f"Loop number {i}") num_nav_dims = np.random.choice([1, 2, 3]) num_sig_dims = np.random.choice([2, 3]) nav_dims = tuple(np.random.randint(low=8, high=16, size=num_nav_dims)) sig_dims = tuple(np.random.randint(low=8, high=16, size=num_sig_dims)) # The mask-based correction is performed as float64 since it creates # numerical instabilities otherwise data = gradient_data(nav_dims, sig_dims).astype(np.float64) gain_map = (np.random.random(sig_dims) + 1).astype(np.float64) dark_image = np.random.random(sig_dims).astype(np.float64) exclude = exclude_pixels(sig_dims=sig_dims, num_excluded=0) damaged_data = data.copy() damaged_data /= gain_map damaged_data += dark_image assert np.allclose((damaged_data - dark_image) * gain_map, data) print("Nav dims: ", nav_dims) print("Sig dims:", sig_dims) print("Exclude: ", exclude) masks = (np.random.random((2, ) + sig_dims) - 0.5).astype(np.float64) data_flat = data.reshape((np.prod(nav_dims), np.prod(sig_dims))) damaged_flat = damaged_data.reshape((np.prod(nav_dims), np.prod(sig_dims))) correct_dot = np.dot(data_flat, masks.reshape((-1, np.prod(sig_dims))).T) corrected_masks = detector.correct_dot_masks(masks, gain_map, exclude) assert not is_sparse(corrected_masks) reconstructed_dot =\ np.dot(damaged_flat, corrected_masks.reshape((-1, np.prod(sig_dims))).T)\ - np.dot(dark_image.flatten(), corrected_masks.reshape((-1, np.prod(sig_dims))).T) _check_result( data=correct_dot, corrected=reconstructed_dot, atol=1e-8, rtol=1e-5 )
def test_sparse_dok_is_sparse(): mask = sparse.DOK.from_numpy(_mk_random(size=(16, 16))) assert is_sparse(mask)
def test_scipy_is_sparse(): mask = sp.csr_matrix(_mk_random(size=(16, 16))) assert is_sparse(mask)
def test_numpy_is_sparse(): mask = _mk_random(size=(16, 16)) assert not is_sparse(mask)