Example #1
0
class PDPlotterTest(unittest.TestCase):
    def setUp(self):
        (elements, entries) = PDEntryIO.from_csv(
            os.path.join(module_dir, "pdentries_test.csv"))
        self.pd = PhaseDiagram(entries)
        self.plotter = PDPlotter(self.pd, show_unstable=True)
        entrieslio = [
            e for e in entries
            if len(e.composition) < 3 and ("Fe" not in e.composition)
        ]

        self.pd_formation = PhaseDiagram(entrieslio)
        self.plotter_formation = PDPlotter(self.pd_formation,
                                           show_unstable=True)
        entries.append(PDEntry("C", 0))
        self.pd3d = PhaseDiagram(entries)
        self.plotter3d = PDPlotter(self.pd3d, show_unstable=True)

    def test_pd_plot_data(self):
        (lines, labels, unstable_entries) = self.plotter.pd_plot_data
        self.assertEqual(len(lines), 22)
        self.assertEqual(len(labels), len(self.pd.stable_entries),
                         "Incorrect number of lines generated!")
        self.assertEqual(
            len(unstable_entries),
            len(self.pd.all_entries) - len(self.pd.stable_entries),
            "Incorrect number of lines generated!")
        (lines, labels, unstable_entries) = self.plotter3d.pd_plot_data
        self.assertEqual(len(lines), 33)
        self.assertEqual(len(labels), len(self.pd3d.stable_entries))
        self.assertEqual(
            len(unstable_entries),
            len(self.pd3d.all_entries) - len(self.pd3d.stable_entries))
        (lines, labels, unstable_entries) = self.plotter_formation.pd_plot_data
        self.assertEqual(len(lines), 3)
        self.assertEqual(len(labels), len(self.pd_formation.stable_entries))

    @unittest.skipIf("DISPLAY" not in os.environ, "Need display")
    def test_get_plot(self):
        # Some very basic non-tests. Just to make sure the methods are callable.
        import matplotlib
        matplotlib.use("pdf")

        self.plotter.get_plot()
        self.plotter3d.get_plot()
        self.plotter.get_plot(energy_colormap="Reds", process_attributes=True)
        plt = self.plotter3d.get_plot(energy_colormap="Reds",
                                      process_attributes=True)
        self.plotter.get_plot(energy_colormap="Reds", process_attributes=False)
        plt = self.plotter3d.get_plot(energy_colormap="Reds",
                                      process_attributes=False)
        self.plotter.get_chempot_range_map_plot([Element("Li"), Element("O")])
Example #2
0
class PDPlotterTest(unittest.TestCase):

    def setUp(self):
        (elements, entries) = PDEntryIO.from_csv(os.path.join(
            module_dir, "pdentries_test.csv"))
        self.pd = PhaseDiagram(entries)
        self.plotter = PDPlotter(self.pd, show_unstable=True)
        entrieslio = [e for e in entries
                   if len(e.composition) < 3 and ("Fe" not in e.composition)]

        self.pd_formation = PhaseDiagram(entrieslio)
        self.plotter_formation = PDPlotter(self.pd_formation, show_unstable=True)
        entries.append(PDEntry("C", 0))
        self.pd3d = PhaseDiagram(entries)
        self.plotter3d = PDPlotter(self.pd3d, show_unstable=True)

    def test_pd_plot_data(self):
        (lines, labels, unstable_entries) = self.plotter.pd_plot_data
        self.assertEqual(len(lines), 22)
        self.assertEqual(len(labels), len(self.pd.stable_entries),
                         "Incorrect number of lines generated!")
        self.assertEqual(len(unstable_entries),
                         len(self.pd.all_entries) - len(self.pd.stable_entries),
                         "Incorrect number of lines generated!")
        (lines, labels, unstable_entries) = self.plotter3d.pd_plot_data
        self.assertEqual(len(lines), 33)
        self.assertEqual(len(labels), len(self.pd3d.stable_entries))
        self.assertEqual(len(unstable_entries),
                         len(self.pd3d.all_entries) - len(self.pd3d.stable_entries))
        (lines, labels, unstable_entries) = self.plotter_formation.pd_plot_data
        self.assertEqual(len(lines), 3)
        self.assertEqual(len(labels), len(self.pd_formation.stable_entries))

    def test_get_plot(self):
        # Some very basic non-tests. Just to make sure the methods are callable.
        self.plotter.get_plot()
        self.plotter3d.get_plot()
        # self.plotter.get_plot(energy_colormap="Reds", process_attributes=True)
        # plt = self.plotter3d.get_plot(energy_colormap="Reds", process_attributes=True)
        # self.plotter.get_plot(energy_colormap="Reds", process_attributes=False)
        # plt = self.plotter3d.get_plot(energy_colormap="Reds",
        #                               process_attributes=False)
        self.plotter.get_chempot_range_map_plot([Element("Li"), Element("O")])
Example #3
0
    def get_phase_diagram_plot(self):
        """
        Returns a phase diagram plot, as a matplotlib plot object.
        """

        # set the font to Times, rendered with Latex
        plt.rc('font', **{'family': 'serif', 'serif': ['Times']})
        plt.rc('text', usetex=True)

        # parse the composition space endpoints
        endpoints_line = self.lines[0].split()
        endpoints = []
        for word in endpoints_line[::-1]:
            if word == 'endpoints:':
                break
            else:
                endpoints.append(Composition(word))

        if len(endpoints) < 2:
            print('There must be at least 2 endpoint compositions to make a '
                  'phase diagram.')
            quit()

        # parse the compositions and total energies of all the structures
        compositions = []
        total_energies = []
        for i in range(4, len(self.lines)):
            line = self.lines[i].split()
            compositions.append(Composition(line[1]))
            total_energies.append(float(line[2]))

        # make a list of PDEntries
        pdentries = []
        for i in range(len(compositions)):
            pdentries.append(PDEntry(compositions[i], total_energies[i]))

        # make a CompoundPhaseDiagram
        compound_pd = CompoundPhaseDiagram(pdentries, endpoints)

        # make a PhaseDiagramPlotter
        pd_plotter = PDPlotter(compound_pd, show_unstable=100)
        return pd_plotter.get_plot(label_unstable=False)
Example #4
0
            gy.append(x2 * math.sqrt(3.0) / 2.0)
    grid_triang = tri.Triangulation(gx, gy)

    fields_strings = [
        "xas normalization to min and max -> normalization factor",
        "xas xmcd minmax -> xmcd max",
        "xas xmcd minmax -> xmcd_min",
        'sum([xmcd max - xmcd min]*f*N) ~ "total magnetic moment"',
    ]
    norms, xmcd_diffs, mag = {}, {}, 0.0
    factors = {"Co": 1.7 / 3.2 / 0.6, "Fe": 2.1 / 3.9 / 0.6}
    for fldidx, fields_str in enumerate(fields_strings):
        fields = fields_str.split(" -> ")
        for elidx, elem in enumerate(chemsys[:-1]):
            print fields_str, elem
            plt = plotter.get_plot()
            title = elem + ": " + fields_str
            if fldidx == 3 and elidx == 1:
                title = fields_str
            plt.suptitle(title, fontsize=24)
            plt.triplot(grid_triang, "k:")

            # heatmap
            x, y, z = [], [], []
            for idx, (comp, cid) in enumerate(comps_cids):
                comp_str = comp if args.dev else doc[cid]["_id"]
                composition = Composition(comp_str)
                x0, x1, x2 = [
                    composition.get_atomic_fraction(el) for el in chemsys
                ]
                x.append(x0 +
            gx.append(x0+x2/2.) # NOTE x0 might need to be replace with x1
            gy.append(x2*math.sqrt(3.)/2.)
    grid_triang = tri.Triangulation(gx, gy)

    fields_strings = [
        'xas normalization to min and max -> normalization factor',
        'xas xmcd minmax -> xmcd max', 'xas xmcd minmax -> xmcd_min',
        'sum([xmcd max - xmcd min]*f*N) ~ "total magnetic moment"'
    ]
    norms, xmcd_diffs, mag = {}, {}, 0.
    factors = {'Co': 1.7/3.2/0.6, 'Fe': 2.1/3.9/0.6}
    for fldidx,fields_str in enumerate(fields_strings):
        fields = fields_str.split(' -> ')
        for elidx,elem in enumerate(chemsys[:-1]):
            print fields_str, elem
            plt = plotter.get_plot()
            title = elem+': '+fields_str
            if fldidx == 3 and elidx == 1: title = fields_str
            plt.suptitle(title, fontsize=24)
            plt.triplot(grid_triang, 'k:')

            # heatmap
            x, y, z = [], [], []
            for idx,(comp,cid) in enumerate(comps_cids):
                comp_str = comp if args.dev else doc[cid]['_id']
                composition = Composition(comp_str)
                x0, x1, x2 = [composition.get_atomic_fraction(el) for el in chemsys]
                x.append(x0+x2/2.) # NOTE x0 might need to be replace with x1
                y.append(x2*math.sqrt(3.)/2.)
                if fldidx < 3:
                    zval = mpfile.document[comp_str]['{} XMCD'.format(elem)][fields[0]][fields[1]]