class DownloadChooser(ipw.HBox): """Download chooser for structure download To be able to have the download button work no matter the widget's final environment, (as long as it supports JavaScript), the very helpful insight from the following page is used: https://stackoverflow.com/questions/2906582/how-to-create-an-html-button-that-acts-like-a-link """ chosen_format = traitlets.Tuple(traitlets.Unicode(), traitlets.Dict()) structure = traitlets.Instance(Structure, allow_none=True) _formats = [ ( "Crystallographic Information File v1.0 (.cif)", { "ext": ".cif", "adapter_format": "cif" }, ), ("Protein Data Bank (.pdb)", { "ext": ".pdb", "adapter_format": "pdb" }), ( "Crystallographic Information File v1.0 [via ASE] (.cif)", { "ext": ".cif", "adapter_format": "ase", "final_format": "cif" }, ), ( "Protein Data Bank [via ASE] (.pdb)", { "ext": ".pdb", "adapter_format": "ase", "final_format": "proteindatabank" }, ), ( "XMol XYZ File [via ASE] (.xyz)", { "ext": ".xyz", "adapter_format": "ase", "final_format": "xyz" }, ), ( "XCrySDen Structure File [via ASE] (.xsf)", { "ext": ".xsf", "adapter_format": "ase", "final_format": "xsf" }, ), ( "WIEN2k Structure File [via ASE] (.struct)", { "ext": ".struct", "adapter_format": "ase", "final_format": "struct" }, ), ( "VASP POSCAR File [via ASE]", { "ext": "", "adapter_format": "ase", "final_format": "vasp" }, ), ( "Quantum ESPRESSO File [via ASE] (.in)", { "ext": ".in", "adapter_format": "ase", "final_format": "espresso-in" }, ), # Not yet implemented: # ( # "Protein Data Bank, macromolecular CIF v1.1 (PDBx/mmCIF) (.cif)", # {"ext": "cif", "adapter_format": "pdbx_mmcif"}, # ), ] _download_button_format = """ <input type="button" class="jupyter-widgets jupyter-button widget-button" value="Download" title="Download structure" style="width:auto;" {disabled} onclick="var link = document.createElement('a'); link.href = 'data:charset={encoding};base64,{data}'; link.download = '{filename}'; document.body.appendChild(link); link.click(); document.body.removeChild(link);" /> """ def __init__(self, **kwargs): self.dropdown = ipw.Dropdown(options=("Select a format", {}), width="auto") self.download_button = ipw.HTML( self._download_button_format.format(disabled="disabled", encoding="", data="", filename="")) self.children = (self.dropdown, self.download_button) super().__init__(children=self.children, layout={"width": "auto"}) self.reset() self.dropdown.observe(self._update_download_button, names="value") @traitlets.observe("structure") def _on_change_structure(self, change: dict): """Update widget when a new structure is chosen""" if change["new"] is None: self.reset() else: self._update_options() self.unfreeze() def _update_options(self): """Update options according to chosen structure""" # Disordered structures not usable with ASE if "disorder" in self.structure.structure_features: options = sorted([ option for option in self._formats if option[1].get("adapter_format", "") != "ase" ]) options.insert(0, ("Select a format", {})) else: options = sorted(self._formats) options.insert(0, ("Select a format", {})) self.dropdown.options = options def _update_download_button(self, change: dict): """Update Download button with correct onclick value The whole parsing process from `Structure` to desired format, is wrapped in a try/except, which is further wrapped in a `warnings.catch_warnings()`. This is in order to be able to log any warnings that might be thrown by the adapter in `optimade-python-tools` and/or any related exceptions. """ desired_format = change["new"] if not desired_format or desired_format is None: self.download_button.value = self._download_button_format.format( disabled="disabled", encoding="", data="", filename="") return with warnings.catch_warnings(): warnings.filterwarnings("error") try: output = getattr(self.structure, f"as_{desired_format['adapter_format']}") if desired_format["adapter_format"] in ( "ase", "pymatgen", "aiida_structuredata", ): # output is not a file, but a proxy Python class func = getattr( self, f"_get_via_{desired_format['adapter_format']}") output = func( output, desired_format=desired_format["final_format"]) encoding = "utf-8" # Specifically for CIF: v1.x CIF needs to be in "latin-1" formatting if desired_format["ext"] == ".cif": encoding = "latin-1" filename = ( f"optimade_structure_{self.structure.id}{desired_format['ext']}" ) if isinstance(output, str): output = output.encode(encoding) data = base64.b64encode(output).decode() except Warning as warn: self.download_button.value = self._download_button_format.format( disabled="disabled", encoding="", data="", filename="") warnings.warn(OptimadeClientWarning(warn)) except Exception as exc: self.download_button.value = self._download_button_format.format( disabled="disabled", encoding="", data="", filename="") if isinstance(exc, exceptions.OptimadeClientError): raise exc # Else wrap the exception to make sure to log it. raise exceptions.OptimadeClientError(exc) else: self.download_button.value = self._download_button_format.format( disabled="", encoding=encoding, data=data, filename=filename) @staticmethod def _get_via_pymatgen( structure_molecule: Union[pymatgenStructure, pymatgenMolecule], desired_format: str, ) -> str: """Use pymatgen.[Structure,Molecule].to() method""" molecule_only_formats = ["xyz", "pdb"] structure_only_formats = ["xsf", "cif"] if desired_format in molecule_only_formats and not isinstance( structure_molecule, pymatgenMolecule): raise exceptions.WrongPymatgenType( f"Converting to '{desired_format}' format is only possible with a pymatgen." f"Molecule, instead got {type(structure_molecule)}") if desired_format in structure_only_formats and not isinstance( structure_molecule, pymatgenStructure): raise exceptions.WrongPymatgenType( f"Converting to '{desired_format}' format is only possible with a pymatgen." f"Structure, instead got {type(structure_molecule)}.") return structure_molecule.to(fmt=desired_format) @staticmethod def _get_via_ase(atoms: aseAtoms, desired_format: str) -> Union[str, bytes]: """Use ase.Atoms.write() method""" with tempfile.NamedTemporaryFile(mode="w+b") as temp_file: atoms.write(temp_file.name, format=desired_format) res = temp_file.read() return res def freeze(self): """Disable widget""" for widget in self.children: widget.disabled = True def unfreeze(self): """Activate widget (in its current state)""" for widget in self.children: widget.disabled = False def reset(self): """Reset widget""" self.dropdown.index = 0 self.freeze()
class ProviderImplementationChooser( # pylint: disable=too-many-instance-attributes ipw.VBox): """List all OPTIMADE providers and their implementations""" provider = traitlets.Instance(LinksResourceAttributes, allow_none=True) database = traitlets.Tuple( traitlets.Unicode(), traitlets.Instance(LinksResourceAttributes, allow_none=True), default_value=("", None), ) HINT = {"provider": "Select a provider", "child_dbs": "Select a database"} INITIAL_CHILD_DBS = [("", (("No provider chosen", None), ))] def __init__( self, child_db_limit: int = None, disable_providers: List[str] = None, skip_providers: List[str] = None, skip_databases: Dict[str, List[str]] = None, provider_database_groupings: Dict[str, Dict[str, List[str]]] = None, **kwargs, ): self.child_db_limit = (child_db_limit if child_db_limit and child_db_limit > 0 else 10) self.skip_child_dbs = skip_databases or {} self.child_db_groupings = (provider_database_groupings if provider_database_groupings is not None else PROVIDER_DATABASE_GROUPINGS) self.offset = 0 self.number = 1 self.__perform_query = True self.__cached_child_dbs = {} self.debug = bool(os.environ.get("OPTIMADE_CLIENT_DEBUG", None)) providers = [] providers, invalid_providers = get_list_of_valid_providers( disable_providers=disable_providers, skip_providers=skip_providers) providers.insert(0, (self.HINT["provider"], {})) if self.debug: from optimade_client.utils import VERSION_PARTS local_provider = LinksResourceAttributes( **{ "name": "Local server", "description": "Local server, running aiida-optimade", "base_url": f"http://localhost:5000{VERSION_PARTS[0][0]}", "homepage": "https://example.org", "link_type": "external", }) providers.insert(1, ("Local server", local_provider)) self.providers = DropdownExtended( options=providers, disabled_options=invalid_providers, layout=ipw.Layout(width="auto"), ) self.show_child_dbs = ipw.Layout(width="auto", display="none") self.child_dbs = DropdownExtended(grouping=self.INITIAL_CHILD_DBS, layout=self.show_child_dbs) self.page_chooser = ResultsPageChooser(page_limit=self.child_db_limit, layout=self.show_child_dbs) self.providers.observe(self._observe_providers, names="value") self.child_dbs.observe(self._observe_child_dbs, names="value") self.page_chooser.observe( self._get_more_child_dbs, names=["page_link", "page_offset", "page_number"]) self.error_or_status_messages = ipw.HTML("") super().__init__( children=( self.providers, self.child_dbs, self.page_chooser, self.error_or_status_messages, ), layout=ipw.Layout(width="auto"), **kwargs, ) def freeze(self): """Disable widget""" self.providers.disabled = True self.show_child_dbs.display = "none" self.page_chooser.freeze() def unfreeze(self): """Activate widget (in its current state)""" self.providers.disabled = False self.show_child_dbs.display = None self.page_chooser.unfreeze() def reset(self): """Reset widget""" self.page_chooser.reset() self.offset = 0 self.number = 1 self.providers.disabled = False self.providers.index = 0 self.show_child_dbs.display = "none" self.child_dbs.grouping = self.INITIAL_CHILD_DBS def _observe_providers(self, change: dict): """Update child database dropdown upon changing provider""" value = change["new"] self.show_child_dbs.display = "none" self.provider = value if value is None or not value: self.show_child_dbs.display = "none" self.child_dbs.grouping = self.INITIAL_CHILD_DBS self.providers.index = 0 self.child_dbs.index = 0 else: self._initialize_child_dbs() if sum([len(_[1]) for _ in self.child_dbs.grouping]) <= 2: # The provider either has 0 or 1 implementations # or we have failed to retrieve any implementations. # Automatically choose the 1 implementation (if there), # while otherwise keeping the dropdown disabled. self.show_child_dbs.display = "none" try: self.child_dbs.index = 1 LOGGER.debug("Changed child_dbs index. New child_dbs: %s", self.child_dbs) except IndexError: pass else: self.show_child_dbs.display = None def _observe_child_dbs(self, change: dict): """Update database traitlet with base URL for chosen child database""" value = change["new"] if value is None or not value: self.database = "", None else: self.database = self.child_dbs.label.strip(), self.child_dbs.value @staticmethod def _remove_current_dropdown_option(dropdown: ipw.Dropdown) -> tuple: """Remove the current option from a Dropdown widget and return updated options Since Dropdown.options is a tuple there is a need to go through a list. """ list_of_options = list(dropdown.options) list_of_options.pop(dropdown.index) return tuple(list_of_options) def _initialize_child_dbs(self): """New provider chosen; initialize child DB dropdown""" self.offset = 0 self.number = 1 try: # Freeze and disable list of structures in dropdown widget # We don't want changes leading to weird things happening prior to the query ending self.freeze() # Reset the error or status message if self.error_or_status_messages.value: self.error_or_status_messages.value = "" if self.provider.base_url in self.__cached_child_dbs: cache = self.__cached_child_dbs[self.provider.base_url] LOGGER.debug( "Initializing child DBs for %s. Using cached info:\n%r", self.provider.name, cache, ) self._set_child_dbs(cache["child_dbs"]) data_returned = cache["data_returned"] data_available = cache["data_available"] links = cache["links"] else: LOGGER.debug("Initializing child DBs for %s.", self.provider.name) # Query database and get child_dbs child_dbs, links, data_returned, data_available = self._query() while True: # Update list of structures in dropdown widget exclude_child_dbs, final_child_dbs = self._update_child_dbs( data=child_dbs, skip_dbs=self.skip_child_dbs.get( self.provider.name, []), ) LOGGER.debug("Exclude child DBs: %r", exclude_child_dbs) data_returned -= len(exclude_child_dbs) data_available -= len(exclude_child_dbs) if exclude_child_dbs and data_returned: child_dbs, links, data_returned, _ = self._query( exclude_ids=exclude_child_dbs) else: break self._set_child_dbs(final_child_dbs) # Cache initial child_dbs and related information self.__cached_child_dbs[self.provider.base_url] = { "child_dbs": final_child_dbs, "data_returned": data_returned, "data_available": data_available, "links": links, } LOGGER.debug( "Found the following, which has now been cached:\n%r", self.__cached_child_dbs[self.provider.base_url], ) # Update pageing self.page_chooser.set_pagination_data( data_returned=data_returned, data_available=data_available, links_to_page=links, reset_cache=True, ) except QueryError as exc: LOGGER.debug( "Trying to initalize child DBs. QueryError caught: %r", exc) if exc.remove_target: LOGGER.debug( "Remove target: %r. Will remove target at %r: %r", exc.remove_target, self.providers.index, self.providers.value, ) self.providers.options = self._remove_current_dropdown_option( self.providers) self.reset() else: LOGGER.debug( "Remove target: %r. Will NOT remove target at %r: %r", exc.remove_target, self.providers.index, self.providers.value, ) self.show_child_dbs.display = "none" self.child_dbs.grouping = self.INITIAL_CHILD_DBS else: self.unfreeze() def _set_child_dbs( self, data: List[Tuple[str, List[Tuple[str, LinksResourceAttributes]]]], ) -> None: """Update the child_dbs options with `data`""" first_choice = (self.HINT["child_dbs"] if data else "No valid implementations found") new_data = list(data) new_data.insert(0, ("", [(first_choice, None)])) self.child_dbs.grouping = new_data def _update_child_dbs( self, data: List[dict], skip_dbs: List[str] = None ) -> Tuple[List[str], List[List[Union[str, List[Tuple[ str, LinksResourceAttributes]]]]], ]: """Update child DB dropdown from response data""" child_dbs = ({ "": [] } if self.providers.label not in self.child_db_groupings else deepcopy( self.child_db_groupings[self.providers.label])) exclude_dbs = [] skip_dbs = skip_dbs or [] for entry in data: child_db = update_old_links_resources(entry) if child_db is None: continue # Skip if desired by user if child_db.id in skip_dbs: exclude_dbs.append(child_db.id) continue attributes = child_db.attributes # Skip if not a 'child' link_type database if attributes.link_type != LinkType.CHILD: LOGGER.debug( "Skip %s: Links resource not a %r link_type, instead: %r", attributes.name, LinkType.CHILD, attributes.link_type, ) continue # Skip if there is no base_url if attributes.base_url is None: LOGGER.debug( "Skip %s: Base URL found to be None for child DB: %r", attributes.name, child_db, ) exclude_dbs.append(child_db.id) continue versioned_base_url = get_versioned_base_url(attributes.base_url) if versioned_base_url: attributes.base_url = versioned_base_url else: # Not a valid/supported child DB: skip LOGGER.debug( "Skip %s: Could not determine versioned base URL for child DB: %r", attributes.name, child_db, ) exclude_dbs.append(child_db.id) continue if self.providers.label in self.child_db_groupings: for group, ids in self.child_db_groupings[ self.providers.label].items(): if child_db.id in ids: index = child_dbs[group].index(child_db.id) child_dbs[group][index] = (attributes.name, attributes) break else: if "" in child_dbs: child_dbs[""].append((attributes.name, attributes)) else: child_dbs[""] = [(attributes.name, attributes)] else: child_dbs[""].append((attributes.name, attributes)) if self.providers.label in self.child_db_groupings: for group, ids in tuple(child_dbs.items()): child_dbs[group] = [_ for _ in ids if isinstance(_, tuple)] child_dbs = list(child_dbs.items()) LOGGER.debug("Final updated child_dbs: %s", child_dbs) return exclude_dbs, child_dbs def _get_more_child_dbs(self, change): """Query for more child DBs according to page_offset""" if self.providers.value is None: # This may be called if a provider is suddenly removed (bad provider) return if not self.__perform_query: self.__perform_query = True LOGGER.debug( "Will not perform query with pageing: name=%s value=%s", change["name"], change["new"], ) return pageing: Union[int, str] = change["new"] LOGGER.debug( "Detected change in page_chooser: name=%s value=%s", change["name"], pageing, ) if change["name"] == "page_offset": LOGGER.debug( "Got offset %d to retrieve more child DBs from %r", pageing, self.providers.value, ) self.offset = pageing pageing = None elif change["name"] == "page_number": LOGGER.debug( "Got number %d to retrieve more child DBs from %r", pageing, self.providers.value, ) self.number = pageing pageing = None else: LOGGER.debug( "Got link %r to retrieve more child DBs from %r", pageing, self.providers.value, ) # It is needed to update page_offset, but we do not wish to query again with self.hold_trait_notifications(): self.__perform_query = False self.page_chooser.update_offset() try: # Freeze and disable both dropdown widgets # We don't want changes leading to weird things happening prior to the query ending self.freeze() # Query index meta-database LOGGER.debug("Querying for more child DBs using pageing: %r", pageing) child_dbs, links, _, _ = self._query(pageing) data_returned = self.page_chooser.data_returned while True: # Update list of child DBs in dropdown widget exclude_child_dbs, final_child_dbs = self._update_child_dbs( child_dbs) data_returned -= len(exclude_child_dbs) if exclude_child_dbs and data_returned: child_dbs, links, data_returned, _ = self._query( link=pageing, exclude_ids=exclude_child_dbs) else: break self._set_child_dbs(final_child_dbs) # Update pageing self.page_chooser.set_pagination_data(data_returned=data_returned, links_to_page=links) except QueryError as exc: LOGGER.debug( "Trying to retrieve more child DBs (new page). QueryError caught: %r", exc, ) if exc.remove_target: LOGGER.debug( "Remove target: %r. Will remove target at %r: %r", exc.remove_target, self.providers.index, self.providers.value, ) self.providers.options = self._remove_current_dropdown_option( self.providers) self.reset() else: LOGGER.debug( "Remove target: %r. Will NOT remove target at %r: %r", exc.remove_target, self.providers.index, self.providers.value, ) self.show_child_dbs.display = "none" self.child_dbs.grouping = self.INITIAL_CHILD_DBS else: self.unfreeze() def _query( # pylint: disable=too-many-locals,too-many-branches,too-many-statements self, link: str = None, exclude_ids: List[str] = None) -> Tuple[List[dict], dict, int, int]: """Query helper function""" # If a complete link is provided, use it straight up if link is not None: try: if exclude_ids: filter_value = " AND ".join( [f'NOT id="{id_}"' for id_ in exclude_ids]) parsed_url = urllib.parse.urlparse(link) queries = urllib.parse.parse_qs(parsed_url.query) # Since parse_qs wraps all values in a list, # this extracts the values from the list(s). queries = {key: value[0] for key, value in queries.items()} if "filter" in queries: queries[ "filter"] = f"( {queries['filter']} ) AND ( {filter_value} )" else: queries["filter"] = filter_value parsed_query = urllib.parse.urlencode(queries) link = ( f"{parsed_url.scheme}://{parsed_url.netloc}{parsed_url.path}" f"?{parsed_query}") link = ordered_query_url(link) response = SESSION.get(link, timeout=TIMEOUT_SECONDS) if response.from_cache: LOGGER.debug("Request to %s was taken from cache !", link) response = response.json() except ( requests.exceptions.ConnectTimeout, requests.exceptions.ConnectionError, requests.exceptions.ReadTimeout, ) as exc: response = { "errors": { "msg": "CLIENT: Connection error or timeout.", "url": link, "Exception": repr(exc), } } except json.JSONDecodeError as exc: response = { "errors": { "msg": "CLIENT: Could not decode response to JSON.", "url": link, "Exception": repr(exc), } } else: filter_ = '( link_type="child" OR type="child" )' if exclude_ids: filter_ += ( " AND ( " + " AND ".join([f'NOT id="{id_}"' for id_ in exclude_ids]) + " )") response = perform_optimade_query( filter=filter_, base_url=self.provider.base_url, endpoint="/links", page_limit=self.child_db_limit, page_offset=self.offset, page_number=self.number, ) msg, http_errors = handle_errors(response) if msg: if 404 in http_errors: # If /links not found move on pass else: self.error_or_status_messages.value = msg raise QueryError(msg=msg, remove_target=True) # Check implementation API version msg = validate_api_version(response.get("meta", {}).get("api_version", ""), raise_on_fail=False) if msg: self.error_or_status_messages.value = ( f"{msg}<br>The provider has been removed.") raise QueryError(msg=msg, remove_target=True) LOGGER.debug( "Manually remove `exclude_ids` if filters are not supported") child_db_data = { impl.get("id", "N/A"): impl for impl in response.get("data", []) } if exclude_ids: for links_id in exclude_ids: if links_id in list(child_db_data.keys()): child_db_data.pop(links_id) LOGGER.debug("child_db_data after popping: %r", child_db_data) response["data"] = list(child_db_data.values()) if "meta" in response: if "data_available" in response["meta"]: old_data_available = response["meta"].get( "data_available", 0) if len(response["data"]) > old_data_available: LOGGER.debug("raising OptimadeClientError") raise OptimadeClientError( f"Reported data_available ({old_data_available}) is smaller than " f"curated list of responses ({len(response['data'])}).", ) response["meta"]["data_available"] = len(response["data"]) else: raise OptimadeClientError( "'meta' not found in response. Bad response") LOGGER.debug( "Attempt for %r (in /links): Found implementations (names+base_url only):\n%s", self.provider.name, [ f"(id: {name}; base_url: {base_url}) " for name, base_url in [( impl.get("id", "N/A"), impl.get("attributes", {}).get("base_url", "N/A"), ) for impl in response.get("data", [])] ], ) # Return all implementations of link_type "child" implementations = [ implementation for implementation in response.get("data", []) if (implementation.get("attributes", {}).get("link_type", "") == "child" or implementation.get("type", "") == "child") ] LOGGER.debug( "After curating for implementations which are of 'link_type' = 'child' or 'type' == " "'child' (old style):\n%s", [ f"(id: {name}; base_url: {base_url}) " for name, base_url in [( impl.get("id", "N/A"), impl.get("attributes", {}).get("base_url", "N/A"), ) for impl in implementations] ], ) # Get links, data_returned, and data_available links = response.get("links", {}) data_returned = response.get("meta", {}).get("data_returned", len(implementations)) if data_returned > 0 and not implementations: # Most probably dealing with pre-v1.0.0-rc.2 implementations data_returned = 0 data_available = response.get("meta", {}).get("data_available", 0) return implementations, links, data_returned, data_available
class MultiToggleButtons_AllOrSome(Box): description = traitlets.Unicode() value = traitlets.Tuple() options = traitlets.Union([traitlets.List(), traitlets.Dict()]) style = traitlets.Dict() def __init__(self, *, short_label_map=None, **kwargs): if short_label_map is None: short_label_map = {} super().__init__(**kwargs) self._selection_obj = widget_selection._MultipleSelection() traitlets.link((self, 'options'), (self._selection_obj, 'options')) traitlets.link((self, 'value'), (self._selection_obj, 'value')) @observer(self, 'options') def _(*_): self.buttons = [] for label in self._selection_obj._options_labels: short_label = short_label_map.get(label, label) self.buttons.append( ToggleButton( description=short_label if len(short_label) < 15 else short_label[:12] + "…", tooltip=label, layout=Layout( margin='1', width='auto', ), )) if self.description: self.label = Label( self.description, layout=Layout( width=self.style.get('description_width', '100px'))) else: self.label = Label( self.description, layout=Layout( width=self.style.get('description_width', '0px'))) self.children = [self.label] + self.buttons @observer(self.buttons, 'value') def _(*_): proposed_value = tuple(value for btn, value in zip( self.buttons, self._selection_obj._options_values) if btn.value) # When nothing is selected, treat as if everything is selected. if len(proposed_value) == 0: proposed_value = tuple(value for btn, value in zip( self.buttons, self._selection_obj._options_values)) self.value = proposed_value self.add_class('btn-group') def reset(self): opts = self.options self.options = [] self.options = opts def set_value(self, x): for b, opt in zip(self.buttons, self.options): b.value = (opt in x) def set_all_on(self): for b, opt in zip(self.buttons, self.options): b.value = True def set_all_off(self): for b, opt in zip(self.buttons, self.options): b.value = False
class ZarrOutputMixin(tl.HasTraits): """ This class assumes that the node has a 'output_format' attribute (currently the "Lambda" Node, and the "Process" Node) Attributes ----------- zarr_file: str Path to the output zarr file that collects all of the computed results. This can reside on S3. dataset: ZarrGroup A handle to the zarr group pointing to the output file fill_output: bool, optional Default is False (unlike parent class). If True, will collect the output data and return it as an xarray. init_file_mode: str, optional Default is 'w'. Mode used for initializing the zarr file. zarr_chunks: dict Size of the chunks in the zarr file for each dimension zarr_shape: dict, optional Default is the {coordinated.dims: coordinates.shape}, where coordinates used as part of the eval call. This does not need to be specified unless the Node modifies the input coordinates (as part of a Reduce operation, for example). The result can be incorrect and requires care/checking by the user. zarr_coordinates: podpac.Coordinates, optional Default is None. If the node modifies the shape of the input coordinates, this allows users to set the coordinates in the output zarr file. This can be incorrect and requires care by the user. skip_existing: bool Default is False. If true, this will check to see if the results already exist. And if so, it will not submit a job for that particular coordinate evaluation. This assumes self.chunks == self.zar_chunks list_dir: bool, optional Default is False. If skip_existing is True, by default existing files are checked by asking for an 'exists' call. If list_dir is True, then at the first opportunity a "list_dir" is performed on the directory and the results are cached. """ zarr_file = tl.Unicode().tag(attr=True) dataset = tl.Any() zarr_node = NodeTrait() zarr_data_key = tl.Union([tl.Unicode(), tl.List()]) fill_output = tl.Bool(False) init_file_mode = tl.Unicode("a").tag(attr=True) zarr_chunks = tl.Dict(default_value=None, allow_none=True).tag(attr=True) zarr_shape = tl.Dict(allow_none=True, default_value=None).tag(attr=True) zarr_coordinates = tl.Instance(Coordinates, allow_none=True, default_value=None).tag(attr=True) zarr_dtype = tl.Unicode("f4") skip_existing = tl.Bool(True).tag(attr=True) list_dir = tl.Bool(False) _list_dir = tl.List(allow_none=True, default_value=[]) _shape = tl.Tuple() _chunks = tl.List() aws_client_kwargs = tl.Dict() aws_config_kwargs = tl.Dict() def eval(self, coordinates, **kwargs): output = kwargs.get("output") if self.zarr_shape is None: self._shape = coordinates.shape else: self._shape = tuple(self.zarr_shape.values()) # initialize zarr file if self.zarr_chunks is None: chunks = [self.chunks[d] for d in coordinates] else: chunks = [self.zarr_chunks[d] for d in coordinates] self._chunks = chunks zf, data_key, zn = self.initialize_zarr_array(self._shape, chunks) self.dataset = zf self.zarr_data_key = data_key self.zarr_node = zn zn.keys # eval _log.debug("Starting parallel eval.") missing_dims = [ d for d in coordinates.dims if d not in self.chunks.keys() ] if self.zarr_coordinates is not None: missing_dims = missing_dims + [ d for d in self.zarr_coordinates.dims if d not in missing_dims ] set_coords = merge_dims( [coordinates.drop(missing_dims), self.zarr_coordinates]) else: set_coords = coordinates.drop(missing_dims) set_coords.transpose(*coordinates.dims) self.set_zarr_coordinates(set_coords, data_key) if self.list_dir: dk = data_key if isinstance(dk, list): dk = dk[0] self._list_dir = self.zarr_node.list_dir(dk) output = super(ZarrOutputMixin, self).eval(coordinates, output=output) # fill in the coordinates, this is guaranteed to be correct even if the user messed up. if output is not None: self.set_zarr_coordinates(Coordinates.from_xarray(output.coords), data_key) else: return zf return output def set_zarr_coordinates(self, coordinates, data_key): # Fill in metadata for dk in data_key: self.dataset[dk].attrs["_ARRAY_DIMENSIONS"] = coordinates.dims for d in coordinates.dims: # TODO ADD UNITS AND TIME DECODING INFORMATION self.dataset.create_dataset(d, shape=coordinates[d].size, overwrite=True) self.dataset[d][:] = coordinates[d].coordinates def initialize_zarr_array(self, shape, chunks): _log.debug("Creating Zarr file.") zn = Zarr(source=self.zarr_file, file_mode=self.init_file_mode, aws_client_kwargs=self.aws_client_kwargs) if self.source.output or getattr(self.source, "data_key", None): data_key = self.source.output if data_key is None: data_key = self.source.data_key if not isinstance(data_key, list): data_key = [data_key] elif self.source.outputs: # If someone restricted the outputs for this node, we need to know data_key = [dk for dk in data_key if dk in self.source.outputs] elif self.source.outputs: data_key = self.source.outputs else: data_key = ["data"] zf = zarr.open(zn._get_store(), mode=self.init_file_mode) # Intialize the output zarr arrays for dk in data_key: try: arr = zf.create_dataset( dk, shape=shape, chunks=chunks, fill_value=np.nan, dtype=self.zarr_dtype, overwrite=not self.skip_existing, ) except ValueError: pass # Dataset already exists # Recompute any cached properties zn = Zarr(source=self.zarr_file, file_mode=self.init_file_mode, aws_client_kwargs=self.aws_client_kwargs) return zf, data_key, zn def eval_source(self, coordinates, coordinates_index, out, i, source=None): if source is None: source = self.source if self.skip_existing: # This section allows previously computed chunks to be skipped dk = self.zarr_data_key if isinstance(dk, list): dk = dk[0] try: exists = self.zarr_node.chunk_exists(coordinates_index, data_key=dk, list_dir=self._list_dir, chunks=self._chunks) except ValueError as e: # This was needed in cases where a poor internet connection caused read errors exists = False if exists: _log.info("Skipping {} (already exists)".format(i)) return out, coordinates_index # Make a copy to prevent any possibility of memory corruption source = Node.from_definition(source.definition) _log.debug("Creating output format.") output = dict( format="zarr_part", format_kwargs=dict( part=[[s.start, min(s.stop, self._shape[i]), s.step] for i, s in enumerate(coordinates_index)], source=self.zarr_file, mode="a", ), ) _log.debug("Finished creating output format.") if source.has_trait("output_format"): source.set_trait("output_format", output) _log.debug("output: {}, coordinates.shape: {}".format( output, coordinates.shape)) _log.debug("Evaluating node.") o, slc = super(ZarrOutputMixin, self).eval_source(coordinates, coordinates_index, out, i, source) if not source.has_trait("output_format"): o.to_format(output["format"], **output["format_kwargs"]) return o, slc
class XELK(ElkTransformer): """NetworkX DiGraphs to ELK dictionary structure""" HIDDEN_ATTR = "hidden" hoist_hidden_edges: bool = True source = T.Tuple(T.Instance(nx.Graph), T.Instance(nx.DiGraph, allow_none=True)) layouts = T.Dict() # keys: networkx nodes {ElkElements: {layout options}} css_classes = T.Dict() port_scale = T.Int(default_value=5) label_key = T.Unicode(default_value="labels") port_key = T.Unicode(default_value="ports") @T.default("source") def _default_source(self): return (nx.Graph(), None) @T.default("text_sizer") def _default_text_sizer(self): return ElkTextSizer() @T.default("layouts") def _default_layouts(self): parent_opts = opt.OptionsWidget( identifier="parents", options=[ opt.HierarchyHandling(), ], ) label_opts = opt.OptionsWidget( identifier=ElkLabel, options=[opt.NodeLabelPlacement(horizontal="center")]) node_opts = opt.OptionsWidget( identifier=ElkNode, options=[ opt.NodeSizeConstraints(), ], ) default = opt.OptionsWidget( options=[parent_opts, node_opts, label_opts]) return {ElkRoot: default.value} def node_id(self, node: Hashable) -> str: """Get the element id for a node in the main graph for use in elk :param node: Node in main graph :type node: Hashable :return: Element ID :rtype: str """ g, tree = self.source if node is ElkRoot: return self.ELK_ROOT_ID elif node in g: return g.nodes.get(node, {}).get("id", f"{node}") return f"{node}" def port_id(self, node: Hashable, port: Optional[Hashable] = None) -> str: """Get a suitable Elk identifier from the node and port :param node: Node from the incoming networkx graph :type node: Hashable :param port: Port identifier, defaults to None :type port: Optional[Hashable], optional :return: If no port is provided will refer to the node :rtype: str """ if port is None: return self.node_id(node) else: return f"{self.node_id(node)}.{port}" def edge_id(self, edge: Edge): # TODO probably will need more sophisticated id generation in future return "{}__{}___{}__{}".format(edge.source, edge.source_port, edge.target, edge.target_port) async def transform(self) -> ElkNode: """Generate ELK dictionary structure :return: Root Elk node :rtype: ElkNode """ # TODO decide behavior for nodes that exist in the tree but not g g, tree = self.source self.clear_registry() visible_edges, hidden_edges = self.collect_edges() # Process visible networkx nodes into elknodes visible_nodes = [ n for n in g.nodes() if not is_hidden(tree, n, self.HIDDEN_ATTR) ] # make elknodes then connect their hierarchy elknodes: NodeMap = {} ports: PortMap = {} for node in visible_nodes: elknode, node_ports = await self.make_elknode(node) for key, value in node_ports.items(): ports[key] = value elknodes[node] = elknode # make top level ElkNode and attach all others as children elknodes[ElkRoot] = top = ElkNode( id=self.ELK_ROOT_ID, children=build_hierarchy(g, tree, elknodes, self.HIDDEN_ATTR), layoutOptions=self.get_layout(ElkRoot, "parents"), ) # map of original nodes to the generated elknodes for node, elk_node in elknodes.items(): self.register(elk_node, node) elknodes, ports = await self.process_edges(elknodes, ports, visible_edges) # process edges with one or both original endpoints are hidden if self.hoist_hidden_edges: elknodes, ports = await self.process_edges( elknodes, ports, hidden_edges, edge_style={"slack-edge"}, port_style={"slack-port"}, ) # iterate through the port map and add ports to ElkNodes for port_id, port in ports.items(): owner = port.node elkport = port.elkport if owner not in elknodes: # TODO skip generating port to begin with break elknode = elknodes[owner] if elknode.ports is None: elknode.ports = [] layout = self.get_layout(owner, ElkPort) elkport.layoutOptions = merge(elkport.layoutOptions, layout) elknode.ports += [elkport] # map of ports to the generated elkports self.register(elkport, port) # bulk calculate label sizes await size_labels(self.text_sizer, collect_labels([top])) return top async def make_elknode(self, node) -> Tuple[ElkNode, PortMap]: # merge layout options defined on the node data with default layout # options node_data = self.get_node_data(node) layout = merge( node_data.get("layoutOptions", {}), self.get_layout(node, ElkNode), ) labels = await self.make_labels(node) # update port map with declared ports in the networkx node data node_ports = await self.collect_ports(node) properties = self.get_properties(node, self.get_css(node, ElkNode)) or None elk_node = ElkNode( id=self.node_id(node), labels=labels, layoutOptions=layout, properties=properties, width=node_data.get("width", None), height=node_data.get("height", None), ) return elk_node, node_ports def get_layout(self, node: Hashable, elk_type: Type[ElkGraphElement]) -> Optional[Dict]: """Get the Elk Layout Options appropriate for given networkx node and filter by given elk_type :param node: [description] :type node: Hashable :param elk_type: [description] :type elk_type: Type[ElkGraphElement] :return: [description] :rtype: [type] """ # TODO look at self.source hierarchy and resolve layout with added # infomation. until then use root node `None` for layout options if node not in self.layouts: node = ElkRoot type_opts = self.layouts.get(node, {}) options = {**type_opts.get(elk_type, {})} if options: return options def get_properties( self, element: Optional[Union[Hashable, str]], dom_classes: Optional[Set[str]] = None, ) -> Dict: """Get the properties for a graph element :param element: Networkx node or edge :type node: Hashable :param dom_classes: Set of base CSS DOM classes to merge, defaults to Set[str]=None :type dom_classes: [type], optional :return: Set of CSS Classes to apply :rtype: Set[str] """ g, tree = self.source properties = [] if g and element in g: g_props = g.nodes[element].get("properties", {}) if g_props: properties += [g_props] if hasattr(element, "data"): properties += [element.data.get("properties", {})] elif isinstance(element, dict): properties += [element.get("properties", {})] if dom_classes: properties += [{"cssClasses": " ".join(dom_classes)}] if not properties: return {} elif len(properties) == 1: return properties[0] merged_properties = {} for props in properties[::-1]: merged_properties = merge(props, merged_properties) return merged_properties def get_css( self, node: Hashable, elk_type: Type[ElkGraphElement], dom_classes: Set[str] = None, ) -> Set[str]: """Get the CSS Classes appropriate for given networkx node given elk_type :param node: Networkx node :type node: Hashable :param elk_type: ElkGraphElement to get appropriate css classes :type elk_type: Type[ElkGraphElement] :param dom_classes: Set of base CSS DOM classes to merge, defaults to Set[str]=None :type dom_classes: [type], optional :return: Set of CSS Classes to apply :rtype: Set[str] """ typed_css = self.css_classes.get(node, {}) css_classes = set(typed_css.get(elk_type, [])) if dom_classes is None: return css_classes return css_classes | dom_classes async def process_edges( self, nodes: NodeMap, ports: PortMap, edges: EdgeMap, edge_style: Set[str] = None, port_style: Set[str] = None, ) -> Tuple[NodeMap, PortMap]: for owner, edge_list in edges.items(): edge_css = self.get_css(owner, ElkEdge, edge_style) port_css = self.get_css(owner, ElkPort, port_style) for edge in edge_list: elknode = nodes[owner] if elknode.edges is None: elknode.edges = [] if edge.source_port is not None: port_id = self.port_id(edge.source, edge.source_port) if port_id not in ports: ports[port_id] = await self.make_port( edge.source, edge.source_port, port_css) if edge.target_port is not None: port_id = self.port_id(edge.target, edge.target_port) if port_id not in ports: ports[port_id] = await self.make_port( edge.target, edge.target_port, port_css) elknode.edges += [await self.make_edge(edge, edge_css)] return nodes, ports async def make_edge(self, edge: Edge, styles: Optional[Set[str]] = None) -> ElkExtendedEdge: """Make the associated Elk edge for the given Edge :param edge: Edge object to wrap :type edge: Edge :param styles: List of css classes to add to given Elk edge, defaults to None :type styles: Optional[List[str]], optional :return: Elk edge :rtype: ElkExtendedEdge """ labels = [] properties = self.get_properties(edge, styles) or None label_layout_options = self.get_layout( edge.owner, ElkLabel) # TODO add edgelabel type? edge_layout_options = self.get_layout(edge.owner, ElkEdge) for i, label in enumerate(edge.data.get(self.label_key, [])): if isinstance(label, ElkLabel): label = label.to_dict() # used to create copy of label if isinstance(label, dict): label = ElkLabel.from_dict(label) if isinstance(label, str): label = ElkLabel(id=f"{edge.owner}_label_{i}_{label}", text=label) label.layoutOptions = merge(label.layoutOptions, label_layout_options) labels.append(label) for label in labels: self.register(label, edge) elk_edge = ElkExtendedEdge( id=edge.data.get("id", self.edge_id(edge)), sources=[self.port_id(edge.source, edge.source_port)], targets=[self.port_id(edge.target, edge.target_port)], properties=properties, layoutOptions=merge(edge.data.get("layoutOptions"), edge_layout_options), labels=compact(labels), ) self.register(elk_edge, edge) return elk_edge async def make_port(self, owner: Hashable, port: Hashable, styles: Optional[Set[str]] = None) -> Port: """Make the associated elk port for the given owner node and port :param owner: [description] :type owner: Hashable :param port: [description] :type port: Hashable :param styles: list of css classes to apply to given ElkPort :type styles: List[str] :return: [description] :rtype: ElkPort """ port_id = self.port_id(owner, port) properties = self.get_properties(port, styles) or None elk_port = ElkPort( id=port_id, height=self.port_scale, width=self.port_scale, properties=properties, # TODO labels ) return Port(node=owner, elkport=elk_port) def get_node_data(self, node: Hashable) -> Dict: g, tree = self.source return g.nodes.get(node, {}) async def collect_ports(self, *nodes) -> PortMap: ports: PortMap = {} for node in nodes: values = self.get_node_data(node).get(self.port_key, []) for i, port in enumerate(values): if isinstance(port, ElkPort): # generate a fresh copy of the port to prevent mutating original port = port.to_dict() elif isinstance(port, str): port_id = self.port_id(node, port) port = { "id": port_id, "labels": [{ "text": port, "id": f"{port}_label_{i}", }], } if isinstance(port, dict): elkport = ElkPort.from_dict(port) if elkport.width is None: elkport.width = self.port_scale if elkport.height is None: elkport.height = self.port_scale ports[elkport.id] = Port(node=node, elkport=elkport) return ports async def make_labels(self, node: Hashable) -> Optional[List[ElkLabel]]: labels = [] g = self.source[0] data = g.nodes.get(node, {}) values = data.get(self.label_key, [data.get("_id", f"{node}")]) properties = {} css_classes = self.get_css(node, ElkLabel) if css_classes: properties["cssClasses"] = " ".join(css_classes) if isinstance(values, (str, ElkLabel)): values = [values] # get node labels for i, label in enumerate(values): if isinstance(label, str): label = ElkLabel( id=f"{label}_label_{i}_{node}", text=label, ) elif isinstance(label, ElkLabel): # prevent mutating original label in the node data label = label.to_dict() if isinstance(label, dict): label = ElkLabel.from_dict(label) # add css classes and layout options label.layoutOptions = merge(label.layoutOptions, self.get_layout(node, ElkLabel)) merged_props = merge(label.properties, properties) if merged_props is not None: merged_props = ElkProperties.from_dict(merged_props) label.properties = merged_props labels.append(label) self.register(label, node) return labels def collect_edges(self) -> Tuple[EdgeMap, EdgeMap]: """[summary] :return: [description] :rtype: Tuple[ Dict[Hashable, List[ElkExtendedEdge]], Dict[Hashable, List[ElkExtendedEdge]] ] """ visible: EdgeMap = defaultdict( list) # will index edges by nx.lowest_common_ancestor hidden: EdgeMap = defaultdict( list) # will index edges by nx.lowest_common_ancestor g, tree = self.source attr = self.HIDDEN_ATTR hidden: List[Tuple[List, List]] = [] visible: EdgeMap = defaultdict( list) # will index edges by nx.lowest_common_ancestor hidden: EdgeMap = defaultdict( list) # will index edges by nx.lowest_common_ancestor closest_visible = map_visible(g, tree, attr) @lru_cache() def closest_common_visible(nodes: Tuple[Hashable]) -> Hashable: if tree is None: return ElkRoot result = lowest_common_ancestor(tree, nodes) return result for source, target, edge_data in g.edges(data=True): source_port, target_port = get_ports(edge_data) vis_source = closest_visible[source] vis_target = closest_visible[target] shidden = vis_source != source thidden = vis_target != target owner = closest_common_visible((vis_source, vis_target)) if source == target and source == owner: if owner in tree: for p in tree.predecessors(owner): # need to make this edge's owner to it's parent owner = p if shidden or thidden: # create new slack ports if source or target is remapped if vis_source != source: source_port = (source, source_port) if vis_target != target: target_port = (target, target_port) if vis_source != vis_target: hidden[owner].append( Edge( source=vis_source, source_port=source_port, target=vis_target, target_port=target_port, data=edge_data, owner=owner, )) else: visible[owner].append( Edge( source=source, source_port=source_port, target=target, target_port=target_port, data=edge_data, owner=owner, )) return visible, hidden
class Widget(DOMWidget): _view_name = traitlets.Unicode('ReboundView').tag(sync=True) _view_module = traitlets.Unicode('rebound').tag(sync=True) count = traitlets.Int(0).tag(sync=True) t = traitlets.Float().tag(sync=True) N = traitlets.Int().tag(sync=True) width = traitlets.Float().tag(sync=True) height = traitlets.Float().tag(sync=True) scale = traitlets.Float().tag(sync=True) particle_data = traitlets.Bytes().tag(sync=True) orbit_data = traitlets.Bytes().tag(sync=True) orientation = traitlets.Tuple().tag(sync=True) orbits = traitlets.Int().tag(sync=True) def __init__(self, simulation, size=(200, 200), orientation=(0., 0., 0., 1.), scale=None, autorefresh=True, orbits=True): """ Initializes a Widget. Widgets provide real-time 3D interactive visualizations for REBOUND simulations within Jupyter Notebooks. To use widgets, the ipywidgets package needs to be installed and enabled in your Jupyter notebook server. Parameters ---------- size : (int, int), optional Specify the size of the widget in pixels. The default is 200 times 200 pixels. orientation : (float, float, float, float), optional Specify the initial orientation of the view. The four floats correspond to the x, y, z, and w components of a quaternion. The quaternion will be normalized. scale : float, optional Set the initial scale of the view. If not set, the widget will determine the scale automatically based on current particle positions. autorefresh : bool, optional The default value if True. The view is updated whenever a particle is added, removed and every 100th of a second while a simulation is running. If set to False, then the user needs to manually call the refresh() function on the widget. This might be useful if performance is an issue. orbits : bool, optional The default value for this is True and the widget will draw the instantaneous orbits of the particles. For simulations in which particles are not on Keplerian orbits, the orbits shown will not be accurate. """ self.width, self.height = size self.t, self.N = simulation.t, simulation.N self.orientation = orientation self.autorefresh = autorefresh self.orbits = orbits self.simp = pointer(simulation) clibrebound.reb_display_copy_data.restype = c_int if scale is None: self.scale = simulation.display_data.contents.scale else: self.scale = scale self.count += 1 super(Widget, self).__init__() def refresh(self, simp=None, isauto=0): """ Manually refreshes a widget. Note that this function can also be called using the wrapper function of the Simulation object: sim.refreshWidgets(). """ if simp == None: simp = self.simp if self.autorefresh == 0 and isauto == 1: return sim = simp.contents size_changed = clibrebound.reb_display_copy_data(simp) clibrebound.reb_display_prepare_data(simp, c_int(self.orbits)) if sim.N > 0: self.particle_data = (c_char * (4 * 7 * sim.N)).from_address( sim.display_data.contents.particle_data).raw if self.orbits: self.orbit_data = ( c_char * (4 * 9 * (sim.N - 1))).from_address( sim.display_data.contents.orbit_data).raw if size_changed: #TODO: Implement better GPU size change pass self.N = sim.N self.t = sim.t self.count += 1 @staticmethod def getClientCode(): return shader_code + js_code
class RobustScaler(Transformer): ''' The RobustScaler removes the median and scales the data according to a given percentile range. By default, the scaling is done between the 25th and the 75th percentile. Centering and scaling happens independently for each feature (column). Example: >>> import vaex >>> df = vaex.from_arrays(x=[2,5,7,2,15], y=[-2,3,0,0,10]) >>> df # x y 0 2 -2 1 5 3 2 7 0 3 2 0 4 15 10 >>> scaler = vaex.ml.MaxAbsScaler(features=['x', 'y']) >>> scaler.fit_transform(df) # x y robust_scaled_x robust_scaled_y 0 2 -2 -0.333686 -0.266302 1 5 3 -0.000596934 0.399453 2 7 0 0.221462 0 3 2 0 -0.333686 0 4 15 10 1.1097 1.33151 ''' with_centering = traitlets.CBool( default_value=True, help='If True, remove the median.').tag(ui='Checkbox') with_scaling = traitlets.CBool( default_value=True, help= 'If True, scale each feature between the specified percentile range.' ).tag(ui='Checkbox') percentile_range = traitlets.Tuple( default_value=(25, 75), help='The percentile range to which to scale each feature to.').tag( ).tag(ui='FloatRangeSlider') prefix = traitlets.Unicode(default_value="robust_scaled_", help=help_prefix).tag(ui='Text') center_ = traitlets.List( traitlets.CFloat(), default_value=None, help='The median of each feature.').tag(output=True) scale_ = traitlets.List( traitlets.CFloat(), default_value=None, help='The percentile range for each feature.').tag(output=True) def fit(self, df): ''' Fit RobustScaler to the DataFrame. :param df: A vaex DataFrame. ''' # check the quantile range q_min, q_max = self.percentile_range if not 0 <= q_min <= q_max <= 100: raise ValueError('Invalid percentile range: %s' % (str(self.percentile_range))) if self.with_centering: self.center_ = df.percentile_approx(expression=self.features, percentage=50).tolist() if self.with_scaling: self.scale_ = (df.percentile_approx(expression=self.features, percentage=q_max) - df.percentile_approx(expression=self.features, percentage=q_min)).tolist() def transform(self, df): ''' Transform a DataFrame with a fitted RobustScaler. :param df: A vaex DataFrame. :returns copy: a shallow copy of the DataFrame that includes the scaled features. :rtype: DataFrame ''' copy = df.copy() for i in range(len(self.features)): name = self.prefix + self.features[i] expr = copy[self.features[i]] if self.with_centering: expr = expr - self.center_[i] if self.with_scaling: expr = expr / self.scale_[i] copy[name] = expr return copy
class MolViz2DBaseWidget(MessageWidget): """ This is actually the D3.js graphics driver for the 2D base widget. It should be refactored with an abstract base class if there's a chance of adding another graphics driver. """ _view_name = Unicode('MolWidget2DView').tag(sync=True) _model_name = Unicode('MolWidget2DModel').tag(sync=True) _view_module = Unicode('nbmolviz-js').tag(sync=True) _model_module = Unicode('nbmolviz-js').tag(sync=True) charge = traitlets.Float().tag(sync=True) uuid = traitlets.Unicode().tag(sync=True) graph = traitlets.Dict().tag(sync=True) clicked_atom_index = traitlets.Int(-1).tag(sync=True) clicked_bond_indices = traitlets.Tuple((-1, -1)).tag(sync=True) _atom_colors = traitlets.Dict({}).tag(sync=True) width = traitlets.Float().tag(sync=True) height = traitlets.Float().tag(sync=True) def __init__(self, atoms, charge=-150, width=400, height=350, **kwargs): super(MolViz2DBaseWidget, self).__init__(width=width, height=height, **kwargs) try: self.atoms = atoms.atoms except AttributeError: self.atoms = atoms else: self.entity = atoms self.width = width self.height = height self.uuid = 'mol2d'+str(uuid.uuid4()) self.charge = charge self._clicks_enabled = False self.graph = self.to_graph(self.atoms) def to_graph(self, atoms): """Turn a set of atoms into a graph Should return a dict of the form {nodes:[a1,a2,a3...], links:[b1,b2,b3...]} where ai = {atom:[atom name],color='black',size=1,index:i} and bi = {bond:[order],source:[i1],dest:[i2], color/category='black',distance=22.0,strength=1.0} You can assign an explicit color with "color" OR get automatically assigned unique colors using "category" """ raise NotImplementedError("This method must be implemented by the interface class") def set_atom_style(self, atoms=None, fill_color=None, outline_color=None): if atoms is None: indices = range(len(self.atoms)) else: indices = map(self.get_atom_index, atoms) spec = {} if fill_color is not None: spec['fill'] = translate_color(fill_color, prefix='#') if outline_color is not None: spec['stroke'] = translate_color(outline_color, prefix='#') self.viewer('setAtomStyle', [indices, spec]) def set_bond_style(self, bonds, color=None, width=None, dash_length=None, opacity=None): """ :param bonds: List of atoms :param color: :param width: :param dash_length: :return: """ atom_pairs = [map(self.get_atom_index, pair) for pair in bonds] spec = {} if width is not None: spec['stroke-width'] = str(width)+'px' if color is not None: spec['stroke'] = color if dash_length is not None: spec['stroke-dasharray'] = str(dash_length)+'px' if opacity is not None: spec['opacity'] = opacity if not spec: raise ValueError('No bond style specified!') self.viewer('setBondStyle', [atom_pairs, spec]) def set_atom_label(self, atom, text=None, text_color=None, size=None, font=None): atomidx = self.get_atom_index(atom) self._change_label('setAtomLabel', atomidx, text, text_color, size, font) def set_bond_label(self, bond, text=None, text_color=None, size=None, font=None): bondids = map(self.get_atom_index, bond) self._change_label('setBondLabel', bondids, text, text_color, size, font) def _change_label(self, driver_function, obj_index, text, text_color, size, font): spec = {} if size is not None: if type(size) is not str: size = str(size)+'pt' spec['font-size'] = size if text_color is not None: spec['fill'] = text_color # this strangely doesn't always work if you send it a name if font is not None: spec['font'] = font self.viewer(driver_function, [obj_index, text, spec]) def highlight_atoms(self, atoms): indices = map(self.get_atom_index, atoms) self.viewer('updateHighlightAtoms', [indices]) def get_atom_index(self, atom): raise NotImplemented("This method must be implemented by the interface class") def set_click_callback(self, callback=None, enabled=True): """ :param callback: Callback can have signature (), (trait_name), (trait_name,old), or (trait_name,old,new) :type callback: callable :param enabled: :return: """ if not enabled: return # TODO: FIX THIS assert callable(callback) self._clicks_enabled = True self.on_trait_change(callback, 'clicked_atom_index') self.click_callback = callback def set_color(self, color, atoms=None, render=None): self.set_atom_style(fill_color=color, atoms=atoms) def set_colors(self, colormap, render=True): """ Args: colormap(Mapping[str,List[Atoms]]): mapping of colors to atoms """ for color, atoms in colormap.iteritems(): self.set_color(atoms=atoms, color=color)
class MinMaxScaler(Transformer): '''Will scale a set of features to a given range. Example: >>> import vaex >>> df = vaex.from_arrays(x=[2,5,7,2,15], y=[-2,3,0,0,10]) >>> df # x y 0 2 -2 1 5 3 2 7 0 3 2 0 4 15 10 >>> scaler = vaex.ml.MinMaxScaler(features=['x', 'y']) >>> scaler.fit_transform(df) # x y minmax_scaled_x minmax_scaled_y 0 2 -2 0 0 1 5 3 0.230769 0.416667 2 7 0 0.384615 0.166667 3 2 0 0 0.166667 4 15 10 1 1 ''' # title = Unicode(default_value='MinMax Scaler', read_only=True).tag(ui='HTML') feature_range = traitlets.Tuple( default_value=(0, 1), help='The range the features are scaled to.').tag().tag( ui='FloatRangeSlider') prefix = traitlets.Unicode(default_value="minmax_scaled_", help=help_prefix).tag(ui='Text') fmax_ = traitlets.List( traitlets.CFloat(), help='The minimum value of a feature.').tag(output=True) fmin_ = traitlets.List( traitlets.CFloat(), help='The maximum value of a feature.').tag(output=True) def fit(self, df): ''' Fit MinMaxScaler to the DataFrame. :param df: A vaex DataFrame. ''' assert len( self.feature_range) == 2, 'feature_range must have 2 elements only' minmax = df.minmax(self.features) self.fmin_ = minmax[:, 0].tolist() self.fmax_ = minmax[:, 1].tolist() def transform(self, df): ''' Transform a DataFrame with a fitted MinMaxScaler. :param df: A vaex DataFrame. :return copy: a shallow copy of the DataFrame that includes the scaled features. :rtype: DataFrame ''' copy = df.copy() for i in range(len(self.features)): name = self.prefix + self.features[i] a = self.feature_range[0] b = self.feature_range[1] expr = copy[self.features[i]] expr = (b - a) * (expr - self.fmin_[i]) / (self.fmax_[i] - self.fmin_[i]) + a copy[name] = expr return copy
class MinMaxScaler(Transformer): '''Will scale a set of features to a given range. Example: >>> import vaex >>> df = vaex.from_arrays(x=[2,5,7,2,15], y=[-2,3,0,0,10]) >>> df # x y 0 2 -2 1 5 3 2 7 0 3 2 0 4 15 10 >>> scaler = vaex.ml.MinMaxScaler(features=['x', 'y']) >>> scaler.fit_transform(df) # x y minmax_scaled_x minmax_scaled_y 0 2 -2 0 0 1 5 3 0.230769 0.416667 2 7 0 0.384615 0.166667 3 2 0 0 0.166667 4 15 10 1 1 ''' snake_name = 'minmax_scaler' # title = Unicode(default_value='MinMax Scaler', read_only=True).tag(ui='HTML') feature_range = traitlets.Tuple( default_value=(0, 1), help='The range the features are scaled to.').tag().tag( ui='FloatRangeSlider') prefix = traitlets.Unicode(default_value="minmax_scaled_", help=help_prefix).tag(ui='Text') fmax_ = traitlets.List( traitlets.CFloat(), help='The minimum value of a feature.').tag(output=True) fmin_ = traitlets.List( traitlets.CFloat(), help='The maximum value of a feature.').tag(output=True) def fit(self, df): ''' Fit MinMaxScaler to the DataFrame. :param df: A vaex DataFrame. ''' minmax = [] for feat in self.features: minmax.append(df.minmax(feat, delay=True)) @vaex.delayed def assign(minmax): self.fmin_ = [elem[0] for elem in minmax] self.fmax_ = [elem[1] for elem in minmax] assign(minmax) df.execute() def transform(self, df): ''' Transform a DataFrame with a fitted MinMaxScaler. :param df: A vaex DataFrame. :return copy: a shallow copy of the DataFrame that includes the scaled features. :rtype: DataFrame ''' copy = df.copy() for i, feature in enumerate(self.features): name = self.prefix + feature a = self.feature_range[0] b = self.feature_range[1] expr = copy[feature] expr = (b - a) * (expr - self.fmin_[i]) / (self.fmax_[i] - self.fmin_[i]) + a copy[name] = expr return copy
class Widget(DOMWidget): _view_name = traitlets.Unicode('ReboundView').tag(sync=True) _view_module = traitlets.Unicode('rebound').tag(sync=True) count = traitlets.Int(0).tag(sync=True) screenshotcount = traitlets.Int(0).tag(sync=True) t = traitlets.Float().tag(sync=True) N = traitlets.Int().tag(sync=True) width = traitlets.Float().tag(sync=True) height = traitlets.Float().tag(sync=True) scale = traitlets.Float().tag(sync=True) particle_data = traitlets.CBytes(allow_none=True).tag(sync=True) orbit_data = traitlets.CBytes(allow_none=True).tag(sync=True) orientation = traitlets.Tuple().tag(sync=True) orbits = traitlets.Int().tag(sync=True) screenshot = traitlets.Unicode().tag(sync=True) def __init__(self, simulation, size=(200, 200), orientation=(0., 0., 0., 1.), scale=None, autorefresh=True, orbits=True): """ Initializes a Widget. Widgets provide real-time 3D interactive visualizations for REBOUND simulations within Jupyter Notebooks. To use widgets, the ipywidgets package needs to be installed and enabled in your Jupyter notebook server. Parameters ---------- size : (int, int), optional Specify the size of the widget in pixels. The default is 200 times 200 pixels. orientation : (float, float, float, float), optional Specify the initial orientation of the view. The four floats correspond to the x, y, z, and w components of a quaternion. The quaternion will be normalized. scale : float, optional Set the initial scale of the view. If not set, the widget will determine the scale automatically based on current particle positions. autorefresh : bool, optional The default value if True. The view is updated whenever a particle is added, removed and every 100th of a second while a simulation is running. If set to False, then the user needs to manually call the refresh() function on the widget. This might be useful if performance is an issue. orbits : bool, optional The default value for this is True and the widget will draw the instantaneous orbits of the particles. For simulations in which particles are not on Keplerian orbits, the orbits shown will not be accurate. """ self.screenshotcountall = 0 self.width, self.height = size self.t, self.N = simulation.t, simulation.N self.orientation = orientation self.autorefresh = autorefresh self.orbits = orbits self.simp = pointer(simulation) clibrebound.reb_display_copy_data.restype = c_int if scale is None: self.scale = simulation.display_data.contents.scale else: self.scale = scale self.count += 1 super(Widget, self).__init__() def refresh(self, simp=None, isauto=0): """ Manually refreshes a widget. Note that this function can also be called using the wrapper function of the Simulation object: sim.refreshWidgets(). """ if simp == None: simp = self.simp if self.autorefresh == 0 and isauto == 1: return sim = simp.contents size_changed = clibrebound.reb_display_copy_data(simp) clibrebound.reb_display_prepare_data(simp, c_int(self.orbits)) if sim.N > 0: self.particle_data = (c_char * (4 * 7 * sim.N)).from_address( sim.display_data.contents.particle_data).raw if self.orbits: self.orbit_data = ( c_char * (4 * 9 * (sim.N - 1))).from_address( sim.display_data.contents.orbit_data).raw if size_changed: #TODO: Implement better GPU size change pass self.N = sim.N self.t = sim.t self.count += 1 def takeScreenshot(self, times=None, prefix="./screenshot", resetCounter=False, archive=None, mode="snapshot"): """ Take one or more screenshots of the widget and save the images to a file. The images can be used to create a video. This function cannot be called multiple times within one cell. Note: this is a new feature and might not work on all systems. It was tested on python 2.7.10 and 3.5.2 on MacOSX. Parameters ---------- times : (float, list), optional If this argument is not given a screenshot of the widget will be made as it is (without integrating the simulation). If a float is given, then the simulation will be integrated to that time and then a screenshot will be taken. If a list of floats is given, the simulation will be integrated to each time specified in the array. A separate screenshot for each time will be saved. prefix : (str), optional This string will be part of the output filename for each image. Follow by a five digit integer and the suffix .png. By default the prefix is './screenshot' which outputs images in the current directory with the filnames screenshot00000.png, screenshot00001.png... Note that the prefix can include a directory. resetCounter : (bool), optional Resets the output counter to 0. archive : (rebound.SimulationArchive), optional Use a REBOUND SimulationArchive. Thus, instead of integratating the Simulation from the current time, it will use the SimulationArchive to load a snapshot. See examples for usage. mode : (string), optional Mode to use when querying the SimulationArchive. See SimulationArchive documentation for details. By default the value is "snapshot". Examples -------- First, create a simulation and widget. All of the following can go in one cell. >>> sim = rebound.Simulation() >>> sim.add(m=1.) >>> sim.add(m=1.e-3,x=1.,vy=1.) >>> w = sim.getWidget() >>> w The widget should show up. To take a screenshot, simply call >>> w.takeScreenshot() A new file with the name screenshot00000.png will appear in the current directory. Note that the takeScreenshot command needs to be in a separate cell, i.e. after you see the widget. You can pass an array of times to the function. This allows you to take multiple screenshots, for example to create a movie, >>> times = [0,10,100] >>> w.takeScreenshot(times) """ self.archive = archive if resetCounter: self.screenshotcountall = 0 self.screenshotprefix = prefix self.screenshotcount = 0 self.screenshot = "" if archive is None: if times is None: times = self.simp.contents.t try: # List len(times) except: # Float: times = [times] self.times = times self.observe(savescreenshot, names="screenshot") self.simp.contents.integrate(times[0]) self.screenshotcount += 1 # triggers first screenshot else: if times is None: raise ValueError("Need times argument for archive mode.") try: len(times) except: raise ValueError("Need a list of times for archive mode.") self.times = times self.mode = mode self.observe(savescreenshot, names="screenshot") sim = archive.getSimulation(times[0], mode=mode) self.refresh(pointer(sim)) self.screenshotcount += 1 # triggers first screenshot @staticmethod def getClientCode(): return shader_code + js_code
class Unit(traitlets.HasTraits): attack_power = traitlets.Integer(default_value=3) hit_points = traitlets.Integer(default_value=200) location = traitlets.Tuple(traitlets.Integer(), traitlets.Integer()) # y, x dead = traitlets.Bool(default_value=False) members = [] opponents = traitlets.Type('__main__.Unit') @classmethod def append(cls, other): cls.members.append(other) def attack(self, other): other.hit_points -= self.attack_power if other.hit_points <= 0: other.dead = True self.opponents.members.remove(other) print(self, 'killed', other) def distance(self, other): return cityblock(self.location, other.location) @property def target(self): opponent_distances = [ self.distance(foe) for foe in self.opponents.members ] potential_targets = [ foe for foe, distance in zip(self.opponents.members, opponent_distances) if distance == 1 ] if not potential_targets: return None elif len(potential_targets) == 1: return potential_targets[0] else: return sorted( potential_targets, key = lambda u: (u.hit_points, *u.location) )[0] def move(self): # first, block out your buddies current_dungeon = DUNGEON.copy() allies = np.array([ friend.location for friend in self.members if friend is not self ]) if allies.size: # locations are stored as y, x, so: current_dungeon[allies[:, 0], allies[:, 1]] = -1 foe_locations = np.array([ foe.location for foe in self.opponents.members ]) # and now find the costs mcp = MCP(current_dungeon, fully_connected=False) cum_costs, traceback = mcp.find_costs( starts=[self.location], #ends=foe_locations, find_all_ends=True ) foe_distances = cum_costs[ foe_locations[:, 0], foe_locations[:, 1] ] if np.isinf(foe_distances.min()): return # no route available to any foe closest_foes = np.arange(len(foe_distances))[foe_distances == foe_distances.min()] closest_foe = sorted( self.opponents.members[i] for i in closest_foes )[0] # now you have one closest foe, reverse the distance calc # and move one step close mcp = MCP(current_dungeon, fully_connected=False) cum_costs, traceback = mcp.find_costs( ends=[self.location], starts=[closest_foe.location], find_all_ends=False ) target_locations = np.argwhere(cum_costs == foe_distances.min() - 1) valid_locations = target_locations[( (target_locations >= np.array(self.location) - 1) & (target_locations <= np.array(self.location) + 1) ).all(axis=1)] y, x = (sorted(tuple(coords) for coords in valid_locations))[0] print(self, 'moving to', y, x) self.location = (int(y), int(x)) def __eq__(self, other): return (*self.location, self.hit_points) == (*other.location, other.hit_points) def __lt__(self, other): return (*self.location, self.hit_points) < (*other.location, other.hit_points) def __gt__(self, other): return (*self.location, self.hit_points) == (*other.location, other.hit_points) def __repr__(self): return f'<{self.__class__.__name__} ap{self.attack_power} hp{self.hit_points} loc{self.location}>' def __add__(self, other): return self.hit_points + other.hit_points def __radd__(self, other): return self.hit_points + other
class OptimadeQueryFilterWidget( # pylint: disable=too-many-instance-attributes ipw.VBox): """Structure search and import widget for OPTIMADE NOTE: Only supports offset- and number-pagination at the moment. """ structure = traitlets.Instance(Structure, allow_none=True) database = traitlets.Tuple( traitlets.Unicode(), traitlets.Instance(LinksResourceAttributes, allow_none=True), ) def __init__( self, result_limit: int = None, button_style: Union[ButtonStyle, str] = None, embedded: bool = False, subparts_order: List[str] = None, **kwargs, ): self.page_limit = result_limit if result_limit else 25 if button_style: if isinstance(button_style, str): button_style = ButtonStyle[button_style.upper()] elif isinstance(button_style, ButtonStyle): pass else: raise TypeError( "button_style should be either a string or a ButtonStyle Enum. " f"You passed type {type(button_style)!r}.") else: button_style = ButtonStyle.PRIMARY subparts_order = subparts_order or QueryFilterWidgetOrder.default_order( as_str=True) self.offset = 0 self.number = 1 self._data_available = None self.__perform_query = True self.__cached_ranges = {} self.__cached_versions = {} self.database_version = "" self.filter_header = ipw.HTML( '<h4 style="margin:0px;padding:0px;">Apply filters</h4>') self.filters = FilterTabs(show_large_filters=not embedded) self.filters.freeze() self.filters.on_submit(self.retrieve_data) self.query_button = ipw.Button( description="Search", button_style=button_style.value, icon="search", disabled=True, tooltip="Search - No database chosen", ) self.query_button.on_click(self.retrieve_data) self.structures_header = ipw.HTML( '<h4 style="margin-bottom:0px;padding:0px;">Results</h4>') self.sort_selector = SortSelector(disabled=True) self.sorting = self.sort_selector.value self.sort_selector.observe(self._sort, names="value") self.structure_drop = StructureDropdown(disabled=True) self.structure_drop.observe(self._on_structure_select, names="value") self.error_or_status_messages = ipw.HTML("") self.structure_page_chooser = ResultsPageChooser(self.page_limit) self.structure_page_chooser.observe( self._get_more_results, names=["page_link", "page_offset", "page_number"]) for subpart in subparts_order: if not hasattr(self, subpart): raise ValueError( f"Wrongly specified subpart_order: {subpart!r}. Available subparts " f"(and default order): {QueryFilterWidgetOrder.default_order(as_str=True)}" ) super().__init__( children=[getattr(self, _) for _ in subparts_order], layout=ipw.Layout(width="auto", height="auto"), **kwargs, ) @traitlets.observe("database") def _on_database_select(self, _): """Load chosen database""" self.structure_drop.reset() if (self.database[1] is None or getattr(self.database[1], "base_url", None) is None): self.query_button.tooltip = "Search - No database chosen" self.freeze() else: self.offset = 0 self.number = 1 self.structure_page_chooser.silent_reset() try: self.freeze() self.query_button.description = "Updating ..." self.query_button.icon = "cog" self.query_button.tooltip = "Updating filters ..." self._set_intslider_ranges() self._set_version() except Exception as exc: # pylint: disable=broad-except LOGGER.error( "Exception raised during setting IntSliderRanges: %s", exc.with_traceback(), ) finally: self.query_button.description = "Search" self.query_button.icon = "search" self.query_button.tooltip = "Search" self.sort_selector.valid_fields = sorted( get_sortable_fields(self.database[1].base_url)) self.unfreeze() def _on_structure_select(self, change): """Update structure trait with chosen structure dropdown value""" chosen_structure = change["new"] if chosen_structure is None: self.structure = None with self.hold_trait_notifications(): self.structure_drop.index = 0 else: self.structure = chosen_structure["structure"] def _get_more_results(self, change): """Query for more results according to pageing""" if not self.__perform_query: self.__perform_query = True LOGGER.debug( "NOT going to perform query with change: name=%s value=%s", change["name"], change["new"], ) return pageing: Union[int, str] = change["new"] LOGGER.debug( "Updating results with pageing change: name=%s value=%s", change["name"], pageing, ) if change["name"] == "page_offset": self.offset = pageing pageing = None elif change["name"] == "page_number": self.number = pageing pageing = None else: # It is needed to update page_offset, but we do not wish to query again with self.hold_trait_notifications(): self.__perform_query = False self.structure_page_chooser.update_offset() try: # Freeze and disable list of structures in dropdown widget # We don't want changes leading to weird things happening prior to the query ending self.freeze() # Update button text and icon self.query_button.description = "Updating ... " self.query_button.icon = "cog" self.query_button.tooltip = "Please wait ..." # Query database response = self._query(pageing) msg, _ = handle_errors(response) if msg: self.error_or_status_messages.value = msg return # Update list of structures in dropdown widget self._update_structures(response["data"]) # Update pageing self.structure_page_chooser.set_pagination_data( links_to_page=response.get("links", {}), ) finally: self.query_button.description = "Search" self.query_button.icon = "search" self.query_button.tooltip = "Search" self.unfreeze() def _sort(self, change: dict) -> None: """Perform new query with new sorting""" sort = change["new"] self.sorting = sort self.retrieve_data({}) def freeze(self): """Disable widget""" self.query_button.disabled = True self.filters.freeze() self.structure_drop.freeze() self.structure_page_chooser.freeze() self.sort_selector.freeze() def unfreeze(self): """Activate widget (in its current state)""" self.query_button.disabled = False self.filters.unfreeze() self.structure_drop.unfreeze() self.structure_page_chooser.unfreeze() self.sort_selector.unfreeze() def reset(self): """Reset widget""" self.offset = 0 self.number = 1 with self.hold_trait_notifications(): self.query_button.disabled = False self.query_button.tooltip = "Search - No database chosen" self.filters.reset() self.structure_drop.reset() self.structure_page_chooser.reset() self.sort_selector.reset() def _uses_new_structure_features(self) -> bool: """Check whether self.database_version is >= v1.0.0-rc.2""" critical_version = SemanticVersion("1.0.0-rc.2") version = SemanticVersion(self.database_version) LOGGER.debug("Semantic version: %r", version) if version.base_version > critical_version.base_version: return True if version.base_version == critical_version.base_version: if version.prerelease: return version.prerelease >= critical_version.prerelease # Version is bigger than critical version and is not a pre-release return True # Major.Minor.Patch is lower than critical version return False def _set_version(self): """Set self.database_version from an /info query""" base_url = self.database[1].base_url if base_url not in self.__cached_versions: # Retrieve and cache version response = perform_optimade_query( base_url=self.database[1].base_url, endpoint="/info") msg, _ = handle_errors(response) if msg: raise QueryError(msg) if "meta" not in response: raise QueryError( f"'meta' field not found in /info endpoint for base URL: {base_url}" ) if "api_version" not in response["meta"]: raise QueryError( f"'api_version' field not found in 'meta' for base URL: {base_url}" ) version = response["meta"]["api_version"] if version.startswith("v"): version = version[1:] self.__cached_versions[base_url] = version LOGGER.debug( "Cached version %r for base URL: %r", self.__cached_versions[base_url], base_url, ) self.database_version = self.__cached_versions[base_url] def _set_intslider_ranges(self): """Update IntRangeSlider ranges according to chosen database Query database to retrieve ranges. Cache ranges in self.__cached_ranges. """ defaults = { "nsites": { "min": 0, "max": 10000 }, "nelements": { "min": 0, "max": len(CHEMICAL_SYMBOLS) }, } db_base_url = self.database[1].base_url if db_base_url not in self.__cached_ranges: self.__cached_ranges[db_base_url] = {} sortable_fields = check_entry_properties( base_url=db_base_url, entry_endpoint="structures", properties=["nsites", "nelements"], checks=["sort"], ) for response_field in sortable_fields: if response_field in self.__cached_ranges[db_base_url]: # Use cached value(s) continue page_limit = 1 new_range = {} for extremum, sort in [ ("min", response_field), ("max", f"-{response_field}"), ]: query_params = { "base_url": db_base_url, "page_limit": page_limit, "response_fields": response_field, "sort": sort, } LOGGER.debug( "Querying %s to get %s of %s.\nParameters: %r", self.database[0], extremum, response_field, query_params, ) response = perform_optimade_query(**query_params) msg, _ = handle_errors(response) if msg: raise QueryError(msg) if not response.get("meta", {}).get("data_available", 0): new_range[extremum] = defaults[response_field][extremum] else: new_range[extremum] = (response.get("data", [{}])[0].get( "attributes", {}).get(response_field, None)) # Cache new values LOGGER.debug( "Caching newly found range values for %s\nValue: %r", db_base_url, {response_field: new_range}, ) self.__cached_ranges[db_base_url].update( {response_field: new_range}) if not self.__cached_ranges[db_base_url]: LOGGER.debug("No values found for %s, storing default values.", db_base_url) self.__cached_ranges[db_base_url].update({ "nsites": { "min": 0, "max": 10000 }, "nelements": { "min": 0, "max": len(CHEMICAL_SYMBOLS) }, }) # Set widget's new extrema LOGGER.debug( "Updating range extrema for %s\nValues: %r", db_base_url, self.__cached_ranges[db_base_url], ) self.filters.update_range_filters(self.__cached_ranges[db_base_url]) def _query(self, link: str = None) -> dict: """Query helper function""" # If a complete link is provided, use it straight up if link is not None: try: link = ordered_query_url(link) response = SESSION.get(link, timeout=TIMEOUT_SECONDS) if response.from_cache: LOGGER.debug("Request to %s was taken from cache !", link) response = response.json() except ( requests.exceptions.ConnectTimeout, requests.exceptions.ConnectionError, requests.exceptions.ReadTimeout, ) as exc: response = { "errors": { "msg": "CLIENT: Connection error or timeout.", "url": link, "Exception": repr(exc), } } except JSONDecodeError as exc: response = { "errors": { "msg": "CLIENT: Could not decode response to JSON.", "url": link, "Exception": repr(exc), } } return response # Avoid structures with null positions and with assemblies. add_to_filter = 'NOT structure_features HAS ANY "assemblies"' if not self._uses_new_structure_features(): add_to_filter += ',"unknown_positions"' optimade_filter = self.filters.collect_value() optimade_filter = ("( {} ) AND ( {} )".format(optimade_filter, add_to_filter) if optimade_filter and add_to_filter else optimade_filter or add_to_filter or None) LOGGER.debug("Querying with filter: %s", optimade_filter) # OPTIMADE queries queries = { "base_url": self.database[1].base_url, "filter": optimade_filter, "page_limit": self.page_limit, "page_offset": self.offset, "page_number": self.number, "sort": self.sorting, } LOGGER.debug( "Parameters (excluding filter) sent to query util func: %s", {key: value for key, value in queries.items() if key != "filter"}, ) return perform_optimade_query(**queries) @staticmethod def _check_species_mass(structure: dict) -> dict: """Ensure species.mass is using OPTIMADE API v1.0.1 type""" if structure.get("attributes", {}).get("species", False): for species in structure["attributes"]["species"] or []: if not isinstance(species.get("mass", None), (list, type(None))): species.pop("mass", None) return structure def _update_structures(self, data: list): """Update structures dropdown from response data""" structures = [] for entry in data: # XXX: THIS IS TEMPORARY AND SHOULD BE REMOVED ASAP entry["attributes"]["chemical_formula_anonymous"] = None structure = Structure(self._check_species_mass(entry)) formula = structure.attributes.chemical_formula_descriptive if formula is None: formula = structure.attributes.chemical_formula_reduced if formula is None: formula = structure.attributes.chemical_formula_anonymous if formula is None: formula = structure.attributes.chemical_formula_hill if formula is None: raise BadResource( resource=structure, fields=[ "chemical_formula_descriptive", "chemical_formula_reduced", "chemical_formula_anonymous", "chemical_formula_hill", ], msg="At least one of the following chemical formula fields " "should have a valid value", ) entry_name = f"{formula} (id={structure.id})" structures.append((entry_name, {"structure": structure})) # Update list of structures in dropdown widget self.structure_drop.set_options(structures) def retrieve_data(self, _): """Perform query and retrieve data""" self.offset = 0 self.number = 1 try: # Freeze and disable list of structures in dropdown widget # We don't want changes leading to weird things happening prior to the query ending self.freeze() # Reset the error or status message if self.error_or_status_messages.value: self.error_or_status_messages.value = "" # Update button text and icon self.query_button.description = "Querying ... " self.query_button.icon = "cog" self.query_button.tooltip = "Please wait ..." # Query database response = self._query() msg, _ = handle_errors(response) if msg: self.error_or_status_messages.value = msg raise QueryError(msg) # Update list of structures in dropdown widget self._update_structures(response["data"]) # Update pageing if self._data_available is None: self._data_available = response.get("meta", {}).get( "data_available", None) data_returned = response.get("meta", {}).get("data_returned", len(response.get("data", []))) self.structure_page_chooser.set_pagination_data( data_returned=data_returned, data_available=self._data_available, links_to_page=response.get("links", {}), reset_cache=True, ) except QueryError: self.structure_drop.reset() self.structure_page_chooser.reset() raise except Exception as exc: self.structure_drop.reset() self.structure_page_chooser.reset() raise QueryError( f"Bad stuff happened: {traceback.format_exc()}") from exc finally: self.query_button.description = "Search" self.query_button.icon = "search" self.query_button.tooltip = "Search" self.unfreeze()