def set_nn_info(self): # check conformance # implicitly assumes that all NearNeighbors subclasses # will correctly identify bonds in diamond, if it # can't there are probably bigger problems subclasses = NearNeighbors.__subclasses__() for subclass in subclasses: nn_info = subclass().get_nn_info(self.diamond, 0) self.assertEqual(nn_info[0]['site_index'], 1) self.assertEqual(nn_info[0]['image'][0], 1)
def set_nn_info(self): # check conformance # implicitly assumes that all NearNeighbors subclasses # will correctly identify bonds in diamond, if it # can't there are probably bigger problems subclasses = NearNeighbors.__subclasses__() for subclass in subclasses: nn_info = subclass().get_nn_info(self.diamond, 0) self.assertEqual(nn_info[0]['site_index'], 1) self.assertEqual(nn_info[0]['image'][0], 1)
def set_nn_info(self): # check conformance # implicitly assumes that all NearNeighbors subclasses # will correctly identify bonds in diamond, if it # can't there are probably bigger problems subclasses = NearNeighbors.__subclasses__() for subclass in subclasses: # Critic2NN has external dependency, is tested separately if 'Critic2' not in str(subclass): nn_info = subclass().get_nn_info(self.diamond, 0) self.assertEqual(nn_info[0]['site_index'], 1) self.assertEqual(nn_info[0]['image'][0], 1)
def set_nn_info(self): # check conformance # implicitly assumes that all NearNeighbors subclasses # will correctly identify bonds in diamond, if it # can't there are probably bigger problems subclasses = NearNeighbors.__subclasses__() for subclass in subclasses: # Critic2NN has external dependency, is tested separately if 'Critic2' not in str(subclass): nn_info = subclass().get_nn_info(self.diamond, 0) self.assertEqual(nn_info[0]['site_index'], 1) self.assertEqual(nn_info[0]['image'][0], 1)
def __init__(self, materials, bonding, strategies=('MinimumDistanceNN', 'MinimumOKeeffeNN', 'JMolNN', 'MinimumVIRENN', 'VoronoiNN', 'CrystalNN', 'EconNN', 'BrunnerNN_real', 'BrunnerNN_relative', 'BrunnerNN_reciprocal', 'Critic2NN'), query=None, **kwargs): """ Builder to calculate bonding in a crystallographic structure via near neighbor strategies, including those in pymatgen.analysis.local_env and using the critic2 tool. Args: materials (Store): Store of materials documents bonding (Store): Store of topology data strategies (list): List of NearNeighbor classes to use (can be an instance of a NearNeighbor class or its name as a string, in which case it will be instantiated with default arguments) query (dict): dictionary to limit materials to be analyzed """ self.chunk_size = 100 self.materials = materials self.bonding = bonding self.query = query or {} available_strategies = { nn.__name__: nn for nn in NearNeighbors.__subclasses__() } if strategies: # use the class if passed directly (e.g. with custom kwargs), # otherwise instantiate class with default options self.strategies = [ strategy if isinstance(strategy, NearNeighbors) else available_strategies[strategy]() for strategy in strategies ] else: # calculate all the strategies self.strategies = available_strategies.values() bonding.validator = BondValidator() super().__init__(sources=[materials], targets=[bonding], **kwargs)
def __init__(self, materials, bonding, strategies=("CrystalNN", ), **kwargs): """ Builder to calculate bonding in a crystallographic structure via near neighbor strategies, including those in pymatgen.analysis.local_env and using the critic2 tool. Args: materials (Store): Store of materials documents bonding (Store): Store of topology data strategies (list): List of NearNeighbor classes to use (can be an instance of a NearNeighbor class or its name as a string, in which case it will be instantiated with default arguments) query (dict): dictionary to limit materials to be analyzed """ self.materials = materials self.bonding = bonding self.bonding.validator = JSONSchemaValidator(loadfn(BOND_SCHEMA)) available_strategies = { nn.__name__: nn for nn in NearNeighbors.__subclasses__() } # use the class if passed directly (e.g. with custom kwargs), # otherwise instantiate class with default options self.strategies = [ strategy if isinstance(strategy, NearNeighbors) else available_strategies[strategy]() for strategy in strategies ] self.strategy_names = [ strategy.__class__.__name__ for strategy in self.strategies ] self.bad_task_ids = [ ] # Voronoi-based strategies can cause some structures to cause crash super().__init__(source=materials, target=bonding, ufn=self.calc, projection=["structure"], **kwargs)
class StructureMoleculeComponent(MPComponent): available_bonding_strategies = { subclass.__name__: subclass for subclass in NearNeighbors.__subclasses__() } available_radius_strategies = ( "atomic", "specified_or_average_ionic", "covalent", "van_der_waals", "atomic_calculated", "uniform", ) # TODO ... available_polyhedra_rules = ("prefer_large_polyhedra", "only_same_species") default_scene_settings = { "lights": [ { "type": "DirectionalLight", "args": ["#ffffff", 0.15], "position": [-10, 10, 10], # "helper":True }, { "type": "DirectionalLight", "args": ["#ffffff", 0.15], "position": [0, 0, -10], # "helper": True }, # {"type":"AmbientLight", "args":["#eeeeee", 0.9]} { "type": "HemisphereLight", "args": ["#eeeeee", "#999999", 1.0] }, ], "material": { "type": "MeshStandardMaterial", "parameters": { "roughness": 0.07, "metalness": 0.00 }, }, "objectScale": 1.0, "cylinderScale": 0.1, "defaultSurfaceOpacity": 0.5, "staticScene": True, } def __init__( self, struct_or_mol=None, id=None, origin_component=None, scene_additions=None, bonding_strategy="BrunnerNN_reciprocal", bonding_strategy_kwargs=None, color_scheme="Jmol", color_scale=None, radius_strategy="uniform", draw_image_atoms=True, bonded_sites_outside_unit_cell=False, ): super().__init__(id=id, contents=struct_or_mol, origin_component=origin_component) self.default_title = "Crystal Toolkit" self.initial_scene_settings = StructureMoleculeComponent.default_scene_settings self.create_store("scene_settings", initial_data=self.initial_scene_settings) self.initial_graph_generation_options = { "bonding_strategy": bonding_strategy, "bonding_strategy_kwargs": bonding_strategy_kwargs, } self.create_store( "graph_generation_options", initial_data=self.initial_graph_generation_options, ) self.initial_display_options = { "color_scheme": color_scheme, "color_scale": color_scale, "radius_strategy": radius_strategy, "draw_image_atoms": draw_image_atoms, "bonded_sites_outside_unit_cell": bonded_sites_outside_unit_cell, } self.create_store("display_options", initial_data=self.initial_display_options) if struct_or_mol: # graph is cached explicitly, this isn't necessary but is an # optimization so that graph is only re-generated if bonding # algorithm changes graph = self._preprocess_input_to_graph( struct_or_mol, bonding_strategy=bonding_strategy, bonding_strategy_kwargs=bonding_strategy_kwargs, ) scene, legend = self.get_scene_and_legend( graph, name=self.id(), **self.initial_display_options) else: # component could be initialized without a structure, in which case # an empty scene should be displayed graph = None scene, legend = self.get_scene_and_legend( None, name=self.id(), **self.initial_display_options) self.initial_legend = legend self.create_store("legend_data", initial_data=self.initial_legend) self.initial_scene_data = scene.to_json() self.initial_graph = graph self.create_store("graph", initial_data=self.to_data(graph)) if scene_additions: self.initial_scene_additions = scene_additions self.create_store("scene_additions", initial_data=scene_additions.to_json()) def _generate_callbacks(self, app, cache): @app.callback( Output(self.id("graph"), "data"), [ Input(self.id("graph_generation_options"), "data"), Input(self.id("unit-cell-choice"), "value"), Input(self.id("repeats"), "value"), Input(self.id(), "data"), ], ) def update_graph(graph_generation_options, unit_cell_choice, repeats, struct_or_mol): struct_or_mol = self.from_data(struct_or_mol) graph_generation_options = self.from_data(graph_generation_options) repeats = int(repeats) if isinstance(struct_or_mol, Structure): if unit_cell_choice != "input": if unit_cell_choice == "primitive": struct_or_mol = struct_or_mol.get_primitive_structure() elif unit_cell_choice == "conventional": sga = SpacegroupAnalyzer(struct_or_mol) struct_or_mol = sga.get_conventional_standard_structure( ) if repeats != 1: struct_or_mol = struct_or_mol * (repeats, repeats, repeats) graph = self._preprocess_input_to_graph( struct_or_mol, bonding_strategy=graph_generation_options["bonding_strategy"], bonding_strategy_kwargs=graph_generation_options[ "bonding_strategy_kwargs"], ) return self.to_data(graph) @app.callback( Output(self.id("scene"), "data"), [ Input(self.id("graph"), "data"), Input(self.id("display_options"), "data"), ], ) def update_scene(graph, display_options): display_options = self.from_data(display_options) graph = self.from_data(graph) scene, legend = self.get_scene_and_legend(graph, **display_options) return scene.to_json() @app.callback( Output(self.id("legend_data"), "data"), [ Input(self.id("graph"), "data"), Input(self.id("display_options"), "data"), ], ) def update_legend(graph, display_options): # TODO: more cleanly split legend from scene generation display_options = self.from_data(display_options) graph = self.from_data(graph) struct_or_mol = self._get_struct_or_mol(graph) site_prop_types = self._analyze_site_props(struct_or_mol) colors, legend = self._get_display_colors_and_legend_for_sites( struct_or_mol, site_prop_types, color_scheme=display_options.get("color_scheme", None), color_scale=display_options.get("color_scale", None), ) return self.to_data(legend) @app.callback( Output(self.id("color-scheme"), "options"), [Input(self.id("graph"), "data")], ) def update_color_options(graph): options = [ { "label": "Jmol", "value": "Jmol" }, { "label": "VESTA", "value": "VESTA" }, ] graph = self.from_data(graph) struct_or_mol = self._get_struct_or_mol(graph) site_props = self._analyze_site_props(struct_or_mol) if "scalar" in site_props: for prop in site_props["scalar"]: options += [{ "label": f"Site property: {prop}", "value": prop }] return options @app.callback( Output(self.id("display_options"), "data"), [ Input(self.id("color-scheme"), "value"), Input(self.id("radius_strategy"), "value"), Input(self.id("draw_options"), "values"), ], [State(self.id("display_options"), "data")], ) def update_display_options(color_scheme, radius_strategy, draw_options, display_options): display_options = self.from_data(display_options) display_options.update({"color_scheme": color_scheme}) display_options.update({"radius_strategy": radius_strategy}) display_options.update( {"draw_image_atoms": "draw_image_atoms" in draw_options}) display_options.update({ "bonded_sites_outside_unit_cell": "bonded_sites_outside_unit_cell" in draw_options }) return self.to_data(display_options) @app.callback( Output(self.id("scene"), "downloadRequest"), [Input(self.id("screenshot_button"), "n_clicks")], [ State(self.id("scene"), "downloadRequest"), State(self.id(), "data") ], ) def screenshot_callback(n_clicks, current_requests, struct_or_mol): if n_clicks is None: raise PreventUpdate struct_or_mol = self.from_data(struct_or_mol) # TODO: this will break if store is structure/molecule graph ... formula = struct_or_mol.composition.reduced_formula if hasattr(struct_or_mol, "get_space_group_info"): spgrp = struct_or_mol.get_space_group_info()[0] else: spgrp = "" request_filename = "{}-{}-crystal-toolkit.png".format( formula, spgrp) if not current_requests: n_requests = 1 else: n_requests = current_requests["n_requests"] + 1 return { "n_requests": n_requests, "filename": request_filename, "filetype": "png", } @app.callback( Output(self.id("scene"), "toggleVisibility"), [Input(self.id("hide-show"), "values")], [State(self.id("hide-show"), "options")], ) def update_visibility(values, options): visibility = { opt["value"]: (opt["value"] in values) for opt in options } return visibility @app.callback( Output(self.id("title_container"), "children"), [Input(self.id("legend_data"), "data")], ) def update_title(legend): legend = self.from_data(legend) return self._make_title(legend) @app.callback( Output(self.id("legend_container"), "children"), [Input(self.id("legend_data"), "data")], ) def update_legend(legend): legend = self.from_data(legend) return self._make_legend(legend) def _make_legend(self, legend): if legend is None or (not legend.get("colors", None)): return html.Div(id=self.id("legend")) def get_font_color(hex_code): # ensures contrasting font color for background color c = tuple(int(hex_code[1:][i:i + 2], 16) for i in (0, 2, 4)) if 1 - (c[0] * 0.299 + c[1] * 0.587 + c[2] * 0.114) / 255 < 0.5: font_color = "#000000" else: font_color = "#ffffff" return font_color formula = Composition.from_dict(legend["composition"]).reduced_formula legend_colors = OrderedDict( sorted(list(legend["colors"].items()), key=lambda x: formula.find(x[1]))) legend_elements = [ Button( html.Span(name, className="icon", style={"color": get_font_color(color)}), kind="static", style={"background-color": color}, ) for color, name in legend_colors.items() ] return Field( [ Control(el, style={"margin-right": "0.2rem"}) for el in legend_elements ], id=self.id("legend"), grouped=True, ) def _make_title(self, legend): if not legend or (not legend.get("composition", None)): return H1(self.default_title, id=self.id("title")) composition = legend["composition"] if isinstance(composition, dict): composition = Composition.from_dict(composition) formula = composition.reduced_formula formula_parts = re.findall(r"[^\d_]+|\d+", formula) formula_components = [ html.Sub(part) if part.isnumeric() else html.Span(part) for part in formula_parts ] return H1(formula_components, id=self.id("title"), style={"display": "inline-block"}) @property def all_layouts(self): struct_layout = html.Div( Simple3DSceneComponent( id=self.id("scene"), data=self.initial_scene_data, settings=self.initial_scene_settings, ), style={ "width": "100%", "height": "100%", "overflow": "hidden", "margin": "0 auto", }, ) screenshot_layout = html.Div( [ Button( [Icon(), html.Span(), "Download Image"], kind="primary", id=self.id("screenshot_button"), ) ], # TODO: change to "bottom" when dropdown included style={ "vertical-align": "top", "display": "inline-block" }, ) title_layout = html.Div(self._make_title(self.initial_legend), id=self.id("title_container")) legend_layout = html.Div(self._make_legend(self.initial_legend), id=self.id("legend_container")) # options = { # "bonding_strategy": bonding_strategy, # "bonding_strategy_kwargs": bonding_strategy_kwargs, options_layout = Field([ # hide if molecule html.Label("Change unit cell:", className="mpc-label"), html.Div( dcc.RadioItems( options=[ { "label": "Input cell", "value": "input" }, { "label": "Primitive cell", "value": "primitive" }, { "label": "Conventional cell", "value": "conventional" }, ], value="conventional", id=self.id("unit-cell-choice"), labelStyle={"display": "block"}, inputClassName="mpc-radio", ), className="mpc-control", ), # hide if molecule html.Label("Change number of repeats:", className="mpc-label"), html.Div( dcc.RadioItems( options=[ { "label": "1×1×1", "value": "1" }, { "label": "2×2×2", "value": "2" }, ], value="1", id=self.id("repeats"), labelStyle={"display": "block"}, inputClassName="mpc-radio", ), className="mpc-control", ), html.Label("Change color scheme:", className="mpc-label"), html.Div( dcc.Dropdown( options=[ { "label": "VESTA", "value": "VESTA" }, { "label": "Jmol", "value": "Jmol" }, ], value="VESTA", clearable=False, id=self.id("color-scheme"), ), className="mpc-control", ), html.Label("Change atomic radii:", className="mpc-label"), html.Div( dcc.Dropdown( options=[ { "label": "Ionic", "value": "specified_or_average_ionic" }, { "label": "Covalent", "value": "covalent" }, { "label": "Van der Waals", "value": "van_der_waals" }, { "label": "Uniform (0.5Å)", "value": "uniform" }, ], value="uniform", clearable=False, id=self.id("radius_strategy"), ), className="mpc-control", ), html.Label("Draw options:", className="mpc-label"), html.Div([ dcc.Checklist( options=[ { "label": "Draw repeats of atoms on periodic boundaries", "value": "draw_image_atoms", }, { "label": "Draw atoms outside unit cell bonded to " "atoms within unit cell", "value": "bonded_sites_outside_unit_cell", }, ], values=["draw_image_atoms"], labelStyle={"display": "block"}, inputClassName="mpc-radio", id=self.id("draw_options"), ) ]), html.Label("Hide/show:", className="mpc-label"), html.Div( [ dcc.Checklist( options=[ { "label": "Atoms", "value": "atoms" }, { "label": "Bonds", "value": "bonds" }, { "label": "Unit cell", "value": "unit_cell" }, { "label": "Polyhedra", "value": "polyhedra" }, ], values=["atoms", "bonds", "unit_cell", "polyhedra"], labelStyle={"display": "block"}, inputClassName="mpc-radio", id=self.id("hide-show"), ) ], className="mpc-control", ), ]) return { "struct": struct_layout, "screenshot": screenshot_layout, "options": options_layout, "title": title_layout, "legend": legend_layout, } @property def standard_layout(self): return self.all_layouts["struct"] @staticmethod def _preprocess_input_to_graph( input: Union[Structure, StructureGraph, Molecule, MoleculeGraph], bonding_strategy: str = "CrystalNN", bonding_strategy_kwargs: Optional[Dict] = None, ) -> Union[StructureGraph, MoleculeGraph]: if isinstance(input, Structure): # ensure fractional co-ordinates are normalized to be in [0,1) # (this is actually not guaranteed by Structure) input = input.as_dict(verbosity=0) for site in input["sites"]: site["abc"] = np.mod(site["abc"], 1) input = Structure.from_dict(input) if not input.is_ordered: # calculating bonds in disordered structures is currently very flaky bonding_strategy = "CutOffDictNN" # we assume most uses of this class will give a structure as an input argument, # meaning we have to calculate the graph for bonding information, however if # the graph is already known and supplied, we will use that if isinstance(input, StructureGraph) or isinstance( input, MoleculeGraph): graph = input else: if (bonding_strategy not in StructureMoleculeComponent. available_bonding_strategies.keys()): raise ValueError( "Bonding strategy not supported. Please supply a name " "of a NearNeighbor subclass, choose from: {}".format( ", ".join(StructureMoleculeComponent. available_bonding_strategies.keys()))) else: bonding_strategy_kwargs = bonding_strategy_kwargs or {} if bonding_strategy == "CutOffDictNN": if "cut_off_dict" in bonding_strategy_kwargs: # TODO: remove this hack by making args properly JSON serializable bonding_strategy_kwargs["cut_off_dict"] = { (x[0], x[1]): x[2] for x in bonding_strategy_kwargs["cut_off_dict"] } bonding_strategy = StructureMoleculeComponent.available_bonding_strategies[ bonding_strategy](**bonding_strategy_kwargs) try: if isinstance(input, Structure): graph = StructureGraph.with_local_env_strategy( input, bonding_strategy) else: graph = MoleculeGraph.with_local_env_strategy( input, bonding_strategy) except: # for some reason computing bonds failed, so let's not have any bonds(!) if isinstance(input, Structure): graph = StructureGraph.with_empty_graph(input) else: graph = MoleculeGraph.with_empty_graph(input) return graph @staticmethod def _analyze_site_props(struct_or_mol): # store list of site props that are vectors, so these can be displayed as arrows # (implicitly assumes all site props for a given key are same type) site_prop_names = defaultdict(list) for name, props in struct_or_mol.site_properties.items(): if isinstance(props[0], float) or isinstance(props[0], int): site_prop_names["scalar"].append(name) elif isinstance(props[0], list) and len(props[0]) == 3: if isinstance(props[0][0], list) and len(props[0][0]) == 3: site_prop_names["matrix"].append(name) else: site_prop_names["vector"].append(name) elif isinstance(props[0], str): site_prop_names["categorical"].append(name) return dict(site_prop_names) @staticmethod def _get_origin(struct_or_mol): if isinstance(struct_or_mol, Structure): # display_range = [0.5, 0.5, 0.5] # x_center = 0.5 * (max(display_range[0]) - min(display_range[0])) # y_center = 0.5 * (max(display_range[1]) - min(display_range[1])) # z_center = 0.5 * (max(display_range[2]) - min(display_range[2])) geometric_center = struct_or_mol.lattice.get_cartesian_coords( (0.5, 0.5, 0.5)) elif isinstance(struct_or_mol, Molecule): geometric_center = np.average(struct_or_mol.cart_coords, axis=0) else: geometric_center = (0, 0, 0) return geometric_center @staticmethod def _get_struct_or_mol(graph) -> Union[Structure, Molecule]: if isinstance(graph, StructureGraph): return graph.structure elif isinstance(graph, MoleculeGraph): return graph.molecule else: raise ValueError @staticmethod def _get_display_colors_and_legend_for_sites( struct_or_mol, site_prop_types, color_scheme="Jmol", color_scale=None) -> Tuple[List[List[str]], Dict]: """ Note this returns a list of lists of strings since each site might have multiple colors defined if the site is disordered. The legend is a dictionary whose keys are colors and values are corresponding element names or values, depending on the color scheme chosen. """ # TODO: check to see if there is a bug here due to Composition being unordered(?) legend = { "composition": struct_or_mol.composition.as_dict(), "colors": {} } # don't calculate color if one is explicitly supplied if "display_color" in struct_or_mol.site_properties: # don't know what the color legend (meaning) is, so return empty legend return (struct_or_mol.site_properties["display_color"], legend) def get_color_hex(x): return "#{:02x}{:02x}{:02x}".format(*x) if color_scheme not in ("VESTA", "Jmol", "colorblind_friendly"): if not struct_or_mol.is_ordered: raise ValueError( "Can only use VESTA, Jmol or colorblind_friendly color " "schemes for disordered structures or molecules, color " "schemes based on site properties are ill-defined.") if (color_scheme not in site_prop_types.get( "scalar", [])) and (color_scheme not in site_prop_types.get( "categorical", [])): raise ValueError( "Unsupported color scheme. Should be VESTA, Jmol, " "colorblind_friendly or a scalar (float) or categorical " "(string) site property.") if color_scheme in ("VESTA", "Jmol"): colors = [] for site in struct_or_mol: elements = [ sp.as_dict()["element"] for sp, _ in site.species_and_occu.items() ] colors.append([ get_color_hex(EL_COLORS[color_scheme][element]) for element in elements ]) # construct legend for element in elements: color = get_color_hex(EL_COLORS[color_scheme][element]) legend["colors"][color] = element elif color_scheme in site_prop_types.get("scalar", []): props = np.array(struct_or_mol.site_properties[color_scheme]) # by default, use blue-grey-red color scheme, # so that zero is ~ grey, and positive/negative # are red/blue color_scale = color_scale or "coolwarm" # try to keep color scheme symmetric around 0 color_max = max([abs(min(props)), max(props)]) color_min = -color_max cmap = get_cmap(color_scale) # normalize in [0, 1] range, as expected by cmap props_normed = (props - color_min) / (color_max - color_min) def get_color_cmap(x): return [int(c * 255) for c in cmap(x)[0:3]] colors = [[get_color_hex(get_color_cmap(x))] for x in props_normed] # construct legend # max/min only: # c = get_color_hex(get_color_cmap(color_min)) # legend["colors"][c] = "{:.1f}".format(color_min) # if color_max != color_min: # c = get_color_hex(get_color_cmap(color_max)) # legend["colors"][c] = "{:.1f}".format(color_max) # all colors: rounded_props = sorted( list(set([np.around(p, decimals=1) for p in props]))) for prop in rounded_props: prop_normed = (prop - color_min) / (color_max - color_min) c = get_color_hex(get_color_cmap(prop_normed)) legend["colors"][c] = "{:.1f}".format(prop) elif color_scheme == "colorblind_friendly": raise NotImplementedError elif color_scheme in site_prop_types.get("categorical", []): # iter() a palettable palettable.colorbrewer.qualitative # cmap.colors, check len, Set1_9 ? raise NotImplementedError return colors, legend @staticmethod def _primitives_from_lattice(lattice, origin=(0, 0, 0), **kwargs): o = -np.array(origin) a, b, c = lattice.matrix[0], lattice.matrix[1], lattice.matrix[2] line_pairs = [ o, o + a, o, o + b, o, o + c, o + a, o + a + b, o + a, o + a + c, o + b, o + b + a, o + b, o + b + c, o + c, o + c + a, o + c, o + c + b, o + a + b, o + a + b + c, o + a + c, o + a + b + c, o + b + c, o + a + b + c, ] line_pairs = [line.tolist() for line in line_pairs] return Lines(line_pairs, **kwargs) @staticmethod def _get_ellipsoids_from_matrix(matrix): raise NotImplementedError # matrix = np.array(matrix) # eigenvalues, eigenvectors = np.linalg.eig(matrix) @staticmethod def _primitives_from_site( site, connected_sites=None, origin=(0, 0, 0), ellipsoid_site_prop=None, all_connected_sites_present=True, explicitly_calculate_polyhedra_hull=False, ): """ Sites must have display_radius and display_color site properties. :param site: :param connected_sites: :param origin: :param ellipsoid_site_prop: (beta) :param all_connected_sites_present: if False, will not calculate polyhedra since this would be misleading :param explicitly_calculate_polyhedra_hull: :return: """ atoms = [] bonds = [] polyhedron = [] # for disordered structures is_ordered = site.is_ordered occu_start = 0.0 # for thermal ellipsoids etc. if ellipsoid_site_prop: matrix = site.properties[ellipsoid_site_prop] ellipsoids = StructureMoleculeComponent._get_ellipsoids_from_matrix( matrix) else: ellipsoids = None position = np.subtract(site.coords, origin).tolist() # site_color is used for bonds and polyhedra, if multiple colors are # defined for site (e.g. a disordered site), then we use grey all_colors = set(site.properties["display_color"]) if len(all_colors) > 1: site_color = "#555555" else: site_color = list(all_colors)[0] for idx, (sp, occu) in enumerate(site.species_and_occu.items()): if isinstance(sp, DummySpecie): cube = Cubes(positions=[position]) atoms.append(cube) else: color = site.properties["display_color"][idx] radius = site.properties["display_radius"][idx] # TODO: make optional/default to None # in disordered structures, we fractionally color-code spheres, # drawing a sphere segment from phi_end to phi_start # (think a sphere pie chart) if not is_ordered: phi_frac_end = occu_start + occu phi_frac_start = occu_start occu_start = phi_frac_end phiStart = phi_frac_start * np.pi * 2 phiEnd = phi_frac_end * np.pi * 2 else: phiStart, phiEnd = None, None # TODO: add names for labels # name = "{}".format(sp) # if occu != 1.0: # name += " ({}% occupancy)".format(occu) sphere = Spheres( positions=[position], color=color, radius=radius, phiStart=phiStart, phiEnd=phiEnd, ellipsoids=ellipsoids, ) atoms.append(sphere) if connected_sites: all_positions = [position] for connected_site in connected_sites: connected_position = np.subtract(connected_site.site.coords, origin) bond_midpoint = np.add(position, connected_position) / 2 cylinder = Cylinders( positionPairs=[[position, bond_midpoint.tolist()]], color=site_color) bonds.append(cylinder) all_positions.append(connected_position.tolist()) if len(connected_sites) > 3 and all_connected_sites_present: if explicitly_calculate_polyhedra_hull: try: # all_positions = [[0, 0, 0], [0, 0, 10], [0, 10, 0], [10, 0, 0]] # gives... # .convex_hull = [[2, 3, 0], [1, 3, 0], [1, 2, 0], [1, 2, 3]] # .vertex_neighbor_vertices = [1, 2, 3, 2, 3, 0, 1, 3, 0, 1, 2, 0] vertices_indices = Delaunay( all_positions).vertex_neighbor_vertices vertices = [ all_positions[idx] for idx in vertices_indices ] polyhedron = [ Surface( positions=vertices, color=site.properties["display_color"][0], ) ] except Exception as e: polyhedron = [] else: polyhedron = [ Convex(positions=all_positions, color=site_color) ] return {"atoms": atoms, "bonds": bonds, "polyhedra": polyhedron} @staticmethod def _get_display_radii_for_sites( struct_or_mol, radius_strategy="specified_or_average_ionic") -> List[List[float]]: """ Note this returns a list of lists of floats since each site might have multiple radii defined if the site is disordered. """ # don't calculate radius if one is explicitly supplied if "display_radius" in struct_or_mol.site_properties: return struct_or_mol.site_properties["display_radius"] if (radius_strategy not in StructureMoleculeComponent.available_radius_strategies): raise ValueError( "Unknown radius strategy {}, choose from: {}".format( radius_strategy, StructureMoleculeComponent.available_radius_strategies, )) radii = [] for site_idx, site in enumerate(struct_or_mol): site_radii = [] for comp_idx, (sp, occu) in enumerate(site.species_and_occu.items()): radius = None if radius_strategy == "uniform": radius = 0.5 if radius_strategy == "atomic": radius = sp.atomic_radius elif (radius_strategy == "specified_or_average_ionic" and isinstance(sp, Specie) and sp.oxi_state): radius = sp.ionic_radius elif radius_strategy == "specified_or_average_ionic": radius = sp.average_ionic_radius elif radius_strategy == "covalent": el = str(getattr(sp, "element", sp)) radius = CovalentRadius.radius[el] elif radius_strategy == "van_der_waals": radius = sp.van_der_waals_radius elif radius_strategy == "atomic_calculated": radius = sp.atomic_radius_calculated if not radius: warnings.warn("Radius unknown for {} and strategy {}, " "setting to 1.0.".format( sp, radius_strategy)) radius = 1.0 site_radii.append(radius) radii.append(site_radii) return radii @staticmethod def _get_sites_to_draw( struct_or_mol: Union[Structure, Molecule], graph: Union[StructureGraph, MoleculeGraph], draw_image_atoms=True, bonded_sites_outside_unit_cell=True, ): """ Returns a list of site indices and image vectors. """ sites_to_draw = [(idx, (0, 0, 0)) for idx in range(len(struct_or_mol))] # trivial in this case if isinstance(struct_or_mol, Molecule): return sites_to_draw if draw_image_atoms: for idx, site in enumerate(struct_or_mol): zero_elements = [ idx for idx, f in enumerate(site.frac_coords) if np.allclose(f, 0, atol=0.05) ] coord_permutations = [ x for l in range(1, len(zero_elements) + 1) for x in combinations(zero_elements, l) ] for perm in coord_permutations: sites_to_draw.append((idx, (int(0 in perm), int(1 in perm), int(2 in perm)))) one_elements = [ idx for idx, f in enumerate(site.frac_coords) if np.allclose(f, 1, atol=0.05) ] coord_permutations = [ x for l in range(1, len(one_elements) + 1) for x in combinations(one_elements, l) ] for perm in coord_permutations: sites_to_draw.append( (idx, (-int(0 in perm), -int(1 in perm), -int(2 in perm)))) if bonded_sites_outside_unit_cell: # TODO: subtle bug here, see mp-5020, expansion logic not quite right sites_to_append = [] for (n, jimage) in sites_to_draw: connected_sites = graph.get_connected_sites(n, jimage=jimage) for connected_site in connected_sites: if connected_site.jimage != (0, 0, 0): sites_to_append.append( (connected_site.index, connected_site.jimage)) sites_to_draw += sites_to_append return sites_to_draw @staticmethod def get_scene_and_legend( graph: Union[StructureGraph, MoleculeGraph], name="StructureMoleculeComponent", color_scheme="Jmol", color_scale=None, radius_strategy="specified_or_average_ionic", ellipsoid_site_prop=None, draw_image_atoms=True, bonded_sites_outside_unit_cell=True, hide_incomplete_bonds=False, explicitly_calculate_polyhedra_hull=False, ) -> Tuple[Scene, Dict[str, str]]: scene = Scene(name=name) if graph is None: return scene, {} struct_or_mol = StructureMoleculeComponent._get_struct_or_mol(graph) site_prop_types = StructureMoleculeComponent._analyze_site_props( struct_or_mol) radii = StructureMoleculeComponent._get_display_radii_for_sites( struct_or_mol, radius_strategy=radius_strategy) colors, legend = StructureMoleculeComponent._get_display_colors_and_legend_for_sites( struct_or_mol, site_prop_types, color_scale=color_scale, color_scheme=color_scheme, ) struct_or_mol.add_site_property("display_radius", radii) struct_or_mol.add_site_property("display_color", colors) origin = StructureMoleculeComponent._get_origin(struct_or_mol) primitives = defaultdict(list) sites_to_draw = StructureMoleculeComponent._get_sites_to_draw( struct_or_mol, graph, draw_image_atoms=draw_image_atoms, bonded_sites_outside_unit_cell=bonded_sites_outside_unit_cell, ) for (idx, jimage) in sites_to_draw: site = struct_or_mol[idx] if jimage != (0, 0, 0): connected_sites = graph.get_connected_sites(idx, jimage=jimage) site = PeriodicSite( site.species_and_occu, np.add(site.frac_coords, jimage), site.lattice, properties=site.properties, ) else: connected_sites = graph.get_connected_sites(idx) true_number_of_connected_sites = len(connected_sites) connected_sites_being_drawn = [ cs for cs in connected_sites if (cs.index, cs.jimage) in sites_to_draw ] number_of_connected_sites_drawn = len(connected_sites_being_drawn) all_connected_sites_present = (true_number_of_connected_sites == number_of_connected_sites_drawn) if hide_incomplete_bonds: # only draw bonds if the destination site is also being drawn connected_sites = connected_sites_being_drawn site_primitives = StructureMoleculeComponent._primitives_from_site( site, connected_sites=connected_sites, all_connected_sites_present=all_connected_sites_present, origin=origin, ellipsoid_site_prop=ellipsoid_site_prop, explicitly_calculate_polyhedra_hull= explicitly_calculate_polyhedra_hull, ) for k, v in site_primitives.items(): primitives[k] += v # we are here ... # select polyhedra # split by atom type at center # see if any intersect, if yes split further # order sets, with each choice, go to add second set etc if don't intersect # they intersect if centre atom forms vertex of another atom (caveat: centre atom may not actually be inside polyhedra! not checking for this, add todo) # def _set_intersects() ->bool: # def _split_set() ->List: (by type, then..?) # def _order_sets()... pick 1, ask can add 2? etc if isinstance(struct_or_mol, Structure): primitives["unit_cell"].append( StructureMoleculeComponent._primitives_from_lattice( struct_or_mol.lattice, origin=origin)) sub_scenes = [Scene(name=k, contents=v) for k, v in primitives.items()] scene.contents = sub_scenes return scene, legend
class StructureMoleculeComponent(MPComponent): """ A component to display pymatgen Structure, Molecule, StructureGraph and MoleculeGraph objects. """ available_bonding_strategies = { subclass.__name__: subclass for subclass in NearNeighbors.__subclasses__() } default_scene_settings = {} # whether to persist options such as atomic radii etc. persistence = False persistence_type = "local" def __init__( self, struct_or_mol: Optional[ Union[Structure, StructureGraph, Molecule, MoleculeGraph] ] = None, id: str = None, scene_additions: Optional[Scene] = None, bonding_strategy: str = DEFAULTS["bonding_strategy"], bonding_strategy_kwargs: Optional[dict] = None, color_scheme: str = DEFAULTS["color_scheme"], color_scale: Optional[str] = None, radius_strategy: str = DEFAULTS["radius_strategy"], unit_cell_choice: str = DEFAULTS["unit_cell_choice"], draw_image_atoms: bool = DEFAULTS["draw_image_atoms"], bonded_sites_outside_unit_cell: bool = DEFAULTS[ "bonded_sites_outside_unit_cell" ], hide_incomplete_bonds: bool = DEFAULTS["hide_incomplete_bonds"], show_compass: bool = DEFAULTS["show_compass"], scene_settings: Optional[Dict] = None, **kwargs, ): """ Create a StructureMoleculeComponent from a structure or molecule. :param struct_or_mol: input structure or molecule :param id: canonical id :param scene_additions: extra geometric elements to add to the 3D scene :param bonding_strategy: bonding strategy from pymatgen NearNeighbors class :param bonding_strategy_kwargs: options for the bonding strategy :param color_scheme: color scheme, see Legend class :param color_scale: color scale, see Legend class :param radius_strategy: radius strategy, see Legend class :param draw_image_atoms: whether to draw repeats of atoms on periodic images :param bonded_sites_outside_unit_cell: whether to draw sites bonded outside the unit cell :param hide_incomplete_bonds: whether to hide or show incomplete bonds :param show_compass: whether to hide or show the compass :param scene_settings: scene settings (lighting etc.) to pass to Simple3DScene :param kwargs: extra keyword arguments to pass to MPComponent """ super().__init__(id=id, default_data=struct_or_mol, **kwargs) # what to show for the title_layout if structure/molecule not loaded self.default_title = "Crystal Toolkit" self.initial_scene_settings = self.default_scene_settings.copy() if scene_settings: self.initial_scene_settings.update(scene_settings) self.create_store("scene_settings", initial_data=self.initial_scene_settings) # unit cell choice and bonding algorithms need to come from a settings # object (in a dcc.Store) guaranteed to be present in layout, rather # than from the controls themselves -- since these are optional and # may not be present in the layout self.create_store( "graph_generation_options", initial_data={ "bonding_strategy": bonding_strategy, "bonding_strategy_kwargs": bonding_strategy_kwargs, "unit_cell_choice": unit_cell_choice, }, ) self.create_store( "display_options", initial_data={ "color_scheme": color_scheme, "color_scale": color_scale, "radius_strategy": radius_strategy, "draw_image_atoms": draw_image_atoms, "bonded_sites_outside_unit_cell": bonded_sites_outside_unit_cell, "hide_incomplete_bonds": hide_incomplete_bonds, "show_compass": show_compass, }, ) if scene_additions: initial_scene_additions = Scene( name="scene_additions", contents=scene_additions ).to_json() else: initial_scene_additions = None self.create_store("scene_additions", initial_data=initial_scene_additions) if struct_or_mol: # graph is cached explicitly, this isn't necessary but is an # optimization so that graph is only re-generated if bonding # algorithm changes graph = self._preprocess_input_to_graph( struct_or_mol, bonding_strategy=bonding_strategy, bonding_strategy_kwargs=bonding_strategy_kwargs, ) scene, legend = self.get_scene_and_legend( graph, name=self.id(), scene_additions=self.initial_data["scene_additions"], **self.initial_data["display_options"], ) if hasattr(struct_or_mol, "lattice"): self._lattice = struct_or_mol.lattice else: # component could be initialized without a structure, in which case # an empty scene should be displayed graph = None scene, legend = self.get_scene_and_legend( None, name=self.id(), scene_additions=self.initial_data["scene_additions"], **self.initial_data["display_options"], ) self.create_store("legend_data", initial_data=legend) self.create_store("graph", initial_data=graph) # this is used by a Simple3DScene component, not a dcc.Store self._initial_data["scene"] = scene def generate_callbacks(self, app, cache): # a lot of the verbosity in this callback is to support custom bonding # this is not the format CutOffDictNN expects (since that is not JSON # serializable), so we store as a list of tuples instead # TODO: make CutOffDictNN args JSON serializable app.clientside_callback( """ function (bonding_strategy, custom_cutoffs_rows, unit_cell_choice) { const bonding_strategy_kwargs = {} if (bonding_strategy === 'CutOffDictNN') { const cut_off_dict = [] custom_cutoffs_rows.forEach(function(row) { cut_off_dict.push([row['A'], row['B'], parseFloat(row['A—B'])]) }) bonding_strategy_kwargs.cut_off_dict = cut_off_dict } return { bonding_strategy: bonding_strategy, bonding_strategy_kwargs: bonding_strategy_kwargs, unit_cell_choice: unit_cell_choice } } """, Output(self.id("graph_generation_options"), "data"), [ Input(self.id("bonding_algorithm"), "value"), Input(self.id("bonding_algorithm_custom_cutoffs"), "data"), Input(self.id("unit-cell-choice"), "value"), ], ) app.clientside_callback( """ function (values, options) { const visibility = {} options.forEach(function (opt) { visibility[opt.value] = Boolean(values.includes(opt.value)) }) return visibility } """, Output(self.id("scene"), "toggleVisibility"), [Input(self.id("hide-show"), "value")], [State(self.id("hide-show"), "options")], ) app.clientside_callback( """ function (colorScheme, radiusStrategy, drawOptions, displayOptions) { const newDisplayOptions = Object.assign({}, displayOptions); newDisplayOptions.color_scheme = colorScheme newDisplayOptions.radius_strategy = radiusStrategy newDisplayOptions.draw_image_atoms = drawOptions.includes('draw_image_atoms') newDisplayOptions.bonded_sites_outside_unit_cell = drawOptions.includes('bonded_sites_outside_unit_cell') newDisplayOptions.hide_incomplete_bonds = drawOptions.includes('hide_incomplete_bonds') return newDisplayOptions } """, Output(self.id("display_options"), "data"), [ Input(self.id("color-scheme"), "value"), Input(self.id("radius_strategy"), "value"), Input(self.id("draw_options"), "value"), ], [State(self.id("display_options"), "data")], ) @app.callback( Output(self.id("graph"), "data"), [ Input(self.id("graph_generation_options"), "data"), Input(self.id(), "data"), ], [State(self.id("graph"), "data")], ) @cache.memoize() def update_graph(graph_generation_options, struct_or_mol, current_graph): if not struct_or_mol: raise PreventUpdate struct_or_mol = self.from_data(struct_or_mol) current_graph = self.from_data(current_graph) bonding_strategy_kwargs = graph_generation_options[ "bonding_strategy_kwargs" ] # TODO: add additional check here? unit_cell_choice = graph_generation_options["unit_cell_choice"] if isinstance(struct_or_mol, Structure): if unit_cell_choice != "input": if unit_cell_choice == "primitive": struct_or_mol = struct_or_mol.get_primitive_structure() elif unit_cell_choice == "conventional": sga = SpacegroupAnalyzer(struct_or_mol) struct_or_mol = sga.get_conventional_standard_structure() elif unit_cell_choice == "reduced": struct_or_mol = struct_or_mol.get_reduced_structure() graph = self._preprocess_input_to_graph( struct_or_mol, bonding_strategy=graph_generation_options["bonding_strategy"], bonding_strategy_kwargs=bonding_strategy_kwargs, ) if ( current_graph and graph.structure == current_graph.structure and graph == current_graph ): raise PreventUpdate return graph @app.callback( [ Output(self.id("scene"), "data"), Output(self.id("legend_data"), "data"), Output(self.id("color-scheme"), "options"), ], [ Input(self.id("graph"), "data"), Input(self.id("display_options"), "data"), Input(self.id("scene_additions"), "data"), ], ) @cache.memoize() def update_scene_and_legend_and_colors(graph, display_options, scene_additions): if not graph or not display_options: raise PreventUpdate display_options = self.from_data(display_options) graph = self.from_data(graph) scene, legend = self.get_scene_and_legend( graph, name=self.id(), **display_options, scene_additions=scene_additions, ) color_options = [ {"label": "Jmol", "value": "Jmol"}, {"label": "VESTA", "value": "VESTA"}, {"label": "Accessible", "value": "accessible"}, ] struct_or_mol = self._get_struct_or_mol(graph) site_props = Legend(struct_or_mol).analyze_site_props(struct_or_mol) for site_prop_type in ("scalar", "categorical"): if site_prop_type in site_props: for prop in site_props[site_prop_type]: color_options += [ {"label": f"Site property: {prop}", "value": prop} ] return scene, legend, color_options @app.callback( Output(self.id("scene"), "downloadRequest"), [Input(self.id("screenshot_button"), "n_clicks")], [State(self.id("scene"), "downloadRequest"), State(self.id(), "data")], ) @cache.memoize() def trigger_screenshot(n_clicks, current_requests, struct_or_mol): if n_clicks is None: raise PreventUpdate struct_or_mol = self.from_data(struct_or_mol) # TODO: this will break if store is structure/molecule graph ... formula = struct_or_mol.composition.reduced_formula if hasattr(struct_or_mol, "get_space_group_info"): spgrp = struct_or_mol.get_space_group_info()[0] else: spgrp = "" request_filename = "{}-{}-crystal-toolkit.png".format(formula, spgrp) if not current_requests: n_requests = 1 else: n_requests = current_requests["n_requests"] + 1 return { "n_requests": n_requests, "filename": request_filename, "filetype": "png", } @app.callback( [ Output(self.id("legend_container"), "children"), Output(self.id("title_container"), "children"), ], [Input(self.id("legend_data"), "data")], ) @cache.memoize() def update_legend_and_title(legend): if not legend: raise PreventUpdate legend = self.from_data(legend) return self._make_legend(legend), self._make_title(legend) @app.callback( [ Output(self.id("bonding_algorithm_custom_cutoffs"), "data"), Output(self.id("bonding_algorithm_custom_cutoffs_container"), "style"), ], [Input(self.id("bonding_algorithm"), "value")], [ State(self.id("graph"), "data"), State(self.id("bonding_algorithm_custom_cutoffs_container"), "style"), ], ) @cache.memoize() def update_custom_bond_options(bonding_algorithm, graph, current_style): if not graph: raise PreventUpdate if bonding_algorithm == "CutOffDictNN": style = {} else: style = {"display": "none"} if style == current_style: # no need to update rows if we're not showing them raise PreventUpdate graph = self.from_data(graph) rows = self._make_bonding_algorithm_custom_cuffoff_data(graph) return rows, style def _make_legend(self, legend): if not legend: return html.Div(id=self.id("legend")) def get_font_color(hex_code): # ensures contrasting font color for background color c = tuple(int(hex_code[1:][i : i + 2], 16) for i in (0, 2, 4)) if 1 - (c[0] * 0.299 + c[1] * 0.587 + c[2] * 0.114) / 255 < 0.5: font_color = "#000000" else: font_color = "#ffffff" return font_color try: formula = Composition.from_dict(legend["composition"]).reduced_formula except: # TODO: fix legend for Dummy Specie compositions formula = "Unknown" legend_colors = OrderedDict( sorted(list(legend["colors"].items()), key=lambda x: formula.find(x[1])) ) legend_elements = [ Button( html.Span( name, className="icon", style={"color": get_font_color(color)} ), kind="static", style={"backgroundColor": color}, ) for color, name in legend_colors.items() ] return Field( [Control(el, style={"marginRight": "0.2rem"}) for el in legend_elements], id=self.id("legend"), grouped=True, ) def _make_title(self, legend): if not legend or (not legend.get("composition", None)): return H1(self.default_title, id=self.id("title")) composition = legend["composition"] if isinstance(composition, dict): # TODO: make Composition handle DummySpecie for title try: composition = Composition.from_dict(composition) formula = composition.iupac_formula formula_parts = re.findall(r"[^\d_]+|\d+", formula) formula_components = [ html.Sub(part) if part.isnumeric() else html.Span(part) for part in formula_parts ] except: formula_components = list(composition.keys()) return H1( formula_components, id=self.id("title"), style={"display": "inline-block"} ) @staticmethod def _make_bonding_algorithm_custom_cuffoff_data(graph): if not graph: return [{"A": None, "B": None, "A—B": None}] struct_or_mol = StructureMoleculeComponent._get_struct_or_mol(graph) # can't use type_of_specie because it doesn't work with disordered structures species = set( map( str, chain.from_iterable( [list(c.keys()) for c in struct_or_mol.species_and_occu] ), ) ) rows = [ {"A": combination[0], "B": combination[1], "A—B": 0} for combination in combinations_with_replacement(species, 2) ] return rows @property def _sub_layouts(self): struct_layout = html.Div( Simple3DSceneComponent( id=self.id("scene"), data=self.initial_data["scene"], settings=self.initial_scene_settings, ), style={ "width": "100%", "height": "100%", "overflow": "hidden", "margin": "0 auto", }, ) screenshot_layout = html.Div( [ Button( [Icon(), html.Span(), "Download Image"], kind="primary", id=self.id("screenshot_button"), ) ], # TODO: change to "bottom" when dropdown included style={"verticalAlign": "top", "display": "inline-block"}, ) title_layout = html.Div( self._make_title(self._initial_data["legend_data"]), id=self.id("title_container"), ) legend_layout = html.Div( self._make_legend(self._initial_data["legend_data"]), id=self.id("legend_container"), ) nn_mapping = { "CrystalNN": "CrystalNN", "Custom Bonds": "CutOffDictNN", "Jmol Bonding": "JmolNN", "Minimum Distance (10% tolerance)": "MinimumDistanceNN", "O'Keeffe's Algorithm": "MinimumOKeeffeNN", "Hoppe's ECoN Algorithm": "EconNN", "Brunner's Reciprocal Algorithm": "BrunnerNN_reciprocal", } bonding_algorithm = dcc.Dropdown( options=[{"label": k, "value": v} for k, v in nn_mapping.items()], value=self.initial_data["graph_generation_options"]["bonding_strategy"], clearable=False, id=self.id("bonding_algorithm"), persistence=self.persistence, persistence_type=self.persistence_type, ) bonding_algorithm_custom_cutoffs = html.Div( [ html.Br(), dt.DataTable( columns=[ {"name": "A", "id": "A"}, {"name": "B", "id": "B"}, {"name": "A—B /Å", "id": "A—B"}, ], editable=True, data=self._make_bonding_algorithm_custom_cuffoff_data( self.initial_data.get("default") ), id=self.id("bonding_algorithm_custom_cutoffs"), ), html.Br(), ], id=self.id("bonding_algorithm_custom_cutoffs_container"), style={"display": "none"}, ) options_layout = Field( [ # TODO: hide if molecule html.Label("Change unit cell:", className="mpc-label"), html.Div( dcc.Dropdown( options=[ {"label": "Input cell", "value": "input"}, {"label": "Primitive cell", "value": "primitive"}, {"label": "Conventional cell", "value": "conventional"}, {"label": "Reduced cell", "value": "reduced"}, ], value="input", clearable=False, id=self.id("unit-cell-choice"), persistence=self.persistence, persistence_type=self.persistence_type, ), className="mpc-control", ), html.Div( [ html.Label("Change bonding algorithm: ", className="mpc-label"), bonding_algorithm, bonding_algorithm_custom_cutoffs, ] ), html.Label("Change color scheme:", className="mpc-label"), html.Div( dcc.Dropdown( options=[ {"label": "VESTA", "value": "VESTA"}, {"label": "Jmol", "value": "Jmol"}, {"label": "Accessible", "value": "accessible"}, ], value=self.initial_data["display_options"]["color_scheme"], clearable=False, persistence=self.persistence, persistence_type=self.persistence_type, id=self.id("color-scheme"), ), className="mpc-control", ), html.Label("Change atomic radii:", className="mpc-label"), html.Div( dcc.Dropdown( options=[ {"label": "Ionic", "value": "specified_or_average_ionic"}, {"label": "Covalent", "value": "covalent"}, {"label": "Van der Waals", "value": "van_der_waals"}, { "label": f"Uniform ({Legend.uniform_radius}Å)", "value": "uniform", }, ], value=self.initial_data["display_options"]["radius_strategy"], clearable=False, persistence=self.persistence, persistence_type=self.persistence_type, id=self.id("radius_strategy"), ), className="mpc-control", ), html.Label("Draw options:", className="mpc-label"), html.Div( [ dcc.Checklist( options=[ { "label": "Draw repeats of atoms on periodic boundaries", "value": "draw_image_atoms", }, { "label": "Draw atoms outside unit cell bonded to " "atoms within unit cell", "value": "bonded_sites_outside_unit_cell", }, { "label": "Hide bonds where destination atoms are not shown", "value": "hide_incomplete_bonds", }, ], value=[ opt for opt in ( "draw_image_atoms", "bonded_sites_outside_unit_cell", "hide_incomplete_bonds", ) if self.initial_data["display_options"][opt] ], labelStyle={"display": "block"}, inputClassName="mpc-radio", id=self.id("draw_options"), persistence=self.persistence, persistence_type=self.persistence_type, ) ] ), html.Label("Hide/show:", className="mpc-label"), html.Div( [ dcc.Checklist( options=[ {"label": "Atoms", "value": "atoms"}, {"label": "Bonds", "value": "bonds"}, {"label": "Unit cell", "value": "unit_cell"}, {"label": "Polyhedra", "value": "polyhedra"}, {"label": "Axes", "value": "axes"}, ], value=["atoms", "bonds", "unit_cell", "polyhedra"], labelStyle={"display": "block"}, inputClassName="mpc-radio", id=self.id("hide-show"), persistence=self.persistence, persistence_type=self.persistence_type, ) ], className="mpc-control", ), ] ) return { "struct": struct_layout, "screenshot": screenshot_layout, "options": options_layout, "title": title_layout, "legend": legend_layout, } def layout(self, size: str = "400px") -> html.Div: """ :param size: a CSS string specifying width/height of Div :return: A html.Div containing the 3D structure or molecule """ return html.Div( self._sub_layouts["struct"], style={"width": size, "height": size} ) @staticmethod def _preprocess_input_to_graph( input: Union[Structure, StructureGraph, Molecule, MoleculeGraph], bonding_strategy: str = DEFAULTS["bonding_strategy"], bonding_strategy_kwargs: Optional[Dict] = None, ) -> Union[StructureGraph, MoleculeGraph]: if isinstance(input, Structure): # ensure fractional co-ordinates are normalized to be in [0,1) # (this is actually not guaranteed by Structure) try: input = input.as_dict(verbosity=0) except TypeError: # TODO: remove this, necessary for Slab(?), some structure subclasses don't have verbosity input = input.as_dict() for site in input["sites"]: site["abc"] = np.mod(site["abc"], 1) input = Structure.from_dict(input) if not input.is_ordered: # calculating bonds in disordered structures is currently very flaky bonding_strategy = "CutOffDictNN" # we assume most uses of this class will give a structure as an input argument, # meaning we have to calculate the graph for bonding information, however if # the graph is already known and supplied, we will use that if isinstance(input, StructureGraph) or isinstance(input, MoleculeGraph): graph = input else: if ( bonding_strategy not in StructureMoleculeComponent.available_bonding_strategies.keys() ): raise ValueError( "Bonding strategy not supported. Please supply a name " "of a NearNeighbor subclass, choose from: {}".format( ", ".join( StructureMoleculeComponent.available_bonding_strategies.keys() ) ) ) else: bonding_strategy_kwargs = bonding_strategy_kwargs or {} if bonding_strategy == "CutOffDictNN": if "cut_off_dict" in bonding_strategy_kwargs: # TODO: remove this hack by making args properly JSON serializable bonding_strategy_kwargs["cut_off_dict"] = { (x[0], x[1]): x[2] for x in bonding_strategy_kwargs["cut_off_dict"] } bonding_strategy = StructureMoleculeComponent.available_bonding_strategies[ bonding_strategy ]( **bonding_strategy_kwargs ) try: with warnings.catch_warnings(): warnings.simplefilter("ignore") if isinstance(input, Structure): graph = StructureGraph.with_local_env_strategy( input, bonding_strategy ) else: graph = MoleculeGraph.with_local_env_strategy( input, bonding_strategy ) except: # for some reason computing bonds failed, so let's not have any bonds(!) if isinstance(input, Structure): graph = StructureGraph.with_empty_graph(input) else: graph = MoleculeGraph.with_empty_graph(input) return graph @staticmethod def _get_struct_or_mol( graph: Union[StructureGraph, MoleculeGraph, Structure, Molecule] ) -> Union[Structure, Molecule]: if isinstance(graph, StructureGraph): return graph.structure elif isinstance(graph, MoleculeGraph): return graph.molecule elif isinstance(graph, Structure) or isinstance(graph, Molecule): return graph else: raise ValueError @staticmethod def get_scene_and_legend( graph: Optional[Union[StructureGraph, MoleculeGraph]], name, color_scheme=DEFAULTS["color_scheme"], color_scale=None, radius_strategy=DEFAULTS["radius_strategy"], draw_image_atoms=DEFAULTS["draw_image_atoms"], bonded_sites_outside_unit_cell=DEFAULTS["bonded_sites_outside_unit_cell"], hide_incomplete_bonds=DEFAULTS["hide_incomplete_bonds"], explicitly_calculate_polyhedra_hull=False, scene_additions=None, show_compass=DEFAULTS["show_compass"], ) -> Tuple[Scene, Dict[str, str]]: # default scene name will be name of component, "_ct_..." # strip leading _ since this will cause problems in JavaScript land scene = Scene(name=name[1:]) if graph is None: return scene, {} struct_or_mol = StructureMoleculeComponent._get_struct_or_mol(graph) # TODO: add radius_scale legend = Legend( struct_or_mol, color_scheme=color_scheme, radius_scheme=radius_strategy, cmap_range=color_scale, ) if isinstance(graph, StructureGraph): scene = graph.get_scene( draw_image_atoms=draw_image_atoms, bonded_sites_outside_unit_cell=bonded_sites_outside_unit_cell, hide_incomplete_edges=hide_incomplete_bonds, explicitly_calculate_polyhedra_hull=explicitly_calculate_polyhedra_hull, legend=legend, ) elif isinstance(graph, MoleculeGraph): scene = graph.get_scene(legend=legend) scene.name = name if hasattr(struct_or_mol, "lattice"): axes = struct_or_mol.lattice._axes_from_lattice() # TODO: fix pop-in ? axes.visible = show_compass scene.contents.append(axes) if scene_additions: # TODO: need a Scene.from_json() to make this work raise NotImplementedError scene["contents"].append(scene_additions) return scene.to_json(), legend.get_legend() def screenshot_layout(self): """ :return: A layout including a button to trigger a screenshot download. """ return self._sub_layouts["screenshot"] def options_layout(self): """ :return: A layout including options to change the appearance, bonding, etc. """ return self._sub_layouts["options"] def title_layout(self): """ :return: A layout including the composition of the structure/molecule as a title. """ return self._sub_layouts["title"] def legend_layout(self): """ :return: A layout including a legend for the structure/molecule. """ return self._sub_layouts["legend"]
app = dash.Dash('') app.title = "MP Viewer" app.scripts.config.serve_locally = True app.css.append_css( {'external_url': 'https://codepen.io/chriddyp/pen/bWLwgP.css'}) mpr = MPRester() DEFAULT_STRUCTURE = loadfn('default_structure.json') DEFAULT_COLOR_SCHEME = 'VESTA' DEFAULT_BONDING_METHOD = 'MinimumOKeeffeNN' AVAILABLE_BONDING_METHODS = [ str(c.__name__) for c in NearNeighbors.__subclasses__() ] # to help with readability, each component of the app is defined # in a modular way below; these are all then included in the app.layout LAYOUT_FORMULA_INPUT = html.Div([ dcc.Input(id='input-box', type='text', placeholder='Enter a formula or mp-id'), html.Span(' '), html.Button('Load', id='button') ]) LAYOUT_VISIBILITY_OPTIONS = html.Div([ dcc.Checklist(id='visibility_options',
class StructureMoleculeComponent(MPComponent): """ A component to display pymatgen Structure, Molecule, StructureGraph and MoleculeGraph objects. """ available_bonding_strategies = { subclass.__name__: subclass for subclass in NearNeighbors.__subclasses__() } default_scene_settings = { "extractAxis": True, # For visual diff testing, we change the renderer # to SVG since this WebGL support is more difficult # in headless browsers / CI. "renderer": "svg" if SETTINGS.TEST_MODE else "webgl", "secondaryObjectView": False, } # what to show for the title_layout if structure/molecule not loaded default_title = "Crystal Toolkit" # human-readable label to file extension # downloading Molecules has not yet been added download_options = { "Structure": { "CIF (Symmetrized)": { "fmt": "cif", "symprec": EmmetSettings().SYMPREC }, "CIF": { "fmt": "cif" }, "POSCAR": { "fmt": "poscar" }, "JSON": { "fmt": "json" }, "Prismatic": { "fmt": "prismatic" }, "VASP Input Set (MPRelaxSet)": {}, # special } } def __init__( self, struct_or_mol: Optional[Union[Structure, StructureGraph, Molecule, MoleculeGraph]] = None, id: str = None, className: str = "box", scene_additions: Optional[Scene] = None, bonding_strategy: str = DEFAULTS["bonding_strategy"], bonding_strategy_kwargs: Optional[dict] = None, color_scheme: str = DEFAULTS["color_scheme"], color_scale: Optional[str] = None, radius_strategy: str = DEFAULTS["radius_strategy"], unit_cell_choice: str = DEFAULTS["unit_cell_choice"], draw_image_atoms: bool = DEFAULTS["draw_image_atoms"], bonded_sites_outside_unit_cell: bool = DEFAULTS[ "bonded_sites_outside_unit_cell"], hide_incomplete_bonds: bool = DEFAULTS["hide_incomplete_bonds"], show_compass: bool = DEFAULTS["show_compass"], scene_settings: Optional[Dict] = None, group_by_site_property: Optional[str] = None, show_legend: bool = DEFAULTS["show_legend"], show_settings: bool = DEFAULTS["show_settings"], show_controls: bool = DEFAULTS["show_controls"], show_expand_button: bool = DEFAULTS["show_expand_button"], show_image_button: bool = DEFAULTS["show_image_button"], show_export_button: bool = DEFAULTS["show_export_button"], show_position_button: bool = DEFAULTS["show_position_button"], **kwargs, ): """ Create a StructureMoleculeComponent from a structure or molecule. :param struct_or_mol: input structure or molecule :param id: canonical id :param scene_additions: extra geometric elements to add to the 3D scene :param bonding_strategy: bonding strategy from pymatgen NearNeighbors class :param bonding_strategy_kwargs: options for the bonding strategy :param color_scheme: color scheme, see Legend class :param color_scale: color scale, see Legend class :param radius_strategy: radius strategy, see Legend class :param draw_image_atoms: whether to draw repeats of atoms on periodic images :param bonded_sites_outside_unit_cell: whether to draw sites bonded outside the unit cell :param hide_incomplete_bonds: whether to hide or show incomplete bonds :param show_compass: whether to hide or show the compass :param scene_settings: scene settings (lighting etc.) to pass to CrystalToolkitScene :param group_by_site_property: a site property used for grouping of atoms for mouseover/interaction, :param show_legend: show or hide legend panel within the scene :param show_controls: show or hide scene control bar :param show_expand_button: show or hide the full screen button within the scene control bar :param show_image_button: show or hide the image download button within the scene control bar :param show_export_button: show or hide the file export button within the scene control bar :param show_position_button: show or hide the revert position button within the scene control bar e.g. Wyckoff label :param kwargs: extra keyword arguments to pass to MPComponent """ super().__init__(id=id, default_data=struct_or_mol, **kwargs) self.className = className self.show_legend = show_legend self.show_settings = show_settings self.show_controls = show_controls self.show_expand_button = show_expand_button self.show_image_button = show_image_button self.show_export_button = show_export_button self.show_position_button = show_position_button self.initial_scene_settings = self.default_scene_settings.copy() if scene_settings: self.initial_scene_settings.update(scene_settings) self.create_store("scene_settings", initial_data=self.initial_scene_settings) # unit cell choice and bonding algorithms need to come from a settings # object (in a dcc.Store) guaranteed to be present in layout, rather # than from the controls themselves -- since these are optional and # may not be present in the layout self.create_store( "graph_generation_options", initial_data={ "bonding_strategy": bonding_strategy, "bonding_strategy_kwargs": bonding_strategy_kwargs, "unit_cell_choice": unit_cell_choice, }, ) self.create_store( "display_options", initial_data={ "color_scheme": color_scheme, "color_scale": color_scale, "radius_strategy": radius_strategy, "draw_image_atoms": draw_image_atoms, "bonded_sites_outside_unit_cell": bonded_sites_outside_unit_cell, "hide_incomplete_bonds": hide_incomplete_bonds, "show_compass": show_compass, "group_by_site_property": group_by_site_property, }, ) if scene_additions: initial_scene_additions = Scene( name="scene_additions", contents=scene_additions).to_json() else: initial_scene_additions = None self.create_store("scene_additions", initial_data=initial_scene_additions) if struct_or_mol: # graph is cached explicitly, this isn't necessary but is an # optimization so that graph is only re-generated if bonding # algorithm changes struct_or_mol = self._preprocess_structure( struct_or_mol, unit_cell_choice=unit_cell_choice) graph = self._preprocess_input_to_graph( struct_or_mol, bonding_strategy=bonding_strategy, bonding_strategy_kwargs=bonding_strategy_kwargs, ) scene, legend = self.get_scene_and_legend( graph, scene_additions=self.initial_data["scene_additions"], **self.initial_data["display_options"], ) if hasattr(struct_or_mol, "lattice"): self._lattice = struct_or_mol.lattice else: # component could be initialized without a structure, in which case # an empty scene should be displayed graph = None scene, legend = self.get_scene_and_legend( None, scene_additions=self.initial_data["scene_additions"], **self.initial_data["display_options"], ) self.create_store("legend_data", initial_data=legend) self.create_store("graph", initial_data=graph) # this is used by a CrystalToolkitScene component, not a dcc.Store self._initial_data["scene"] = scene # hide axes inset for molecules if isinstance(struct_or_mol, Molecule) or isinstance( struct_or_mol, MoleculeGraph): self.scene_kwargs = {"axisView": "HIDDEN"} else: self.scene_kwargs = {} def generate_callbacks(self, app, cache): # a lot of the verbosity in this callback is to support custom bonding # this is not the format CutOffDictNN expects (since that is not JSON # serializable), so we store as a list of tuples instead # TODO: make CutOffDictNN args JSON serializable app.clientside_callback( """ function (bonding_strategy, custom_cutoffs_rows, unit_cell_choice) { const bonding_strategy_kwargs = {} if (bonding_strategy === 'CutOffDictNN') { const cut_off_dict = [] custom_cutoffs_rows.forEach(function(row) { cut_off_dict.push([row['A'], row['B'], parseFloat(row['A—B'])]) }) bonding_strategy_kwargs.cut_off_dict = cut_off_dict } return { bonding_strategy: bonding_strategy, bonding_strategy_kwargs: bonding_strategy_kwargs, unit_cell_choice: unit_cell_choice } } """, Output(self.id("graph_generation_options"), "data"), [ Input(self.id("bonding_algorithm"), "value"), Input(self.id("bonding_algorithm_custom_cutoffs"), "data"), Input(self.id("unit-cell-choice"), "value"), ], ) app.clientside_callback( """ function (values, options) { const visibility = {} options.forEach(function (opt) { visibility[opt.value] = Boolean(values.includes(opt.value)) }) return visibility } """, Output(self.id("scene"), "toggleVisibility"), [Input(self.id("hide-show"), "value")], [State(self.id("hide-show"), "options")], ) app.clientside_callback( """ function (colorScheme, radiusStrategy, drawOptions, displayOptions) { const newDisplayOptions = Object.assign({}, displayOptions); newDisplayOptions.color_scheme = colorScheme newDisplayOptions.radius_strategy = radiusStrategy newDisplayOptions.draw_image_atoms = drawOptions.includes('draw_image_atoms') newDisplayOptions.bonded_sites_outside_unit_cell = drawOptions.includes('bonded_sites_outside_unit_cell') newDisplayOptions.hide_incomplete_bonds = drawOptions.includes('hide_incomplete_bonds') return newDisplayOptions } """, Output(self.id("display_options"), "data"), [ Input(self.id("color-scheme"), "value"), Input(self.id("radius_strategy"), "value"), Input(self.id("draw_options"), "value"), ], [State(self.id("display_options"), "data")], ) @app.callback( Output(self.id("graph"), "data"), [ Input(self.id("graph_generation_options"), "data"), Input(self.id(), "data"), ], [State(self.id("graph"), "data")], ) @cache.memoize() def update_graph(graph_generation_options, struct_or_mol, current_graph): if not struct_or_mol: raise PreventUpdate struct_or_mol = self.from_data(struct_or_mol) current_graph = self.from_data(current_graph) bonding_strategy_kwargs = graph_generation_options[ "bonding_strategy_kwargs"] # TODO: add additional check here? unit_cell_choice = graph_generation_options["unit_cell_choice"] struct_or_mol = self._preprocess_structure(struct_or_mol, unit_cell_choice) graph = self._preprocess_input_to_graph( struct_or_mol, bonding_strategy=graph_generation_options["bonding_strategy"], bonding_strategy_kwargs=bonding_strategy_kwargs, ) if (current_graph and graph.structure == current_graph.structure and graph == current_graph): raise PreventUpdate return graph @app.callback( Output(self.id("scene"), "data"), [ Input(self.id("graph"), "data"), Input(self.id("display_options"), "data"), Input(self.id("scene_additions"), "data"), ], ) @cache.memoize() def update_scene(graph, display_options, scene_additions): if not graph or not display_options: raise PreventUpdate display_options = self.from_data(display_options) graph = self.from_data(graph) scene, legend = self.get_scene_and_legend( graph, **display_options, scene_additions=scene_additions) return scene @app.callback( Output(self.id("legend_data"), "data"), [ Input(self.id("graph"), "data"), Input(self.id("display_options"), "data"), Input(self.id("scene_additions"), "data"), ], ) @cache.memoize() def update_legend_and_colors(graph, display_options, scene_additions): if not graph or not display_options: raise PreventUpdate display_options = self.from_data(display_options) graph = self.from_data(graph) scene, legend = self.get_scene_and_legend( graph, **display_options, scene_additions=scene_additions) return legend @app.callback( Output(self.id("color-scheme"), "options"), [Input(self.id("legend_data"), "data")], ) def update_color_options(legend_data): # TODO: make client-side color_options = [ { "label": "Jmol", "value": "Jmol" }, { "label": "VESTA", "value": "VESTA" }, { "label": "Accessible", "value": "accessible" }, ] if not legend_data: return color_options for option in legend_data["available_color_schemes"]: color_options += [{ "label": f"Site property: {option}", "value": option }] return color_options # app.clientside_callback( # """ # function (legendData) { # # var colorOptions = [ # {label: "Jmol", value: "Jmol"}, # {label: "VESTA", value: "VESTA"}, # {label: "Accessible", value: "accessible"}, # ] # # # # return colorOptions # } # """, # Output(self.id("color-scheme"), "options"), # [Input(self.id("legend_data"), "data")] # ) @app.callback( Output(self.id("download-image"), "data"), Input(self.id("scene"), "imageDataTimestamp"), [ State(self.id("scene"), "imageData"), State(self.id(), "data"), ], ) def download_image(image_data_timestamp, image_data, data): if not image_data_timestamp: raise PreventUpdate struct_or_mol = self.from_data(data) if isinstance(struct_or_mol, StructureGraph): formula = struct_or_mol.structure.composition.reduced_formula elif isinstance(struct_or_mol, MoleculeGraph): formula = struct_or_mol.molecule.composition.reduced_Formula else: formula = struct_or_mol.composition.reduced_formula if hasattr(struct_or_mol, "get_space_group_info"): spgrp = struct_or_mol.get_space_group_info()[0] else: spgrp = "" request_filename = "{}-{}-crystal-toolkit.png".format( formula, spgrp) return { "content": image_data[len("data:image/png;base64,"):], "filename": request_filename, "base64": True, "type": "image/png", } @app.callback( Output(self.id("download-structure"), "data"), Input(self.id("scene"), "fileTimestamp"), [ State(self.id("scene"), "fileType"), State(self.id(), "data"), ], ) def download_structure(file_timestamp, download_option, data): if not file_timestamp: raise PreventUpdate structure = self.from_data(data) if isinstance(structure, StructureGraph): structure = structure.structure file_prefix = structure.composition.reduced_formula if "VASP" not in download_option: extension = self.download_options["Structure"][ download_option]["fmt"] options = self.download_options["Structure"][download_option] try: contents = structure.to(**options) except Exception as exc: # don't fail silently, tell user what went wrong contents = exc base64 = b64encode(contents.encode("utf-8")).decode("ascii") download_data = { "content": base64, "base64": True, "type": "text/plain", "filename": f"{file_prefix}.{extension}", } else: if "Relax" in download_option: vis = MPRelaxSet(structure) expected_filename = "MPRelaxSet.zip" else: raise ValueError( "No other VASP input sets currently supported.") with TemporaryDirectory() as tmpdir: vis.write_input(tmpdir, potcar_spec=True, zip_output=True) path = Path(tmpdir) / expected_filename bytes = b64encode(path.read_bytes()).decode("ascii") download_data = { "content": bytes, "base64": True, "type": "application/zip", "filename": f"{file_prefix} {expected_filename}", } return download_data @app.callback( Output(self.id("title_container"), "children"), [Input(self.id("legend_data"), "data")], ) @cache.memoize() def update_title(legend): if not legend: raise PreventUpdate legend = self.from_data(legend) return self._make_title(legend) @app.callback( Output(self.id("legend_container"), "children"), [Input(self.id("legend_data"), "data")], ) @cache.memoize() def update_legend(legend): if not legend: raise PreventUpdate legend = self.from_data(legend) return self._make_legend(legend) @app.callback( [ Output(self.id("bonding_algorithm_custom_cutoffs"), "data"), Output(self.id("bonding_algorithm_custom_cutoffs_container"), "style"), ], [Input(self.id("bonding_algorithm"), "value")], [ State(self.id("graph"), "data"), State(self.id("bonding_algorithm_custom_cutoffs_container"), "style"), ], ) @cache.memoize() def update_custom_bond_options(bonding_algorithm, graph, current_style): if not graph: raise PreventUpdate if bonding_algorithm == "CutOffDictNN": style = {} else: style = {"display": "none"} if style == current_style: # no need to update rows if we're not showing them raise PreventUpdate graph = self.from_data(graph) rows = self._make_bonding_algorithm_custom_cuffoff_data(graph) return rows, style def _make_legend(self, legend): if not legend: return html.Div(id=self.id("legend")) def get_font_color(hex_code): # ensures contrasting font color for background color c = tuple(int(hex_code[1:][i:i + 2], 16) for i in (0, 2, 4)) if 1 - (c[0] * 0.299 + c[1] * 0.587 + c[2] * 0.114) / 255 < 0.5: font_color = "#000000" else: font_color = "#ffffff" return font_color try: formula = Composition.from_dict( legend["composition"]).reduced_formula except: # TODO: fix legend for Dummy Specie compositions formula = "Unknown" legend_colors = OrderedDict( sorted(list(legend["colors"].items()), key=lambda x: formula.find(x[1]))) legend_elements = [ html.Span( html.Span(name, className="icon", style={"color": get_font_color(color)}), className="button is-static is-rounded", style={"backgroundColor": color}, ) for color, name in legend_colors.items() ] return html.Div(legend_elements, id=self.id("legend"), style={"display": "flex"}, className="buttons") def _make_title(self, legend): if not legend or (not legend.get("composition", None)): return H2(self.default_title, id=self.id("title")) composition = legend["composition"] if isinstance(composition, dict): try: composition = Composition.from_dict(composition) # strip DummySpecie if present (TODO: should be method in pymatgen) composition = Composition({ el: amt for el, amt in composition.items() if not isinstance(el, DummySpecie) }) composition = composition.get_reduced_composition_and_factor( )[0] formula = composition.reduced_formula formula_parts = re.findall(r"[^\d_]+|\d+", formula) formula_components = [ html.Sub(part.strip()) if part.isnumeric() else html.Span(part.strip()) for part in formula_parts ] except: formula_components = list(map(str, composition.keys())) return H2(formula_components, id=self.id("title"), style={"display": "inline-block"}) @staticmethod def _make_bonding_algorithm_custom_cuffoff_data(graph): if not graph: return [{"A": None, "B": None, "A—B": None}] struct_or_mol = StructureMoleculeComponent._get_struct_or_mol(graph) # can't use type_of_specie because it doesn't work with disordered structures species = set( map( str, chain.from_iterable( [list(c.keys()) for c in struct_or_mol.species_and_occu]), )) rows = [{ "A": combination[0], "B": combination[1], "A—B": 0 } for combination in combinations_with_replacement(species, 2)] return rows @property def _sub_layouts(self): title_layout = html.Div( self._make_title(self._initial_data["legend_data"]), id=self.id("title_container"), ) nn_mapping = { "CrystalNN": "CrystalNN", "Custom Bonds": "CutOffDictNN", "Jmol Bonding": "JmolNN", "Minimum Distance (10% tolerance)": "MinimumDistanceNN", "O'Keeffe's Algorithm": "MinimumOKeeffeNN", "Hoppe's ECoN Algorithm": "EconNN", "Brunner's Reciprocal Algorithm": "BrunnerNN_reciprocal", } bonding_algorithm = dcc.Dropdown( options=[{ "label": k, "value": v } for k, v in nn_mapping.items()], value=self.initial_data["graph_generation_options"] ["bonding_strategy"], clearable=False, id=self.id("bonding_algorithm"), persistence=SETTINGS.PERSISTENCE, persistence_type=SETTINGS.PERSISTENCE_TYPE, ) bonding_algorithm_custom_cutoffs = html.Div( [ html.Br(), dt.DataTable( columns=[ { "name": "A", "id": "A" }, { "name": "B", "id": "B" }, { "name": "A—B /Å", "id": "A—B" }, ], editable=True, data=self._make_bonding_algorithm_custom_cuffoff_data( self.initial_data.get("default")), id=self.id("bonding_algorithm_custom_cutoffs"), ), html.Br(), ], id=self.id("bonding_algorithm_custom_cutoffs_container"), style={"display": "none"}, ) if self.show_settings: options_layout = Field([ # TODO: hide if molecule html.Label("Change unit cell:", className="mpc-label"), html.Div( dcc.Dropdown( options=[ { "label": "Input cell", "value": "input" }, { "label": "Primitive cell", "value": "primitive" }, { "label": "Conventional cell", "value": "conventional" }, { "label": "Reduced cell (Niggli)", "value": "reduced_niggli", }, { "label": "Reduced cell (LLL)", "value": "reduced_lll" }, ], value="input", clearable=False, id=self.id("unit-cell-choice"), persistence=SETTINGS.PERSISTENCE, persistence_type=SETTINGS.PERSISTENCE_TYPE, ), className="mpc-control", ), html.Div([ html.Label("Change bonding algorithm: ", className="mpc-label"), bonding_algorithm, bonding_algorithm_custom_cutoffs, ]), html.Label("Change color scheme:", className="mpc-label"), html.Div( dcc.Dropdown( options=[ { "label": "VESTA", "value": "VESTA" }, { "label": "Jmol", "value": "Jmol" }, { "label": "Accessible", "value": "accessible" }, ], value=self.initial_data["display_options"] ["color_scheme"], clearable=False, persistence=SETTINGS.PERSISTENCE, persistence_type=SETTINGS.PERSISTENCE_TYPE, id=self.id("color-scheme"), ), className="mpc-control", ), html.Label("Change atomic radii:", className="mpc-label"), html.Div( dcc.Dropdown( options=[ { "label": "Ionic", "value": "specified_or_average_ionic" }, { "label": "Covalent", "value": "covalent" }, { "label": "Van der Waals", "value": "van_der_waals" }, { "label": f"Uniform ({Legend.uniform_radius}Å)", "value": "uniform", }, ], value=self.initial_data["display_options"] ["radius_strategy"], clearable=False, persistence=SETTINGS.PERSISTENCE, persistence_type=SETTINGS.PERSISTENCE_TYPE, id=self.id("radius_strategy"), ), className="mpc-control", ), html.Label("Draw options:", className="mpc-label"), html.Div([ dcc.Checklist( options=[ { "label": "Draw repeats of atoms on periodic boundaries", "value": "draw_image_atoms", }, { "label": "Draw atoms outside unit cell bonded to " "atoms within unit cell", "value": "bonded_sites_outside_unit_cell", }, { "label": "Hide bonds where destination atoms are not shown", "value": "hide_incomplete_bonds", }, ], value=[ opt for opt in ( "draw_image_atoms", "bonded_sites_outside_unit_cell", "hide_incomplete_bonds", ) if self.initial_data["display_options"][opt] ], labelStyle={"display": "block"}, inputClassName="mpc-radio", id=self.id("draw_options"), persistence=SETTINGS.PERSISTENCE, persistence_type=SETTINGS.PERSISTENCE_TYPE, ) ]), html.Label("Hide/show:", className="mpc-label"), html.Div( [ dcc.Checklist( options=[ { "label": "Atoms", "value": "atoms" }, { "label": "Bonds", "value": "bonds" }, { "label": "Unit cell", "value": "unit_cell" }, { "label": "Polyhedra", "value": "polyhedra" }, { "label": "Axes", "value": "axes" }, ], value=["atoms", "bonds", "unit_cell", "polyhedra"], labelStyle={"display": "block"}, inputClassName="mpc-radio", id=self.id("hide-show"), persistence=SETTINGS.PERSISTENCE, persistence_type=SETTINGS.PERSISTENCE_TYPE, ) ], className="mpc-control", ), ]) else: options_layout = None if self.show_legend: legend_layout = html.Div( self._make_legend(self._initial_data["legend_data"]), id=self.id("legend_container"), ) else: legend_layout = None struct_layout = html.Div([ CrystalToolkitScene( [options_layout, legend_layout], id=self.id("scene"), className=self.className, data=self.initial_data["scene"], settings=self.initial_scene_settings, sceneSize="100%", fileOptions=list(self.download_options["Structure"].keys()), showControls=self.show_controls, showExpandButton=self.show_expand_button, showImageButton=self.show_image_button, showExportButton=self.show_export_button, showPositionButton=self.show_position_button, **self.scene_kwargs, ), dcc.Download(id=self.id("download-image")), dcc.Download(id=self.id("download-structure")) ]) return { "struct": struct_layout, "options": options_layout, "title": title_layout, "legend": legend_layout, } def layout(self, size: str = "500px") -> html.Div: """ :param size: a CSS string specifying width/height of Div :return: A html.Div containing the 3D structure or molecule """ return html.Div(self._sub_layouts["struct"], style={ "width": size, "height": size }) @staticmethod def _preprocess_structure( struct_or_mol: Union[Structure, StructureGraph, Molecule, MoleculeGraph], unit_cell_choice: Literal["input", "primitive", "conventional", "reduced_niggli", "reduced_lll"] = "input", ): if isinstance(struct_or_mol, Structure): if unit_cell_choice != "input": if unit_cell_choice == "primitive": struct_or_mol = struct_or_mol.get_primitive_structure() elif unit_cell_choice == "conventional": sga = SpacegroupAnalyzer(struct_or_mol) struct_or_mol = sga.get_conventional_standard_structure() elif unit_cell_choice == "reduced_niggli": struct_or_mol = struct_or_mol.get_reduced_structure( reduction_algo="niggli") elif unit_cell_choice == "reduced_lll": struct_or_mol = struct_or_mol.get_reduced_structure( reduction_algo="LLL") return struct_or_mol @staticmethod def _preprocess_input_to_graph( input: Union[Structure, StructureGraph, Molecule, MoleculeGraph], bonding_strategy: str = DEFAULTS["bonding_strategy"], bonding_strategy_kwargs: Optional[Dict] = None, ) -> Union[StructureGraph, MoleculeGraph]: if isinstance(input, Structure): # ensure fractional co-ordinates are normalized to be in [0,1) # (this is actually not guaranteed by Structure) try: input = input.as_dict(verbosity=0) except TypeError: # TODO: remove this, necessary for Slab(?), some structure subclasses don't have verbosity input = input.as_dict() for site in input["sites"]: site["abc"] = np.mod(site["abc"], 1) input = Structure.from_dict(input) if not input.is_ordered: # calculating bonds in disordered structures is currently very flaky bonding_strategy = "CutOffDictNN" # we assume most uses of this class will give a structure as an input argument, # meaning we have to calculate the graph for bonding information, however if # the graph is already known and supplied, we will use that if isinstance(input, StructureGraph) or isinstance( input, MoleculeGraph): graph = input else: if (bonding_strategy not in StructureMoleculeComponent. available_bonding_strategies.keys()): raise ValueError( "Bonding strategy not supported. Please supply a name " "of a NearNeighbor subclass, choose from: {}".format( ", ".join(StructureMoleculeComponent. available_bonding_strategies.keys()))) else: bonding_strategy_kwargs = bonding_strategy_kwargs or {} if bonding_strategy == "CutOffDictNN": if "cut_off_dict" in bonding_strategy_kwargs: # TODO: remove this hack by making args properly JSON serializable bonding_strategy_kwargs["cut_off_dict"] = { (x[0], x[1]): x[2] for x in bonding_strategy_kwargs["cut_off_dict"] } bonding_strategy = StructureMoleculeComponent.available_bonding_strategies[ bonding_strategy](**bonding_strategy_kwargs) try: with warnings.catch_warnings(): warnings.simplefilter("ignore") if isinstance(input, Structure): graph = StructureGraph.with_local_env_strategy( input, bonding_strategy) else: graph = MoleculeGraph.with_local_env_strategy( input, bonding_strategy, reorder=False) except: # for some reason computing bonds failed, so let's not have any bonds(!) if isinstance(input, Structure): graph = StructureGraph.with_empty_graph(input) else: graph = MoleculeGraph.with_empty_graph(input) return graph @staticmethod def _get_struct_or_mol( graph: Union[StructureGraph, MoleculeGraph, Structure, Molecule] ) -> Union[Structure, Molecule]: if isinstance(graph, StructureGraph): return graph.structure elif isinstance(graph, MoleculeGraph): return graph.molecule elif isinstance(graph, Structure) or isinstance(graph, Molecule): return graph else: raise ValueError @staticmethod def get_scene_and_legend( graph: Optional[Union[StructureGraph, MoleculeGraph]], color_scheme=DEFAULTS["color_scheme"], color_scale=None, radius_strategy=DEFAULTS["radius_strategy"], draw_image_atoms=DEFAULTS["draw_image_atoms"], bonded_sites_outside_unit_cell=DEFAULTS[ "bonded_sites_outside_unit_cell"], hide_incomplete_bonds=DEFAULTS["hide_incomplete_bonds"], explicitly_calculate_polyhedra_hull=False, scene_additions=None, show_compass=DEFAULTS["show_compass"], group_by_site_property=None, ) -> Tuple[Scene, Dict[str, str]]: scene = Scene(name="StructureMoleculeComponentScene") if graph is None: return scene, {} struct_or_mol = StructureMoleculeComponent._get_struct_or_mol(graph) # TODO: add radius_scale legend = Legend( struct_or_mol, color_scheme=color_scheme, radius_scheme=radius_strategy, cmap_range=color_scale, ) if isinstance(graph, StructureGraph): scene = graph.get_scene( draw_image_atoms=draw_image_atoms, bonded_sites_outside_unit_cell=bonded_sites_outside_unit_cell, hide_incomplete_edges=hide_incomplete_bonds, explicitly_calculate_polyhedra_hull= explicitly_calculate_polyhedra_hull, group_by_site_property=group_by_site_property, legend=legend, ) elif isinstance(graph, MoleculeGraph): scene = graph.get_scene(legend=legend) scene.name = "StructureMoleculeComponentScene" if hasattr(struct_or_mol, "lattice"): axes = struct_or_mol.lattice._axes_from_lattice() axes.visible = show_compass scene.contents.append(axes) scene_json = scene.to_json() if scene_additions: # TODO: this might be cleaner if we had a Scene.from_json() method scene_json["contents"].append(scene_additions) return scene_json, legend.get_legend() def title_layout(self): """ :return: A layout including the composition of the structure/molecule as a title. """ return self._sub_layouts["title"]
class StructureComponent(MPComponent): """ A component to display pymatgen Structure objects. """ available_bonding_strategies = \ {subcls.__name__: subcls for subcls in NearNeighbors.__subclasses__()} default_scene_settings = { "extractAxis": True, # For visual diff testing, we change the renderer # to SVG since this WebGL support is more difficult # in headless browsers / CI. "renderer": "svg" if SETTINGS.TEST_MODE else "webgl", } default_title = "Structure" def __init__( self, structure: Optional[Structure] = None, id: str = None, scene_additions: Optional[Scene] = None, bonding_strategy: str = DEFAULTS["bonding_strategy"], bonding_strategy_kwargs: Optional[dict] = None, color_scale: Optional[str] = None, radius_strategy: str = DEFAULTS["radius_strategy"], unit_cell_choice: str = DEFAULTS["unit_cell_choice"], draw_image_atoms: bool = DEFAULTS["draw_image_atoms"], bonded_sites_outside_unit_cell: bool = DEFAULTS[ "bonded_sites_outside_unit_cell"], hide_incomplete_bonds: bool = DEFAULTS["hide_incomplete_bonds"], show_compass: bool = DEFAULTS["show_compass"], scene_settings: Optional[Dict] = None, **kwargs, ): super().__init__(id=id, default_data=structure, **kwargs) self.initial_scene_settings = self.default_scene_settings.copy() if scene_settings: self.initial_scene_settings.update(scene_settings) self.create_store("scene_settings", initial_data=self.initial_scene_settings) # unit cell choice and bonding algorithms need to come from a settings # object (in a dcc.Store) guaranteed to be present in layout, rather # than from the controls themselves -- since these are optional and # may not be present in the layout self.create_store( "graph_generation_options", initial_data={ "bonding_strategy": bonding_strategy, "bonding_strategy_kwargs": bonding_strategy_kwargs, "unit_cell_choice": unit_cell_choice, }, ) self.create_store( "display_options", initial_data={ "color_scale": color_scale, "radius_strategy": radius_strategy, "draw_image_atoms": draw_image_atoms, "bonded_sites_outside_unit_cell": bonded_sites_outside_unit_cell, "hide_incomplete_bonds": hide_incomplete_bonds, "show_compass": show_compass, }, ) if scene_additions: initial_scene_additions = Scene( name="scene_additions", contents=scene_additions).to_json() else: initial_scene_additions = None self.create_store("scene_additions", initial_data=initial_scene_additions) if structure: # graph is cached explicitly, this isn't necessary but is an # optimization so that graph is only re-generated if bonding # algorithm changes graph = self._preprocess_input_to_graph( structure, bonding_strategy=bonding_strategy, bonding_strategy_kwargs=bonding_strategy_kwargs, ) scene, legend = self.get_scene_and_legend( graph, scene_additions=self.initial_data["scene_additions"], **self.initial_data["display_options"], ) if hasattr(structure, "lattice"): self._lattice = structure.lattice else: # component could be initialized without a structure, in which case # an empty scene should be displayed graph = None scene, legend = self.get_scene_and_legend( None, scene_additions=self.initial_data["scene_additions"], **self.initial_data["display_options"]) self.create_store("legend_data", initial_data=legend) self.create_store("graph", initial_data=graph) # this is used by a Simple3DScene component, not a dcc.Store self._initial_data["scene"] = scene # hide axes inset for molecules if isinstance(structure, Molecule) or isinstance( structure, MoleculeGraph): self.scene_kwargs = {"axisView": "HIDDEN"} else: self.scene_kwargs = {} def _make_legend(self, legend): if not legend: return html.Div(id=self.id("legend")) def get_font_color(hex_code): # ensures contrasting font color for background color c = tuple(int(hex_code[1:][i:i + 2], 16) for i in (0, 2, 4)) if 1 - (c[0] * 0.299 + c[1] * 0.587 + c[2] * 0.114) / 255 < 0.5: font_color = "#000000" else: font_color = "#ffffff" return font_color try: formula = Composition.from_dict( legend["composition"]).reduced_formula except: # TODO: fix legend for Dummy Specie compositions formula = "Unknown" legend_colors = OrderedDict( sorted(list(legend["colors"].items()), key=lambda x: formula.find(x[1]))) legend_elements = [ Button( html.Span(name, className="icon", style={"color": get_font_color(color)}), kind="static", style={"backgroundColor": color}, ) for color, name in legend_colors.items() ] return Field( [ Control(el, style={"marginRight": "0.2rem"}) for el in legend_elements ], id=self.id("legend"), grouped=True, ) def _make_title(self, legend): if not legend or (not legend.get("composition", None)): return H1(self.default_title, id=self.id("title")) composition = legend["composition"] if isinstance(composition, dict): try: composition = Composition.from_dict(composition) # strip DummySpecie if present (TODO: should be method in pymatgen) composition = Composition({ el: amt for el, amt in composition.items() if not isinstance(el, DummySpecie) }) composition = composition.get_reduced_composition_and_factor( )[0] formula = composition.reduced_formula formula_parts = re.findall(r"[^\d_]+|\d+", formula) formula_components = [ html.Sub(part.strip()) if part.isnumeric() else html.Span(part.strip()) for part in formula_parts ] except: formula_components = list(map(str, composition.keys())) return H1(formula_components, id=self.id("title"), style={"display": "inline-block"}) @staticmethod def _make_bonding_algorithm_custom_cuffoff_data(graph): if not graph: return [{"A": None, "B": None, "A—B": None}] struct_or_mol = StructureComponent._get_structure(graph) # can't use type_of_specie because it doesn't work with disordered structures species = set( map( str, chain.from_iterable( [list(c.keys()) for c in struct_or_mol.species_and_occu]), )) rows = [{ "A": combination[0], "B": combination[1], "A—B": 0 } for combination in combinations_with_replacement(species, 2)] return rows @property def _sub_layouts(self): struct_layout = html.Div( Simple3DScene( id=self.id("scene"), data=self.initial_data["scene"], settings=self.initial_scene_settings, sceneSize="100%", **self.scene_kwargs, ), style={ "width": "100%", "height": "100%", "overflow": "hidden", "margin": "0 auto", }, ) title_layout = html.Div( self._make_title(self._initial_data["legend_data"]), id=self.id("title_container"), ) legend_layout = html.Div( self._make_legend(self._initial_data["legend_data"]), id=self.id("legend_container"), ) return { "struct": struct_layout, "title": title_layout, "legend": legend_layout } def layout(self, size: str = "500px") -> html.Div: """ :param size: a CSS string specifying width/height of Div :return: A html.Div containing the 3D structure or molecule """ return html.Div(self._sub_layouts["struct"], style={ "width": size, "height": size }) @staticmethod def _preprocess_input_to_graph( input: Union[Structure, StructureGraph], bonding_strategy: str = DEFAULTS["bonding_strategy"], bonding_strategy_kwargs: Optional[Dict] = None, ) -> Union[StructureGraph, MoleculeGraph]: # ensure fractional co-ordinates are normalized to be in [0,1) # (this is actually not guaranteed by Structure) try: input = input.as_dict(verbosity=0) except TypeError: # TODO: remove this, necessary for Slab(?), some structure subclasses don't have verbosity input = input.as_dict() for site in input["sites"]: site["abc"] = np.mod(site["abc"], 1) input = Structure.from_dict(input) if not input.is_ordered: # calculating bonds in disordered structures is currently very flaky bonding_strategy = "CutOffDictNN" if (bonding_strategy not in StructureComponent.available_bonding_strategies.keys()): raise ValueError( "Bonding strategy not supported. Please supply a name " "of a NearNeighbor subclass, choose from: {}".format(", ".join( StructureComponent.available_bonding_strategies.keys()))) else: bonding_strategy_kwargs = bonding_strategy_kwargs or {} if bonding_strategy == "CutOffDictNN": if "cut_off_dict" in bonding_strategy_kwargs: # TODO: remove this hack by making args properly JSON serializable bonding_strategy_kwargs["cut_off_dict"] = { (x[0], x[1]): x[2] for x in bonding_strategy_kwargs["cut_off_dict"] } bonding_strategy = StructureComponent.available_bonding_strategies[ bonding_strategy](**bonding_strategy_kwargs) try: with warnings.catch_warnings(): warnings.simplefilter("ignore") graph = StructureGraph.with_local_env_strategy( input, bonding_strategy) except: # for some reason computing bonds failed, so let's not have any bonds(!) graph = StructureGraph.with_empty_graph(input) return graph def generate_callbacks(self, app, cache): pass @staticmethod def _get_structure( graph: Union[StructureGraph, Structure]) -> Union[Structure, Molecule]: if isinstance(graph, StructureGraph): return graph.structure elif isinstance(graph, Structure): return graph else: raise ValueError @staticmethod def get_scene_and_legend( graph: Optional[Union[StructureGraph, MoleculeGraph]], color_scale=None, radius_strategy=DEFAULTS["radius_strategy"], draw_image_atoms=DEFAULTS["draw_image_atoms"], bonded_sites_outside_unit_cell=DEFAULTS[ "bonded_sites_outside_unit_cell"], hide_incomplete_bonds=DEFAULTS["hide_incomplete_bonds"], explicitly_calculate_polyhedra_hull=False, scene_additions=None, show_compass=DEFAULTS["show_compass"], ) -> Tuple[Scene, Dict[str, str]]: scene = Scene(name="AceStructureMoleculeComponentScene") if graph is None: return scene, {} structure = StructureComponent._get_structure(graph) # TODO: add radius_scale legend = Legend( structure, color_scheme="VESTA", radius_scheme=radius_strategy, cmap_range=color_scale, ) scene = graph.get_scene( draw_image_atoms=draw_image_atoms, bonded_sites_outside_unit_cell=bonded_sites_outside_unit_cell, hide_incomplete_edges=hide_incomplete_bonds, explicitly_calculate_polyhedra_hull= explicitly_calculate_polyhedra_hull, legend=legend, ) scene.name = "StructureComponentScene" if hasattr(structure, "lattice"): axes = structure.lattice._axes_from_lattice() axes.visible = show_compass scene.contents.append(axes) scene = scene.to_json() if scene_additions: # TODO: need a Scene.from_json() to make this work # raise NotImplementedError scene["contents"].append(scene_additions) return scene, legend.get_legend() def title_layout(self): """ :return: A layout including the composition of the structure/molecule as a title. """ return self._sub_layouts["title"] def legend_layout(self): """ :return: A layout including a legend for the structure/molecule. """ return self._sub_layouts["legend"]
class PymatgenVisualizationIntermediateFormat: """ Class takes a Structure or StructureGraph and outputs a list of primitives (spheres, cylinders, etc.) for drawing purposes. """ available_bonding_strategies = { subclass.__name__: subclass for subclass in NearNeighbors.__subclasses__() } available_radius_strategies = ('atomic', 'bvanalyzer_ionic', 'average_ionic', 'covalent', 'van_der_waals', 'atomic_calculated') def __init__( self, structure, bonding_strategy='MinimumDistanceNN', bonding_strategy_kwargs=None, color_scheme="VESTA", color_scale=None, radius_strategy="average_ionic", draw_image_atoms=True, repeat_of_atoms_on_boundaries=True, bonded_sites_outside_display_area=True, symmetrize=True, # TODO display_repeats=((0, 2), (0, 2), (0, 2))): """ This class is used to generate a generic JSON of geometric primitives that can be parsed by a 3D rendering tool such as Three.js or Blender. The intent is that any crystallographic knowledge is handled by pymatgen (e.g. finding Cartesian co-ordinates of atoms, suggesting colors and radii for the atoms, detecting whether andw here bonds should be present, detecting oxidation environments, calculating coordination polyhedra, etc.) Therefore, the 3D rendering code does not need to have any knowledge of crystallography, but just has to parse the JSON and draw the corresponding geometric primitives. The resulting JSON should be self-explanatory and constructed such that an appropriate scene graph can be made without difficulty. The 'type' keys can be either 'sphere', 'cylinder', 'geometry' or 'text', the 'position' key is always a vector with respect to global Cartesian co-ordinates. Bonds and co-ordination polyhedra reference the index of the appropriate atom, rather than duplicating position vectors. Args: structure: an instance of Structure or StructureGraph (the latter for when bonding information is already known) bonding_strategy: a name of a NearNeighbor subclass to calculate bonding, will only be used if bonding not already defined (at time of writing, options are 'JMolNN', 'MinimumDistanceNN', 'MinimumOKeeffeeNN', 'MinimumVIRENN', 'VoronoiNN', but see pymatgen.analysis.local_env for latest options) bonding_strategy_kwargs: kwargs to be passed to the NearNeighbor subclass color_scheme: can be "VESTA", "Jmol" or a name of a site property to color-code by that site property color_scale: only relevant if color-coding by site property, by default will use viridis, but all matplotlib color scales are supported (see matplotlib documentation) radius_strategy: options are "ionic" or "van_der_Waals": by default, will use ionic radii if oxidation states are supplied or if BVAnalyzer can supply oxidation states, otherwise will use average ionic radius """ # TODO: update docstring # ensure fractional co-ordinates are normalized to be in [0,1) # (this is actually not guaranteed by Structure) if isinstance(structure, Structure): structure = structure.as_dict(verbosity=0) for site in structure['sites']: site['abc'] = np.mod(site['abc'], 1) structure = Structure.from_dict(structure) # we assume most uses of this class will give a structure as an input argument, # meaning we have to calculate the graph for bonding information, however if # the graph is already known and supplied, we will use that if not isinstance(structure, StructureGraph): if bonding_strategy not in self.available_bonding_strategies.keys( ): raise ValueError( "Bonding strategy not supported. Please supply a name " "of a NearNeighbor subclass, choose from: {}".format( ", ".join(self.available_bonding_strategies.keys()))) else: bonding_strategy_kwargs = bonding_strategy_kwargs or {} bonding_strategy = self.available_bonding_strategies[ bonding_strategy](**bonding_strategy_kwargs) self.structure_graph = StructureGraph.with_local_env_strategy( structure, bonding_strategy) cns = [ self.structure_graph.get_coordination_of_site(i) for i in range(len(structure)) ] self.structure_graph.structure.add_site_property( 'coordination_no', cns) else: self.structure_graph = structure self.structure = self.structure_graph.structure self.lattice = self.structure_graph.structure.lattice self.color_scheme = color_scheme # TODO: add coord as option self.color_scale = color_scale self.radius_strategy = radius_strategy self.draw_image_atoms = draw_image_atoms self.bonded_sites_outside_display_area = bonded_sites_outside_display_area self.display_range = display_repeats # categorize site properties so we know which can be used for color schemes etc. self.site_prop_names = self._analyze_site_props() self.color_legend = self._generate_colors( ) # adds 'display_color' site prop self._generate_radii() # adds 'display_radius' site prop @property def json(self): atoms = self._generate_atoms() bonds = self._generate_bonds(atoms) polyhedra_json = self._generate_polyhedra(atoms, bonds) unit_cell_json = self._generate_unit_cell() json = { 'atoms': list(atoms.values()), 'bonds': list(bonds.values()), 'polyhedra': polyhedra_json, 'unit_cell': unit_cell_json, 'color_legend': self.color_legend, 'site_props': self.site_prop_names } return json @property def graph_json(self): nodes = [] edges = [] for node in self.structure_graph.graph.nodes(): r, g, b = self.structure_graph.structure[node].properties[ 'display_color'][0] color = "#{:02x}{:02x}{:02x}".format(r, g, b) nodes.append({'id': node, 'label': node, 'color': color}) for u, v, d in self.structure_graph.graph.edges(data=True): edge = {'from': u, 'to': v, 'arrows': ''} to_jimage = d['to_jimage'] # TODO: check these edge weights dist = self.structure.get_distance(u, v, to_jimage) edge['length'] = 50 * dist if to_jimage != (0, 0, 0): edge['arrows'] = 'to' edge['label'] = str(to_jimage) edges.append(edge) return {'nodes': nodes, 'edges': edges} def _analyze_site_props(self): # store list of site props that are vectors, so these can be displayed as arrows # (implicitly assumes all site props for a given key are same type) site_prop_names = defaultdict(list) for name, props in self.structure_graph.structure.site_properties.items( ): if isinstance(props[0], float) or isinstance(props[0], int): site_prop_names['scalar'].append(name) elif isinstance(props[0], list) and len(props[0]) == 3: if isinstance(props[0][0], list) and len(props[0][0]) == 3: site_prop_names['matrix'].append(name) else: site_prop_names['vector'].append(name) elif isinstance(props[0], str): site_prop_names['categorical'].append(name) return dict(site_prop_names) def _generate_atoms(self): # to translate atoms so that geometric center at (0, 0, 0) # in global co-ordinate system x_center = 0.5 * (max(self.display_range[0]) - min(self.display_range[0])) y_center = 0.5 * (max(self.display_range[1]) - min(self.display_range[1])) z_center = 0.5 * (max(self.display_range[2]) - min(self.display_range[2])) self.geometric_center = self.lattice.get_cartesian_coords( (x_center, y_center, z_center)) ranges = [ range(int(np.sign(r[0]) * np.ceil(np.abs(r[0]))), 1 + int(np.sign(r[1]) * np.ceil(np.abs(r[1])))) for r in self.display_range ] possible_images = list(itertools.product(*ranges)) site_images_to_draw = defaultdict(list) lower_corner = np.array([min(r) for r in self.display_range]) upper_corner = np.array([max(r) for r in self.display_range]) for idx, site in enumerate(self.structure): for image in possible_images: frac_coords = np.add(image, site.frac_coords) if np.all(np.less_equal(lower_corner, frac_coords)) \ and np.all(np.less_equal(frac_coords, upper_corner)): site_images_to_draw[idx].append(image) images_to_add = defaultdict(list) if self.bonded_sites_outside_display_area: for site_idx, images in site_images_to_draw.items(): for u, v, d in self.structure_graph.graph.edges( nbunch=site_idx, data=True): for image in images: # check bonds going in both directions, i.e. from u or from v # to_image is defined from u going to v to_image = tuple( np.add(d['to_jimage'], image).astype(int)) # make sure we're drawing the site the bond is going to if to_image not in site_images_to_draw[v]: images_to_add[v].append(to_image) # and also the site the bond is coming from from_image_complement = tuple(np.multiply( -1, to_image)) if from_image_complement not in site_images_to_draw[u]: images_to_add[u].append(from_image_complement) atoms = OrderedDict() for site_idx, images in site_images_to_draw.items(): site = self.structure[site_idx] # for disordered structures occu_start = 0.0 fragments = [] for comp_idx, (sp, occu) in enumerate(site.species_and_occu.items()): # in disordered structures, we fractionally color-code spheres, # drawing a sphere segment from phi_end to phi_start # (think a sphere pie chart) phi_frac_end = occu_start + occu phi_frac_start = occu_start occu_start = phi_frac_end radius = site.properties['display_radius'][comp_idx] color = site.properties['display_color'][comp_idx] name = "{}".format(sp) if occu != 1.0: name += " ({}% occupancy)".format(occu) fragments.append({ 'radius': radius, 'color': color, 'name': name, 'phi_start': phi_frac_start * np.pi * 2, 'phi_end': phi_frac_end * np.pi * 2 }) bond_color = fragments[0]['color'] if site.is_ordered else [ 55, 55, 55 ] # TODO: do some appropriate scaling here if 'vector' in self.site_prop_names: vectors = { name: site.properties[name] for name in self.site_prop_names['vector'] } else: vectors = None if 'matrix' in self.site_prop_names: matrices = { name: site.properties[name] for name in self.site_prop_names['matrix'] } else: matrices = None for image in images: position_cart = list( np.subtract( self.lattice.get_cartesian_coords( np.add(site.frac_coords, image)), self.geometric_center)) atoms[(site_idx, image)] = { 'type': 'sphere', 'idx': len(atoms), 'position': position_cart, 'bond_color': bond_color, 'fragments': fragments, 'vectors': vectors, 'matrices': matrices } return atoms def _generate_bonds(self, atoms): bonds_set = set() atoms = list(atoms.keys()) for site_idx, image in atoms: for u, v, d in self.structure_graph.graph.edges(nbunch=site_idx, data=True): to_image = tuple(np.add(d['to_jimage'], image).astype(int)) bond = frozenset({(u, image), (v, to_image)}) bonds_set.add(bond) bonds = OrderedDict() for bond in bonds_set: bond = tuple(bond) try: from_atom_idx = atoms.index(bond[0]) to_atom_idx = atoms.index(bond[1]) except ValueError: pass # one of the atoms in the bond isn't being drawn else: bonds[bond] = { 'from_atom_index': from_atom_idx, 'to_atom_index': to_atom_idx } return bonds def _generate_radii(self): structure = self.structure_graph.structure # don't calculate radius if one is explicitly supplied if 'display_radius' in structure.site_properties: return if self.radius_strategy is 'bvanalyzer_ionic': trans = AutoOxiStateDecorationTransformation() try: structure = trans.apply_transformation( self.structure_graph.structure) except: # if we can't assign valences use average ionic self.radius_strategy = 'average_ionic' radii = [] for site_idx, site in enumerate(structure): site_radii = [] for comp_idx, (sp, occu) in enumerate(site.species_and_occu.items()): radius = None if self.radius_strategy not in self.available_radius_strategies: raise ValueError( "Unknown radius strategy {}, choose from: {}".format( self.radius_strategy, self.available_radius_strategies)) if self.radius_strategy is 'atomic': radius = sp.atomic_radius elif self.radius_strategy is 'bvanalyzer_ionic' and isinstance( sp, Specie): radius = sp.ionic_radius elif self.radius_strategy is 'average_ionic': radius = sp.average_ionic_radius elif self.radius_strategy is 'covalent': el = str(getattr(sp, 'element', sp)) radius = CovalentRadius.radius[el] elif self.radius_strategy is 'van_der_waals': radius = sp.van_der_waals_radius elif self.radius_strategy is 'atomic_calculated': radius = sp.atomic_radius_calculated if not radius: warnings.warn('Radius unknown for {} and strategy {}, ' 'setting to 1.0.'.format( sp, self.radius_strategy)) radius = 1.0 site_radii.append(radius) radii.append(site_radii) self.structure_graph.structure.add_site_property( 'display_radius', radii) def _generate_colors(self): structure = self.structure_graph.structure legend = {} # don't calculate color if one is explicitly supplied if 'display_color' in structure.site_properties: return legend # don't know what the color legend (meaning) is, so return empty legend if self.color_scheme not in ('VESTA', 'Jmol'): if not structure.is_ordered: raise ValueError( 'Can only use VESTA or Jmol color schemes ' 'for disordered structures, color schemes based ' 'on site properties are ill-defined.') if self.color_scheme in self.site_prop_names.get('scalar', []): props = np.array(structure.site_properties[self.color_scheme]) if min(props) < 0 and max(props) > 0: # by default, use blue-grey-red color scheme, # so that zero is ~ grey, and positive/negative # are red/blue color_scale = self.color_scale or 'coolwarm' # try to keep color scheme symmetric around 0 color_max = max([abs(min(props)), max(props)]) color_min = -color_max else: # but if all values are positive, use a # perceptually-uniform color scale by default # like viridis color_scale = self.color_scale or 'viridis' color_max = max(props) color_min = min(props) cmap = get_cmap(color_scale) # normalize in [0, 1] range, as expected by cmap props = (props - min(props)) / (max(props) - min(props)) def _get_color(x): return [int(c * 255) for c in cmap(x)[0:3]] colors = [[_get_color(x)] for x in props] # construct legend c = "#{:02x}{:02x}{:02x}".format(*_get_color(color_min)) legend[c] = "{}".format(color_min) if color_max != color_min: c = "#{:02x}{:02x}{:02x}".format(*_get_color(color_max)) legend[c] = "{}".format(color_max) color_mid = (color_max - color_min) / 2 if color_max % 1 == 0 and color_min % 1 == 0 and color_max - color_min > 1: color_mid = int(color_mid) c = "#{:02x}{:02x}{:02x}".format(*_get_color(color_mid)) legend[c] = "{}".format(color_mid) elif self.color_scheme in self.site_prop_names.get( 'categorical', []): raise NotImplementedError # iter() a palettable palettable.colorbrewer.qualitative cmap.colors, check len, Set1_9 ? else: raise ValueError( 'Unsupported color scheme. Should be "VESTA", "Jmol" or ' 'a scalar or categorical site property.') else: colors = [] for site in structure: elements = [ sp.as_dict()['element'] for sp, _ in site.species_and_occu.items() ] colors.append([ EL_COLORS[self.color_scheme][element] for element in elements ]) # construct legend for element in elements: color = "#{:02x}{:02x}{:02x}".format( *EL_COLORS[self.color_scheme][element]) legend[color] = element self.structure_graph.structure.add_site_property( 'display_color', colors) return legend def _generate_unit_cell(self): o = -self.geometric_center a, b, c = self.lattice.matrix[0], self.lattice.matrix[ 1], self.lattice.matrix[2] line_pairs = [ o, o + a, o, o + b, o, o + c, o + a, o + a + b, o + a, o + a + c, o + b, o + b + a, o + b, o + b + c, o + c, o + c + a, o + c, o + c + b, o + a + b, o + a + b + c, o + a + c, o + a + b + c, o + b + c, o + a + b + c ] line_pairs = [line.tolist() for line in line_pairs] unit_cell = {'type': 'lines', 'lines': line_pairs} return unit_cell def _generate_polyhedra(self, atoms, bonds): # TODO: this function is a bit confusing # mostly due to number of similarly-named data structures ... rethink? potential_polyhedra_by_site = {} for idx, site in enumerate(self.structure): connected_sites = self.structure_graph.get_connected_sites(idx) neighbors_sp = [cn[0].species_string for cn in connected_sites] neighbors_idx = [cn.index for cn in connected_sites] # could enforce len(set(neighbors_sp)) == 1 here if we want to only # draw polyhedra when neighboring atoms are all the same if len(neighbors_sp) > 2: # store num expected vertices, we don't want to draw incomplete polyhedra potential_polyhedra_by_site[idx] = len(neighbors_sp) polyhedra = defaultdict(list) for ((from_site_idx, from_image), (to_site_idx, to_image)), d in bonds.items(): if from_site_idx in potential_polyhedra_by_site: polyhedra[(from_site_idx, from_image)].append( (to_site_idx, to_image)) if to_site_idx in potential_polyhedra_by_site: polyhedra[(to_site_idx, to_image)].append( (from_site_idx, from_image)) # discard polyhedra with incorrect coordination (e.g. half the polyhedra's atoms are # not in the draw range so would be cut off) polyhedra = { k: v for k, v in polyhedra.items() if len(v) == potential_polyhedra_by_site[k[0]] } polyhedra_by_species = defaultdict(list) for k in polyhedra.keys(): polyhedra_by_species[self.structure[k[0]].species_string].append(k) polyhedra_json_by_species = {} polyhedra_by_species_vertices = {} polyhedra_by_species_centres = {} for sp, polyhedra_centres in polyhedra_by_species.items(): polyhedra_json = [] polyhedra_vertices = [] for polyhedron_centre in polyhedra_centres: # book-keeping to prevent intersecting polyhedra polyhedron_vertices = polyhedra[polyhedron_centre] polyhedra_vertices += polyhedron_vertices polyhedron_points_cart = [ atoms[vert]['position'] for vert in polyhedron_vertices ] polyhedron_points_idx = [ atoms[vert]['idx'] for vert in polyhedron_vertices ] polyhedron_center_idx = atoms[polyhedron_centre]['idx'] # Delaunay can fail in some edge cases try: hull = Delaunay( polyhedron_points_cart).convex_hull.tolist() # TODO: storing duplicate info here ... ? polyhedra_json.append({ 'type': 'convex', 'points_idx': polyhedron_points_idx, 'points': polyhedron_points_cart, 'hull': hull, 'center': polyhedron_center_idx }) except Exception as e: print(e) polyhedron_centres = set(polyhedra_centres) polyhedron_vertices = set(polyhedra_vertices) if (not polyhedron_vertices.intersection(polyhedra_centres)) \ and len(polyhedra_json) > 0 : name = "{}-centered".format(sp) polyhedra_json_by_species[name] = polyhedra_json polyhedra_by_species_centres[name] = polyhedron_centres polyhedra_by_species_vertices[name] = polyhedron_vertices if polyhedra_json_by_species: # get compatible sets of polyhedra compatible_subsets = { (k, ): len(v) for k, v in polyhedra_json_by_species.items() } for r in range(2, len(polyhedra_json_by_species) + 1): for subset in itertools.combinations( polyhedra_json_by_species.keys(), r): compatible = True all_centres = set.union( *[polyhedra_by_species_vertices[sp] for sp in subset]) all_verts = set.union( *[polyhedra_by_species_centres[sp] for sp in subset]) if not all_verts.intersection(all_centres): compatible_subsets[tuple(subset)] = len(all_centres) # sort by longest subset, secondary sort by radius compatible_subsets = sorted(compatible_subsets.items(), key=lambda s: -s[1]) default_polyhedra = list(compatible_subsets[0][0]) else: default_polyhedra = [] return { 'polyhedra_by_type': polyhedra_json_by_species, 'polyhedra_types': list(polyhedra_json_by_species.keys()), 'default_polyhedra_types': default_polyhedra }
class MPVisualizer: """ Class takes a Structure or StructureGraph and outputs a list of primitives (spheres, cylinders, etc.) for drawing purposes. """ allowed_bonding_strategies = { subclass.__name__: subclass for subclass in NearNeighbors.__subclasses__() } def __init__(self, structure, bonding_strategy='MinimumOKeeffeNN', bonding_strategy_kwargs=None, color_scheme="VESTA", color_scale=None, radius_strategy="ionic", coordination_polyhedra=False, draw_image_atoms=True, repeat_of_atoms_on_boundaries=True, bonded_atoms_outside_unit_cell=True, scale=None): """ This class is used to generate a generic JSON of geometric primitives that can be parsed by a 3D rendering tool such as Three.js or Blender. The intent is that any crystallographic knowledge is handled by pymatgen (e.g. finding Cartesian co-ordinates of atoms, suggesting colors and radii for the atoms, detecting whether andw here bonds should be present, detecting oxidation environments, calculating coordination polyhedra, etc.) Therefore, the 3D rendering code does not need to have any knowledge of crystallography, but just has to parse the JSON and draw the corresponding geometric primitives. The resulting JSON should be self-explanatory and constructed such that an appropriate scene graph can be made without difficulty. The 'type' keys can be either 'sphere', 'cylinder', 'geometry' or 'text', the 'position' key is always a vector with respect to global Cartesian co-ordinates. Bonds and co-ordination polyhedra reference the index of the appropriate atom, rather than duplicating position vectors. Args: structure: an instance of Structure or StructureGraph (the latter for when bonding information is already known) bonding_strategy: a name of a NearNeighbor subclass to calculate bonding, will only be used if bonding not already defined (at time of writing, options are 'JMolNN', 'MinimumDistanceNN', 'MinimumOKeeffeeNN', 'MinimumVIRENN', 'VoronoiNN', but see pymatgen.analysis.local_env for latest options) bonding_strategy_kwargs: kwargs to be passed to the NearNeighbor subclass color_scheme: can be "VESTA", "Jmol" or a name of a site property to color-code by that site property color_scale: only relevant if color-coding by site property, by default will use viridis, but all matplotlib color scales are supported (see matplotlib documentation) radius_strategy: options are "ionic" or "van_der_Waals": by default, will use ionic radii if oxidation states are supplied or if BVAnalyzer can supply oxidation states, otherwise will use average ionic radius coordination_polyhedra: if False, will not calculate coordination polyhedra, if True will calculate coordination polyhedra taking the smallest atoms as the vertices of the polyhedra, if a tuple of ints is supplied will only draw polyhedra with those number of vertices (e.g. supplying (4,6) will only draw tetrahedra and octahedra) """ # draw periodic repeats if requested if scale: structure = structure * scale # we assume most uses of this class will give a structure as an input argument, # meaning we have to calculate the graph for bonding information, however if # the graph is already known and supplied, we will use that if not isinstance(structure, StructureGraph): if bonding_strategy not in self.allowed_bonding_strategies.keys(): raise ValueError( "Bonding strategy not supported. Please supply a name " "of a NearNeighbor subclass, choose from: {}".format( ", ".join(self.allowed_bonding_strategies.keys()))) else: bonding_strategy_kwargs = bonding_strategy_kwargs or {} bonding_strategy = self.allowed_bonding_strategies[ bonding_strategy](**bonding_strategy_kwargs) self.structure_graph = StructureGraph.with_local_env_strategy( structure, bonding_strategy, decorate=False) self.lattice = structure.lattice else: self.structure_graph = structure self.lattice = self.structure_graph.structure.lattice if coordination_polyhedra: # TODO: coming soon raise NotImplementedError self.color_scheme = color_scheme self.color_scale = color_scale self.radius_strategy = radius_strategy self.coordination_polyhedra = coordination_polyhedra self.draw_image_atoms = draw_image_atoms self.repeat_of_atoms_on_boundaries = repeat_of_atoms_on_boundaries self.bonded_atoms_outside_unit_cell = bonded_atoms_outside_unit_cell @property def json(self): json = {} # adds display colors as site properties self._generate_colors() # used to keep track of atoms outside periodic boundaries for bonding self._atom_indexes = {} self._site_images_to_draw = {} self._bonds = [] self._polyhedra = [] json.update(self._generate_atoms()) json.update(self._generate_bonds()) json.update(self._generate_polyhedra()) json.update(self._generate_unit_cell()) return json def _generate_atoms(self): # try to work out oxidation states trans = AutoOxiStateDecorationTransformation try: bv_structure = trans.apply_transformation( self.structure_graph.structure) except: # if we can't assign valences juse use original structure bv_structure = self.structure_graph.structure self._bonds = [] # to translate atoms so that geometric center at (0, 0, 0) # in global co-ordinate system lattice = self.structure_graph.structure.lattice geometric_center = lattice.get_cartesian_coords((0.5, 0.5, 0.5)) # used to easily find positions of atoms that lie on periodic boundaries adjacent_images = [(1, 0, 0), (-1, 0, 0), (0, 1, 0), (0, -1, 0), (0, 0, 1), (0, 0, -1), (1, 1, 0), (1, -1, 0), (-1, 1, 0), (-1, -1, 0), (0, 1, 1), (0, 1, -1), (0, -1, 1), (0, -1, -1), (1, 0, 1), (1, 0, -1), (-1, 0, 1), (-1, 0, -1), (1, 1, 1), (1, 1, -1), (1, -1, 1), (-1, 1, 1), (1, -1, -1), (-1, 1, -1), (-1, -1, 1), (-1, -1, -1)] self.site_images_to_draw = { i: [] for i in range(len(self.structure_graph)) } for site_idx, site in enumerate(self.structure_graph.structure): images_to_draw = [(0, 0, 0)] if self.repeat_of_atoms_on_boundaries: possible_positions = [ site.frac_coords + adjacent_image for adjacent_image in adjacent_images ] # essentially find which atoms lie on periodic boundaries, and # draw their repeats images_to_draw += [ image for image, p in zip(adjacent_images, possible_positions) if 0 <= p[0] <= 1 and 0 <= p[1] <= 1 and 0 <= p[2] <= 1 ] self.site_images_to_draw[site_idx] += images_to_draw # get bond information for site # why? to know if we want to draw an image atom, # we have to know what bonds are present for image in images_to_draw: for u, v, d in self.structure_graph.graph.edges( nbunch=site_idx, data=True): to_jimage = tuple(np.add(d['to_jimage'], image)) self._bonds.append((site_idx, image, v, to_jimage)) if to_jimage not in self.site_images_to_draw[v]: self.site_images_to_draw[v].append(to_jimage) if self.bonded_atoms_outside_unit_cell and \ to_jimage != (0, 0, 0) and image == (0, 0, 0): from_image_complement = tuple( np.multiply(-1, to_jimage)) self._bonds.append( (site_idx, from_image_complement, v, (0, 0, 0))) if from_image_complement not in self.site_images_to_draw[ site_idx]: self.site_images_to_draw[site_idx].append( from_image_complement) atoms = [] self._atoms_cart = {} for atom_idx, (site_idx, images) in enumerate(self.site_images_to_draw.items()): for image in images: self._atom_indexes[(site_idx, image)] = len(atoms) # for disordered structures occu_start = 0.0 fragments = [] site = self.structure_graph.structure[site_idx] position_cart = list( np.subtract( lattice.get_cartesian_coords( np.add(site.frac_coords, image)), geometric_center)) self._atoms_cart[(site_idx, image)] = position_cart for comp_idx, (sp, occu) in enumerate( site.species_and_occu.items()): # in disordered structures, we fractionally color-code spheres, # drawing a sphere segment from phi_end to phi_start # (think a sphere pie chart) phi_frac_end = occu_start + occu phi_frac_start = occu_start occu_start = phi_frac_end bv_site = bv_structure[site_idx] # get radius of sphere we want to draw if self.radius_strategy == 'ionic': radius = getattr(bv_site.specie, "ionic_radius", bv_site.specie.average_ionic_radius) else: radius = site.species.van_der_waals_radius color = site.properties['display_color'][comp_idx] # generate a label (e.g. to use for mouse-over text) if not bv_structure.is_ordered and sp != bv_structure[ site_idx].specie: # if we can guess what oxidation states are present, # add them as a label ... we only attempt this for ordered structures name = "{} (detected as likely {})".format( site.specie, bv_structure[site_idx].specie) else: name = "{}".format(sp) if occu != 1.0: name += " ({}% occupancy)".format(occu) fragments.append({ 'radius': radius, 'color': color, 'name': name, 'phi_start': phi_frac_start * np.pi * 2, 'phi_end': phi_frac_end * np.pi * 2 }) atoms.append({ 'type': 'sphere', 'position': position_cart, 'bond_color': fragments[0]['color'] if site.is_ordered else [55, 55, 55], 'fragments': fragments, 'ghost': True if image != (0, 0, 0) else False }) return {'atoms': atoms} def _generate_bonds(self): # most of bonding logic is done inside _generate_atoms # why? because to decide which atoms we want to draw, we # first have to construct the bonds bonds = [] for from_site_idx, from_image, to_site_idx, to_image in self._bonds: from_atom_idx = self._atom_indexes[(from_site_idx, from_image)] to_atom_idx = self._atom_indexes[(to_site_idx, to_image)] bond = { 'from_atom_index': from_atom_idx, 'to_atom_index': to_atom_idx } if bond not in bonds: bonds.append(bond) return {'bonds': bonds} def _generate_colors(self): structure = self.structure_graph.structure # TODO: get color scale object from matplotlib if self.color_scheme not in ('VESTA', 'Jmol'): if not structure.is_ordered: raise ValueError( 'Can only use VESTA or Jmol color schemes ' 'for disordered structures, color schemes based ' 'on site properties are ill-defined.') if self.color_scheme in structure.site_properties: props = np.array(structure.site_properties[self.color_scheme]) if min(props) < 0: # by default, use blue-grey-red color scheme, # so that zero is ~ grey, and positive/negative # are red/blue color_scale = self.color_scale or 'coolwarm' # try to keep color scheme symmetric around 0 color_max = max([abs(min(props)), max(props)]) color_min = -color_max else: # but if all values are positive, use a # perceptually-uniform color scale by default # like viridis color_scale = self.color_scale or 'viridis' color_max = max(props) color_min = min(props) cmap = get_cmap(color_scale) # normalize in [0, 1] range, as expected by cmap props = (props - min(props)) / (max(props) - min(props)) # TODO: reduce calls to cmap here colors = [[[ int(cmap(x)[0] * 255), int(cmap(x)[1] * 255), int(cmap(x)[2] * 255) ]] for x in props] else: raise ValueError( 'Unsupported color scheme. Should be "VESTA", "Jmol" or ' 'a site property.') else: colors = [] for site in structure: elements = [ sp.as_dict()['element'] for sp, _ in site.species_and_occu.items() ] colors.append([ EL_COLORS[self.color_scheme][element] for element in elements ]) self.structure_graph.structure.add_site_property( 'display_color', colors) def _generate_unit_cell(self): frac_vertices = [[0, 0, 0], [1, 0, 0], [0, 1, 0], [0, 0, 1], [1, 1, 0], [0, 1, 1], [1, 0, 1], [1, 1, 1]] cart_vertices = [ list( self.lattice.get_cartesian_coords( np.subtract(vert, (0.5, 0.5, 0.5)))) for vert in frac_vertices ] tri = Delaunay(cart_vertices) unit_cell = { 'type': 'convex', 'points': cart_vertices, 'hull': tri.convex_hull.tolist(), } return {'unit_cell': unit_cell} def _generate_polyhedra(self): if not self.bonded_atoms_outside_unit_cell: raise ValueError( "All bonds from unit cell must be drawn to be able to reliably " "draw co-ordination polyhedra, please set " "bonded_atoms_outside_unit_cell to True.") # this just creates a list of all bonded atoms from each site # (this isn't used for the bonding itself, otherwise we'd be # double counting bonds, but for polyhedra, in principle each # site has its own polyhedra defined) potential_polyhedra = {i: [] for i in range(len(self.structure_graph))} for from_site_idx, from_image, to_site_idx, to_image in self._bonds: if from_image == (0, 0, 0): potential_polyhedra[from_site_idx].append( (to_site_idx, to_image)) if to_image == (0, 0, 0): potential_polyhedra[to_site_idx].append( (from_site_idx, from_image)) # sort sites for which polyhedra we should prioritize: we don't want # polyhedra to intersect each other # TODO: placeholder sorted_sites_idxs = list(range(len(self.structure_graph))) # now we actually store the polyhedra we want! # a single polyhedron is simply a list of atom indexes that form # its vertices; a convex hull algorithm is necessary to actually # construct the faces polyhedra = [] polyhedra_types = set() for site_idx in sorted_sites_idxs: if site_idx in potential_polyhedra: polyhedron_points = potential_polyhedra[site_idx] # remove the polyhedron's vertices from the list of potential polyhedra centers for (site, image) in polyhedron_points: if image == (0, 0, 0) and site in potential_polyhedra: del potential_polyhedra[site] # and look up the actual positions in the array of atoms we're drawing polyhedron_points_cart = [ self._atoms_cart[(site, image)] for (site, image) in polyhedron_points ] polyhedron_points_idx = [ self._atom_indexes[(site, image)] for (site, image) in polyhedron_points ] # calculate the hull #print(len(polyhedron_points_cart)) #print(polyhedron_points_cart) try: tri = Delaunay(polyhedron_points_cart) # a pretty name, helps with filtering too site = self.structure_graph.structure[site_idx] species = ", ".join( map(str, list(site.species_and_occu.keys()))) polyhedron_type = '{}-centered'.format(species) polyhedra_types.add(polyhedron_type) polyhedra.append({ 'type': 'convex', 'points_idx': polyhedron_points_idx, 'points': polyhedron_points_cart, 'hull': tri.convex_hull, 'name': polyhedron_type, 'center': site_idx }) except Exception as e: print(e) return { 'polyhedra': { 'polyhedra_list': polyhedra, 'polyhedra_types': list(polyhedra_types) } }
class StructureMoleculeComponent(MPComponent): available_bonding_strategies = { subclass.__name__: subclass for subclass in NearNeighbors.__subclasses__() } available_radius_strategies = ( "atomic", "specified_or_average_ionic", "covalent", "van_der_waals", "atomic_calculated", "uniform", ) default_scene_settings = {"cylinderScale": 0.1} def __init__( self, struct_or_mol=None, id=None, origin_component=None, scene_additions=None, bonding_strategy="MinimumDistanceNN", bonding_strategy_kwargs=None, color_scheme="VESTA", color_scale=None, radius_strategy="uniform", radius_scale=1.0, draw_image_atoms=True, bonded_sites_outside_unit_cell=False, hide_incomplete_bonds=False, show_compass=False, scene_settings=None, **kwargs, ): super().__init__( id=id, contents=struct_or_mol, origin_component=origin_component, **kwargs ) self.default_title = "Crystal Toolkit" self.initial_scene_settings = ( StructureMoleculeComponent.default_scene_settings.copy() ) if scene_settings: self.initial_scene_settings.update(scene_settings) self.create_store("scene_settings", initial_data=self.initial_scene_settings) self.initial_graph_generation_options = { "bonding_strategy": bonding_strategy, "bonding_strategy_kwargs": bonding_strategy_kwargs, } self.create_store( "graph_generation_options", initial_data=self.initial_graph_generation_options, ) self.initial_display_options = { "color_scheme": color_scheme, "color_scale": color_scale, "radius_strategy": radius_strategy, "radius_scale": radius_scale, "draw_image_atoms": draw_image_atoms, "bonded_sites_outside_unit_cell": bonded_sites_outside_unit_cell, "hide_incomplete_bonds": hide_incomplete_bonds, "show_compass": show_compass, } self.create_store("display_options", initial_data=self.initial_display_options) if scene_additions: self.initial_scene_additions = Scene( name="scene_additions", contents=scene_additions ) else: self.initial_scene_additions = Scene(name="scene_additions") self.create_store( "scene_additions", initial_data=self.initial_scene_additions.to_json() ) if struct_or_mol: # graph is cached explicitly, this isn't necessary but is an # optimization so that graph is only re-generated if bonding # algorithm changes graph = self._preprocess_input_to_graph( struct_or_mol, bonding_strategy=bonding_strategy, bonding_strategy_kwargs=bonding_strategy_kwargs, ) scene, legend = self.get_scene_and_legend( graph, name=self.id(), scene_additions=self.initial_scene_additions, **self.initial_display_options, ) if hasattr(struct_or_mol, "lattice"): self._lattice = struct_or_mol.lattice else: # component could be initialized without a structure, in which case # an empty scene should be displayed graph = None scene, legend = self.get_scene_and_legend( None, name=self.id(), scene_additions=self.initial_scene_additions, **self.initial_display_options, ) self.initial_legend = legend self.create_store("legend_data", initial_data=self.initial_legend) self.initial_scene_data = scene.to_json() self.initial_graph = graph self.create_store("graph", initial_data=self.to_data(graph)) def generate_callbacks(self, app, cache): @app.callback( Output(self.id("graph"), "data"), [ Input(self.id("graph_generation_options"), "data"), Input(self.id("unit-cell-choice"), "value"), Input(self.id(), "data"), ], ) def update_graph(graph_generation_options, unit_cell_choice, struct_or_mol): if not struct_or_mol: raise PreventUpdate struct_or_mol = self.from_data(struct_or_mol) graph_generation_options = self.from_data(graph_generation_options) if isinstance(struct_or_mol, Structure): if unit_cell_choice != "input": if unit_cell_choice == "primitive": struct_or_mol = struct_or_mol.get_primitive_structure() elif unit_cell_choice == "conventional": sga = SpacegroupAnalyzer(struct_or_mol) struct_or_mol = sga.get_conventional_standard_structure() elif unit_cell_choice == "reduced": struct_or_mol = struct_or_mol.get_reduced_structure() graph = self._preprocess_input_to_graph( struct_or_mol, bonding_strategy=graph_generation_options["bonding_strategy"], bonding_strategy_kwargs=graph_generation_options[ "bonding_strategy_kwargs" ], ) self.logger.debug("Constructed graph") return self.to_data(graph) @app.callback( Output(self.id("scene"), "data"), [ Input(self.id("graph"), "data"), Input(self.id("display_options"), "data"), ], ) def update_scene(graph, display_options): display_options = self.from_data(display_options) graph = self.from_data(graph) scene, legend = self.get_scene_and_legend(graph, **display_options) return scene.to_json() @app.callback( Output(self.id("legend_data"), "data"), [ Input(self.id("graph"), "data"), Input(self.id("display_options"), "data"), ], ) def update_legend(graph, display_options): # TODO: more cleanly split legend from scene generation display_options = self.from_data(display_options) graph = self.from_data(graph) struct_or_mol = self._get_struct_or_mol(graph) site_prop_types = self._analyze_site_props(struct_or_mol) colors, legend = self._get_display_colors_and_legend_for_sites( struct_or_mol, site_prop_types, color_scheme=display_options.get("color_scheme", None), color_scale=display_options.get("color_scale", None), ) return self.to_data(legend) @app.callback( Output(self.id("color-scheme"), "options"), [Input(self.id("graph"), "data")], ) def update_color_options(graph): options = [ {"label": "Jmol", "value": "Jmol"}, {"label": "VESTA", "value": "VESTA"}, {"label": "Colorblind-friendly", "value": "colorblind_friendly"}, ] graph = self.from_data(graph) struct_or_mol = self._get_struct_or_mol(graph) site_props = self._analyze_site_props(struct_or_mol) for site_prop_type in ("scalar", "categorical"): if site_prop_type in site_props: for prop in site_props[site_prop_type]: options += [{"label": f"Site property: {prop}", "value": prop}] return options @app.callback( Output(self.id("display_options"), "data"), [ Input(self.id("color-scheme"), "value"), Input(self.id("radius_strategy"), "value"), Input(self.id("draw_options"), "value"), ], [State(self.id("display_options"), "data")], ) def update_display_options( color_scheme, radius_strategy, draw_options, display_options ): display_options = self.from_data(display_options) display_options.update({"color_scheme": color_scheme}) display_options.update({"radius_strategy": radius_strategy}) display_options.update( {"draw_image_atoms": "draw_image_atoms" in draw_options} ) display_options.update( { "bonded_sites_outside_unit_cell": "bonded_sites_outside_unit_cell" in draw_options } ) display_options.update( {"hide_incomplete_bonds": "hide_incomplete_bonds" in draw_options} ) if display_options == self.initial_display_options: raise PreventUpdate self.logger.debug("Display options updated") return self.to_data(display_options) @app.callback( Output(self.id("scene"), "downloadRequest"), [Input(self.id("screenshot_button"), "n_clicks")], [State(self.id("scene"), "downloadRequest"), State(self.id(), "data")], ) def screenshot_callback(n_clicks, current_requests, struct_or_mol): if n_clicks is None: raise PreventUpdate struct_or_mol = self.from_data(struct_or_mol) # TODO: this will break if store is structure/molecule graph ... formula = struct_or_mol.composition.reduced_formula if hasattr(struct_or_mol, "get_space_group_info"): spgrp = struct_or_mol.get_space_group_info()[0] else: spgrp = "" request_filename = "{}-{}-crystal-toolkit.png".format(formula, spgrp) if not current_requests: n_requests = 1 else: n_requests = current_requests["n_requests"] + 1 return { "n_requests": n_requests, "filename": request_filename, "filetype": "png", } @app.callback( Output(self.id("scene"), "toggleVisibility"), [Input(self.id("hide-show"), "value")], [State(self.id("hide-show"), "options")], ) def update_visibility(values, options): visibility = {opt["value"]: (opt["value"] in values) for opt in options} return visibility @app.callback( [ Output(self.id("legend_container"), "children"), Output(self.id("title_container"), "children"), ], [Input(self.id("legend_data"), "data")], ) def update_legend(legend): legend = self.from_data(legend) if legend == self.initial_legend: raise PreventUpdate return self._make_legend(legend), self._make_title(legend) @app.callback( Output(self.id("graph_generation_options"), "data"), [ Input(self.id("bonding_algorithm"), "value"), Input(self.id("bonding_algorithm_custom_cutoffs"), "data"), ], ) def update_structure_viewer_data(bonding_algorithm, custom_cutoffs_rows): graph_generation_options = { "bonding_strategy": bonding_algorithm, "bonding_strategy_kwargs": None, } if graph_generation_options == self.initial_graph_generation_options: raise PreventUpdate if bonding_algorithm == "CutOffDictNN": # this is not the format CutOffDictNN expects (since that is not JSON # serializable), so we store as a list of tuples instead # TODO: make CutOffDictNN args JSON serializable custom_cutoffs = [ (row["A"], row["B"], float(row["A—B"])) for row in custom_cutoffs_rows ] graph_generation_options["bonding_strategy_kwargs"] = { "cut_off_dict": custom_cutoffs } return self.to_data(graph_generation_options) @app.callback( [ Output(self.id("bonding_algorithm_custom_cutoffs"), "data"), Output(self.id("bonding_algorithm_custom_cutoffs_container"), "style"), ], [Input(self.id("bonding_algorithm"), "value")], [State(self.id("graph"), "data")], ) def update_custom_bond_options(bonding_algorithm, graph): if not graph: raise PreventUpdate if bonding_algorithm == "CutOffDictNN": style = {} else: style = {"display": "none"} graph = self.from_data(graph) struct_or_mol = self._get_struct_or_mol(graph) # can't use type_of_specie because it doesn't work with disordered structures species = set( map( str, chain.from_iterable( [list(c.keys()) for c in struct_or_mol.species_and_occu] ), ) ) rows = [ {"A": combination[0], "B": combination[1], "A—B": 0} for combination in combinations_with_replacement(species, 2) ] return rows, style def _make_legend(self, legend): if legend is None or (not legend.get("colors", None)): return html.Div(id=self.id("legend")) def get_font_color(hex_code): # ensures contrasting font color for background color c = tuple(int(hex_code[1:][i : i + 2], 16) for i in (0, 2, 4)) if 1 - (c[0] * 0.299 + c[1] * 0.587 + c[2] * 0.114) / 255 < 0.5: font_color = "#000000" else: font_color = "#ffffff" return font_color try: formula = Composition.from_dict(legend["composition"]).reduced_formula except: # TODO: fix for Dummy Specie compositions formula = "Unknown" legend_colors = OrderedDict( sorted(list(legend["colors"].items()), key=lambda x: formula.find(x[1])) ) legend_elements = [ Button( html.Span( name, className="icon", style={"color": get_font_color(color)} ), kind="static", style={"background-color": color}, ) for color, name in legend_colors.items() ] return Field( [Control(el, style={"margin-right": "0.2rem"}) for el in legend_elements], id=self.id("legend"), grouped=True, ) def _make_title(self, legend): if not legend or (not legend.get("composition", None)): return H1(self.default_title, id=self.id("title")) composition = legend["composition"] if isinstance(composition, dict): # TODO: make Composition handle DummySpecie try: composition = Composition.from_dict(composition) formula = composition.reduced_formula formula_parts = re.findall(r"[^\d_]+|\d+", formula) formula_components = [ html.Sub(part) if part.isnumeric() else html.Span(part) for part in formula_parts ] except: formula_components = list(composition.keys()) return H1( formula_components, id=self.id("title"), style={"display": "inline-block"} ) @property def all_layouts(self): struct_layout = html.Div( Simple3DSceneComponent( id=self.id("scene"), data=self.initial_scene_data, settings=self.initial_scene_settings, ), style={ "width": "100%", "height": "100%", "overflow": "hidden", "margin": "0 auto", }, ) screenshot_layout = html.Div( [ Button( [Icon(), html.Span(), "Download Image"], kind="primary", id=self.id("screenshot_button"), ) ], # TODO: change to "bottom" when dropdown included style={"vertical-align": "top", "display": "inline-block"}, ) title_layout = html.Div( self._make_title(self.initial_legend), id=self.id("title_container") ) legend_layout = html.Div( self._make_legend(self.initial_legend), id=self.id("legend_container") ) nn_mapping = { "CrystalNN": "CrystalNN", "Custom Bonds": "CutOffDictNN", "Jmol Bonding": "JmolNN", "Minimum Distance (10% tolerance)": "MinimumDistanceNN", "O'Keeffe's Algorithm": "MinimumOKeeffeNN", "Hoppe's ECoN Algorithm": "EconNN", "Brunner's Reciprocal Algorithm": "BrunnerNN_reciprocal", } bonding_algorithm = dcc.Dropdown( options=[{"label": k, "value": v} for k, v in nn_mapping.items()], value="CrystalNN", id=self.id("bonding_algorithm"), ) bonding_algorithm_custom_cutoffs = html.Div( [ html.Br(), dt.DataTable( columns=[ {"name": "A", "id": "A"}, {"name": "B", "id": "B"}, {"name": "A—B /Å", "id": "A—B"}, ], editable=True, id=self.id("bonding_algorithm_custom_cutoffs"), ), html.Br(), ], id=self.id("bonding_algorithm_custom_cutoffs_container"), style={"display": "none"}, ) options_layout = Field( [ # TODO: hide if molecule html.Label("Change unit cell:", className="mpc-label"), html.Div( dcc.RadioItems( options=[ {"label": "Input cell", "value": "input"}, {"label": "Primitive cell", "value": "primitive"}, {"label": "Conventional cell", "value": "conventional"}, {"label": "Reduced cell", "value": "reduced"}, ], value="input", id=self.id("unit-cell-choice"), labelStyle={"display": "block"}, inputClassName="mpc-radio", ), className="mpc-control", ), html.Div( [ html.Label("Change bonding algorithm: ", className="mpc-label"), bonding_algorithm, bonding_algorithm_custom_cutoffs, ] ), html.Label("Change color scheme:", className="mpc-label"), html.Div( dcc.Dropdown( options=[ {"label": "VESTA", "value": "VESTA"}, {"label": "Jmol", "value": "Jmol"}, ], value=self.initial_display_options["color_scheme"], clearable=False, id=self.id("color-scheme"), ), className="mpc-control", ), html.Label("Change atomic radii:", className="mpc-label"), html.Div( dcc.Dropdown( options=[ {"label": "Ionic", "value": "specified_or_average_ionic"}, {"label": "Covalent", "value": "covalent"}, {"label": "Van der Waals", "value": "van_der_waals"}, {"label": "Uniform (0.5Å)", "value": "uniform"}, ], value=self.initial_display_options["radius_strategy"], clearable=False, id=self.id("radius_strategy"), ), className="mpc-control", ), html.Label("Draw options:", className="mpc-label"), html.Div( [ dcc.Checklist( options=[ { "label": "Draw repeats of atoms on periodic boundaries", "value": "draw_image_atoms", }, { "label": "Draw atoms outside unit cell bonded to " "atoms within unit cell", "value": "bonded_sites_outside_unit_cell", }, { "label": "Hide bonds where destination atoms are not shown", "value": "hide_incomplete_bonds", }, ], value=[ opt for opt in ( "draw_image_atoms", "bonded_sites_outside_unit_cell", "hide_incomplete_bonds", ) if self.initial_display_options[opt] ], labelStyle={"display": "block"}, inputClassName="mpc-radio", id=self.id("draw_options"), ) ] ), html.Label("Hide/show:", className="mpc-label"), html.Div( [ dcc.Checklist( options=[ {"label": "Atoms", "value": "atoms"}, {"label": "Bonds", "value": "bonds"}, {"label": "Unit cell", "value": "unit_cell"}, {"label": "Polyhedra", "value": "polyhedra"}, ], value=["atoms", "bonds", "unit_cell", "polyhedra"], labelStyle={"display": "block"}, inputClassName="mpc-radio", id=self.id("hide-show"), ) ], className="mpc-control", ), ] ) return { "struct": struct_layout, "screenshot": screenshot_layout, "options": options_layout, "title": title_layout, "legend": legend_layout, } @property def standard_layout(self): return html.Div( self.all_layouts["struct"], style={"width": "100vw", "height": "100vh"} ) @staticmethod def _preprocess_input_to_graph( input: Union[Structure, StructureGraph, Molecule, MoleculeGraph], bonding_strategy: str = "CrystalNN", bonding_strategy_kwargs: Optional[Dict] = None, ) -> Union[StructureGraph, MoleculeGraph]: if isinstance(input, Structure): # ensure fractional co-ordinates are normalized to be in [0,1) # (this is actually not guaranteed by Structure) try: input = input.as_dict(verbosity=0) except TypeError: # TODO: remove this, necessary for Slab(?), some structure subclasses don't have verbosity input = input.as_dict() for site in input["sites"]: site["abc"] = np.mod(site["abc"], 1) input = Structure.from_dict(input) if not input.is_ordered: # calculating bonds in disordered structures is currently very flaky bonding_strategy = "CutOffDictNN" # we assume most uses of this class will give a structure as an input argument, # meaning we have to calculate the graph for bonding information, however if # the graph is already known and supplied, we will use that if isinstance(input, StructureGraph) or isinstance(input, MoleculeGraph): graph = input else: if ( bonding_strategy not in StructureMoleculeComponent.available_bonding_strategies.keys() ): raise ValueError( "Bonding strategy not supported. Please supply a name " "of a NearNeighbor subclass, choose from: {}".format( ", ".join( StructureMoleculeComponent.available_bonding_strategies.keys() ) ) ) else: bonding_strategy_kwargs = bonding_strategy_kwargs or {} if bonding_strategy == "CutOffDictNN": if "cut_off_dict" in bonding_strategy_kwargs: # TODO: remove this hack by making args properly JSON serializable bonding_strategy_kwargs["cut_off_dict"] = { (x[0], x[1]): x[2] for x in bonding_strategy_kwargs["cut_off_dict"] } bonding_strategy = StructureMoleculeComponent.available_bonding_strategies[ bonding_strategy ]( **bonding_strategy_kwargs ) try: if isinstance(input, Structure): graph = StructureGraph.with_local_env_strategy( input, bonding_strategy ) else: graph = MoleculeGraph.with_local_env_strategy( input, bonding_strategy ) except: # for some reason computing bonds failed, so let's not have any bonds(!) if isinstance(input, Structure): graph = StructureGraph.with_empty_graph(input) else: graph = MoleculeGraph.with_empty_graph(input) return graph @staticmethod def _analyze_site_props(struct_or_mol): # store list of site props that are vectors, so these can be displayed as arrows # (implicitly assumes all site props for a given key are same type) site_prop_names = defaultdict(list) for name, props in struct_or_mol.site_properties.items(): if isinstance(props[0], float) or isinstance(props[0], int): site_prop_names["scalar"].append(name) elif isinstance(props[0], list) and len(props[0]) == 3: if isinstance(props[0][0], list) and len(props[0][0]) == 3: site_prop_names["matrix"].append(name) else: site_prop_names["vector"].append(name) elif isinstance(props[0], str): site_prop_names["categorical"].append(name) return dict(site_prop_names) @staticmethod def _get_origin(struct_or_mol): if isinstance(struct_or_mol, Structure): # display_range = [0.5, 0.5, 0.5] # x_center = 0.5 * (max(display_range[0]) - min(display_range[0])) # y_center = 0.5 * (max(display_range[1]) - min(display_range[1])) # z_center = 0.5 * (max(display_range[2]) - min(display_range[2])) geometric_center = struct_or_mol.lattice.get_cartesian_coords( (0.5, 0.5, 0.5) ) elif isinstance(struct_or_mol, Molecule): geometric_center = np.average(struct_or_mol.cart_coords, axis=0) else: geometric_center = (0, 0, 0) return geometric_center @staticmethod def _get_struct_or_mol(graph) -> Union[Structure, Molecule]: if isinstance(graph, StructureGraph): return graph.structure elif isinstance(graph, MoleculeGraph): return graph.molecule else: raise ValueError @staticmethod def _compass_from_lattice( lattice, origin=(0, 0, 0), scale=0.7, offset=0.15, compass_style="corner", **kwargs, ): # TODO: add along lattice """ Get the display components of the compass :param lattice: the pymatgen Lattice object that contains the primitive lattice vectors :param origin: the reference position to place the compass :param scale: scale all the geometric objects that makes up the compass the lattice vectors are normalized before the scaling so everything should be the same size :param offset: shift the compass from the origin by a ratio of the diagonal of the cell relative the size :return: list of cystal_toolkit.helper.scene objects that makes up the compass """ o = -np.array(origin) o = o - offset * (lattice.matrix[0] + lattice.matrix[1] + lattice.matrix[2]) a = lattice.matrix[0] / np.linalg.norm(lattice.matrix[0]) * scale b = lattice.matrix[1] / np.linalg.norm(lattice.matrix[1]) * scale c = lattice.matrix[2] / np.linalg.norm(lattice.matrix[2]) * scale a_arrow = [[o, o + a]] b_arrow = [[o, o + b]] c_arrow = [[o, o + c]] o_sphere = Spheres(positions=[o], color="black", radius=0.1 * scale) return Scene(name='compass', contents=[ Arrows( a_arrow, color="red", radius=0.7 * scale, headLength=2.3 * scale, headWidth=1.4 * scale, **kwargs, ), Arrows( b_arrow, color="blue", radius=0.7 * scale, headLength=2.3 * scale, headWidth=1.4 * scale, **kwargs, ), Arrows( c_arrow, color="green", radius=0.7 * scale, headLength=2.3 * scale, headWidth=1.4 * scale, **kwargs, ), o_sphere, ]) @staticmethod def _get_display_colors_and_legend_for_sites( struct_or_mol, site_prop_types, color_scheme="Jmol", color_scale=None ) -> Tuple[List[List[str]], Dict]: """ Note this returns a list of lists of strings since each site might have multiple colors defined if the site is disordered. The legend is a dictionary whose keys are colors and values are corresponding element names or values, depending on the color scheme chosen. """ # TODO: check to see if there is a bug here due to Composition being unordered(?) legend = {"composition": struct_or_mol.composition.as_dict(), "colors": {}} # don't calculate color if one is explicitly supplied if "display_color" in struct_or_mol.site_properties: # don't know what the color legend (meaning) is, so return empty legend return (struct_or_mol.site_properties["display_color"], legend) def get_color_hex(x): return "#{:02x}{:02x}{:02x}".format(*x) allowed_schemes = ( ["VESTA", "Jmol", "colorblind_friendly"] + site_prop_types.get("scalar", []) + site_prop_types.get("categorical", []) ) default_scheme = "Jmol" if color_scheme not in allowed_schemes: warnings.warn( f"Color scheme {color_scheme} not available, falling back to {default_scheme}." ) color_scheme = default_scheme if color_scheme not in ("VESTA", "Jmol", "colorblind_friendly"): if not struct_or_mol.is_ordered: raise ValueError( "Can only use VESTA, Jmol or colorblind_friendly color " "schemes for disordered structures or molecules, color " "schemes based on site properties are ill-defined." ) if (color_scheme not in site_prop_types.get("scalar", [])) and ( color_scheme not in site_prop_types.get("categorical", []) ): raise ValueError( "Unsupported color scheme. Should be VESTA, Jmol, " "colorblind_friendly or a scalar (float) or categorical " "(string) site property." ) if color_scheme in ("VESTA", "Jmol"): # TODO: define fallback color as global variable # TODO: maybe fallback categorical based on letter, for DummySpecie? colors = [] for site in struct_or_mol: elements = [sp.as_dict()["element"] for sp, _ in site.species.items()] colors.append( [ get_color_hex(EL_COLORS[color_scheme].get(element, [0, 0, 0])) for element in elements ] ) # construct legend for element in elements: color = get_color_hex( EL_COLORS[color_scheme].get(element, [0, 0, 0]) ) label = unicodeify_species(site.species_string) if color in legend["colors"] and legend["colors"][color] != label: legend["colors"][ color ] = f"{element}ˣ" # TODO: mixed valence, improve this else: legend["colors"][color] = label elif color_scheme == "colorblind_friendly": labels = [site.species_string for site in struct_or_mol] # thanks to https://doi.org/10.1038/nmeth.1618 palette = [ [0, 0, 0], # 0, black [230, 159, 0], # 1, orange [86, 180, 233], # 2, sky blue [0, 158, 115], # 3, bluish green [240, 228, 66], # 4, yellow [0, 114, 178], # 5, blue [213, 94, 0], # 6, vermillion [204, 121, 167], # 7, reddish purple [255, 255, 255], # 8, white ] # similar to CPK preferred_colors = { "O": 6, "N": 2, "C": 0, "H": 8, "F": 3, "Cl": 3, "Fe": 1, "Br": 7, "I": 7, "P": 1, "S": 4, } if len(set(labels)) > len(palette): warnings.warn( "Too many distinct types of site to use a color-blind friendly color scheme." ) # colors = [......] # present_specie = sorted(struct_or_mol.types_of_specie) # if len(struct_or_mol.types_of_specie) > len(colors): # # colors.append([DEFAULT_COLOR]*(len(struct_or_mol.types_of_specie)-len(colors)) # # test for disordered structures too! # # try to prefer certain colors of certain elements for historical consistency # preferred_colors = {"O": 1} # idx of colors # for el, idx in preferred_colors.items(): # if el in present_specie: # want (idx of el in present_specie) to match idx # colors.swap(idx to present_specie_idx) # color_scheme = {el:colors[idx] for idx, el in enumerate(sorted(struct_or_mol.types_of_specie))} elif color_scheme in site_prop_types.get("scalar", []): props = np.array(struct_or_mol.site_properties[color_scheme]) # by default, use blue-grey-red color scheme, # so that zero is ~ grey, and positive/negative # are red/blue color_scale = color_scale or "coolwarm" # try to keep color scheme symmetric around 0 prop_max = max([abs(min(props)), max(props)]) prop_min = -prop_max cmap = get_cmap(color_scale) # normalize in [0, 1] range, as expected by cmap props_normed = (props - prop_min) / (prop_max - prop_min) def get_color_cmap(x): return [int(c * 255) for c in cmap(x)[0:3]] colors = [[get_color_hex(get_color_cmap(x))] for x in props_normed] # construct legend rounded_props = sorted(list(set([np.around(p, decimals=1) for p in props]))) for prop in rounded_props: prop_normed = (prop - prop_min) / (prop_max - prop_min) c = get_color_hex(get_color_cmap(prop_normed)) legend["colors"][c] = "{:.1f}".format(prop) elif color_scheme in site_prop_types.get("categorical", []): props = np.array(struct_or_mol.site_properties[color_scheme]) palette = [get_color_hex(c) for c in Set1_9.colors] le = LabelEncoder() le.fit(props) transformed_props = le.transform(props) # if we have more categories than availiable colors, # arbitrarily group some categories together warnings.warn( "Too many categories for a complete categorical " "color scheme." ) transformed_props = [ p if p < len(palette) else -1 for p in transformed_props ] colors = [[palette[p]] for p in transformed_props] for category, p in zip(props, transformed_props): legend["colors"][palette[p]] = category return colors, legend @staticmethod def _get_display_radii_for_sites( struct_or_mol, radius_strategy="specified_or_average_ionic", radius_scale=1.0 ) -> List[List[float]]: """ Note this returns a list of lists of floats since each site might have multiple radii defined if the site is disordered. """ # don't calculate radius if one is explicitly supplied if "display_radius" in struct_or_mol.site_properties: return struct_or_mol.site_properties["display_radius"] if ( radius_strategy not in StructureMoleculeComponent.available_radius_strategies ): raise ValueError( "Unknown radius strategy {}, choose from: {}".format( radius_strategy, StructureMoleculeComponent.available_radius_strategies, ) ) radii = [] for site_idx, site in enumerate(struct_or_mol): site_radii = [] for comp_idx, (sp, occu) in enumerate(site.species.items()): radius = None if radius_strategy == "uniform": radius = 0.5 if radius_strategy == "atomic": radius = sp.atomic_radius elif ( radius_strategy == "specified_or_average_ionic" and isinstance(sp, Specie) and sp.oxi_state ): radius = sp.ionic_radius elif radius_strategy == "specified_or_average_ionic": radius = sp.average_ionic_radius elif radius_strategy == "covalent": el = str(getattr(sp, "element", sp)) radius = CovalentRadius.radius[el] elif radius_strategy == "van_der_waals": radius = sp.van_der_waals_radius elif radius_strategy == "atomic_calculated": radius = sp.atomic_radius_calculated if not radius: warnings.warn( "Radius unknown for {} and strategy {}, " "setting to 1.0.".format(sp, radius_strategy) ) radius = 1.0 radius = radius * radius_scale site_radii.append(radius) radii.append(site_radii) return radii @staticmethod def get_scene_and_legend( graph: Union[StructureGraph, MoleculeGraph], name="StructureMoleculeComponent", color_scheme="Jmol", color_scale=None, radius_strategy="specified_or_average_ionic", radius_scale=1.0, ellipsoid_site_prop=None, draw_image_atoms=True, bonded_sites_outside_unit_cell=True, hide_incomplete_bonds=False, explicitly_calculate_polyhedra_hull=False, scene_additions=None, show_compass=True, ) -> Tuple[Scene, Dict[str, str]]: scene = Scene(name=name) if graph is None: return scene, {} struct_or_mol = StructureMoleculeComponent._get_struct_or_mol(graph) site_prop_types = StructureMoleculeComponent._analyze_site_props(struct_or_mol) radii = StructureMoleculeComponent._get_display_radii_for_sites( struct_or_mol, radius_strategy=radius_strategy, radius_scale=radius_scale ) colors, legend = StructureMoleculeComponent._get_display_colors_and_legend_for_sites( struct_or_mol, site_prop_types, color_scale=color_scale, color_scheme=color_scheme, ) # TODO: add set_display_color option, set_display_radius, set_ellipsoid # call it "set_display_options" ? # sets legend too! display_legend struct_or_mol.add_site_property("display_radius", radii) struct_or_mol.add_site_property("display_color", colors) origin = StructureMoleculeComponent._get_origin(struct_or_mol) scene = graph.get_scene( draw_image_atoms=draw_image_atoms, bonded_sites_outside_unit_cell=bonded_sites_outside_unit_cell, hide_incomplete_edges=hide_incomplete_bonds, explicitly_calculate_polyhedra_hull=explicitly_calculate_polyhedra_hull, origin=origin, ) scene.name = name # TODO: ... scene.origin = StructureMoleculeComponent._get_origin(struct_or_mol) if show_compass: scene.contents.append( StructureMoleculeComponent._compass_from_lattice( struct_or_mol.lattice, origin=origin ) ) if scene_additions: scene.contents.append(scene_additions) return scene, legend