def test_flux(self): r = tpt(self.P, self.A, self.B) fig, pos = plot_flux(r) assert type(fig) is matplotlib.figure.Figure # matplotlib.pyplot.show(fig) # x values should be close to committor np.testing.assert_allclose(pos[:, 0], r.committor)
def test_state_labels_flux_auto(self): """ ensure auto generated labels show up in the plot""" A = [0,1] B = [2,4] flux = tpt(self.msm, A, B) fig, pos = plot_flux(flux, state_labels='auto') labels_in_fig = np.array([text.get_text() for text in fig.axes[0].texts]) self.assertEqual((labels_in_fig == "A").sum(), len(A)) self.assertEqual((labels_in_fig == "B").sum(), len(B))
def test_state_labels_flux(self): """ ensure our labels show up in the plot""" flux = tpt(self.msm, [0,1], [2,4]) labels = ['foo', '0', '1', '2', 'bar'] fig, pos = plot_flux(flux, state_labels=labels) labels_in_fig = np.array([text.get_text() for text in fig.axes[0].texts]) for l in labels: self.assertEqual((labels_in_fig == l).sum(), 1)
def test_random(self): C = np.random.randint(0, 1000, size=(10, 10)) P = msmtools.estimation.transition_matrix(C, reversible=True) r = tpt(P, [0], [len(C) - 1]) fig, pos = plot_flux(r)