class PhononBSPlotterTest(unittest.TestCase): def setUp(self): with open(os.path.join(test_dir, "NaCl_phonon_bandstructure.json"), "r") as f: d = json.loads(f.read()) self.bs = PhononBandStructureSymmLine.from_dict(d) self.plotter = PhononBSPlotter(self.bs) def test_bs_plot_data(self): self.assertEqual(len(self.plotter.bs_plot_data()['distances'][0]), 51, "wrong number of distances in the first branch") self.assertEqual(len(self.plotter.bs_plot_data()['distances']), 4, "wrong number of branches") self.assertEqual( sum([len(e) for e in self.plotter.bs_plot_data()['distances']]), 204, "wrong number of distances") self.assertEqual(self.plotter.bs_plot_data()['ticks']['label'][4], "Y", "wrong tick label") self.assertEqual(len(self.plotter.bs_plot_data()['ticks']['label']), 8, "wrong number of tick labels") def test_plot(self): # Disabling latex for testing. from matplotlib import rc rc('text', usetex=False) self.plotter.get_plot(units="mev")
class PhononBSPlotterTest(unittest.TestCase): def setUp(self): with open( os.path.join(PymatgenTest.TEST_FILES_DIR, "NaCl_phonon_bandstructure.json")) as f: d = json.loads(f.read()) self.bs = PhononBandStructureSymmLine.from_dict(d) self.plotter = PhononBSPlotter(self.bs) def test_bs_plot_data(self): self.assertEqual( len(self.plotter.bs_plot_data()["distances"][0]), 51, "wrong number of distances in the first branch", ) self.assertEqual(len(self.plotter.bs_plot_data()["distances"]), 4, "wrong number of branches") self.assertEqual( sum(len(e) for e in self.plotter.bs_plot_data()["distances"]), 204, "wrong number of distances", ) self.assertEqual(self.plotter.bs_plot_data()["ticks"]["label"][4], "Y", "wrong tick label") self.assertEqual( len(self.plotter.bs_plot_data()["ticks"]["label"]), 8, "wrong number of tick labels", ) def test_plot(self): # Disabling latex for testing. from matplotlib import rc rc("text", usetex=False) self.plotter.get_plot(units="mev") def test_plot_compare(self): # Disabling latex for testing. from matplotlib import rc rc("text", usetex=False) self.plotter.plot_compare(self.plotter, units="mev")
def rms_kdep_plot(self, whichkpath=1, filename="rms.eps", format="eps"): rms = self.rms_kdep() if whichkpath == 1: plotter = PhononBSPlotter(bs=self.bs1) elif whichkpath == 2: plotter = PhononBSPlotter(bs=self.bs2) distances = [] for element in plotter.bs_plot_data()["distances"]: distances.extend(element) import matplotlib.pyplot as plt plt.close("all") plt.plot(distances, rms) plt.xticks(ticks=plotter.bs_plot_data()["ticks"]["distance"], labels=plotter.bs_plot_data()["ticks"]["label"]) plt.xlabel("Wave vector") plt.ylabel("Phonons RMS (THz)") plt.savefig(filename, format=format)
def process_item(self, item): mp_id = item['mp-id'] self.logger.debug("Processing {}".format(mp_id)) decoder = MontyDecoder() ph_bs = decoder.process_decoded(item['ph_bs']) web_doc = ph_bs.as_phononwebsite() plotter = PhononBSPlotter(ph_bs) ylim = (0, max(py_.flatten_deep(plotter.bs_plot_data()['frequency']))) filelike = io.BytesIO() plotter.save_plot(filelike, ylim=ylim, img_format="png") image = Binary(filelike.getvalue()) filelike.close() return dict(mp_id=mp_id, web_doc=web_doc, image=image)
def get_plot(self, units='THz', ymin=None, ymax=None, width=None, height=None, dpi=None, plt=None, fonts=None, dos=None, dos_aspect=3, color=None, style=None, no_base_style=False, from_json=None, legend=None): """Get a :obj:`matplotlib.pyplot` object of the phonon band structure. Args: units (:obj:`str`, optional): Units of phonon frequency. Accepted (case-insensitive) values are Thz, cm-1, eV, meV. ymin (:obj:`float`, optional): The minimum energy on the y-axis. ymax (:obj:`float`, optional): The maximum energy on the y-axis. width (:obj:`float`, optional): The width of the plot. height (:obj:`float`, optional): The height of the plot. dpi (:obj:`int`, optional): The dots-per-inch (pixel density) for the image. fonts (:obj:`list`, optional): Fonts to use in the plot. Can be a a single font, specified as a :obj:`str`, or several fonts, specified as a :obj:`list` of :obj:`str`. plt (:obj:`matplotlib.pyplot`, optional): A :obj:`matplotlib.pyplot` object to use for plotting. dos (:obj:`np.ndarray`): 2D Numpy array of total DOS data dos_aspect (float): Width division for vertical DOS color (:obj:`str` or :obj:`tuple`, optional): Line/fill colour in any matplotlib-accepted format style (:obj:`list`, :obj:`str`, or :obj:`dict`): Any matplotlib style specifications, to be composed on top of Sumo base style. no_base_style (:obj:`bool`, optional): Prevent use of sumo base style. This can make alternative styles behave more predictably. from_json (:obj:`list` or :obj:`None`, optional): List of paths to :obj:`pymatgen.phonon.bandstructure.PhononBandStructureSymmline` JSON dump files. These are used to generate additional plots displayed under the data attached to this plotter. The k-point path should be the same as the main plot; the reciprocal lattice is adjusted to fit the scaling of the main data input. Returns: :obj:`matplotlib.pyplot`: The phonon band structure plot. """ if from_json is None: from_json = [] if legend is None: legend = [''] * (len(from_json) + 1) else: if len(legend) == 1 + len(from_json): pass elif len(legend) == len(from_json): legend = [''] + list(legend) else: raise ValueError('Inappropriate number of legend entries') if color is None: color = 'C0' # Default to first colour in matplotlib series if dos is not None: plt = pretty_subplot(1, 2, width=width, height=height, sharex=False, sharey=True, dpi=dpi, plt=plt, gridspec_kw={'width_ratios': [dos_aspect, 1], 'wspace': 0}) ax = plt.gcf().axes[0] else: plt = pretty_plot(width, height, dpi=dpi, plt=plt) ax = plt.gca() def _plot_lines(data, ax, color=None, alpha=1, zorder=1): """Pull data from any PhononBSPlotter and add to axis""" dists = data['distances'] freqs = data['frequency'] # nd is branch index, nb is band index, nk is kpoint index for nd, nb in itertools.product(range(len(data['distances'])), range(self._nb_bands)): f = freqs[nd][nb] # plot band data ax.plot(dists[nd], f, ls='-', c=color, zorder=zorder) data = self.bs_plot_data() _plot_lines(data, ax, color=color) for i, bs_json in enumerate(from_json): with open(bs_json, 'rt') as f: json_data = json.load(f) json_data['lattice_rec'] = json.loads( self._bs.lattice_rec.to_json()) bs = PhononBandStructureSymmLine.from_dict(json_data) # bs.lattice_rec = self._bs.lattice_rec # raise Exception(bs.qpoints) json_plotter = PhononBSPlotter(bs) json_data = json_plotter.bs_plot_data() if json_plotter._nb_bands != self._nb_bands: raise Exception('Number of bands in {} does not match ' 'main plot'.format(bs_json)) _plot_lines(json_data, ax, color='C{}'.format(i + 1), zorder=0.5) if any(legend): # Don't show legend if all entries are empty string from matplotlib.lines import Line2D ax.legend([Line2D([0], [0], color='C{}'.format(i)) for i in range(len(legend))], legend) self._maketicks(ax, units=units) self._makeplot(ax, plt.gcf(), data, width=width, height=height, ymin=ymin, ymax=ymax, dos=dos, color=color) plt.tight_layout() plt.subplots_adjust(wspace=0) return plt