class TestDataAxis:

    def setup_method(self, method):
        self._axis = np.arange(16)**2
        self.axis = DataAxis(axis=self._axis)

    def _test_initialisation_parameters(self, axis):
        np.testing.assert_allclose(axis.axis, self._axis)

    def test_initialisation_parameters(self):
        self._test_initialisation_parameters(self.axis)

    def test_create_axis(self):
        axis = create_axis(**self.axis.get_axis_dictionary())
        assert isinstance(axis, DataAxis)
        self._test_initialisation_parameters(axis)

    def test_axis_value(self):
        assert_allclose(self.axis.axis, np.arange(16)**2)
        assert self.axis.size == 16
        assert not self.axis.is_uniform

    def test_update_axes(self):
        values = np.arange(20)**2
        self.axis.axis = values.tolist()
        self.axis.update_axis()
        assert self.axis.size == 20
        assert_allclose(self.axis.axis, values)

    def test_update_axes2(self):
        values = np.array([3, 4, 10, 40])
        self.axis.axis = values
        self.axis.update_axis()
        assert_allclose(self.axis.axis, values)

    def test_update_axis_from_list(self):
        values = np.arange(16)**2
        self.axis.axis = values.tolist()
        self.axis.update_axis()
        assert_allclose(self.axis.axis, values)

    def test_unsorted_axis(self):
        with pytest.raises(ValueError):
            DataAxis(axis=np.array([10, 40, 1, 30, 20]))

    def test_index_changed_event(self):
        ax = self.axis
        m = mock.Mock()
        ax.events.index_changed.connect(m.trigger_me)
        ax.index = ax.index
        assert not m.trigger_me.called
        ax.index += 1
        assert m.trigger_me.called

    def test_value_changed_event(self):
        ax = self.axis
        m = mock.Mock()
        ax.events.value_changed.connect(m.trigger_me)
        ax.value = ax.value
        assert not m.trigger_me.called
        ax.value = ax.value + (ax.axis[1] - ax.axis[0]) * 0.4
        assert not m.trigger_me.called
        ax.value = ax.value + (ax.axis[1] - ax.axis[0]) / 2
        assert not m.trigger_me.called
        ax.value = ax.axis[1]
        assert m.trigger_me.called

    def test_deepcopy(self):
        ac = copy.deepcopy(self.axis)
        np.testing.assert_allclose(ac.axis, np.arange(16)**2)
        ac.name = 'name changed'
        assert ac.name == 'name changed'
        assert self.axis.name != ac.name

    def test_slice_me(self):
        assert self.axis._slice_me(slice(1, 5)) == slice(1, 5)
        assert self.axis.size == 4
        np.testing.assert_allclose(self.axis.axis, np.arange(1, 5)**2)

    def test_slice_me_step(self):
        assert self.axis._slice_me(slice(0, 10, 2)) == slice(0, 10, 2)
        assert self.axis.size == 5
        np.testing.assert_allclose(self.axis.axis, np.arange(0, 10, 2)**2)

    def test_convert_to_uniform_axis(self):
        scale = (self.axis.high_value - self.axis.low_value) / self.axis.size
        is_binned = self.axis.is_binned
        navigate = self.axis.navigate
        self.axis.name = "parrot"
        self.axis.units = "plumage"
        s = Signal1D(np.arange(10), axes=[self.axis])
        index_in_array = s.axes_manager[0].index_in_array
        s.axes_manager[0].convert_to_uniform_axis()
        assert isinstance(s.axes_manager[0], UniformDataAxis)
        assert s.axes_manager[0].name == "parrot"
        assert s.axes_manager[0].units == "plumage"
        assert s.axes_manager[0].size == 16
        assert s.axes_manager[0].scale == scale
        assert s.axes_manager[0].offset == 0
        assert s.axes_manager[0].low_value == 0
        assert s.axes_manager[0].high_value == 15 * scale
        assert index_in_array == s.axes_manager[0].index_in_array
        assert is_binned == s.axes_manager[0].is_binned
        assert navigate == s.axes_manager[0].navigate

    def test_value2index(self):
        assert self.axis.value2index(10.15) == 3
        assert self.axis.value2index(60) == 8
        assert self.axis.value2index(2.5, rounding=round) == 1
        assert self.axis.value2index(2.5, rounding=math.ceil) == 2
        assert self.axis.value2index(2.5, rounding=math.floor) == 1
        # Test that output is integer
        assert isinstance(self.axis.value2index(60), (int, np.integer))
        self.axis.axis = self.axis.axis - 2
        # test rounding on negative value
        assert self.axis.value2index(-1.5, rounding=round) == 1


    def test_value2index_error(self):
        with pytest.raises(ValueError):
            self.axis.value2index(226)

    def test_parse_value_from_relative_string(self):
        ax = self.axis
        assert ax._parse_value_from_string('rel0.0') == 0.0
        assert ax._parse_value_from_string('rel0.5') == 112.5
        assert ax._parse_value_from_string('rel1.0') == 225.0
        with pytest.raises(ValueError):
            ax._parse_value_from_string('rela0.5')
        with pytest.raises(ValueError):
            ax._parse_value_from_string('rel1.5')
        with pytest.raises(ValueError):
            ax._parse_value_from_string('abcd')

    def test_parse_value_from_string_with_units(self):
        ax = self.axis
        ax.units = 'nm'
        with pytest.raises(ValueError):
            ax._parse_value_from_string('0.02 um')

    @pytest.mark.parametrize("use_indices", (False, True))
    def test_crop(self, use_indices):
        axis = DataAxis(axis=self._axis)
        start, end = 4., 196.
        if use_indices:
            start = axis.value2index(start)
            end = axis.value2index(end)
        axis.crop(start, end)
        assert axis.size == 12
        np.testing.assert_almost_equal(axis.axis[0], 4)
        np.testing.assert_almost_equal(axis.axis[-1], 169)

    def test_crop_reverses_indexing(self):
        # reverse indexing
        axis = DataAxis(axis=self._axis)
        axis.crop(-14, -2)
        assert axis.size == 12
        np.testing.assert_almost_equal(axis.axis[0], 4)
        np.testing.assert_almost_equal(axis.axis[-1], 169)

        # mixed reverses indexing
        axis = DataAxis(axis=self._axis)
        axis.crop(2, -2)
        assert axis.size == 12
        np.testing.assert_almost_equal(axis.axis[0], 4)
        np.testing.assert_almost_equal(axis.axis[-1], 169)

    def test_error_DataAxis(self):
        with pytest.raises(ValueError):
            _ = DataAxis(axis=np.arange(16)**2, _type='UniformDataAxis')
        with pytest.raises(AttributeError):
            self.axis.index_in_axes_manager()
        with pytest.raises(IndexError):
            self.axis._get_positive_index(-17)
        with pytest.raises(ValueError):
            self.axis._get_array_slices(slice_=slice(1,2,1.5))
        with pytest.raises(IndexError):
            self.axis._get_array_slices(slice_=slice(225,-1.1,1))
        with pytest.raises(IndexError):
            self.axis._get_array_slices(slice_=slice(225.1,0,1))

    def test_update_from(self):
        ax2 = DataAxis(units="plumage", name="parrot", axis=np.arange(16))
        self.axis.update_from(ax2, attributes=("units", "name"))
        assert ((ax2.units, ax2.name) ==
                (self.axis.units, self.axis.name))

    def test_calibrate(self):
        with pytest.raises(TypeError, match="only for uniform axes"):
            self.axis.calibrate(value_tuple=(11,12), index_tuple=(0,5))