def test_lmplot_facets(self): g = lm.lmplot(x="x", y="y", data=self.df, row="g", col="h") assert g.axes.shape == (3, 2) g = lm.lmplot(x="x", y="y", data=self.df, col="u", col_wrap=4) assert g.axes.shape == (6, ) g = lm.lmplot(x="x", y="y", data=self.df, hue="h", col="u") assert g.axes.shape == (1, 6)
def test_lmplot_markers(self): g1 = lm.lmplot(x="x", y="y", data=self.df, hue="h", markers="s") assert g1.hue_kws == {"marker": ["s", "s"]} g2 = lm.lmplot(x="x", y="y", data=self.df, hue="h", markers=["o", "s"]) assert g2.hue_kws == {"marker": ["o", "s"]} with pytest.raises(ValueError): lm.lmplot(x="x", y="y", data=self.df, hue="h", markers=["o", "s", "d"])
def test_lmplot_hue(self): g = lm.lmplot(x="x", y="y", data=self.df, hue="h") ax = g.axes[0, 0] assert len(ax.lines) == 2 assert len(ax.collections) == 4
def test_lmplot_scatter_kws(self): g = lm.lmplot(x="x", y="y", hue="h", data=self.df, ci=None) red_scatter, blue_scatter = g.axes[0, 0].collections red, blue = color_palette(n_colors=2) npt.assert_array_equal(red, red_scatter.get_facecolors()[0, :3]) npt.assert_array_equal(blue, blue_scatter.get_facecolors()[0, :3])
def test_lmplot_facet_kws(self): xlim = -4, 20 g = lm.lmplot(data=self.df, x="x", y="y", col="h", facet_kws={"xlim": xlim}) for ax in g.axes.flat: assert ax.get_xlim() == xlim
def test_lmplot_marker_linewidths(self): g = lm.lmplot(x="x", y="y", data=self.df, hue="h", fit_reg=False, markers=["o", "+"]) c = g.axes[0, 0].collections assert c[1].get_linewidths()[0] == mpl.rcParams["lines.linewidth"]
def test_lmplot_basic(self): g = lm.lmplot(x="x", y="y", data=self.df) ax = g.axes[0, 0] assert len(ax.lines) == 1 assert len(ax.collections) == 2 x, y = ax.collections[0].get_offsets().T npt.assert_array_equal(x, self.df.x) npt.assert_array_equal(y, self.df.y)
def test_lmplot_sharey(self): df = pd.DataFrame( dict( x=[0, 1, 2, 0, 1, 2], y=[1, -1, 0, -100, 200, 0], z=["a", "a", "a", "b", "b", "b"], )) with pytest.warns(UserWarning): g = lm.lmplot(data=df, x="x", y="y", col="z", sharey=False) ax1, ax2 = g.axes.flat assert ax1.get_ylim()[0] > ax2.get_ylim()[0] assert ax1.get_ylim()[1] < ax2.get_ylim()[1]
def test_lmplot_facet_truncate(self, sharex): g = lm.lmplot( data=self.df, x="x", y="y", hue="g", col="h", truncate=False, facet_kws=dict(sharex=sharex), ) for ax in g.axes.flat: for line in ax.lines: xdata = line.get_xdata() assert ax.get_xlim() == tuple(xdata[[0, -1]])
def test_lmplot_hue_col_nolegend(self): g = lm.lmplot(x="x", y="y", data=self.df, col="h", hue="h") assert g._legend is None
def test_lmplot_no_data(self): with pytest.raises(TypeError): # keyword argument `data` is required lm.lmplot(x="x", y="y")