Exemplo n.º 1
0
    def test_equals_and_identical(self):
        orig = DataArray(np.arange(5.0), {'a': 42}, dims='x')

        expected = orig
        actual = orig.copy()
        self.assertTrue(expected.equals(actual))
        self.assertTrue(expected.identical(actual))

        actual = expected.rename('baz')
        self.assertTrue(expected.equals(actual))
        self.assertFalse(expected.identical(actual))

        actual = expected.rename({'x': 'xxx'})
        self.assertFalse(expected.equals(actual))
        self.assertFalse(expected.identical(actual))

        actual = expected.copy()
        actual.attrs['foo'] = 'bar'
        self.assertTrue(expected.equals(actual))
        self.assertFalse(expected.identical(actual))

        actual = expected.copy()
        actual['x'] = ('x', -np.arange(5))
        self.assertFalse(expected.equals(actual))
        self.assertFalse(expected.identical(actual))

        actual = expected.reset_coords(drop=True)
        self.assertFalse(expected.equals(actual))
        self.assertFalse(expected.identical(actual))

        actual = orig.copy()
        actual[0] = np.nan
        expected = actual.copy()
        self.assertTrue(expected.equals(actual))
        self.assertTrue(expected.identical(actual))

        actual[:] = np.nan
        self.assertFalse(expected.equals(actual))
        self.assertFalse(expected.identical(actual))

        actual = expected.copy()
        actual['a'] = 100000
        self.assertFalse(expected.equals(actual))
        self.assertFalse(expected.identical(actual))
Exemplo n.º 2
0
    def test_equals_and_identical(self):
        orig = DataArray(np.arange(5.0), {'a': 42}, dims='x')

        expected = orig
        actual = orig.copy()
        self.assertTrue(expected.equals(actual))
        self.assertTrue(expected.identical(actual))

        actual = expected.rename('baz')
        self.assertTrue(expected.equals(actual))
        self.assertFalse(expected.identical(actual))

        actual = expected.rename({'x': 'xxx'})
        self.assertFalse(expected.equals(actual))
        self.assertFalse(expected.identical(actual))

        actual = expected.copy()
        actual.attrs['foo'] = 'bar'
        self.assertTrue(expected.equals(actual))
        self.assertFalse(expected.identical(actual))

        actual = expected.copy()
        actual['x'] = ('x', -np.arange(5))
        self.assertFalse(expected.equals(actual))
        self.assertFalse(expected.identical(actual))

        actual = expected.reset_coords(drop=True)
        self.assertFalse(expected.equals(actual))
        self.assertFalse(expected.identical(actual))

        actual = orig.copy()
        actual[0] = np.nan
        expected = actual.copy()
        self.assertTrue(expected.equals(actual))
        self.assertTrue(expected.identical(actual))

        actual[:] = np.nan
        self.assertFalse(expected.equals(actual))
        self.assertFalse(expected.identical(actual))

        actual = expected.copy()
        actual['a'] = 100000
        self.assertFalse(expected.equals(actual))
        self.assertFalse(expected.identical(actual))
Exemplo n.º 3
0
    def test_reset_coords(self):
        data = DataArray(np.zeros((3, 4)), {
            'bar': ('x', ['a', 'b', 'c']),
            'baz': ('y', range(4))
        },
                         dims=['x', 'y'],
                         name='foo')

        actual = data.reset_coords()
        expected = Dataset({
            'foo': (['x', 'y'], np.zeros((3, 4))),
            'bar': ('x', ['a', 'b', 'c']),
            'baz': ('y', range(4))
        })
        self.assertDatasetIdentical(actual, expected)

        actual = data.reset_coords(['bar', 'baz'])
        self.assertDatasetIdentical(actual, expected)

        actual = data.reset_coords('bar')
        expected = Dataset(
            {
                'foo': (['x', 'y'], np.zeros((3, 4))),
                'bar': ('x', ['a', 'b', 'c'])
            }, {'baz': ('y', range(4))})
        self.assertDatasetIdentical(actual, expected)

        actual = data.reset_coords(['bar'])
        self.assertDatasetIdentical(actual, expected)

        actual = data.reset_coords(drop=True)
        expected = DataArray(np.zeros((3, 4)), dims=['x', 'y'], name='foo')
        self.assertDataArrayIdentical(actual, expected)

        actual = data.copy()
        actual.reset_coords(drop=True, inplace=True)
        self.assertDataArrayIdentical(actual, expected)

        actual = data.reset_coords('bar', drop=True)
        expected = DataArray(np.zeros((3, 4)), {'baz': ('y', range(4))},
                             dims=['x', 'y'],
                             name='foo')
        self.assertDataArrayIdentical(actual, expected)

        with self.assertRaisesRegexp(ValueError, 'cannot reset coord'):
            data.reset_coords(inplace=True)
        with self.assertRaises(KeyError):
            data.reset_coords('foo', drop=True)
        with self.assertRaisesRegexp(ValueError, 'cannot be found'):
            data.reset_coords('not_found')
        with self.assertRaisesRegexp(ValueError, 'cannot remove index'):
            data.reset_coords('y')
Exemplo n.º 4
0
    def test_drop_coordinates(self):
        expected = DataArray(np.random.randn(2, 3), dims=['x', 'y'])
        arr = expected.copy()
        arr.coords['z'] = 2
        actual = arr.drop('z')
        self.assertDataArrayIdentical(expected, actual)

        with self.assertRaises(ValueError):
            arr.drop('not found')

        with self.assertRaisesRegexp(ValueError, 'cannot drop'):
            arr.drop(None)

        renamed = arr.rename('foo')
        with self.assertRaisesRegexp(ValueError, 'cannot drop'):
            renamed.drop('foo')
Exemplo n.º 5
0
    def test_reset_coords(self):
        data = DataArray(np.zeros((3, 4)),
                         {'bar': ('x', ['a', 'b', 'c']),
                          'baz': ('y', range(4))},
                         dims=['x', 'y'],
                         name='foo')

        actual = data.reset_coords()
        expected = Dataset({'foo': (['x', 'y'], np.zeros((3, 4))),
                            'bar': ('x', ['a', 'b', 'c']),
                            'baz': ('y', range(4))})
        self.assertDatasetIdentical(actual, expected)

        actual = data.reset_coords(['bar', 'baz'])
        self.assertDatasetIdentical(actual, expected)

        actual = data.reset_coords('bar')
        expected = Dataset({'foo': (['x', 'y'], np.zeros((3, 4))),
                            'bar': ('x', ['a', 'b', 'c'])},
                           {'baz': ('y', range(4))})
        self.assertDatasetIdentical(actual, expected)

        actual = data.reset_coords(['bar'])
        self.assertDatasetIdentical(actual, expected)

        actual = data.reset_coords(drop=True)
        expected = DataArray(np.zeros((3, 4)), dims=['x', 'y'], name='foo')
        self.assertDataArrayIdentical(actual, expected)

        actual = data.copy()
        actual.reset_coords(drop=True, inplace=True)
        self.assertDataArrayIdentical(actual, expected)

        actual = data.reset_coords('bar', drop=True)
        expected = DataArray(np.zeros((3, 4)), {'baz': ('y', range(4))},
                             dims=['x', 'y'], name='foo')
        self.assertDataArrayIdentical(actual, expected)

        with self.assertRaisesRegexp(ValueError, 'cannot reset coord'):
            data.reset_coords(inplace=True)
        with self.assertRaises(KeyError):
            data.reset_coords('foo', drop=True)
        with self.assertRaisesRegexp(ValueError, 'cannot be found'):
            data.reset_coords('not_found')
        with self.assertRaisesRegexp(ValueError, 'cannot remove index'):
            data.reset_coords('y')
Exemplo n.º 6
0
class Common2dMixin:
    """
    Common tests for 2d plotting go here.

    These tests assume that a staticmethod for `self.plotfunc` exists.
    Should have the same name as the method.
    """
    def setUp(self):
        rs = np.random.RandomState(123)
        self.darray = DataArray(rs.randn(10, 15), dims=['y', 'x'])
        self.plotmethod = getattr(self.darray.plot, self.plotfunc.__name__)

    def test_label_names(self):
        self.plotmethod()
        self.assertEqual('x', plt.gca().get_xlabel())
        self.assertEqual('y', plt.gca().get_ylabel())

    def test_1d_raises_valueerror(self):
        with self.assertRaisesRegexp(ValueError, r'[Dd]im'):
            self.plotfunc(self.darray[0, :])

    def test_3d_raises_valueerror(self):
        a = DataArray(np.random.randn(2, 3, 4))
        with self.assertRaisesRegexp(ValueError, r'[Dd]im'):
            self.plotfunc(a)

    def test_nonnumeric_index_raises_typeerror(self):
        a = DataArray(np.random.randn(3, 2),
                      coords=[['a', 'b', 'c'], ['d', 'e']])
        with self.assertRaisesRegexp(TypeError, r'[Pp]lot'):
            self.plotfunc(a)

    def test_can_pass_in_axis(self):
        self.pass_in_axis(self.plotmethod)

    def test_xyincrease_false_changes_axes(self):
        self.plotmethod(xincrease=False, yincrease=False)
        xlim = plt.gca().get_xlim()
        ylim = plt.gca().get_ylim()
        diffs = xlim[0] - 14, xlim[1] - 0, ylim[0] - 9, ylim[1] - 0
        self.assertTrue(all(abs(x) < 1 for x in diffs))

    def test_xyincrease_true_changes_axes(self):
        self.plotmethod(xincrease=True, yincrease=True)
        xlim = plt.gca().get_xlim()
        ylim = plt.gca().get_ylim()
        diffs = xlim[0] - 0, xlim[1] - 14, ylim[0] - 0, ylim[1] - 9
        self.assertTrue(all(abs(x) < 1 for x in diffs))

    def test_plot_nans(self):
        x1 = self.darray[:5]
        x2 = self.darray.copy()
        x2[5:] = np.nan

        clim1 = self.plotfunc(x1).get_clim()
        clim2 = self.plotfunc(x2).get_clim()
        self.assertEqual(clim1, clim2)

    def test_viridis_cmap(self):
        cmap_name = self.plotmethod(cmap='viridis').get_cmap().name
        self.assertEqual('viridis', cmap_name)

    def test_default_cmap(self):
        cmap_name = self.plotmethod().get_cmap().name
        self.assertEqual('RdBu_r', cmap_name)

        cmap_name = self.plotfunc(abs(self.darray)).get_cmap().name
        self.assertEqual('viridis', cmap_name)

    def test_can_change_default_cmap(self):
        cmap_name = self.plotmethod(cmap='Blues').get_cmap().name
        self.assertEqual('Blues', cmap_name)

    def test_diverging_color_limits(self):
        artist = self.plotmethod()
        vmin, vmax = artist.get_clim()
        self.assertAlmostEqual(-vmin, vmax)

    def test_default_title(self):
        a = DataArray(np.random.randn(4, 3, 2), dims=['a', 'b', 'c'])
        a.coords['d'] = 10
        self.plotfunc(a.isel(c=1))
        title = plt.gca().get_title()
        self.assertEqual('c = 1, d = 10', title)

    def test_colorbar_label(self):
        self.darray.name = 'testvar'
        self.plotmethod()
        alltxt = [t.get_text() for t in plt.gcf().findobj(mpl.text.Text)]
        # Set comprehension not compatible with Python 2.6
        alltxt = set(alltxt)
        self.assertIn(self.darray.name, alltxt)

    def test_no_labels(self):
        self.darray.name = 'testvar'
        self.plotmethod(add_labels=False)
        alltxt = [t.get_text() for t in plt.gcf().findobj(mpl.text.Text)]
        alltxt = set(alltxt)
        for string in ['x', 'y', 'testvar']:
            self.assertNotIn(string, alltxt)
Exemplo n.º 7
0
class Common2dMixin:
    """
    Common tests for 2d plotting go here.

    These tests assume that a staticmethod for `self.plotfunc` exists.
    Should have the same name as the method.
    """

    def setUp(self):
        self.darray = DataArray(easy_array(
            (10, 15), start=-1), dims=['y', 'x'])
        self.plotmethod = getattr(self.darray.plot, self.plotfunc.__name__)

    def test_label_names(self):
        self.plotmethod()
        self.assertEqual('x', plt.gca().get_xlabel())
        self.assertEqual('y', plt.gca().get_ylabel())

    def test_1d_raises_valueerror(self):
        with self.assertRaisesRegexp(ValueError, r'[Dd]im'):
            self.plotfunc(self.darray[0, :])

    def test_3d_raises_valueerror(self):
        a = DataArray(easy_array((2, 3, 4)))
        with self.assertRaisesRegexp(ValueError, r'[Dd]im'):
            self.plotfunc(a)

    def test_nonnumeric_index_raises_typeerror(self):
        a = DataArray(easy_array((3, 2)),
                      coords=[['a', 'b', 'c'], ['d', 'e']])
        with self.assertRaisesRegexp(TypeError, r'[Pp]lot'):
            self.plotfunc(a)

    def test_can_pass_in_axis(self):
        self.pass_in_axis(self.plotmethod)

    def test_xyincrease_false_changes_axes(self):
        self.plotmethod(xincrease=False, yincrease=False)
        xlim = plt.gca().get_xlim()
        ylim = plt.gca().get_ylim()
        diffs = xlim[0] - 14, xlim[1] - 0, ylim[0] - 9, ylim[1] - 0
        self.assertTrue(all(abs(x) < 1 for x in diffs))

    def test_xyincrease_true_changes_axes(self):
        self.plotmethod(xincrease=True, yincrease=True)
        xlim = plt.gca().get_xlim()
        ylim = plt.gca().get_ylim()
        diffs = xlim[0] - 0, xlim[1] - 14, ylim[0] - 0, ylim[1] - 9
        self.assertTrue(all(abs(x) < 1 for x in diffs))

    def test_plot_nans(self):
        x1 = self.darray[:5]
        x2 = self.darray.copy()
        x2[5:] = np.nan

        clim1 = self.plotfunc(x1).get_clim()
        clim2 = self.plotfunc(x2).get_clim()
        self.assertEqual(clim1, clim2)

    def test_viridis_cmap(self):
        cmap_name = self.plotmethod(cmap='viridis').get_cmap().name
        self.assertEqual('viridis', cmap_name)

    def test_default_cmap(self):
        cmap_name = self.plotmethod().get_cmap().name
        self.assertEqual('RdBu_r', cmap_name)

        cmap_name = self.plotfunc(abs(self.darray)).get_cmap().name
        self.assertEqual('viridis', cmap_name)

    def test_seaborn_palette_as_cmap(self):
        try:
            import seaborn
            cmap_name = self.plotmethod(
                levels=2, cmap='husl').get_cmap().name
            self.assertEqual('husl', cmap_name)
        except ImportError:
            pass

    def test_can_change_default_cmap(self):
        cmap_name = self.plotmethod(cmap='Blues').get_cmap().name
        self.assertEqual('Blues', cmap_name)

    def test_diverging_color_limits(self):
        artist = self.plotmethod()
        vmin, vmax = artist.get_clim()
        self.assertAlmostEqual(-vmin, vmax)

    def test_xy_strings(self):
        self.plotmethod('y', 'x')
        ax = plt.gca()
        self.assertEqual('y', ax.get_xlabel())
        self.assertEqual('x', ax.get_ylabel())

    def test_positional_x_string(self):
        self.plotmethod('y')
        ax = plt.gca()
        self.assertEqual('y', ax.get_xlabel())
        self.assertEqual('x', ax.get_ylabel())

    def test_y_string(self):
        self.plotmethod(y='x')
        ax = plt.gca()
        self.assertEqual('y', ax.get_xlabel())
        self.assertEqual('x', ax.get_ylabel())

    def test_bad_x_string_exception(self):
        with self.assertRaisesRegexp(KeyError, r'y'):
            self.plotmethod('not_a_real_dim')

        self.darray.coords['z'] = 100
        with self.assertRaisesRegexp(KeyError, r'y'):
            self.plotmethod('z')

    def test_default_title(self):
        a = DataArray(easy_array((4, 3, 2)), dims=['a', 'b', 'c'])
        a.coords['d'] = u'foo'
        self.plotfunc(a.isel(c=1))
        title = plt.gca().get_title()
        self.assertEqual('c = 1, d = foo', title)

    def test_colorbar_label(self):
        self.darray.name = 'testvar'
        self.plotmethod()
        self.assertIn(self.darray.name, text_in_fig())

    def test_no_labels(self):
        self.darray.name = 'testvar'
        self.plotmethod(add_labels=False)
        alltxt = text_in_fig()
        for string in ['x', 'y', 'testvar']:
            self.assertNotIn(string, alltxt)

    def test_verbose_facetgrid(self):
        a = easy_array((10, 15, 3))
        d = DataArray(a, dims=['y', 'x', 'z'])
        g = xplt.FacetGrid(d, col='z')
        g.map_dataarray(self.plotfunc, 'x', 'y')
        for ax in g.axes.flat:
            self.assertTrue(ax.has_data())

    @incompatible_2_6
    def test_2d_function_and_method_signature_same(self):
        func_sig = inspect.getcallargs(self.plotfunc, self.darray)
        method_sig = inspect.getcallargs(self.plotmethod)
        del method_sig['_PlotMethods_obj']
        del func_sig['darray']
        self.assertEqual(func_sig, method_sig)

    def test_convenient_facetgrid(self):
        a = easy_array((10, 15, 4))
        d = DataArray(a, dims=['y', 'x', 'z'])
        g = self.plotfunc(d, x='x', y='y', col='z', col_wrap=2)

        self.assertArrayEqual(g.axes.shape, [2, 2])
        for ax in g.axes.flat:
            self.assertTrue(ax.has_data())

    def test_convenient_facetgrid_4d(self):
        a = easy_array((10, 15, 2, 3))
        d = DataArray(a, dims=['y', 'x', 'columns', 'rows'])
        g = self.plotfunc(d, x='x', y='y', col='columns', row='rows')

        self.assertArrayEqual(g.axes.shape, [3, 2])
        for ax in g.axes.flat:
            self.assertTrue(ax.has_data())
Exemplo n.º 8
0
class Common2dMixin:
    """
    Common tests for 2d plotting go here.

    These tests assume that a staticmethod for `self.plotfunc` exists.
    Should have the same name as the method.
    """

    def setUp(self):
        self.darray = DataArray(easy_array(
            (10, 15), start=-1), dims=['y', 'x'])
        self.plotmethod = getattr(self.darray.plot, self.plotfunc.__name__)

    def test_label_names(self):
        self.plotmethod()
        self.assertEqual('x', plt.gca().get_xlabel())
        self.assertEqual('y', plt.gca().get_ylabel())

    def test_1d_raises_valueerror(self):
        with self.assertRaisesRegexp(ValueError, r'[Dd]im'):
            self.plotfunc(self.darray[0, :])

    def test_3d_raises_valueerror(self):
        a = DataArray(easy_array((2, 3, 4)))
        with self.assertRaisesRegexp(ValueError, r'[Dd]im'):
            self.plotfunc(a)

    def test_nonnumeric_index_raises_typeerror(self):
        a = DataArray(easy_array((3, 2)),
                      coords=[['a', 'b', 'c'], ['d', 'e']])
        with self.assertRaisesRegexp(TypeError, r'[Pp]lot'):
            self.plotfunc(a)

    def test_can_pass_in_axis(self):
        self.pass_in_axis(self.plotmethod)

    def test_xyincrease_false_changes_axes(self):
        self.plotmethod(xincrease=False, yincrease=False)
        xlim = plt.gca().get_xlim()
        ylim = plt.gca().get_ylim()
        diffs = xlim[0] - 14, xlim[1] - 0, ylim[0] - 9, ylim[1] - 0
        self.assertTrue(all(abs(x) < 1 for x in diffs))

    def test_xyincrease_true_changes_axes(self):
        self.plotmethod(xincrease=True, yincrease=True)
        xlim = plt.gca().get_xlim()
        ylim = plt.gca().get_ylim()
        diffs = xlim[0] - 0, xlim[1] - 14, ylim[0] - 0, ylim[1] - 9
        self.assertTrue(all(abs(x) < 1 for x in diffs))

    def test_plot_nans(self):
        x1 = self.darray[:5]
        x2 = self.darray.copy()
        x2[5:] = np.nan

        clim1 = self.plotfunc(x1).get_clim()
        clim2 = self.plotfunc(x2).get_clim()
        self.assertEqual(clim1, clim2)

    def test_viridis_cmap(self):
        cmap_name = self.plotmethod(cmap='viridis').get_cmap().name
        self.assertEqual('viridis', cmap_name)

    def test_default_cmap(self):
        cmap_name = self.plotmethod().get_cmap().name
        self.assertEqual('RdBu_r', cmap_name)

        cmap_name = self.plotfunc(abs(self.darray)).get_cmap().name
        self.assertEqual('viridis', cmap_name)

    def test_seaborn_palette_as_cmap(self):
        try:
            import seaborn
            cmap_name = self.plotmethod(
                levels=2, cmap='husl').get_cmap().name
            self.assertEqual('husl', cmap_name)
        except ImportError:
            pass

    def test_can_change_default_cmap(self):
        cmap_name = self.plotmethod(cmap='Blues').get_cmap().name
        self.assertEqual('Blues', cmap_name)

    def test_diverging_color_limits(self):
        artist = self.plotmethod()
        vmin, vmax = artist.get_clim()
        self.assertAlmostEqual(-vmin, vmax)

    def test_xy_strings(self):
        self.plotmethod('y', 'x')
        ax = plt.gca()
        self.assertEqual('y', ax.get_xlabel())
        self.assertEqual('x', ax.get_ylabel())

    def test_positional_x_string(self):
        self.plotmethod('y')
        ax = plt.gca()
        self.assertEqual('y', ax.get_xlabel())
        self.assertEqual('x', ax.get_ylabel())

    def test_y_string(self):
        self.plotmethod(y='x')
        ax = plt.gca()
        self.assertEqual('y', ax.get_xlabel())
        self.assertEqual('x', ax.get_ylabel())

    def test_bad_x_string_exception(self):
        with self.assertRaisesRegexp(KeyError, r'y'):
            self.plotmethod('not_a_real_dim')

        self.darray.coords['z'] = 100
        with self.assertRaisesRegexp(KeyError, r'y'):
            self.plotmethod('z')

    def test_default_title(self):
        a = DataArray(easy_array((4, 3, 2)), dims=['a', 'b', 'c'])
        a.coords['d'] = u'foo'
        self.plotfunc(a.isel(c=1))
        title = plt.gca().get_title()
        self.assertEqual('c = 1, d = foo', title)

    def test_colorbar_label(self):
        self.darray.name = 'testvar'
        self.plotmethod()
        self.assertIn(self.darray.name, text_in_fig())

    def test_no_labels(self):
        self.darray.name = 'testvar'
        self.plotmethod(add_labels=False)
        alltxt = text_in_fig()
        for string in ['x', 'y', 'testvar']:
            self.assertNotIn(string, alltxt)

    def test_verbose_facetgrid(self):
        a = easy_array((10, 15, 3))
        d = DataArray(a, dims=['y', 'x', 'z'])
        g = xplt.FacetGrid(d, col='z')
        g.map_dataarray(self.plotfunc, 'x', 'y')
        for ax in g.axes.flat:
            self.assertTrue(ax.has_data())

    @incompatible_2_6
    def test_2d_function_and_method_signature_same(self):
        func_sig = inspect.getcallargs(self.plotfunc, self.darray)
        method_sig = inspect.getcallargs(self.plotmethod)
        del method_sig['_PlotMethods_obj']
        del func_sig['darray']
        self.assertEqual(func_sig, method_sig)

    def test_convenient_facetgrid(self):
        a = easy_array((10, 15, 4))
        d = DataArray(a, dims=['y', 'x', 'z'])
        g = self.plotfunc(d, x='x', y='y', col='z', col_wrap=2)

        self.assertArrayEqual(g.axes.shape, [2, 2])
        for ax in g.axes.flat:
            self.assertTrue(ax.has_data())

    def test_convenient_facetgrid_4d(self):
        a = easy_array((10, 15, 2, 3))
        d = DataArray(a, dims=['y', 'x', 'columns', 'rows'])
        g = self.plotfunc(d, x='x', y='y', col='columns', row='rows')

        self.assertArrayEqual(g.axes.shape, [3, 2])
        for ax in g.axes.flat:
            self.assertTrue(ax.has_data())

    def test_facetgrid_cmap(self):
        # Regression test for GH592
        data = (np.random.random(size=(20, 25, 12)) + np.linspace(-3, 3, 12))
        d = DataArray(data, dims=['x', 'y', 'time'])
        fg = d.plot.pcolormesh(col='time')
        # check that all color limits are the same
        self.assertTrue(len(set(m.get_clim() for m in fg._mappables)) == 1)
        # check that all colormaps are the same
        self.assertTrue(len(set(m.get_cmap().name for m in fg._mappables)) == 1)
Exemplo n.º 9
0
class Common2dMixin:
    """
    Common tests for 2d plotting go here.

    These tests assume that a staticmethod for `self.plotfunc` exists.
    Should have the same name as the method.
    """
    def setUp(self):
        rs = np.random.RandomState(123)
        self.darray = DataArray(rs.randn(10, 15), dims=['y', 'x'])
        self.plotmethod = getattr(self.darray.plot, self.plotfunc.__name__)

    def test_label_names(self):
        self.plotmethod()
        self.assertEqual('x', plt.gca().get_xlabel())
        self.assertEqual('y', plt.gca().get_ylabel())

    def test_1d_raises_valueerror(self):
        with self.assertRaisesRegexp(ValueError, r'[Dd]im'):
            self.plotfunc(self.darray[0, :])

    def test_3d_raises_valueerror(self):
        a = DataArray(np.random.randn(2, 3, 4))
        with self.assertRaisesRegexp(ValueError, r'[Dd]im'):
            self.plotfunc(a)

    def test_nonnumeric_index_raises_typeerror(self):
        a = DataArray(np.random.randn(3, 2),
                      coords=[['a', 'b', 'c'], ['d', 'e']])
        with self.assertRaisesRegexp(TypeError, r'[Pp]lot'):
            self.plotfunc(a)

    def test_can_pass_in_axis(self):
        self.pass_in_axis(self.plotmethod)

    def test_xyincrease_false_changes_axes(self):
        self.plotmethod(xincrease=False, yincrease=False)
        xlim = plt.gca().get_xlim()
        ylim = plt.gca().get_ylim()
        diffs = xlim[0] - 14, xlim[1] - 0, ylim[0] - 9, ylim[1] - 0
        self.assertTrue(all(abs(x) < 1 for x in diffs))

    def test_xyincrease_true_changes_axes(self):
        self.plotmethod(xincrease=True, yincrease=True)
        xlim = plt.gca().get_xlim()
        ylim = plt.gca().get_ylim()
        diffs = xlim[0] - 0, xlim[1] - 14, ylim[0] - 0, ylim[1] - 9
        self.assertTrue(all(abs(x) < 1 for x in diffs))

    def test_plot_nans(self):
        x1 = self.darray[:5]
        x2 = self.darray.copy()
        x2[5:] = np.nan

        clim1 = self.plotfunc(x1).get_clim()
        clim2 = self.plotfunc(x2).get_clim()
        self.assertEqual(clim1, clim2)

    def test_viridis_cmap(self):
        cmap_name = self.plotmethod(cmap='viridis').get_cmap().name
        self.assertEqual('viridis', cmap_name)

    def test_default_cmap(self):
        cmap_name = self.plotmethod().get_cmap().name
        self.assertEqual('RdBu_r', cmap_name)

        cmap_name = self.plotfunc(abs(self.darray)).get_cmap().name
        self.assertEqual('viridis', cmap_name)

    def test_can_change_default_cmap(self):
        cmap_name = self.plotmethod(cmap='Blues').get_cmap().name
        self.assertEqual('Blues', cmap_name)

    def test_diverging_color_limits(self):
        artist = self.plotmethod()
        vmin, vmax = artist.get_clim()
        self.assertAlmostEqual(-vmin, vmax)

    def test_default_title(self):
        a = DataArray(np.random.randn(4, 3, 2), dims=['a', 'b', 'c'])
        a.coords['d'] = 10
        self.plotfunc(a.isel(c=1))
        title = plt.gca().get_title()
        self.assertEqual('c = 1, d = 10', title)

    def test_colorbar_label(self):
        self.darray.name = 'testvar'
        self.plotmethod()
        alltxt = [t.get_text() for t in plt.gcf().findobj(mpl.text.Text)]
        # Set comprehension not compatible with Python 2.6
        alltxt = set(alltxt)
        self.assertIn(self.darray.name, alltxt)

    def test_no_labels(self):
        self.darray.name = 'testvar'
        self.plotmethod(add_labels=False)
        alltxt = [t.get_text() for t in plt.gcf().findobj(mpl.text.Text)]
        alltxt = set(alltxt)
        for string in ['x', 'y', 'testvar']:
            self.assertNotIn(string, alltxt)
Exemplo n.º 10
0
class Common2dMixin:
    """
    Common tests for 2d plotting go here.

    These tests assume that a staticmethod for `self.plotfunc` exists.
    Should have the same name as the method.
    """
    def setUp(self):
        rs = np.random.RandomState(123)
        self.darray = DataArray(rs.randn(10, 15), dims=['y', 'x'])
        self.plotmethod = getattr(self.darray.plot, self.plotfunc.__name__)

    def test_label_names(self):
        self.plotmethod()
        self.assertEqual('x', plt.gca().get_xlabel())
        self.assertEqual('y', plt.gca().get_ylabel())

    def test_1d_raises_valueerror(self):
        with self.assertRaisesRegexp(ValueError, r'[Dd]im'):
            self.plotfunc(self.darray[0, :])

    def test_3d_raises_valueerror(self):
        a = DataArray(np.random.randn(2, 3, 4))
        with self.assertRaisesRegexp(ValueError, r'[Dd]im'):
            self.plotfunc(a)

    def test_nonnumeric_index_raises_typeerror(self):
        a = DataArray(np.random.randn(3, 2),
                      coords=[['a', 'b', 'c'], ['d', 'e']])
        with self.assertRaisesRegexp(TypeError, r'[Pp]lot'):
            self.plotfunc(a)

    def test_can_pass_in_axis(self):
        self.pass_in_axis(self.plotmethod)

    def test_xyincrease_false_changes_axes(self):
        self.plotmethod(xincrease=False, yincrease=False)
        xlim = plt.gca().get_xlim()
        ylim = plt.gca().get_ylim()
        diffs = xlim[0] - 14, xlim[1] - 0, ylim[0] - 9, ylim[1] - 0
        self.assertTrue(all(abs(x) < 1 for x in diffs))

    def test_xyincrease_true_changes_axes(self):
        self.plotmethod(xincrease=True, yincrease=True)
        xlim = plt.gca().get_xlim()
        ylim = plt.gca().get_ylim()
        diffs = xlim[0] - 0, xlim[1] - 14, ylim[0] - 0, ylim[1] - 9
        self.assertTrue(all(abs(x) < 1 for x in diffs))

    def test_plot_nans(self):
        x1 = self.darray[:5]
        x2 = self.darray.copy()
        x2[5:] = np.nan

        clim1 = self.plotfunc(x1).get_clim()
        clim2 = self.plotfunc(x2).get_clim()
        self.assertEqual(clim1, clim2)

    def test_viridis_cmap(self):
        cmap_name = self.plotmethod(cmap='viridis').get_cmap().name
        self.assertEqual('viridis', cmap_name)

    def test_default_cmap(self):
        cmap_name = self.plotmethod().get_cmap().name
        self.assertEqual('RdBu_r', cmap_name)

        cmap_name = self.plotfunc(abs(self.darray)).get_cmap().name
        self.assertEqual('viridis', cmap_name)

    def test_can_change_default_cmap(self):
        cmap_name = self.plotmethod(cmap='jet').get_cmap().name
        self.assertEqual('jet', cmap_name)

    def test_diverging_color_limits(self):
        artist = self.plotmethod()
        vmin, vmax = artist.get_clim()
        self.assertAlmostEqual(-vmin, vmax)