Example #1
0
def filter_cursor_by_chempots(species, cursor):
    """ For the desired chemical potentials, remove any incompatible structures
    from cursor.

    Parameters:
        species (list): list of chemical potential formulae.
        cursor (list): list of matador documents to filter.

    Returns:
        list: the filtered cursor.

    """
    from matador.utils.chem_utils import get_number_of_chempots
    # filter out structures with any elements with missing chem pots
    chempot_stoichiometries = []
    for label in species:
        chempot_stoichiometries.append(get_stoich_from_formula(label))

    inds_to_remove = set()
    for ind, doc in enumerate(cursor):
        try:
            cursor[ind]['num_chempots'] = get_number_of_chempots(doc, chempot_stoichiometries)
        except RuntimeError:
            inds_to_remove.add(ind)
        else:
            cursor[ind]['concentration'] = (cursor[ind]['num_chempots'][:-1] /
                                            np.sum(cursor[ind]['num_chempots'])).tolist()
            for idx, conc in enumerate(cursor[ind]['concentration']):
                if conc < 0 + EPS:
                    cursor[ind]['concentration'][idx] = 0.0
                elif conc > 1 - EPS:
                    cursor[ind]['concentration'][idx] = 1.0

    return [doc for ind, doc in enumerate(cursor) if ind not in inds_to_remove]
Example #2
0
    def fake_chempots(self, custom_elem=None):
        """ Spoof documents for command-line chemical potentials.

        Keyword arguments:
            custom_elem (list(str)): list of element symbols to generate chempots for.

        """
        self.chempot_cursor = []
        print('Generating fake chempots...')

        if custom_elem is None:
            custom_elem = self.species

        if len(custom_elem) != len(self.species):
            raise RuntimeError(
                'Wrong number of compounds/chemical potentials specified: {} vs {}'
                .format(custom_elem, self.args.get('chempots')))
        for i, _ in enumerate(self.args.get('chempots')):
            self.chempot_cursor.append(dict())
            self.chempot_cursor[i]['stoichiometry'] = get_stoich_from_formula(
                custom_elem[i])
            self.chempot_cursor[i][self.energy_key] = -1 * abs(
                self.args.get('chempots')[i])
            self.chempot_cursor[i][self._extensive_energy_key] = self.chempot_cursor[i][self.energy_key] * \
                sum(elem[1] for elem in self.chempot_cursor[i]['stoichiometry'])
            self.chempot_cursor[i]['num_fu'] = 1
            self.chempot_cursor[i]['num_atoms'] = 1
            self.chempot_cursor[i]['text_id'] = ['command', 'line']
            self.chempot_cursor[i]['_id'] = None
            self.chempot_cursor[i]['source'] = ['command_line']
            self.chempot_cursor[i]['space_group'] = 'xxx'
            self.chempot_cursor[i][self._extensive_energy_key +
                                   '_per_b'] = self.chempot_cursor[i][
                                       self.energy_key]
            self.chempot_cursor[i]['num_a'] = 0
            self.chempot_cursor[i]['cell_volume'] = 1
            self.chempot_cursor[i]['concentration'] = [
                1 if i == ind else 0 for ind in range(self._dimension - 1)
            ]
        self.chempot_cursor[0]['num_a'] = float('inf')
        notify = 'Custom chempots:'
        for chempot in self.chempot_cursor:
            notify += '{:3} = {} eV/fu, '.format(
                get_formula_from_stoich(chempot['stoichiometry'], sort=False),
                chempot[self._extensive_energy_key])

        if self.args.get('debug'):
            for match in self.chempot_cursor:
                print(match)
        print(len(notify) * '─')
        print(notify)
        print(len(notify) * '─')
Example #3
0
    def _query_stoichiometry(self, custom_stoich=None, partial_formula=None):
        """ Query DB for particular stoichiometry. """
        # alias stoichiometry
        if custom_stoich is None:
            stoich = self.args.get('stoichiometry')
            if isinstance(stoich, str):
                stoich = [stoich]
        else:
            stoich = custom_stoich
        if partial_formula is None:
            partial_formula = self.args.get('partial_formula')
        if ':' in stoich[0]:
            raise RuntimeError('Formula cannot contain ":", you probably meant to query composition.')

        stoich = get_stoich_from_formula(stoich[0], sort=False)

        query_dict = dict()
        query_dict['$and'] = []

        for ind, _ in enumerate(stoich):
            elem = stoich[ind][0]
            fraction = int(stoich[ind][1])

            if '[' in elem or ']' in elem:
                types_dict = dict()
                types_dict['$or'] = list()
                elem = elem.strip('[').strip(']')
                if elem in self._periodic_table:
                    for group_elem in self._periodic_table[elem]:
                        types_dict['$or'].append(dict())
                        types_dict['$or'][-1]['stoichiometry'] = dict()
                        types_dict['$or'][-1]['stoichiometry']['$in'] = [[group_elem, fraction]]
                    query_dict['$and'].append(types_dict)
                elif ',' in elem:
                    for group_elem in elem.split(','):
                        types_dict['$or'].append(dict())
                        types_dict['$or'][-1]['stoichiometry'] = dict()
                        types_dict['$or'][-1]['stoichiometry']['$in'] = [[group_elem, fraction]]
                    query_dict['$and'].append(types_dict)
            else:
                stoich_dict = dict()
                stoich_dict['stoichiometry'] = dict()
                stoich_dict['stoichiometry']['$in'] = [[elem, fraction]]
                query_dict['$and'].append(stoich_dict)
        if not partial_formula:
            size_dict = dict()
            size_dict['stoichiometry'] = dict()
            size_dict['stoichiometry']['$size'] = len(stoich)
            query_dict['$and'].append(size_dict)

        return query_dict
Example #4
0
def get_single_structure(request: Request,
                         entry_id: str,
                         params: SingleEntryQueryParams = Depends()):

    response = get_single_entry(
        collection=structures_coll,
        entry_id=entry_id,
        request=request,
        params=params,
        response=StructureResponseOne,
    )

    context = {"request": request, "entry_id": entry_id}

    if response.meta.data_returned < 1:
        return TEMPLATES.TemplateResponse("structure_not_found.html", context)

    stoichiometry = get_stoich_from_formula(
        response.data.attributes.chemical_formula_descriptive)
    for ind, (elem, num) in enumerate(stoichiometry):
        if num - int(num) > 1e-5:
            raise RuntimeError("Unable to cast formula to correct format")
        stoichiometry[ind][1] = int(num)

    context.update({
        "odbx_title": "odbx",
        "odbx_blurb": "the open database of xtals",
        "odbx_about":
        'odbx is a public database of crystal structures from the group of <a href="https://ajm143.github.io">Dr Andrew Morris</a> at the University of Birmingham.',
        "odbx_cif_string": optimade_to_basic_cif(response.data),
        "structure_info": dict(response),
        "cif_link": str(request.url).replace("structures/", "cif/"),
        "stoichiometry": stoichiometry,
    })

    return TEMPLATES.TemplateResponse("structure.html", context)
Example #5
0
    def set_chempots(self, energy_key=None):
        """ Search for chemical potentials that match the structures in
        the query cursor and add them to the cursor. Also set the concentration
        of chemical potentials in :attr:`cursor`, if not already set.

        """
        if energy_key is None:
            energy_key = self.energy_key
        query = self._query
        query_dict = dict()
        species_stoich = [sorted(get_stoich_from_formula(spec, sort=False)) for spec in self.species]
        self.chempot_cursor = []

        if self.args.get('chempots') is not None:
            self.fake_chempots(custom_elem=self.species)

        elif self.from_cursor:
            chempot_cursor = sorted([doc for doc in self.cursor if doc['stoichiometry'] in species_stoich],
                                    key=lambda doc: recursive_get(doc, energy_key))

            for species in species_stoich:
                for doc in chempot_cursor:
                    if doc['stoichiometry'] == species:
                        self.chempot_cursor.append(doc)
                        break

            if len(self.chempot_cursor) != len(self.species):
                raise RuntimeError('Found {} of {} required chemical potentials'
                                   .format(len(self.chempot_cursor), len(self.species)))

            if self.args.get('debug'):
                print([mu['stoichiometry'] for mu in self.chempot_cursor])

        else:
            print(60 * '─')
            self.chempot_cursor = len(self.species) * [None]
            # scan for suitable chem pots in database
            for ind, elem in enumerate(self.species):

                print('Scanning for suitable', elem, 'chemical potential...')
                query_dict['$and'] = deepcopy(list(query.calc_dict['$and']))

                if not self.args.get('ignore_warnings'):
                    query_dict['$and'].append(query.query_quality())

                if len(species_stoich[ind]) == 1:
                    query_dict['$and'].append(query.query_composition(custom_elem=[elem]))
                else:
                    query_dict['$and'].append(query.query_stoichiometry(custom_stoich=[elem]))

                # if oqmd, only query composition, not parameters
                if query.args.get('tags') is not None:
                    query_dict['$and'].append(query.query_tags())

                mu_cursor = query.repo.find(SON(query_dict)).sort(energy_key, pm.ASCENDING)
                if mu_cursor.count() == 0:
                    print_notify('Failed... searching without spin polarization field...')
                    scanned = False
                    while not scanned:
                        for idx, dicts in enumerate(query_dict['$and']):
                            for key in dicts:
                                if key == 'spin_polarized':
                                    del query_dict['$and'][idx][key]
                                    break
                            if idx == len(query_dict['$and']) - 1:
                                scanned = True
                    mu_cursor = query.repo.find(SON(query_dict)).sort(energy_key, pm.ASCENDING)

                if mu_cursor.count() == 0:
                    raise RuntimeError('No chemical potentials found for {}...'.format(elem))

                self.chempot_cursor[ind] = mu_cursor[0]
                if self.chempot_cursor[ind] is not None:
                    print('Using', ''.join([self.chempot_cursor[ind]['text_id'][0], ' ',
                                            self.chempot_cursor[ind]['text_id'][1]]), 'as chem pot for', elem)
                    print(60 * '─')
                else:
                    raise RuntimeError('No possible chem pots available for {}.'.format(elem))

            for i, mu in enumerate(self.chempot_cursor):
                self.chempot_cursor[i][self._extensive_energy_key + '_per_b'] = mu[energy_key]
                self.chempot_cursor[i]['num_a'] = 0

            self.chempot_cursor[0]['num_a'] = float('inf')

        # don't check for IDs if we're loading from cursor
        if not self.from_cursor:
            ids = [doc['_id'] for doc in self.cursor]
            if self.chempot_cursor[0]['_id'] is None or self.chempot_cursor[0]['_id'] not in ids:
                self.cursor.insert(0, self.chempot_cursor[0])
            for match in self.chempot_cursor[1:]:
                if match['_id'] is None or match['_id'] not in ids:
                    self.cursor.append(match)

        # add faked chempots to overall cursor
        elif self.args.get('chempots') is not None:
            self.cursor.insert(0, self.chempot_cursor[0])
            self.cursor.extend(self.chempot_cursor[1:])

        # find all elements present in the chemical potentials
        elements = []
        for mu in self.chempot_cursor:
            for elem, _ in mu['stoichiometry']:
                if elem not in elements:
                    elements.append(elem)
        self.elements = elements
        self.num_elements = len(elements)
Example #6
0
def plot_2d_hull(hull,
                 ax=None,
                 show=True,
                 plot_points=True,
                 plot_tie_line=True,
                 plot_hull_points=True,
                 labels=None,
                 label_cutoff=None,
                 colour_by_source=False,
                 sources=None,
                 hull_label=None,
                 source_labels=None,
                 title=True,
                 plot_fname=None,
                 show_cbar=True,
                 label_offset=(1.15, 0.05),
                 eform_limits=None,
                 legend_kwargs=None,
                 **kwargs):
    """ Plot calculated hull, returning ax and fig objects for further editing.

    Parameters:
        hull (matador.hull.QueryConvexHull): matador hull object.

    Keyword arguments:
        ax (matplotlib.axes.Axes): an existing axis on which to plot,
        show (bool): whether or not to display the plot in an X window,
        plot_points (bool): whether or not to display off-hull structures,
        plot_hull_points (bool): whether or not to display on-hull structures,
        labels (bool): whether to label formulae of hull structures, also read from
            hull.args.
        label_cutoff (float/:obj:`tuple` of :obj:`float`): draw labels less than or
            between these distances form the hull, also read from hull.args.
        colour_by_source (bool): plot and label points by their sources
        alpha (float): alpha value of points when colour_by_source is True
        sources (list): list of possible provenances to colour when colour_by_source
            is True (others will be grey)
        title (str/bool): whether to include a plot title.
        png/pdf/svg (bool): whether or not to write the plot to a file.
        plot_fname (str): filename to write plot to, without file extension.

    Returns:
        matplotlib.axes.Axes: matplotlib axis with plot.

    """
    import matplotlib.pyplot as plt
    import matplotlib.colors as colours

    if ax is None:
        fig = plt.figure()
        ax = fig.add_subplot(111)

    if not hasattr(hull, 'colours'):
        hull.colours = list(plt.rcParams['axes.prop_cycle'].by_key()['color'])
    hull.default_cmap_list = get_linear_cmap(hull.colours[1:4], list_only=True)
    hull.default_cmap = get_linear_cmap(hull.colours[1:4], list_only=False)

    if labels is None:
        labels = hull.args.get('labels', False)
    if label_cutoff is None:
        label_cutoff = hull.args.get('label_cutoff')

    scale = 1
    scatter = []
    chempot_labels = [
        get_formula_from_stoich(get_stoich_from_formula(species, sort=False),
                                tex=True) for species in hull.species
    ]
    tie_line = hull.convex_hull.points[hull.convex_hull.vertices]

    # plot hull structures
    if plot_hull_points:
        ax.scatter(tie_line[:, 0],
                   tie_line[:, 1],
                   c=hull.colours[1],
                   marker='o',
                   zorder=99999,
                   edgecolor='k',
                   s=scale * 40,
                   lw=1.5)
        if plot_tie_line:
            ax.plot(np.sort(tie_line[:, 0]),
                    tie_line[np.argsort(tie_line[:, 0]), 1],
                    c=hull.colours[0],
                    zorder=1,
                    label=hull_label,
                    marker='o',
                    markerfacecolor=hull.colours[0],
                    markeredgecolor='k',
                    markeredgewidth=1.5,
                    markersize=np.sqrt(scale * 40))
    if plot_tie_line:
        ax.plot(np.sort(tie_line[:, 0]),
                tie_line[np.argsort(tie_line[:, 0]), 1],
                c=hull.colours[0],
                zorder=1,
                label=hull_label,
                markersize=0)

    if hull.hull_cutoff > 0:
        ax.plot(np.sort(tie_line[:, 0]),
                tie_line[np.argsort(tie_line[:, 0]), 1] + hull.hull_cutoff,
                '--',
                c=hull.colours[1],
                alpha=0.5,
                zorder=1,
                label='')

    # annotate hull structures
    if labels or label_cutoff is not None:
        label_cursor = _get_hull_labels(hull,
                                        num_species=2,
                                        label_cutoff=label_cutoff)
        already_labelled = []
        for ind, doc in enumerate(label_cursor):
            formula = get_formula_from_stoich(doc['stoichiometry'], sort=True)
            if formula not in already_labelled:
                arrowprops = dict(arrowstyle="-|>",
                                  lw=2,
                                  alpha=1,
                                  zorder=1,
                                  shrinkA=2,
                                  shrinkB=4)
                min_comp = tie_line[np.argmin(tie_line[:, 1]), 0]
                e_f = label_cursor[ind]['formation_' + str(hull.energy_key)]
                conc = label_cursor[ind]['concentration'][0]
                if conc < min_comp:
                    position = (0.8 * conc,
                                label_offset[0] * (e_f - label_offset[1]))
                elif label_cursor[ind]['concentration'][0] == min_comp:
                    position = (conc,
                                label_offset[0] * (e_f - label_offset[1]))
                else:
                    position = (min(1.1 * conc + 0.15, 0.95),
                                label_offset[0] * (e_f - label_offset[1]))
                ax.annotate(get_formula_from_stoich(
                    doc['stoichiometry'],
                    latex_sub_style=r'\mathregular',
                    tex=True,
                    elements=hull.species,
                    sort=False),
                            xy=(conc, e_f),
                            xytext=position,
                            textcoords='data',
                            ha='right',
                            va='bottom',
                            arrowprops=arrowprops,
                            zorder=1)
                already_labelled.append(formula)

    # points for off hull structures; we either colour by source or by energy
    if plot_points and not colour_by_source:

        if hull.hull_cutoff == 0:
            # if no specified hull cutoff, ignore labels and colour by hull distance
            cmap = hull.default_cmap
            if plot_points:
                scatter = ax.scatter(
                    hull.structures[np.argsort(hull.hull_dist), 0][::-1],
                    hull.structures[np.argsort(hull.hull_dist), -1][::-1],
                    s=scale * 40,
                    c=np.sort(hull.hull_dist)[::-1],
                    zorder=10000,
                    cmap=cmap,
                    norm=colours.LogNorm(0.02, 2))

                if show_cbar:
                    cbar = plt.colorbar(
                        scatter,
                        aspect=30,
                        pad=0.02,
                        ticks=[0, 0.02, 0.04, 0.08, 0.16, 0.32, 0.64, 1.28])
                    cbar.ax.tick_params(length=0)
                    cbar.ax.set_yticklabels(
                        [0, 0.02, 0.04, 0.08, 0.16, 0.32, 0.64, 1.28])
                    cbar.ax.yaxis.set_ticks_position('right')
                    cbar.ax.set_frame_on(False)
                    cbar.outline.set_visible(False)
                    cbar.set_label('Distance from hull (eV/atom)')

        elif hull.hull_cutoff != 0:
            # if specified hull cutoff colour those below
            c = hull.colours[1]
            for ind in range(len(hull.structures)):
                if hull.hull_dist[
                        ind] <= hull.hull_cutoff or hull.hull_cutoff == 0:
                    if plot_points:
                        scatter.append(
                            ax.scatter(hull.structures[ind, 0],
                                       hull.structures[ind, 1],
                                       s=scale * 40,
                                       alpha=0.9,
                                       c=c,
                                       zorder=300))
            if plot_points:
                ax.scatter(hull.structures[1:-1, 0],
                           hull.structures[1:-1, 1],
                           s=scale * 30,
                           lw=0,
                           alpha=0.3,
                           c=hull.colours[-2],
                           edgecolor='k',
                           zorder=10)

    elif colour_by_source:
        _scatter_plot_by_source(hull,
                                ax,
                                scale,
                                kwargs,
                                sources=sources,
                                source_labels=source_labels,
                                plot_hull_points=plot_hull_points,
                                legend_kwargs=legend_kwargs)

    if eform_limits is None:
        eform_limits = (np.min(hull.structures[:, 1]),
                        np.max(hull.structures[:, 1]))
        lims = (-0.1 if eform_limits[0] >= 0 else 1.4 * eform_limits[0],
                eform_limits[1] if eform_limits[0] >= 0 else 0.1)
    else:
        lims = sorted(eform_limits)
    ax.set_ylim(lims)

    if isinstance(title, bool) and title:
        if hull._non_elemental:
            ax.set_title(
                r'({d[0]})$_\mathrm{{x}}$({d[1]})$_\mathrm{{1-x}}$'.format(
                    d=chempot_labels))
        else:
            ax.set_title(
                r'{d[0]}$_\mathrm{{x}}${d[1]}$_\mathrm{{1-x}}$'.format(
                    d=chempot_labels))
    elif isinstance(title, str) and title != '':
        ax.set_title(title)

    plt.locator_params(nbins=3)
    if hull._non_elemental:
        ax.set_xlabel(
            r'x in ({d[0]})$_\mathrm{{x}}$({d[1]})$_\mathrm{{1-x}}$'.format(
                d=chempot_labels))
    else:
        ax.set_xlabel(
            r'x in {d[0]}$_\mathrm{{x}}${d[1]}$_\mathrm{{1-x}}$'.format(
                d=chempot_labels))

    ax.grid(False)
    ax.set_xlim(-0.05, 1.05)
    ax.set_xticks([0, 0.25, 0.5, 0.75, 1])
    ax.set_xticklabels(ax.get_xticks())
    ax.set_ylabel('Formation energy (eV/atom)')

    if hull.savefig or any([kwargs.get(ext) for ext in SAVE_EXTS]):
        import os
        fname = plot_fname or ''.join(hull.species) + '_hull'
        for ext in SAVE_EXTS:
            if hull.args.get(ext) or kwargs.get(ext):
                fname_tmp = fname
                ind = 0
                while os.path.isfile('{}.{}'.format(fname_tmp, ext)):
                    ind += 1
                    fname_tmp = fname + str(ind)

                fname = fname_tmp
                plt.savefig('{}.{}'.format(fname, ext),
                            bbox_inches='tight',
                            transparent=True)
                print('Wrote {}.{}'.format(fname, ext))

    if show:
        plt.show()

    return ax
Example #7
0
def plot_ternary_hull(hull,
                      axis=None,
                      show=True,
                      plot_points=True,
                      hull_cutoff=None,
                      fig_height=None,
                      label_cutoff=None,
                      label_corners=True,
                      expecting_cbar=True,
                      labels=None,
                      plot_fname=None,
                      hull_dist_unit="meV",
                      efmap=None,
                      sampmap=None,
                      capmap=None,
                      pathways=False,
                      **kwargs):
    """ Plot calculated ternary hull as a 2D projection.

    Parameters:
        hull (matador.hull.QueryConvexHull): matador hull object.

    Keyword arguments:
        axis (matplotlib.axes.Axes): matplotlib axis object on which to plot.
        show (bool): whether or not to show plot in X window.
        plot_points (bool): whether or not to plot each structure as a point.
        label_cutoff (float/:obj:`tuple` of :obj:`float`): draw labels less than or
            between these distances form the hull, also read from hull.args.
        expecting_cbar (bool): whether or not to space out the plot to preserve
            aspect ratio if a colourbar is present.
        labels (bool): whether or not to label on-hull structures
        label_corners (bool): whether or not to put axis labels on corners or edges.
        hull_dist_unit (str): either "eV" or "meV",
        png/pdf/svg (bool): whether or not to write the plot to a file.
        plot_fname (str): filename to write plot to.
        efmap (bool): plot heatmap of formation energy,
        sampmap (bool): plot heatmap showing sampling density,
        capmap (bool): plot heatmap showing gravimetric capacity.
        pathways (bool): plot the pathway from the starting electrode to active ion.

    Returns:
        matplotlib.axes.Axes: matplotlib axis with plot.

    """
    import ternary
    import matplotlib.pyplot as plt
    import matplotlib.colors as colours
    from matador.utils.chem_utils import get_generic_grav_capacity

    plt.rcParams['axes.linewidth'] = 0
    plt.rcParams['xtick.major.size'] = 0
    plt.rcParams['ytick.major.size'] = 0
    plt.rcParams['xtick.minor.size'] = 0
    plt.rcParams['ytick.minor.size'] = 0

    if efmap is None:
        efmap = hull.args.get('efmap')
    if sampmap is None:
        sampmap = hull.args.get('sampmap')
    if capmap is None:
        capmap = hull.args.get('capmap')
    if pathways is None:
        pathways = hull.args.get('pathways')

    if labels is None:
        labels = hull.args.get('labels')
    if label_cutoff is None:
        label_cutoff = hull.args.get('label_cutoff')
        if label_cutoff is None:
            label_cutoff = 0
    else:
        labels = True

    if hull_cutoff is None and hull.hull_cutoff is None:
        hull_cutoff = 0
    else:
        hull_cutoff = hull.hull_cutoff

    print('Plotting ternary hull...')
    if capmap or efmap:
        scale = 100
    elif sampmap:
        scale = 20
    else:
        scale = 1

    if axis is not None:
        fig, ax = ternary.figure(scale=scale, ax=axis)
    else:
        fig, ax = ternary.figure(scale=scale)

    # maintain aspect ratio of triangle
    if fig_height is None:
        _user_height = plt.rcParams.get("figure.figsize", (8, 6))[0]
    else:
        _user_height = fig_height
    if capmap or efmap or sampmap:
        fig.set_size_inches(_user_height, 5 / 8 * _user_height)
    elif not expecting_cbar:
        fig.set_size_inches(_user_height, _user_height)
    else:
        fig.set_size_inches(_user_height, 5 / 6.67 * _user_height)

    ax.boundary(linewidth=2.0, zorder=99)
    ax.clear_matplotlib_ticks()

    chempot_labels = [
        get_formula_from_stoich(get_stoich_from_formula(species, sort=False),
                                sort=False,
                                tex=True) for species in hull.species
    ]

    ax.gridlines(color='black', multiple=scale * 0.1, linewidth=0.5)
    ticks = [float(val) for val in np.linspace(0, 1, 6)]
    if label_corners:
        # remove 0 and 1 ticks when labelling corners
        ticks = ticks[1:-1]
        ax.left_corner_label(chempot_labels[2], fontsize='large')
        ax.right_corner_label(chempot_labels[0], fontsize='large')
        ax.top_corner_label(chempot_labels[1], fontsize='large', offset=0.16)
    else:
        ax.left_axis_label(chempot_labels[2], fontsize='large', offset=0.12)
        ax.right_axis_label(chempot_labels[1], fontsize='large', offset=0.12)
        ax.bottom_axis_label(chempot_labels[0], fontsize='large', offset=0.08)
        ax.set_title('-'.join(['{}'.format(label)
                               for label in chempot_labels]),
                     fontsize='large',
                     y=1.02)

    ax.ticks(axis='lbr',
             linewidth=1,
             offset=0.025,
             fontsize='small',
             locations=(scale * np.asarray(ticks)).tolist(),
             ticks=ticks,
             tick_formats='%.1f')

    concs = np.zeros((len(hull.structures), 3))
    concs[:, :-1] = hull.structures[:, :-1]
    for i in range(len(concs)):
        # set third triangular coordinate
        concs[i, -1] = 1 - concs[i, 0] - concs[i, 1]

    stable = concs[np.where(hull.hull_dist <= 0 + EPS)]

    # sort by hull distances so things are plotting the right order
    concs = concs[np.argsort(hull.hull_dist)].tolist()
    hull_dist = np.sort(hull.hull_dist)

    filtered_concs = []
    filtered_hull_dists = []
    for ind, conc in enumerate(concs):
        if conc not in filtered_concs:
            if hull_dist[ind] <= hull.hull_cutoff or (hull.hull_cutoff == 0 and
                                                      hull_dist[ind] < 0.1):
                filtered_concs.append(conc)
                filtered_hull_dists.append(hull_dist[ind])
    if hull.args.get('debug'):
        print('Trying to plot {} points...'.format(len(filtered_concs)))

    concs = np.asarray(filtered_concs)
    hull_dist = np.asarray(filtered_hull_dists)

    min_cut = 0.0
    max_cut = 0.2

    if hull_dist_unit.lower() == "mev":
        hull_dist *= 1000
        min_cut *= 1000
        max_cut *= 1000

    hull.colours = list(plt.rcParams['axes.prop_cycle'].by_key()['color'])
    hull.default_cmap_list = get_linear_cmap(hull.colours[1:4], list_only=True)
    hull.default_cmap = get_linear_cmap(hull.colours[1:4], list_only=False)
    n_colours = len(hull.default_cmap_list)
    colours_hull = hull.default_cmap_list

    cmap = hull.default_cmap
    cmap_full = plt.cm.get_cmap('Pastel2')
    pastel_cmap = colours.LinearSegmentedColormap.from_list(
        'Pastel2', cmap_full.colors)

    for plane in hull.convex_hull.planes:
        plane.append(plane[0])
        plane = np.asarray(plane)
        ax.plot(scale * plane, c=hull.colours[0], lw=1.5, alpha=1, zorder=98)

    if pathways:
        for phase in stable:
            if phase[0] == 0 and phase[1] != 0 and phase[2] != 0:
                ax.plot([scale * phase, [scale, 0, 0]],
                        c='r',
                        alpha=0.2,
                        lw=6,
                        zorder=99)

    # add points
    if plot_points:
        colours_list = []
        colour_metric = hull_dist
        for i, _ in enumerate(colour_metric):
            if hull_dist[i] >= max_cut:
                colours_list.append(n_colours - 1)
            elif hull_dist[i] <= min_cut:
                colours_list.append(0)
            else:
                colours_list.append(
                    int((n_colours - 1) * (hull_dist[i] / max_cut)))
        colours_list = np.asarray(colours_list)
        ax.scatter(scale * stable,
                   marker='o',
                   color=hull.colours[1],
                   edgecolors='black',
                   zorder=9999999,
                   s=150,
                   lw=1.5)
        ax.scatter(scale * concs,
                   colormap=cmap,
                   colorbar=True,
                   cbarlabel='Distance from hull ({}eV/atom)'.format(
                       "m" if hull_dist_unit.lower() == "mev" else ""),
                   c=colour_metric,
                   vmax=max_cut,
                   vmin=min_cut,
                   zorder=1000,
                   s=40,
                   alpha=0)
        for i, _ in enumerate(concs):
            ax.scatter(scale * concs[i].reshape(1, 3),
                       color=colours_hull[colours_list[i]],
                       marker='o',
                       zorder=10000 - colours_list[i],
                       s=70 * (1 - float(colours_list[i]) / n_colours) + 15,
                       lw=1,
                       edgecolors='black')

    # add colourmaps
    if capmap:
        capacities = dict()
        from ternary.helpers import simplex_iterator
        for (i, j, k) in simplex_iterator(scale):
            capacities[(i, j, k)] = get_generic_grav_capacity([
                float(i) / scale,
                float(j) / scale,
                float(scale - i - j) / scale
            ], hull.species)
        ax.heatmap(capacities,
                   style="hexagonal",
                   cbarlabel='Gravimetric capacity (mAh/g)',
                   vmin=0,
                   vmax=3000,
                   cmap=pastel_cmap)
    elif efmap:
        energies = dict()
        fake_structures = []
        from ternary.helpers import simplex_iterator
        for (i, j, k) in simplex_iterator(scale):
            fake_structures.append([float(i) / scale, float(j) / scale, 0.0])
        fake_structures = np.asarray(fake_structures)
        plane_energies = hull.get_hull_distances(fake_structures,
                                                 precompute=False)
        ind = 0
        for (i, j, k) in simplex_iterator(scale):
            energies[(i, j, k)] = -1 * plane_energies[ind]
            ind += 1
        if isinstance(efmap, str):
            efmap = efmap
        else:
            efmap = 'BuPu_r'
        ax.heatmap(energies,
                   style="hexagonal",
                   cbarlabel='Formation energy (eV/atom)',
                   vmax=0,
                   cmap=efmap)
    elif sampmap:
        sampling = dict()
        from ternary.helpers import simplex_iterator
        eps = 1.0 / float(scale)
        for (i, j, k) in simplex_iterator(scale):
            sampling[(i, j, k)] = np.size(
                np.where((concs[:, 0] <= float(i) / scale + eps) *
                         (concs[:, 0] >= float(i) / scale - eps) *
                         (concs[:, 1] <= float(j) / scale + eps) *
                         (concs[:, 1] >= float(j) / scale - eps) *
                         (concs[:, 2] <= float(k) / scale + eps) *
                         (concs[:, 2] >= float(k) / scale - eps)))
        ax.heatmap(sampling,
                   style="hexagonal",
                   cbarlabel='Number of structures',
                   cmap='afmhot')

    # add labels
    if labels:
        label_cursor = _get_hull_labels(hull, label_cutoff=label_cutoff)
        if len(label_cursor) == 1:
            label_coords = [[0.25, 0.5]]
        else:
            label_coords = [[
                0.1 + (val - 0.5) * 0.3, val
            ] for val in np.linspace(0.5, 0.8,
                                     int(round(len(label_cursor) / 2.) + 1))]
            label_coords += [[0.9 - (val - 0.5) * 0.3, val + 0.2]
                             for val in np.linspace(
                                 0.5, 0.8, int(round(len(label_cursor) / 2.)))]
        from matador.utils.hull_utils import barycentric2cart
        for ind, doc in enumerate(label_cursor):
            conc = np.asarray(doc['concentration'] +
                              [1 - sum(doc['concentration'])])
            formula = get_formula_from_stoich(doc['stoichiometry'],
                                              sort=False,
                                              tex=True,
                                              latex_sub_style=r'\mathregular',
                                              elements=hull.species)
            arrowprops = dict(arrowstyle="-|>",
                              color='k',
                              lw=2,
                              alpha=0.5,
                              zorder=1,
                              shrinkA=2,
                              shrinkB=4)
            cart = barycentric2cart([doc['concentration'] + [0]])[0][:2]
            min_dist = 1e20
            closest_label = 0
            for coord_ind, coord in enumerate(label_coords):
                dist = np.sqrt((cart[0] - coord[0])**2 +
                               (cart[1] - coord[1])**2)
                if dist < min_dist:
                    min_dist = dist
                    closest_label = coord_ind
            ax.annotate(
                formula,
                scale * conc,
                textcoords='data',
                xytext=[scale * val for val in label_coords[closest_label]],
                ha='right',
                va='bottom',
                arrowprops=arrowprops)
            del label_coords[closest_label]

    plt.tight_layout(w_pad=0.2)
    # important for retaining labels if exporting to PDF
    # see https://github.com/marcharper/python-ternary/issues/36
    ax._redraw_labels()  # noqa

    if hull.savefig:
        fname = plot_fname or ''.join(hull.species) + '_hull'
        for ext in SAVE_EXTS:
            if hull.args.get(ext) or kwargs.get(ext):
                plt.savefig('{}.{}'.format(fname, ext),
                            bbox_inches='tight',
                            transparent=True)
                print('Wrote {}.{}'.format(fname, ext))
    elif show:
        print('Showing plot...')
        plt.show()

    return ax
Example #8
0
 def test_form_to_stoich(self):
     formula = "Li12P1N18"
     stoich = [["Li", 12], ["P", 1], ["N", 18]]
     self.assertEqual(stoich, get_stoich_from_formula(formula, sort=False))