コード例 #1
0
ファイル: test_detector.py プロジェクト: LiberTEM/LiberTEM
def test_mask_patch_sparse():
    for i in range(REPEATS):
        print(f"Loop number {i}")
        num_nav_dims = np.random.choice([1, 2, 3])
        num_sig_dims = np.random.choice([2, 3])

        nav_dims = tuple(np.random.randint(low=8, high=16, size=num_nav_dims))
        sig_dims = tuple(np.random.randint(low=8, high=16, size=num_sig_dims))

        # The mask-based correction is performed as float64 since it creates
        # numerical instabilities otherwise
        data = gradient_data(nav_dims, sig_dims).astype(np.float64)

        gain_map = (np.random.random(sig_dims) + 1).astype(np.float64)
        dark_image = np.random.random(sig_dims).astype(np.float64)

        exclude = exclude_pixels(sig_dims=sig_dims, num_excluded=3)

        damaged_data = data.copy()
        damaged_data /= gain_map
        damaged_data += dark_image
        damaged_data[(Ellipsis, *exclude)] = 1e24

        print("Nav dims: ", nav_dims)
        print("Sig dims:", sig_dims)
        print("Exclude: ", exclude)

        masks = sparse.DOK(sparse.zeros((20, ) + sig_dims, dtype=np.float64))
        indices = [
            np.random.randint(low=0, high=s, size=s // 2)
            for s in (20, ) + sig_dims
        ]
        for tup in zip(*indices):
            masks[tup] = 1
        masks = masks.to_coo()

        data_flat = data.reshape((np.prod(nav_dims), np.prod(sig_dims)))
        damaged_flat = damaged_data.reshape(
            (np.prod(nav_dims), np.prod(sig_dims)))

        correct_dot = sparse.dot(data_flat,
                                 masks.reshape((-1, np.prod(sig_dims))).T)
        corrected_masks = detector.correct_dot_masks(masks, gain_map, exclude)
        assert is_sparse(corrected_masks)

        reconstructed_dot =\
            sparse.dot(damaged_flat, corrected_masks.reshape((-1, np.prod(sig_dims))).T)\
            - sparse.dot(dark_image.flatten(), corrected_masks.reshape((-1, np.prod(sig_dims))).T)

        _check_result(data=correct_dot,
                      corrected=reconstructed_dot,
                      atol=1e-8,
                      rtol=1e-5)
コード例 #2
0
def dot_wl(mat, vec):
    print(mat.shape)
    result = np.empty((vec.shape[0], mat.shape[1]))

    if len(mat.shape) == 3:
        for i1 in range(vec.shape[0]):  # loop over wavelengths
            result[i1, :] = dot(mat[i1], vec[i1])

    if len(mat.shape) == 2:
        for i1 in range(vec.shape[0]):  # loop over wavelengths
            result[i1, :] = dot(mat, vec[i1])

    return result
コード例 #3
0
ファイル: masks.py プロジェクト: woozey/LiberTEM
 def __call__(self):
     num_masks = len(self.masks)
     part = zeros_aligned((num_masks, ) + tuple(self.partition.shape.nav),
                          dtype=self.dtype)
     for data_tile in self.partition.get_tiles(mmap=True,
                                               dest_dtype=self.read_dtype):
         flat_data = data_tile.flat_data
         masks = self.masks.get(data_tile, self.mask_dtype)
         if isinstance(masks, sparse.SparseArray):
             result = sparse.dot(flat_data, masks)
         elif scipy.sparse.issparse(masks):
             # This is scipy.sparse using the old matrix interface
             # where "*" is the dot product
             result = flat_data * masks
         elif self.use_torch:
             result = torch.mm(
                 torch.from_numpy(flat_data),
                 torch.from_numpy(masks),
             ).numpy()
         else:
             result = flat_data.dot(masks)
         dest_slice = data_tile.tile_slice.shift(self.partition.slice)
         reshaped = self.reshaped_data(data=result, dest_slice=dest_slice)
         # Ellipsis to match the "number of masks" part of the result
         part[(..., ) + dest_slice.get(nav_only=True)] += reshaped
     return [
         MaskResultTile(
             data=part,
             dest_slice=self.partition.slice.get(nav_only=True),
         )
     ]
コード例 #4
0
ファイル: masks.py プロジェクト: ozej8y/LiberTEM
 def __call__(self):
     num_masks = len(self.masks)
     dest_dtype = np.dtype(self.partition.dtype)
     if dest_dtype.kind not in ('c', 'f'):
         dest_dtype = 'float32'
     part = zeros_aligned((num_masks, ) + tuple(self.partition.shape.nav),
                          dtype=dest_dtype)
     for data_tile in self.partition.get_tiles(mmap=True,
                                               dest_dtype=dest_dtype):
         flat_data = data_tile.flat_data
         masks = self.masks[data_tile]
         if self.masks.use_sparse:
             result = sparse.dot(flat_data, masks)
         elif self.use_torch:
             result = torch.mm(
                 torch.from_numpy(flat_data),
                 torch.from_numpy(masks),
             ).numpy()
         else:
             result = flat_data.dot(masks)
         dest_slice = data_tile.tile_slice.shift(self.partition.slice)
         reshaped = self.reshaped_data(data=result, dest_slice=dest_slice)
         # Ellipsis to match the "number of masks" part of the result
         part[(..., ) + dest_slice.get(nav_only=True)] += reshaped
     return [
         MaskResultTile(
             data=part,
             dest_slice=self.partition.slice.get(nav_only=True),
         )
     ]
コード例 #5
0
ファイル: masks.py プロジェクト: bryanfleming99/TEMproject
    def __call__(self):
        num_masks = len(self.masks)
        part = zeros_aligned((num_masks,) + tuple(self.partition.shape.nav), dtype=self.dtype)

        # FIXME: tileshape negotiation!
        shape = self.partition.shape
        tileshape = Shape(
            (1,) + tuple(shape.sig),
            sig_dims=shape.sig.dims
        )
        tiling_scheme = self.tiling_scheme
        if tiling_scheme is None:
            tiling_scheme = TilingScheme.make_for_shape(
                tileshape=tileshape,
                dataset_shape=shape,  # ...
            )

        tiles = self.partition.get_tiles(
            tiling_scheme=tiling_scheme,
            dest_dtype=self.read_dtype
        )

        with set_num_threads(1):
            try:
                import torch
            except ImportError:
                torch = None
            for data_tile in tiles:
                flat_data = data_tile.flat_data
                masks = self.masks.get(data_tile, self.mask_dtype)
                if isinstance(masks, sparse.SparseArray):
                    result = sparse.dot(flat_data, masks)
                elif scipy.sparse.issparse(masks):
                    # This is scipy.sparse using the old matrix interface
                    # where "*" is the dot product
                    result = flat_data * masks
                elif self.use_torch:
                    result = torch.mm(
                        torch.from_numpy(flat_data),
                        torch.from_numpy(masks),
                    ).numpy()
                else:
                    result = flat_data.dot(masks)
                dest_slice = data_tile.tile_slice.shift(self.partition.slice)
                reshaped = self.reshaped_data(data=result, dest_slice=dest_slice)
                # Ellipsis to match the "number of masks" part of the result
                part[(...,) + dest_slice.get(nav_only=True)] += reshaped
            return [
                MaskResultTile(
                    data=part,
                    dest_slice=self.partition.slice.get(nav_only=True),
                )
            ]
コード例 #6
0
def test_dot_with_sparse():
    A = sparse.random((1024, 64))
    B = sparse.random((64))
    ans = sparse.dot(A, B)

    # dot(sparse.array, sparse.array)
    res = utils.dot(A, B)
    assert_eq(ans, res)

    # dot(sparse.array, dask.array)
    res = utils.dot(A, da.from_array(B, chunks=B.shape))
    assert_eq(ans, res.compute())

    # dot(dask.array, sparse.array)
    res = utils.dot(da.from_array(A, chunks=A.shape), B)
    assert_eq(ans, res.compute())
コード例 #7
0
ファイル: masks.py プロジェクト: twentyse7en/LiberTEM
 def process_tile(self, tile):
     ''
     masks = self.task_data.masks.get(self.meta.slice, transpose=True)
     flat_data = tile.reshape((tile.shape[0], -1))
     if isinstance(masks, sparse.SparseArray):
         result = sparse.dot(flat_data, masks)
     elif scipy.sparse.issparse(masks):
         # This is scipy.sparse using the old matrix interface
         # where "*" is the dot product
         result = flat_data * masks
     elif self.task_data.use_torch:
         result = torch.mm(
             torch.from_numpy(flat_data),
             torch.from_numpy(masks),
         ).numpy()
     else:
         result = flat_data.dot(masks)
     # '+' is the correct merge for dot product
     self.results.intensity[:] += result
コード例 #8
0
ファイル: numpy_backend.py プロジェクト: sz144/tensorly
 def dot(self, x, y):
     if is_sparse(x) or is_sparse(y):
         return sparse.dot(x, y)
     return np.dot(x, y)
コード例 #9
0
def dot_wl_prof(mat, vec):
    result = np.empty((vec.shape[0], mat.shape[1], mat.shape[3]))
    for i1 in range(vec.shape[0]):  # loop over wavelengths
        result[i1, :, :] = dot(mat[i1], vec[i1])
    return result