def generate_entry_label(entry): """ Generates a label for the pourbaix plotter Args: entry (PourbaixEntry or MultiEntry): entry to get a label for """ if isinstance(entry, MultiEntry): return " + ".join([latexify_ion(latexify(e.name)) for e in entry.entry_list]) return latexify_ion(latexify(entry.name))
def _get_xaxis_title(self, latex: bool = True) -> str: """Returns the formatted title of the x axis (using either html/latex)""" if latex: f1 = latexify(self.c1.reduced_formula) f2 = latexify(self.c2.reduced_formula) title = f"$x$ in $x${f1} + $(1-x)${f2}" else: f1 = htmlify(self.c1.reduced_formula) f2 = htmlify(self.c2.reduced_formula) title = f"<i>x</i> in <i>x</i>{f1} + (1-<i>x</i>){f2}" return title
def _get_plot(self, label_stable=True, label_unstable=False): """ Plot convex hull of Pourbaix Diagram entries """ import matplotlib.pyplot as plt import mpl_toolkits.mplot3d.axes3d as p3 from matplotlib.font_manager import FontProperties fig = plt.figure() ax = p3.Axes3D(fig) font = FontProperties() font.set_weight("bold") font.set_size(14) (lines, labels, unstable) = self.pourbaix_hull_plot_data count = 1 newlabels = list() for x, y, z in lines: ax.plot(x, y, z, "bo-", linewidth=3, markeredgecolor="b", markerfacecolor="r", markersize=10) for coords in sorted(labels.keys()): entry = labels[coords] label = self.print_name(entry) if label_stable: ax.text(coords[0], coords[1], coords[2], str(count)) newlabels.append("{} : {}".format( count, latexify_ion(latexify(label)))) count += 1 if label_unstable: for entry in unstable.keys(): label = self.print_name(entry) coords = unstable[entry] ax.plot([coords[0], coords[0]], [coords[1], coords[1]], [coords[2], coords[2]], "bo", markerfacecolor="g", markersize=10) ax.text(coords[0], coords[1], coords[2], str(count)) newlabels.append("{} : {}".format( count, latexify_ion(latexify(label)))) count += 1 plt.figtext(0.01, 0.01, "\n".join(newlabels)) plt.xlabel("pH") plt.ylabel("V") return plt
def print_name(self, entry): """ Print entry name if single, else print multientry """ str_name = "" if isinstance(entry, MultiEntry): if len(entry.entrylist) > 2: return str(self._pd.qhull_entries.index(entry)) for e in entry.entrylist: str_name += latexify_ion(latexify(e.name)) + " + " str_name = str_name[:-3] return str_name else: return latexify_ion(latexify(entry.name))
def _get_3d_plot(self, label_stable=True): """ Shows the plot using pylab. Usually I won"t do imports in methods, but since plotting is a fairly expensive library to load and not all machines have matplotlib installed, I have done it this way. """ import matplotlib.pyplot as plt import mpl_toolkits.mplot3d.axes3d as p3 from matplotlib.font_manager import FontProperties fig = plt.figure() ax = p3.Axes3D(fig) font = FontProperties() font.set_weight("bold") font.set_size(20) (lines, labels, unstable) = self.pd_plot_data count = 1 newlabels = list() for x, y, z in lines: ax.plot(x, y, z, "bo-", linewidth=3, markeredgecolor="b", markerfacecolor="r", markersize=10) for coords in sorted(labels.keys()): entry = labels[coords] label = entry.name if label_stable: if len(entry.composition.elements) == 1: ax.text(coords[0], coords[1], coords[2], label) else: ax.text(coords[0], coords[1], coords[2], str(count)) newlabels.append("{} : {}".format(count, latexify(label))) count += 1 plt.figtext(0.01, 0.01, "\n".join(newlabels)) ax.axis("off") return plt
def plot_planes(self): """ Plot the free energy facets as a function of pH and V """ if self.show_unstable: entries = self._pd._all_entries else: entries = self._pd.stable_entries num_plots = len(entries) import matplotlib.pyplot as plt colormap = plt.cm.gist_ncar fig = plt.figure().gca(projection='3d') color_array = [colormap(i) for i in np.linspace(0, 0.9, num_plots)] labels = [] color_index = -1 for entry in entries: normal = np.array([-PREFAC * entry.npH, -entry.nPhi, +1]) d = entry.g0 color_index += 1 pH, V = np.meshgrid(np.linspace(-10, 28, 100), np.linspace(-3, 3, 100)) g = (-normal[0] * pH - normal[1] * V + d) / normal[2] lbl = latexify_ion( latexify(entry._entry.composition.reduced_formula)) labels.append(lbl) fig.plot_surface(pH, V, g, color=color_array[color_index], label=lbl) plt.legend(labels) plt.xlabel("pH") plt.ylabel("E (V)") plt.show()
def _plot(self): for composition, position in self.info.comp_centers.items(): self.draw_simplex(composition) self._ax.text(*position, latexify(composition.reduced_formula), color=self._text_color(composition), **self._mpl_defaults.label, **self._text_kwargs)
def generate_entry_label(entry): """ Generates a label for the pourbaix plotter Args: entry (PourbaixEntry or MultiEntry): entry to get a label for """ if isinstance(entry, MultiEntry): return " + ".join([latexify_ion(e.name) for e in entry.entry_list]) else: return latexify_ion(latexify(entry.name))
def get_all_component_descriptions(self) -> str: """Gets the descriptions of all components in the structure. Returns: A description of all components in the structure. """ if len(self._da.components) == 1: return self.get_component_description( self._da.get_component_groups()[0].components[0].index, single_component=True, ) else: component_groups = self._da.get_component_groups() component_descriptions = [] for group in component_groups: for component in group.components: if group.molecule_name: # don't describe known molecules continue formula = group.formula group_count = group.count component_count = component.count shape = dimensionality_to_shape[group.dimensionality] if self.fmt == "latex": formula = latexify(formula) elif self.fmt == "unicode": formula = unicodeify(formula) elif self.fmt == "html": formula = htmlify(formula) if group_count == component_count: s_filler = "the" if group_count == 1 else "each" else: s_filler = "{} of the".format( en.number_to_words(component_count) ) shape = en.plural(shape) desc = f"In {s_filler} {formula} {shape}, " desc += self.get_component_description(component.index) component_descriptions.append(desc) return " ".join(component_descriptions)
def _get_matplotlib_figure(self) -> plt.Figure: """Returns a matplotlib figure of reaction kinks diagram""" pretty_plot(8, 5) plt.xlim([-0.05, 1.05]) # plot boundary is 5% wider on each side kinks = list(zip(*self.get_kinks())) # type: ignore _, x, energy, reactions, _ = kinks plt.plot(x, energy, "o-", markersize=8, c="navy", zorder=1) plt.scatter(self.minimum[0], self.minimum[1], marker="*", c="red", s=400, zorder=2) for x_coord, y_coord, rxn in zip(x, energy, reactions): products = ", ".join([ latexify(p.reduced_formula) for p in rxn.products # type: ignore if not np.isclose(rxn.get_coeff(p), 0) # type: ignore ]) plt.annotate( products, xy=(x_coord, y_coord), xytext=(10, -30), textcoords="offset points", ha="right", va="bottom", arrowprops=dict(arrowstyle="->", connectionstyle="arc3,rad=0"), ) if self.norm: plt.ylabel("Energy (eV/atom)") else: plt.ylabel("Energy (eV/f.u.)") plt.xlabel(self._get_xaxis_title()) plt.ylim(self.minimum[1] + 0.05 * self.minimum[1]) # plot boundary is 5% lower fig = plt.gcf() plt.close(fig) return fig
def make_defect_formation_energy(args): title = latexify( args.perfect_calc_results.structure.composition.reduced_formula) chem_pot_diag = ChemPotDiag.from_yaml(args.chem_pot_diag) abs_chem_pot = chem_pot_diag.abs_chem_pot_dict(args.label) single_energies = [] for d in args.dirs: if args.skip_shallow and loadfn( d / "band_edge_states.json").is_shallow: continue single_energies.append( make_single_defect_energy(args.perfect_calc_results, loadfn(d / "calc_results.json"), loadfn(d / "defect_entry.json"), abs_chem_pot, loadfn(d / "correction.json"))) defect_energies = make_defect_energies(single_energies) if args.print: print(" charge E_f correction ") for e in defect_energies: print(e) print("") print("-- cross points -- ") for e in defect_energies: print(e.name) print(e.cross_points(args.unitcell.vbm, args.unitcell.cbm)) print("") return plotter = DefectEnergyMplPlotter( title=title, defect_energies=defect_energies, vbm=args.unitcell.vbm, cbm=args.unitcell.cbm, supercell_vbm=args.perfect_calc_results.vbm, supercell_cbm=args.perfect_calc_results.cbm, y_range=args.y_range) plotter.construct_plot() plotter.plt.savefig(f"energy_{args.label}.pdf")
def defect_energy_plotters(): va_o = DefectEnergy(name="Va_O1", charges=[0, 1, 2], energies=[5, 2, -5], corrections=[1, 1, 1]) va_mg = DefectEnergy(name="Va_Mg1", charges=[-2, -1, 0], energies=[5, 2, 0], corrections=[-1, -1, -1]) mg_i = DefectEnergy(name="Mg_i1", charges=[1], energies=[4], corrections=[1]) d = dict(title=latexify("MgAl2O4"), defect_energies=[va_o, va_mg, mg_i], vbm=1.5, cbm=5.5, supercell_vbm=1.0, supercell_cbm=6.0) return DefectEnergyMplPlotter(**d), DefectEnergyPlotlyPlotter(**d)
def get_mineral_description(self) -> str: """Gets the mineral name and space group description. If the structure is a perfect match for a known prototype (e.g. the distance parameter is -1, the mineral name is the prototype name. If a structure is not a perfect match but similar to a known mineral, "-like" will be added to the mineral name. If the structure is a good match to a mineral but contains a different number of element types than the mineral prototype, "-derived" will be added to the mineral name. Returns: The description of the mineral name. """ spg_symbol = self._da.spg_symbol formula = self._da.formula if self.fmt == "latex": spg_symbol = latexify_spacegroup(self._da.spg_symbol) formula = latexify(formula) elif self.fmt == "unicode": spg_symbol = unicodeify_spacegroup(self._da.spg_symbol) formula = unicodeify(formula) elif self.fmt == "html": spg_symbol = htmlify_spacegroup(self._da.spg_symbol) formula = htmlify(formula) mineral_name = get_mineral_name(self._da.mineral) if mineral_name: desc = f"{formula} is {mineral_name} structured and" else: desc = f"{formula}" desc += " crystallizes in the {} {} space group.".format( self._da.crystal_system, spg_symbol ) return desc
def get_pourbaix_plot_colorfill_by_domain_name(self, limits=None, title="", label_domains=True, label_color='k', domain_color=None, domain_fontsize=None, domain_edge_lw=0.5, bold_domains=None, cluster_domains=(), add_h2o_stablity_line=True, add_center_line=False, h2o_lw=0.5): """ Color domains by the colors specific by the domain_color dict Args: limits: 2D list containing limits of the Pourbaix diagram of the form [[xlo, xhi], [ylo, yhi]] lable_domains (Bool): whether add the text lable for domains label_color (str): color of domain lables, defaults to be black domain_color (dict): colors of each domain e.g {"Al(s)": "#FF1100"}. If set to None default color set will be used. domain_fontsize (int): Font size used in domain text labels. domain_edge_lw (int): line width for the boundaries between domains. bold_domains (list): List of domain names to use bold text style for domain lables. cluster_domains (list): List of domain names in cluster phase add_h2o_stablity_line (Bool): whether plot H2O stability line add_center_line (Bool): whether plot lines shows the center coordinate h2o_lw (int): line width for H2O stability line and center lines """ # helper functions def len_elts(entry): comp = Composition(entry[:-3]) if "(s)" in entry else Ion.from_formula(entry) return len(set(comp.elements) - {Element("H"), Element("O")}) def special_lines(xlim, ylim): h_line = np.transpose([[xlim[0], -xlim[0] * PREFAC], [xlim[1], -xlim[1] * PREFAC]]) o_line = np.transpose([[xlim[0], -xlim[0] * PREFAC + 1.23], [xlim[1], -xlim[1] * PREFAC + 1.23]]) neutral_line = np.transpose([[7, ylim[0]], [7, ylim[1]]]) V0_line = np.transpose([[xlim[0], 0], [xlim[1], 0]]) return h_line, o_line, neutral_line, V0_line from matplotlib.patches import Polygon from pymatgen import Composition, Element from pymatgen.core.ion import Ion default_domain_font_size = 12 default_solid_phase_color = '#b8f9e7' # this slighly darker than the MP scheme, to default_cluster_phase_color = '#d0fbef' # avoid making the cluster phase too light plt = pretty_plot(8, dpi=300) (stable, unstable) = self.pourbaix_plot_data(limits) num_of_overlaps = {key: 0 for key in stable.keys()} entry_dict_of_multientries = collections.defaultdict(list) for entry in stable: if isinstance(entry, MultiEntry): for e in entry.entrylist: entry_dict_of_multientries[e.name].append(entry) num_of_overlaps[entry] += 1 else: entry_dict_of_multientries[entry.name].append(entry) xlim, ylim = limits[:2] if limits else self._analyzer.chempot_limits[:2] h_line, o_line, neutral_line, V0_line = special_lines(xlim, ylim) ax = plt.gca() ax.set_xlim(xlim) ax.set_ylim(ylim) ax.xaxis.set_major_formatter(FormatStrFormatter('%.1f')) ax.yaxis.set_major_formatter(FormatStrFormatter('%.1f')) ax.tick_params(direction='out') ax.xaxis.set_ticks_position('bottom') ax.yaxis.set_ticks_position('left') sorted_entry = list(entry_dict_of_multientries.keys()) sorted_entry.sort(key=len_elts) if domain_fontsize is None: domain_fontsize = {en: default_domain_font_size for en in sorted_entry} if domain_color is None: domain_color = {en: default_solid_phase_color if '(s)' in en else (default_cluster_phase_color if en in cluster_domains else 'w') for i, en in enumerate(sorted_entry)} if bold_domains is None: bold_domains = [en for en in sorted_entry if '(s)' not in en] for entry in sorted_entry: x_coord, y_coord, npts = 0.0, 0.0, 0 for e in entry_dict_of_multientries[entry]: xy = self.domain_vertices(e) if add_h2o_stablity_line: c = self.get_distribution_corrected_center(stable[e], h_line, o_line, 0.3) else: c = self.get_distribution_corrected_center(stable[e]) x_coord += c[0] y_coord += c[1] npts += 1 patch = Polygon(xy, facecolor=domain_color[entry], closed=True, lw=domain_edge_lw, fill=True, antialiased=True) ax.add_patch(patch) xy_center = (x_coord / npts, y_coord / npts) if label_domains: if platform.system() == 'Darwin': # Have to hack to the hard coded font path to get current font On Mac OS X if entry in bold_domains: font = FontProperties(fname='/Library/Fonts/Times New Roman Bold.ttf', size=domain_fontsize[entry]) else: font = FontProperties(fname='/Library/Fonts/Times New Roman.ttf', size=domain_fontsize[entry]) else: if entry in bold_domains: font = FontProperties(family='Times New Roman', weight='bold', size=domain_fontsize[entry]) else: font = FontProperties(family='Times New Roman', weight='regular', size=domain_fontsize[entry]) plt.text(*xy_center, s=latexify_ion(latexify(entry)), fontproperties=font, horizontalalignment="center", verticalalignment="center", multialignment="center", color=label_color) if add_h2o_stablity_line: dashes = (3, 1.5) line, = plt.plot(h_line[0], h_line[1], "k--", linewidth=h2o_lw, antialiased=True) line.set_dashes(dashes) line, = plt.plot(o_line[0], o_line[1], "k--", linewidth=h2o_lw, antialiased=True) line.set_dashes(dashes) if add_center_line: plt.plot(neutral_line[0], neutral_line[1], "k-.", linewidth=h2o_lw, antialiased=False) plt.plot(V0_line[0], V0_line[1], "k-.", linewidth=h2o_lw, antialiased=False) plt.xlabel("pH", fontname="Times New Roman", fontsize=18) plt.ylabel("E (V)", fontname="Times New Roman", fontsize=18) plt.xticks(fontname="Times New Roman", fontsize=16) plt.yticks(fontname="Times New Roman", fontsize=16) plt.title(title, fontsize=20, fontweight='bold', fontname="Times New Roman") return plt
def _get_2d_plot(self, label_stable=True, label_unstable=True, ordering=None, energy_colormap=None, vmin_mev=-60.0, vmax_mev=60.0, show_colorbar=True, process_attributes=False): """ Shows the plot using pylab. Usually I won't do imports in methods, but since plotting is a fairly expensive library to load and not all machines have matplotlib installed, I have done it this way. """ plt = pretty_plot(8, 6) from matplotlib.font_manager import FontProperties if ordering is None: (lines, labels, unstable) = self.pd_plot_data else: (_lines, _labels, _unstable) = self.pd_plot_data (lines, labels, unstable) = order_phase_diagram( _lines, _labels, _unstable, ordering) if energy_colormap is None: if process_attributes: for x, y in lines: plt.plot(x, y, "k-", linewidth=3, markeredgecolor="k") # One should think about a clever way to have "complex" # attributes with complex processing options but with a clear # logic. At this moment, I just use the attributes to know # whether an entry is a new compound or an existing (from the # ICSD or from the MP) one. for x, y in labels.keys(): if labels[(x, y)].attribute is None or \ labels[(x, y)].attribute == "existing": plt.plot(x, y, "ko", linewidth=3, markeredgecolor="k", markerfacecolor="b", markersize=12) else: plt.plot(x, y, "k*", linewidth=3, markeredgecolor="k", markerfacecolor="g", markersize=18) else: for x, y in lines: plt.plot(x, y, "ko-", linewidth=3, markeredgecolor="k", markerfacecolor="b", markersize=15) else: from matplotlib.colors import Normalize, LinearSegmentedColormap from matplotlib.cm import ScalarMappable pda = PDAnalyzer(self._pd) for x, y in lines: plt.plot(x, y, "k-", linewidth=3, markeredgecolor="k") vmin = vmin_mev / 1000.0 vmax = vmax_mev / 1000.0 if energy_colormap == 'default': mid = - vmin / (vmax - vmin) cmap = LinearSegmentedColormap.from_list( 'my_colormap', [(0.0, '#005500'), (mid, '#55FF55'), (mid, '#FFAAAA'), (1.0, '#FF0000')]) else: cmap = energy_colormap norm = Normalize(vmin=vmin, vmax=vmax) _map = ScalarMappable(norm=norm, cmap=cmap) _energies = [pda.get_equilibrium_reaction_energy(entry) for coord, entry in labels.items()] energies = [en if en < 0.0 else -0.00000001 for en in _energies] vals_stable = _map.to_rgba(energies) ii = 0 if process_attributes: for x, y in labels.keys(): if labels[(x, y)].attribute is None or \ labels[(x, y)].attribute == "existing": plt.plot(x, y, "o", markerfacecolor=vals_stable[ii], markersize=12) else: plt.plot(x, y, "*", markerfacecolor=vals_stable[ii], markersize=18) ii += 1 else: for x, y in labels.keys(): plt.plot(x, y, "o", markerfacecolor=vals_stable[ii], markersize=15) ii += 1 font = FontProperties() font.set_weight("bold") font.set_size(24) # Sets a nice layout depending on the type of PD. Also defines a # "center" for the PD, which then allows the annotations to be spread # out in a nice manner. if len(self._pd.elements) == 3: plt.axis("equal") plt.xlim((-0.1, 1.2)) plt.ylim((-0.1, 1.0)) plt.axis("off") center = (0.5, math.sqrt(3) / 6) else: all_coords = labels.keys() miny = min([c[1] for c in all_coords]) ybuffer = max(abs(miny) * 0.1, 0.1) plt.xlim((-0.1, 1.1)) plt.ylim((miny - ybuffer, ybuffer)) center = (0.5, miny / 2) plt.xlabel("Fraction", fontsize=28, fontweight='bold') plt.ylabel("Formation energy (eV/fu)", fontsize=28, fontweight='bold') for coords in sorted(labels.keys(), key=lambda x: -x[1]): entry = labels[coords] label = entry.name # The follow defines an offset for the annotation text emanating # from the center of the PD. Results in fairly nice layouts for the # most part. vec = (np.array(coords) - center) vec = vec / np.linalg.norm(vec) * 10 if np.linalg.norm(vec) != 0 \ else vec valign = "bottom" if vec[1] > 0 else "top" if vec[0] < -0.01: halign = "right" elif vec[0] > 0.01: halign = "left" else: halign = "center" if label_stable: if process_attributes and entry.attribute == 'new': plt.annotate(latexify(label), coords, xytext=vec, textcoords="offset points", horizontalalignment=halign, verticalalignment=valign, fontproperties=font, color='g') else: plt.annotate(latexify(label), coords, xytext=vec, textcoords="offset points", horizontalalignment=halign, verticalalignment=valign, fontproperties=font) if self.show_unstable: font = FontProperties() font.set_size(16) pda = PDAnalyzer(self._pd) energies_unstable = [pda.get_e_above_hull(entry) for entry, coord in unstable.items()] if energy_colormap is not None: energies.extend(energies_unstable) vals_unstable = _map.to_rgba(energies_unstable) ii = 0 for entry, coords in unstable.items(): ehull = pda.get_e_above_hull(entry) if ehull < self.show_unstable: vec = (np.array(coords) - center) vec = vec / np.linalg.norm(vec) * 10 \ if np.linalg.norm(vec) != 0 else vec label = entry.name if energy_colormap is None: plt.plot(coords[0], coords[1], "ks", linewidth=3, markeredgecolor="k", markerfacecolor="r", markersize=8) else: plt.plot(coords[0], coords[1], "s", linewidth=3, markeredgecolor="k", markerfacecolor=vals_unstable[ii], markersize=8) if label_unstable: plt.annotate(latexify(label), coords, xytext=vec, textcoords="offset points", horizontalalignment=halign, color="b", verticalalignment=valign, fontproperties=font) ii += 1 if energy_colormap is not None and show_colorbar: _map.set_array(energies) cbar = plt.colorbar(_map) cbar.set_label( 'Energy [meV/at] above hull (in red)\nInverse energy [' 'meV/at] above hull (in green)', rotation=-90, ha='left', va='center') ticks = cbar.ax.get_yticklabels() # cbar.ax.set_yticklabels(['${v}$'.format( # v=float(t.get_text().strip('$'))*1000.0) for t in ticks]) f = plt.gcf() f.set_size_inches((8, 6)) plt.subplots_adjust(left=0.09, right=0.98, top=0.98, bottom=0.07) return plt
def get_pourbaix_plot_colorfill_by_domain_name(self, limits=None, title="", label_domains=True, label_color='k', domain_color=None, domain_fontsize=None, domain_edge_lw=0.5, bold_domains=None, cluster_domains=(), add_h2o_stablity_line=True, add_center_line=False, h2o_lw=0.5): """ Color domains by the colors specific by the domain_color dict Args: limits: 2D list containing limits of the Pourbaix diagram of the form [[xlo, xhi], [ylo, yhi]] lable_domains (Bool): whether add the text lable for domains label_color (str): color of domain lables, defaults to be black domain_color (dict): colors of each domain e.g {"Al(s)": "#FF1100"}. If set to None default color set will be used. domain_fontsize (int): Font size used in domain text labels. domain_edge_lw (int): line width for the boundaries between domains. bold_domains (list): List of domain names to use bold text style for domain lables. cluster_domains (list): List of domain names in cluster phase add_h2o_stablity_line (Bool): whether plot H2O stability line add_center_line (Bool): whether plot lines shows the center coordinate h2o_lw (int): line width for H2O stability line and center lines """ # helper functions def len_elts(entry): comp = Composition( entry[:-3]) if "(s)" in entry else Ion.from_formula(entry) return len(set(comp.elements) - {Element("H"), Element("O")}) def special_lines(xlim, ylim): h_line = np.transpose([[xlim[0], -xlim[0] * PREFAC], [xlim[1], -xlim[1] * PREFAC]]) o_line = np.transpose([[xlim[0], -xlim[0] * PREFAC + 1.23], [xlim[1], -xlim[1] * PREFAC + 1.23]]) neutral_line = np.transpose([[7, ylim[0]], [7, ylim[1]]]) V0_line = np.transpose([[xlim[0], 0], [xlim[1], 0]]) return h_line, o_line, neutral_line, V0_line from matplotlib.patches import Polygon from pymatgen import Composition, Element from pymatgen.core.ion import Ion default_domain_font_size = 12 default_solid_phase_color = '#b8f9e7' # this slighly darker than the MP scheme, to default_cluster_phase_color = '#d0fbef' # avoid making the cluster phase too light plt = pretty_plot(8, dpi=300) (stable, unstable) = self.pourbaix_plot_data(limits) num_of_overlaps = {key: 0 for key in stable.keys()} entry_dict_of_multientries = collections.defaultdict(list) for entry in stable: if isinstance(entry, MultiEntry): for e in entry.entrylist: entry_dict_of_multientries[e.name].append(entry) num_of_overlaps[entry] += 1 else: entry_dict_of_multientries[entry.name].append(entry) xlim, ylim = limits[:2] if limits else self._analyzer.chempot_limits[:2] h_line, o_line, neutral_line, V0_line = special_lines(xlim, ylim) ax = plt.gca() ax.set_xlim(xlim) ax.set_ylim(ylim) ax.xaxis.set_major_formatter(FormatStrFormatter('%.1f')) ax.yaxis.set_major_formatter(FormatStrFormatter('%.1f')) ax.tick_params(direction='out') ax.xaxis.set_ticks_position('bottom') ax.yaxis.set_ticks_position('left') sorted_entry = list(entry_dict_of_multientries.keys()) sorted_entry.sort(key=len_elts) if domain_fontsize is None: domain_fontsize = { en: default_domain_font_size for en in sorted_entry } if domain_color is None: domain_color = { en: default_solid_phase_color if '(s)' in en else (default_cluster_phase_color if en in cluster_domains else 'w') for i, en in enumerate(sorted_entry) } if bold_domains is None: bold_domains = [en for en in sorted_entry if '(s)' not in en] for entry in sorted_entry: x_coord, y_coord, npts = 0.0, 0.0, 0 for e in entry_dict_of_multientries[entry]: xy = self.domain_vertices(e) if add_h2o_stablity_line: c = self.get_distribution_corrected_center( stable[e], h_line, o_line, 0.3) else: c = self.get_distribution_corrected_center(stable[e]) x_coord += c[0] y_coord += c[1] npts += 1 patch = Polygon(xy, facecolor=domain_color[entry], closed=True, lw=domain_edge_lw, fill=True, antialiased=True) ax.add_patch(patch) xy_center = (x_coord / npts, y_coord / npts) if label_domains: if platform.system() == 'Darwin': # Have to hack to the hard coded font path to get current font On Mac OS X if entry in bold_domains: font = FontProperties( fname='/Library/Fonts/Times New Roman Bold.ttf', size=domain_fontsize[entry]) else: font = FontProperties( fname='/Library/Fonts/Times New Roman.ttf', size=domain_fontsize[entry]) else: if entry in bold_domains: font = FontProperties(family='Times New Roman', weight='bold', size=domain_fontsize[entry]) else: font = FontProperties(family='Times New Roman', weight='regular', size=domain_fontsize[entry]) plt.text(*xy_center, s=latexify_ion(latexify(entry)), fontproperties=font, horizontalalignment="center", verticalalignment="center", multialignment="center", color=label_color) if add_h2o_stablity_line: dashes = (3, 1.5) line, = plt.plot(h_line[0], h_line[1], "k--", linewidth=h2o_lw, antialiased=True) line.set_dashes(dashes) line, = plt.plot(o_line[0], o_line[1], "k--", linewidth=h2o_lw, antialiased=True) line.set_dashes(dashes) if add_center_line: plt.plot(neutral_line[0], neutral_line[1], "k-.", linewidth=h2o_lw, antialiased=False) plt.plot(V0_line[0], V0_line[1], "k-.", linewidth=h2o_lw, antialiased=False) plt.xlabel("pH", fontname="Times New Roman", fontsize=18) plt.ylabel("E (V)", fontname="Times New Roman", fontsize=18) plt.xticks(fontname="Times New Roman", fontsize=16) plt.yticks(fontname="Times New Roman", fontsize=16) plt.title(title, fontsize=20, fontweight='bold', fontname="Times New Roman") return plt
def make_defect_formation_energy(args): formula = args.perfect_calc_results.structure.composition.reduced_formula chem_pot_diag = ChemPotDiag.from_yaml(args.cpd_yaml) pcr = args.perfect_calc_results defects, defect_entries, corrections, edge_states = [], [], [], [] for d in args.dirs: if args.skip_shallow: edge_states = BandEdgeStates.from_yaml(d / "band_edge_states.yaml") if edge_states.is_shallow: continue defects.append(loadfn(d / "calc_results.json")) defect_entries.append(loadfn(d / "defect_entry.json")) corrections.append(loadfn(d / "correction.json")) if args.web_gui: from crystal_toolkit.settings import SETTINGS import dash_html_components as html from crystal_toolkit.helpers.layouts import Column import crystal_toolkit.components as ctc import dash edge_states = [] for d in args.dirs: edge_states.append( BandEdgeStates.from_yaml(d / "band_edge_states.yaml")) app = dash.Dash(__name__, suppress_callback_exceptions=True, assets_folder=SETTINGS.ASSETS_PATH, external_stylesheets=[ 'https://codepen.io/chriddyp/pen/bWLwgP.css' ]) cpd_plot_info = CpdPlotInfo(chem_pot_diag) cpd_e_component = CpdEnergyComponent(cpd_plot_info, pcr, defects, defect_entries, corrections, args.unitcell.vbm, args.unitcell.cbm, edge_states) my_layout = html.Div([Column(cpd_e_component.layout)]) ctc.register_crystal_toolkit(app=app, layout=my_layout, cache=None) app.run_server(port=args.port) return abs_chem_pot = chem_pot_diag.abs_chem_pot_dict(args.label) title = " ".join([latexify(formula), "point", args.label]) defect_energies = make_energies(pcr, defects, defect_entries, corrections, abs_chem_pot) if args.print: defect_energies = slide_energy(defect_energies, args.unitcell.vbm) print(" charge E_f correction ") for e in defect_energies: print(e) print("") print("-- cross points -- ") for e in defect_energies: print(e.name) print( e.cross_points(ef_min=args.unitcell.vbm, ef_max=args.unitcell.cbm, base_ef=args.unitcell.vbm)) print("") return plotter = DefectEnergyMplPlotter(title=title, defect_energies=defect_energies, vbm=args.unitcell.vbm, cbm=args.unitcell.cbm, supercell_vbm=pcr.vbm, supercell_cbm=pcr.cbm, y_range=args.y_range, supercell_edge=args.supercell_edge, label_line=args.label_line, add_charges=args.add_charges) plotter.construct_plot() plotter.plt.savefig(f"energy_{args.label}.pdf")
def _title(self): return latexify(self._composition.reduced_formula)
def test_latexify(self): self.assertEqual(latexify("Li3Fe2(PO4)3"), "Li$_{3}$Fe$_{2}$(PO$_{4}$)$_{3}$") self.assertEqual(latexify("Li0.2Na0.8Cl"), "Li$_{0.2}$Na$_{0.8}$Cl")
def get_pourbaix_plot_colorfill_by_element(self, limits=None, title="", label_domains=True, element=None): """ Color domains by element """ from matplotlib.patches import Polygon entry_dict_of_multientries = collections.defaultdict(list) plt = pretty_plot(16) optim_colors = ['#0000FF', '#FF0000', '#00FF00', '#FFFF00', '#FF00FF', '#FF8080', '#DCDCDC', '#800000', '#FF8000'] optim_font_color = ['#FFFFA0', '#00FFFF', '#FF00FF', '#0000FF', '#00FF00', '#007F7F', '#232323', '#7FFFFF', '#007FFF'] hatch = ['/', '\\', '|', '-', '+', 'o', '*'] (stable, unstable) = self.pourbaix_plot_data(limits) num_of_overlaps = {key: 0 for key in stable.keys()} for entry in stable: if isinstance(entry, MultiEntry): for e in entry.entrylist: if element in e.composition.elements: entry_dict_of_multientries[e.name].append(entry) num_of_overlaps[entry] += 1 else: entry_dict_of_multientries[entry.name].append(entry) if limits: xlim = limits[0] ylim = limits[1] else: xlim = self._analyzer.chempot_limits[0] ylim = self._analyzer.chempot_limits[1] h_line = np.transpose([[xlim[0], -xlim[0] * PREFAC], [xlim[1], -xlim[1] * PREFAC]]) o_line = np.transpose([[xlim[0], -xlim[0] * PREFAC + 1.23], [xlim[1], -xlim[1] * PREFAC + 1.23]]) neutral_line = np.transpose([[7, ylim[0]], [7, ylim[1]]]) V0_line = np.transpose([[xlim[0], 0], [xlim[1], 0]]) ax = plt.gca() ax.set_xlim(xlim) ax.set_ylim(ylim) from pymatgen import Composition, Element from pymatgen.core.ion import Ion def len_elts(entry): if "(s)" in entry: comp = Composition(entry[:-3]) else: comp = Ion.from_formula(entry) return len([el for el in comp.elements if el not in [Element("H"), Element("O")]]) sorted_entry = entry_dict_of_multientries.keys() sorted_entry.sort(key=len_elts) i = -1 label_chr = map(chr, list(range(65, 91))) for entry in sorted_entry: color_indx = 0 x_coord = 0.0 y_coord = 0.0 npts = 0 i += 1 for e in entry_dict_of_multientries[entry]: hc = 0 fc = 0 bc = 0 xy = self.domain_vertices(e) c = self.get_center(stable[e]) x_coord += c[0] y_coord += c[1] npts += 1 color_indx = i if "(s)" in entry: comp = Composition(entry[:-3]) else: comp = Ion.from_formula(entry) if len([el for el in comp.elements if el not in [Element("H"), Element("O")]]) == 1: if color_indx >= len(optim_colors): color_indx = color_indx -\ int(color_indx / len(optim_colors)) * len(optim_colors) patch = Polygon(xy, facecolor=optim_colors[color_indx], closed=True, lw=3.0, fill=True) bc = optim_colors[color_indx] else: if color_indx >= len(hatch): color_indx = color_indx - int(color_indx / len(hatch)) * len(hatch) patch = Polygon(xy, hatch=hatch[color_indx], closed=True, lw=3.0, fill=False) hc = hatch[color_indx] ax.add_patch(patch) xy_center = (x_coord / npts, y_coord / npts) if label_domains: if color_indx >= len(optim_colors): color_indx = color_indx -\ int(color_indx / len(optim_colors)) * len(optim_colors) fc = optim_font_color[color_indx] if bc and not hc: bbox = dict(boxstyle="round", fc=fc) if hc and not bc: bc = 'k' fc = 'w' bbox = dict(boxstyle="round", hatch=hc, fill=False) if bc and hc: bbox = dict(boxstyle="round", hatch=hc, fc=fc) # bbox.set_path_effects([PathEffects.withSimplePatchShadow()]) plt.annotate(latexify_ion(latexify(entry)), xy_center, color=bc, fontsize=30, bbox=bbox) # plt.annotate(label_chr[i], xy_center, # color=bc, fontsize=30, bbox=bbox) lw = 3 plt.plot(h_line[0], h_line[1], "r--", linewidth=lw) plt.plot(o_line[0], o_line[1], "r--", linewidth=lw) plt.plot(neutral_line[0], neutral_line[1], "k-.", linewidth=lw) plt.plot(V0_line[0], V0_line[1], "k-.", linewidth=lw) plt.xlabel("pH") plt.ylabel("E (V)") plt.title(title, fontsize=20, fontweight='bold') return plt
def _get_poly_site_description(self, site_index: int): """Gets a description of a connected polyhedral site. If the site likeness (order parameter) is less than ``distorted_tol``, "distorted" will be added to the geometry description. Args: site_index: An inequivalent site index. Returns: A description the a polyhedral site, including connectivity. """ site = self._da.sites[site_index] nnn_details = self._da.get_next_nearest_neighbor_details( site_index, group=not self.describe_symmetry_labels ) from_element = get_formatted_el( site["element"], self._da.sym_labels[site_index], use_oxi_state=self.describe_oxidation_state, use_sym_label=self.describe_symmetry_labels, fmt=self.fmt, ) from_poly_formula = site["poly_formula"] if self.fmt == "latex": from_poly_formula = latexify(from_poly_formula) elif self.fmt == "unicode": from_poly_formula = unicodeify(from_poly_formula) elif self.fmt == "html": from_poly_formula = htmlify(from_poly_formula) s_from_poly_formula = get_el(site["element"]) + from_poly_formula if site["geometry"]["likeness"] < self.distorted_tol: s_distorted = "distorted " else: s_distorted = "" s_polyhedra = geometry_to_polyhedra[site["geometry"]["type"]] s_polyhedra = polyhedra_plurals[s_polyhedra] nn_desc = self._get_nearest_neighbor_description(site_index) desc = f"{from_element} is bonded to {nn_desc} to form " # handle the case we were are connected to the same type of polyhedra if ( nnn_details[0].element == site["element"] and len( {(nnn_site.element, nnn_site.poly_formula) for nnn_site in nnn_details} ) ) == 1: connectivities = list({nnn_site.connectivity for nnn_site in nnn_details}) s_mixture = "a mixture of " if len(connectivities) != 1 else "" s_connectivities = en.join(connectivities) desc += "{}{}{}-sharing {} {}".format( s_mixture, s_distorted, s_connectivities, s_from_poly_formula, s_polyhedra, ) return desc # otherwise loop through nnn connectivities and describe individually desc += "{}{} {} that share ".format( s_distorted, s_from_poly_formula, s_polyhedra ) nnn_descriptions = [] for nnn_site in nnn_details: to_element = get_formatted_el( nnn_site.element, nnn_site.sym_label, use_oxi_state=False, use_sym_label=self.describe_symmetry_labels, ) to_poly_formula = nnn_site.poly_formula if self.fmt == "latex": to_poly_formula = latexify(to_poly_formula) elif self.fmt == "unicode": to_poly_formula = unicodeify(to_poly_formula) elif self.fmt == "html": to_poly_formula = htmlify(to_poly_formula) to_poly_formula = to_element + to_poly_formula to_shape = geometry_to_polyhedra[nnn_site.geometry] if len(nnn_site.sites) == 1 and nnn_site.count != 1: s_equivalent = " equivalent " else: s_equivalent = " " if nnn_site.count == 1: s_an = f" {en.an(nnn_site.connectivity)}" else: s_an = "" to_shape = polyhedra_plurals[to_shape] nnn_descriptions.append( "{}{} with {}{}{} {}".format( s_an, en.plural(nnn_site.connectivity, nnn_site.count), en.number_to_words(nnn_site.count), s_equivalent, to_poly_formula, to_shape, ) ) return desc + en.join(nnn_descriptions)
def get_chempot_range_map_plot(self, elements,referenced=True): """ Returns a plot of the chemical potential range _map. Currently works only for 3-component PDs. Args: elements: Sequence of elements to be considered as independent variables. E.g., if you want to show the stability ranges of all Li-Co-O phases wrt to uLi and uO, you will supply [Element("Li"), Element("O")] referenced: if True, gives the results with a reference being the energy of the elemental phase. If False, gives absolute values. Returns: A matplotlib plot object. """ plt = pretty_plot(12, 8) analyzer = PDAnalyzer(self._pd) chempot_ranges = analyzer.get_chempot_range_map( elements, referenced=referenced) missing_lines = {} excluded_region = [] for entry, lines in chempot_ranges.items(): comp = entry.composition center_x = 0 center_y = 0 coords = [] contain_zero = any([comp.get_atomic_fraction(el) == 0 for el in elements]) is_boundary = (not contain_zero) and \ sum([comp.get_atomic_fraction(el) for el in elements]) == 1 for line in lines: (x, y) = line.coords.transpose() plt.plot(x, y, "k-") for coord in line.coords: if not in_coord_list(coords, coord): coords.append(coord.tolist()) center_x += coord[0] center_y += coord[1] if is_boundary: excluded_region.extend(line.coords) if coords and contain_zero: missing_lines[entry] = coords else: xy = (center_x / len(coords), center_y / len(coords)) plt.annotate(latexify(entry.name), xy, fontsize=22) ax = plt.gca() xlim = ax.get_xlim() ylim = ax.get_ylim() #Shade the forbidden chemical potential regions. excluded_region.append([xlim[1], ylim[1]]) excluded_region = sorted(excluded_region, key=lambda c: c[0]) (x, y) = np.transpose(excluded_region) plt.fill(x, y, "0.80") #The hull does not generate the missing horizontal and vertical lines. #The following code fixes this. el0 = elements[0] el1 = elements[1] for entry, coords in missing_lines.items(): center_x = sum([c[0] for c in coords]) center_y = sum([c[1] for c in coords]) comp = entry.composition is_x = comp.get_atomic_fraction(el0) < 0.01 is_y = comp.get_atomic_fraction(el1) < 0.01 n = len(coords) if not (is_x and is_y): if is_x: coords = sorted(coords, key=lambda c: c[1]) for i in [0, -1]: x = [min(xlim), coords[i][0]] y = [coords[i][1], coords[i][1]] plt.plot(x, y, "k") center_x += min(xlim) center_y += coords[i][1] elif is_y: coords = sorted(coords, key=lambda c: c[0]) for i in [0, -1]: x = [coords[i][0], coords[i][0]] y = [coords[i][1], min(ylim)] plt.plot(x, y, "k") center_x += coords[i][0] center_y += min(ylim) xy = (center_x / (n + 2), center_y / (n + 2)) else: center_x = sum(coord[0] for coord in coords) + xlim[0] center_y = sum(coord[1] for coord in coords) + ylim[0] xy = (center_x / (n + 1), center_y / (n + 1)) plt.annotate(latexify(entry.name), xy, horizontalalignment="center", verticalalignment="center", fontsize=22) plt.xlabel("$\\mu_{{{0}}} - \\mu_{{{0}}}^0$ (eV)" .format(el0.symbol)) plt.ylabel("$\\mu_{{{0}}} - \\mu_{{{0}}}^0$ (eV)" .format(el1.symbol)) plt.tight_layout() return plt
def get_component_makeup_summary(self) -> str: """Gets a summary of the makeup of components in a structure. Returns: A description of the number of components and their dimensionalities and orientations. """ component_groups = self._da.get_component_groups() if ( len(component_groups) == 1 and component_groups[0].count == 1 and component_groups[0].dimensionality == 3 ): desc = "" else: if self._da.dimensionality == 3: desc = "The structure consists of " else: desc = "The structure is {}-dimensional and consists of " "".format( en.number_to_words(self._da.dimensionality) ) component_makeup_summaries = [] nframeworks = len( [ c for g in component_groups for c in g.components if c.dimensionality == 3 ] ) for component_group in component_groups: if nframeworks == 1 and component_group.dimensionality == 3: s_count = "a" else: s_count = en.number_to_words(component_group.count) dimensionality = component_group.dimensionality if component_group.molecule_name: if component_group.nsites == 1: shape = "atom" else: shape = "molecule" shape = en.plural(shape, s_count) formula = component_group.molecule_name else: shape = en.plural(dimensionality_to_shape[dimensionality], s_count) formula = component_group.formula if self.fmt == "latex": formula = latexify(formula) elif self.fmt == "unicode": formula = unicodeify(formula) print(formula) elif self.fmt == "html": formula = htmlify(formula) comp_desc = f"{s_count} {formula} {shape}" if component_group.dimensionality in [1, 2]: orientations = list( {c.orientation for c in component_group.components} ) s_direction = en.plural("direction", len(orientations)) comp_desc += " oriented in the {} {}".format( en.join(orientations), s_direction ) component_makeup_summaries.append(comp_desc) if nframeworks == 1 and len(component_makeup_summaries) > 1: # when there is a single framework, make the description read # "... and 8 Sn atoms inside a SnO2 framework" instead of # "..., 8 Sn atoms and one SnO2 framework" # This works because the component summaries are sorted by # dimensionality desc += en.join(component_makeup_summaries[:-1]) desc += f" inside {component_makeup_summaries[-1]}." else: desc += en.join(component_makeup_summaries) + "." return desc
def optplot(modes=('absorption', ), filenames=None, prefix=None, directory=None, gaussian=None, band_gaps=None, labels=None, average=True, height=6, width=6, xmin=0, xmax=None, ymin=0, ymax=1e5, colours=None, style=None, no_base_style=None, image_format='pdf', dpi=400, plt=None, fonts=None): """A script to plot optical absorption spectra from VASP calculations. Args: modes (:obj:`list` or :obj:`tuple`): Ordered list of :obj:`str` determining properties to plot. Accepted options are 'absorption' (default), 'eps', 'eps-real', 'eps-im', 'n', 'n-real', 'n-im', 'loss' (equivalent to n-im). filenames (:obj:`str` or :obj:`list`, optional): Path to vasprun.xml file (can be gzipped). Alternatively, a list of paths can be provided, in which case the absorption spectra for each will be plotted concurrently. prefix (:obj:`str`, optional): Prefix for file names. directory (:obj:`str`, optional): The directory in which to save files. gaussian (:obj:`float`): Standard deviation for gaussian broadening. band_gaps (:obj:`float` or :obj:`list`, optional): The band gap as a :obj:`float`, plotted as a dashed line. If plotting multiple spectra then a :obj:`list` of band gaps can be provided. labels (:obj:`str` or :obj:`list`): A label to identify the spectra. If plotting multiple spectra then a :obj:`list` of labels can be provided. average (:obj:`bool`, optional): Average the dielectric response across all lattice directions. Defaults to ``True``. height (:obj:`float`, optional): The height of the plot. width (:obj:`float`, optional): The width of the plot. xmin (:obj:`float`, optional): The minimum energy on the x-axis. xmax (:obj:`float`, optional): The maximum energy on the x-axis. ymin (:obj:`float`, optional): The minimum absorption intensity on the y-axis. ymax (:obj:`float`, optional): The maximum absorption intensity on the y-axis. colours (:obj:`list`, optional): A :obj:`list` of colours to use in the plot. The colours can be specified as a hex code, set of rgb values, or any other format supported by matplotlib. style (:obj:`list` or :obj:`str`, optional): (List of) matplotlib style specifications, to be composed on top of Sumo base style. no_base_style (:obj:`bool`, optional): Prevent use of sumo base style. This can make alternative styles behave more predictably. image_format (:obj:`str`, optional): The image file format. Can be any format supported by matplotlib, including: png, jpg, pdf, and svg. Defaults to pdf. dpi (:obj:`int`, optional): The dots-per-inch (pixel density) for the image. plt (:obj:`matplotlib.pyplot`, optional): A :obj:`matplotlib.pyplot` object to use for plotting. fonts (:obj:`list`, optional): Fonts to use in the plot. Can be a a single font, specified as a :obj:`str`, or several fonts, specified as a :obj:`list` of :obj:`str`. Returns: A matplotlib pyplot object. """ if not filenames: if os.path.exists('vasprun.xml'): filenames = ['vasprun.xml'] elif os.path.exists('vasprun.xml.gz'): filenames = ['vasprun.xml.gz'] else: logging.error('ERROR: No vasprun.xml found!') sys.exit() elif isinstance(filenames, str): filenames = [filenames] vrs = [Vasprun(f) for f in filenames] dielectrics = [vr.dielectric for vr in vrs] if gaussian: dielectrics = [broaden_eps(d, gaussian) for d in dielectrics] # initialize spectrum data ready to append from each dataset abs_data = OrderedDict() for mode in modes: abs_data.update({mode: []}) # for each calculation, get all required properties and append to data for d in dielectrics: for mode, spectrum in calculate_dielectric_properties( d, set(modes), average=average).items(): abs_data[mode].append(spectrum) if isinstance(band_gaps, list) and not band_gaps: # empty list therefore get bandgap from vasprun files band_gaps = [ vr.get_band_structure().get_band_gap()['energy'] for vr in vrs ] elif isinstance(band_gaps, list) and 'vasprun' in band_gaps[0]: # band_gaps contains list of vasprun files bg_vrs = [Vasprun(f) for f in band_gaps] band_gaps = [ vr.get_band_structure().get_band_gap()['energy'] for vr in bg_vrs ] elif isinstance(band_gaps, list): # band_gaps is non empty list w. no vaspruns; presume floats band_gaps = [float(i) for i in band_gaps] save_files = False if plt else True if len(abs_data) > 1 and not labels: labels = [ latexify(vr.final_structure.composition.reduced_formula).replace( '$_', '$_\mathregular') for vr in vrs ] plotter = SOpticsPlotter(abs_data, band_gap=band_gaps, label=labels) plt = plotter.get_plot(width=width, height=height, xmin=xmin, xmax=xmax, ymin=ymin, ymax=ymax, colours=colours, dpi=dpi, plt=plt, fonts=fonts, style=style, no_base_style=no_base_style) if save_files: basename = 'absorption' if prefix: basename = '{}_{}'.format(prefix, basename) image_filename = '{}.{}'.format(basename, image_format) if directory: image_filename = os.path.join(directory, image_filename) plt.savefig(image_filename, format=image_format, dpi=dpi) for mode, data in abs_data.items(): basename = 'absorption' if mode == 'abs' else mode write_files(data, basename=basename, prefix=prefix, directory=directory) else: return plt
def optplot( modes=("absorption", ), filenames=None, codes="vasp", prefix=None, directory=None, gaussian=None, band_gaps=None, labels=None, average=True, height=6, width=6, xmin=0, xmax=None, ymin=0, ymax=1e5, colours=None, style=None, no_base_style=None, image_format="pdf", dpi=400, plt=None, fonts=None, units="eV", ): """A script to plot optical absorption spectra from VASP calculations. Args: modes (:obj:`list` or :obj:`tuple`): Ordered list of :obj:`str` determining properties to plot. Accepted options are 'absorption' (default), 'eps', 'eps-real', 'eps-im', 'n', 'n-real', 'n-im', 'loss' (equivalent to n-im). filenames (:obj:`str` or :obj:`list`, optional): Path to data file. For VASP this is a *vasprun.xml* file (can be gzipped); for Questaal the *opt.ext* file from *lmf* or *eps_BSE.out* from *bethesalpeter* may be used. Alternatively, a list of paths can be provided, in which case the absorption spectra for each will be plotted concurrently. codes (:obj:`str` or :obj:`list`, optional): Original calculator. Accepted values are 'vasp' and 'questaal'. Items should correspond to filenames. prefix (:obj:`str`, optional): Prefix for file names. directory (:obj:`str`, optional): The directory in which to save files. gaussian (:obj:`float`): Standard deviation for gaussian broadening. band_gaps (:obj:`float`, :obj:`str` or :obj:`list`, optional): The band gap as a :obj:`float`, in eV, plotted as a dashed line. If plotting multiple spectra then a :obj:`list` of band gaps can be provided. Band gaps can be provided as a floating-point number or as a path to a *vasprun.xml* file. To skip over a line, set its bandgap to zero or a negative number to place it outside the visible range. labels (:obj:`str` or :obj:`list`): A label to identify the spectra. If plotting multiple spectra then a :obj:`list` of labels can be provided. average (:obj:`bool`, optional): Average the dielectric response across all lattice directions. Defaults to ``True``. height (:obj:`float`, optional): The height of the plot. width (:obj:`float`, optional): The width of the plot. xmin (:obj:`float`, optional): The minimum energy on the x-axis. xmax (:obj:`float`, optional): The maximum energy on the x-axis. ymin (:obj:`float`, optional): The minimum absorption intensity on the y-axis. ymax (:obj:`float`, optional): The maximum absorption intensity on the y-axis. colours (:obj:`list`, optional): A :obj:`list` of colours to use in the plot. The colours can be specified as a hex code, set of rgb values, or any other format supported by matplotlib. style (:obj:`list` or :obj:`str`, optional): (List of) matplotlib style specifications, to be composed on top of Sumo base style. no_base_style (:obj:`bool`, optional): Prevent use of sumo base style. This can make alternative styles behave more predictably. image_format (:obj:`str`, optional): The image file format. Can be any format supported by matplotlib, including: png, jpg, pdf, and svg. Defaults to pdf. dpi (:obj:`int`, optional): The dots-per-inch (pixel density) for the image. plt (:obj:`matplotlib.pyplot`, optional): A :obj:`matplotlib.pyplot` object to use for plotting. fonts (:obj:`list`, optional): Fonts to use in the plot. Can be a a single font, specified as a :obj:`str`, or several fonts, specified as a :obj:`list` of :obj:`str`. units (:obj:`str`, optional): X-axis units for the plot. 'eV' for energy in electronvolts or 'nm' for wavelength in nanometers. Defaults to 'eV'. Returns: A matplotlib pyplot object. """ # Don't write files if this is being done to manipulate existing plt save_files = False if plt else True # BUILD LIST OF FILES AUTOMATICALLY IF NECESSARY if codes == "vasp": if not filenames: if os.path.exists("vasprun.xml"): filenames = ["vasprun.xml"] elif os.path.exists("vasprun.xml.gz"): filenames = ["vasprun.xml.gz"] else: logging.error("ERROR: No vasprun.xml found!") sys.exit() elif codes == "questaal": if not filenames: if len(glob("opt.*")) > 0: filenames = glob("opt.*") if len(filenames) == 1: logging.info("Found optics file: " + filenames[0]) else: logging.info("Found optics files: " + ", ".join(filenames)) if isinstance(filenames, str): filenames = [filenames] if isinstance(codes, str): codes = [codes] * len(filenames) elif len(codes) == 1: codes = list(codes) * len(filenames) # ITERATE OVER FILES READING DIELECTRIC DATA dielectrics = [] auto_labels = [] auto_band_gaps = [] for i, (filename, code) in enumerate(zip(filenames, codes)): if code == "vasp": vr = Vasprun(filename) dielectrics.append(vr.dielectric) auto_labels.append( latexify( vr.final_structure.composition.reduced_formula).replace( "$_", r"$_\mathregular")) if isinstance(band_gaps, list) and not band_gaps: # band_gaps = [], auto band gap requested auto_band_gaps.append( vr.get_band_structure().get_band_gap()["energy"]) else: auto_band_gaps.append(None) elif code == "questaal": if not save_files: out_filename = None elif len(filenames) == 1: out_filename = "dielectric.dat" else: out_filename = f"dielectric_{i + 1}.dat" dielectrics.append( questaal.dielectric_from_file(filename, out_filename)) auto_band_gaps.append(None) auto_labels.append(filename.split(".")[-1]) if isinstance(band_gaps, list) and not band_gaps: logging.info("Bandgap requested but not supported for Questaal" " file {}: skipping...".format(filename)) else: raise Exception(f'Code selection "{code}" not recognised') if not labels and len(filenames) > 1: labels = auto_labels # PROCESS DIELECTRIC DATA: BROADENING AND DERIVED PROPERTIES if gaussian: dielectrics = [broaden_eps(d, gaussian) for d in dielectrics] # initialize spectrum data ready to append from each dataset abs_data = OrderedDict() for mode in modes: abs_data.update({mode: []}) # for each calculation, get all required properties and append to data for d in dielectrics: # TODO: add support for other eigs and full modes energies, properties = calculate_dielectric_properties( d, set(modes), mode="average" if average else "trace") for mode, spectrum in properties.items(): abs_data[mode].append((energies, spectrum)) if isinstance(band_gaps, list) and not band_gaps: # empty list therefore use bandgaps collected from vasprun files band_gaps = auto_band_gaps elif isinstance(band_gaps, list): # list containing filenames and/or values: mutate the list in-place for i, item in enumerate(band_gaps): if item is None: pass elif _floatable(item): band_gaps[i] = float(item) elif "vasprun" in item: band_gaps[i] = (Vasprun( item).get_band_structure().get_band_gap()["energy"]) else: raise ValueError( f"Format not recognised for auto bandgap: {item}.") plotter = SOpticsPlotter(abs_data, band_gap=band_gaps, label=labels) plt = plotter.get_plot( width=width, height=height, xmin=xmin, xmax=xmax, ymin=ymin, ymax=ymax, colours=colours, dpi=dpi, plt=plt, fonts=fonts, style=style, no_base_style=no_base_style, units=units, ) if save_files: basename = "absorption" if prefix: basename = f"{prefix}_{basename}" image_filename = f"{basename}.{image_format}" if directory: image_filename = os.path.join(directory, image_filename) plt.savefig(image_filename, format=image_format, dpi=dpi) for mode, data in abs_data.items(): basename = "absorption" if mode == "abs" else mode write_files(data, basename=basename, prefix=prefix, directory=directory) else: return plt
def optplot(filenames=None, prefix=None, directory=None, gaussian=None, band_gaps=None, labels=None, average=True, height=6, width=6, xmin=0, xmax=None, ymin=0, ymax=1e5, colours=None, image_format='pdf', dpi=400, plt=None, fonts=None): """A script to plot optical absorption spectra from VASP calculations. Args: filenames (:obj:`str` or :obj:`list`, optional): Path to vasprun.xml file (can be gzipped). Alternatively, a list of paths can be provided, in which case the absorption spectra for each will be plotted concurrently. prefix (:obj:`str`, optional): Prefix for file names. directory (:obj:`str`, optional): The directory in which to save files. gaussian (:obj:`float`): Standard deviation for gaussian broadening. band_gaps (:obj:`float` or :obj:`list`, optional): The band gap as a :obj:`float`, plotted as a dashed line. If plotting multiple spectra then a :obj:`list` of band gaps can be provided. labels (:obj:`str` or :obj:`list`): A label to identify the spectra. If plotting multiple spectra then a :obj:`list` of labels can be provided. average (:obj:`bool`, optional): Average the dielectric response across all lattice directions. Defaults to ``True``. height (:obj:`float`, optional): The height of the plot. width (:obj:`float`, optional): The width of the plot. xmin (:obj:`float`, optional): The minimum energy on the x-axis. xmax (:obj:`float`, optional): The maximum energy on the x-axis. ymin (:obj:`float`, optional): The minimum absorption intensity on the y-axis. ymax (:obj:`float`, optional): The maximum absorption intensity on the y-axis. colours (:obj:`list`, optional): A :obj:`list` of colours to use in the plot. The colours can be specified as a hex code, set of rgb values, or any other format supported by matplotlib. image_format (:obj:`str`, optional): The image file format. Can be any format supported by matplotlib, including: png, jpg, pdf, and svg. Defaults to pdf. dpi (:obj:`int`, optional): The dots-per-inch (pixel density) for the image. plt (:obj:`matplotlib.pyplot`, optional): A :obj:`matplotlib.pyplot` object to use for plotting. fonts (:obj:`list`, optional): Fonts to use in the plot. Can be a a single font, specified as a :obj:`str`, or several fonts, specified as a :obj:`list` of :obj:`str`. Returns: A matplotlib pyplot object. """ if not filenames: if os.path.exists('vasprun.xml'): filenames = ['vasprun.xml'] elif os.path.exists('vasprun.xml.gz'): filenames = ['vasprun.xml.gz'] else: logging.error('ERROR: No vasprun.xml found!') sys.exit() elif type(filenames) is str: filenames = [filenames] vrs = [Vasprun(f) for f in filenames] dielectrics = [vr.dielectric for vr in vrs] if gaussian: dielectrics = [broaden_eps(d, gaussian) for d in dielectrics] abs_data = [calculate_alpha(d, average=average) for d in dielectrics] if type(band_gaps) is list and not band_gaps: # empty list therefore get bandgap from vasprun files band_gaps = [ vr.get_band_structure().get_band_gap()['energy'] for vr in vrs ] elif type(band_gaps) is list and 'vasprun' in band_gaps[0]: # band_gaps contains list of vasprun files bg_vrs = [Vasprun(f) for f in band_gaps] band_gaps = [ vr.get_band_structure().get_band_gap()['energy'] for vr in bg_vrs ] elif type(band_gaps) is list: # band_gaps is non empty list w. no vaspruns; presume floats band_gaps = [float(i) for i in band_gaps] save_files = False if plt else True if len(abs_data) > 1 and not labels: labels = [ latexify(vr.final_structure.composition.reduced_formula).replace( '$_', '$_\mathregular') for vr in vrs ] plotter = SOpticsPlotter(abs_data, band_gap=band_gaps, label=labels) plt = plotter.get_plot(width=width, height=height, xmin=xmin, xmax=xmax, ymin=ymin, ymax=ymax, colours=colours, dpi=dpi, plt=plt, fonts=fonts) if save_files: basename = 'absorption.{}'.format(image_format) filename = '{}_{}'.format(prefix, basename) if prefix else basename if directory: filename = os.path.join(directory, filename) plt.savefig(filename, format=image_format, dpi=dpi) write_files(abs_data, prefix=prefix, directory=directory) else: return plt