示例#1
0
    def test_plot_one_axis(self):
        w = Window(window="gaussian", shape=(5, ), std=2)
        fig, im, cbar = w.plot()

        # Compare to global window GAUSS5_STD2
        np.testing.assert_array_almost_equal(w, GAUSS5_STD2)
        np.testing.assert_array_almost_equal(im.get_array().data[:, 0],
                                             GAUSS5_STD2)
示例#2
0
    def test_make_circular(self, window, shape, answer_coeff, answer_circular,
                           answer_type):
        k = Window(window=window, shape=shape)
        k.make_circular()

        np.testing.assert_array_almost_equal(k, answer_coeff)
        assert k.name == answer_type
        assert k.circular is answer_circular
示例#3
0
    def test_plot_default_values(self):
        w = Window()
        fig, im, cbar = w.plot()

        np.testing.assert_array_almost_equal(w, im.get_array().data)
        assert im.cmap.name == "viridis"
        assert isinstance(fig, Figure)
        assert isinstance(im, AxesImage)
        assert isinstance(cbar, Colorbar)
示例#4
0
 def test_init_general_gaussian(self):
     window = "general_gaussian"
     shape = (5, 5)
     w = Window(
         window=window,
         shape=shape,
         p=0.5,
         std=2,
     )
     assert w.is_valid()
     np.testing.assert_array_almost_equal(w.data,
                                          GENERAL_GAUSS55_PWR05_STD2)
     assert w.name == window
     assert w.shape == shape
示例#5
0
    def test_fft_filter_frequency(
        self,
        dummy_signal,
        shift,
        transfer_function,
        kwargs,
        dtype_out,
        expected_spectrum_sum,
    ):
        if dtype_out is None:
            dtype_out = np.float32
        dummy_signal.data = dummy_signal.data.astype(dtype_out)

        shape = dummy_signal.axes_manager.signal_shape
        w = Window(transfer_function, shape=shape, **kwargs)

        dummy_signal.fft_filter(
            transfer_function=w, function_domain="frequency", shift=shift,
        )

        assert isinstance(dummy_signal, EBSD)
        assert dummy_signal.data.dtype == dtype_out
        assert np.allclose(
            np.sum(fft_spectrum(dummy_signal.inav[0, 0].data)),
            expected_spectrum_sum,
            atol=1e-4,
        )
示例#6
0
    def test_fft_filter(
        self, shift, transfer_function, kwargs, dtype_out, expected_spectrum_sum
    ):
        dtype_in = np.float64

        shape = (101, 101)
        p = np.ones((4,) + shape, dtype=dtype_in)
        this_id = 2
        p[this_id, 50, 50] = 2

        w = Window(transfer_function, shape=shape, **kwargs)

        filter_func = fft_filter

        p_fft = chunk.fft_filter(
            patterns=p,
            filter_func=filter_func,
            transfer_function=w,
            shift=shift,
            dtype_out=dtype_out,
        )

        this_fft = p_fft[this_id]

        if dtype_out is None:
            dtype_out = np.float64

        assert this_fft.dtype == dtype_out
        assert np.allclose(
            np.sum(fft_spectrum.py_func(this_fft)),
            expected_spectrum_sum,
            atol=1e-4,
        )
示例#7
0
def _dynamic_background_frequency_space_setup(
    pattern_shape: Union[List[int], Tuple[int, int]],
    std: Union[int, float],
    truncate: Union[int, float],
) -> Tuple[Tuple[int, int], Tuple[int, int], np.ndarray, Tuple[int, int],
           Tuple[int, int], ]:
    # Get Gaussian filtering window
    shape = (int(truncate * std), ) * 2
    window = Window("gaussian", std=std, shape=shape)
    window = window / (2 * np.pi * std**2)
    window = window / np.sum(window)

    # FFT filter setup
    (
        fft_shape,
        transfer_function,
        offset_before_fft,
        offset_after_ifft,
    ) = _fft_filter_setup(pattern_shape, window)

    return (
        fft_shape,
        window.shape,
        transfer_function,
        offset_before_fft,
        offset_after_ifft,
    )
示例#8
0
    def test_highpass_fft_filter_equal(self):
        shape = (96, 96)
        c = 30
        w_c = c // 2
        w1 = Window("highpass", cutoff=c, cutoff_width=w_c, shape=shape)
        w2 = highpass_fft_filter(shape=shape, cutoff=c)

        assert np.allclose(w1, w2)
示例#9
0
    def test_pad_window(self):
        window_shape = (5, 5)
        ky, kx = window_shape
        w = Window("gaussian", shape=window_shape)
        fft_shape = (10, 10)
        w_padded = barnes._pad_window(window=w, fft_shape=fft_shape)

        assert w_padded.shape == fft_shape
        assert np.allclose(np.sum(w_padded[:ky, :kx]), np.sum(w), atol=1e-5)
示例#10
0
    def test_plot(self, window, answer_coeff, cmap, textcolors, cmap_label,
                  tmp_path):
        w = Window(window=window)

        fig, im, cbar = w.plot(cmap=cmap,
                               textcolors=textcolors,
                               cmap_label=cmap_label)

        np.testing.assert_array_almost_equal(w, answer_coeff)
        np.testing.assert_array_almost_equal(im.get_array().data, answer_coeff)
        assert isinstance(fig, Figure)
        assert isinstance(im, AxesImage)
        assert isinstance(cbar, Colorbar)

        # Check that the figure can be written to and read from file
        os.chdir(tmp_path)
        fname = "tests.png"
        fig.savefig(fname)
        _ = imread(fname)
示例#11
0
    def test_init(
        self,
        window,
        window_type,
        shape,
        kwargs,
        answer_shape,
        answer_coeff,
        answer_circular,
    ):
        if kwargs is None:
            w = Window(window=window, shape=shape)
        else:
            w = Window(window=window, shape=shape, kwargs=kwargs)

        assert w.is_valid()
        assert w.name == window_type
        assert w.shape == answer_shape
        assert w.circular is answer_circular
        np.testing.assert_array_almost_equal(w.data, answer_coeff)
示例#12
0
    def test_average_neighbour_patterns_chunk(self, dummy_signal, dtype_in):
        w = Window()

        # Get array to operate on
        dask_array = _get_dask_array(dummy_signal)
        dtype_out = dask_array.dtype

        # Get sum of window data for each image
        nav_shape = dummy_signal.axes_manager.navigation_shape
        w_sums = convolve(
            input=np.ones(nav_shape[::-1]),
            weights=w.data,
            mode="constant",
            cval=0,
        )

        for i in range(dummy_signal.axes_manager.signal_dimension):
            w_sums = np.expand_dims(w_sums, axis=w_sums.ndim)
        w_sums = da.from_array(w_sums, chunks=dask_array.chunksize)

        # Add signal dimensions to window array to enable its use with Dask's
        # map_blocks()
        w = w.reshape(
            w.shape + (1,) * dummy_signal.axes_manager.signal_dimension
        )

        averaged_patterns = dask_array.map_blocks(
            func=chunk.average_neighbour_patterns,
            window_sums=w_sums,
            window=w,
            dtype_out=dtype_in,
            dtype=dtype_out,
        )

        answer = np.array([7, 4, 6, 6, 3, 7, 7, 3, 2], dtype=np.uint8).reshape(
            (3, 3)
        )

        # Check for correct data type and gives expected output intensities
        assert averaged_patterns.dtype == dtype_out
        assert np.allclose(averaged_patterns[0, 0].compute(), answer)
示例#13
0
    def test_fft_pattern_apodization_window(self, dummy_signal, window):
        p = dummy_signal.inav[0, 0].data
        w = Window(window, shape=p.shape)
        p2 = fft(pattern=p, apodization_window=w, shift=True)
        p3 = fft(pattern=p * w, shift=True)
        p4 = fft(pattern=p, shift=True)

        assert p2.shape == p.shape
        assert p3.shape == p.shape
        assert np.allclose(p2, p3)
        assert not np.allclose(p2, p4, atol=1e-1)
        assert not np.allclose(p3, p4, atol=1e-1)
示例#14
0
    def test_init_from_array(self):
        a = np.arange(5)
        w = Window(a)

        assert isinstance(w, Window)
        assert w.name == "custom"
        assert w.circular is False
        assert np.sum(a) == np.sum(w)

        w2 = w[1:]
        assert isinstance(w2, Window)
        assert w2.name == "custom"
        assert np.sum(a[1:]) == np.sum(w2)
示例#15
0
    def test_is_valid(self):
        change_attribute = np.array([0, 0, 0, 1])

        # Change one attribute at a time and check whether the window is valid
        for i in range(len(change_attribute)):
            w = Window()

            valid_window = True
            if sum(change_attribute[:3]) == 1:
                valid_window = False

            if change_attribute[0]:  # Set type from str to int
                w.name = 1
            elif change_attribute[1]:  # Add a third axis
                w = np.expand_dims(w, 1)
            elif change_attribute[2]:  # Change circular boolean value to str
                w.circular = "True"

            # Roll axis to change which attribute to change next time
            change_attribute = np.roll(change_attribute, 1)

            assert w.is_valid() == valid_window
示例#16
0
    def test_pad_image(self, image_shape, expected_shape):
        p = np.ones(image_shape, dtype=np.uint8)
        w = Window("gaussian", shape=(21, 23), std=5)

        fft_shape, _, off1, _ = barnes._fft_filter_setup(image_shape=p.shape, window=w)
        p_padded = barnes._pad_image(
            image=p,
            fft_shape=fft_shape,
            window_shape=w.shape,
            offset_before_fft=off1,
        )

        sy, sx = p.shape
        assert p_padded.shape == expected_shape
        assert np.allclose(np.sum(p_padded[:sy, :sx]), np.sum(p))
示例#17
0
 def test_average_neighbour_patterns_pass_window(self, dummy_signal):
     w = Window()
     dummy_signal.average_neighbour_patterns(w)
     # fmt: off
     answer = np.array(
         [
             7, 4, 6, 6, 3, 7, 7, 3, 2, 4, 4, 6, 4, 2, 5, 4, 3, 5, 5, 5, 3,
             5, 3, 8, 6, 5, 5, 5, 2, 6, 4, 3, 3, 4, 1, 1, 6, 4, 6, 4, 3, 4,
             5, 5, 3, 5, 3, 3, 3, 3, 5, 3, 4, 5, 5, 3, 7, 4, 4, 2, 3, 4, 1,
             5, 3, 6, 3, 4, 1, 1, 4, 4, 7, 6, 3, 4, 6, 4, 3, 6, 3,
         ],
         dtype=np.uint8
     ).reshape(dummy_signal.axes_manager.shape)
     # fmt: on
     assert np.allclose(dummy_signal.data, answer)
     assert dummy_signal.data.dtype == answer.dtype
示例#18
0
    def test_fft_filter_private(self):
        p = np.ones((60, 60), dtype=np.uint8)
        w = Window("gaussian", shape=(10, 10), std=5)

        fft_shape, window_rfft, off1, off2 = barnes._fft_filter_setup(
            image_shape=p.shape, window=w
        )

        p_filtered = barnes._fft_filter(
            image=p,
            fft_shape=fft_shape,
            window_shape=w.shape,
            transfer_function=window_rfft,
            offset_before_fft=off1,
            offset_after_ifft=off2,
        )

        assert p_filtered.shape == p.shape
示例#19
0
    def test_fft_filter_setup(self):
        p = np.ones((60, 60), dtype=np.uint8)
        w = Window("gaussian", shape=(10, 10), std=5)

        fft_shape, window_rfft, off1, off2 = barnes._fft_filter_setup(
            image_shape=p.shape, window=w
        )

        assert isinstance(fft_shape, tuple)
        assert fft_shape == (72, 72)

        assert np.sum(window_rfft.imag) != 0

        assert isinstance(off1, tuple)
        assert len(off1) == 2
        assert np.issubdtype(type(off1[0]), np.signedinteger)

        assert isinstance(off2, tuple)
        assert len(off2) == 2
        assert np.issubdtype(type(off2[0]), np.signedinteger)
示例#20
0
    def test_average_neighbour_patterns_chunk(self, dummy_signal, dtype_in):
        w = Window()

        # Get array to operate on
        dask_array = get_dask_array(dummy_signal)
        dtype_out = dask_array.dtype

        # Get sum of window data for each image
        nav_shape = dummy_signal.axes_manager.navigation_shape
        w_sums = convolve(
            input=np.ones(nav_shape[::-1], dtype=int),
            weights=w.data,
            mode="constant",
            cval=0,
        )

        # Add signal dimensions to arrays to enable their use with
        # Dask's map_blocks()
        sig_dim = dummy_signal.axes_manager.signal_dimension
        nav_dim = dummy_signal.axes_manager.navigation_dimension
        for _ in range(sig_dim):
            w_sums = np.expand_dims(w_sums, axis=w_sums.ndim)
            w = np.expand_dims(w, axis=w.ndim)
        w_sums = da.from_array(w_sums,
                               chunks=dask_array.chunks[:nav_dim] +
                               (1, ) * sig_dim)

        averaged_patterns = dask_array.map_blocks(
            func=chunk.average_neighbour_patterns,
            window_sums=w_sums,
            window=w,
            dtype_out=dtype_in,
            dtype=dtype_out,
        )

        answer = np.array([255, 109, 218, 218, 36, 236, 255, 36, 0],
                          dtype=np.uint8).reshape((3, 3))

        # Check for correct data type and gives expected output intensities
        assert averaged_patterns.dtype == dtype_out
        assert np.allclose(averaged_patterns[0, 0].compute(), answer)
示例#21
0
class TestWindow:
    @pytest.mark.parametrize(
        ("window, window_type, shape, kwargs, answer_shape, "
         "answer_coeff, answer_circular"),
        [
            ("circular", "circular", (3, 3), None, (3, 3), CIRCULAR33, True),
            (CUSTOM, "custom", (10, 20), None, CUSTOM.shape, CUSTOM, False),
            ("gaussian", "gaussian", (5, 5), 2, (5, 5), GAUSS55_STD2, False),
        ],
    )
    def test_init(
        self,
        window,
        window_type,
        shape,
        kwargs,
        answer_shape,
        answer_coeff,
        answer_circular,
    ):
        if kwargs is None:
            w = Window(window=window, shape=shape)
        else:
            w = Window(window=window, shape=shape, kwargs=kwargs)

        assert w.is_valid()
        assert w.name == window_type
        assert w.shape == answer_shape
        assert w.circular is answer_circular
        np.testing.assert_array_almost_equal(w.data, answer_coeff)

    @pytest.mark.parametrize(
        "window, shape, error_type, match",
        [
            (
                [[0, 1, 0], [1, 1, 1], [0, 1, 0]],
                (5, 5),
                ValueError,
                "Window <class 'list'> must be of type numpy.ndarray,",
            ),
            (
                "boxcar",
                (5, -5),
                ValueError,
                "All window axes .* must be > 0",
            ),
            (
                "boxcar",
                (5, 5.1),
                TypeError,
                "Window shape .* must be a sequence of ints.",
            ),
        ],
    )
    def test_init_raises_errors(self, window, shape, error_type, match):
        with pytest.raises(error_type, match=match):
            _ = Window(window=window, shape=shape)

    @pytest.mark.parametrize("Nx", [3, 5, 7, 8])
    def test_init_passing_nx(self, Nx):
        w = Window(Nx=Nx)
        assert w.shape == (Nx, )

    def test_init_from_array(self):
        a = np.arange(5)
        w = Window(a)

        assert isinstance(w, Window)
        assert w.name == "custom"
        assert w.circular is False
        assert np.sum(a) == np.sum(w)

        w2 = w[1:]
        assert isinstance(w2, Window)
        assert w2.name == "custom"
        assert np.sum(a[1:]) == np.sum(w2)

    def test_init_cast_with_view(self):
        a = np.arange(5)
        w = a.view(Window)
        assert isinstance(w, Window)

    def test_array_finalize_returns_none(self):
        w = Window()
        assert w.__array_finalize__(None) is None

    def test_init_general_gaussian(self):
        window = "general_gaussian"
        shape = (5, 5)
        w = Window(
            window=window,
            shape=shape,
            p=0.5,
            std=2,
        )
        assert w.is_valid()
        np.testing.assert_array_almost_equal(w.data,
                                             GENERAL_GAUSS55_PWR05_STD2)
        assert w.name == window
        assert w.shape == shape

    def test_representation(self):
        w = Window()
        object_type = str(type(w)).strip(">'").split(".")[-1]
        assert w.__repr__() == (f"{object_type} {w.shape} {w.name}\n"
                                "[[0. 1. 0.]\n [1. 1. 1.]\n [0. 1. 0.]]")

    def test_is_valid(self):
        change_attribute = np.array([0, 0, 0, 1])

        # Change one attribute at a time and check whether the window is valid
        for i in range(len(change_attribute)):
            w = Window()

            valid_window = True
            if sum(change_attribute[:3]) == 1:
                valid_window = False

            if change_attribute[0]:  # Set type from str to int
                w.name = 1
            elif change_attribute[1]:  # Add a third axis
                w = np.expand_dims(w, 1)
            elif change_attribute[2]:  # Change circular boolean value to str
                w.circular = "True"

            # Roll axis to change which attribute to change next time
            change_attribute = np.roll(change_attribute, 1)

            assert w.is_valid() == valid_window

    @pytest.mark.parametrize(
        "window, shape, answer_coeff, answer_circular, answer_type",
        [
            # Changes type as well
            ("rectangular", (3, 3), CIRCULAR33, True, "circular"),
            ("boxcar", (3, 3), CIRCULAR33, True, "circular"),
            # Does nothing since window has only one axis
            ("rectangular", (3, ), RECTANGULAR3, False, "rectangular"),
            # Behaves as expected
            ("gaussian", (3, 3), GAUSS33_CIRCULAR, True, "gaussian"),
            # Even axis
            ("rectangular", (5, 4), CIRCULAR54, True, "circular"),
        ],
    )
    def test_make_circular(self, window, shape, answer_coeff, answer_circular,
                           answer_type):
        k = Window(window=window, shape=shape)
        k.make_circular()

        np.testing.assert_array_almost_equal(k, answer_coeff)
        assert k.name == answer_type
        assert k.circular is answer_circular

    @pytest.mark.parametrize(
        "shape, compatible",
        [
            ((3, ), True),
            ((3, 3), True),
            ((3, 4), False),
            ((4, 3), False),
            ((4, 4), False),
        ],
    )
    def test_shape_compatible(self, dummy_signal, shape, compatible):
        w = Window(shape=shape)
        assert (w.shape_compatible(
            dummy_signal.axes_manager.navigation_shape) == compatible)

    def test_plot_default_values(self):
        w = Window()
        fig, im, cbar = w.plot()

        np.testing.assert_array_almost_equal(w, im.get_array().data)
        assert im.cmap.name == "viridis"
        assert isinstance(fig, Figure)
        assert isinstance(im, AxesImage)
        assert isinstance(cbar, Colorbar)

    def test_plot_invalid_window(self):
        w = Window()
        w.name = 1
        assert w.is_valid() is False
        with pytest.raises(ValueError, match="Window is invalid."):
            w.plot()

    @pytest.mark.parametrize(
        "window, answer_coeff, cmap, textcolors, cmap_label",
        [
            (
                "circular",
                CIRCULAR33,
                "viridis",
                ["k", "w"],
                "Coefficient",
            ),
            (
                "rectangular",
                RECTANGULAR33,
                "inferno",
                ["b", "r"],
                "Coeff.",
            ),
        ],
    )
    def test_plot(self, window, answer_coeff, cmap, textcolors, cmap_label,
                  tmp_path):
        w = Window(window=window)

        fig, im, cbar = w.plot(cmap=cmap,
                               textcolors=textcolors,
                               cmap_label=cmap_label)

        np.testing.assert_array_almost_equal(w, answer_coeff)
        np.testing.assert_array_almost_equal(im.get_array().data, answer_coeff)
        assert isinstance(fig, Figure)
        assert isinstance(im, AxesImage)
        assert isinstance(cbar, Colorbar)

        # Check that the figure can be written to and read from file
        os.chdir(tmp_path)
        fname = "tests.png"
        fig.savefig(fname)
        _ = imread(fname)

    def test_plot_one_axis(self):
        w = Window(window="gaussian", shape=(5, ), std=2)
        fig, im, cbar = w.plot()

        # Compare to global window GAUSS5_STD2
        np.testing.assert_array_almost_equal(w, GAUSS5_STD2)
        np.testing.assert_array_almost_equal(im.get_array().data[:, 0],
                                             GAUSS5_STD2)

    @pytest.mark.parametrize(
        "shape, c, w_c, answer",
        [
            (
                (5, 5),
                1,
                1,
                # fmt: off
                np.array([
                    [0.0012, 0.0470, 0.1353, 0.0470, 0.0012],
                    [0.0470, 0.7095, 1., 0.7095, 0.0470],
                    [0.1353, 1., 1., 1., 0.1353],
                    [0.0470, 0.7095, 1., 0.7095, 0.0470],
                    [0.0012, 0.0470, 0.1353, 0.0470, 0.0012],
                ])
                # fmt: on
            ),
            (
                (6, 5),
                2,
                1,
                # fmt: off
                np.array([
                    [0.0057, 0.0670, 0.1353, 0.0670, 0.0057],
                    [0.2534, 0.8945, 1., 0.8945, 0.2534],
                    [0.8945, 1., 1., 1., 0.8945],
                    [1., 1., 1., 1., 1.],
                    [0.8945, 1., 1., 1., 0.8945],
                    [0.2534, 0.8945, 1., 0.8945, 0.2534],
                ])
                # fmt: on
            ),
        ],
    )
    def test_lowpass_fft_filter_direct(self, shape, c, w_c, answer):
        w = lowpass_fft_filter(shape=shape, cutoff=c, cutoff_width=w_c)

        assert w.shape == answer.shape
        assert np.allclose(w, answer, atol=1e-4)

    def test_lowpass_fft_filter_equal(self):
        shape = (96, 96)
        c = 30
        w_c = c // 2
        w1 = Window("lowpass", cutoff=c, cutoff_width=w_c, shape=shape)
        w2 = lowpass_fft_filter(shape=shape, cutoff=c)

        assert np.allclose(w1, w2)

    @pytest.mark.parametrize(
        "shape, c, w_c, answer",
        [
            (
                (5, 5),
                2,
                2,
                # fmt: off
                np.array([
                    [1, 1, 1, 1, 1],
                    [1, 0.8423, 0.6065, 0.8423, 1],
                    [1, 0.6065, 0.1353, 0.6065, 1],
                    [1, 0.8423, 0.6065, 0.8423, 1],
                    [1, 1, 1, 1, 1],
                ])
                # fmt: on
            ),
            (
                (6, 5),
                2,
                1,
                # fmt: off
                np.array([
                    [1, 1, 1, 1, 1],
                    [1, 1, 1, 1, 1],
                    [1, 0.5034, 0.1353, 0.5034, 1],
                    [1, 0.1353, 0.0003, 0.1353, 1],
                    [1, 0.5034, 0.1353, 0.5034, 1],
                    [1, 1, 1, 1, 1],
                ])
                # fmt: on
            ),
        ],
    )
    def test_highpass_fft_filter_direct(self, shape, c, w_c, answer):
        w = highpass_fft_filter(shape=shape, cutoff=c, cutoff_width=w_c)

        assert w.shape == answer.shape
        assert np.allclose(w, answer, atol=1e-4)

    def test_highpass_fft_filter_equal(self):
        shape = (96, 96)
        c = 30
        w_c = c // 2
        w1 = Window("highpass", cutoff=c, cutoff_width=w_c, shape=shape)
        w2 = highpass_fft_filter(shape=shape, cutoff=c)

        assert np.allclose(w1, w2)

    @pytest.mark.parametrize(
        "Nx, answer",
        [
            (3, np.array([0.5, 1, 0.5])),
            # fmt: off
            (11,
             np.array([
                 0.1423, 0.4154, 0.6548, 0.8412, 0.9594, 1., 0.9594, 0.8412,
                 0.6548, 0.4154, 0.1423
             ])),
            # fmt: on
        ],
    )
    def test_modified_hann_direct(self, Nx, answer):
        w = modified_hann.py_func(Nx)

        assert np.allclose(w, answer, atol=1e-4)

    @pytest.mark.parametrize(
        "Nx, answer",
        [(96, 61.1182), (801, 509.9328)],
    )
    def test_modified_hann_direct_sum(self, Nx, answer):
        # py_func ensures coverage for a Numba decorated function
        w = modified_hann.py_func(Nx)

        assert np.allclose(np.sum(w), answer, atol=1e-4)

    def test_modified_hann_equal(self):
        w1 = Window("modified_hann", shape=(30, ))
        w2 = modified_hann(Nx=30)

        assert np.allclose(w1, w2)

    @pytest.mark.parametrize(
        "shape, origin, answer",
        [
            (
                (5, 5),
                None,
                np.array([
                    [2.8284, 2.2360, 2, 2.2360, 2.8284],
                    [2.2360, 1.4142, 1, 1.4142, 2.2360],
                    [2, 1, 0, 1, 2],
                    [2.2360, 1.4142, 1, 1.4142, 2.2360],
                    [2.8284, 2.2360, 2, 2.2360, 2.8284],
                ]),
            ),
            (
                (5, ),
                (2, ),
                np.array([2, 1, 0, 1, 2]),
            ),
            (
                (4, 4),
                (2, 3),
                np.array([
                    [3.6055, 2.8284, 2.2360, 2],
                    [3.1622, 2.2360, 1.4142, 1],
                    [3, 2, 1, 0],
                    [3.1622, 2.2360, 1.4142, 1],
                ]),
            ),
        ],
    )
    def test_distance_to_origin(self, shape, origin, answer):
        r = distance_to_origin(shape=shape, origin=origin)
        assert np.allclose(r, answer, atol=1e-4)

    @pytest.mark.parametrize(
        "std, shape, answer",
        [
            (0.001, (1, ), np.array([[1]])),
            (-0.5, (1, ), Window("gaussian", std=0.5, shape=(1, ))),
            (
                0.5,
                (3, 3),
                Window(
                    np.array([
                        [0.01134374, 0.08381951, 0.01134374],
                        [0.08381951, 0.61934703, 0.08381951],
                        [0.01134374, 0.08381951, 0.01134374],
                    ])),
            ),
        ],
    )
    def test_gaussian(self, std, shape, answer):
        w = Window("gaussian", std=std, shape=shape)
        w = w / (2 * np.pi * std**2)
        w = w / np.sum(w)

        assert np.allclose(w, answer)

    @pytest.mark.parametrize(
        "shape, desired_n_neighbours",
        [
            ((3, 3), (1, 1)),
            ((3, ), (1, )),
            (
                (7, 5),
                (3, 2),
            ),
            ((6, 5), (2, 2)),
            ((5, 7), (2, 3)),
        ],
    )
    def test_n_neighbours(self, shape, desired_n_neighbours):
        assert Window(shape=shape).n_neighbours == desired_n_neighbours
示例#22
0
 def test_pad_window_raises(self):
     w = Window("gaussian", shape=(5, 5))
     with pytest.raises(ValueError,
                        match="could not broadcast input array"):
         _ = barnes._pad_window(window=w, fft_shape=(4, 4))
示例#23
0
 def test_shape_compatible(self, dummy_signal, shape, compatible):
     w = Window(shape=shape)
     assert (w.shape_compatible(
         dummy_signal.axes_manager.navigation_shape) == compatible)
示例#24
0
 def test_plot_invalid_window(self):
     w = Window()
     w.name = 1
     assert w.is_valid() is False
     with pytest.raises(ValueError, match="Window is invalid."):
         w.plot()
示例#25
0
 def test_array_finalize_returns_none(self):
     w = Window()
     assert w.__array_finalize__(None) is None
示例#26
0
 def test_init_passing_nx(self, Nx):
     w = Window(Nx=Nx)
     assert w.shape == (Nx, )
示例#27
0
    def test_modified_hann_equal(self):
        w1 = Window("modified_hann", shape=(30, ))
        w2 = modified_hann(Nx=30)

        assert np.allclose(w1, w2)
示例#28
0
    def test_gaussian(self, std, shape, answer):
        w = Window("gaussian", std=std, shape=shape)
        w = w / (2 * np.pi * std**2)
        w = w / np.sum(w)

        assert np.allclose(w, answer)
示例#29
0
 def test_n_neighbours(self, shape, desired_n_neighbours):
     assert Window(shape=shape).n_neighbours == desired_n_neighbours
示例#30
0
 def test_representation(self):
     w = Window()
     object_type = str(type(w)).strip(">'").split(".")[-1]
     assert w.__repr__() == (f"{object_type} {w.shape} {w.name}\n"
                             "[[0. 1. 0.]\n [1. 1. 1.]\n [0. 1. 0.]]")