Exemple #1
0
    def set_nn_info(self):

        # check conformance
        # implicitly assumes that all NearNeighbors subclasses
        # will correctly identify bonds in diamond, if it
        # can't there are probably bigger problems
        subclasses = NearNeighbors.__subclasses__()
        for subclass in subclasses:
            nn_info = subclass().get_nn_info(self.diamond, 0)
            self.assertEqual(nn_info[0]['site_index'], 1)
            self.assertEqual(nn_info[0]['image'][0], 1)
Exemple #2
0
    def set_nn_info(self):

        # check conformance
        # implicitly assumes that all NearNeighbors subclasses
        # will correctly identify bonds in diamond, if it
        # can't there are probably bigger problems
        subclasses = NearNeighbors.__subclasses__()
        for subclass in subclasses:
            nn_info = subclass().get_nn_info(self.diamond, 0)
            self.assertEqual(nn_info[0]['site_index'], 1)
            self.assertEqual(nn_info[0]['image'][0], 1)
    def set_nn_info(self):

        # check conformance
        # implicitly assumes that all NearNeighbors subclasses
        # will correctly identify bonds in diamond, if it
        # can't there are probably bigger problems
        subclasses = NearNeighbors.__subclasses__()
        for subclass in subclasses:
            # Critic2NN has external dependency, is tested separately
            if 'Critic2' not in str(subclass):
                nn_info = subclass().get_nn_info(self.diamond, 0)
                self.assertEqual(nn_info[0]['site_index'], 1)
                self.assertEqual(nn_info[0]['image'][0], 1)
Exemple #4
0
    def set_nn_info(self):

        # check conformance
        # implicitly assumes that all NearNeighbors subclasses
        # will correctly identify bonds in diamond, if it
        # can't there are probably bigger problems
        subclasses = NearNeighbors.__subclasses__()
        for subclass in subclasses:
            # Critic2NN has external dependency, is tested separately
            if 'Critic2' not in str(subclass):
                nn_info = subclass().get_nn_info(self.diamond, 0)
                self.assertEqual(nn_info[0]['site_index'], 1)
                self.assertEqual(nn_info[0]['image'][0], 1)
Exemple #5
0
    def __init__(self,
                 materials,
                 bonding,
                 strategies=('MinimumDistanceNN', 'MinimumOKeeffeNN', 'JMolNN',
                             'MinimumVIRENN', 'VoronoiNN', 'CrystalNN',
                             'EconNN', 'BrunnerNN_real', 'BrunnerNN_relative',
                             'BrunnerNN_reciprocal', 'Critic2NN'),
                 query=None,
                 **kwargs):
        """
        Builder to calculate bonding in a crystallographic
        structure via near neighbor strategies, including those
        in pymatgen.analysis.local_env and using the critic2 tool.

        Args:
            materials (Store): Store of materials documents
            bonding (Store): Store of topology data
            strategies (list): List of NearNeighbor classes to use (can be
            an instance of a NearNeighbor class or its name as a string,
            in which case it will be instantiated with default arguments)
            query (dict): dictionary to limit materials to be analyzed
        """

        self.chunk_size = 100

        self.materials = materials
        self.bonding = bonding
        self.query = query or {}

        available_strategies = {
            nn.__name__: nn
            for nn in NearNeighbors.__subclasses__()
        }

        if strategies:
            # use the class if passed directly (e.g. with custom kwargs),
            # otherwise instantiate class with default options
            self.strategies = [
                strategy if isinstance(strategy, NearNeighbors) else
                available_strategies[strategy]() for strategy in strategies
            ]
        else:
            # calculate all the strategies
            self.strategies = available_strategies.values()

        bonding.validator = BondValidator()

        super().__init__(sources=[materials], targets=[bonding], **kwargs)
Exemple #6
0
    def __init__(self,
                 materials,
                 bonding,
                 strategies=("CrystalNN", ),
                 **kwargs):
        """
        Builder to calculate bonding in a crystallographic
        structure via near neighbor strategies, including those
        in pymatgen.analysis.local_env and using the critic2 tool.

        Args:
            materials (Store): Store of materials documents
            bonding (Store): Store of topology data
            strategies (list): List of NearNeighbor classes to use (can be
            an instance of a NearNeighbor class or its name as a string,
            in which case it will be instantiated with default arguments)
            query (dict): dictionary to limit materials to be analyzed
        """

        self.materials = materials
        self.bonding = bonding
        self.bonding.validator = JSONSchemaValidator(loadfn(BOND_SCHEMA))

        available_strategies = {
            nn.__name__: nn
            for nn in NearNeighbors.__subclasses__()
        }

        # use the class if passed directly (e.g. with custom kwargs),
        # otherwise instantiate class with default options
        self.strategies = [
            strategy if isinstance(strategy, NearNeighbors) else
            available_strategies[strategy]() for strategy in strategies
        ]
        self.strategy_names = [
            strategy.__class__.__name__ for strategy in self.strategies
        ]

        self.bad_task_ids = [
        ]  # Voronoi-based strategies can cause some structures to cause crash

        super().__init__(source=materials,
                         target=bonding,
                         ufn=self.calc,
                         projection=["structure"],
                         **kwargs)
Exemple #7
0
class StructureMoleculeComponent(MPComponent):

    available_bonding_strategies = {
        subclass.__name__: subclass
        for subclass in NearNeighbors.__subclasses__()
    }

    available_radius_strategies = (
        "atomic",
        "specified_or_average_ionic",
        "covalent",
        "van_der_waals",
        "atomic_calculated",
        "uniform",
    )

    # TODO ...
    available_polyhedra_rules = ("prefer_large_polyhedra", "only_same_species")

    default_scene_settings = {
        "lights": [
            {
                "type": "DirectionalLight",
                "args": ["#ffffff", 0.15],
                "position": [-10, 10, 10],
                # "helper":True
            },
            {
                "type": "DirectionalLight",
                "args": ["#ffffff", 0.15],
                "position": [0, 0, -10],
                # "helper": True
            },
            # {"type":"AmbientLight", "args":["#eeeeee", 0.9]}
            {
                "type": "HemisphereLight",
                "args": ["#eeeeee", "#999999", 1.0]
            },
        ],
        "material": {
            "type": "MeshStandardMaterial",
            "parameters": {
                "roughness": 0.07,
                "metalness": 0.00
            },
        },
        "objectScale":
        1.0,
        "cylinderScale":
        0.1,
        "defaultSurfaceOpacity":
        0.5,
        "staticScene":
        True,
    }

    def __init__(
        self,
        struct_or_mol=None,
        id=None,
        origin_component=None,
        scene_additions=None,
        bonding_strategy="BrunnerNN_reciprocal",
        bonding_strategy_kwargs=None,
        color_scheme="Jmol",
        color_scale=None,
        radius_strategy="uniform",
        draw_image_atoms=True,
        bonded_sites_outside_unit_cell=False,
    ):

        super().__init__(id=id,
                         contents=struct_or_mol,
                         origin_component=origin_component)

        self.default_title = "Crystal Toolkit"

        self.initial_scene_settings = StructureMoleculeComponent.default_scene_settings
        self.create_store("scene_settings",
                          initial_data=self.initial_scene_settings)

        self.initial_graph_generation_options = {
            "bonding_strategy": bonding_strategy,
            "bonding_strategy_kwargs": bonding_strategy_kwargs,
        }
        self.create_store(
            "graph_generation_options",
            initial_data=self.initial_graph_generation_options,
        )

        self.initial_display_options = {
            "color_scheme": color_scheme,
            "color_scale": color_scale,
            "radius_strategy": radius_strategy,
            "draw_image_atoms": draw_image_atoms,
            "bonded_sites_outside_unit_cell": bonded_sites_outside_unit_cell,
        }
        self.create_store("display_options",
                          initial_data=self.initial_display_options)

        if struct_or_mol:
            # graph is cached explicitly, this isn't necessary but is an
            # optimization so that graph is only re-generated if bonding
            # algorithm changes
            graph = self._preprocess_input_to_graph(
                struct_or_mol,
                bonding_strategy=bonding_strategy,
                bonding_strategy_kwargs=bonding_strategy_kwargs,
            )
            scene, legend = self.get_scene_and_legend(
                graph, name=self.id(), **self.initial_display_options)
        else:
            # component could be initialized without a structure, in which case
            # an empty scene should be displayed
            graph = None
            scene, legend = self.get_scene_and_legend(
                None, name=self.id(), **self.initial_display_options)

        self.initial_legend = legend
        self.create_store("legend_data", initial_data=self.initial_legend)

        self.initial_scene_data = scene.to_json()

        self.initial_graph = graph
        self.create_store("graph", initial_data=self.to_data(graph))

        if scene_additions:
            self.initial_scene_additions = scene_additions
            self.create_store("scene_additions",
                              initial_data=scene_additions.to_json())

    def _generate_callbacks(self, app, cache):
        @app.callback(
            Output(self.id("graph"), "data"),
            [
                Input(self.id("graph_generation_options"), "data"),
                Input(self.id("unit-cell-choice"), "value"),
                Input(self.id("repeats"), "value"),
                Input(self.id(), "data"),
            ],
        )
        def update_graph(graph_generation_options, unit_cell_choice, repeats,
                         struct_or_mol):

            struct_or_mol = self.from_data(struct_or_mol)
            graph_generation_options = self.from_data(graph_generation_options)
            repeats = int(repeats)

            if isinstance(struct_or_mol, Structure):
                if unit_cell_choice != "input":
                    if unit_cell_choice == "primitive":
                        struct_or_mol = struct_or_mol.get_primitive_structure()
                    elif unit_cell_choice == "conventional":
                        sga = SpacegroupAnalyzer(struct_or_mol)
                        struct_or_mol = sga.get_conventional_standard_structure(
                        )
                if repeats != 1:
                    struct_or_mol = struct_or_mol * (repeats, repeats, repeats)

            graph = self._preprocess_input_to_graph(
                struct_or_mol,
                bonding_strategy=graph_generation_options["bonding_strategy"],
                bonding_strategy_kwargs=graph_generation_options[
                    "bonding_strategy_kwargs"],
            )

            return self.to_data(graph)

        @app.callback(
            Output(self.id("scene"), "data"),
            [
                Input(self.id("graph"), "data"),
                Input(self.id("display_options"), "data"),
            ],
        )
        def update_scene(graph, display_options):
            display_options = self.from_data(display_options)
            graph = self.from_data(graph)
            scene, legend = self.get_scene_and_legend(graph, **display_options)
            return scene.to_json()

        @app.callback(
            Output(self.id("legend_data"), "data"),
            [
                Input(self.id("graph"), "data"),
                Input(self.id("display_options"), "data"),
            ],
        )
        def update_legend(graph, display_options):
            # TODO: more cleanly split legend from scene generation
            display_options = self.from_data(display_options)
            graph = self.from_data(graph)
            struct_or_mol = self._get_struct_or_mol(graph)
            site_prop_types = self._analyze_site_props(struct_or_mol)
            colors, legend = self._get_display_colors_and_legend_for_sites(
                struct_or_mol,
                site_prop_types,
                color_scheme=display_options.get("color_scheme", None),
                color_scale=display_options.get("color_scale", None),
            )
            return self.to_data(legend)

        @app.callback(
            Output(self.id("color-scheme"), "options"),
            [Input(self.id("graph"), "data")],
        )
        def update_color_options(graph):

            options = [
                {
                    "label": "Jmol",
                    "value": "Jmol"
                },
                {
                    "label": "VESTA",
                    "value": "VESTA"
                },
            ]
            graph = self.from_data(graph)
            struct_or_mol = self._get_struct_or_mol(graph)
            site_props = self._analyze_site_props(struct_or_mol)
            if "scalar" in site_props:
                for prop in site_props["scalar"]:
                    options += [{
                        "label": f"Site property: {prop}",
                        "value": prop
                    }]

            return options

        @app.callback(
            Output(self.id("display_options"), "data"),
            [
                Input(self.id("color-scheme"), "value"),
                Input(self.id("radius_strategy"), "value"),
                Input(self.id("draw_options"), "values"),
            ],
            [State(self.id("display_options"), "data")],
        )
        def update_display_options(color_scheme, radius_strategy, draw_options,
                                   display_options):
            display_options = self.from_data(display_options)
            display_options.update({"color_scheme": color_scheme})
            display_options.update({"radius_strategy": radius_strategy})
            display_options.update(
                {"draw_image_atoms": "draw_image_atoms" in draw_options})
            display_options.update({
                "bonded_sites_outside_unit_cell":
                "bonded_sites_outside_unit_cell" in draw_options
            })
            return self.to_data(display_options)

        @app.callback(
            Output(self.id("scene"), "downloadRequest"),
            [Input(self.id("screenshot_button"), "n_clicks")],
            [
                State(self.id("scene"), "downloadRequest"),
                State(self.id(), "data")
            ],
        )
        def screenshot_callback(n_clicks, current_requests, struct_or_mol):
            if n_clicks is None:
                raise PreventUpdate
            struct_or_mol = self.from_data(struct_or_mol)
            # TODO: this will break if store is structure/molecule graph ...
            formula = struct_or_mol.composition.reduced_formula
            if hasattr(struct_or_mol, "get_space_group_info"):
                spgrp = struct_or_mol.get_space_group_info()[0]
            else:
                spgrp = ""
            request_filename = "{}-{}-crystal-toolkit.png".format(
                formula, spgrp)
            if not current_requests:
                n_requests = 1
            else:
                n_requests = current_requests["n_requests"] + 1
            return {
                "n_requests": n_requests,
                "filename": request_filename,
                "filetype": "png",
            }

        @app.callback(
            Output(self.id("scene"), "toggleVisibility"),
            [Input(self.id("hide-show"), "values")],
            [State(self.id("hide-show"), "options")],
        )
        def update_visibility(values, options):
            visibility = {
                opt["value"]: (opt["value"] in values)
                for opt in options
            }
            return visibility

        @app.callback(
            Output(self.id("title_container"), "children"),
            [Input(self.id("legend_data"), "data")],
        )
        def update_title(legend):
            legend = self.from_data(legend)
            return self._make_title(legend)

        @app.callback(
            Output(self.id("legend_container"), "children"),
            [Input(self.id("legend_data"), "data")],
        )
        def update_legend(legend):
            legend = self.from_data(legend)
            return self._make_legend(legend)

    def _make_legend(self, legend):

        if legend is None or (not legend.get("colors", None)):
            return html.Div(id=self.id("legend"))

        def get_font_color(hex_code):
            # ensures contrasting font color for background color
            c = tuple(int(hex_code[1:][i:i + 2], 16) for i in (0, 2, 4))
            if 1 - (c[0] * 0.299 + c[1] * 0.587 + c[2] * 0.114) / 255 < 0.5:
                font_color = "#000000"
            else:
                font_color = "#ffffff"
            return font_color

        formula = Composition.from_dict(legend["composition"]).reduced_formula
        legend_colors = OrderedDict(
            sorted(list(legend["colors"].items()),
                   key=lambda x: formula.find(x[1])))

        legend_elements = [
            Button(
                html.Span(name,
                          className="icon",
                          style={"color": get_font_color(color)}),
                kind="static",
                style={"background-color": color},
            ) for color, name in legend_colors.items()
        ]

        return Field(
            [
                Control(el, style={"margin-right": "0.2rem"})
                for el in legend_elements
            ],
            id=self.id("legend"),
            grouped=True,
        )

    def _make_title(self, legend):

        if not legend or (not legend.get("composition", None)):
            return H1(self.default_title, id=self.id("title"))

        composition = legend["composition"]
        if isinstance(composition, dict):
            composition = Composition.from_dict(composition)

        formula = composition.reduced_formula
        formula_parts = re.findall(r"[^\d_]+|\d+", formula)

        formula_components = [
            html.Sub(part) if part.isnumeric() else html.Span(part)
            for part in formula_parts
        ]

        return H1(formula_components,
                  id=self.id("title"),
                  style={"display": "inline-block"})

    @property
    def all_layouts(self):

        struct_layout = html.Div(
            Simple3DSceneComponent(
                id=self.id("scene"),
                data=self.initial_scene_data,
                settings=self.initial_scene_settings,
            ),
            style={
                "width": "100%",
                "height": "100%",
                "overflow": "hidden",
                "margin": "0 auto",
            },
        )

        screenshot_layout = html.Div(
            [
                Button(
                    [Icon(), html.Span(), "Download Image"],
                    kind="primary",
                    id=self.id("screenshot_button"),
                )
            ],
            # TODO: change to "bottom" when dropdown included
            style={
                "vertical-align": "top",
                "display": "inline-block"
            },
        )

        title_layout = html.Div(self._make_title(self.initial_legend),
                                id=self.id("title_container"))

        legend_layout = html.Div(self._make_legend(self.initial_legend),
                                 id=self.id("legend_container"))

        # options = {
        #    "bonding_strategy": bonding_strategy,
        #    "bonding_strategy_kwargs": bonding_strategy_kwargs,

        options_layout = Field([
            #  hide if molecule
            html.Label("Change unit cell:", className="mpc-label"),
            html.Div(
                dcc.RadioItems(
                    options=[
                        {
                            "label": "Input cell",
                            "value": "input"
                        },
                        {
                            "label": "Primitive cell",
                            "value": "primitive"
                        },
                        {
                            "label": "Conventional cell",
                            "value": "conventional"
                        },
                    ],
                    value="conventional",
                    id=self.id("unit-cell-choice"),
                    labelStyle={"display": "block"},
                    inputClassName="mpc-radio",
                ),
                className="mpc-control",
            ),
            #  hide if molecule
            html.Label("Change number of repeats:", className="mpc-label"),
            html.Div(
                dcc.RadioItems(
                    options=[
                        {
                            "label": "1×1×1",
                            "value": "1"
                        },
                        {
                            "label": "2×2×2",
                            "value": "2"
                        },
                    ],
                    value="1",
                    id=self.id("repeats"),
                    labelStyle={"display": "block"},
                    inputClassName="mpc-radio",
                ),
                className="mpc-control",
            ),
            html.Label("Change color scheme:", className="mpc-label"),
            html.Div(
                dcc.Dropdown(
                    options=[
                        {
                            "label": "VESTA",
                            "value": "VESTA"
                        },
                        {
                            "label": "Jmol",
                            "value": "Jmol"
                        },
                    ],
                    value="VESTA",
                    clearable=False,
                    id=self.id("color-scheme"),
                ),
                className="mpc-control",
            ),
            html.Label("Change atomic radii:", className="mpc-label"),
            html.Div(
                dcc.Dropdown(
                    options=[
                        {
                            "label": "Ionic",
                            "value": "specified_or_average_ionic"
                        },
                        {
                            "label": "Covalent",
                            "value": "covalent"
                        },
                        {
                            "label": "Van der Waals",
                            "value": "van_der_waals"
                        },
                        {
                            "label": "Uniform (0.5Å)",
                            "value": "uniform"
                        },
                    ],
                    value="uniform",
                    clearable=False,
                    id=self.id("radius_strategy"),
                ),
                className="mpc-control",
            ),
            html.Label("Draw options:", className="mpc-label"),
            html.Div([
                dcc.Checklist(
                    options=[
                        {
                            "label":
                            "Draw repeats of atoms on periodic boundaries",
                            "value": "draw_image_atoms",
                        },
                        {
                            "label": "Draw atoms outside unit cell bonded to "
                            "atoms within unit cell",
                            "value": "bonded_sites_outside_unit_cell",
                        },
                    ],
                    values=["draw_image_atoms"],
                    labelStyle={"display": "block"},
                    inputClassName="mpc-radio",
                    id=self.id("draw_options"),
                )
            ]),
            html.Label("Hide/show:", className="mpc-label"),
            html.Div(
                [
                    dcc.Checklist(
                        options=[
                            {
                                "label": "Atoms",
                                "value": "atoms"
                            },
                            {
                                "label": "Bonds",
                                "value": "bonds"
                            },
                            {
                                "label": "Unit cell",
                                "value": "unit_cell"
                            },
                            {
                                "label": "Polyhedra",
                                "value": "polyhedra"
                            },
                        ],
                        values=["atoms", "bonds", "unit_cell", "polyhedra"],
                        labelStyle={"display": "block"},
                        inputClassName="mpc-radio",
                        id=self.id("hide-show"),
                    )
                ],
                className="mpc-control",
            ),
        ])

        return {
            "struct": struct_layout,
            "screenshot": screenshot_layout,
            "options": options_layout,
            "title": title_layout,
            "legend": legend_layout,
        }

    @property
    def standard_layout(self):
        return self.all_layouts["struct"]

    @staticmethod
    def _preprocess_input_to_graph(
        input: Union[Structure, StructureGraph, Molecule, MoleculeGraph],
        bonding_strategy: str = "CrystalNN",
        bonding_strategy_kwargs: Optional[Dict] = None,
    ) -> Union[StructureGraph, MoleculeGraph]:

        if isinstance(input, Structure):

            # ensure fractional co-ordinates are normalized to be in [0,1)
            # (this is actually not guaranteed by Structure)
            input = input.as_dict(verbosity=0)
            for site in input["sites"]:
                site["abc"] = np.mod(site["abc"], 1)
            input = Structure.from_dict(input)

            if not input.is_ordered:
                # calculating bonds in disordered structures is currently very flaky
                bonding_strategy = "CutOffDictNN"

        # we assume most uses of this class will give a structure as an input argument,
        # meaning we have to calculate the graph for bonding information, however if
        # the graph is already known and supplied, we will use that
        if isinstance(input, StructureGraph) or isinstance(
                input, MoleculeGraph):
            graph = input
        else:
            if (bonding_strategy not in StructureMoleculeComponent.
                    available_bonding_strategies.keys()):
                raise ValueError(
                    "Bonding strategy not supported. Please supply a name "
                    "of a NearNeighbor subclass, choose from: {}".format(
                        ", ".join(StructureMoleculeComponent.
                                  available_bonding_strategies.keys())))
            else:
                bonding_strategy_kwargs = bonding_strategy_kwargs or {}
                if bonding_strategy == "CutOffDictNN":
                    if "cut_off_dict" in bonding_strategy_kwargs:
                        # TODO: remove this hack by making args properly JSON serializable
                        bonding_strategy_kwargs["cut_off_dict"] = {
                            (x[0], x[1]): x[2]
                            for x in bonding_strategy_kwargs["cut_off_dict"]
                        }
                bonding_strategy = StructureMoleculeComponent.available_bonding_strategies[
                    bonding_strategy](**bonding_strategy_kwargs)
                try:
                    if isinstance(input, Structure):
                        graph = StructureGraph.with_local_env_strategy(
                            input, bonding_strategy)
                    else:
                        graph = MoleculeGraph.with_local_env_strategy(
                            input, bonding_strategy)
                except:
                    # for some reason computing bonds failed, so let's not have any bonds(!)
                    if isinstance(input, Structure):
                        graph = StructureGraph.with_empty_graph(input)
                    else:
                        graph = MoleculeGraph.with_empty_graph(input)

        return graph

    @staticmethod
    def _analyze_site_props(struct_or_mol):

        # store list of site props that are vectors, so these can be displayed as arrows
        # (implicitly assumes all site props for a given key are same type)
        site_prop_names = defaultdict(list)
        for name, props in struct_or_mol.site_properties.items():
            if isinstance(props[0], float) or isinstance(props[0], int):
                site_prop_names["scalar"].append(name)
            elif isinstance(props[0], list) and len(props[0]) == 3:
                if isinstance(props[0][0], list) and len(props[0][0]) == 3:
                    site_prop_names["matrix"].append(name)
                else:
                    site_prop_names["vector"].append(name)
            elif isinstance(props[0], str):
                site_prop_names["categorical"].append(name)

        return dict(site_prop_names)

    @staticmethod
    def _get_origin(struct_or_mol):

        if isinstance(struct_or_mol, Structure):
            # display_range = [0.5, 0.5, 0.5]
            # x_center = 0.5 * (max(display_range[0]) - min(display_range[0]))
            # y_center = 0.5 * (max(display_range[1]) - min(display_range[1]))
            # z_center = 0.5 * (max(display_range[2]) - min(display_range[2]))
            geometric_center = struct_or_mol.lattice.get_cartesian_coords(
                (0.5, 0.5, 0.5))
        elif isinstance(struct_or_mol, Molecule):
            geometric_center = np.average(struct_or_mol.cart_coords, axis=0)
        else:
            geometric_center = (0, 0, 0)

        return geometric_center

    @staticmethod
    def _get_struct_or_mol(graph) -> Union[Structure, Molecule]:
        if isinstance(graph, StructureGraph):
            return graph.structure
        elif isinstance(graph, MoleculeGraph):
            return graph.molecule
        else:
            raise ValueError

    @staticmethod
    def _get_display_colors_and_legend_for_sites(
            struct_or_mol,
            site_prop_types,
            color_scheme="Jmol",
            color_scale=None) -> Tuple[List[List[str]], Dict]:
        """
        Note this returns a list of lists of strings since each
        site might have multiple colors defined if the site is
        disordered.

        The legend is a dictionary whose keys are colors and values
        are corresponding element names or values, depending on the color
        scheme chosen.
        """

        # TODO: check to see if there is a bug here due to Composition being unordered(?)

        legend = {
            "composition": struct_or_mol.composition.as_dict(),
            "colors": {}
        }

        # don't calculate color if one is explicitly supplied
        if "display_color" in struct_or_mol.site_properties:
            # don't know what the color legend (meaning) is, so return empty legend
            return (struct_or_mol.site_properties["display_color"], legend)

        def get_color_hex(x):
            return "#{:02x}{:02x}{:02x}".format(*x)

        if color_scheme not in ("VESTA", "Jmol", "colorblind_friendly"):

            if not struct_or_mol.is_ordered:
                raise ValueError(
                    "Can only use VESTA, Jmol or colorblind_friendly color "
                    "schemes for disordered structures or molecules, color "
                    "schemes based on site properties are ill-defined.")

            if (color_scheme not in site_prop_types.get(
                    "scalar", [])) and (color_scheme
                                        not in site_prop_types.get(
                                            "categorical", [])):

                raise ValueError(
                    "Unsupported color scheme. Should be VESTA, Jmol, "
                    "colorblind_friendly or a scalar (float) or categorical "
                    "(string) site property.")

        if color_scheme in ("VESTA", "Jmol"):

            colors = []
            for site in struct_or_mol:
                elements = [
                    sp.as_dict()["element"]
                    for sp, _ in site.species_and_occu.items()
                ]
                colors.append([
                    get_color_hex(EL_COLORS[color_scheme][element])
                    for element in elements
                ])
                # construct legend
                for element in elements:
                    color = get_color_hex(EL_COLORS[color_scheme][element])
                    legend["colors"][color] = element

        elif color_scheme in site_prop_types.get("scalar", []):

            props = np.array(struct_or_mol.site_properties[color_scheme])

            # by default, use blue-grey-red color scheme,
            # so that zero is ~ grey, and positive/negative
            # are red/blue
            color_scale = color_scale or "coolwarm"
            # try to keep color scheme symmetric around 0
            color_max = max([abs(min(props)), max(props)])
            color_min = -color_max

            cmap = get_cmap(color_scale)
            # normalize in [0, 1] range, as expected by cmap
            props_normed = (props - color_min) / (color_max - color_min)

            def get_color_cmap(x):
                return [int(c * 255) for c in cmap(x)[0:3]]

            colors = [[get_color_hex(get_color_cmap(x))] for x in props_normed]

            # construct legend

            # max/min only:
            # c = get_color_hex(get_color_cmap(color_min))
            # legend["colors"][c] = "{:.1f}".format(color_min)
            # if color_max != color_min:
            #    c = get_color_hex(get_color_cmap(color_max))
            #    legend["colors"][c] = "{:.1f}".format(color_max)

            # all colors:
            rounded_props = sorted(
                list(set([np.around(p, decimals=1) for p in props])))
            for prop in rounded_props:
                prop_normed = (prop - color_min) / (color_max - color_min)
                c = get_color_hex(get_color_cmap(prop_normed))
                legend["colors"][c] = "{:.1f}".format(prop)

        elif color_scheme == "colorblind_friendly":
            raise NotImplementedError

        elif color_scheme in site_prop_types.get("categorical", []):
            # iter() a palettable  palettable.colorbrewer.qualitative
            # cmap.colors, check len, Set1_9 ?
            raise NotImplementedError

        return colors, legend

    @staticmethod
    def _primitives_from_lattice(lattice, origin=(0, 0, 0), **kwargs):

        o = -np.array(origin)
        a, b, c = lattice.matrix[0], lattice.matrix[1], lattice.matrix[2]
        line_pairs = [
            o,
            o + a,
            o,
            o + b,
            o,
            o + c,
            o + a,
            o + a + b,
            o + a,
            o + a + c,
            o + b,
            o + b + a,
            o + b,
            o + b + c,
            o + c,
            o + c + a,
            o + c,
            o + c + b,
            o + a + b,
            o + a + b + c,
            o + a + c,
            o + a + b + c,
            o + b + c,
            o + a + b + c,
        ]
        line_pairs = [line.tolist() for line in line_pairs]

        return Lines(line_pairs, **kwargs)

    @staticmethod
    def _get_ellipsoids_from_matrix(matrix):
        raise NotImplementedError
        # matrix = np.array(matrix)
        # eigenvalues, eigenvectors = np.linalg.eig(matrix)

    @staticmethod
    def _primitives_from_site(
        site,
        connected_sites=None,
        origin=(0, 0, 0),
        ellipsoid_site_prop=None,
        all_connected_sites_present=True,
        explicitly_calculate_polyhedra_hull=False,
    ):
        """
        Sites must have display_radius and display_color site properties.
        :param site:
        :param connected_sites:
        :param origin:
        :param ellipsoid_site_prop: (beta)
        :param all_connected_sites_present: if False, will not calculate
        polyhedra since this would be misleading
        :param explicitly_calculate_polyhedra_hull:
        :return:
        """

        atoms = []
        bonds = []
        polyhedron = []

        # for disordered structures
        is_ordered = site.is_ordered
        occu_start = 0.0

        # for thermal ellipsoids etc.
        if ellipsoid_site_prop:
            matrix = site.properties[ellipsoid_site_prop]
            ellipsoids = StructureMoleculeComponent._get_ellipsoids_from_matrix(
                matrix)
        else:
            ellipsoids = None

        position = np.subtract(site.coords, origin).tolist()

        # site_color is used for bonds and polyhedra, if multiple colors are
        # defined for site (e.g. a disordered site), then we use grey
        all_colors = set(site.properties["display_color"])
        if len(all_colors) > 1:
            site_color = "#555555"
        else:
            site_color = list(all_colors)[0]

        for idx, (sp, occu) in enumerate(site.species_and_occu.items()):

            if isinstance(sp, DummySpecie):

                cube = Cubes(positions=[position])
                atoms.append(cube)

            else:

                color = site.properties["display_color"][idx]
                radius = site.properties["display_radius"][idx]

                # TODO: make optional/default to None
                # in disordered structures, we fractionally color-code spheres,
                # drawing a sphere segment from phi_end to phi_start
                # (think a sphere pie chart)
                if not is_ordered:
                    phi_frac_end = occu_start + occu
                    phi_frac_start = occu_start
                    occu_start = phi_frac_end
                    phiStart = phi_frac_start * np.pi * 2
                    phiEnd = phi_frac_end * np.pi * 2
                else:
                    phiStart, phiEnd = None, None

                # TODO: add names for labels
                # name = "{}".format(sp)
                # if occu != 1.0:
                #    name += " ({}% occupancy)".format(occu)

                sphere = Spheres(
                    positions=[position],
                    color=color,
                    radius=radius,
                    phiStart=phiStart,
                    phiEnd=phiEnd,
                    ellipsoids=ellipsoids,
                )
                atoms.append(sphere)

        if connected_sites:

            all_positions = [position]
            for connected_site in connected_sites:

                connected_position = np.subtract(connected_site.site.coords,
                                                 origin)
                bond_midpoint = np.add(position, connected_position) / 2

                cylinder = Cylinders(
                    positionPairs=[[position, bond_midpoint.tolist()]],
                    color=site_color)
                bonds.append(cylinder)
                all_positions.append(connected_position.tolist())

            if len(connected_sites) > 3 and all_connected_sites_present:
                if explicitly_calculate_polyhedra_hull:

                    try:

                        # all_positions = [[0, 0, 0], [0, 0, 10], [0, 10, 0], [10, 0, 0]]
                        # gives...
                        # .convex_hull = [[2, 3, 0], [1, 3, 0], [1, 2, 0], [1, 2, 3]]
                        # .vertex_neighbor_vertices = [1, 2, 3, 2, 3, 0, 1, 3, 0, 1, 2, 0]

                        vertices_indices = Delaunay(
                            all_positions).vertex_neighbor_vertices
                        vertices = [
                            all_positions[idx] for idx in vertices_indices
                        ]

                        polyhedron = [
                            Surface(
                                positions=vertices,
                                color=site.properties["display_color"][0],
                            )
                        ]

                    except Exception as e:

                        polyhedron = []

                else:

                    polyhedron = [
                        Convex(positions=all_positions, color=site_color)
                    ]

        return {"atoms": atoms, "bonds": bonds, "polyhedra": polyhedron}

    @staticmethod
    def _get_display_radii_for_sites(
            struct_or_mol,
            radius_strategy="specified_or_average_ionic") -> List[List[float]]:
        """
        Note this returns a list of lists of floats since each
        site might have multiple radii defined if the site is
        disordered.
        """

        # don't calculate radius if one is explicitly supplied
        if "display_radius" in struct_or_mol.site_properties:
            return struct_or_mol.site_properties["display_radius"]

        if (radius_strategy
                not in StructureMoleculeComponent.available_radius_strategies):
            raise ValueError(
                "Unknown radius strategy {}, choose from: {}".format(
                    radius_strategy,
                    StructureMoleculeComponent.available_radius_strategies,
                ))
        radii = []

        for site_idx, site in enumerate(struct_or_mol):

            site_radii = []

            for comp_idx, (sp,
                           occu) in enumerate(site.species_and_occu.items()):

                radius = None

                if radius_strategy == "uniform":
                    radius = 0.5
                if radius_strategy == "atomic":
                    radius = sp.atomic_radius
                elif (radius_strategy == "specified_or_average_ionic"
                      and isinstance(sp, Specie) and sp.oxi_state):
                    radius = sp.ionic_radius
                elif radius_strategy == "specified_or_average_ionic":
                    radius = sp.average_ionic_radius
                elif radius_strategy == "covalent":
                    el = str(getattr(sp, "element", sp))
                    radius = CovalentRadius.radius[el]
                elif radius_strategy == "van_der_waals":
                    radius = sp.van_der_waals_radius
                elif radius_strategy == "atomic_calculated":
                    radius = sp.atomic_radius_calculated

                if not radius:
                    warnings.warn("Radius unknown for {} and strategy {}, "
                                  "setting to 1.0.".format(
                                      sp, radius_strategy))
                    radius = 1.0

                site_radii.append(radius)

            radii.append(site_radii)

        return radii

    @staticmethod
    def _get_sites_to_draw(
        struct_or_mol: Union[Structure, Molecule],
        graph: Union[StructureGraph, MoleculeGraph],
        draw_image_atoms=True,
        bonded_sites_outside_unit_cell=True,
    ):
        """
        Returns a list of site indices and image vectors.
        """

        sites_to_draw = [(idx, (0, 0, 0)) for idx in range(len(struct_or_mol))]

        # trivial in this case
        if isinstance(struct_or_mol, Molecule):
            return sites_to_draw

        if draw_image_atoms:

            for idx, site in enumerate(struct_or_mol):

                zero_elements = [
                    idx for idx, f in enumerate(site.frac_coords)
                    if np.allclose(f, 0, atol=0.05)
                ]

                coord_permutations = [
                    x for l in range(1,
                                     len(zero_elements) + 1)
                    for x in combinations(zero_elements, l)
                ]

                for perm in coord_permutations:
                    sites_to_draw.append((idx, (int(0 in perm), int(1 in perm),
                                                int(2 in perm))))

                one_elements = [
                    idx for idx, f in enumerate(site.frac_coords)
                    if np.allclose(f, 1, atol=0.05)
                ]

                coord_permutations = [
                    x for l in range(1,
                                     len(one_elements) + 1)
                    for x in combinations(one_elements, l)
                ]

                for perm in coord_permutations:
                    sites_to_draw.append(
                        (idx, (-int(0 in perm), -int(1 in perm),
                               -int(2 in perm))))

        if bonded_sites_outside_unit_cell:

            # TODO: subtle bug here, see mp-5020, expansion logic not quite right
            sites_to_append = []
            for (n, jimage) in sites_to_draw:
                connected_sites = graph.get_connected_sites(n, jimage=jimage)
                for connected_site in connected_sites:
                    if connected_site.jimage != (0, 0, 0):
                        sites_to_append.append(
                            (connected_site.index, connected_site.jimage))
            sites_to_draw += sites_to_append

        return sites_to_draw

    @staticmethod
    def get_scene_and_legend(
        graph: Union[StructureGraph, MoleculeGraph],
        name="StructureMoleculeComponent",
        color_scheme="Jmol",
        color_scale=None,
        radius_strategy="specified_or_average_ionic",
        ellipsoid_site_prop=None,
        draw_image_atoms=True,
        bonded_sites_outside_unit_cell=True,
        hide_incomplete_bonds=False,
        explicitly_calculate_polyhedra_hull=False,
    ) -> Tuple[Scene, Dict[str, str]]:

        scene = Scene(name=name)

        if graph is None:
            return scene, {}

        struct_or_mol = StructureMoleculeComponent._get_struct_or_mol(graph)
        site_prop_types = StructureMoleculeComponent._analyze_site_props(
            struct_or_mol)

        radii = StructureMoleculeComponent._get_display_radii_for_sites(
            struct_or_mol, radius_strategy=radius_strategy)
        colors, legend = StructureMoleculeComponent._get_display_colors_and_legend_for_sites(
            struct_or_mol,
            site_prop_types,
            color_scale=color_scale,
            color_scheme=color_scheme,
        )

        struct_or_mol.add_site_property("display_radius", radii)
        struct_or_mol.add_site_property("display_color", colors)

        origin = StructureMoleculeComponent._get_origin(struct_or_mol)

        primitives = defaultdict(list)
        sites_to_draw = StructureMoleculeComponent._get_sites_to_draw(
            struct_or_mol,
            graph,
            draw_image_atoms=draw_image_atoms,
            bonded_sites_outside_unit_cell=bonded_sites_outside_unit_cell,
        )

        for (idx, jimage) in sites_to_draw:

            site = struct_or_mol[idx]
            if jimage != (0, 0, 0):
                connected_sites = graph.get_connected_sites(idx, jimage=jimage)
                site = PeriodicSite(
                    site.species_and_occu,
                    np.add(site.frac_coords, jimage),
                    site.lattice,
                    properties=site.properties,
                )
            else:
                connected_sites = graph.get_connected_sites(idx)

            true_number_of_connected_sites = len(connected_sites)
            connected_sites_being_drawn = [
                cs for cs in connected_sites
                if (cs.index, cs.jimage) in sites_to_draw
            ]
            number_of_connected_sites_drawn = len(connected_sites_being_drawn)
            all_connected_sites_present = (true_number_of_connected_sites ==
                                           number_of_connected_sites_drawn)
            if hide_incomplete_bonds:
                # only draw bonds if the destination site is also being drawn
                connected_sites = connected_sites_being_drawn

            site_primitives = StructureMoleculeComponent._primitives_from_site(
                site,
                connected_sites=connected_sites,
                all_connected_sites_present=all_connected_sites_present,
                origin=origin,
                ellipsoid_site_prop=ellipsoid_site_prop,
                explicitly_calculate_polyhedra_hull=
                explicitly_calculate_polyhedra_hull,
            )
            for k, v in site_primitives.items():
                primitives[k] += v

        # we are here ...
        # select polyhedra
        # split by atom type at center
        # see if any intersect, if yes split further
        # order sets, with each choice, go to add second set etc if don't intersect
        # they intersect if centre atom forms vertex of another atom (caveat: centre atom may not actually be inside polyhedra! not checking for this, add todo)
        # def _set_intersects() ->bool:
        # def _split_set() ->List: (by type, then..?)
        # def _order_sets()... pick 1, ask can add 2? etc

        if isinstance(struct_or_mol, Structure):
            primitives["unit_cell"].append(
                StructureMoleculeComponent._primitives_from_lattice(
                    struct_or_mol.lattice, origin=origin))

        sub_scenes = [Scene(name=k, contents=v) for k, v in primitives.items()]
        scene.contents = sub_scenes

        return scene, legend
Exemple #8
0
class StructureMoleculeComponent(MPComponent):
    """
    A component to display pymatgen Structure, Molecule, StructureGraph
    and MoleculeGraph objects.
    """

    available_bonding_strategies = {
        subclass.__name__: subclass for subclass in NearNeighbors.__subclasses__()
    }

    default_scene_settings = {}

    # whether to persist options such as atomic radii etc.
    persistence = False
    persistence_type = "local"

    def __init__(
        self,
        struct_or_mol: Optional[
            Union[Structure, StructureGraph, Molecule, MoleculeGraph]
        ] = None,
        id: str = None,
        scene_additions: Optional[Scene] = None,
        bonding_strategy: str = DEFAULTS["bonding_strategy"],
        bonding_strategy_kwargs: Optional[dict] = None,
        color_scheme: str = DEFAULTS["color_scheme"],
        color_scale: Optional[str] = None,
        radius_strategy: str = DEFAULTS["radius_strategy"],
        unit_cell_choice: str = DEFAULTS["unit_cell_choice"],
        draw_image_atoms: bool = DEFAULTS["draw_image_atoms"],
        bonded_sites_outside_unit_cell: bool = DEFAULTS[
            "bonded_sites_outside_unit_cell"
        ],
        hide_incomplete_bonds: bool = DEFAULTS["hide_incomplete_bonds"],
        show_compass: bool = DEFAULTS["show_compass"],
        scene_settings: Optional[Dict] = None,
        **kwargs,
    ):
        """
        Create a StructureMoleculeComponent from a structure or molecule.

        :param struct_or_mol: input structure or molecule
        :param id: canonical id
        :param scene_additions: extra geometric elements to add to the 3D scene
        :param bonding_strategy: bonding strategy from pymatgen NearNeighbors class
        :param bonding_strategy_kwargs: options for the bonding strategy
        :param color_scheme: color scheme, see Legend class
        :param color_scale: color scale, see Legend class
        :param radius_strategy: radius strategy, see Legend class
        :param draw_image_atoms: whether to draw repeats of atoms on periodic images
        :param bonded_sites_outside_unit_cell: whether to draw sites bonded outside the unit cell
        :param hide_incomplete_bonds: whether to hide or show incomplete bonds
        :param show_compass: whether to hide or show the compass
        :param scene_settings: scene settings (lighting etc.) to pass to Simple3DScene
        :param kwargs: extra keyword arguments to pass to MPComponent
        """

        super().__init__(id=id, default_data=struct_or_mol, **kwargs)

        # what to show for the title_layout if structure/molecule not loaded
        self.default_title = "Crystal Toolkit"

        self.initial_scene_settings = self.default_scene_settings.copy()
        if scene_settings:
            self.initial_scene_settings.update(scene_settings)

        self.create_store("scene_settings", initial_data=self.initial_scene_settings)

        # unit cell choice and bonding algorithms need to come from a settings
        # object (in a dcc.Store) guaranteed to be present in layout, rather
        # than from the controls themselves -- since these are optional and
        # may not be present in the layout
        self.create_store(
            "graph_generation_options",
            initial_data={
                "bonding_strategy": bonding_strategy,
                "bonding_strategy_kwargs": bonding_strategy_kwargs,
                "unit_cell_choice": unit_cell_choice,
            },
        )

        self.create_store(
            "display_options",
            initial_data={
                "color_scheme": color_scheme,
                "color_scale": color_scale,
                "radius_strategy": radius_strategy,
                "draw_image_atoms": draw_image_atoms,
                "bonded_sites_outside_unit_cell": bonded_sites_outside_unit_cell,
                "hide_incomplete_bonds": hide_incomplete_bonds,
                "show_compass": show_compass,
            },
        )

        if scene_additions:
            initial_scene_additions = Scene(
                name="scene_additions", contents=scene_additions
            ).to_json()
        else:
            initial_scene_additions = None
        self.create_store("scene_additions", initial_data=initial_scene_additions)

        if struct_or_mol:
            # graph is cached explicitly, this isn't necessary but is an
            # optimization so that graph is only re-generated if bonding
            # algorithm changes
            graph = self._preprocess_input_to_graph(
                struct_or_mol,
                bonding_strategy=bonding_strategy,
                bonding_strategy_kwargs=bonding_strategy_kwargs,
            )
            scene, legend = self.get_scene_and_legend(
                graph,
                name=self.id(),
                scene_additions=self.initial_data["scene_additions"],
                **self.initial_data["display_options"],
            )
            if hasattr(struct_or_mol, "lattice"):
                self._lattice = struct_or_mol.lattice
        else:
            # component could be initialized without a structure, in which case
            # an empty scene should be displayed
            graph = None
            scene, legend = self.get_scene_and_legend(
                None,
                name=self.id(),
                scene_additions=self.initial_data["scene_additions"],
                **self.initial_data["display_options"],
            )

        self.create_store("legend_data", initial_data=legend)
        self.create_store("graph", initial_data=graph)

        # this is used by a Simple3DScene component, not a dcc.Store
        self._initial_data["scene"] = scene

    def generate_callbacks(self, app, cache):

        # a lot of the verbosity in this callback is to support custom bonding
        # this is not the format CutOffDictNN expects (since that is not JSON
        # serializable), so we store as a list of tuples instead
        # TODO: make CutOffDictNN args JSON serializable
        app.clientside_callback(
            """
            function (bonding_strategy, custom_cutoffs_rows, unit_cell_choice) {
            
                const bonding_strategy_kwargs = {}
                if (bonding_strategy === 'CutOffDictNN') {
                    const cut_off_dict = []
                    custom_cutoffs_rows.forEach(function(row) {
                        cut_off_dict.push([row['A'], row['B'], parseFloat(row['A—B'])])
                    })
                    bonding_strategy_kwargs.cut_off_dict = cut_off_dict
                }
            
                return {
                    bonding_strategy: bonding_strategy,
                    bonding_strategy_kwargs: bonding_strategy_kwargs,
                    unit_cell_choice: unit_cell_choice
                }
            }
            """,
            Output(self.id("graph_generation_options"), "data"),
            [
                Input(self.id("bonding_algorithm"), "value"),
                Input(self.id("bonding_algorithm_custom_cutoffs"), "data"),
                Input(self.id("unit-cell-choice"), "value"),
            ],
        )

        app.clientside_callback(
            """
            function (values, options) {
                const visibility = {}
                options.forEach(function (opt) {
                    visibility[opt.value] = Boolean(values.includes(opt.value))
                })
                return visibility
            }
            """,
            Output(self.id("scene"), "toggleVisibility"),
            [Input(self.id("hide-show"), "value")],
            [State(self.id("hide-show"), "options")],
        )

        app.clientside_callback(
            """
            function (colorScheme, radiusStrategy, drawOptions, displayOptions) {
            
                const newDisplayOptions = Object.assign({}, displayOptions);
                newDisplayOptions.color_scheme = colorScheme
                newDisplayOptions.radius_strategy = radiusStrategy
                newDisplayOptions.draw_image_atoms = drawOptions.includes('draw_image_atoms')
                newDisplayOptions.bonded_sites_outside_unit_cell =  drawOptions.includes('bonded_sites_outside_unit_cell')
                newDisplayOptions.hide_incomplete_bonds = drawOptions.includes('hide_incomplete_bonds')

                return newDisplayOptions
            }
            """,
            Output(self.id("display_options"), "data"),
            [
                Input(self.id("color-scheme"), "value"),
                Input(self.id("radius_strategy"), "value"),
                Input(self.id("draw_options"), "value"),
            ],
            [State(self.id("display_options"), "data")],
        )

        @app.callback(
            Output(self.id("graph"), "data"),
            [
                Input(self.id("graph_generation_options"), "data"),
                Input(self.id(), "data"),
            ],
            [State(self.id("graph"), "data")],
        )
        @cache.memoize()
        def update_graph(graph_generation_options, struct_or_mol, current_graph):

            if not struct_or_mol:
                raise PreventUpdate

            struct_or_mol = self.from_data(struct_or_mol)
            current_graph = self.from_data(current_graph)

            bonding_strategy_kwargs = graph_generation_options[
                "bonding_strategy_kwargs"
            ]

            # TODO: add additional check here?
            unit_cell_choice = graph_generation_options["unit_cell_choice"]
            if isinstance(struct_or_mol, Structure):
                if unit_cell_choice != "input":
                    if unit_cell_choice == "primitive":
                        struct_or_mol = struct_or_mol.get_primitive_structure()
                    elif unit_cell_choice == "conventional":
                        sga = SpacegroupAnalyzer(struct_or_mol)
                        struct_or_mol = sga.get_conventional_standard_structure()
                    elif unit_cell_choice == "reduced":
                        struct_or_mol = struct_or_mol.get_reduced_structure()

            graph = self._preprocess_input_to_graph(
                struct_or_mol,
                bonding_strategy=graph_generation_options["bonding_strategy"],
                bonding_strategy_kwargs=bonding_strategy_kwargs,
            )

            if (
                current_graph
                and graph.structure == current_graph.structure
                and graph == current_graph
            ):
                raise PreventUpdate

            return graph

        @app.callback(
            [
                Output(self.id("scene"), "data"),
                Output(self.id("legend_data"), "data"),
                Output(self.id("color-scheme"), "options"),
            ],
            [
                Input(self.id("graph"), "data"),
                Input(self.id("display_options"), "data"),
                Input(self.id("scene_additions"), "data"),
            ],
        )
        @cache.memoize()
        def update_scene_and_legend_and_colors(graph, display_options, scene_additions):
            if not graph or not display_options:
                raise PreventUpdate
            display_options = self.from_data(display_options)
            graph = self.from_data(graph)
            scene, legend = self.get_scene_and_legend(
                graph,
                name=self.id(),
                **display_options,
                scene_additions=scene_additions,
            )

            color_options = [
                {"label": "Jmol", "value": "Jmol"},
                {"label": "VESTA", "value": "VESTA"},
                {"label": "Accessible", "value": "accessible"},
            ]
            struct_or_mol = self._get_struct_or_mol(graph)
            site_props = Legend(struct_or_mol).analyze_site_props(struct_or_mol)
            for site_prop_type in ("scalar", "categorical"):
                if site_prop_type in site_props:
                    for prop in site_props[site_prop_type]:
                        color_options += [
                            {"label": f"Site property: {prop}", "value": prop}
                        ]

            return scene, legend, color_options

        @app.callback(
            Output(self.id("scene"), "downloadRequest"),
            [Input(self.id("screenshot_button"), "n_clicks")],
            [State(self.id("scene"), "downloadRequest"), State(self.id(), "data")],
        )
        @cache.memoize()
        def trigger_screenshot(n_clicks, current_requests, struct_or_mol):
            if n_clicks is None:
                raise PreventUpdate
            struct_or_mol = self.from_data(struct_or_mol)
            # TODO: this will break if store is structure/molecule graph ...
            formula = struct_or_mol.composition.reduced_formula
            if hasattr(struct_or_mol, "get_space_group_info"):
                spgrp = struct_or_mol.get_space_group_info()[0]
            else:
                spgrp = ""
            request_filename = "{}-{}-crystal-toolkit.png".format(formula, spgrp)
            if not current_requests:
                n_requests = 1
            else:
                n_requests = current_requests["n_requests"] + 1
            return {
                "n_requests": n_requests,
                "filename": request_filename,
                "filetype": "png",
            }

        @app.callback(
            [
                Output(self.id("legend_container"), "children"),
                Output(self.id("title_container"), "children"),
            ],
            [Input(self.id("legend_data"), "data")],
        )
        @cache.memoize()
        def update_legend_and_title(legend):

            if not legend:
                raise PreventUpdate

            legend = self.from_data(legend)

            return self._make_legend(legend), self._make_title(legend)

        @app.callback(
            [
                Output(self.id("bonding_algorithm_custom_cutoffs"), "data"),
                Output(self.id("bonding_algorithm_custom_cutoffs_container"), "style"),
            ],
            [Input(self.id("bonding_algorithm"), "value")],
            [
                State(self.id("graph"), "data"),
                State(self.id("bonding_algorithm_custom_cutoffs_container"), "style"),
            ],
        )
        @cache.memoize()
        def update_custom_bond_options(bonding_algorithm, graph, current_style):

            if not graph:
                raise PreventUpdate

            if bonding_algorithm == "CutOffDictNN":
                style = {}
            else:
                style = {"display": "none"}
                if style == current_style:
                    # no need to update rows if we're not showing them
                    raise PreventUpdate

            graph = self.from_data(graph)
            rows = self._make_bonding_algorithm_custom_cuffoff_data(graph)

            return rows, style

    def _make_legend(self, legend):

        if not legend:
            return html.Div(id=self.id("legend"))

        def get_font_color(hex_code):
            # ensures contrasting font color for background color
            c = tuple(int(hex_code[1:][i : i + 2], 16) for i in (0, 2, 4))
            if 1 - (c[0] * 0.299 + c[1] * 0.587 + c[2] * 0.114) / 255 < 0.5:
                font_color = "#000000"
            else:
                font_color = "#ffffff"
            return font_color

        try:
            formula = Composition.from_dict(legend["composition"]).reduced_formula
        except:
            # TODO: fix legend for Dummy Specie compositions
            formula = "Unknown"

        legend_colors = OrderedDict(
            sorted(list(legend["colors"].items()), key=lambda x: formula.find(x[1]))
        )

        legend_elements = [
            Button(
                html.Span(
                    name, className="icon", style={"color": get_font_color(color)}
                ),
                kind="static",
                style={"backgroundColor": color},
            )
            for color, name in legend_colors.items()
        ]

        return Field(
            [Control(el, style={"marginRight": "0.2rem"}) for el in legend_elements],
            id=self.id("legend"),
            grouped=True,
        )

    def _make_title(self, legend):

        if not legend or (not legend.get("composition", None)):
            return H1(self.default_title, id=self.id("title"))

        composition = legend["composition"]
        if isinstance(composition, dict):

            # TODO: make Composition handle DummySpecie for title
            try:
                composition = Composition.from_dict(composition)
                formula = composition.iupac_formula
                formula_parts = re.findall(r"[^\d_]+|\d+", formula)
                formula_components = [
                    html.Sub(part) if part.isnumeric() else html.Span(part)
                    for part in formula_parts
                ]
            except:
                formula_components = list(composition.keys())

        return H1(
            formula_components, id=self.id("title"), style={"display": "inline-block"}
        )

    @staticmethod
    def _make_bonding_algorithm_custom_cuffoff_data(graph):
        if not graph:
            return [{"A": None, "B": None, "A—B": None}]
        struct_or_mol = StructureMoleculeComponent._get_struct_or_mol(graph)
        # can't use type_of_specie because it doesn't work with disordered structures
        species = set(
            map(
                str,
                chain.from_iterable(
                    [list(c.keys()) for c in struct_or_mol.species_and_occu]
                ),
            )
        )
        rows = [
            {"A": combination[0], "B": combination[1], "A—B": 0}
            for combination in combinations_with_replacement(species, 2)
        ]
        return rows

    @property
    def _sub_layouts(self):

        struct_layout = html.Div(
            Simple3DSceneComponent(
                id=self.id("scene"),
                data=self.initial_data["scene"],
                settings=self.initial_scene_settings,
            ),
            style={
                "width": "100%",
                "height": "100%",
                "overflow": "hidden",
                "margin": "0 auto",
            },
        )

        screenshot_layout = html.Div(
            [
                Button(
                    [Icon(), html.Span(), "Download Image"],
                    kind="primary",
                    id=self.id("screenshot_button"),
                )
            ],
            # TODO: change to "bottom" when dropdown included
            style={"verticalAlign": "top", "display": "inline-block"},
        )

        title_layout = html.Div(
            self._make_title(self._initial_data["legend_data"]),
            id=self.id("title_container"),
        )

        legend_layout = html.Div(
            self._make_legend(self._initial_data["legend_data"]),
            id=self.id("legend_container"),
        )

        nn_mapping = {
            "CrystalNN": "CrystalNN",
            "Custom Bonds": "CutOffDictNN",
            "Jmol Bonding": "JmolNN",
            "Minimum Distance (10% tolerance)": "MinimumDistanceNN",
            "O'Keeffe's Algorithm": "MinimumOKeeffeNN",
            "Hoppe's ECoN Algorithm": "EconNN",
            "Brunner's Reciprocal Algorithm": "BrunnerNN_reciprocal",
        }

        bonding_algorithm = dcc.Dropdown(
            options=[{"label": k, "value": v} for k, v in nn_mapping.items()],
            value=self.initial_data["graph_generation_options"]["bonding_strategy"],
            clearable=False,
            id=self.id("bonding_algorithm"),
            persistence=self.persistence,
            persistence_type=self.persistence_type,
        )

        bonding_algorithm_custom_cutoffs = html.Div(
            [
                html.Br(),
                dt.DataTable(
                    columns=[
                        {"name": "A", "id": "A"},
                        {"name": "B", "id": "B"},
                        {"name": "A—B /Å", "id": "A—B"},
                    ],
                    editable=True,
                    data=self._make_bonding_algorithm_custom_cuffoff_data(
                        self.initial_data.get("default")
                    ),
                    id=self.id("bonding_algorithm_custom_cutoffs"),
                ),
                html.Br(),
            ],
            id=self.id("bonding_algorithm_custom_cutoffs_container"),
            style={"display": "none"},
        )

        options_layout = Field(
            [
                #  TODO: hide if molecule
                html.Label("Change unit cell:", className="mpc-label"),
                html.Div(
                    dcc.Dropdown(
                        options=[
                            {"label": "Input cell", "value": "input"},
                            {"label": "Primitive cell", "value": "primitive"},
                            {"label": "Conventional cell", "value": "conventional"},
                            {"label": "Reduced cell", "value": "reduced"},
                        ],
                        value="input",
                        clearable=False,
                        id=self.id("unit-cell-choice"),
                        persistence=self.persistence,
                        persistence_type=self.persistence_type,
                    ),
                    className="mpc-control",
                ),
                html.Div(
                    [
                        html.Label("Change bonding algorithm: ", className="mpc-label"),
                        bonding_algorithm,
                        bonding_algorithm_custom_cutoffs,
                    ]
                ),
                html.Label("Change color scheme:", className="mpc-label"),
                html.Div(
                    dcc.Dropdown(
                        options=[
                            {"label": "VESTA", "value": "VESTA"},
                            {"label": "Jmol", "value": "Jmol"},
                            {"label": "Accessible", "value": "accessible"},
                        ],
                        value=self.initial_data["display_options"]["color_scheme"],
                        clearable=False,
                        persistence=self.persistence,
                        persistence_type=self.persistence_type,
                        id=self.id("color-scheme"),
                    ),
                    className="mpc-control",
                ),
                html.Label("Change atomic radii:", className="mpc-label"),
                html.Div(
                    dcc.Dropdown(
                        options=[
                            {"label": "Ionic", "value": "specified_or_average_ionic"},
                            {"label": "Covalent", "value": "covalent"},
                            {"label": "Van der Waals", "value": "van_der_waals"},
                            {
                                "label": f"Uniform ({Legend.uniform_radius}Å)",
                                "value": "uniform",
                            },
                        ],
                        value=self.initial_data["display_options"]["radius_strategy"],
                        clearable=False,
                        persistence=self.persistence,
                        persistence_type=self.persistence_type,
                        id=self.id("radius_strategy"),
                    ),
                    className="mpc-control",
                ),
                html.Label("Draw options:", className="mpc-label"),
                html.Div(
                    [
                        dcc.Checklist(
                            options=[
                                {
                                    "label": "Draw repeats of atoms on periodic boundaries",
                                    "value": "draw_image_atoms",
                                },
                                {
                                    "label": "Draw atoms outside unit cell bonded to "
                                    "atoms within unit cell",
                                    "value": "bonded_sites_outside_unit_cell",
                                },
                                {
                                    "label": "Hide bonds where destination atoms are not shown",
                                    "value": "hide_incomplete_bonds",
                                },
                            ],
                            value=[
                                opt
                                for opt in (
                                    "draw_image_atoms",
                                    "bonded_sites_outside_unit_cell",
                                    "hide_incomplete_bonds",
                                )
                                if self.initial_data["display_options"][opt]
                            ],
                            labelStyle={"display": "block"},
                            inputClassName="mpc-radio",
                            id=self.id("draw_options"),
                            persistence=self.persistence,
                            persistence_type=self.persistence_type,
                        )
                    ]
                ),
                html.Label("Hide/show:", className="mpc-label"),
                html.Div(
                    [
                        dcc.Checklist(
                            options=[
                                {"label": "Atoms", "value": "atoms"},
                                {"label": "Bonds", "value": "bonds"},
                                {"label": "Unit cell", "value": "unit_cell"},
                                {"label": "Polyhedra", "value": "polyhedra"},
                                {"label": "Axes", "value": "axes"},
                            ],
                            value=["atoms", "bonds", "unit_cell", "polyhedra"],
                            labelStyle={"display": "block"},
                            inputClassName="mpc-radio",
                            id=self.id("hide-show"),
                            persistence=self.persistence,
                            persistence_type=self.persistence_type,
                        )
                    ],
                    className="mpc-control",
                ),
            ]
        )

        return {
            "struct": struct_layout,
            "screenshot": screenshot_layout,
            "options": options_layout,
            "title": title_layout,
            "legend": legend_layout,
        }

    def layout(self, size: str = "400px") -> html.Div:
        """
        :param size: a CSS string specifying width/height of Div
        :return: A html.Div containing the 3D structure or molecule
        """
        return html.Div(
            self._sub_layouts["struct"], style={"width": size, "height": size}
        )

    @staticmethod
    def _preprocess_input_to_graph(
        input: Union[Structure, StructureGraph, Molecule, MoleculeGraph],
        bonding_strategy: str = DEFAULTS["bonding_strategy"],
        bonding_strategy_kwargs: Optional[Dict] = None,
    ) -> Union[StructureGraph, MoleculeGraph]:

        if isinstance(input, Structure):

            # ensure fractional co-ordinates are normalized to be in [0,1)
            # (this is actually not guaranteed by Structure)
            try:
                input = input.as_dict(verbosity=0)
            except TypeError:
                # TODO: remove this, necessary for Slab(?), some structure subclasses don't have verbosity
                input = input.as_dict()
            for site in input["sites"]:
                site["abc"] = np.mod(site["abc"], 1)
            input = Structure.from_dict(input)

            if not input.is_ordered:
                # calculating bonds in disordered structures is currently very flaky
                bonding_strategy = "CutOffDictNN"

        # we assume most uses of this class will give a structure as an input argument,
        # meaning we have to calculate the graph for bonding information, however if
        # the graph is already known and supplied, we will use that
        if isinstance(input, StructureGraph) or isinstance(input, MoleculeGraph):
            graph = input
        else:
            if (
                bonding_strategy
                not in StructureMoleculeComponent.available_bonding_strategies.keys()
            ):
                raise ValueError(
                    "Bonding strategy not supported. Please supply a name "
                    "of a NearNeighbor subclass, choose from: {}".format(
                        ", ".join(
                            StructureMoleculeComponent.available_bonding_strategies.keys()
                        )
                    )
                )
            else:
                bonding_strategy_kwargs = bonding_strategy_kwargs or {}
                if bonding_strategy == "CutOffDictNN":
                    if "cut_off_dict" in bonding_strategy_kwargs:
                        # TODO: remove this hack by making args properly JSON serializable
                        bonding_strategy_kwargs["cut_off_dict"] = {
                            (x[0], x[1]): x[2]
                            for x in bonding_strategy_kwargs["cut_off_dict"]
                        }
                bonding_strategy = StructureMoleculeComponent.available_bonding_strategies[
                    bonding_strategy
                ](
                    **bonding_strategy_kwargs
                )
                try:
                    with warnings.catch_warnings():
                        warnings.simplefilter("ignore")
                        if isinstance(input, Structure):
                            graph = StructureGraph.with_local_env_strategy(
                                input, bonding_strategy
                            )
                        else:
                            graph = MoleculeGraph.with_local_env_strategy(
                                input, bonding_strategy
                            )
                except:
                    # for some reason computing bonds failed, so let's not have any bonds(!)
                    if isinstance(input, Structure):
                        graph = StructureGraph.with_empty_graph(input)
                    else:
                        graph = MoleculeGraph.with_empty_graph(input)

        return graph

    @staticmethod
    def _get_struct_or_mol(
        graph: Union[StructureGraph, MoleculeGraph, Structure, Molecule]
    ) -> Union[Structure, Molecule]:
        if isinstance(graph, StructureGraph):
            return graph.structure
        elif isinstance(graph, MoleculeGraph):
            return graph.molecule
        elif isinstance(graph, Structure) or isinstance(graph, Molecule):
            return graph
        else:
            raise ValueError

    @staticmethod
    def get_scene_and_legend(
        graph: Optional[Union[StructureGraph, MoleculeGraph]],
        name,
        color_scheme=DEFAULTS["color_scheme"],
        color_scale=None,
        radius_strategy=DEFAULTS["radius_strategy"],
        draw_image_atoms=DEFAULTS["draw_image_atoms"],
        bonded_sites_outside_unit_cell=DEFAULTS["bonded_sites_outside_unit_cell"],
        hide_incomplete_bonds=DEFAULTS["hide_incomplete_bonds"],
        explicitly_calculate_polyhedra_hull=False,
        scene_additions=None,
        show_compass=DEFAULTS["show_compass"],
    ) -> Tuple[Scene, Dict[str, str]]:

        # default scene name will be name of component, "_ct_..."
        # strip leading _ since this will cause problems in JavaScript land
        scene = Scene(name=name[1:])

        if graph is None:
            return scene, {}

        struct_or_mol = StructureMoleculeComponent._get_struct_or_mol(graph)

        # TODO: add radius_scale
        legend = Legend(
            struct_or_mol,
            color_scheme=color_scheme,
            radius_scheme=radius_strategy,
            cmap_range=color_scale,
        )

        if isinstance(graph, StructureGraph):
            scene = graph.get_scene(
                draw_image_atoms=draw_image_atoms,
                bonded_sites_outside_unit_cell=bonded_sites_outside_unit_cell,
                hide_incomplete_edges=hide_incomplete_bonds,
                explicitly_calculate_polyhedra_hull=explicitly_calculate_polyhedra_hull,
                legend=legend,
            )
        elif isinstance(graph, MoleculeGraph):
            scene = graph.get_scene(legend=legend)

        scene.name = name

        if hasattr(struct_or_mol, "lattice"):
            axes = struct_or_mol.lattice._axes_from_lattice()
            # TODO: fix pop-in ?
            axes.visible = show_compass
            scene.contents.append(axes)

        if scene_additions:
            # TODO: need a Scene.from_json() to make this work
            raise NotImplementedError
            scene["contents"].append(scene_additions)

        return scene.to_json(), legend.get_legend()

    def screenshot_layout(self):
        """
        :return: A layout including a button to trigger a screenshot download.
        """
        return self._sub_layouts["screenshot"]

    def options_layout(self):
        """
        :return: A layout including options to change the appearance, bonding, etc.
        """
        return self._sub_layouts["options"]

    def title_layout(self):
        """
        :return: A layout including the composition of the structure/molecule as a title.
        """
        return self._sub_layouts["title"]

    def legend_layout(self):
        """
        :return: A layout including a legend for the structure/molecule.
        """
        return self._sub_layouts["legend"]
Exemple #9
0
app = dash.Dash('')
app.title = "MP Viewer"

app.scripts.config.serve_locally = True
app.css.append_css(
    {'external_url': 'https://codepen.io/chriddyp/pen/bWLwgP.css'})

mpr = MPRester()

DEFAULT_STRUCTURE = loadfn('default_structure.json')
DEFAULT_COLOR_SCHEME = 'VESTA'
DEFAULT_BONDING_METHOD = 'MinimumOKeeffeNN'

AVAILABLE_BONDING_METHODS = [
    str(c.__name__) for c in NearNeighbors.__subclasses__()
]

# to help with readability, each component of the app is defined
# in a modular way below; these are all then included in the app.layout

LAYOUT_FORMULA_INPUT = html.Div([
    dcc.Input(id='input-box',
              type='text',
              placeholder='Enter a formula or mp-id'),
    html.Span(' '),
    html.Button('Load', id='button')
])

LAYOUT_VISIBILITY_OPTIONS = html.Div([
    dcc.Checklist(id='visibility_options',
Exemple #10
0
class StructureMoleculeComponent(MPComponent):
    """
    A component to display pymatgen Structure, Molecule, StructureGraph
    and MoleculeGraph objects.
    """

    available_bonding_strategies = {
        subclass.__name__: subclass
        for subclass in NearNeighbors.__subclasses__()
    }

    default_scene_settings = {
        "extractAxis": True,
        # For visual diff testing, we change the renderer
        # to SVG since this WebGL support is more difficult
        # in headless browsers / CI.
        "renderer": "svg" if SETTINGS.TEST_MODE else "webgl",
        "secondaryObjectView": False,
    }

    # what to show for the title_layout if structure/molecule not loaded
    default_title = "Crystal Toolkit"

    # human-readable label to file extension
    # downloading Molecules has not yet been added
    download_options = {
        "Structure": {
            "CIF (Symmetrized)": {
                "fmt": "cif",
                "symprec": EmmetSettings().SYMPREC
            },
            "CIF": {
                "fmt": "cif"
            },
            "POSCAR": {
                "fmt": "poscar"
            },
            "JSON": {
                "fmt": "json"
            },
            "Prismatic": {
                "fmt": "prismatic"
            },
            "VASP Input Set (MPRelaxSet)": {},  # special
        }
    }

    def __init__(
        self,
        struct_or_mol: Optional[Union[Structure, StructureGraph, Molecule,
                                      MoleculeGraph]] = None,
        id: str = None,
        className: str = "box",
        scene_additions: Optional[Scene] = None,
        bonding_strategy: str = DEFAULTS["bonding_strategy"],
        bonding_strategy_kwargs: Optional[dict] = None,
        color_scheme: str = DEFAULTS["color_scheme"],
        color_scale: Optional[str] = None,
        radius_strategy: str = DEFAULTS["radius_strategy"],
        unit_cell_choice: str = DEFAULTS["unit_cell_choice"],
        draw_image_atoms: bool = DEFAULTS["draw_image_atoms"],
        bonded_sites_outside_unit_cell: bool = DEFAULTS[
            "bonded_sites_outside_unit_cell"],
        hide_incomplete_bonds: bool = DEFAULTS["hide_incomplete_bonds"],
        show_compass: bool = DEFAULTS["show_compass"],
        scene_settings: Optional[Dict] = None,
        group_by_site_property: Optional[str] = None,
        show_legend: bool = DEFAULTS["show_legend"],
        show_settings: bool = DEFAULTS["show_settings"],
        show_controls: bool = DEFAULTS["show_controls"],
        show_expand_button: bool = DEFAULTS["show_expand_button"],
        show_image_button: bool = DEFAULTS["show_image_button"],
        show_export_button: bool = DEFAULTS["show_export_button"],
        show_position_button: bool = DEFAULTS["show_position_button"],
        **kwargs,
    ):
        """
        Create a StructureMoleculeComponent from a structure or molecule.

        :param struct_or_mol: input structure or molecule
        :param id: canonical id
        :param scene_additions: extra geometric elements to add to the 3D scene
        :param bonding_strategy: bonding strategy from pymatgen NearNeighbors class
        :param bonding_strategy_kwargs: options for the bonding strategy
        :param color_scheme: color scheme, see Legend class
        :param color_scale: color scale, see Legend class
        :param radius_strategy: radius strategy, see Legend class
        :param draw_image_atoms: whether to draw repeats of atoms on periodic images
        :param bonded_sites_outside_unit_cell: whether to draw sites bonded outside the unit cell
        :param hide_incomplete_bonds: whether to hide or show incomplete bonds
        :param show_compass: whether to hide or show the compass
        :param scene_settings: scene settings (lighting etc.) to pass to CrystalToolkitScene
        :param group_by_site_property: a site property used for grouping of atoms for mouseover/interaction,
        :param show_legend: show or hide legend panel within the scene
        :param show_controls: show or hide scene control bar
        :param show_expand_button: show or hide the full screen button within the scene control bar
        :param show_image_button: show or hide the image download button within the scene control bar
        :param show_export_button: show or hide the file export button within the scene control bar
        :param show_position_button: show or hide the revert position button within the scene control bar
        e.g. Wyckoff label
        :param kwargs: extra keyword arguments to pass to MPComponent
        """

        super().__init__(id=id, default_data=struct_or_mol, **kwargs)
        self.className = className
        self.show_legend = show_legend
        self.show_settings = show_settings
        self.show_controls = show_controls
        self.show_expand_button = show_expand_button
        self.show_image_button = show_image_button
        self.show_export_button = show_export_button
        self.show_position_button = show_position_button

        self.initial_scene_settings = self.default_scene_settings.copy()
        if scene_settings:
            self.initial_scene_settings.update(scene_settings)

        self.create_store("scene_settings",
                          initial_data=self.initial_scene_settings)

        # unit cell choice and bonding algorithms need to come from a settings
        # object (in a dcc.Store) guaranteed to be present in layout, rather
        # than from the controls themselves -- since these are optional and
        # may not be present in the layout
        self.create_store(
            "graph_generation_options",
            initial_data={
                "bonding_strategy": bonding_strategy,
                "bonding_strategy_kwargs": bonding_strategy_kwargs,
                "unit_cell_choice": unit_cell_choice,
            },
        )

        self.create_store(
            "display_options",
            initial_data={
                "color_scheme": color_scheme,
                "color_scale": color_scale,
                "radius_strategy": radius_strategy,
                "draw_image_atoms": draw_image_atoms,
                "bonded_sites_outside_unit_cell":
                bonded_sites_outside_unit_cell,
                "hide_incomplete_bonds": hide_incomplete_bonds,
                "show_compass": show_compass,
                "group_by_site_property": group_by_site_property,
            },
        )

        if scene_additions:
            initial_scene_additions = Scene(
                name="scene_additions", contents=scene_additions).to_json()
        else:
            initial_scene_additions = None
        self.create_store("scene_additions",
                          initial_data=initial_scene_additions)

        if struct_or_mol:
            # graph is cached explicitly, this isn't necessary but is an
            # optimization so that graph is only re-generated if bonding
            # algorithm changes
            struct_or_mol = self._preprocess_structure(
                struct_or_mol, unit_cell_choice=unit_cell_choice)
            graph = self._preprocess_input_to_graph(
                struct_or_mol,
                bonding_strategy=bonding_strategy,
                bonding_strategy_kwargs=bonding_strategy_kwargs,
            )
            scene, legend = self.get_scene_and_legend(
                graph,
                scene_additions=self.initial_data["scene_additions"],
                **self.initial_data["display_options"],
            )
            if hasattr(struct_or_mol, "lattice"):
                self._lattice = struct_or_mol.lattice
        else:
            # component could be initialized without a structure, in which case
            # an empty scene should be displayed
            graph = None
            scene, legend = self.get_scene_and_legend(
                None,
                scene_additions=self.initial_data["scene_additions"],
                **self.initial_data["display_options"],
            )

        self.create_store("legend_data", initial_data=legend)
        self.create_store("graph", initial_data=graph)

        # this is used by a CrystalToolkitScene component, not a dcc.Store
        self._initial_data["scene"] = scene

        # hide axes inset for molecules
        if isinstance(struct_or_mol, Molecule) or isinstance(
                struct_or_mol, MoleculeGraph):
            self.scene_kwargs = {"axisView": "HIDDEN"}
        else:
            self.scene_kwargs = {}

    def generate_callbacks(self, app, cache):

        # a lot of the verbosity in this callback is to support custom bonding
        # this is not the format CutOffDictNN expects (since that is not JSON
        # serializable), so we store as a list of tuples instead
        # TODO: make CutOffDictNN args JSON serializable
        app.clientside_callback(
            """
            function (bonding_strategy, custom_cutoffs_rows, unit_cell_choice) {
            
                const bonding_strategy_kwargs = {}
                if (bonding_strategy === 'CutOffDictNN') {
                    const cut_off_dict = []
                    custom_cutoffs_rows.forEach(function(row) {
                        cut_off_dict.push([row['A'], row['B'], parseFloat(row['A—B'])])
                    })
                    bonding_strategy_kwargs.cut_off_dict = cut_off_dict
                }
            
                return {
                    bonding_strategy: bonding_strategy,
                    bonding_strategy_kwargs: bonding_strategy_kwargs,
                    unit_cell_choice: unit_cell_choice
                }
            }
            """,
            Output(self.id("graph_generation_options"), "data"),
            [
                Input(self.id("bonding_algorithm"), "value"),
                Input(self.id("bonding_algorithm_custom_cutoffs"), "data"),
                Input(self.id("unit-cell-choice"), "value"),
            ],
        )

        app.clientside_callback(
            """
            function (values, options) {
                const visibility = {}
                options.forEach(function (opt) {
                    visibility[opt.value] = Boolean(values.includes(opt.value))
                })
                return visibility
            }
            """,
            Output(self.id("scene"), "toggleVisibility"),
            [Input(self.id("hide-show"), "value")],
            [State(self.id("hide-show"), "options")],
        )

        app.clientside_callback(
            """
            function (colorScheme, radiusStrategy, drawOptions, displayOptions) {
            
                const newDisplayOptions = Object.assign({}, displayOptions);
                newDisplayOptions.color_scheme = colorScheme
                newDisplayOptions.radius_strategy = radiusStrategy
                newDisplayOptions.draw_image_atoms = drawOptions.includes('draw_image_atoms')
                newDisplayOptions.bonded_sites_outside_unit_cell =  drawOptions.includes('bonded_sites_outside_unit_cell')
                newDisplayOptions.hide_incomplete_bonds = drawOptions.includes('hide_incomplete_bonds')

                return newDisplayOptions
            }
            """,
            Output(self.id("display_options"), "data"),
            [
                Input(self.id("color-scheme"), "value"),
                Input(self.id("radius_strategy"), "value"),
                Input(self.id("draw_options"), "value"),
            ],
            [State(self.id("display_options"), "data")],
        )

        @app.callback(
            Output(self.id("graph"), "data"),
            [
                Input(self.id("graph_generation_options"), "data"),
                Input(self.id(), "data"),
            ],
            [State(self.id("graph"), "data")],
        )
        @cache.memoize()
        def update_graph(graph_generation_options, struct_or_mol,
                         current_graph):

            if not struct_or_mol:
                raise PreventUpdate

            struct_or_mol = self.from_data(struct_or_mol)
            current_graph = self.from_data(current_graph)

            bonding_strategy_kwargs = graph_generation_options[
                "bonding_strategy_kwargs"]

            # TODO: add additional check here?
            unit_cell_choice = graph_generation_options["unit_cell_choice"]
            struct_or_mol = self._preprocess_structure(struct_or_mol,
                                                       unit_cell_choice)

            graph = self._preprocess_input_to_graph(
                struct_or_mol,
                bonding_strategy=graph_generation_options["bonding_strategy"],
                bonding_strategy_kwargs=bonding_strategy_kwargs,
            )

            if (current_graph and graph.structure == current_graph.structure
                    and graph == current_graph):
                raise PreventUpdate

            return graph

        @app.callback(
            Output(self.id("scene"), "data"),
            [
                Input(self.id("graph"), "data"),
                Input(self.id("display_options"), "data"),
                Input(self.id("scene_additions"), "data"),
            ],
        )
        @cache.memoize()
        def update_scene(graph, display_options, scene_additions):
            if not graph or not display_options:
                raise PreventUpdate
            display_options = self.from_data(display_options)
            graph = self.from_data(graph)
            scene, legend = self.get_scene_and_legend(
                graph, **display_options, scene_additions=scene_additions)
            return scene

        @app.callback(
            Output(self.id("legend_data"), "data"),
            [
                Input(self.id("graph"), "data"),
                Input(self.id("display_options"), "data"),
                Input(self.id("scene_additions"), "data"),
            ],
        )
        @cache.memoize()
        def update_legend_and_colors(graph, display_options, scene_additions):
            if not graph or not display_options:
                raise PreventUpdate
            display_options = self.from_data(display_options)
            graph = self.from_data(graph)
            scene, legend = self.get_scene_and_legend(
                graph, **display_options, scene_additions=scene_additions)
            return legend

        @app.callback(
            Output(self.id("color-scheme"), "options"),
            [Input(self.id("legend_data"), "data")],
        )
        def update_color_options(legend_data):

            # TODO: make client-side
            color_options = [
                {
                    "label": "Jmol",
                    "value": "Jmol"
                },
                {
                    "label": "VESTA",
                    "value": "VESTA"
                },
                {
                    "label": "Accessible",
                    "value": "accessible"
                },
            ]

            if not legend_data:
                return color_options

            for option in legend_data["available_color_schemes"]:
                color_options += [{
                    "label": f"Site property: {option}",
                    "value": option
                }]

            return color_options

        # app.clientside_callback(
        #     """
        #     function (legendData) {
        #
        #         var colorOptions = [
        #             {label: "Jmol", value: "Jmol"},
        #             {label: "VESTA", value: "VESTA"},
        #             {label: "Accessible", value: "accessible"},
        #         ]
        #
        #
        #
        #         return colorOptions
        #     }
        #     """,
        #     Output(self.id("color-scheme"), "options"),
        #     [Input(self.id("legend_data"), "data")]
        # )

        @app.callback(
            Output(self.id("download-image"), "data"),
            Input(self.id("scene"), "imageDataTimestamp"),
            [
                State(self.id("scene"), "imageData"),
                State(self.id(), "data"),
            ],
        )
        def download_image(image_data_timestamp, image_data, data):
            if not image_data_timestamp:
                raise PreventUpdate

            struct_or_mol = self.from_data(data)
            if isinstance(struct_or_mol, StructureGraph):
                formula = struct_or_mol.structure.composition.reduced_formula
            elif isinstance(struct_or_mol, MoleculeGraph):
                formula = struct_or_mol.molecule.composition.reduced_Formula
            else:
                formula = struct_or_mol.composition.reduced_formula
            if hasattr(struct_or_mol, "get_space_group_info"):
                spgrp = struct_or_mol.get_space_group_info()[0]
            else:
                spgrp = ""
            request_filename = "{}-{}-crystal-toolkit.png".format(
                formula, spgrp)

            return {
                "content": image_data[len("data:image/png;base64,"):],
                "filename": request_filename,
                "base64": True,
                "type": "image/png",
            }

        @app.callback(
            Output(self.id("download-structure"), "data"),
            Input(self.id("scene"), "fileTimestamp"),
            [
                State(self.id("scene"), "fileType"),
                State(self.id(), "data"),
            ],
        )
        def download_structure(file_timestamp, download_option, data):
            if not file_timestamp:
                raise PreventUpdate

            structure = self.from_data(data)
            if isinstance(structure, StructureGraph):
                structure = structure.structure

            file_prefix = structure.composition.reduced_formula

            if "VASP" not in download_option:

                extension = self.download_options["Structure"][
                    download_option]["fmt"]
                options = self.download_options["Structure"][download_option]

                try:
                    contents = structure.to(**options)
                except Exception as exc:
                    # don't fail silently, tell user what went wrong
                    contents = exc

                base64 = b64encode(contents.encode("utf-8")).decode("ascii")

                download_data = {
                    "content": base64,
                    "base64": True,
                    "type": "text/plain",
                    "filename": f"{file_prefix}.{extension}",
                }

            else:

                if "Relax" in download_option:
                    vis = MPRelaxSet(structure)
                    expected_filename = "MPRelaxSet.zip"
                else:
                    raise ValueError(
                        "No other VASP input sets currently supported.")

                with TemporaryDirectory() as tmpdir:
                    vis.write_input(tmpdir, potcar_spec=True, zip_output=True)
                    path = Path(tmpdir) / expected_filename
                    bytes = b64encode(path.read_bytes()).decode("ascii")

                download_data = {
                    "content": bytes,
                    "base64": True,
                    "type": "application/zip",
                    "filename": f"{file_prefix} {expected_filename}",
                }

            return download_data

        @app.callback(
            Output(self.id("title_container"), "children"),
            [Input(self.id("legend_data"), "data")],
        )
        @cache.memoize()
        def update_title(legend):

            if not legend:
                raise PreventUpdate

            legend = self.from_data(legend)

            return self._make_title(legend)

        @app.callback(
            Output(self.id("legend_container"), "children"),
            [Input(self.id("legend_data"), "data")],
        )
        @cache.memoize()
        def update_legend(legend):

            if not legend:
                raise PreventUpdate

            legend = self.from_data(legend)

            return self._make_legend(legend)

        @app.callback(
            [
                Output(self.id("bonding_algorithm_custom_cutoffs"), "data"),
                Output(self.id("bonding_algorithm_custom_cutoffs_container"),
                       "style"),
            ],
            [Input(self.id("bonding_algorithm"), "value")],
            [
                State(self.id("graph"), "data"),
                State(self.id("bonding_algorithm_custom_cutoffs_container"),
                      "style"),
            ],
        )
        @cache.memoize()
        def update_custom_bond_options(bonding_algorithm, graph,
                                       current_style):

            if not graph:
                raise PreventUpdate

            if bonding_algorithm == "CutOffDictNN":
                style = {}
            else:
                style = {"display": "none"}
                if style == current_style:
                    # no need to update rows if we're not showing them
                    raise PreventUpdate

            graph = self.from_data(graph)
            rows = self._make_bonding_algorithm_custom_cuffoff_data(graph)

            return rows, style

    def _make_legend(self, legend):

        if not legend:
            return html.Div(id=self.id("legend"))

        def get_font_color(hex_code):
            # ensures contrasting font color for background color
            c = tuple(int(hex_code[1:][i:i + 2], 16) for i in (0, 2, 4))
            if 1 - (c[0] * 0.299 + c[1] * 0.587 + c[2] * 0.114) / 255 < 0.5:
                font_color = "#000000"
            else:
                font_color = "#ffffff"
            return font_color

        try:
            formula = Composition.from_dict(
                legend["composition"]).reduced_formula
        except:
            # TODO: fix legend for Dummy Specie compositions
            formula = "Unknown"

        legend_colors = OrderedDict(
            sorted(list(legend["colors"].items()),
                   key=lambda x: formula.find(x[1])))

        legend_elements = [
            html.Span(
                html.Span(name,
                          className="icon",
                          style={"color": get_font_color(color)}),
                className="button is-static is-rounded",
                style={"backgroundColor": color},
            ) for color, name in legend_colors.items()
        ]

        return html.Div(legend_elements,
                        id=self.id("legend"),
                        style={"display": "flex"},
                        className="buttons")

    def _make_title(self, legend):

        if not legend or (not legend.get("composition", None)):
            return H2(self.default_title, id=self.id("title"))

        composition = legend["composition"]
        if isinstance(composition, dict):

            try:
                composition = Composition.from_dict(composition)

                # strip DummySpecie if present (TODO: should be method in pymatgen)
                composition = Composition({
                    el: amt
                    for el, amt in composition.items()
                    if not isinstance(el, DummySpecie)
                })
                composition = composition.get_reduced_composition_and_factor(
                )[0]
                formula = composition.reduced_formula
                formula_parts = re.findall(r"[^\d_]+|\d+", formula)
                formula_components = [
                    html.Sub(part.strip())
                    if part.isnumeric() else html.Span(part.strip())
                    for part in formula_parts
                ]
            except:
                formula_components = list(map(str, composition.keys()))

        return H2(formula_components,
                  id=self.id("title"),
                  style={"display": "inline-block"})

    @staticmethod
    def _make_bonding_algorithm_custom_cuffoff_data(graph):
        if not graph:
            return [{"A": None, "B": None, "A—B": None}]
        struct_or_mol = StructureMoleculeComponent._get_struct_or_mol(graph)
        # can't use type_of_specie because it doesn't work with disordered structures
        species = set(
            map(
                str,
                chain.from_iterable(
                    [list(c.keys()) for c in struct_or_mol.species_and_occu]),
            ))
        rows = [{
            "A": combination[0],
            "B": combination[1],
            "A—B": 0
        } for combination in combinations_with_replacement(species, 2)]
        return rows

    @property
    def _sub_layouts(self):

        title_layout = html.Div(
            self._make_title(self._initial_data["legend_data"]),
            id=self.id("title_container"),
        )

        nn_mapping = {
            "CrystalNN": "CrystalNN",
            "Custom Bonds": "CutOffDictNN",
            "Jmol Bonding": "JmolNN",
            "Minimum Distance (10% tolerance)": "MinimumDistanceNN",
            "O'Keeffe's Algorithm": "MinimumOKeeffeNN",
            "Hoppe's ECoN Algorithm": "EconNN",
            "Brunner's Reciprocal Algorithm": "BrunnerNN_reciprocal",
        }

        bonding_algorithm = dcc.Dropdown(
            options=[{
                "label": k,
                "value": v
            } for k, v in nn_mapping.items()],
            value=self.initial_data["graph_generation_options"]
            ["bonding_strategy"],
            clearable=False,
            id=self.id("bonding_algorithm"),
            persistence=SETTINGS.PERSISTENCE,
            persistence_type=SETTINGS.PERSISTENCE_TYPE,
        )

        bonding_algorithm_custom_cutoffs = html.Div(
            [
                html.Br(),
                dt.DataTable(
                    columns=[
                        {
                            "name": "A",
                            "id": "A"
                        },
                        {
                            "name": "B",
                            "id": "B"
                        },
                        {
                            "name": "A—B /Å",
                            "id": "A—B"
                        },
                    ],
                    editable=True,
                    data=self._make_bonding_algorithm_custom_cuffoff_data(
                        self.initial_data.get("default")),
                    id=self.id("bonding_algorithm_custom_cutoffs"),
                ),
                html.Br(),
            ],
            id=self.id("bonding_algorithm_custom_cutoffs_container"),
            style={"display": "none"},
        )

        if self.show_settings:
            options_layout = Field([
                #  TODO: hide if molecule
                html.Label("Change unit cell:", className="mpc-label"),
                html.Div(
                    dcc.Dropdown(
                        options=[
                            {
                                "label": "Input cell",
                                "value": "input"
                            },
                            {
                                "label": "Primitive cell",
                                "value": "primitive"
                            },
                            {
                                "label": "Conventional cell",
                                "value": "conventional"
                            },
                            {
                                "label": "Reduced cell (Niggli)",
                                "value": "reduced_niggli",
                            },
                            {
                                "label": "Reduced cell (LLL)",
                                "value": "reduced_lll"
                            },
                        ],
                        value="input",
                        clearable=False,
                        id=self.id("unit-cell-choice"),
                        persistence=SETTINGS.PERSISTENCE,
                        persistence_type=SETTINGS.PERSISTENCE_TYPE,
                    ),
                    className="mpc-control",
                ),
                html.Div([
                    html.Label("Change bonding algorithm: ",
                               className="mpc-label"),
                    bonding_algorithm,
                    bonding_algorithm_custom_cutoffs,
                ]),
                html.Label("Change color scheme:", className="mpc-label"),
                html.Div(
                    dcc.Dropdown(
                        options=[
                            {
                                "label": "VESTA",
                                "value": "VESTA"
                            },
                            {
                                "label": "Jmol",
                                "value": "Jmol"
                            },
                            {
                                "label": "Accessible",
                                "value": "accessible"
                            },
                        ],
                        value=self.initial_data["display_options"]
                        ["color_scheme"],
                        clearable=False,
                        persistence=SETTINGS.PERSISTENCE,
                        persistence_type=SETTINGS.PERSISTENCE_TYPE,
                        id=self.id("color-scheme"),
                    ),
                    className="mpc-control",
                ),
                html.Label("Change atomic radii:", className="mpc-label"),
                html.Div(
                    dcc.Dropdown(
                        options=[
                            {
                                "label": "Ionic",
                                "value": "specified_or_average_ionic"
                            },
                            {
                                "label": "Covalent",
                                "value": "covalent"
                            },
                            {
                                "label": "Van der Waals",
                                "value": "van_der_waals"
                            },
                            {
                                "label": f"Uniform ({Legend.uniform_radius}Å)",
                                "value": "uniform",
                            },
                        ],
                        value=self.initial_data["display_options"]
                        ["radius_strategy"],
                        clearable=False,
                        persistence=SETTINGS.PERSISTENCE,
                        persistence_type=SETTINGS.PERSISTENCE_TYPE,
                        id=self.id("radius_strategy"),
                    ),
                    className="mpc-control",
                ),
                html.Label("Draw options:", className="mpc-label"),
                html.Div([
                    dcc.Checklist(
                        options=[
                            {
                                "label":
                                "Draw repeats of atoms on periodic boundaries",
                                "value": "draw_image_atoms",
                            },
                            {
                                "label":
                                "Draw atoms outside unit cell bonded to "
                                "atoms within unit cell",
                                "value": "bonded_sites_outside_unit_cell",
                            },
                            {
                                "label":
                                "Hide bonds where destination atoms are not shown",
                                "value": "hide_incomplete_bonds",
                            },
                        ],
                        value=[
                            opt for opt in (
                                "draw_image_atoms",
                                "bonded_sites_outside_unit_cell",
                                "hide_incomplete_bonds",
                            ) if self.initial_data["display_options"][opt]
                        ],
                        labelStyle={"display": "block"},
                        inputClassName="mpc-radio",
                        id=self.id("draw_options"),
                        persistence=SETTINGS.PERSISTENCE,
                        persistence_type=SETTINGS.PERSISTENCE_TYPE,
                    )
                ]),
                html.Label("Hide/show:", className="mpc-label"),
                html.Div(
                    [
                        dcc.Checklist(
                            options=[
                                {
                                    "label": "Atoms",
                                    "value": "atoms"
                                },
                                {
                                    "label": "Bonds",
                                    "value": "bonds"
                                },
                                {
                                    "label": "Unit cell",
                                    "value": "unit_cell"
                                },
                                {
                                    "label": "Polyhedra",
                                    "value": "polyhedra"
                                },
                                {
                                    "label": "Axes",
                                    "value": "axes"
                                },
                            ],
                            value=["atoms", "bonds", "unit_cell", "polyhedra"],
                            labelStyle={"display": "block"},
                            inputClassName="mpc-radio",
                            id=self.id("hide-show"),
                            persistence=SETTINGS.PERSISTENCE,
                            persistence_type=SETTINGS.PERSISTENCE_TYPE,
                        )
                    ],
                    className="mpc-control",
                ),
            ])
        else:
            options_layout = None

        if self.show_legend:
            legend_layout = html.Div(
                self._make_legend(self._initial_data["legend_data"]),
                id=self.id("legend_container"),
            )
        else:
            legend_layout = None

        struct_layout = html.Div([
            CrystalToolkitScene(
                [options_layout, legend_layout],
                id=self.id("scene"),
                className=self.className,
                data=self.initial_data["scene"],
                settings=self.initial_scene_settings,
                sceneSize="100%",
                fileOptions=list(self.download_options["Structure"].keys()),
                showControls=self.show_controls,
                showExpandButton=self.show_expand_button,
                showImageButton=self.show_image_button,
                showExportButton=self.show_export_button,
                showPositionButton=self.show_position_button,
                **self.scene_kwargs,
            ),
            dcc.Download(id=self.id("download-image")),
            dcc.Download(id=self.id("download-structure"))
        ])

        return {
            "struct": struct_layout,
            "options": options_layout,
            "title": title_layout,
            "legend": legend_layout,
        }

    def layout(self, size: str = "500px") -> html.Div:
        """
        :param size: a CSS string specifying width/height of Div
        :return: A html.Div containing the 3D structure or molecule
        """
        return html.Div(self._sub_layouts["struct"],
                        style={
                            "width": size,
                            "height": size
                        })

    @staticmethod
    def _preprocess_structure(
        struct_or_mol: Union[Structure, StructureGraph, Molecule,
                             MoleculeGraph],
        unit_cell_choice: Literal["input", "primitive", "conventional",
                                  "reduced_niggli", "reduced_lll"] = "input",
    ):
        if isinstance(struct_or_mol, Structure):
            if unit_cell_choice != "input":
                if unit_cell_choice == "primitive":
                    struct_or_mol = struct_or_mol.get_primitive_structure()
                elif unit_cell_choice == "conventional":
                    sga = SpacegroupAnalyzer(struct_or_mol)
                    struct_or_mol = sga.get_conventional_standard_structure()
                elif unit_cell_choice == "reduced_niggli":
                    struct_or_mol = struct_or_mol.get_reduced_structure(
                        reduction_algo="niggli")
                elif unit_cell_choice == "reduced_lll":
                    struct_or_mol = struct_or_mol.get_reduced_structure(
                        reduction_algo="LLL")
        return struct_or_mol

    @staticmethod
    def _preprocess_input_to_graph(
        input: Union[Structure, StructureGraph, Molecule, MoleculeGraph],
        bonding_strategy: str = DEFAULTS["bonding_strategy"],
        bonding_strategy_kwargs: Optional[Dict] = None,
    ) -> Union[StructureGraph, MoleculeGraph]:

        if isinstance(input, Structure):

            # ensure fractional co-ordinates are normalized to be in [0,1)
            # (this is actually not guaranteed by Structure)
            try:
                input = input.as_dict(verbosity=0)
            except TypeError:
                # TODO: remove this, necessary for Slab(?), some structure subclasses don't have verbosity
                input = input.as_dict()
            for site in input["sites"]:
                site["abc"] = np.mod(site["abc"], 1)
            input = Structure.from_dict(input)

            if not input.is_ordered:
                # calculating bonds in disordered structures is currently very flaky
                bonding_strategy = "CutOffDictNN"

        # we assume most uses of this class will give a structure as an input argument,
        # meaning we have to calculate the graph for bonding information, however if
        # the graph is already known and supplied, we will use that
        if isinstance(input, StructureGraph) or isinstance(
                input, MoleculeGraph):
            graph = input
        else:
            if (bonding_strategy not in StructureMoleculeComponent.
                    available_bonding_strategies.keys()):
                raise ValueError(
                    "Bonding strategy not supported. Please supply a name "
                    "of a NearNeighbor subclass, choose from: {}".format(
                        ", ".join(StructureMoleculeComponent.
                                  available_bonding_strategies.keys())))
            else:
                bonding_strategy_kwargs = bonding_strategy_kwargs or {}
                if bonding_strategy == "CutOffDictNN":
                    if "cut_off_dict" in bonding_strategy_kwargs:
                        # TODO: remove this hack by making args properly JSON serializable
                        bonding_strategy_kwargs["cut_off_dict"] = {
                            (x[0], x[1]): x[2]
                            for x in bonding_strategy_kwargs["cut_off_dict"]
                        }
                bonding_strategy = StructureMoleculeComponent.available_bonding_strategies[
                    bonding_strategy](**bonding_strategy_kwargs)
                try:
                    with warnings.catch_warnings():
                        warnings.simplefilter("ignore")
                        if isinstance(input, Structure):
                            graph = StructureGraph.with_local_env_strategy(
                                input, bonding_strategy)
                        else:
                            graph = MoleculeGraph.with_local_env_strategy(
                                input, bonding_strategy, reorder=False)
                except:
                    # for some reason computing bonds failed, so let's not have any bonds(!)
                    if isinstance(input, Structure):
                        graph = StructureGraph.with_empty_graph(input)
                    else:
                        graph = MoleculeGraph.with_empty_graph(input)

        return graph

    @staticmethod
    def _get_struct_or_mol(
        graph: Union[StructureGraph, MoleculeGraph, Structure, Molecule]
    ) -> Union[Structure, Molecule]:
        if isinstance(graph, StructureGraph):
            return graph.structure
        elif isinstance(graph, MoleculeGraph):
            return graph.molecule
        elif isinstance(graph, Structure) or isinstance(graph, Molecule):
            return graph
        else:
            raise ValueError

    @staticmethod
    def get_scene_and_legend(
        graph: Optional[Union[StructureGraph, MoleculeGraph]],
        color_scheme=DEFAULTS["color_scheme"],
        color_scale=None,
        radius_strategy=DEFAULTS["radius_strategy"],
        draw_image_atoms=DEFAULTS["draw_image_atoms"],
        bonded_sites_outside_unit_cell=DEFAULTS[
            "bonded_sites_outside_unit_cell"],
        hide_incomplete_bonds=DEFAULTS["hide_incomplete_bonds"],
        explicitly_calculate_polyhedra_hull=False,
        scene_additions=None,
        show_compass=DEFAULTS["show_compass"],
        group_by_site_property=None,
    ) -> Tuple[Scene, Dict[str, str]]:

        scene = Scene(name="StructureMoleculeComponentScene")

        if graph is None:
            return scene, {}

        struct_or_mol = StructureMoleculeComponent._get_struct_or_mol(graph)

        # TODO: add radius_scale
        legend = Legend(
            struct_or_mol,
            color_scheme=color_scheme,
            radius_scheme=radius_strategy,
            cmap_range=color_scale,
        )

        if isinstance(graph, StructureGraph):
            scene = graph.get_scene(
                draw_image_atoms=draw_image_atoms,
                bonded_sites_outside_unit_cell=bonded_sites_outside_unit_cell,
                hide_incomplete_edges=hide_incomplete_bonds,
                explicitly_calculate_polyhedra_hull=
                explicitly_calculate_polyhedra_hull,
                group_by_site_property=group_by_site_property,
                legend=legend,
            )
        elif isinstance(graph, MoleculeGraph):
            scene = graph.get_scene(legend=legend)

        scene.name = "StructureMoleculeComponentScene"

        if hasattr(struct_or_mol, "lattice"):
            axes = struct_or_mol.lattice._axes_from_lattice()
            axes.visible = show_compass
            scene.contents.append(axes)

        scene_json = scene.to_json()

        if scene_additions:
            # TODO: this might be cleaner if we had a Scene.from_json() method
            scene_json["contents"].append(scene_additions)

        return scene_json, legend.get_legend()

    def title_layout(self):
        """
        :return: A layout including the composition of the structure/molecule as a title.
        """
        return self._sub_layouts["title"]
class StructureComponent(MPComponent):
    """
    A component to display pymatgen Structure objects.
    """

    available_bonding_strategies = \
        {subcls.__name__: subcls for subcls in NearNeighbors.__subclasses__()}

    default_scene_settings = {
        "extractAxis": True,
        # For visual diff testing, we change the renderer
        # to SVG since this WebGL support is more difficult
        # in headless browsers / CI.
        "renderer": "svg" if SETTINGS.TEST_MODE else "webgl",
    }

    default_title = "Structure"

    def __init__(
        self,
        structure: Optional[Structure] = None,
        id: str = None,
        scene_additions: Optional[Scene] = None,
        bonding_strategy: str = DEFAULTS["bonding_strategy"],
        bonding_strategy_kwargs: Optional[dict] = None,
        color_scale: Optional[str] = None,
        radius_strategy: str = DEFAULTS["radius_strategy"],
        unit_cell_choice: str = DEFAULTS["unit_cell_choice"],
        draw_image_atoms: bool = DEFAULTS["draw_image_atoms"],
        bonded_sites_outside_unit_cell: bool = DEFAULTS[
            "bonded_sites_outside_unit_cell"],
        hide_incomplete_bonds: bool = DEFAULTS["hide_incomplete_bonds"],
        show_compass: bool = DEFAULTS["show_compass"],
        scene_settings: Optional[Dict] = None,
        **kwargs,
    ):

        super().__init__(id=id, default_data=structure, **kwargs)

        self.initial_scene_settings = self.default_scene_settings.copy()
        if scene_settings:
            self.initial_scene_settings.update(scene_settings)

        self.create_store("scene_settings",
                          initial_data=self.initial_scene_settings)

        # unit cell choice and bonding algorithms need to come from a settings
        # object (in a dcc.Store) guaranteed to be present in layout, rather
        # than from the controls themselves -- since these are optional and
        # may not be present in the layout
        self.create_store(
            "graph_generation_options",
            initial_data={
                "bonding_strategy": bonding_strategy,
                "bonding_strategy_kwargs": bonding_strategy_kwargs,
                "unit_cell_choice": unit_cell_choice,
            },
        )

        self.create_store(
            "display_options",
            initial_data={
                "color_scale": color_scale,
                "radius_strategy": radius_strategy,
                "draw_image_atoms": draw_image_atoms,
                "bonded_sites_outside_unit_cell":
                bonded_sites_outside_unit_cell,
                "hide_incomplete_bonds": hide_incomplete_bonds,
                "show_compass": show_compass,
            },
        )

        if scene_additions:
            initial_scene_additions = Scene(
                name="scene_additions", contents=scene_additions).to_json()
        else:
            initial_scene_additions = None
        self.create_store("scene_additions",
                          initial_data=initial_scene_additions)

        if structure:
            # graph is cached explicitly, this isn't necessary but is an
            # optimization so that graph is only re-generated if bonding
            # algorithm changes
            graph = self._preprocess_input_to_graph(
                structure,
                bonding_strategy=bonding_strategy,
                bonding_strategy_kwargs=bonding_strategy_kwargs,
            )
            scene, legend = self.get_scene_and_legend(
                graph,
                scene_additions=self.initial_data["scene_additions"],
                **self.initial_data["display_options"],
            )
            if hasattr(structure, "lattice"):
                self._lattice = structure.lattice
        else:
            # component could be initialized without a structure, in which case
            # an empty scene should be displayed
            graph = None
            scene, legend = self.get_scene_and_legend(
                None,
                scene_additions=self.initial_data["scene_additions"],
                **self.initial_data["display_options"])

        self.create_store("legend_data", initial_data=legend)
        self.create_store("graph", initial_data=graph)

        # this is used by a Simple3DScene component, not a dcc.Store
        self._initial_data["scene"] = scene

        # hide axes inset for molecules
        if isinstance(structure, Molecule) or isinstance(
                structure, MoleculeGraph):
            self.scene_kwargs = {"axisView": "HIDDEN"}
        else:
            self.scene_kwargs = {}

    def _make_legend(self, legend):

        if not legend:
            return html.Div(id=self.id("legend"))

        def get_font_color(hex_code):
            # ensures contrasting font color for background color
            c = tuple(int(hex_code[1:][i:i + 2], 16) for i in (0, 2, 4))
            if 1 - (c[0] * 0.299 + c[1] * 0.587 + c[2] * 0.114) / 255 < 0.5:
                font_color = "#000000"
            else:
                font_color = "#ffffff"
            return font_color

        try:
            formula = Composition.from_dict(
                legend["composition"]).reduced_formula
        except:
            # TODO: fix legend for Dummy Specie compositions
            formula = "Unknown"

        legend_colors = OrderedDict(
            sorted(list(legend["colors"].items()),
                   key=lambda x: formula.find(x[1])))

        legend_elements = [
            Button(
                html.Span(name,
                          className="icon",
                          style={"color": get_font_color(color)}),
                kind="static",
                style={"backgroundColor": color},
            ) for color, name in legend_colors.items()
        ]

        return Field(
            [
                Control(el, style={"marginRight": "0.2rem"})
                for el in legend_elements
            ],
            id=self.id("legend"),
            grouped=True,
        )

    def _make_title(self, legend):

        if not legend or (not legend.get("composition", None)):
            return H1(self.default_title, id=self.id("title"))

        composition = legend["composition"]
        if isinstance(composition, dict):

            try:
                composition = Composition.from_dict(composition)

                # strip DummySpecie if present (TODO: should be method in pymatgen)
                composition = Composition({
                    el: amt
                    for el, amt in composition.items()
                    if not isinstance(el, DummySpecie)
                })
                composition = composition.get_reduced_composition_and_factor(
                )[0]
                formula = composition.reduced_formula
                formula_parts = re.findall(r"[^\d_]+|\d+", formula)
                formula_components = [
                    html.Sub(part.strip())
                    if part.isnumeric() else html.Span(part.strip())
                    for part in formula_parts
                ]
            except:
                formula_components = list(map(str, composition.keys()))

        return H1(formula_components,
                  id=self.id("title"),
                  style={"display": "inline-block"})

    @staticmethod
    def _make_bonding_algorithm_custom_cuffoff_data(graph):
        if not graph:
            return [{"A": None, "B": None, "A—B": None}]
        struct_or_mol = StructureComponent._get_structure(graph)
        # can't use type_of_specie because it doesn't work with disordered structures
        species = set(
            map(
                str,
                chain.from_iterable(
                    [list(c.keys()) for c in struct_or_mol.species_and_occu]),
            ))
        rows = [{
            "A": combination[0],
            "B": combination[1],
            "A—B": 0
        } for combination in combinations_with_replacement(species, 2)]
        return rows

    @property
    def _sub_layouts(self):

        struct_layout = html.Div(
            Simple3DScene(
                id=self.id("scene"),
                data=self.initial_data["scene"],
                settings=self.initial_scene_settings,
                sceneSize="100%",
                **self.scene_kwargs,
            ),
            style={
                "width": "100%",
                "height": "100%",
                "overflow": "hidden",
                "margin": "0 auto",
            },
        )

        title_layout = html.Div(
            self._make_title(self._initial_data["legend_data"]),
            id=self.id("title_container"),
        )

        legend_layout = html.Div(
            self._make_legend(self._initial_data["legend_data"]),
            id=self.id("legend_container"),
        )

        return {
            "struct": struct_layout,
            "title": title_layout,
            "legend": legend_layout
        }

    def layout(self, size: str = "500px") -> html.Div:
        """
        :param size: a CSS string specifying width/height of Div
        :return: A html.Div containing the 3D structure or molecule
        """
        return html.Div(self._sub_layouts["struct"],
                        style={
                            "width": size,
                            "height": size
                        })

    @staticmethod
    def _preprocess_input_to_graph(
        input: Union[Structure, StructureGraph],
        bonding_strategy: str = DEFAULTS["bonding_strategy"],
        bonding_strategy_kwargs: Optional[Dict] = None,
    ) -> Union[StructureGraph, MoleculeGraph]:

        # ensure fractional co-ordinates are normalized to be in [0,1)
        # (this is actually not guaranteed by Structure)
        try:
            input = input.as_dict(verbosity=0)
        except TypeError:
            # TODO: remove this, necessary for Slab(?), some structure subclasses don't have verbosity
            input = input.as_dict()
        for site in input["sites"]:
            site["abc"] = np.mod(site["abc"], 1)
        input = Structure.from_dict(input)

        if not input.is_ordered:
            # calculating bonds in disordered structures is currently very flaky
            bonding_strategy = "CutOffDictNN"

        if (bonding_strategy
                not in StructureComponent.available_bonding_strategies.keys()):
            raise ValueError(
                "Bonding strategy not supported. Please supply a name "
                "of a NearNeighbor subclass, choose from: {}".format(", ".join(
                    StructureComponent.available_bonding_strategies.keys())))
        else:
            bonding_strategy_kwargs = bonding_strategy_kwargs or {}
            if bonding_strategy == "CutOffDictNN":
                if "cut_off_dict" in bonding_strategy_kwargs:
                    # TODO: remove this hack by making args properly JSON serializable
                    bonding_strategy_kwargs["cut_off_dict"] = {
                        (x[0], x[1]): x[2]
                        for x in bonding_strategy_kwargs["cut_off_dict"]
                    }
            bonding_strategy = StructureComponent.available_bonding_strategies[
                bonding_strategy](**bonding_strategy_kwargs)
            try:
                with warnings.catch_warnings():
                    warnings.simplefilter("ignore")
                    graph = StructureGraph.with_local_env_strategy(
                        input, bonding_strategy)
            except:
                # for some reason computing bonds failed, so let's not have any bonds(!)
                graph = StructureGraph.with_empty_graph(input)

        return graph

    def generate_callbacks(self, app, cache):
        pass

    @staticmethod
    def _get_structure(
            graph: Union[StructureGraph,
                         Structure]) -> Union[Structure, Molecule]:
        if isinstance(graph, StructureGraph):
            return graph.structure
        elif isinstance(graph, Structure):
            return graph
        else:
            raise ValueError

    @staticmethod
    def get_scene_and_legend(
        graph: Optional[Union[StructureGraph, MoleculeGraph]],
        color_scale=None,
        radius_strategy=DEFAULTS["radius_strategy"],
        draw_image_atoms=DEFAULTS["draw_image_atoms"],
        bonded_sites_outside_unit_cell=DEFAULTS[
            "bonded_sites_outside_unit_cell"],
        hide_incomplete_bonds=DEFAULTS["hide_incomplete_bonds"],
        explicitly_calculate_polyhedra_hull=False,
        scene_additions=None,
        show_compass=DEFAULTS["show_compass"],
    ) -> Tuple[Scene, Dict[str, str]]:

        scene = Scene(name="AceStructureMoleculeComponentScene")

        if graph is None:
            return scene, {}

        structure = StructureComponent._get_structure(graph)

        # TODO: add radius_scale
        legend = Legend(
            structure,
            color_scheme="VESTA",
            radius_scheme=radius_strategy,
            cmap_range=color_scale,
        )

        scene = graph.get_scene(
            draw_image_atoms=draw_image_atoms,
            bonded_sites_outside_unit_cell=bonded_sites_outside_unit_cell,
            hide_incomplete_edges=hide_incomplete_bonds,
            explicitly_calculate_polyhedra_hull=
            explicitly_calculate_polyhedra_hull,
            legend=legend,
        )

        scene.name = "StructureComponentScene"

        if hasattr(structure, "lattice"):
            axes = structure.lattice._axes_from_lattice()
            axes.visible = show_compass
            scene.contents.append(axes)

        scene = scene.to_json()
        if scene_additions:
            # TODO: need a Scene.from_json() to make this work
            # raise NotImplementedError
            scene["contents"].append(scene_additions)

        return scene, legend.get_legend()

    def title_layout(self):
        """
        :return: A layout including the composition of the structure/molecule as a title.
        """
        return self._sub_layouts["title"]

    def legend_layout(self):
        """
        :return: A layout including a legend for the structure/molecule.
        """
        return self._sub_layouts["legend"]
Exemple #12
0
class PymatgenVisualizationIntermediateFormat:
    """
    Class takes a Structure or StructureGraph and outputs a
    list of primitives (spheres, cylinders, etc.) for drawing purposes.
    """

    available_bonding_strategies = {
        subclass.__name__: subclass
        for subclass in NearNeighbors.__subclasses__()
    }

    available_radius_strategies = ('atomic', 'bvanalyzer_ionic',
                                   'average_ionic', 'covalent',
                                   'van_der_waals', 'atomic_calculated')

    def __init__(
        self,
        structure,
        bonding_strategy='MinimumDistanceNN',
        bonding_strategy_kwargs=None,
        color_scheme="VESTA",
        color_scale=None,
        radius_strategy="average_ionic",
        draw_image_atoms=True,
        repeat_of_atoms_on_boundaries=True,
        bonded_sites_outside_display_area=True,
        symmetrize=True,  # TODO
        display_repeats=((0, 2), (0, 2), (0, 2))):
        """
        This class is used to generate a generic JSON of geometric primitives
        that can be parsed by a 3D rendering tool such as Three.js or Blender.

        The intent is that any crystallographic knowledge is handled by pymatgen
        (e.g. finding Cartesian co-ordinates of atoms, suggesting colors and radii
        for the atoms, detecting whether andw here bonds should be present, detecting
        oxidation environments, calculating coordination polyhedra, etc.) Therefore,
        the 3D rendering code does not need to have any knowledge of crystallography,
        but just has to parse the JSON and draw the corresponding geometric primitives.

        The resulting JSON should be self-explanatory and constructed such that
        an appropriate scene graph can be made without difficulty. The 'type'
        keys can be either 'sphere', 'cylinder', 'geometry' or 'text', the 'position'
        key is always a vector with respect to global Cartesian co-ordinates.

        Bonds and co-ordination polyhedra reference the index of the appropriate
        atom, rather than duplicating position vectors.

        Args:
            structure: an instance of Structure or StructureGraph (the latter
            for when bonding information is already known)
            bonding_strategy: a name of a NearNeighbor subclass to calculate
            bonding, will only be used if bonding not already defined (at time
            of writing, options are 'JMolNN', 'MinimumDistanceNN', 'MinimumOKeeffeeNN',
            'MinimumVIRENN', 'VoronoiNN', but see pymatgen.analysis.local_env for
            latest options)
            bonding_strategy_kwargs: kwargs to be passed to the NearNeighbor
            subclass
            color_scheme: can be "VESTA", "Jmol" or a name of a site property
            to color-code by that site property
            color_scale: only relevant if color-coding by site property, by default
            will use viridis, but all matplotlib color scales are supported (see
            matplotlib documentation)
            radius_strategy: options are "ionic" or "van_der_Waals":
            by default, will use ionic radii if oxidation states are supplied or if BVAnalyzer
            can supply oxidation states, otherwise will use average ionic radius
        """
        # TODO: update docstring

        # ensure fractional co-ordinates are normalized to be in [0,1)
        # (this is actually not guaranteed by Structure)
        if isinstance(structure, Structure):
            structure = structure.as_dict(verbosity=0)
            for site in structure['sites']:
                site['abc'] = np.mod(site['abc'], 1)
            structure = Structure.from_dict(structure)

        # we assume most uses of this class will give a structure as an input argument,
        # meaning we have to calculate the graph for bonding information, however if
        # the graph is already known and supplied, we will use that
        if not isinstance(structure, StructureGraph):
            if bonding_strategy not in self.available_bonding_strategies.keys(
            ):
                raise ValueError(
                    "Bonding strategy not supported. Please supply a name "
                    "of a NearNeighbor subclass, choose from: {}".format(
                        ", ".join(self.available_bonding_strategies.keys())))
            else:
                bonding_strategy_kwargs = bonding_strategy_kwargs or {}
                bonding_strategy = self.available_bonding_strategies[
                    bonding_strategy](**bonding_strategy_kwargs)
                self.structure_graph = StructureGraph.with_local_env_strategy(
                    structure, bonding_strategy)
                cns = [
                    self.structure_graph.get_coordination_of_site(i)
                    for i in range(len(structure))
                ]
                self.structure_graph.structure.add_site_property(
                    'coordination_no', cns)
        else:
            self.structure_graph = structure

        self.structure = self.structure_graph.structure
        self.lattice = self.structure_graph.structure.lattice

        self.color_scheme = color_scheme  # TODO: add coord as option
        self.color_scale = color_scale
        self.radius_strategy = radius_strategy
        self.draw_image_atoms = draw_image_atoms
        self.bonded_sites_outside_display_area = bonded_sites_outside_display_area
        self.display_range = display_repeats

        # categorize site properties so we know which can be used for color schemes etc.
        self.site_prop_names = self._analyze_site_props()

        self.color_legend = self._generate_colors(
        )  # adds 'display_color' site prop
        self._generate_radii()  # adds 'display_radius' site prop

    @property
    def json(self):

        atoms = self._generate_atoms()
        bonds = self._generate_bonds(atoms)
        polyhedra_json = self._generate_polyhedra(atoms, bonds)
        unit_cell_json = self._generate_unit_cell()

        json = {
            'atoms': list(atoms.values()),
            'bonds': list(bonds.values()),
            'polyhedra': polyhedra_json,
            'unit_cell': unit_cell_json,
            'color_legend': self.color_legend,
            'site_props': self.site_prop_names
        }

        return json

    @property
    def graph_json(self):

        nodes = []
        edges = []

        for node in self.structure_graph.graph.nodes():

            r, g, b = self.structure_graph.structure[node].properties[
                'display_color'][0]
            color = "#{:02x}{:02x}{:02x}".format(r, g, b)

            nodes.append({'id': node, 'label': node, 'color': color})

        for u, v, d in self.structure_graph.graph.edges(data=True):

            edge = {'from': u, 'to': v, 'arrows': ''}

            to_jimage = d['to_jimage']

            # TODO: check these edge weights
            dist = self.structure.get_distance(u, v, to_jimage)
            edge['length'] = 50 * dist
            if to_jimage != (0, 0, 0):
                edge['arrows'] = 'to'
                edge['label'] = str(to_jimage)

            edges.append(edge)

        return {'nodes': nodes, 'edges': edges}

    def _analyze_site_props(self):

        # store list of site props that are vectors, so these can be displayed as arrows
        # (implicitly assumes all site props for a given key are same type)
        site_prop_names = defaultdict(list)
        for name, props in self.structure_graph.structure.site_properties.items(
        ):
            if isinstance(props[0], float) or isinstance(props[0], int):
                site_prop_names['scalar'].append(name)
            elif isinstance(props[0], list) and len(props[0]) == 3:
                if isinstance(props[0][0], list) and len(props[0][0]) == 3:
                    site_prop_names['matrix'].append(name)
                else:
                    site_prop_names['vector'].append(name)
            elif isinstance(props[0], str):
                site_prop_names['categorical'].append(name)

        return dict(site_prop_names)

    def _generate_atoms(self):

        # to translate atoms so that geometric center at (0, 0, 0)
        # in global co-ordinate system
        x_center = 0.5 * (max(self.display_range[0]) -
                          min(self.display_range[0]))
        y_center = 0.5 * (max(self.display_range[1]) -
                          min(self.display_range[1]))
        z_center = 0.5 * (max(self.display_range[2]) -
                          min(self.display_range[2]))
        self.geometric_center = self.lattice.get_cartesian_coords(
            (x_center, y_center, z_center))

        ranges = [
            range(int(np.sign(r[0]) * np.ceil(np.abs(r[0]))),
                  1 + int(np.sign(r[1]) * np.ceil(np.abs(r[1]))))
            for r in self.display_range
        ]
        possible_images = list(itertools.product(*ranges))

        site_images_to_draw = defaultdict(list)

        lower_corner = np.array([min(r) for r in self.display_range])
        upper_corner = np.array([max(r) for r in self.display_range])
        for idx, site in enumerate(self.structure):
            for image in possible_images:
                frac_coords = np.add(image, site.frac_coords)
                if np.all(np.less_equal(lower_corner, frac_coords)) \
                        and np.all(np.less_equal(frac_coords, upper_corner)):
                    site_images_to_draw[idx].append(image)

        images_to_add = defaultdict(list)
        if self.bonded_sites_outside_display_area:
            for site_idx, images in site_images_to_draw.items():
                for u, v, d in self.structure_graph.graph.edges(
                        nbunch=site_idx, data=True):

                    for image in images:
                        # check bonds going in both directions, i.e. from u or from v
                        # to_image is defined from u going to v
                        to_image = tuple(
                            np.add(d['to_jimage'], image).astype(int))

                        # make sure we're drawing the site the bond is going to
                        if to_image not in site_images_to_draw[v]:
                            images_to_add[v].append(to_image)

                        # and also the site the bond is coming from
                        from_image_complement = tuple(np.multiply(
                            -1, to_image))
                        if from_image_complement not in site_images_to_draw[u]:
                            images_to_add[u].append(from_image_complement)

        atoms = OrderedDict()
        for site_idx, images in site_images_to_draw.items():

            site = self.structure[site_idx]

            # for disordered structures
            occu_start = 0.0
            fragments = []

            for comp_idx, (sp,
                           occu) in enumerate(site.species_and_occu.items()):

                # in disordered structures, we fractionally color-code spheres,
                # drawing a sphere segment from phi_end to phi_start
                # (think a sphere pie chart)
                phi_frac_end = occu_start + occu
                phi_frac_start = occu_start
                occu_start = phi_frac_end

                radius = site.properties['display_radius'][comp_idx]
                color = site.properties['display_color'][comp_idx]

                name = "{}".format(sp)
                if occu != 1.0:
                    name += " ({}% occupancy)".format(occu)

                fragments.append({
                    'radius': radius,
                    'color': color,
                    'name': name,
                    'phi_start': phi_frac_start * np.pi * 2,
                    'phi_end': phi_frac_end * np.pi * 2
                })

            bond_color = fragments[0]['color'] if site.is_ordered else [
                55, 55, 55
            ]

            # TODO: do some appropriate scaling here
            if 'vector' in self.site_prop_names:
                vectors = {
                    name: site.properties[name]
                    for name in self.site_prop_names['vector']
                }
            else:
                vectors = None

            if 'matrix' in self.site_prop_names:
                matrices = {
                    name: site.properties[name]
                    for name in self.site_prop_names['matrix']
                }
            else:
                matrices = None

            for image in images:

                position_cart = list(
                    np.subtract(
                        self.lattice.get_cartesian_coords(
                            np.add(site.frac_coords, image)),
                        self.geometric_center))

                atoms[(site_idx, image)] = {
                    'type': 'sphere',
                    'idx': len(atoms),
                    'position': position_cart,
                    'bond_color': bond_color,
                    'fragments': fragments,
                    'vectors': vectors,
                    'matrices': matrices
                }

        return atoms

    def _generate_bonds(self, atoms):

        bonds_set = set()
        atoms = list(atoms.keys())

        for site_idx, image in atoms:
            for u, v, d in self.structure_graph.graph.edges(nbunch=site_idx,
                                                            data=True):

                to_image = tuple(np.add(d['to_jimage'], image).astype(int))

                bond = frozenset({(u, image), (v, to_image)})
                bonds_set.add(bond)

        bonds = OrderedDict()
        for bond in bonds_set:

            bond = tuple(bond)

            try:
                from_atom_idx = atoms.index(bond[0])
                to_atom_idx = atoms.index(bond[1])
            except ValueError:
                pass  # one of the atoms in the bond isn't being drawn
            else:
                bonds[bond] = {
                    'from_atom_index': from_atom_idx,
                    'to_atom_index': to_atom_idx
                }

        return bonds

    def _generate_radii(self):

        structure = self.structure_graph.structure

        # don't calculate radius if one is explicitly supplied
        if 'display_radius' in structure.site_properties:
            return

        if self.radius_strategy is 'bvanalyzer_ionic':

            trans = AutoOxiStateDecorationTransformation()
            try:
                structure = trans.apply_transformation(
                    self.structure_graph.structure)
            except:
                # if we can't assign valences use average ionic
                self.radius_strategy = 'average_ionic'

        radii = []
        for site_idx, site in enumerate(structure):

            site_radii = []

            for comp_idx, (sp,
                           occu) in enumerate(site.species_and_occu.items()):

                radius = None

                if self.radius_strategy not in self.available_radius_strategies:
                    raise ValueError(
                        "Unknown radius strategy {}, choose from: {}".format(
                            self.radius_strategy,
                            self.available_radius_strategies))

                if self.radius_strategy is 'atomic':
                    radius = sp.atomic_radius
                elif self.radius_strategy is 'bvanalyzer_ionic' and isinstance(
                        sp, Specie):
                    radius = sp.ionic_radius
                elif self.radius_strategy is 'average_ionic':
                    radius = sp.average_ionic_radius
                elif self.radius_strategy is 'covalent':
                    el = str(getattr(sp, 'element', sp))
                    radius = CovalentRadius.radius[el]
                elif self.radius_strategy is 'van_der_waals':
                    radius = sp.van_der_waals_radius
                elif self.radius_strategy is 'atomic_calculated':
                    radius = sp.atomic_radius_calculated

                if not radius:
                    warnings.warn('Radius unknown for {} and strategy {}, '
                                  'setting to 1.0.'.format(
                                      sp, self.radius_strategy))
                    radius = 1.0

                site_radii.append(radius)

            radii.append(site_radii)

        self.structure_graph.structure.add_site_property(
            'display_radius', radii)

    def _generate_colors(self):

        structure = self.structure_graph.structure
        legend = {}

        # don't calculate color if one is explicitly supplied
        if 'display_color' in structure.site_properties:
            return legend  # don't know what the color legend (meaning) is, so return empty legend

        if self.color_scheme not in ('VESTA', 'Jmol'):

            if not structure.is_ordered:
                raise ValueError(
                    'Can only use VESTA or Jmol color schemes '
                    'for disordered structures, color schemes based '
                    'on site properties are ill-defined.')

            if self.color_scheme in self.site_prop_names.get('scalar', []):

                props = np.array(structure.site_properties[self.color_scheme])

                if min(props) < 0 and max(props) > 0:
                    # by default, use blue-grey-red color scheme,
                    # so that zero is ~ grey, and positive/negative
                    # are red/blue
                    color_scale = self.color_scale or 'coolwarm'
                    # try to keep color scheme symmetric around 0
                    color_max = max([abs(min(props)), max(props)])
                    color_min = -color_max
                else:
                    # but if all values are positive, use a
                    # perceptually-uniform color scale by default
                    # like viridis
                    color_scale = self.color_scale or 'viridis'
                    color_max = max(props)
                    color_min = min(props)

                cmap = get_cmap(color_scale)
                # normalize in [0, 1] range, as expected by cmap
                props = (props - min(props)) / (max(props) - min(props))

                def _get_color(x):
                    return [int(c * 255) for c in cmap(x)[0:3]]

                colors = [[_get_color(x)] for x in props]

                # construct legend
                c = "#{:02x}{:02x}{:02x}".format(*_get_color(color_min))
                legend[c] = "{}".format(color_min)
                if color_max != color_min:

                    c = "#{:02x}{:02x}{:02x}".format(*_get_color(color_max))
                    legend[c] = "{}".format(color_max)

                    color_mid = (color_max - color_min) / 2
                    if color_max % 1 == 0 and color_min % 1 == 0 and color_max - color_min > 1:
                        color_mid = int(color_mid)

                    c = "#{:02x}{:02x}{:02x}".format(*_get_color(color_mid))
                    legend[c] = "{}".format(color_mid)

            elif self.color_scheme in self.site_prop_names.get(
                    'categorical', []):
                raise NotImplementedError
                # iter() a palettable  palettable.colorbrewer.qualitative cmap.colors, check len, Set1_9 ?

            else:
                raise ValueError(
                    'Unsupported color scheme. Should be "VESTA", "Jmol" or '
                    'a scalar or categorical site property.')
        else:

            colors = []
            for site in structure:
                elements = [
                    sp.as_dict()['element']
                    for sp, _ in site.species_and_occu.items()
                ]
                colors.append([
                    EL_COLORS[self.color_scheme][element]
                    for element in elements
                ])

                # construct legend
                for element in elements:
                    color = "#{:02x}{:02x}{:02x}".format(
                        *EL_COLORS[self.color_scheme][element])
                    legend[color] = element

        self.structure_graph.structure.add_site_property(
            'display_color', colors)

        return legend

    def _generate_unit_cell(self):

        o = -self.geometric_center
        a, b, c = self.lattice.matrix[0], self.lattice.matrix[
            1], self.lattice.matrix[2]

        line_pairs = [
            o, o + a, o, o + b, o, o + c, o + a, o + a + b, o + a, o + a + c,
            o + b, o + b + a, o + b, o + b + c, o + c, o + c + a, o + c,
            o + c + b, o + a + b, o + a + b + c, o + a + c, o + a + b + c,
            o + b + c, o + a + b + c
        ]

        line_pairs = [line.tolist() for line in line_pairs]

        unit_cell = {'type': 'lines', 'lines': line_pairs}

        return unit_cell

    def _generate_polyhedra(self, atoms, bonds):

        # TODO: this function is a bit confusing
        # mostly due to number of similarly-named data structures ... rethink?

        potential_polyhedra_by_site = {}
        for idx, site in enumerate(self.structure):
            connected_sites = self.structure_graph.get_connected_sites(idx)
            neighbors_sp = [cn[0].species_string for cn in connected_sites]
            neighbors_idx = [cn.index for cn in connected_sites]
            # could enforce len(set(neighbors_sp)) == 1 here if we want to only
            # draw polyhedra when neighboring atoms are all the same
            if len(neighbors_sp) > 2:
                # store num expected vertices, we don't want to draw incomplete polyhedra
                potential_polyhedra_by_site[idx] = len(neighbors_sp)

        polyhedra = defaultdict(list)
        for ((from_site_idx, from_image), (to_site_idx,
                                           to_image)), d in bonds.items():
            if from_site_idx in potential_polyhedra_by_site:
                polyhedra[(from_site_idx, from_image)].append(
                    (to_site_idx, to_image))
            if to_site_idx in potential_polyhedra_by_site:
                polyhedra[(to_site_idx, to_image)].append(
                    (from_site_idx, from_image))

        # discard polyhedra with incorrect coordination (e.g. half the polyhedra's atoms are
        # not in the draw range so would be cut off)
        polyhedra = {
            k: v
            for k, v in polyhedra.items()
            if len(v) == potential_polyhedra_by_site[k[0]]
        }

        polyhedra_by_species = defaultdict(list)
        for k in polyhedra.keys():
            polyhedra_by_species[self.structure[k[0]].species_string].append(k)

        polyhedra_json_by_species = {}
        polyhedra_by_species_vertices = {}
        polyhedra_by_species_centres = {}
        for sp, polyhedra_centres in polyhedra_by_species.items():
            polyhedra_json = []
            polyhedra_vertices = []
            for polyhedron_centre in polyhedra_centres:

                # book-keeping to prevent intersecting polyhedra
                polyhedron_vertices = polyhedra[polyhedron_centre]
                polyhedra_vertices += polyhedron_vertices

                polyhedron_points_cart = [
                    atoms[vert]['position'] for vert in polyhedron_vertices
                ]
                polyhedron_points_idx = [
                    atoms[vert]['idx'] for vert in polyhedron_vertices
                ]
                polyhedron_center_idx = atoms[polyhedron_centre]['idx']

                # Delaunay can fail in some edge cases
                try:

                    hull = Delaunay(
                        polyhedron_points_cart).convex_hull.tolist()

                    # TODO: storing duplicate info here ... ?
                    polyhedra_json.append({
                        'type': 'convex',
                        'points_idx': polyhedron_points_idx,
                        'points': polyhedron_points_cart,
                        'hull': hull,
                        'center': polyhedron_center_idx
                    })

                except Exception as e:
                    print(e)

            polyhedron_centres = set(polyhedra_centres)
            polyhedron_vertices = set(polyhedra_vertices)

            if (not polyhedron_vertices.intersection(polyhedra_centres)) \
                    and len(polyhedra_json) > 0 :
                name = "{}-centered".format(sp)
                polyhedra_json_by_species[name] = polyhedra_json
                polyhedra_by_species_centres[name] = polyhedron_centres
                polyhedra_by_species_vertices[name] = polyhedron_vertices

        if polyhedra_json_by_species:

            # get compatible sets of polyhedra
            compatible_subsets = {
                (k, ): len(v)
                for k, v in polyhedra_json_by_species.items()
            }
            for r in range(2, len(polyhedra_json_by_species) + 1):
                for subset in itertools.combinations(
                        polyhedra_json_by_species.keys(), r):
                    compatible = True
                    all_centres = set.union(
                        *[polyhedra_by_species_vertices[sp] for sp in subset])
                    all_verts = set.union(
                        *[polyhedra_by_species_centres[sp] for sp in subset])
                    if not all_verts.intersection(all_centres):
                        compatible_subsets[tuple(subset)] = len(all_centres)

            # sort by longest subset, secondary sort by radius
            compatible_subsets = sorted(compatible_subsets.items(),
                                        key=lambda s: -s[1])
            default_polyhedra = list(compatible_subsets[0][0])

        else:

            default_polyhedra = []

        return {
            'polyhedra_by_type': polyhedra_json_by_species,
            'polyhedra_types': list(polyhedra_json_by_species.keys()),
            'default_polyhedra_types': default_polyhedra
        }
Exemple #13
0
class MPVisualizer:
    """
    Class takes a Structure or StructureGraph and outputs a
    list of primitives (spheres, cylinders, etc.) for drawing purposes.
    """

    allowed_bonding_strategies = {
        subclass.__name__: subclass
        for subclass in NearNeighbors.__subclasses__()
    }

    def __init__(self,
                 structure,
                 bonding_strategy='MinimumOKeeffeNN',
                 bonding_strategy_kwargs=None,
                 color_scheme="VESTA",
                 color_scale=None,
                 radius_strategy="ionic",
                 coordination_polyhedra=False,
                 draw_image_atoms=True,
                 repeat_of_atoms_on_boundaries=True,
                 bonded_atoms_outside_unit_cell=True,
                 scale=None):
        """
        This class is used to generate a generic JSON of geometric primitives
        that can be parsed by a 3D rendering tool such as Three.js or Blender.

        The intent is that any crystallographic knowledge is handled by pymatgen
        (e.g. finding Cartesian co-ordinates of atoms, suggesting colors and radii
        for the atoms, detecting whether andw here bonds should be present, detecting
        oxidation environments, calculating coordination polyhedra, etc.) Therefore,
        the 3D rendering code does not need to have any knowledge of crystallography,
        but just has to parse the JSON and draw the corresponding geometric primitives.

        The resulting JSON should be self-explanatory and constructed such that
        an appropriate scene graph can be made without difficulty. The 'type'
        keys can be either 'sphere', 'cylinder', 'geometry' or 'text', the 'position'
        key is always a vector with respect to global Cartesian co-ordinates.

        Bonds and co-ordination polyhedra reference the index of the appropriate
        atom, rather than duplicating position vectors.

        Args:
            structure: an instance of Structure or StructureGraph (the latter
            for when bonding information is already known)
            bonding_strategy: a name of a NearNeighbor subclass to calculate
            bonding, will only be used if bonding not already defined (at time
            of writing, options are 'JMolNN', 'MinimumDistanceNN', 'MinimumOKeeffeeNN',
            'MinimumVIRENN', 'VoronoiNN', but see pymatgen.analysis.local_env for
            latest options)
            bonding_strategy_kwargs: kwargs to be passed to the NearNeighbor
            subclass
            color_scheme: can be "VESTA", "Jmol" or a name of a site property
            to color-code by that site property
            color_scale: only relevant if color-coding by site property, by default
            will use viridis, but all matplotlib color scales are supported (see
            matplotlib documentation)
            radius_strategy: options are "ionic" or "van_der_Waals":
            by default, will use ionic radii if oxidation states are supplied or if BVAnalyzer
            can supply oxidation states, otherwise will use average ionic radius
            coordination_polyhedra: if False, will not calculate coordination polyhedra,
            if True will calculate coordination polyhedra taking the smallest atoms as
            the vertices of the polyhedra, if a tuple of ints is supplied will only draw
            polyhedra with those number of vertices (e.g. supplying (4,6) will only draw
            tetrahedra and octahedra)
        """

        # draw periodic repeats if requested
        if scale:
            structure = structure * scale

        # we assume most uses of this class will give a structure as an input argument,
        # meaning we have to calculate the graph for bonding information, however if
        # the graph is already known and supplied, we will use that
        if not isinstance(structure, StructureGraph):
            if bonding_strategy not in self.allowed_bonding_strategies.keys():
                raise ValueError(
                    "Bonding strategy not supported. Please supply a name "
                    "of a NearNeighbor subclass, choose from: {}".format(
                        ", ".join(self.allowed_bonding_strategies.keys())))
            else:
                bonding_strategy_kwargs = bonding_strategy_kwargs or {}
                bonding_strategy = self.allowed_bonding_strategies[
                    bonding_strategy](**bonding_strategy_kwargs)
                self.structure_graph = StructureGraph.with_local_env_strategy(
                    structure, bonding_strategy, decorate=False)
            self.lattice = structure.lattice
        else:
            self.structure_graph = structure
            self.lattice = self.structure_graph.structure.lattice

        if coordination_polyhedra:
            # TODO: coming soon
            raise NotImplementedError

        self.color_scheme = color_scheme
        self.color_scale = color_scale
        self.radius_strategy = radius_strategy
        self.coordination_polyhedra = coordination_polyhedra
        self.draw_image_atoms = draw_image_atoms
        self.repeat_of_atoms_on_boundaries = repeat_of_atoms_on_boundaries
        self.bonded_atoms_outside_unit_cell = bonded_atoms_outside_unit_cell

    @property
    def json(self):

        json = {}

        # adds display colors as site properties
        self._generate_colors()

        # used to keep track of atoms outside periodic boundaries for bonding
        self._atom_indexes = {}
        self._site_images_to_draw = {}
        self._bonds = []
        self._polyhedra = []

        json.update(self._generate_atoms())
        json.update(self._generate_bonds())
        json.update(self._generate_polyhedra())
        json.update(self._generate_unit_cell())

        return json

    def _generate_atoms(self):

        # try to work out oxidation states
        trans = AutoOxiStateDecorationTransformation
        try:
            bv_structure = trans.apply_transformation(
                self.structure_graph.structure)
        except:
            # if we can't assign valences juse use original structure
            bv_structure = self.structure_graph.structure

        self._bonds = []

        # to translate atoms so that geometric center at (0, 0, 0)
        # in global co-ordinate system
        lattice = self.structure_graph.structure.lattice
        geometric_center = lattice.get_cartesian_coords((0.5, 0.5, 0.5))

        # used to easily find positions of atoms that lie on periodic boundaries
        adjacent_images = [(1, 0, 0), (-1, 0, 0), (0, 1, 0), (0, -1, 0),
                           (0, 0, 1), (0, 0, -1), (1, 1, 0), (1, -1, 0),
                           (-1, 1, 0), (-1, -1, 0), (0, 1, 1), (0, 1, -1),
                           (0, -1, 1), (0, -1, -1), (1, 0, 1), (1, 0, -1),
                           (-1, 0, 1), (-1, 0, -1), (1, 1, 1), (1, 1, -1),
                           (1, -1, 1), (-1, 1, 1), (1, -1, -1), (-1, 1, -1),
                           (-1, -1, 1), (-1, -1, -1)]

        self.site_images_to_draw = {
            i: []
            for i in range(len(self.structure_graph))
        }
        for site_idx, site in enumerate(self.structure_graph.structure):

            images_to_draw = [(0, 0, 0)]
            if self.repeat_of_atoms_on_boundaries:
                possible_positions = [
                    site.frac_coords + adjacent_image
                    for adjacent_image in adjacent_images
                ]
                # essentially find which atoms lie on periodic boundaries, and
                # draw their repeats
                images_to_draw += [
                    image
                    for image, p in zip(adjacent_images, possible_positions)
                    if 0 <= p[0] <= 1 and 0 <= p[1] <= 1 and 0 <= p[2] <= 1
                ]

            self.site_images_to_draw[site_idx] += images_to_draw

            # get bond information for site
            # why? to know if we want to draw an image atom,
            # we have to know what bonds are present
            for image in images_to_draw:

                for u, v, d in self.structure_graph.graph.edges(
                        nbunch=site_idx, data=True):

                    to_jimage = tuple(np.add(d['to_jimage'], image))
                    self._bonds.append((site_idx, image, v, to_jimage))

                    if to_jimage not in self.site_images_to_draw[v]:
                        self.site_images_to_draw[v].append(to_jimage)

                    if self.bonded_atoms_outside_unit_cell and \
                                    to_jimage != (0, 0, 0) and image == (0, 0, 0):

                        from_image_complement = tuple(
                            np.multiply(-1, to_jimage))
                        self._bonds.append(
                            (site_idx, from_image_complement, v, (0, 0, 0)))

                        if from_image_complement not in self.site_images_to_draw[
                                site_idx]:
                            self.site_images_to_draw[site_idx].append(
                                from_image_complement)

        atoms = []
        self._atoms_cart = {}
        for atom_idx, (site_idx,
                       images) in enumerate(self.site_images_to_draw.items()):

            for image in images:

                self._atom_indexes[(site_idx, image)] = len(atoms)

                # for disordered structures
                occu_start = 0.0
                fragments = []

                site = self.structure_graph.structure[site_idx]
                position_cart = list(
                    np.subtract(
                        lattice.get_cartesian_coords(
                            np.add(site.frac_coords, image)),
                        geometric_center))

                self._atoms_cart[(site_idx, image)] = position_cart

                for comp_idx, (sp, occu) in enumerate(
                        site.species_and_occu.items()):

                    # in disordered structures, we fractionally color-code spheres,
                    # drawing a sphere segment from phi_end to phi_start
                    # (think a sphere pie chart)
                    phi_frac_end = occu_start + occu
                    phi_frac_start = occu_start
                    occu_start = phi_frac_end

                    bv_site = bv_structure[site_idx]

                    # get radius of sphere we want to draw
                    if self.radius_strategy == 'ionic':
                        radius = getattr(bv_site.specie, "ionic_radius",
                                         bv_site.specie.average_ionic_radius)
                    else:
                        radius = site.species.van_der_waals_radius

                    color = site.properties['display_color'][comp_idx]

                    # generate a label (e.g. to use for mouse-over text)
                    if not bv_structure.is_ordered and sp != bv_structure[
                            site_idx].specie:
                        # if we can guess what oxidation states are present,
                        # add them as a label ... we only attempt this for ordered structures
                        name = "{} (detected as likely {})".format(
                            site.specie, bv_structure[site_idx].specie)
                    else:
                        name = "{}".format(sp)

                    if occu != 1.0:
                        name += " ({}% occupancy)".format(occu)

                    fragments.append({
                        'radius': radius,
                        'color': color,
                        'name': name,
                        'phi_start': phi_frac_start * np.pi * 2,
                        'phi_end': phi_frac_end * np.pi * 2
                    })

                atoms.append({
                    'type':
                    'sphere',
                    'position':
                    position_cart,
                    'bond_color':
                    fragments[0]['color'] if site.is_ordered else [55, 55, 55],
                    'fragments':
                    fragments,
                    'ghost':
                    True if image != (0, 0, 0) else False
                })

        return {'atoms': atoms}

    def _generate_bonds(self):

        # most of bonding logic is done inside _generate_atoms
        # why? because to decide which atoms we want to draw, we
        # first have to construct the bonds

        bonds = []

        for from_site_idx, from_image, to_site_idx, to_image in self._bonds:

            from_atom_idx = self._atom_indexes[(from_site_idx, from_image)]
            to_atom_idx = self._atom_indexes[(to_site_idx, to_image)]
            bond = {
                'from_atom_index': from_atom_idx,
                'to_atom_index': to_atom_idx
            }

            if bond not in bonds:
                bonds.append(bond)

        return {'bonds': bonds}

    def _generate_colors(self):

        structure = self.structure_graph.structure

        # TODO: get color scale object from matplotlib

        if self.color_scheme not in ('VESTA', 'Jmol'):

            if not structure.is_ordered:
                raise ValueError(
                    'Can only use VESTA or Jmol color schemes '
                    'for disordered structures, color schemes based '
                    'on site properties are ill-defined.')

            if self.color_scheme in structure.site_properties:

                props = np.array(structure.site_properties[self.color_scheme])

                if min(props) < 0:
                    # by default, use blue-grey-red color scheme,
                    # so that zero is ~ grey, and positive/negative
                    # are red/blue
                    color_scale = self.color_scale or 'coolwarm'
                    # try to keep color scheme symmetric around 0
                    color_max = max([abs(min(props)), max(props)])
                    color_min = -color_max
                else:
                    # but if all values are positive, use a
                    # perceptually-uniform color scale by default
                    # like viridis
                    color_scale = self.color_scale or 'viridis'
                    color_max = max(props)
                    color_min = min(props)

                cmap = get_cmap(color_scale)
                # normalize in [0, 1] range, as expected by cmap
                props = (props - min(props)) / (max(props) - min(props))

                # TODO: reduce calls to cmap here
                colors = [[[
                    int(cmap(x)[0] * 255),
                    int(cmap(x)[1] * 255),
                    int(cmap(x)[2] * 255)
                ]] for x in props]

            else:
                raise ValueError(
                    'Unsupported color scheme. Should be "VESTA", "Jmol" or '
                    'a site property.')
        else:

            colors = []
            for site in structure:
                elements = [
                    sp.as_dict()['element']
                    for sp, _ in site.species_and_occu.items()
                ]
                colors.append([
                    EL_COLORS[self.color_scheme][element]
                    for element in elements
                ])

        self.structure_graph.structure.add_site_property(
            'display_color', colors)

    def _generate_unit_cell(self):

        frac_vertices = [[0, 0, 0], [1, 0, 0], [0, 1, 0], [0, 0, 1], [1, 1, 0],
                         [0, 1, 1], [1, 0, 1], [1, 1, 1]]

        cart_vertices = [
            list(
                self.lattice.get_cartesian_coords(
                    np.subtract(vert, (0.5, 0.5, 0.5))))
            for vert in frac_vertices
        ]

        tri = Delaunay(cart_vertices)

        unit_cell = {
            'type': 'convex',
            'points': cart_vertices,
            'hull': tri.convex_hull.tolist(),
        }

        return {'unit_cell': unit_cell}

    def _generate_polyhedra(self):

        if not self.bonded_atoms_outside_unit_cell:
            raise ValueError(
                "All bonds from unit cell must be drawn to be able to reliably "
                "draw co-ordination polyhedra, please set "
                "bonded_atoms_outside_unit_cell to True.")

        # this just creates a list of all bonded atoms from each site
        # (this isn't used for the bonding itself, otherwise we'd be
        # double counting bonds, but for polyhedra, in principle each
        # site has its own polyhedra defined)
        potential_polyhedra = {i: [] for i in range(len(self.structure_graph))}

        for from_site_idx, from_image, to_site_idx, to_image in self._bonds:
            if from_image == (0, 0, 0):
                potential_polyhedra[from_site_idx].append(
                    (to_site_idx, to_image))
            if to_image == (0, 0, 0):
                potential_polyhedra[to_site_idx].append(
                    (from_site_idx, from_image))

        # sort sites for which polyhedra we should prioritize: we don't want
        # polyhedra to intersect each other
        # TODO: placeholder
        sorted_sites_idxs = list(range(len(self.structure_graph)))

        # now we actually store the polyhedra we want!
        # a single polyhedron is simply a list of atom indexes that form
        # its vertices; a convex hull algorithm is necessary to actually
        # construct the faces
        polyhedra = []
        polyhedra_types = set()
        for site_idx in sorted_sites_idxs:
            if site_idx in potential_polyhedra:
                polyhedron_points = potential_polyhedra[site_idx]
                # remove the polyhedron's vertices from the list of potential polyhedra centers
                for (site, image) in polyhedron_points:
                    if image == (0, 0, 0) and site in potential_polyhedra:
                        del potential_polyhedra[site]
                # and look up the actual positions in the array of atoms we're drawing

                polyhedron_points_cart = [
                    self._atoms_cart[(site, image)]
                    for (site, image) in polyhedron_points
                ]

                polyhedron_points_idx = [
                    self._atom_indexes[(site, image)]
                    for (site, image) in polyhedron_points
                ]

                # calculate the hull
                #print(len(polyhedron_points_cart))
                #print(polyhedron_points_cart)

                try:
                    tri = Delaunay(polyhedron_points_cart)

                    # a pretty name, helps with filtering too
                    site = self.structure_graph.structure[site_idx]
                    species = ", ".join(
                        map(str, list(site.species_and_occu.keys())))
                    polyhedron_type = '{}-centered'.format(species)
                    polyhedra_types.add(polyhedron_type)

                    polyhedra.append({
                        'type': 'convex',
                        'points_idx': polyhedron_points_idx,
                        'points': polyhedron_points_cart,
                        'hull': tri.convex_hull,
                        'name': polyhedron_type,
                        'center': site_idx
                    })

                except Exception as e:
                    print(e)

        return {
            'polyhedra': {
                'polyhedra_list': polyhedra,
                'polyhedra_types': list(polyhedra_types)
            }
        }
Exemple #14
0
class StructureMoleculeComponent(MPComponent):

    available_bonding_strategies = {
        subclass.__name__: subclass for subclass in NearNeighbors.__subclasses__()
    }

    available_radius_strategies = (
        "atomic",
        "specified_or_average_ionic",
        "covalent",
        "van_der_waals",
        "atomic_calculated",
        "uniform",
    )

    default_scene_settings = {"cylinderScale": 0.1}

    def __init__(
        self,
        struct_or_mol=None,
        id=None,
        origin_component=None,
        scene_additions=None,
        bonding_strategy="MinimumDistanceNN",
        bonding_strategy_kwargs=None,
        color_scheme="VESTA",
        color_scale=None,
        radius_strategy="uniform",
        radius_scale=1.0,
        draw_image_atoms=True,
        bonded_sites_outside_unit_cell=False,
        hide_incomplete_bonds=False,
        show_compass=False,
        scene_settings=None,
        **kwargs,
    ):

        super().__init__(
            id=id, contents=struct_or_mol, origin_component=origin_component, **kwargs
        )

        self.default_title = "Crystal Toolkit"

        self.initial_scene_settings = (
            StructureMoleculeComponent.default_scene_settings.copy()
        )
        if scene_settings:
            self.initial_scene_settings.update(scene_settings)

        self.create_store("scene_settings", initial_data=self.initial_scene_settings)

        self.initial_graph_generation_options = {
            "bonding_strategy": bonding_strategy,
            "bonding_strategy_kwargs": bonding_strategy_kwargs,
        }
        self.create_store(
            "graph_generation_options",
            initial_data=self.initial_graph_generation_options,
        )

        self.initial_display_options = {
            "color_scheme": color_scheme,
            "color_scale": color_scale,
            "radius_strategy": radius_strategy,
            "radius_scale": radius_scale,
            "draw_image_atoms": draw_image_atoms,
            "bonded_sites_outside_unit_cell": bonded_sites_outside_unit_cell,
            "hide_incomplete_bonds": hide_incomplete_bonds,
            "show_compass": show_compass,
        }
        self.create_store("display_options", initial_data=self.initial_display_options)

        if scene_additions:
            self.initial_scene_additions = Scene(
                name="scene_additions", contents=scene_additions
            )
        else:
            self.initial_scene_additions = Scene(name="scene_additions")
        self.create_store(
            "scene_additions", initial_data=self.initial_scene_additions.to_json()
        )

        if struct_or_mol:
            # graph is cached explicitly, this isn't necessary but is an
            # optimization so that graph is only re-generated if bonding
            # algorithm changes
            graph = self._preprocess_input_to_graph(
                struct_or_mol,
                bonding_strategy=bonding_strategy,
                bonding_strategy_kwargs=bonding_strategy_kwargs,
            )
            scene, legend = self.get_scene_and_legend(
                graph,
                name=self.id(),
                scene_additions=self.initial_scene_additions,
                **self.initial_display_options,
            )
            if hasattr(struct_or_mol, "lattice"):
                self._lattice = struct_or_mol.lattice
        else:
            # component could be initialized without a structure, in which case
            # an empty scene should be displayed
            graph = None
            scene, legend = self.get_scene_and_legend(
                None,
                name=self.id(),
                scene_additions=self.initial_scene_additions,
                **self.initial_display_options,
            )

        self.initial_legend = legend
        self.create_store("legend_data", initial_data=self.initial_legend)

        self.initial_scene_data = scene.to_json()

        self.initial_graph = graph
        self.create_store("graph", initial_data=self.to_data(graph))

    def generate_callbacks(self, app, cache):
        @app.callback(
            Output(self.id("graph"), "data"),
            [
                Input(self.id("graph_generation_options"), "data"),
                Input(self.id("unit-cell-choice"), "value"),
                Input(self.id(), "data"),
            ],
        )
        def update_graph(graph_generation_options, unit_cell_choice, struct_or_mol):

            if not struct_or_mol:
                raise PreventUpdate

            struct_or_mol = self.from_data(struct_or_mol)
            graph_generation_options = self.from_data(graph_generation_options)

            if isinstance(struct_or_mol, Structure):
                if unit_cell_choice != "input":
                    if unit_cell_choice == "primitive":
                        struct_or_mol = struct_or_mol.get_primitive_structure()
                    elif unit_cell_choice == "conventional":
                        sga = SpacegroupAnalyzer(struct_or_mol)
                        struct_or_mol = sga.get_conventional_standard_structure()
                    elif unit_cell_choice == "reduced":
                        struct_or_mol = struct_or_mol.get_reduced_structure()

            graph = self._preprocess_input_to_graph(
                struct_or_mol,
                bonding_strategy=graph_generation_options["bonding_strategy"],
                bonding_strategy_kwargs=graph_generation_options[
                    "bonding_strategy_kwargs"
                ],
            )

            self.logger.debug("Constructed graph")

            return self.to_data(graph)

        @app.callback(
            Output(self.id("scene"), "data"),
            [
                Input(self.id("graph"), "data"),
                Input(self.id("display_options"), "data"),
            ],
        )
        def update_scene(graph, display_options):
            display_options = self.from_data(display_options)
            graph = self.from_data(graph)
            scene, legend = self.get_scene_and_legend(graph, **display_options)
            return scene.to_json()

        @app.callback(
            Output(self.id("legend_data"), "data"),
            [
                Input(self.id("graph"), "data"),
                Input(self.id("display_options"), "data"),
            ],
        )
        def update_legend(graph, display_options):
            # TODO: more cleanly split legend from scene generation
            display_options = self.from_data(display_options)
            graph = self.from_data(graph)
            struct_or_mol = self._get_struct_or_mol(graph)
            site_prop_types = self._analyze_site_props(struct_or_mol)
            colors, legend = self._get_display_colors_and_legend_for_sites(
                struct_or_mol,
                site_prop_types,
                color_scheme=display_options.get("color_scheme", None),
                color_scale=display_options.get("color_scale", None),
            )
            return self.to_data(legend)

        @app.callback(
            Output(self.id("color-scheme"), "options"),
            [Input(self.id("graph"), "data")],
        )
        def update_color_options(graph):

            options = [
                {"label": "Jmol", "value": "Jmol"},
                {"label": "VESTA", "value": "VESTA"},
                {"label": "Colorblind-friendly", "value": "colorblind_friendly"},
            ]
            graph = self.from_data(graph)
            struct_or_mol = self._get_struct_or_mol(graph)
            site_props = self._analyze_site_props(struct_or_mol)
            for site_prop_type in ("scalar", "categorical"):
                if site_prop_type in site_props:
                    for prop in site_props[site_prop_type]:
                        options += [{"label": f"Site property: {prop}", "value": prop}]

            return options

        @app.callback(
            Output(self.id("display_options"), "data"),
            [
                Input(self.id("color-scheme"), "value"),
                Input(self.id("radius_strategy"), "value"),
                Input(self.id("draw_options"), "value"),
            ],
            [State(self.id("display_options"), "data")],
        )
        def update_display_options(
            color_scheme, radius_strategy, draw_options, display_options
        ):
            display_options = self.from_data(display_options)
            display_options.update({"color_scheme": color_scheme})
            display_options.update({"radius_strategy": radius_strategy})
            display_options.update(
                {"draw_image_atoms": "draw_image_atoms" in draw_options}
            )
            display_options.update(
                {
                    "bonded_sites_outside_unit_cell": "bonded_sites_outside_unit_cell"
                    in draw_options
                }
            )
            display_options.update(
                {"hide_incomplete_bonds": "hide_incomplete_bonds" in draw_options}
            )

            if display_options == self.initial_display_options:
                raise PreventUpdate

            self.logger.debug("Display options updated")

            return self.to_data(display_options)

        @app.callback(
            Output(self.id("scene"), "downloadRequest"),
            [Input(self.id("screenshot_button"), "n_clicks")],
            [State(self.id("scene"), "downloadRequest"), State(self.id(), "data")],
        )
        def screenshot_callback(n_clicks, current_requests, struct_or_mol):
            if n_clicks is None:
                raise PreventUpdate
            struct_or_mol = self.from_data(struct_or_mol)
            # TODO: this will break if store is structure/molecule graph ...
            formula = struct_or_mol.composition.reduced_formula
            if hasattr(struct_or_mol, "get_space_group_info"):
                spgrp = struct_or_mol.get_space_group_info()[0]
            else:
                spgrp = ""
            request_filename = "{}-{}-crystal-toolkit.png".format(formula, spgrp)
            if not current_requests:
                n_requests = 1
            else:
                n_requests = current_requests["n_requests"] + 1
            return {
                "n_requests": n_requests,
                "filename": request_filename,
                "filetype": "png",
            }

        @app.callback(
            Output(self.id("scene"), "toggleVisibility"),
            [Input(self.id("hide-show"), "value")],
            [State(self.id("hide-show"), "options")],
        )
        def update_visibility(values, options):
            visibility = {opt["value"]: (opt["value"] in values) for opt in options}
            return visibility

        @app.callback(
            [
                Output(self.id("legend_container"), "children"),
                Output(self.id("title_container"), "children"),
            ],
            [Input(self.id("legend_data"), "data")],
        )
        def update_legend(legend):

            legend = self.from_data(legend)

            if legend == self.initial_legend:
                raise PreventUpdate

            return self._make_legend(legend), self._make_title(legend)

        @app.callback(
            Output(self.id("graph_generation_options"), "data"),
            [
                Input(self.id("bonding_algorithm"), "value"),
                Input(self.id("bonding_algorithm_custom_cutoffs"), "data"),
            ],
        )
        def update_structure_viewer_data(bonding_algorithm, custom_cutoffs_rows):

            graph_generation_options = {
                "bonding_strategy": bonding_algorithm,
                "bonding_strategy_kwargs": None,
            }

            if graph_generation_options == self.initial_graph_generation_options:
                raise PreventUpdate

            if bonding_algorithm == "CutOffDictNN":
                # this is not the format CutOffDictNN expects (since that is not JSON
                # serializable), so we store as a list of tuples instead
                # TODO: make CutOffDictNN args JSON serializable
                custom_cutoffs = [
                    (row["A"], row["B"], float(row["A—B"]))
                    for row in custom_cutoffs_rows
                ]
                graph_generation_options["bonding_strategy_kwargs"] = {
                    "cut_off_dict": custom_cutoffs
                }
            return self.to_data(graph_generation_options)

        @app.callback(
            [
                Output(self.id("bonding_algorithm_custom_cutoffs"), "data"),
                Output(self.id("bonding_algorithm_custom_cutoffs_container"), "style"),
            ],
            [Input(self.id("bonding_algorithm"), "value")],
            [State(self.id("graph"), "data")],
        )
        def update_custom_bond_options(bonding_algorithm, graph):

            if not graph:
                raise PreventUpdate

            if bonding_algorithm == "CutOffDictNN":
                style = {}
            else:
                style = {"display": "none"}

            graph = self.from_data(graph)
            struct_or_mol = self._get_struct_or_mol(graph)
            # can't use type_of_specie because it doesn't work with disordered structures
            species = set(
                map(
                    str,
                    chain.from_iterable(
                        [list(c.keys()) for c in struct_or_mol.species_and_occu]
                    ),
                )
            )
            rows = [
                {"A": combination[0], "B": combination[1], "A—B": 0}
                for combination in combinations_with_replacement(species, 2)
            ]
            return rows, style

    def _make_legend(self, legend):

        if legend is None or (not legend.get("colors", None)):
            return html.Div(id=self.id("legend"))

        def get_font_color(hex_code):
            # ensures contrasting font color for background color
            c = tuple(int(hex_code[1:][i : i + 2], 16) for i in (0, 2, 4))
            if 1 - (c[0] * 0.299 + c[1] * 0.587 + c[2] * 0.114) / 255 < 0.5:
                font_color = "#000000"
            else:
                font_color = "#ffffff"
            return font_color

        try:
            formula = Composition.from_dict(legend["composition"]).reduced_formula
        except:
            # TODO: fix for Dummy Specie compositions
            formula = "Unknown"

        legend_colors = OrderedDict(
            sorted(list(legend["colors"].items()), key=lambda x: formula.find(x[1]))
        )

        legend_elements = [
            Button(
                html.Span(
                    name, className="icon", style={"color": get_font_color(color)}
                ),
                kind="static",
                style={"background-color": color},
            )
            for color, name in legend_colors.items()
        ]

        return Field(
            [Control(el, style={"margin-right": "0.2rem"}) for el in legend_elements],
            id=self.id("legend"),
            grouped=True,
        )

    def _make_title(self, legend):

        if not legend or (not legend.get("composition", None)):
            return H1(self.default_title, id=self.id("title"))

        composition = legend["composition"]
        if isinstance(composition, dict):

            # TODO: make Composition handle DummySpecie
            try:
                composition = Composition.from_dict(composition)
                formula = composition.reduced_formula
                formula_parts = re.findall(r"[^\d_]+|\d+", formula)
                formula_components = [
                    html.Sub(part) if part.isnumeric() else html.Span(part)
                    for part in formula_parts
                ]
            except:
                formula_components = list(composition.keys())

        return H1(
            formula_components, id=self.id("title"), style={"display": "inline-block"}
        )

    @property
    def all_layouts(self):

        struct_layout = html.Div(
            Simple3DSceneComponent(
                id=self.id("scene"),
                data=self.initial_scene_data,
                settings=self.initial_scene_settings,
            ),
            style={
                "width": "100%",
                "height": "100%",
                "overflow": "hidden",
                "margin": "0 auto",
            },
        )

        screenshot_layout = html.Div(
            [
                Button(
                    [Icon(), html.Span(), "Download Image"],
                    kind="primary",
                    id=self.id("screenshot_button"),
                )
            ],
            # TODO: change to "bottom" when dropdown included
            style={"vertical-align": "top", "display": "inline-block"},
        )

        title_layout = html.Div(
            self._make_title(self.initial_legend), id=self.id("title_container")
        )

        legend_layout = html.Div(
            self._make_legend(self.initial_legend), id=self.id("legend_container")
        )

        nn_mapping = {
            "CrystalNN": "CrystalNN",
            "Custom Bonds": "CutOffDictNN",
            "Jmol Bonding": "JmolNN",
            "Minimum Distance (10% tolerance)": "MinimumDistanceNN",
            "O'Keeffe's Algorithm": "MinimumOKeeffeNN",
            "Hoppe's ECoN Algorithm": "EconNN",
            "Brunner's Reciprocal Algorithm": "BrunnerNN_reciprocal",
        }

        bonding_algorithm = dcc.Dropdown(
            options=[{"label": k, "value": v} for k, v in nn_mapping.items()],
            value="CrystalNN",
            id=self.id("bonding_algorithm"),
        )

        bonding_algorithm_custom_cutoffs = html.Div(
            [
                html.Br(),
                dt.DataTable(
                    columns=[
                        {"name": "A", "id": "A"},
                        {"name": "B", "id": "B"},
                        {"name": "A—B /Å", "id": "A—B"},
                    ],
                    editable=True,
                    id=self.id("bonding_algorithm_custom_cutoffs"),
                ),
                html.Br(),
            ],
            id=self.id("bonding_algorithm_custom_cutoffs_container"),
            style={"display": "none"},
        )

        options_layout = Field(
            [
                #  TODO: hide if molecule
                html.Label("Change unit cell:", className="mpc-label"),
                html.Div(
                    dcc.RadioItems(
                        options=[
                            {"label": "Input cell", "value": "input"},
                            {"label": "Primitive cell", "value": "primitive"},
                            {"label": "Conventional cell", "value": "conventional"},
                            {"label": "Reduced cell", "value": "reduced"},
                        ],
                        value="input",
                        id=self.id("unit-cell-choice"),
                        labelStyle={"display": "block"},
                        inputClassName="mpc-radio",
                    ),
                    className="mpc-control",
                ),
                html.Div(
                    [
                        html.Label("Change bonding algorithm: ", className="mpc-label"),
                        bonding_algorithm,
                        bonding_algorithm_custom_cutoffs,
                    ]
                ),
                html.Label("Change color scheme:", className="mpc-label"),
                html.Div(
                    dcc.Dropdown(
                        options=[
                            {"label": "VESTA", "value": "VESTA"},
                            {"label": "Jmol", "value": "Jmol"},
                        ],
                        value=self.initial_display_options["color_scheme"],
                        clearable=False,
                        id=self.id("color-scheme"),
                    ),
                    className="mpc-control",
                ),
                html.Label("Change atomic radii:", className="mpc-label"),
                html.Div(
                    dcc.Dropdown(
                        options=[
                            {"label": "Ionic", "value": "specified_or_average_ionic"},
                            {"label": "Covalent", "value": "covalent"},
                            {"label": "Van der Waals", "value": "van_der_waals"},
                            {"label": "Uniform (0.5Å)", "value": "uniform"},
                        ],
                        value=self.initial_display_options["radius_strategy"],
                        clearable=False,
                        id=self.id("radius_strategy"),
                    ),
                    className="mpc-control",
                ),
                html.Label("Draw options:", className="mpc-label"),
                html.Div(
                    [
                        dcc.Checklist(
                            options=[
                                {
                                    "label": "Draw repeats of atoms on periodic boundaries",
                                    "value": "draw_image_atoms",
                                },
                                {
                                    "label": "Draw atoms outside unit cell bonded to "
                                    "atoms within unit cell",
                                    "value": "bonded_sites_outside_unit_cell",
                                },
                                {
                                    "label": "Hide bonds where destination atoms are not shown",
                                    "value": "hide_incomplete_bonds",
                                },
                            ],
                            value=[
                                opt
                                for opt in (
                                    "draw_image_atoms",
                                    "bonded_sites_outside_unit_cell",
                                    "hide_incomplete_bonds",
                                )
                                if self.initial_display_options[opt]
                            ],
                            labelStyle={"display": "block"},
                            inputClassName="mpc-radio",
                            id=self.id("draw_options"),
                        )
                    ]
                ),
                html.Label("Hide/show:", className="mpc-label"),
                html.Div(
                    [
                        dcc.Checklist(
                            options=[
                                {"label": "Atoms", "value": "atoms"},
                                {"label": "Bonds", "value": "bonds"},
                                {"label": "Unit cell", "value": "unit_cell"},
                                {"label": "Polyhedra", "value": "polyhedra"},
                            ],
                            value=["atoms", "bonds", "unit_cell", "polyhedra"],
                            labelStyle={"display": "block"},
                            inputClassName="mpc-radio",
                            id=self.id("hide-show"),
                        )
                    ],
                    className="mpc-control",
                ),
            ]
        )

        return {
            "struct": struct_layout,
            "screenshot": screenshot_layout,
            "options": options_layout,
            "title": title_layout,
            "legend": legend_layout,
        }

    @property
    def standard_layout(self):
        return html.Div(
            self.all_layouts["struct"], style={"width": "100vw", "height": "100vh"}
        )

    @staticmethod
    def _preprocess_input_to_graph(
        input: Union[Structure, StructureGraph, Molecule, MoleculeGraph],
        bonding_strategy: str = "CrystalNN",
        bonding_strategy_kwargs: Optional[Dict] = None,
    ) -> Union[StructureGraph, MoleculeGraph]:

        if isinstance(input, Structure):

            # ensure fractional co-ordinates are normalized to be in [0,1)
            # (this is actually not guaranteed by Structure)
            try:
                input = input.as_dict(verbosity=0)
            except TypeError:
                # TODO: remove this, necessary for Slab(?), some structure subclasses don't have verbosity
                input = input.as_dict()
            for site in input["sites"]:
                site["abc"] = np.mod(site["abc"], 1)
            input = Structure.from_dict(input)

            if not input.is_ordered:
                # calculating bonds in disordered structures is currently very flaky
                bonding_strategy = "CutOffDictNN"

        # we assume most uses of this class will give a structure as an input argument,
        # meaning we have to calculate the graph for bonding information, however if
        # the graph is already known and supplied, we will use that
        if isinstance(input, StructureGraph) or isinstance(input, MoleculeGraph):
            graph = input
        else:
            if (
                bonding_strategy
                not in StructureMoleculeComponent.available_bonding_strategies.keys()
            ):
                raise ValueError(
                    "Bonding strategy not supported. Please supply a name "
                    "of a NearNeighbor subclass, choose from: {}".format(
                        ", ".join(
                            StructureMoleculeComponent.available_bonding_strategies.keys()
                        )
                    )
                )
            else:
                bonding_strategy_kwargs = bonding_strategy_kwargs or {}
                if bonding_strategy == "CutOffDictNN":
                    if "cut_off_dict" in bonding_strategy_kwargs:
                        # TODO: remove this hack by making args properly JSON serializable
                        bonding_strategy_kwargs["cut_off_dict"] = {
                            (x[0], x[1]): x[2]
                            for x in bonding_strategy_kwargs["cut_off_dict"]
                        }
                bonding_strategy = StructureMoleculeComponent.available_bonding_strategies[
                    bonding_strategy
                ](
                    **bonding_strategy_kwargs
                )
                try:
                    if isinstance(input, Structure):
                        graph = StructureGraph.with_local_env_strategy(
                            input, bonding_strategy
                        )
                    else:
                        graph = MoleculeGraph.with_local_env_strategy(
                            input, bonding_strategy
                        )
                except:
                    # for some reason computing bonds failed, so let's not have any bonds(!)
                    if isinstance(input, Structure):
                        graph = StructureGraph.with_empty_graph(input)
                    else:
                        graph = MoleculeGraph.with_empty_graph(input)

        return graph

    @staticmethod
    def _analyze_site_props(struct_or_mol):

        # store list of site props that are vectors, so these can be displayed as arrows
        # (implicitly assumes all site props for a given key are same type)
        site_prop_names = defaultdict(list)
        for name, props in struct_or_mol.site_properties.items():
            if isinstance(props[0], float) or isinstance(props[0], int):
                site_prop_names["scalar"].append(name)
            elif isinstance(props[0], list) and len(props[0]) == 3:
                if isinstance(props[0][0], list) and len(props[0][0]) == 3:
                    site_prop_names["matrix"].append(name)
                else:
                    site_prop_names["vector"].append(name)
            elif isinstance(props[0], str):
                site_prop_names["categorical"].append(name)

        return dict(site_prop_names)

    @staticmethod
    def _get_origin(struct_or_mol):

        if isinstance(struct_or_mol, Structure):
            # display_range = [0.5, 0.5, 0.5]
            # x_center = 0.5 * (max(display_range[0]) - min(display_range[0]))
            # y_center = 0.5 * (max(display_range[1]) - min(display_range[1]))
            # z_center = 0.5 * (max(display_range[2]) - min(display_range[2]))
            geometric_center = struct_or_mol.lattice.get_cartesian_coords(
                (0.5, 0.5, 0.5)
            )
        elif isinstance(struct_or_mol, Molecule):
            geometric_center = np.average(struct_or_mol.cart_coords, axis=0)
        else:
            geometric_center = (0, 0, 0)

        return geometric_center

    @staticmethod
    def _get_struct_or_mol(graph) -> Union[Structure, Molecule]:
        if isinstance(graph, StructureGraph):
            return graph.structure
        elif isinstance(graph, MoleculeGraph):
            return graph.molecule
        else:
            raise ValueError

    @staticmethod
    def _compass_from_lattice(
        lattice,
        origin=(0, 0, 0),
        scale=0.7,
        offset=0.15,
        compass_style="corner",
        **kwargs,
    ):
        # TODO: add along lattice
        """
        Get the display components of the compass
        :param lattice: the pymatgen Lattice object that contains the primitive lattice vectors
        :param origin: the reference position to place the compass
        :param scale: scale all the geometric objects that makes up the compass the lattice vectors are normalized before the scaling so everything should be the same size
        :param offset: shift the compass from the origin by a ratio of the diagonal of the cell relative the size 
        :return: list of cystal_toolkit.helper.scene objects that makes up the compass
        """
        o = -np.array(origin)
        o = o - offset * (lattice.matrix[0] + lattice.matrix[1] + lattice.matrix[2])
        a = lattice.matrix[0] / np.linalg.norm(lattice.matrix[0]) * scale
        b = lattice.matrix[1] / np.linalg.norm(lattice.matrix[1]) * scale
        c = lattice.matrix[2] / np.linalg.norm(lattice.matrix[2]) * scale
        a_arrow = [[o, o + a]]
        b_arrow = [[o, o + b]]
        c_arrow = [[o, o + c]]

        o_sphere = Spheres(positions=[o], color="black", radius=0.1 * scale)

        return Scene(name='compass', contents=[
            Arrows(
                a_arrow,
                color="red",
                radius=0.7 * scale,
                headLength=2.3 * scale,
                headWidth=1.4 * scale,
                **kwargs,
            ),
            Arrows(
                b_arrow,
                color="blue",
                radius=0.7 * scale,
                headLength=2.3 * scale,
                headWidth=1.4 * scale,
                **kwargs,
            ),
            Arrows(
                c_arrow,
                color="green",
                radius=0.7 * scale,
                headLength=2.3 * scale,
                headWidth=1.4 * scale,
                **kwargs,
            ),
            o_sphere,
        ])

    @staticmethod
    def _get_display_colors_and_legend_for_sites(
        struct_or_mol, site_prop_types, color_scheme="Jmol", color_scale=None
    ) -> Tuple[List[List[str]], Dict]:
        """
        Note this returns a list of lists of strings since each
        site might have multiple colors defined if the site is
        disordered.

        The legend is a dictionary whose keys are colors and values
        are corresponding element names or values, depending on the color
        scheme chosen.
        """

        # TODO: check to see if there is a bug here due to Composition being unordered(?)

        legend = {"composition": struct_or_mol.composition.as_dict(), "colors": {}}

        # don't calculate color if one is explicitly supplied
        if "display_color" in struct_or_mol.site_properties:
            # don't know what the color legend (meaning) is, so return empty legend
            return (struct_or_mol.site_properties["display_color"], legend)

        def get_color_hex(x):
            return "#{:02x}{:02x}{:02x}".format(*x)

        allowed_schemes = (
            ["VESTA", "Jmol", "colorblind_friendly"]
            + site_prop_types.get("scalar", [])
            + site_prop_types.get("categorical", [])
        )
        default_scheme = "Jmol"
        if color_scheme not in allowed_schemes:
            warnings.warn(
                f"Color scheme {color_scheme} not available, falling back to {default_scheme}."
            )
            color_scheme = default_scheme

        if color_scheme not in ("VESTA", "Jmol", "colorblind_friendly"):

            if not struct_or_mol.is_ordered:
                raise ValueError(
                    "Can only use VESTA, Jmol or colorblind_friendly color "
                    "schemes for disordered structures or molecules, color "
                    "schemes based on site properties are ill-defined."
                )

            if (color_scheme not in site_prop_types.get("scalar", [])) and (
                color_scheme not in site_prop_types.get("categorical", [])
            ):

                raise ValueError(
                    "Unsupported color scheme. Should be VESTA, Jmol, "
                    "colorblind_friendly or a scalar (float) or categorical "
                    "(string) site property."
                )

        if color_scheme in ("VESTA", "Jmol"):

            #  TODO: define fallback color as global variable
            # TODO: maybe fallback categorical based on letter, for DummySpecie?

            colors = []
            for site in struct_or_mol:
                elements = [sp.as_dict()["element"] for sp, _ in site.species.items()]
                colors.append(
                    [
                        get_color_hex(EL_COLORS[color_scheme].get(element, [0, 0, 0]))
                        for element in elements
                    ]
                )
                # construct legend
                for element in elements:
                    color = get_color_hex(
                        EL_COLORS[color_scheme].get(element, [0, 0, 0])
                    )
                    label = unicodeify_species(site.species_string)
                    if color in legend["colors"] and legend["colors"][color] != label:
                        legend["colors"][
                            color
                        ] = f"{element}ˣ"  # TODO: mixed valence, improve this
                    else:
                        legend["colors"][color] = label

        elif color_scheme == "colorblind_friendly":

            labels = [site.species_string for site in struct_or_mol]

            # thanks to https://doi.org/10.1038/nmeth.1618
            palette = [
                [0, 0, 0],  # 0, black
                [230, 159, 0],  # 1, orange
                [86, 180, 233],  # 2, sky blue
                [0, 158, 115],  #  3, bluish green
                [240, 228, 66],  # 4, yellow
                [0, 114, 178],  # 5, blue
                [213, 94, 0],  # 6, vermillion
                [204, 121, 167],  # 7, reddish purple
                [255, 255, 255],  #  8, white
            ]

            # similar to CPK
            preferred_colors = {
                "O": 6,
                "N": 2,
                "C": 0,
                "H": 8,
                "F": 3,
                "Cl": 3,
                "Fe": 1,
                "Br": 7,
                "I": 7,
                "P": 1,
                "S": 4,
            }

            if len(set(labels)) > len(palette):
                warnings.warn(
                    "Too many distinct types of site to use a color-blind friendly color scheme."
                )

        # colors = [......]
        # present_specie = sorted(struct_or_mol.types_of_specie)
        # if len(struct_or_mol.types_of_specie) > len(colors):
        #
        #    colors.append([DEFAULT_COLOR]*(len(struct_or_mol.types_of_specie)-len(colors))
        # # test for disordered structures too!
        # # try to prefer certain colors of certain elements for historical consistency
        # preferred_colors = {"O": 1}  # idx of colors
        # for el, idx in preferred_colors.items():
        #   if el in present_specie:
        #       want (idx of el in present_specie) to match idx
        #       colors.swap(idx to present_specie_idx)
        # color_scheme = {el:colors[idx] for idx, el in enumerate(sorted(struct_or_mol.types_of_specie))}

        elif color_scheme in site_prop_types.get("scalar", []):

            props = np.array(struct_or_mol.site_properties[color_scheme])

            # by default, use blue-grey-red color scheme,
            # so that zero is ~ grey, and positive/negative
            # are red/blue
            color_scale = color_scale or "coolwarm"
            # try to keep color scheme symmetric around 0
            prop_max = max([abs(min(props)), max(props)])
            prop_min = -prop_max

            cmap = get_cmap(color_scale)
            # normalize in [0, 1] range, as expected by cmap
            props_normed = (props - prop_min) / (prop_max - prop_min)

            def get_color_cmap(x):
                return [int(c * 255) for c in cmap(x)[0:3]]

            colors = [[get_color_hex(get_color_cmap(x))] for x in props_normed]

            # construct legend
            rounded_props = sorted(list(set([np.around(p, decimals=1) for p in props])))
            for prop in rounded_props:
                prop_normed = (prop - prop_min) / (prop_max - prop_min)
                c = get_color_hex(get_color_cmap(prop_normed))
                legend["colors"][c] = "{:.1f}".format(prop)

        elif color_scheme in site_prop_types.get("categorical", []):

            props = np.array(struct_or_mol.site_properties[color_scheme])

            palette = [get_color_hex(c) for c in Set1_9.colors]

            le = LabelEncoder()
            le.fit(props)
            transformed_props = le.transform(props)

            # if we have more categories than availiable colors,
            # arbitrarily group some categories together
            warnings.warn(
                "Too many categories for a complete categorical " "color scheme."
            )
            transformed_props = [
                p if p < len(palette) else -1 for p in transformed_props
            ]

            colors = [[palette[p]] for p in transformed_props]

            for category, p in zip(props, transformed_props):
                legend["colors"][palette[p]] = category

        return colors, legend

    @staticmethod
    def _get_display_radii_for_sites(
        struct_or_mol, radius_strategy="specified_or_average_ionic", radius_scale=1.0
    ) -> List[List[float]]:
        """
        Note this returns a list of lists of floats since each
        site might have multiple radii defined if the site is
        disordered.
        """

        # don't calculate radius if one is explicitly supplied
        if "display_radius" in struct_or_mol.site_properties:
            return struct_or_mol.site_properties["display_radius"]

        if (
            radius_strategy
            not in StructureMoleculeComponent.available_radius_strategies
        ):
            raise ValueError(
                "Unknown radius strategy {}, choose from: {}".format(
                    radius_strategy,
                    StructureMoleculeComponent.available_radius_strategies,
                )
            )
        radii = []

        for site_idx, site in enumerate(struct_or_mol):

            site_radii = []

            for comp_idx, (sp, occu) in enumerate(site.species.items()):

                radius = None

                if radius_strategy == "uniform":
                    radius = 0.5
                if radius_strategy == "atomic":
                    radius = sp.atomic_radius
                elif (
                    radius_strategy == "specified_or_average_ionic"
                    and isinstance(sp, Specie)
                    and sp.oxi_state
                ):
                    radius = sp.ionic_radius
                elif radius_strategy == "specified_or_average_ionic":
                    radius = sp.average_ionic_radius
                elif radius_strategy == "covalent":
                    el = str(getattr(sp, "element", sp))
                    radius = CovalentRadius.radius[el]
                elif radius_strategy == "van_der_waals":
                    radius = sp.van_der_waals_radius
                elif radius_strategy == "atomic_calculated":
                    radius = sp.atomic_radius_calculated

                if not radius:
                    warnings.warn(
                        "Radius unknown for {} and strategy {}, "
                        "setting to 1.0.".format(sp, radius_strategy)
                    )
                    radius = 1.0

                radius = radius * radius_scale
                site_radii.append(radius)

            radii.append(site_radii)

        return radii

    @staticmethod
    def get_scene_and_legend(
        graph: Union[StructureGraph, MoleculeGraph],
        name="StructureMoleculeComponent",
        color_scheme="Jmol",
        color_scale=None,
        radius_strategy="specified_or_average_ionic",
        radius_scale=1.0,
        ellipsoid_site_prop=None,
        draw_image_atoms=True,
        bonded_sites_outside_unit_cell=True,
        hide_incomplete_bonds=False,
        explicitly_calculate_polyhedra_hull=False,
        scene_additions=None,
        show_compass=True,
    ) -> Tuple[Scene, Dict[str, str]]:

        scene = Scene(name=name)

        if graph is None:
            return scene, {}

        struct_or_mol = StructureMoleculeComponent._get_struct_or_mol(graph)
        site_prop_types = StructureMoleculeComponent._analyze_site_props(struct_or_mol)

        radii = StructureMoleculeComponent._get_display_radii_for_sites(
            struct_or_mol, radius_strategy=radius_strategy, radius_scale=radius_scale
        )
        colors, legend = StructureMoleculeComponent._get_display_colors_and_legend_for_sites(
            struct_or_mol,
            site_prop_types,
            color_scale=color_scale,
            color_scheme=color_scheme,
        )

        # TODO: add set_display_color option, set_display_radius, set_ellipsoid
        # call it "set_display_options" ?
        # sets legend too! display_legend

        struct_or_mol.add_site_property("display_radius", radii)
        struct_or_mol.add_site_property("display_color", colors)

        origin = StructureMoleculeComponent._get_origin(struct_or_mol)

        scene = graph.get_scene(
            draw_image_atoms=draw_image_atoms,
            bonded_sites_outside_unit_cell=bonded_sites_outside_unit_cell,
            hide_incomplete_edges=hide_incomplete_bonds,
            explicitly_calculate_polyhedra_hull=explicitly_calculate_polyhedra_hull,
            origin=origin,
        )

        scene.name = name
        # TODO: ...
        scene.origin = StructureMoleculeComponent._get_origin(struct_or_mol)

        if show_compass:
            scene.contents.append(
                StructureMoleculeComponent._compass_from_lattice(
                    struct_or_mol.lattice, origin=origin
                )
            )

        if scene_additions:
            scene.contents.append(scene_additions)

        return scene, legend