Beispiel #1
0
    def plot(self, use_index=False, fontsize=8, **kwargs):
        """
        Plot all arrays. Use multiple axes if datasets.

        Args:
            use_index: By default, the x-values are taken from the first column.
                If use_index is False, the x-values are the row index.
            fontsize: fontsize for title.
            kwargs: options passed to ``ax.plot``.

        Return: |matplotlib-figure|
        """
        # build grid of plots.
        num_plots, ncols, nrows = len(self.od), 1, 1
        if num_plots > 1:
            ncols = 2
            nrows = (num_plots // ncols) + (num_plots % ncols)

        ax_list, fig, plt = get_axarray_fig_plt(None, nrows=nrows, ncols=ncols,
                                                sharex=False, sharey=False, squeeze=False)
        ax_list = ax_list.ravel()

        # Don't show the last ax if num_plots is odd.
        if num_plots % ncols != 0: ax_list[-1].axis("off")

        for ax, (key, arr) in zip(ax_list, self.od.items()):
            ax.set_title(key, fontsize=fontsize)
            ax.grid(True)
            xs = arr[0] if not use_index else list(range(len(arr[0])))
            for ys in arr[1:] if not use_index else arr:
                ax.plot(xs, ys)

        return fig
Beispiel #2
0
    def plot(self, use_index=False, fontsize=8, colormap="viridis", **kwargs):
        """
        Plot all arrays. Use multiple axes if datasets.

        Args:
            use_index: By default, the x-values are taken from the first column.
                If use_index is False, the x-values are the row index.
            fontsize: fontsize for title.
            colormap: matplotlib color map.
            kwargs: options passed to ``ax.plot``.

        Return: |matplotlib-figure|
        """
        if not self.odlist: return None

        # Compute intersection of all keys.
        # Here we loose the initial ordering in the dict but oh well!
        klist = [list(d.keys()) for d in self.odlist]
        keys = set(klist[0]).intersection(*klist)
        if not keys:
            print("Warning: cannot find common keys in files. Check input data")
            return None

        # Build grid of plots.
        num_plots, ncols, nrows = len(keys), 1, 1
        if num_plots > 1:
            ncols = 2
            nrows = (num_plots // ncols) + (num_plots % ncols)

        ax_list, fig, plt = get_axarray_fig_plt(None, nrows=nrows, ncols=ncols,
                                                sharex=False, sharey=False, squeeze=False)
        ax_list = ax_list.ravel()

        # Don't show the last ax if num_plots is odd.
        if num_plots % ncols != 0: ax_list[-1].axis("off")

        cmap = plt.get_cmap(colormap)
        line_cycle = itertools.cycle(["-", ":", "--", "-.",])

        # One ax for key, each ax may show multiple arrays
        # so we need different line styles that are consistent with input data.
        # Figure may be crowded but it's difficult to do better without metadata
        # so I'm not gonna spend time to implement more complicated logic.
        for ax, key in zip(ax_list, keys):
            ax.set_title(key, fontsize=fontsize)
            ax.grid(True)
            for iod, (od, filepath) in enumerate(zip(self.odlist, self.filepaths)):
                if key not in od: continue
                arr = od[key]
                color = cmap(iod / len(self.odlist))
                xvals = arr[0] if not use_index else list(range(len(arr[0])))
                arr_list = arr[1:] if not use_index else arr
                for iarr, (ys, linestyle) in enumerate(zip(arr_list, line_cycle)):
                    ax.plot(xvals, ys, color=color, linestyle=linestyle,
                            label=os.path.relpath(filepath) if iarr == 0 else None)

            ax.legend(loc="best", fontsize=fontsize, shadow=True)

        return fig
Beispiel #3
0
    def plot(self, ax_list=None, fontsize=12, **kwargs):
        """
        Uses matplotlib to plot the evolution of the SCF cycle.

        Args:
            ax_list: List of axes. If None a new figure is produced.
            fontsize: legend fontsize.
            kwargs: keyword arguments are passed to ax.plot

        Returns: matplotlib figure
        """
        # Build grid of plots.
        num_plots, ncols, nrows = len(self), 1, 1
        if num_plots > 1:
            ncols = 2
            nrows = num_plots // ncols + num_plots % ncols

        ax_list, fig, plot = get_axarray_fig_plt(ax_list, nrows=nrows, ncols=ncols,
                                                 sharex=True, sharey=False, squeeze=False)
        ax_list = np.array(ax_list).ravel()

        iter_num = np.array(list(range(self.num_iterations))) + 1
        label = kwargs.pop("label", None)

        for i, ((key, values), ax) in enumerate(zip(self.items(), ax_list)):
            ax.grid(True)
            ax.set_xlabel('Iteration Step')
            ax.set_xticks(iter_num, minor=False)
            ax.set_ylabel(key)

            xx, yy = iter_num, values
            if self.num_iterations > 1:
                # Don't show the first iteration since it's not very useful.
                xx, yy = xx[1:], values[1:]

            if not kwargs and label is None:
                ax.plot(xx, yy, "-o", lw=2.0)
            else:
                ax.plot(xx, yy, label=label if i == 0 else None, **kwargs)

            if key in _VARS_SUPPORTING_LOGSCALE and np.all(yy > 1e-22):
                ax.set_yscale("log")

            if key in _VARS_WITH_YRANGE:
                ymin, ymax = _VARS_WITH_YRANGE[key]
                val_min, val_max = np.min(yy), np.max(yy)
                if abs(val_max - val_min) > abs(ymax - ymin):
                    ax.set_ylim(ymin, ymax)

            if label is not None:
                ax.legend(loc="best", fontsize=fontsize, shadow=True)

        # Get around a bug in matplotlib.
        if num_plots % ncols != 0:
            ax_list[-1].plot(xx, yy, lw=0.0)
            ax_list[-1].axis('off')

        return fig
Beispiel #4
0
    def plot(self, ax_list=None, fontsize=12, **kwargs):
        """
        Plot relaxation history i.e. the results of the last iteration of each SCF cycle.

        Args:
            ax_list: List of axes. If None a new figure is produced.
            fontsize: legend fontsize.
            kwargs: keyword arguments are passed to ax.plot

        Returns: matplotlib figure
        """
        history = self.history

        # Build grid of plots.
        num_plots, ncols, nrows = len(history), 1, 1
        if num_plots > 1:
            ncols = 2
            nrows = num_plots // ncols + num_plots % ncols

        ax_list, fig, plot = get_axarray_fig_plt(ax_list, nrows=nrows, ncols=ncols,
                                                 sharex=True, sharey=False, squeeze=False)
        ax_list = np.array(ax_list).ravel()

        iter_num = np.array(list(range(self.num_iterations))) + 1
        label = kwargs.pop("label", None)

        for i, ((key, values), ax) in enumerate(zip(history.items(), ax_list)):
            ax.grid(True)
            ax.set_xlabel('Relaxation Step')
            ax.set_xticks(iter_num, minor=False)
            ax.set_ylabel(key)

            xx, yy = iter_num, values
            if not kwargs and label is None:
                ax.plot(xx, yy, "-o", lw=2.0)
            else:
                ax.plot(xx, yy, label=label if i == 0 else None, **kwargs)

            if key in _VARS_SUPPORTING_LOGSCALE and np.all(yy > 1e-22):
                ax.set_yscale("log")

            if key in _VARS_WITH_YRANGE:
                ymin, ymax = _VARS_WITH_YRANGE[key]
                val_min, val_max = np.min(yy), np.max(yy)
                if abs(val_max - val_min) > abs(ymax - ymin):
                    ax.set_ylim(ymin, ymax)

            if label is not None:
                ax.legend(loc="best", fontsize=fontsize, shadow=True)

        # Get around a bug in matplotlib.
        if num_plots % ncols != 0:
            ax_list[-1].plot(xx, yy, lw=0.0)
            ax_list[-1].axis('off')

        return fig
Beispiel #5
0
    def plot(self, use_index=False, fontsize=8, **kwargs):
        """
        Plot all arrays. Use multiple axes if datasets.

        Args:
            use_index: By default, the x-values are taken from the first column.
                If use_index is False, the x-values are the row index.
            fontsize: fontsize for title.
            kwargs: options passed to ``ax.plot``.

        Return: |matplotlib-figure|
        """
        # build grid of plots.
        num_plots, ncols, nrows = len(self.od), 1, 1
        if num_plots > 1:
            ncols = 2
            nrows = (num_plots // ncols) + (num_plots % ncols)

        ax_list, fig, plt = get_axarray_fig_plt(None,
                                                nrows=nrows,
                                                ncols=ncols,
                                                sharex=False,
                                                sharey=False,
                                                squeeze=False)
        ax_list = ax_list.ravel()

        # Don't show the last ax if num_plots is odd.
        if num_plots % ncols != 0: ax_list[-1].axis("off")

        for ax, (key, arr) in zip(ax_list, self.od.items()):
            ax.set_title(key, fontsize=fontsize)
            ax.grid(True)
            xs = arr[0] if not use_index else list(range(len(arr[0])))
            for ys in arr[1:] if not use_index else arr:
                ax.plot(xs, ys)

        return fig
Beispiel #6
0
#%%

from matplotlib import pyplot as plt
import numpy as np
from pymatgen.electronic_structure.plotter import DosPlotter
from pymatgen.util.plotting import get_axarray_fig_plt, pretty_plot
import os
os.chdir('/home/jinho93/oxides/perobskite/lanthanum-aluminate/slab/14')

arr = np.genfromtxt('lsmo.dat')
num = len(arr[0])
num -= 1
ax_array, fig, plt = get_axarray_fig_plt(None, ncols=num, sharey=True)
plt = pretty_plot(12, 12, plt=plt)
for i in range(1, num):
    fig.add_subplot(1, num, i)
    plt.plot(arr[:, i], arr[:, 0])
    plt.ylim((-5, 3))
    plt.xlim(0, 3)
    plt.xticks([])
    plt.yticks([])
    plt.axis('off')
plt.subplots_adjust(wspace=0)
plt.savefig('figure.png')
plt.show()

# %%
Beispiel #7
0
    def plot(self, use_index=False, fontsize=8, colormap="viridis", **kwargs):
        """
        Plot all arrays. Use multiple axes if datasets.

        Args:
            use_index: By default, the x-values are taken from the first column.
                If use_index is False, the x-values are the row index.
            fontsize: fontsize for title.
            colormap: matplotlib color map.
            kwargs: options passed to ``ax.plot``.

        Return: |matplotlib-figure|
        """
        if not self.odlist: return None

        # Compute intersection of all keys.
        # Here we loose the initial ordering in the dict but oh well!
        klist = [list(d.keys()) for d in self.odlist]
        keys = set(klist[0]).intersection(*klist)
        if not keys:
            print(
                "Warning: cannot find common keys in files. Check input data")
            return None

        # Build grid of plots.
        num_plots, ncols, nrows = len(keys), 1, 1
        if num_plots > 1:
            ncols = 2
            nrows = (num_plots // ncols) + (num_plots % ncols)

        ax_list, fig, plt = get_axarray_fig_plt(None,
                                                nrows=nrows,
                                                ncols=ncols,
                                                sharex=False,
                                                sharey=False,
                                                squeeze=False)
        ax_list = ax_list.ravel()

        # Don't show the last ax if num_plots is odd.
        if num_plots % ncols != 0: ax_list[-1].axis("off")

        cmap = plt.get_cmap(colormap)
        line_cycle = itertools.cycle([
            "-",
            ":",
            "--",
            "-.",
        ])

        # One ax for key, each ax may show multiple arrays
        # so we need different line styles that are consistent with input data.
        # Figure may be crowded but it's difficult to do better without metadata
        # so I'm not gonna spend time to implement more complicated logic.
        for ax, key in zip(ax_list, keys):
            ax.set_title(key, fontsize=fontsize)
            ax.grid(True)
            for iod, (od,
                      filepath) in enumerate(zip(self.odlist, self.filepaths)):
                if key not in od: continue
                arr = od[key]
                color = cmap(iod / len(self.odlist))
                xvals = arr[0] if not use_index else list(range(len(arr[0])))
                arr_list = arr[1:] if not use_index else arr
                for iarr, (ys,
                           linestyle) in enumerate(zip(arr_list, line_cycle)):
                    ax.plot(
                        xvals,
                        ys,
                        color=color,
                        linestyle=linestyle,
                        label=os.path.relpath(filepath) if iarr == 0 else None)

            ax.legend(loc="best", fontsize=fontsize, shadow=True)

        return fig
Beispiel #8
0
def plot_xy_with_hue(data,
                     x,
                     y,
                     hue,
                     decimals=None,
                     ax=None,
                     xlims=None,
                     ylims=None,
                     fontsize=12,
                     **kwargs):
    """
    Plot y = f(x) relation for different values of `hue`.
    Useful for convergence tests done wrt to two parameters.

    Args:
        data: |pandas-DataFrame| containing columns `x`, `y`, and `hue`.
        x: Name of the column used as x-value
        y: Name of the column(s) used as y-value
        hue: Variable that define subsets of the data, which will be drawn on separate lines
        decimals: Number of decimal places to round `hue` columns. Ignore if None
        ax: |matplotlib-Axes| or None if a new figure should be created.
        xlims ylims: Set the data limits for the x(y)-axis. Accept tuple e.g. `(left, right)`
            or scalar e.g. `left`. If left (right) is None, default values are used
        fontsize: Legend fontsize.
        kwargs: Keywork arguments are passed to ax.plot method.

    Returns: |matplotlib-Figure|
    """
    if isinstance(y, (list, tuple)):
        # Recursive call for each ax in ax_list.
        num_plots, ncols, nrows = len(y), 1, 1
        if num_plots > 1:
            ncols = 2
            nrows = (num_plots // ncols) + (num_plots % ncols)

        ax_list, fig, plt = get_axarray_fig_plt(None,
                                                nrows=nrows,
                                                ncols=ncols,
                                                sharex=False,
                                                sharey=False,
                                                squeeze=False)

        ax_list = ax_list.ravel()
        if num_plots % ncols != 0: ax_list[-1].axis('off')

        for yname, ax in zip(y, ax_list):
            plot_xy_with_hue(data,
                             x,
                             str(yname),
                             hue,
                             decimals=decimals,
                             ax=ax,
                             xlims=xlims,
                             ylims=ylims,
                             fontsize=fontsize,
                             show=False,
                             **kwargs)
        return fig

    # Check here because pandas error messages are a bit criptic.
    miss = [k for k in (x, y, hue) if k not in data]
    if miss:
        raise ValueError(
            "Cannot find `%s` in dataframe.\nAvailable keys are: %s" %
            (str(miss), str(data.keys())))

    # Truncate values in hue column so that we can group.
    if decimals is not None:
        data = data.round({hue: decimals})

    ax, fig, plt = get_ax_fig_plt(ax=ax)
    for key, grp in data.groupby(hue):
        # Sort xs and rearrange ys
        xy = np.array(sorted(zip(grp[x], grp[y]), key=lambda t: t[0]))
        xvals, yvals = xy[:, 0], xy[:, 1]

        #label = "{} = {}".format(hue, key)
        label = "%s" % (str(key))
        if not kwargs:
            ax.plot(xvals, yvals, 'o-', label=label)
        else:
            ax.plot(xvals, yvals, label=label, **kwargs)

    ax.grid(True)
    ax.set_xlabel(x)
    ax.set_ylabel(y)
    set_axlims(ax, xlims, "x")
    set_axlims(ax, ylims, "y")
    ax.legend(loc="best", fontsize=fontsize, shadow=True)

    return fig
Beispiel #9
0
def plot_xy_with_hue(data, x, y, hue, decimals=None, ax=None,
                     xlims=None, ylims=None, fontsize=12, **kwargs):
    """
    Plot y = f(x) relation for different values of `hue`.
    Useful for convergence tests done wrt to two parameters.

    Args:
        data: |pandas-DataFrame| containing columns `x`, `y`, and `hue`.
        x: Name of the column used as x-value
        y: Name of the column(s) used as y-value
        hue: Variable that define subsets of the data, which will be drawn on separate lines
        decimals: Number of decimal places to round `hue` columns. Ignore if None
        ax: |matplotlib-Axes| or None if a new figure should be created.
        xlims ylims: Set the data limits for the x(y)-axis. Accept tuple e.g. `(left, right)`
            or scalar e.g. `left`. If left (right) is None, default values are used
        fontsize: Legend fontsize.
        kwargs: Keywork arguments are passed to ax.plot method.

    Returns: |matplotlib-Figure|
    """
    if isinstance(y, (list, tuple)):
        # Recursive call for each ax in ax_list.
        num_plots, ncols, nrows = len(y), 1, 1
        if num_plots > 1:
            ncols = 2
            nrows = (num_plots // ncols) + (num_plots % ncols)

        ax_list, fig, plt = get_axarray_fig_plt(None, nrows=nrows, ncols=ncols,
                                                sharex=False, sharey=False, squeeze=False)

        ax_list = ax_list.ravel()
        if num_plots % ncols != 0: ax_list[-1].axis('off')

        for yname, ax in zip(y, ax_list):
            plot_xy_with_hue(data, x, str(yname), hue, decimals=decimals, ax=ax,
                             xlims=xlims, ylims=ylims, fontsize=fontsize, show=False, **kwargs)
        return fig

    # Check here because pandas error messages are a bit criptic.
    miss = [k for k in (x, y, hue) if k not in data]
    if miss:
        raise ValueError("Cannot find `%s` in dataframe.\nAvailable keys are: %s" % (str(miss), str(data.keys())))

    # Truncate values in hue column so that we can group.
    if decimals is not None:
        data = data.round({hue: decimals})

    ax, fig, plt = get_ax_fig_plt(ax=ax)
    for key, grp in data.groupby(hue):
        # Sort xs and rearrange ys
        xy = np.array(sorted(zip(grp[x], grp[y]), key=lambda t: t[0]))
        xvals, yvals = xy[:, 0], xy[:, 1]

        #label = "{} = {}".format(hue, key)
        label = "%s" % (str(key))
        if not kwargs:
            ax.plot(xvals, yvals, 'o-', label=label)
        else:
            ax.plot(xvals, yvals, label=label, **kwargs)

    ax.grid(True)
    ax.set_xlabel(x)
    ax.set_ylabel(y)
    set_axlims(ax, xlims, "x")
    set_axlims(ax, ylims, "y")
    ax.legend(loc="best", fontsize=fontsize, shadow=True)

    return fig
Beispiel #10
0
os.chdir(
    '/home/ksrc5/FTJ/1.bfo/111-dir/junction/sto/vasp/vac/conf3/4.node03/dense_k_dos/again'
)

vrun = Vasprun('vasprun.xml')
s: Structure = vrun.final_structure
cdos = vrun.complete_dos
pdos = cdos.pdos
doss = dict()
num = 16
arr = np.linspace(0, 1, num, endpoint=False)
darr = arr[1] - arr[0]
for j in arr:
    densities = []
    for i in s.sites:
        if j + darr > i.c >= j:
            densities.append(cdos.get_site_dos(i).get_densities())
    densities = np.sum(densities, axis=0)
    doss[f'{j:.2f}'] = Dos(cdos.efermi, cdos.energies, {Spin.up: densities})
dsp = DosPlotter(sigma=0.1)
ax_array, fig, plt = get_axarray_fig_plt(None, nrows=num, sharex=True)
plt = pretty_plot(12, 6, plt=plt)
for i in range(num):
    dsp.__init__(sigma=0.05)
    a = doss.popitem()
    dsp.add_dos(*a)
    fig.add_subplot(num, 1, i + 1)
    subplt = dsp.get_plot(xlim=(-1, 2), plt=plt)
plt.savefig('figure.png')
plt.show()