Exemplo n.º 1
0
    def get_projected_plot(
        self,
        selection,
        mode="rgb",
        normalise="all",
        interpolate_factor=4,
        circle_size=150,
        projection_cutoff=0.001,
        zero_to_efermi=True,
        ymin=-6.0,
        ymax=6.0,
        width=None,
        height=None,
        vbm_cbm_marker=False,
        ylabel="Energy (eV)",
        dpi=400,
        plt=None,
        dos_plotter=None,
        dos_options=None,
        dos_label=None,
        plot_dos_legend=True,
        dos_aspect=3,
        aspect=None,
        fonts=None,
        style=None,
        no_base_style=False,
        spin=None,
    ):
        """Get a :obj:`matplotlib.pyplot` of the projected band structure.

        If the system is spin polarised, no spin has been specified and
        ``mode = 'rgb'`` spin up and spin down bands are differentiated by
        solid and dashed lines, respectively.
        For the other modes, spin up and spin down are plotted separately.

        Args:
            selection (list): A list of :obj:`tuple` or :obj:`string`
                identifying which elements and orbitals to project on to the
                band structure. These can be specified by both element and
                orbital, for example, the following will project the Bi s, p
                and S p orbitals::

                    [('Bi', 's'), ('Bi', 'p'), ('S', 'p')]

                If just the element is specified then all the orbitals of
                that element are combined. For example, to sum all the S
                orbitals::

                    [('Bi', 's'), ('Bi', 'p'), 'S']

                You can also choose to sum particular orbitals by supplying a
                :obj:`tuple` of orbitals. For example, to sum the S s, p, and
                d orbitals into a single projection::

                  [('Bi', 's'), ('Bi', 'p'), ('S', ('s', 'p', 'd'))]

                If ``mode = 'rgb'``, a maximum of 3 orbital/element
                combinations can be plotted simultaneously (one for red, green
                and blue), otherwise an unlimited number of elements/orbitals
                can be selected.
            mode (:obj:`str`, optional): Type of projected band structure to
                plot. Options are:

                    "rgb"
                        The band structure line color depends on the character
                        of the band. Each element/orbital contributes either
                        red, green or blue with the corresponding line colour a
                        mixture of all three colours. This mode only supports
                        up to 3 elements/orbitals combinations. The order of
                        the ``selection`` :obj:`tuple` determines which colour
                        is used for each selection.
                    "stacked"
                        The element/orbital contributions are drawn as a
                        series of stacked circles, with the colour depending on
                        the composition of the band. The size of the circles
                        can be scaled using the ``circle_size`` option.

            normalise (:obj:`str`, optional): Normalisation the projections.
                Options are:

                  * ``'all'``: Projections normalised against the sum of all
                       other projections.
                  * ``'select'``: Projections normalised against the sum of the
                       selected projections.
                  * ``None``: No normalisation performed.

            interpolate_factor (:obj:`int`, optional): The factor by which to
                interpolate the band structure (necessary to make smooth
                lines). A larger number indicates greater interpolation.
            circle_size (:obj:`float`, optional): The area of the circles used
                when ``mode = 'stacked'``.
            projection_cutoff (:obj:`float`): Don't plot projections with
                intensities below this number. This option is useful for
                stacked plots, where small projections clutter the plot.
            zero_to_efermi (:obj:`bool`): Normalise the plot such that the
                valence band maximum is set as 0 eV.
            ymin (:obj:`float`, optional): The minimum energy on the y-axis.
            ymax (:obj:`float`, optional): The maximum energy on the y-axis.
            width (:obj:`float`, optional): The width of the plot.
            height (:obj:`float`, optional): The height of the plot.
            vbm_cbm_marker (:obj:`bool`, optional): Plot markers to indicate
                the VBM and CBM locations.
            ylabel (:obj:`str`, optional): y-axis (i.e. energy) label/units
            dpi (:obj:`int`, optional): The dots-per-inch (pixel density) for
                the image.
            plt (:obj:`matplotlib.pyplot`, optional): A
                :obj:`matplotlib.pyplot` object to use for plotting.
            dos_plotter (:obj:`~sumo.plotting.dos_plotter.SDOSPlotter`, \
                optional): Plot the density of states alongside the band
                structure. This should be a
                :obj:`~sumo.plotting.dos_plotter.SDOSPlotter` object
                initialised with the data to plot.
            dos_options (:obj:`dict`, optional): The options for density of
                states plotting. This should be formatted as a :obj:`dict`
                containing any of the following keys:

                    "yscale" (:obj:`float`)
                        Scaling factor for the y-axis.
                    "xmin" (:obj:`float`)
                        The minimum energy to mask the energy and density of
                        states data (reduces plotting load).
                    "xmax" (:obj:`float`)
                        The maximum energy to mask the energy and density of
                        states data (reduces plotting load).
                    "colours" (:obj:`dict`)
                        Use custom colours for specific element and orbital
                        combinations. Specified as a :obj:`dict` of
                        :obj:`dict` of the colours. For example::

                           {
                                'Sn': {'s': 'r', 'p': 'b'},
                                'O': {'s': '#000000'}
                            }

                        The colour can be a hex code, series of rgb value, or
                        any other format supported by matplotlib.
                    "plot_total" (:obj:`bool`)
                        Plot the total density of states. Defaults to ``True``.
                    "legend_cutoff" (:obj:`float`)
                        The cut-off (in % of the maximum density of states
                        within the plotting range) for an elemental orbital to
                        be labelled in the legend. This prevents the legend
                        from containing labels for orbitals that have very
                        little contribution in the plotting range.
                    "subplot" (:obj:`bool`)
                        Plot the density of states for each element on separate
                        subplots. Defaults to ``False``.

            dos_label (:obj:`str`, optional): DOS axis label/units
            plot_dos_legend (:obj:`bool`): Whether to plot the dos legend.
            dos_aspect (:obj:`float`, optional): Aspect ratio for the band
                structure and density of states subplot. For example,
                ``dos_aspect = 3``, results in a ratio of 3:1, for the band
                structure:dos plots.
            aspect (:obj:`float`, optional): The aspect ratio of the band
                structure plot. By default the dimensions of the figure size
                are used to determine the aspect ratio. Set to ``1`` to force
                the plot to be square.
            fonts (:obj:`list`, optional): Fonts to use in the plot. Can be a
                a single font, specified as a :obj:`str`, or several fonts,
                specified as a :obj:`list` of :obj:`str`.
            style (:obj:`list`, :obj:`str`, or :obj:`dict`): Any matplotlib
                style specifications, to be composed on top of Sumo base
                style.
            no_base_style (:obj:`bool`, optional): Prevent use of sumo base
                style. This can make alternative styles behave more
                predictably.
            spin (:obj:`Spin`, optional): Plot a spin-polarised band structure,
                "up" or "1" for spin up only, "down" or "-1" for spin down only.
                Defaults to ``None``.

        Returns:
            :obj:`matplotlib.pyplot`: The projected electronic band structure
            plot.
        """
        if mode == "rgb" and len(selection) > 3:
            raise ValueError("Too many elements/orbitals specified (max 3)")
        elif mode == "solo" and dos_plotter:
            raise ValueError("Solo mode plotting with DOS not supported")

        if dos_plotter:
            plt = pretty_subplot(
                1,
                2,
                width,
                height,
                sharex=False,
                dpi=dpi,
                plt=plt,
                gridspec_kw={
                    "width_ratios": [dos_aspect, 1],
                    "wspace": 0
                },
            )
            ax = plt.gcf().axes[0]
        else:
            plt = pretty_plot(width, height, dpi=dpi, plt=plt)
            ax = plt.gca()

        data = self.bs_plot_data(zero_to_efermi)
        nbranches = len(data["distances"])

        # Ensure we do spin up first, then spin down
        spins = sorted(self.bs.bands.keys(), key=lambda s: -s.value)
        if spin is not None and len(spins) == 1:
            raise ValueError(
                "Spin-selection only possible with spin-polarised "
                "calculation results")
        if spin is Spin.up:
            spins = [spins[0]]
        elif spin is Spin.down:
            spins = [spins[1]]

        proj = get_projections_by_branches(self.bs,
                                           selection,
                                           normalise=normalise)

        # nd is branch index
        for spin, nd in it.product(spins, range(nbranches)):

            # mask data to reduce plotting load
            bands = np.array(data["energy"][str(spin)][nd])
            mask = np.where(
                np.any(bands > ymin - 0.05, axis=1)
                & np.any(bands < ymax + 0.05, axis=1))
            distances = data["distances"][nd]
            bands = bands[mask]
            weights = [proj[nd][i][spin][mask] for i in range(len(selection))]

            if len(distances
                   ) > 2:  # Only interpolate if it makes sense to do so
                # interpolate band structure to improve smoothness
                temp_dists = np.linspace(distances[0], distances[-1],
                                         len(distances) * interpolate_factor)
                bands = interp1d(
                    distances,
                    bands,
                    axis=1,
                    bounds_error=False,
                    fill_value="extrapolate",
                )(temp_dists)
                weights = interp1d(
                    distances,
                    weights,
                    axis=2,
                    bounds_error=False,
                    fill_value="extrapolate",
                )(temp_dists)
                distances = temp_dists

            else:  # change from list to array if we skipped the scipy interpolation
                weights = np.array(weights)
                bands = np.array(bands)
                distances = np.array(distances)

            # sometimes VASP produces very small negative weights
            weights[weights < 0] = 0

            if mode == "rgb":

                # colours aren't used now but needed later for legend
                colours = ["#ff0000", "#00ff00", "#0000ff"]

                # if only two orbitals then just use red and blue
                if len(weights) == 2:
                    weights = np.insert(weights,
                                        1,
                                        np.zeros(weights[0].shape),
                                        axis=0)
                    colours = ["#ff0000", "#0000ff"]

                ls = "-" if spin == Spin.up else "--"
                lc = rgbline(
                    distances,
                    bands,
                    weights[0],
                    weights[1],
                    weights[2],
                    alpha=1,
                    linestyles=ls,
                    linewidth=(rcParams["lines.linewidth"] * 1.25),
                )
                ax.add_collection(lc)

            elif mode == "stacked":
                # TODO: Handle spin

                # use some nice custom colours first, then default colours
                colours = [
                    "#3952A3", "#FAA41A", "#67BC47", "#6ECCDD", "#ED2025"
                ]
                colour_series = rcParams["axes.prop_cycle"].by_key()["color"]
                colours.extend(colour_series)

                # very small circles look crap
                weights[weights < projection_cutoff] = 0

                distances = list(distances) * len(bands)
                bands = bands.flatten()
                zorders = range(-len(weights), 0)
                for w, c, z in zip(weights, colours, zorders):
                    ax.scatter(
                        distances,
                        bands,
                        c=c,
                        s=circle_size * w**2,
                        zorder=z,
                        rasterized=True,
                    )

        # plot the legend
        for c, spec in zip(colours, selection):
            if isinstance(spec, str):
                label = spec
            else:
                label = "{} ({})".format(spec[0], " + ".join(spec[1]))
            ax.scatter([-10000], [-10000],
                       c=c,
                       s=50,
                       label=label,
                       edgecolors="none")

        if dos_plotter:
            loc = 1
            anchor_point = (-0.2, 1)
        else:
            loc = 2
            anchor_point = (0.95, 1)

        ax.legend(
            bbox_to_anchor=anchor_point,
            loc=loc,
            frameon=False,
            handletextpad=0.1,
            scatterpoints=1,
        )

        # finish and tidy plot
        self._maketicks(ax, ylabel=ylabel)
        self._makeplot(
            ax,
            plt.gcf(),
            data,
            zero_to_efermi=zero_to_efermi,
            vbm_cbm_marker=vbm_cbm_marker,
            width=width,
            height=height,
            ymin=ymin,
            ymax=ymax,
            dos_plotter=dos_plotter,
            dos_options=dos_options,
            dos_label=dos_label,
            plot_dos_legend=plot_dos_legend,
            aspect=aspect,
        )
        return plt
Exemplo n.º 2
0
    def get_projected_plot(self,
                           selection,
                           mode='rgb',
                           interpolate_factor=4,
                           circle_size=150,
                           projection_cutoff=0.001,
                           zero_to_efermi=True,
                           ymin=-6.,
                           ymax=6.,
                           width=6.,
                           height=6.,
                           vbm_cbm_marker=False,
                           ylabel='Energy (eV)',
                           dpi=400,
                           plt=None,
                           dos_plotter=None,
                           dos_options=None,
                           dos_label=None,
                           dos_aspect=3,
                           fonts=None):
        """Get a :obj:`matplotlib.pyplot` of the projected band structure.

        If the system is spin polarised and ``mode = 'rgb'`` spin up and spin
        down bands are differentiated by solid and dashed lines, respectively.
        For the other modes, spin up and spin down are plotted separately.

        Args:
            selection (list): A list of :obj:`tuple` or :obj:`string`
                identifying which elements and orbitals to project on to the
                band structure. These can be specified by both element and
                orbital, for example, the following will project the Bi s, p
                and S p orbitals::

                    [('Bi', 's'), ('Bi', 'p'), ('S', 'p')]

                If just the element is specified then all the orbitals of
                that element are combined. For example, to sum all the S
                orbitals::

                    [('Bi', 's'), ('Bi', 'p'), 'S']

                You can also choose to sum particular orbitals by supplying a
                :obj:`tuple` of orbitals. For example, to sum the S s, p, and
                d orbitals into a single projection::

                  [('Bi', 's'), ('Bi', 'p'), ('S', ('s', 'p', 'd'))]

                If ``mode = 'rgb'``, a maximum of 3 orbital/element
                combinations can be plotted simultaneously (one for red, green
                and blue), otherwise an unlimited number of elements/orbitals
                can be selected.
            mode (:obj:`str`, optional): Type of projected band structure to
                plot. Options are:

                    "rgb"
                        The band structure line color depends on the character
                        of the band. Each element/orbital contributes either
                        red, green or blue with the corresponding line colour a
                        mixture of all three colours. This mode only supports
                        up to 3 elements/orbitals combinations. The order of
                        the ``selection`` :obj:`tuple` determines which colour
                        is used for each selection.
                    "stacked"
                        The element/orbital contributions are drawn as a
                        series of stacked circles, with the colour depending on
                        the composition of the band. The size of the circles
                        can be scaled using the ``circle_size`` option.

            interpolate_factor (:obj:`int`, optional): The factor by which to
                interpolate the band structure (necessary to make smooth
                lines). A larger number indicates greater interpolation.
            circle_size (:obj:`float`, optional): The area of the circles used
                when ``mode = 'stacked'``.
            projection_cutoff (:obj:`float`): Don't plot projections with
                intensities below this number. This option is useful for
                stacked plots, where small projections clutter the plot.
            zero_to_efermi (:obj:`bool`): Normalise the plot such that the
                valence band maximum is set as 0 eV.
            ymin (:obj:`float`, optional): The minimum energy on the y-axis.
            ymax (:obj:`float`, optional): The maximum energy on the y-axis.
            width (:obj:`float`, optional): The width of the plot.
            height (:obj:`float`, optional): The height of the plot.
            vbm_cbm_marker (:obj:`bool`, optional): Plot markers to indicate
                the VBM and CBM locations.
            ylabel (:obj:`str`, optional): y-axis (i.e. energy) label/units
            dpi (:obj:`int`, optional): The dots-per-inch (pixel density) for
                the image.
            plt (:obj:`matplotlib.pyplot`, optional): A
                :obj:`matplotlib.pyplot` object to use for plotting.
            dos_plotter (:obj:`~sumo.plotting.dos_plotter.SDOSPlotter`, \
                optional): Plot the density of states alongside the band
                structure. This should be a
                :obj:`~sumo.plotting.dos_plotter.SDOSPlotter` object
                initialised with the data to plot.
            dos_options (:obj:`dict`, optional): The options for density of
                states plotting. This should be formatted as a :obj:`dict`
                containing any of the following keys:

                    "yscale" (:obj:`float`)
                        Scaling factor for the y-axis.
                    "xmin" (:obj:`float`)
                        The minimum energy to mask the energy and density of
                        states data (reduces plotting load).
                    "xmax" (:obj:`float`)
                        The maximum energy to mask the energy and density of
                        states data (reduces plotting load).
                    "colours" (:obj:`dict`)
                        Use custom colours for specific element and orbital
                        combinations. Specified as a :obj:`dict` of
                        :obj:`dict` of the colours. For example::

                           {
                                'Sn': {'s': 'r', 'p': 'b'},
                                'O': {'s': '#000000'}
                            }

                        The colour can be a hex code, series of rgb value, or
                        any other format supported by matplotlib.
                    "plot_total" (:obj:`bool`)
                        Plot the total density of states. Defaults to ``True``.
                    "legend_cutoff" (:obj:`float`)
                        The cut-off (in % of the maximum density of states
                        within the plotting range) for an elemental orbital to
                        be labelled in the legend. This prevents the legend
                        from containing labels for orbitals that have very
                        little contribution in the plotting range.
                    "subplot" (:obj:`bool`)
                        Plot the density of states for each element on separate
                        subplots. Defaults to ``False``.

            dos_label (:obj:`str`, optional): DOS axis label/units
            fonts (:obj:`list`, optional): Fonts to use in the plot. Can be a
                a single font, specified as a :obj:`str`, or several fonts,
                specified as a :obj:`list` of :obj:`str`.

        Returns:
            :obj:`matplotlib.pyplot`: The projected electronic band structure
            plot.
        """
        if mode == 'rgb' and len(selection) > 3:
            raise ValueError('Too many elements/orbitals specified (max 3)')
        elif mode == 'solo' and dos_plotter:
            raise ValueError('Solo mode plotting with DOS not supported')

        if dos_plotter:
            plt = pretty_subplot(1,
                                 2,
                                 width,
                                 height,
                                 sharex=False,
                                 dpi=dpi,
                                 plt=plt,
                                 fonts=fonts,
                                 gridspec_kw={
                                     'width_ratios': [dos_aspect, 1],
                                     'wspace': 0
                                 })
            ax = plt.gcf().axes[0]
        else:
            plt = pretty_plot(width, height, dpi=dpi, plt=plt, fonts=fonts)
            ax = plt.gca()

        data = self.bs_plot_data(zero_to_efermi)
        nbranches = len(data['distances'])

        # Ensure we do spin up first, then spin down
        spins = sorted(self._bs.bands.keys(), key=lambda spin: -spin.value)

        proj = get_projections_by_branches(self._bs,
                                           selection,
                                           normalise='select')

        # nd is branch index
        for spin, nd in it.product(spins, range(nbranches)):

            # mask data to reduce plotting load
            bands = np.array(data['energy'][nd][str(spin)])
            mask = np.where(
                np.any(bands > ymin - 0.05, axis=1)
                & np.any(bands < ymax + 0.05, axis=1))
            distances = data['distances'][nd]
            bands = bands[mask]
            weights = [proj[nd][i][spin][mask] for i in range(len(selection))]

            # interpolate band structure to improve smoothness
            dx = (distances[1] - distances[0]) / interpolate_factor
            temp_dists = np.arange(distances[0], distances[-1], dx)
            bands = interp1d(distances, bands, axis=1)(temp_dists)
            weights = interp1d(distances, weights, axis=2)(temp_dists)
            distances = temp_dists

            # sometimes VASP produces very small negative weights
            weights[weights < 0] = 0

            if mode == 'rgb':

                # colours aren't used now but needed later for legend
                colours = ['#ff0000', '#00ff00', '#0000ff']

                # if only two orbitals then just use red and blue
                if len(weights) == 2:
                    weights = np.insert(weights,
                                        1,
                                        np.zeros(weights[0].shape),
                                        axis=0)
                    colours = ['#ff0000', '#0000ff']

                ls = '-' if spin == Spin.up else '--'
                lc = rgbline(distances,
                             bands,
                             weights[0],
                             weights[1],
                             weights[2],
                             alpha=1,
                             linestyles=ls,
                             linewidth=2.5)
                ax.add_collection(lc)

            elif mode == 'stacked':
                # TODO: Handle spin

                # use some nice custom colours first, then default colours
                colours = [
                    '#3952A3', '#FAA41A', '#67BC47', '#6ECCDD', '#ED2025'
                ]
                colours.extend(np.array(default_colours) / 255)

                # very small circles look crap
                weights[weights < projection_cutoff] = 0

                distances = list(distances) * len(bands)
                bands = bands.flatten()
                zorders = range(-len(weights), 0)
                for w, c, z in zip(weights, colours, zorders):
                    ax.scatter(distances,
                               bands,
                               c=c,
                               s=circle_size * w**2,
                               zorder=z,
                               rasterized=True)

        # plot the legend
        for c, spec in zip(colours, selection):
            if type(spec) == str:
                label = spec
            else:
                label = '{} ({})'.format(spec[0], " + ".join(spec[1]))
            ax.scatter([-10000], [-10000],
                       c=c,
                       s=50,
                       label=label,
                       edgecolors='none')

        if dos_plotter:
            loc = 1
            anchor_point = (-0.2, 1)
        else:
            loc = 2
            anchor_point = (0.95, 1)

        ax.legend(bbox_to_anchor=anchor_point,
                  loc=loc,
                  frameon=False,
                  prop={'size': label_size - 2},
                  handletextpad=0.1,
                  scatterpoints=1)

        # finish and tidy plot
        self._maketicks(ax, ylabel=ylabel)
        self._makeplot(ax,
                       plt.gcf(),
                       data,
                       zero_to_efermi=zero_to_efermi,
                       vbm_cbm_marker=vbm_cbm_marker,
                       width=width,
                       height=height,
                       ymin=ymin,
                       ymax=ymax,
                       dos_plotter=dos_plotter,
                       dos_options=dos_options,
                       dos_label=dos_label)
        return plt