Esempio n. 1
0
    def get_plot(
        self,
        n_idx,
        t_idx,
        x_property="mean free path",
        y_property="conductivity",
        height=6,
        width=6,
        xlabel=None,
        ylabel=None,
        xlim=None,
        ylim=None,
        logx=False,
        plt=None,
        style=None,
        no_base_style=False,
        fonts=None,
    ):
        x_values, y_values = self.get_plot_data(n_idx, t_idx, x_property,
                                                y_property)

        plt = pretty_plot(width=width, height=height, plt=plt)
        ax = plt.gca()
        ax.plot(x_values, y_values)

        xlabel = xlabel if xlabel else _x_labels[x_property.lower()]
        ylabel = ylabel if ylabel else _y_labels[y_property.lower()]
        ax.set(xlabel=xlabel, ylabel=ylabel, xlim=xlim, ylim=ylim)

        if logx:
            ax.semilogx()
Esempio n. 2
0
    def get_plot(
        self,
        n_idx,
        t_idx,
        zero_to_efermi=True,
        estep=0.01,
        line_density=100,
        height=3.2,
        width=3.2,
        emin=None,
        emax=None,
        amin=5e-5,
        amax=1e-1,
        ylabel="Energy (eV)",
        plt=None,
        aspect=None,
        kpath=None,
        cmap="viridis",
        colorbar=True,
        style=None,
        no_base_style=False,
        fonts=None,
    ):
        interpolater = self._get_interpolater(n_idx, t_idx)

        bs, prop = interpolater.get_line_mode_band_structure(
            line_density=line_density,
            return_other_properties=True,
            kpath=kpath,
            symprec=self.symprec,
        )
        bs, rates = force_branches(bs, {s: p["rates"] for s, p in prop.items()})

        fd_emin, fd_emax = self.fd_cutoffs
        if not emin:
            emin = fd_emin
            if zero_to_efermi:
                emin -= bs.efermi

        if not emax:
            emax = fd_emax
            if zero_to_efermi:
                emax -= bs.efermi

        logger.info("Plotting band structure")
        if isinstance(plt, (Axis, SubplotBase)):
            ax = plt
        else:
            plt = pretty_plot(width=width, height=height, plt=plt)
            ax = plt.gca()

        if zero_to_efermi:
            bs.bands = {s: b - bs.efermi for s, b in bs.bands.items()}
            bs.efermi = 0

        bs_plotter = BSPlotter(bs)
        plot_data = bs_plotter.bs_plot_data(zero_to_efermi=zero_to_efermi)

        energies = np.linspace(emin, emax, int((emax - emin) / estep))
        distances = np.array([d for x in plot_data["distances"] for d in x])

        # rates are currently log(rate)
        mesh_data = np.full((len(distances), len(energies)), 0.0)
        for spin in self.spins:
            for spin_energies, spin_rates in zip(bs.bands[spin], rates[spin]):
                for d_idx in range(len(distances)):
                    energy = spin_energies[d_idx]
                    linewidth = 10 ** spin_rates[d_idx] * hbar / 2
                    broadening = lorentzian(energies, energy, linewidth)
                    broadening /= 1000  # convert 1/eV to 1/meV
                    mesh_data[d_idx] += broadening

        im = ax.pcolormesh(
            distances,
            energies,
            mesh_data.T,
            rasterized=True,
            cmap=cmap,
            norm=LogNorm(vmin=amin, vmax=amax),
            shading="auto",
        )
        if colorbar:
            pos = ax.get_position()
            cax = plt.gcf().add_axes([pos.x1 + 0.035, pos.y0, 0.035, pos.height])
            cbar = plt.colorbar(im, cax=cax)
            cbar.ax.tick_params(axis="y", length=rcParams["ytick.major.size"] * 0.5)
            cbar.ax.set_ylabel(
                r"$A_\mathbf{k}$ (meV$^{-1}$)", rotation=270, va="bottom"
            )

        _maketicks(ax, bs_plotter, ylabel=ylabel)
        _makeplot(
            ax,
            plot_data,
            bs,
            zero_to_efermi=zero_to_efermi,
            width=width,
            height=height,
            ymin=emin,
            ymax=emax,
            aspect=aspect,
        )
        return plt
Esempio n. 3
0
    def get_plot(self,
                 subplot=False,
                 width=None,
                 height=None,
                 xmin=-6.,
                 xmax=6.,
                 yscale=1,
                 colours=None,
                 plot_total=True,
                 legend_on=True,
                 num_columns=2,
                 legend_frame_on=False,
                 legend_cutoff=3,
                 xlabel='Energy (eV)',
                 ylabel='Arb. units',
                 zero_to_efermi=True,
                 dpi=400,
                 fonts=None,
                 plt=None,
                 style=None,
                 no_base_style=False,
                 spin=None):
        """Get a :obj:`matplotlib.pyplot` object of the density of states.

        Args:
            subplot (:obj:`bool`, optional): Plot the density of states for
                each element on separate subplots. Defaults to ``False``.
            width (:obj:`float`, optional): The width of the plot.
            height (:obj:`float`, optional): The height of the plot.
            xmin (:obj:`float`, optional): The minimum energy on the x-axis.
            xmax (:obj:`float`, optional): The maximum energy on the x-axis.
            yscale (:obj:`float`, optional): Scaling factor for the y-axis.
            colours (:obj:`dict`, optional): Use custom colours for specific
                element and orbital combinations. Specified as a :obj:`dict` of
                :obj:`dict` of the colours. For example::

                    {
                        'Sn': {'s': 'r', 'p': 'b'},
                        'O': {'s': '#000000'}
                    }

                The colour can be a hex code, series of rgb value, or any other
                format supported by matplotlib.
            plot_total (:obj:`bool`, optional): Plot the total density of
                states. Defaults to ``True``.
            legend_on (:obj:`bool`, optional): Plot the graph legend. Defaults
                to ``True``.
            num_columns (:obj:`int`, optional): The number of columns in the
                legend.
            legend_frame_on (:obj:`bool`, optional): Plot a frame around the
                graph legend. Defaults to ``False``.
            legend_cutoff (:obj:`float`, optional): The cut-off (in % of the
                maximum density of states within the plotting range) for an
                elemental orbital to be labelled in the legend. This prevents
                the legend from containing labels for orbitals that have very
                little contribution in the plotting range.
            xlabel (:obj:`str`, optional): Label/units for x-axis (i.e. energy)
            ylabel (:obj:`str`, optional): Label/units for y-axis (i.e. DOS)
            zero_to_efermi (:obj:`bool`, optional): Normalise the plot such
                that the Fermi level is set as 0 eV.
            dpi (:obj:`int`, optional): The dots-per-inch (pixel density) for
                the image.
            fonts (:obj:`list`, optional): Fonts to use in the plot. Can be a
                a single font, specified as a :obj:`str`, or several fonts,
                specified as a :obj:`list` of :obj:`str`.
            plt (:obj:`matplotlib.pyplot`, optional): A
                :obj:`matplotlib.pyplot` object to use for plotting.
            style (:obj:`list`, :obj:`str`, or :obj:`dict`): Any matplotlib
                style specifications, to be composed on top of Sumo base
                style.
            no_base_style (:obj:`bool`, optional): Prevent use of sumo base
                style. This can make alternative styles behave more
                predictably.
            spin (:obj:`Spin`, optional): Plot a spin-polarised density of states,
            "up" or "1" for spin up only, "down" or "-1" for spin down only.
            Defaults to ``None``.

        Returns:
            :obj:`matplotlib.pyplot`: The density of states plot.
        """
        plot_data = self.dos_plot_data(yscale=yscale,
                                       xmin=xmin,
                                       xmax=xmax,
                                       colours=colours,
                                       plot_total=plot_total,
                                       legend_cutoff=legend_cutoff,
                                       subplot=subplot,
                                       zero_to_efermi=zero_to_efermi,
                                       spin=spin)

        if subplot:
            nplots = len(plot_data['lines'])
            plt = pretty_subplot(nplots,
                                 1,
                                 width=width,
                                 height=height,
                                 dpi=dpi,
                                 plt=plt)
        else:
            plt = pretty_plot(width=width, height=height, dpi=dpi, plt=plt)

        mask = plot_data['mask']
        energies = plot_data['energies'][mask]
        fig = plt.gcf()
        lines = plot_data['lines']
        if len(lines[0][0]['dens']) == 1:
            spins = [Spin.up]
        elif spin is not None:
            spins = [spin]
        else:
            spins = [Spin.up, Spin.down]

        for i, line_set in enumerate(plot_data['lines']):
            if subplot:
                ax = fig.axes[i]
            else:
                ax = plt.gca()

            for line, spin in itertools.product(line_set, spins):
                if len(spins) == 1:
                    label = line['label']
                    densities = line['dens'][spin][mask]
                elif spin is Spin.up:
                    label = line['label']
                    densities = line['dens'][spin][mask]
                elif spin is Spin.down:
                    label = ""
                    densities = -line['dens'][spin][mask]
                ax.fill_between(energies,
                                densities,
                                lw=0,
                                facecolor=line['colour'],
                                alpha=line['alpha'])
                ax.plot(energies, densities, label=label, color=line['colour'])

            ax.set_xlim(xmin, xmax)
            if len(spins) == 1:
                ax.set_ylim(0, plot_data['ymax'])
            else:
                ax.set_ylim(plot_data['ymin'], plot_data['ymax'])

            ax.tick_params(axis='y', labelleft=False)
            ax.yaxis.set_minor_locator(AutoMinorLocator(2))
            ax.xaxis.set_minor_locator(AutoMinorLocator(2))

            loc = 'upper right' if subplot else 'best'
            ncol = 1 if subplot else num_columns
            if legend_on:
                ax.legend(loc=loc, frameon=legend_frame_on, ncol=ncol)

        # no add axis labels and sort out ticks
        if subplot:
            ax.set_xlabel(xlabel)
            fig.subplots_adjust(hspace=0)
            plt.setp([a.get_xticklabels() for a in fig.axes[:-1]],
                     visible=False)
            if 'axes.labelcolor' in matplotlib.rcParams:
                ylabelcolor = matplotlib.rcParams['axes.labelcolor']
            else:
                ylabelcolor = None

            fig.text(0.08,
                     0.5,
                     ylabel,
                     ha='left',
                     color=ylabelcolor,
                     va='center',
                     rotation='vertical',
                     transform=ax.transAxes)
        else:
            ax.set_xlabel(xlabel)
            ax.set_ylabel(ylabel)

        return plt
Esempio n. 4
0
    def get_plot(
        self,
        zero_to_efermi=True,
        ymin=-6.0,
        ymax=6.0,
        width=None,
        height=None,
        vbm_cbm_marker=False,
        ylabel="Energy (eV)",
        dpi=None,
        plt=None,
        plot_dos_legend=True,
        dos_plotter=None,
        dos_options=None,
        dos_label=None,
        dos_aspect=3,
        aspect=None,
        spin=None,
        fonts=None,
        style=None,
        no_base_style=False,
    ):
        """Get a :obj:`matplotlib.pyplot` object of the band structure.

        If the system is spin polarised, and no spin has been specified, orange
        lines are spin up, dashed blue lines are spin down. For metals, all
        bands are coloured blue. For semiconductors, blue lines indicate
        valence bands and orange lines indicates conduction bands.

        Args:
            zero_to_efermi (:obj:`bool`): Normalise the plot such that the
                valence band maximum is set as 0 eV.
            ymin (:obj:`float`, optional): The minimum energy on the y-axis.
            ymax (:obj:`float`, optional): The maximum energy on the y-axis.
            width (:obj:`float`, optional): The width of the plot.
            height (:obj:`float`, optional): The height of the plot.
            vbm_cbm_marker (:obj:`bool`, optional): Plot markers to indicate
                the VBM and CBM locations.
            ylabel (:obj:`str`, optional): y-axis (i.e. energy) label/units
            dpi (:obj:`int`, optional): The dots-per-inch (pixel density) for
                the image.
            plt (:obj:`matplotlib.pyplot`, optional): A
                :obj:`matplotlib.pyplot` object to use for plotting.
            dos_plotter (:obj:`~sumo.plotting.dos_plotter.SDOSPlotter`, \
                optional): Plot the density of states alongside the band
                structure. This should be a
                :obj:`~sumo.plotting.dos_plotter.SDOSPlotter` object
                initialised with the data to plot.
            dos_options (:obj:`dict`, optional): The options for density of
                states plotting. This should be formatted as a :obj:`dict`
                containing any of the following keys:

                    "yscale" (:obj:`float`)
                        Scaling factor for the y-axis.
                    "xmin" (:obj:`float`)
                        The minimum energy to mask the energy and density of
                        states data (reduces plotting load).
                    "xmax" (:obj:`float`)
                        The maximum energy to mask the energy and density of
                        states data (reduces plotting load).
                    "colours" (:obj:`dict`)
                        Use custom colours for specific element and orbital
                        combinations. Specified as a :obj:`dict` of
                        :obj:`dict` of the colours. For example::

                            {
                                'Sn': {'s': 'r', 'p': 'b'},
                                'O': {'s': '#000000'}
                            }

                        The colour can be a hex code, series of rgb value, or
                        any other format supported by matplotlib.
                    "plot_total" (:obj:`bool`)
                        Plot the total density of states. Defaults to ``True``.
                    "legend_cutoff" (:obj:`float`)
                        The cut-off (in % of the maximum density of states
                        within the plotting range) for an elemental orbital to
                        be labelled in the legend. This prevents the legend
                        from containing labels for orbitals that have very
                        little contribution in the plotting range.
                    "subplot" (:obj:`bool`)
                        Plot the density of states for each element on separate
                        subplots. Defaults to ``False``.

            dos_label (:obj:`str`, optional): DOS axis label/units
            dos_aspect (:obj:`float`, optional): Aspect ratio for the band
                structure and density of states subplot. For example,
                ``dos_aspect = 3``, results in a ratio of 3:1, for the band
                structure:dos plots.
            plot_dos_legend (:obj:`bool`): Whether to plot the dos legend.
            aspect (:obj:`float`, optional): The aspect ratio of the band
                structure plot. By default the dimensions of the figure size
                are used to determine the aspect ratio. Set to ``1`` to force
                the plot to be square.
            spin (:obj:`Spin`, optional): Plot a spin-polarised band structure,
                "up" or "1" for spin up only, "down" or "-1" for spin down only.
                Defaults to ``None``.
            fonts (:obj:`list`, optional): Fonts to use in the plot. Can be a
                a single font, specified as a :obj:`str`, or several fonts,
                specified as a :obj:`list` of :obj:`str`.
            style (:obj:`list`, :obj:`str`, or :obj:`dict`): Any matplotlib
                style specifications, to be composed on top of Sumo base
                style.
                no_base_style (:obj:`bool`, optional): Prevent use of sumo base
                style. This can make alternative styles behave more
                predictably.
            no_base_style (:obj:`bool`, optional): Prevent use of sumo base
                style. This can make alternative styles behave more
                predictably.

        Returns:
            :obj:`matplotlib.pyplot`: The electronic band structure plot.
        """
        if dos_plotter:
            plt = pretty_subplot(
                1,
                2,
                width=width,
                height=height,
                sharex=False,
                dpi=dpi,
                plt=plt,
                gridspec_kw={
                    "width_ratios": [dos_aspect, 1],
                    "wspace": 0
                },
            )
            ax = plt.gcf().axes[0]
        else:
            plt = pretty_plot(width=width, height=height, dpi=dpi, plt=plt)
            ax = plt.gca()

        data = self.bs_plot_data(zero_to_efermi)
        dists = data["distances"]
        eners = data["energy"]

        if spin is not None and not self.bs.is_spin_polarized:
            raise ValueError(
                "Spin-selection only possible with spin-polarised "
                "calculation results")
        elif self.bs.is_metal() or (self.bs.is_spin_polarized and not spin):
            # if metal or spin polarized and spin not specified
            is_vb = [True]
        elif spin:
            # not metal, spin-polarized and spin is set
            is_vb = self.bs.bands[spin] <= self.bs.get_vbm()["energy"]
        else:
            # not metal, not spin polarized and therefore spin not set
            is_vb = self.bs.bands[Spin.up] <= self.bs.get_vbm()["energy"]

        # nd is branch index, nb is band index, nk is kpoint index
        for nd, nb in it.product(range(len(data["distances"])),
                                 range(self.nbands)):
            e = (eners[str(spin)][nd][nb]
                 if spin is not None else eners[str(Spin.up)][nd][nb])

            # For closed-shell calculations with a bandgap, colour valence
            # bands blue (C0) and conduction bands orange (C1)
            #
            # For closed-shell calculations with no bandgap, colour with C0
            #
            # For spin-polarized calculations, colour spin up channel with C1
            # and overlay with C0 (dashed) spin down channel

            if self.bs.is_spin_polarized and spin is None:
                c = "C1"
            elif self.bs.is_metal() or np.all(is_vb[nb]):
                c = "C0"
            else:
                c = "C1"

            ax.plot(dists[nd], e, ls="-", c=c, zorder=1)

        # Plot second spin channel if it exists and no spin selected
        if self.bs.is_spin_polarized and spin is None:
            for nd, nb in it.product(range(len(data["distances"])),
                                     range(self.nbands)):
                e = eners[str(Spin.down)][nd][nb]
                ax.plot(dists[nd], e, c="C0", linestyle="--", zorder=2)

        self._maketicks(ax, ylabel=ylabel)
        self._makeplot(
            ax,
            plt.gcf(),
            data,
            zero_to_efermi=zero_to_efermi,
            vbm_cbm_marker=vbm_cbm_marker,
            width=width,
            height=height,
            ymin=ymin,
            ymax=ymax,
            dos_plotter=dos_plotter,
            dos_options=dos_options,
            plot_dos_legend=plot_dos_legend,
            dos_label=dos_label,
            aspect=aspect,
        )
        return plt
Esempio n. 5
0
    def get_projected_plot(
        self,
        selection,
        mode="rgb",
        normalise="all",
        interpolate_factor=4,
        circle_size=150,
        projection_cutoff=0.001,
        zero_to_efermi=True,
        ymin=-6.0,
        ymax=6.0,
        width=None,
        height=None,
        vbm_cbm_marker=False,
        ylabel="Energy (eV)",
        dpi=400,
        plt=None,
        dos_plotter=None,
        dos_options=None,
        dos_label=None,
        plot_dos_legend=True,
        dos_aspect=3,
        aspect=None,
        fonts=None,
        style=None,
        no_base_style=False,
        spin=None,
    ):
        """Get a :obj:`matplotlib.pyplot` of the projected band structure.

        If the system is spin polarised, no spin has been specified and
        ``mode = 'rgb'`` spin up and spin down bands are differentiated by
        solid and dashed lines, respectively.
        For the other modes, spin up and spin down are plotted separately.

        Args:
            selection (list): A list of :obj:`tuple` or :obj:`string`
                identifying which elements and orbitals to project on to the
                band structure. These can be specified by both element and
                orbital, for example, the following will project the Bi s, p
                and S p orbitals::

                    [('Bi', 's'), ('Bi', 'p'), ('S', 'p')]

                If just the element is specified then all the orbitals of
                that element are combined. For example, to sum all the S
                orbitals::

                    [('Bi', 's'), ('Bi', 'p'), 'S']

                You can also choose to sum particular orbitals by supplying a
                :obj:`tuple` of orbitals. For example, to sum the S s, p, and
                d orbitals into a single projection::

                  [('Bi', 's'), ('Bi', 'p'), ('S', ('s', 'p', 'd'))]

                If ``mode = 'rgb'``, a maximum of 3 orbital/element
                combinations can be plotted simultaneously (one for red, green
                and blue), otherwise an unlimited number of elements/orbitals
                can be selected.
            mode (:obj:`str`, optional): Type of projected band structure to
                plot. Options are:

                    "rgb"
                        The band structure line color depends on the character
                        of the band. Each element/orbital contributes either
                        red, green or blue with the corresponding line colour a
                        mixture of all three colours. This mode only supports
                        up to 3 elements/orbitals combinations. The order of
                        the ``selection`` :obj:`tuple` determines which colour
                        is used for each selection.
                    "stacked"
                        The element/orbital contributions are drawn as a
                        series of stacked circles, with the colour depending on
                        the composition of the band. The size of the circles
                        can be scaled using the ``circle_size`` option.

            normalise (:obj:`str`, optional): Normalisation the projections.
                Options are:

                  * ``'all'``: Projections normalised against the sum of all
                       other projections.
                  * ``'select'``: Projections normalised against the sum of the
                       selected projections.
                  * ``None``: No normalisation performed.

            interpolate_factor (:obj:`int`, optional): The factor by which to
                interpolate the band structure (necessary to make smooth
                lines). A larger number indicates greater interpolation.
            circle_size (:obj:`float`, optional): The area of the circles used
                when ``mode = 'stacked'``.
            projection_cutoff (:obj:`float`): Don't plot projections with
                intensities below this number. This option is useful for
                stacked plots, where small projections clutter the plot.
            zero_to_efermi (:obj:`bool`): Normalise the plot such that the
                valence band maximum is set as 0 eV.
            ymin (:obj:`float`, optional): The minimum energy on the y-axis.
            ymax (:obj:`float`, optional): The maximum energy on the y-axis.
            width (:obj:`float`, optional): The width of the plot.
            height (:obj:`float`, optional): The height of the plot.
            vbm_cbm_marker (:obj:`bool`, optional): Plot markers to indicate
                the VBM and CBM locations.
            ylabel (:obj:`str`, optional): y-axis (i.e. energy) label/units
            dpi (:obj:`int`, optional): The dots-per-inch (pixel density) for
                the image.
            plt (:obj:`matplotlib.pyplot`, optional): A
                :obj:`matplotlib.pyplot` object to use for plotting.
            dos_plotter (:obj:`~sumo.plotting.dos_plotter.SDOSPlotter`, \
                optional): Plot the density of states alongside the band
                structure. This should be a
                :obj:`~sumo.plotting.dos_plotter.SDOSPlotter` object
                initialised with the data to plot.
            dos_options (:obj:`dict`, optional): The options for density of
                states plotting. This should be formatted as a :obj:`dict`
                containing any of the following keys:

                    "yscale" (:obj:`float`)
                        Scaling factor for the y-axis.
                    "xmin" (:obj:`float`)
                        The minimum energy to mask the energy and density of
                        states data (reduces plotting load).
                    "xmax" (:obj:`float`)
                        The maximum energy to mask the energy and density of
                        states data (reduces plotting load).
                    "colours" (:obj:`dict`)
                        Use custom colours for specific element and orbital
                        combinations. Specified as a :obj:`dict` of
                        :obj:`dict` of the colours. For example::

                           {
                                'Sn': {'s': 'r', 'p': 'b'},
                                'O': {'s': '#000000'}
                            }

                        The colour can be a hex code, series of rgb value, or
                        any other format supported by matplotlib.
                    "plot_total" (:obj:`bool`)
                        Plot the total density of states. Defaults to ``True``.
                    "legend_cutoff" (:obj:`float`)
                        The cut-off (in % of the maximum density of states
                        within the plotting range) for an elemental orbital to
                        be labelled in the legend. This prevents the legend
                        from containing labels for orbitals that have very
                        little contribution in the plotting range.
                    "subplot" (:obj:`bool`)
                        Plot the density of states for each element on separate
                        subplots. Defaults to ``False``.

            dos_label (:obj:`str`, optional): DOS axis label/units
            plot_dos_legend (:obj:`bool`): Whether to plot the dos legend.
            dos_aspect (:obj:`float`, optional): Aspect ratio for the band
                structure and density of states subplot. For example,
                ``dos_aspect = 3``, results in a ratio of 3:1, for the band
                structure:dos plots.
            aspect (:obj:`float`, optional): The aspect ratio of the band
                structure plot. By default the dimensions of the figure size
                are used to determine the aspect ratio. Set to ``1`` to force
                the plot to be square.
            fonts (:obj:`list`, optional): Fonts to use in the plot. Can be a
                a single font, specified as a :obj:`str`, or several fonts,
                specified as a :obj:`list` of :obj:`str`.
            style (:obj:`list`, :obj:`str`, or :obj:`dict`): Any matplotlib
                style specifications, to be composed on top of Sumo base
                style.
            no_base_style (:obj:`bool`, optional): Prevent use of sumo base
                style. This can make alternative styles behave more
                predictably.
            spin (:obj:`Spin`, optional): Plot a spin-polarised band structure,
                "up" or "1" for spin up only, "down" or "-1" for spin down only.
                Defaults to ``None``.

        Returns:
            :obj:`matplotlib.pyplot`: The projected electronic band structure
            plot.
        """
        if mode == "rgb" and len(selection) > 3:
            raise ValueError("Too many elements/orbitals specified (max 3)")
        elif mode == "solo" and dos_plotter:
            raise ValueError("Solo mode plotting with DOS not supported")

        if dos_plotter:
            plt = pretty_subplot(
                1,
                2,
                width,
                height,
                sharex=False,
                dpi=dpi,
                plt=plt,
                gridspec_kw={
                    "width_ratios": [dos_aspect, 1],
                    "wspace": 0
                },
            )
            ax = plt.gcf().axes[0]
        else:
            plt = pretty_plot(width, height, dpi=dpi, plt=plt)
            ax = plt.gca()

        data = self.bs_plot_data(zero_to_efermi)
        nbranches = len(data["distances"])

        # Ensure we do spin up first, then spin down
        spins = sorted(self.bs.bands.keys(), key=lambda s: -s.value)
        if spin is not None and len(spins) == 1:
            raise ValueError(
                "Spin-selection only possible with spin-polarised "
                "calculation results")
        if spin is Spin.up:
            spins = [spins[0]]
        elif spin is Spin.down:
            spins = [spins[1]]

        proj = get_projections_by_branches(self.bs,
                                           selection,
                                           normalise=normalise)

        # nd is branch index
        for spin, nd in it.product(spins, range(nbranches)):

            # mask data to reduce plotting load
            bands = np.array(data["energy"][str(spin)][nd])
            mask = np.where(
                np.any(bands > ymin - 0.05, axis=1)
                & np.any(bands < ymax + 0.05, axis=1))
            distances = data["distances"][nd]
            bands = bands[mask]
            weights = [proj[nd][i][spin][mask] for i in range(len(selection))]

            if len(distances
                   ) > 2:  # Only interpolate if it makes sense to do so
                # interpolate band structure to improve smoothness
                temp_dists = np.linspace(distances[0], distances[-1],
                                         len(distances) * interpolate_factor)
                bands = interp1d(
                    distances,
                    bands,
                    axis=1,
                    bounds_error=False,
                    fill_value="extrapolate",
                )(temp_dists)
                weights = interp1d(
                    distances,
                    weights,
                    axis=2,
                    bounds_error=False,
                    fill_value="extrapolate",
                )(temp_dists)
                distances = temp_dists

            else:  # change from list to array if we skipped the scipy interpolation
                weights = np.array(weights)
                bands = np.array(bands)
                distances = np.array(distances)

            # sometimes VASP produces very small negative weights
            weights[weights < 0] = 0

            if mode == "rgb":

                # colours aren't used now but needed later for legend
                colours = ["#ff0000", "#00ff00", "#0000ff"]

                # if only two orbitals then just use red and blue
                if len(weights) == 2:
                    weights = np.insert(weights,
                                        1,
                                        np.zeros(weights[0].shape),
                                        axis=0)
                    colours = ["#ff0000", "#0000ff"]

                ls = "-" if spin == Spin.up else "--"
                lc = rgbline(
                    distances,
                    bands,
                    weights[0],
                    weights[1],
                    weights[2],
                    alpha=1,
                    linestyles=ls,
                    linewidth=(rcParams["lines.linewidth"] * 1.25),
                )
                ax.add_collection(lc)

            elif mode == "stacked":
                # TODO: Handle spin

                # use some nice custom colours first, then default colours
                colours = [
                    "#3952A3", "#FAA41A", "#67BC47", "#6ECCDD", "#ED2025"
                ]
                colour_series = rcParams["axes.prop_cycle"].by_key()["color"]
                colours.extend(colour_series)

                # very small circles look crap
                weights[weights < projection_cutoff] = 0

                distances = list(distances) * len(bands)
                bands = bands.flatten()
                zorders = range(-len(weights), 0)
                for w, c, z in zip(weights, colours, zorders):
                    ax.scatter(
                        distances,
                        bands,
                        c=c,
                        s=circle_size * w**2,
                        zorder=z,
                        rasterized=True,
                    )

        # plot the legend
        for c, spec in zip(colours, selection):
            if isinstance(spec, str):
                label = spec
            else:
                label = "{} ({})".format(spec[0], " + ".join(spec[1]))
            ax.scatter([-10000], [-10000],
                       c=c,
                       s=50,
                       label=label,
                       edgecolors="none")

        if dos_plotter:
            loc = 1
            anchor_point = (-0.2, 1)
        else:
            loc = 2
            anchor_point = (0.95, 1)

        ax.legend(
            bbox_to_anchor=anchor_point,
            loc=loc,
            frameon=False,
            handletextpad=0.1,
            scatterpoints=1,
        )

        # finish and tidy plot
        self._maketicks(ax, ylabel=ylabel)
        self._makeplot(
            ax,
            plt.gcf(),
            data,
            zero_to_efermi=zero_to_efermi,
            vbm_cbm_marker=vbm_cbm_marker,
            width=width,
            height=height,
            ymin=ymin,
            ymax=ymax,
            dos_plotter=dos_plotter,
            dos_options=dos_options,
            dos_label=dos_label,
            plot_dos_legend=plot_dos_legend,
            aspect=aspect,
        )
        return plt
Esempio n. 6
0
    def get_plot(
        self,
        n_idx,
        t_idx,
        zero_to_efermi=True,
        estep=0.01,
        line_density=100,
        height=6,
        width=6,
        emin=None,
        emax=None,
        ylabel="Energy (eV)",
        plt=None,
        aspect=None,
        distance_factor=10,
        kpath=None,
        style=None,
        no_base_style=False,
        fonts=None,
    ):
        interpolater = self._get_interpolater(n_idx, t_idx)

        bs, prop = interpolater.get_line_mode_band_structure(
            line_density=line_density,
            return_other_properties=True,
            kpath=kpath)

        fd_emin, fd_emax = self.fd_cutoffs
        if not emin:
            emin = fd_emin * hartree_to_ev
            if zero_to_efermi:
                emin -= bs.efermi

        if not emax:
            emax = fd_emax * hartree_to_ev
            if zero_to_efermi:
                emax -= bs.efermi

        logger.info("Plotting band structure")
        plt = pretty_plot(width=width, height=height, plt=plt)
        ax = plt.gca()

        if zero_to_efermi:
            bs.bands = {s: b - bs.efermi for s, b in bs.bands.items()}
            bs.efermi = 0

        bs_plotter = BSPlotter(bs)
        plot_data = bs_plotter.bs_plot_data(zero_to_efermi=zero_to_efermi)

        energies = np.linspace(emin, emax, int((emax - emin) / estep))
        distances = np.array([d for x in plot_data["distances"] for d in x])

        # rates are currently log(rate)
        rates = {}
        for spin, spin_data in prop.items():
            rates[spin] = spin_data["rates"]
            rates[spin][rates[spin] <= 0] = np.min(
                rates[spin][rates[spin] > 0])
            rates[spin][rates[spin] >= 15] = 15

        interp_distances = np.linspace(distances.min(), distances.max(),
                                       int(len(distances) * distance_factor))

        window = np.min([len(distances) - 2, 71])
        window += window % 2 + 1
        mesh_data = np.full((len(interp_distances), len(energies)), 1e-2)

        for spin in self.spins:
            for spin_energies, spin_rates in zip(bs.bands[spin], rates[spin]):
                interp_energies = interp1d(distances,
                                           spin_energies)(interp_distances)
                spin_rates = savgol_filter(spin_rates, window, 3)
                interp_rates = interp1d(distances,
                                        spin_rates)(interp_distances)
                linewidths = 10**interp_rates * hbar / 2

                for d_idx in range(len(interp_distances)):
                    energy = interp_energies[d_idx]
                    linewidth = linewidths[d_idx]

                    broadening = lorentzian(energies, energy, linewidth)
                    mesh_data[d_idx] = np.maximum(broadening, mesh_data[d_idx])
                    mesh_data[d_idx] = np.maximum(broadening, mesh_data[d_idx])

        ax.pcolormesh(
            interp_distances,
            energies,
            mesh_data.T,
            rasterized=True,
            norm=LogNorm(vmin=mesh_data.min(), vmax=mesh_data.max()),
        )

        _maketicks(ax, bs_plotter, ylabel=ylabel)
        _makeplot(
            ax,
            plot_data,
            bs,
            zero_to_efermi=zero_to_efermi,
            width=width,
            height=height,
            ymin=emin,
            ymax=emax,
            aspect=aspect,
        )
        return plt
Esempio n. 7
0
    def get_plot(self,
                 ymin=None,
                 ymax=None,
                 width=6.,
                 height=6.,
                 dpi=400,
                 plt=None,
                 fonts=None,
                 dos=None,
                 dos_aspect=3,
                 color=None):
        """Get a :obj:`matplotlib.pyplot` object of the phonon band structure.

        Args:
            ymin (:obj:`float`, optional): The minimum energy on the y-axis.
            ymax (:obj:`float`, optional): The maximum energy on the y-axis.
            width (:obj:`float`, optional): The width of the plot.
            height (:obj:`float`, optional): The height of the plot.
            dpi (:obj:`int`, optional): The dots-per-inch (pixel density) for
                the image.
            fonts (:obj:`list`, optional): Fonts to use in the plot. Can be a
                a single font, specified as a :obj:`str`, or several fonts,
                specified as a :obj:`list` of :obj:`str`.
            plt (:obj:`matplotlib.pyplot`, optional): A
                :obj:`matplotlib.pyplot` object to use for plotting.
            dos (:obj:`np.ndarray`): 2D Numpy array of total DOS data
            dos_aspect (float): Width division for vertical DOS
            color (:obj:`str` or :obj:`tuple`, optional): Line/fill colour in
                any matplotlib-accepted format

        Returns:
            :obj:`matplotlib.pyplot`: The phonon band structure plot.
        """

        if color is None:
            color = 'C2'  # Default to first colour in matplotlib series

        if dos is not None:
            plt = pretty_subplot(1,
                                 2,
                                 width,
                                 height,
                                 sharex=False,
                                 sharey=True,
                                 dpi=dpi,
                                 plt=plt,
                                 fonts=fonts,
                                 gridspec_kw={
                                     'width_ratios': [dos_aspect, 1],
                                     'wspace': 0
                                 })
            ax = plt.gcf().axes[0]
        else:
            plt = pretty_plot(width, height, dpi=dpi, plt=plt, fonts=fonts)
            ax = plt.gca()

        data = self.bs_plot_data()
        dists = data['distances']
        freqs = data['frequency']

        # nd is branch index, nb is band index, nk is kpoint index
        for nd, nb in itertools.product(range(len(data['distances'])),
                                        range(self._nb_bands)):
            f = freqs[nd][nb]

            # plot band data
            ax.plot(dists[nd], f, ls='-', c=color, linewidth=band_linewidth)

        self._maketicks(ax)
        self._makeplot(ax,
                       plt.gcf(),
                       data,
                       width=width,
                       height=height,
                       ymin=ymin,
                       ymax=ymax,
                       dos=dos,
                       color=color)
        plt.tight_layout()
        plt.subplots_adjust(wspace=0)

        return plt
Esempio n. 8
0
    def get_plot(self, units='THz', ymin=None, ymax=None, width=None,
                 height=None, dpi=None, plt=None, fonts=None, dos=None,
                 dos_aspect=3, color=None, style=None, no_base_style=False,
                 from_json=None, legend=None):
        """Get a :obj:`matplotlib.pyplot` object of the phonon band structure.

        Args:
            units (:obj:`str`, optional): Units of phonon frequency. Accepted
                (case-insensitive) values are Thz, cm-1, eV, meV.
            ymin (:obj:`float`, optional): The minimum energy on the y-axis.
            ymax (:obj:`float`, optional): The maximum energy on the y-axis.
            width (:obj:`float`, optional): The width of the plot.
            height (:obj:`float`, optional): The height of the plot.
            dpi (:obj:`int`, optional): The dots-per-inch (pixel density) for
                the image.
            fonts (:obj:`list`, optional): Fonts to use in the plot. Can be a
                a single font, specified as a :obj:`str`, or several fonts,
                specified as a :obj:`list` of :obj:`str`.
            plt (:obj:`matplotlib.pyplot`, optional): A
                :obj:`matplotlib.pyplot` object to use for plotting.
            dos (:obj:`np.ndarray`): 2D Numpy array of total DOS data
            dos_aspect (float): Width division for vertical DOS
            color (:obj:`str` or :obj:`tuple`, optional): Line/fill colour in
                any matplotlib-accepted format
            style (:obj:`list`, :obj:`str`, or :obj:`dict`): Any matplotlib
                style specifications, to be composed on top of Sumo base
                style.
            no_base_style (:obj:`bool`, optional): Prevent use of sumo base
                style. This can make alternative styles behave more
                predictably.
            from_json (:obj:`list` or :obj:`None`, optional): List of paths to
                :obj:`pymatgen.phonon.bandstructure.PhononBandStructureSymmline`
                JSON dump files. These are used to generate additional plots
                displayed under the data attached to this plotter.
                The k-point path should be the same as the main plot; the
                reciprocal lattice is adjusted to fit the scaling of the main
                data input.

        Returns:
            :obj:`matplotlib.pyplot`: The phonon band structure plot.
        """
        if from_json is None:
            from_json = []

        if legend is None:
            legend = [''] * (len(from_json) + 1)
        else:
            if len(legend) == 1 + len(from_json):
                pass
            elif len(legend) == len(from_json):
                legend = [''] + list(legend)
            else:
                raise ValueError('Inappropriate number of legend entries')

        if color is None:
            color = 'C0'  # Default to first colour in matplotlib series

        if dos is not None:
            plt = pretty_subplot(1, 2, width=width, height=height,
                                 sharex=False, sharey=True, dpi=dpi, plt=plt,
                                 gridspec_kw={'width_ratios': [dos_aspect, 1],
                                              'wspace': 0})
            ax = plt.gcf().axes[0]
        else:
            plt = pretty_plot(width, height, dpi=dpi, plt=plt)
            ax = plt.gca()

        def _plot_lines(data, ax, color=None, alpha=1, zorder=1):
            """Pull data from any PhononBSPlotter and add to axis"""
            dists = data['distances']
            freqs = data['frequency']

            # nd is branch index, nb is band index, nk is kpoint index
            for nd, nb in itertools.product(range(len(data['distances'])),
                                            range(self._nb_bands)):
                f = freqs[nd][nb]

                # plot band data
                ax.plot(dists[nd], f, ls='-', c=color,
                        zorder=zorder)

        data = self.bs_plot_data()
        _plot_lines(data, ax, color=color)

        for i, bs_json in enumerate(from_json):
            with open(bs_json, 'rt') as f:
                json_data = json.load(f)
                json_data['lattice_rec'] = json.loads(
                    self._bs.lattice_rec.to_json())
                bs = PhononBandStructureSymmLine.from_dict(json_data)

                # bs.lattice_rec = self._bs.lattice_rec
                # raise Exception(bs.qpoints)
            json_plotter = PhononBSPlotter(bs)
            json_data = json_plotter.bs_plot_data()
            if json_plotter._nb_bands != self._nb_bands:
                raise Exception('Number of bands in {} does not match '
                                'main plot'.format(bs_json))
            _plot_lines(json_data, ax,
                        color='C{}'.format(i + 1),
                        zorder=0.5)

        if any(legend):  # Don't show legend if all entries are empty string
            from matplotlib.lines import Line2D
            ax.legend([Line2D([0], [0], color='C{}'.format(i))
                       for i in range(len(legend))],
                       legend)

        self._maketicks(ax, units=units)
        self._makeplot(ax, plt.gcf(), data, width=width, height=height,
                       ymin=ymin, ymax=ymax, dos=dos, color=color)
        plt.tight_layout()
        plt.subplots_adjust(wspace=0)

        return plt
Esempio n. 9
0
    def get_plot(self,
                 width=6.,
                 height=6.,
                 xmin=0.,
                 xmax=None,
                 ymin=0,
                 ymax=1e5,
                 colours=None,
                 dpi=400,
                 plt=None,
                 fonts=None):
        """Get a :obj:`matplotlib.pyplot` object of the optical spectra.

        Args:
            width (:obj:`float`, optional): The width of the plot.
            height (:obj:`float`, optional): The height of the plot.
            xmin (:obj:`float`, optional): The minimum energy on the x-axis.
            xmax (:obj:`float`, optional): The maximum energy on the x-axis.
            ymin (:obj:`float`, optional): The minimum absorption intensity on
                the y-axis.
            ymax (:obj:`float`, optional): The maximum absorption intensity on
                the y-axis.
            colours (:obj:`list`, optional): A :obj:`list` of colours to use in
                the plot. The colours can be specified as a hex code, set of
                rgb values, or any other format supported by matplotlib.
            dpi (:obj:`int`, optional): The dots-per-inch (pixel density) for
                the image.
            plt (:obj:`matplotlib.pyplot`, optional): A
                :obj:`matplotlib.pyplot` object to use for plotting.
            fonts (:obj:`list`, optional): Fonts to use in the plot. Can be a
                a single font, specified as a :obj:`str`, or several fonts,
                specified as a :obj:`list` of :obj:`str`.

        Returns:
            :obj:`matplotlib.pyplot`: The plot of optical spectra.
        """
        plt = pretty_plot(width=width,
                          height=height,
                          dpi=dpi,
                          plt=plt,
                          fonts=fonts)
        ax = plt.gca()

        colours = colours + optics_colours if colours else optics_colours
        for (ener, alpha), abs_label, bg, c in zip(self._abs_data, self._label,
                                                   self._band_gap, colours):
            if len(alpha.shape) == 1:
                # if averaged optics only plot one line
                ax.plot(ener, alpha, lw=line_width, label=abs_label, c=c)

            else:
                data = zip(range(3), ['xx', 'yy', 'zz'], ['-', '--', '-.'])

                for direction_id, direction_label, ls in data:
                    if not abs_label:
                        label = direction_label
                    else:
                        label = r'{}$_\mathregular{{{}}}$'
                        label.format(direction_label, direction_id)

                    ax.plot(ener,
                            alpha[:, direction_id],
                            lw=line_width,
                            ls=ls,
                            label=label,
                            c=c)

            if bg:
                # plot band gap line
                ax.plot([bg, bg], [ymin, ymax], lw=line_width, ls=':', c=c)

        xmax = xmax if xmax else self._xmax
        ax.set_xlim(xmin, xmax)
        ax.set_ylim(ymin, ymax)

        ax.tick_params(axis='x', which='both', top='off')
        ax.tick_params(axis='x', which='both', right='off')
        ax.yaxis.set_major_formatter(FuncFormatter(power_tick))
        ax.yaxis.set_major_locator(MaxNLocator(4))

        ax.set_xlabel('Energy (eV)')
        ax.set_ylabel(r'Absorption (cm$^\mathregular{-1}$)')

        if (not np.all(np.array(self._label) == '')
                or len(np.array(self._abs_data[0][1]).shape) > 1):
            ax.legend(loc='best',
                      frameon=False,
                      ncol=1,
                      prop={'size': label_size - 3})

        x0, x1 = ax.get_xlim()
        y0, y1 = ax.get_ylim()
        ax.set_aspect((height / width) * ((x1 - x0) / (y1 - y0)))
        plt.tight_layout()

        return plt
Esempio n. 10
0
    def get_plot(self,
                 subplot=False,
                 width=6.,
                 height=8.,
                 xmin=-6.,
                 xmax=6.,
                 yscale=1,
                 colours=None,
                 plot_total=True,
                 legend_on=True,
                 num_columns=2,
                 legend_frame_on=False,
                 legend_cutoff=3,
                 xlabel='Energy (eV)',
                 ylabel='Arb. units',
                 dpi=400,
                 fonts=None,
                 plt=None):
        """Get a :obj:`matplotlib.pyplot` object of the density of states.

        Args:
            subplot (:obj:`bool`, optional): Plot the density of states for
                each element on separate subplots. Defaults to ``False``.
            width (:obj:`float`, optional): The width of the plot.
            height (:obj:`float`, optional): The height of the plot.
            xmin (:obj:`float`, optional): The minimum energy on the x-axis.
            xmax (:obj:`float`, optional): The maximum energy on the x-axis.
            yscale (:obj:`float`, optional): Scaling factor for the y-axis.
            colours (:obj:`dict`, optional): Use custom colours for specific
                element and orbital combinations. Specified as a :obj:`dict` of
                :obj:`dict` of the colours. For example::

                    {
                        'Sn': {'s': 'r', 'p': 'b'},
                        'O': {'s': '#000000'}
                    }

                The colour can be a hex code, series of rgb value, or any other
                format supported by matplotlib.
            plot_total (:obj:`bool`, optional): Plot the total density of
                states. Defaults to ``True``.
            legend_on (:obj:`bool`, optional): Plot the graph legend. Defaults
                to ``True``.
            num_columns (:obj:`int`, optional): The number of columns in the
                legend.
            legend_frame_on (:obj:`bool`, optional): Plot a frame around the
                graph legend. Defaults to ``False``.
            legend_cutoff (:obj:`float`, optional): The cut-off (in % of the
                maximum density of states within the plotting range) for an
                elemental orbital to be labelled in the legend. This prevents
                the legend from containing labels for orbitals that have very
                little contribution in the plotting range.
            xlabel (:obj:`str`, optional): Label/units for x-axis (i.e. energy)
            ylabel (:obj:`str`, optional): Label/units for y-axis (i.e. DOS)
            dpi (:obj:`int`, optional): The dots-per-inch (pixel density) for
                the image.
            fonts (:obj:`list`, optional): Fonts to use in the plot. Can be a
                a single font, specified as a :obj:`str`, or several fonts,
                specified as a :obj:`list` of :obj:`str`.
            plt (:obj:`matplotlib.pyplot`, optional): A
                :obj:`matplotlib.pyplot` object to use for plotting.

        Returns:
            :obj:`matplotlib.pyplot`: The density of states plot.
        """
        plot_data = self.dos_plot_data(yscale=yscale,
                                       xmin=xmin,
                                       xmax=xmax,
                                       colours=colours,
                                       plot_total=plot_total,
                                       legend_cutoff=legend_cutoff,
                                       subplot=subplot)

        if subplot:
            nplots = len(plot_data['lines'])
            plt = pretty_subplot(nplots,
                                 1,
                                 width=width,
                                 height=height,
                                 dpi=dpi,
                                 plt=plt,
                                 fonts=fonts)
        else:
            plt = pretty_plot(width=width,
                              height=height,
                              dpi=dpi,
                              plt=plt,
                              fonts=fonts)

        mask = plot_data['mask']
        energies = plot_data['energies'][mask]
        fig = plt.gcf()
        lines = plot_data['lines']
        spins = [Spin.up] if len(lines[0][0]['dens']) == 1 else \
            [Spin.up, Spin.down]

        for i, line_set in enumerate(plot_data['lines']):
            if subplot:
                ax = fig.axes[i]
            else:
                ax = plt.gca()

            for line, spin in itertools.product(line_set, spins):
                if spin == Spin.up:
                    label = line['label']
                    densities = line['dens'][spin][mask]
                elif spin == Spin.down:
                    label = ""
                    densities = -line['dens'][spin][mask]
                ax.fill_between(energies,
                                densities,
                                lw=0,
                                facecolor=line['colour'],
                                alpha=line['alpha'])
                ax.plot(energies,
                        densities,
                        label=label,
                        color=line['colour'],
                        lw=line_width)

            ax.set_ylim(plot_data['ymin'], plot_data['ymax'])
            ax.set_xlim(xmin, xmax)

            ax.tick_params(axis='x', which='both', top='off')
            ax.tick_params(axis='y',
                           which='both',
                           labelleft='off',
                           labelright='off',
                           left='off',
                           right='off')

            loc = 'upper right' if subplot else 'best'
            ncol = 1 if subplot else num_columns
            if legend_on:
                ax.legend(loc=loc,
                          frameon=legend_frame_on,
                          ncol=ncol,
                          handlelength=2,
                          prop={'size': label_size - 3})

        # no add axis labels and sort out ticks
        if subplot:
            ax.set_xlabel(xlabel, fontsize=label_size)
            fig.subplots_adjust(hspace=0)
            plt.setp([a.get_xticklabels() for a in fig.axes[:-1]],
                     visible=False)
            fig.text(0.08,
                     0.5,
                     ylabel,
                     fontsize=label_size,
                     ha='left',
                     va='center',
                     rotation='vertical',
                     transform=ax.transAxes)
        else:
            ax.set_xlabel(xlabel)
            ax.set_ylabel(ylabel)

        return plt
Esempio n. 11
0
    def get_plot(self,
                 zero_to_efermi=True,
                 ymin=-6.,
                 ymax=6.,
                 width=6.,
                 height=6.,
                 vbm_cbm_marker=False,
                 ylabel='Energy (eV)',
                 dpi=400,
                 plt=None,
                 dos_plotter=None,
                 dos_options=None,
                 dos_label=None,
                 dos_aspect=3,
                 fonts=None):
        """Get a :obj:`matplotlib.pyplot` object of the band structure.

        If the system is spin polarised, blue lines are spin up, red lines are
        spin down. For metals, all bands are coloured blue. For semiconductors,
        blue lines indicate valence bands and orange lines indicates conduction
        bands.

        Args:
            zero_to_efermi (:obj:`bool`): Normalise the plot such that the
                valence band maximum is set as 0 eV.
            ymin (:obj:`float`, optional): The minimum energy on the y-axis.
            ymax (:obj:`float`, optional): The maximum energy on the y-axis.
            width (:obj:`float`, optional): The width of the plot.
            height (:obj:`float`, optional): The height of the plot.
            vbm_cbm_marker (:obj:`bool`, optional): Plot markers to indicate
                the VBM and CBM locations.
            ylabel (:obj:`str`, optional): y-axis (i.e. energy) label/units
            dpi (:obj:`int`, optional): The dots-per-inch (pixel density) for
                the image.
            plt (:obj:`matplotlib.pyplot`, optional): A
                :obj:`matplotlib.pyplot` object to use for plotting.
            dos_plotter (:obj:`~sumo.plotting.dos_plotter.SDOSPlotter`, \
                optional): Plot the density of states alongside the band
                structure. This should be a
                :obj:`~sumo.plotting.dos_plotter.SDOSPlotter` object
                initialised with the data to plot.
            dos_options (:obj:`dict`, optional): The options for density of
                states plotting. This should be formatted as a :obj:`dict`
                containing any of the following keys:

                    "yscale" (:obj:`float`)
                        Scaling factor for the y-axis.
                    "xmin" (:obj:`float`)
                        The minimum energy to mask the energy and density of
                        states data (reduces plotting load).
                    "xmax" (:obj:`float`)
                        The maximum energy to mask the energy and density of
                        states data (reduces plotting load).
                    "colours" (:obj:`dict`)
                        Use custom colours for specific element and orbital
                        combinations. Specified as a :obj:`dict` of
                        :obj:`dict` of the colours. For example::

                            {
                                'Sn': {'s': 'r', 'p': 'b'},
                                'O': {'s': '#000000'}
                            }

                        The colour can be a hex code, series of rgb value, or
                        any other format supported by matplotlib.
                    "plot_total" (:obj:`bool`)
                        Plot the total density of states. Defaults to ``True``.
                    "legend_cutoff" (:obj:`float`)
                        The cut-off (in % of the maximum density of states
                        within the plotting range) for an elemental orbital to
                        be labelled in the legend. This prevents the legend
                        from containing labels for orbitals that have very
                        little contribution in the plotting range.
                    "subplot" (:obj:`bool`)
                        Plot the density of states for each element on separate
                        subplots. Defaults to ``False``.

            dos_label (:obj:`str`, optional): DOS axis label/units
            dos_aspect (:obj:`float`, optional): Aspect ratio for the band
                structure and density of states subplot. For example,
                ``dos_aspect = 3``, results in a ratio of 3:1, for the band
                structure:dos plots.
            fonts (:obj:`list`, optional): Fonts to use in the plot. Can be a
                a single font, specified as a :obj:`str`, or several fonts,
                specified as a :obj:`list` of :obj:`str`.

        Returns:
            :obj:`matplotlib.pyplot`: The electronic band structure plot.
        """
        if dos_plotter:
            plt = pretty_subplot(1,
                                 2,
                                 width,
                                 height,
                                 sharex=False,
                                 dpi=dpi,
                                 plt=plt,
                                 fonts=fonts,
                                 gridspec_kw={
                                     'width_ratios': [dos_aspect, 1],
                                     'wspace': 0
                                 })
            ax = plt.gcf().axes[0]
        else:
            plt = pretty_plot(width, height, dpi=dpi, plt=plt, fonts=fonts)
            ax = plt.gca()

        data = self.bs_plot_data(zero_to_efermi)
        dists = data['distances']
        eners = data['energy']

        if self._bs.is_spin_polarized or self._bs.is_metal():
            is_vb = True
        else:
            is_vb = self._bs.bands[Spin.up] <= self._bs.get_vbm()['energy']

        # nd is branch index, nb is band index, nk is kpoint index
        for nd, nb in it.product(range(len(data['distances'])),
                                 range(self._nb_bands)):
            e = eners[nd][str(Spin.up)][nb]

            # this check is very slow but works for now
            # colour valence bands blue and conduction bands orange
            if (self._bs.is_spin_polarized or self._bs.is_metal()
                    or np.all(is_vb[nb])):
                c = '#3953A4'
            else:
                c = '#FAA316'

            # plot band data
            ax.plot(dists[nd], e, ls='-', c=c, linewidth=band_linewidth)
            if self._bs.is_spin_polarized:
                e = eners[nd][str(Spin.down)][nb]
                ax.plot(dists[nd], e, 'r--', linewidth=band_linewidth)

        self._maketicks(ax, ylabel=ylabel)
        self._makeplot(ax,
                       plt.gcf(),
                       data,
                       zero_to_efermi=zero_to_efermi,
                       vbm_cbm_marker=vbm_cbm_marker,
                       width=width,
                       height=height,
                       ymin=ymin,
                       ymax=ymax,
                       dos_plotter=dos_plotter,
                       dos_options=dos_options,
                       dos_label=dos_label)
        return plt
Esempio n. 12
0
    def get_projected_plot(self,
                           selection,
                           mode='rgb',
                           interpolate_factor=4,
                           circle_size=150,
                           projection_cutoff=0.001,
                           zero_to_efermi=True,
                           ymin=-6.,
                           ymax=6.,
                           width=6.,
                           height=6.,
                           vbm_cbm_marker=False,
                           ylabel='Energy (eV)',
                           dpi=400,
                           plt=None,
                           dos_plotter=None,
                           dos_options=None,
                           dos_label=None,
                           dos_aspect=3,
                           fonts=None):
        """Get a :obj:`matplotlib.pyplot` of the projected band structure.

        If the system is spin polarised and ``mode = 'rgb'`` spin up and spin
        down bands are differentiated by solid and dashed lines, respectively.
        For the other modes, spin up and spin down are plotted separately.

        Args:
            selection (list): A list of :obj:`tuple` or :obj:`string`
                identifying which elements and orbitals to project on to the
                band structure. These can be specified by both element and
                orbital, for example, the following will project the Bi s, p
                and S p orbitals::

                    [('Bi', 's'), ('Bi', 'p'), ('S', 'p')]

                If just the element is specified then all the orbitals of
                that element are combined. For example, to sum all the S
                orbitals::

                    [('Bi', 's'), ('Bi', 'p'), 'S']

                You can also choose to sum particular orbitals by supplying a
                :obj:`tuple` of orbitals. For example, to sum the S s, p, and
                d orbitals into a single projection::

                  [('Bi', 's'), ('Bi', 'p'), ('S', ('s', 'p', 'd'))]

                If ``mode = 'rgb'``, a maximum of 3 orbital/element
                combinations can be plotted simultaneously (one for red, green
                and blue), otherwise an unlimited number of elements/orbitals
                can be selected.
            mode (:obj:`str`, optional): Type of projected band structure to
                plot. Options are:

                    "rgb"
                        The band structure line color depends on the character
                        of the band. Each element/orbital contributes either
                        red, green or blue with the corresponding line colour a
                        mixture of all three colours. This mode only supports
                        up to 3 elements/orbitals combinations. The order of
                        the ``selection`` :obj:`tuple` determines which colour
                        is used for each selection.
                    "stacked"
                        The element/orbital contributions are drawn as a
                        series of stacked circles, with the colour depending on
                        the composition of the band. The size of the circles
                        can be scaled using the ``circle_size`` option.

            interpolate_factor (:obj:`int`, optional): The factor by which to
                interpolate the band structure (necessary to make smooth
                lines). A larger number indicates greater interpolation.
            circle_size (:obj:`float`, optional): The area of the circles used
                when ``mode = 'stacked'``.
            projection_cutoff (:obj:`float`): Don't plot projections with
                intensities below this number. This option is useful for
                stacked plots, where small projections clutter the plot.
            zero_to_efermi (:obj:`bool`): Normalise the plot such that the
                valence band maximum is set as 0 eV.
            ymin (:obj:`float`, optional): The minimum energy on the y-axis.
            ymax (:obj:`float`, optional): The maximum energy on the y-axis.
            width (:obj:`float`, optional): The width of the plot.
            height (:obj:`float`, optional): The height of the plot.
            vbm_cbm_marker (:obj:`bool`, optional): Plot markers to indicate
                the VBM and CBM locations.
            ylabel (:obj:`str`, optional): y-axis (i.e. energy) label/units
            dpi (:obj:`int`, optional): The dots-per-inch (pixel density) for
                the image.
            plt (:obj:`matplotlib.pyplot`, optional): A
                :obj:`matplotlib.pyplot` object to use for plotting.
            dos_plotter (:obj:`~sumo.plotting.dos_plotter.SDOSPlotter`, \
                optional): Plot the density of states alongside the band
                structure. This should be a
                :obj:`~sumo.plotting.dos_plotter.SDOSPlotter` object
                initialised with the data to plot.
            dos_options (:obj:`dict`, optional): The options for density of
                states plotting. This should be formatted as a :obj:`dict`
                containing any of the following keys:

                    "yscale" (:obj:`float`)
                        Scaling factor for the y-axis.
                    "xmin" (:obj:`float`)
                        The minimum energy to mask the energy and density of
                        states data (reduces plotting load).
                    "xmax" (:obj:`float`)
                        The maximum energy to mask the energy and density of
                        states data (reduces plotting load).
                    "colours" (:obj:`dict`)
                        Use custom colours for specific element and orbital
                        combinations. Specified as a :obj:`dict` of
                        :obj:`dict` of the colours. For example::

                           {
                                'Sn': {'s': 'r', 'p': 'b'},
                                'O': {'s': '#000000'}
                            }

                        The colour can be a hex code, series of rgb value, or
                        any other format supported by matplotlib.
                    "plot_total" (:obj:`bool`)
                        Plot the total density of states. Defaults to ``True``.
                    "legend_cutoff" (:obj:`float`)
                        The cut-off (in % of the maximum density of states
                        within the plotting range) for an elemental orbital to
                        be labelled in the legend. This prevents the legend
                        from containing labels for orbitals that have very
                        little contribution in the plotting range.
                    "subplot" (:obj:`bool`)
                        Plot the density of states for each element on separate
                        subplots. Defaults to ``False``.

            dos_label (:obj:`str`, optional): DOS axis label/units
            fonts (:obj:`list`, optional): Fonts to use in the plot. Can be a
                a single font, specified as a :obj:`str`, or several fonts,
                specified as a :obj:`list` of :obj:`str`.

        Returns:
            :obj:`matplotlib.pyplot`: The projected electronic band structure
            plot.
        """
        if mode == 'rgb' and len(selection) > 3:
            raise ValueError('Too many elements/orbitals specified (max 3)')
        elif mode == 'solo' and dos_plotter:
            raise ValueError('Solo mode plotting with DOS not supported')

        if dos_plotter:
            plt = pretty_subplot(1,
                                 2,
                                 width,
                                 height,
                                 sharex=False,
                                 dpi=dpi,
                                 plt=plt,
                                 fonts=fonts,
                                 gridspec_kw={
                                     'width_ratios': [dos_aspect, 1],
                                     'wspace': 0
                                 })
            ax = plt.gcf().axes[0]
        else:
            plt = pretty_plot(width, height, dpi=dpi, plt=plt, fonts=fonts)
            ax = plt.gca()

        data = self.bs_plot_data(zero_to_efermi)
        nbranches = len(data['distances'])

        # Ensure we do spin up first, then spin down
        spins = sorted(self._bs.bands.keys(), key=lambda spin: -spin.value)

        proj = get_projections_by_branches(self._bs,
                                           selection,
                                           normalise='select')

        # nd is branch index
        for spin, nd in it.product(spins, range(nbranches)):

            # mask data to reduce plotting load
            bands = np.array(data['energy'][nd][str(spin)])
            mask = np.where(
                np.any(bands > ymin - 0.05, axis=1)
                & np.any(bands < ymax + 0.05, axis=1))
            distances = data['distances'][nd]
            bands = bands[mask]
            weights = [proj[nd][i][spin][mask] for i in range(len(selection))]

            # interpolate band structure to improve smoothness
            dx = (distances[1] - distances[0]) / interpolate_factor
            temp_dists = np.arange(distances[0], distances[-1], dx)
            bands = interp1d(distances, bands, axis=1)(temp_dists)
            weights = interp1d(distances, weights, axis=2)(temp_dists)
            distances = temp_dists

            # sometimes VASP produces very small negative weights
            weights[weights < 0] = 0

            if mode == 'rgb':

                # colours aren't used now but needed later for legend
                colours = ['#ff0000', '#00ff00', '#0000ff']

                # if only two orbitals then just use red and blue
                if len(weights) == 2:
                    weights = np.insert(weights,
                                        1,
                                        np.zeros(weights[0].shape),
                                        axis=0)
                    colours = ['#ff0000', '#0000ff']

                ls = '-' if spin == Spin.up else '--'
                lc = rgbline(distances,
                             bands,
                             weights[0],
                             weights[1],
                             weights[2],
                             alpha=1,
                             linestyles=ls,
                             linewidth=2.5)
                ax.add_collection(lc)

            elif mode == 'stacked':
                # TODO: Handle spin

                # use some nice custom colours first, then default colours
                colours = [
                    '#3952A3', '#FAA41A', '#67BC47', '#6ECCDD', '#ED2025'
                ]
                colours.extend(np.array(default_colours) / 255)

                # very small circles look crap
                weights[weights < projection_cutoff] = 0

                distances = list(distances) * len(bands)
                bands = bands.flatten()
                zorders = range(-len(weights), 0)
                for w, c, z in zip(weights, colours, zorders):
                    ax.scatter(distances,
                               bands,
                               c=c,
                               s=circle_size * w**2,
                               zorder=z,
                               rasterized=True)

        # plot the legend
        for c, spec in zip(colours, selection):
            if type(spec) == str:
                label = spec
            else:
                label = '{} ({})'.format(spec[0], " + ".join(spec[1]))
            ax.scatter([-10000], [-10000],
                       c=c,
                       s=50,
                       label=label,
                       edgecolors='none')

        if dos_plotter:
            loc = 1
            anchor_point = (-0.2, 1)
        else:
            loc = 2
            anchor_point = (0.95, 1)

        ax.legend(bbox_to_anchor=anchor_point,
                  loc=loc,
                  frameon=False,
                  prop={'size': label_size - 2},
                  handletextpad=0.1,
                  scatterpoints=1)

        # finish and tidy plot
        self._maketicks(ax, ylabel=ylabel)
        self._makeplot(ax,
                       plt.gcf(),
                       data,
                       zero_to_efermi=zero_to_efermi,
                       vbm_cbm_marker=vbm_cbm_marker,
                       width=width,
                       height=height,
                       ymin=ymin,
                       ymax=ymax,
                       dos_plotter=dos_plotter,
                       dos_options=dos_options,
                       dos_label=dos_label)
        return plt
Esempio n. 13
0
    def get_plot(self, width=None, height=None, xmin=0., xmax=None, ymin=0,
                 ymax=1e5, colours=None, dpi=400, plt=None, fonts=None,
                 style=None, no_base_style=False):
        """Get a :obj:`matplotlib.pyplot` object of the optical spectra.

        Args:
            width (:obj:`float`, optional): The width of the plot.
            height (:obj:`float`, optional): The height of the plot.
            xmin (:obj:`float`, optional): The minimum energy on the x-axis.
            xmax (:obj:`float`, optional): The maximum energy on the x-axis.
            ymin (:obj:`float`, optional): The minimum absorption intensity on
                the y-axis.
            ymax (:obj:`float`, optional): The maximum absorption intensity on
                the y-axis.
            colours (:obj:`list`, optional): A :obj:`list` of colours to use in
                the plot. The colours can be specified as a hex code, set of
                rgb values, or any other format supported by matplotlib.
            dpi (:obj:`int`, optional): The dots-per-inch (pixel density) for
                the image.
            plt (:obj:`matplotlib.pyplot`, optional): A
                :obj:`matplotlib.pyplot` object to use for plotting.
            fonts (:obj:`list`, optional): Fonts to use in the plot. Can be a
                a single font, specified as a :obj:`str`, or several fonts,
                specified as a :obj:`list` of :obj:`str`.
            style (:obj:`list`, :obj:`str`, or :obj:`dict`): Any matplotlib
                style specifications, to be composed on top of Sumo base
                style.
            no_base_style (:obj:`bool`, optional): Prevent use of sumo base
                style. This can make alternative styles behave more
                predictably.

        Returns:
            :obj:`matplotlib.pyplot`: The plot of optical spectra.
        """
        plt = pretty_plot(width=width, height=height, dpi=dpi, plt=plt)
        ax = plt.gca()

        optics_colours = rcParams['axes.prop_cycle'].by_key()['color']
        if colours is not None:
            optics_colours = colours + optics_colours

        for (ener, alpha), abs_label, bg, c in zip(self._abs_data,
                                                   self._label,
                                                   self._band_gap,
                                                   optics_colours):
            if len(alpha.shape) == 1:
                # if averaged optics only plot one line
                ax.plot(ener, alpha, label=abs_label, c=c)

            else:
                data = zip(range(3), ['xx', 'yy', 'zz'], ['-', '--', '-.'])

                for direction_id, direction_label, ls in data:
                    if not abs_label:
                        label = direction_label
                    else:
                        label = r'{}$_\mathregular{{{}}}$'
                        label.format(direction_label, direction_id)

                    ax.plot(ener, alpha[:, direction_id], ls=ls,
                            label=label, c=c)

            if bg:
                # plot band gap line
                ax.plot([bg, bg], [ymin, ymax], ls=':', c=c)

        xmax = xmax if xmax else self._xmax
        ax.set_xlim(xmin, xmax)
        ax.set_ylim(ymin, ymax)

        ax.yaxis.set_major_formatter(FuncFormatter(power_tick))
        ax.yaxis.set_major_locator(MaxNLocator(5))
        ax.xaxis.set_major_locator(MaxNLocator(3))
        ax.yaxis.set_minor_locator(AutoMinorLocator(2))
        ax.xaxis.set_minor_locator(AutoMinorLocator(2))

        ax.set_xlabel('Energy (eV)')
        ax.set_ylabel(r'Absorption (cm$^\mathregular{-1}$)')

        if (not np.all(np.array(self._label) == '')
                or len(np.array(self._abs_data[0][1]).shape) > 1):
            ax.legend(loc='best', frameon=False, ncol=1)

        x0, x1 = ax.get_xlim()
        y0, y1 = ax.get_ylim()
        if width is None:
            width = rcParams['figure.figsize'][0]
        if height is None:
            height = rcParams['figure.figsize'][1]
        ax.set_aspect((height/width) * ((x1-x0)/(y1-y0)))
        plt.tight_layout()

        return plt