def get_plot(self, normalize_rxn_coordinate=True, label_barrier=True): """ Returns the NEB plot. Uses Henkelman's approach of spline fitting each section of the reaction path based on tangent force and energies. Args: normalize_rxn_coordinate (bool): Whether to normalize the reaction coordinate to between 0 and 1. Defaults to True. label_barrier (bool): Whether to label the maximum barrier. Returns: matplotlib.pyplot object. """ plt = pretty_plot(12, 8) scale = 1 if not normalize_rxn_coordinate else 1 / self.r[-1] x = np.arange(0, np.max(self.r), 0.01) y = self.spline(x) * 1000 relative_energies = self.energies - self.energies[0] plt.plot(self.r * scale, relative_energies * 1000, 'ro', x * scale, y, 'k-', linewidth=2, markersize=10) plt.xlabel("Reaction coordinate") plt.ylabel("Energy (meV)") plt.ylim((np.min(y) - 10, np.max(y) * 1.02 + 20)) if label_barrier: data = zip(x * scale, y) barrier = max(data, key=lambda d: d[1]) plt.plot([0, barrier[0]], [barrier[1], barrier[1]], 'k--') plt.annotate('%.0f meV' % barrier[1], xy=(barrier[0] / 2, barrier[1] * 1.02), xytext=(barrier[0] / 2, barrier[1] * 1.02), horizontalalignment='center') plt.tight_layout() return plt
def get_plot(self, width=8, height=8, term_zero=True): """ Returns a plot object. Args: width: Width of the plot. Defaults to 8 in. height: Height of the plot. Defaults to 6 in. term_zero: If True append zero voltage point at the end Returns: A matplotlib plot object. """ plt = pretty_plot(width, height) wion_symbol = set() formula = set() for label, electrode in self._electrodes.items(): (x, y) = self.get_plot_data(electrode, term_zero=term_zero) wion_symbol.add(electrode.working_ion.symbol) formula.add(electrode.framework_formula) plt.plot(x, y, "-", linewidth=2, label=label) plt.legend() plt.xlabel( self._choose_best_x_lable(formula=formula, wion_symbol=wion_symbol)) plt.ylabel("Voltage (V)") plt.tight_layout() return plt
def get_plot(self, marker="o", markersize=None, units="thz"): """ will produce a plot Args: marker: marker for the depiction markersize: size of the marker units: unit for the plots, accepted units: thz, ev, mev, ha, cm-1, cm^-1 Returns: plot """ u = freq_units(units) x = self._gruneisen.frequencies.flatten() * u.factor y = self._gruneisen.gruneisen.flatten() plt = pretty_plot(12, 8) plt.xlabel(rf"$\mathrm{{Frequency\ ({u.label})}}$") plt.ylabel(r"$\mathrm{Grüneisen\ parameter}$") n = len(y) - 1 for i, (y, x) in enumerate(zip(y, x)): color = (1.0 / n * i, 0, 1.0 / n * (n - i)) if markersize: plt.plot(x, y, marker, color=color, markersize=markersize) else: plt.plot(x, y, marker, color=color) plt.tight_layout() return plt
def get_arrhenius_plot(self): from pymatgen.util.plotting import pretty_plot plt = pretty_plot(12, 8) arr = np.power(10, self.slope * self.x + self.intercept) plt.plot(self.x, self.diffusivities, 'ko', self.x, arr, 'k--', markersize=10) plt.errorbar(self.x, self.diffusivities, yerr=self.diffusivity_errors, fmt='ko', ecolor='k', capthick=2, linewidth=2) ax = plt.axes() ax.set_yscale('log') plt.text(0.6, 0.85, "E$_a$ = {:.3f} eV".format(self.Ea), fontsize=30, transform=plt.axes().transAxes) plt.ylabel("D (cm$^2$/s)") plt.xlabel("1000/T (K$^{-1}$)") plt.tight_layout() return plt
def get_framework_rms_plot(self, plt=None, granularity=200, matching_s=None): """ Get the plot of rms framework displacement vs time. Useful for checking for melting, especially if framework atoms can move via paddle-wheel or similar mechanism (which would show up in max framework displacement but doesn't constitute melting). Args: plt (matplotlib.pyplot): If plt is supplied, changes will be made to an existing plot. Otherwise, a new plot will be created. granularity (int): Number of structures to match matching_s (Structure): Optionally match to a disordered structure instead of the first structure in the analyzer. Required when a secondary mobile ion is present. Notes: The method doesn't apply to NPT-AIMD simulation analysis. """ from pymatgen.util.plotting import pretty_plot if self.lattices is not None and len(self.lattices) > 1: warnings.warn( "Note the method doesn't apply to NPT-AIMD simulation analysis!" ) plt = pretty_plot(12, 8, plt=plt) step = (self.corrected_displacements.shape[1] - 1) // (granularity - 1) f = (matching_s or self.structure).copy() f.remove_species([self.specie]) sm = StructureMatcher( primitive_cell=False, stol=0.6, comparator=OrderDisorderElementComparator(), allow_subset=True, ) rms = [] for s in self.get_drift_corrected_structures(step=step): s.remove_species([self.specie]) d = sm.get_rms_dist(f, s) if d: rms.append(d) else: rms.append((1, 1)) max_dt = (len(rms) - 1) * step * self.step_skip * self.time_step if max_dt > 100000: plot_dt = np.linspace(0, max_dt / 1000, len(rms)) unit = "ps" else: plot_dt = np.linspace(0, max_dt, len(rms)) unit = "fs" rms = np.array(rms) plt.plot(plot_dt, rms[:, 0], label="RMS") plt.plot(plot_dt, rms[:, 1], label="max") plt.legend(loc="best") plt.xlabel(f"Timestep ({unit})") plt.ylabel("normalized distance") plt.tight_layout() return plt
def get_plot(self, normalize_rxn_coordinate=True, label_barrier=True): """ Returns the NEB plot. Uses Henkelman's approach of spline fitting each section of the reaction path based on tangent force and energies. Args: normalize_rxn_coordinate (bool): Whether to normalize the reaction coordinate to between 0 and 1. Defaults to True. label_barrier (bool): Whether to label the maximum barrier. Returns: matplotlib.pyplot object. """ plt = pretty_plot(12, 8) scale = 1 if not normalize_rxn_coordinate else 1 / self.r[-1] x = np.arange(0, np.max(self.r), 0.01) y = self.spline(x) * 1000 relative_energies = self.energies - self.energies[0] plt.plot(self.r * scale, relative_energies * 1000, 'ro', x * scale, y, 'k-', linewidth=2, markersize=10) plt.xlabel("Reaction coordinate") plt.ylabel("Energy (meV)") plt.ylim((np.min(y) - 10, np.max(y) * 1.02 + 20)) if label_barrier: data = zip(x * scale, y) barrier = max(data, key=lambda d: d[1]) plt.plot([0, barrier[0]], [barrier[1], barrier[1]], 'k--') plt.annotate('%.0f meV' % (np.max(y) - np.min(y)), xy=(barrier[0] / 2, barrier[1] * 1.02), xytext=(barrier[0] / 2, barrier[1] * 1.02), horizontalalignment='center') plt.tight_layout() return plt
def get_rdf_plot(self, label=None, xlim=(0.0, 8.0), ylim=(-0.005, 3.0)): """ Plot the average RDF function. Args: label (str): The legend label. xlim (list): Set the x limits of the current axes. ylim (list): Set the y limits of the current axes. """ if label is None: symbol_list = [e.symbol for e in self.structures[0].composition.keys()] symbol_list = [symbol for symbol in symbol_list if symbol in self.species] if len(symbol_list) == 1: label = symbol_list[0] else: label = "-".join(symbol_list) plt = pretty_plot(12, 8) plt.plot(self.interval, self.rdf, color="r", label=label, linewidth=4.0) plt.xlabel("$r$ ($\\rm\AA$)") plt.ylabel("$g(r)$") plt.legend(loc='upper right', fontsize=36) plt.xlim(xlim[0], xlim[1]) plt.ylim(ylim[0], ylim[1]) plt.tight_layout() return plt
def get_chgint_plot(args): chgcar = Chgcar.from_file(args.chgcar_file) s = chgcar.structure if args.inds: atom_ind = [int(i) for i in args.inds[0].split(",")] else: finder = SpacegroupAnalyzer(s, symprec=0.1) sites = [ sites[0] for sites in finder.get_symmetrized_structure().equivalent_sites ] atom_ind = [s.sites.index(site) for site in sites] from pymatgen.util.plotting import pretty_plot plt = pretty_plot(12, 8) for i in atom_ind: d = chgcar.get_integrated_diff(i, args.radius, 30) plt.plot(d[:, 0], d[:, 1], label="Atom {} - {}".format(i, s[i].species_string)) plt.legend(loc="upper left") plt.xlabel("Radius (A)") plt.ylabel("Integrated charge (e)") plt.tight_layout() return plt
def get_xrd_plot(self, structure, two_theta_range=(0, 90), annotate_peaks=True): """ Returns the XRD plot as a matplotlib.pyplot. Args: structure: Input structure two_theta_range ([float of length 2]): Tuple for range of two_thetas to calculate in degrees. Defaults to (0, 90). Set to None if you want all diffracted beams within the limiting sphere of radius 2 / wavelength. annotate_peaks: Whether to annotate the peaks with plane information. Returns: (matplotlib.pyplot) """ from pymatgen.util.plotting import pretty_plot plt = pretty_plot(16, 10) for two_theta, i, hkls, d_hkl in self.get_xrd_data( structure, two_theta_range=two_theta_range): if two_theta_range[0] <= two_theta <= two_theta_range[1]: label = ", ".join([str(hkl) for hkl in hkls.keys()]) plt.plot([two_theta, two_theta], [0, i], color='k', linewidth=3, label=label) if annotate_peaks: plt.annotate(label, xy=[two_theta, i], xytext=[two_theta, i], fontsize=16) plt.xlabel(r"$2\theta$ ($^\circ$)") plt.ylabel("Intensities (scaled)") plt.tight_layout() return plt
def plot_conc_temp(self, me=[1.0, 1.0, 1.0], mh=[1.0, 1.0, 1.0]): """ plot the concentration of carriers vs temperature both in eq and non-eq after quenching at 300K Args: me: the effective mass for the electrons as a list of 3 eigenvalues mh: the effective mass for the holes as a list of 3 eigenvalues Returns; a matplotlib object """ temps = [i * 100 for i in range(3, 20)] qi = [] qi_non_eq = [] for t in temps: qi.append(self._analyzer.get_eq_Ef(t, me, mh)['Qi'] * 1e-6) qi_non_eq.append( self._analyzer.get_non_eq_Ef(t, 300, me, mh)['Qi'] * 1e-6) plt = pretty_plot(12, 8) plt.xlabel("temperature (K)") plt.ylabel("carrier concentration (cm$^{-3}$)") plt.semilogy(temps, qi, linewidth=3.0) plt.semilogy(temps, qi_non_eq, linewidth=3) plt.legend(['eq', 'non-eq']) return plt
def plot(self, width=8, height=None, plt=None, dpi=None, **kwargs): """ Plot the equation of state. Args: width (float): Width of plot in inches. Defaults to 8in. height (float): Height of plot in inches. Defaults to width * golden ratio. plt (matplotlib.pyplot): If plt is supplied, changes will be made to an existing plot. Otherwise, a new plot will be created. dpi: kwargs (dict): additional args fed to pyplot.plot. supported keys: style, color, text, label Returns: Matplotlib plot object. """ # pylint: disable=E1307 plt = pretty_plot(width=width, height=height, plt=plt, dpi=dpi) color = kwargs.get("color", "r") label = kwargs.get("label", "{} fit".format(self.__class__.__name__)) lines = [ "Equation of State: %s" % self.__class__.__name__, "Minimum energy = %1.2f eV" % self.e0, "Minimum or reference volume = %1.2f Ang^3" % self.v0, "Bulk modulus = %1.2f eV/Ang^3 = %1.2f GPa" % (self.b0, self.b0_GPa), "Derivative of bulk modulus wrt pressure = %1.2f" % self.b1, ] text = "\n".join(lines) text = kwargs.get("text", text) # Plot input data. plt.plot(self.volumes, self.energies, linestyle="None", marker="o", color=color) # Plot eos fit. vmin, vmax = min(self.volumes), max(self.volumes) vmin, vmax = (vmin - 0.01 * abs(vmin), vmax + 0.01 * abs(vmax)) vfit = np.linspace(vmin, vmax, 100) plt.plot(vfit, self.func(vfit), linestyle="dashed", color=color, label=label) plt.grid(True) plt.xlabel("Volume $\\AA^3$") plt.ylabel("Energy (eV)") plt.legend(loc="best", shadow=True) # Add text with fit parameters. plt.text(0.4, 0.5, text, transform=plt.gca().transAxes) return plt
def get_pourbaix_plot(self, limits=None, title="", label_domains=True, plt=None): """ Plot Pourbaix diagram. Args: limits: 2D list containing limits of the Pourbaix diagram of the form [[xlo, xhi], [ylo, yhi]] title (str): Title to display on plot label_domains (bool): whether to label pourbaix domains plt (pyplot): Pyplot instance for plotting Returns: plt (pyplot) - matplotlib plot object with pourbaix diagram """ if limits is None: limits = [[-2, 16], [-3, 3]] plt = plt or pretty_plot(16) xlim = limits[0] ylim = limits[1] h_line = np.transpose([[xlim[0], -xlim[0] * PREFAC], [xlim[1], -xlim[1] * PREFAC]]) o_line = np.transpose([[xlim[0], -xlim[0] * PREFAC + 1.23], [xlim[1], -xlim[1] * PREFAC + 1.23]]) neutral_line = np.transpose([[7, ylim[0]], [7, ylim[1]]]) V0_line = np.transpose([[xlim[0], 0], [xlim[1], 0]]) ax = plt.gca() ax.set_xlim(xlim) ax.set_ylim(ylim) lw = 3 plt.plot(h_line[0], h_line[1], "r--", linewidth=lw) plt.plot(o_line[0], o_line[1], "r--", linewidth=lw) plt.plot(neutral_line[0], neutral_line[1], "k-.", linewidth=lw) plt.plot(V0_line[0], V0_line[1], "k-.", linewidth=lw) for entry, vertices in self._pbx._stable_domain_vertices.items(): center = np.average(vertices, axis=0) x, y = np.transpose(np.vstack([vertices, vertices[0]])) plt.plot(x, y, 'k-', linewidth=lw) if label_domains: plt.annotate(generate_entry_label(entry), center, ha='center', va='center', fontsize=20, color="b").draggable() plt.xlabel("pH") plt.ylabel("E (V)") plt.title(title, fontsize=20, fontweight='bold') return plt
def get_plot(self, structure, two_theta_range=(0, 90), annotate_peaks=True, ax=None, with_labels=True, fontsize=16): """ Returns the diffraction plot as a matplotlib.pyplot. Args: structure: Input structure two_theta_range ([float of length 2]): Tuple for range of two_thetas to calculate in degrees. Defaults to (0, 90). Set to None if you want all diffracted beams within the limiting sphere of radius 2 / wavelength. annotate_peaks: Whether to annotate the peaks with plane information. ax: matplotlib :class:`Axes` or None if a new figure should be created. with_labels: True to add xlabels and ylabels to the plot. fontsize: (int) fontsize for peak labels. Returns: (matplotlib.pyplot) """ if ax is None: from pymatgen.util.plotting import pretty_plot plt = pretty_plot(16, 10) ax = plt.gca() else: # This to maintain the type of the return value. import matplotlib.pyplot as plt xrd = self.get_pattern(structure, two_theta_range=two_theta_range) for two_theta, i, hkls, d_hkl in zip(xrd.x, xrd.y, xrd.hkls, xrd.d_hkls): if two_theta_range[0] <= two_theta <= two_theta_range[1]: print(hkls) label = ", ".join([str(hkl["hkl"]) for hkl in hkls]) ax.plot([two_theta, two_theta], [0, i], color='k', linewidth=3, label=label) if annotate_peaks: ax.annotate(label, xy=[two_theta, i], xytext=[two_theta, i], fontsize=fontsize) if with_labels: ax.set_xlabel(r"$2\theta$ ($^\circ$)") ax.set_ylabel("Intensities (scaled)") if hasattr(ax, "tight_layout"): ax.tight_layout() return plt
def _get_matplotlib_figure(self) -> plt.Figure: """Returns a matplotlib figure of reaction kinks diagram""" pretty_plot(8, 5) plt.xlim([-0.05, 1.05]) # plot boundary is 5% wider on each side kinks = list(zip(*self.get_kinks())) # type: ignore _, x, energy, reactions, _ = kinks plt.plot(x, energy, "o-", markersize=8, c="navy", zorder=1) plt.scatter(self.minimum[0], self.minimum[1], marker="*", c="red", s=400, zorder=2) for x_coord, y_coord, rxn in zip(x, energy, reactions): products = ", ".join([ latexify(p.reduced_formula) for p in rxn.products # type: ignore if not np.isclose(rxn.get_coeff(p), 0) # type: ignore ]) plt.annotate( products, xy=(x_coord, y_coord), xytext=(10, -30), textcoords="offset points", ha="right", va="bottom", arrowprops=dict(arrowstyle="->", connectionstyle="arc3,rad=0"), ) if self.norm: plt.ylabel("Energy (eV/atom)") else: plt.ylabel("Energy (eV/f.u.)") plt.xlabel(self._get_xaxis_title()) plt.ylim(self.minimum[1] + 0.05 * self.minimum[1]) # plot boundary is 5% lower fig = plt.gcf() plt.close(fig) return fig
def get_plot(self, xlim=None, ylim=None): """ Get a matplotlib plot showing the DOS. Args: xlim: Specifies the x-axis limits. Set to None for automatic determination. ylim: Specifies the y-axis limits. """ plt = pretty_plot(7, 0) base = 0.0 i = 0 for key, sp in self._spectra.items(): if not self.stack: plt.plot( sp.x, sp.y + self.yshift * i, color=self.colors[i], label=str(key), linewidth=3, ) else: plt.fill_between( sp.x, base, sp.y + self.yshift * i, color=self.colors[i], label=str(key), linewidth=3, ) base = sp.y + base plt.xlabel('Número de onda ' + r'($cm^{-1}$)') plt.ylabel('Intensidadade (u.a.)') i += 1 if xlim: plt.xlim(xlim) if ylim: plt.ylim(ylim) """ ************************************************************************* Configuração feito para ordenar a legenda ************************************************************************* """ # current_handles, current_labels = plt.gca().get_legend_handles_labels() # reversed_handles = list(reversed(current_handles)) # reversed_labels = list(reversed(current_labels)) # plt.legend(reversed_handles, reversed_labels) # *********************************************************************** plt.legend() leg = plt.gca().get_legend() ltext = leg.get_texts() # all the text.Text instance in the legend plt.setp(ltext, fontsize=30) plt.tight_layout() return plt
def get_arrhenius_plot(temps, diffusivities, diffusivity_errors=None, **kwargs): r""" Returns an Arrhenius plot. Args: temps ([float]): A sequence of temperatures. diffusivities ([float]): A sequence of diffusivities (e.g., from DiffusionAnalyzer.diffusivity). diffusivity_errors ([float]): A sequence of errors for the diffusivities. If None, no error bar is plotted. \\*\\*kwargs: Any keyword args supported by matplotlib.pyplot.plot. Returns: A matplotlib.pyplot object. Do plt.show() to show the plot. """ Ea, c, _ = fit_arrhenius(temps, diffusivities) from pymatgen.util.plotting import pretty_plot plt = pretty_plot(12, 8) # log10 of the arrhenius fit arr = c * np.exp(-Ea / (const.k / const.e * np.array(temps))) t_1 = 1000 / np.array(temps) plt.plot(t_1, diffusivities, 'ko', t_1, arr, 'k--', markersize=10, **kwargs) if diffusivity_errors is not None: n = len(diffusivity_errors) plt.errorbar(t_1[0:n], diffusivities[0:n], yerr=diffusivity_errors, fmt='ko', ecolor='k', capthick=2, linewidth=2) ax = plt.axes() ax.set_yscale('log') plt.text(0.6, 0.85, "E$_a$ = {:.0f} meV".format(Ea * 1000), fontsize=30, transform=plt.axes().transAxes) plt.ylabel("D (cm$^2$/s)") plt.xlabel("1000/T (K$^{-1}$)") plt.tight_layout() return plt
def get_msd_plot(self, plt=None, mode="specie"): """ Get the plot of the smoothed msd vs time graph. Useful for checking convergence. This can be written to an image file. Args: plt: A plot object. Defaults to None, which means one will be generated. mode (str): Determines type of msd plot. By "species", "sites", or direction (default). If mode = "mscd", the smoothed mscd vs. time will be plotted. """ from pymatgen.util.plotting import pretty_plot plt = pretty_plot(12, 8, plt=plt) if np.max(self.dt) > 100000: plot_dt = self.dt / 1000 unit = "ps" else: plot_dt = self.dt unit = "fs" if mode == "species": for sp in sorted(self.structure.composition.keys()): indices = [ i for i, site in enumerate(self.structure) if site.specie == sp ] sd = np.average(self.sq_disp_ions[indices, :], axis=0) plt.plot(plot_dt, sd, label=sp.__str__()) plt.legend(loc=2, prop={"size": 20}) elif mode == "sites": for i, site in enumerate(self.structure): sd = self.sq_disp_ions[i, :] plt.plot(plot_dt, sd, label="%s - %d" % (site.specie.__str__(), i)) plt.legend(loc=2, prop={"size": 20}) elif mode == "mscd": plt.plot(plot_dt, self.mscd, "r") plt.legend(["Overall"], loc=2, prop={"size": 20}) else: # Handle default / invalid mode case plt.plot(plot_dt, self.msd, "k") plt.plot(plot_dt, self.msd_components[:, 0], "r") plt.plot(plot_dt, self.msd_components[:, 1], "g") plt.plot(plot_dt, self.msd_components[:, 2], "b") plt.legend(["Overall", "a", "b", "c"], loc=2, prop={"size": 20}) plt.xlabel("Timestep ({})".format(unit)) if mode == "mscd": plt.ylabel("MSCD ($\\AA^2$)") else: plt.ylabel("MSD ($\\AA^2$)") plt.tight_layout() return plt
def get_rdf_plot( self, label: str = None, xlim: tuple = (0.0, 8.0), ylim: tuple = (-0.005, 3.0), loc_peak: bool = False, ): """ Plot the average RDF function. Args: label (str): The legend label. xlim (list): Set the x limits of the current axes. ylim (list): Set the y limits of the current axes. loc_peak (bool): Label peaks if True. """ if label is None: symbol_list = [ e.symbol for e in self.structures[0].composition.keys() ] symbol_list = [ symbol for symbol in symbol_list if symbol in self.species ] if len(symbol_list) == 1: label = symbol_list[0] else: label = "-".join(symbol_list) plt = pretty_plot(12, 8) plt.plot(self.interval, self.rdf, label=label, linewidth=4.0, zorder=1) if loc_peak: plt.scatter( self.peak_r, self.peak_rdf, marker="P", s=240, c="k", linewidths=0.1, alpha=0.7, zorder=2, label="Peaks", ) plt.xlabel("$r$ ($\\rm\\AA$)") plt.ylabel("$g(r)$") plt.legend(loc="upper right", fontsize=36) plt.xlim(xlim[0], xlim[1]) plt.ylim(ylim[0], ylim[1]) plt.tight_layout() return plt
def plot_something(): matplotlib.rc('text', usetex=True) matplotlib.rc('font', family='serif') plt = pretty_plot(6, 5.5) plt.plot([1, 2, 3], [1, 2, 3]) plt.tight_layout() return plt
def get_framework_rms_plot(self, plt=None, granularity=200, matching_s=None): """ Get the plot of rms framework displacement vs time. Useful for checking for melting, especially if framework atoms can move via paddle-wheel or similar mechanism (which would show up in max framework displacement but doesn't constitute melting). Args: plt (matplotlib.pyplot): If plt is supplied, changes will be made to an existing plot. Otherwise, a new plot will be created. granularity (int): Number of structures to match matching_s (Structure): Optionally match to a disordered structure instead of the first structure in the analyzer. Required when a secondary mobile ion is present. Notes: The method doesn't apply to NPT-AIMD simulation analysis. """ from pymatgen.util.plotting import pretty_plot if self.lattices is not None and len(self.lattices) > 1: warnings.warn("Note the method doesn't apply to NPT-AIMD " "simulation analysis!") plt = pretty_plot(12, 8, plt=plt) step = (self.corrected_displacements.shape[1] - 1) // (granularity - 1) f = (matching_s or self.structure).copy() f.remove_species([self.specie]) sm = StructureMatcher(primitive_cell=False, stol=0.6, comparator=OrderDisorderElementComparator(), allow_subset=True) rms = [] for s in self.get_drift_corrected_structures(step=step): s.remove_species([self.specie]) d = sm.get_rms_dist(f, s) if d: rms.append(d) else: rms.append((1, 1)) max_dt = (len(rms) - 1) * step * self.step_skip * self.time_step if max_dt > 100000: plot_dt = np.linspace(0, max_dt / 1000, len(rms)) unit = 'ps' else: plot_dt = np.linspace(0, max_dt, len(rms)) unit = 'fs' rms = np.array(rms) plt.plot(plot_dt, rms[:, 0], label='RMS') plt.plot(plot_dt, rms[:, 1], label='max') plt.legend(loc='best') plt.xlabel("Timestep ({})".format(unit)) plt.ylabel("normalized distance") plt.tight_layout() return plt
def get_locpot_along_slab_plot(self, label_energies=True, plt=None, label_fontsize=10): """ Returns a plot of the local potential (eV) vs the position along the c axis of the slab model (Ang) Args: label_energies (bool): Whether to label relevant energy quantities such as the work function, Fermi energy, vacuum locpot, bulk-like locpot plt (plt): Matplotlib pylab object label_fontsize (float): Fontsize of labels Returns plt of the locpot vs c axis """ plt = pretty_plot(width=10, height=8) if not plt else plt # plot the raw locpot signal along c plt.plot(self.along_c, self.locpot_along_c, 'b--') # Get the local averaged signal of the locpot along c xg, yg = [], [] for i, p in enumerate(self.locpot_along_c): # average signal is just the bulk-like potential when in the slab region if p < self.ave_bulk_p \ or self.sorted_sites[-1].frac_coords[self.direction] >= self.along_c[i] \ >= self.sorted_sites[0].frac_coords[self.direction]: yg.append(self.ave_bulk_p) xg.append(self.along_c[i]) else: yg.append(p) xg.append(self.along_c[i]) xg, yg = zip(*sorted(zip(xg, yg))) plt.plot(xg, yg, 'r', linewidth=2.5, zorder=-1) # make it look nice if label_energies: plt = self.get_labels(plt, label_fontsize=label_fontsize) plt.xlim([0, 1]) plt.ylim([min(self.locpot_along_c), self.vacuum_locpot + self.ave_locpot * 0.2]) if self.direction == 0: plt.xlabel(r"Fractional coordinates ($\hat{a}$)", fontsize=25) elif self.direction == 1: plt.xlabel(r"Fractional coordinates ($\hat{b}$)", fontsize=25) elif self.direction == 2: plt.xlabel(r"Fractional coordinates ($\hat{c}$)", fontsize=25) plt.xticks(fontsize=15, rotation=45) plt.ylabel(r"Potential (eV)", fontsize=25) plt.yticks(fontsize=15) return plt
def get_1d_plot(self, mode: str = "distinct", times: List = [0.0], colors: List = None): """ Plot the van Hove function at given r or t. Args: mode (str): Specify which part of van Hove function to be plotted. times (list of float): Time moments (in ps) in which the van Hove function will be plotted. colors (list strings/tuples): Additional color settings. If not set, seaborn.color_plaette("Set1", 10) will be used. """ if colors is None: import seaborn as sns colors = sns.color_palette("Set1", 10) assert mode in ["distinct", "self"] assert len(times) <= len(colors) if mode == "distinct": grt = self.gdrt.copy() ylabel = "$G_d$($t$,$r$)" ylim = [-0.005, 4.0] elif mode == "self": grt = self.gsrt.copy() ylabel = "4$\pi r^2G_s$($t$,$r$)" ylim = [-0.005, 1.0] plt = pretty_plot(12, 8) for i, time in enumerate(times): index = int(np.round(time / self.timeskip)) index = min(index, np.shape(grt)[0] - 1) new_time = index * self.timeskip label = str(new_time) + " ps" plt.plot(self.interval, grt[index], color=colors[i], label=label, linewidth=4.0) plt.xlabel(r"$r$ ($\AA$)") plt.ylabel(ylabel) plt.legend(loc="upper right", fontsize=36) plt.xlim(0.0, self.interval[-1] - 1.0) plt.ylim(ylim[0], ylim[1]) plt.tight_layout() return plt
def get_msd_plot(self, plt=None, mode="specie"): """ Get the plot of the smoothed msd vs time graph. Useful for checking convergence. This can be written to an image file. Args: plt: A plot object. Defaults to None, which means one will be generated. mode (str): Determines type of msd plot. By "species", "sites", or direction (default). If mode = "mscd", the smoothed mscd vs. time will be plotted. """ from pymatgen.util.plotting import pretty_plot plt = pretty_plot(12, 8, plt=plt) if np.max(self.dt) > 100000: plot_dt = self.dt / 1000 unit = 'ps' else: plot_dt = self.dt unit = 'fs' if mode == "species": for sp in sorted(self.structure.composition.keys()): indices = [i for i, site in enumerate(self.structure) if site.specie == sp] sd = np.average(self.sq_disp_ions[indices, :], axis=0) plt.plot(plot_dt, sd, label=sp.__str__()) plt.legend(loc=2, prop={"size": 20}) elif mode == "sites": for i, site in enumerate(self.structure): sd = self.sq_disp_ions[i, :] plt.plot(plot_dt, sd, label="%s - %d" % ( site.specie.__str__(), i)) plt.legend(loc=2, prop={"size": 20}) elif mode == "mscd": plt.plot(plot_dt, self.mscd, 'r') plt.legend(["Overall"], loc=2, prop={"size": 20}) else: # Handle default / invalid mode case plt.plot(plot_dt, self.msd, 'k') plt.plot(plot_dt, self.msd_components[:, 0], 'r') plt.plot(plot_dt, self.msd_components[:, 1], 'g') plt.plot(plot_dt, self.msd_components[:, 2], 'b') plt.legend(["Overall", "a", "b", "c"], loc=2, prop={"size": 20}) plt.xlabel("Timestep ({})".format(unit)) if mode == "mscd": plt.ylabel("MSCD ($\\AA^2$)") else: plt.ylabel("MSD ($\\AA^2$)") plt.tight_layout() return plt
def get_pourbaix_plot(self, limits=None, title="", label_domains=True, plt=None): """ Plot Pourbaix diagram. Args: limits: 2D list containing limits of the Pourbaix diagram of the form [[xlo, xhi], [ylo, yhi]] title (str): Title to display on plot label_domains (bool): whether to label pourbaix domains plt (pyplot): Pyplot instance for plotting Returns: plt (pyplot) - matplotlib plot object with pourbaix diagram """ if limits is None: limits = [[-2, 16], [-3, 3]] plt = plt or pretty_plot(16) xlim = limits[0] ylim = limits[1] h_line = np.transpose([[xlim[0], -xlim[0] * PREFAC], [xlim[1], -xlim[1] * PREFAC]]) o_line = np.transpose([[xlim[0], -xlim[0] * PREFAC + 1.23], [xlim[1], -xlim[1] * PREFAC + 1.23]]) neutral_line = np.transpose([[7, ylim[0]], [7, ylim[1]]]) V0_line = np.transpose([[xlim[0], 0], [xlim[1], 0]]) ax = plt.gca() ax.set_xlim(xlim) ax.set_ylim(ylim) lw = 3 plt.plot(h_line[0], h_line[1], "r--", linewidth=lw) plt.plot(o_line[0], o_line[1], "r--", linewidth=lw) plt.plot(neutral_line[0], neutral_line[1], "k-.", linewidth=lw) plt.plot(V0_line[0], V0_line[1], "k-.", linewidth=lw) for entry, vertices in self._pd._stable_domain_vertices.items(): center = np.average(vertices, axis=0) x, y = np.transpose(np.vstack([vertices, vertices[0]])) plt.plot(x, y, 'k-', linewidth=lw) if label_domains: plt.annotate(generate_entry_label(entry), center, ha='center', va='center', fontsize=20, color="b") plt.xlabel("pH") plt.ylabel("E (V)") plt.title(title, fontsize=20, fontweight='bold') return plt
def main(): """ Main function. """ parser = argparse.ArgumentParser(description=""" Convenient DOS Plotter for Feff runs. Author: Alan Dozier Version: 1.0 Last updated: April, 2013""") parser.add_argument("filename", metavar="filename", type=str, nargs=1, help="xmu file to plot") parser.add_argument( "filename1", metavar="filename1", type=str, nargs=1, help="feff.inp filename to import", ) plt = pretty_plot(12, 8) color_order = ["r", "b", "g", "c", "k", "m", "y"] args = parser.parse_args() xmu = Xmu.from_file(args.filename[0], args.filename1[0]) data = xmu.as_dict() plt.title(data["calc"] + " Feff9.6 Calculation for " + data["atom"] + " in " + data["formula"] + " unit cell") plt.xlabel("Energies (eV)") plt.ylabel("Absorption Cross-section") x = data["energies"] y = data["scross"] tle = "Single " + data["atom"] + " " + data["edge"] + " edge" plt.plot(x, y, color_order[1 % 7], label=tle) y = data["across"] tle = data["atom"] + " " + data["edge"] + " edge in " + data["formula"] plt.plot(x, y, color_order[2 % 7], label=tle) plt.legend() leg = plt.gca().get_legend() ltext = leg.get_texts() # all the text.Text instance in the legend plt.setp(ltext, fontsize=15) plt.tight_layout() plt.show()
def get_plot(self, ylim=None): """ Get a matplotlib object for the bandstructure plot. Args: ylim: Specify the y-axis (frequency) limits; by default None let the code choose. """ plt = pretty_plot(12, 8) from matplotlib import rc import scipy.interpolate as scint try: rc('text', usetex=True) except: # Fall back on non Tex if errored. rc('text', usetex=False) band_linewidth = 1 data = self.bs_plot_data() for d in range(len(data['distances'])): for i in range(self._nb_bands): plt.plot(data['distances'][d], [ data['frequency'][d][i][j] for j in range(len(data['distances'][d])) ], 'b-', linewidth=band_linewidth) self._maketicks(plt) # plot y=0 line plt.axhline(0, linewidth=1, color='k') # Main X and Y Labels plt.xlabel(r'$\mathrm{Wave\ Vector}$', fontsize=30) ylabel = r'$\mathrm{Frequency\ (THz)}$' plt.ylabel(ylabel, fontsize=30) # X range (K) # last distance point x_max = data['distances'][-1][-1] plt.xlim(0, x_max) if ylim is not None: plt.ylim(ylim) plt.tight_layout() return plt
def plot(self, width=8, height=None, plt=None, dpi=None, **kwargs): """ Plot the equation of state. Args: width (float): Width of plot in inches. Defaults to 8in. height (float): Height of plot in inches. Defaults to width * golden ratio. plt (matplotlib.pyplot): If plt is supplied, changes will be made to an existing plot. Otherwise, a new plot will be created. dpi: kwargs (dict): additional args fed to pyplot.plot. supported keys: style, color, text, label Returns: Matplotlib plot object. """ plt = pretty_plot(width=width, height=height, plt=plt, dpi=dpi) color = kwargs.get("color", "r") label = kwargs.get("label", "{} fit".format(self.__class__.__name__)) lines = ["Equation of State: %s" % self.__class__.__name__, "Minimum energy = %1.2f eV" % self.e0, "Minimum or reference volume = %1.2f Ang^3" % self.v0, "Bulk modulus = %1.2f eV/Ang^3 = %1.2f GPa" % (self.b0, self.b0_GPa), "Derivative of bulk modulus wrt pressure = %1.2f" % self.b1] text = "\n".join(lines) text = kwargs.get("text", text) # Plot input data. plt.plot(self.volumes, self.energies, linestyle="None", marker="o", color=color) # Plot eos fit. vmin, vmax = min(self.volumes), max(self.volumes) vmin, vmax = (vmin - 0.01 * abs(vmin), vmax + 0.01 * abs(vmax)) vfit = np.linspace(vmin, vmax, 100) plt.plot(vfit, self.func(vfit), linestyle="dashed", color=color, label=label) plt.grid(True) plt.xlabel("Volume $\\AA^3$") plt.ylabel("Energy (eV)") plt.legend(loc="best", shadow=True) # Add text with fit parameters. plt.text(0.4, 0.5, text, transform=plt.gca().transAxes) return plt
def get_plot(self, ylim=None, units="thz"): """ Get a matplotlib object for the bandstructure plot. Args: ylim: Specify the y-axis (frequency) limits; by default None let the code choose. units: units for the frequencies. Accepted values thz, ev, mev, ha, cm-1, cm^-1. """ u = freq_units(units) plt = pretty_plot(12, 8) band_linewidth = 1 data = self.bs_plot_data() for d in range(len(data["distances"])): for i in range(self._nb_bands): plt.plot( data["distances"][d], [ data["frequency"][d][i][j] * u.factor for j in range(len(data["distances"][d])) ], "b-", linewidth=band_linewidth, ) self._maketicks(plt) # plot y=0 line plt.axhline(0, linewidth=1, color="k") # Main X and Y Labels plt.xlabel(r"$\mathrm{Wave\ Vector}$", fontsize=30) ylabel = r"$\mathrm{{Frequencies\ ({})}}$".format(u.label) plt.ylabel(ylabel, fontsize=30) # X range (K) # last distance point x_max = data["distances"][-1][-1] plt.xlim(0, x_max) if ylim is not None: plt.ylim(ylim) plt.tight_layout() return plt
def get_plot(self, ylim=None): """ Get a matplotlib object for the bandstructure plot. Args: ylim: Specify the y-axis (frequency) limits; by default None let the code choose. """ plt = pretty_plot(12, 8) from matplotlib import rc import scipy.interpolate as scint try: rc('text', usetex=True) except: # Fall back on non Tex if errored. rc('text', usetex=False) band_linewidth = 1 data = self.bs_plot_data() for d in range(len(data['distances'])): for i in range(self._nb_bands): plt.plot(data['distances'][d], [data['frequency'][d][i][j] for j in range(len(data['distances'][d]))], 'b-', linewidth=band_linewidth) self._maketicks(plt) # plot y=0 line plt.axhline(0, linewidth=1, color='k') # Main X and Y Labels plt.xlabel(r'$\mathrm{Wave\ Vector}$', fontsize=30) ylabel = r'$\mathrm{Frequency\ (THz)}$' plt.ylabel(ylabel, fontsize=30) # X range (K) # last distance point x_max = data['distances'][-1][-1] plt.xlim(0, x_max) if ylim is not None: plt.ylim(ylim) plt.tight_layout() return plt
def main(): """ Main function. """ parser = argparse.ArgumentParser(description=''' Convenient DOS Plotter for Feff runs. Author: Alan Dozier Version: 1.0 Last updated: April, 2013''') parser.add_argument('filename', metavar='filename', type=str, nargs=1, help='xmu file to plot') parser.add_argument('filename1', metavar='filename1', type=str, nargs=1, help='feff.inp filename to import') plt = pretty_plot(12, 8) color_order = ['r', 'b', 'g', 'c', 'k', 'm', 'y'] args = parser.parse_args() xmu = Xmu.from_file(args.filename[0], args.filename1[0]) data = xmu.to_dict plt.title(data['calc'] + ' Feff9.6 Calculation for ' + data['atom'] + ' in ' + data['formula'] + ' unit cell') plt.xlabel('Energies (eV)') plt.ylabel('Absorption Cross-section') x = data['energies'] y = data['scross'] tle = 'Single ' + data['atom'] + ' ' + data['edge'] + ' edge' plt.plot(x, y, color_order[1 % 7], label=tle) y = data['across'] tle = data['atom'] + ' ' + data['edge'] + ' edge in ' + data['formula'] plt.plot(x, y, color_order[2 % 7], label=tle) plt.legend() leg = plt.gca().get_legend() ltext = leg.get_texts() # all the text.Text instance in the legend plt.setp(ltext, fontsize=15) plt.tight_layout() plt.show()
def get_plot_gs(self, ylim=None): """ Get a matplotlib object for the gruneisen bandstructure plot. Args: ylim: Specify the y-axis (gruneisen) limits; by default None let the code choose. """ plt = pretty_plot(12, 8) # band_linewidth = 1 data = self.bs_plot_data() for d in range(len(data["distances"])): for i in range(self._nb_bands): plt.plot( data["distances"][d], [ data["gruneisen"][d][i][j] for j in range(len(data["distances"][d])) ], "b-", # linewidth=band_linewidth) marker="o", markersize=2, linewidth=2, ) self._maketicks(plt) # plot y=0 line plt.axhline(0, linewidth=1, color="k") # Main X and Y Labels plt.xlabel(r"$\mathrm{Wave\ Vector}$", fontsize=30) plt.ylabel(r"$\mathrm{Grüneisen\ Parameter}$", fontsize=30) # X range (K) # last distance point x_max = data["distances"][-1][-1] plt.xlim(0, x_max) if ylim is not None: plt.ylim(ylim) plt.tight_layout() return plt
def get_plot(self, xlim=None, ylim=None): """ Get a matplotlib plot showing the DOS. Args: xlim: Specifies the x-axis limits. Set to None for automatic determination. ylim: Specifies the y-axis limits. """ plt = pretty_plot(12, 8) base = 0.0 i = 0 for key, sp in self._spectra.items(): if not self.stack: plt.plot( sp.x, sp.y + self.yshift * i, color=self.colors[i], label=str(key), linewidth=3, ) else: plt.fill_between( sp.x, base, sp.y + self.yshift * i, color=self.colors[i], label=str(key), linewidth=3, ) base = sp.y + base plt.xlabel(sp.XLABEL) plt.ylabel(sp.YLABEL) i += 1 if xlim: plt.xlim(xlim) if ylim: plt.ylim(ylim) plt.legend() leg = plt.gca().get_legend() ltext = leg.get_texts() # all the text.Text instance in the legend plt.setp(ltext, fontsize=30) plt.tight_layout() return plt
def get_plot(self, ylim=None, units="thz"): """ Get a matplotlib object for the bandstructure plot. Args: ylim: Specify the y-axis (frequency) limits; by default None let the code choose. units: units for the frequencies. Accepted values thz, ev, mev, ha, cm-1, cm^-1. """ u = freq_units(units) plt = pretty_plot(12, 8) band_linewidth = 1 data = self.bs_plot_data() for d in range(len(data['distances'])): for i in range(self._nb_bands): plt.plot(data['distances'][d], [data['frequency'][d][i][j] * u.factor for j in range(len(data['distances'][d]))], 'b-', linewidth=band_linewidth) self._maketicks(plt) # plot y=0 line plt.axhline(0, linewidth=1, color='k') # Main X and Y Labels plt.xlabel(r'$\mathrm{Wave\ Vector}$', fontsize=30) ylabel = r'$\mathrm{{Frequencies\ ({})}}$'.format(u.label) plt.ylabel(ylabel, fontsize=30) # X range (K) # last distance point x_max = data['distances'][-1][-1] plt.xlim(0, x_max) if ylim is not None: plt.ylim(ylim) plt.tight_layout() return plt
def get_arrhenius_plot(temps, diffusivities, diffusivity_errors=None, **kwargs): """ Returns an Arrhenius plot. Args: temps ([float]): A sequence of temperatures. diffusivities ([float]): A sequence of diffusivities (e.g., from DiffusionAnalyzer.diffusivity). diffusivity_errors ([float]): A sequence of errors for the diffusivities. If None, no error bar is plotted. \\*\\*kwargs: Any keyword args supported by matplotlib.pyplot.plot. Returns: A matplotlib.pyplot object. Do plt.show() to show the plot. """ Ea, c, _ = fit_arrhenius(temps, diffusivities) from pymatgen.util.plotting import pretty_plot plt = pretty_plot(12, 8) # log10 of the arrhenius fit arr = c * np.exp(-Ea / (const.k / const.e * np.array(temps))) t_1 = 1000 / np.array(temps) plt.plot(t_1, diffusivities, 'ko', t_1, arr, 'k--', markersize=10, **kwargs) if diffusivity_errors is not None: n = len(diffusivity_errors) plt.errorbar(t_1[0:n], diffusivities[0:n], yerr=diffusivity_errors, fmt='ko', ecolor='k', capthick=2, linewidth=2) ax = plt.axes() ax.set_yscale('log') plt.text(0.6, 0.85, "E$_a$ = {:.0f} meV".format(Ea * 1000), fontsize=30, transform=plt.axes().transAxes) plt.ylabel("D (cm$^2$/s)") plt.xlabel("1000/T (K$^{-1}$)") plt.tight_layout() return plt
def main(): parser = argparse.ArgumentParser(description=''' Convenient DOS Plotter for Feff runs. Author: Alan Dozier Version: 1.0 Last updated: April, 2013''') parser.add_argument('filename', metavar='filename', type=str, nargs=1, help='xmu file to plot') parser.add_argument('filename1', metavar='filename1', type=str, nargs=1, help='feff.inp filename to import') plt = pretty_plot(12, 8) color_order = ['r', 'b', 'g', 'c', 'k', 'm', 'y'] args = parser.parse_args() xmu = Xmu.from_file(args.filename[0], args.filename1[0]) data = xmu.to_dict plt.title(data['calc'] + ' Feff9.6 Calculation for ' + data['atom'] + ' in ' + data['formula'] + ' unit cell') plt.xlabel('Energies (eV)') plt.ylabel('Absorption Cross-section') x = data['energies'] y = data['scross'] tle = 'Single ' + data['atom'] + ' ' + data['edge'] + ' edge' plt.plot(x, y, color_order[1 % 7], label=tle) y = data['across'] tle = data['atom'] + ' ' + data['edge'] + ' edge in ' + data['formula'] plt.plot(x, y, color_order[2 % 7], label=tle) plt.legend() leg = plt.gca().get_legend() ltext = leg.get_texts() # all the text.Text instance in the legend plt.setp(ltext, fontsize=15) plt.tight_layout() plt.show()
def get_chgint_plot(args): chgcar = Chgcar.from_file(args.chgcar_file) s = chgcar.structure if args.inds: atom_ind = [int(i) for i in args.inds[0].split(",")] else: finder = SpacegroupAnalyzer(s, symprec=0.1) sites = [sites[0] for sites in finder.get_symmetrized_structure().equivalent_sites] atom_ind = [s.sites.index(site) for site in sites] from pymatgen.util.plotting import pretty_plot plt = pretty_plot(12, 8) for i in atom_ind: d = chgcar.get_integrated_diff(i, args.radius, 30) plt.plot(d[:, 0], d[:, 1], label="Atom {} - {}".format(i, s[i].species_string)) plt.legend(loc="upper left") plt.xlabel("Radius (A)") plt.ylabel("Integrated charge (e)") plt.tight_layout() return plt
def get_plot(self, xlim=None, ylim=None): """ Get a matplotlib plot showing the DOS. Args: xlim: Specifies the x-axis limits. Set to None for automatic determination. ylim: Specifies the y-axis limits. """ plt = pretty_plot(12, 8) base = 0.0 i = 0 for key, sp in self._spectra.items(): if not self.stack: plt.plot(sp.x, sp.y + self.yshift * i, color=self.colors[i], label=str(key), linewidth=3) else: plt.fill_between(sp.x, base, sp.y + self.yshift * i, color=self.colors[i], label=str(key), linewidth=3) base = sp.y + base plt.xlabel(sp.XLABEL) plt.ylabel(sp.YLABEL) i += 1 if xlim: plt.xlim(xlim) if ylim: plt.ylim(ylim) plt.legend() leg = plt.gca().get_legend() ltext = leg.get_texts() # all the text.Text instance in the legend plt.setp(ltext, fontsize=30) plt.tight_layout() return plt
def get_plot(self, width=8, height=8): """ Returns a plot object. Args: width: Width of the plot. Defaults to 8 in. height: Height of the plot. Defaults to 6 in. Returns: A matplotlib plot object. """ plt = pretty_plot(width, height) for label, electrode in self._electrodes.items(): (x, y) = self.get_plot_data(electrode) plt.plot(x, y, '-', linewidth=2, label=label) plt.legend() if self.xaxis == "capacity": plt.xlabel('Capacity (mAh/g)') else: plt.xlabel('Fraction') plt.ylabel('Voltage (V)') plt.tight_layout() return plt
def get_pourbaix_mark_passive(self, limits=None, title="", label_domains=True, passive_entry=None): """ Color domains by element """ from matplotlib.patches import Polygon from pymatgen import Element from itertools import chain import operator plt = pretty_plot(16) optim_colors = ['#0000FF', '#FF0000', '#00FF00', '#FFFF00', '#FF00FF', '#FF8080', '#DCDCDC', '#800000', '#FF8000'] optim_font_colors = ['#FFC000', '#00FFFF', '#FF00FF', '#0000FF', '#00FF00', '#007F7F', '#232323', '#7FFFFF', '#007FFF'] (stable, unstable) = self.pourbaix_plot_data(limits) mark_passive = {key: 0 for key in stable.keys()} if self._pd._elt_comp: maxval = max(six.iteritems(self._pd._elt_comp), key=operator.itemgetter(1))[1] key = [k for k, v in self._pd._elt_comp.items() if v == maxval] passive_entry = key[0] def list_elts(entry): elts_list = set() if isinstance(entry, MultiEntry): for el in chain.from_iterable([[el for el in e.composition.elements] for e in entry.entrylist]): elts_list.add(el) else: elts_list = entry.composition.elements return elts_list for entry in stable: if passive_entry + str("(s)") in entry.name: mark_passive[entry] = 2 continue if "(s)" not in entry.name: continue elif len(set([Element("O"), Element("H")]).intersection(set(list_elts(entry)))) > 0: mark_passive[entry] = 1 if limits: xlim = limits[0] ylim = limits[1] else: xlim = self._analyzer.chempot_limits[0] ylim = self._analyzer.chempot_limits[1] h_line = np.transpose([[xlim[0], -xlim[0] * PREFAC], [xlim[1], -xlim[1] * PREFAC]]) o_line = np.transpose([[xlim[0], -xlim[0] * PREFAC + 1.23], [xlim[1], -xlim[1] * PREFAC + 1.23]]) neutral_line = np.transpose([[7, ylim[0]], [7, ylim[1]]]) V0_line = np.transpose([[xlim[0], 0], [xlim[1], 0]]) ax = plt.gca() ax.set_xlim(xlim) ax.set_ylim(ylim) for e in stable.keys(): xy = self.domain_vertices(e) c = self.get_center(stable[e]) if mark_passive[e] == 1: color = optim_colors[0] fontcolor = optim_font_colors[0] colorfill = True elif mark_passive[e] == 2: color = optim_colors[1] fontcolor = optim_font_colors[1] colorfill = True else: color = "w" colorfill = False fontcolor = "k" patch = Polygon(xy, facecolor=color, closed=True, lw=3.0, fill=colorfill) ax.add_patch(patch) if label_domains: plt.annotate(self.print_name(e), c, color=fontcolor, fontsize=20) lw = 3 plt.plot(h_line[0], h_line[1], "r--", linewidth=lw) plt.plot(o_line[0], o_line[1], "r--", linewidth=lw) plt.plot(neutral_line[0], neutral_line[1], "k-.", linewidth=lw) plt.plot(V0_line[0], V0_line[1], "k-.", linewidth=lw) plt.xlabel("pH") plt.ylabel("E (V)") plt.title(title, fontsize=20, fontweight='bold') return plt
def get_pourbaix_plot_colorfill_by_domain_name(self, limits=None, title="", label_domains=True, label_color='k', domain_color=None, domain_fontsize=None, domain_edge_lw=0.5, bold_domains=None, cluster_domains=(), add_h2o_stablity_line=True, add_center_line=False, h2o_lw=0.5, fill_domain=True, width=8, height=None, font_family='Times New Roman'): """ Color domains by the colors specific by the domain_color dict Args: limits: 2D list containing limits of the Pourbaix diagram of the form [[xlo, xhi], [ylo, yhi]] lable_domains (Bool): whether add the text lable for domains label_color (str): color of domain lables, defaults to be black domain_color (dict): colors of each domain e.g {"Al(s)": "#FF1100"}. If set to None default color set will be used. domain_fontsize (int): Font size used in domain text labels. domain_edge_lw (int): line width for the boundaries between domains. bold_domains (list): List of domain names to use bold text style for domain lables. If set to False, no domain will be bold. cluster_domains (list): List of domain names in cluster phase add_h2o_stablity_line (Bool): whether plot H2O stability line add_center_line (Bool): whether plot lines shows the center coordinate h2o_lw (int): line width for H2O stability line and center lines fill_domain (bool): a version without color will be product if set to False. width (float): Width of plot in inches. Defaults to 8in. height (float): Height of plot in inches. Defaults to width * golden ratio. font_family (str): Font family of the labels """ def special_lines(xlim, ylim): h_line = np.transpose([[xlim[0], -xlim[0] * PREFAC], [xlim[1], -xlim[1] * PREFAC]]) o_line = np.transpose([[xlim[0], -xlim[0] * PREFAC + 1.23], [xlim[1], -xlim[1] * PREFAC + 1.23]]) neutral_line = np.transpose([[7, ylim[0]], [7, ylim[1]]]) V0_line = np.transpose([[xlim[0], 0], [xlim[1], 0]]) return h_line, o_line, neutral_line, V0_line from matplotlib.patches import Polygon import copy default_domain_font_size = 12 default_solid_phase_color = '#b8f9e7' # this slighly darker than the MP scheme, to default_cluster_phase_color = '#d0fbef' # avoid making the cluster phase too light plt = pretty_plot(width=width, height=height, dpi=300) (stable, unstable) = self.pourbaix_plot_data(limits) xlim, ylim = limits[:2] if limits else self._analyzer.chempot_limits[:2] h_line, o_line, neutral_line, V0_line = special_lines(xlim, ylim) ax = plt.gca() ax.set_xlim(xlim) ax.set_ylim(ylim) ax.xaxis.set_major_formatter(FormatStrFormatter('%d')) ax.yaxis.set_major_formatter(FormatStrFormatter('%.1f')) ax.tick_params(direction='out') ax.xaxis.set_ticks_position('bottom') ax.yaxis.set_ticks_position('left') sorted_entry = list(stable.keys()) sorted_entry.sort(key=lambda en: en.energy, reverse=True) if domain_fontsize is None: domain_fontsize = {en.name: default_domain_font_size for en in sorted_entry} elif not isinstance(domain_fontsize, dict): domain_fontsize = {en.name: domain_fontsize for en in sorted_entry} if domain_color is None: domain_color = {en.name: default_solid_phase_color if '(s)' in en.name else (default_cluster_phase_color if en.name in cluster_domains else 'w') for i, en in enumerate(sorted_entry)} else: domain_color = {en.name: domain_color[en.name] if en.name in domain_color else "w" for i, en in enumerate(sorted_entry)} if bold_domains is None: bold_domains = [en.name for en in sorted_entry if '(s)' not in en.name] if bold_domains == False: bold_domains = [] for entry in sorted_entry: xy = self.domain_vertices(entry) if add_h2o_stablity_line: c = self.get_distribution_corrected_center(stable[entry], h_line, o_line, 0.3) else: c = self.get_distribution_corrected_center(stable[entry]) patch = Polygon(xy, facecolor=domain_color[entry.name], edgecolor="black", closed=True, lw=domain_edge_lw, fill=fill_domain, antialiased=True) ax.add_patch(patch) if label_domains: if platform.system() == 'Darwin' and font_family == "Times New Roman": # Have to hack to the hard coded font path to get current font On Mac OS X if entry.name in bold_domains: font = FontProperties(fname='/Library/Fonts/Times New Roman Bold.ttf', size=domain_fontsize[entry.name]) else: font = FontProperties(fname='/Library/Fonts/Times New Roman.ttf', size=domain_fontsize[entry.name]) else: if entry.name in bold_domains: font = FontProperties(family=font_family, weight='bold', size=domain_fontsize[entry.name]) else: font = FontProperties(family=font_family, weight='regular', size=domain_fontsize[entry.name]) plt.text(*c, s=self.print_name(entry), fontproperties=font, horizontalalignment="center", verticalalignment="center", multialignment="center", color=label_color) if add_h2o_stablity_line: dashes = (3, 1.5) line, = plt.plot(h_line[0], h_line[1], "k--", linewidth=h2o_lw, antialiased=True) line.set_dashes(dashes) line, = plt.plot(o_line[0], o_line[1], "k--", linewidth=h2o_lw, antialiased=True) line.set_dashes(dashes) if add_center_line: plt.plot(neutral_line[0], neutral_line[1], "k-.", linewidth=h2o_lw, antialiased=False) plt.plot(V0_line[0], V0_line[1], "k-.", linewidth=h2o_lw, antialiased=False) plt.xlabel("pH", fontname=font_family, fontsize=18) plt.ylabel("E (V vs SHE)", fontname=font_family, fontsize=18) plt.xticks(fontname=font_family, fontsize=16) plt.yticks(fontname=font_family, fontsize=16) plt.title(title, fontsize=20, fontweight='bold', fontname=font_family) return plt
def get_pourbaix_plot_colorfill_by_element(self, limits=None, title="", label_domains=True, element=None): """ Color domains by element """ from matplotlib.patches import Polygon entry_dict_of_multientries = collections.defaultdict(list) plt = pretty_plot(16) optim_colors = ['#0000FF', '#FF0000', '#00FF00', '#FFFF00', '#FF00FF', '#FF8080', '#DCDCDC', '#800000', '#FF8000'] optim_font_color = ['#FFFFA0', '#00FFFF', '#FF00FF', '#0000FF', '#00FF00', '#007F7F', '#232323', '#7FFFFF', '#007FFF'] hatch = ['/', '\\', '|', '-', '+', 'o', '*'] (stable, unstable) = self.pourbaix_plot_data(limits) num_of_overlaps = {key: 0 for key in stable.keys()} for entry in stable: if isinstance(entry, MultiEntry): for e in entry.entrylist: if element in e.composition.elements: entry_dict_of_multientries[e.name].append(entry) num_of_overlaps[entry] += 1 else: entry_dict_of_multientries[entry.name].append(entry) if limits: xlim = limits[0] ylim = limits[1] else: xlim = self._analyzer.chempot_limits[0] ylim = self._analyzer.chempot_limits[1] h_line = np.transpose([[xlim[0], -xlim[0] * PREFAC], [xlim[1], -xlim[1] * PREFAC]]) o_line = np.transpose([[xlim[0], -xlim[0] * PREFAC + 1.23], [xlim[1], -xlim[1] * PREFAC + 1.23]]) neutral_line = np.transpose([[7, ylim[0]], [7, ylim[1]]]) V0_line = np.transpose([[xlim[0], 0], [xlim[1], 0]]) ax = plt.gca() ax.set_xlim(xlim) ax.set_ylim(ylim) from pymatgen import Composition, Element from pymatgen.core.ion import Ion def len_elts(entry): if "(s)" in entry: comp = Composition(entry[:-3]) else: comp = Ion.from_formula(entry) return len([el for el in comp.elements if el not in [Element("H"), Element("O")]]) sorted_entry = entry_dict_of_multientries.keys() sorted_entry.sort(key=len_elts) i = -1 label_chr = map(chr, list(range(65, 91))) for entry in sorted_entry: color_indx = 0 x_coord = 0.0 y_coord = 0.0 npts = 0 i += 1 for e in entry_dict_of_multientries[entry]: hc = 0 fc = 0 bc = 0 xy = self.domain_vertices(e) c = self.get_center(stable[e]) x_coord += c[0] y_coord += c[1] npts += 1 color_indx = i if "(s)" in entry: comp = Composition(entry[:-3]) else: comp = Ion.from_formula(entry) if len([el for el in comp.elements if el not in [Element("H"), Element("O")]]) == 1: if color_indx >= len(optim_colors): color_indx = color_indx -\ int(color_indx / len(optim_colors)) * len(optim_colors) patch = Polygon(xy, facecolor=optim_colors[color_indx], closed=True, lw=3.0, fill=True) bc = optim_colors[color_indx] else: if color_indx >= len(hatch): color_indx = color_indx - int(color_indx / len(hatch)) * len(hatch) patch = Polygon(xy, hatch=hatch[color_indx], closed=True, lw=3.0, fill=False) hc = hatch[color_indx] ax.add_patch(patch) xy_center = (x_coord / npts, y_coord / npts) if label_domains: if color_indx >= len(optim_colors): color_indx = color_indx -\ int(color_indx / len(optim_colors)) * len(optim_colors) fc = optim_font_color[color_indx] if bc and not hc: bbox = dict(boxstyle="round", fc=fc) if hc and not bc: bc = 'k' fc = 'w' bbox = dict(boxstyle="round", hatch=hc, fill=False) if bc and hc: bbox = dict(boxstyle="round", hatch=hc, fc=fc) # bbox.set_path_effects([PathEffects.withSimplePatchShadow()]) plt.annotate(latexify_ion(latexify(entry)), xy_center, color=bc, fontsize=30, bbox=bbox) # plt.annotate(label_chr[i], xy_center, # color=bc, fontsize=30, bbox=bbox) lw = 3 plt.plot(h_line[0], h_line[1], "r--", linewidth=lw) plt.plot(o_line[0], o_line[1], "r--", linewidth=lw) plt.plot(neutral_line[0], neutral_line[1], "k-.", linewidth=lw) plt.plot(V0_line[0], V0_line[1], "k-.", linewidth=lw) plt.xlabel("pH") plt.ylabel("E (V)") plt.title(title, fontsize=20, fontweight='bold') return plt
def area_frac_vs_chempot_plot(self, cmap=cm.jet, at_intersections=False, increments=10): """ Plots the change in the area contribution of each facet as a function of chemical potential. Args: cmap (cm): A matplotlib colormap object, defaults to jet. at_intersections (bool): Whether to generate a Wulff shape for each intersection of surface energy for a specific facet (eg. at the point where a (111) stoichiometric surface energy plot intersects with the (111) nonstoichiometric plot) or to just generate two Wulff shapes, one at the min and max chemical potential. increments (bool): Number of data points between min/max or point of intersection. Defaults to 5 points. """ # Choose unique colors for each facet f = [int(i) for i in np.linspace(0, 255, len(self.vasprun_dict.keys()))] # Get all points of min/max chempot and intersections chempot_intersections = [] chempot_intersections.extend(self.chempot_range) for hkl in self.vasprun_dict.keys(): chempot_intersections.extend([ints[0] for ints in self.get_intersections(hkl)]) chempot_intersections = sorted(chempot_intersections) # Get all chempots if at_intersections: all_chempots = [] for i, intersection in enumerate(chempot_intersections): if i < len(chempot_intersections)-1: all_chempots.extend(np.linspace(intersection, chempot_intersections[i+1], increments)) else: all_chempots = np.linspace(min(self.chempot_range), max(self.chempot_range), increments) # initialize a dictionary of lists of fractional areas for each hkl hkl_area_dict = {} for hkl in self.vasprun_dict.keys(): hkl_area_dict[hkl] = [] # Get plot points for each Miller index for u in all_chempots: wulffshape = self.wulff_shape_from_chempot(u) for hkl in wulffshape.area_fraction_dict.keys(): hkl_area_dict[hkl].append(wulffshape.area_fraction_dict[hkl]) # Plot the area fraction vs chemical potential for each facet plt = pretty_plot() for i, hkl in enumerate(self.vasprun_dict.keys()): # Ignore any facets that never show up on the # Wulff shape regardless of chemical potential if all([a == 0 for a in hkl_area_dict[hkl]]): continue else: plt.plot(all_chempots, hkl_area_dict[hkl], '--', color=cmap(f[i]), label=str(hkl)) # Make the figure look nice plt.ylim([0,1]) plt.xlim(self.chempot_range) plt.ylabel(r"Fractional area $A^{Wulff}_{hkl}/A^{Wulff}$") plt.xlabel(r"Chemical potential $\Delta\mu_{%s}$ (eV)" %(self.ref_element)) plt.legend(bbox_to_anchor=(1.01, 1), loc=2, borderaxespad=0.) return plt
def get_plot(self, xlim=None, ylim=None): """ Get a matplotlib plot showing the DOS. Args: xlim: Specifies the x-axis limits. Set to None for automatic determination. ylim: Specifies the y-axis limits. """ import prettyplotlib as ppl from prettyplotlib import brewer2mpl ncolors = max(3, len(self._doses)) ncolors = min(9, ncolors) colors = brewer2mpl.get_map('Set1', 'qualitative', ncolors).mpl_colors y = None alldensities = [] allfrequencies = [] plt = pretty_plot(12, 8) # Note that this complicated processing of frequencies is to allow for # stacked plots in matplotlib. for key, dos in self._doses.items(): frequencies = dos['frequencies'] densities = dos['densities'] if y is None: y = np.zeros(frequencies.shape) if self.stack: y += densities newdens = y.copy() else: newdens = densities allfrequencies.append(frequencies) alldensities.append(newdens) keys = list(self._doses.keys()) keys.reverse() alldensities.reverse() allfrequencies.reverse() allpts = [] for i, (key, frequencies, densities) in enumerate(zip(keys, allfrequencies, alldensities)): allpts.extend(list(zip(frequencies, densities))) if self.stack: plt.fill(frequencies, densities, color=colors[i % ncolors], label=str(key)) else: ppl.plot(frequencies, densities, color=colors[i % ncolors], label=str(key), linewidth=3) if xlim: plt.xlim(xlim) if ylim: plt.ylim(ylim) else: xlim = plt.xlim() relevanty = [p[1] for p in allpts if xlim[0] < p[0] < xlim[1]] plt.ylim((min(relevanty), max(relevanty))) ylim = plt.ylim() plt.plot([0, 0], ylim, 'k--', linewidth=2) plt.xlabel('Frequencies (THz)') plt.ylabel('Density of states') plt.legend() leg = plt.gca().get_legend() ltext = leg.get_texts() # all the text.Text instance in the legend plt.setp(ltext, fontsize=30) plt.tight_layout() return plt
def chempot_vs_gamma_plot(self, cmap=cm.jet, show_unstable_points=False): """ Plots the surface energy of all facets as a function of chemical potential. Each facet will be associated with its own distinct colors. Dashed lines will represent stoichiometries different from that of the mpid's compound. Args: cmap (cm): A matplotlib colormap object, defaults to jet. show_unstable_points (bool): For each facet, there may be various terminations or stoichiometries and the relative stability of these different slabs may change with chemical potential. This option will only plot the most stable surface energy for a given chemical potential. """ plt = pretty_plot() # Choose unique colors for each facet f = [int(i) for i in np.linspace(0, 255, sum([len(vaspruns) for vaspruns in self.vasprun_dict.values()]))] i, already_labelled, colors = 0, [], [] for hkl in self.vasprun_dict.keys(): for vasprun in self.vasprun_dict[hkl]: slab = vasprun.final_structure # Generate a label for the type of slab label = str(hkl) # use dashed lines for slabs that are not stoichiometric # wrt bulk. Label with formula if nonstoichiometric if slab.composition.reduced_composition != \ self.ucell_entry.composition.reduced_composition: mark = '--' label += " %s" % (slab.composition.reduced_composition) else: mark = '-' # label the chemical environment at the surface if different from the bulk. # First get the surface sites, then get the reduced composition at the surface # s = vasprun.final_structure # ucell = SpacegroupAnalyzer(self.ucell_entry.structure).\ # get_conventional_standard_structure() # slab = Slab(s.lattice, s.species, s.frac_coords, hkl, ucell, 0, None) # surf_comp = slab.surface_composition() # # if surf_comp.reduced_composition != ucell.composition.reduced_composition: # label += " %s" %(surf_comp.reduced_composition) if label in already_labelled: c = colors[already_labelled.index(label)] label = None else: already_labelled.append(label) c = cmap(f[i]) colors.append(c) se_range = self.calculate_gamma(vasprun) plt.plot(self.chempot_range, se_range, mark, color=c, label=label) i += 1 # Make the figure look nice axes = plt.gca() ylim = axes.get_ylim() plt.ylim(ylim) plt.xlim(self.chempot_range) plt.ylabel(r"Surface energy (eV/$\AA$)") plt.xlabel(r"Chemical potential $\Delta\mu_{%s}$ (eV)" %(self.ref_element)) plt.legend(bbox_to_anchor=(1.01, 1), loc=2, borderaxespad=0.) return plt
def get_plot(self, xlim=None, ylim=None, units="thz"): """ Get a matplotlib plot showing the DOS. Args: xlim: Specifies the x-axis limits. Set to None for automatic determination. ylim: Specifies the y-axis limits. units: units for the frequencies. Accepted values thz, ev, mev, ha, cm-1, cm^-1. """ u = freq_units(units) ncolors = max(3, len(self._doses)) ncolors = min(9, ncolors) import palettable colors = palettable.colorbrewer.qualitative.Set1_9.mpl_colors y = None alldensities = [] allfrequencies = [] plt = pretty_plot(12, 8) # Note that this complicated processing of frequencies is to allow for # stacked plots in matplotlib. for key, dos in self._doses.items(): frequencies = dos['frequencies'] * u.factor densities = dos['densities'] if y is None: y = np.zeros(frequencies.shape) if self.stack: y += densities newdens = y.copy() else: newdens = densities allfrequencies.append(frequencies) alldensities.append(newdens) keys = list(self._doses.keys()) keys.reverse() alldensities.reverse() allfrequencies.reverse() allpts = [] for i, (key, frequencies, densities) in enumerate(zip(keys, allfrequencies, alldensities)): allpts.extend(list(zip(frequencies, densities))) if self.stack: plt.fill(frequencies, densities, color=colors[i % ncolors], label=str(key)) else: plt.plot(frequencies, densities, color=colors[i % ncolors], label=str(key), linewidth=3) if xlim: plt.xlim(xlim) if ylim: plt.ylim(ylim) else: xlim = plt.xlim() relevanty = [p[1] for p in allpts if xlim[0] < p[0] < xlim[1]] plt.ylim((min(relevanty), max(relevanty))) ylim = plt.ylim() plt.plot([0, 0], ylim, 'k--', linewidth=2) plt.xlabel(r'$\mathrm{{Frequencies\ ({})}}$'.format(u.label)) plt.ylabel(r'$\mathrm{Density\ of\ states}$') plt.legend() leg = plt.gca().get_legend() ltext = leg.get_texts() # all the text.Text instance in the legend plt.setp(ltext, fontsize=30) plt.tight_layout() return plt
def _get_2d_plot(self, label_stable=True, label_unstable=True, ordering=None, energy_colormap=None, vmin_mev=-60.0, vmax_mev=60.0, show_colorbar=True, process_attributes=False): """ Shows the plot using pylab. Usually I won't do imports in methods, but since plotting is a fairly expensive library to load and not all machines have matplotlib installed, I have done it this way. """ plt = pretty_plot(8, 6) from matplotlib.font_manager import FontProperties if ordering is None: (lines, labels, unstable) = self.pd_plot_data else: (_lines, _labels, _unstable) = self.pd_plot_data (lines, labels, unstable) = order_phase_diagram( _lines, _labels, _unstable, ordering) if energy_colormap is None: if process_attributes: for x, y in lines: plt.plot(x, y, "k-", linewidth=3, markeredgecolor="k") # One should think about a clever way to have "complex" # attributes with complex processing options but with a clear # logic. At this moment, I just use the attributes to know # whether an entry is a new compound or an existing (from the # ICSD or from the MP) one. for x, y in labels.keys(): if labels[(x, y)].attribute is None or \ labels[(x, y)].attribute == "existing": plt.plot(x, y, "ko", linewidth=3, markeredgecolor="k", markerfacecolor="b", markersize=12) else: plt.plot(x, y, "k*", linewidth=3, markeredgecolor="k", markerfacecolor="g", markersize=18) else: for x, y in lines: plt.plot(x, y, "ko-", linewidth=3, markeredgecolor="k", markerfacecolor="b", markersize=15) else: from matplotlib.colors import Normalize, LinearSegmentedColormap from matplotlib.cm import ScalarMappable pda = PDAnalyzer(self._pd) for x, y in lines: plt.plot(x, y, "k-", linewidth=3, markeredgecolor="k") vmin = vmin_mev / 1000.0 vmax = vmax_mev / 1000.0 if energy_colormap == 'default': mid = - vmin / (vmax - vmin) cmap = LinearSegmentedColormap.from_list( 'my_colormap', [(0.0, '#005500'), (mid, '#55FF55'), (mid, '#FFAAAA'), (1.0, '#FF0000')]) else: cmap = energy_colormap norm = Normalize(vmin=vmin, vmax=vmax) _map = ScalarMappable(norm=norm, cmap=cmap) _energies = [pda.get_equilibrium_reaction_energy(entry) for coord, entry in labels.items()] energies = [en if en < 0.0 else -0.00000001 for en in _energies] vals_stable = _map.to_rgba(energies) ii = 0 if process_attributes: for x, y in labels.keys(): if labels[(x, y)].attribute is None or \ labels[(x, y)].attribute == "existing": plt.plot(x, y, "o", markerfacecolor=vals_stable[ii], markersize=12) else: plt.plot(x, y, "*", markerfacecolor=vals_stable[ii], markersize=18) ii += 1 else: for x, y in labels.keys(): plt.plot(x, y, "o", markerfacecolor=vals_stable[ii], markersize=15) ii += 1 font = FontProperties() font.set_weight("bold") font.set_size(24) # Sets a nice layout depending on the type of PD. Also defines a # "center" for the PD, which then allows the annotations to be spread # out in a nice manner. if len(self._pd.elements) == 3: plt.axis("equal") plt.xlim((-0.1, 1.2)) plt.ylim((-0.1, 1.0)) plt.axis("off") center = (0.5, math.sqrt(3) / 6) else: all_coords = labels.keys() miny = min([c[1] for c in all_coords]) ybuffer = max(abs(miny) * 0.1, 0.1) plt.xlim((-0.1, 1.1)) plt.ylim((miny - ybuffer, ybuffer)) center = (0.5, miny / 2) plt.xlabel("Fraction", fontsize=28, fontweight='bold') plt.ylabel("Formation energy (eV/fu)", fontsize=28, fontweight='bold') for coords in sorted(labels.keys(), key=lambda x: -x[1]): entry = labels[coords] label = entry.name # The follow defines an offset for the annotation text emanating # from the center of the PD. Results in fairly nice layouts for the # most part. vec = (np.array(coords) - center) vec = vec / np.linalg.norm(vec) * 10 if np.linalg.norm(vec) != 0 \ else vec valign = "bottom" if vec[1] > 0 else "top" if vec[0] < -0.01: halign = "right" elif vec[0] > 0.01: halign = "left" else: halign = "center" if label_stable: if process_attributes and entry.attribute == 'new': plt.annotate(latexify(label), coords, xytext=vec, textcoords="offset points", horizontalalignment=halign, verticalalignment=valign, fontproperties=font, color='g') else: plt.annotate(latexify(label), coords, xytext=vec, textcoords="offset points", horizontalalignment=halign, verticalalignment=valign, fontproperties=font) if self.show_unstable: font = FontProperties() font.set_size(16) pda = PDAnalyzer(self._pd) energies_unstable = [pda.get_e_above_hull(entry) for entry, coord in unstable.items()] if energy_colormap is not None: energies.extend(energies_unstable) vals_unstable = _map.to_rgba(energies_unstable) ii = 0 for entry, coords in unstable.items(): ehull = pda.get_e_above_hull(entry) if ehull < self.show_unstable: vec = (np.array(coords) - center) vec = vec / np.linalg.norm(vec) * 10 \ if np.linalg.norm(vec) != 0 else vec label = entry.name if energy_colormap is None: plt.plot(coords[0], coords[1], "ks", linewidth=3, markeredgecolor="k", markerfacecolor="r", markersize=8) else: plt.plot(coords[0], coords[1], "s", linewidth=3, markeredgecolor="k", markerfacecolor=vals_unstable[ii], markersize=8) if label_unstable: plt.annotate(latexify(label), coords, xytext=vec, textcoords="offset points", horizontalalignment=halign, color="b", verticalalignment=valign, fontproperties=font) ii += 1 if energy_colormap is not None and show_colorbar: _map.set_array(energies) cbar = plt.colorbar(_map) cbar.set_label( 'Energy [meV/at] above hull (in red)\nInverse energy [' 'meV/at] above hull (in green)', rotation=-90, ha='left', va='center') ticks = cbar.ax.get_yticklabels() # cbar.ax.set_yticklabels(['${v}$'.format( # v=float(t.get_text().strip('$'))*1000.0) for t in ticks]) f = plt.gcf() f.set_size_inches((8, 6)) plt.subplots_adjust(left=0.09, right=0.98, top=0.98, bottom=0.07) return plt
def get_pourbaix_plot_colorfill_by_domain_name(self, limits=None, title="", label_domains=True, label_color='k', domain_color=None, domain_fontsize=None, domain_edge_lw=0.5, bold_domains=None, cluster_domains=(), add_h2o_stablity_line=True, add_center_line=False, h2o_lw=0.5): """ Color domains by the colors specific by the domain_color dict Args: limits: 2D list containing limits of the Pourbaix diagram of the form [[xlo, xhi], [ylo, yhi]] lable_domains (Bool): whether add the text lable for domains label_color (str): color of domain lables, defaults to be black domain_color (dict): colors of each domain e.g {"Al(s)": "#FF1100"}. If set to None default color set will be used. domain_fontsize (int): Font size used in domain text labels. domain_edge_lw (int): line width for the boundaries between domains. bold_domains (list): List of domain names to use bold text style for domain lables. cluster_domains (list): List of domain names in cluster phase add_h2o_stablity_line (Bool): whether plot H2O stability line add_center_line (Bool): whether plot lines shows the center coordinate h2o_lw (int): line width for H2O stability line and center lines """ # helper functions def len_elts(entry): comp = Composition(entry[:-3]) if "(s)" in entry else Ion.from_formula(entry) return len(set(comp.elements) - {Element("H"), Element("O")}) def special_lines(xlim, ylim): h_line = np.transpose([[xlim[0], -xlim[0] * PREFAC], [xlim[1], -xlim[1] * PREFAC]]) o_line = np.transpose([[xlim[0], -xlim[0] * PREFAC + 1.23], [xlim[1], -xlim[1] * PREFAC + 1.23]]) neutral_line = np.transpose([[7, ylim[0]], [7, ylim[1]]]) V0_line = np.transpose([[xlim[0], 0], [xlim[1], 0]]) return h_line, o_line, neutral_line, V0_line from matplotlib.patches import Polygon from pymatgen import Composition, Element from pymatgen.core.ion import Ion default_domain_font_size = 12 default_solid_phase_color = '#b8f9e7' # this slighly darker than the MP scheme, to default_cluster_phase_color = '#d0fbef' # avoid making the cluster phase too light plt = pretty_plot(8, dpi=300) (stable, unstable) = self.pourbaix_plot_data(limits) num_of_overlaps = {key: 0 for key in stable.keys()} entry_dict_of_multientries = collections.defaultdict(list) for entry in stable: if isinstance(entry, MultiEntry): for e in entry.entrylist: entry_dict_of_multientries[e.name].append(entry) num_of_overlaps[entry] += 1 else: entry_dict_of_multientries[entry.name].append(entry) xlim, ylim = limits[:2] if limits else self._analyzer.chempot_limits[:2] h_line, o_line, neutral_line, V0_line = special_lines(xlim, ylim) ax = plt.gca() ax.set_xlim(xlim) ax.set_ylim(ylim) ax.xaxis.set_major_formatter(FormatStrFormatter('%.1f')) ax.yaxis.set_major_formatter(FormatStrFormatter('%.1f')) ax.tick_params(direction='out') ax.xaxis.set_ticks_position('bottom') ax.yaxis.set_ticks_position('left') sorted_entry = list(entry_dict_of_multientries.keys()) sorted_entry.sort(key=len_elts) if domain_fontsize is None: domain_fontsize = {en: default_domain_font_size for en in sorted_entry} if domain_color is None: domain_color = {en: default_solid_phase_color if '(s)' in en else (default_cluster_phase_color if en in cluster_domains else 'w') for i, en in enumerate(sorted_entry)} if bold_domains is None: bold_domains = [en for en in sorted_entry if '(s)' not in en] for entry in sorted_entry: x_coord, y_coord, npts = 0.0, 0.0, 0 for e in entry_dict_of_multientries[entry]: xy = self.domain_vertices(e) if add_h2o_stablity_line: c = self.get_distribution_corrected_center(stable[e], h_line, o_line, 0.3) else: c = self.get_distribution_corrected_center(stable[e]) x_coord += c[0] y_coord += c[1] npts += 1 patch = Polygon(xy, facecolor=domain_color[entry], closed=True, lw=domain_edge_lw, fill=True, antialiased=True) ax.add_patch(patch) xy_center = (x_coord / npts, y_coord / npts) if label_domains: if platform.system() == 'Darwin': # Have to hack to the hard coded font path to get current font On Mac OS X if entry in bold_domains: font = FontProperties(fname='/Library/Fonts/Times New Roman Bold.ttf', size=domain_fontsize[entry]) else: font = FontProperties(fname='/Library/Fonts/Times New Roman.ttf', size=domain_fontsize[entry]) else: if entry in bold_domains: font = FontProperties(family='Times New Roman', weight='bold', size=domain_fontsize[entry]) else: font = FontProperties(family='Times New Roman', weight='regular', size=domain_fontsize[entry]) plt.text(*xy_center, s=latexify_ion(latexify(entry)), fontproperties=font, horizontalalignment="center", verticalalignment="center", multialignment="center", color=label_color) if add_h2o_stablity_line: dashes = (3, 1.5) line, = plt.plot(h_line[0], h_line[1], "k--", linewidth=h2o_lw, antialiased=True) line.set_dashes(dashes) line, = plt.plot(o_line[0], o_line[1], "k--", linewidth=h2o_lw, antialiased=True) line.set_dashes(dashes) if add_center_line: plt.plot(neutral_line[0], neutral_line[1], "k-.", linewidth=h2o_lw, antialiased=False) plt.plot(V0_line[0], V0_line[1], "k-.", linewidth=h2o_lw, antialiased=False) plt.xlabel("pH", fontname="Times New Roman", fontsize=18) plt.ylabel("E (V)", fontname="Times New Roman", fontsize=18) plt.xticks(fontname="Times New Roman", fontsize=16) plt.yticks(fontname="Times New Roman", fontsize=16) plt.title(title, fontsize=20, fontweight='bold', fontname="Times New Roman") return plt
def get_chempot_range_map_plot(self, elements,referenced=True): """ Returns a plot of the chemical potential range _map. Currently works only for 3-component PDs. Args: elements: Sequence of elements to be considered as independent variables. E.g., if you want to show the stability ranges of all Li-Co-O phases wrt to uLi and uO, you will supply [Element("Li"), Element("O")] referenced: if True, gives the results with a reference being the energy of the elemental phase. If False, gives absolute values. Returns: A matplotlib plot object. """ plt = pretty_plot(12, 8) analyzer = PDAnalyzer(self._pd) chempot_ranges = analyzer.get_chempot_range_map( elements, referenced=referenced) missing_lines = {} excluded_region = [] for entry, lines in chempot_ranges.items(): comp = entry.composition center_x = 0 center_y = 0 coords = [] contain_zero = any([comp.get_atomic_fraction(el) == 0 for el in elements]) is_boundary = (not contain_zero) and \ sum([comp.get_atomic_fraction(el) for el in elements]) == 1 for line in lines: (x, y) = line.coords.transpose() plt.plot(x, y, "k-") for coord in line.coords: if not in_coord_list(coords, coord): coords.append(coord.tolist()) center_x += coord[0] center_y += coord[1] if is_boundary: excluded_region.extend(line.coords) if coords and contain_zero: missing_lines[entry] = coords else: xy = (center_x / len(coords), center_y / len(coords)) plt.annotate(latexify(entry.name), xy, fontsize=22) ax = plt.gca() xlim = ax.get_xlim() ylim = ax.get_ylim() #Shade the forbidden chemical potential regions. excluded_region.append([xlim[1], ylim[1]]) excluded_region = sorted(excluded_region, key=lambda c: c[0]) (x, y) = np.transpose(excluded_region) plt.fill(x, y, "0.80") #The hull does not generate the missing horizontal and vertical lines. #The following code fixes this. el0 = elements[0] el1 = elements[1] for entry, coords in missing_lines.items(): center_x = sum([c[0] for c in coords]) center_y = sum([c[1] for c in coords]) comp = entry.composition is_x = comp.get_atomic_fraction(el0) < 0.01 is_y = comp.get_atomic_fraction(el1) < 0.01 n = len(coords) if not (is_x and is_y): if is_x: coords = sorted(coords, key=lambda c: c[1]) for i in [0, -1]: x = [min(xlim), coords[i][0]] y = [coords[i][1], coords[i][1]] plt.plot(x, y, "k") center_x += min(xlim) center_y += coords[i][1] elif is_y: coords = sorted(coords, key=lambda c: c[0]) for i in [0, -1]: x = [coords[i][0], coords[i][0]] y = [coords[i][1], min(ylim)] plt.plot(x, y, "k") center_x += coords[i][0] center_y += min(ylim) xy = (center_x / (n + 2), center_y / (n + 2)) else: center_x = sum(coord[0] for coord in coords) + xlim[0] center_y = sum(coord[1] for coord in coords) + ylim[0] xy = (center_x / (n + 1), center_y / (n + 1)) plt.annotate(latexify(entry.name), xy, horizontalalignment="center", verticalalignment="center", fontsize=22) plt.xlabel("$\\mu_{{{0}}} - \\mu_{{{0}}}^0$ (eV)" .format(el0.symbol)) plt.ylabel("$\\mu_{{{0}}} - \\mu_{{{0}}}^0$ (eV)" .format(el1.symbol)) plt.tight_layout() return plt
def get_pourbaix_plot(self, limits=None, title="", label_domains=True): """ Plot Pourbaix diagram. Args: limits: 2D list containing limits of the Pourbaix diagram of the form [[xlo, xhi], [ylo, yhi]] Returns: plt: matplotlib plot object """ # plt = pretty_plot(24, 14.4) plt = pretty_plot(16) (stable, unstable) = self.pourbaix_plot_data(limits) if limits: xlim = limits[0] ylim = limits[1] else: xlim = self._analyzer.chempot_limits[0] ylim = self._analyzer.chempot_limits[1] h_line = np.transpose([[xlim[0], -xlim[0] * PREFAC], [xlim[1], -xlim[1] * PREFAC]]) o_line = np.transpose([[xlim[0], -xlim[0] * PREFAC + 1.23], [xlim[1], -xlim[1] * PREFAC + 1.23]]) neutral_line = np.transpose([[7, ylim[0]], [7, ylim[1]]]) V0_line = np.transpose([[xlim[0], 0], [xlim[1], 0]]) ax = plt.gca() ax.set_xlim(xlim) ax.set_ylim(ylim) lw = 3 plt.plot(h_line[0], h_line[1], "r--", linewidth=lw) plt.plot(o_line[0], o_line[1], "r--", linewidth=lw) plt.plot(neutral_line[0], neutral_line[1], "k-.", linewidth=lw) plt.plot(V0_line[0], V0_line[1], "k-.", linewidth=lw) for entry, lines in stable.items(): center_x = 0.0 center_y = 0.0 coords = [] count_center = 0.0 for line in lines: (x, y) = line plt.plot(x, y, "k-", linewidth=lw) for coord in np.array(line).T: if not in_coord_list(coords, coord): coords.append(coord.tolist()) cx = coord[0] cy = coord[1] center_x += cx center_y += cy count_center += 1.0 if count_center == 0.0: count_center = 1.0 center_x /= count_center center_y /= count_center if ((center_x <= xlim[0]) | (center_x >= xlim[1]) | (center_y <= ylim[0]) | (center_y >= ylim[1])): continue xy = (center_x, center_y) if label_domains: plt.annotate(self.print_name(entry), xy, fontsize=20, color="b") plt.xlabel("pH") plt.ylabel("E (V)") plt.title(title, fontsize=20, fontweight='bold') return plt