Exemple #1
0
    def get_plot(self, normalize_rxn_coordinate=True, label_barrier=True):
        """
        Returns the NEB plot. Uses Henkelman's approach of spline fitting
        each section of the reaction path based on tangent force and energies.

        Args:
            normalize_rxn_coordinate (bool): Whether to normalize the
                reaction coordinate to between 0 and 1. Defaults to True.
            label_barrier (bool): Whether to label the maximum barrier.

        Returns:
            matplotlib.pyplot object.
        """
        plt = pretty_plot(12, 8)
        scale = 1 if not normalize_rxn_coordinate else 1 / self.r[-1]
        x = np.arange(0, np.max(self.r), 0.01)
        y = self.spline(x) * 1000
        relative_energies = self.energies - self.energies[0]
        plt.plot(self.r * scale, relative_energies * 1000, 'ro',
                 x * scale, y, 'k-', linewidth=2, markersize=10)
        plt.xlabel("Reaction coordinate")
        plt.ylabel("Energy (meV)")
        plt.ylim((np.min(y) - 10, np.max(y) * 1.02 + 20))
        if label_barrier:
            data = zip(x * scale, y)
            barrier = max(data, key=lambda d: d[1])
            plt.plot([0, barrier[0]], [barrier[1], barrier[1]], 'k--')
            plt.annotate('%.0f meV' % barrier[1],
                         xy=(barrier[0] / 2, barrier[1] * 1.02),
                         xytext=(barrier[0] / 2, barrier[1] * 1.02),
                         horizontalalignment='center')
        plt.tight_layout()
        return plt
Exemple #2
0
    def get_plot(self, width=8, height=8, term_zero=True):
        """
        Returns a plot object.

        Args:
            width: Width of the plot. Defaults to 8 in.
            height: Height of the plot. Defaults to 6 in.
            term_zero: If True append zero voltage point at the end

        Returns:
            A matplotlib plot object.
        """
        plt = pretty_plot(width, height)
        wion_symbol = set()
        formula = set()

        for label, electrode in self._electrodes.items():
            (x, y) = self.get_plot_data(electrode, term_zero=term_zero)
            wion_symbol.add(electrode.working_ion.symbol)
            formula.add(electrode.framework_formula)
            plt.plot(x, y, "-", linewidth=2, label=label)

        plt.legend()
        plt.xlabel(
            self._choose_best_x_lable(formula=formula,
                                      wion_symbol=wion_symbol))
        plt.ylabel("Voltage (V)")
        plt.tight_layout()
        return plt
Exemple #3
0
    def get_plot(self, marker="o", markersize=None, units="thz"):
        """
        will produce a plot
        Args:
            marker: marker for the depiction
            markersize: size of the marker
            units: unit for the plots, accepted units: thz, ev, mev, ha, cm-1, cm^-1

        Returns: plot

        """

        u = freq_units(units)

        x = self._gruneisen.frequencies.flatten() * u.factor
        y = self._gruneisen.gruneisen.flatten()

        plt = pretty_plot(12, 8)

        plt.xlabel(rf"$\mathrm{{Frequency\ ({u.label})}}$")
        plt.ylabel(r"$\mathrm{Grüneisen\ parameter}$")

        n = len(y) - 1
        for i, (y, x) in enumerate(zip(y, x)):
            color = (1.0 / n * i, 0, 1.0 / n * (n - i))

            if markersize:
                plt.plot(x, y, marker, color=color, markersize=markersize)
            else:
                plt.plot(x, y, marker, color=color)

        plt.tight_layout()

        return plt
Exemple #4
0
 def get_arrhenius_plot(self):
     from pymatgen.util.plotting import pretty_plot
     plt = pretty_plot(12, 8)
     arr = np.power(10, self.slope * self.x + self.intercept)
     plt.plot(self.x,
              self.diffusivities,
              'ko',
              self.x,
              arr,
              'k--',
              markersize=10)
     plt.errorbar(self.x,
                  self.diffusivities,
                  yerr=self.diffusivity_errors,
                  fmt='ko',
                  ecolor='k',
                  capthick=2,
                  linewidth=2)
     ax = plt.axes()
     ax.set_yscale('log')
     plt.text(0.6,
              0.85,
              "E$_a$ = {:.3f} eV".format(self.Ea),
              fontsize=30,
              transform=plt.axes().transAxes)
     plt.ylabel("D (cm$^2$/s)")
     plt.xlabel("1000/T (K$^{-1}$)")
     plt.tight_layout()
     return plt
    def get_framework_rms_plot(self,
                               plt=None,
                               granularity=200,
                               matching_s=None):
        """
        Get the plot of rms framework displacement vs time. Useful for checking
        for melting, especially if framework atoms can move via paddle-wheel
        or similar mechanism (which would show up in max framework displacement
        but doesn't constitute melting).

        Args:
            plt (matplotlib.pyplot): If plt is supplied, changes will be made
                to an existing plot. Otherwise, a new plot will be created.
            granularity (int): Number of structures to match
            matching_s (Structure): Optionally match to a disordered structure
                instead of the first structure in the analyzer. Required when
                a secondary mobile ion is present.
        Notes:
            The method doesn't apply to NPT-AIMD simulation analysis.
        """
        from pymatgen.util.plotting import pretty_plot

        if self.lattices is not None and len(self.lattices) > 1:
            warnings.warn(
                "Note the method doesn't apply to NPT-AIMD simulation analysis!"
            )

        plt = pretty_plot(12, 8, plt=plt)
        step = (self.corrected_displacements.shape[1] - 1) // (granularity - 1)
        f = (matching_s or self.structure).copy()
        f.remove_species([self.specie])
        sm = StructureMatcher(
            primitive_cell=False,
            stol=0.6,
            comparator=OrderDisorderElementComparator(),
            allow_subset=True,
        )
        rms = []
        for s in self.get_drift_corrected_structures(step=step):
            s.remove_species([self.specie])
            d = sm.get_rms_dist(f, s)
            if d:
                rms.append(d)
            else:
                rms.append((1, 1))
        max_dt = (len(rms) - 1) * step * self.step_skip * self.time_step
        if max_dt > 100000:
            plot_dt = np.linspace(0, max_dt / 1000, len(rms))
            unit = "ps"
        else:
            plot_dt = np.linspace(0, max_dt, len(rms))
            unit = "fs"
        rms = np.array(rms)
        plt.plot(plot_dt, rms[:, 0], label="RMS")
        plt.plot(plot_dt, rms[:, 1], label="max")
        plt.legend(loc="best")
        plt.xlabel(f"Timestep ({unit})")
        plt.ylabel("normalized distance")
        plt.tight_layout()
        return plt
    def get_plot(self, normalize_rxn_coordinate=True, label_barrier=True):
        """
        Returns the NEB plot. Uses Henkelman's approach of spline fitting
        each section of the reaction path based on tangent force and energies.

        Args:
            normalize_rxn_coordinate (bool): Whether to normalize the
                reaction coordinate to between 0 and 1. Defaults to True.
            label_barrier (bool): Whether to label the maximum barrier.

        Returns:
            matplotlib.pyplot object.
        """
        plt = pretty_plot(12, 8)
        scale = 1 if not normalize_rxn_coordinate else 1 / self.r[-1]
        x = np.arange(0, np.max(self.r), 0.01)
        y = self.spline(x) * 1000
        relative_energies = self.energies - self.energies[0]
        plt.plot(self.r * scale, relative_energies * 1000, 'ro',
                 x * scale, y, 'k-', linewidth=2, markersize=10)
        plt.xlabel("Reaction coordinate")
        plt.ylabel("Energy (meV)")
        plt.ylim((np.min(y) - 10, np.max(y) * 1.02 + 20))
        if label_barrier:
            data = zip(x * scale, y)
            barrier = max(data, key=lambda d: d[1])
            plt.plot([0, barrier[0]], [barrier[1], barrier[1]], 'k--')
            plt.annotate('%.0f meV' % (np.max(y) - np.min(y)),
                         xy=(barrier[0] / 2, barrier[1] * 1.02),
                         xytext=(barrier[0] / 2, barrier[1] * 1.02),
                         horizontalalignment='center')
        plt.tight_layout()
        return plt
    def get_rdf_plot(self, label=None, xlim=(0.0, 8.0), ylim=(-0.005, 3.0)):
        """
        Plot the average RDF function.

        Args:
            label (str): The legend label.
            xlim (list): Set the x limits of the current axes.
            ylim (list): Set the y limits of the current axes.
        """

        if label is None:
            symbol_list = [e.symbol for e in
                           self.structures[0].composition.keys()]
            symbol_list = [symbol for symbol in symbol_list if
                           symbol in self.species]

            if len(symbol_list) == 1:
                label = symbol_list[0]
            else:
                label = "-".join(symbol_list)

        plt = pretty_plot(12, 8)
        plt.plot(self.interval, self.rdf, color="r", label=label, linewidth=4.0)
        plt.xlabel("$r$ ($\\rm\AA$)")
        plt.ylabel("$g(r)$")
        plt.legend(loc='upper right', fontsize=36)
        plt.xlim(xlim[0], xlim[1])
        plt.ylim(ylim[0], ylim[1])
        plt.tight_layout()

        return plt
Exemple #8
0
def get_chgint_plot(args):
    chgcar = Chgcar.from_file(args.chgcar_file)
    s = chgcar.structure

    if args.inds:
        atom_ind = [int(i) for i in args.inds[0].split(",")]
    else:
        finder = SpacegroupAnalyzer(s, symprec=0.1)
        sites = [
            sites[0]
            for sites in finder.get_symmetrized_structure().equivalent_sites
        ]
        atom_ind = [s.sites.index(site) for site in sites]

    from pymatgen.util.plotting import pretty_plot
    plt = pretty_plot(12, 8)
    for i in atom_ind:
        d = chgcar.get_integrated_diff(i, args.radius, 30)
        plt.plot(d[:, 0],
                 d[:, 1],
                 label="Atom {} - {}".format(i, s[i].species_string))
    plt.legend(loc="upper left")
    plt.xlabel("Radius (A)")
    plt.ylabel("Integrated charge (e)")
    plt.tight_layout()
    return plt
Exemple #9
0
    def get_xrd_plot(self, structure, two_theta_range=(0, 90),
                     annotate_peaks=True):
        """
        Returns the XRD plot as a matplotlib.pyplot.

        Args:
            structure: Input structure
            two_theta_range ([float of length 2]): Tuple for range of
                two_thetas to calculate in degrees. Defaults to (0, 90). Set to
                None if you want all diffracted beams within the limiting
                sphere of radius 2 / wavelength.
            annotate_peaks: Whether to annotate the peaks with plane
                information.

        Returns:
            (matplotlib.pyplot)
        """
        from pymatgen.util.plotting import pretty_plot
        plt = pretty_plot(16, 10)
        for two_theta, i, hkls, d_hkl in self.get_xrd_data(
                structure, two_theta_range=two_theta_range):
            if two_theta_range[0] <= two_theta <= two_theta_range[1]:
                label = ", ".join([str(hkl) for hkl in hkls.keys()])
                plt.plot([two_theta, two_theta], [0, i], color='k',
                         linewidth=3, label=label)
                if annotate_peaks:
                    plt.annotate(label, xy=[two_theta, i],
                                 xytext=[two_theta, i], fontsize=16)
        plt.xlabel(r"$2\theta$ ($^\circ$)")
        plt.ylabel("Intensities (scaled)")
        plt.tight_layout()

        return plt
Exemple #10
0
    def plot_conc_temp(self, me=[1.0, 1.0, 1.0], mh=[1.0, 1.0, 1.0]):
        """
        plot the concentration of carriers vs temperature both in eq and non-eq after quenching at 300K
        Args:
            me:
                the effective mass for the electrons as a list of 3 eigenvalues
            mh:
                the effective mass for the holes as a list of 3 eigenvalues
        Returns;
            a matplotlib object

        """
        temps = [i * 100 for i in range(3, 20)]
        qi = []
        qi_non_eq = []
        for t in temps:
            qi.append(self._analyzer.get_eq_Ef(t, me, mh)['Qi'] * 1e-6)
            qi_non_eq.append(
                self._analyzer.get_non_eq_Ef(t, 300, me, mh)['Qi'] * 1e-6)

        plt = pretty_plot(12, 8)
        plt.xlabel("temperature (K)")
        plt.ylabel("carrier concentration (cm$^{-3}$)")
        plt.semilogy(temps, qi, linewidth=3.0)
        plt.semilogy(temps, qi_non_eq, linewidth=3)
        plt.legend(['eq', 'non-eq'])
        return plt
Exemple #11
0
    def plot(self, width=8, height=None, plt=None, dpi=None, **kwargs):
        """
        Plot the equation of state.

        Args:
            width (float): Width of plot in inches. Defaults to 8in.
            height (float): Height of plot in inches. Defaults to width *
                golden ratio.
            plt (matplotlib.pyplot): If plt is supplied, changes will be made
                to an existing plot. Otherwise, a new plot will be created.
            dpi:
            kwargs (dict): additional args fed to pyplot.plot.
                supported keys: style, color, text, label

        Returns:
            Matplotlib plot object.
        """
        # pylint: disable=E1307
        plt = pretty_plot(width=width, height=height, plt=plt, dpi=dpi)

        color = kwargs.get("color", "r")
        label = kwargs.get("label", "{} fit".format(self.__class__.__name__))
        lines = [
            "Equation of State: %s" % self.__class__.__name__,
            "Minimum energy = %1.2f eV" % self.e0,
            "Minimum or reference volume = %1.2f Ang^3" % self.v0,
            "Bulk modulus = %1.2f eV/Ang^3 = %1.2f GPa" %
            (self.b0, self.b0_GPa),
            "Derivative of bulk modulus wrt pressure = %1.2f" % self.b1,
        ]
        text = "\n".join(lines)
        text = kwargs.get("text", text)

        # Plot input data.
        plt.plot(self.volumes,
                 self.energies,
                 linestyle="None",
                 marker="o",
                 color=color)

        # Plot eos fit.
        vmin, vmax = min(self.volumes), max(self.volumes)
        vmin, vmax = (vmin - 0.01 * abs(vmin), vmax + 0.01 * abs(vmax))
        vfit = np.linspace(vmin, vmax, 100)

        plt.plot(vfit,
                 self.func(vfit),
                 linestyle="dashed",
                 color=color,
                 label=label)

        plt.grid(True)
        plt.xlabel("Volume $\\AA^3$")
        plt.ylabel("Energy (eV)")
        plt.legend(loc="best", shadow=True)
        # Add text with fit parameters.
        plt.text(0.4, 0.5, text, transform=plt.gca().transAxes)

        return plt
Exemple #12
0
    def get_pourbaix_plot(self,
                          limits=None,
                          title="",
                          label_domains=True,
                          plt=None):
        """
        Plot Pourbaix diagram.

        Args:
            limits: 2D list containing limits of the Pourbaix diagram
                of the form [[xlo, xhi], [ylo, yhi]]
            title (str): Title to display on plot
            label_domains (bool): whether to label pourbaix domains
            plt (pyplot): Pyplot instance for plotting

        Returns:
            plt (pyplot) - matplotlib plot object with pourbaix diagram
        """
        if limits is None:
            limits = [[-2, 16], [-3, 3]]

        plt = plt or pretty_plot(16)

        xlim = limits[0]
        ylim = limits[1]

        h_line = np.transpose([[xlim[0], -xlim[0] * PREFAC],
                               [xlim[1], -xlim[1] * PREFAC]])
        o_line = np.transpose([[xlim[0], -xlim[0] * PREFAC + 1.23],
                               [xlim[1], -xlim[1] * PREFAC + 1.23]])
        neutral_line = np.transpose([[7, ylim[0]], [7, ylim[1]]])
        V0_line = np.transpose([[xlim[0], 0], [xlim[1], 0]])

        ax = plt.gca()
        ax.set_xlim(xlim)
        ax.set_ylim(ylim)
        lw = 3
        plt.plot(h_line[0], h_line[1], "r--", linewidth=lw)
        plt.plot(o_line[0], o_line[1], "r--", linewidth=lw)
        plt.plot(neutral_line[0], neutral_line[1], "k-.", linewidth=lw)
        plt.plot(V0_line[0], V0_line[1], "k-.", linewidth=lw)

        for entry, vertices in self._pbx._stable_domain_vertices.items():
            center = np.average(vertices, axis=0)
            x, y = np.transpose(np.vstack([vertices, vertices[0]]))
            plt.plot(x, y, 'k-', linewidth=lw)

            if label_domains:
                plt.annotate(generate_entry_label(entry),
                             center,
                             ha='center',
                             va='center',
                             fontsize=20,
                             color="b").draggable()

        plt.xlabel("pH")
        plt.ylabel("E (V)")
        plt.title(title, fontsize=20, fontweight='bold')
        return plt
Exemple #13
0
    def get_plot(self,
                 structure,
                 two_theta_range=(0, 90),
                 annotate_peaks=True,
                 ax=None,
                 with_labels=True,
                 fontsize=16):
        """
        Returns the diffraction plot as a matplotlib.pyplot.

        Args:
            structure: Input structure
            two_theta_range ([float of length 2]): Tuple for range of
                two_thetas to calculate in degrees. Defaults to (0, 90). Set to
                None if you want all diffracted beams within the limiting
                sphere of radius 2 / wavelength.
            annotate_peaks: Whether to annotate the peaks with plane
                information.
            ax: matplotlib :class:`Axes` or None if a new figure should be created.
            with_labels: True to add xlabels and ylabels to the plot.
            fontsize: (int) fontsize for peak labels.

        Returns:
            (matplotlib.pyplot)
        """
        if ax is None:
            from pymatgen.util.plotting import pretty_plot
            plt = pretty_plot(16, 10)
            ax = plt.gca()
        else:
            # This to maintain the type of the return value.
            import matplotlib.pyplot as plt

        xrd = self.get_pattern(structure, two_theta_range=two_theta_range)

        for two_theta, i, hkls, d_hkl in zip(xrd.x, xrd.y, xrd.hkls,
                                             xrd.d_hkls):
            if two_theta_range[0] <= two_theta <= two_theta_range[1]:
                print(hkls)
                label = ", ".join([str(hkl["hkl"]) for hkl in hkls])
                ax.plot([two_theta, two_theta], [0, i],
                        color='k',
                        linewidth=3,
                        label=label)
                if annotate_peaks:
                    ax.annotate(label,
                                xy=[two_theta, i],
                                xytext=[two_theta, i],
                                fontsize=fontsize)

        if with_labels:
            ax.set_xlabel(r"$2\theta$ ($^\circ$)")
            ax.set_ylabel("Intensities (scaled)")

        if hasattr(ax, "tight_layout"):
            ax.tight_layout()

        return plt
Exemple #14
0
    def _get_matplotlib_figure(self) -> plt.Figure:
        """Returns a matplotlib figure of reaction kinks diagram"""
        pretty_plot(8, 5)
        plt.xlim([-0.05, 1.05])  # plot boundary is 5% wider on each side

        kinks = list(zip(*self.get_kinks()))  # type: ignore
        _, x, energy, reactions, _ = kinks

        plt.plot(x, energy, "o-", markersize=8, c="navy", zorder=1)
        plt.scatter(self.minimum[0],
                    self.minimum[1],
                    marker="*",
                    c="red",
                    s=400,
                    zorder=2)

        for x_coord, y_coord, rxn in zip(x, energy, reactions):
            products = ", ".join([
                latexify(p.reduced_formula)
                for p in rxn.products  # type: ignore
                if not np.isclose(rxn.get_coeff(p), 0)  # type: ignore
            ])
            plt.annotate(
                products,
                xy=(x_coord, y_coord),
                xytext=(10, -30),
                textcoords="offset points",
                ha="right",
                va="bottom",
                arrowprops=dict(arrowstyle="->", connectionstyle="arc3,rad=0"),
            )

        if self.norm:
            plt.ylabel("Energy (eV/atom)")
        else:
            plt.ylabel("Energy (eV/f.u.)")

        plt.xlabel(self._get_xaxis_title())
        plt.ylim(self.minimum[1] +
                 0.05 * self.minimum[1])  # plot boundary is 5% lower

        fig = plt.gcf()
        plt.close(fig)

        return fig
Exemple #15
0
    def get_plot(self, xlim=None, ylim=None):
        """
        Get a matplotlib plot showing the DOS.

        Args:
            xlim: Specifies the x-axis limits. Set to None for automatic
                determination.
            ylim: Specifies the y-axis limits.
        """

        plt = pretty_plot(7, 0)
        base = 0.0
        i = 0
        for key, sp in self._spectra.items():
            if not self.stack:
                plt.plot(
                    sp.x,
                    sp.y + self.yshift * i,
                    color=self.colors[i],
                    label=str(key),
                    linewidth=3,
                )
            else:
                plt.fill_between(
                    sp.x,
                    base,
                    sp.y + self.yshift * i,
                    color=self.colors[i],
                    label=str(key),
                    linewidth=3,
                )
                base = sp.y + base
            plt.xlabel('Número de onda ' + r'($cm^{-1}$)')
            plt.ylabel('Intensidadade (u.a.)')
            i += 1

        if xlim:
            plt.xlim(xlim)
        if ylim:
            plt.ylim(ylim)
        """
        *************************************************************************
        Configuração feito para ordenar a legenda
        *************************************************************************
        """
        # current_handles, current_labels = plt.gca().get_legend_handles_labels()
        # reversed_handles = list(reversed(current_handles))
        # reversed_labels = list(reversed(current_labels))
        # plt.legend(reversed_handles, reversed_labels)
        # ***********************************************************************

        plt.legend()
        leg = plt.gca().get_legend()
        ltext = leg.get_texts()  # all the text.Text instance in the legend
        plt.setp(ltext, fontsize=30)
        plt.tight_layout()
        return plt
Exemple #16
0
def get_arrhenius_plot(temps,
                       diffusivities,
                       diffusivity_errors=None,
                       **kwargs):
    r"""
    Returns an Arrhenius plot.

    Args:
        temps ([float]): A sequence of temperatures.
        diffusivities ([float]): A sequence of diffusivities (e.g.,
            from DiffusionAnalyzer.diffusivity).
        diffusivity_errors ([float]): A sequence of errors for the
            diffusivities. If None, no error bar is plotted.
        \\*\\*kwargs:
            Any keyword args supported by matplotlib.pyplot.plot.

    Returns:
        A matplotlib.pyplot object. Do plt.show() to show the plot.
    """
    Ea, c, _ = fit_arrhenius(temps, diffusivities)

    from pymatgen.util.plotting import pretty_plot
    plt = pretty_plot(12, 8)

    # log10 of the arrhenius fit
    arr = c * np.exp(-Ea / (const.k / const.e * np.array(temps)))

    t_1 = 1000 / np.array(temps)

    plt.plot(t_1,
             diffusivities,
             'ko',
             t_1,
             arr,
             'k--',
             markersize=10,
             **kwargs)
    if diffusivity_errors is not None:
        n = len(diffusivity_errors)
        plt.errorbar(t_1[0:n],
                     diffusivities[0:n],
                     yerr=diffusivity_errors,
                     fmt='ko',
                     ecolor='k',
                     capthick=2,
                     linewidth=2)
    ax = plt.axes()
    ax.set_yscale('log')
    plt.text(0.6,
             0.85,
             "E$_a$ = {:.0f} meV".format(Ea * 1000),
             fontsize=30,
             transform=plt.axes().transAxes)
    plt.ylabel("D (cm$^2$/s)")
    plt.xlabel("1000/T (K$^{-1}$)")
    plt.tight_layout()
    return plt
    def get_msd_plot(self, plt=None, mode="specie"):
        """
        Get the plot of the smoothed msd vs time graph. Useful for
        checking convergence. This can be written to an image file.

        Args:
            plt: A plot object. Defaults to None, which means one will be
                generated.
            mode (str): Determines type of msd plot. By "species", "sites",
                or direction (default). If mode = "mscd", the smoothed mscd vs.
                time will be plotted.
        """
        from pymatgen.util.plotting import pretty_plot

        plt = pretty_plot(12, 8, plt=plt)
        if np.max(self.dt) > 100000:
            plot_dt = self.dt / 1000
            unit = "ps"
        else:
            plot_dt = self.dt
            unit = "fs"

        if mode == "species":
            for sp in sorted(self.structure.composition.keys()):
                indices = [
                    i for i, site in enumerate(self.structure)
                    if site.specie == sp
                ]
                sd = np.average(self.sq_disp_ions[indices, :], axis=0)
                plt.plot(plot_dt, sd, label=sp.__str__())
            plt.legend(loc=2, prop={"size": 20})
        elif mode == "sites":
            for i, site in enumerate(self.structure):
                sd = self.sq_disp_ions[i, :]
                plt.plot(plot_dt,
                         sd,
                         label="%s - %d" % (site.specie.__str__(), i))
            plt.legend(loc=2, prop={"size": 20})
        elif mode == "mscd":
            plt.plot(plot_dt, self.mscd, "r")
            plt.legend(["Overall"], loc=2, prop={"size": 20})
        else:
            # Handle default / invalid mode case
            plt.plot(plot_dt, self.msd, "k")
            plt.plot(plot_dt, self.msd_components[:, 0], "r")
            plt.plot(plot_dt, self.msd_components[:, 1], "g")
            plt.plot(plot_dt, self.msd_components[:, 2], "b")
            plt.legend(["Overall", "a", "b", "c"], loc=2, prop={"size": 20})

        plt.xlabel("Timestep ({})".format(unit))
        if mode == "mscd":
            plt.ylabel("MSCD ($\\AA^2$)")
        else:
            plt.ylabel("MSD ($\\AA^2$)")
        plt.tight_layout()
        return plt
    def get_rdf_plot(
            self,
            label: str = None,
            xlim: tuple = (0.0, 8.0),
            ylim: tuple = (-0.005, 3.0),
            loc_peak: bool = False,
    ):
        """
        Plot the average RDF function.

        Args:
            label (str): The legend label.
            xlim (list): Set the x limits of the current axes.
            ylim (list): Set the y limits of the current axes.
            loc_peak (bool): Label peaks if True.
        """

        if label is None:
            symbol_list = [
                e.symbol for e in self.structures[0].composition.keys()
            ]
            symbol_list = [
                symbol for symbol in symbol_list if symbol in self.species
            ]

            if len(symbol_list) == 1:
                label = symbol_list[0]
            else:
                label = "-".join(symbol_list)

        plt = pretty_plot(12, 8)
        plt.plot(self.interval, self.rdf, label=label, linewidth=4.0, zorder=1)

        if loc_peak:
            plt.scatter(
                self.peak_r,
                self.peak_rdf,
                marker="P",
                s=240,
                c="k",
                linewidths=0.1,
                alpha=0.7,
                zorder=2,
                label="Peaks",
            )

        plt.xlabel("$r$ ($\\rm\\AA$)")
        plt.ylabel("$g(r)$")
        plt.legend(loc="upper right", fontsize=36)
        plt.xlim(xlim[0], xlim[1])
        plt.ylim(ylim[0], ylim[1])
        plt.tight_layout()

        return plt
Exemple #19
0
def plot_something():
    matplotlib.rc('text', usetex=True)
    matplotlib.rc('font', family='serif')

    plt = pretty_plot(6, 5.5)

    plt.plot([1, 2, 3], [1, 2, 3])

    plt.tight_layout()

    return plt
    def get_framework_rms_plot(self, plt=None, granularity=200,
                               matching_s=None):
        """
        Get the plot of rms framework displacement vs time. Useful for checking
        for melting, especially if framework atoms can move via paddle-wheel
        or similar mechanism (which would show up in max framework displacement
        but doesn't constitute melting).

        Args:
            plt (matplotlib.pyplot): If plt is supplied, changes will be made 
                to an existing plot. Otherwise, a new plot will be created.
            granularity (int): Number of structures to match
            matching_s (Structure): Optionally match to a disordered structure
                instead of the first structure in the analyzer. Required when
                a secondary mobile ion is present.
        Notes:
            The method doesn't apply to NPT-AIMD simulation analysis.
        """
        from pymatgen.util.plotting import pretty_plot
        if self.lattices is not None and len(self.lattices) > 1:
            warnings.warn("Note the method doesn't apply to NPT-AIMD "
                          "simulation analysis!")

        plt = pretty_plot(12, 8, plt=plt)
        step = (self.corrected_displacements.shape[1] - 1) // (granularity - 1)
        f = (matching_s or self.structure).copy()
        f.remove_species([self.specie])
        sm = StructureMatcher(primitive_cell=False, stol=0.6,
                              comparator=OrderDisorderElementComparator(),
                              allow_subset=True)
        rms = []
        for s in self.get_drift_corrected_structures(step=step):
            s.remove_species([self.specie])
            d = sm.get_rms_dist(f, s)
            if d:
                rms.append(d)
            else:
                rms.append((1, 1))
        max_dt = (len(rms) - 1) * step * self.step_skip * self.time_step
        if max_dt > 100000:
            plot_dt = np.linspace(0, max_dt / 1000, len(rms))
            unit = 'ps'
        else:
            plot_dt = np.linspace(0, max_dt, len(rms))
            unit = 'fs'
        rms = np.array(rms)
        plt.plot(plot_dt, rms[:, 0], label='RMS')
        plt.plot(plot_dt, rms[:, 1], label='max')
        plt.legend(loc='best')
        plt.xlabel("Timestep ({})".format(unit))
        plt.ylabel("normalized distance")
        plt.tight_layout()
        return plt
Exemple #21
0
    def get_locpot_along_slab_plot(self, label_energies=True,
                                   plt=None, label_fontsize=10):
        """
        Returns a plot of the local potential (eV) vs the
            position along the c axis of the slab model (Ang)

        Args:
            label_energies (bool): Whether to label relevant energy
                quantities such as the work function, Fermi energy,
                vacuum locpot, bulk-like locpot
            plt (plt): Matplotlib pylab object
            label_fontsize (float): Fontsize of labels

        Returns plt of the locpot vs c axis
        """

        plt = pretty_plot(width=10, height=8) if not plt else plt

        # plot the raw locpot signal along c
        plt.plot(self.along_c, self.locpot_along_c, 'b--')

        # Get the local averaged signal of the locpot along c
        xg, yg = [], []
        for i, p in enumerate(self.locpot_along_c):
            # average signal is just the bulk-like potential when in the slab region
            if p < self.ave_bulk_p \
                or self.sorted_sites[-1].frac_coords[self.direction] >= self.along_c[i] \
                    >= self.sorted_sites[0].frac_coords[self.direction]:
                yg.append(self.ave_bulk_p)
                xg.append(self.along_c[i])
            else:
                yg.append(p)
                xg.append(self.along_c[i])
        xg, yg = zip(*sorted(zip(xg, yg)))
        plt.plot(xg, yg, 'r', linewidth=2.5, zorder=-1)

        # make it look nice
        if label_energies:
            plt = self.get_labels(plt, label_fontsize=label_fontsize)
        plt.xlim([0, 1])
        plt.ylim([min(self.locpot_along_c),
                  self.vacuum_locpot + self.ave_locpot * 0.2])
        if self.direction == 0:
            plt.xlabel(r"Fractional coordinates ($\hat{a}$)", fontsize=25)
        elif self.direction == 1:
            plt.xlabel(r"Fractional coordinates ($\hat{b}$)", fontsize=25)
        elif self.direction == 2:
            plt.xlabel(r"Fractional coordinates ($\hat{c}$)", fontsize=25)
        plt.xticks(fontsize=15, rotation=45)
        plt.ylabel(r"Potential (eV)", fontsize=25)
        plt.yticks(fontsize=15)

        return plt
    def get_1d_plot(self,
                    mode: str = "distinct",
                    times: List = [0.0],
                    colors: List = None):
        """
        Plot the van Hove function at given r or t.

        Args:
            mode (str): Specify which part of van Hove function to be plotted.
            times (list of float): Time moments (in ps) in which the van Hove
                            function will be plotted.
            colors (list strings/tuples): Additional color settings. If not set,
                            seaborn.color_plaette("Set1", 10) will be used.
        """
        if colors is None:
            import seaborn as sns

            colors = sns.color_palette("Set1", 10)

        assert mode in ["distinct", "self"]
        assert len(times) <= len(colors)

        if mode == "distinct":
            grt = self.gdrt.copy()
            ylabel = "$G_d$($t$,$r$)"
            ylim = [-0.005, 4.0]
        elif mode == "self":
            grt = self.gsrt.copy()
            ylabel = "4$\pi r^2G_s$($t$,$r$)"
            ylim = [-0.005, 1.0]

        plt = pretty_plot(12, 8)

        for i, time in enumerate(times):
            index = int(np.round(time / self.timeskip))
            index = min(index, np.shape(grt)[0] - 1)
            new_time = index * self.timeskip
            label = str(new_time) + " ps"
            plt.plot(self.interval,
                     grt[index],
                     color=colors[i],
                     label=label,
                     linewidth=4.0)

        plt.xlabel(r"$r$ ($\AA$)")
        plt.ylabel(ylabel)
        plt.legend(loc="upper right", fontsize=36)
        plt.xlim(0.0, self.interval[-1] - 1.0)
        plt.ylim(ylim[0], ylim[1])
        plt.tight_layout()

        return plt
Exemple #23
0
    def get_msd_plot(self, plt=None, mode="specie"):
        """
        Get the plot of the smoothed msd vs time graph. Useful for
        checking convergence. This can be written to an image file.

        Args:
            plt: A plot object. Defaults to None, which means one will be
                generated.
            mode (str): Determines type of msd plot. By "species", "sites",
                or direction (default). If mode = "mscd", the smoothed mscd vs.
                time will be plotted.
        """
        from pymatgen.util.plotting import pretty_plot
        plt = pretty_plot(12, 8, plt=plt)
        if np.max(self.dt) > 100000:
            plot_dt = self.dt / 1000
            unit = 'ps'
        else:
            plot_dt = self.dt
            unit = 'fs'

        if mode == "species":
            for sp in sorted(self.structure.composition.keys()):
                indices = [i for i, site in enumerate(self.structure) if
                           site.specie == sp]
                sd = np.average(self.sq_disp_ions[indices, :], axis=0)
                plt.plot(plot_dt, sd, label=sp.__str__())
            plt.legend(loc=2, prop={"size": 20})
        elif mode == "sites":
            for i, site in enumerate(self.structure):
                sd = self.sq_disp_ions[i, :]
                plt.plot(plot_dt, sd, label="%s - %d" % (
                    site.specie.__str__(), i))
            plt.legend(loc=2, prop={"size": 20})
        elif mode == "mscd":
            plt.plot(plot_dt, self.mscd, 'r')
            plt.legend(["Overall"], loc=2, prop={"size": 20})
        else:
            # Handle default / invalid mode case
            plt.plot(plot_dt, self.msd, 'k')
            plt.plot(plot_dt, self.msd_components[:, 0], 'r')
            plt.plot(plot_dt, self.msd_components[:, 1], 'g')
            plt.plot(plot_dt, self.msd_components[:, 2], 'b')
            plt.legend(["Overall", "a", "b", "c"], loc=2, prop={"size": 20})

        plt.xlabel("Timestep ({})".format(unit))
        if mode == "mscd":
            plt.ylabel("MSCD ($\\AA^2$)")
        else:
            plt.ylabel("MSD ($\\AA^2$)")
        plt.tight_layout()
        return plt
Exemple #24
0
    def get_pourbaix_plot(self, limits=None, title="",
                          label_domains=True, plt=None):
        """
        Plot Pourbaix diagram.

        Args:
            limits: 2D list containing limits of the Pourbaix diagram
                of the form [[xlo, xhi], [ylo, yhi]]
            title (str): Title to display on plot
            label_domains (bool): whether to label pourbaix domains
            plt (pyplot): Pyplot instance for plotting

        Returns:
            plt (pyplot) - matplotlib plot object with pourbaix diagram
        """
        if limits is None:
            limits = [[-2, 16], [-3, 3]]

        plt = plt or pretty_plot(16)

        xlim = limits[0]
        ylim = limits[1]

        h_line = np.transpose([[xlim[0], -xlim[0] * PREFAC],
                               [xlim[1], -xlim[1] * PREFAC]])
        o_line = np.transpose([[xlim[0], -xlim[0] * PREFAC + 1.23],
                               [xlim[1], -xlim[1] * PREFAC + 1.23]])
        neutral_line = np.transpose([[7, ylim[0]], [7, ylim[1]]])
        V0_line = np.transpose([[xlim[0], 0], [xlim[1], 0]])

        ax = plt.gca()
        ax.set_xlim(xlim)
        ax.set_ylim(ylim)
        lw = 3
        plt.plot(h_line[0], h_line[1], "r--", linewidth=lw)
        plt.plot(o_line[0], o_line[1], "r--", linewidth=lw)
        plt.plot(neutral_line[0], neutral_line[1], "k-.", linewidth=lw)
        plt.plot(V0_line[0], V0_line[1], "k-.", linewidth=lw)

        for entry, vertices in self._pd._stable_domain_vertices.items():
            center = np.average(vertices, axis=0)
            x, y = np.transpose(np.vstack([vertices, vertices[0]]))
            plt.plot(x, y, 'k-', linewidth=lw)
            if label_domains:
                plt.annotate(generate_entry_label(entry), center, ha='center',
                             va='center', fontsize=20, color="b")

        plt.xlabel("pH")
        plt.ylabel("E (V)")
        plt.title(title, fontsize=20, fontweight='bold')
        return plt
def main():
    """
    Main function.
    """
    parser = argparse.ArgumentParser(description="""
    Convenient DOS Plotter for Feff runs.
    Author: Alan Dozier
    Version: 1.0
    Last updated: April, 2013""")

    parser.add_argument("filename",
                        metavar="filename",
                        type=str,
                        nargs=1,
                        help="xmu file to plot")
    parser.add_argument(
        "filename1",
        metavar="filename1",
        type=str,
        nargs=1,
        help="feff.inp filename to import",
    )

    plt = pretty_plot(12, 8)
    color_order = ["r", "b", "g", "c", "k", "m", "y"]

    args = parser.parse_args()
    xmu = Xmu.from_file(args.filename[0], args.filename1[0])

    data = xmu.as_dict()

    plt.title(data["calc"] + " Feff9.6 Calculation for " + data["atom"] +
              " in " + data["formula"] + " unit cell")
    plt.xlabel("Energies (eV)")
    plt.ylabel("Absorption Cross-section")

    x = data["energies"]
    y = data["scross"]
    tle = "Single " + data["atom"] + " " + data["edge"] + " edge"
    plt.plot(x, y, color_order[1 % 7], label=tle)

    y = data["across"]
    tle = data["atom"] + " " + data["edge"] + " edge in " + data["formula"]
    plt.plot(x, y, color_order[2 % 7], label=tle)

    plt.legend()
    leg = plt.gca().get_legend()
    ltext = leg.get_texts()  # all the text.Text instance in the legend
    plt.setp(ltext, fontsize=15)
    plt.tight_layout()
    plt.show()
Exemple #26
0
    def get_plot(self, ylim=None):
        """
        Get a matplotlib object for the bandstructure plot.

        Args:
            ylim: Specify the y-axis (frequency) limits; by default None let
                the code choose.
        """
        plt = pretty_plot(12, 8)
        from matplotlib import rc
        import scipy.interpolate as scint
        try:
            rc('text', usetex=True)
        except:
            # Fall back on non Tex if errored.
            rc('text', usetex=False)

        band_linewidth = 1

        data = self.bs_plot_data()
        for d in range(len(data['distances'])):
            for i in range(self._nb_bands):
                plt.plot(data['distances'][d], [
                    data['frequency'][d][i][j]
                    for j in range(len(data['distances'][d]))
                ],
                         'b-',
                         linewidth=band_linewidth)

        self._maketicks(plt)

        # plot y=0 line
        plt.axhline(0, linewidth=1, color='k')

        # Main X and Y Labels
        plt.xlabel(r'$\mathrm{Wave\ Vector}$', fontsize=30)
        ylabel = r'$\mathrm{Frequency\ (THz)}$'
        plt.ylabel(ylabel, fontsize=30)

        # X range (K)
        # last distance point
        x_max = data['distances'][-1][-1]
        plt.xlim(0, x_max)

        if ylim is not None:
            plt.ylim(ylim)

        plt.tight_layout()

        return plt
Exemple #27
0
    def plot(self, width=8, height=None, plt=None, dpi=None, **kwargs):
        """
        Plot the equation of state.

        Args:
            width (float): Width of plot in inches. Defaults to 8in.
            height (float): Height of plot in inches. Defaults to width *
                golden ratio.
            plt (matplotlib.pyplot): If plt is supplied, changes will be made
                to an existing plot. Otherwise, a new plot will be created.
            dpi:
            kwargs (dict): additional args fed to pyplot.plot.
                supported keys: style, color, text, label

        Returns:
            Matplotlib plot object.
        """
        plt = pretty_plot(width=width, height=height, plt=plt, dpi=dpi)

        color = kwargs.get("color", "r")
        label = kwargs.get("label", "{} fit".format(self.__class__.__name__))
        lines = ["Equation of State: %s" % self.__class__.__name__,
                 "Minimum energy = %1.2f eV" % self.e0,
                 "Minimum or reference volume = %1.2f Ang^3" % self.v0,
                 "Bulk modulus = %1.2f eV/Ang^3 = %1.2f GPa" %
                 (self.b0, self.b0_GPa),
                 "Derivative of bulk modulus wrt pressure = %1.2f" % self.b1]
        text = "\n".join(lines)
        text = kwargs.get("text", text)

        # Plot input data.
        plt.plot(self.volumes, self.energies, linestyle="None", marker="o",
                 color=color)

        # Plot eos fit.
        vmin, vmax = min(self.volumes), max(self.volumes)
        vmin, vmax = (vmin - 0.01 * abs(vmin), vmax + 0.01 * abs(vmax))
        vfit = np.linspace(vmin, vmax, 100)

        plt.plot(vfit, self.func(vfit), linestyle="dashed", color=color,
                 label=label)

        plt.grid(True)
        plt.xlabel("Volume $\\AA^3$")
        plt.ylabel("Energy (eV)")
        plt.legend(loc="best", shadow=True)
        # Add text with fit parameters.
        plt.text(0.4, 0.5, text, transform=plt.gca().transAxes)

        return plt
Exemple #28
0
    def get_plot(self, ylim=None, units="thz"):
        """
        Get a matplotlib object for the bandstructure plot.

        Args:
            ylim: Specify the y-axis (frequency) limits; by default None let
                the code choose.
            units: units for the frequencies. Accepted values thz, ev, mev, ha, cm-1, cm^-1.
        """

        u = freq_units(units)

        plt = pretty_plot(12, 8)

        band_linewidth = 1

        data = self.bs_plot_data()
        for d in range(len(data["distances"])):
            for i in range(self._nb_bands):
                plt.plot(
                    data["distances"][d],
                    [
                        data["frequency"][d][i][j] * u.factor
                        for j in range(len(data["distances"][d]))
                    ],
                    "b-",
                    linewidth=band_linewidth,
                )

        self._maketicks(plt)

        # plot y=0 line
        plt.axhline(0, linewidth=1, color="k")

        # Main X and Y Labels
        plt.xlabel(r"$\mathrm{Wave\ Vector}$", fontsize=30)
        ylabel = r"$\mathrm{{Frequencies\ ({})}}$".format(u.label)
        plt.ylabel(ylabel, fontsize=30)

        # X range (K)
        # last distance point
        x_max = data["distances"][-1][-1]
        plt.xlim(0, x_max)

        if ylim is not None:
            plt.ylim(ylim)

        plt.tight_layout()

        return plt
Exemple #29
0
    def get_plot(self, ylim=None):
        """
        Get a matplotlib object for the bandstructure plot.

        Args:
            ylim: Specify the y-axis (frequency) limits; by default None let
                the code choose.
        """
        plt = pretty_plot(12, 8)
        from matplotlib import rc
        import scipy.interpolate as scint
        try:
            rc('text', usetex=True)
        except:
            # Fall back on non Tex if errored.
            rc('text', usetex=False)

        band_linewidth = 1

        data = self.bs_plot_data()
        for d in range(len(data['distances'])):
            for i in range(self._nb_bands):
                plt.plot(data['distances'][d],
                         [data['frequency'][d][i][j]
                          for j in range(len(data['distances'][d]))], 'b-',
                         linewidth=band_linewidth)


        self._maketicks(plt)

        # plot y=0 line
        plt.axhline(0, linewidth=1, color='k')

        # Main X and Y Labels
        plt.xlabel(r'$\mathrm{Wave\ Vector}$', fontsize=30)
        ylabel = r'$\mathrm{Frequency\ (THz)}$'
        plt.ylabel(ylabel, fontsize=30)

        # X range (K)
        # last distance point
        x_max = data['distances'][-1][-1]
        plt.xlim(0, x_max)

        if ylim is not None:
            plt.ylim(ylim)

        plt.tight_layout()

        return plt
Exemple #30
0
    def get_plot(self, structure, two_theta_range=(0, 90),
                 annotate_peaks=True, ax=None, with_labels=True,
                 fontsize=16):
        """
        Returns the diffraction plot as a matplotlib.pyplot.

        Args:
            structure: Input structure
            two_theta_range ([float of length 2]): Tuple for range of
                two_thetas to calculate in degrees. Defaults to (0, 90). Set to
                None if you want all diffracted beams within the limiting
                sphere of radius 2 / wavelength.
            annotate_peaks: Whether to annotate the peaks with plane
                information.
            ax: matplotlib :class:`Axes` or None if a new figure should be created.
            with_labels: True to add xlabels and ylabels to the plot.
            fontsize: (int) fontsize for peak labels.

        Returns:
            (matplotlib.pyplot)
        """
        if ax is None:
            from pymatgen.util.plotting import pretty_plot
            plt = pretty_plot(16, 10)
            ax = plt.gca()
        else:
            # This to maintain the type of the return value.
            import matplotlib.pyplot as plt

        xrd = self.get_pattern(structure, two_theta_range=two_theta_range)

        for two_theta, i, hkls, d_hkl in zip(xrd.x, xrd.y, xrd.hkls, xrd.d_hkls):
            if two_theta_range[0] <= two_theta <= two_theta_range[1]:
                print(hkls)
                label = ", ".join([str(hkl["hkl"]) for hkl in hkls])
                ax.plot([two_theta, two_theta], [0, i], color='k',
                         linewidth=3, label=label)
                if annotate_peaks:
                    ax.annotate(label, xy=[two_theta, i],
                                xytext=[two_theta, i], fontsize=fontsize)

        if with_labels:
            ax.set_xlabel(r"$2\theta$ ($^\circ$)")
            ax.set_ylabel("Intensities (scaled)")

        if hasattr(ax, "tight_layout"):
            ax.tight_layout()

        return plt
def main():
    """
    Main function.
    """
    parser = argparse.ArgumentParser(description='''
    Convenient DOS Plotter for Feff runs.
    Author: Alan Dozier
    Version: 1.0
    Last updated: April, 2013''')

    parser.add_argument('filename',
                        metavar='filename',
                        type=str,
                        nargs=1,
                        help='xmu file to plot')
    parser.add_argument('filename1',
                        metavar='filename1',
                        type=str,
                        nargs=1,
                        help='feff.inp filename to import')

    plt = pretty_plot(12, 8)
    color_order = ['r', 'b', 'g', 'c', 'k', 'm', 'y']

    args = parser.parse_args()
    xmu = Xmu.from_file(args.filename[0], args.filename1[0])

    data = xmu.to_dict

    plt.title(data['calc'] + ' Feff9.6 Calculation for ' + data['atom'] +
              ' in ' + data['formula'] + ' unit cell')
    plt.xlabel('Energies (eV)')
    plt.ylabel('Absorption Cross-section')

    x = data['energies']
    y = data['scross']
    tle = 'Single ' + data['atom'] + ' ' + data['edge'] + ' edge'
    plt.plot(x, y, color_order[1 % 7], label=tle)

    y = data['across']
    tle = data['atom'] + ' ' + data['edge'] + ' edge in ' + data['formula']
    plt.plot(x, y, color_order[2 % 7], label=tle)

    plt.legend()
    leg = plt.gca().get_legend()
    ltext = leg.get_texts()  # all the text.Text instance in the legend
    plt.setp(ltext, fontsize=15)
    plt.tight_layout()
    plt.show()
Exemple #32
0
    def get_plot_gs(self, ylim=None):
        """
        Get a matplotlib object for the gruneisen bandstructure plot.

        Args:
            ylim: Specify the y-axis (gruneisen) limits; by default None let
                the code choose.
        """

        plt = pretty_plot(12, 8)

        # band_linewidth = 1

        data = self.bs_plot_data()
        for d in range(len(data["distances"])):
            for i in range(self._nb_bands):
                plt.plot(
                    data["distances"][d],
                    [
                        data["gruneisen"][d][i][j]
                        for j in range(len(data["distances"][d]))
                    ],
                    "b-",
                    # linewidth=band_linewidth)
                    marker="o",
                    markersize=2,
                    linewidth=2,
                )

        self._maketicks(plt)

        # plot y=0 line
        plt.axhline(0, linewidth=1, color="k")

        # Main X and Y Labels
        plt.xlabel(r"$\mathrm{Wave\ Vector}$", fontsize=30)
        plt.ylabel(r"$\mathrm{Grüneisen\ Parameter}$", fontsize=30)

        # X range (K)
        # last distance point
        x_max = data["distances"][-1][-1]
        plt.xlim(0, x_max)

        if ylim is not None:
            plt.ylim(ylim)

        plt.tight_layout()

        return plt
Exemple #33
0
    def get_plot(self, xlim=None, ylim=None):
        """
        Get a matplotlib plot showing the DOS.

        Args:
            xlim: Specifies the x-axis limits. Set to None for automatic
                determination.
            ylim: Specifies the y-axis limits.
        """

        plt = pretty_plot(12, 8)
        base = 0.0
        i = 0
        for key, sp in self._spectra.items():
            if not self.stack:
                plt.plot(
                    sp.x,
                    sp.y + self.yshift * i,
                    color=self.colors[i],
                    label=str(key),
                    linewidth=3,
                )
            else:
                plt.fill_between(
                    sp.x,
                    base,
                    sp.y + self.yshift * i,
                    color=self.colors[i],
                    label=str(key),
                    linewidth=3,
                )
                base = sp.y + base
            plt.xlabel(sp.XLABEL)
            plt.ylabel(sp.YLABEL)
            i += 1

        if xlim:
            plt.xlim(xlim)
        if ylim:
            plt.ylim(ylim)

        plt.legend()
        leg = plt.gca().get_legend()
        ltext = leg.get_texts()  # all the text.Text instance in the legend
        plt.setp(ltext, fontsize=30)
        plt.tight_layout()
        return plt
Exemple #34
0
    def get_plot(self, ylim=None, units="thz"):
        """
        Get a matplotlib object for the bandstructure plot.

        Args:
            ylim: Specify the y-axis (frequency) limits; by default None let
                the code choose.
            units: units for the frequencies. Accepted values thz, ev, mev, ha, cm-1, cm^-1.
        """

        u = freq_units(units)

        plt = pretty_plot(12, 8)

        band_linewidth = 1

        data = self.bs_plot_data()
        for d in range(len(data['distances'])):
            for i in range(self._nb_bands):
                plt.plot(data['distances'][d],
                         [data['frequency'][d][i][j] * u.factor
                          for j in range(len(data['distances'][d]))], 'b-',
                         linewidth=band_linewidth)

        self._maketicks(plt)

        # plot y=0 line
        plt.axhline(0, linewidth=1, color='k')

        # Main X and Y Labels
        plt.xlabel(r'$\mathrm{Wave\ Vector}$', fontsize=30)
        ylabel = r'$\mathrm{{Frequencies\ ({})}}$'.format(u.label)
        plt.ylabel(ylabel, fontsize=30)

        # X range (K)
        # last distance point
        x_max = data['distances'][-1][-1]
        plt.xlim(0, x_max)

        if ylim is not None:
            plt.ylim(ylim)

        plt.tight_layout()

        return plt
def get_arrhenius_plot(temps, diffusivities, diffusivity_errors=None,
                       **kwargs):
    """
    Returns an Arrhenius plot.

    Args:
        temps ([float]): A sequence of temperatures.
        diffusivities ([float]): A sequence of diffusivities (e.g.,
            from DiffusionAnalyzer.diffusivity).
        diffusivity_errors ([float]): A sequence of errors for the
            diffusivities. If None, no error bar is plotted.
        \\*\\*kwargs:
            Any keyword args supported by matplotlib.pyplot.plot.

    Returns:
        A matplotlib.pyplot object. Do plt.show() to show the plot.
    """
    Ea, c, _ = fit_arrhenius(temps, diffusivities)

    from pymatgen.util.plotting import pretty_plot
    plt = pretty_plot(12, 8)

    # log10 of the arrhenius fit
    arr = c * np.exp(-Ea / (const.k / const.e * np.array(temps)))

    t_1 = 1000 / np.array(temps)

    plt.plot(t_1, diffusivities, 'ko', t_1, arr, 'k--', markersize=10,
             **kwargs)
    if diffusivity_errors is not None:
        n = len(diffusivity_errors)
        plt.errorbar(t_1[0:n], diffusivities[0:n], yerr=diffusivity_errors,
                     fmt='ko', ecolor='k', capthick=2, linewidth=2)
    ax = plt.axes()
    ax.set_yscale('log')
    plt.text(0.6, 0.85, "E$_a$ = {:.0f} meV".format(Ea * 1000),
             fontsize=30, transform=plt.axes().transAxes)
    plt.ylabel("D (cm$^2$/s)")
    plt.xlabel("1000/T (K$^{-1}$)")
    plt.tight_layout()
    return plt
def main():
    parser = argparse.ArgumentParser(description='''
    Convenient DOS Plotter for Feff runs.
    Author: Alan Dozier
    Version: 1.0
    Last updated: April, 2013''')

    parser.add_argument('filename', metavar='filename', type=str, nargs=1,
                        help='xmu file to plot')
    parser.add_argument('filename1', metavar='filename1', type=str, nargs=1,
                        help='feff.inp filename to import')

    plt = pretty_plot(12, 8)
    color_order = ['r', 'b', 'g', 'c', 'k', 'm', 'y']

    args = parser.parse_args()
    xmu = Xmu.from_file(args.filename[0], args.filename1[0])

    data = xmu.to_dict

    plt.title(data['calc'] + ' Feff9.6 Calculation for ' + data['atom'] + ' in ' +
              data['formula'] + ' unit cell')
    plt.xlabel('Energies (eV)')
    plt.ylabel('Absorption Cross-section')

    x = data['energies']
    y = data['scross']
    tle = 'Single ' + data['atom'] + ' ' + data['edge'] + ' edge'
    plt.plot(x, y, color_order[1 % 7], label=tle)

    y = data['across']
    tle = data['atom'] + ' ' + data['edge'] + ' edge in ' + data['formula']
    plt.plot(x, y, color_order[2 % 7], label=tle)

    plt.legend()
    leg = plt.gca().get_legend()
    ltext = leg.get_texts()  # all the text.Text instance in the legend
    plt.setp(ltext, fontsize=15)
    plt.tight_layout()
    plt.show()
Exemple #37
0
    def get_xrd_plot(self,
                     structure,
                     two_theta_range=(0, 90),
                     annotate_peaks=True):
        """
        Returns the XRD plot as a matplotlib.pyplot.

        Args:
            structure: Input structure
            two_theta_range ([float of length 2]): Tuple for range of
                two_thetas to calculate in degrees. Defaults to (0, 90). Set to
                None if you want all diffracted beams within the limiting
                sphere of radius 2 / wavelength.
            annotate_peaks: Whether to annotate the peaks with plane
                information.

        Returns:
            (matplotlib.pyplot)
        """
        from pymatgen.util.plotting import pretty_plot
        plt = pretty_plot(16, 10)
        for two_theta, i, hkls, d_hkl in self.get_xrd_data(
                structure, two_theta_range=two_theta_range):
            if two_theta_range[0] <= two_theta <= two_theta_range[1]:
                label = ", ".join([str(hkl) for hkl in hkls.keys()])
                plt.plot([two_theta, two_theta], [0, i],
                         color='k',
                         linewidth=3,
                         label=label)
                if annotate_peaks:
                    plt.annotate(label,
                                 xy=[two_theta, i],
                                 xytext=[two_theta, i],
                                 fontsize=16)
        plt.xlabel(r"$2\theta$ ($^\circ$)")
        plt.ylabel("Intensities (scaled)")
        plt.tight_layout()

        return plt
Exemple #38
0
def get_chgint_plot(args):
    chgcar = Chgcar.from_file(args.chgcar_file)
    s = chgcar.structure

    if args.inds:
        atom_ind = [int(i) for i in args.inds[0].split(",")]
    else:
        finder = SpacegroupAnalyzer(s, symprec=0.1)
        sites = [sites[0] for sites in
                 finder.get_symmetrized_structure().equivalent_sites]
        atom_ind = [s.sites.index(site) for site in sites]

    from pymatgen.util.plotting import pretty_plot
    plt = pretty_plot(12, 8)
    for i in atom_ind:
        d = chgcar.get_integrated_diff(i, args.radius, 30)
        plt.plot(d[:, 0], d[:, 1],
                 label="Atom {} - {}".format(i, s[i].species_string))
    plt.legend(loc="upper left")
    plt.xlabel("Radius (A)")
    plt.ylabel("Integrated charge (e)")
    plt.tight_layout()
    return plt
Exemple #39
0
    def get_plot(self, xlim=None, ylim=None):
        """
        Get a matplotlib plot showing the DOS.

        Args:
            xlim: Specifies the x-axis limits. Set to None for automatic
                determination.
            ylim: Specifies the y-axis limits.
        """

        plt = pretty_plot(12, 8)
        base = 0.0
        i = 0
        for key, sp in self._spectra.items():
            if not self.stack:
                plt.plot(sp.x, sp.y + self.yshift * i, color=self.colors[i],
                         label=str(key), linewidth=3)
            else:
                plt.fill_between(sp.x, base, sp.y + self.yshift * i,
                                 color=self.colors[i],
                                 label=str(key), linewidth=3)
                base = sp.y + base
            plt.xlabel(sp.XLABEL)
            plt.ylabel(sp.YLABEL)
            i += 1

        if xlim:
            plt.xlim(xlim)
        if ylim:
            plt.ylim(ylim)

        plt.legend()
        leg = plt.gca().get_legend()
        ltext = leg.get_texts()  # all the text.Text instance in the legend
        plt.setp(ltext, fontsize=30)
        plt.tight_layout()
        return plt
Exemple #40
0
    def get_plot(self, width=8, height=8):
        """
        Returns a plot object.

        Args:
            width: Width of the plot. Defaults to 8 in.
            height: Height of the plot. Defaults to 6 in.

        Returns:
            A matplotlib plot object.
        """
        plt = pretty_plot(width, height)
        for label, electrode in self._electrodes.items():
            (x, y) = self.get_plot_data(electrode)
            plt.plot(x, y, '-', linewidth=2, label=label)

        plt.legend()
        if self.xaxis == "capacity":
            plt.xlabel('Capacity (mAh/g)')
        else:
            plt.xlabel('Fraction')
        plt.ylabel('Voltage (V)')
        plt.tight_layout()
        return plt
Exemple #41
0
    def get_plot(self, width=8, height=8):
        """
        Returns a plot object.

        Args:
            width: Width of the plot. Defaults to 8 in.
            height: Height of the plot. Defaults to 6 in.

        Returns:
            A matplotlib plot object.
        """
        plt = pretty_plot(width, height)
        for label, electrode in self._electrodes.items():
            (x, y) = self.get_plot_data(electrode)
            plt.plot(x, y, '-', linewidth=2, label=label)

        plt.legend()
        if self.xaxis == "capacity":
            plt.xlabel('Capacity (mAh/g)')
        else:
            plt.xlabel('Fraction')
        plt.ylabel('Voltage (V)')
        plt.tight_layout()
        return plt
Exemple #42
0
    def get_pourbaix_mark_passive(self, limits=None, title="", label_domains=True, passive_entry=None):
        """
        Color domains by element
        """
        from matplotlib.patches import Polygon
        from pymatgen import Element
        from itertools import chain
        import operator

        plt = pretty_plot(16)
        optim_colors = ['#0000FF', '#FF0000', '#00FF00', '#FFFF00', '#FF00FF',
                        '#FF8080', '#DCDCDC', '#800000', '#FF8000']
        optim_font_colors = ['#FFC000', '#00FFFF', '#FF00FF', '#0000FF', '#00FF00',
                            '#007F7F', '#232323', '#7FFFFF', '#007FFF']
        (stable, unstable) = self.pourbaix_plot_data(limits)
        mark_passive = {key: 0 for key in stable.keys()}

        if self._pd._elt_comp:
            maxval = max(six.iteritems(self._pd._elt_comp), key=operator.itemgetter(1))[1]
            key = [k for k, v in self._pd._elt_comp.items() if v == maxval]
        passive_entry = key[0]

        def list_elts(entry):
            elts_list = set()
            if isinstance(entry, MultiEntry):
                for el in chain.from_iterable([[el for el in e.composition.elements]
                                                for e in entry.entrylist]):
                    elts_list.add(el)
            else:
                elts_list = entry.composition.elements
            return elts_list

        for entry in stable:
            if passive_entry + str("(s)") in entry.name:
                mark_passive[entry] = 2
                continue
            if "(s)" not in entry.name:
                continue
            elif len(set([Element("O"), Element("H")]).intersection(set(list_elts(entry)))) > 0:
                mark_passive[entry] = 1

        if limits:
            xlim = limits[0]
            ylim = limits[1]
        else:
            xlim = self._analyzer.chempot_limits[0]
            ylim = self._analyzer.chempot_limits[1]

        h_line = np.transpose([[xlim[0], -xlim[0] * PREFAC],
                               [xlim[1], -xlim[1] * PREFAC]])
        o_line = np.transpose([[xlim[0], -xlim[0] * PREFAC + 1.23],
                               [xlim[1], -xlim[1] * PREFAC + 1.23]])
        neutral_line = np.transpose([[7, ylim[0]], [7, ylim[1]]])
        V0_line = np.transpose([[xlim[0], 0], [xlim[1], 0]])

        ax = plt.gca()
        ax.set_xlim(xlim)
        ax.set_ylim(ylim)
        for e in stable.keys():
            xy = self.domain_vertices(e)
            c = self.get_center(stable[e])
            if mark_passive[e] == 1:
                color = optim_colors[0]
                fontcolor = optim_font_colors[0]
                colorfill = True
            elif mark_passive[e] == 2:
                color = optim_colors[1]
                fontcolor = optim_font_colors[1]
                colorfill = True
            else:
                color = "w"
                colorfill = False
                fontcolor = "k"
            patch = Polygon(xy, facecolor=color, closed=True, lw=3.0, fill=colorfill)
            ax.add_patch(patch)
            if label_domains:
                plt.annotate(self.print_name(e), c, color=fontcolor, fontsize=20)

        lw = 3
        plt.plot(h_line[0], h_line[1], "r--", linewidth=lw)
        plt.plot(o_line[0], o_line[1], "r--", linewidth=lw)
        plt.plot(neutral_line[0], neutral_line[1], "k-.", linewidth=lw)
        plt.plot(V0_line[0], V0_line[1], "k-.", linewidth=lw)

        plt.xlabel("pH")
        plt.ylabel("E (V)")
        plt.title(title, fontsize=20, fontweight='bold')
        return plt
Exemple #43
0
    def get_pourbaix_plot_colorfill_by_domain_name(self, limits=None, title="",
            label_domains=True, label_color='k', domain_color=None, domain_fontsize=None,
            domain_edge_lw=0.5, bold_domains=None, cluster_domains=(),
            add_h2o_stablity_line=True, add_center_line=False, h2o_lw=0.5,
            fill_domain=True, width=8, height=None, font_family='Times New Roman'):
        """
        Color domains by the colors specific by the domain_color dict

        Args:
            limits: 2D list containing limits of the Pourbaix diagram
                of the form [[xlo, xhi], [ylo, yhi]]
            lable_domains (Bool): whether add the text lable for domains
            label_color (str): color of domain lables, defaults to be black
            domain_color (dict): colors of each domain e.g {"Al(s)": "#FF1100"}. If set
                to None default color set will be used.
            domain_fontsize (int): Font size used in domain text labels.
            domain_edge_lw (int): line width for the boundaries between domains.
            bold_domains (list): List of domain names to use bold text style for domain
                lables. If set to False, no domain will be bold.
            cluster_domains (list): List of domain names in cluster phase
            add_h2o_stablity_line (Bool): whether plot H2O stability line
            add_center_line (Bool): whether plot lines shows the center coordinate
            h2o_lw (int): line width for H2O stability line and center lines
            fill_domain (bool): a version without color will be product if set 
                to False.
            width (float): Width of plot in inches. Defaults to 8in.
                height (float): Height of plot in inches. Defaults to width * golden
                ratio.
            font_family (str): Font family of the labels
        """

        def special_lines(xlim, ylim):
            h_line = np.transpose([[xlim[0], -xlim[0] * PREFAC],
                                   [xlim[1], -xlim[1] * PREFAC]])
            o_line = np.transpose([[xlim[0], -xlim[0] * PREFAC + 1.23],
                                   [xlim[1], -xlim[1] * PREFAC + 1.23]])
            neutral_line = np.transpose([[7, ylim[0]], [7, ylim[1]]])
            V0_line = np.transpose([[xlim[0], 0], [xlim[1], 0]])
            return h_line, o_line, neutral_line, V0_line

        from matplotlib.patches import Polygon
        import copy

        default_domain_font_size = 12
        default_solid_phase_color = '#b8f9e7'    # this slighly darker than the MP scheme, to
        default_cluster_phase_color = '#d0fbef'  # avoid making the cluster phase too light

        plt = pretty_plot(width=width, height=height, dpi=300)

        (stable, unstable) = self.pourbaix_plot_data(limits)

        xlim, ylim = limits[:2] if limits else self._analyzer.chempot_limits[:2]
        h_line, o_line, neutral_line, V0_line = special_lines(xlim, ylim)
        ax = plt.gca()
        ax.set_xlim(xlim)
        ax.set_ylim(ylim)
        ax.xaxis.set_major_formatter(FormatStrFormatter('%d'))
        ax.yaxis.set_major_formatter(FormatStrFormatter('%.1f'))
        ax.tick_params(direction='out')
        ax.xaxis.set_ticks_position('bottom')
        ax.yaxis.set_ticks_position('left')

        sorted_entry = list(stable.keys())
        sorted_entry.sort(key=lambda en: en.energy, reverse=True)

        if domain_fontsize is None:
            domain_fontsize = {en.name: default_domain_font_size for en in sorted_entry}
        elif not isinstance(domain_fontsize, dict):
            domain_fontsize = {en.name: domain_fontsize for en in sorted_entry}
        if domain_color is None:
            domain_color = {en.name: default_solid_phase_color if '(s)' in en.name else
                            (default_cluster_phase_color if en.name in cluster_domains else 'w')
                            for i, en in enumerate(sorted_entry)}
        else:
            domain_color = {en.name: domain_color[en.name] if en.name in domain_color else "w"
                            for i, en in enumerate(sorted_entry)}
        if bold_domains is None:
            bold_domains = [en.name for en in sorted_entry if '(s)' not in en.name]
        if bold_domains == False:
            bold_domains = []

        for entry in sorted_entry:
            xy = self.domain_vertices(entry)
            if add_h2o_stablity_line:
                c = self.get_distribution_corrected_center(stable[entry], h_line, o_line, 0.3)
            else:
                c = self.get_distribution_corrected_center(stable[entry])
            patch = Polygon(xy, facecolor=domain_color[entry.name], edgecolor="black",
                            closed=True, lw=domain_edge_lw, fill=fill_domain, antialiased=True)
            ax.add_patch(patch)
            if label_domains:
                if platform.system() == 'Darwin' and font_family == "Times New Roman":
                    # Have to hack to the hard coded font path to get current font On Mac OS X
                    if entry.name in bold_domains:
                        font = FontProperties(fname='/Library/Fonts/Times New Roman Bold.ttf',
                                              size=domain_fontsize[entry.name])
                    else:
                        font = FontProperties(fname='/Library/Fonts/Times New Roman.ttf',
                                              size=domain_fontsize[entry.name])
                else:
                    if entry.name in bold_domains:
                        font = FontProperties(family=font_family,
                                              weight='bold',
                                              size=domain_fontsize[entry.name])
                    else:
                        font = FontProperties(family=font_family,
                                              weight='regular',
                                              size=domain_fontsize[entry.name])
                plt.text(*c, s=self.print_name(entry),
                         fontproperties=font,
                         horizontalalignment="center", verticalalignment="center",
                         multialignment="center", color=label_color)

        if add_h2o_stablity_line:
            dashes = (3, 1.5)
            line, = plt.plot(h_line[0], h_line[1], "k--", linewidth=h2o_lw, antialiased=True)
            line.set_dashes(dashes)
            line, = plt.plot(o_line[0], o_line[1], "k--", linewidth=h2o_lw, antialiased=True)
            line.set_dashes(dashes)
        if add_center_line:
            plt.plot(neutral_line[0], neutral_line[1], "k-.", linewidth=h2o_lw, antialiased=False)
            plt.plot(V0_line[0], V0_line[1], "k-.", linewidth=h2o_lw, antialiased=False)

        plt.xlabel("pH", fontname=font_family, fontsize=18)
        plt.ylabel("E (V vs SHE)", fontname=font_family, fontsize=18)
        plt.xticks(fontname=font_family, fontsize=16)
        plt.yticks(fontname=font_family, fontsize=16)
        plt.title(title, fontsize=20, fontweight='bold', fontname=font_family)
        return plt
Exemple #44
0
    def get_pourbaix_plot_colorfill_by_element(self, limits=None, title="",
                                                label_domains=True, element=None):
        """
        Color domains by element
        """
        from matplotlib.patches import Polygon

        entry_dict_of_multientries = collections.defaultdict(list)
        plt = pretty_plot(16)
        optim_colors = ['#0000FF', '#FF0000', '#00FF00', '#FFFF00', '#FF00FF',
                         '#FF8080', '#DCDCDC', '#800000', '#FF8000']
        optim_font_color = ['#FFFFA0', '#00FFFF', '#FF00FF', '#0000FF', '#00FF00',
                            '#007F7F', '#232323', '#7FFFFF', '#007FFF']
        hatch = ['/', '\\', '|', '-', '+', 'o', '*']
        (stable, unstable) = self.pourbaix_plot_data(limits)
        num_of_overlaps = {key: 0 for key in stable.keys()}
        for entry in stable:
            if isinstance(entry, MultiEntry):
                for e in entry.entrylist:
                    if element in e.composition.elements:
                        entry_dict_of_multientries[e.name].append(entry)
                        num_of_overlaps[entry] += 1
            else:
                entry_dict_of_multientries[entry.name].append(entry)
        if limits:
            xlim = limits[0]
            ylim = limits[1]
        else:
            xlim = self._analyzer.chempot_limits[0]
            ylim = self._analyzer.chempot_limits[1]

        h_line = np.transpose([[xlim[0], -xlim[0] * PREFAC],
                               [xlim[1], -xlim[1] * PREFAC]])
        o_line = np.transpose([[xlim[0], -xlim[0] * PREFAC + 1.23],
                               [xlim[1], -xlim[1] * PREFAC + 1.23]])
        neutral_line = np.transpose([[7, ylim[0]], [7, ylim[1]]])
        V0_line = np.transpose([[xlim[0], 0], [xlim[1], 0]])

        ax = plt.gca()
        ax.set_xlim(xlim)
        ax.set_ylim(ylim)
        from pymatgen import Composition, Element
        from pymatgen.core.ion import Ion

        def len_elts(entry):
            if "(s)" in entry:
                comp = Composition(entry[:-3])
            else:
                comp = Ion.from_formula(entry)
            return len([el for el in comp.elements if el not in
                        [Element("H"), Element("O")]])

        sorted_entry = entry_dict_of_multientries.keys()
        sorted_entry.sort(key=len_elts)
        i = -1
        label_chr = map(chr, list(range(65, 91)))
        for entry in sorted_entry:
            color_indx = 0
            x_coord = 0.0
            y_coord = 0.0
            npts = 0
            i += 1
            for e in entry_dict_of_multientries[entry]:
                hc = 0
                fc = 0
                bc = 0
                xy = self.domain_vertices(e)
                c = self.get_center(stable[e])
                x_coord += c[0]
                y_coord += c[1]
                npts += 1
                color_indx = i
                if "(s)" in entry:
                    comp = Composition(entry[:-3])
                else:
                    comp = Ion.from_formula(entry)
                if len([el for el in comp.elements if el not in
                         [Element("H"), Element("O")]]) == 1:
                    if color_indx >= len(optim_colors):
                        color_indx = color_indx -\
                         int(color_indx / len(optim_colors)) * len(optim_colors)
                    patch = Polygon(xy, facecolor=optim_colors[color_indx],
                                     closed=True, lw=3.0, fill=True)
                    bc = optim_colors[color_indx]
                else:
                    if color_indx >= len(hatch):
                        color_indx = color_indx - int(color_indx / len(hatch)) * len(hatch)
                    patch = Polygon(xy, hatch=hatch[color_indx], closed=True, lw=3.0, fill=False)
                    hc = hatch[color_indx]
                ax.add_patch(patch)

            xy_center = (x_coord / npts, y_coord / npts)
            if label_domains:
                if color_indx >= len(optim_colors):
                    color_indx = color_indx -\
                        int(color_indx / len(optim_colors)) * len(optim_colors)
                fc = optim_font_color[color_indx]
                if bc and not hc:
                    bbox = dict(boxstyle="round", fc=fc)
                if hc and not bc:
                    bc = 'k'
                    fc = 'w'
                    bbox = dict(boxstyle="round", hatch=hc, fill=False)
                if bc and hc:
                    bbox = dict(boxstyle="round", hatch=hc, fc=fc)
#                 bbox.set_path_effects([PathEffects.withSimplePatchShadow()])
                plt.annotate(latexify_ion(latexify(entry)), xy_center,
                            color=bc, fontsize=30, bbox=bbox)
#                 plt.annotate(label_chr[i], xy_center,
#                               color=bc, fontsize=30, bbox=bbox)

        lw = 3
        plt.plot(h_line[0], h_line[1], "r--", linewidth=lw)
        plt.plot(o_line[0], o_line[1], "r--", linewidth=lw)
        plt.plot(neutral_line[0], neutral_line[1], "k-.", linewidth=lw)
        plt.plot(V0_line[0], V0_line[1], "k-.", linewidth=lw)

        plt.xlabel("pH")
        plt.ylabel("E (V)")
        plt.title(title, fontsize=20, fontweight='bold')
        return plt
Exemple #45
0
    def area_frac_vs_chempot_plot(self, cmap=cm.jet, at_intersections=False,
                                  increments=10):
        """
        Plots the change in the area contribution of
        each facet as a function of chemical potential.
        Args:
            cmap (cm): A matplotlib colormap object, defaults to jet.
            at_intersections (bool): Whether to generate a Wulff shape for each
                intersection of surface energy for a specific facet (eg. at the
                point where a (111) stoichiometric surface energy plot intersects
                with the (111) nonstoichiometric plot) or to just generate two
                Wulff shapes, one at the min and max chemical potential.
            increments (bool): Number of data points between min/max or point
                of intersection. Defaults to 5 points.
        """

        # Choose unique colors for each facet
        f = [int(i) for i in np.linspace(0, 255, len(self.vasprun_dict.keys()))]

        # Get all points of min/max chempot and intersections
        chempot_intersections = []
        chempot_intersections.extend(self.chempot_range)
        for hkl in self.vasprun_dict.keys():
            chempot_intersections.extend([ints[0] for ints in
                                          self.get_intersections(hkl)])
        chempot_intersections = sorted(chempot_intersections)

        # Get all chempots
        if at_intersections:
            all_chempots = []
            for i, intersection in enumerate(chempot_intersections):
                if i < len(chempot_intersections)-1:
                    all_chempots.extend(np.linspace(intersection,
                                                    chempot_intersections[i+1],
                                                    increments))
        else:
            all_chempots = np.linspace(min(self.chempot_range),
                                       max(self.chempot_range), increments)

        # initialize a dictionary of lists of fractional areas for each hkl
        hkl_area_dict = {}
        for hkl in self.vasprun_dict.keys():
            hkl_area_dict[hkl] = []

        # Get plot points for each Miller index
        for u in all_chempots:
            wulffshape = self.wulff_shape_from_chempot(u)
            for hkl in wulffshape.area_fraction_dict.keys():
                hkl_area_dict[hkl].append(wulffshape.area_fraction_dict[hkl])

        # Plot the area fraction vs chemical potential for each facet
        plt = pretty_plot()
        for i, hkl in enumerate(self.vasprun_dict.keys()):
            # Ignore any facets that never show up on the
            # Wulff shape regardless of chemical potential
            if all([a == 0 for a in hkl_area_dict[hkl]]):
                continue
            else:
                plt.plot(all_chempots, hkl_area_dict[hkl],
                         '--', color=cmap(f[i]), label=str(hkl))

        # Make the figure look nice
        plt.ylim([0,1])
        plt.xlim(self.chempot_range)
        plt.ylabel(r"Fractional area $A^{Wulff}_{hkl}/A^{Wulff}$")
        plt.xlabel(r"Chemical potential $\Delta\mu_{%s}$ (eV)" %(self.ref_element))
        plt.legend(bbox_to_anchor=(1.01, 1), loc=2, borderaxespad=0.)

        return plt
Exemple #46
0
    def get_plot(self, xlim=None, ylim=None):
        """
        Get a matplotlib plot showing the DOS.

        Args:
            xlim: Specifies the x-axis limits. Set to None for automatic
                determination.
            ylim: Specifies the y-axis limits.
        """
        import prettyplotlib as ppl
        from prettyplotlib import brewer2mpl

        ncolors = max(3, len(self._doses))
        ncolors = min(9, ncolors)
        colors = brewer2mpl.get_map('Set1', 'qualitative', ncolors).mpl_colors

        y = None
        alldensities = []
        allfrequencies = []
        plt = pretty_plot(12, 8)

        # Note that this complicated processing of frequencies is to allow for
        # stacked plots in matplotlib.
        for key, dos in self._doses.items():
            frequencies = dos['frequencies']
            densities = dos['densities']
            if y is None:
                y = np.zeros(frequencies.shape)
            if self.stack:
                y += densities
                newdens = y.copy()
            else:
                newdens = densities
            allfrequencies.append(frequencies)
            alldensities.append(newdens)

        keys = list(self._doses.keys())
        keys.reverse()
        alldensities.reverse()
        allfrequencies.reverse()
        allpts = []
        for i, (key, frequencies, densities) in enumerate(zip(keys, allfrequencies, alldensities)):
            allpts.extend(list(zip(frequencies, densities)))
            if self.stack:
                plt.fill(frequencies, densities, color=colors[i % ncolors],
                         label=str(key))
            else:
                ppl.plot(frequencies, densities, color=colors[i % ncolors],
                         label=str(key), linewidth=3)

        if xlim:
            plt.xlim(xlim)
        if ylim:
            plt.ylim(ylim)
        else:
            xlim = plt.xlim()
            relevanty = [p[1] for p in allpts
                         if xlim[0] < p[0] < xlim[1]]
            plt.ylim((min(relevanty), max(relevanty)))

        ylim = plt.ylim()
        plt.plot([0, 0], ylim, 'k--', linewidth=2)

        plt.xlabel('Frequencies (THz)')
        plt.ylabel('Density of states')

        plt.legend()
        leg = plt.gca().get_legend()
        ltext = leg.get_texts()  # all the text.Text instance in the legend
        plt.setp(ltext, fontsize=30)
        plt.tight_layout()
        return plt
Exemple #47
0
    def chempot_vs_gamma_plot(self, cmap=cm.jet, show_unstable_points=False):
        """
        Plots the surface energy of all facets as a function of chemical potential.
            Each facet will be associated with its own distinct colors. Dashed lines
            will represent stoichiometries different from that of the mpid's compound.

        Args:
            cmap (cm): A matplotlib colormap object, defaults to jet.
            show_unstable_points (bool): For each facet, there may be various
                terminations or stoichiometries and the relative stability of
                these different slabs may change with chemical potential. This
                option will only plot the most stable surface energy for a
                given chemical potential.
        """

        plt = pretty_plot()
        # Choose unique colors for each facet
        f = [int(i) for i in np.linspace(0, 255, sum([len(vaspruns) for vaspruns in
                                                      self.vasprun_dict.values()]))]
        i, already_labelled, colors = 0, [], []
        for hkl in self.vasprun_dict.keys():
            for vasprun in self.vasprun_dict[hkl]:
                slab = vasprun.final_structure
                # Generate a label for the type of slab
                label = str(hkl)
                # use dashed lines for slabs that are not stoichiometric
                # wrt bulk. Label with formula if nonstoichiometric
                if slab.composition.reduced_composition != \
                        self.ucell_entry.composition.reduced_composition:
                    mark = '--'
                    label += " %s" % (slab.composition.reduced_composition)
                else:
                    mark = '-'

                # label the chemical environment at the surface if different from the bulk.
                # First get the surface sites, then get the reduced composition at the surface
                # s = vasprun.final_structure
                # ucell = SpacegroupAnalyzer(self.ucell_entry.structure).\
                #     get_conventional_standard_structure()
                # slab = Slab(s.lattice, s.species, s.frac_coords, hkl, ucell, 0, None)
                # surf_comp = slab.surface_composition()
                #
                # if surf_comp.reduced_composition != ucell.composition.reduced_composition:
                #     label += " %s" %(surf_comp.reduced_composition)

                if label in already_labelled:
                    c = colors[already_labelled.index(label)]
                    label = None
                else:
                    already_labelled.append(label)
                    c = cmap(f[i])
                    colors.append(c)

                se_range = self.calculate_gamma(vasprun)
                plt.plot(self.chempot_range, se_range, mark, color=c, label=label)
                i += 1

        # Make the figure look nice
        axes = plt.gca()
        ylim = axes.get_ylim()
        plt.ylim(ylim)
        plt.xlim(self.chempot_range)
        plt.ylabel(r"Surface energy (eV/$\AA$)")
        plt.xlabel(r"Chemical potential $\Delta\mu_{%s}$ (eV)" %(self.ref_element))
        plt.legend(bbox_to_anchor=(1.01, 1), loc=2, borderaxespad=0.)

        return plt
Exemple #48
0
    def get_plot(self, xlim=None, ylim=None, units="thz"):
        """
        Get a matplotlib plot showing the DOS.

        Args:
            xlim: Specifies the x-axis limits. Set to None for automatic
                determination.
            ylim: Specifies the y-axis limits.
            units: units for the frequencies. Accepted values thz, ev, mev, ha, cm-1, cm^-1.
        """

        u = freq_units(units)

        ncolors = max(3, len(self._doses))
        ncolors = min(9, ncolors)

        import palettable

        colors = palettable.colorbrewer.qualitative.Set1_9.mpl_colors

        y = None
        alldensities = []
        allfrequencies = []
        plt = pretty_plot(12, 8)

        # Note that this complicated processing of frequencies is to allow for
        # stacked plots in matplotlib.
        for key, dos in self._doses.items():
            frequencies = dos['frequencies'] * u.factor
            densities = dos['densities']
            if y is None:
                y = np.zeros(frequencies.shape)
            if self.stack:
                y += densities
                newdens = y.copy()
            else:
                newdens = densities
            allfrequencies.append(frequencies)
            alldensities.append(newdens)

        keys = list(self._doses.keys())
        keys.reverse()
        alldensities.reverse()
        allfrequencies.reverse()
        allpts = []
        for i, (key, frequencies, densities) in enumerate(zip(keys, allfrequencies, alldensities)):
            allpts.extend(list(zip(frequencies, densities)))
            if self.stack:
                plt.fill(frequencies, densities, color=colors[i % ncolors],
                         label=str(key))
            else:
                plt.plot(frequencies, densities, color=colors[i % ncolors],
                         label=str(key), linewidth=3)

        if xlim:
            plt.xlim(xlim)
        if ylim:
            plt.ylim(ylim)
        else:
            xlim = plt.xlim()
            relevanty = [p[1] for p in allpts
                         if xlim[0] < p[0] < xlim[1]]
            plt.ylim((min(relevanty), max(relevanty)))

        ylim = plt.ylim()
        plt.plot([0, 0], ylim, 'k--', linewidth=2)

        plt.xlabel(r'$\mathrm{{Frequencies\ ({})}}$'.format(u.label))
        plt.ylabel(r'$\mathrm{Density\ of\ states}$')

        plt.legend()
        leg = plt.gca().get_legend()
        ltext = leg.get_texts()  # all the text.Text instance in the legend
        plt.setp(ltext, fontsize=30)
        plt.tight_layout()
        return plt
Exemple #49
0
    def _get_2d_plot(self, label_stable=True, label_unstable=True,
                     ordering=None, energy_colormap=None, vmin_mev=-60.0,
                     vmax_mev=60.0, show_colorbar=True,
                     process_attributes=False):
        """
        Shows the plot using pylab.  Usually I won't do imports in methods,
        but since plotting is a fairly expensive library to load and not all
        machines have matplotlib installed, I have done it this way.
        """

        plt = pretty_plot(8, 6)
        from matplotlib.font_manager import FontProperties
        if ordering is None:
            (lines, labels, unstable) = self.pd_plot_data
        else:
            (_lines, _labels, _unstable) = self.pd_plot_data
            (lines, labels, unstable) = order_phase_diagram(
                _lines, _labels, _unstable, ordering)
        if energy_colormap is None:
            if process_attributes:
                for x, y in lines:
                    plt.plot(x, y, "k-", linewidth=3, markeredgecolor="k")
                # One should think about a clever way to have "complex"
                # attributes with complex processing options but with a clear
                #  logic. At this moment, I just use the attributes to know
                # whether an entry is a new compound or an existing (from the
                #  ICSD or from the MP) one.
                for x, y in labels.keys():
                    if labels[(x, y)].attribute is None or \
                            labels[(x, y)].attribute == "existing":
                        plt.plot(x, y, "ko", linewidth=3, markeredgecolor="k",
                                 markerfacecolor="b", markersize=12)
                    else:
                        plt.plot(x, y, "k*", linewidth=3, markeredgecolor="k",
                                 markerfacecolor="g", markersize=18)
            else:
                for x, y in lines:
                    plt.plot(x, y, "ko-", linewidth=3, markeredgecolor="k",
                             markerfacecolor="b", markersize=15)
        else:
            from matplotlib.colors import Normalize, LinearSegmentedColormap
            from matplotlib.cm import ScalarMappable
            pda = PDAnalyzer(self._pd)
            for x, y in lines:
                plt.plot(x, y, "k-", linewidth=3, markeredgecolor="k")
            vmin = vmin_mev / 1000.0
            vmax = vmax_mev / 1000.0
            if energy_colormap == 'default':
                mid = - vmin / (vmax - vmin)
                cmap = LinearSegmentedColormap.from_list(
                    'my_colormap', [(0.0, '#005500'), (mid, '#55FF55'),
                                    (mid, '#FFAAAA'), (1.0, '#FF0000')])
            else:
                cmap = energy_colormap
            norm = Normalize(vmin=vmin, vmax=vmax)
            _map = ScalarMappable(norm=norm, cmap=cmap)
            _energies = [pda.get_equilibrium_reaction_energy(entry)
                         for coord, entry in labels.items()]
            energies = [en if en < 0.0 else -0.00000001 for en in _energies]
            vals_stable = _map.to_rgba(energies)
            ii = 0
            if process_attributes:
                for x, y in labels.keys():
                    if labels[(x, y)].attribute is None or \
                            labels[(x, y)].attribute == "existing":
                        plt.plot(x, y, "o", markerfacecolor=vals_stable[ii],
                                 markersize=12)
                    else:
                        plt.plot(x, y, "*", markerfacecolor=vals_stable[ii],
                                 markersize=18)
                    ii += 1
            else:
                for x, y in labels.keys():
                    plt.plot(x, y, "o", markerfacecolor=vals_stable[ii],
                             markersize=15)
                    ii += 1

        font = FontProperties()
        font.set_weight("bold")
        font.set_size(24)

        # Sets a nice layout depending on the type of PD. Also defines a
        # "center" for the PD, which then allows the annotations to be spread
        # out in a nice manner.
        if len(self._pd.elements) == 3:
            plt.axis("equal")
            plt.xlim((-0.1, 1.2))
            plt.ylim((-0.1, 1.0))
            plt.axis("off")
            center = (0.5, math.sqrt(3) / 6)
        else:
            all_coords = labels.keys()
            miny = min([c[1] for c in all_coords])
            ybuffer = max(abs(miny) * 0.1, 0.1)
            plt.xlim((-0.1, 1.1))
            plt.ylim((miny - ybuffer, ybuffer))
            center = (0.5, miny / 2)
            plt.xlabel("Fraction", fontsize=28, fontweight='bold')
            plt.ylabel("Formation energy (eV/fu)", fontsize=28,
                       fontweight='bold')

        for coords in sorted(labels.keys(), key=lambda x: -x[1]):
            entry = labels[coords]
            label = entry.name

            # The follow defines an offset for the annotation text emanating
            # from the center of the PD. Results in fairly nice layouts for the
            # most part.
            vec = (np.array(coords) - center)
            vec = vec / np.linalg.norm(vec) * 10 if np.linalg.norm(vec) != 0 \
                else vec
            valign = "bottom" if vec[1] > 0 else "top"
            if vec[0] < -0.01:
                halign = "right"
            elif vec[0] > 0.01:
                halign = "left"
            else:
                halign = "center"
            if label_stable:
                if process_attributes and entry.attribute == 'new':
                    plt.annotate(latexify(label), coords, xytext=vec,
                                 textcoords="offset points",
                                 horizontalalignment=halign,
                                 verticalalignment=valign,
                                 fontproperties=font,
                                 color='g')
                else:
                    plt.annotate(latexify(label), coords, xytext=vec,
                                 textcoords="offset points",
                                 horizontalalignment=halign,
                                 verticalalignment=valign,
                                 fontproperties=font)

        if self.show_unstable:
            font = FontProperties()
            font.set_size(16)
            pda = PDAnalyzer(self._pd)
            energies_unstable = [pda.get_e_above_hull(entry)
                                 for entry, coord in unstable.items()]
            if energy_colormap is not None:
                energies.extend(energies_unstable)
                vals_unstable = _map.to_rgba(energies_unstable)
            ii = 0
            for entry, coords in unstable.items():
                ehull = pda.get_e_above_hull(entry)
                if ehull < self.show_unstable:
                    vec = (np.array(coords) - center)
                    vec = vec / np.linalg.norm(vec) * 10 \
                        if np.linalg.norm(vec) != 0 else vec
                    label = entry.name
                    if energy_colormap is None:
                        plt.plot(coords[0], coords[1], "ks", linewidth=3,
                                 markeredgecolor="k", markerfacecolor="r",
                                 markersize=8)
                    else:
                        plt.plot(coords[0], coords[1], "s", linewidth=3,
                                 markeredgecolor="k",
                                 markerfacecolor=vals_unstable[ii],
                                 markersize=8)
                    if label_unstable:
                        plt.annotate(latexify(label), coords, xytext=vec,
                                     textcoords="offset points",
                                     horizontalalignment=halign, color="b",
                                     verticalalignment=valign,
                                     fontproperties=font)
                    ii += 1
        if energy_colormap is not None and show_colorbar:
            _map.set_array(energies)
            cbar = plt.colorbar(_map)
            cbar.set_label(
                'Energy [meV/at] above hull (in red)\nInverse energy ['
                'meV/at] above hull (in green)',
                rotation=-90, ha='left', va='center')
            ticks = cbar.ax.get_yticklabels()
            # cbar.ax.set_yticklabels(['${v}$'.format(
            #     v=float(t.get_text().strip('$'))*1000.0) for t in ticks])
        f = plt.gcf()
        f.set_size_inches((8, 6))
        plt.subplots_adjust(left=0.09, right=0.98, top=0.98, bottom=0.07)
        return plt
Exemple #50
0
    def get_pourbaix_plot_colorfill_by_domain_name(self, limits=None, title="",
            label_domains=True, label_color='k', domain_color=None, domain_fontsize=None,
            domain_edge_lw=0.5, bold_domains=None, cluster_domains=(),
            add_h2o_stablity_line=True, add_center_line=False, h2o_lw=0.5):
        """
        Color domains by the colors specific by the domain_color dict

        Args:
            limits: 2D list containing limits of the Pourbaix diagram
                of the form [[xlo, xhi], [ylo, yhi]]
            lable_domains (Bool): whether add the text lable for domains
            label_color (str): color of domain lables, defaults to be black
            domain_color (dict): colors of each domain e.g {"Al(s)": "#FF1100"}. If set
                to None default color set will be used.
            domain_fontsize (int): Font size used in domain text labels.
            domain_edge_lw (int): line width for the boundaries between domains.
            bold_domains (list): List of domain names to use bold text style for domain
                lables.
            cluster_domains (list): List of domain names in cluster phase
            add_h2o_stablity_line (Bool): whether plot H2O stability line
            add_center_line (Bool): whether plot lines shows the center coordinate
            h2o_lw (int): line width for H2O stability line and center lines
        """
        # helper functions
        def len_elts(entry):
            comp = Composition(entry[:-3]) if "(s)" in entry else Ion.from_formula(entry)
            return len(set(comp.elements) - {Element("H"), Element("O")})

        def special_lines(xlim, ylim):
            h_line = np.transpose([[xlim[0], -xlim[0] * PREFAC],
                                   [xlim[1], -xlim[1] * PREFAC]])
            o_line = np.transpose([[xlim[0], -xlim[0] * PREFAC + 1.23],
                                   [xlim[1], -xlim[1] * PREFAC + 1.23]])
            neutral_line = np.transpose([[7, ylim[0]], [7, ylim[1]]])
            V0_line = np.transpose([[xlim[0], 0], [xlim[1], 0]])
            return h_line, o_line, neutral_line, V0_line

        from matplotlib.patches import Polygon
        from pymatgen import Composition, Element
        from pymatgen.core.ion import Ion

        default_domain_font_size = 12
        default_solid_phase_color = '#b8f9e7'    # this slighly darker than the MP scheme, to
        default_cluster_phase_color = '#d0fbef'  # avoid making the cluster phase too light

        plt = pretty_plot(8, dpi=300)

        (stable, unstable) = self.pourbaix_plot_data(limits)
        num_of_overlaps = {key: 0 for key in stable.keys()}
        entry_dict_of_multientries = collections.defaultdict(list)
        for entry in stable:
            if isinstance(entry, MultiEntry):
                for e in entry.entrylist:
                    entry_dict_of_multientries[e.name].append(entry)
                    num_of_overlaps[entry] += 1
            else:
                entry_dict_of_multientries[entry.name].append(entry)

        xlim, ylim = limits[:2] if limits else self._analyzer.chempot_limits[:2]
        h_line, o_line, neutral_line, V0_line = special_lines(xlim, ylim)
        ax = plt.gca()
        ax.set_xlim(xlim)
        ax.set_ylim(ylim)
        ax.xaxis.set_major_formatter(FormatStrFormatter('%.1f'))
        ax.yaxis.set_major_formatter(FormatStrFormatter('%.1f'))
        ax.tick_params(direction='out')
        ax.xaxis.set_ticks_position('bottom')
        ax.yaxis.set_ticks_position('left')

        sorted_entry = list(entry_dict_of_multientries.keys())
        sorted_entry.sort(key=len_elts)

        if domain_fontsize is None:
            domain_fontsize = {en: default_domain_font_size for en in sorted_entry}
        if domain_color is None:
            domain_color = {en: default_solid_phase_color if '(s)' in en else
                            (default_cluster_phase_color if en in cluster_domains else 'w')
                            for i, en in enumerate(sorted_entry)}
        if bold_domains is None:
            bold_domains = [en for en in sorted_entry if '(s)' not in en]

        for entry in sorted_entry:
            x_coord, y_coord, npts = 0.0, 0.0, 0
            for e in entry_dict_of_multientries[entry]:
                xy = self.domain_vertices(e)
                if add_h2o_stablity_line:
                    c = self.get_distribution_corrected_center(stable[e], h_line, o_line, 0.3)
                else:
                    c = self.get_distribution_corrected_center(stable[e])
                x_coord += c[0]
                y_coord += c[1]
                npts += 1
                patch = Polygon(xy, facecolor=domain_color[entry],
                                closed=True, lw=domain_edge_lw, fill=True, antialiased=True)
                ax.add_patch(patch)
            xy_center = (x_coord / npts, y_coord / npts)
            if label_domains:
                if platform.system() == 'Darwin':
                    # Have to hack to the hard coded font path to get current font On Mac OS X
                    if entry in bold_domains:
                        font = FontProperties(fname='/Library/Fonts/Times New Roman Bold.ttf',
                                              size=domain_fontsize[entry])
                    else:
                        font = FontProperties(fname='/Library/Fonts/Times New Roman.ttf',
                                              size=domain_fontsize[entry])
                else:
                    if entry in bold_domains:
                        font = FontProperties(family='Times New Roman',
                                              weight='bold',
                                              size=domain_fontsize[entry])
                    else:
                        font = FontProperties(family='Times New Roman',
                                              weight='regular',
                                              size=domain_fontsize[entry])
                plt.text(*xy_center, s=latexify_ion(latexify(entry)), fontproperties=font,
                         horizontalalignment="center", verticalalignment="center",
                         multialignment="center", color=label_color)

        if add_h2o_stablity_line:
            dashes = (3, 1.5)
            line, = plt.plot(h_line[0], h_line[1], "k--", linewidth=h2o_lw, antialiased=True)
            line.set_dashes(dashes)
            line, = plt.plot(o_line[0], o_line[1], "k--", linewidth=h2o_lw, antialiased=True)
            line.set_dashes(dashes)
        if add_center_line:
            plt.plot(neutral_line[0], neutral_line[1], "k-.", linewidth=h2o_lw, antialiased=False)
            plt.plot(V0_line[0], V0_line[1], "k-.", linewidth=h2o_lw, antialiased=False)

        plt.xlabel("pH", fontname="Times New Roman", fontsize=18)
        plt.ylabel("E (V)", fontname="Times New Roman", fontsize=18)
        plt.xticks(fontname="Times New Roman", fontsize=16)
        plt.yticks(fontname="Times New Roman", fontsize=16)
        plt.title(title, fontsize=20, fontweight='bold', fontname="Times New Roman")
        return plt
Exemple #51
0
    def get_chempot_range_map_plot(self, elements,referenced=True):
        """
        Returns a plot of the chemical potential range _map. Currently works
        only for 3-component PDs.

        Args:
            elements: Sequence of elements to be considered as independent
                variables. E.g., if you want to show the stability ranges of
                all Li-Co-O phases wrt to uLi and uO, you will supply
                [Element("Li"), Element("O")]
            referenced: if True, gives the results with a reference being the
                        energy of the elemental phase. If False, gives absolute values.

        Returns:
            A matplotlib plot object.
        """

        plt = pretty_plot(12, 8)
        analyzer = PDAnalyzer(self._pd)
        chempot_ranges = analyzer.get_chempot_range_map(
            elements, referenced=referenced)
        missing_lines = {}
        excluded_region = []
        for entry, lines in chempot_ranges.items():
            comp = entry.composition
            center_x = 0
            center_y = 0
            coords = []
            contain_zero = any([comp.get_atomic_fraction(el) == 0
                                for el in elements])
            is_boundary = (not contain_zero) and \
                sum([comp.get_atomic_fraction(el) for el in elements]) == 1
            for line in lines:
                (x, y) = line.coords.transpose()
                plt.plot(x, y, "k-")

                for coord in line.coords:
                    if not in_coord_list(coords, coord):
                        coords.append(coord.tolist())
                        center_x += coord[0]
                        center_y += coord[1]
                if is_boundary:
                    excluded_region.extend(line.coords)

            if coords and contain_zero:
                missing_lines[entry] = coords
            else:
                xy = (center_x / len(coords), center_y / len(coords))
                plt.annotate(latexify(entry.name), xy, fontsize=22)

        ax = plt.gca()
        xlim = ax.get_xlim()
        ylim = ax.get_ylim()

        #Shade the forbidden chemical potential regions.
        excluded_region.append([xlim[1], ylim[1]])
        excluded_region = sorted(excluded_region, key=lambda c: c[0])
        (x, y) = np.transpose(excluded_region)
        plt.fill(x, y, "0.80")

        #The hull does not generate the missing horizontal and vertical lines.
        #The following code fixes this.
        el0 = elements[0]
        el1 = elements[1]
        for entry, coords in missing_lines.items():
            center_x = sum([c[0] for c in coords])
            center_y = sum([c[1] for c in coords])
            comp = entry.composition
            is_x = comp.get_atomic_fraction(el0) < 0.01
            is_y = comp.get_atomic_fraction(el1) < 0.01
            n = len(coords)
            if not (is_x and is_y):
                if is_x:
                    coords = sorted(coords, key=lambda c: c[1])
                    for i in [0, -1]:
                        x = [min(xlim), coords[i][0]]
                        y = [coords[i][1], coords[i][1]]
                        plt.plot(x, y, "k")
                        center_x += min(xlim)
                        center_y += coords[i][1]
                elif is_y:
                    coords = sorted(coords, key=lambda c: c[0])
                    for i in [0, -1]:
                        x = [coords[i][0], coords[i][0]]
                        y = [coords[i][1], min(ylim)]
                        plt.plot(x, y, "k")
                        center_x += coords[i][0]
                        center_y += min(ylim)
                xy = (center_x / (n + 2), center_y / (n + 2))
            else:
                center_x = sum(coord[0] for coord in coords) + xlim[0]
                center_y = sum(coord[1] for coord in coords) + ylim[0]
                xy = (center_x / (n + 1), center_y / (n + 1))

            plt.annotate(latexify(entry.name), xy,
                         horizontalalignment="center",
                         verticalalignment="center", fontsize=22)

        plt.xlabel("$\\mu_{{{0}}} - \\mu_{{{0}}}^0$ (eV)"
                   .format(el0.symbol))
        plt.ylabel("$\\mu_{{{0}}} - \\mu_{{{0}}}^0$ (eV)"
                   .format(el1.symbol))
        plt.tight_layout()
        return plt
Exemple #52
0
    def get_pourbaix_plot(self, limits=None, title="", label_domains=True):
        """
        Plot Pourbaix diagram.

        Args:
            limits: 2D list containing limits of the Pourbaix diagram
                of the form [[xlo, xhi], [ylo, yhi]]

        Returns:
            plt:
                matplotlib plot object
        """
#        plt = pretty_plot(24, 14.4)
        plt = pretty_plot(16)
        (stable, unstable) = self.pourbaix_plot_data(limits)
        if limits:
            xlim = limits[0]
            ylim = limits[1]
        else:
            xlim = self._analyzer.chempot_limits[0]
            ylim = self._analyzer.chempot_limits[1]

        h_line = np.transpose([[xlim[0], -xlim[0] * PREFAC],
                               [xlim[1], -xlim[1] * PREFAC]])
        o_line = np.transpose([[xlim[0], -xlim[0] * PREFAC + 1.23],
                               [xlim[1], -xlim[1] * PREFAC + 1.23]])
        neutral_line = np.transpose([[7, ylim[0]], [7, ylim[1]]])
        V0_line = np.transpose([[xlim[0], 0], [xlim[1], 0]])

        ax = plt.gca()
        ax.set_xlim(xlim)
        ax.set_ylim(ylim)
        lw = 3
        plt.plot(h_line[0], h_line[1], "r--", linewidth=lw)
        plt.plot(o_line[0], o_line[1], "r--", linewidth=lw)
        plt.plot(neutral_line[0], neutral_line[1], "k-.", linewidth=lw)
        plt.plot(V0_line[0], V0_line[1], "k-.", linewidth=lw)

        for entry, lines in stable.items():
            center_x = 0.0
            center_y = 0.0
            coords = []
            count_center = 0.0
            for line in lines:
                (x, y) = line
                plt.plot(x, y, "k-", linewidth=lw)
                for coord in np.array(line).T:
                    if not in_coord_list(coords, coord):
                        coords.append(coord.tolist())
                        cx = coord[0]
                        cy = coord[1]
                        center_x += cx
                        center_y += cy
                        count_center += 1.0
            if count_center == 0.0:
                count_center = 1.0
            center_x /= count_center
            center_y /= count_center
            if ((center_x <= xlim[0]) | (center_x >= xlim[1]) |
                    (center_y <= ylim[0]) | (center_y >= ylim[1])):
                continue
            xy = (center_x, center_y)
            if label_domains:
                plt.annotate(self.print_name(entry), xy, fontsize=20, color="b")

        plt.xlabel("pH")
        plt.ylabel("E (V)")
        plt.title(title, fontsize=20, fontweight='bold')
        return plt
Exemple #53
0
    def get_pourbaix_plot_colorfill_by_element(self, limits=None, title="",
                                                label_domains=True, element=None):
        """
        Color domains by element
        """
        from matplotlib.patches import Polygon

        entry_dict_of_multientries = collections.defaultdict(list)
        plt = pretty_plot(16)
        optim_colors = ['#0000FF', '#FF0000', '#00FF00', '#FFFF00', '#FF00FF',
                         '#FF8080', '#DCDCDC', '#800000', '#FF8000']
        optim_font_color = ['#FFFFA0', '#00FFFF', '#FF00FF', '#0000FF', '#00FF00',
                            '#007F7F', '#232323', '#7FFFFF', '#007FFF']
        hatch = ['/', '\\', '|', '-', '+', 'o', '*']
        (stable, unstable) = self.pourbaix_plot_data(limits)
        num_of_overlaps = {key: 0 for key in stable.keys()}
        for entry in stable:
            if isinstance(entry, MultiEntry):
                for e in entry.entrylist:
                    if element in e.composition.elements:
                        entry_dict_of_multientries[e.name].append(entry)
                        num_of_overlaps[entry] += 1
            else:
                entry_dict_of_multientries[entry.name].append(entry)
        if limits:
            xlim = limits[0]
            ylim = limits[1]
        else:
            xlim = self._analyzer.chempot_limits[0]
            ylim = self._analyzer.chempot_limits[1]

        h_line = np.transpose([[xlim[0], -xlim[0] * PREFAC],
                               [xlim[1], -xlim[1] * PREFAC]])
        o_line = np.transpose([[xlim[0], -xlim[0] * PREFAC + 1.23],
                               [xlim[1], -xlim[1] * PREFAC + 1.23]])
        neutral_line = np.transpose([[7, ylim[0]], [7, ylim[1]]])
        V0_line = np.transpose([[xlim[0], 0], [xlim[1], 0]])

        ax = plt.gca()
        ax.set_xlim(xlim)
        ax.set_ylim(ylim)
        from pymatgen import Composition, Element
        from pymatgen.core.ion import Ion

        def len_elts(entry):
            if "(s)" in entry:
                comp = Composition(entry[:-3])
            else:
                comp = Ion.from_formula(entry)
            return len([el for el in comp.elements if el not in
                        [Element("H"), Element("O")]])

        sorted_entry = entry_dict_of_multientries.keys()
        sorted_entry.sort(key=len_elts)
        i = -1
        label_chr = map(chr, list(range(65, 91)))
        for entry in sorted_entry:
            color_indx = 0
            x_coord = 0.0
            y_coord = 0.0
            npts = 0
            i += 1
            for e in entry_dict_of_multientries[entry]:
                hc = 0
                fc = 0
                bc = 0
                xy = self.domain_vertices(e)
                c = self.get_center(stable[e])
                x_coord += c[0]
                y_coord += c[1]
                npts += 1
                color_indx = i
                if "(s)" in entry:
                    comp = Composition(entry[:-3])
                else:
                    comp = Ion.from_formula(entry)
                if len([el for el in comp.elements if el not in
                         [Element("H"), Element("O")]]) == 1:
                    if color_indx >= len(optim_colors):
                        color_indx = color_indx -\
                         int(color_indx / len(optim_colors)) * len(optim_colors)
                    patch = Polygon(xy, facecolor=optim_colors[color_indx],
                                     closed=True, lw=3.0, fill=True)
                    bc = optim_colors[color_indx]
                else:
                    if color_indx >= len(hatch):
                        color_indx = color_indx - int(color_indx / len(hatch)) * len(hatch)
                    patch = Polygon(xy, hatch=hatch[color_indx], closed=True, lw=3.0, fill=False)
                    hc = hatch[color_indx]
                ax.add_patch(patch)

            xy_center = (x_coord / npts, y_coord / npts)
            if label_domains:
                if color_indx >= len(optim_colors):
                    color_indx = color_indx -\
                        int(color_indx / len(optim_colors)) * len(optim_colors)
                fc = optim_font_color[color_indx]
                if bc and not hc:
                    bbox = dict(boxstyle="round", fc=fc)
                if hc and not bc:
                    bc = 'k'
                    fc = 'w'
                    bbox = dict(boxstyle="round", hatch=hc, fill=False)
                if bc and hc:
                    bbox = dict(boxstyle="round", hatch=hc, fc=fc)
#                 bbox.set_path_effects([PathEffects.withSimplePatchShadow()])
                plt.annotate(latexify_ion(latexify(entry)), xy_center,
                            color=bc, fontsize=30, bbox=bbox)
#                 plt.annotate(label_chr[i], xy_center,
#                               color=bc, fontsize=30, bbox=bbox)

        lw = 3
        plt.plot(h_line[0], h_line[1], "r--", linewidth=lw)
        plt.plot(o_line[0], o_line[1], "r--", linewidth=lw)
        plt.plot(neutral_line[0], neutral_line[1], "k-.", linewidth=lw)
        plt.plot(V0_line[0], V0_line[1], "k-.", linewidth=lw)

        plt.xlabel("pH")
        plt.ylabel("E (V)")
        plt.title(title, fontsize=20, fontweight='bold')
        return plt
Exemple #54
0
    def get_pourbaix_plot_colorfill_by_domain_name(self, limits=None, title="",
            label_domains=True, label_color='k', domain_color=None, domain_fontsize=None,
            domain_edge_lw=0.5, bold_domains=None, cluster_domains=(),
            add_h2o_stablity_line=True, add_center_line=False, h2o_lw=0.5,
            fill_domain=True, width=8, height=None, font_family='Times New Roman'):
        """
        Color domains by the colors specific by the domain_color dict

        Args:
            limits: 2D list containing limits of the Pourbaix diagram
                of the form [[xlo, xhi], [ylo, yhi]]
            lable_domains (Bool): whether add the text lable for domains
            label_color (str): color of domain lables, defaults to be black
            domain_color (dict): colors of each domain e.g {"Al(s)": "#FF1100"}. If set
                to None default color set will be used.
            domain_fontsize (int): Font size used in domain text labels.
            domain_edge_lw (int): line width for the boundaries between domains.
            bold_domains (list): List of domain names to use bold text style for domain
                lables. If set to False, no domain will be bold.
            cluster_domains (list): List of domain names in cluster phase
            add_h2o_stablity_line (Bool): whether plot H2O stability line
            add_center_line (Bool): whether plot lines shows the center coordinate
            h2o_lw (int): line width for H2O stability line and center lines
            fill_domain (bool): a version without color will be product if set 
                to False.
            width (float): Width of plot in inches. Defaults to 8in.
                height (float): Height of plot in inches. Defaults to width * golden
                ratio.
            font_family (str): Font family of the labels
        """

        def special_lines(xlim, ylim):
            h_line = np.transpose([[xlim[0], -xlim[0] * PREFAC],
                                   [xlim[1], -xlim[1] * PREFAC]])
            o_line = np.transpose([[xlim[0], -xlim[0] * PREFAC + 1.23],
                                   [xlim[1], -xlim[1] * PREFAC + 1.23]])
            neutral_line = np.transpose([[7, ylim[0]], [7, ylim[1]]])
            V0_line = np.transpose([[xlim[0], 0], [xlim[1], 0]])
            return h_line, o_line, neutral_line, V0_line

        from matplotlib.patches import Polygon
        import copy

        default_domain_font_size = 12
        default_solid_phase_color = '#b8f9e7'    # this slighly darker than the MP scheme, to
        default_cluster_phase_color = '#d0fbef'  # avoid making the cluster phase too light

        plt = pretty_plot(width=width, height=height, dpi=300)

        (stable, unstable) = self.pourbaix_plot_data(limits)

        xlim, ylim = limits[:2] if limits else self._analyzer.chempot_limits[:2]
        h_line, o_line, neutral_line, V0_line = special_lines(xlim, ylim)
        ax = plt.gca()
        ax.set_xlim(xlim)
        ax.set_ylim(ylim)
        ax.xaxis.set_major_formatter(FormatStrFormatter('%d'))
        ax.yaxis.set_major_formatter(FormatStrFormatter('%.1f'))
        ax.tick_params(direction='out')
        ax.xaxis.set_ticks_position('bottom')
        ax.yaxis.set_ticks_position('left')

        sorted_entry = list(stable.keys())
        sorted_entry.sort(key=lambda en: en.energy, reverse=True)

        if domain_fontsize is None:
            domain_fontsize = {en.name: default_domain_font_size for en in sorted_entry}
        elif not isinstance(domain_fontsize, dict):
            domain_fontsize = {en.name: domain_fontsize for en in sorted_entry}
        if domain_color is None:
            domain_color = {en.name: default_solid_phase_color if '(s)' in en.name else
                            (default_cluster_phase_color if en.name in cluster_domains else 'w')
                            for i, en in enumerate(sorted_entry)}
        else:
            domain_color = {en.name: domain_color[en.name] if en.name in domain_color else "w"
                            for i, en in enumerate(sorted_entry)}
        if bold_domains is None:
            bold_domains = [en.name for en in sorted_entry if '(s)' not in en.name]
        if bold_domains == False:
            bold_domains = []

        for entry in sorted_entry:
            xy = self.domain_vertices(entry)
            if add_h2o_stablity_line:
                c = self.get_distribution_corrected_center(stable[entry], h_line, o_line, 0.3)
            else:
                c = self.get_distribution_corrected_center(stable[entry])
            patch = Polygon(xy, facecolor=domain_color[entry.name], edgecolor="black",
                            closed=True, lw=domain_edge_lw, fill=fill_domain, antialiased=True)
            ax.add_patch(patch)
            if label_domains:
                if platform.system() == 'Darwin' and font_family == "Times New Roman":
                    # Have to hack to the hard coded font path to get current font On Mac OS X
                    if entry.name in bold_domains:
                        font = FontProperties(fname='/Library/Fonts/Times New Roman Bold.ttf',
                                              size=domain_fontsize[entry.name])
                    else:
                        font = FontProperties(fname='/Library/Fonts/Times New Roman.ttf',
                                              size=domain_fontsize[entry.name])
                else:
                    if entry.name in bold_domains:
                        font = FontProperties(family=font_family,
                                              weight='bold',
                                              size=domain_fontsize[entry.name])
                    else:
                        font = FontProperties(family=font_family,
                                              weight='regular',
                                              size=domain_fontsize[entry.name])
                plt.text(*c, s=self.print_name(entry),
                         fontproperties=font,
                         horizontalalignment="center", verticalalignment="center",
                         multialignment="center", color=label_color)

        if add_h2o_stablity_line:
            dashes = (3, 1.5)
            line, = plt.plot(h_line[0], h_line[1], "k--", linewidth=h2o_lw, antialiased=True)
            line.set_dashes(dashes)
            line, = plt.plot(o_line[0], o_line[1], "k--", linewidth=h2o_lw, antialiased=True)
            line.set_dashes(dashes)
        if add_center_line:
            plt.plot(neutral_line[0], neutral_line[1], "k-.", linewidth=h2o_lw, antialiased=False)
            plt.plot(V0_line[0], V0_line[1], "k-.", linewidth=h2o_lw, antialiased=False)

        plt.xlabel("pH", fontname=font_family, fontsize=18)
        plt.ylabel("E (V vs SHE)", fontname=font_family, fontsize=18)
        plt.xticks(fontname=font_family, fontsize=16)
        plt.yticks(fontname=font_family, fontsize=16)
        plt.title(title, fontsize=20, fontweight='bold', fontname=font_family)
        return plt
Exemple #55
0
    def get_pourbaix_mark_passive(self, limits=None, title="", label_domains=True, passive_entry=None):
        """
        Color domains by element
        """
        from matplotlib.patches import Polygon
        from pymatgen import Element
        from itertools import chain
        import operator

        plt = pretty_plot(16)
        optim_colors = ['#0000FF', '#FF0000', '#00FF00', '#FFFF00', '#FF00FF',
                        '#FF8080', '#DCDCDC', '#800000', '#FF8000']
        optim_font_colors = ['#FFC000', '#00FFFF', '#FF00FF', '#0000FF', '#00FF00',
                            '#007F7F', '#232323', '#7FFFFF', '#007FFF']
        (stable, unstable) = self.pourbaix_plot_data(limits)
        mark_passive = {key: 0 for key in stable.keys()}

        if self._pd._elt_comp:
            maxval = max(six.iteritems(self._pd._elt_comp), key=operator.itemgetter(1))[1]
            key = [k for k, v in self._pd._elt_comp.items() if v == maxval]
        passive_entry = key[0]

        def list_elts(entry):
            elts_list = set()
            if isinstance(entry, MultiEntry):
                for el in chain.from_iterable([[el for el in e.composition.elements]
                                                for e in entry.entrylist]):
                    elts_list.add(el)
            else:
                elts_list = entry.composition.elements
            return elts_list

        for entry in stable:
            if passive_entry + str("(s)") in entry.name:
                mark_passive[entry] = 2
                continue
            if "(s)" not in entry.name:
                continue
            elif len(set([Element("O"), Element("H")]).intersection(set(list_elts(entry)))) > 0:
                mark_passive[entry] = 1

        if limits:
            xlim = limits[0]
            ylim = limits[1]
        else:
            xlim = self._analyzer.chempot_limits[0]
            ylim = self._analyzer.chempot_limits[1]

        h_line = np.transpose([[xlim[0], -xlim[0] * PREFAC],
                               [xlim[1], -xlim[1] * PREFAC]])
        o_line = np.transpose([[xlim[0], -xlim[0] * PREFAC + 1.23],
                               [xlim[1], -xlim[1] * PREFAC + 1.23]])
        neutral_line = np.transpose([[7, ylim[0]], [7, ylim[1]]])
        V0_line = np.transpose([[xlim[0], 0], [xlim[1], 0]])

        ax = plt.gca()
        ax.set_xlim(xlim)
        ax.set_ylim(ylim)
        for e in stable.keys():
            xy = self.domain_vertices(e)
            c = self.get_center(stable[e])
            if mark_passive[e] == 1:
                color = optim_colors[0]
                fontcolor = optim_font_colors[0]
                colorfill = True
            elif mark_passive[e] == 2:
                color = optim_colors[1]
                fontcolor = optim_font_colors[1]
                colorfill = True
            else:
                color = "w"
                colorfill = False
                fontcolor = "k"
            patch = Polygon(xy, facecolor=color, closed=True, lw=3.0, fill=colorfill)
            ax.add_patch(patch)
            if label_domains:
                plt.annotate(self.print_name(e), c, color=fontcolor, fontsize=20)

        lw = 3
        plt.plot(h_line[0], h_line[1], "r--", linewidth=lw)
        plt.plot(o_line[0], o_line[1], "r--", linewidth=lw)
        plt.plot(neutral_line[0], neutral_line[1], "k-.", linewidth=lw)
        plt.plot(V0_line[0], V0_line[1], "k-.", linewidth=lw)

        plt.xlabel("pH")
        plt.ylabel("E (V)")
        plt.title(title, fontsize=20, fontweight='bold')
        return plt