Пример #1
0
    def gridplot_eos(self, eos_names="all", fontsize=6, **kwargs):
        """
        Plot multiple EOS on a grid with captions showing the final results.

        Args:
            eos_names: String or list of strings with EOS names. See pymatgen.analysis.EOS
            fontsize: Fontsize used for caption text.

        Returns: |matplotlib-Figure|
        """
        r = self.get_eos_fits_dataframe(eos_names=eos_names)

        num_plots, ncols, nrows = len(r.fits), 1, 1
        if num_plots > 1:
            ncols = 2
            nrows = (num_plots // ncols) + (num_plots % ncols)

        # Build grid of plots.
        ax_list, fig, plt = get_axarray_fig_plt(None,
                                                nrows=nrows,
                                                ncols=ncols,
                                                sharex=False,
                                                sharey=False,
                                                squeeze=False)
        ax_list = ax_list.ravel()

        for i, (fit, ax) in enumerate(zip(r.fits, ax_list)):
            fit.plot_ax(ax=ax, fontsize=fontsize, label="", show=False)

        # Get around a bug in matplotlib
        if num_plots % ncols != 0:
            ax_list[-1].axis('off')

        return fig
Пример #2
0
    def plot_gkq2_diff(self, iref=0, **kwargs):
        """
        Wraps gkq.plot_diff_with_other
        Produce scatter and histogram plot to compare the gkq matrix elements stored in all the files
        contained in the robot. Assume all files have the same q-point. Compare the `iref` file with others.
        kwargs are passed to `plot_diff_with_other`.
        """
        if len(self) <= 1: return None
        self._check_qpoints_equal()

        ncols, nrows = 2, len(self) - 1
        num_plots = ncols * nrows
        ax_mat, fig, plt = get_axarray_fig_plt(None,
                                               nrows=nrows,
                                               ncols=ncols,
                                               sharex=False,
                                               sharey=False,
                                               squeeze=False)

        ref_gkq, ref_label = self.abifiles[iref], self.labels[iref]
        cnt = -1
        for ifile, (other_label,
                    other_gkq) in enumerate(zip(self.labels, self.abifiles)):
            if ifile == iref: continue
            cnt += 1
            labels = [ref_label, other_label]
            ref_gkq.plot_diff_with_other(other_gkq,
                                         ax_list=ax_mat[cnt],
                                         labels=labels,
                                         show=False,
                                         **kwargs)

        return fig
Пример #3
0
    def plot_ekmap_temps(self, temp_inds=None, spins=None, estep=0.02, with_colorbar=True,
                        ylims=None, fontsize=8, **kwargs):
        """
        Plot (k, e) color maps for different temperatures.

        Args:
            fontsize (int): fontsize for titles and legend

        Return: |matplotlib-Figure|
        """
        temp_inds = range(self.ntemp) if temp_inds is None else temp_inds
        # Build plot grid.
        num_plots, ncols, nrows = len(temp_inds), 1, 1
        if num_plots > 1:
            ncols = 2
            nrows = (num_plots // ncols) + (num_plots % ncols)

        ax_list, fig, plt = get_axarray_fig_plt(None, nrows=nrows, ncols=ncols,
                                                sharex=True, sharey=True, squeeze=False)
        ax_list = ax_list.ravel()

        # Don't show the last ax if numeb is odd.
        if num_plots % ncols != 0: ax_list[-1].axis("off")

        for itemp, ax in zip(temp_inds, ax_list):
            self.plot_ekmap_itemp(itemp=itemp, spins=spins, estep=estep, ax=ax, ylims=ylims,
                    with_colorbar=with_colorbar, show=False, **kwargs)
            ax.set_title("T = %.1f K" % self.tmesh[itemp], fontsize=fontsize)

        return fig
Пример #4
0
    def plot(self, **kwargs):
        """
        Driver routine to plot several quantities on the same graph.

        Args:
            ecut_ffnl: Max cutoff energy for ffnl plot (optional)

        Return: |matplotlib-Figure|
        """
        methods = [
            "plot_tcore_rspace",
            "plot_tcore_qspace",
            "plot_ffspl",
            "plot_vlocq",
        ]

        ax_list, fig, plt = get_axarray_fig_plt(None,
                                                nrows=2,
                                                ncols=2,
                                                sharex=False,
                                                sharey=False,
                                                squeeze=True)

        ecut_ffnl = kwargs.pop("ecut_ffnl", None)
        for m, ax in zip(methods, ax_list.ravel()):
            getattr(self, m)(ax=ax, ecut_ffnl=ecut_ffnl, show=False)

        return fig
Пример #5
0
    def plot_ekmap_temps(self, temp_inds=None, spins=None, estep=0.02, with_colorbar=True,
                        ylims=None, fontsize=8, **kwargs):
        """
        Plot (k, e) color maps for different temperatures.

        Args:
            fontsize (int): fontsize for titles and legend

        Return: |matplotlib-Figure|
        """
        temp_inds = range(self.ntemp) if temp_inds is None else temp_inds
        # Build plot grid.
        num_plots, ncols, nrows = len(temp_inds), 1, 1
        if num_plots > 1:
            ncols = 2
            nrows = (num_plots // ncols) + (num_plots % ncols)

        ax_list, fig, plt = get_axarray_fig_plt(None, nrows=nrows, ncols=ncols,
                                                sharex=True, sharey=True, squeeze=False)
        ax_list = ax_list.ravel()

        # Don't show the last ax if numeb is odd.
        if num_plots % ncols != 0: ax_list[-1].axis("off")

        for itemp, ax in zip(temp_inds, ax_list):
            self.plot_ekmap_itemp(itemp=itemp, spins=spins, estep=estep, ax=ax, ylims=ylims,
                    with_colorbar=with_colorbar, show=False, **kwargs)
            ax.set_title("T = %.1f K" % self.tmesh[itemp], fontsize=fontsize)

        return fig
Пример #6
0
    def gridplot_eos(self, eos_names="all", fontsize=6, **kwargs):
        """
        Plot multiple EOS on a grid with captions showing the final results.

        Args:
            eos_names: String or list of strings with EOS names. See pymatgen.analysis.EOS
            fontsize: Fontsize used for caption text.

        Returns: |matplotlib-Figure|
        """
        r = self.get_eos_fits_dataframe(eos_names=eos_names)

        num_plots, ncols, nrows = len(r.fits), 1, 1
        if num_plots > 1:
            ncols = 2
            nrows = (num_plots // ncols) + (num_plots % ncols)

        # Build grid of plots.
        ax_list, fig, plt = get_axarray_fig_plt(None, nrows=nrows, ncols=ncols,
                                                sharex=False, sharey=False, squeeze=False)
        ax_list = ax_list.ravel()

        for i, (fit, ax) in enumerate(zip(r.fits, ax_list)):
            fit.plot_ax(ax=ax, fontsize=fontsize, label="", show=False)

        # Get around a bug in matplotlib
        if num_plots % ncols != 0:
            ax_list[-1].axis('off')

        return fig
Пример #7
0
    def plot_phonons_occ(self, temps=(100, 200, 300, 400), **kwargs):
        """
        Plot phonon band structure with markers proportional to the occupation
        of each phonon mode for different temperatures.

        Args:
            temps: List of temperatures in Kelvin.

        Return: |matplotlib-Figure|
        """
        temps = np.array(temps)
        ntemp = len(temps)

        # Build plot grid.
        num_plots, ncols, nrows = ntemp, 1, 1
        if num_plots > 1:
            ncols = 2
            nrows = (num_plots // ncols) + (num_plots % ncols)

        ax_list, fig, plt = get_axarray_fig_plt(None,
                                                nrows=nrows,
                                                ncols=ncols,
                                                sharex=True,
                                                sharey=True,
                                                squeeze=False)
        ax_list = ax_list.ravel()

        for ax, temp in zip(ax_list, temps.ravel()):
            self.phb_qpath.plot(ax=ax,
                                units="eV",
                                temp=temp,
                                fontsize=8,
                                show=False)

        return fig
Пример #8
0
    def plot_linopt(self, select="all", itemp=0, xlims=None, **kwargs):
        """
        Subplots with all linear optic quantities selected by ``select`` at temperature ``itemp``.

        Args:
            select:
            itemp: Temperature index.
            xlims: Set the data limits for the x-axis. Accept tuple e.g. ``(left, right)``
                or scalar e.g. ``left``. If left (right) is None, default values are used.

        Returns: |matplotlib-Figure|
        """
        key = "linopt"
        if not self.reader.computed_components[key]: return None
        if select == "all": select = list(LINEPS_WHAT2EFUNC.keys())
        select = list_strings(select)

        nrows, ncols = len(select), 1
        ax_mat, fig, plt = get_axarray_fig_plt(None,
                                               nrows=nrows,
                                               ncols=ncols,
                                               sharex=True,
                                               sharey=False,
                                               squeeze=True)

        components = self.reader.computed_components[key]
        for i, (what, ax) in enumerate(zip(select, ax_mat)):
            self.plot_linear_epsilon(what=what,
                                     itemp=itemp,
                                     components=components,
                                     ax=ax,
                                     xlims=xlims,
                                     with_xlabel=(i == len(select) - 1),
                                     show=False)
        return fig
Пример #9
0
    def plot_line_neighbors(self, site_index, radius, num=200, max_nn=10, fontsize=12, **kwargs):
        """
        Plot (interpolated) density/potential in real space along the lines connecting
        an atom specified by ``site_index`` and all neighbors within a sphere of given ``radius``.

        .. warning::

            This routine can produce lots of plots!
            Be careful with the value of ``radius``. See also ``max_nn``.

        Args:
            site_index: Index of the atom in the structure.
            radius: Radius of the sphere in Angstrom
            num: Number of points sampled along the line.
            max_nn: By default, only the first `max_nn` neighbors are showed.
            fontsize: legend and title fontsize

        Return: |matplotlib-Figure|
        """
        site = self.structure[site_index]
        nn_list = self.structure.get_neighbors(site, radius, include_index=True)
        if not nn_list:
            cprint("Zero neighbors found for radius %s Ang. Returning None." % radius, "yellow")
            return None
        # Sorte sites by distance.
        nn_list = list(sorted(nn_list, key=lambda t: t[1]))

        if max_nn is not None and len(nn_list) > max_nn:
            cprint("For radius %s, found %s neighbors but only max_nn %s sites are show." %
                    (radius, len(nn_list), max_nn), "yellow")
            nn_list = nn_list[:max_nn]

        # Get grid of axes.
        nrows, ncols = len(nn_list), 1
        ax_list, fig, plt = get_axarray_fig_plt(None, nrows=nrows, ncols=ncols,
                                                sharex=True, sharey=True, squeeze=True)
        ax_list = ax_list.ravel()

        interpolator = self.get_interpolator()

        for i, (nn, ax) in enumerate(zip(nn_list, ax_list)):
            nn_site, nn_dist, nn_sc_index  = nn
            title = "%s, %s, dist=%.3f A" % (nn_site.species_string, str(nn_site.frac_coords), nn_dist)

            r = interpolator.eval_line(site.frac_coords, nn_site.frac_coords, num=num, kpoint=None)

            for ispden in range(self.nspden):
                ax.plot(r.dist, r.values[ispden],
                        label=latexlabel_ispden(ispden, self.nspden) if i == 0 else None)

            ax.set_title(title, fontsize=fontsize)
            ax.grid(True)

            if i == nrows - 1:
                ax.set_xlabel("Distance from site_index %s [Angstrom]" % site_index)
                ax.set_ylabel(self.latex_label)
                if self.nspden > 1:
                    ax.legend(loc="best", fontsize=fontsize, shadow=True)

        return fig
Пример #10
0
    def plot_linopt(self, select="all", itemp=0, xlims=None, **kwargs):
        """
        Subplots with all linear optic quantities selected by ``select`` at temperature ``itemp``.

        Args:
            select:
            itemp: Temperature index.
            xlims: Set the data limits for the x-axis. Accept tuple e.g. ``(left, right)``
                or scalar e.g. ``left``. If left (right) is None, default values are used.

        Returns: |matplotlib-Figure|
        """
        key = "linopt"
        if not self.reader.computed_components[key]: return None
        if select == "all": select = list(LINEPS_WHAT2EFUNC.keys())
        select = list_strings(select)

        nrows, ncols = len(select), 1
        ax_mat, fig, plt = get_axarray_fig_plt(None, nrows=nrows, ncols=ncols,
                                               sharex=True, sharey=False, squeeze=True)

        components = self.reader.computed_components[key]
        for i, (what, ax) in enumerate(zip(select, ax_mat)):
            self.plot_linear_epsilon(what=what, itemp=itemp, components=components,
                                     ax=ax, xlims=xlims, with_xlabel=(i == len(select) - 1),
                                     show=False)
        return fig
Пример #11
0
    def plot_ibte_mrta_serta_conv(self,  what="resistivity", fontsize=8, **kwargs):
        """
        """
        #num_plots, ncols, nrows, what_list = x2_grid(what_list)
        nrows = 1 # xx
        ncols = 3 # SERTA, MRTA, IBTE
        ax_list, fig, plt = get_axarray_fig_plt(None, nrows=nrows, ncols=ncols,
                                                sharex=True, sharey=True, squeeze=False)
        ax_list = ax_list.ravel()
        # don't show the last ax if numeb is odd.
        #if num_plots % ncols != 0: ax_list[-1].axis("off")

        i = j = 0
        from collections import defaultdict
        data = defaultdict(list)
        for abifile in self.abifiles:
            rta_vals = abifile.reader.read_value("resistivity")
            data["serta"].append(rta_vals[0, :, j, i])
            data["mrta"].append(rta_vals[1, :, j, i])
            ibte_vals = abifile.reader.read_value("ibte_rho")
            data["ibte"].append(ibte_vals[:, j, i])

        tmesh = self.get_same_tmesh()
        keys = ["serta", "mrta", "ibte"]
        for ix, (key, ax) in enumerate(zip(keys, ax_list)):
            ax.grid(True)
            ax.set_title(key.upper(), fontsize=fontsize)
            for ifile, ys in enumerate(data[key]):
                ax.plot(tmesh, ys, marker="o", label=self.labels[ifile])
            ax.set_xlabel("Temperature (K)")
            if ix == 0:
                ax.set_ylabel(r"Resistivity ($\mu\Omega\;cm$)")
                ax.legend(loc="best", shadow=True, fontsize=fontsize)

        return fig
Пример #12
0
    def plot(self, ax_list=None, fontsize=8, **kwargs):
        """
        Plot the evolution of structural parameters (lattice lengths, angles and volume)
        as well as pressure, info on forces and total energy.

        Args:
            ax_list: List of |matplotlib-Axes|. If None, a new figure is created.
            fontsize: fontsize for legend

        Returns: |matplotlib-Figure|
        """
        what_list = ["abc", "angles", "volume", "pressure", "forces", "energy"]
        nrows, ncols = 3, 2
        ax_list, fig, plt = get_axarray_fig_plt(None,
                                                nrows=nrows,
                                                ncols=ncols,
                                                sharex=True,
                                                sharey=False,
                                                squeeze=False)
        ax_list = ax_list.ravel()
        assert len(ax_list) == len(what_list)

        for what, ax in zip(what_list, ax_list):
            self.plot_ax(ax, what, fontsize=fontsize, marker="o")

        return fig
Пример #13
0
    def plot_ak_vs_temp(self, temp_inds=None, spins=None, band_inds=None, kpt_inds=None,
                        apad=1.0, estep=0.02, colormap="jet", fontsize=8, **kwargs):
        """

        Args:
            temp_inds:
            spins:
            band_inds:
            kpt_inds:
            estep:
            colormap
            fontsize (int): fontsize for titles and legend

        Return: |matplotlib-Figure|
        """
        temp_inds = range(self.ntemp) if temp_inds is None else temp_inds
        ntemp = len(temp_inds)
        spins = range(self.ebands.nsppol) if spins is None else spins
        kpt_inds = range(self.ebands.nkpt) if kpt_inds is None else kpt_inds
        nkpt = len(kpt_inds)

        xs, emin, emax = self.get_emesh_eminmax(estep)
        nene = len(xs)

        num_plots, ncols, nrows = nkpt, 1, 1
        if num_plots > 1:
            ncols = 2
            nrows = (num_plots // ncols) + (num_plots % ncols)

        # Build plot grid.
        ax_list, fig, plt = get_axarray_fig_plt(None, nrows=nrows, ncols=ncols,
                                                sharex=True, sharey=True, squeeze=False)
        ax_list = np.array(ax_list).ravel()
        cmap = plt.get_cmap(colormap)

        for isp, spin in enumerate(spins):
            spin_sign = +1 if spin == 0 else -1
            for ik, (ikpt, ax) in enumerate(zip(kpt_inds, ax_list)):
                ax.grid(True)
                atw = self.get_atw(xs, spin, ikpt, band_inds, temp_inds)
                for it, itemp in enumerate(temp_inds):
                    ys = spin_sign * atw[it] + (it * apad)
                    ax.plot(xs, ys, lw=2, alpha=0.8, color=cmap(float(it) / ntemp),
                            label = "T = %.1f K" % self.tmesh[itemp] if (ik, isp) == (0, 0) else None)

                if spin == 0:
                    kpt = self.ebands.kpoints[ikpt]
                    ax.set_title("k:%s" % (repr(kpt)), fontsize=fontsize)

                if (ik, isp) == (0, 0):
                    ax.legend(loc="best", fontsize=fontsize, shadow=True)

        return fig
Пример #14
0
    def plot_ak_vs_temp(self, temp_inds=None, spins=None, band_inds=None, kpt_inds=None,
                        apad=1.0, estep=0.02, colormap="jet", fontsize=8, **kwargs):
        """

        Args:
            temp_inds:
            spins:
            band_inds:
            kpt_inds:
            estep:
            colormap
            fontsize (int): fontsize for titles and legend

        Return: |matplotlib-Figure|
        """
        temp_inds = range(self.ntemp) if temp_inds is None else temp_inds
        ntemp = len(temp_inds)
        spins = range(self.ebands.nsppol) if spins is None else spins
        kpt_inds = range(self.ebands.nkpt) if kpt_inds is None else kpt_inds
        nkpt = len(kpt_inds)

        xs, emin, emax = self.get_emesh_eminmax(estep)
        nene = len(xs)

        num_plots, ncols, nrows = nkpt, 1, 1
        if num_plots > 1:
            ncols = 2
            nrows = (num_plots // ncols) + (num_plots % ncols)

        # Build plot grid.
        ax_list, fig, plt = get_axarray_fig_plt(None, nrows=nrows, ncols=ncols,
                                                sharex=True, sharey=True, squeeze=False)
        ax_list = np.array(ax_list).ravel()
        cmap = plt.get_cmap(colormap)

        for isp, spin in enumerate(spins):
            spin_sign = +1 if spin == 0 else -1
            for ik, (ikpt, ax) in enumerate(zip(kpt_inds, ax_list)):
                ax.grid(True)
                atw = self.get_atw(xs, spin, ikpt, band_inds, temp_inds)
                for it, itemp in enumerate(temp_inds):
                    ys = spin_sign * atw[it] + (it * apad)
                    ax.plot(xs, ys, lw=2, alpha=0.8, color=cmap(float(it) / ntemp),
                            label="T = %.1f K" % self.tmesh[itemp] if (ik, isp) == (0, 0) else None)

                if spin == 0:
                    kpt = self.ebands.kpoints[ikpt]
                    ax.set_title("k:%s" % (repr(kpt)), fontsize=fontsize)

                if (ik, isp) == (0, 0):
                    ax.legend(loc="best", fontsize=fontsize, shadow=True)

        return fig
Пример #15
0
    def plot_ibte_vs_rta_rho(self, component="xx", fontsize=8, **kwargs):
        """
        """
        nrows = 1 # xx
        ncols = len(self) # SERTA, MRTA, IBTE
        ax_list, fig, plt = get_axarray_fig_plt(None, nrows=nrows, ncols=ncols,
                                                sharex=True, sharey=True, squeeze=False)
        ax_list = ax_list.ravel()

        for abifile, ax in zip(self.abifiles, ax_list):
            abifile.plot_ibte_vs_rta_rho(component="xx", fontsize=fontsize, ax=ax, show=False)

        return fig
Пример #16
0
    def combiplot(self, what_list=None, colormap="jet", fontsize=6, **kwargs):
        """
        Plot multiple HIST.nc_ files on a grid. One plot for each ``what`` value.

        Args:
            what_list: List of strings with the quantities to plot. If None, all quanties are plotted.
            colormap: matplotlib color map.
            fontsize: fontisize for legend.

        Returns: |matplotlib-Figure|.
        """
        what_list = (list_strings(what_list) if what_list is not None else [
            "energy", "a", "b", "c", "alpha", "beta", "gamma", "volume",
            "pressure"
        ])

        num_plots, ncols, nrows = len(what_list), 1, 1
        if num_plots > 1:
            ncols = 2
            nrows = (num_plots // ncols) + (num_plots % ncols)

        ax_list, fig, plt = get_axarray_fig_plt(None,
                                                nrows=nrows,
                                                ncols=ncols,
                                                sharex=True,
                                                sharey=False,
                                                squeeze=False)
        ax_list = ax_list.ravel()
        cmap = plt.get_cmap(colormap)

        for i, (ax, what) in enumerate(zip(ax_list, what_list)):
            for ih, hist in enumerate(self.abifiles):
                label = None if i != 0 else hist.relpath
                hist.plot_ax(ax,
                             what,
                             color=cmap(ih / len(self)),
                             label=label,
                             fontsize=fontsize)

            if label is not None:
                ax.legend(loc="best", fontsize=fontsize, shadow=True)

            if i == len(ax_list) - 1:
                ax.set_xlabel("Step")
            else:
                ax.set_xlabel("")

        # Get around a bug in matplotlib.
        if num_plots % ncols != 0: ax_list[-1].axis('off')

        return fig
Пример #17
0
    def plot(self, mdf_type="exc", qview="avg", xlims=None, ylims=None, fontsize=8, **kwargs):
        """
        Plot all macroscopic dielectric functions (MDF) stored in the plotter

        Args:
            mdf_type: Selects the type of dielectric function.
                "exc" for the MDF with excitonic effects.
                "rpa" for RPA with KS energies.
                "gwrpa" for RPA with GW (or KS-corrected) results.
            qview: "avg" to plot the results averaged over q-points. "all" to plot q-point dependence.
            xlims: Set the data limits for the y-axis. Accept tuple e.g. `(left, right)`
                  or scalar e.g. `left`. If left (right) is None, default values are used
            ylims: Same meaning as `ylims` but for the y-axis
            fontsize: fontsize for titles and legend.

        Return: |matplotlib-Figure|
        """
        # Build plot grid.
        if qview == "avg":
            ncols, nrows = 2, 1
        elif qview == "all":
            qpoints = self._get_qpoints()
            ncols, nrows = 2, len(qpoints)
        else:
            raise ValueError("Invalid value of qview: %s" % str(qview))

        ax_mat, fig, plt = get_axarray_fig_plt(None, nrows=nrows, ncols=ncols,
                                               sharex=True, sharey=True, squeeze=False)

        if qview == "avg":
            # Plot averaged values
            self.plot_mdftype_cplx(mdf_type, "Re", ax=ax_mat[0, 0], xlims=xlims, ylims=ylims,
                                   fontsize=fontsize, with_legend=True, show=False)
            self.plot_mdftype_cplx(mdf_type, "Im", ax=ax_mat[0, 1], xlims=xlims, ylims=ylims,
                                   fontsize=fontsize, with_legend=False, show=False)
        elif qview == "all":
            # Plot MDF(q)
            nqpt = len(qpoints)
            for iq, qpt in enumerate(qpoints):
                islast = (iq == nqpt - 1)
                self.plot_mdftype_cplx(mdf_type, "Re", qpoint=qpt, ax=ax_mat[iq, 0], xlims=xlims, ylims=ylims,
                    fontsize=fontsize, with_legend=(iq == 0), with_xlabel=islast, with_ylabel=islast, show=False)
                self.plot_mdftype_cplx(mdf_type, "Im", qpoint=qpt, ax=ax_mat[iq, 1], xlims=xlims, ylims=ylims,
                    fontsize=fontsize, with_legend=False, with_xlabel=islast, with_ylabel=islast, show=False)

        else:
            raise ValueError("Invalid value of qview: `%s`" % str(qview))

        #ax_mat[0, 0].legend(loc="best", fontsize=fontsize, shadow=True)

        return fig
Пример #18
0
    def plot_transport_tensors_mu(self, component="xx", spin=0,
                                  what_list=("sigma", "seebeck", "kappa", "pi"),
                                  colormap="jet", fontsize=8, **kwargs):
        """
        Plot selected Cartesian components of transport tensors as a function
        of the chemical potential mu at the given temperature.

        Args:
            ax_list: |matplotlib-Axes| or None if a new figure should be created.
            fontsize: fontsize for legends and titles

        Return: |matplotlib-Figure|
        """
        i, j = abu.s2itup(component)

        num_plots, ncols, nrows, what_list = x2_grid(what_list)
        ax_list, fig, plt = get_axarray_fig_plt(None, nrows=nrows, ncols=ncols,
                                                sharex=True, sharey=False, squeeze=False)
        ax_list = ax_list.ravel()
        # don't show the last ax if numeb is odd.
        if num_plots % ncols != 0: ax_list[-1].axis("off")

        cmap = plt.get_cmap(colormap)

        for iax, (what, ax) in enumerate(zip(what_list, ax_list)):
            irow, icol = divmod(iax, ncols)
            # nctkarr_t('seebeck', "dp", "three, three, edos_nw, ntemp, nsppol, nrta")
            what_var = self.reader.read_variable(what)

            for irta in range(self.nrta):
                for itemp, temp in enumerate(self.tmesh):
                    ys = what_var[irta, spin, itemp, :, j, i]
                    label = "T = %dK" % temp
                    if itemp == 0: label = "%s (%s)" % (label, irta2s(irta))
                    if irta == 0 and itemp > 0: label = None
                    ax.plot(self.edos_mesh_eV, ys, c=cmap(itemp / self.ntemp), label=label, **style_for_irta(irta))

            ax.grid(True)
            ax.set_ylabel(transptens2latex(what, component))

            ax.legend(loc="best", fontsize=fontsize, shadow=True)
            if irow == nrows - 1:
                ax.set_xlabel(r"$\mu$ (eV)")

            self._add_vline_at_bandedge(ax, spin, "both")

        if "title" not in kwargs:
            fig.suptitle("Transport tensors", fontsize=fontsize)

        return fig
Пример #19
0
    def plot_doses(self,
                   xlims=None,
                   dos_names="all",
                   with_idos=True,
                   **kwargs):
        r"""
        Plot the different doses stored in the GRUNS.nc file.

        Args:
            xlims: Set the data limits for the x-axis in eV. Accept tuple e.g. ``(left, right)``
                or scalar e.g. ``left``. If left (right) is None, default values are used
            dos_names: List of strings defining the DOSes to plot. Use "all" to plot all DOSes available.
            with_idos: True to display integrated doses.

        Return: |matplotlib-Figure|
        """
        if not self.doses: return None

        dos_names = _ALL_DOS_NAMES.keys(
        ) if dos_names == "all" else list_strings(dos_names)
        wmesh = self.doses["wmesh"]

        nrows, ncols = len(dos_names), 1
        ax_list, fig, plt = get_axarray_fig_plt(None,
                                                nrows=nrows,
                                                ncols=ncols,
                                                sharex=True,
                                                sharey=False,
                                                squeeze=False)
        ax_list = ax_list.ravel()

        for i, (name, ax) in enumerate(zip(dos_names, ax_list)):
            dos, idos = self.doses[name][0], self.doses[name][1]
            ax.plot(wmesh, dos, color="k")
            ax.grid(True)
            set_axlims(ax, xlims, "x")
            ax.set_ylabel(_ALL_DOS_NAMES[name]["latex"])
            #ax.yaxis.set_ticks_position("right")

            if with_idos:
                other_ax = ax.twinx()
                other_ax.plot(wmesh, idos, color="k")
                other_ax.set_ylabel(_ALL_DOS_NAMES[name]["latex"].replace(
                    "DOS", "IDOS"))

            if i == len(dos_names) - 1:
                ax.set_xlabel(r"$\omega$ (eV)")
            #ax.legend(loc="best", fontsize=fontsize, shadow=True)

        return fig
Пример #20
0
    def plot(self, fontsize=12, **kwargs):
        """
        Plot the convergence of the Wannierise cycle.

        Args:
            fontsize: legend and label fontsize.

        Returns: |matplotlib-Figure|
        """
        if self._parse_iterations() != 0:
            print("Wout files does not contain Wannierization cycles. Returning None")
            return None

        items = ["delta_spread", "rms_gradient", "spread"]
        if self.use_disentangle:
            items += ["omegaI_i"]

        # Build grid of plots.
        num_plots, ncols, nrows = len(items), 1, 1
        if num_plots > 1:
            ncols = 2
            nrows = (num_plots // ncols) + (num_plots % ncols)

        ax_list, fig, plt = get_axarray_fig_plt(None, nrows=nrows, ncols=ncols,
                                                sharex=True, sharey=False, squeeze=False)
        ax_list = ax_list.ravel()

        # Don't show the last ax if num_plots is odd.
        if num_plots % ncols != 0: ax_list[-1].axis("off")

        marker = "."
        for ax, item in zip(ax_list, items):
            ax.grid(True)
            ax.set_xlabel("Iteration Step")
            ax.set_ylabel(item)
            s = 1
            if item == "omegaI_i":
                # Plot Disentanglement cycles
                ax.plot(self.dis_df.iter[s:], self.dis_df[item][s:], marker=marker)
                from mpl_toolkits.axes_grid1.inset_locator import inset_axes
                ax2 = inset_axes(ax, width="60%", height="40%", loc="upper right")
                ax2.grid(True)
                ax2.set_title("delta_frac", fontsize=8)
                ax2.plot(self.dis_df.iter[s:], self.dis_df["delta_frac"][s:], marker=marker)

            else:
                ax.plot(self.conv_df.iter[s:], self.conv_df[item][s:], marker=marker)

        return fig
Пример #21
0
    def plot_centers_spread(self, fontsize=8, **kwargs):
        """
        Plot the convergence of the Wannier centers and spread
        as function of iteration number

        Args:
            fontsize: legend and label fontsize.

        Returns: |matplotlib-Figure|
        """
        if self._parse_iterations() != 0:
            print("Wout files does not contain Wannierization cycles. Returning None")
            return None

        # Build grid of plots.
        # nwan subplot with evolution of the WF center + last subplot with all spreads
        num_plots, ncols, nrows = self.nwan + 1, 1, 1
        if num_plots > 1:
            ncols = 2
            nrows = (num_plots // ncols) + (num_plots % ncols)

        ax_list, fig, plt = get_axarray_fig_plt(None, nrows=nrows, ncols=ncols,
                                                sharex=True, sharey=False, squeeze=False)
        ax_list = ax_list.ravel()

        # Don't show the last ax if num_plots is odd.
        if num_plots % ncols != 0: ax_list[-1].axis("off")

        marker = "."
        for iax in range(num_plots):
            ax = ax_list[iax]
            ax.grid(True)
            ax.set_xlabel("Iteration Step")
            s = 1
            if iax < self.nwan:
                ax.set_ylabel("Center of WF #%s" % (iax + 1))
                for idir in range(3):
                    ax.plot(self.conv_df.iter[s:], self.wf_centers[iax, s:, idir], marker=marker,
                            label={0: "x", 1: "y", 2: "z"}[idir] if iax == 0 else None)
            else:
                ax.set_ylabel("WF Spread")
                for iw in range(self.nwan):
                    ax.plot(self.conv_df.iter[s:], self.wf_spreads[iw, s:], marker=marker,
                            label="WF#%d" % (iw + 1))

            if iax in (0, self.nwan):
                ax.legend(loc="best", shadow=True, fontsize=fontsize)

        return fig
Пример #22
0
    def gridplot(self,
                 what_list=None,
                 sharex="row",
                 sharey="row",
                 fontsize=8,
                 **kwargs):
        """
        Plot the ``what`` value extracted from multiple HIST.nc_ files on a grid.

        Args:
            what_list: List of quantities to plot.
                Must be in ["energy", "abc", "angles", "volume", "pressure", "forces"]
            sharex: True if xaxis should be shared.
            sharey: True if yaxis should be shared.
            fontsize: fontsize for legend.

        Returns: |matplotlib-Figure|
        """
        what_list = list_strings(
            what_list) if what_list is not None else self.what_list

        # Build grid of plots.
        nrows, ncols = len(what_list), len(self)

        ax_mat, fig, plt = get_axarray_fig_plt(None,
                                               nrows=nrows,
                                               ncols=ncols,
                                               sharex=sharex,
                                               sharey=sharey,
                                               squeeze=False)
        ax_mat = np.reshape(ax_mat, (nrows, ncols))

        for irow, what in enumerate(what_list):
            for icol, hist in enumerate(self.abifiles):
                ax = ax_mat[irow, icol]
                ax.grid(True)
                hist.plot_ax(ax_mat[irow, icol],
                             what,
                             fontsize=fontsize,
                             marker="o")

                if irow == 0:
                    ax.set_title(hist.relpath, fontsize=fontsize)
                if irow != nrows - 1:
                    set_visible(ax, False, "xlabel")
                if icol != 0:
                    set_visible(ax, False, "ylabel")

        return fig
Пример #23
0
    def plot_tau_isoe(self, ax_list=None, colormap="jet", fontsize=8, **kwargs):
        r"""
        Plot tau(e). Energy-dependent scattering rate defined by:

            $\tau(\epsilon) = \frac{1}{N_k} \sum_{nk} \tau_{nk}\,\delta(\epsilon - \epsilon_{nk})$

        Two differet subplots for SERTA and MRTA.

        Args:
            ax_list: List of |matplotlib-Axes| or None if a new figure should be created.
            colormap:
            fontsize (int): fontsize for titles and legend

        Return: |matplotlib-Figure|
        """
        ax_list, fig, plt = get_axarray_fig_plt(ax_list, nrows=self.nrta, ncols=1,
                                                sharex=True, sharey=True, squeeze=False)
        ax_list = ax_list.ravel()
        cmap = plt.get_cmap(colormap)

        # nctkarr_t('tau_dos', "dp", "edos_nw, ntemp, nsppol, nrta")
        tau_dos = self.reader.read_value("tau_dos")

        for irta, ax in enumerate(ax_list):
            for spin in range(self.nsppol):
                spin_sign = +1 if spin == 0 else -1
                for it, temp in enumerate(self.tmesh):
                    # Convert to femtoseconds
                    ys = spin_sign * tau_dos[irta, spin, it] * abu.Time_Sec * 1e+15
                    ax.plot(self.edos_mesh_eV , ys, c=cmap(it / self.ntemp),
                            label="T = %dK" % temp if spin == 0 else None)

            ax.grid(True)
            ax.legend(loc="best", shadow=True, fontsize=fontsize)
            if irta == (len(ax_list) - 1):
                ax.set_xlabel('Energy (eV)')
                ax.set_ylabel(r"$\tau(\epsilon)\, (fms)$")

            self._add_vline_at_bandedge(ax, spin, "both")

            ax.text(0.1, 0.9, irta2s(irta), fontsize=fontsize,
                horizontalalignment='center', verticalalignment='center', transform=ax.transAxes,
                bbox=dict(alpha=0.5))

        if "title" not in kwargs:
            title = r"$\tau(\epsilon) = \frac{1}{N_k} \sum_{nk} \tau_{nk}\,\delta(\epsilon - \epsilon_{nk})$"
            fig.suptitle(title, fontsize=fontsize)

        return fig
Пример #24
0
    def plot_convergence_rank3(self, key, components="all", itemp=0, what_list=("abs",),
                               sortby="nkpt", decompose=False, xlims=None, **kwargs):
        """
        Plot convergence of arbitrary rank3 tensor. This is a low-level routine used in other plot methods.

        Args:
            key: Name of the quantity to analyze.
            components: List of cartesian tensor components to plot e.g. ["xxx", "xyz"].
                "all" if all components available on file should be plotted on the same ax.
            itemp: Temperature index.
            what_list: List of quantities to plot. "re" for real part, "im" for imaginary.
                Accepts also "abs", "angle".
            sortby: Define the convergence parameter, sort files and produce plot labels. Can be None, string or function.
                If None, no sorting is performed.
                If string, it's assumed that the ncfile has an attribute with the same name and ``getattr`` is invoked.
                If callable, the output of callable(ncfile) is used.
            decompose: True to plot individual contributions.
            xlims: Set the data limits for the x-axis. Accept tuple e.g. ``(left, right)``
                or scalar e.g. ``left``. If left (right) is None, default values are used.

        Returns: |matplotlib-Figure|
        """
        # Build grid plot: computed tensors along the rows, what_list along columns.
        components = self.computed_components_intersection[key]

        nrows, ncols = len(components), len(what_list)
        ax_mat, fig, plt = get_axarray_fig_plt(None, nrows=nrows, ncols=ncols,
                                               sharex=True, sharey=False, squeeze=False)

        label_ncfile_param = self.sortby(sortby)
        for i, comp in enumerate(components):
            for j, what in enumerate(what_list):
                ax = ax_mat[i, j]
                for ifile, (label, ncfile, param) in enumerate(label_ncfile_param):

                    ncfile.plot_chi2(key=key, components=comp, what=what, itemp=itemp, decompose=decompose,
                        ax=ax, xlims=xlims, with_xlabel=(i == len(components) - 1),
                        label="%s %s" % (sortby, param) if not callable(sortby) else str(param),
                        show=False, **kwargs)

                    if ifile == 0:
                        ax.set_title(ncfile.get_chi2_latex_label(key, what, comp))

                if (i, j) != (0, 0):
                    ax.legend().set_visible(False)

        return fig
Пример #25
0
    def plot_convergence_rank3(self, key, components="all", itemp=0, what_list=("abs",),
                               sortby="nkpt", decompose=False, xlims=None, **kwargs):
        """
        Plot convergence of arbitrary rank3 tensor. This is a low-level routine used in other plot methods.

        Args:
            key: Name of the quantity to analyze.
            components: List of cartesian tensor components to plot e.g. ["xxx", "xyz"].
                "all" if all components available on file should be plotted on the same ax.
            itemp: Temperature index.
            what_list: List of quantities to plot. "re" for real part, "im" for imaginary.
                Accepts also "abs", "angle".
            sortby: Define the convergence parameter, sort files and produce plot labels. Can be None, string or function.
                If None, no sorting is performed.
                If string, it's assumed that the ncfile has an attribute with the same name and ``getattr`` is invoked.
                If callable, the output of callable(ncfile) is used.
            decompose: True to plot individual contributions.
            xlims: Set the data limits for the x-axis. Accept tuple e.g. ``(left, right)``
                or scalar e.g. ``left``. If left (right) is None, default values are used.

        Returns: |matplotlib-Figure|
        """
        # Build grid plot: computed tensors along the rows, what_list along columns.
        components = self.computed_components_intersection[key]

        nrows, ncols = len(components), len(what_list)
        ax_mat, fig, plt = get_axarray_fig_plt(None, nrows=nrows, ncols=ncols,
                                               sharex=True, sharey=False, squeeze=False)

        label_ncfile_param = self.sortby(sortby)
        for i, comp in enumerate(components):
            for j, what in enumerate(what_list):
                ax = ax_mat[i, j]
                for ifile, (label, ncfile, param) in enumerate(label_ncfile_param):

                    ncfile.plot_chi2(key=key, components=comp, what=what, itemp=itemp, decompose=decompose,
                        ax=ax, xlims=xlims, with_xlabel=(i == len(components) - 1),
                        label="%s %s" % (sortby, param) if not callable(sortby) else str(param),
                        show=False, **kwargs)

                    if ifile == 0:
                        ax.set_title(ncfile.get_chi2_latex_label(key, what, comp))

                if (i, j) != (0, 0):
                    ax.legend().set_visible(False)

        return fig
Пример #26
0
    def plot_elastic_properties(self, fontsize=10, **kwargs):
        """
        Args:
            fontsize: legend and label fontsize.

        Returns: |matplotlib-Figure|
        """
        df = self.get_elastic_dataframe(with_geo=False,
                                        abspath=False,
                                        with_params=False)
        from pandas.api.types import is_numeric_dtype
        keys = [k for k in df.keys() if is_numeric_dtype(df[k])]
        i = keys.index("fitted_to_structure")
        if i != -1:
            keys.pop(i)

        num_plots, ncols, nrows = len(keys), 1, 1
        if num_plots > 1:
            ncols = 3
            nrows = (num_plots // ncols) + (num_plots % ncols)

        ax_list, fig, plt = get_axarray_fig_plt(None,
                                                nrows=nrows,
                                                ncols=ncols,
                                                sharex=False,
                                                sharey=False,
                                                squeeze=False)
        ax_list = ax_list.ravel()

        for ix, (key, ax) in enumerate(zip(keys, ax_list)):
            irow, icol = divmod(ix, ncols)
            xn = range(len(df.index))
            ax.plot(xn, df[key], marker="o")
            ax.grid(True)
            ax.set_xticks(xn)
            ax.set_ylabel(key, fontsize=fontsize)
            ax.set_xticklabels([])

        ax.set_xticklabels(self.keys(), fontsize=fontsize)
        rotate_ticklabels(ax, 15)

        if ix != len(ax_list) - 1:
            for ix in range(ix + 1, len(ax_list)):
                ax_list[ix].axis('off')

        return fig
Пример #27
0
    def plot_energies(self, **kwargs):
        """
        Plots the energies as a function of volume at different temperatures.
        kwargs are propagated to the analogous method of QHA.
        """
        self._consistency_check()
        # Build grid of plots.
        ax_list, fig, plt = get_axarray_fig_plt(None,
                                                nrows=self.num_qmeshes,
                                                ncols=1,
                                                sharex=True,
                                                sharey=True,
                                                squeeze=False)
        ax_list = ax_list.ravel()

        for qha, ngqpt, ax in zip(self.qha_list, self.ngqpt_list, ax_list):
            qha.plot_energies(ax=ax, show=False, **kwargs)
            ax.set_title("ngpqt: %s" % str(ngqpt), fontsize=self.fontsize)
        return fig
Пример #28
0
    def combiplot(self, what_list=None, colormap="jet", fontsize=6, **kwargs):
        """
        Plot multiple HIST.nc_ files on a grid. One plot for each ``what`` value.

        Args:
            what_list: List of strings with the quantities to plot. If None, all quanties are plotted.
            colormap: matplotlib color map.
            fontsize: fontisize for legend.

        Returns: |matplotlib-Figure|.
        """
        what_list = (list_strings(what_list) if what_list is not None
            else ["energy", "a", "b", "c", "alpha", "beta", "gamma", "volume", "pressure"])

        num_plots, ncols, nrows = len(what_list), 1, 1
        if num_plots > 1:
            ncols = 2
            nrows = (num_plots // ncols) + (num_plots % ncols)

        ax_list, fig, plt = get_axarray_fig_plt(None, nrows=nrows, ncols=ncols,
                                                sharex=True, sharey=False, squeeze=False)
        ax_list = ax_list.ravel()
        cmap = plt.get_cmap(colormap)

        for i, (ax, what) in enumerate(zip(ax_list, what_list)):
            for ih, hist in enumerate(self.abifiles):
                label= None if i != 0 else hist.relpath
                hist.plot_ax(ax, what, color=cmap(ih / len(self)), label=label, fontsize=fontsize)

            if label is not None:
                ax.legend(loc="best", fontsize=fontsize, shadow=True)

            if i == len(ax_list) - 1:
                ax.set_xlabel("Step")
            else:
                ax.set_xlabel("")

        # Get around a bug in matplotlib.
        if num_plots % ncols != 0: ax_list[-1].axis('off')

        return fig
Пример #29
0
    def plot(self, what_list=None, ax_list=None, fontsize=8, **kwargs):
        """
        Plot the evolution of structural parameters (lattice lengths, angles and volume)
        as well as pressure, info on forces and total energy.

        Args:
            what_list:
            ax_list: List of |matplotlib-Axes|. If None, a new figure is created.
            fontsize: fontsize for legend

        Returns: |matplotlib-Figure|
        """
        if what_list is None:
            what_list = [
                "abc", "angles", "volume", "pressure", "forces", "energy"
            ]
        else:
            what_list = list_strings(what_list)

        nplots = len(what_list)
        nrows, ncols = 1, 1
        if nplots > 1:
            ncols = 2
            nrows = nplots // ncols + nplots % ncols

        ax_list, fig, plt = get_axarray_fig_plt(ax_list,
                                                nrows=nrows,
                                                ncols=ncols,
                                                sharex=True,
                                                sharey=False,
                                                squeeze=False)
        ax_list = ax_list.ravel()
        assert len(ax_list) == len(what_list)

        # don't show the last ax if nplots is odd.
        if nplots % ncols != 0: ax_list[-1].axis("off")

        for what, ax in zip(what_list, ax_list):
            self.plot_ax(ax, what, fontsize=fontsize, marker="o")

        return fig
Пример #30
0
    def plot_elastic_properties(self, fontsize=10, **kwargs):
        """
        Args:
            fontsize: legend and label fontsize.

        Returns: |matplotlib-Figure|
        """
        df = self.get_elastic_dataframe(with_geo=False, abspath=False, with_params=False)
        from pandas.api.types import is_numeric_dtype
        keys = [k for k in df.keys() if is_numeric_dtype(df[k])]
        i = keys.index("fitted_to_structure")
        if i != -1:
            keys.pop(i)

        num_plots, ncols, nrows = len(keys), 1, 1
        if num_plots > 1:
            ncols = 3
            nrows = (num_plots // ncols) + (num_plots % ncols)

        ax_list, fig, plt = get_axarray_fig_plt(None, nrows=nrows, ncols=ncols,
                                                sharex=False, sharey=False, squeeze=False)
        ax_list = ax_list.ravel()

        for ix, (key, ax) in enumerate(zip(keys, ax_list)):
            irow, icol = divmod(ix, ncols)
            xn = range(len(df.index))
            ax.plot(xn, df[key], marker="o")
            ax.grid(True)
            ax.set_xticks(xn)
            ax.set_ylabel(key, fontsize=fontsize)
            ax.set_xticklabels([])

        ax.set_xticklabels(self.keys(), fontsize=fontsize)
        rotate_ticklabels(ax, 15)

        if ix != len(ax_list) -1:
            for ix in range(ix + 1, len(ax_list)):
                ax_list[ix].axis('off')

        return fig
Пример #31
0
    def plot_emass(self, acc=4, fontsize=6, colormap="viridis", **kwargs):
        """
        Plot electronic dispersion and quadratic curve based on the
        effective masses computed along each segment.

        Args:
            acc:
            fontsize: legend and title fontsize.
            colormap: matplotlib colormap
        """
        self._consistency_check()

        # Build grid of plots for this spin.
        num_plots, ncols, nrows = len(self.segments), 1, 1
        if num_plots > 1:
            ncols = 2
            nrows = (num_plots // ncols) + (num_plots % ncols)

        ax_list, fig, plt = get_axarray_fig_plt(None,
                                                nrows=nrows,
                                                ncols=ncols,
                                                sharex=False,
                                                sharey=True,
                                                squeeze=False)
        ax_list = ax_list.ravel()

        for iseg, (segment, ax) in enumerate(zip(self.segments, ax_list)):
            irow, icol = divmod(iseg, ncols)
            segment.plot_emass(ax=ax,
                               acc=acc,
                               fontsize=fontsize,
                               colormap=colormap,
                               show=False)
            if iseg != 0: set_visible(ax, False, "ylabel")
            if irow != nrows - 1: set_visible(ax, False, "xticklabels")

        # don't show the last ax if numeb is odd.
        if num_plots % ncols != 0: ax_list[-1].axis("off")

        return fig
Пример #32
0
    def plot_doses(self, xlims=None, dos_names="all", with_idos=True, **kwargs):
        r"""
        Plot the different doses stored in the GRUNS.nc file.

        Args:
            xlims: Set the data limits for the x-axis in eV. Accept tuple e.g. ``(left, right)``
                or scalar e.g. ``left``. If left (right) is None, default values are used
            dos_names: List of strings defining the DOSes to plot. Use "all" to plot all DOSes available.
            with_idos: True to display integrated doses.

        Return: |matplotlib-Figure|
        """
        if not self.doses: return None

        dos_names = _ALL_DOS_NAMES.keys() if dos_names == "all" else list_strings(dos_names)
        wmesh = self.doses["wmesh"]

        nrows, ncols = len(dos_names), 1
        ax_list, fig, plt = get_axarray_fig_plt(None, nrows=nrows, ncols=ncols,
                                                sharex=True, sharey=False, squeeze=False)
        ax_list = ax_list.ravel()

        for i, (name, ax) in enumerate(zip(dos_names, ax_list)):
            dos, idos = self.doses[name][0], self.doses[name][1]
            ax.plot(wmesh, dos, color="k")
            ax.grid(True)
            set_axlims(ax, xlims, "x")
            ax.set_ylabel(_ALL_DOS_NAMES[name]["latex"])
            #ax.yaxis.set_ticks_position("right")

            if with_idos:
                other_ax = ax.twinx()
                other_ax.plot(wmesh, idos, color="k")
                other_ax.set_ylabel(_ALL_DOS_NAMES[name]["latex"].replace("DOS", "IDOS"))

            if i == len(dos_names) - 1:
                ax.set_xlabel(r"$\omega$ (eV)")
            #ax.legend(loc="best", fontsize=fontsize, shadow=True)

        return fig
Пример #33
0
    def compare(self, others, **kwargs):
        """Produce matplotlib plot comparing self with another list of pseudos ``others``."""
        if not isinstance(others, (list, tuple)):
            others = [others]

        ax_list, fig, plt = get_axarray_fig_plt(None, nrows=2, ncols=2,
                                                sharex=False, sharey=False, squeeze=True)
        ax_list = ax_list.ravel()
        #fig.suptitle("%s vs %s" % (self.basename, ", ".join(o.basename for o in others)))

        def mkcolor(count):
            npseudos = 1 + len(others)
            if npseudos <= 2:
                return {0: "red", 1: "blue"}[count]
            else:
                cmap = plt.get_cmap("jet")
                return cmap(float(count)/ (1 + len(others)))

        ic = 0; ax = ax_list[ic]
        self.plot_tcore_rspace(ax=ax, color=mkcolor(0), show=False, with_legend=False)
        for count, other in enumerate(others):
            other.plot_tcore_rspace(ax=ax, color=mkcolor(count+1), show=False, with_legend=False)

        ic += 1; ax = ax_list[ic]
        self.plot_tcore_qspace(ax=ax, with_qn=0, color=mkcolor(0), show=False)
        for count, other in enumerate(others):
            other.plot_tcore_qspace(ax=ax, with_qn=0, color=mkcolor(count+1), show=False)

        ic += 1; ax = ax_list[ic]
        self.plot_vlocq(ax=ax, with_qn=0, color=mkcolor(0), show=False)
        for count, other in enumerate(others):
            other.plot_vlocq(ax=ax, with_qn=0, color=mkcolor(count+1), show=False)

        ic += 1; ax = ax_list[ic]
        self.plot_ffspl(ax=ax, with_qn=0, color=mkcolor(0), show=False)
        for count, other in enumerate(others):
            other.plot_ffspl(ax=ax, with_qn=0, color=mkcolor(count+1), show=False)

        return fig
Пример #34
0
    def plot(self, ax_list=None, fontsize=8, **kwargs):
        """
        Plot the evolution of structural parameters (lattice lengths, angles and volume)
        as well as pressure, info on forces and total energy.

        Args:
            ax_list: List of |matplotlib-Axes|. If None, a new figure is created.
            fontsize: fontsize for legend

        Returns: |matplotlib-Figure|
        """
        what_list = ["abc", "angles", "volume", "pressure", "forces", "energy"]
        nrows, ncols = 3, 2
        ax_list, fig, plt = get_axarray_fig_plt(None, nrows=nrows, ncols=ncols,
                                                sharex=True, sharey=False, squeeze=False)
        ax_list = ax_list.ravel()
        assert len(ax_list) == len(what_list)

        for what, ax in zip(what_list, ax_list):
            self.plot_ax(ax, what, fontsize=fontsize, marker="o")

        return fig
Пример #35
0
    def gridplot(self, what_list=None, sharex="row", sharey="row", fontsize=8, **kwargs):
        """
        Plot the ``what`` value extracted from multiple HIST.nc_ files on a grid.

        Args:
            what_list: List of quantities to plot.
                Must be in ["energy", "abc", "angles", "volume", "pressure", "forces"]
            sharex: True if xaxis should be shared.
            sharey: True if yaxis should be shared.
            fontsize: fontsize for legend.

        Returns: |matplotlib-Figure|
        """
        what_list = list_strings(what_list) if what_list is not None else self.what_list

        # Build grid of plots.
        nrows, ncols = len(what_list), len(self)

        ax_mat, fig, plt = get_axarray_fig_plt(None, nrows=nrows, ncols=ncols,
                                               sharex=sharex, sharey=sharey, squeeze=False)
        ax_mat = np.reshape(ax_mat, (nrows, ncols))

        for irow, what in enumerate(what_list):
            for icol, hist in enumerate(self.abifiles):
                ax = ax_mat[irow, icol]
                ax.grid(True)
                hist.plot_ax(ax_mat[irow, icol], what, fontsize=fontsize, marker="o")

                if irow == 0:
                    ax.set_title(hist.relpath, fontsize=fontsize)
                if irow != nrows - 1:
                    set_visible(ax, False, "xlabel")
                if icol != 0:
                    set_visible(ax, False, "ylabel")

        return fig
Пример #36
0
    def plot(self, **kwargs):
        """
        Driver routine to plot several quantities on the same graph.

        Args:
            ecut_ffnl: Max cutoff energy for ffnl plot (optional)

        Return: |matplotlib-Figure|
        """
        methods = [
            "plot_tcore_rspace",
            "plot_tcore_qspace",
            "plot_ffspl",
            "plot_vlocq",
        ]

        ax_list, fig, plt = get_axarray_fig_plt(None, nrows=2, ncols=2,
                                                sharex=False, sharey=False, squeeze=True)

        ecut_ffnl = kwargs.pop("ecut_ffnl", None)
        for m, ax in zip(methods, ax_list.ravel()):
            getattr(self, m)(ax=ax, ecut_ffnl=ecut_ffnl, show=False)

        return fig
Пример #37
0
    def plot_centers_spread(self, fontsize=8, **kwargs):
        """
        Plot the convergence of the Wannier centers and spread
        as function of iteration number

        Args:
            fontsize: legend and label fontsize.

        Returns: |matplotlib-Figure|
        """
        if self._parse_iterations() != 0:
            print(
                "Wout files does not contain Wannierization cycles. Returning None"
            )
            return None

        # Build grid of plots.
        # nwan subplot with evolution of the WF center + last subplot with all spreads
        num_plots, ncols, nrows = self.nwan + 1, 1, 1
        if num_plots > 1:
            ncols = 2
            nrows = (num_plots // ncols) + (num_plots % ncols)

        ax_list, fig, plt = get_axarray_fig_plt(None,
                                                nrows=nrows,
                                                ncols=ncols,
                                                sharex=True,
                                                sharey=False,
                                                squeeze=False)
        ax_list = ax_list.ravel()

        # Don't show the last ax if num_plots is odd.
        if num_plots % ncols != 0: ax_list[-1].axis("off")

        marker = "."
        for iax in range(num_plots):
            ax = ax_list[iax]
            ax.grid(True)
            ax.set_xlabel("Iteration Step")
            s = 1
            if iax < self.nwan:
                ax.set_ylabel("Center of WF #%s" % (iax + 1))
                for idir in range(3):
                    ax.plot(self.conv_df.iter[s:],
                            self.wf_centers[iax, s:, idir],
                            marker=marker,
                            label={
                                0: "x",
                                1: "y",
                                2: "z"
                            }[idir] if iax == 0 else None)
            else:
                ax.set_ylabel("WF Spread")
                for iw in range(self.nwan):
                    ax.plot(self.conv_df.iter[s:],
                            self.wf_spreads[iw, s:],
                            marker=marker,
                            label="WF#%d" % (iw + 1))

            if iax in (0, self.nwan):
                ax.legend(loc="best", shadow=True, fontsize=fontsize)

        return fig
Пример #38
0
    def plot(self, fontsize=12, **kwargs):
        """
        Plot the convergence of the Wannierise cycle.

        Args:
            fontsize: legend and label fontsize.

        Returns: |matplotlib-Figure|
        """
        if self._parse_iterations() != 0:
            print(
                "Wout files does not contain Wannierization cycles. Returning None"
            )
            return None

        items = ["delta_spread", "rms_gradient", "spread"]
        if self.use_disentangle:
            items += ["omegaI_i"]

        # Build grid of plots.
        num_plots, ncols, nrows = len(items), 1, 1
        if num_plots > 1:
            ncols = 2
            nrows = (num_plots // ncols) + (num_plots % ncols)

        ax_list, fig, plt = get_axarray_fig_plt(None,
                                                nrows=nrows,
                                                ncols=ncols,
                                                sharex=True,
                                                sharey=False,
                                                squeeze=False)
        ax_list = ax_list.ravel()

        # Don't show the last ax if num_plots is odd.
        if num_plots % ncols != 0: ax_list[-1].axis("off")

        marker = "."
        for ax, item in zip(ax_list, items):
            ax.grid(True)
            ax.set_xlabel("Iteration Step")
            ax.set_ylabel(item)
            s = 1
            if item == "omegaI_i":
                # Plot Disentanglement cycles
                ax.plot(self.dis_df.iter[s:],
                        self.dis_df[item][s:],
                        marker=marker)
                from mpl_toolkits.axes_grid1.inset_locator import inset_axes
                ax2 = inset_axes(ax,
                                 width="60%",
                                 height="40%",
                                 loc="upper right")
                ax2.grid(True)
                ax2.set_title("delta_frac", fontsize=8)
                ax2.plot(self.dis_df.iter[s:],
                         self.dis_df["delta_frac"][s:],
                         marker=marker)

            else:
                ax.plot(self.conv_df.iter[s:],
                        self.conv_df[item][s:],
                        marker=marker)

        return fig
Пример #39
0
    def plot_pots_at_qpoint(self, qpoint=0, fontsize=8, **kwargs):
        """
        Args:
            qpoint:
            ax: |matplotlib-Axes| or None if a new figure should be created.
            fontsize: fontsize for legends and titles

        Return: |matplotlib-Figure|
        """
        iq, qpoint = self._find_iqpt_qpoint(qpoint)

        # complex arrays with shape: (natom3, nspden * nfft)
        origin_v1 = self.read_v1_at_iq("origin_v1scf",
                                       iq,
                                       reshape_nfft_nspden=True)
        symm_v1 = self.read_v1_at_iq("recons_v1scf",
                                     iq,
                                     reshape_nfft_nspden=True)

        num_plots, ncols, nrows = self.natom3, 3, self.natom3 // 3
        ax_list, fig, plt = get_axarray_fig_plt(None,
                                                nrows=nrows,
                                                ncols=ncols,
                                                sharex=False,
                                                sharey=False,
                                                squeeze=False)

        natom = len(self.structure)
        xs = np.arange(self.nspden * self.nfft)
        for nu, ax in enumerate(ax_list.ravel()):
            idir = nu % 3
            ipert = (nu - idir) // 3

            # l1_rerr(f1, f2) = \int |f1 - f2| dr / (\int |f2| dr
            abs_diff = np.abs(origin_v1[nu] - symm_v1[nu])
            l1_rerr = np.sum(abs_diff) / np.sum(np.abs(origin_v1[nu]))

            stats = OrderedDict([
                ("max", abs_diff.max()),
                ("min", abs_diff.min()),
                ("mean", abs_diff.mean()),
                ("std", abs_diff.std()),
                ("L1_rerr", l1_rerr),
            ])

            ax.grid(True)
            ax.set_title("idir: %d, iat: %d, pertsy: %d" %
                         (idir, ipert, self.pertsy_qpt[iq, ipert, idir]),
                         fontsize=fontsize)
            # Plot absolute error
            #ax.plot(xs, abs_diff, linestyle="-", color="red", alpha=1.0, label="Abs diff" if nu == 0 else None)

            # Plot absolute values
            #ax.plot(xs, np.abs(origin_v1[nu]), linestyle="--", color="red", alpha=0.4, label="Origin" if nu == 0 else None)
            #ax.plot(xs, -np.abs(symm_v1[nu]), linestyle="--", color="blue", alpha=0.4, label="-Symm" if nu == 0 else None)

            # Plot real and imag
            #ax.plot(xs, origin_v1[nu].real, linestyle="--", color="red", alpha=0.4, label="Re Origin" if nu == 0 else None)
            #ax.plot(xs, -symm_v1[nu].real, linestyle="--", color="blue", alpha=0.4, label="Re Symm" if nu == 0 else None)

            data = np.angle(origin_v1[nu], deg=True) - np.angle(symm_v1[nu],
                                                                deg=True)
            #data = data[abs_diff > stats["mean"]]
            data = data[np.abs(origin_v1[nu]) > 1e-5]
            ax.plot(np.arange(len(data)),
                    data,
                    linestyle="--",
                    color="red",
                    alpha=0.4,
                    label="diff angle degrees" if nu == 0 else None)

            #ax.plot(xs, origin_v1[nu].real, linestyle="--", color="red", alpha=0.4, label="Re Origin" if nu == 0 else None)
            #ax.plot(xs, -symm_v1[nu].real, linestyle="--", color="blue", alpha=0.4, label="Re Symm" if nu == 0 else None)

            #ax.plot(xs, origin_v1[nu].real - symm_v1[nu].real, linestyle="--", color="red", alpha=0.4,
            #        label="Re Origin" if nu == 0 else None)

            #ax.plot(xs, origin_v1[nu].imag, linestyle=":", color="red", alpha=0.4, label="Imag Origin" if nu == 0 else None)
            #ax.plot(xs, -symm_v1[nu].imag, linestyle=":", color="blue", alpha=0.4, label="Imag Symm" if nu == 0 else None)

            #ax.plot(xs, origin_v1[nu].imag - symm_v1[nu].imag, linestyle="--", color="blue", alpha=0.4,
            #        label="Re Origin" if nu == 0 else None)

            if nu == 0:
                ax.set_ylabel(r"Abs diff")
                ax.legend(loc="best", fontsize=fontsize, shadow=True)
            if ipert == natom - 1:
                ax.set_xlabel(r"FFT index")

            #ax.axvline(stats["mean"], color='k', linestyle='dashed', linewidth=1)
            _, max_ = ax.get_ylim()
            ax.text(0.7,
                    0.7,
                    "\n".join("%s = %.1E" % item for item in stats.items()),
                    fontsize=fontsize,
                    horizontalalignment='center',
                    verticalalignment='center',
                    transform=ax.transAxes)

            #ax2 = ax.twinx()
            #rerr = 100 * abs_diff / np.abs(origin_v1[nu])
            #ax2.plot(xs, rerr, linestyle="--", color="blue", alpha=0.4,
            #          label=r"|V_{\mathrm{origin}}|" if nu == 0 else None)

        fig.suptitle("qpoint: %s" % repr(qpoint))
        return fig
Пример #40
0
    def plot_convergence_items(self, items, sortby=None, hue=None, fontsize=6, **kwargs):
        """
        Plot the convergence of a list of ``items`` wrt to the ``sortby`` parameter.
        Values can optionally be grouped by ``hue``.

        Args:
            items: List of attributes (or callables) to be analyzed.
            sortby: Define the convergence parameter, sort files and produce plot labels.
                Can be None, string or function. If None, no sorting is performed.
                If string and not empty it's assumed that the abifile has an attribute
                with the same name and `getattr` is invoked.
                If callable, the output of sortby(abifile) is used.
            hue: Variable that define subsets of the data, which will be drawn on separate lines.
                Accepts callable or string
                If string, it's assumed that the abifile has an attribute with the same name and getattr is invoked.
                Dot notation is also supported e.g. hue="structure.formula" --> abifile.structure.formula
                If callable, the output of hue(abifile) is used.
            fontsize: legend and label fontsize.
            kwargs: keyword arguments are passed to ax.plot

        Returns: |matplotlib-Figure|
        """
        # Note: in principle one could call plot_convergence inside a loop but
        # this one is faster as sorting is done only once.

        # Build grid plot.
        nrows, ncols = len(items), 1
        ax_list, fig, plt = get_axarray_fig_plt(None, nrows=nrows, ncols=ncols,
                                                sharex=True, sharey=False, squeeze=False)
        ax_list = ax_list.ravel()

        # Sort and group files if hue.
        if hue is None:
            labels, ncfiles, params = self.sortby(sortby, unpack=True)
        else:
            groups = self.group_and_sortby(hue, sortby)

        marker = kwargs.pop("marker", "o")
        for i, (ax, item) in enumerate(zip(ax_list, items)):
            if hue is None:
                # Extract data.
                if callable(item):
                    yvals = [float(item(gsr)) for gsr in self.abifiles]
                else:
                    yvals = [getattrd(gsr, item) for gsr in self.abifiles]

                if not is_string(params[0]):
                    ax.plot(params, yvals, marker=marker, **kwargs)
                else:
                    # Must handle list of strings in a different way.
                    xn = range(len(params))
                    ax.plot(xn, yvals, marker=marker, **kwargs)
                    ax.set_xticks(xn)
                    ax.set_xticklabels(params, fontsize=fontsize)
            else:
                for g in groups:
                    # Extract data.
                    if callable(item):
                        yvals = [float(item(gsr)) for gsr in g.abifiles]
                    else:
                        yvals = [getattrd(gsr, item) for gsr in g.abifiles]
                    label = "%s: %s" % (self._get_label(hue), g.hvalue)
                    ax.plot(g.xvalues, yvals, label=label, marker=marker, **kwargs)

            ax.grid(True)
            ax.set_ylabel(self._get_label(item))
            if i == len(items) - 1:
                ax.set_xlabel("%s" % self._get_label(sortby))
                if sortby is None: rotate_ticklabels(ax, 15)
            if i == 0 and hue is not None:
                ax.legend(loc="best", fontsize=fontsize, shadow=True)

        return fig
Пример #41
0
    def plot_lattice_convergence(self, what_list=None, sortby=None, hue=None, fontsize=8, **kwargs):
        """
        Plot the convergence of the lattice parameters (a, b, c, alpha, beta, gamma).
        wrt the``sortby`` parameter. Values can optionally be grouped by ``hue``.

        Args:
            what_list: List of strings with the quantities to plot e.g. ["a", "alpha", "beta"].
                None means all.
            item: Define the quantity to plot. Accepts callable or string
                If string, it's assumed that the abifile has an attribute
                with the same name and `getattr` is invoked.
                If callable, the output of item(abifile) is used.
            sortby: Define the convergence parameter, sort files and produce plot labels.
                Can be None, string or function.
                If None, no sorting is performed.
                If string and not empty it's assumed that the abifile has an attribute
                with the same name and `getattr` is invoked.
                If callable, the output of sortby(abifile) is used.
            hue: Variable that define subsets of the data, which will be drawn on separate lines.
                Accepts callable or string
                If string, it's assumed that the abifile has an attribute with the same name and getattr is invoked.
                Dot notation is also supported e.g. hue="structure.formula" --> abifile.structure.formula
                If callable, the output of hue(abifile) is used.
            ax: |matplotlib-Axes| or None if a new figure should be created.
            fontsize: legend and label fontsize.

        Returns: |matplotlib-Figure|

        Example:

             robot.plot_lattice_convergence()

             robot.plot_lattice_convergence(sortby="nkpt")

             robot.plot_lattice_convergence(sortby="nkpt", hue="tsmear")
        """
        if not self.abifiles: return None

        # The majority of AbiPy files have a structure object
        # whereas Hist.nc defines final_structure. Use geattr and key to extract structure object.
        key = "structure"
        if not hasattr(self.abifiles[0], "structure"):
            if hasattr(self.abifiles[0], "final_structure"):
                key = "final_structure"
            else:
                raise TypeError("Don't know how to extract structure from %s" % type(self.abifiles[0]))

        # Define callbacks. docstrings will be used as ylabels.
        def a(afile):
            "a (Ang)"
            return getattr(afile, key).lattice.a
        def b(afile):
            "b (Ang)"
            return getattr(afile, key).lattice.b
        def c(afile):
            "c (Ang)"
            return getattr(afile, key).lattice.c
        def volume(afile):
            r"$V$"
            return getattr(afile, key).lattice.volume
        def alpha(afile):
            r"$\alpha$"
            return getattr(afile, key).lattice.alpha
        def beta(afile):
            r"$\beta$"
            return getattr(afile, key).lattice.beta
        def gamma(afile):
            r"$\gamma$"
            return getattr(afile, key).lattice.gamma

        items = [a, b, c, volume, alpha, beta, gamma]
        if what_list is not None:
            locs = locals()
            items = [locs[what] for what in list_strings(what_list)]

        # Build plot grid.
        nrows, ncols = len(items), 1
        ax_list, fig, plt = get_axarray_fig_plt(None, nrows=nrows, ncols=ncols,
                                                sharex=True, sharey=False, squeeze=False)

        marker = kwargs.pop("marker", "o")
        for i, (ax, item) in enumerate(zip(ax_list.ravel(), items)):
            self.plot_convergence(item, sortby=sortby, hue=hue, ax=ax, fontsize=fontsize,
                                  marker=marker, show=False)
            if i != 0:
                set_visible(ax, False, "legend")
            if i != len(items) - 1:
                set_visible(ax, False, "xlabel")

        return fig
Пример #42
0
    def plot_tensor(self,
                    tstart=0,
                    tstop=600,
                    num=50,
                    components="all",
                    what="displ",
                    view="inequivalent",
                    select_symbols=None,
                    colormap="jet",
                    xlims=None,
                    ylims=None,
                    fontsize=10,
                    verbose=0,
                    **kwargs):
        """
        Plot tensor(T) for each atom in the unit cell.
        One subplot for each component, each subplot show all inequivalent sites.
        By default, only "inequivalent" atoms are shown.

        Args:
            tstart: The starting value (in Kelvin) of the temperature mesh.
            tstop: The end value (in Kelvin) of the mesh.
            num: int, optional Number of samples to generate.
            components: "all" for all components. "diag" for diagonal elements, "offdiag" for off-diagonal terms only.
            what: "displ" for displament, "vel" for velocity.
            view: "inequivalent" to show only inequivalent atoms. "all" for all sites.
            select_symbols: String or list of strings with chemical symbols. Used to select only atoms of this type.
            colormap: matplotlib colormap.
            xlims: Set the data limits for the x-axis. Accept tuple e.g. ``(left, right)``
                   or scalar e.g. ``left``. If left (right) is None, default values are used.
            ylims: Set the data limits for the y-axis. Accept tuple e.g. ``(left, right)``
                   or scalar e.g. ``left``. If left (right) is None, default values are used
            fontsize: Legend and title fontsize.
            verbose: Verbosity level.

        Returns: |matplotlib-Figure|
        """
        # Select atoms.
        aview = self._get_atomview(view,
                                   select_symbols=select_symbols,
                                   verbose=verbose)

        # One subplot for each component
        diag = ["xx", "yy", "zz"]
        offdiag = ["xy", "xz", "yz"]
        components = {
            "all": diag + offdiag,
            "diag": diag,
            "offdiag": offdiag,
        }[components]

        components = self._get_components(components)
        shape = np.reshape(components, (-1, 3)).shape
        nrows, ncols = shape[0], shape[1]

        ax_list, fig, plt = get_axarray_fig_plt(None,
                                                nrows=nrows,
                                                ncols=ncols,
                                                sharex=True,
                                                sharey=True,
                                                squeeze=True)
        ax_list = np.reshape(ax_list, (nrows, ncols)).ravel()
        cmap = plt.get_cmap(colormap)

        # Compute U(T)
        tmesh = np.linspace(tstart, tstop, num=num)
        msq = self.get_msq_tmesh(tmesh,
                                 iatom_list=aview.iatom_list,
                                 what_list=what)
        # [natom,3,3,nt] array
        values = getattr(msq, what)

        for ix, (ax, comp) in enumerate(zip(ax_list, components)):
            irow, icol = divmod(ix, ncols)
            ax.grid(True)
            set_axlims(ax, xlims, "x")
            set_axlims(ax, ylims, "y")
            ylabel = comp.get_tavg_label(what, with_units=True)
            ax.set_ylabel(ylabel, fontsize=fontsize)

            # Plot this component for all inequivalent atoms on the same subplot.
            for ii, (iatom, site_label) in enumerate(
                    zip(aview.iatom_list, aview.site_labels)):
                color = cmap(float(ii) / max((len(aview.iatom_list) - 1), 1))
                ys = comp.eval33w(values[iatom])
                ax.plot(msq.tmesh,
                        ys,
                        label=site_label if ix == 0 else None,
                        color=color)  #, marker="o")
                if ix == 0:
                    ax.legend(loc="best", fontsize=fontsize, shadow=True)

            if irow == 1:
                ax.set_xlabel('Temperature (K)')
            else:
                set_visible(ax, False, "xlabel", "xticklabels")

        return fig
Пример #43
0
    def plot_uiso(self,
                  tstart=0,
                  tstop=600,
                  num=50,
                  what="displ",
                  view="inequivalent",
                  select_symbols=None,
                  colormap="jet",
                  xlims=None,
                  ylims=None,
                  sharey=False,
                  fontsize=10,
                  verbose=0,
                  **kwargs):
        """
        Plot phonon PJDOS for each atom in the unit cell.
        One subplot for each component, each subplot show all inequivalent sites.
        By default, only "inequivalent" atoms are shown.

        comparison of Ueq values, which
        are calculated as the mean of the diagonal elements of the harmonic ADP tensor, (d)
        comparison of the ADP anisotropy factor, which is defined as the ratio between maximum Uii
        and minimum Uii values. A ratio of 1 would correspond to an isotropic displacement.

        Args:
            tstart: The starting value (in Kelvin) of the temperature mesh.
            tstop: The end value (in Kelvin) of the mesh.
            num: int, optional Number of samples to generate.
            components: "all" for all components. "diag" for diagonal elements, "offdiag" for off-diagonal terms only.
            what: "displ" for displament, "vel" for velocity.
            view: "inequivalent" to show only inequivalent atoms. "all" for all sites.
            select_symbols: String or list of strings with chemical symbols. Used to select only atoms of this type.
            colormap: matplotlib colormap.
            xlims: Set the data limits for the x-axis. Accept tuple e.g. ``(left, right)``
                   or scalar e.g. ``left``. If left (right) is None, default values are used.
            ylims: Set the data limits for the y-axis. Accept tuple e.g. ``(left, right)``
                   or scalar e.g. ``left``. If left (right) is None, default values are used
            sharey: True if y-axis should be shared.
            fontsize: Legend and title fontsize.
            verbose: Verbosity level.

        Returns: |matplotlib-Figure|
        """
        # Select atoms.
        aview = self._get_atomview(view,
                                   select_symbols=select_symbols,
                                   verbose=verbose)

        ax_list, fig, plt = get_axarray_fig_plt(None,
                                                nrows=2,
                                                ncols=1,
                                                sharex=True,
                                                sharey=sharey,
                                                squeeze=True)
        cmap = plt.get_cmap(colormap)

        # Compute U(T)
        tmesh = np.linspace(tstart, tstop, num=num)
        msq = self.get_msq_tmesh(tmesh,
                                 iatom_list=aview.iatom_list,
                                 what_list=what)
        # [natom, 3, 3, nt]
        values = getattr(msq, what)
        ntemp = len(msq.tmesh)

        for ix, ax in enumerate(ax_list):
            ax.grid(True)
            set_axlims(ax, xlims, "x")
            set_axlims(ax, ylims, "y")
            if what == "displ":
                ylabel = r"$U_{iso}\;(\AA^2)$" if ix == 0 else \
                         r"Anisotropy factor ($\dfrac{\epsilon_{max}}{\epsilon_{min}}}$)"
            elif what == "vel":
                ylabel = r"$V_{iso}\;(m/s)^2$" if ix == 0 else \
                         r"Anisotropy factor ($\dfrac{\epsilon_{max}}{\epsilon_{min}}}$)"
            else:
                raise ValueError("Unknown value for what: `%s`" % str(what))
            ax.set_ylabel(ylabel, fontsize=fontsize)

            # Plot this component for all inequivalent atoms on the same subplot.
            for ii, (iatom, site_label) in enumerate(
                    zip(aview.iatom_list, aview.site_labels)):
                color = cmap(float(ii) / max((len(aview.iatom_list) - 1), 1))
                #msq.displ[iatom, 3, 3, nt]
                if ix == 0:
                    # ISO calculated as the mean of the diagonal elements of the harmonic ADP tensor
                    ys = np.trace(values[iatom]) / 3.0
                elif ix == 1:
                    # Ratio between maximum Uii and minimum Uii values.
                    # A ratio of 1 would correspond to an isotropic displacement.
                    ys = np.empty(ntemp)
                    for itemp in range(ntemp):
                        eigs = np.linalg.eigvalsh(values[iatom, :, :, itemp],
                                                  UPLO='U')
                        ys[itemp] = eigs.max() / eigs.min()
                else:
                    raise ValueError("Invalid ix index: `%s" % ix)

                ax.plot(msq.tmesh,
                        ys,
                        label=site_label if ix == 0 else None,
                        color=color)  #, marker="o")
                if ix == 0:
                    ax.legend(loc="best", fontsize=fontsize, shadow=True)

            if ix == len(ax_list) - 1:
                ax.set_xlabel("Temperature (K)")
            else:
                set_visible(ax, False, "xlabel", "xticklabels")

        return fig
Пример #44
0
    def plot_lattice_convergence(self,
                                 what_list=None,
                                 sortby=None,
                                 hue=None,
                                 fontsize=8,
                                 **kwargs):
        """
        Plot the convergence of the lattice parameters (a, b, c, alpha, beta, gamma).
        wrt the``sortby`` parameter. Values can optionally be grouped by ``hue``.

        Args:
            what_list: List of strings with the quantities to plot e.g. ["a", "alpha", "beta"].
                None means all.
            item: Define the quantity to plot. Accepts callable or string
                If string, it's assumed that the abifile has an attribute
                with the same name and `getattr` is invoked.
                If callable, the output of item(abifile) is used.
            sortby: Define the convergence parameter, sort files and produce plot labels.
                Can be None, string or function.
                If None, no sorting is performed.
                If string and not empty it's assumed that the abifile has an attribute
                with the same name and `getattr` is invoked.
                If callable, the output of sortby(abifile) is used.
            hue: Variable that define subsets of the data, which will be drawn on separate lines.
                Accepts callable or string
                If string, it's assumed that the abifile has an attribute with the same name and getattr is invoked.
                Dot notation is also supported e.g. hue="structure.formula" --> abifile.structure.formula
                If callable, the output of hue(abifile) is used.
            ax: |matplotlib-Axes| or None if a new figure should be created.
            fontsize: legend and label fontsize.

        Returns: |matplotlib-Figure|

        Example:

             robot.plot_lattice_convergence()

             robot.plot_lattice_convergence(sortby="nkpt")

             robot.plot_lattice_convergence(sortby="nkpt", hue="tsmear")
        """
        if not self.abifiles: return None

        # The majority of AbiPy files have a structure object
        # whereas Hist.nc defines final_structure. Use geattr and key to extract structure object.
        key = "structure"
        if not hasattr(self.abifiles[0], "structure"):
            if hasattr(self.abifiles[0], "final_structure"):
                key = "final_structure"
            else:
                raise TypeError("Don't know how to extract structure from %s" %
                                type(self.abifiles[0]))

        # Define callbacks. docstrings will be used as ylabels.
        def a(afile):
            "a (Ang)"
            return getattr(afile, key).lattice.a

        def b(afile):
            "b (Ang)"
            return getattr(afile, key).lattice.b

        def c(afile):
            "c (Ang)"
            return getattr(afile, key).lattice.c

        def volume(afile):
            r"$V$"
            return getattr(afile, key).lattice.volume

        def alpha(afile):
            r"$\alpha$"
            return getattr(afile, key).lattice.alpha

        def beta(afile):
            r"$\beta$"
            return getattr(afile, key).lattice.beta

        def gamma(afile):
            r"$\gamma$"
            return getattr(afile, key).lattice.gamma

        items = [a, b, c, volume, alpha, beta, gamma]
        if what_list is not None:
            locs = locals()
            items = [locs[what] for what in list_strings(what_list)]

        # Build plot grid.
        nrows, ncols = len(items), 1
        ax_list, fig, plt = get_axarray_fig_plt(None,
                                                nrows=nrows,
                                                ncols=ncols,
                                                sharex=True,
                                                sharey=False,
                                                squeeze=False)

        marker = kwargs.pop("marker", "o")
        for i, (ax, item) in enumerate(zip(ax_list.ravel(), items)):
            self.plot_convergence(item,
                                  sortby=sortby,
                                  hue=hue,
                                  ax=ax,
                                  fontsize=fontsize,
                                  marker=marker,
                                  show=False)
            if i != 0:
                set_visible(ax, False, "legend")
            if i != len(items) - 1:
                set_visible(ax, False, "xlabel")

        return fig
Пример #45
0
    def plot_convergence_items(self,
                               items,
                               sortby=None,
                               hue=None,
                               fontsize=6,
                               **kwargs):
        """
        Plot the convergence of a list of ``items`` wrt to the ``sortby`` parameter.
        Values can optionally be grouped by ``hue``.

        Args:
            items: List of attributes (or callables) to be analyzed.
            sortby: Define the convergence parameter, sort files and produce plot labels.
                Can be None, string or function. If None, no sorting is performed.
                If string and not empty it's assumed that the abifile has an attribute
                with the same name and `getattr` is invoked.
                If callable, the output of sortby(abifile) is used.
            hue: Variable that define subsets of the data, which will be drawn on separate lines.
                Accepts callable or string
                If string, it's assumed that the abifile has an attribute with the same name and getattr is invoked.
                Dot notation is also supported e.g. hue="structure.formula" --> abifile.structure.formula
                If callable, the output of hue(abifile) is used.
            fontsize: legend and label fontsize.
            kwargs: keyword arguments are passed to ax.plot

        Returns: |matplotlib-Figure|
        """
        # Note: in principle one could call plot_convergence inside a loop but
        # this one is faster as sorting is done only once.

        # Build grid plot.
        nrows, ncols = len(items), 1
        ax_list, fig, plt = get_axarray_fig_plt(None,
                                                nrows=nrows,
                                                ncols=ncols,
                                                sharex=True,
                                                sharey=False,
                                                squeeze=False)
        ax_list = ax_list.ravel()

        # Sort and group files if hue.
        if hue is None:
            labels, ncfiles, params = self.sortby(sortby, unpack=True)
        else:
            groups = self.group_and_sortby(hue, sortby)

        marker = kwargs.pop("marker", "o")
        for i, (ax, item) in enumerate(zip(ax_list, items)):
            if hue is None:
                # Extract data.
                if callable(item):
                    yvals = [float(item(gsr)) for gsr in self.abifiles]
                else:
                    yvals = [getattrd(gsr, item) for gsr in self.abifiles]

                if not is_string(params[0]):
                    ax.plot(params, yvals, marker=marker, **kwargs)
                else:
                    # Must handle list of strings in a different way.
                    xn = range(len(params))
                    ax.plot(xn, yvals, marker=marker, **kwargs)
                    ax.set_xticks(xn)
                    ax.set_xticklabels(params, fontsize=fontsize)
            else:
                for g in groups:
                    # Extract data.
                    if callable(item):
                        yvals = [float(item(gsr)) for gsr in g.abifiles]
                    else:
                        yvals = [getattrd(gsr, item) for gsr in g.abifiles]
                    label = "%s: %s" % (self._get_label(hue), g.hvalue)
                    ax.plot(g.xvalues,
                            yvals,
                            label=label,
                            marker=marker,
                            **kwargs)

            ax.grid(True)
            ax.set_ylabel(self._get_label(item))
            if i == len(items) - 1:
                ax.set_xlabel("%s" % self._get_label(sortby))
                if sortby is None: rotate_ticklabels(ax, 15)
            if i == 0 and hue is not None:
                ax.legend(loc="best", fontsize=fontsize, shadow=True)

        return fig
Пример #46
0
    def plot(self,
             mdf_type="exc",
             qview="avg",
             xlims=None,
             ylims=None,
             fontsize=8,
             **kwargs):
        """
        Plot all macroscopic dielectric functions (MDF) stored in the plotter

        Args:
            mdf_type: Selects the type of dielectric function.
                "exc" for the MDF with excitonic effects.
                "rpa" for RPA with KS energies.
                "gwrpa" for RPA with GW (or KS-corrected) results.
            qview: "avg" to plot the results averaged over q-points. "all" to plot q-point dependence.
            xlims: Set the data limits for the y-axis. Accept tuple e.g. `(left, right)`
                  or scalar e.g. `left`. If left (right) is None, default values are used
            ylims: Same meaning as `ylims` but for the y-axis
            fontsize: fontsize for titles and legend.

        Return: |matplotlib-Figure|
        """
        # Build plot grid.
        if qview == "avg":
            ncols, nrows = 2, 1
        elif qview == "all":
            qpoints = self._get_qpoints()
            ncols, nrows = 2, len(qpoints)
        else:
            raise ValueError("Invalid value of qview: %s" % str(qview))

        ax_mat, fig, plt = get_axarray_fig_plt(None,
                                               nrows=nrows,
                                               ncols=ncols,
                                               sharex=True,
                                               sharey=True,
                                               squeeze=False)

        if qview == "avg":
            # Plot averaged values
            self.plot_mdftype_cplx(mdf_type,
                                   "Re",
                                   ax=ax_mat[0, 0],
                                   xlims=xlims,
                                   ylims=ylims,
                                   fontsize=fontsize,
                                   with_legend=True,
                                   show=False)
            self.plot_mdftype_cplx(mdf_type,
                                   "Im",
                                   ax=ax_mat[0, 1],
                                   xlims=xlims,
                                   ylims=ylims,
                                   fontsize=fontsize,
                                   with_legend=False,
                                   show=False)
        elif qview == "all":
            # Plot MDF(q)
            nqpt = len(qpoints)
            for iq, qpt in enumerate(qpoints):
                islast = (iq == nqpt - 1)
                self.plot_mdftype_cplx(mdf_type,
                                       "Re",
                                       qpoint=qpt,
                                       ax=ax_mat[iq, 0],
                                       xlims=xlims,
                                       ylims=ylims,
                                       fontsize=fontsize,
                                       with_legend=(iq == 0),
                                       with_xlabel=islast,
                                       with_ylabel=islast,
                                       show=False)
                self.plot_mdftype_cplx(mdf_type,
                                       "Im",
                                       qpoint=qpt,
                                       ax=ax_mat[iq, 1],
                                       xlims=xlims,
                                       ylims=ylims,
                                       fontsize=fontsize,
                                       with_legend=False,
                                       with_xlabel=islast,
                                       with_ylabel=islast,
                                       show=False)

        else:
            raise ValueError("Invalid value of qview: `%s`" % str(qview))

        #ax_mat[0, 0].legend(loc="best", fontsize=fontsize, shadow=True)

        return fig
Пример #47
0
    def plot_line_neighbors(self, site_index, radius, num=200, with_krphase=False, max_nn=10, fontsize=12, **kwargs):
        """
        Plot (interpolated) density/potential in real space along the lines connecting
        an atom specified by ``site_index`` and all neighbors within a sphere of given ``radius``.

        .. warning:

            This routine can produce lots of plots! Be careful with the value of ``radius``.
            See also ``max_nn``.

        Args:
            site_index: Index of the atom in the structure.
            radius: Radius of the sphere in Angstrom.
            num: Number of points sampled along the line.
            with_krphase: True to include the :math:`e^{ikr}` phase-factor.
            max_nn: By default, only the first ``max_nn`` neighbors are showed.
            fontsize: legend and label fontsize.

        Return: |matplotlib-Figure|
        """
        site = self.structure[site_index]
        nn_list = self.structure.get_neighbors(site, radius, include_index=True)
        if not nn_list:
            cprint("Zero neighbors found for radius %s Ang. Returning None." % radius, "yellow")
            return None

        # Sort sites by distance.
        nn_list = list(sorted(nn_list, key=lambda t: t[1]))
        if max_nn is not None and len(nn_list) > max_nn:
            cprint("For radius %s, found %s neighbors but only max_nn %s sites are show." %
                    (radius, len(nn_list), max_nn), "yellow")
            nn_list = nn_list[:max_nn]

        # Get grid of axes (one row for neighbor)
        nrows, ncols = len(nn_list), 1
        ax_list, fig, plt = get_axarray_fig_plt(None, nrows=nrows, ncols=ncols,
                                                sharex=True, sharey=True, squeeze=True)
        ax_list = ax_list.ravel()

        interpolator = self.get_interpolator()
        kpoint = None if not with_krphase else self.kpoint
        which = r"\psi(r)" if with_krphase else "u(r)"

        # For each neighbor, plot psi along the line connecting site to nn.
        for i, (nn, ax) in enumerate(zip(nn_list, ax_list)):
            nn_site, nn_dist, nn_sc_index  = nn
            title = "%s, %s, dist=%.3f A" % (nn_site.species_string, str(nn_site.frac_coords), nn_dist)

            r = interpolator.eval_line(site.frac_coords, nn_site.frac_coords, num=num, kpoint=kpoint)

            for ispinor in range(self.nspinor):
                spinor_label = latex_label_ispinor(ispinor, self.nspinor)
                ur = r.values[ispinor]
                ax.plot(r.dist, ur.real, label=r"$\Re %s$ %s" % (which, spinor_label))
                ax.plot(r.dist, ur.imag, label=r"$\Im %s$ %s" % (which, spinor_label))
                ax.plot(r.dist, ur.real**2 + ur.imag**2, label=r"$|\psi(r)|^2$ %s" % spinor_label)

            ax.set_title(title, fontsize=fontsize)
            ax.grid(True)

            if i == nrows - 1:
                ax.set_xlabel("Distance from site_index %s [Angstrom]" % site_index)
                ax.legend(loc="best", fontsize=fontsize, shadow=True)

        return fig
Пример #48
0
    def plot_diff_at_qpoint(self, qpoint=0, fontsize=8, **kwargs):
        """
        Args:
            qpoint:
            ax: |matplotlib-Axes| or None if a new figure should be created.
            fontsize: fontsize for legends and titles

        Return: |matplotlib-Figure|
        """
        iq, qpoint = self._find_iqpt_qpoint(qpoint)

        # complex arrays with shape: (natom3, nspden * nfft)
        origin_v1 = self.read_v1_at_iq("origin_v1scf",
                                       iq,
                                       reshape_nfft_nspden=True)
        symm_v1 = self.read_v1_at_iq("recons_v1scf",
                                     iq,
                                     reshape_nfft_nspden=True)

        num_plots, ncols, nrows = self.natom3, 3, self.natom3 // 3
        ax_list, fig, plt = get_axarray_fig_plt(None,
                                                nrows=nrows,
                                                ncols=ncols,
                                                sharex=False,
                                                sharey=False,
                                                squeeze=False)

        for nu, ax in enumerate(ax_list.ravel()):
            idir = nu % 3
            ipert = (nu - idir) // 3

            # l1_rerr(f1, f2) = \int |f1 - f2| dr / (\int |f2| dr
            abs_diff = np.abs(origin_v1[nu] - symm_v1[nu])
            l1_rerr = np.sum(abs_diff) / np.sum(np.abs(origin_v1[nu]))

            stats = OrderedDict([
                ("max", abs_diff.max()),
                ("min", abs_diff.min()),
                ("mean", abs_diff.mean()),
                ("std", abs_diff.std()),
                ("L1_rerr", l1_rerr),
            ])

            xs = np.arange(len(abs_diff))
            ax.hist(abs_diff, facecolor='g', alpha=0.75)
            ax.grid(True)
            ax.set_title("idir: %d, iat: %d, pertsy: %d" %
                         (idir, ipert, self.pertsy_qpt[iq, ipert, idir]),
                         fontsize=fontsize)

            ax.axvline(stats["mean"],
                       color='k',
                       linestyle='dashed',
                       linewidth=1)
            _, max_ = ax.get_ylim()
            ax.text(0.7,
                    0.7,
                    "\n".join("%s = %.1E" % item for item in stats.items()),
                    fontsize=fontsize,
                    horizontalalignment='center',
                    verticalalignment='center',
                    transform=ax.transAxes)

        fig.suptitle("qpoint: %s" % repr(qpoint))
        return fig