예제 #1
0
    def get_plot(self,
                 width=None,
                 height=None,
                 xmin=0.,
                 xmax=None,
                 ymin=0,
                 ymax=None,
                 colours=None,
                 dpi=400,
                 plt=None,
                 fonts=None,
                 style=None,
                 no_base_style=False):
        """Get a :obj:`matplotlib.pyplot` object of the optical spectra.

        Args:
            width (:obj:`float`, optional): The width of the plot.
            height (:obj:`float`, optional): The height of the plot.
            xmin (:obj:`float`, optional): The minimum energy on the x-axis.
            xmax (:obj:`float`, optional): The maximum energy on the x-axis.
            ymin (:obj:`float`, optional): The minimum absorption intensity on
                the y-axis.
            ymax (:obj:`float`, optional): The maximum absorption intensity on
                the y-axis.
            colours (:obj:`list`, optional): A :obj:`list` of colours to use in
                the plot. The colours can be specified as a hex code, set of
                rgb values, or any other format supported by matplotlib.
            dpi (:obj:`int`, optional): The dots-per-inch (pixel density) for
                the image.
            plt (:obj:`matplotlib.pyplot`, optional): A
                :obj:`matplotlib.pyplot` object to use for plotting.
            fonts (:obj:`list`, optional): Fonts to use in the plot. Can be a
                a single font, specified as a :obj:`str`, or several fonts,
                specified as a :obj:`list` of :obj:`str`.
            style (:obj:`list`, :obj:`str`, or :obj:`dict`): Any matplotlib
                style specifications, to be composed on top of Sumo base
                style.
            no_base_style (:obj:`bool`, optional): Prevent use of sumo base
                style. This can make alternative styles behave more
                predictably.

        Returns:
            :obj:`matplotlib.pyplot`: The plot of optical spectra.
        """
        n_plots = len(self._spec_data)
        plt = pretty_subplot(n_plots,
                             1,
                             sharex=True,
                             sharey=False,
                             width=width,
                             height=height,
                             dpi=dpi,
                             plt=plt)
        fig = plt.gcf()

        optics_colours = rcParams['axes.prop_cycle'].by_key()['color']
        if colours is not None:
            optics_colours = colours + optics_colours

        standard_ylabels = {
            'absorption': r'Absorption (cm$^\mathregular{-1}$)',
            'loss': r'Energy-loss',
            'eps_real': r'Re($\epsilon$)',
            'eps_imag': r'Im($\epsilon$)',
            'n_real': r'Re(n)',
            'n_imag': r'Im(n)'
        }

        if ymax is None:
            ymax_series = [None] * n_plots
        elif isinstance(ymax, float) or isinstance(ymax, int):
            ymax_series = [ymax] * n_plots
        elif not isinstance(ymax, list):
            raise ValueError()
        else:
            ymax_series = ymax

        if ymin is None:
            ymin_series = [None] * n_plots
        elif isinstance(ymin, float) or isinstance(ymin, int):
            ymin_series = [ymin] * n_plots
        elif not isinstance(ymin, list):
            raise ValueError()
        else:
            ymin_series = ymin

        for i, (spectrum_key, data), ymin, ymax in zip(range(n_plots),
                                                       self._spec_data.items(),
                                                       ymin_series,
                                                       ymax_series):
            ax = fig.axes[i]
            _plot_spectrum(data, self._label, self._band_gap, ax,
                           optics_colours)

            xmax = xmax if xmax else self._xmax
            ax.set_xlim(xmin, xmax)

            if ymin is None and spectrum_key in ('absorption', 'loss',
                                                 'eps_imag', 'n_imag'):
                ymin = 0
            elif ymin is None:
                ymin = ax.get_ylim()[0]

            if ymax is None and spectrum_key in ('absorption', ):
                ymax = 1e5
            elif ymax is None:
                ymax = ax.get_ylim()[1]

            ax.set_ylim(ymin, ymax)

            if spectrum_key == 'absorption':
                font = findfont(FontProperties(family=['sans-serif']))
                if 'Whitney' in font:
                    times_sign = 'x'
                else:
                    times_sign = r'\times'
                ax.yaxis.set_major_formatter(
                    FuncFormatter(curry_power_tick(times_sign=times_sign)))

            ax.yaxis.set_major_locator(MaxNLocator(5))
            ax.xaxis.set_major_locator(MaxNLocator(3))
            ax.yaxis.set_minor_locator(AutoMinorLocator(2))
            ax.xaxis.set_minor_locator(AutoMinorLocator(2))

            ax.set_ylabel(standard_ylabels.get(spectrum_key, spectrum_key))

            if i == 0:
                if (not np.all(np.array(self._label) == '') or len(
                        np.array(next(iter(
                            self._spec_data.items()))[1][0][1]).shape) > 1):
                    ax.legend(loc='best', frameon=False, ncol=1)

        ax.set_xlabel('Energy (eV)')

        # If only one plot, fix aspect ratio to match canvas
        if len(self._spec_data) == 1:
            x0, x1 = ax.get_xlim()
            y0, y1 = ax.get_ylim()
            if width is None:
                width = rcParams['figure.figsize'][0]
            if height is None:
                height = rcParams['figure.figsize'][1]
            ax.set_aspect((height / width) * ((x1 - x0) / (y1 - y0)))

        # Otherwise, rely only on tight_layout and hope for the best
        plt.tight_layout()
        return plt
예제 #2
0
    def get_plot(
        self,
        units="eV",
        width=None,
        height=None,
        xmin=0.0,
        xmax=None,
        ymin=0,
        ymax=None,
        colours=None,
        dpi=400,
        plt=None,
        fonts=None,
        style=None,
        no_base_style=False,
    ):
        """Get a :obj:`matplotlib.pyplot` object of the optical spectra.

        Args:
            units (:obj:`str`, optional): X-axis units for the plot. 'eV' for
                energy in electronvolts or 'nm' for wavelength in nanometers.
                Defaults to 'eV'.
            width (:obj:`float`, optional): The width of the plot.
            height (:obj:`float`, optional): The height of the plot.
            xmin (:obj:`float`, optional): The minimum energy on the x-axis.
            xmax (:obj:`float`, optional): The maximum energy on the x-axis.
            ymin (:obj:`float`, optional): The minimum absorption intensity on
                the y-axis.
            ymax (:obj:`float`, optional): The maximum absorption intensity on
                the y-axis.
            colours (:obj:`list`, optional): A :obj:`list` of colours to use in
                the plot. The colours can be specified as a hex code, set of
                rgb values, or any other format supported by matplotlib.
            dpi (:obj:`int`, optional): The dots-per-inch (pixel density) for
                the image.
            plt (:obj:`matplotlib.pyplot`, optional): A
                :obj:`matplotlib.pyplot` object to use for plotting.
            fonts (:obj:`list`, optional): Fonts to use in the plot. Can be a
                a single font, specified as a :obj:`str`, or several fonts,
                specified as a :obj:`list` of :obj:`str`.
            style (:obj:`list`, :obj:`str`, or :obj:`dict`): Any matplotlib
                style specifications, to be composed on top of Sumo base
                style.
            no_base_style (:obj:`bool`, optional): Prevent use of sumo base
                style. This can make alternative styles behave more
                predictably.

        Returns:
            :obj:`matplotlib.pyplot`: The plot of optical spectra.
        """
        n_plots = len(self._spec_data)
        plt = pretty_subplot(
            n_plots,
            1,
            sharex=True,
            sharey=False,
            width=width,
            height=height,
            dpi=dpi,
            plt=plt,
        )
        fig = plt.gcf()

        optics_colours = rcParams["axes.prop_cycle"].by_key()["color"]
        if colours is not None:
            optics_colours = colours + optics_colours

        standard_ylabels = {
            "absorption": r"Absorption (cm$^\mathregular{-1}$)",
            "loss": r"Energy-loss",
            "eps_real": r"Re($\epsilon$)",
            "eps_imag": r"Im($\epsilon$)",
            "n_real": r"Re(n)",
            "n_imag": r"Im(n)",
        }

        if ymax is None:
            ymax_series = [None] * n_plots
        elif isinstance(ymax, float) or isinstance(ymax, int):
            ymax_series = [ymax] * n_plots
        elif not isinstance(ymax, list):
            raise ValueError()
        else:
            ymax_series = ymax

        if ymin is None:
            ymin_series = [None] * n_plots
        elif isinstance(ymin, float) or isinstance(ymin, int):
            ymin_series = [ymin] * n_plots
        elif not isinstance(ymin, list):
            raise ValueError()
        else:
            ymin_series = ymin

        for i, (spectrum_key, data), ymin, ymax in zip(
            range(n_plots), self._spec_data.items(), ymin_series, ymax_series
        ):
            ax = fig.axes[i]
            _plot_spectrum(data, self._label, self._band_gap, ax, optics_colours, units)

            if units in ["ev", "eV"] and xmax is None:
                xmax = self._xmax  # use sumo-determined energy limits
            elif units == "nm":
                # use default minimum energy (max wavelength) of 2500 nm
                xmax = xmax if xmax is not None else 2500
                # convert sumo-determined max energy to min wavelength
                xmin = xmin if xmin else ev_to_nm(self._xmax)
            ax.set_xlim(xmin, xmax)

            if ymin is None and spectrum_key in (
                "absorption",
                "loss",
                "eps_imag",
                "n_imag",
            ):
                ymin = 0
            elif ymin is None:
                ymin = ax.get_ylim()[0]

            if ymax is None and spectrum_key in ("absorption",):
                ymax = 1e5
            elif ymax is None:
                ymax = ax.get_ylim()[1]

            ax.set_ylim(ymin, ymax)

            if spectrum_key == "absorption":
                font = findfont(FontProperties(family=["sans-serif"]))
                if "Whitney" in font:
                    times_sign = "x"
                else:
                    times_sign = r"\times"
                ax.yaxis.set_major_formatter(
                    FuncFormatter(curry_power_tick(times_sign=times_sign))
                )

            ax.yaxis.set_major_locator(MaxNLocator(5))
            ax.xaxis.set_major_locator(MaxNLocator(3))
            ax.yaxis.set_minor_locator(AutoMinorLocator(2))
            ax.xaxis.set_minor_locator(AutoMinorLocator(2))

            ax.set_ylabel(standard_ylabels.get(spectrum_key, spectrum_key))

            if i == 0:
                if (
                    not np.all(np.array(self._label) == "")
                    or len(np.array(next(iter(self._spec_data.items()))[1][0][1]).shape)
                    > 1
                ):
                    ax.legend(loc="best", frameon=False, ncol=1)

        xlabel = "Energy (eV)" if units == "eV" else "Wavelength (nm)"
        ax.set_xlabel(xlabel)

        # If only one plot, fix aspect ratio to match canvas
        if len(self._spec_data) == 1:
            x0, x1 = ax.get_xlim()
            y0, y1 = ax.get_ylim()
            if width is None:
                width = rcParams["figure.figsize"][0]
            if height is None:
                height = rcParams["figure.figsize"][1]
            ax.set_aspect((height / width) * ((x1 - x0) / (y1 - y0)))

        # Otherwise, rely only on tight_layout and hope for the best
        plt.tight_layout()
        return plt