Ejemplo n.º 1
0
def dosplot(filename=None,
            code='vasp',
            prefix=None,
            directory=None,
            elements=None,
            lm_orbitals=None,
            atoms=None,
            spin=None,
            subplot=False,
            shift=True,
            total_only=False,
            plot_total=True,
            legend_on=True,
            legend_frame_on=False,
            legend_cutoff=3.,
            gaussian=None,
            height=6.,
            width=8.,
            xmin=-6.,
            xmax=6.,
            num_columns=2,
            colours=None,
            yscale=1,
            xlabel='Energy (eV)',
            ylabel='Arb. units',
            style=None,
            no_base_style=False,
            image_format='pdf',
            dpi=400,
            plt=None,
            fonts=None):
    """A script to plot the density of states from a vasprun.xml file.

    Args:
        filename (:obj:`str`, optional): Path to a DOS data file (can be
            gzipped). The preferred file type depends on the electronic
            structure code: vasprun.xml (VASP); *.bands (CASTEP); dos.*
            (Questaal).
        code (:obj:`str`, optional): Electronic structure code used ('vasp',
              'castep' or 'questaal'). Note that for Castep only a rough TDOS
              is available, assembled by sampling the eigenvalues.
        prefix (:obj:`str`, optional): Prefix for file names.
        directory (:obj:`str`, optional): The directory in which to save files.
        elements (:obj:`dict`, optional): The elements and orbitals to extract
            from the projected density of states. Should be provided as a
            :obj:`dict` with the keys as the element names and corresponding
            values as a :obj:`tuple` of orbitals. For example, the following
            would extract the Bi s, px, py and d orbitals::

                {'Bi': ('s', 'px', 'py', 'd')}

            If an element is included with an empty :obj:`tuple`, all orbitals
            for that species will be extracted. If ``elements`` is not set or
            set to ``None``, all elements for all species will be extracted.
        lm_orbitals (:obj:`dict`, optional): The orbitals to decompose into
            their lm contributions (e.g. p -> px, py, pz). Should be provided
            as a :obj:`dict`, with the elements names as keys and a
            :obj:`tuple` of orbitals as the corresponding values. For example,
            the following would be used to decompose the oxygen p and d
            orbitals::

                {'O': ('p', 'd')}

        atoms (:obj:`dict`, optional): Which atomic sites to use when
            calculating the projected density of states. Should be provided as
            a :obj:`dict`, with the element names as keys and a :obj:`tuple` of
            :obj:`int` specifying the atomic indices as the corresponding
            values. The elemental projected density of states will be summed
            only over the atom indices specified. If an element is included
            with an empty :obj:`tuple`, then all sites for that element will
            be included. The indices are 0 based for each element specified in
            the POSCAR. For example, the following will calculate the density
            of states for the first 4 Sn atoms and all O atoms in the
            structure::

                {'Sn': (1, 2, 3, 4), 'O': (, )}

            If ``atoms`` is not set or set to ``None`` then all atomic sites
            for all elements will be considered.
        spin (:obj:`Spin`, optional): Plot only one spin channel from a
            spin-polarised calculation; "up" or "1" for spin up only, "down" or
            "-1" for spin down only. Defaults to ``None``.
        subplot (:obj:`bool`, optional): Plot the density of states for each
            element on separate subplots. Defaults to ``False``.
        shift (:obj:`bool`, optional): Shift the energies such that the valence
            band maximum (or Fermi level for metals) is at 0 eV. Defaults to
            ``True``.
        total_only (:obj:`bool`, optional): Only extract the total density of
            states. Defaults to ``False``.
        plot_total (:obj:`bool`, optional): Plot the total density of states.
            Defaults to ``True``.
        legend_on (:obj:`bool`, optional): Plot the graph legend. Defaults
            to ``True``.
        legend_frame_on (:obj:`bool`, optional): Plot a frame around the
            graph legend. Defaults to ``False``.
        legend_cutoff (:obj:`float`, optional): The cut-off (in % of the
            maximum density of states within the plotting range) for an
            elemental orbital to be labelled in the legend. This prevents
            the legend from containing labels for orbitals that have very
            little contribution in the plotting range.
        gaussian (:obj:`float`, optional): Broaden the density of states using
            convolution with a gaussian function. This parameter controls the
            sigma or standard deviation of the gaussian distribution.
        height (:obj:`float`, optional): The height of the plot.
        width (:obj:`float`, optional): The width of the plot.
        xmin (:obj:`float`, optional): The minimum energy on the x-axis.
        xmax (:obj:`float`, optional): The maximum energy on the x-axis.
        num_columns (:obj:`int`, optional): The number of columns in the
            legend.
        colours (:obj:`dict`, optional): Use custom colours for specific
            element and orbital combinations. Specified as a :obj:`dict` of
            :obj:`dict` of the colours. For example::

                {
                    'Sn': {'s': 'r', 'p': 'b'},
                    'O': {'s': '#000000'}
                }

            The colour can be a hex code, series of rgb value, or any other
            format supported by matplotlib.
        xlabel (:obj:`str`, optional): Label/units for x-axis (i.e. energy)
        ylabel (:obj:`str`, optional): Label/units for y-axis (i.e. DOS)
        yscale (:obj:`float`, optional): Scaling factor for the y-axis.
        style (:obj:`list` or :obj:`str`, optional): (List of) matplotlib style
            specifications, to be composed on top of Sumo base style.
        no_base_style (:obj:`bool`, optional): Prevent use of sumo base style.
            This can make alternative styles behave more predictably.
        image_format (:obj:`str`, optional): The image file format. Can be any
            format supported by matplotlib, including: png, jpg, pdf, and svg.
            Defaults to pdf.
        dpi (:obj:`int`, optional): The dots-per-inch (pixel density) for
            the image.
        plt (:obj:`matplotlib.pyplot`, optional): A
            :obj:`matplotlib.pyplot` object to use for plotting.
        fonts (:obj:`list`, optional): Fonts to use in the plot. Can be a
            a single font, specified as a :obj:`str`, or several fonts,
            specified as a :obj:`list` of :obj:`str`.

    Returns:
        A matplotlib pyplot object.
    """

    if code.lower() == 'vasp':
        if not filename:
            if os.path.exists('vasprun.xml'):
                filename = 'vasprun.xml'
            elif os.path.exists('vasprun.xml.gz'):
                filename = 'vasprun.xml.gz'
            else:
                logging.error('ERROR: No vasprun.xml found!')
                sys.exit()

        dos, pdos = load_dos(filename, elements, lm_orbitals, atoms, gaussian,
                             total_only)

    elif code.lower() == 'castep':
        for arg in sumo.io.castep.unsupported_dosplot_args:
            if locals().get(arg, None) is not None:
                logging.error('Cannot set "{}" for CASTEP DOS; only total DOS '
                              'is available.'.format(arg))
                sys.exit()

        if filename:
            bands_file = filename
        else:
            band_candidates = glob('*.bands')
            if len(band_candidates) == 0:
                logging.error('ERROR: No *.bands file found!')
                sys.exit()
            elif len(band_candidates) == 1:
                bands_file = band_candidates[0]
            else:
                logging.error('ERROR: Too many *.bands files found!')
                sys.exit()
        dos = sumo.io.castep.read_tdos(bands_file,
                                       gaussian=gaussian,
                                       emin=xmin,
                                       emax=xmax)
        pdos = {}

    elif code.lower() == 'questaal':
        if filename:
            pdos_file = filename
            ext = pdos_file.split('.')[-1]
        else:
            pdos_candidates = glob('dos.*')
            for candidate in pdos_candidates:
                if candidate.split('.')[-1] in ('pdf', 'png', 'svg', 'jpg',
                                                'jpeg'):
                    continue
                elif candidate.split('.')[-1].lower() in ('gz', 'z', 'bz2'):
                    pdos_file = candidate
                    ext = candidate.split('.')[-2]
                    break
                else:
                    pdos_file = candidate
                    ext = candidate.split('.')[-1]
                    break
            else:
                raise ValueError("No questaal dos file found")

        if os.path.exists('tdos.{}'.format(ext)):
            tdos_file = 'tdos.{}'.format(ext)
        else:
            tdos_file = None
        if os.path.exists('site.{}'.format(ext)):
            site_file = 'site.{}'.format(ext)
        else:
            site_file = None

        if shift:
            logging.warning("Fermi level shift requested, but not implemented "
                            "for Questaal DOS.")

        dos, pdos = sumo.io.questaal.read_dos(pdos_file=pdos_file,
                                              tdos_file=tdos_file,
                                              site_file=site_file,
                                              ry=True,
                                              gaussian=gaussian,
                                              total_only=total_only,
                                              elements=elements,
                                              lm_orbitals=lm_orbitals,
                                              atoms=atoms)

    save_files = False if plt else True  # don't save if pyplot object provided

    spin = string_to_spin(
        spin)  # Convert spin argument to pymatgen Spin object
    plotter = SDOSPlotter(dos, pdos)
    plt = plotter.get_plot(subplot=subplot,
                           width=width,
                           height=height,
                           xmin=xmin,
                           xmax=xmax,
                           yscale=yscale,
                           colours=colours,
                           plot_total=plot_total,
                           legend_on=legend_on,
                           num_columns=num_columns,
                           legend_frame_on=legend_frame_on,
                           xlabel=xlabel,
                           ylabel=ylabel,
                           legend_cutoff=legend_cutoff,
                           dpi=dpi,
                           plt=plt,
                           fonts=fonts,
                           style=style,
                           no_base_style=no_base_style,
                           spin=spin)

    if save_files:
        basename = 'dos.{}'.format(image_format)
        filename = '{}_{}'.format(prefix, basename) if prefix else basename
        if directory:
            filename = os.path.join(directory, filename)
        plt.savefig(filename,
                    format=image_format,
                    dpi=dpi,
                    bbox_inches='tight')
        write_files(dos, pdos, prefix=prefix, directory=directory)
    else:
        return plt
Ejemplo n.º 2
0
def bandplot(
    filenames=None,
    code="vasp",
    prefix=None,
    directory=None,
    vbm_cbm_marker=False,
    projection_selection=None,
    mode="rgb",
    normalise="all",
    interpolate_factor=4,
    circle_size=150,
    dos_file=None,
    cart_coords=False,
    scissor=None,
    ylabel="Energy (eV)",
    dos_label=None,
    elements=None,
    lm_orbitals=None,
    atoms=None,
    spin=None,
    total_only=False,
    plot_total=True,
    legend_cutoff=3,
    gaussian=None,
    height=None,
    width=None,
    ymin=-6.0,
    ymax=6.0,
    colours=None,
    yscale=1,
    style=None,
    no_base_style=False,
    image_format="pdf",
    dpi=400,
    plt=None,
    fonts=None,
):
    """Plot electronic band structure diagrams from vasprun.xml files.

    Args:
        filenames (:obj:`str` or :obj:`list`, optional): Path to input files:

            Vasp:
                Use vasprun.xml or vasprun.xml.gz file.
            Questaal:
                Path to a bnds.ext file. The extension will also be used to
                find site.ext and syml.ext files in the same directory.
            Castep:
                Path to a seedname.bands file. The prefix ("seedname") is used
                to locate a seedname.cell file in the same directory and read
                in the positions of high-symmetry points.

            If no filenames are provided, sumo
            will search for vasprun.xml or vasprun.xml.gz files in folders
            named 'split-0*'. Failing that, the code will look for a vasprun in
            the current directory. If a :obj:`list` of vasprun files is
            provided, these will be combined into a single band structure.

        code (:obj:`str`, optional): Calculation type. Default is 'vasp';
            'questaal' and 'castep' also supported (with a reduced
            feature-set).
        prefix (:obj:`str`, optional): Prefix for file names.
        directory (:obj:`str`, optional): The directory in which to save files.
        vbm_cbm_marker (:obj:`bool`, optional): Plot markers to indicate the
            VBM and CBM locations.
        projection_selection (list): A list of :obj:`tuple` or :obj:`string`
            identifying which elements and orbitals to project on to the
            band structure. These can be specified by both element and
            orbital, for example, the following will project the Bi s, p
            and S p orbitals::

                [('Bi', 's'), ('Bi', 'p'), ('S', 'p')]

            If just the element is specified then all the orbitals of
            that element are combined. For example, to sum all the S
            orbitals::

                [('Bi', 's'), ('Bi', 'p'), 'S']

            You can also choose to sum particular orbitals by supplying a
            :obj:`tuple` of orbitals. For example, to sum the S s, p, and
            d orbitals into a single projection::

                [('Bi', 's'), ('Bi', 'p'), ('S', ('s', 'p', 'd'))]

            If ``mode = 'rgb'``, a maximum of 3 orbital/element
            combinations can be plotted simultaneously (one for red, green
            and blue), otherwise an unlimited number of elements/orbitals
            can be selected.
        mode (:obj:`str`, optional): Type of projected band structure to
            plot. Options are:

                "rgb"
                    The band structure line color depends on the character
                    of the band. Each element/orbital contributes either
                    red, green or blue with the corresponding line colour a
                    mixture of all three colours. This mode only supports
                    up to 3 elements/orbitals combinations. The order of
                    the ``selection`` :obj:`tuple` determines which colour
                    is used for each selection.
                "stacked"
                    The element/orbital contributions are drawn as a
                    series of stacked circles, with the colour depending on
                    the composition of the band. The size of the circles
                    can be scaled using the ``circle_size`` option.

        normalise (:obj:`str`, optional): Normalisation the projections.
            Options are:

              * ``'all'``: Projections normalised against the sum of all
                   other projections.
              * ``'select'``: Projections normalised against the sum of the
                   selected projections.
              * ``None``: No normalisation performed.

        circle_size (:obj:`float`, optional): The area of the circles used
            when ``mode = 'stacked'``.
        cart_coords (:obj:`bool`, optional): Whether the k-points are read as
            cartesian or reciprocal coordinates. This is only required for
            Questaal output; Vasp output is less ambiguous. Defaults to
            ``False`` (fractional coordinates).
        scissor (:obj:`float`, optional): Apply a scissor operator (rigid shift
            of the CBM), use with caution if applying to metals.
        dos_file (:obj:'str', optional): Path to vasprun.xml file from which to
            read the density of states information. If set, the density of
            states will be plotted alongside the bandstructure.
        elements (:obj:`dict`, optional): The elements and orbitals to extract
            from the projected density of states. Should be provided as a
            :obj:`dict` with the keys as the element names and corresponding
            values as a :obj:`tuple` of orbitals. For example, the following
            would extract the Bi s, px, py and d orbitals::

                {'Bi': ('s', 'px', 'py', 'd')}

            If an element is included with an empty :obj:`tuple`, all orbitals
            for that species will be extracted. If ``elements`` is not set or
            set to ``None``, all elements for all species will be extracted.
        lm_orbitals (:obj:`dict`, optional): The orbitals to decompose into
            their lm contributions (e.g. p -> px, py, pz). Should be provided
            as a :obj:`dict`, with the elements names as keys and a
            :obj:`tuple` of orbitals as the corresponding values. For example,
            the following would be used to decompose the oxygen p and d
            orbitals::

                {'O': ('p', 'd')}

        atoms (:obj:`dict`, optional): Which atomic sites to use when
            calculating the projected density of states. Should be provided as
            a :obj:`dict`, with the element names as keys and a :obj:`tuple` of
            :obj:`int` specifying the atomic indices as the corresponding
            values. The elemental projected density of states will be summed
            only over the atom indices specified. If an element is included
            with an empty :obj:`tuple`, then all sites for that element will
            be included. The indices are 0 based for each element specified in
            the POSCAR. For example, the following will calculate the density
            of states for the first 4 Sn atoms and all O atoms in the
            structure::

                {'Sn': (1, 2, 3, 4), 'O': (, )}

            If ``atoms`` is not set or set to ``None`` then all atomic sites
            for all elements will be considered.
        spin (:obj:`Spin`, optional): Plot only one spin channel from a
            spin-polarised calculation; "up" or "1" for spin up only, "down" or
            "-1" for spin down only. Defaults to ``None``.
        total_only (:obj:`bool`, optional): Only extract the total density of
            states. Defaults to ``False``.
        plot_total (:obj:`bool`, optional): Plot the total density of states.
            Defaults to ``True``.
        legend_cutoff (:obj:`float`, optional): The cut-off (in % of the
            maximum density of states within the plotting range) for an
            elemental orbital to be labelled in the legend. This prevents
            the legend from containing labels for orbitals that have very
            little contribution in the plotting range.
        gaussian (:obj:`float`, optional): Broaden the density of states using
            convolution with a gaussian function. This parameter controls the
            sigma or standard deviation of the gaussian distribution.
        height (:obj:`float`, optional): The height of the plot.
        width (:obj:`float`, optional): The width of the plot.
        ymin (:obj:`float`, optional): The minimum energy on the y-axis.
        ymax (:obj:`float`, optional): The maximum energy on the y-axis.
        style (:obj:`list` or :obj:`str`, optional): (List of) matplotlib style
            specifications, to be composed on top of Sumo base style.
        no_base_style (:obj:`bool`, optional): Prevent use of sumo base style.
            This can make alternative styles behave more predictably.
        colours (:obj:`dict`, optional): Use custom colours for specific
            element and orbital combinations. Specified as a :obj:`dict` of
            :obj:`dict` of the colours. For example::

                {
                    'Sn': {'s': 'r', 'p': 'b'},
                    'O': {'s': '#000000'}
                }

            The colour can be a hex code, series of rgb value, or any other
            format supported by matplotlib.
        yscale (:obj:`float`, optional): Scaling factor for the y-axis.
        image_format (:obj:`str`, optional): The image file format. Can be any
            format supported by matplotlib, including: png, jpg, pdf, and svg.
            Defaults to pdf.
        dpi (:obj:`int`, optional): The dots-per-inch (pixel density) for
            the image.
        plt (:obj:`matplotlib.pyplot`, optional): A
            :obj:`matplotlib.pyplot` object to use for plotting.
        fonts (:obj:`list`, optional): Fonts to use in the plot. Can be a
            a single font, specified as a :obj:`str`, or several fonts,
            specified as a :obj:`list` of :obj:`str`.

    Returns:
        If ``plt`` set then the ``plt`` object will be returned. Otherwise, the
        method will return a :obj:`list` of filenames written to disk.
    """
    if not filenames:
        filenames = find_vasprun_files()
    elif isinstance(filenames, str):
        filenames = [filenames]

    # only load the orbital projects if we definitely need them
    parse_projected = True if projection_selection else False

    # now load all the band structure data and combine using the
    # get_reconstructed_band_structure function from pymatgen
    bandstructures = []
    if code == "vasp":
        for vr_file in filenames:
            vr = BSVasprun(vr_file, parse_projected_eigen=parse_projected)
            bs = vr.get_band_structure(line_mode=True)
            bandstructures.append(bs)
        bs = get_reconstructed_band_structure(bandstructures)
    elif code == "castep":
        for bands_file in filenames:
            cell_file = _replace_ext(bands_file, "cell")
            if os.path.isfile(cell_file):
                logging.info(f"Found cell file {cell_file}...")
            else:
                logging.info(f"Did not find cell file {cell_file}...")
                cell_file = None
            bs = castep_band_structure(bands_file, cell_file=cell_file)
            bandstructures.append(bs)
        bs = get_reconstructed_band_structure(bandstructures)
    elif code == "questaal":
        bnds_file = filenames[0]
        ext = bnds_file.split(".")[-1]
        bnds_folder = os.path.join(bnds_file, os.path.pardir)

        site_file = os.path.abspath(os.path.join(bnds_folder, f"site.{ext}"))

        if os.path.isfile(site_file):
            logging.info("site file found, reading lattice...")
            site_data = QuestaalSite.from_file(site_file)
            bnds_lattice = site_data.structure.lattice
            alat = site_data.alat
        else:
            raise OSError(
                "Site file {} not found: "
                "needed to determine lattice".format(site_file)
            )

        syml_file = os.path.abspath(os.path.join(bnds_folder, f"syml.{ext}"))
        if os.path.isfile(syml_file):
            logging.info("syml file found, reading special-point labels...")
            bnds_labels = labels_from_syml(syml_file)
        else:
            logging.info("syml file not found, band structure lacks labels")
            bnds_labels = {}

        bs = questaal_band_structure(
            bnds_file,
            bnds_lattice,
            alat=alat,
            labels=bnds_labels,
            coords_are_cartesian=cart_coords,
        )

    # currently not supported as it is a pain to make subplots within subplots,
    # although need to check this is still the case
    if "split" in mode and dos_file:
        logging.error(
            "ERROR: Plotting split projected band structure with DOS"
            " not supported.\nPlease use --projected-rgb or "
            "--projected-stacked options."
        )
        sys.exit()

    if projection_selection and mode == "rgb" and len(projection_selection) > 3:
        logging.error(
            "ERROR: RGB projected band structure only "
            "supports up to 3 elements/orbitals."
            "\nUse alternative --mode setting."
        )
        sys.exit()

    # don't save if pyplot object provided
    save_files = False if plt else True

    dos_plotter = None
    dos_opts = None
    if dos_file:
        if code == "vasp":
            dos, pdos = load_dos(
                dos_file, elements, lm_orbitals, atoms, gaussian, total_only
            )
        elif code == "castep":
            pdos_file = None
            if cell_file:
                pdos_file = _replace_ext(cell_file, "pdos_bin")
                if not os.path.isfile(pdos_file):
                    pdos_file = None
                    logging.info(
                        f"PDOS file {pdos_file} does not exist, "
                        "falling back to TDOS."
                    )
                else:
                    logging.info(f"Found PDOS file {pdos_file}")
            else:
                logging.info(
                    f"Cell file {cell_file} does not exist, " "cannot plot PDOS."
                )

            dos, pdos = read_castep_dos(
                dos_file,
                pdos_file=pdos_file,
                cell_file=cell_file,
                gaussian=gaussian,
                lm_orbitals=lm_orbitals,
                elements=elements,
                efermi_to_vbm=True,
            )

        dos_plotter = SDOSPlotter(dos, pdos)
        dos_opts = {
            "plot_total": plot_total,
            "legend_cutoff": legend_cutoff,
            "colours": colours,
            "yscale": yscale,
        }

    if scissor:
        bs = bs.apply_scissor(scissor)

    spin = string_to_spin(spin)  # Convert spin name to pymatgen Spin object
    plotter = SBSPlotter(bs)
    if projection_selection:
        plt = plotter.get_projected_plot(
            projection_selection,
            mode=mode,
            normalise=normalise,
            interpolate_factor=interpolate_factor,
            circle_size=circle_size,
            zero_to_efermi=True,
            ymin=ymin,
            ymax=ymax,
            height=height,
            width=width,
            vbm_cbm_marker=vbm_cbm_marker,
            ylabel=ylabel,
            plt=plt,
            dos_plotter=dos_plotter,
            dos_options=dos_opts,
            dos_label=dos_label,
            fonts=fonts,
            style=style,
            no_base_style=no_base_style,
            spin=spin,
        )
    else:
        plt = plotter.get_plot(
            zero_to_efermi=True,
            ymin=ymin,
            ymax=ymax,
            height=height,
            width=width,
            vbm_cbm_marker=vbm_cbm_marker,
            ylabel=ylabel,
            plt=plt,
            dos_plotter=dos_plotter,
            dos_options=dos_opts,
            dos_label=dos_label,
            fonts=fonts,
            style=style,
            no_base_style=no_base_style,
            spin=spin,
        )

    if save_files:
        basename = f"band.{image_format}"
        filename = f"{prefix}_{basename}" if prefix else basename
        if directory:
            filename = os.path.join(directory, filename)
        plt.savefig(filename, format=image_format, dpi=dpi, bbox_inches="tight")

        written = [filename]
        written += save_data_files(bs, prefix=prefix, directory=directory)
        return written

    else:
        return plt
Ejemplo n.º 3
0
def dosplot(
    filename=None,
    code="vasp",
    prefix=None,
    directory=None,
    elements=None,
    lm_orbitals=None,
    atoms=None,
    spin=None,
    subplot=False,
    shift=True,
    total_only=False,
    plot_total=True,
    legend_on=True,
    legend_frame_on=False,
    legend_cutoff=3.0,
    gaussian=None,
    colours=None,
    height=6.0,
    width=8.0,
    xmin=-6.0,
    xmax=6.0,
    num_columns=2,
    xlabel="Energy (eV)",
    ylabel="DOS",
    yscale=1,
    zero_energy=None,
    zero_line=False,
    style=None,
    no_base_style=False,
    image_format="pdf",
    dpi=400,
    plt=None,
    fonts=None,
):
    """A script to plot the density of states from a vasprun.xml file.

    Args:
        filename (:obj:`str`, optional): Path to a DOS data file (can be
            gzipped). The preferred file type depends on the electronic
            structure code: vasprun.xml (VASP); *.bands (CASTEP); dos.*
            (Questaal).
        code (:obj:`str`, optional): Electronic structure code used ('vasp',
              'castep' or 'questaal'). Note that for Castep only a rough TDOS
              is available, assembled by sampling the eigenvalues.
        prefix (:obj:`str`, optional): Prefix for file names.
        directory (:obj:`str`, optional): The directory in which to save files.
        elements (:obj:`dict`, optional): The elements and orbitals to extract
            from the projected density of states. Should be provided as a
            :obj:`dict` with the keys as the element names and corresponding
            values as a :obj:`tuple` of orbitals. For example, the following
            would extract the Bi s, px, py and d orbitals::

                {'Bi': ('s', 'px', 'py', 'd')}

            If an element is included with an empty :obj:`tuple`, all orbitals
            for that species will be extracted. If ``elements`` is not set or
            set to ``None``, all elements for all species will be extracted.
        lm_orbitals (:obj:`dict`, optional): The orbitals to decompose into
            their lm contributions (e.g. p -> px, py, pz). Should be provided
            as a :obj:`dict`, with the elements names as keys and a
            :obj:`tuple` of orbitals as the corresponding values. For example,
            the following would be used to decompose the oxygen p and d
            orbitals::

                {'O': ('p', 'd')}

        atoms (:obj:`dict`, optional): Which atomic sites to use when
            calculating the projected density of states. Should be provided as
            a :obj:`dict`, with the element names as keys and a :obj:`tuple` of
            :obj:`int` specifying the atomic indices as the corresponding
            values. The elemental projected density of states will be summed
            only over the atom indices specified. If an element is included
            with an empty :obj:`tuple`, then all sites for that element will
            be included. The indices are 0 based for each element specified in
            the POSCAR. For example, the following will calculate the density
            of states for the first 4 Sn atoms and all O atoms in the
            structure::

                {'Sn': (1, 2, 3, 4), 'O': (, )}

            If ``atoms`` is not set or set to ``None`` then all atomic sites
            for all elements will be considered.
        spin (:obj:`Spin`, optional): Plot only one spin channel from a
            spin-polarised calculation; "up" or "1" for spin up only, "down" or
            "-1" for spin down only. Defaults to ``None``.
        subplot (:obj:`bool`, optional): Plot the density of states for each
            element on separate subplots. Defaults to ``False``.
        shift (:obj:`bool`, optional): Shift the energies such that the valence
            band maximum (or Fermi level for metals) is at 0 eV. Defaults to
            ``True``.
        total_only (:obj:`bool`, optional): Only extract the total density of
            states. Defaults to ``False``.
        plot_total (:obj:`bool`, optional): Plot the total density of states.
            Defaults to ``True``.
        legend_on (:obj:`bool`, optional): Plot the graph legend. Defaults
            to ``True``.
        legend_frame_on (:obj:`bool`, optional): Plot a frame around the
            graph legend. Defaults to ``False``.
        legend_cutoff (:obj:`float`, optional): The cut-off (in % of the
            maximum density of states within the plotting range) for an
            elemental orbital to be labelled in the legend. This prevents
            the legend from containing labels for orbitals that have very
            little contribution in the plotting range.
        gaussian (:obj:`float`, optional): Broaden the density of states using
            convolution with a gaussian function. This parameter controls the
            sigma or standard deviation of the gaussian distribution.
        height (:obj:`float`, optional): The height of the plot.
        width (:obj:`float`, optional): The width of the plot.
        xmin (:obj:`float`, optional): The minimum energy on the x-axis.
        xmax (:obj:`float`, optional): The maximum energy on the x-axis.
        num_columns (:obj:`int`, optional): The number of columns in the
            legend.
        colours (:obj:`dict`, optional): Use custom colours for specific
            element and orbital combinations. Specified as a :obj:`dict` of
            :obj:`dict` of the colours. For example::

                {
                    'Sn': {'s': 'r', 'p': 'b'},
                    'O': {'s': '#000000'}
                }

            The colour can be a hex code, series of rgb value, or any other
            format supported by matplotlib.
        xlabel (:obj:`str`, optional): Label/units for x-axis (i.e. energy)
        ylabel (:obj:`str`, optional): Label/units for y-axis (i.e. DOS)
        yscale (:obj:`float`, optional): Scaling factor for the y-axis.
        zero_line (:obj:`bool`, optional): Plot vertical line at energy zero.
        zero_energy (:obj:`float`, optional): Zero energy reference (e.g. Fermi
            energy from sc-fermi.) If not given, behaviour is determined by
            boolean ``shift``.
        style (:obj:`list` or :obj:`str`, optional): (List of) matplotlib style
            specifications, to be composed on top of Sumo base style.
        no_base_style (:obj:`bool`, optional): Prevent use of sumo base style.
            This can make alternative styles behave more predictably.
        image_format (:obj:`str`, optional): The image file format. Can be any
            format supported by matplotlib, including: png, jpg, pdf, and svg.
            Defaults to pdf.
        dpi (:obj:`int`, optional): The dots-per-inch (pixel density) for
            the image.
        plt (:obj:`matplotlib.pyplot`, optional): A
            :obj:`matplotlib.pyplot` object to use for plotting.
        fonts (:obj:`list`, optional): Fonts to use in the plot. Can be a
            a single font, specified as a :obj:`str`, or several fonts,
            specified as a :obj:`list` of :obj:`str`.

    Returns:
        A matplotlib pyplot object.
    """

    if code.lower() == "vasp":
        if not filename:
            if os.path.exists("vasprun.xml"):
                filename = "vasprun.xml"
            elif os.path.exists("vasprun.xml.gz"):
                filename = "vasprun.xml.gz"
            else:
                logging.error("ERROR: No vasprun.xml found!")
                sys.exit()

        dos, pdos = load_dos(filename, elements, lm_orbitals, atoms, gaussian,
                             total_only)

    elif code.lower() == "castep":

        if filename:
            bands_file = filename
        else:
            band_candidates = glob("*.bands")
            if len(band_candidates) == 0:
                logging.error("ERROR: No *.bands file found!")
                sys.exit()
            elif len(band_candidates) == 1:
                bands_file = band_candidates[0]
            else:
                logging.error("ERROR: Too many *.bands files found!")
                sys.exit()
        pdos_file = _replace_ext(bands_file, "pdos_bin")
        cell_file = _replace_ext(bands_file, "cell")
        pdos_file = pdos_file if os.path.isfile(pdos_file) else None
        cell_file = cell_file if os.path.isfile(cell_file) else None

        if not total_only:
            # Check if both pdos_bin and cell files are present.
            # If not, we cannot plot the PDOS
            if pdos_file is not None:
                if cell_file is None:
                    logging.info("Plotting PDOS requires the .cell file to be "
                                 "present; falling back to TDOS.")
                    pdos_file = None
                else:
                    logging.info(f"Found PDOS binary file {pdos_file}; "
                                 "including PDOS in the plot.")
            else:
                logging.info("PDOS not available, falling back to TDOS.")

        dos, pdos = sumo.io.castep.read_dos(
            bands_file,
            pdos_file=pdos_file,
            cell_file=cell_file,
            gaussian=gaussian,
            emin=xmin,
            emax=xmax,
            lm_orbitals=lm_orbitals,
            elements=elements,
            total_only=total_only,
            atoms=atoms,
        )

    elif code.lower() == "questaal":
        if filename:
            pdos_file = filename
            ext = pdos_file.split(".")[-1]
        else:
            pdos_candidates = glob("dos.*")
            for candidate in pdos_candidates:
                if candidate.split(".")[-1] in ("pdf", "png", "svg", "jpg",
                                                "jpeg"):
                    continue
                elif candidate.split(".")[-1].lower() in ("gz", "z", "bz2"):
                    pdos_file = candidate
                    ext = candidate.split(".")[-2]
                    break
                else:
                    pdos_file = candidate
                    ext = candidate.split(".")[-1]
                    break
            else:
                raise ValueError("No questaal dos file found")

        if os.path.exists(f"tdos.{ext}"):
            tdos_file = f"tdos.{ext}"
        else:
            tdos_file = None
        if os.path.exists(f"site.{ext}"):
            site_file = f"site.{ext}"
        else:
            site_file = None

        if shift:
            logging.warning(
                "Fermi level shift requested, but not implemented for Questaal DOS."
            )

        dos, pdos = sumo.io.questaal.read_dos(
            pdos_file=pdos_file,
            tdos_file=tdos_file,
            site_file=site_file,
            ry=True,
            gaussian=gaussian,
            total_only=total_only,
            elements=elements,
            lm_orbitals=lm_orbitals,
            atoms=atoms,
        )

    else:
        logging.error(f"ERROR: Unrecognised code: {code}")
        return

    save_files = False if plt else True  # don't save if pyplot object provided

    spin = string_to_spin(spin)  # Convert spin name to pymatgen Spin object
    plotter = SDOSPlotter(dos, pdos)

    plt = plotter.get_plot(
        subplot=subplot,
        width=width,
        height=height,
        xmin=xmin,
        xmax=xmax,
        yscale=yscale,
        colours=colours,
        plot_total=plot_total,
        legend_on=legend_on,
        num_columns=num_columns,
        legend_frame_on=legend_frame_on,
        xlabel=xlabel,
        ylabel=ylabel,
        zero_line=zero_line,
        zero_to_efermi=shift,
        zero_energy=zero_energy,
        legend_cutoff=legend_cutoff,
        dpi=dpi,
        plt=plt,
        fonts=fonts,
        style=style,
        no_base_style=no_base_style,
        spin=spin,
    )

    if save_files:
        basename = f"dos.{image_format}"
        filename = f"{prefix}_{basename}" if prefix else basename
        if directory:
            filename = os.path.join(directory, filename)
        plt.savefig(filename,
                    format=image_format,
                    dpi=dpi,
                    bbox_inches="tight")
        write_files(dos, pdos, prefix=prefix, directory=directory)
    else:
        return plt