Exemplo n.º 1
0
    def apply(self, data: np.ndarray, tile_slice: Slice):
        """
        Apply corrections in-place to `data`, cropping the
        correction data to the `tile_slice`.
        """
        dark_frame = self.get_dark_frame()
        gain_map = self.get_gain_map()
        excluded_pixels = self.get_excluded_pixels()

        if not self.have_corrections():
            return

        sig_slice = tile_slice.get(sig_only=True)

        if dark_frame is not None:
            dark_frame = dark_frame[sig_slice]
        if gain_map is not None:
            gain_map = gain_map[sig_slice]
        if excluded_pixels is not None:
            excluded_pixels = excluded_pixels[sig_slice]
            excluded_pixels = excluded_pixels.coords
        correct(
            buffer=data,
            dark_image=dark_frame,
            gain_map=gain_map,
            excluded_pixels=excluded_pixels,
            inplace=True,
            sig_shape=tuple(tile_slice.shape.sig),
        )
Exemplo n.º 2
0
    def apply(self, data: np.ndarray, tile_slice: Slice):
        """
        Apply corrections in-place to `data`, cropping the
        correction data to the `tile_slice`.
        """
        dark_frame = self.get_dark_frame()
        gain_map = self.get_gain_map()

        if not self.have_corrections():
            return

        sig_slice = tile_slice.get(sig_only=True)

        if dark_frame is not None:
            dark_frame = dark_frame[sig_slice]
        if gain_map is not None:
            gain_map = gain_map[sig_slice]

        correct(buffer=data,
                dark_image=dark_frame,
                gain_map=gain_map,
                repair_descriptor=self.repair_descriptor(
                    tile_slice.discard_nav()),
                inplace=True,
                sig_shape=tuple(tile_slice.shape.sig),
                allow_empty=self._allow_empty)
Exemplo n.º 3
0
def test_detector_correction():
    for i in range(10):
        print(f"Loop number {i}")
        num_nav_dims = np.random.choice([1, 2, 3])
        num_sig_dims = np.random.choice([1, 2, 3])

        nav_dims = tuple(np.random.randint(low=1, high=16, size=num_nav_dims))
        sig_dims = tuple(np.random.randint(low=1, high=16, size=num_sig_dims))

        data = _make_data(nav_dims, sig_dims)

        # Test pure gain and offset correction without
        # patching pixels
        exclude = _generate_exclude_pixels(sig_dims=sig_dims, num_excluded=0)

        gain_map = np.random.random(sig_dims) + 1
        dark_image = np.random.random(sig_dims)

        damaged_data = data.copy()
        damaged_data /= gain_map
        damaged_data += dark_image

        assert np.allclose((damaged_data - dark_image) * gain_map, data)

        print("Nav dims: ", nav_dims)
        print("Sig dims:", sig_dims)
        print("Exclude: ", exclude)

        corrected = detector.correct(
            buffer=damaged_data,
            dark_image=dark_image,
            gain_map=gain_map,
            excluded_pixels=exclude,
            inplace=False
        )

        _check_result(
            data=data, corrected=corrected,
            atol=1e-8, rtol=1e-5
        )
        # Make sure we didn't do it in place
        assert not np.allclose(corrected, damaged_data)

        detector.correct(
            buffer=damaged_data,
            dark_image=dark_image,
            gain_map=gain_map,
            excluded_pixels=exclude,
            inplace=True
        )

        # Now damaged_data should be modified and equal to corrected
        # since it should have been done in place
        assert np.allclose(corrected, damaged_data)
Exemplo n.º 4
0
def test_detector_uint8():
    for i in range(REPEATS):
        print(f"Loop number {i}")
        num_nav_dims = np.random.choice([1, 2, 3])
        num_sig_dims = np.random.choice([1, 2, 3])

        nav_dims = tuple(np.random.randint(low=8, high=16, size=num_nav_dims))
        sig_dims = tuple(np.random.randint(low=8, high=16, size=num_sig_dims))

        data = np.ones(nav_dims + sig_dims, dtype=np.uint8)

        exclude = exclude_pixels(sig_dims=sig_dims, num_excluded=2)

        gain_map = np.ones(sig_dims)
        dark_image = (np.random.random(sig_dims) * 3).astype(np.uint8)
        # Make sure the dark image is not all zero so that
        # the damaged data is different from the original
        # https://github.com/LiberTEM/LiberTEM/issues/910
        # This is only necessary for an integer dark image
        # since for float it would be extremely unlikely
        # that all values are exactly 0
        atleastone = np.random.randint(0, np.prod(sig_dims))
        dark_image[np.unravel_index(atleastone, sig_dims)] = 1

        damaged_data = data.copy()
        # We don't do that since it is set to 1 above
        # damaged_data /= gain_map
        damaged_data += dark_image

        damaged_data = damaged_data.astype(np.uint8)

        print("Nav dims: ", nav_dims)
        print("Sig dims:", sig_dims)
        print("Exclude: ", exclude)

        corrected = detector.correct(buffer=damaged_data,
                                     dark_image=dark_image,
                                     gain_map=gain_map,
                                     excluded_pixels=exclude,
                                     inplace=False)

        assert corrected.dtype.kind == 'f'

        _check_result(data=data, corrected=corrected, atol=1e-8, rtol=1e-5)
        # Make sure we didn't do it in place
        assert not np.allclose(corrected, damaged_data)

        with pytest.raises(TypeError):
            detector.correct(buffer=damaged_data,
                             dark_image=dark_image,
                             gain_map=gain_map,
                             excluded_pixels=exclude,
                             inplace=True)
Exemplo n.º 5
0
def test_detector_uint8():
    for i in range(10):
        print(f"Loop number {i}")
        num_nav_dims = np.random.choice([1, 2, 3])
        num_sig_dims = np.random.choice([1, 2, 3])

        nav_dims = tuple(np.random.randint(low=8, high=16, size=num_nav_dims))
        sig_dims = tuple(np.random.randint(low=8, high=16, size=num_sig_dims))

        data = np.ones(nav_dims + sig_dims, dtype=np.uint8)

        exclude = _generate_exclude_pixels(sig_dims=sig_dims, num_excluded=2)

        gain_map = np.ones(sig_dims)
        dark_image = (np.random.random(sig_dims) * 3).astype(np.uint8)

        damaged_data = data.copy()
        # We don't do that since it is set to 1 above
        # damaged_data /= gain_map
        damaged_data += dark_image

        damaged_data = damaged_data.astype(np.uint8)

        print("Nav dims: ", nav_dims)
        print("Sig dims:", sig_dims)
        print("Exclude: ", exclude)

        corrected = detector.correct(
            buffer=damaged_data,
            dark_image=dark_image,
            gain_map=gain_map,
            excluded_pixels=exclude,
            inplace=False
        )

        assert corrected.dtype.kind == 'f'

        _check_result(
            data=data, corrected=corrected,
            atol=1e-8, rtol=1e-5
        )
        # Make sure we didn't do it in place
        assert not np.allclose(corrected, damaged_data)

        with pytest.raises(TypeError):
            detector.correct(
                buffer=damaged_data,
                dark_image=dark_image,
                gain_map=gain_map,
                excluded_pixels=exclude,
                inplace=True
            )
Exemplo n.º 6
0
 def preprocess(self, data, tile_slice, decoder):
     dark_frame = decoder.get_dark_frame()
     gain_map = decoder.get_gain_map()
     if dark_frame is None and gain_map is None:
         return
     if dark_frame is not None:
         dark_frame = dark_frame[tile_slice.get(sig_only=True)]
     if gain_map is not None:
         gain_map = gain_map[tile_slice.get(sig_only=True)]
     correct(
         buffer=data,
         dark_image=dark_frame,
         gain_map=gain_map,
         excluded_pixels=None,
         inplace=True,
     )
Exemplo n.º 7
0
def test_detector_patch_too_large():
    for i in range(REPEATS):
        print(f"Loop number {i}")
        num_nav_dims = np.random.choice([2, 3])
        num_sig_dims = 2

        nav_dims = tuple(np.random.randint(low=3, high=5, size=num_nav_dims))
        sig_dims = tuple(
            np.random.randint(low=4 * 32, high=1024, size=num_sig_dims))

        data = gradient_data(nav_dims, sig_dims)

        exclude = exclude_pixels(sig_dims=sig_dims, num_excluded=1001)

        damaged_data = data.copy()
        damaged_data[(Ellipsis, *exclude)] = 1e24

        print("Nav dims: ", nav_dims)
        print("Sig dims:", sig_dims)
        print("Exclude: ", exclude)

        corrected = detector.correct(buffer=damaged_data,
                                     excluded_pixels=exclude,
                                     sig_shape=sig_dims,
                                     inplace=False)

        _check_result(data=data, corrected=corrected, atol=1e-8, rtol=1e-5)
Exemplo n.º 8
0
def test_detector_patch():
    for i in range(REPEATS):
        print(f"Loop number {i}")
        num_nav_dims = np.random.choice([2, 3])
        num_sig_dims = np.random.choice([2, 3])

        nav_dims = tuple(np.random.randint(low=8, high=16, size=num_nav_dims))
        sig_dims = tuple(np.random.randint(low=8, high=16, size=num_sig_dims))

        data = gradient_data(nav_dims, sig_dims)

        gain_map = np.random.random(sig_dims) + 1
        dark_image = np.random.random(sig_dims)

        exclude = exclude_pixels(sig_dims=sig_dims, num_excluded=3)

        damaged_data = data.copy()
        damaged_data /= gain_map
        damaged_data += dark_image
        damaged_data[(Ellipsis, *exclude)] = 1e24

        print("Nav dims: ", nav_dims)
        print("Sig dims:", sig_dims)
        print("Exclude: ", exclude)

        corrected = detector.correct(buffer=damaged_data,
                                     dark_image=dark_image,
                                     gain_map=gain_map,
                                     excluded_pixels=exclude,
                                     inplace=False)

        _check_result(data=data, corrected=corrected, atol=1e-8, rtol=1e-5)
Exemplo n.º 9
0
def dataset_correction_masks(ds, roi, lt_ctx, exclude=None):
    """
    compare correction via sparse mask multiplication w/ correct function
    """
    for i in range(1):
        shape = (-1, *tuple(ds.shape.sig))
        uncorr = CorrectionSet()
        data = lt_ctx.run_udf(udf=PickUDF(), dataset=ds, roi=roi, corrections=uncorr)

        gain = np.random.random(ds.shape.sig) + 1
        dark = np.random.random(ds.shape.sig) - 0.5

        if exclude is None:
            exclude = [
                (np.random.randint(0, s), np.random.randint(0, s))
                for s in tuple(ds.shape.sig)
            ]

        exclude_coo = sparse.COO(coords=exclude, data=True, shape=ds.shape.sig)
        corrset = CorrectionSet(dark=dark, gain=gain, excluded_pixels=exclude_coo)

        def mask_factory():
            s = tuple(ds.shape.sig)
            return sparse.eye(np.prod(s)).reshape((-1, *s))

        # This one casts to float
        mask_res = lt_ctx.run_udf(
            udf=ApplyMasksUDF(mask_factory),
            dataset=ds,
            corrections=corrset,
            roi=roi,
        )
        # This one uses native input data
        corrected = correct(
            buffer=data['intensity'].raw_data.reshape(shape),
            dark_image=dark,
            gain_map=gain,
            excluded_pixels=exclude,
            inplace=False
        )

        print("Exclude: ", exclude)

        print(mask_res['intensity'].raw_data.dtype)
        print(corrected.dtype)

        assert np.allclose(
            mask_res['intensity'].raw_data.reshape(shape),
            corrected
        )
Exemplo n.º 10
0
def dataset_correction_verification(ds, roi, lt_ctx, exclude=None):
    """
    compare correct function w/ corrected pick
    """
    for i in range(1):
        shape = (-1, *tuple(ds.shape.sig))
        uncorr = CorrectionSet()
        data = lt_ctx.run_udf(udf=PickUDF(),
                              dataset=ds,
                              roi=roi,
                              corrections=uncorr)

        gain = np.random.random(ds.shape.sig) + 1
        dark = np.random.random(ds.shape.sig) - 0.5

        if exclude is None:
            exclude = [(np.random.randint(0, s), np.random.randint(0, s))
                       for s in tuple(ds.shape.sig)]

        exclude_coo = sparse.COO(coords=exclude, data=True, shape=ds.shape.sig)
        corrset = CorrectionSet(dark=dark,
                                gain=gain,
                                excluded_pixels=exclude_coo)

        # This one uses native input data
        pick_res = lt_ctx.run_udf(udf=PickUDF(),
                                  dataset=ds,
                                  corrections=corrset,
                                  roi=roi)
        corrected = correct(buffer=data['intensity'].raw_data.reshape(shape),
                            dark_image=dark,
                            gain_map=gain,
                            excluded_pixels=exclude,
                            inplace=False)

        print("Exclude: ", exclude)

        print(pick_res['intensity'].raw_data.dtype)
        print(corrected.dtype)

        assert np.allclose(pick_res['intensity'].raw_data.reshape(shape),
                           corrected)
Exemplo n.º 11
0
def test_detector_patch_overlapping():
    for i in range(10):
        print(f"Loop number {i}")
        num_nav_dims = np.random.choice([2, 3])
        num_sig_dims = np.random.choice([2, 3])

        nav_dims = tuple(np.random.randint(low=8, high=16, size=num_nav_dims))
        sig_dims = tuple(np.random.randint(low=8, high=16, size=num_sig_dims))

        # Faithfully reconstruct in a constant dataset
        data = np.ones(nav_dims + sig_dims)

        gain_map = np.random.random(sig_dims) + 1
        dark_image = np.random.random(sig_dims)

        # Neighboring excluded pixels
        exclude = np.ones((num_sig_dims, 3), dtype=np.int32)
        exclude[0, 1] += 1
        exclude[1, 2] += 1

        damaged_data = data.copy()
        damaged_data /= gain_map
        damaged_data += dark_image
        damaged_data[(Ellipsis, *exclude)] = 1e24

        print("Nav dims: ", nav_dims)
        print("Sig dims:", sig_dims)
        print("Exclude: ", exclude)

        corrected = detector.correct(
            buffer=damaged_data,
            dark_image=dark_image,
            gain_map=gain_map,
            excluded_pixels=exclude,
            inplace=False
        )

        _check_result(
            data=data, corrected=corrected,
            atol=1e-8, rtol=1e-5
        )