Esempio n. 1
0
class PhononBSPlotterTest(unittest.TestCase):
    def setUp(self):
        with open(os.path.join(test_dir, "NaCl_phonon_bandstructure.json"),
                  "r") as f:
            d = json.loads(f.read())
            self.bs = PhononBandStructureSymmLine.from_dict(d)
            self.plotter = PhononBSPlotter(self.bs)

    def test_bs_plot_data(self):
        self.assertEqual(len(self.plotter.bs_plot_data()['distances'][0]), 51,
                         "wrong number of distances in the first branch")
        self.assertEqual(len(self.plotter.bs_plot_data()['distances']), 4,
                         "wrong number of branches")
        self.assertEqual(
            sum([len(e) for e in self.plotter.bs_plot_data()['distances']]),
            204, "wrong number of distances")
        self.assertEqual(self.plotter.bs_plot_data()['ticks']['label'][4], "Y",
                         "wrong tick label")
        self.assertEqual(len(self.plotter.bs_plot_data()['ticks']['label']), 8,
                         "wrong number of tick labels")

    def test_plot(self):
        # Disabling latex for testing.
        from matplotlib import rc
        rc('text', usetex=False)
        self.plotter.get_plot(units="mev")
Esempio n. 2
0
class PhononBSPlotterTest(unittest.TestCase):

    def setUp(self):
        with open(os.path.join(test_dir, "NaCl_phonon_bandstructure.json"), "r") as f:
            d = json.loads(f.read())
            self.bs = PhononBandStructureSymmLine.from_dict(d)
            self.plotter = PhononBSPlotter(self.bs)

    def test_bs_plot_data(self):
        self.assertEqual(len(self.plotter.bs_plot_data()['distances'][0]), 51,
                         "wrong number of distances in the first branch")
        self.assertEqual(len(self.plotter.bs_plot_data()['distances']), 4,
                         "wrong number of branches")
        self.assertEqual(
            sum([len(e) for e in self.plotter.bs_plot_data()['distances']]),
            204, "wrong number of distances")
        self.assertEqual(self.plotter.bs_plot_data()['ticks']['label'][4], "Y",
                         "wrong tick label")
        self.assertEqual(len(self.plotter.bs_plot_data()['ticks']['label']),
                         8, "wrong number of tick labels")

    def test_plot(self):
        # Disabling latex for testing.
        from matplotlib import rc
        rc('text', usetex=False)
        self.plotter.get_plot(units="mev")
Esempio n. 3
0
class PhononBSPlotterTest(unittest.TestCase):
    def setUp(self):
        with open(
                os.path.join(PymatgenTest.TEST_FILES_DIR,
                             "NaCl_phonon_bandstructure.json")) as f:
            d = json.loads(f.read())
            self.bs = PhononBandStructureSymmLine.from_dict(d)
            self.plotter = PhononBSPlotter(self.bs)

    def test_bs_plot_data(self):
        self.assertEqual(
            len(self.plotter.bs_plot_data()["distances"][0]),
            51,
            "wrong number of distances in the first branch",
        )
        self.assertEqual(len(self.plotter.bs_plot_data()["distances"]), 4,
                         "wrong number of branches")
        self.assertEqual(
            sum(len(e) for e in self.plotter.bs_plot_data()["distances"]),
            204,
            "wrong number of distances",
        )
        self.assertEqual(self.plotter.bs_plot_data()["ticks"]["label"][4], "Y",
                         "wrong tick label")
        self.assertEqual(
            len(self.plotter.bs_plot_data()["ticks"]["label"]),
            8,
            "wrong number of tick labels",
        )

    def test_plot(self):
        # Disabling latex for testing.
        from matplotlib import rc

        rc("text", usetex=False)
        self.plotter.get_plot(units="mev")

    def test_plot_compare(self):
        # Disabling latex for testing.
        from matplotlib import rc

        rc("text", usetex=False)
        self.plotter.plot_compare(self.plotter, units="mev")
Esempio n. 4
0
    def rms_kdep_plot(self, whichkpath=1, filename="rms.eps", format="eps"):
        rms = self.rms_kdep()

        if whichkpath == 1:
            plotter = PhononBSPlotter(bs=self.bs1)
        elif whichkpath == 2:
            plotter = PhononBSPlotter(bs=self.bs2)

        distances = []
        for element in plotter.bs_plot_data()["distances"]:
            distances.extend(element)
        import matplotlib.pyplot as plt
        plt.close("all")
        plt.plot(distances, rms)
        plt.xticks(ticks=plotter.bs_plot_data()["ticks"]["distance"],
                   labels=plotter.bs_plot_data()["ticks"]["label"])
        plt.xlabel("Wave vector")
        plt.ylabel("Phonons RMS (THz)")
        plt.savefig(filename, format=format)
Esempio n. 5
0
 def process_item(self, item):
     mp_id = item['mp-id']
     self.logger.debug("Processing {}".format(mp_id))
     decoder = MontyDecoder()
     ph_bs = decoder.process_decoded(item['ph_bs'])
     web_doc = ph_bs.as_phononwebsite()
     plotter = PhononBSPlotter(ph_bs)
     ylim = (0, max(py_.flatten_deep(plotter.bs_plot_data()['frequency'])))
     filelike = io.BytesIO()
     plotter.save_plot(filelike, ylim=ylim, img_format="png")
     image = Binary(filelike.getvalue())
     filelike.close()
     return dict(mp_id=mp_id, web_doc=web_doc, image=image)
Esempio n. 6
0
    def get_plot(self, units='THz', ymin=None, ymax=None, width=None,
                 height=None, dpi=None, plt=None, fonts=None, dos=None,
                 dos_aspect=3, color=None, style=None, no_base_style=False,
                 from_json=None, legend=None):
        """Get a :obj:`matplotlib.pyplot` object of the phonon band structure.

        Args:
            units (:obj:`str`, optional): Units of phonon frequency. Accepted
                (case-insensitive) values are Thz, cm-1, eV, meV.
            ymin (:obj:`float`, optional): The minimum energy on the y-axis.
            ymax (:obj:`float`, optional): The maximum energy on the y-axis.
            width (:obj:`float`, optional): The width of the plot.
            height (:obj:`float`, optional): The height of the plot.
            dpi (:obj:`int`, optional): The dots-per-inch (pixel density) for
                the image.
            fonts (:obj:`list`, optional): Fonts to use in the plot. Can be a
                a single font, specified as a :obj:`str`, or several fonts,
                specified as a :obj:`list` of :obj:`str`.
            plt (:obj:`matplotlib.pyplot`, optional): A
                :obj:`matplotlib.pyplot` object to use for plotting.
            dos (:obj:`np.ndarray`): 2D Numpy array of total DOS data
            dos_aspect (float): Width division for vertical DOS
            color (:obj:`str` or :obj:`tuple`, optional): Line/fill colour in
                any matplotlib-accepted format
            style (:obj:`list`, :obj:`str`, or :obj:`dict`): Any matplotlib
                style specifications, to be composed on top of Sumo base
                style.
            no_base_style (:obj:`bool`, optional): Prevent use of sumo base
                style. This can make alternative styles behave more
                predictably.
            from_json (:obj:`list` or :obj:`None`, optional): List of paths to
                :obj:`pymatgen.phonon.bandstructure.PhononBandStructureSymmline`
                JSON dump files. These are used to generate additional plots
                displayed under the data attached to this plotter.
                The k-point path should be the same as the main plot; the
                reciprocal lattice is adjusted to fit the scaling of the main
                data input.

        Returns:
            :obj:`matplotlib.pyplot`: The phonon band structure plot.
        """
        if from_json is None:
            from_json = []

        if legend is None:
            legend = [''] * (len(from_json) + 1)
        else:
            if len(legend) == 1 + len(from_json):
                pass
            elif len(legend) == len(from_json):
                legend = [''] + list(legend)
            else:
                raise ValueError('Inappropriate number of legend entries')

        if color is None:
            color = 'C0'  # Default to first colour in matplotlib series

        if dos is not None:
            plt = pretty_subplot(1, 2, width=width, height=height,
                                 sharex=False, sharey=True, dpi=dpi, plt=plt,
                                 gridspec_kw={'width_ratios': [dos_aspect, 1],
                                              'wspace': 0})
            ax = plt.gcf().axes[0]
        else:
            plt = pretty_plot(width, height, dpi=dpi, plt=plt)
            ax = plt.gca()

        def _plot_lines(data, ax, color=None, alpha=1, zorder=1):
            """Pull data from any PhononBSPlotter and add to axis"""
            dists = data['distances']
            freqs = data['frequency']

            # nd is branch index, nb is band index, nk is kpoint index
            for nd, nb in itertools.product(range(len(data['distances'])),
                                            range(self._nb_bands)):
                f = freqs[nd][nb]

                # plot band data
                ax.plot(dists[nd], f, ls='-', c=color,
                        zorder=zorder)

        data = self.bs_plot_data()
        _plot_lines(data, ax, color=color)

        for i, bs_json in enumerate(from_json):
            with open(bs_json, 'rt') as f:
                json_data = json.load(f)
                json_data['lattice_rec'] = json.loads(
                    self._bs.lattice_rec.to_json())
                bs = PhononBandStructureSymmLine.from_dict(json_data)

                # bs.lattice_rec = self._bs.lattice_rec
                # raise Exception(bs.qpoints)
            json_plotter = PhononBSPlotter(bs)
            json_data = json_plotter.bs_plot_data()
            if json_plotter._nb_bands != self._nb_bands:
                raise Exception('Number of bands in {} does not match '
                                'main plot'.format(bs_json))
            _plot_lines(json_data, ax,
                        color='C{}'.format(i + 1),
                        zorder=0.5)

        if any(legend):  # Don't show legend if all entries are empty string
            from matplotlib.lines import Line2D
            ax.legend([Line2D([0], [0], color='C{}'.format(i))
                       for i in range(len(legend))],
                       legend)

        self._maketicks(ax, units=units)
        self._makeplot(ax, plt.gcf(), data, width=width, height=height,
                       ymin=ymin, ymax=ymax, dos=dos, color=color)
        plt.tight_layout()
        plt.subplots_adjust(wspace=0)

        return plt