Example #1
0
class PDPlotterTest(unittest.TestCase):

    def setUp(self):
        (elements, entries) = PDEntryIO.from_csv(os.path.join(module_dir, "pdentries_test.csv"))
        self.pd = PhaseDiagram(entries)
        self.plotter = PDPlotter(self.pd, show_unstable=True)
        entrieslio = [e for e in entries
                   if len(e.composition) < 3 and ("Fe" not in e.composition)]

        self.pd_formation = PhaseDiagram(entrieslio)
        self.plotter_formation = PDPlotter(self.pd_formation, show_unstable=True)
        entries.append(PDEntry("C", 0))
        self.pd3d = PhaseDiagram(entries)
        self.plotter3d = PDPlotter(self.pd3d, show_unstable=True)


    def test_pd_plot_data(self):
        (lines, labels, unstable_entries) = self.plotter.pd_plot_data
        self.assertEqual(len(lines), 22)
        self.assertEqual(len(labels), len(self.pd.stable_entries), "Incorrect number of lines generated!")
        self.assertEqual(len(unstable_entries), len(self.pd.all_entries) - len(self.pd.stable_entries), "Incorrect number of lines generated!")
        (lines, labels, unstable_entries) = self.plotter3d.pd_plot_data
        self.assertEqual(len(lines), 33)
        self.assertEqual(len(labels), len(self.pd3d.stable_entries))
        self.assertEqual(len(unstable_entries),
                         len(self.pd3d.all_entries) - len(self.pd3d.stable_entries))
        (lines, labels, unstable_entries) = self.plotter_formation.pd_plot_data
        self.assertEqual(len(lines), 3)
        self.assertEqual(len(labels), len(self.pd_formation.stable_entries))

    def test_get_plot(self):
        # Some very basic non-tests. Just to make sure the methods are callable.
        import matplotlib
        matplotlib.use("pdf")
        self.plotter.get_plot()
        self.plotter3d.get_plot()
        self.plotter.get_plot(energy_colormap="Reds", process_attributes=True)
        plt = self.plotter3d.get_plot(energy_colormap="Reds", process_attributes=True)
        self.plotter.get_plot(energy_colormap="Reds", process_attributes=False)
        plt = self.plotter3d.get_plot(energy_colormap="Reds",
                                      process_attributes=False)
        self.plotter.get_chempot_range_map_plot([Element("Li"), Element("O")])
        self.plotter.get_contour_pd_plot()