示例#1
0
    def plot_densities(self, ax=None, timesr2=False, **kwargs):
        """
        Plot ae, ps and model densities on axis ax.
        """
        ax, fig, plt = get_ax_fig_plt(ax)

        lines, legends = [], []
        for name, rho in self.densities.items():
            d = rho.values if not timesr2 else rho.values * rho.rmesh**2
            line, = ax.plot(rho.rmesh,
                            d,
                            linewidth=self.linewidth,
                            markersize=self.markersize)
            lines.append(line)
            legends.append(name)

        ylabel = "$n(r)$" if not timesr2 else "$r^2 n(r)$"
        decorate_ax(ax,
                    xlabel="r [Bohr]",
                    ylabel=ylabel,
                    title="Charge densities",
                    lines=lines,
                    legends=legends)

        return fig
示例#2
0
    def plot_key(self, key, ax=None, **kwargs):
        """Plot a singol quantity specified by key."""
        ax, fig, plt = get_ax_fig_plt(ax)

        # key --> self.plot_key()
        getattr(self, "plot_" + key)(ax=ax, **kwargs)
        self._mplt.show()
示例#3
0
    def plot_stacked_hist(self, key="wall_time", nmax=5, ax=None, **kwargs):
        """
        Plot stacked histogram of the different timers.

        Args:
            key: Keyword used to extract data from the timers. Only the first `nmax`
                sections with largest value are show.
            mmax: Maximum nuber of sections to show. Other entries are grouped together
                in the `others` section.
            ax: matplotlib :class:`Axes` or None if a new figure should be created.

        Returns:
            `matplotlib` figure
        """
        ax, fig, plt = get_ax_fig_plt(ax=ax)

        mpi_rank = "0"
        timers = self.timers(mpi_rank=mpi_rank)
        n = len(timers)

        names, values = [], []
        rest = np.zeros(n)

        for idx, sname in enumerate(self.section_names(ordkey=key)):
            sections = self.get_sections(sname)
            svals = np.asarray([s.__dict__[key] for s in sections])
            if idx < nmax:
                names.append(sname)
                values.append(svals)
            else:
                rest += svals

        names.append("others (nmax=%d)" % nmax)
        values.append(rest)

        # The dataset is stored in values. Now create the stacked histogram.
        ind = np.arange(n) # the locations for the groups
        width = 0.35       # the width of the bars
        colors = nmax * ['r', 'g', 'b', 'c', 'k', 'y', 'm']

        bars = []
        bottom = np.zeros(n)
        for idx, vals in enumerate(values):
            color = colors[idx]
            bar = ax.bar(ind, vals, width, color=color, bottom=bottom)
            bars.append(bar)
            bottom += vals

        ax.set_ylabel(key)
        ax.set_title("Stacked histogram with the %d most important sections" % nmax)

        ticks = ind + width / 2.0
        labels = ["MPI=%d, OMP=%d" % (t.mpi_nprocs, t.omp_nthreads) for t in timers]
        ax.set_xticks(ticks)
        ax.set_xticklabels(labels, rotation=15)

        # Add legend.
        ax.legend([bar[0] for bar in bars], names, loc="best")

        return fig
示例#4
0
    def plot_radial_wfs(self, ax=None, **kwargs):
        """
        Plot ae and ps radial wavefunctions on axis ax.

        lselect: List to select l channels.
        """
        ax, fig, plt = get_ax_fig_plt(ax)

        ae_wfs, ps_wfs = self.radial_wfs.ae, self.radial_wfs.ps
        lselect = kwargs.get("lselect", [])

        lines, legends = [], []
        for nlk, ae_wf in ae_wfs.items():
            ps_wf, l, k = ps_wfs[nlk], nlk.l, nlk.k
            if l in lselect: continue
            #print(nlk)

            ae_line, = ax.plot(ae_wf.rmesh, ae_wf.values, **self._wf_pltopts(l, "ae"))
            ps_line, = ax.plot(ps_wf.rmesh, ps_wf.values, **self._wf_pltopts(l, "ps"))

            lines.extend([ae_line, ps_line])
            if k is None:
                legends.extend(["AE l=%s" % l, "PS l=%s" % l])
            else:
                legends.extend(["AE l=%s, k=%s" % (l, k), "PS l=%s, k=%s" % (l, k)])

        decorate_ax(ax, xlabel="r [Bohr]", ylabel="$\phi(r)$", title="Wave Functions",
                    lines=lines, legends=legends)

        return fig
示例#5
0
    def plot_der_densities(self, ax=None, order=1, **kwargs):
        """
        Plot the derivatives of the densitiers on axis ax.
        Used to analyze possible derivative discontinuities
        """
        ax, fig, plt = get_ax_fig_plt(ax)

        from scipy.interpolate import UnivariateSpline

        lines, legends = [], []
        for name, rho in self.densities.items():
            if name != "rhoM": continue
            # Need linear mesh for finite_difference --> Spline input densities on lin_rmesh
            lin_rmesh, h = np.linspace(rho.rmesh[0], rho.rmesh[-1], num=len(rho.rmesh) * 4, retstep=True)
            spline = UnivariateSpline(rho.rmesh, rho.values, s=0)
            lin_values = spline(lin_rmesh)
            vder = finite_diff(lin_values, h, order=order, acc=4)
            line, = ax.plot(lin_rmesh, vder) #, **self._wf_pltopts(l, "ae"))
            lines.append(line)

            legends.append("%s-order derivative of %s" % (order, name))

        decorate_ax(ax, xlabel="r [Bohr]", ylabel="$D^%s \n(r)$" % order, title="Derivative of the charge densities",
                    lines=lines, legends=legends)
        return fig
示例#6
0
文件: plotting.py 项目: zbwang/abipy
def plot_xy_with_hue(data,
                     x,
                     y,
                     hue,
                     decimals=None,
                     ax=None,
                     xlims=None,
                     ylims=None,
                     fontsize=12,
                     **kwargs):
    """
    Plot y = f(x) relation for different values of `hue`.
    Useful for convergence tests done wrt to two parameters.

    Args:
        data: |pandas-DataFrame| containing columns `x`, `y`, and `hue`.
        x: Name of the column used as x-value
        y: Name of the column used as y-value
        hue: Variable that define subsets of the data, which will be drawn on separate lines
        decimals: Number of decimal places to round `hue` columns. Ignore if None
        ax: |matplotlib-Axes| or None if a new figure should be created.
        xlims ylims: Set the data limits for the x(y)-axis. Accept tuple e.g. `(left, right)`
            or scalar e.g. `left`. If left (right) is None, default values are used
        fontsize: Legend fontsize.
        kwargs: Keywork arguments are passed to ax.plot method.

    Returns: |matplotlib-Figure|
    """
    # Check here because pandas error messages are a bit criptic.
    miss = [k for k in (x, y, hue) if k not in data]
    if miss:
        raise ValueError(
            "Cannot find `%s` in dataframe.\nAvailable keys are: %s" %
            (str(miss), str(data.keys())))

    # Truncate values in hue column so that we can group.
    if decimals is not None:
        data = data.round({hue: decimals})

    ax, fig, plt = get_ax_fig_plt(ax=ax)
    for key, grp in data.groupby(hue):
        #xvals, yvals = grp[x], grp[y]
        # Sort xs and rearrange ys
        xy = np.array(sorted(zip(grp[x], grp[y]), key=lambda t: t[0]))
        xvals, yvals = xy[:, 0], xy[:, 1]

        label = "{} = {}".format(hue, key)
        if not kwargs:
            ax.plot(xvals, yvals, 'o-', label=label)
        else:
            ax.plot(xvals, yvals, label=label, **kwargs)

    ax.grid(True)
    ax.set_xlabel(x)
    ax.set_ylabel(y)
    set_axlims(ax, xlims, "x")
    set_axlims(ax, ylims, "y")
    ax.legend(loc="best", fontsize=fontsize, shadow=True)

    return fig
示例#7
0
    def plot(self,
             ax=None,
             components=('xx', ),
             reim="reim",
             vertical_lines=True,
             **kwargs):
        """
        Return an instance of the spectra and plot it using matplotlib
        """
        ax, fig, plt = get_ax_fig_plt(ax=ax)
        directions_map = {'x': 0, 'y': 1, 'z': 2}
        functions_map = {'re': lambda x: x.real, 'im': lambda x: x.imag}
        reim_label = {'re': 'Re', 'im': 'Im'}
        for component in components:
            i, j = [directions_map[direction] for direction in component]
            for fstr in functions_map:
                if fstr in reim:
                    f = functions_map[fstr]
                    label = "%s{$\epsilon_{%s}$}" % (reim_label[fstr],
                                                     component)
                    ax.plot(self.frequencies * 1000,
                            f(self.ir_spectra_tensor[:, i, j]),
                            label=label,
                            **kwargs)

        if vertical_lines:
            phfreqs = self.ir_spectra_generator.phfreqs[3:]
            ax.scatter(phfreqs * 1000, np.zeros_like(phfreqs))
        ax.set_xlabel('$\epsilon(\omega)$')
        ax.set_xlabel('Frequency (meV)')
        ax.legend()
        return fig
示例#8
0
    def plot_atan_logders(self, ax=None, **kwargs):
        """Plot arctan of logder on axis ax."""
        ae, ps = self.atan_logders.ae, self.atan_logders.ps

        ax, fig, plt = get_ax_fig_plt(ax)

        lines, legends = [], []
        for l, ae_alog in ae.items():
            ps_alog = ps[l]

            # Add padd to avoid overlapping curves.
            pad = (l + 1) * 1.0

            ae_line, = ax.plot(ae_alog.energies, ae_alog.values + pad,
                               **self._wf_pltopts(l, "ae"))
            ps_line, = ax.plot(ps_alog.energies, ps_alog.values + pad,
                               **self._wf_pltopts(l, "ps"))

            lines.extend([ae_line, ps_line])
            legends.extend(["AE l=%s" % str(l), "PS l=%s" % str(l)])

        decorate_ax(ax,
                    xlabel="Energy [Ha]",
                    ylabel="ATAN(LogDer)",
                    title="ATAN(Log Derivative)",
                    lines=lines,
                    legends=legends)

        return fig
示例#9
0
    def plot_errors_for_structure(self, struct_type, ax=None, **kwargs):
        """
        Plot the errors for a given crystalline structure.
        """
        ax, fig, plt = get_ax_fig_plt(ax=ax)
        data = self[self["struct_type"] == struct_type].copy()
        if not len(data):
            print("No results available for struct_type:", struct_type)
            return None

        colnames = ["this", "gbrv_paw"]
        for col in colnames:
            data[col + "_rel_err"] = 100 * (data[col] - data["ae"]) / data["ae"]
            #data[col + "_rel_err"] = abs(100 * (data[col] - data["ae"]) / data["ae"])
            data.plot(x="formula", y=col + "_rel_err", ax=ax, style="o-", grid=True)
        labels = data['formula'].values
        ax.set_ylabel("relative error %% for %s" % struct_type)
        ticks = list(range(len(data.index)))
        ticks1 = range(min(ticks), max(ticks)+1, 2)
        ticks2 = range(min(ticks)+1, max(ticks)+1, 2)
        labels1 = [labels[i] for i in ticks1]
        labels2 = [labels[i] for i in ticks2]
 #       ax.tick_params(which='both', direction='out')
        #ax.set_ylim(-1, 1)
        ax.set_xticks(ticks1)
        ax.set_xticklabels(labels1, rotation=90)
        ax2 = ax.twiny()
        ax2.set_zorder(-1)
        ax2.set_xticks(ticks2)
        ax2.set_xticklabels(labels2, rotation=90)
        ax2.set_xlim(ax.get_xlim())

        return fig
示例#10
0
    def plot_hist(self, struct_type, ax=None, errtxt=True, **kwargs):
        """
        Histogram plot.
        """
        #if codes is None: codes = ["ae"]
        ax, fig, plt = get_ax_fig_plt(ax)
        import seaborn as sns

        codes = ["this", "gbrv_paw"]  #, "gbrv_uspp", "pslib", "vasp"]
        new = self[self["struct_type"] == struct_type].copy()
        ypos = 0.8
        for i, code in enumerate(codes):
            values = (100 * (new[code] - new["ae"]) / new["ae"]).dropna()
            sns.distplot(values, ax=ax, rug=True, hist=False, label=code)

            # Add text with Mean or (MARE/RMSRE)
            if errtxt:
                text = []
                app = text.append
                #app("%s MARE = %.2f" % (code, values.abs().mean()))
                app("%s RMSRE = %.2f" % (code, np.sqrt((values**2).mean())))
                ax.text(0.6, ypos, "\n".join(text), transform=ax.transAxes)
                ypos -= 0.1

        ax.grid(True)
        ax.set_xlabel("relative error %")
        ax.set_xlim(-0.8, 0.8)

        return fig
示例#11
0
    def plot_eos(self, ax=None, accuracy="normal", **kwargs):
        """
        Plot the equation of state.

        Args:
            ax: matplotlib :class:`Axes` or None if a new figure should be created.

        Returns:
            `matplotlib` figure.
        """
        ax, fig, plt = get_ax_fig_plt(ax)
        if not self.has_data(accuracy): return fig
        d = self["accuracy"]

        num_sites, volumes, etotals = d["num_sites"], np.array(
            d["volumes"]), np.array(d["etotals"])

        # Perform quadratic fit.
        eos = EOS.Quadratic()
        eos_fit = eos.fit(volumes / num_sites, etotals / num_sites)

        label = "ecut %.1f" % d["ecut"]
        eos_fit.plot(ax=ax, text=False, label=label,
                     show=False)  # color=cmap(i/num_ecuts, alpha=1),
        return fig
示例#12
0
    def plot_stacked_hist(self, key="wall_time", nmax=5, ax=None, **kwargs):
        """
        Plot stacked histogram of the different timers.

        Args:
            key: Keyword used to extract data from the timers. Only the first `nmax`
                sections with largest value are show.
            mmax: Maximum nuber of sections to show. Other entries are grouped together
                in the `others` section.
            ax: matplotlib :class:`Axes` or None if a new figure should be created.

        Returns:
            `matplotlib` figure
        """
        ax, fig, plt = get_ax_fig_plt(ax=ax)

        mpi_rank = "0"
        timers = self.timers(mpi_rank=mpi_rank)
        n = len(timers)

        names, values = [], []
        rest = np.zeros(n)

        for idx, sname in enumerate(self.section_names(ordkey=key)):
            sections = self.get_sections(sname)
            svals = np.asarray([s.__dict__[key] for s in sections])
            if idx < nmax:
                names.append(sname)
                values.append(svals)
            else:
                rest += svals

        names.append("others (nmax=%d)" % nmax)
        values.append(rest)

        # The dataset is stored in values. Now create the stacked histogram.
        ind = np.arange(n) # the locations for the groups
        width = 0.35       # the width of the bars
        colors = nmax * ['r', 'g', 'b', 'c', 'k', 'y', 'm']

        bars = []
        bottom = np.zeros(n)
        for idx, vals in enumerate(values):
            color = colors[idx]
            bar = ax.bar(ind, vals, width, color=color, bottom=bottom)
            bars.append(bar)
            bottom += vals

        ax.set_ylabel(key)
        ax.set_title("Stacked histogram with the %d most important sections" % nmax)

        ticks = ind + width / 2.0
        labels = ["MPI=%d, OMP=%d" % (t.mpi_nprocs, t.omp_nthreads) for t in timers]
        ax.set_xticks(ticks)
        ax.set_xticklabels(labels, rotation=15)

        # Add legend.
        ax.legend([bar[0] for bar in bars], names, loc="best")

        return fig
示例#13
0
    def plot_der_potentials(self, ax=None, order=1, **kwargs):
        """
        Plot the derivatives of vl and vloc potentials on axis ax.
        Used to analyze the derivative discontinuity introduced by the RRKJ method at rc.
        """
        ax, fig, plt = get_ax_fig_plt(ax)
        from abipy.tools.derivatives import finite_diff
        from scipy.interpolate import UnivariateSpline
        lines, legends = [], []
        for l, pot in self.potentials.items():
            # Need linear mesh for finite_difference --> Spline input potentials on lin_rmesh
            lin_rmesh, h = np.linspace(pot.rmesh[0], pot.rmesh[-1], num=len(pot.rmesh) * 4, retstep=True)
            spline = UnivariateSpline(pot.rmesh, pot.values, s=0)
            lin_values = spline(lin_rmesh)
            vder = finite_diff(lin_values, h, order=order, acc=4)
            line, = ax.plot(lin_rmesh, vder, **self._wf_pltopts(l, "ae"))
            lines.append(line)

            if l == -1:
                legends.append("%s-order derivative Vloc" % order)
            else:
                legends.append("$s-order derivative PS l=%s" % str(l))

        decorate_ax(ax, xlabel="r [Bohr]", ylabel="$D^%s \phi(r)$" % order,
                    title="Derivative of the ion Pseudopotentials",
                    lines=lines, legends=legends)
        return fig
示例#14
0
    def cpuwall_histogram(self, ax=None, **kwargs):
        ax, fig, plt = get_ax_fig_plt(ax=ax)

        nk = len(self.sections)
        ind = np.arange(nk)  # the x locations for the groups
        width = 0.35  # the width of the bars

        cpu_times = self.get_values("cpu_time")
        rects1 = plt.bar(ind, cpu_times, width, color='r')

        wall_times = self.get_values("wall_time")
        rects2 = plt.bar(ind + width, wall_times, width, color='y')

        # Add ylable and title
        ax.set_ylabel('Time (s)')

        #if title:
        #    plt.title(title)
        #else:
        #    plt.title('CPU-time and Wall-time for the different sections of the code')

        ticks = self.get_values("name")
        ax.set_xticks(ind + width, ticks)

        ax.legend((rects1[0], rects2[0]), ('CPU', 'Wall'), loc="best")

        return fig
示例#15
0
    def cpuwall_histogram(self, ax=None, **kwargs):
        ax, fig, plt = get_ax_fig_plt(ax=ax)

        nk = len(self.sections)
        ind = np.arange(nk)  # the x locations for the groups
        width = 0.35         # the width of the bars

        cpu_times = self.get_values("cpu_time")
        rects1 = plt.bar(ind, cpu_times, width, color='r')

        wall_times = self.get_values("wall_time")
        rects2 = plt.bar(ind + width, wall_times, width, color='y')

        # Add ylable and title
        ax.set_ylabel('Time (s)')

        #if title:
        #    plt.title(title)
        #else:
        #    plt.title('CPU-time and Wall-time for the different sections of the code')

        ticks = self.get_values("name")
        ax.set_xticks(ind + width, ticks)

        ax.legend((rects1[0], rects2[0]), ('CPU', 'Wall'), loc="best")

        return fig
示例#16
0
    def plot_den_formfact(self, ecut=60, ax=None, **kwargs):
        """
        Plot the density form factor as function of ecut (Ha units). Return matplotlib Figure.
        """
        ax, fig, plt = get_ax_fig_plt(ax)

        lines, legends = [], []
        for name, rho in self.densities.items():
            if name == "rhoC": continue
            form = rho.get_intr2j0(ecut=ecut) / (4 * np.pi)
            line, = ax.plot(form.mesh, form.values, linewidth=self.linewidth, markersize=self.markersize)
            lines.append(line); legends.append(name)

            intg = rho.r2f_integral()[-1]
            print("r2 f integral: ", intg)
            print("form_factor(0): ", name, form.values[0])

        # Plot vloc(q)
        #for l, pot in self.potentials.items():
        #    if l != -1: continue
        #    form = pot.get_intr2j0(ecut=ecut)
        #    mask = np.where(np.abs(form.values) > 20); form.values[mask] = 20
        #    line, = ax.plot(form.mesh, form.values, linewidth=self.linewidth, markersize=self.markersize)
        #    lines.append(line); legends.append("Vloc(q)")

        decorate_ax(ax, xlabel="Ecut [Ha]", ylabel="$n(q)$", title="Form factor, l=0 ", lines=lines, legends=legends)
        return fig
示例#17
0
    def cpuwall_histogram(self, ax=None, **kwargs):
        """
        Plot histogram with cpu- and wall-time on axis `ax`.

        Args:
            ax: matplotlib :class:`Axes` or None if a new figure should be created.

        Returns: `matplotlib` figure
        """
        ax, fig, plt = get_ax_fig_plt(ax=ax)

        nk = len(self.sections)
        ind = np.arange(nk)  # the x locations for the groups
        width = 0.35  # the width of the bars

        cpu_times = self.get_values("cpu_time")
        rects1 = plt.bar(ind, cpu_times, width, color='r')

        wall_times = self.get_values("wall_time")
        rects2 = plt.bar(ind + width, wall_times, width, color='y')

        # Add ylable and title
        ax.set_ylabel('Time (s)')

        # plt.title('CPU-time and Wall-time for the different sections of the code')

        ticks = self.get_values("name")
        ax.set_xticks(ind + width, ticks)

        ax.legend((rects1[0], rects2[0]), ('CPU', 'Wall'), loc="best")

        return fig
示例#18
0
    def plot_projectors(self, ax=None, **kwargs):
        """
        Plot oncvpsp projectors on axis ax.

        lselect: List to select l channels
        """
        ax, fig, plt = get_ax_fig_plt(ax)

        lselect = kwargs.get("lselect", [])

        linestyle = {1: "solid", 2: "dashed"}
        lines, legends = [], []
        for nlk, proj in self.projectors.items():
            #print(nlk)
            if nlk.l in lselect: continue
            line, = ax.plot(proj.rmesh,
                            proj.values,
                            color=self.color_l.get(nlk.l, 'black'),
                            linestyle=linestyle[nlk.n],
                            linewidth=self.linewidth,
                            markersize=self.markersize)
            lines.append(line)
            legends.append("Proj %s" % str(nlk))

        decorate_ax(ax,
                    xlabel="r [Bohr]",
                    ylabel="$p(r)$",
                    title="Projector Wave Functions",
                    lines=lines,
                    legends=legends)
        return fig
示例#19
0
    def plot_hist(self, struct_type, ax=None, errtxt=True, **kwargs):
        """
        Histogram plot.
        """
        #if codes is None: codes = ["ae"]
        ax, fig, plt = get_ax_fig_plt(ax)
        import seaborn as sns

        codes = ["this", "gbrv_paw"] #, "gbrv_uspp", "pslib", "vasp"]
        new = self[self["struct_type"] == struct_type].copy()
        ypos = 0.8
        for i, code in enumerate(codes):
            values = (100 * (new[code] - new["ae"]) / new["ae"]).dropna()
            sns.distplot(values, ax=ax, rug=True, hist=False, label=code)

            # Add text with Mean or (MARE/RMSRE)
            if errtxt:
                text = []; app = text.append
                #app("%s MARE = %.2f" % (code, values.abs().mean()))
                app("%s RMSRE = %.2f" % (code, np.sqrt((values**2).mean())))
                ax.text(0.6, ypos, "\n".join(text), transform=ax.transAxes)
                ypos -= 0.1

        ax.grid(True)
        ax.set_xlabel("relative error %")
        ax.set_xlim(-0.8, 0.8)

        return fig
示例#20
0
    def plot_stacked_hist(self, key="wall_time", nmax=5, ax=None, **kwargs):
        """Stacked histogram of the different timers."""
        ax, fig, plt = get_ax_fig_plt(ax=ax)

        mpi_rank = "0"
        timers = self.timers(mpi_rank=mpi_rank)
        n = len(timers)

        names, values = [], []
        rest = np.zeros(n)

        for idx, sname in enumerate(self.section_names(ordkey=key)):
            sections = self.get_sections(sname)
            svals = np.asarray([s.__dict__[key] for s in sections])

            if idx < nmax:
                names.append(sname)
                values.append(svals)
            else:
                rest += svals

        names.append("others (nmax = %d)" % nmax)
        values.append(rest)
        #for (n, vals) in zip(names, values): print(n, vals)

        # The dataset is stored in values.
        # Now create the stacked histogram.

        ind = np.arange(n)  # the locations for the groups
        width = 0.35  # the width of the bars

        # this does not work with matplotlib < 1.0
        #plt.rcParams['axes.color_cycle'] = ['r', 'g', 'b', 'c']
        colors = nmax * ['r', 'g', 'b', 'c', 'k', 'y', 'm']

        bars = []
        bottom = np.zeros(n)

        for idx, vals in enumerate(values):
            color = colors[idx]

            bar = plt.bar(ind, vals, width, color=color, bottom=bottom)
            bars.append(bar)

            bottom += vals

        ax.set_ylabel(key)
        #ax.title("Stacked histogram for the %d most important sections" % nmax)

        labels = [
            "MPI = %d, OMP = %d" % (t.mpi_nprocs, t.omp_nthreads)
            for t in timers
        ]
        plt.xticks(ind + width / 2.0, labels, rotation=15)
        #plt.yticks(np.arange(0,81,10))

        ax.legend([bar[0] for bar in bars], names, loc="best")

        return fig
示例#21
0
    def scatter_hist(self, ax=None, **kwargs):
        """
        Scatter plot + histogram.

        Args:
            ax: matplotlib :class:`Axes` or None if a new figure should be created.

        Returns: `matplotlib` figure
        """
        from mpl_toolkits.axes_grid1 import make_axes_locatable

        ax, fig, plt = get_ax_fig_plt(ax=ax)

        x = np.asarray(self.get_values("cpu_time"))
        y = np.asarray(self.get_values("wall_time"))

        # the scatter plot:
        axScatter = plt.subplot(1, 1, 1)
        axScatter.scatter(x, y)
        axScatter.set_aspect("auto")

        # create new axes on the right and on the top of the current axes
        # The first argument of the new_vertical(new_horizontal) method is
        # the height (width) of the axes to be created in inches.
        divider = make_axes_locatable(axScatter)
        axHistx = divider.append_axes("top", 1.2, pad=0.1, sharex=axScatter)
        axHisty = divider.append_axes("right", 1.2, pad=0.1, sharey=axScatter)

        # make some labels invisible
        plt.setp(axHistx.get_xticklabels() + axHisty.get_yticklabels(),
                 visible=False)

        # now determine nice limits by hand:
        binwidth = 0.25
        xymax = np.max([np.max(np.fabs(x)), np.max(np.fabs(y))])
        lim = (int(xymax / binwidth) + 1) * binwidth

        bins = np.arange(-lim, lim + binwidth, binwidth)
        axHistx.hist(x, bins=bins)
        axHisty.hist(y, bins=bins, orientation="horizontal")

        # the xaxis of axHistx and yaxis of axHisty are shared with axScatter,
        # thus there is no need to manually adjust the xlim and ylim of these axis.

        # axHistx.axis["bottom"].major_ticklabels.set_visible(False)
        for tl in axHistx.get_xticklabels():
            tl.set_visible(False)
            axHistx.set_yticks([0, 50, 100])

            # axHisty.axis["left"].major_ticklabels.set_visible(False)
            for tl in axHisty.get_yticklabels():
                tl.set_visible(False)
                axHisty.set_xticks([0, 50, 100])

        # plt.draw()
        return fig
示例#22
0
    def plot(self, ax=None, **kwargs):
        """
        Plot the histogram with matplotlib, returns `matplotlib` figure.
        """
        ax, fig, plt = get_ax_fig_plt(ax)

        yy = [len(v) for v in self.values]
        ax.plot(self.binvals, yy, **kwargs)

        return fig
示例#23
0
文件: utils.py 项目: ExpHP/pymatgen
    def plot(self, ax=None, **kwargs):
        """
        Plot the histogram with matplotlib, returns `matplotlib` figure.
        """
        ax, fig, plt = get_ax_fig_plt(ax)

        yy = [len(v) for v in self.values]
        ax.plot(self.binvals, yy, **kwargs)

        return fig
示例#24
0
    def plot_ax(self, ax=None, fontsize=12, **kwargs):
        """
        Plot the equation of state on axis `ax`

        Args:
            ax: matplotlib :class:`Axes` or None if a new figure should be created.
            fontsize: Legend fontsize.
            color (str): plot color.
            label (str): Plot label
            text (str): Legend text (options)

        Returns:
            Matplotlib figure object.
        """
        # pylint: disable=E1307
        ax, fig, plt = get_ax_fig_plt(ax=ax)

        color = kwargs.get("color", "r")
        label = kwargs.get("label", f"{self.__class__.__name__} fit")
        lines = [
            "Equation of State: %s" % self.__class__.__name__,
            "Minimum energy = %1.2f eV" % self.e0,
            "Minimum or reference volume = %1.2f Ang^3" % self.v0,
            f"Bulk modulus = {self.b0:1.2f} eV/Ang^3 = {self.b0_GPa:1.2f} GPa",
            "Derivative of bulk modulus wrt pressure = %1.2f" % self.b1,
        ]
        text = "\n".join(lines)
        text = kwargs.get("text", text)

        # Plot input data.
        ax.plot(self.volumes, self.energies, linestyle="None", marker="o", color=color)

        # Plot eos fit.
        vmin, vmax = min(self.volumes), max(self.volumes)
        vmin, vmax = (vmin - 0.01 * abs(vmin), vmax + 0.01 * abs(vmax))
        vfit = np.linspace(vmin, vmax, 100)

        ax.plot(vfit, self.func(vfit), linestyle="dashed", color=color, label=label)

        ax.grid(True)
        ax.set_xlabel("Volume $\\AA^3$")
        ax.set_ylabel("Energy (eV)")
        ax.legend(loc="best", shadow=True)
        # Add text with fit parameters.
        ax.text(
            0.5,
            0.5,
            text,
            fontsize=fontsize,
            horizontalalignment="center",
            verticalalignment="center",
            transform=ax.transAxes,
        )

        return fig
示例#25
0
    def plot_errors_for_elements(self, ax=None, **kwargs):
        """
        Plot the relative errors associated to the chemical elements.
        """
        dict_list = []
        for idx, row in self.iterrows():
            rerr = 100 * (row["this"] - row["ae"]) / row["ae"]
            for symbol in set(species_from_formula(row.formula)):
                dict_list.append(dict(
                    element=symbol,
                    rerr=rerr,
                    formula=row.formula,
                    struct_type=row.struct_type,
                    ))

        frame = DataFrame(dict_list)
        order = sort_symbols_by_Z(set(frame["element"]))
        #print_frame(frame)

        import seaborn as sns
        ax, fig, plt = get_ax_fig_plt(ax=ax)

        # Draw violinplot
        #sns.violinplot(x="element", y="rerr", order=order, data=frame, ax=ax, orient="v")

        # Box plot
        ax = sns.boxplot(x="element", y="rerr", data=frame, ax=ax, order=order, whis=np.inf, color="c")
        # Add in points to show each observation
        sns.stripplot(x="element", y="rerr", data=frame, ax=ax, order=order, hue='struct_type',
        #              jitter=True, size=5, color=".3", linewidth=0)
                      jitter=0, size=4, color=".3", linewidth=0, palette=sns.color_palette("muted"))

        sns.despine(left=True)
        ax.set_ylabel("Relative error %")

        labels = ax.get_xticklabels()
        ticks = ax.get_xticks()
        ticks1 = range(min(ticks), max(ticks)+1, 2)
        ticks2 = range(min(ticks) + 1, max(ticks)+1, 2)
        labels1 = [labels[i].get_text() for i in ticks1]
        labels2 = [labels[i].get_text() for i in ticks2]

        #       ax.tick_params(which='both', direction='out')
        #ax.set_ylim(-1, 1)
        ax.set_xticks(ticks1)
        ax.set_xticklabels(labels1, rotation=90)
        ax2 = ax.twiny()
        ax2.set_zorder(-1)
        ax2.set_xticks(ticks2)
        ax2.set_xticklabels(labels2, rotation=90)
        ax2.set_xlim(ax.get_xlim())

        ax.grid(True)
        return fig
示例#26
0
    def plot_hints(self, with_soc=False, **kwargs):
        # Build pandas dataframe with results.
        rows = []
        for p in self:
            if not p.has_dojo_report:
                cprint("Cannot find dojo_report in %s" % p.basename, "magenta")
                continue
            report = p.dojo_report
            row = {att: getattr(p, att) for att in ("basename", "symbol", "Z", "Z_val", "l_max")}

            # Get deltafactor data with/without SOC
            df_dict = report.get_last_df_results(with_soc=with_soc)
            row.update(df_dict)
            for struct_type in ["fcc", "bcc"]:
                gbrv_dict = report.get_last_gbrv_results(struct_type, with_soc=with_soc)
            row.update(gbrv_dict)

            # Get the hints
            hint = p.hint_for_accuracy(accuracy="normal")
            row.update(dict(ecut=hint.ecut, pawecutdg=hint.pawecutdg))

            rows.append(row)

        import pandas as pd
        frame = pd.DataFrame(rows)

        def print_frame(x):
            import pandas as pd
            with pd.option_context('display.max_rows', len(x),
                                   'display.max_columns', len(list(x.keys()))):
                print(x)

        print_frame(frame)
        # Create axes
        #import matplotlib.pyplot as plt

        import seaborn as sns
        ax, fig, plt = get_ax_fig_plt(ax=None)

        #order = sort_symbols_by_Z(set(frame["element"]))

        # Box plot
        ax = sns.boxplot(x="symbol", y="ecut", data=frame, ax=ax, #order=order,
                         whis=np.inf, color="c")
        # Add in points to show each observation
        sns.stripplot(x="symbol", y="ecut", data=frame, ax=ax, #order=order,
                      jitter=True, size=5, color=".3", linewidth=0)

        sns.despine(left=True)
        ax.set_ylabel("Relative error %")
        ax.grid(True)

        return fig
示例#27
0
def plot_xy_with_hue(data, x, y, hue, decimals=None, ax=None,
                     xlims=None, ylims=None, fontsize=12, **kwargs):
    """
    Plot y = f(x) relation for different values of `hue`.
    Useful for convergence tests done wrt to two parameters.

    Args:
        data: |pandas-DataFrame| containing columns `x`, `y`, and `hue`.
        x: Name of the column used as x-value
        y: Name of the column used as y-value
        hue: Variable that define subsets of the data, which will be drawn on separate lines
        decimals: Number of decimal places to round `hue` columns. Ignore if None
        ax: |matplotlib-Axes| or None if a new figure should be created.
        xlims ylims: Set the data limits for the x(y)-axis. Accept tuple e.g. `(left, right)`
            or scalar e.g. `left`. If left (right) is None, default values are used
        fontsize: Legend fontsize.
        kwargs: Keywork arguments are passed to ax.plot method.

    Returns: |matplotlib-Figure|
    """
    # Check here because pandas error messages are a bit criptic.
    miss = [k for k in (x, y, hue) if k not in data]
    if miss:
        raise ValueError("Cannot find `%s` in dataframe.\nAvailable keys are: %s" % (str(miss), str(data.keys())))

    # Truncate values in hue column so that we can group.
    if decimals is not None:
        data = data.round({hue: decimals})

    ax, fig, plt = get_ax_fig_plt(ax=ax)
    for key, grp in data.groupby(hue):
        # Sort xs and rearrange ys
        xy = np.array(sorted(zip(grp[x], grp[y]), key=lambda t: t[0]))
        xvals, yvals = xy[:, 0], xy[:, 1]

        label = "{} = {}".format(hue, key)
        if not kwargs:
            ax.plot(xvals, yvals, 'o-', label=label)
        else:
            ax.plot(xvals, yvals, label=label, **kwargs)

    ax.grid(True)
    ax.set_xlabel(x)
    ax.set_ylabel(y)
    set_axlims(ax, xlims, "x")
    set_axlims(ax, ylims, "y")
    ax.legend(loc="best", fontsize=fontsize, shadow=True)

    return fig
示例#28
0
    def plot_ene_vs_ecut(self, ax=None, **kwargs):
        """Plot the converge of ene wrt ecut on axis ax."""
        ax, fig, plt = get_ax_fig_plt(ax)
        lines, legends = [], []
        for l, data in self.ene_vs_ecut.items():
            line, = ax.plot(data.energies, data.values, **self._wf_pltopts(l, "ae"))

            lines.append(line)
            legends.append("Conv l=%s" % str(l))

        decorate_ax(ax, xlabel="Ecut [Ha]", ylabel="$\Delta E$", title="Energy error per electron [Ha]",
                    lines=lines, legends=legends)

        ax.set_yscale("log")
        return fig
示例#29
0
    def _plot_thermo(self,
                     func,
                     temperatures,
                     factor=1,
                     ax=None,
                     ylabel=None,
                     label=None,
                     ylim=None,
                     **kwargs):
        """
        Plots a thermodynamic property for a generic function from a PhononDos instance.

        Args:
            func: the thermodynamic function to be used to calculate the property
            temperatures: a list of temperatures
            factor: a multiplicative factor applied to the thermodynamic property calculated. Used to change
                the units.
            ax: matplotlib :class:`Axes` or None if a new figure should be created.
            ylabel: label for the y axis
            label: label of the plot
            ylim: tuple specifying the y-axis limits.
            kwargs: kwargs passed to the matplotlib function 'plot'.
        Returns:
            matplotlib figure
        """

        ax, fig, plt = get_ax_fig_plt(ax)

        values = []

        for t in temperatures:
            values.append(func(t, structure=self.structure) * factor)

        ax.plot(temperatures, values, label=label, **kwargs)

        if ylim:
            ax.set_ylim(ylim)

        ax.set_xlim((np.min(temperatures), np.max(temperatures)))
        ylim = plt.ylim()
        if ylim[0] < 0 < ylim[1]:
            plt.plot(plt.xlim(), [0, 0], "k-", linewidth=1)

        ax.set_xlabel(r"$T$ (K)")
        if ylabel:
            ax.set_ylabel(ylabel)

        return fig
示例#30
0
    def scatter_hist(self, ax=None, **kwargs):
        from mpl_toolkits.axes_grid1 import make_axes_locatable
        ax, fig, plt = get_ax_fig_plt(ax=ax)

        x = np.asarray(self.get_values("cpu_time"))
        y = np.asarray(self.get_values("wall_time"))

        # the scatter plot:
        axScatter = plt.subplot(1, 1, 1)
        axScatter.scatter(x, y)
        axScatter.set_aspect("auto")

        # create new axes on the right and on the top of the current axes
        # The first argument of the new_vertical(new_horizontal) method is
        # the height (width) of the axes to be created in inches.
        divider = make_axes_locatable(axScatter)
        axHistx = divider.append_axes("top", 1.2, pad=0.1, sharex=axScatter)
        axHisty = divider.append_axes("right", 1.2, pad=0.1, sharey=axScatter)

        # make some labels invisible
        plt.setp(axHistx.get_xticklabels() + axHisty.get_yticklabels(), visible=False)

        # now determine nice limits by hand:
        binwidth = 0.25
        xymax = np.max([np.max(np.fabs(x)), np.max(np.fabs(y))])
        lim = (int(xymax / binwidth) + 1) * binwidth

        bins = np.arange(-lim, lim + binwidth, binwidth)
        axHistx.hist(x, bins=bins)
        axHisty.hist(y, bins=bins, orientation='horizontal')

        # the xaxis of axHistx and yaxis of axHisty are shared with axScatter,
        # thus there is no need to manually adjust the xlim and ylim of these axis.

        #axHistx.axis["bottom"].major_ticklabels.set_visible(False)
        for tl in axHistx.get_xticklabels():
            tl.set_visible(False)
            axHistx.set_yticks([0, 50, 100])

            #axHisty.axis["left"].major_ticklabels.set_visible(False)
            for tl in axHisty.get_yticklabels():
                tl.set_visible(False)
                axHisty.set_xticks([0, 50, 100])

        #plt.draw()
        return fig
示例#31
0
    def plot_potentials(self, ax=None, **kwargs):
        """Plot vl and vloc potentials on axis ax"""
        ax, fig, plt = get_ax_fig_plt(ax)

        lines, legends = [], []
        for l, pot in self.potentials.items():
            line, = ax.plot(pot.rmesh, pot.values, **self._wf_pltopts(l, "ae"))
            lines.append(line)

            if l == -1:
                legends.append("Vloc")
            else:
                legends.append("PS l=%s" % str(l))

        decorate_ax(ax, xlabel="r [Bohr]", ylabel="$v_l(r)$", title="Ion Pseudopotentials",
                    lines=lines, legends=legends)
        return fig
示例#32
0
文件: eos.py 项目: ExpHP/pymatgen
    def plot_ax(self, ax=None, fontsize=12, **kwargs):
        """
        Plot the equation of state on axis `ax`

        Args:
            ax: matplotlib :class:`Axes` or None if a new figure should be created.
            fontsize: Legend fontsize.
            color (str): plot color.
            label (str): Plot label
            text (str): Legend text (options)

        Returns:
            Matplotlib figure object.
        """
        ax, fig, plt = get_ax_fig_plt(ax=ax)

        color = kwargs.get("color", "r")
        label = kwargs.get("label", "{} fit".format(self.__class__.__name__))
        lines = ["Equation of State: %s" % self.__class__.__name__,
                 "Minimum energy = %1.2f eV" % self.e0,
                 "Minimum or reference volume = %1.2f Ang^3" % self.v0,
                 "Bulk modulus = %1.2f eV/Ang^3 = %1.2f GPa" %
                 (self.b0, self.b0_GPa),
                 "Derivative of bulk modulus wrt pressure = %1.2f" % self.b1]
        text = "\n".join(lines)
        text = kwargs.get("text", text)

        # Plot input data.
        ax.plot(self.volumes, self.energies, linestyle="None", marker="o", color=color)

        # Plot eos fit.
        vmin, vmax = min(self.volumes), max(self.volumes)
        vmin, vmax = (vmin - 0.01 * abs(vmin), vmax + 0.01 * abs(vmax))
        vfit = np.linspace(vmin, vmax, 100)

        ax.plot(vfit, self.func(vfit), linestyle="dashed", color=color, label=label)

        ax.grid(True)
        ax.set_xlabel("Volume $\\AA^3$")
        ax.set_ylabel("Energy (eV)")
        ax.legend(loc="best", shadow=True)
        # Add text with fit parameters.
        ax.text(0.5, 0.5, text, fontsize=fontsize, horizontalalignment='center',
            verticalalignment='center', transform=ax.transAxes)

        return fig
示例#33
0
    def pie(self, key="wall_time", minfract=0.05, ax=None, **kwargs):
        """
        Plot pie chart for this timer.

        Args:
            key: Keyword used to extract data from the timer.
            minfract: Don't show sections whose relative weight is less that minfract.
            ax: matplotlib :class:`Axes` or None if a new figure should be created.

        Returns: `matplotlib` figure
        """
        ax, fig, plt = get_ax_fig_plt(ax=ax)
        # Set aspect ratio to be equal so that pie is drawn as a circle.
        ax.axis("equal")
        # Don't show section whose value is less that minfract
        labels, vals = self.names_and_values(key, minfract=minfract)
        ax.pie(vals, explode=None, labels=labels, autopct="%1.1f%%", shadow=True)
        return fig
示例#34
0
    def plot_pie(self, key="wall_time", minfract=0.05, ax=None, **kwargs):
        """Pie charts of the different timers."""
        ax, fig, plt = get_ax_fig_plt(ax=ax)

        timers = self.timers()
        n = len(timers)

        # Make square figures and axes
        the_grid = plt.GridSpec(n, 1)

        fig = plt.figure(1, figsize=(6, 6))

        for idx, timer in enumerate(timers):
            plt.subplot(the_grid[idx, 0])
            plt.title(str(timer))
            timer.pie(key=key, minfract=minfract)

        return fig
示例#35
0
    def plot(self, ax=None, **kwargs):
        """
        Plot the evolution of the energies.

        Args:
            ax (Axes):  a matplotlib Axes or None if a new figure should be created.
            kwargs: arguments passed to matplotlib plot function.

        Returns:
            A matplotlib Figure
        """

        ax, fig, plt = get_ax_fig_plt(ax=ax)

        ax.plot(range(self.n_steps), self.total, **kwargs)
        ax.set_xlabel("Steps")
        ax.set_ylabel("Energy")

        return fig
示例#36
0
    def pie(self, key="wall_time", minfract=0.05, ax=None, **kwargs):
        """
        Plot pie chart for this timer.

        Args:
            key: Keyword used to extract data from the timer.
            minfract: Don't show sections whose relative weight is less that minfract.
            ax: matplotlib :class:`Axes` or None if a new figure should be created.

        Returns:
            `matplotlib` figure
        """
        ax, fig, plt = get_ax_fig_plt(ax=ax)
        # Set aspect ratio to be equal so that pie is drawn as a circle.
        ax.axis("equal")
        # Don't show section whose value is less that minfract
        labels, vals = self.names_and_values(key, minfract=minfract)
        ax.pie(vals, explode=None, labels=labels, autopct='%1.1f%%', shadow=True)
        return fig
示例#37
0
    def _plot_thermo(self, func, temperatures, factor=1, ax=None, ylabel=None, label=None, ylim=None, **kwargs):
        """
        Plots a thermodynamic property for a generic function from a PhononDos instance.

        Args:
            func: the thermodynamic function to be used to calculate the property
            temperatures: a list of temperatures
            factor: a multiplicative factor applied to the thermodynamic property calculated. Used to change
                the units.
            ax: matplotlib :class:`Axes` or None if a new figure should be created.
            ylabel: label for the y axis
            label: label of the plot
            ylim: tuple specifying the y-axis limits.
            kwargs: kwargs passed to the matplotlib function 'plot'.
        Returns:
            matplotlib figure
        """

        ax, fig, plt = get_ax_fig_plt(ax)

        values = []

        for t in temperatures:
            values.append(func(t, structure=self.structure) * factor)

        ax.plot(temperatures, values, label=label, **kwargs)

        if ylim:
            ax.set_ylim(ylim)

        ax.set_xlim((np.min(temperatures), np.max(temperatures)))
        ylim = plt.ylim()
        if ylim[0] < 0 < ylim[1]:
            plt.plot(plt.xlim(), [0, 0], 'k-', linewidth=1)

        ax.set_xlabel(r"$T$ (K)")
        if ylabel:
            ax.set_ylabel(ylabel)

        return fig
示例#38
0
    def plot_errors_for_structure(self, struct_type, ax=None, **kwargs):
        """
        Plot the errors for a given crystalline structure.
        """
        ax, fig, plt = get_ax_fig_plt(ax=ax)
        data = self[self["struct_type"] == struct_type].copy()
        if not len(data):
            print("No results available for struct_type:", struct_type)
            return None

        colnames = ["this", "gbrv_paw"]
        for col in colnames:
            data[col +
                 "_rel_err"] = 100 * (data[col] - data["ae"]) / data["ae"]
            #data[col + "_rel_err"] = abs(100 * (data[col] - data["ae"]) / data["ae"])
            data.plot(x="formula",
                      y=col + "_rel_err",
                      ax=ax,
                      style="o-",
                      grid=True)
        labels = data['formula'].values
        ax.set_ylabel("relative error %% for %s" % struct_type)
        ticks = list(range(len(data.index)))
        ticks1 = range(min(ticks), max(ticks) + 1, 2)
        ticks2 = range(min(ticks) + 1, max(ticks) + 1, 2)
        labels1 = [labels[i] for i in ticks1]
        labels2 = [labels[i] for i in ticks2]
        #       ax.tick_params(which='both', direction='out')
        #ax.set_ylim(-1, 1)
        ax.set_xticks(ticks1)
        ax.set_xticklabels(labels1, rotation=90)
        ax2 = ax.twiny()
        ax2.set_zorder(-1)
        ax2.set_xticks(ticks2)
        ax2.set_xticklabels(labels2, rotation=90)
        ax2.set_xlim(ax.get_xlim())

        return fig
示例#39
0
    def plot_eos(self, ax=None, accuracy="normal", **kwargs):
        """
        Plot the equation of state.

        Args:
            ax: matplotlib :class:`Axes` or None if a new figure should be created.

        Returns:
            `matplotlib` figure.
        """
        ax, fig, plt = get_ax_fig_plt(ax)
        if not self.has_data(accuracy): return fig
        d = self["accuracy"]

        num_sites, volumes, etotals = d["num_sites"], np.array(d["volumes"]), np.array(d["etotals"])

        # Perform quadratic fit.
        eos = EOS.Quadratic()
        eos_fit = eos.fit(volumes/num_sites, etotals/num_sites)

        label = "ecut %.1f" % d["ecut"]
        eos_fit.plot(ax=ax, text=False, label=label, show=False) # color=cmap(i/num_ecuts, alpha=1),
        return fig
示例#40
0
    def plot_errors_for_elements(self, ax=None, **kwargs):
        """
        Plot the relative errors associated to the chemical elements.
        """
        dict_list = []
        for idx, row in self.iterrows():
            rerr = 100 * (row["this"] - row["ae"]) / row["ae"]
            for symbol in set(species_from_formula(row.formula)):
                dict_list.append(
                    dict(
                        element=symbol,
                        rerr=rerr,
                        formula=row.formula,
                        struct_type=row.struct_type,
                    ))

        frame = DataFrame(dict_list)
        order = sort_symbols_by_Z(set(frame["element"]))
        #print_frame(frame)

        import seaborn as sns
        ax, fig, plt = get_ax_fig_plt(ax=ax)

        # Draw violinplot
        #sns.violinplot(x="element", y="rerr", order=order, data=frame, ax=ax, orient="v")

        # Box plot
        ax = sns.boxplot(x="element",
                         y="rerr",
                         data=frame,
                         ax=ax,
                         order=order,
                         whis=np.inf,
                         color="c")
        # Add in points to show each observation
        sns.stripplot(
            x="element",
            y="rerr",
            data=frame,
            ax=ax,
            order=order,
            hue='struct_type',
            #              jitter=True, size=5, color=".3", linewidth=0)
            jitter=0,
            size=4,
            color=".3",
            linewidth=0,
            palette=sns.color_palette("muted"))

        sns.despine(left=True)
        ax.set_ylabel("Relative error %")

        labels = ax.get_xticklabels()
        ticks = ax.get_xticks()
        ticks1 = range(min(ticks), max(ticks) + 1, 2)
        ticks2 = range(min(ticks) + 1, max(ticks) + 1, 2)
        labels1 = [labels[i].get_text() for i in ticks1]
        labels2 = [labels[i].get_text() for i in ticks2]

        #       ax.tick_params(which='both', direction='out')
        #ax.set_ylim(-1, 1)
        ax.set_xticks(ticks1)
        ax.set_xticklabels(labels1, rotation=90)
        ax2 = ax.twiny()
        ax2.set_zorder(-1)
        ax2.set_xticks(ticks2)
        ax2.set_xticklabels(labels2, rotation=90)
        ax2.set_xlim(ax.get_xlim())

        ax.grid(True)
        return fig
示例#41
0
    def plot_efficiency(self, key="wall_time", what="good+bad", nmax=5, ax=None, **kwargs):
        """
        Plot the parallel efficiency

        Args:
            key: Parallel efficiency is computed using the wall_time.
            what: Specifies what to plot: `good` for sections with good parallel efficiency.
                `bad` for sections with bad efficiency. Options can be concatenated with `+`.
            nmax: Maximum number of entries in plot
            ax: matplotlib :class:`Axes` or None if a new figure should be created.

        ================  ====================================================
        kwargs            Meaning
        ================  ====================================================
        linewidth         matplotlib linewidth. Default: 2.0
        markersize        matplotlib markersize. Default: 10
        ================  ====================================================

        Returns:
            `matplotlib` figure
        """
        ax, fig, plt = get_ax_fig_plt(ax=ax)
        lw = kwargs.pop("linewidth", 2.0)
        msize = kwargs.pop("markersize", 10)
        what = what.split("+")

        timers = self.timers()
        peff = self.pefficiency()
        n = len(timers)
        xx = np.arange(n)

        #ax.set_color_cycle(['g', 'b', 'c', 'm', 'y', 'k'])
        ax.set_prop_cycle(color=['g', 'b', 'c', 'm', 'y', 'k'])

        lines, legend_entries = [], []
        # Plot sections with good efficiency.
        if "good" in what:
            good = peff.good_sections(key=key, nmax=nmax)
            for g in good:
                #print(g, peff[g])
                yy = peff[g][key]
                line, = ax.plot(xx, yy, "-->", linewidth=lw, markersize=msize)
                lines.append(line)
                legend_entries.append(g)

        # Plot sections with bad efficiency.
        if "bad" in what:
            bad = peff.bad_sections(key=key, nmax=nmax)
            for b in bad:
                #print(b, peff[b])
                yy = peff[b][key]
                line, = ax.plot(xx, yy, "-.<", linewidth=lw, markersize=msize)
                lines.append(line)
                legend_entries.append(b)

        # Add total if not already done
        if "total" not in legend_entries:
            yy = peff["total"][key]
            total_line, = ax.plot(xx, yy, "r", linewidth=lw, markersize=msize)
            lines.append(total_line)
            legend_entries.append("total")

        ax.legend(lines, legend_entries, loc="best", shadow=True)

        #ax.set_title(title)
        ax.set_xlabel('Total_NCPUs')
        ax.set_ylabel('Efficiency')
        ax.grid(True)

        # Set xticks and labels.
        labels = ["MPI=%d, OMP=%d" % (t.mpi_nprocs, t.omp_nthreads) for t in timers]
        ax.set_xticks(xx)
        ax.set_xticklabels(labels, fontdict=None, minor=False, rotation=15)

        return fig
示例#42
0
    def plot_efficiency(self,
                        key="wall_time",
                        what="good+bad",
                        nmax=5,
                        ax=None,
                        **kwargs):
        """
        Plot the parallel efficiency

        Args:
            key: Parallel efficiency is computed using the wall_time.
            what: Specifies what to plot: `good` for sections with good parallel efficiency.
                `bad` for sections with bad efficiency. Options can be concatenated with `+`.
            nmax: Maximum number of entries in plot
            ax: matplotlib :class:`Axes` or None if a new figure should be created.

        ================  ====================================================
        kwargs            Meaning
        ================  ====================================================
        linewidth         matplotlib linewidth. Default: 2.0
        markersize        matplotlib markersize. Default: 10
        ================  ====================================================

        Returns:
            `matplotlib` figure
        """
        ax, fig, plt = get_ax_fig_plt(ax=ax)
        lw = kwargs.pop("linewidth", 2.0)
        msize = kwargs.pop("markersize", 10)
        what = what.split("+")

        timers = self.timers()
        peff = self.pefficiency()
        n = len(timers)
        xx = np.arange(n)

        # ax.set_color_cycle(['g', 'b', 'c', 'm', 'y', 'k'])
        ax.set_prop_cycle(color=['g', 'b', 'c', 'm', 'y', 'k'])

        lines, legend_entries = [], []
        # Plot sections with good efficiency.
        if "good" in what:
            good = peff.good_sections(key=key, nmax=nmax)
            for g in good:
                # print(g, peff[g])
                yy = peff[g][key]
                line, = ax.plot(xx, yy, "-->", linewidth=lw, markersize=msize)
                lines.append(line)
                legend_entries.append(g)

        # Plot sections with bad efficiency.
        if "bad" in what:
            bad = peff.bad_sections(key=key, nmax=nmax)
            for b in bad:
                # print(b, peff[b])
                yy = peff[b][key]
                line, = ax.plot(xx, yy, "-.<", linewidth=lw, markersize=msize)
                lines.append(line)
                legend_entries.append(b)

        # Add total if not already done
        if "total" not in legend_entries:
            yy = peff["total"][key]
            total_line, = ax.plot(xx, yy, "r", linewidth=lw, markersize=msize)
            lines.append(total_line)
            legend_entries.append("total")

        ax.legend(lines, legend_entries, loc="best", shadow=True)

        # ax.set_title(title)
        ax.set_xlabel('Total_NCPUs')
        ax.set_ylabel('Efficiency')
        ax.grid(True)

        # Set xticks and labels.
        labels = [
            "MPI=%d, OMP=%d" % (t.mpi_nprocs, t.omp_nthreads) for t in timers
        ]
        ax.set_xticks(xx)
        ax.set_xticklabels(labels, fontdict=None, minor=False, rotation=15)

        return fig
示例#43
0
    def plot(self, ax=None, **kwargs):
        """
        Uses Matplotlib to plot the energy curve.

        Args:
            ax: :class:`Axes` object. If ax is None, a new figure is produced.

        ================  ==============================================================
        kwargs            Meaning
        ================  ==============================================================
        style
        color
        text
        label
        ================  ==============================================================

        Returns:
            Matplotlib figure.
        """
        ax, fig, plt = get_ax_fig_plt(ax)

        vmin, vmax = self.volumes.min(), self.volumes.max()
        emin, emax = self.energies.min(), self.energies.max()

        vmin, vmax = (vmin - 0.01 * abs(vmin), vmax + 0.01 * abs(vmax))
        emin, emax = (emin - 0.01 * abs(emin), emax + 0.01 * abs(emax))

        color = kwargs.pop("color", "r")
        label = kwargs.pop("label", None)

        # Plot input data.
        ax.plot(self.volumes,
                self.energies,
                linestyle="None",
                marker="o",
                color=color)  #, label="Input Data")

        # Plot EOS.
        vfit = np.linspace(vmin, vmax, 100)
        if label is None:
            label = self.name + ' fit'

        if self.eos_name == "deltafactor":
            xx = vfit**(-2. / 3.)
            ax.plot(vfit,
                    np.polyval(self.eos_params, xx),
                    linestyle="dashed",
                    color=color,
                    label=label)
        else:
            ax.plot(vfit,
                    self.func(vfit, *self.eos_params),
                    linestyle="dashed",
                    color=color,
                    label=label)

        # Set xticks and labels.
        ax.grid(True)
        ax.set_xlabel("Volume $\AA^3$")
        ax.set_ylabel("Energy (eV)")

        ax.legend(loc="best", shadow=True)

        # Add text with fit parameters.
        if kwargs.pop("text", True):
            text = []
            app = text.append
            app("Min Volume = %1.2f $\AA^3$" % self.v0)
            app("Bulk modulus = %1.2f eV/$\AA^3$ = %1.2f GPa" %
                (self.b0, self.b0_GPa))
            app("B1 = %1.2f" % self.b1)
            fig.text(0.4, 0.5, "\n".join(text), transform=ax.transAxes)

        return fig