Example #1
0
class PDPlotterTest(unittest.TestCase):
    def setUp(self):
        (elements, entries) = PDEntryIO.from_csv(
            os.path.join(module_dir, "pdentries_test.csv"))
        self.pd = PhaseDiagram(entries)
        self.plotter = PDPlotter(self.pd, show_unstable=True)
        entrieslio = [
            e for e in entries
            if len(e.composition) < 3 and ("Fe" not in e.composition)
        ]

        self.pd_formation = PhaseDiagram(entrieslio)
        self.plotter_formation = PDPlotter(self.pd_formation,
                                           show_unstable=True)
        entries.append(PDEntry("C", 0))
        self.pd3d = PhaseDiagram(entries)
        self.plotter3d = PDPlotter(self.pd3d, show_unstable=True)

    def test_pd_plot_data(self):
        (lines, labels, unstable_entries) = self.plotter.pd_plot_data
        self.assertEqual(len(lines), 22)
        self.assertEqual(len(labels), len(self.pd.stable_entries),
                         "Incorrect number of lines generated!")
        self.assertEqual(
            len(unstable_entries),
            len(self.pd.all_entries) - len(self.pd.stable_entries),
            "Incorrect number of lines generated!")
        (lines, labels, unstable_entries) = self.plotter3d.pd_plot_data
        self.assertEqual(len(lines), 33)
        self.assertEqual(len(labels), len(self.pd3d.stable_entries))
        self.assertEqual(
            len(unstable_entries),
            len(self.pd3d.all_entries) - len(self.pd3d.stable_entries))
        (lines, labels, unstable_entries) = self.plotter_formation.pd_plot_data
        self.assertEqual(len(lines), 3)
        self.assertEqual(len(labels), len(self.pd_formation.stable_entries))

    @unittest.skipIf("DISPLAY" not in os.environ, "Need display")
    def test_get_plot(self):
        # Some very basic non-tests. Just to make sure the methods are callable.
        import matplotlib
        matplotlib.use("pdf")

        self.plotter.get_plot()
        self.plotter3d.get_plot()
        self.plotter.get_plot(energy_colormap="Reds", process_attributes=True)
        plt = self.plotter3d.get_plot(energy_colormap="Reds",
                                      process_attributes=True)
        self.plotter.get_plot(energy_colormap="Reds", process_attributes=False)
        plt = self.plotter3d.get_plot(energy_colormap="Reds",
                                      process_attributes=False)
        self.plotter.get_chempot_range_map_plot([Element("Li"), Element("O")])