Beispiel #1
0
    def test_sanitise_label(self):
        for label_in, label_out in (('X', 'X'), ('X 1', 'X 1'), ('X @', 'X '),
                                    ('X@', 'X'), ('X@@@', 'X'), ('HEX', 'HEX'),
                                    ('HE@X', 'HE@X'), ('HE@X 1', 'HE@X 1'),
                                    ('@', None), ('@X', None), ('@HEX', None)):

            self.assertEqual(SBSPlotter._sanitise_label(label_in), label_out)
Beispiel #2
0
    def test_sanitise_label_group(self):
        for label_in, label_out in (
            ("X", "X"),
            ("X 1", "X 1"),
            ("X @", "X "),
            ("X@", "X"),
            ("X@@@", "X"),
            ("HEX", "HEX"),
            ("HE@X", "HE@X"),
            ("HE@X 1", "HE@X 1"),
            ("@", None),
            ("@X", None),
            ("@HEX", None),
            (r"X$\mid$Y", r"X$\mid$Y"),
            (r"X$\mid$Y$\mid$Z", r"X$\mid$Y$\mid$Z"),
            (r"@X$\mid$Y$\mid$Z", r"Y$\mid$Z"),
            (r"X$\mid$@Y$\mid$Z", r"X$\mid$Z"),
            (r"X$\mid$@Z", r"X"),
            (r"X@$\mid$Y", r"X$\mid$Y"),
            (r"X@$\mid$Y@@", r"X$\mid$Y"),
            (r"@X@$\mid$Y@", r"Y"),
            (r"X@$\mid$@Y", r"X"),
            (r"@X@$\mid$@Y", None),
        ):

            self.assertEqual(SBSPlotter._sanitise_label_group(label_in),
                             label_out)
Beispiel #3
0
    def process_item(self, mat):
        """
        Process the tasks and materials into just a list of materials

        Args:
            mat (dict): material document

        Returns:
            (dict): electronic_structure document
        """
        d = {self.electronic_structure.key: mat[self.materials.key]}
        self.logger.info("Processing: {}".format(mat[self.materials.key]))

        bs = build_bs(mat["bandstructure"]["bs"], mat)
        dos = CompleteDos.from_dict(mat["bandstructure"]["dos"])

        if bs and dos:
            try:
                pdos = get_pdos(dos)
                dos_plotter = SDOSPlotter(dos, pdos)
                bs_plotter = SBSPlotter(bs)
                plt = bs_plotter.get_plot(dos_plotter=dos_plotter,
                                          **self.plot_options)
                d["plot"] = image_from_plot(plt)
                plt.close()
            except Exception:
                traceback.print_exc()
                self.logger.warning(
                    "Caught error in bandstructure plotting for {}: {}".format(
                        mat[self.materials.key], traceback.format_exc()))

        # Reduced Band structure plot
        try:
            gap = bs.get_band_gap()["energy"]
            plot_data = bs_plotter.bs_plot_data()
            d["bs_plot_small"] = get_small_plot(plot_data, gap)
        except Exception:
            self.logger.warning(
                "Caught error in generating reduced bandstructure plot for {}: {}"
                .format(mat[self.materials.key], traceback.format_exc()))

        # Store task_ids
        for k in ["bs_task", "dos_task", "uniform_task"]:
            if k in mat["bandstructure"]:
                d[k] = mat["bandstructure"][k]

        return d
Beispiel #4
0
 def get_bs_plotter(self, line_density=100, kpath=None):
     lm_bs = self.interpolater.get_line_mode_band_structure(
         line_density=line_density,
         kpath=kpath,
         symprec=self.symprec,
         energy_cutoff=self.energy_cutoff,
     )
     return SBSPlotter(lm_bs)
Beispiel #5
0
    def test_sanitise_label(self):
        for label_in, label_out in (
            ("X", "X"),
            ("X 1", "X 1"),
            ("X @", "X "),
            ("X@", "X"),
            ("X@@@", "X"),
            ("HEX", "HEX"),
            ("HE@X", "HE@X"),
            ("HE@X 1", "HE@X 1"),
            ("@", None),
            ("@X", None),
            ("@HEX", None),
        ):

            self.assertEqual(SBSPlotter._sanitise_label(label_in), label_out)
Beispiel #6
0
    def test_sanitise_label_group(self):
        for label_in, label_out in (('X', 'X'), ('X 1', 'X 1'), ('X @', 'X '),
                                    ('X@', 'X'), ('X@@@', 'X'), ('HEX', 'HEX'),
                                    ('HE@X', 'HE@X'), ('HE@X 1', 'HE@X 1'),
                                    ('@', None), ('@X', None), ('@HEX', None),
                                    (r'X$\mid$Y',
                                     r'X$\mid$Y'), (r'X$\mid$Y$\mid$Z',
                                                    r'X$\mid$Y$\mid$Z'),
                                    (r'@X$\mid$Y$\mid$Z',
                                     r'Y$\mid$Z'), (r'X$\mid$@Y$\mid$Z',
                                                    r'X$\mid$Z'),
                                    (r'X$\mid$@Z',
                                     r'X'), (r'X@$\mid$Y',
                                             r'X$\mid$Y'), (r'X@$\mid$Y@@',
                                                            r'X$\mid$Y'),
                                    (r'@X@$\mid$Y@',
                                     r'Y'), (r'X@$\mid$@Y',
                                             r'X'), (r'@X@$\mid$@Y', None)):

            self.assertEqual(SBSPlotter._sanitise_label_group(label_in),
                             label_out)
Beispiel #7
0
def bandplot(
    filenames=None,
    code="vasp",
    prefix=None,
    directory=None,
    vbm_cbm_marker=False,
    projection_selection=None,
    mode="rgb",
    normalise="all",
    interpolate_factor=4,
    circle_size=150,
    dos_file=None,
    cart_coords=False,
    scissor=None,
    ylabel="Energy (eV)",
    dos_label=None,
    elements=None,
    lm_orbitals=None,
    atoms=None,
    spin=None,
    total_only=False,
    plot_total=True,
    legend_cutoff=3,
    gaussian=None,
    height=None,
    width=None,
    ymin=-6.0,
    ymax=6.0,
    colours=None,
    yscale=1,
    style=None,
    no_base_style=False,
    image_format="pdf",
    dpi=400,
    plt=None,
    fonts=None,
):
    """Plot electronic band structure diagrams from vasprun.xml files.

    Args:
        filenames (:obj:`str` or :obj:`list`, optional): Path to input files:

            Vasp:
                Use vasprun.xml or vasprun.xml.gz file.
            Questaal:
                Path to a bnds.ext file. The extension will also be used to
                find site.ext and syml.ext files in the same directory.
            Castep:
                Path to a seedname.bands file. The prefix ("seedname") is used
                to locate a seedname.cell file in the same directory and read
                in the positions of high-symmetry points.

            If no filenames are provided, sumo
            will search for vasprun.xml or vasprun.xml.gz files in folders
            named 'split-0*'. Failing that, the code will look for a vasprun in
            the current directory. If a :obj:`list` of vasprun files is
            provided, these will be combined into a single band structure.

        code (:obj:`str`, optional): Calculation type. Default is 'vasp';
            'questaal' and 'castep' also supported (with a reduced
            feature-set).
        prefix (:obj:`str`, optional): Prefix for file names.
        directory (:obj:`str`, optional): The directory in which to save files.
        vbm_cbm_marker (:obj:`bool`, optional): Plot markers to indicate the
            VBM and CBM locations.
        projection_selection (list): A list of :obj:`tuple` or :obj:`string`
            identifying which elements and orbitals to project on to the
            band structure. These can be specified by both element and
            orbital, for example, the following will project the Bi s, p
            and S p orbitals::

                [('Bi', 's'), ('Bi', 'p'), ('S', 'p')]

            If just the element is specified then all the orbitals of
            that element are combined. For example, to sum all the S
            orbitals::

                [('Bi', 's'), ('Bi', 'p'), 'S']

            You can also choose to sum particular orbitals by supplying a
            :obj:`tuple` of orbitals. For example, to sum the S s, p, and
            d orbitals into a single projection::

                [('Bi', 's'), ('Bi', 'p'), ('S', ('s', 'p', 'd'))]

            If ``mode = 'rgb'``, a maximum of 3 orbital/element
            combinations can be plotted simultaneously (one for red, green
            and blue), otherwise an unlimited number of elements/orbitals
            can be selected.
        mode (:obj:`str`, optional): Type of projected band structure to
            plot. Options are:

                "rgb"
                    The band structure line color depends on the character
                    of the band. Each element/orbital contributes either
                    red, green or blue with the corresponding line colour a
                    mixture of all three colours. This mode only supports
                    up to 3 elements/orbitals combinations. The order of
                    the ``selection`` :obj:`tuple` determines which colour
                    is used for each selection.
                "stacked"
                    The element/orbital contributions are drawn as a
                    series of stacked circles, with the colour depending on
                    the composition of the band. The size of the circles
                    can be scaled using the ``circle_size`` option.

        normalise (:obj:`str`, optional): Normalisation the projections.
            Options are:

              * ``'all'``: Projections normalised against the sum of all
                   other projections.
              * ``'select'``: Projections normalised against the sum of the
                   selected projections.
              * ``None``: No normalisation performed.

        circle_size (:obj:`float`, optional): The area of the circles used
            when ``mode = 'stacked'``.
        cart_coords (:obj:`bool`, optional): Whether the k-points are read as
            cartesian or reciprocal coordinates. This is only required for
            Questaal output; Vasp output is less ambiguous. Defaults to
            ``False`` (fractional coordinates).
        scissor (:obj:`float`, optional): Apply a scissor operator (rigid shift
            of the CBM), use with caution if applying to metals.
        dos_file (:obj:'str', optional): Path to vasprun.xml file from which to
            read the density of states information. If set, the density of
            states will be plotted alongside the bandstructure.
        elements (:obj:`dict`, optional): The elements and orbitals to extract
            from the projected density of states. Should be provided as a
            :obj:`dict` with the keys as the element names and corresponding
            values as a :obj:`tuple` of orbitals. For example, the following
            would extract the Bi s, px, py and d orbitals::

                {'Bi': ('s', 'px', 'py', 'd')}

            If an element is included with an empty :obj:`tuple`, all orbitals
            for that species will be extracted. If ``elements`` is not set or
            set to ``None``, all elements for all species will be extracted.
        lm_orbitals (:obj:`dict`, optional): The orbitals to decompose into
            their lm contributions (e.g. p -> px, py, pz). Should be provided
            as a :obj:`dict`, with the elements names as keys and a
            :obj:`tuple` of orbitals as the corresponding values. For example,
            the following would be used to decompose the oxygen p and d
            orbitals::

                {'O': ('p', 'd')}

        atoms (:obj:`dict`, optional): Which atomic sites to use when
            calculating the projected density of states. Should be provided as
            a :obj:`dict`, with the element names as keys and a :obj:`tuple` of
            :obj:`int` specifying the atomic indices as the corresponding
            values. The elemental projected density of states will be summed
            only over the atom indices specified. If an element is included
            with an empty :obj:`tuple`, then all sites for that element will
            be included. The indices are 0 based for each element specified in
            the POSCAR. For example, the following will calculate the density
            of states for the first 4 Sn atoms and all O atoms in the
            structure::

                {'Sn': (1, 2, 3, 4), 'O': (, )}

            If ``atoms`` is not set or set to ``None`` then all atomic sites
            for all elements will be considered.
        spin (:obj:`Spin`, optional): Plot only one spin channel from a
            spin-polarised calculation; "up" or "1" for spin up only, "down" or
            "-1" for spin down only. Defaults to ``None``.
        total_only (:obj:`bool`, optional): Only extract the total density of
            states. Defaults to ``False``.
        plot_total (:obj:`bool`, optional): Plot the total density of states.
            Defaults to ``True``.
        legend_cutoff (:obj:`float`, optional): The cut-off (in % of the
            maximum density of states within the plotting range) for an
            elemental orbital to be labelled in the legend. This prevents
            the legend from containing labels for orbitals that have very
            little contribution in the plotting range.
        gaussian (:obj:`float`, optional): Broaden the density of states using
            convolution with a gaussian function. This parameter controls the
            sigma or standard deviation of the gaussian distribution.
        height (:obj:`float`, optional): The height of the plot.
        width (:obj:`float`, optional): The width of the plot.
        ymin (:obj:`float`, optional): The minimum energy on the y-axis.
        ymax (:obj:`float`, optional): The maximum energy on the y-axis.
        style (:obj:`list` or :obj:`str`, optional): (List of) matplotlib style
            specifications, to be composed on top of Sumo base style.
        no_base_style (:obj:`bool`, optional): Prevent use of sumo base style.
            This can make alternative styles behave more predictably.
        colours (:obj:`dict`, optional): Use custom colours for specific
            element and orbital combinations. Specified as a :obj:`dict` of
            :obj:`dict` of the colours. For example::

                {
                    'Sn': {'s': 'r', 'p': 'b'},
                    'O': {'s': '#000000'}
                }

            The colour can be a hex code, series of rgb value, or any other
            format supported by matplotlib.
        yscale (:obj:`float`, optional): Scaling factor for the y-axis.
        image_format (:obj:`str`, optional): The image file format. Can be any
            format supported by matplotlib, including: png, jpg, pdf, and svg.
            Defaults to pdf.
        dpi (:obj:`int`, optional): The dots-per-inch (pixel density) for
            the image.
        plt (:obj:`matplotlib.pyplot`, optional): A
            :obj:`matplotlib.pyplot` object to use for plotting.
        fonts (:obj:`list`, optional): Fonts to use in the plot. Can be a
            a single font, specified as a :obj:`str`, or several fonts,
            specified as a :obj:`list` of :obj:`str`.

    Returns:
        If ``plt`` set then the ``plt`` object will be returned. Otherwise, the
        method will return a :obj:`list` of filenames written to disk.
    """
    if not filenames:
        filenames = find_vasprun_files()
    elif isinstance(filenames, str):
        filenames = [filenames]

    # only load the orbital projects if we definitely need them
    parse_projected = True if projection_selection else False

    # now load all the band structure data and combine using the
    # get_reconstructed_band_structure function from pymatgen
    bandstructures = []
    if code == "vasp":
        for vr_file in filenames:
            vr = BSVasprun(vr_file, parse_projected_eigen=parse_projected)
            bs = vr.get_band_structure(line_mode=True)
            bandstructures.append(bs)
        bs = get_reconstructed_band_structure(bandstructures)
    elif code == "castep":
        for bands_file in filenames:
            cell_file = _replace_ext(bands_file, "cell")
            if os.path.isfile(cell_file):
                logging.info(f"Found cell file {cell_file}...")
            else:
                logging.info(f"Did not find cell file {cell_file}...")
                cell_file = None
            bs = castep_band_structure(bands_file, cell_file=cell_file)
            bandstructures.append(bs)
        bs = get_reconstructed_band_structure(bandstructures)
    elif code == "questaal":
        bnds_file = filenames[0]
        ext = bnds_file.split(".")[-1]
        bnds_folder = os.path.join(bnds_file, os.path.pardir)

        site_file = os.path.abspath(os.path.join(bnds_folder, f"site.{ext}"))

        if os.path.isfile(site_file):
            logging.info("site file found, reading lattice...")
            site_data = QuestaalSite.from_file(site_file)
            bnds_lattice = site_data.structure.lattice
            alat = site_data.alat
        else:
            raise OSError(
                "Site file {} not found: "
                "needed to determine lattice".format(site_file)
            )

        syml_file = os.path.abspath(os.path.join(bnds_folder, f"syml.{ext}"))
        if os.path.isfile(syml_file):
            logging.info("syml file found, reading special-point labels...")
            bnds_labels = labels_from_syml(syml_file)
        else:
            logging.info("syml file not found, band structure lacks labels")
            bnds_labels = {}

        bs = questaal_band_structure(
            bnds_file,
            bnds_lattice,
            alat=alat,
            labels=bnds_labels,
            coords_are_cartesian=cart_coords,
        )

    # currently not supported as it is a pain to make subplots within subplots,
    # although need to check this is still the case
    if "split" in mode and dos_file:
        logging.error(
            "ERROR: Plotting split projected band structure with DOS"
            " not supported.\nPlease use --projected-rgb or "
            "--projected-stacked options."
        )
        sys.exit()

    if projection_selection and mode == "rgb" and len(projection_selection) > 3:
        logging.error(
            "ERROR: RGB projected band structure only "
            "supports up to 3 elements/orbitals."
            "\nUse alternative --mode setting."
        )
        sys.exit()

    # don't save if pyplot object provided
    save_files = False if plt else True

    dos_plotter = None
    dos_opts = None
    if dos_file:
        if code == "vasp":
            dos, pdos = load_dos(
                dos_file, elements, lm_orbitals, atoms, gaussian, total_only
            )
        elif code == "castep":
            pdos_file = None
            if cell_file:
                pdos_file = _replace_ext(cell_file, "pdos_bin")
                if not os.path.isfile(pdos_file):
                    pdos_file = None
                    logging.info(
                        f"PDOS file {pdos_file} does not exist, "
                        "falling back to TDOS."
                    )
                else:
                    logging.info(f"Found PDOS file {pdos_file}")
            else:
                logging.info(
                    f"Cell file {cell_file} does not exist, " "cannot plot PDOS."
                )

            dos, pdos = read_castep_dos(
                dos_file,
                pdos_file=pdos_file,
                cell_file=cell_file,
                gaussian=gaussian,
                lm_orbitals=lm_orbitals,
                elements=elements,
                efermi_to_vbm=True,
            )

        dos_plotter = SDOSPlotter(dos, pdos)
        dos_opts = {
            "plot_total": plot_total,
            "legend_cutoff": legend_cutoff,
            "colours": colours,
            "yscale": yscale,
        }

    if scissor:
        bs = bs.apply_scissor(scissor)

    spin = string_to_spin(spin)  # Convert spin name to pymatgen Spin object
    plotter = SBSPlotter(bs)
    if projection_selection:
        plt = plotter.get_projected_plot(
            projection_selection,
            mode=mode,
            normalise=normalise,
            interpolate_factor=interpolate_factor,
            circle_size=circle_size,
            zero_to_efermi=True,
            ymin=ymin,
            ymax=ymax,
            height=height,
            width=width,
            vbm_cbm_marker=vbm_cbm_marker,
            ylabel=ylabel,
            plt=plt,
            dos_plotter=dos_plotter,
            dos_options=dos_opts,
            dos_label=dos_label,
            fonts=fonts,
            style=style,
            no_base_style=no_base_style,
            spin=spin,
        )
    else:
        plt = plotter.get_plot(
            zero_to_efermi=True,
            ymin=ymin,
            ymax=ymax,
            height=height,
            width=width,
            vbm_cbm_marker=vbm_cbm_marker,
            ylabel=ylabel,
            plt=plt,
            dos_plotter=dos_plotter,
            dos_options=dos_opts,
            dos_label=dos_label,
            fonts=fonts,
            style=style,
            no_base_style=no_base_style,
            spin=spin,
        )

    if save_files:
        basename = f"band.{image_format}"
        filename = f"{prefix}_{basename}" if prefix else basename
        if directory:
            filename = os.path.join(directory, filename)
        plt.savefig(filename, format=image_format, dpi=dpi, bbox_inches="tight")

        written = [filename]
        written += save_data_files(bs, prefix=prefix, directory=directory)
        return written

    else:
        return plt
Beispiel #8
0
    def plot_bs(db_credentials,
                material_id,
                task_query=None,
                filename=None,
                **plot_kwargs):
        task_query = task_query if task_query else {}

        # get database connections
        db = MongoClient(db_credentials["host"],
                         db_credentials["port"])[db_credentials["database"]]
        # db.authenticate(db_credentials["username"], db_credentials["password"])
        calc_db = VaspCalcDb(db_credentials["host"], db_credentials["port"],
                             db_credentials["database"], "tasks",
                             db_credentials["username"],
                             db_credentials["password"])

        material_result = db.materials.find_one({"material_id": material_id})
        if not material_result:
            raise RuntimeError(
                "Material id not found in database: {}".format(material_id))

        tasks_ids = [
            int(tid.split("-")[-1])
            for tid in material_result["_tasksbuilder"]["all_task_ids"]
        ]
        # get all band structure tasks
        bs_query = {
            "task_id": {
                "$in": tasks_ids
            },
            "task_label": {
                "$in": ["nscf line"]
            }
        }
        bs_query.update(task_query)
        bs_tasks = list(db.tasks.find(bs_query))
        print(bs_tasks)
        if not bs_tasks:
            raise RuntimeError(
                "No band structure available for: {}".format(material_id))

        # get the band structure of the last band structure task
        band_structure = calc_db.get_band_structure(
            task_id=bs_tasks[-1]["task_id"])
        bs_plotter = SBSPlotter(band_structure)

        # get the DOS tasks
        dos_query = {
            "task_id": {
                "$in": tasks_ids
            },
            "task_label": {
                "$in": ["nscf uniform"]
            }
        }
        dos_query.update(task_query)
        dos_tasks = list(db.tasks.find(dos_query))

        if dos_tasks:
            # get the DOS for the last DOS task
            dos = calc_db.get_dos(task_id=dos_tasks[-1]["task_id"])
            pdos = get_pdos(dos)

            # generate a combined DOS and band structure plot
            dos_plotter = SDOSPlotter(dos, pdos)

            # set some better defaults for BS+DOS plots but don't overwrite user
            # settings
            if not 'dos_aspect' in plot_kwargs:
                plot_kwargs['dos_aspect'] = 4

            if not 'width' in plot_kwargs:
                plot_kwargs['width'] = 8

            plt = bs_plotter.get_plot(dos_plotter=dos_plotter, **plot_kwargs)

        else:
            # if no DOS just plot band structure only
            plt = bs_plotter.get_plot(**plot_kwargs)

        if filename:
            plt.savefig(filename, dpi=400, bbox_inches='tight')
            return plt

        else:
            figfile = BytesIO()
            plt.savefig(figfile, format='png', dpi=400, bbox_inches='tight')

            # rewind to beginning of file and base64 encode
            figfile.seek(0)
            figdata_png = base64.b64encode(figfile.getvalue())
            return figdata_png
Beispiel #9
0
def bandplot(filenames=None,
             prefix=None,
             directory=None,
             vbm_cbm_marker=False,
             projection_selection=None,
             mode='rgb',
             interpolate_factor=4,
             circle_size=150,
             dos_file=None,
             ylabel='Energy (eV)',
             dos_label=None,
             elements=None,
             lm_orbitals=None,
             atoms=None,
             total_only=False,
             plot_total=True,
             legend_cutoff=3,
             gaussian=None,
             height=6.,
             width=6.,
             ymin=-6.,
             ymax=6.,
             colours=None,
             yscale=1,
             image_format='pdf',
             dpi=400,
             plt=None,
             fonts=None):
    """Plot electronic band structure diagrams from vasprun.xml files.

    Args:
        filenames (:obj:`str` or :obj:`list`, optional): Path to vasprun.xml
            or vasprun.xml.gz file. If no filenames are provided, the code
            will search for vasprun.xml or vasprun.xml.gz files in folders
            named 'split-0*'. Failing that, the code will look for a vasprun in
            the current directory. If a :obj:`list` of vasprun files is
            provided, these will be combined into a single band structure.
        prefix (:obj:`str`, optional): Prefix for file names.
        directory (:obj:`str`, optional): The directory in which to save files.
        vbm_cbm_marker (:obj:`bool`, optional): Plot markers to indicate the
            VBM and CBM locations.
        projection_selection (list): A list of :obj:`tuple` or :obj:`string`
            identifying which elements and orbitals to project on to the
            band structure. These can be specified by both element and
            orbital, for example, the following will project the Bi s, p
            and S p orbitals::

                [('Bi', 's'), ('Bi', 'p'), ('S', 'p')]

            If just the element is specified then all the orbitals of
            that element are combined. For example, to sum all the S
            orbitals::

                [('Bi', 's'), ('Bi', 'p'), 'S']

            You can also choose to sum particular orbitals by supplying a
            :obj:`tuple` of orbitals. For example, to sum the S s, p, and
            d orbitals into a single projection::

                [('Bi', 's'), ('Bi', 'p'), ('S', ('s', 'p', 'd'))]

            If ``mode = 'rgb'``, a maximum of 3 orbital/element
            combinations can be plotted simultaneously (one for red, green
            and blue), otherwise an unlimited number of elements/orbitals
            can be selected.
        mode (:obj:`str`, optional): Type of projected band structure to
            plot. Options are:

                "rgb"
                    The band structure line color depends on the character
                    of the band. Each element/orbital contributes either
                    red, green or blue with the corresponding line colour a
                    mixture of all three colours. This mode only supports
                    up to 3 elements/orbitals combinations. The order of
                    the ``selection`` :obj:`tuple` determines which colour
                    is used for each selection.
                "stacked"
                    The element/orbital contributions are drawn as a
                    series of stacked circles, with the colour depending on
                    the composition of the band. The size of the circles
                    can be scaled using the ``circle_size`` option.
        circle_size (:obj:`float`, optional): The area of the circles used
            when ``mode = 'stacked'``.
        dos_file (:obj:'str', optional): Path to vasprun.xml file from which to
            read the density of states information. If set, the density of
            states will be plotted alongside the bandstructure.
        elements (:obj:`dict`, optional): The elements and orbitals to extract
            from the projected density of states. Should be provided as a
            :obj:`dict` with the keys as the element names and corresponding
            values as a :obj:`tuple` of orbitals. For example, the following
            would extract the Bi s, px, py and d orbitals::

                {'Bi': ('s', 'px', 'py', 'd')}

            If an element is included with an empty :obj:`tuple`, all orbitals
            for that species will be extracted. If ``elements`` is not set or
            set to ``None``, all elements for all species will be extracted.
        lm_orbitals (:obj:`dict`, optional): The orbitals to decompose into
            their lm contributions (e.g. p -> px, py, pz). Should be provided
            as a :obj:`dict`, with the elements names as keys and a
            :obj:`tuple` of orbitals as the corresponding values. For example,
            the following would be used to decompose the oxygen p and d
            orbitals::

                {'O': ('p', 'd')}

        atoms (:obj:`dict`, optional): Which atomic sites to use when
            calculating the projected density of states. Should be provided as
            a :obj:`dict`, with the element names as keys and a :obj:`tuple` of
            :obj:`int` specifying the atomic indices as the corresponding
            values. The elemental projected density of states will be summed
            only over the atom indices specified. If an element is included
            with an empty :obj:`tuple`, then all sites for that element will
            be included. The indices are 0 based for each element specified in
            the POSCAR. For example, the following will calculate the density
            of states for the first 4 Sn atoms and all O atoms in the
            structure::

                {'Sn': (1, 2, 3, 4), 'O': (, )}

            If ``atoms`` is not set or set to ``None`` then all atomic sites
            for all elements will be considered.
        total_only (:obj:`bool`, optional): Only extract the total density of
            states. Defaults to ``False``.
        plot_total (:obj:`bool`, optional): Plot the total density of states.
            Defaults to ``True``.
        legend_cutoff (:obj:`float`, optional): The cut-off (in % of the
            maximum density of states within the plotting range) for an
            elemental orbital to be labelled in the legend. This prevents
            the legend from containing labels for orbitals that have very
            little contribution in the plotting range.
        gaussian (:obj:`float`, optional): Broaden the density of states using
            convolution with a gaussian function. This parameter controls the
            sigma or standard deviation of the gaussian distribution.
        height (:obj:`float`, optional): The height of the plot.
        width (:obj:`float`, optional): The width of the plot.
        ymin (:obj:`float`, optional): The minimum energy on the y-axis.
        ymax (:obj:`float`, optional): The maximum energy on the y-axis.
        colours (:obj:`dict`, optional): Use custom colours for specific
            element and orbital combinations. Specified as a :obj:`dict` of
            :obj:`dict` of the colours. For example::

                {
                    'Sn': {'s': 'r', 'p': 'b'},
                    'O': {'s': '#000000'}
                }

            The colour can be a hex code, series of rgb value, or any other
            format supported by matplotlib.
        yscale (:obj:`float`, optional): Scaling factor for the y-axis.
        image_format (:obj:`str`, optional): The image file format. Can be any
            format supported by matplotlib, including: png, jpg, pdf, and svg.
            Defaults to pdf.
        dpi (:obj:`int`, optional): The dots-per-inch (pixel density) for
            the image.
        plt (:obj:`matplotlib.pyplot`, optional): A
            :obj:`matplotlib.pyplot` object to use for plotting.
        fonts (:obj:`list`, optional): Fonts to use in the plot. Can be a
            a single font, specified as a :obj:`str`, or several fonts,
            specified as a :obj:`list` of :obj:`str`.

    Returns:
        If ``plt`` set then the ``plt`` object will be returned. Otherwise, the
        method will return a :obj:`list` of filenames written to disk.
    """
    if not filenames:
        filenames = find_vasprun_files()
    elif type(filenames) == str:
        filenames = [filenames]

    # only load the orbital projects if we definitely need them
    parse_projected = True if projection_selection else False

    # now load all the vaspruns and combine them together using the
    # get_reconstructed_band_structure function from pymatgen
    bandstructures = []
    for vr_file in filenames:
        vr = BSVasprun(vr_file, parse_projected_eigen=parse_projected)
        bs = vr.get_band_structure(line_mode=True)
        bandstructures.append(bs)
    bs = get_reconstructed_band_structure(bandstructures)

    # currently not supported as it is a pain to make subplots within subplots,
    # although need to check this is still the case
    if 'split' in mode and dos_file:
        logging.error('ERROR: Plotting split projected band structure with DOS'
                      ' not supported.\nPlease use --projected-rgb or '
                      '--projected-stacked options.')
        sys.exit()

    if (projection_selection and mode == 'rgb'
            and len(projection_selection) > 3):
        logging.error('ERROR: RGB projected band structure only '
                      'supports up to 3 elements/orbitals.'
                      '\nUse alternative --mode setting.')
        sys.exit()

    # don't save if pyplot object provided
    save_files = False if plt else True

    dos_plotter = None
    dos_opts = None
    if dos_file:
        dos, pdos = load_dos(dos_file, elements, lm_orbitals, atoms, gaussian,
                             total_only)
        dos_plotter = SDOSPlotter(dos, pdos)
        dos_opts = {
            'plot_total': plot_total,
            'legend_cutoff': legend_cutoff,
            'colours': colours,
            'yscale': yscale
        }

    plotter = SBSPlotter(bs)
    if projection_selection:
        plt = plotter.get_projected_plot(projection_selection,
                                         mode=mode,
                                         interpolate_factor=interpolate_factor,
                                         circle_size=circle_size,
                                         zero_to_efermi=True,
                                         ymin=ymin,
                                         ymax=ymax,
                                         height=height,
                                         width=width,
                                         vbm_cbm_marker=vbm_cbm_marker,
                                         ylabel=ylabel,
                                         plt=plt,
                                         dos_plotter=dos_plotter,
                                         dos_options=dos_opts,
                                         dos_label=dos_label,
                                         fonts=fonts)
    else:
        plt = plotter.get_plot(zero_to_efermi=True,
                               ymin=ymin,
                               ymax=ymax,
                               height=height,
                               width=width,
                               vbm_cbm_marker=vbm_cbm_marker,
                               ylabel=ylabel,
                               plt=plt,
                               dos_plotter=dos_plotter,
                               dos_options=dos_opts,
                               dos_label=dos_label,
                               fonts=fonts)

    if save_files:
        basename = 'band.{}'.format(image_format)
        filename = '{}_{}'.format(prefix, basename) if prefix else basename
        if directory:
            filename = os.path.join(directory, filename)
        plt.savefig(filename,
                    format=image_format,
                    dpi=dpi,
                    bbox_inches='tight')

        written = [filename]
        written += save_data_files(vr, bs, prefix=prefix, directory=directory)
        return written

    else:
        return plt
Beispiel #10
0
def bandplot_func(
        filenames=None,
        code='vasp',
        prefix=None,
        directory=None,
        vbm_cbm_marker=False,
        projection_selection=None,
        mode='rgb',
        pred=None,
        interpolate_factor=4,
        circle_size=150,
        dos_file=None,
        cart_coords=False,
        scissor=None,
        ylabel='Energy (eV)',
        dos_label=None,
        elements=None,
        lm_orbitals=None,
        atoms=None,
        spin=None,
        total_only=False,
        plot_total=True,
        legend_cutoff=3,
        gaussian=None,
        height=None,
        width=None,
        ymin=-6.,
        ymax=6.,
        colours=None,
        yscale=1,
        style=None,
        no_base_style=False,
        image_format='pdf',
        dpi=400,
        plt=None,
        fonts=None,
        boltz={
            "ifinter": "T",
            "lpfac": "10",
            "energy_range": "50",
            "curvature": "",
            "load": "T",
            'ismetaltolerance': '0.01'
        },
        nelec=0):
    if not filenames:
        filenames = find_vasprun_files()
    elif isinstance(filenames, str):
        filenames = [filenames]

    # only load the orbital projects if we definitely need them
    parse_projected = True if projection_selection else False

    # now load all the band structure data and combine using the
    # get_reconstructed_band_structure function from pymatgen
    bandstructures = []
    if code == 'vasp':
        for vr_file in filenames:
            vr = BSVasprun(vr_file, parse_projected_eigen=parse_projected)
            print("BSVasprun", type(vr), vr)
            print("vr.eigenvalues.keys()", type(vr.eigenvalues.keys()),
                  vr.eigenvalues.keys())
            if pred.any():
                # Fill in Model prediction
                model = BSVasprun(vr_file,
                                  parse_projected_eigen=parse_projected)
                print("pred", type(pred), pred.shape)
                pred = np.expand_dims(pred, axis=-1)
                for key in model.eigenvalues.keys():
                    key_last = key
                    print("model.eigenvalues[key][:, :, :].shape[0]", key,
                          type(model.eigenvalues[key][:, :, :]),
                          model.eigenvalues[key][:, :, :].shape,
                          pred[:, :, :].shape)
                    bands = min(model.eigenvalues[key][:, :, :].shape[1],
                                pred.shape[1])
                    print(
                        "bands", bands, "attention! max: ",
                        max(model.eigenvalues[key][:, :, :].shape[1],
                            pred.shape[1]))
                    print(
                        "equel False?",
                        np.sum(model.eigenvalues[key][:, :bands, :] -
                               pred[:, :bands, :]))
                    model.eigenvalues[key][:, :bands, :] = pred[:, :bands, :]
                    print(
                        "equel True?",
                        np.sum(model.eigenvalues[key][:, :bands, :] -
                               pred[:, :bands, :]))
                    print(
                        "equel model vr?",
                        np.sum(model.eigenvalues[key][:, :bands, :] -
                               vr.eigenvalues[key][:, :bands, :]))
                    # spin = 1  # for only plotting spin up oder down 1, -1
                # model.eigenvalues[key_last][:, :bands, :] = pred[:, :bands, :]

            #boltztrap={'ifinter':False,'lpfac':10,'energy_range':50,'curvature':False}):

            if bool(boltz['ifinter']):
                b_data = VasprunBSLoader(vr)
                model_data = VasprunBSLoader(model)
                print("BSVasprunLoader", type(b_data), b_data)
                b_inter = BztInterpolator(b_data,
                                          lpfac=int(boltz['lpfac']),
                                          energy_range=float(
                                              boltz['energy_range']),
                                          curvature=bool(boltz['curvature']),
                                          save_bztInterp=True,
                                          load_bztInterp=bool(boltz['load']))
                model_inter = BztInterpolator(
                    model_data,
                    lpfac=int(boltz['lpfac']),
                    energy_range=float(boltz['energy_range']),
                    curvature=bool(boltz['curvature']),
                    save_bztInterp=True,
                    load_bztInterp=bool(boltz['load']))

                try:
                    kpath = json.load(open('./kpath', 'r'))
                    kpaths = kpath['path']
                    kpoints_lbls_dict = {}
                    for i in range(len(kpaths)):
                        for j in [0, 1]:
                            if 'GAMMA' == kpaths[i][j]:
                                kpaths[i][j] = '\Gamma'
                    for k, v in kpath['kpoints_rel'].items():
                        if k == 'GAMMA':
                            k = '\Gamma'
                        kpoints_lbls_dict[k] = v
                except:
                    kpaths = None
                    kpoints_lbls_dict = None

                print(kpaths, kpoints_lbls_dict)
                bs = b_inter.get_band_structure(
                    kpaths=kpaths, kpoints_lbls_dict=kpoints_lbls_dict)
                model_bs = model_inter.get_band_structure(
                    kpaths=kpaths, kpoints_lbls_dict=kpoints_lbls_dict)

                #bs_uniform = b_inter.get_band_structure()
                gap = bs.get_band_gap()
                nvb = int(np.ceil(nelec / (int(bs.is_spin_polarized) + 1)))
                vbm = -100
                print("WHC interpolated gap: %s" % gap)
                for spin, v in bs.bands.items():
                    vbm = max(vbm, max(v[nvb - 1]))
                print(
                    'WHC WARNNING vasp fermi %s interpolation vbm %s nelec %s nvb %s'
                    % (bs.efermi, vbm, nelec, nvb))
                if vbm < bs.efermi:
                    bs.efermi = vbm
                    print("if vbm <")
                if vbm < model_bs.efermi:
                    model_bs.efermi = vbm
                    print("if vbm <")
                print(bs.bands.keys())
                band_keys = list(bs.bands.keys())
                print("Band shapes", bs.bands[band_keys[0]].shape,
                      model_bs.bands[band_keys[0]].shape)
                print(
                    "equel bands?",
                    np.sum((bs.bands[band_keys[0]] - bs.efermi) -
                           model_bs.bands[band_keys[0]]))
                bs.bands[band_keys[0]] = (
                    bs.bands[band_keys[0]] - bs.efermi
                )  # why??????????????????????????????????????????????????
                # bs.bands[band_keys[1]] = (bs.bands[band_keys[1]] - bs.efermi)  # why??????????????????????????????????????????????????
                print(
                    "equel bands fermi shifted?",
                    np.sum((bs.bands[band_keys[0]] - bs.efermi) -
                           model_bs.bands[band_keys[0]]))

                # bandstructures.append(bs)
                # bandstructures.append(model_bs)
            bs = get_reconstructed_band_structure([bs])
            model_bs = get_reconstructed_band_structure([model_bs])

            if bool(boltz['ifinter']):
                bs.nvb = nvb
                bs.ismetaltolerance = float(boltz['ismetaltolerance'])
                model_bs.nvb = nvb
                model_bs.ismetaltolerance = float(boltz['ismetaltolerance'])

            print("dft bands", bs.bands[band_keys[0]])
            print("dft ktps", len(bs.kpoints), kpts.shape)
            print("dft labels", bs.labels_dict)

            for key in bs.labels_dict.keys():
                print(bs.labels_dict[key].label, bs.labels_dict[key].as_dict(),
                      bs.labels_dict[key].a, bs.labels_dict[key].b,
                      bs.labels_dict[key].c, bs.labels_dict[key].frac_coords)
            labels = []
            for i in range(len(bs.kpoints)):
                # print(i, bs.kpoints[i])
                for key in bs.labels_dict.keys():
                    if bs.labels_dict[key].label == bs.kpoints[
                            i].label and bs.labels_dict[
                                key].label != bs.kpoints[i - 1].label:
                        # print("Labels!!!!!", i, bs.labels_dict[key].label)
                        labels.append([i, bs.labels_dict[key].label])
            print(labels)

            print("dft efermi", bs.efermi)
            print("dft lattice_rec", bs.lattice_rec)
            print("dft structure", bs.structure)

            print("model bands", model_bs.bands[band_keys[0]])
            print("model ktps", len(model_bs.kpoints), kpts.shape)
            print("model labels", model_bs.labels_dict)
            print("model efermi", model_bs.efermi)
            print("model lattice_rec", model_bs.lattice_rec)
            print("model structure", model_bs.structure)

            return bs.bands[band_keys[0]], model_bs.bands[band_keys[0]], labels

            save_files = False if plt else True
            dos_plotter = None
            dos_opts = None
            if dos_file:
                dos, pdos = load_dos(dos_file, elements, lm_orbitals, atoms,
                                     gaussian, total_only)
                dos_plotter = SDOSPlotter(dos, pdos)
                dos_opts = {
                    'plot_total': plot_total,
                    'legend_cutoff': legend_cutoff,
                    'colours': colours,
                    'yscale': yscale
                }

            model_and_dft_bs = [bs, model_bs]
            plotter = SBSPlotter(model_bs)
            print("spin", spin)
            if len(vr.eigenvalues.keys()) == 1:
                spin = None
            print("spin", spin)
            plt = plotter.get_plot(zero_to_efermi=True,
                                   ymin=ymin,
                                   ymax=ymax,
                                   height=height,
                                   width=width,
                                   vbm_cbm_marker=vbm_cbm_marker,
                                   ylabel=ylabel,
                                   plt=plt,
                                   dos_plotter=dos_plotter,
                                   dos_options=dos_opts,
                                   dos_label=dos_label,
                                   fonts=fonts,
                                   style=style,
                                   no_base_style=no_base_style,
                                   spin=spin)

        # don't save if pyplot object provided
        save_files = False if plt else True
        if save_files:
            basename = 'band.{}'.format(image_format)
            filename = '{}_{}'.format(prefix, basename) if prefix else basename
            if directory:
                filename = os.path.join(directory, filename)
            plt.savefig(filename,
                        format=image_format,
                        dpi=dpi,
                        bbox_inches='tight')

            written = [filename]
            written += save_data_files(bs, prefix=prefix, directory=directory)
            return written

        else:
            return plt