def get_rdf_plot(self, label=None, xlim=[0.0, 8.0], ylim=[-0.005, 3.0]):
        """
        Plot the average RDF function.
        """

        if label is None:
            symbol_list = [e.symbol for e in
                           self.structures[0].composition.keys()]
            symbol_list = [symbol for symbol in symbol_list if
                           symbol in self.species]

            if len(symbol_list) == 1:
                label = symbol_list[0]
            else:
                label = "-".join(symbol_list)

        plt = get_publication_quality_plot(12, 8)
        plt.plot(self.interval, self.rdf, color="r", label=label, linewidth=4.0)
        plt.xlabel("$r$ ($\AA$)")
        plt.ylabel("$g(r)$")
        plt.legend(loc='upper right', fontsize=36)
        plt.xlim(xlim[0], xlim[1])
        plt.ylim(ylim[0], ylim[1])
        plt.tight_layout()

        return plt
Ejemplo n.º 2
0
    def get_scan_plot(self, coords=None):
        """
        Get a matplotlib plot of the potential energy surface.

        Args:
            coords: internal coordinate name to use as abcissa.
        """
        from pymatgen.util.plotting_utils import get_publication_quality_plot

        plt = get_publication_quality_plot(12, 8)

        d = self.read_scan()

        if coords and coords in d["coords"]:
            x = d["coords"][coords]
            plt.xlabel(coords)
        else:
            x = range(len(d["energies"]))
            plt.xlabel("points")

        plt.ylabel("Energy   /   eV")

        e_min = min(d["energies"])
        y = [(e - e_min) * HARTREE_TO_ELECTRON_VOLT for e in d["energies"]]

        plt.plot(x, y, "ro--")
        return plt
Ejemplo n.º 3
0
    def get_rdf_plot(self, label=None, xlim=[0.0, 8.0], ylim=[-0.005, 3.0]):
        """
        Plot the average RDF function.
        """

        if label is None:
            symbol_list = [e.symbol for e in
                           self.structures[0].composition.keys()]
            symbol_list = [symbol for symbol in symbol_list if
                           symbol in self.species]

            if len(symbol_list) == 1:
                label = symbol_list[0]
            else:
                label = "-".join(symbol_list)

        plt = get_publication_quality_plot(12, 8)
        plt.plot(self.interval, self.rdf, color="r", label=label, linewidth=4.0)
        plt.xlabel("$r$ ($\AA$)")
        plt.ylabel("$g(r)$")
        plt.legend(loc='upper right', fontsize=36)
        plt.xlim(xlim[0], xlim[1])
        plt.ylim(ylim[0], ylim[1])
        plt.tight_layout()

        return plt
Ejemplo n.º 4
0
    def get_xrd_plot(self, structure, two_theta_range=(0, 90),
                     annotate_peaks=True):
        """
        Returns the XRD plot as a matplotlib.pyplot.

        Args:
            structure: Input structure
            two_theta_range ([float of length 2]): Tuple for range of
                two_thetas to calculate in degrees. Defaults to (0, 90). Set to
                None if you want all diffracted beams within the limiting
                sphere of radius 2 / wavelength.
            annotate_peaks: Whether to annotate the peaks with plane
                information.

        Returns:
            (matplotlib.pyplot)
        """
        from pymatgen.util.plotting_utils import get_publication_quality_plot
        plt = get_publication_quality_plot(16, 10)
        for two_theta, i, hkls, d_hkl in self.get_xrd_data(
                structure, two_theta_range=two_theta_range):
            if two_theta_range[0] <= two_theta <= two_theta_range[1]:
                label = ", ".join([str(hkl) for hkl in hkls.keys()])
                plt.plot([two_theta, two_theta], [0, i], color='k',
                         linewidth=3, label=label)
                if annotate_peaks:
                    plt.annotate(label, xy=[two_theta, i],
                                 xytext=[two_theta, i], fontsize=16)
        plt.xlabel(r"2\theta (degrees)")
        plt.ylabel("Intensities (scaled)")
        plt.tight_layout()

        return plt
Ejemplo n.º 5
0
    def get_msd_plot(self, plt=None, mode="specie"):
        """
        Get the plot of the smoothed msd vs time graph. Useful for
        checking convergence. This can be written to an image file.

        Args:
            plt: A plot object. Defaults to None, which means one will be
                generated.
        """
        from pymatgen.util.plotting_utils import get_publication_quality_plot
        plt = get_publication_quality_plot(12, 8, plt=plt)

        if mode == "species":
            for sp in sorted(self.structure.composition.keys()):
                indices = [i for i, site in enumerate(self.structure) if
                           site.specie == sp]
                sd = np.average(self.sq_disp_ions[indices, :], axis=0)
                plt.plot(self.dt, sd, label=sp.__str__())
            plt.legend(loc=2, prop={"size": 20})
        elif mode == "ions":
            for i, site in enumerate(self.structure):
                sd = self.sq_disp_ions[i, :]
                plt.plot(self.dt, sd, label="%s - %d" % (
                    site.specie.__str__(), i))
            plt.legend(loc=2, prop={"size": 20})
        else: #Handle default / invalid mode case
            plt.plot(self.dt, self.msd, 'k')
            plt.plot(self.dt, self.msd_components[:, 0], 'r')
            plt.plot(self.dt, self.msd_components[:, 1], 'g')
            plt.plot(self.dt, self.msd_components[:, 2], 'b')
            plt.legend(["Overall", "a", "b", "c"], loc=2, prop={"size": 20})
        plt.xlabel("Timestep (fs)")
        plt.ylabel("MSD ($\AA^2$)")
        plt.tight_layout()
        return plt
Ejemplo n.º 6
0
    def get_plot(self, width, height):
        """
        Returns a plot object.

        Args:
            width:
                Width of the plot.
            height:
                Height of the plot.

        Returns:
            A matplotlib plot object.
        """
        plt = get_publication_quality_plot(width, height)
        for label, electrode in self._electrodes.items():
            (x, y) = self.get_plot_data(electrode)
            plt.plot(x, y, '-', linewidth=2, label=label)

        plt.legend()
        if self.xaxis == "capacity":
            plt.xlabel('Capacity (mAh/g)')
        else:
            plt.xlabel('Fraction')
        plt.ylabel('Voltage (V)')
        plt.tight_layout()
        return plt
    def get_plot(self, plt=None, mode="specie", element="Na"):

        from pymatgen.util.plotting_utils import get_publication_quality_plot
        plt = get_publication_quality_plot(12, 8, plt=plt)

        if mode == "species":
            for sp in sorted(self.material.composition.keys()):
                indices = [i for i, site in enumerate(self.material) if site.specie == sp]
                sd = np.average(self.sq_disp_ions[indices, :], axis=0)
                plt.plot(self.dt, sd, label=sp.__str__())
            plt.legend(loc=2, prop={"size": 20})
        elif mode == "ions":
            for i, site in enumerate(self.material):
                sd = self.sq_disp_ions[i, :]
                if site.specie.__str__() == element:
                    plt.plot(self.dt, sd)
            plt.legend(loc=2, prop={"size": 20})
        else: 
            plt.plot(self.dt, self.msd, 'k')
            plt.plot(self.dt, self.msd_comp[:, 0], 'r')
            plt.plot(self.dt, self.msd_comp[:, 1], 'g')
            plt.plot(self.dt, self.msd_comp[:, 2], 'b')
            plt.legend(["Overall", "a", "b", "c"], loc=2, prop={"size": 20})
        plt.xlabel("Timestep (fs)")
        plt.ylabel("MSD ($\AA^2$)")
        plt.tight_layout()
        return plt
Ejemplo n.º 8
0
    def get_msd_plot(self, plt=None, mode="specie"):
        """
        Get the plot of the smoothed msd vs time graph. Useful for
        checking convergence. This can be written to an image file.

        Args:
            plt: A plot object. Defaults to None, which means one will be
                generated.
        """
        from pymatgen.util.plotting_utils import get_publication_quality_plot
        plt = get_publication_quality_plot(12, 8, plt=plt)

        if mode == "species":
            for sp in sorted(self.structure.composition.keys()):
                indices = [i for i, site in enumerate(self.structure) if
                           site.specie == sp]
                sd = np.average(self.sq_disp_ions[indices, :], axis=0)
                plt.plot(self.dt, sd, label=sp.__str__())
            plt.legend(loc=2, prop={"size": 20})
        elif mode == "ions":
            for i, site in enumerate(self.structure):
                sd = self.sq_disp_ions[i, :]
                plt.plot(self.dt, sd, label="%s - %d" % (
                    site.specie.__str__(), i))
            plt.legend(loc=2, prop={"size": 20})
        else: #Handle default / invalid mode case
            plt.plot(self.dt, self.msd, 'k')
            plt.plot(self.dt, self.msd_components[:, 0], 'r')
            plt.plot(self.dt, self.msd_components[:, 1], 'g')
            plt.plot(self.dt, self.msd_components[:, 2], 'b')
            plt.legend(["Overall", "a", "b", "c"], loc=2, prop={"size": 20})
        plt.xlabel("Timestep (fs)")
        plt.ylabel("MSD ($\AA^2$)")
        plt.tight_layout()
        return plt
Ejemplo n.º 9
0
    def get_plot(self, normalize_rxn_coordinate=True, label_barrier=True):
        """
        Returns the NEB plot. Uses Henkelman's approach of spline fitting
        each section of the reaction path based on tangent force and energies.

        Args:
            normalize_rxn_coordinate (bool): Whether to normalize the
                reaction coordinate to between 0 and 1. Defaults to True.
            label_barrier (bool): Whether to label the maximum barrier.

        Returns:
            matplotlib.pyplot object.
        """
        plt = get_publication_quality_plot(12, 8)
        scale = 1 if not normalize_rxn_coordinate else 1 / self.r[-1]
        x = np.arange(0, np.max(self.r), 0.01)
        y = self.spline(x) * 1000
        plt.plot(self.r * scale, self.energies * 1000, 'ro',
                 x * scale, y, 'k-', linewidth=2, markersize=10)
        plt.xlabel("Reaction coordinate")
        plt.ylabel("Energy (meV)")
        plt.ylim((np.min(y) - 10, np.max(y) * 1.02 + 20))
        if label_barrier:
            data = zip(x * scale, y)
            barrier = max(data, key=lambda d: d[1])
            plt.plot([0, barrier[0]], [barrier[1], barrier[1]], 'k--')
            plt.annotate('%.0f meV' % barrier[1],
                         xy=(barrier[0] / 2, barrier[1] * 1.02),
                         xytext=(barrier[0] / 2, barrier[1] * 1.02),
                         horizontalalignment='center')
        plt.tight_layout()
        return plt
Ejemplo n.º 10
0
    def plot_conc_temp(self, me=[1.0, 1.0, 1.0], mh=[1.0, 1.0, 1.0]):
        """
        plot the concentration of carriers vs temperature both in eq and non-eq after quenching at 300K
        Args:
            me:
                the effective mass for the electrons as a list of 3 eigenvalues
            mh:
                the effective mass for the holes as a list of 3 eigenvalues
        Returns;
            a matplotlib object

        """
        temps = [i * 100 for i in range(3, 20)]
        qi = []
        qi_non_eq = []
        for t in temps:
            qi.append(self._analyzer.get_eq_Ef(t, me, mh)['Qi'] * 1e-6)
            qi_non_eq.append(
                self._analyzer.get_non_eq_Ef(t, 300, me, mh)['Qi'] * 1e-6)

        plt = get_publication_quality_plot(12, 8)
        plt.xlabel("temperature (K)")
        plt.ylabel("carrier concentration (cm$^{-3}$)")
        plt.semilogy(temps, qi, linewidth=3.0)
        plt.semilogy(temps, qi_non_eq, linewidth=3)
        plt.legend(['eq', 'non-eq'])
        return plt
Ejemplo n.º 11
0
def get_arrhenius_plot(temps, diffusivities, **kwargs):
    """
    Returns an Arrhenius plot.

    Args:
        temps ([float]): A sequence of temperatures.
        diffusivities ([float]): A sequence of diffusivities (e.g.,
            from DiffusionAnalyzer.diffusivity).
        \*\*kwargs:
            Any keyword args supported by matplotlib.pyplot.plot.

    Returns:
        A matplotlib.pyplot object. Do plt.show() to show the plot.
    """
    t_1 = 1000 / np.array(temps)
    logd = np.log10(diffusivities)
    #Do a least squares regression of log(D) vs 1000/T
    A = np.array([t_1, np.ones(len(temps))]).T
    w = np.array(np.linalg.lstsq(A, logd)[0])
    from pymatgen.util.plotting_utils import get_publication_quality_plot
    plt = get_publication_quality_plot(12, 8)
    plt.plot(t_1, logd, 'ko', t_1, np.dot(A, w), 'k--', markersize=10,
             **kwargs)
    # Calculate the activation energy in meV = negative of the slope,
    # * kB (/ electron charge to convert to eV), * 1000 (inv. temperature
    # scale), * 1000 (eV -> meV), * math.log(10) (the regression is carried
    # out in base 10 for easier reading of the diffusivity scale,
    # but the Arrhenius relationship is in base e).
    actv_energy = - w[0] * phyc.k_b / phyc.e * 1e6 * math.log(10)
    plt.text(0.6, 0.85, "E$_a$ = {:.0f} meV".format(actv_energy),
             fontsize=30, transform=plt.axes().transAxes)
    plt.ylabel("log(D (cm$^2$/s))")
    plt.xlabel("1000/T (K$^{-1}$)")
    plt.tight_layout()
    return plt
Ejemplo n.º 12
0
    def get_plot(self, width, height):
        """
        Returns a plot object.

        Args:
            width:
                Width of the plot.
            height:
                Height of the plot.

        Returns:
            A matplotlib plot object.
        """
        plt = get_publication_quality_plot(width, height)
        for label, electrode in self._electrodes.items():
            (x, y) = self.get_plot_data(electrode)
            plt.plot(x, y, '-', linewidth=2, label=label)

        plt.legend()
        if self.xaxis == "capacity":
            plt.xlabel('Capacity (mAh/g)')
        else:
            plt.xlabel('Fraction')
        plt.ylabel('Voltage (V)')
        plt.tight_layout()
        return plt
Ejemplo n.º 13
0
    def get_scan_plot(self, coords=None):
        """
        Get a matplotlib plot of the potential energy surface.

        Args:
            coords: internal coordinate name to use as abcissa.
        """
        from pymatgen.util.plotting_utils import get_publication_quality_plot

        plt = get_publication_quality_plot(12, 8)

        d = self.read_scan()

        if coords and coords in d["coords"]:
            x = d["coords"][coords]
            plt.xlabel(coords)
        else:
            x = range(len(d["energies"]))
            plt.xlabel("points")
        
        plt.ylabel("Energy   /   eV")
        
        e_min = min(d["energies"])
        y = [(e - e_min) * HARTREE_TO_ELECTRON_VOLT for e in d["energies"]]
        
        plt.plot(x, y, "ro--")
        return plt
Ejemplo n.º 14
0
    def get_plot(self, normalize_rxn_coordinate=True, label_barrier=True):
        """
        Returns the NEB plot. Uses Henkelman's approach of spline fitting
        each section of the reaction path based on tangent force and energies.

        Args:
            normalize_rxn_coordinate (bool): Whether to normalize the
                reaction coordinate to between 0 and 1. Defaults to True.
            label_barrier (bool): Whether to label the maximum barrier.

        Returns:
            matplotlib.pyplot object.
        """
        plt = get_publication_quality_plot(12, 8)
        scale = 1 if not normalize_rxn_coordinate else 1 / self.r[-1]
        x = np.arange(0, np.max(self.r), 0.01)
        y = self.spline(x) * 1000
        plt.plot(self.r * scale, self.energies * 1000, 'ro',
                 x * scale, y, 'k-', linewidth=2, markersize=10)
        plt.xlabel("Reaction coordinate")
        plt.ylabel("Energy (meV)")
        plt.ylim((np.min(y) - 10, np.max(y) * 1.02 + 20))
        if label_barrier:
            data = zip(x * scale, y)
            barrier = max(data, key=lambda d: d[1])
            plt.plot([0, barrier[0]], [barrier[1], barrier[1]], 'k--')
            plt.annotate('%.0f meV' % barrier[1],
                         xy=(barrier[0] / 2, barrier[1] * 1.02),
                         xytext=(barrier[0] / 2, barrier[1] * 1.02),
                         horizontalalignment='center')
        plt.tight_layout()
        return plt
Ejemplo n.º 15
0
    def plot(self, width=8, height=None, plt=None, dpi=None, **kwargs):
        """
        Plot the equation of state.

        Args:
            width (float): Width of plot in inches. Defaults to 8in.
            height (float): Height of plot in inches. Defaults to width *
                golden ratio.
            plt (matplotlib.pyplot): If plt is supplied, changes will be made
                to an existing plot. Otherwise, a new plot will be created.
            dpi:
            kwargs (dict): additional args fed to pyplot.plot.
                supported keys: style, color, text, label

        Returns:
            Matplotlib plot object.
        """
        plt = get_publication_quality_plot(width=width,
                                           height=height,
                                           plt=plt,
                                           dpi=dpi)

        color = kwargs.get("color", "r")
        label = kwargs.get("label", "{} fit".format(self.name))
        text = kwargs.get("text", None)

        # Plot input data.
        plt.plot(self.volumes,
                 self.energies,
                 linestyle="None",
                 marker="o",
                 color=color)

        # Plot EOS.
        vmin, vmax = self.volumes.min(), self.volumes.max()
        vmin, vmax = (vmin - 0.01 * abs(vmin), vmax + 0.01 * abs(vmax))
        vfit = np.linspace(vmin, vmax, 100)

        if self.eos_name == "deltafactor":
            plt.plot(vfit,
                     np.polyval(self.eos_params, vfit**(-2. / 3.)),
                     linestyle="dashed",
                     color=color,
                     label=label)
        else:
            plt.plot(vfit,
                     self.func(vfit, *self.eos_params),
                     linestyle="dashed",
                     color=color,
                     label=label)

        plt.grid(True)
        plt.xlabel("Volume $\AA^3$")
        plt.ylabel("Energy (eV)")
        plt.legend(loc="best", shadow=True)
        # Add text with fit parameters.
        if not text:
            plt.text(0.4, 0.5, str(self), transform=plt.gca().transAxes)

        return plt
Ejemplo n.º 16
0
def plot_chgint(args):
    chgcar = Chgcar.from_file(args.filename[0])
    s = chgcar.structure

    if args.inds:
        atom_ind = map(int, args.inds[0].split(","))
    else:
        finder = SymmetryFinder(s, symprec=0.1)
        sites = [
            sites[0]
            for sites in finder.get_symmetrized_structure().equivalent_sites
        ]
        atom_ind = [s.sites.index(site) for site in sites]

    from pymatgen.util.plotting_utils import get_publication_quality_plot
    plt = get_publication_quality_plot(12, 8)
    for i in atom_ind:
        d = chgcar.get_integrated_diff(i, args.radius, 30)
        plt.plot(d[:, 0],
                 d[:, 1],
                 label="Atom {} - {}".format(i, s[i].species_string))
    plt.legend(loc="upper left")
    plt.xlabel("Radius (A)")
    plt.ylabel("Integrated charge (e)")
    plt.tight_layout()
    plt.show()
Ejemplo n.º 17
0
    def get_xrd_plot(self, structure, two_theta_range=(0, 90),
                     annotate_peaks=True):
        """
        Returns the XRD plot as a matplotlib.pyplot.

        Args:
            structure: Input structure
            two_theta_range ([float of length 2]): Tuple for range of
                two_thetas to calculate in degrees. Defaults to (0, 90). Set to
                None if you want all diffracted beams within the limiting
                sphere of radius 2 / wavelength.
            annotate_peaks: Whether to annotate the peaks with plane
                information.

        Returns:
            (matplotlib.pyplot)
        """
        from pymatgen.util.plotting_utils import get_publication_quality_plot
        plt = get_publication_quality_plot(16, 10)
        for two_theta, i, hkls, d_hkl in self.get_xrd_data(
                structure, two_theta_range=two_theta_range):
            if two_theta_range[0] <= two_theta <= two_theta_range[1]:
                label = ", ".join([str(hkl) for hkl in hkls.keys()])
                plt.plot([two_theta, two_theta], [0, i], color='k',
                         linewidth=3, label=label)
                if annotate_peaks:
                    plt.annotate(label, xy=[two_theta, i],
                                 xytext=[two_theta, i], fontsize=16)
        plt.xlabel(r"$2\theta$ ($^\circ$)")
        plt.ylabel("Intensities (scaled)")
        plt.tight_layout()

        return plt
Ejemplo n.º 18
0
def get_convergence_plot(df, x, y, tol):
    '''
    Convenient method to directly draw the convergence test plot.
    :param df (pd.df): Pandas DataFrame. For energy cutoff, use the df from
        analyzer. For others, you need to update the input parameter using
        appropriate analyze_*() method.
    :param x (str): Tested input parameter. Choose among ['nkpts','ecut',
        'nvac','nslab'].
    :param y (str): Tested output. Choose among ['energy','total_force',
        'surf_energy'].
    :param tol (float): Convergence criteria.
    :return: Convergence test plot as a matplotlib.pyplot object.
    '''
    i = get_converged_param(df, x, y, tol)
    print i
    plt = get_publication_quality_plot(8, 6)
    x = df[x].values
    y = df[y].values
    plt.plot(x, y, 'bo-', fillstyle='none')
    ax = plt.gca()
    xmin, xmax = ax.get_xlim()
    for e in [y[-1] - tol, y[-1] + tol]:
        plt.plot([xmin, xmax], [e, e], 'k--', lw=2)
    plt.axvline(x=i, color='k', linestyle='dashed', lw=2)
    return plt
Ejemplo n.º 19
0
def get_convergence_plot(df, x, y, tol):
    '''
    Convenient method to directly draw the convergence test plot.
    :param df (pd.df): Pandas DataFrame. For energy cutoff, use the df from
        analyzer. For others, you need to update the input parameter using
        appropriate analyze_*() method.
    :param x (str): Tested input parameter. Choose among ['nkpts','ecut',
        'nvac','nslab'].
    :param y (str): Tested output. Choose among ['energy','total_force',
        'surf_energy'].
    :param tol (float): Convergence criteria.
    :return: Convergence test plot as a matplotlib.pyplot object.
    '''
    i = get_converged_param(df, x, y, tol)
    print i
    plt = get_publication_quality_plot(8, 6)
    x = df[x].values
    y = df[y].values
    plt.plot(x, y, 'bo-', fillstyle='none')
    ax = plt.gca()
    xmin, xmax = ax.get_xlim()
    for e in [y[-1] - tol, y[-1] + tol]:
        plt.plot([xmin,xmax], [e,e], 'k--', lw=2)
    plt.axvline(x=i, color='k', linestyle='dashed', lw=2)
    return plt
Ejemplo n.º 20
0
def get_arrhenius_plot(temps,
                       diffusivities,
                       diffusivity_errors=None,
                       **kwargs):
    """
    Returns an Arrhenius plot.

    Args:
        temps ([float]): A sequence of temperatures.
        diffusivities ([float]): A sequence of diffusivities (e.g.,
            from DiffusionAnalyzer.diffusivity).
        diffusivity_errors ([float]): A sequence of errors for the
            diffusivities. If None, no error bar is plotted.
        \*\*kwargs:
            Any keyword args supported by matplotlib.pyplot.plot.

    Returns:
        A matplotlib.pyplot object. Do plt.show() to show the plot.
    """
    Ea, c, _ = fit_arrhenius(temps, diffusivities)

    from pymatgen.util.plotting_utils import get_publication_quality_plot
    plt = get_publication_quality_plot(12, 8)

    # log10 of the arrhenius fit
    arr = c * np.exp(-Ea / (const.k / const.e * np.array(temps)))

    t_1 = 1000 / np.array(temps)

    plt.plot(t_1,
             diffusivities,
             'ko',
             t_1,
             arr,
             'k--',
             markersize=10,
             **kwargs)
    if diffusivity_errors is not None:
        n = len(diffusivity_errors)
        plt.errorbar(t_1[0:n],
                     diffusivities[0:n],
                     yerr=diffusivity_errors,
                     fmt='ko',
                     ecolor='k',
                     capthick=2,
                     linewidth=2)
    ax = plt.axes()
    ax.set_yscale('log')
    plt.text(0.6,
             0.85,
             "E$_a$ = {:.0f} meV".format(Ea * 1000),
             fontsize=30,
             transform=plt.axes().transAxes)
    plt.ylabel("D (cm$^2$/s)")
    plt.xlabel("1000/T (K$^{-1}$)")
    plt.tight_layout()
    return plt
Ejemplo n.º 21
0
    def get_spectre_plot(self, sigma=0.05, step=0.01):
        """
        Get a matplotlib plot of the UV-visible spectra. Transition are plotted
        as vertical lines and as a sum of normal functions with sigma with. The
        broadening is applied in energy and the spectra is plotted as a function
        of the wavelength.

        Args:
            sigma: Full width at half maximum in eV for normal functions.
            step: bin interval in eV

        Returns:
            A dict: {"energies": values, "lambda": values, "spectra": values}
                    where values are lists of abscissa (energies, lamba) and
                    the sum of gaussian functions (spectra).
            A matplotlib plot.
        """
        from pymatgen.util.plotting_utils import get_publication_quality_plot
        from matplotlib.mlab import normpdf

        plt = get_publication_quality_plot(12, 8)

        transitions = self.read_excitation_energies()

        minval = min([val[0] for val in transitions]) - 5.0 * sigma
        maxval = max([val[0] for val in transitions]) + 5.0 * sigma
        npts = int((maxval - minval) / step) + 1

        eneval = np.linspace(minval, maxval, npts)  # in eV
        lambdaval = [cst.h * cst.c / (val * cst.e) * 1.0e9 for val in eneval]  # in nm

        # sum of gaussian functions
        spectre = np.zeros(npts)
        for trans in transitions:
            spectre += trans[2] * normpdf(eneval, trans[0], sigma)
        spectre /= spectre.max()
        plt.plot(lambdaval, spectre, "r-", label="spectre")

        data = {"energies": eneval, "lambda": lambdaval, "spectra": spectre}

        # plot transitions as vlines
        plt.vlines(
            [val[1] for val in transitions],
            0.0,
            [val[2] for val in transitions],
            color="blue",
            label="transitions",
            linewidth=2,
        )

        plt.xlabel("$\lambda$ (nm)")
        plt.ylabel("Arbitrary unit")
        plt.legend()

        return data, plt
Ejemplo n.º 22
0
def banddos(pref='',storedir=None):
    ru=str("vasprun.xml")
    kpfile=str("KPOINTS")




    run = Vasprun(ru, parse_projected_eigen = True)
    bands = run.get_band_structure(kpfile, line_mode = True, efermi = run.efermi)
    bsp =  BSPlotter(bands)
    zero_to_efermi=True
    bandgap=str(round(bands.get_band_gap()['energy'],3))
    print "bg=",bandgap
    data=bsp.bs_plot_data(zero_to_efermi)
    plt = get_publication_quality_plot(12, 8)
    band_linewidth = 3
    x_max = data['distances'][-1][-1]
    print (x_max)
    for d in range(len(data['distances'])):
       for i in range(bsp._nb_bands):
          plt.plot(data['distances'][d],
                 [data['energy'][d]['1'][i][j]
                  for j in range(len(data['distances'][d]))], 'b-',
                 linewidth=band_linewidth)
          if bsp._bs.is_spin_polarized:
             plt.plot(data['distances'][d],
                     [data['energy'][d]['-1'][i][j]
                      for j in range(len(data['distances'][d]))],
                     'r--', linewidth=band_linewidth)
    bsp._maketicks(plt)
    if bsp._bs.is_metal():
         e_min = -10
         e_max = 10
         band_linewidth = 3

    for cbm in data['cbm']:
            plt.scatter(cbm[0], cbm[1], color='r', marker='o',
                        s=100)

            for vbm in data['vbm']:
                plt.scatter(vbm[0], vbm[1], color='g', marker='o',
                            s=100)


    plt.xlabel(r'$\mathrm{Wave\ Vector}$', fontsize=30)
    ylabel = r'$\mathrm{E\ -\ E_f\ (eV)}$' if zero_to_efermi \
       else r'$\mathrm{Energy\ (eV)}$'
    plt.ylabel(ylabel, fontsize=30)
    plt.ylim(-4,4)
    plt.xlim(0,x_max)
    plt.tight_layout()
    plt.savefig('BAND.png',img_format="png")

    plt.close()
Ejemplo n.º 23
0
    def get_spectre_plot(self, sigma=0.05, step=0.01):
        """
        Get a matplotlib plot of the UV-visible spectra. Transition are plotted
        as vertical lines and as a sum of normal functions with sigma with. The
        broadening is applied in energy and the spectra is plotted as a function
        of the wavelength.

        Args:
            sigma: Full width at half maximum in eV for normal functions.
            step: bin interval in eV

        Returns:
            A dict: {"energies": values, "lambda": values, "spectra": values}
                    where values are lists of abscissa (energies, lamba) and
                    the sum of gaussian functions (spectra).
            A matplotlib plot.
        """
        from pymatgen.util.plotting_utils import get_publication_quality_plot
        from matplotlib.mlab import normpdf
        plt = get_publication_quality_plot(12, 8)

        transitions = self.read_excitation_energies()

        minval = min([val[0] for val in transitions]) - 5.0 * sigma
        maxval = max([val[0] for val in transitions]) + 5.0 * sigma
        npts = int((maxval - minval) / step) + 1

        eneval = np.linspace(minval, maxval, npts)  # in eV
        lambdaval = [cst.h * cst.c / (val * cst.e) * 1.e9
                     for val in eneval]  # in nm

        # sum of gaussian functions
        spectre = np.zeros(npts)
        for trans in transitions:
            spectre += trans[2] * normpdf(eneval, trans[0], sigma)
        spectre /= spectre.max()
        plt.plot(lambdaval, spectre, "r-", label="spectre")

        data = {"energies": eneval, "lambda": lambdaval, "spectra": spectre}

        # plot transitions as vlines
        plt.vlines([val[1] for val in transitions], \
                   0., \
                   [val[2] for val in transitions], \
                   color="blue", \
                   label="transitions",
                   linewidth=2)

        plt.xlabel("$\lambda$ (nm)")
        plt.ylabel("Arbitrary unit")
        plt.legend()

        return data, plt
Ejemplo n.º 24
0
    def get_plot(self, ylim=None):
        """
        Get a matplotlib object for the bandstructure plot.

        Args:
            ylim: Specify the y-axis (frequency) limits; by default None let
                the code choose.
        """
        plt = get_publication_quality_plot(12, 8)
        from matplotlib import rc
        import scipy.interpolate as scint
        try:
            rc('text', usetex=True)
        except:
            # Fall back on non Tex if errored.
            rc('text', usetex=False)

        band_linewidth = 1

        data = self.bs_plot_data()
        for d in range(len(data['distances'])):
            for i in range(self._nb_bands):
                plt.plot(data['distances'][d], [
                    data['frequency'][d][i][j]
                    for j in range(len(data['distances'][d]))
                ],
                         'b-',
                         linewidth=band_linewidth)

        self._maketicks(plt)

        # plot y=0 line
        plt.axhline(0, linewidth=1, color='k')

        # Main X and Y Labels
        plt.xlabel(r'$\mathrm{Wave\ Vector}$', fontsize=30)
        ylabel = r'$\mathrm{Frequency\ (THz)}$'
        plt.ylabel(ylabel, fontsize=30)

        # X range (K)
        # last distance point
        x_max = data['distances'][-1][-1]
        plt.xlim(0, x_max)

        if ylim is not None:
            plt.ylim(ylim)

        plt.tight_layout()

        return plt
Ejemplo n.º 25
0
 def plot_smoothed_msd(self):
     """
     Plot the smoothed msd vs time graph. Useful for checking convergence.
     """
     from pymatgen.util.plotting_utils import get_publication_quality_plot
     plt = get_publication_quality_plot(12, 8)
     plt.plot(self.dt, self.s_msd, 'k')
     plt.plot(self.dt, self.s_msd_components[:, 0], 'r')
     plt.plot(self.dt, self.s_msd_components[:, 1], 'g')
     plt.plot(self.dt, self.s_msd_components[:, 2], 'b')
     plt.legend(["Overall", "a", "b", "c"], loc=2, prop={"size": 20})
     plt.xlabel("Timestep")
     plt.ylabel("MSD")
     plt.tight_layout()
     plt.show()
Ejemplo n.º 26
0
 def plot_smoothed_msd(self):
     """
     Plot the smoothed msd vs time graph. Useful for checking convergence.
     """
     from pymatgen.util.plotting_utils import get_publication_quality_plot
     plt = get_publication_quality_plot(12, 8)
     plt.plot(self.dt, self.s_msd, 'k')
     plt.plot(self.dt, self.s_msd_components[:, 0], 'r')
     plt.plot(self.dt, self.s_msd_components[:, 1], 'g')
     plt.plot(self.dt, self.s_msd_components[:, 2], 'b')
     plt.legend(["Overall", "a", "b", "c"], loc=2, prop={"size": 20})
     plt.xlabel("Timestep")
     plt.ylabel("MSD")
     plt.tight_layout()
     plt.show()
Ejemplo n.º 27
0
    def get_1d_plot(self, type="distinct", times=[0.0], colors=None):
        """
        Plot the van Hove function at given r or t.

        Args:
            type (str): Specify which part of van Hove function to be plotted.
            times (list of float): Time moments (in ps) in which the van Hove
                            function will be plotted.
            colors (list strings/tuples): Additional color settings. If not set,
                            seaborn.color_plaette("Set1", 10) will be used.
        """
        if colors is None:
            import seaborn as sns
            colors = sns.color_palette("Set1", 10)

        assert type in ["distinct", "self"]
        assert len(times) <= len(colors)

        if type == "distinct":
            grt = self.gdrt.copy()
            ylabel = "$G_d$($t$,$r$)"
            ylim = [-0.005, 4.0]
        elif type == "self":
            grt = self.gsrt.copy()
            ylabel = "4$\pi r^2G_s$($t$,$r$)"
            ylim = [-0.005, 1.0]

        plt = get_publication_quality_plot(12, 8)

        for i, time in enumerate(times):
            index = int(np.round(time / self.timeskip))
            index = min(index, np.shape(grt)[0] - 1)
            new_time = index * self.timeskip
            label = str(new_time) + " ps"
            plt.plot(self.interval,
                     grt[index],
                     color=colors[i],
                     label=label,
                     linewidth=4.0)

        plt.xlabel("$r$ ($\AA$)")
        plt.ylabel(ylabel)
        plt.legend(loc='upper right', fontsize=36)
        plt.xlim(0.0, self.interval[-1] - 1.0)
        plt.ylim(ylim[0], ylim[1])
        plt.tight_layout()

        return plt
Ejemplo n.º 28
0
    def get_framework_rms_plot(self,
                               plt=None,
                               granularity=200,
                               matching_s=None):
        """
        Get the plot of rms framework displacement vs time. Useful for checking
        for melting, especially if framework atoms can move via paddle-wheel
        or similar mechanism (which would show up in max framework displacement
        but doesn't constitute melting).

        Args:
            granularity (int): Number of structures to match
            matching_s (Structure): Optionally match to a disordered structure
                instead of the first structure in the analyzer. Required when
                a secondary mobile ion is present.
        """
        from pymatgen.util.plotting_utils import get_publication_quality_plot
        plt = get_publication_quality_plot(12, 8, plt=plt)
        step = (self.corrected_displacements.shape[1] - 1) // (granularity - 1)
        f = (matching_s or self.structure).copy()
        f.remove_species([self.specie])
        sm = StructureMatcher(primitive_cell=False,
                              stol=0.6,
                              comparator=OrderDisorderElementComparator(),
                              allow_subset=True)
        rms = []
        for s in self.get_drift_corrected_structures(step=step):
            s.remove_species([self.specie])
            d = sm.get_rms_dist(f, s)
            if d:
                rms.append(d)
            else:
                rms.append((1, 1))
        max_dt = (len(rms) - 1) * step * self.step_skip * self.time_step
        if max_dt > 100000:
            plot_dt = np.linspace(0, max_dt / 1000, len(rms))
            unit = 'ps'
        else:
            plot_dt = np.linspace(0, max_dt, len(rms))
            unit = 'fs'
        rms = np.array(rms)
        plt.plot(plot_dt, rms[:, 0], label='RMS')
        plt.plot(plot_dt, rms[:, 1], label='max')
        plt.legend(loc='best')
        plt.xlabel("Timestep ({})".format(unit))
        plt.ylabel("normalized distance")
        plt.tight_layout()
        return plt
Ejemplo n.º 29
0
def get_arrhenius_plot(temps, diffusivites, **kwargs):
    """
    Returns an Arrhenius plot.

    Args:
        temps:
            A sequence of temperatures.
        diffusivities:
            A sequence of diffusivities (e.g., from DiffusionAnalyzer
            .diffusivity).
        \*\*kwargs:
            Any keyword args supported by matplotlib.pyplot.plot.

    Returns:
        A matplotlib.pyplot object. Do plt.show() to show the plot.
    """
    t_1 = 1000 / np.array(temps)
    logd = np.log10(diffusivites)
    #Do a least squares regression of log(D) vs 1000/T
    A = np.array([t_1, np.ones(len(temps))]).T
    w = np.array(np.linalg.lstsq(A, logd)[0])
    from pymatgen.util.plotting_utils import get_publication_quality_plot
    plt = get_publication_quality_plot(12, 8)
    plt.plot(t_1,
             logd,
             'ko',
             t_1,
             np.dot(A, w),
             'k--',
             markersize=10,
             **kwargs)
    # Calculate the activation energy in meV = negative of the slope,
    # * kB (/ electron charge to convert to eV), * 1000 (inv. temperature
    # scale), * 1000 (eV -> meV), * math.log(10) (the regression is carried
    # out in base 10 for easier reading of the diffusivity scale,
    # but the Arrhenius relationship is in base e).
    actv_energy = -w[0] * phyc.k_b / phyc.e * 1e6 * math.log(10)
    plt.annotate("E$_a$ = {:.0f} meV".format(actv_energy),
                 (t_1[-1], w[0] * t_1[-1] + w[1]),
                 xytext=(100, 0),
                 xycoords='data',
                 textcoords='offset points',
                 fontsize=30)
    plt.ylabel("log(D (cm$^2$/s))")
    plt.xlabel("1000/T (K$^{-1}$)")
    plt.tight_layout()
    return plt
Ejemplo n.º 30
0
def main():
    parser = argparse.ArgumentParser(
        description='''Convenient DOS Plotter for Feff runs.
    Author: Alan Dozier
    Version: 1.0
    Last updated: April, 2013''')

    parser.add_argument('filename',
                        metavar='filename',
                        type=str,
                        nargs=1,
                        help='xmu file to plot')
    parser.add_argument('filename1',
                        metavar='filename1',
                        type=str,
                        nargs=1,
                        help='feff.inp filename to import')

    plt = get_publication_quality_plot(12, 8)
    color_order = ['r', 'b', 'g', 'c', 'k', 'm', 'y']

    args = parser.parse_args()
    xmu = Xmu.from_file(args.filename[0], args.filename1[0])

    data = xmu.to_dict

    plt.title(data['calc'] + ' Feff9.6 Calculation for ' + data['atom'] +
              ' in ' + data['formula'] + ' unit cell')
    plt.xlabel('Energies (eV)')
    plt.ylabel('Absorption Cross-section')

    x = data['energies']
    y = data['scross']
    tle = 'Single ' + data['atom'] + ' ' + data['edge'] + ' edge'
    plt.plot(x, y, color_order[1 % 7], label=tle)

    y = data['across']
    tle = data['atom'] + ' ' + data['edge'] + ' edge in ' + data['formula']
    plt.plot(x, y, color_order[2 % 7], label=tle)

    plt.legend()
    leg = plt.gca().get_legend()
    ltext = leg.get_texts()  # all the text.Text instance in the legend
    plt.setp(ltext, fontsize=15)
    plt.tight_layout()
    plt.show()
Ejemplo n.º 31
0
def plot_enc_convergence(directory='../vasp/examples/SiOptb88/',plot=False,filename='.'):
    """
    Plot convergence for plane-wave cut-off data
    Works only if jobs run through jarvis-tools framework
    
    Args:
       directory: parent directory for job run
    Returns:
           matplotlib object, converged cut-off value
    """

    x=[]
    y=[]
    for a in glob.glob(str(directory)+str('/*.json')):
      if 'MAIN-RELAX' in a:
         main_inc=str(a.split(".json")[0])+str("/")+str("INCAR")
         main_inc_obj=Incar.from_file(main_inc)
         convg_encut=float(main_inc_obj['ENCUT'])
      elif 'ENCUT' in a:
          run=str(a.split(".json")[0])+str("/")+str("vasprun.xml")
          contcar=Structure.from_file((a.split(".json")[0])+str("/")+str("CONTCAR"))
          vrun=Vasprun(run)
          infile=str(a.split(".json")[0])+str("/")+str("INCAR")
          inc=Incar.from_file(infile)
          encut=inc['ENCUT']
          en =float(vrun.final_energy)#/float(contcar.composition.num_atoms)
          x.append(encut)
          y.append(en)
    order = np.argsort(x)
    xs = np.array(x)[order]
    ys = np.array(y)[order]
    plt = get_publication_quality_plot(14, 10)
    plt.ylabel('Energy (eV)')
    plt.plot(xs,ys,'s-',linewidth=2,markersize=10)
    plt.xlabel('Increment in ENCUT ')
    ax = plt.gca()
    ax.get_yaxis().get_major_formatter().set_useOffset(False)
    plt.title(str("Converged at ")+str(int(convg_encut))+str("eV"),fontsize=26)
    #filename=str('Encut.png')
    plt.tight_layout()
    if plot==True:
       plt.savefig(filename)
       plt.close()

    return plt,convg_encut
Ejemplo n.º 32
0
def get_ea_plot(csv_file, e_scale=1):
    '''
    Convenient method to directly draw the energy as a function of lattice
    constant plot from a csv file.
    :param csv_file (str): CSV file name.
    :param e_scale (float): The factor of total energy in case
            various scales (meV/f.u. or meV/atom) are needed.
    :return: E-a plot as a matplotlib.pyplot object.
    '''
    b = BasicAnalyzer(csv_file, 'alat', e_scale)
    assert b['alat'].max() != b['alat'].min()
    a0 = b.emin_config.alat
    print a0
    plt = get_publication_quality_plot(8, 6)
    plt.plot(b['alat'], b['energy'], 'bo--', fillstyle='none')
    plt.axvline(x=a0, color='k', linestyle='dashed', lw=2)
    plt.xlabel('Lattice constant $a$ ($\AA$)')
    return plt
Ejemplo n.º 33
0
def get_ea_plot(csv_file, e_scale=1):
    '''
    Convenient method to directly draw the energy as a function of lattice
    constant plot from a csv file.
    :param csv_file (str): CSV file name.
    :param e_scale (float): The factor of total energy in case
            various scales (meV/f.u. or meV/atom) are needed.
    :return: E-a plot as a matplotlib.pyplot object.
    '''
    b = BasicAnalyzer(csv_file, 'alat', e_scale)
    assert b['alat'].max() != b['alat'].min()
    a0 = b.emin_config.alat
    print a0
    plt = get_publication_quality_plot(8, 6)
    plt.plot(b['alat'], b['energy'], 'bo--', fillstyle='none')
    plt.axvline(x=a0, color='k', linestyle='dashed', lw=2)
    plt.xlabel('Lattice constant $a$ ($\AA$)')
    return plt
Ejemplo n.º 34
0
    def get_framework_rms_plot(self, plt=None, granularity=200, matching_s=None):
        """
        Get the plot of rms framework displacement vs time. Useful for checking
        for melting, especially if framework atoms can move via paddle-wheel
        or similar mechanism (which would show up in max framework displacement
        but doesn't constitute melting).

        Args:
            granularity (int): Number of structures to match
            matching_s (Structure): Optionally match to a disordered structure
                instead of the first structure in the analyzer. Required when
                a secondary mobile ion is present.
        """
        from pymatgen.util.plotting_utils import get_publication_quality_plot
        plt = get_publication_quality_plot(12, 8, plt=plt)
        step = (self.corrected_displacements.shape[1] - 1) // (granularity - 1)
        f = (matching_s or self.structure).copy()
        f.remove_species([self.specie])
        sm = StructureMatcher(primitive_cell=False, stol=0.6,
                              comparator=OrderDisorderElementComparator(),
                              allow_subset=True)
        rms = []
        for s in self.get_drift_corrected_structures(step=step):
            s.remove_species([self.specie])
            d = sm.get_rms_dist(f, s)
            if d:
                rms.append(d)
            else:
                rms.append((1, 1))
        max_dt = (len(rms) - 1) * step * self.step_skip * self.time_step
        if max_dt > 100000:
            plot_dt = np.linspace(0, max_dt/1000, len(rms))
            unit = 'ps'
        else:
            plot_dt = np.linspace(0, max_dt, len(rms))
            unit = 'fs'
        rms = np.array(rms)
        plt.plot(plot_dt, rms[:, 0], label='RMS')
        plt.plot(plot_dt, rms[:, 1], label='max')
        plt.legend(loc='best')
        plt.xlabel("Timestep ({})".format(unit))
        plt.ylabel("normalized distance")
        plt.tight_layout()
        return plt
Ejemplo n.º 35
0
    def plot(self, sg):
        cs = sg.crystal_system
        params = {
            "a": 10,
            "b": 12,
            "c": 14,
            "alpha": 20,
            "beta": 30,
            "gamma": 40
        }
        cs = "rhombohedral" if cs == "Trigonal" else cs
        func = getattr(Lattice, cs.lower())
        kw = {k: params[k] for k in inspect.getargspec(func).args}
        lattice = func(**kw)
        global plt
        for i in range(2):
            plt.plot([0, lattice.matrix[i][0]], [0, lattice.matrix[i][1]],
                     'k-')

        l = np.arange(0, 0.02, 0.02 / 100)
        theta = np.arange(0, 2 * np.pi, 2 * np.pi / 100) - np.pi / 2

        x = l * np.cos(theta) + 0.025
        y = l * np.sin(theta) + 0.031
        d = np.array(zip(x, y, [0.01] * len(x), [1] * len(x)))

        for op in sg.symmetry_ops:
            dd = np.dot(op, d.T).T
            for tx, ty in itertools.product((0, 1), (0, 1)):
                ddd = dd[:, 0:3] + np.array([tx, ty, 0])[None, :]
                color = "r" if ddd[0, 2] > 0.5 else "b"
                coords = lattice.get_cartesian_coords(ddd[:, 0:3])
                plt.plot(coords[:, 0], coords[:, 1], color + "-")

        plt.plot(x, y, 'k-')
        max_l = max(params['a'], params['b'])
        plt = get_publication_quality_plot(8, 8, plt)
        lim = [-max_l * 0.1, max_l * 1.1]
        plt.xlim(lim)
        plt.ylim(lim)
        plt.tight_layout()
        plt.show()
        print lattice.lengths_and_angles
Ejemplo n.º 36
0
def get_arrhenius_plot(temps, diffusivities, diffusivity_errors=None,
                       **kwargs):
    """
    Returns an Arrhenius plot.

    Args:
        temps ([float]): A sequence of temperatures.
        diffusivities ([float]): A sequence of diffusivities (e.g.,
            from DiffusionAnalyzer.diffusivity).
        diffusivity_errors ([float]): A sequence of errors for the
            diffusivities. If None, no error bar is plotted.
        \*\*kwargs:
            Any keyword args supported by matplotlib.pyplot.plot.

    Returns:
        A matplotlib.pyplot object. Do plt.show() to show the plot.
    """
    Ea, c = fit_arrhenius(temps, diffusivities)

    from pymatgen.util.plotting_utils import get_publication_quality_plot
    plt = get_publication_quality_plot(12, 8)

    #log10 of the arrhenius fit
    arr = c * np.exp(-Ea / (phyc.k_b / phyc.e *
                                               np.array(temps)))

    t_1 = 1000 / np.array(temps)

    plt.plot(t_1, diffusivities, 'ko', t_1, arr, 'k--', markersize=10,
             **kwargs)
    if diffusivity_errors is not None:
        n = len(diffusivity_errors)
        plt.errorbar(t_1[0:n], diffusivities[0:n], yerr=diffusivity_errors,
                     fmt='ko', ecolor='k', capthick=2, linewidth=2)
    ax = plt.axes()
    ax.set_yscale('log')
    plt.text(0.6, 0.85, "E$_a$ = {:.0f} meV".format(Ea * 1000),
             fontsize=30, transform=plt.axes().transAxes)
    plt.ylabel("D (cm$^2$/s)")
    plt.xlabel("1000/T (K$^{-1}$)")
    plt.tight_layout()
    return plt
Ejemplo n.º 37
0
    def get_smoothed_msd_plot(self, plt=None):
        """
        Get the plot of the smoothed msd vs time graph. Useful for
        checking convergence. This can be written to an image file.

        Args:
            plt: A plot object. Defaults to None, which means one will be
                generated.
        """
        from pymatgen.util.plotting_utils import get_publication_quality_plot
        plt = get_publication_quality_plot(12, 8, plt=plt)
        plt.plot(self.dt, self.s_msd, 'k')
        plt.plot(self.dt, self.s_msd_components[:, 0], 'r')
        plt.plot(self.dt, self.s_msd_components[:, 1], 'g')
        plt.plot(self.dt, self.s_msd_components[:, 2], 'b')
        plt.legend(["Overall", "a", "b", "c"], loc=2, prop={"size": 20})
        plt.xlabel("Timestep")
        plt.ylabel("MSD")
        plt.tight_layout()
        return plt
Ejemplo n.º 38
0
    def get_smoothed_msd_plot(self, plt=None):
        """
        Get the plot of the smoothed msd vs time graph. Useful for
        checking convergence. This can be written to an image file.

        Args:
            plt: A plot object. Defaults to None, which means one will be
                generated.
        """
        from pymatgen.util.plotting_utils import get_publication_quality_plot
        plt = get_publication_quality_plot(12, 8, plt=plt)
        plt.plot(self.dt, self.s_msd, 'k')
        plt.plot(self.dt, self.s_msd_components[:, 0], 'r')
        plt.plot(self.dt, self.s_msd_components[:, 1], 'g')
        plt.plot(self.dt, self.s_msd_components[:, 2], 'b')
        plt.legend(["Overall", "a", "b", "c"], loc=2, prop={"size": 20})
        plt.xlabel("Timestep")
        plt.ylabel("MSD")
        plt.tight_layout()
        return plt
    def get_1d_plot(self, type="distinct", times=[0.0], colors=["r", "g", "b"]):
        """
        Plot the van Hove function at given r or t.

        Args:
            type (str): Specify which part of van Hove function to be plotted.
            times (list of float): Time moments (in ps) in which the van Hove
                            function will be plotted.
            colors (list of str): Default list of colors for plotting.
        """

        assert type in ["distinct", "self"]
        assert len(times) <= len(colors)

        if type == "distinct":
            grt = self.gdrt.copy()
            ylabel = "$G_d$($t$,$r$)"
            ylim = [-0.005, 4.0]
        elif type == "self":
            grt = self.gsrt.copy()
            ylabel = "4$\pi r^2G_s$($t$,$r$)"
            ylim = [-0.005, 1.0]

        plt = get_publication_quality_plot(12, 8)

        for i, time in enumerate(times):
            index = int(np.round(time / self.timeskip))
            index = min(index, np.shape(grt)[0] - 1)
            new_time = index * self.timeskip
            label = str(new_time) + " ps"
            plt.plot(self.interval, grt[index], color=colors[i], label=label,
                     linewidth=4.0)

        plt.xlabel("$r$ ($\AA$)")
        plt.ylabel(ylabel)
        plt.legend(loc='upper right', fontsize=36)
        plt.xlim(0.0, self.interval[-1] - 1.0)
        plt.ylim(ylim[0], ylim[1])
        plt.tight_layout()

        return plt
Ejemplo n.º 40
0
def main():
    parser = argparse.ArgumentParser(description='''
    Convenient DOS Plotter for Feff runs.
    Author: Alan Dozier
    Version: 1.0
    Last updated: April, 2013''')

    parser.add_argument('filename', metavar='filename', type=str, nargs=1,
                        help='xmu file to plot')
    parser.add_argument('filename1', metavar='filename1', type=str, nargs=1,
                        help='feff.inp filename to import')

    plt = get_publication_quality_plot(12, 8)
    color_order = ['r', 'b', 'g', 'c', 'k', 'm', 'y']

    args = parser.parse_args()
    xmu = Xmu.from_file(args.filename[0], args.filename1[0])

    data = xmu.to_dict

    plt.title(data['calc'] + ' Feff9.6 Calculation for ' + data['atom'] + ' in ' +
              data['formula'] + ' unit cell')
    plt.xlabel('Energies (eV)')
    plt.ylabel('Absorption Cross-section')

    x = data['energies']
    y = data['scross']
    tle = 'Single ' + data['atom'] + ' ' + data['edge'] + ' edge'
    plt.plot(x, y, color_order[1 % 7], label=tle)

    y = data['across']
    tle = data['atom'] + ' ' + data['edge'] + ' edge in ' + data['formula']
    plt.plot(x, y, color_order[2 % 7], label=tle)

    plt.legend()
    leg = plt.gca().get_legend()
    ltext = leg.get_texts()  # all the text.Text instance in the legend
    plt.setp(ltext, fontsize=15)
    plt.tight_layout()
    plt.show()
Ejemplo n.º 41
0
def plot_chgint(args):
    chgcar = Chgcar.from_file(args.filename[0])
    s = chgcar.structure

    if args.inds:
        atom_ind = map(int, args.inds[0].split(","))
    else:
        finder = SymmetryFinder(s, symprec=0.1)
        sites = [sites[0] for sites in finder.get_symmetrized_structure().equivalent_sites]
        atom_ind = [s.sites.index(site) for site in sites]

    from pymatgen.util.plotting_utils import get_publication_quality_plot

    plt = get_publication_quality_plot(12, 8)
    for i in atom_ind:
        d = chgcar.get_integrated_diff(i, args.radius, 30)
        plt.plot(d[:, 0], d[:, 1], label="Atom {} - {}".format(i, s[i].species_string))
    plt.legend(loc="upper left")
    plt.xlabel("Radius (A)")
    plt.ylabel("Integrated charge (e)")
    plt.tight_layout()
    plt.show()
Ejemplo n.º 42
0
 def plot_carriers_ef(self,
                      temp=300,
                      me=[1.0, 1.0, 1.0],
                      mh=[1.0, 1.0, 1.0]):
     """
     plot carrier concentration in function of the fermi energy
     Args:
         temp:
             temperature
         me:
             the effective mass for the electrons as a list of 3 eigenvalues
         mh:
             the effective mass for the holes as a list of 3 eigenvalues
     Returns:
         a matplotlib object
     """
     plt = get_publication_quality_plot(12, 8)
     qi = []
     efs = []
     for ef in [x * 0.01 for x in range(0, 100)]:
         efs.append(ef)
         qi.append(self._analyzer.get_Qi(ef, temp, me, mh) * 1e-6)
     plt.ylim([1e14, 1e22])
     return plt.semilogy(efs, qi)
Ejemplo n.º 43
0
    def get_elt_projected_plots(self, zero_to_efermi=True, ylim=None):
        """
        Method returning a plot composed of subplots along different elements

        Returns:
            a pylab object with different subfigures for each projection
            The blue and red colors are for spin up and spin down
            The bigger the red or blue dot in the band structure the higher
            character for the corresponding element and orbital
        """
        band_linewidth = 1.0
        proj = self._get_projections_by_branches({e.symbol: ['s', 'p', 'd']
                                                  for e in self._bs._structure.composition.elements})
        print proj
        data = self.bs_plot_data(zero_to_efermi)
        from pymatgen.util.plotting_utils import get_publication_quality_plot
        plt = get_publication_quality_plot(12, 8)
        e_min = -4
        e_max = 4
        if self._bs.is_metal():
            e_min = -10
            e_max = 10
        count = 1
        for el in self._bs._structure.composition.elements:
                plt.subplot(220 + count)
                self._maketicks(plt)
                for b in range(len(data['distances'])):
                    for i in range(self._nb_bands):
                        plt.plot(data['distances'][b], [data['energy'][b][str(Spin.up)][i][j]
                                                        for j in range(len(data['distances'][b]))], 'b-',
                                 linewidth=band_linewidth)
                        if self._bs.is_spin_polarized:
                            plt.plot(data['distances'][b],
                                     [data['energy'][b][str(Spin.down)][i][j]
                                      for j in range(len(data['distances'][b]))],
                                     'r--', linewidth=band_linewidth)
                            for j in range(len(data['energy'][b][str(Spin.up)][i])):
                                plt.plot(data['distances'][b][j], data['energy'][b][str(Spin.down)][i][j], 'ro',
                                         markersize=sum([proj[b][str(Spin.down)][i][j][str(el)][o] for o in proj[b]
                                         [str(Spin.down)][i][j][str(el)]]) * 15.0)
                        for j in range(len(data['energy'][b][str(Spin.up)][i])):
                            plt.plot(data['distances'][b][j],
                                     data['energy'][b][str(Spin.up)][i][j], 'bo',
                                     markersize=sum([proj[b][str(Spin.up)][i][j][str(el)][o] for o in proj[b]
                                     [str(Spin.up)][i][j][str(el)]]) * 15.0)
                if ylim is None:
                    if self._bs.is_metal():
                        if zero_to_efermi:
                            plt.ylim(e_min, e_max)
                        else:
                            plt.ylim(self._bs.efermi + e_min, self._bs._efermi
                                     + e_max)
                    else:

                        for cbm in data['cbm']:
                            plt.scatter(cbm[0], cbm[1], color='r', marker='o',
                                        s=100)

                        for vbm in data['vbm']:
                            plt.scatter(vbm[0], vbm[1], color='g', marker='o',
                                        s=100)

                        plt.ylim(data['vbm'][0][1] + e_min, data['cbm'][0][1]
                                 + e_max)
                else:
                    plt.ylim(ylim)
                plt.title(str(el))
                count += 1

        return plt
Ejemplo n.º 44
0
    def get_projected_plots_dots(self, dictio, zero_to_efermi=True, ylim=None):
        """
        Method returning a plot composed of subplots along different elements
        and orbitals.

        Args:
            dictio: The element and orbitals you want a projection on. The
                format is {Element:[Orbitals]} for instance
                {'Cu':['d','s'],'O':['p']} will give projections for Cu on
                d and s orbitals and on oxygen p.

        Returns:
            a pylab object with different subfigures for each projection
            The blue and red colors are for spin up and spin down.
            The bigger the red or blue dot in the band structure the higher
            character for the corresponding element and orbital.
        """
        from pymatgen.util.plotting_utils import get_publication_quality_plot
        band_linewidth = 1.0
        fig_number = sum([len(v) for v in dictio.values()])
        proj = self._get_projections_by_branches(dictio)
        data = self.bs_plot_data(zero_to_efermi)
        plt = get_publication_quality_plot(12, 8)
        e_min = -4
        e_max = 4
        if self._bs.is_metal():
            e_min = -10
            e_max = 10
        count = 1

        for el in dictio:
            for o in dictio[el]:
                print el, o
                plt.subplot(100 * math.ceil(fig_number / 2) + 20 + count)
                self._maketicks(plt)
                for b in range(len(data['distances'])):
                    for i in range(self._nb_bands):
                        plt.plot(data['distances'][b],
                                 [data['energy'][b][str(Spin.up)][i][j]
                                  for j in range(len(data['distances'][b]))], 'b-',
                                 linewidth=band_linewidth)
                        if self._bs.is_spin_polarized:
                            plt.plot(data['distances'][b],
                                     [data['energy'][b][str(Spin.down)][i][j]
                                      for j in range(len(data['distances'][b]))],
                                     'r--', linewidth=band_linewidth)
                            for j in range(len(data['energy'][b][str(Spin.up)][i])):
                                plt.plot(data['distances'][b][j],
                                         data['energy'][b][str(Spin.down)][i][j], 'ro',
                                         markersize=proj[b][str(Spin.down)][i][j][str(el)][o] * 15.0)
                        for j in range(len(data['energy'][b][str(Spin.up)][i])):
                            plt.plot(data['distances'][b][j],
                                     data['energy'][b][str(Spin.up)][i][j], 'bo',
                                     markersize=proj[b][str(Spin.up)][i][j][str(el)][o] * 15.0)
                if ylim is None:
                    if self._bs.is_metal():
                        if zero_to_efermi:
                            plt.ylim(e_min, e_max)
                        else:
                            plt.ylim(self._bs.efermi + e_min, self._bs._efermi
                                     + e_max)
                    else:

                        for cbm in data['cbm']:
                            plt.scatter(cbm[0], cbm[1], color='r', marker='o',
                                        s=100)

                        for vbm in data['vbm']:
                            plt.scatter(vbm[0], vbm[1], color='g', marker='o',
                                        s=100)

                        plt.ylim(data['vbm'][0][1] + e_min, data['cbm'][0][1]
                                 + e_max)
                else:
                    plt.ylim(ylim)
                plt.title(str(el) + " " + str(o))
                count += 1
        return plt
Ejemplo n.º 45
0
    def get_plot(self, zero_to_efermi=True, ylim=None, smooth=False):
        """
        get a matplotlib object for the bandstructure plot.
        Blue lines are up spin, red lines are down
        spin.

        Args:
            zero_to_efermi: Automatically subtract off the Fermi energy from
                the eigenvalues and plot (E-Ef).
            ylim: Specify the y-axis (energy) limits; by default None let
                the code choose. It is vbm-4 and cbm+4 if insulator
                efermi-10 and efermi+10 if metal
            smooth: interpolates the bands by a spline cubic
        """
        from pymatgen.util.plotting_utils import get_publication_quality_plot
        plt = get_publication_quality_plot(12, 8)
        from matplotlib import rc
        import scipy.interpolate as scint

        rc('text', usetex=True)

        #main internal config options
        e_min = -4
        e_max = 4
        if self._bs.is_metal():
            e_min = -10
            e_max = 10
        band_linewidth = 3

        data = self.bs_plot_data(zero_to_efermi)
        if not smooth:
            for d in range(len(data['distances'])):
                for i in range(self._nb_bands):
                    plt.plot(data['distances'][d],
                             [data['energy'][d][str(Spin.up)][i][j]
                              for j in range(len(data['distances'][d]))], 'b-',
                             linewidth=band_linewidth)
                    if self._bs.is_spin_polarized:
                        plt.plot(data['distances'][d],
                                 [data['energy'][d][str(Spin.down)][i][j]
                                  for j in range(len(data['distances'][d]))],
                                 'r--', linewidth=band_linewidth)
        else:
            for d in range(len(data['distances'])):
                for i in range(self._nb_bands):
                    tck = scint.splrep(
                        data['distances'][d],
                        [data['energy'][d][str(Spin.up)][i][j]
                         for j in range(len(data['distances'][d]))])
                    step = (data['distances'][d][-1]
                            - data['distances'][d][0]) / 1000

                    plt.plot([x * step+data['distances'][d][0]
                              for x in range(1000)],
                             [scint.splev(x * step+data['distances'][d][0],
                                          tck, der=0)
                              for x in range(1000)], 'b-',
                             linewidth=band_linewidth)

                    if self._bs.is_spin_polarized:

                        tck = scint.splrep(
                            data['distances'][d],
                            [data['energy'][d][str(Spin.down)][i][j]
                             for j in range(len(data['distances'][d]))])
                        step = (data['distances'][d][-1]
                                - data['distances'][d][0]) / 1000

                        plt.plot([x * step+data['distances'][d][0]
                                  for x in range(1000)],
                                 [scint.splev(x * step+data['distances'][d][0],
                                              tck, der=0)
                                  for x in range(1000)], 'r--',
                                 linewidth=band_linewidth)
        self._maketicks(plt)

        #Main X and Y Labels
        plt.xlabel(r'$\mathrm{Wave\ Vector}$', fontsize=30)
        ylabel = r'$\mathrm{E\ -\ E_f\ (eV)}$' if zero_to_efermi \
            else r'$\mathrm{Energy\ (eV)}$'
        plt.ylabel(ylabel, fontsize=30)

        # Draw Fermi energy, only if not the zero
        if not zero_to_efermi:
            ef = self._bs.efermi
            plt.axhline(ef, linewidth=2, color='k')

        # X range (K)
        #last distance point
        x_max = data['distances'][-1][-1]
        plt.xlim(0, x_max)

        if ylim is None:
            if self._bs.is_metal():
                # Plot A Metal
                if zero_to_efermi:
                    plt.ylim(e_min, e_max)
                else:
                    plt.ylim(self._bs.efermi + e_min, self._bs._efermi + e_max)
            else:
                for cbm in data['cbm']:
                    plt.scatter(cbm[0], cbm[1], color='r', marker='o', s=100)

                for vbm in data['vbm']:
                    plt.scatter(vbm[0], vbm[1], color='g', marker='o', s=100)
                plt.ylim(data['vbm'][0][1] + e_min, data['cbm'][0][1] + e_max)
        else:
            plt.ylim(ylim)

        plt.tight_layout()

        return plt
Ejemplo n.º 46
0
    def get_plot(self, xlim=None, ylim=None):
        """
        Get a matplotlib plot showing the DOS.

        Args:
            xlim: Specifies the x-axis limits. Set to None for automatic
                determination.
            ylim: Specifies the y-axis limits.
        """
        from pymatgen.util.plotting_utils import get_publication_quality_plot
        plt = get_publication_quality_plot(12, 8)
        color_order = ['r', 'b', 'g', 'c', 'm', 'k']

        y = None
        alldensities = []
        allenergies = []
        # Note that this complicated processing of energies is to allow for
        # stacked plots in matplotlib.
        for key, dos in self._doses.items():
            energies = dos['energies']
            densities = dos['densities']
            if not y:
                y = {Spin.up: np.zeros(energies.shape),
                     Spin.down: np.zeros(energies.shape)}
            newdens = {}
            for spin in [Spin.up, Spin.down]:
                if spin in densities:
                    if self.stack:
                        y[spin] += densities[spin]
                        newdens[spin] = y[spin].copy()
                    else:
                        newdens[spin] = densities[spin]
            allenergies.append(energies)
            alldensities.append(newdens)

        keys = list(self._doses.keys())
        keys.reverse()
        alldensities.reverse()
        allenergies.reverse()
        allpts = []
        for i, key in enumerate(keys):
            x = []
            y = []
            for spin in [Spin.up, Spin.down]:
                if spin in alldensities[i]:
                    densities = list(int(spin) * alldensities[i][spin])
                    energies = list(allenergies[i])
                    if spin == Spin.down:
                        energies.reverse()
                        densities.reverse()
                    x.extend(energies)
                    y.extend(densities)
            allpts.extend(zip(x, y))
            if self.stack:
                plt.fill(x, y, color=color_order[i % len(color_order)],
                         label=str(key))
            else:
                plt.plot(x, y, color=color_order[i % len(color_order)],
                         label=str(key),linewidth=3)
            if not self.zero_at_efermi:
                ylim = plt.ylim()
                plt.plot([self._doses[key]['efermi'],
                          self._doses[key]['efermi']], ylim,
                         color_order[i % 4] + '--', linewidth=2)

        plt.xlabel('Energies (eV)')
        plt.ylabel('Density of states')
        if xlim:
            plt.xlim(xlim)
        if ylim:
            plt.ylim(ylim)
        else:
            xlim = plt.xlim()
            relevanty = [p[1] for p in allpts
                         if xlim[0] < p[0] < xlim[1]]
            plt.ylim((min(relevanty), max(relevanty)))

        if self.zero_at_efermi:
            ylim = plt.ylim()
            plt.plot([0, 0], ylim, 'k--', linewidth=2)

        plt.legend()
        leg = plt.gca().get_legend()
        ltext = leg.get_texts()  # all the text.Text instance in the legend
        plt.setp(ltext, fontsize=30)
        plt.tight_layout()
        return plt
Ejemplo n.º 47
0
    def get_plot(self, xlim=None, ylim=None):
        """
        Get a matplotlib plot showing the DOS.

        Args:
            xlim: Specifies the x-axis limits. Set to None for automatic
                determination.
            ylim: Specifies the y-axis limits.
        """
        import prettyplotlib as ppl
        from prettyplotlib import brewer2mpl

        ncolors = max(3, len(self._doses))
        ncolors = min(9, ncolors)
        colors = brewer2mpl.get_map('Set1', 'qualitative', ncolors).mpl_colors

        y = None
        alldensities = []
        allfrequencies = []
        plt = get_publication_quality_plot(12, 8)

        # Note that this complicated processing of frequencies is to allow for
        # stacked plots in matplotlib.
        for key, dos in self._doses.items():
            frequencies = dos['frequencies']
            densities = dos['densities']
            if y is None:
                y = np.zeros(frequencies.shape)
            if self.stack:
                y += densities
                newdens = y.copy()
            else:
                newdens = densities
            allfrequencies.append(frequencies)
            alldensities.append(newdens)

        keys = list(self._doses.keys())
        keys.reverse()
        alldensities.reverse()
        allfrequencies.reverse()
        allpts = []
        for i, (key, frequencies,
                densities) in enumerate(zip(keys, allfrequencies,
                                            alldensities)):
            allpts.extend(list(zip(frequencies, densities)))
            if self.stack:
                plt.fill(frequencies,
                         densities,
                         color=colors[i % ncolors],
                         label=str(key))
            else:
                ppl.plot(frequencies,
                         densities,
                         color=colors[i % ncolors],
                         label=str(key),
                         linewidth=3)

        if xlim:
            plt.xlim(xlim)
        if ylim:
            plt.ylim(ylim)
        else:
            xlim = plt.xlim()
            relevanty = [p[1] for p in allpts if xlim[0] < p[0] < xlim[1]]
            plt.ylim((min(relevanty), max(relevanty)))

        ylim = plt.ylim()
        plt.plot([0, 0], ylim, 'k--', linewidth=2)

        plt.xlabel('Frequencies (THz)')
        plt.ylabel('Density of states')

        plt.legend()
        leg = plt.gca().get_legend()
        ltext = leg.get_texts()  # all the text.Text instance in the legend
        plt.setp(ltext, fontsize=30)
        plt.tight_layout()
        return plt
Ejemplo n.º 48
0
    def get_chempot_range_map_plot(self, elements):
        """
        Returns a plot of the chemical potential range map. Currently works
        only for 3-component PDs.

        Args:
            elements:
                Sequence of elements to be considered as independent variables.
                E.g., if you want to show the stability ranges of all Li-Co-O
                phases wrt to uLi and uO, you will supply
                [Element("Li"), Element("O")]
        Returns:
            A matplotlib plot object.
        """

        plt = get_publication_quality_plot(12, 8)
        analyzer = PDAnalyzer(self._pd)
        chempot_ranges = analyzer.get_chempot_range_map(elements)
        missing_lines = {}
        excluded_region = []
        for entry, lines in chempot_ranges.items():
            comp = entry.composition
            center_x = 0
            center_y = 0
            coords = []
            contain_zero = any([comp.get_atomic_fraction(el) == 0 for el in elements])
            is_boundary = (not contain_zero) and sum([comp.get_atomic_fraction(el) for el in elements]) == 1
            for line in lines:
                (x, y) = line.coords.transpose()
                plt.plot(x, y, "k-")

                for coord in line.coords:
                    if not in_coord_list(coords, coord):
                        coords.append(coord.tolist())
                        center_x += coord[0]
                        center_y += coord[1]
                if is_boundary:
                    excluded_region.extend(line.coords)

            if coords and contain_zero:
                missing_lines[entry] = coords
            else:
                xy = (center_x / len(coords), center_y / len(coords))
                plt.annotate(latexify(entry.name), xy, fontsize=22)

        ax = plt.gca()
        xlim = ax.get_xlim()
        ylim = ax.get_ylim()

        # Shade the forbidden chemical potential regions.
        excluded_region.append([xlim[1], ylim[1]])
        excluded_region = sorted(excluded_region, key=lambda c: c[0])
        (x, y) = np.transpose(excluded_region)
        plt.fill(x, y, "0.80")

        # The hull does not generate the missing horizontal and vertical lines.
        # The following code fixes this.
        el0 = elements[0]
        el1 = elements[1]
        for entry, coords in missing_lines.items():
            center_x = sum([c[0] for c in coords])
            center_y = sum([c[1] for c in coords])
            comp = entry.composition
            is_x = comp.get_atomic_fraction(el0) < 0.01
            is_y = comp.get_atomic_fraction(el1) < 0.01
            n = len(coords)
            if not (is_x and is_y):
                if is_x:
                    coords = sorted(coords, key=lambda c: c[1])
                    for i in [0, -1]:
                        x = [min(xlim), coords[i][0]]
                        y = [coords[i][1], coords[i][1]]
                        plt.plot(x, y, "k")
                        center_x += min(xlim)
                        center_y += coords[i][1]
                elif is_y:
                    coords = sorted(coords, key=lambda c: c[0])
                    for i in [0, -1]:
                        x = [coords[i][0], coords[i][0]]
                        y = [coords[i][1], min(ylim)]
                        plt.plot(x, y, "k")
                        center_x += coords[i][0]
                        center_y += min(ylim)
                xy = (center_x / (n + 2), center_y / (n + 2))
            else:
                center_x = sum(coord[0] for coord in coords) + xlim[0]
                center_y = sum(coord[1] for coord in coords) + ylim[0]
                xy = (center_x / (n + 1), center_y / (n + 1))

            plt.annotate(
                latexify(entry.name), xy, horizontalalignment="center", verticalalignment="center", fontsize=22
            )

        plt.xlabel("$\mu_{{{0}}} - \mu_{{{0}}}^0$ (eV)".format(el0.symbol))
        plt.ylabel("$\mu_{{{0}}} - \mu_{{{0}}}^0$ (eV)".format(el1.symbol))
        plt.tight_layout()
        return plt
Ejemplo n.º 49
0
import argparse

from pymatgen.io.feffio import Xmu
from pymatgen.util.plotting_utils import get_publication_quality_plot

parser = argparse.ArgumentParser(description='''Convenient DOS Plotter for Feff runs.
Author: Alan Dozier
Version: 1.0
Last updated: August, 2012''')

parser.add_argument('filename', metavar='filename', type=str, nargs=1,
                    help='xmu file to plot')
parser.add_argument('filename1', metavar='filename1', type=str, nargs=1,
                    help='feff.inp filename to import')

plt = get_publication_quality_plot(12, 8)
color_order = ['r', 'b', 'g', 'c', 'k', 'm', 'y']

args = parser.parse_args()
xmu = Xmu.from_file(args.filename[0], args.filename1[0])

data = xmu.to_dict

plt.title(data['calc'] + ' Feff9.6 Calculation for ' + data['atom'] + ' in ' +
          data['formula'] + ' unit cell')
plt.xlabel('Energies (eV)')
plt.ylabel('Absorption Cross-section')

x = data['energies']
y = data['scross']
tle = 'Single ' + data['atom'] + ' ' + data['edge'] + ' edge'
Ejemplo n.º 50
0
    def get_pourbaix_plot_colorfill_by_domain_name(self,
                                                   limits=None,
                                                   title="",
                                                   label_domains=True,
                                                   label_color='k',
                                                   domain_color=None,
                                                   domain_fontsize=None,
                                                   domain_edge_lw=0.5,
                                                   bold_domains=None,
                                                   cluster_domains=(),
                                                   add_h2o_stablity_line=True,
                                                   add_center_line=False,
                                                   h2o_lw=0.5):
        """
        Color domains by the colors specific by the domain_color dict

        Args:
            limits: 2D list containing limits of the Pourbaix diagram
                of the form [[xlo, xhi], [ylo, yhi]]
            lable_domains (Bool): whether add the text lable for domains
            label_color (str): color of domain lables, defaults to be black
            domain_color (dict): colors of each domain e.g {"Al(s)": "#FF1100"}. If set
                to None default color set will be used.
            domain_fontsize (int): Font size used in domain text labels.
            domain_edge_lw (int): line width for the boundaries between domains.
            bold_domains (list): List of domain names to use bold text style for domain
                lables.
            cluster_domains (list): List of domain names in cluster phase
            add_h2o_stablity_line (Bool): whether plot H2O stability line
            add_center_line (Bool): whether plot lines shows the center coordinate
            h2o_lw (int): line width for H2O stability line and center lines
        """

        # helper functions
        def len_elts(entry):
            comp = Composition(
                entry[:-3]) if "(s)" in entry else Ion.from_formula(entry)
            return len(set(comp.elements) - {Element("H"), Element("O")})

        def special_lines(xlim, ylim):
            h_line = np.transpose([[xlim[0], -xlim[0] * PREFAC],
                                   [xlim[1], -xlim[1] * PREFAC]])
            o_line = np.transpose([[xlim[0], -xlim[0] * PREFAC + 1.23],
                                   [xlim[1], -xlim[1] * PREFAC + 1.23]])
            neutral_line = np.transpose([[7, ylim[0]], [7, ylim[1]]])
            V0_line = np.transpose([[xlim[0], 0], [xlim[1], 0]])
            return h_line, o_line, neutral_line, V0_line

        from matplotlib.patches import Polygon
        from pymatgen import Composition, Element
        from pymatgen.core.ion import Ion

        default_domain_font_size = 12
        default_solid_phase_color = '#b8f9e7'  # this slighly darker than the MP scheme, to
        default_cluster_phase_color = '#d0fbef'  # avoid making the cluster phase too light

        plt = get_publication_quality_plot(8, dpi=300)

        (stable, unstable) = self.pourbaix_plot_data(limits)
        num_of_overlaps = {key: 0 for key in stable.keys()}
        entry_dict_of_multientries = collections.defaultdict(list)
        for entry in stable:
            if isinstance(entry, MultiEntry):
                for e in entry.entrylist:
                    entry_dict_of_multientries[e.name].append(entry)
                    num_of_overlaps[entry] += 1
            else:
                entry_dict_of_multientries[entry.name].append(entry)

        xlim, ylim = limits[:2] if limits else self._analyzer.chempot_limits[:2]
        h_line, o_line, neutral_line, V0_line = special_lines(xlim, ylim)
        ax = plt.gca()
        ax.set_xlim(xlim)
        ax.set_ylim(ylim)
        ax.xaxis.set_major_formatter(FormatStrFormatter('%.1f'))
        ax.yaxis.set_major_formatter(FormatStrFormatter('%.1f'))
        ax.tick_params(direction='out')
        ax.xaxis.set_ticks_position('bottom')
        ax.yaxis.set_ticks_position('left')

        sorted_entry = list(entry_dict_of_multientries.keys())
        sorted_entry.sort(key=len_elts)

        if domain_fontsize is None:
            domain_fontsize = {
                en: default_domain_font_size
                for en in sorted_entry
            }
        if domain_color is None:
            domain_color = {
                en: default_solid_phase_color if '(s)' in en else
                (default_cluster_phase_color if en in cluster_domains else 'w')
                for i, en in enumerate(sorted_entry)
            }
        if bold_domains is None:
            bold_domains = [en for en in sorted_entry if '(s)' not in en]

        for entry in sorted_entry:
            x_coord, y_coord, npts = 0.0, 0.0, 0
            for e in entry_dict_of_multientries[entry]:
                xy = self.domain_vertices(e)
                if add_h2o_stablity_line:
                    c = self.get_distribution_corrected_center(
                        stable[e], h_line, o_line, 0.3)
                else:
                    c = self.get_distribution_corrected_center(stable[e])
                x_coord += c[0]
                y_coord += c[1]
                npts += 1
                patch = Polygon(xy,
                                facecolor=domain_color[entry],
                                closed=True,
                                lw=domain_edge_lw,
                                fill=True,
                                antialiased=True)
                ax.add_patch(patch)
            xy_center = (x_coord / npts, y_coord / npts)
            if label_domains:
                if platform.system() == 'Darwin':
                    # Have to hack to the hard coded font path to get current font On Mac OS X
                    if entry in bold_domains:
                        font = FontProperties(
                            fname='/Library/Fonts/Times New Roman Bold.ttf',
                            size=domain_fontsize[entry])
                    else:
                        font = FontProperties(
                            fname='/Library/Fonts/Times New Roman.ttf',
                            size=domain_fontsize[entry])
                else:
                    if entry in bold_domains:
                        font = FontProperties(family='Times New Roman',
                                              weight='bold',
                                              size=domain_fontsize[entry])
                    else:
                        font = FontProperties(family='Times New Roman',
                                              weight='regular',
                                              size=domain_fontsize[entry])
                plt.text(*xy_center,
                         s=latexify_ion(latexify(entry)),
                         fontproperties=font,
                         horizontalalignment="center",
                         verticalalignment="center",
                         multialignment="center",
                         color=label_color)

        if add_h2o_stablity_line:
            dashes = (3, 1.5)
            line, = plt.plot(h_line[0],
                             h_line[1],
                             "k--",
                             linewidth=h2o_lw,
                             antialiased=True)
            line.set_dashes(dashes)
            line, = plt.plot(o_line[0],
                             o_line[1],
                             "k--",
                             linewidth=h2o_lw,
                             antialiased=True)
            line.set_dashes(dashes)
        if add_center_line:
            plt.plot(neutral_line[0],
                     neutral_line[1],
                     "k-.",
                     linewidth=h2o_lw,
                     antialiased=False)
            plt.plot(V0_line[0],
                     V0_line[1],
                     "k-.",
                     linewidth=h2o_lw,
                     antialiased=False)

        plt.xlabel("pH", fontname="Times New Roman", fontsize=18)
        plt.ylabel("E (V)", fontname="Times New Roman", fontsize=18)
        plt.xticks(fontname="Times New Roman", fontsize=16)
        plt.yticks(fontname="Times New Roman", fontsize=16)
        plt.title(title,
                  fontsize=20,
                  fontweight='bold',
                  fontname="Times New Roman")
        return plt
Ejemplo n.º 51
0
    def get_pourbaix_plot_colorfill_by_element(self,
                                               limits=None,
                                               title="",
                                               label_domains=True,
                                               element=None):
        """
        Color domains by element
        """
        from matplotlib.patches import Polygon

        entry_dict_of_multientries = collections.defaultdict(list)
        plt = get_publication_quality_plot(16)
        optim_colors = [
            '#0000FF', '#FF0000', '#00FF00', '#FFFF00', '#FF00FF', '#FF8080',
            '#DCDCDC', '#800000', '#FF8000'
        ]
        optim_font_color = [
            '#FFFFA0', '#00FFFF', '#FF00FF', '#0000FF', '#00FF00', '#007F7F',
            '#232323', '#7FFFFF', '#007FFF'
        ]
        hatch = ['/', '\\', '|', '-', '+', 'o', '*']
        (stable, unstable) = self.pourbaix_plot_data(limits)
        num_of_overlaps = {key: 0 for key in stable.keys()}
        for entry in stable:
            if isinstance(entry, MultiEntry):
                for e in entry.entrylist:
                    if element in e.composition.elements:
                        entry_dict_of_multientries[e.name].append(entry)
                        num_of_overlaps[entry] += 1
            else:
                entry_dict_of_multientries[entry.name].append(entry)
        if limits:
            xlim = limits[0]
            ylim = limits[1]
        else:
            xlim = self._analyzer.chempot_limits[0]
            ylim = self._analyzer.chempot_limits[1]

        h_line = np.transpose([[xlim[0], -xlim[0] * PREFAC],
                               [xlim[1], -xlim[1] * PREFAC]])
        o_line = np.transpose([[xlim[0], -xlim[0] * PREFAC + 1.23],
                               [xlim[1], -xlim[1] * PREFAC + 1.23]])
        neutral_line = np.transpose([[7, ylim[0]], [7, ylim[1]]])
        V0_line = np.transpose([[xlim[0], 0], [xlim[1], 0]])

        ax = plt.gca()
        ax.set_xlim(xlim)
        ax.set_ylim(ylim)
        from pymatgen import Composition, Element
        from pymatgen.core.ion import Ion

        def len_elts(entry):
            if "(s)" in entry:
                comp = Composition(entry[:-3])
            else:
                comp = Ion.from_formula(entry)
            return len([
                el for el in comp.elements
                if el not in [Element("H"), Element("O")]
            ])

        sorted_entry = entry_dict_of_multientries.keys()
        sorted_entry.sort(key=len_elts)
        i = -1
        label_chr = map(chr, list(range(65, 91)))
        for entry in sorted_entry:
            color_indx = 0
            x_coord = 0.0
            y_coord = 0.0
            npts = 0
            i += 1
            for e in entry_dict_of_multientries[entry]:
                hc = 0
                fc = 0
                bc = 0
                xy = self.domain_vertices(e)
                c = self.get_center(stable[e])
                x_coord += c[0]
                y_coord += c[1]
                npts += 1
                color_indx = i
                if "(s)" in entry:
                    comp = Composition(entry[:-3])
                else:
                    comp = Ion.from_formula(entry)
                if len([
                        el for el in comp.elements
                        if el not in [Element("H"), Element("O")]
                ]) == 1:
                    if color_indx >= len(optim_colors):
                        color_indx = color_indx -\
                         int(color_indx / len(optim_colors)) * len(optim_colors)
                    patch = Polygon(xy,
                                    facecolor=optim_colors[color_indx],
                                    closed=True,
                                    lw=3.0,
                                    fill=True)
                    bc = optim_colors[color_indx]
                else:
                    if color_indx >= len(hatch):
                        color_indx = color_indx - int(
                            color_indx / len(hatch)) * len(hatch)
                    patch = Polygon(xy,
                                    hatch=hatch[color_indx],
                                    closed=True,
                                    lw=3.0,
                                    fill=False)
                    hc = hatch[color_indx]
                ax.add_patch(patch)

            xy_center = (x_coord / npts, y_coord / npts)
            if label_domains:
                if color_indx >= len(optim_colors):
                    color_indx = color_indx -\
                        int(color_indx / len(optim_colors)) * len(optim_colors)
                fc = optim_font_color[color_indx]
                if bc and not hc:
                    bbox = dict(boxstyle="round", fc=fc)
                if hc and not bc:
                    bc = 'k'
                    fc = 'w'
                    bbox = dict(boxstyle="round", hatch=hc, fill=False)
                if bc and hc:
                    bbox = dict(boxstyle="round", hatch=hc, fc=fc)
#                 bbox.set_path_effects([PathEffects.withSimplePatchShadow()])
                plt.annotate(latexify_ion(latexify(entry)),
                             xy_center,
                             color=bc,
                             fontsize=30,
                             bbox=bbox)


#                 plt.annotate(label_chr[i], xy_center,
#                               color=bc, fontsize=30, bbox=bbox)

        lw = 3
        plt.plot(h_line[0], h_line[1], "r--", linewidth=lw)
        plt.plot(o_line[0], o_line[1], "r--", linewidth=lw)
        plt.plot(neutral_line[0], neutral_line[1], "k-.", linewidth=lw)
        plt.plot(V0_line[0], V0_line[1], "k-.", linewidth=lw)

        plt.xlabel("pH")
        plt.ylabel("E (V)")
        plt.title(title, fontsize=20, fontweight='bold')
        return plt
Ejemplo n.º 52
0
    def get_pourbaix_plot(self, limits=None, title="", label_domains=True):
        """
        Plot Pourbaix diagram.

        Args:
            limits: 2D list containing limits of the Pourbaix diagram
                of the form [[xlo, xhi], [ylo, yhi]]

        Returns:
            plt:
                matplotlib plot object
        """
        #        plt = get_publication_quality_plot(24, 14.4)
        plt = get_publication_quality_plot(16)
        (stable, unstable) = self.pourbaix_plot_data(limits)
        if limits:
            xlim = limits[0]
            ylim = limits[1]
        else:
            xlim = self._analyzer.chempot_limits[0]
            ylim = self._analyzer.chempot_limits[1]

        h_line = np.transpose([[xlim[0], -xlim[0] * PREFAC],
                               [xlim[1], -xlim[1] * PREFAC]])
        o_line = np.transpose([[xlim[0], -xlim[0] * PREFAC + 1.23],
                               [xlim[1], -xlim[1] * PREFAC + 1.23]])
        neutral_line = np.transpose([[7, ylim[0]], [7, ylim[1]]])
        V0_line = np.transpose([[xlim[0], 0], [xlim[1], 0]])

        ax = plt.gca()
        ax.set_xlim(xlim)
        ax.set_ylim(ylim)
        lw = 3
        plt.plot(h_line[0], h_line[1], "r--", linewidth=lw)
        plt.plot(o_line[0], o_line[1], "r--", linewidth=lw)
        plt.plot(neutral_line[0], neutral_line[1], "k-.", linewidth=lw)
        plt.plot(V0_line[0], V0_line[1], "k-.", linewidth=lw)

        for entry, lines in stable.items():
            center_x = 0.0
            center_y = 0.0
            coords = []
            count_center = 0.0
            for line in lines:
                (x, y) = line
                plt.plot(x, y, "k-", linewidth=lw)
                for coord in np.array(line).T:
                    if not in_coord_list(coords, coord):
                        coords.append(coord.tolist())
                        cx = coord[0]
                        cy = coord[1]
                        center_x += cx
                        center_y += cy
                        count_center += 1.0
            if count_center == 0.0:
                count_center = 1.0
            center_x /= count_center
            center_y /= count_center
            if ((center_x <= xlim[0]) | (center_x >= xlim[1]) |
                (center_y <= ylim[0]) | (center_y >= ylim[1])):
                continue
            xy = (center_x, center_y)
            if label_domains:
                plt.annotate(self.print_name(entry),
                             xy,
                             fontsize=20,
                             color="b")

        plt.xlabel("pH")
        plt.ylabel("E (V)")
        plt.title(title, fontsize=20, fontweight='bold')
        return plt
Ejemplo n.º 53
0
    def get_elt_projected_plots_color(self, zero_to_efermi=True,
                                      elt_ordered=None):
        """
        returns a pylab plot object with one plot where the band structure
        line color depends on the character of the band (along different
        elements). Each element is associated with red, green or blue
        and the corresponding rgb color depending on the character of the band
        is used. The method can only deal with binary and ternary compounds

        spin up and spin down are differientiated by a '-' and a '--' line

        Args:
            elt_ordered: A list of Element ordered. The first one is red,
                second green, last blue

        Returns:
            a pylab object

        """
        band_linewidth = 3.0
        if len(self._bs._structure.composition.elements) > 3:
            raise ValueError
        if elt_ordered is None:
            elt_ordered = self._bs._structure.composition.elements
        proj = self._get_projections_by_branches(
            {e.symbol: ['s', 'p', 'd']
             for e in self._bs._structure.composition.elements})
        data = self.bs_plot_data(zero_to_efermi)
        from pymatgen.util.plotting_utils import get_publication_quality_plot
        plt = get_publication_quality_plot(12, 8)

        spins = [Spin.up]
        if self._bs.is_spin_polarized:
            spins = [Spin.up, Spin.down]
        self._maketicks(plt)
        for s in spins:
            for b in range(len(data['distances'])):
                for i in range(self._nb_bands):
                    for j in range(len(data['energy'][b][str(s)][i]) - 1):
                        sum_e = 0.0
                        for el in elt_ordered:
                            sum_e = sum_e + \
                                    sum([proj[b][str(s)][i][j][str(el)][o]
                                         for o
                                         in proj[b][str(s)][i][j][str(el)]])
                        if sum_e == 0.0:
                            color = [0.0] * len(elt_ordered)
                        else:
                            color = [sum([proj[b][str(s)][i][j][str(el)][o]
                                          for o
                                          in proj[b][str(s)][i][j][str(el)]])
                                     / sum_e
                                     for el in elt_ordered]
                        if len(color) == 2:
                            color.append(0.0)
                            color[2] = color[1]
                            color[1] = 0.0
                        sign = '-'
                        if s == Spin.down:
                            sign = '--'
                        plt.plot([data['distances'][b][j],
                                  data['distances'][b][j + 1]],
                                 [data['energy'][b][str(s)][i][j],
                                  data['energy'][b][str(s)][i][j + 1]], sign,
                                 color=color, linewidth=band_linewidth)

        plt.ylim(data['vbm'][0][1] - 4.0, data['cbm'][0][1] + 2.0)
        return plt
Ejemplo n.º 54
0
    def get_pourbaix_mark_passive(self, limits=None, title="", label_domains=True, passive_entry=None):
        """
        Color domains by element
        """
        from matplotlib.patches import Polygon
        from pymatgen import Element
        from itertools import chain
        import operator

        plt = get_publication_quality_plot(16)
        optim_colors = ['#0000FF', '#FF0000', '#00FF00', '#FFFF00', '#FF00FF',
                        '#FF8080', '#DCDCDC', '#800000', '#FF8000']
        optim_font_colors = ['#FFC000', '#00FFFF', '#FF00FF', '#0000FF', '#00FF00',
                            '#007F7F', '#232323', '#7FFFFF', '#007FFF']
        (stable, unstable) = self.pourbaix_plot_data(limits)
        mark_passive = {key: 0 for key in stable.keys()}

        if self._pd._elt_comp:
            maxval = max(self._pd._elt_comp.iteritems(), key=operator.itemgetter(1))[1]
            key = [k for k, v in self._pd._elt_comp.items() if v == maxval]
        passive_entry = key[0]

        def list_elts(entry):
            elts_list = set()
            if isinstance(entry, MultiEntry):
                for el in chain.from_iterable([[el for el in e.composition.elements]
                                                for e in entry.entrylist]):
                    elts_list.add(el)
            else:
                elts_list = entry.composition.elements
            return elts_list

        for entry in stable:
            if passive_entry + str("(s)") in entry.name:
                mark_passive[entry] = 2
                continue
            if "(s)" not in entry.name:
                continue
            elif len(set([Element("O"), Element("H")]).intersection(set(list_elts(entry)))) > 0:
                mark_passive[entry] = 1

        if limits:
            xlim = limits[0]
            ylim = limits[1]
        else:
            xlim = self._analyzer.chempot_limits[0]
            ylim = self._analyzer.chempot_limits[1]

        h_line = np.transpose([[xlim[0], -xlim[0] * PREFAC],
                               [xlim[1], -xlim[1] * PREFAC]])
        o_line = np.transpose([[xlim[0], -xlim[0] * PREFAC + 1.23],
                               [xlim[1], -xlim[1] * PREFAC + 1.23]])
        neutral_line = np.transpose([[7, ylim[0]], [7, ylim[1]]])
        V0_line = np.transpose([[xlim[0], 0], [xlim[1], 0]])

        ax = plt.gca()
        ax.set_xlim(xlim)
        ax.set_ylim(ylim)
        for e in stable.keys():
            xy = self.domain_vertices(e)
            c = self.get_center(stable[e])
            if mark_passive[e] == 1:
                color = optim_colors[0]
                fontcolor = optim_font_colors[0]
                colorfill = True
            elif mark_passive[e] == 2:
                color = optim_colors[1]
                fontcolor = optim_font_colors[1]
                colorfill = True
            else:
                color = "w"
                colorfill = False
                fontcolor = "k"
            patch = Polygon(xy, facecolor=color, closed=True, lw=3.0, fill=colorfill)
            ax.add_patch(patch)
            if label_domains:
                plt.annotate(self.print_name(e), c, color=fontcolor, fontsize=20)

        lw = 3
        plt.plot(h_line[0], h_line[1], "r--", linewidth=lw)
        plt.plot(o_line[0], o_line[1], "r--", linewidth=lw)
        plt.plot(neutral_line[0], neutral_line[1], "k-.", linewidth=lw)
        plt.plot(V0_line[0], V0_line[1], "k-.", linewidth=lw)

        plt.xlabel("pH")
        plt.ylabel("E (V)")
        plt.title(title, fontsize=20, fontweight='bold')
        return plt
Ejemplo n.º 55
0
    def get_pourbaix_mark_passive(self,
                                  limits=None,
                                  title="",
                                  label_domains=True,
                                  passive_entry=None):
        """
        Color domains by element
        """
        from matplotlib.patches import Polygon
        from pymatgen import Element
        from itertools import chain
        import operator

        plt = get_publication_quality_plot(16)
        optim_colors = [
            '#0000FF', '#FF0000', '#00FF00', '#FFFF00', '#FF00FF', '#FF8080',
            '#DCDCDC', '#800000', '#FF8000'
        ]
        optim_font_colors = [
            '#FFC000', '#00FFFF', '#FF00FF', '#0000FF', '#00FF00', '#007F7F',
            '#232323', '#7FFFFF', '#007FFF'
        ]
        (stable, unstable) = self.pourbaix_plot_data(limits)
        mark_passive = {key: 0 for key in stable.keys()}

        if self._pd._elt_comp:
            maxval = max(six.iteritems(self._pd._elt_comp),
                         key=operator.itemgetter(1))[1]
            key = [k for k, v in self._pd._elt_comp.items() if v == maxval]
        passive_entry = key[0]

        def list_elts(entry):
            elts_list = set()
            if isinstance(entry, MultiEntry):
                for el in chain.from_iterable(
                    [[el for el in e.composition.elements]
                     for e in entry.entrylist]):
                    elts_list.add(el)
            else:
                elts_list = entry.composition.elements
            return elts_list

        for entry in stable:
            if passive_entry + str("(s)") in entry.name:
                mark_passive[entry] = 2
                continue
            if "(s)" not in entry.name:
                continue
            elif len(
                    set([Element("O"), Element("H")]).intersection(
                        set(list_elts(entry)))) > 0:
                mark_passive[entry] = 1

        if limits:
            xlim = limits[0]
            ylim = limits[1]
        else:
            xlim = self._analyzer.chempot_limits[0]
            ylim = self._analyzer.chempot_limits[1]

        h_line = np.transpose([[xlim[0], -xlim[0] * PREFAC],
                               [xlim[1], -xlim[1] * PREFAC]])
        o_line = np.transpose([[xlim[0], -xlim[0] * PREFAC + 1.23],
                               [xlim[1], -xlim[1] * PREFAC + 1.23]])
        neutral_line = np.transpose([[7, ylim[0]], [7, ylim[1]]])
        V0_line = np.transpose([[xlim[0], 0], [xlim[1], 0]])

        ax = plt.gca()
        ax.set_xlim(xlim)
        ax.set_ylim(ylim)
        for e in stable.keys():
            xy = self.domain_vertices(e)
            c = self.get_center(stable[e])
            if mark_passive[e] == 1:
                color = optim_colors[0]
                fontcolor = optim_font_colors[0]
                colorfill = True
            elif mark_passive[e] == 2:
                color = optim_colors[1]
                fontcolor = optim_font_colors[1]
                colorfill = True
            else:
                color = "w"
                colorfill = False
                fontcolor = "k"
            patch = Polygon(xy,
                            facecolor=color,
                            closed=True,
                            lw=3.0,
                            fill=colorfill)
            ax.add_patch(patch)
            if label_domains:
                plt.annotate(self.print_name(e),
                             c,
                             color=fontcolor,
                             fontsize=20)

        lw = 3
        plt.plot(h_line[0], h_line[1], "r--", linewidth=lw)
        plt.plot(o_line[0], o_line[1], "r--", linewidth=lw)
        plt.plot(neutral_line[0], neutral_line[1], "k-.", linewidth=lw)
        plt.plot(V0_line[0], V0_line[1], "k-.", linewidth=lw)

        plt.xlabel("pH")
        plt.ylabel("E (V)")
        plt.title(title, fontsize=20, fontweight='bold')
        return plt
Ejemplo n.º 56
0
    def show(self, xlim=None, ylim=None):
        """
        Show the plot using matplotlib.
        
        Args:
            xlim:
                Specifies the x-axis limits. Set to None for automatic 
                determination.
            ylim:
                Specifies the y-axis limits. 
        """
        plt = get_publication_quality_plot(12, 8)
        color_order = ["r", "b", "g", "c"]

        y = None
        alldensities = []
        allenergies = []
        """
        Note that this complicated processing of energies is to allow for
        stacked plots in matplotlib.
        """
        for key, dos in self._doses.items():
            energies = dos["energies"]
            densities = dos["densities"]
            if not y:
                y = {Spin.up: np.zeros(energies.shape), Spin.down: np.zeros(energies.shape)}
            newdens = {}
            for spin in [Spin.up, Spin.down]:
                if spin in densities:
                    if self.stack:
                        y[spin] += densities[spin]
                        newdens[spin] = y[spin].copy()
                    else:
                        newdens[spin] = densities[spin]
            allenergies.append(energies)
            alldensities.append(newdens)

        keys = list(self._doses.keys())
        keys.reverse()
        alldensities.reverse()
        allenergies.reverse()
        allpts = []
        for i, key in enumerate(keys):
            x = []
            y = []
            for spin in [Spin.up, Spin.down]:
                if spin in alldensities[i]:
                    densities = list(int(spin) * alldensities[i][spin])
                    energies = list(allenergies[i])
                    if spin == Spin.down:
                        energies.reverse()
                        densities.reverse()
                    x.extend(energies)
                    y.extend(densities)
            allpts.extend(zip(x, y))
            if self.stack:
                plt.fill(x, y, color=color_order[i % 4], label=str(key))
            else:
                plt.plot(x, y, color=color_order[i % 4], label=str(key))
            if not self.zero_at_efermi:
                ylim = plt.ylim()
                plt.plot(
                    [self._doses[key]["efermi"], self._doses[key]["efermi"]],
                    ylim,
                    color_order[i % 4] + "--",
                    linewidth=2,
                )

        plt.xlabel("Energies (eV)")
        plt.ylabel("Density of states")
        if xlim:
            plt.xlim(xlim)
        if ylim:
            plt.ylim(ylim)
        else:
            xlim = plt.xlim()
            relevanty = [p[1] for p in allpts if p[0] > xlim[0] and p[0] < xlim[1]]
            plt.ylim((min(relevanty), max(relevanty)))

        if self.zero_at_efermi:
            ylim = plt.ylim()
            plt.plot([0, 0], ylim, "k--", linewidth=2)

        plt.legend()
        leg = plt.gca().get_legend()
        ltext = leg.get_texts()  # all the text.Text instance in the legend
        plt.setp(ltext, fontsize=30)
        plt.tight_layout()
        plt.show()
Ejemplo n.º 57
0
    def get_pourbaix_plot(self, limits=None, title="", label_domains=True):
        """
        Plot Pourbaix diagram.

        Args:
            limits: 2D list containing limits of the Pourbaix diagram
                of the form [[xlo, xhi], [ylo, yhi]]

        Returns:
            plt:
                matplotlib plot object
        """
#        plt = get_publication_quality_plot(24, 14.4)
        plt = get_publication_quality_plot(16)
        (stable, unstable) = self.pourbaix_plot_data(limits)
        if limits:
            xlim = limits[0]
            ylim = limits[1]
        else:
            xlim = self._analyzer.chempot_limits[0]
            ylim = self._analyzer.chempot_limits[1]

        h_line = np.transpose([[xlim[0], -xlim[0] * PREFAC],
                               [xlim[1], -xlim[1] * PREFAC]])
        o_line = np.transpose([[xlim[0], -xlim[0] * PREFAC + 1.23],
                               [xlim[1], -xlim[1] * PREFAC + 1.23]])
        neutral_line = np.transpose([[7, ylim[0]], [7, ylim[1]]])
        V0_line = np.transpose([[xlim[0], 0], [xlim[1], 0]])

        ax = plt.gca()
        ax.set_xlim(xlim)
        ax.set_ylim(ylim)
        lw = 3
        plt.plot(h_line[0], h_line[1], "r--", linewidth=lw)
        plt.plot(o_line[0], o_line[1], "r--", linewidth=lw)
        plt.plot(neutral_line[0], neutral_line[1], "k-.", linewidth=lw)
        plt.plot(V0_line[0], V0_line[1], "k-.", linewidth=lw)

        for entry, lines in stable.items():
            center_x = 0.0
            center_y = 0.0
            coords = []
            count_center = 0.0
            for line in lines:
                (x, y) = line
                plt.plot(x, y, "k-", linewidth=lw)
                for coord in np.array(line).T:
                    if not in_coord_list(coords, coord):
                        coords.append(coord.tolist())
                        cx = coord[0]
                        cy = coord[1]
                        center_x += cx
                        center_y += cy
                        count_center += 1.0
            if count_center == 0.0:
                count_center = 1.0
            center_x /= count_center
            center_y /= count_center
            if ((center_x <= xlim[0]) | (center_x >= xlim[1]) |
                    (center_y <= ylim[0]) | (center_y >= ylim[1])):
                continue
            xy = (center_x, center_y)
            if label_domains:
                plt.annotate(self.print_name(entry), xy, fontsize=20, color="b")

        plt.xlabel("pH")
        plt.ylabel("E (V)")
        plt.title(title, fontsize=20, fontweight='bold')
        return plt
Ejemplo n.º 58
0
    def get_pourbaix_plot_colorfill_by_element(self, limits=None, title="",
                                                label_domains=True, element=None):
        """
        Color domains by element
        """
        from matplotlib.patches import Polygon
        import matplotlib.patheffects as PathEffects

        entry_dict_of_multientries = collections.defaultdict(list)
        plt = get_publication_quality_plot(16)
        optim_colors = ['#0000FF', '#FF0000', '#00FF00', '#FFFF00', '#FF00FF',
                         '#FF8080', '#DCDCDC', '#800000', '#FF8000']
        optim_font_color = ['#FFFFA0', '#00FFFF', '#FF00FF', '#0000FF', '#00FF00',
                            '#007F7F', '#232323', '#7FFFFF', '#007FFF']
        hatch = ['/', '\\', '|', '-', '+', 'o', '*']
        (stable, unstable) = self.pourbaix_plot_data(limits)
        num_of_overlaps = {key: 0 for key in stable.keys()}
        for entry in stable:
            if isinstance(entry, MultiEntry):
                for e in entry.entrylist:
                    if element in e.composition.elements:
                        entry_dict_of_multientries[e.name].append(entry)
                        num_of_overlaps[entry] += 1
            else:
                entry_dict_of_multientries[entry.name].append(entry)
        if limits:
            xlim = limits[0]
            ylim = limits[1]
        else:
            xlim = self._analyzer.chempot_limits[0]
            ylim = self._analyzer.chempot_limits[1]

        h_line = np.transpose([[xlim[0], -xlim[0] * PREFAC],
                               [xlim[1], -xlim[1] * PREFAC]])
        o_line = np.transpose([[xlim[0], -xlim[0] * PREFAC + 1.23],
                               [xlim[1], -xlim[1] * PREFAC + 1.23]])
        neutral_line = np.transpose([[7, ylim[0]], [7, ylim[1]]])
        V0_line = np.transpose([[xlim[0], 0], [xlim[1], 0]])

        ax = plt.gca()
        ax.set_xlim(xlim)
        ax.set_ylim(ylim)
        from pymatgen import Composition, Element
        from pymatgen.core.ion import Ion

        def len_elts(entry):
            if "(s)" in entry:
                comp = Composition(entry[:-3])
            else:
                comp = Ion.from_formula(entry)
            return len([el for el in comp.elements if el not in
                        [Element("H"), Element("O")]])

        sorted_entry = entry_dict_of_multientries.keys()
        sorted_entry.sort(key=len_elts)
        i = -1
        label_chr = map(chr, range(65, 91))
        for entry in sorted_entry:
            color_indx = 0
            x_coord = 0.0
            y_coord = 0.0
            npts = 0
            i += 1
            for e in entry_dict_of_multientries[entry]:
                hc = 0
                fc = 0
                bc = 0
                xy = self.domain_vertices(e)
                c = self.get_center(stable[e])
                x_coord += c[0]
                y_coord += c[1]
                npts += 1
                color_indx = i
                if "(s)" in entry:
                    comp = Composition(entry[:-3])
                else:
                    comp = Ion.from_formula(entry)
                if len([el for el in comp.elements if el not in
                         [Element("H"), Element("O")]]) == 1:
                    if color_indx >= len(optim_colors):
                        color_indx = color_indx -\
                         int(color_indx / len(optim_colors)) * len(optim_colors)
                    patch = Polygon(xy, facecolor=optim_colors[color_indx],
                                     closed=True, lw=3.0, fill=True)
                    bc = optim_colors[color_indx]
                else:
                    if color_indx >= len(hatch):
                        color_indx = color_indx - int(color_indx / len(hatch)) * len(hatch)
                    patch = Polygon(xy, hatch=hatch[color_indx], closed=True, lw=3.0, fill=False)
                    hc = hatch[color_indx]
                ax.add_patch(patch)
                
            xy_center = (x_coord / npts, y_coord / npts)
            if label_domains:
                if color_indx >= len(optim_colors):
                    color_indx = color_indx -\
                        int(color_indx / len(optim_colors)) * len(optim_colors)
                fc = optim_font_color[color_indx]
                if bc and not hc:
                    bbox = dict(boxstyle="round", fc=fc)
                if hc and not bc:
                    bc = 'k'
                    fc = 'w'
                    bbox = dict(boxstyle="round", hatch=hc, fill=False)
                if bc and hc:
                    bbox = dict(boxstyle="round", hatch=hc, fc=fc)
#                 bbox.set_path_effects([PathEffects.withSimplePatchShadow()])
#                 plt.annotate(latexify_ion(latexify(entry)), xy_center,
#                               color=fc, fontsize=30, bbox=bbox)
                plt.annotate(label_chr[i], xy_center,
                              color=bc, fontsize=30, bbox=bbox)

        lw = 3
        plt.plot(h_line[0], h_line[1], "r--", linewidth=lw)
        plt.plot(o_line[0], o_line[1], "r--", linewidth=lw)
        plt.plot(neutral_line[0], neutral_line[1], "k-.", linewidth=lw)
        plt.plot(V0_line[0], V0_line[1], "k-.", linewidth=lw)

        plt.xlabel("pH")
        plt.ylabel("E (V)")
        plt.title(title, fontsize=20, fontweight='bold')
        return plt
Ejemplo n.º 59
0
    def get_elt_projected_plots_color(self,
                                      zero_to_efermi=True,
                                      elt_ordered=None):
        """
        returns a pylab plot object with one plot where the band structure
        line color depends on the character of the band (along different
        elements). Each element is associated with red, green or blue
        and the corresponding rgb color depending on the character of the band
        is used. The method can only deal with binary and ternary compounds

        spin up and spin down are differientiated by a '-' and a '--' line

        Args:
            elt_ordered: A list of Element ordered. The first one is red,
                second green, last blue

        Returns:
            a pylab object

        """
        band_linewidth = 3.0
        if len(self._bs._structure.composition.elements) > 3:
            raise ValueError
        if elt_ordered is None:
            elt_ordered = self._bs._structure.composition.elements
        proj = self._get_projections_by_branches({
            e.symbol: ['s', 'p', 'd']
            for e in self._bs._structure.composition.elements
        })
        data = self.bs_plot_data(zero_to_efermi)
        from pymatgen.util.plotting_utils import get_publication_quality_plot
        plt = get_publication_quality_plot(12, 8)

        spins = [Spin.up]
        if self._bs.is_spin_polarized:
            spins = [Spin.up, Spin.down]
        self._maketicks(plt)
        for s in spins:
            for b in range(len(data['distances'])):
                for i in range(self._nb_bands):
                    for j in range(len(data['energy'][b][str(s)][i]) - 1):
                        sum_e = 0.0
                        for el in elt_ordered:
                            sum_e = sum_e + \
                                    sum([proj[b][str(s)][i][j][str(el)][o]
                                         for o
                                         in proj[b][str(s)][i][j][str(el)]])
                        if sum_e == 0.0:
                            color = [0.0] * len(elt_ordered)
                        else:
                            color = [
                                sum([
                                    proj[b][str(s)][i][j][str(el)][o]
                                    for o in proj[b][str(s)][i][j][str(el)]
                                ]) / sum_e for el in elt_ordered
                            ]
                        if len(color) == 2:
                            color.append(0.0)
                            color[2] = color[1]
                            color[1] = 0.0
                        sign = '-'
                        if s == Spin.down:
                            sign = '--'
                        plt.plot([
                            data['distances'][b][j],
                            data['distances'][b][j + 1]
                        ], [
                            data['energy'][b][str(s)][i][j],
                            data['energy'][b][str(s)][i][j + 1]
                        ],
                                 sign,
                                 color=color,
                                 linewidth=band_linewidth)

        plt.ylim(data['vbm'][0][1] - 4.0, data['cbm'][0][1] + 2.0)
        return plt
Ejemplo n.º 60
0
    def _get_2d_plot(self, label_stable=True, label_unstable=True):
        """
        Shows the plot using pylab.  Usually I won"t do imports in methods,
        but since plotting is a fairly expensive library to load and not all
        machines have matplotlib installed, I have done it this way.
        """

        plt = get_publication_quality_plot(8, 6)
        from matplotlib.font_manager import FontProperties

        (lines, labels, unstable) = self.pd_plot_data
        for x, y in lines:
            plt.plot(x, y, "ko-", linewidth=3, markeredgecolor="k", markerfacecolor="b", markersize=15)
        font = FontProperties()
        font.set_weight("bold")
        font.set_size(24)

        # Sets a nice layout depending on the type of PD. Also defines a
        # "center" for the PD, which then allows the annotations to be spread
        # out in a nice manner.
        if len(self._pd.elements) == 3:
            plt.axis("equal")
            plt.xlim((-0.1, 1.2))
            plt.ylim((-0.1, 1.0))
            plt.axis("off")
            center = (0.5, math.sqrt(3) / 6)
        else:
            all_coords = labels.keys()
            miny = min([c[1] for c in all_coords])
            ybuffer = max(abs(miny) * 0.1, 0.1)
            plt.xlim((-0.1, 1.1))
            plt.ylim((miny - ybuffer, ybuffer))
            center = (0.5, miny / 2)
            plt.xlabel("Fraction", fontsize=28, fontweight="bold")
            plt.ylabel("Formation energy (eV/fu)", fontsize=28, fontweight="bold")

        for coords in sorted(labels.keys(), key=lambda x: -x[1]):
            entry = labels[coords]
            label = entry.name

            # The follow defines an offset for the annotation text emanating
            # from the center of the PD. Results in fairly nice layouts for the
            # most part.
            vec = np.array(coords) - center
            vec = vec / np.linalg.norm(vec) * 10 if np.linalg.norm(vec) != 0 else vec
            valign = "bottom" if vec[1] > 0 else "top"
            if vec[0] < -0.01:
                halign = "right"
            elif vec[0] > 0.01:
                halign = "left"
            else:
                halign = "center"
            if label_stable:
                plt.annotate(
                    latexify(label),
                    coords,
                    xytext=vec,
                    textcoords="offset points",
                    horizontalalignment=halign,
                    verticalalignment=valign,
                    fontproperties=font,
                )

        if self.show_unstable:
            font = FontProperties()
            font.set_size(16)
            for entry, coords in unstable.items():
                vec = np.array(coords) - center
                vec = vec / np.linalg.norm(vec) * 10
                label = entry.name
                plt.plot(
                    coords[0], coords[1], "ks", linewidth=3, markeredgecolor="k", markerfacecolor="r", markersize=8
                )
                if label_unstable:
                    plt.annotate(
                        latexify(label),
                        coords,
                        xytext=vec,
                        textcoords="offset points",
                        horizontalalignment=halign,
                        color="b",
                        verticalalignment=valign,
                        fontproperties=font,
                    )
        F = plt.gcf()
        F.set_size_inches((8, 6))
        plt.subplots_adjust(left=0.09, right=0.98, top=0.98, bottom=0.07)
        return plt