Example #1
0
class DownloadChooser(ipw.HBox):
    """Download chooser for structure download

    To be able to have the download button work no matter the widget's final environment,
    (as long as it supports JavaScript), the very helpful insight from the following page is used:
    https://stackoverflow.com/questions/2906582/how-to-create-an-html-button-that-acts-like-a-link
    """

    chosen_format = traitlets.Tuple(traitlets.Unicode(), traitlets.Dict())
    structure = traitlets.Instance(Structure, allow_none=True)

    _formats = [
        (
            "Crystallographic Information File v1.0 (.cif)",
            {
                "ext": ".cif",
                "adapter_format": "cif"
            },
        ),
        ("Protein Data Bank (.pdb)", {
            "ext": ".pdb",
            "adapter_format": "pdb"
        }),
        (
            "Crystallographic Information File v1.0 [via ASE] (.cif)",
            {
                "ext": ".cif",
                "adapter_format": "ase",
                "final_format": "cif"
            },
        ),
        (
            "Protein Data Bank [via ASE] (.pdb)",
            {
                "ext": ".pdb",
                "adapter_format": "ase",
                "final_format": "proteindatabank"
            },
        ),
        (
            "XMol XYZ File [via ASE] (.xyz)",
            {
                "ext": ".xyz",
                "adapter_format": "ase",
                "final_format": "xyz"
            },
        ),
        (
            "XCrySDen Structure File [via ASE] (.xsf)",
            {
                "ext": ".xsf",
                "adapter_format": "ase",
                "final_format": "xsf"
            },
        ),
        (
            "WIEN2k Structure File [via ASE] (.struct)",
            {
                "ext": ".struct",
                "adapter_format": "ase",
                "final_format": "struct"
            },
        ),
        (
            "VASP POSCAR File [via ASE]",
            {
                "ext": "",
                "adapter_format": "ase",
                "final_format": "vasp"
            },
        ),
        (
            "Quantum ESPRESSO File [via ASE] (.in)",
            {
                "ext": ".in",
                "adapter_format": "ase",
                "final_format": "espresso-in"
            },
        ),
        # Not yet implemented:
        # (
        #     "Protein Data Bank, macromolecular CIF v1.1 (PDBx/mmCIF) (.cif)",
        #     {"ext": "cif", "adapter_format": "pdbx_mmcif"},
        # ),
    ]
    _download_button_format = """
<input type="button" class="jupyter-widgets jupyter-button widget-button" value="Download" title="Download structure" style="width:auto;" {disabled}
onclick="var link = document.createElement('a');
link.href = 'data:charset={encoding};base64,{data}';
link.download = '{filename}';
document.body.appendChild(link);
link.click();
document.body.removeChild(link);" />
"""

    def __init__(self, **kwargs):
        self.dropdown = ipw.Dropdown(options=("Select a format", {}),
                                     width="auto")
        self.download_button = ipw.HTML(
            self._download_button_format.format(disabled="disabled",
                                                encoding="",
                                                data="",
                                                filename=""))

        self.children = (self.dropdown, self.download_button)
        super().__init__(children=self.children, layout={"width": "auto"})
        self.reset()

        self.dropdown.observe(self._update_download_button, names="value")

    @traitlets.observe("structure")
    def _on_change_structure(self, change: dict):
        """Update widget when a new structure is chosen"""
        if change["new"] is None:
            self.reset()
        else:
            self._update_options()
            self.unfreeze()

    def _update_options(self):
        """Update options according to chosen structure"""
        # Disordered structures not usable with ASE
        if "disorder" in self.structure.structure_features:
            options = sorted([
                option for option in self._formats
                if option[1].get("adapter_format", "") != "ase"
            ])
            options.insert(0, ("Select a format", {}))
        else:
            options = sorted(self._formats)
            options.insert(0, ("Select a format", {}))
        self.dropdown.options = options

    def _update_download_button(self, change: dict):
        """Update Download button with correct onclick value

        The whole parsing process from `Structure` to desired format, is wrapped in a try/except,
        which is further wrapped in a `warnings.catch_warnings()`.
        This is in order to be able to log any warnings that might be thrown by the adapter in
        `optimade-python-tools` and/or any related exceptions.
        """
        desired_format = change["new"]
        if not desired_format or desired_format is None:
            self.download_button.value = self._download_button_format.format(
                disabled="disabled", encoding="", data="", filename="")
            return

        with warnings.catch_warnings():
            warnings.filterwarnings("error")

            try:
                output = getattr(self.structure,
                                 f"as_{desired_format['adapter_format']}")

                if desired_format["adapter_format"] in (
                        "ase",
                        "pymatgen",
                        "aiida_structuredata",
                ):
                    # output is not a file, but a proxy Python class
                    func = getattr(
                        self, f"_get_via_{desired_format['adapter_format']}")
                    output = func(
                        output, desired_format=desired_format["final_format"])
                encoding = "utf-8"

                # Specifically for CIF: v1.x CIF needs to be in "latin-1" formatting
                if desired_format["ext"] == ".cif":
                    encoding = "latin-1"

                filename = (
                    f"optimade_structure_{self.structure.id}{desired_format['ext']}"
                )

                if isinstance(output, str):
                    output = output.encode(encoding)
                data = base64.b64encode(output).decode()
            except Warning as warn:
                self.download_button.value = self._download_button_format.format(
                    disabled="disabled", encoding="", data="", filename="")
                warnings.warn(OptimadeClientWarning(warn))
            except Exception as exc:
                self.download_button.value = self._download_button_format.format(
                    disabled="disabled", encoding="", data="", filename="")
                if isinstance(exc, exceptions.OptimadeClientError):
                    raise exc
                # Else wrap the exception to make sure to log it.
                raise exceptions.OptimadeClientError(exc)
            else:
                self.download_button.value = self._download_button_format.format(
                    disabled="",
                    encoding=encoding,
                    data=data,
                    filename=filename)

    @staticmethod
    def _get_via_pymatgen(
        structure_molecule: Union[pymatgenStructure, pymatgenMolecule],
        desired_format: str,
    ) -> str:
        """Use pymatgen.[Structure,Molecule].to() method"""
        molecule_only_formats = ["xyz", "pdb"]
        structure_only_formats = ["xsf", "cif"]
        if desired_format in molecule_only_formats and not isinstance(
                structure_molecule, pymatgenMolecule):
            raise exceptions.WrongPymatgenType(
                f"Converting to '{desired_format}' format is only possible with a pymatgen."
                f"Molecule, instead got {type(structure_molecule)}")
        if desired_format in structure_only_formats and not isinstance(
                structure_molecule, pymatgenStructure):
            raise exceptions.WrongPymatgenType(
                f"Converting to '{desired_format}' format is only possible with a pymatgen."
                f"Structure, instead got {type(structure_molecule)}.")

        return structure_molecule.to(fmt=desired_format)

    @staticmethod
    def _get_via_ase(atoms: aseAtoms,
                     desired_format: str) -> Union[str, bytes]:
        """Use ase.Atoms.write() method"""
        with tempfile.NamedTemporaryFile(mode="w+b") as temp_file:
            atoms.write(temp_file.name, format=desired_format)
            res = temp_file.read()
        return res

    def freeze(self):
        """Disable widget"""
        for widget in self.children:
            widget.disabled = True

    def unfreeze(self):
        """Activate widget (in its current state)"""
        for widget in self.children:
            widget.disabled = False

    def reset(self):
        """Reset widget"""
        self.dropdown.index = 0
        self.freeze()
class ProviderImplementationChooser(  # pylint: disable=too-many-instance-attributes
        ipw.VBox):
    """List all OPTIMADE providers and their implementations"""

    provider = traitlets.Instance(LinksResourceAttributes, allow_none=True)
    database = traitlets.Tuple(
        traitlets.Unicode(),
        traitlets.Instance(LinksResourceAttributes, allow_none=True),
        default_value=("", None),
    )

    HINT = {"provider": "Select a provider", "child_dbs": "Select a database"}
    INITIAL_CHILD_DBS = [("", (("No provider chosen", None), ))]

    def __init__(
        self,
        child_db_limit: int = None,
        disable_providers: List[str] = None,
        skip_providers: List[str] = None,
        skip_databases: Dict[str, List[str]] = None,
        provider_database_groupings: Dict[str, Dict[str, List[str]]] = None,
        **kwargs,
    ):
        self.child_db_limit = (child_db_limit if child_db_limit
                               and child_db_limit > 0 else 10)
        self.skip_child_dbs = skip_databases or {}
        self.child_db_groupings = (provider_database_groupings
                                   if provider_database_groupings is not None
                                   else PROVIDER_DATABASE_GROUPINGS)
        self.offset = 0
        self.number = 1
        self.__perform_query = True
        self.__cached_child_dbs = {}

        self.debug = bool(os.environ.get("OPTIMADE_CLIENT_DEBUG", None))

        providers = []
        providers, invalid_providers = get_list_of_valid_providers(
            disable_providers=disable_providers, skip_providers=skip_providers)
        providers.insert(0, (self.HINT["provider"], {}))
        if self.debug:
            from optimade_client.utils import VERSION_PARTS

            local_provider = LinksResourceAttributes(
                **{
                    "name": "Local server",
                    "description": "Local server, running aiida-optimade",
                    "base_url": f"http://localhost:5000{VERSION_PARTS[0][0]}",
                    "homepage": "https://example.org",
                    "link_type": "external",
                })
            providers.insert(1, ("Local server", local_provider))

        self.providers = DropdownExtended(
            options=providers,
            disabled_options=invalid_providers,
            layout=ipw.Layout(width="auto"),
        )

        self.show_child_dbs = ipw.Layout(width="auto", display="none")
        self.child_dbs = DropdownExtended(grouping=self.INITIAL_CHILD_DBS,
                                          layout=self.show_child_dbs)
        self.page_chooser = ResultsPageChooser(page_limit=self.child_db_limit,
                                               layout=self.show_child_dbs)

        self.providers.observe(self._observe_providers, names="value")
        self.child_dbs.observe(self._observe_child_dbs, names="value")
        self.page_chooser.observe(
            self._get_more_child_dbs,
            names=["page_link", "page_offset", "page_number"])
        self.error_or_status_messages = ipw.HTML("")

        super().__init__(
            children=(
                self.providers,
                self.child_dbs,
                self.page_chooser,
                self.error_or_status_messages,
            ),
            layout=ipw.Layout(width="auto"),
            **kwargs,
        )

    def freeze(self):
        """Disable widget"""
        self.providers.disabled = True
        self.show_child_dbs.display = "none"
        self.page_chooser.freeze()

    def unfreeze(self):
        """Activate widget (in its current state)"""
        self.providers.disabled = False
        self.show_child_dbs.display = None
        self.page_chooser.unfreeze()

    def reset(self):
        """Reset widget"""
        self.page_chooser.reset()
        self.offset = 0
        self.number = 1

        self.providers.disabled = False
        self.providers.index = 0

        self.show_child_dbs.display = "none"
        self.child_dbs.grouping = self.INITIAL_CHILD_DBS

    def _observe_providers(self, change: dict):
        """Update child database dropdown upon changing provider"""
        value = change["new"]
        self.show_child_dbs.display = "none"
        self.provider = value
        if value is None or not value:
            self.show_child_dbs.display = "none"
            self.child_dbs.grouping = self.INITIAL_CHILD_DBS
            self.providers.index = 0
            self.child_dbs.index = 0
        else:
            self._initialize_child_dbs()
            if sum([len(_[1]) for _ in self.child_dbs.grouping]) <= 2:
                # The provider either has 0 or 1 implementations
                # or we have failed to retrieve any implementations.
                # Automatically choose the 1 implementation (if there),
                # while otherwise keeping the dropdown disabled.
                self.show_child_dbs.display = "none"
                try:
                    self.child_dbs.index = 1
                    LOGGER.debug("Changed child_dbs index. New child_dbs: %s",
                                 self.child_dbs)
                except IndexError:
                    pass
            else:
                self.show_child_dbs.display = None

    def _observe_child_dbs(self, change: dict):
        """Update database traitlet with base URL for chosen child database"""
        value = change["new"]
        if value is None or not value:
            self.database = "", None
        else:
            self.database = self.child_dbs.label.strip(), self.child_dbs.value

    @staticmethod
    def _remove_current_dropdown_option(dropdown: ipw.Dropdown) -> tuple:
        """Remove the current option from a Dropdown widget and return updated options

        Since Dropdown.options is a tuple there is a need to go through a list.
        """
        list_of_options = list(dropdown.options)
        list_of_options.pop(dropdown.index)
        return tuple(list_of_options)

    def _initialize_child_dbs(self):
        """New provider chosen; initialize child DB dropdown"""
        self.offset = 0
        self.number = 1
        try:
            # Freeze and disable list of structures in dropdown widget
            # We don't want changes leading to weird things happening prior to the query ending
            self.freeze()

            # Reset the error or status message
            if self.error_or_status_messages.value:
                self.error_or_status_messages.value = ""

            if self.provider.base_url in self.__cached_child_dbs:
                cache = self.__cached_child_dbs[self.provider.base_url]

                LOGGER.debug(
                    "Initializing child DBs for %s. Using cached info:\n%r",
                    self.provider.name,
                    cache,
                )

                self._set_child_dbs(cache["child_dbs"])
                data_returned = cache["data_returned"]
                data_available = cache["data_available"]
                links = cache["links"]
            else:
                LOGGER.debug("Initializing child DBs for %s.",
                             self.provider.name)

                # Query database and get child_dbs
                child_dbs, links, data_returned, data_available = self._query()

                while True:
                    # Update list of structures in dropdown widget
                    exclude_child_dbs, final_child_dbs = self._update_child_dbs(
                        data=child_dbs,
                        skip_dbs=self.skip_child_dbs.get(
                            self.provider.name, []),
                    )

                    LOGGER.debug("Exclude child DBs: %r", exclude_child_dbs)
                    data_returned -= len(exclude_child_dbs)
                    data_available -= len(exclude_child_dbs)
                    if exclude_child_dbs and data_returned:
                        child_dbs, links, data_returned, _ = self._query(
                            exclude_ids=exclude_child_dbs)
                    else:
                        break
                self._set_child_dbs(final_child_dbs)

                # Cache initial child_dbs and related information
                self.__cached_child_dbs[self.provider.base_url] = {
                    "child_dbs": final_child_dbs,
                    "data_returned": data_returned,
                    "data_available": data_available,
                    "links": links,
                }

                LOGGER.debug(
                    "Found the following, which has now been cached:\n%r",
                    self.__cached_child_dbs[self.provider.base_url],
                )

            # Update pageing
            self.page_chooser.set_pagination_data(
                data_returned=data_returned,
                data_available=data_available,
                links_to_page=links,
                reset_cache=True,
            )

        except QueryError as exc:
            LOGGER.debug(
                "Trying to initalize child DBs. QueryError caught: %r", exc)
            if exc.remove_target:
                LOGGER.debug(
                    "Remove target: %r. Will remove target at %r: %r",
                    exc.remove_target,
                    self.providers.index,
                    self.providers.value,
                )
                self.providers.options = self._remove_current_dropdown_option(
                    self.providers)
                self.reset()
            else:
                LOGGER.debug(
                    "Remove target: %r. Will NOT remove target at %r: %r",
                    exc.remove_target,
                    self.providers.index,
                    self.providers.value,
                )
                self.show_child_dbs.display = "none"
                self.child_dbs.grouping = self.INITIAL_CHILD_DBS

        else:
            self.unfreeze()

    def _set_child_dbs(
        self,
        data: List[Tuple[str, List[Tuple[str, LinksResourceAttributes]]]],
    ) -> None:
        """Update the child_dbs options with `data`"""
        first_choice = (self.HINT["child_dbs"]
                        if data else "No valid implementations found")
        new_data = list(data)
        new_data.insert(0, ("", [(first_choice, None)]))
        self.child_dbs.grouping = new_data

    def _update_child_dbs(
        self,
        data: List[dict],
        skip_dbs: List[str] = None
    ) -> Tuple[List[str], List[List[Union[str, List[Tuple[
            str, LinksResourceAttributes]]]]], ]:
        """Update child DB dropdown from response data"""
        child_dbs = ({
            "": []
        } if self.providers.label not in self.child_db_groupings else deepcopy(
            self.child_db_groupings[self.providers.label]))
        exclude_dbs = []
        skip_dbs = skip_dbs or []

        for entry in data:
            child_db = update_old_links_resources(entry)
            if child_db is None:
                continue

            # Skip if desired by user
            if child_db.id in skip_dbs:
                exclude_dbs.append(child_db.id)
                continue

            attributes = child_db.attributes

            # Skip if not a 'child' link_type database
            if attributes.link_type != LinkType.CHILD:
                LOGGER.debug(
                    "Skip %s: Links resource not a %r link_type, instead: %r",
                    attributes.name,
                    LinkType.CHILD,
                    attributes.link_type,
                )
                continue

            # Skip if there is no base_url
            if attributes.base_url is None:
                LOGGER.debug(
                    "Skip %s: Base URL found to be None for child DB: %r",
                    attributes.name,
                    child_db,
                )
                exclude_dbs.append(child_db.id)
                continue

            versioned_base_url = get_versioned_base_url(attributes.base_url)
            if versioned_base_url:
                attributes.base_url = versioned_base_url
            else:
                # Not a valid/supported child DB: skip
                LOGGER.debug(
                    "Skip %s: Could not determine versioned base URL for child DB: %r",
                    attributes.name,
                    child_db,
                )
                exclude_dbs.append(child_db.id)
                continue

            if self.providers.label in self.child_db_groupings:
                for group, ids in self.child_db_groupings[
                        self.providers.label].items():
                    if child_db.id in ids:
                        index = child_dbs[group].index(child_db.id)
                        child_dbs[group][index] = (attributes.name, attributes)
                        break
                else:
                    if "" in child_dbs:
                        child_dbs[""].append((attributes.name, attributes))
                    else:
                        child_dbs[""] = [(attributes.name, attributes)]
            else:
                child_dbs[""].append((attributes.name, attributes))

        if self.providers.label in self.child_db_groupings:
            for group, ids in tuple(child_dbs.items()):
                child_dbs[group] = [_ for _ in ids if isinstance(_, tuple)]
        child_dbs = list(child_dbs.items())

        LOGGER.debug("Final updated child_dbs: %s", child_dbs)

        return exclude_dbs, child_dbs

    def _get_more_child_dbs(self, change):
        """Query for more child DBs according to page_offset"""
        if self.providers.value is None:
            # This may be called if a provider is suddenly removed (bad provider)
            return

        if not self.__perform_query:
            self.__perform_query = True
            LOGGER.debug(
                "Will not perform query with pageing: name=%s value=%s",
                change["name"],
                change["new"],
            )
            return

        pageing: Union[int, str] = change["new"]
        LOGGER.debug(
            "Detected change in page_chooser: name=%s value=%s",
            change["name"],
            pageing,
        )
        if change["name"] == "page_offset":
            LOGGER.debug(
                "Got offset %d to retrieve more child DBs from %r",
                pageing,
                self.providers.value,
            )
            self.offset = pageing
            pageing = None
        elif change["name"] == "page_number":
            LOGGER.debug(
                "Got number %d to retrieve more child DBs from %r",
                pageing,
                self.providers.value,
            )
            self.number = pageing
            pageing = None
        else:
            LOGGER.debug(
                "Got link %r to retrieve more child DBs from %r",
                pageing,
                self.providers.value,
            )
            # It is needed to update page_offset, but we do not wish to query again
            with self.hold_trait_notifications():
                self.__perform_query = False
                self.page_chooser.update_offset()

        try:
            # Freeze and disable both dropdown widgets
            # We don't want changes leading to weird things happening prior to the query ending
            self.freeze()

            # Query index meta-database
            LOGGER.debug("Querying for more child DBs using pageing: %r",
                         pageing)
            child_dbs, links, _, _ = self._query(pageing)

            data_returned = self.page_chooser.data_returned
            while True:
                # Update list of child DBs in dropdown widget
                exclude_child_dbs, final_child_dbs = self._update_child_dbs(
                    child_dbs)

                data_returned -= len(exclude_child_dbs)
                if exclude_child_dbs and data_returned:
                    child_dbs, links, data_returned, _ = self._query(
                        link=pageing, exclude_ids=exclude_child_dbs)
                else:
                    break
            self._set_child_dbs(final_child_dbs)

            # Update pageing
            self.page_chooser.set_pagination_data(data_returned=data_returned,
                                                  links_to_page=links)

        except QueryError as exc:
            LOGGER.debug(
                "Trying to retrieve more child DBs (new page). QueryError caught: %r",
                exc,
            )
            if exc.remove_target:
                LOGGER.debug(
                    "Remove target: %r. Will remove target at %r: %r",
                    exc.remove_target,
                    self.providers.index,
                    self.providers.value,
                )
                self.providers.options = self._remove_current_dropdown_option(
                    self.providers)
                self.reset()
            else:
                LOGGER.debug(
                    "Remove target: %r. Will NOT remove target at %r: %r",
                    exc.remove_target,
                    self.providers.index,
                    self.providers.value,
                )
                self.show_child_dbs.display = "none"
                self.child_dbs.grouping = self.INITIAL_CHILD_DBS

        else:
            self.unfreeze()

    def _query(  # pylint: disable=too-many-locals,too-many-branches,too-many-statements
            self,
            link: str = None,
            exclude_ids: List[str] = None) -> Tuple[List[dict], dict, int,
                                                    int]:
        """Query helper function"""
        # If a complete link is provided, use it straight up
        if link is not None:
            try:
                if exclude_ids:
                    filter_value = " AND ".join(
                        [f'NOT id="{id_}"' for id_ in exclude_ids])

                    parsed_url = urllib.parse.urlparse(link)
                    queries = urllib.parse.parse_qs(parsed_url.query)
                    # Since parse_qs wraps all values in a list,
                    # this extracts the values from the list(s).
                    queries = {key: value[0] for key, value in queries.items()}

                    if "filter" in queries:
                        queries[
                            "filter"] = f"( {queries['filter']} ) AND ( {filter_value} )"
                    else:
                        queries["filter"] = filter_value

                    parsed_query = urllib.parse.urlencode(queries)

                    link = (
                        f"{parsed_url.scheme}://{parsed_url.netloc}{parsed_url.path}"
                        f"?{parsed_query}")

                link = ordered_query_url(link)
                response = SESSION.get(link, timeout=TIMEOUT_SECONDS)
                if response.from_cache:
                    LOGGER.debug("Request to %s was taken from cache !", link)
                response = response.json()
            except (
                    requests.exceptions.ConnectTimeout,
                    requests.exceptions.ConnectionError,
                    requests.exceptions.ReadTimeout,
            ) as exc:
                response = {
                    "errors": {
                        "msg": "CLIENT: Connection error or timeout.",
                        "url": link,
                        "Exception": repr(exc),
                    }
                }
            except json.JSONDecodeError as exc:
                response = {
                    "errors": {
                        "msg": "CLIENT: Could not decode response to JSON.",
                        "url": link,
                        "Exception": repr(exc),
                    }
                }
        else:
            filter_ = '( link_type="child" OR type="child" )'
            if exclude_ids:
                filter_ += (
                    " AND ( " +
                    " AND ".join([f'NOT id="{id_}"'
                                  for id_ in exclude_ids]) + " )")

            response = perform_optimade_query(
                filter=filter_,
                base_url=self.provider.base_url,
                endpoint="/links",
                page_limit=self.child_db_limit,
                page_offset=self.offset,
                page_number=self.number,
            )
        msg, http_errors = handle_errors(response)
        if msg:
            if 404 in http_errors:
                # If /links not found move on
                pass
            else:
                self.error_or_status_messages.value = msg
                raise QueryError(msg=msg, remove_target=True)

        # Check implementation API version
        msg = validate_api_version(response.get("meta",
                                                {}).get("api_version", ""),
                                   raise_on_fail=False)
        if msg:
            self.error_or_status_messages.value = (
                f"{msg}<br>The provider has been removed.")
            raise QueryError(msg=msg, remove_target=True)

        LOGGER.debug(
            "Manually remove `exclude_ids` if filters are not supported")
        child_db_data = {
            impl.get("id", "N/A"): impl
            for impl in response.get("data", [])
        }
        if exclude_ids:
            for links_id in exclude_ids:
                if links_id in list(child_db_data.keys()):
                    child_db_data.pop(links_id)
            LOGGER.debug("child_db_data after popping: %r", child_db_data)
            response["data"] = list(child_db_data.values())
            if "meta" in response:
                if "data_available" in response["meta"]:
                    old_data_available = response["meta"].get(
                        "data_available", 0)
                    if len(response["data"]) > old_data_available:
                        LOGGER.debug("raising OptimadeClientError")
                        raise OptimadeClientError(
                            f"Reported data_available ({old_data_available}) is smaller than "
                            f"curated list of responses ({len(response['data'])}).",
                        )
                response["meta"]["data_available"] = len(response["data"])
            else:
                raise OptimadeClientError(
                    "'meta' not found in response. Bad response")

        LOGGER.debug(
            "Attempt for %r (in /links): Found implementations (names+base_url only):\n%s",
            self.provider.name,
            [
                f"(id: {name}; base_url: {base_url}) " for name, base_url in [(
                    impl.get("id", "N/A"),
                    impl.get("attributes", {}).get("base_url", "N/A"),
                ) for impl in response.get("data", [])]
            ],
        )
        # Return all implementations of link_type "child"
        implementations = [
            implementation for implementation in response.get("data", [])
            if (implementation.get("attributes", {}).get("link_type", "") ==
                "child" or implementation.get("type", "") == "child")
        ]
        LOGGER.debug(
            "After curating for implementations which are of 'link_type' = 'child' or 'type' == "
            "'child' (old style):\n%s",
            [
                f"(id: {name}; base_url: {base_url}) " for name, base_url in [(
                    impl.get("id", "N/A"),
                    impl.get("attributes", {}).get("base_url", "N/A"),
                ) for impl in implementations]
            ],
        )

        # Get links, data_returned, and data_available
        links = response.get("links", {})
        data_returned = response.get("meta", {}).get("data_returned",
                                                     len(implementations))
        if data_returned > 0 and not implementations:
            # Most probably dealing with pre-v1.0.0-rc.2 implementations
            data_returned = 0
        data_available = response.get("meta", {}).get("data_available", 0)

        return implementations, links, data_returned, data_available
Example #3
0
class MultiToggleButtons_AllOrSome(Box):
    description = traitlets.Unicode()
    value = traitlets.Tuple()
    options = traitlets.Union([traitlets.List(), traitlets.Dict()])
    style = traitlets.Dict()

    def __init__(self, *, short_label_map=None, **kwargs):
        if short_label_map is None:
            short_label_map = {}
        super().__init__(**kwargs)
        self._selection_obj = widget_selection._MultipleSelection()
        traitlets.link((self, 'options'), (self._selection_obj, 'options'))
        traitlets.link((self, 'value'), (self._selection_obj, 'value'))

        @observer(self, 'options')
        def _(*_):
            self.buttons = []
            for label in self._selection_obj._options_labels:
                short_label = short_label_map.get(label, label)
                self.buttons.append(
                    ToggleButton(
                        description=short_label
                        if len(short_label) < 15 else short_label[:12] + "…",
                        tooltip=label,
                        layout=Layout(
                            margin='1',
                            width='auto',
                        ),
                    ))
            if self.description:
                self.label = Label(
                    self.description,
                    layout=Layout(
                        width=self.style.get('description_width', '100px')))
            else:
                self.label = Label(
                    self.description,
                    layout=Layout(
                        width=self.style.get('description_width', '0px')))
            self.children = [self.label] + self.buttons

            @observer(self.buttons, 'value')
            def _(*_):
                proposed_value = tuple(value for btn, value in zip(
                    self.buttons, self._selection_obj._options_values)
                                       if btn.value)
                # When nothing is selected, treat as if everything is selected.
                if len(proposed_value) == 0:
                    proposed_value = tuple(value for btn, value in zip(
                        self.buttons, self._selection_obj._options_values))
                self.value = proposed_value

        self.add_class('btn-group')

    def reset(self):
        opts = self.options
        self.options = []
        self.options = opts

    def set_value(self, x):
        for b, opt in zip(self.buttons, self.options):
            b.value = (opt in x)

    def set_all_on(self):
        for b, opt in zip(self.buttons, self.options):
            b.value = True

    def set_all_off(self):
        for b, opt in zip(self.buttons, self.options):
            b.value = False
Example #4
0
class ZarrOutputMixin(tl.HasTraits):
    """
    This class assumes that the node has a 'output_format' attribute
    (currently the "Lambda" Node, and the "Process" Node)

    Attributes
    -----------
    zarr_file: str
        Path to the output zarr file that collects all of the computed results. This can reside on S3.
    dataset: ZarrGroup
        A handle to the zarr group pointing to the output file
    fill_output: bool, optional
        Default is False (unlike parent class). If True, will collect the output data and return it as an xarray.
    init_file_mode: str, optional
        Default is 'w'. Mode used for initializing the zarr file.
    zarr_chunks: dict
        Size of the chunks in the zarr file for each dimension
    zarr_shape: dict, optional
        Default is the {coordinated.dims: coordinates.shape}, where coordinates used as part of the eval call. This
        does not need to be specified unless the Node modifies the input coordinates (as part of a Reduce operation,
        for example). The result can be incorrect and requires care/checking by the user.
    zarr_coordinates: podpac.Coordinates, optional
        Default is None. If the node modifies the shape of the input coordinates, this allows users to set the
        coordinates in the output zarr file. This can be incorrect and requires care by the user.
    skip_existing: bool
        Default is False. If true, this will check to see if the results already exist. And if so, it will not
        submit a job for that particular coordinate evaluation. This assumes self.chunks == self.zar_chunks
    list_dir: bool, optional
        Default is False. If skip_existing is True, by default existing files are checked by asking for an 'exists' call.
        If list_dir is True, then at the first opportunity a "list_dir" is performed on the directory and the results
        are cached.
    """

    zarr_file = tl.Unicode().tag(attr=True)
    dataset = tl.Any()
    zarr_node = NodeTrait()
    zarr_data_key = tl.Union([tl.Unicode(), tl.List()])
    fill_output = tl.Bool(False)
    init_file_mode = tl.Unicode("a").tag(attr=True)
    zarr_chunks = tl.Dict(default_value=None, allow_none=True).tag(attr=True)
    zarr_shape = tl.Dict(allow_none=True, default_value=None).tag(attr=True)
    zarr_coordinates = tl.Instance(Coordinates,
                                   allow_none=True,
                                   default_value=None).tag(attr=True)
    zarr_dtype = tl.Unicode("f4")
    skip_existing = tl.Bool(True).tag(attr=True)
    list_dir = tl.Bool(False)
    _list_dir = tl.List(allow_none=True, default_value=[])
    _shape = tl.Tuple()
    _chunks = tl.List()
    aws_client_kwargs = tl.Dict()
    aws_config_kwargs = tl.Dict()

    def eval(self, coordinates, **kwargs):
        output = kwargs.get("output")
        if self.zarr_shape is None:
            self._shape = coordinates.shape
        else:
            self._shape = tuple(self.zarr_shape.values())

        # initialize zarr file
        if self.zarr_chunks is None:
            chunks = [self.chunks[d] for d in coordinates]
        else:
            chunks = [self.zarr_chunks[d] for d in coordinates]
        self._chunks = chunks
        zf, data_key, zn = self.initialize_zarr_array(self._shape, chunks)
        self.dataset = zf
        self.zarr_data_key = data_key
        self.zarr_node = zn
        zn.keys

        # eval
        _log.debug("Starting parallel eval.")
        missing_dims = [
            d for d in coordinates.dims if d not in self.chunks.keys()
        ]
        if self.zarr_coordinates is not None:
            missing_dims = missing_dims + [
                d for d in self.zarr_coordinates.dims if d not in missing_dims
            ]
            set_coords = merge_dims(
                [coordinates.drop(missing_dims), self.zarr_coordinates])
        else:
            set_coords = coordinates.drop(missing_dims)
        set_coords.transpose(*coordinates.dims)

        self.set_zarr_coordinates(set_coords, data_key)
        if self.list_dir:
            dk = data_key
            if isinstance(dk, list):
                dk = dk[0]
            self._list_dir = self.zarr_node.list_dir(dk)

        output = super(ZarrOutputMixin, self).eval(coordinates, output=output)

        # fill in the coordinates, this is guaranteed to be correct even if the user messed up.
        if output is not None:
            self.set_zarr_coordinates(Coordinates.from_xarray(output.coords),
                                      data_key)
        else:
            return zf

        return output

    def set_zarr_coordinates(self, coordinates, data_key):
        # Fill in metadata
        for dk in data_key:
            self.dataset[dk].attrs["_ARRAY_DIMENSIONS"] = coordinates.dims
        for d in coordinates.dims:
            # TODO ADD UNITS AND TIME DECODING INFORMATION
            self.dataset.create_dataset(d,
                                        shape=coordinates[d].size,
                                        overwrite=True)
            self.dataset[d][:] = coordinates[d].coordinates

    def initialize_zarr_array(self, shape, chunks):
        _log.debug("Creating Zarr file.")
        zn = Zarr(source=self.zarr_file,
                  file_mode=self.init_file_mode,
                  aws_client_kwargs=self.aws_client_kwargs)
        if self.source.output or getattr(self.source, "data_key", None):
            data_key = self.source.output
            if data_key is None:
                data_key = self.source.data_key
            if not isinstance(data_key, list):
                data_key = [data_key]
            elif self.source.outputs:  # If someone restricted the outputs for this node, we need to know
                data_key = [dk for dk in data_key if dk in self.source.outputs]
        elif self.source.outputs:
            data_key = self.source.outputs
        else:
            data_key = ["data"]

        zf = zarr.open(zn._get_store(), mode=self.init_file_mode)

        # Intialize the output zarr arrays
        for dk in data_key:
            try:
                arr = zf.create_dataset(
                    dk,
                    shape=shape,
                    chunks=chunks,
                    fill_value=np.nan,
                    dtype=self.zarr_dtype,
                    overwrite=not self.skip_existing,
                )
            except ValueError:
                pass  # Dataset already exists

        # Recompute any cached properties
        zn = Zarr(source=self.zarr_file,
                  file_mode=self.init_file_mode,
                  aws_client_kwargs=self.aws_client_kwargs)
        return zf, data_key, zn

    def eval_source(self, coordinates, coordinates_index, out, i, source=None):
        if source is None:
            source = self.source

        if self.skip_existing:  # This section allows previously computed chunks to be skipped
            dk = self.zarr_data_key
            if isinstance(dk, list):
                dk = dk[0]
            try:
                exists = self.zarr_node.chunk_exists(coordinates_index,
                                                     data_key=dk,
                                                     list_dir=self._list_dir,
                                                     chunks=self._chunks)
            except ValueError as e:  # This was needed in cases where a poor internet connection caused read errors
                exists = False
            if exists:
                _log.info("Skipping {} (already exists)".format(i))
                return out, coordinates_index

        # Make a copy to prevent any possibility of memory corruption
        source = Node.from_definition(source.definition)
        _log.debug("Creating output format.")
        output = dict(
            format="zarr_part",
            format_kwargs=dict(
                part=[[s.start, min(s.stop, self._shape[i]), s.step]
                      for i, s in enumerate(coordinates_index)],
                source=self.zarr_file,
                mode="a",
            ),
        )
        _log.debug("Finished creating output format.")

        if source.has_trait("output_format"):
            source.set_trait("output_format", output)
        _log.debug("output: {}, coordinates.shape: {}".format(
            output, coordinates.shape))
        _log.debug("Evaluating node.")

        o, slc = super(ZarrOutputMixin,
                       self).eval_source(coordinates, coordinates_index, out,
                                         i, source)

        if not source.has_trait("output_format"):
            o.to_format(output["format"], **output["format_kwargs"])
        return o, slc
Example #5
0
class XELK(ElkTransformer):
    """NetworkX DiGraphs to ELK dictionary structure"""

    HIDDEN_ATTR = "hidden"
    hoist_hidden_edges: bool = True

    source = T.Tuple(T.Instance(nx.Graph),
                     T.Instance(nx.DiGraph, allow_none=True))
    layouts = T.Dict()  # keys: networkx nodes {ElkElements: {layout options}}
    css_classes = T.Dict()

    port_scale = T.Int(default_value=5)
    label_key = T.Unicode(default_value="labels")
    port_key = T.Unicode(default_value="ports")

    @T.default("source")
    def _default_source(self):
        return (nx.Graph(), None)

    @T.default("text_sizer")
    def _default_text_sizer(self):
        return ElkTextSizer()

    @T.default("layouts")
    def _default_layouts(self):
        parent_opts = opt.OptionsWidget(
            identifier="parents",
            options=[
                opt.HierarchyHandling(),
            ],
        )
        label_opts = opt.OptionsWidget(
            identifier=ElkLabel,
            options=[opt.NodeLabelPlacement(horizontal="center")])
        node_opts = opt.OptionsWidget(
            identifier=ElkNode,
            options=[
                opt.NodeSizeConstraints(),
            ],
        )

        default = opt.OptionsWidget(
            options=[parent_opts, node_opts, label_opts])
        return {ElkRoot: default.value}

    def node_id(self, node: Hashable) -> str:
        """Get the element id for a node in the main graph for use in elk

        :param node: Node in main  graph
        :type node: Hashable
        :return: Element ID
        :rtype: str
        """
        g, tree = self.source
        if node is ElkRoot:
            return self.ELK_ROOT_ID
        elif node in g:
            return g.nodes.get(node, {}).get("id", f"{node}")
        return f"{node}"

    def port_id(self, node: Hashable, port: Optional[Hashable] = None) -> str:
        """Get a suitable Elk identifier from the node and port

        :param node: Node from the incoming networkx graph
        :type node: Hashable
        :param port: Port identifier, defaults to None
        :type port: Optional[Hashable], optional
        :return: If no port is provided will refer to the node
        :rtype: str
        """
        if port is None:
            return self.node_id(node)
        else:
            return f"{self.node_id(node)}.{port}"

    def edge_id(self, edge: Edge):
        # TODO probably will need more sophisticated id generation in future
        return "{}__{}___{}__{}".format(edge.source, edge.source_port,
                                        edge.target, edge.target_port)

    async def transform(self) -> ElkNode:
        """Generate ELK dictionary structure
        :return: Root Elk node
        :rtype: ElkNode
        """
        # TODO decide behavior for nodes that exist in the tree but not g
        g, tree = self.source
        self.clear_registry()
        visible_edges, hidden_edges = self.collect_edges()

        # Process visible networkx nodes into elknodes
        visible_nodes = [
            n for n in g.nodes() if not is_hidden(tree, n, self.HIDDEN_ATTR)
        ]

        # make elknodes then connect their hierarchy
        elknodes: NodeMap = {}
        ports: PortMap = {}
        for node in visible_nodes:
            elknode, node_ports = await self.make_elknode(node)
            for key, value in node_ports.items():
                ports[key] = value
            elknodes[node] = elknode

        # make top level ElkNode and attach all others as children
        elknodes[ElkRoot] = top = ElkNode(
            id=self.ELK_ROOT_ID,
            children=build_hierarchy(g, tree, elknodes, self.HIDDEN_ATTR),
            layoutOptions=self.get_layout(ElkRoot, "parents"),
        )

        # map of original nodes to the generated elknodes
        for node, elk_node in elknodes.items():
            self.register(elk_node, node)

        elknodes, ports = await self.process_edges(elknodes, ports,
                                                   visible_edges)
        # process edges with one or both original endpoints are hidden
        if self.hoist_hidden_edges:
            elknodes, ports = await self.process_edges(
                elknodes,
                ports,
                hidden_edges,
                edge_style={"slack-edge"},
                port_style={"slack-port"},
            )

        # iterate through the port map and add ports to ElkNodes
        for port_id, port in ports.items():
            owner = port.node
            elkport = port.elkport
            if owner not in elknodes:
                # TODO skip generating port to begin with
                break
            elknode = elknodes[owner]
            if elknode.ports is None:
                elknode.ports = []
            layout = self.get_layout(owner, ElkPort)
            elkport.layoutOptions = merge(elkport.layoutOptions, layout)
            elknode.ports += [elkport]

            # map of ports to the generated elkports
            self.register(elkport, port)

        # bulk calculate label sizes
        await size_labels(self.text_sizer, collect_labels([top]))

        return top

    async def make_elknode(self, node) -> Tuple[ElkNode, PortMap]:
        # merge layout options defined on the node data with default layout
        # options
        node_data = self.get_node_data(node)
        layout = merge(
            node_data.get("layoutOptions", {}),
            self.get_layout(node, ElkNode),
        )
        labels = await self.make_labels(node)

        # update port map with declared ports in the networkx node data
        node_ports = await self.collect_ports(node)

        properties = self.get_properties(node, self.get_css(node,
                                                            ElkNode)) or None

        elk_node = ElkNode(
            id=self.node_id(node),
            labels=labels,
            layoutOptions=layout,
            properties=properties,
            width=node_data.get("width", None),
            height=node_data.get("height", None),
        )
        return elk_node, node_ports

    def get_layout(self, node: Hashable,
                   elk_type: Type[ElkGraphElement]) -> Optional[Dict]:
        """Get the Elk Layout Options appropriate for given networkx node and
        filter by given elk_type

        :param node: [description]
        :type node: Hashable
        :param elk_type: [description]
        :type elk_type: Type[ElkGraphElement]
        :return: [description]
        :rtype: [type]
        """
        # TODO look at self.source hierarchy and resolve layout with added
        # infomation. until then use root node `None` for layout options
        if node not in self.layouts:
            node = ElkRoot

        type_opts = self.layouts.get(node, {})
        options = {**type_opts.get(elk_type, {})}
        if options:
            return options

    def get_properties(
        self,
        element: Optional[Union[Hashable, str]],
        dom_classes: Optional[Set[str]] = None,
    ) -> Dict:
        """Get the properties for a graph element

        :param element: Networkx node or edge
        :type node: Hashable
        :param dom_classes: Set of base CSS DOM classes to merge, defaults to
        Set[str]=None
        :type dom_classes: [type], optional
        :return: Set of CSS Classes to apply
        :rtype: Set[str]
        """

        g, tree = self.source

        properties = []

        if g and element in g:
            g_props = g.nodes[element].get("properties", {})
            if g_props:
                properties += [g_props]
        if hasattr(element, "data"):
            properties += [element.data.get("properties", {})]
        elif isinstance(element, dict):
            properties += [element.get("properties", {})]

        if dom_classes:
            properties += [{"cssClasses": " ".join(dom_classes)}]

        if not properties:
            return {}
        elif len(properties) == 1:
            return properties[0]

        merged_properties = {}

        for props in properties[::-1]:
            merged_properties = merge(props, merged_properties)

        return merged_properties

    def get_css(
        self,
        node: Hashable,
        elk_type: Type[ElkGraphElement],
        dom_classes: Set[str] = None,
    ) -> Set[str]:
        """Get the CSS Classes appropriate for given networkx node given
        elk_type

        :param node: Networkx node
        :type node: Hashable
        :param elk_type: ElkGraphElement to get appropriate css classes
        :type elk_type: Type[ElkGraphElement]
        :param dom_classes: Set of base CSS DOM classes to merge, defaults to
        Set[str]=None
        :type dom_classes: [type], optional
        :return: Set of CSS Classes to apply
        :rtype: Set[str]
        """
        typed_css = self.css_classes.get(node, {})
        css_classes = set(typed_css.get(elk_type, []))
        if dom_classes is None:
            return css_classes
        return css_classes | dom_classes

    async def process_edges(
        self,
        nodes: NodeMap,
        ports: PortMap,
        edges: EdgeMap,
        edge_style: Set[str] = None,
        port_style: Set[str] = None,
    ) -> Tuple[NodeMap, PortMap]:
        for owner, edge_list in edges.items():
            edge_css = self.get_css(owner, ElkEdge, edge_style)
            port_css = self.get_css(owner, ElkPort, port_style)
            for edge in edge_list:
                elknode = nodes[owner]
                if elknode.edges is None:
                    elknode.edges = []
                if edge.source_port is not None:
                    port_id = self.port_id(edge.source, edge.source_port)
                    if port_id not in ports:
                        ports[port_id] = await self.make_port(
                            edge.source, edge.source_port, port_css)
                if edge.target_port is not None:
                    port_id = self.port_id(edge.target, edge.target_port)
                    if port_id not in ports:
                        ports[port_id] = await self.make_port(
                            edge.target, edge.target_port, port_css)

                elknode.edges += [await self.make_edge(edge, edge_css)]
        return nodes, ports

    async def make_edge(self,
                        edge: Edge,
                        styles: Optional[Set[str]] = None) -> ElkExtendedEdge:
        """Make the associated Elk edge for the given Edge

        :param edge: Edge object to wrap
        :type edge: Edge
        :param styles: List of css classes to add to given Elk edge, defaults to None
        :type styles: Optional[List[str]], optional
        :return: Elk edge
        :rtype: ElkExtendedEdge
        """

        labels = []
        properties = self.get_properties(edge, styles) or None
        label_layout_options = self.get_layout(
            edge.owner, ElkLabel)  # TODO add edgelabel type?
        edge_layout_options = self.get_layout(edge.owner, ElkEdge)

        for i, label in enumerate(edge.data.get(self.label_key, [])):

            if isinstance(label, ElkLabel):
                label = label.to_dict()  # used to create copy of label
            if isinstance(label, dict):
                label = ElkLabel.from_dict(label)
            if isinstance(label, str):
                label = ElkLabel(id=f"{edge.owner}_label_{i}_{label}",
                                 text=label)
            label.layoutOptions = merge(label.layoutOptions,
                                        label_layout_options)
            labels.append(label)
        for label in labels:
            self.register(label, edge)
        elk_edge = ElkExtendedEdge(
            id=edge.data.get("id", self.edge_id(edge)),
            sources=[self.port_id(edge.source, edge.source_port)],
            targets=[self.port_id(edge.target, edge.target_port)],
            properties=properties,
            layoutOptions=merge(edge.data.get("layoutOptions"),
                                edge_layout_options),
            labels=compact(labels),
        )
        self.register(elk_edge, edge)
        return elk_edge

    async def make_port(self,
                        owner: Hashable,
                        port: Hashable,
                        styles: Optional[Set[str]] = None) -> Port:
        """Make the associated elk port for the given owner node and port

        :param owner: [description]
        :type owner: Hashable
        :param port: [description]
        :type port: Hashable
        :param styles: list of css classes to apply to given ElkPort
        :type styles: List[str]
        :return: [description]
        :rtype: ElkPort
        """
        port_id = self.port_id(owner, port)
        properties = self.get_properties(port, styles) or None

        elk_port = ElkPort(
            id=port_id,
            height=self.port_scale,
            width=self.port_scale,
            properties=properties,
            # TODO labels
        )
        return Port(node=owner, elkport=elk_port)

    def get_node_data(self, node: Hashable) -> Dict:
        g, tree = self.source
        return g.nodes.get(node, {})

    async def collect_ports(self, *nodes) -> PortMap:
        ports: PortMap = {}
        for node in nodes:
            values = self.get_node_data(node).get(self.port_key, [])
            for i, port in enumerate(values):
                if isinstance(port, ElkPort):
                    # generate a fresh copy of the port to prevent mutating original
                    port = port.to_dict()
                elif isinstance(port, str):
                    port_id = self.port_id(node, port)
                    port = {
                        "id": port_id,
                        "labels": [{
                            "text": port,
                            "id": f"{port}_label_{i}",
                        }],
                    }

                if isinstance(port, dict):
                    elkport = ElkPort.from_dict(port)

                if elkport.width is None:
                    elkport.width = self.port_scale
                if elkport.height is None:
                    elkport.height = self.port_scale
                ports[elkport.id] = Port(node=node, elkport=elkport)
        return ports

    async def make_labels(self, node: Hashable) -> Optional[List[ElkLabel]]:
        labels = []
        g = self.source[0]
        data = g.nodes.get(node, {})
        values = data.get(self.label_key, [data.get("_id", f"{node}")])

        properties = {}
        css_classes = self.get_css(node, ElkLabel)
        if css_classes:
            properties["cssClasses"] = " ".join(css_classes)

        if isinstance(values, (str, ElkLabel)):
            values = [values]
        # get node labels
        for i, label in enumerate(values):
            if isinstance(label, str):
                label = ElkLabel(
                    id=f"{label}_label_{i}_{node}",
                    text=label,
                )
            elif isinstance(label, ElkLabel):
                # prevent mutating original label in the node data
                label = label.to_dict()
            if isinstance(label, dict):
                label = ElkLabel.from_dict(label)

            # add css classes and layout options
            label.layoutOptions = merge(label.layoutOptions,
                                        self.get_layout(node, ElkLabel))
            merged_props = merge(label.properties, properties)
            if merged_props is not None:
                merged_props = ElkProperties.from_dict(merged_props)
            label.properties = merged_props

            labels.append(label)
            self.register(label, node)
        return labels

    def collect_edges(self) -> Tuple[EdgeMap, EdgeMap]:
        """[summary]

        :return: [description]
        :rtype: Tuple[
            Dict[Hashable, List[ElkExtendedEdge]],
            Dict[Hashable, List[ElkExtendedEdge]]
        ]
        """
        visible: EdgeMap = defaultdict(
            list)  # will index edges by nx.lowest_common_ancestor
        hidden: EdgeMap = defaultdict(
            list)  # will index edges by nx.lowest_common_ancestor
        g, tree = self.source
        attr = self.HIDDEN_ATTR
        hidden: List[Tuple[List, List]] = []

        visible: EdgeMap = defaultdict(
            list)  # will index edges by nx.lowest_common_ancestor
        hidden: EdgeMap = defaultdict(
            list)  # will index edges by nx.lowest_common_ancestor

        closest_visible = map_visible(g, tree, attr)

        @lru_cache()
        def closest_common_visible(nodes: Tuple[Hashable]) -> Hashable:
            if tree is None:
                return ElkRoot
            result = lowest_common_ancestor(tree, nodes)
            return result

        for source, target, edge_data in g.edges(data=True):
            source_port, target_port = get_ports(edge_data)
            vis_source = closest_visible[source]
            vis_target = closest_visible[target]
            shidden = vis_source != source
            thidden = vis_target != target
            owner = closest_common_visible((vis_source, vis_target))
            if source == target and source == owner:
                if owner in tree:
                    for p in tree.predecessors(owner):
                        # need to make this edge's owner to it's parent
                        owner = p

            if shidden or thidden:
                # create new slack ports if source or target is remapped
                if vis_source != source:
                    source_port = (source, source_port)
                if vis_target != target:
                    target_port = (target, target_port)

                if vis_source != vis_target:
                    hidden[owner].append(
                        Edge(
                            source=vis_source,
                            source_port=source_port,
                            target=vis_target,
                            target_port=target_port,
                            data=edge_data,
                            owner=owner,
                        ))
            else:
                visible[owner].append(
                    Edge(
                        source=source,
                        source_port=source_port,
                        target=target,
                        target_port=target_port,
                        data=edge_data,
                        owner=owner,
                    ))
        return visible, hidden
Example #6
0
class Widget(DOMWidget):
    _view_name = traitlets.Unicode('ReboundView').tag(sync=True)
    _view_module = traitlets.Unicode('rebound').tag(sync=True)
    count = traitlets.Int(0).tag(sync=True)
    t = traitlets.Float().tag(sync=True)
    N = traitlets.Int().tag(sync=True)
    width = traitlets.Float().tag(sync=True)
    height = traitlets.Float().tag(sync=True)
    scale = traitlets.Float().tag(sync=True)
    particle_data = traitlets.Bytes().tag(sync=True)
    orbit_data = traitlets.Bytes().tag(sync=True)
    orientation = traitlets.Tuple().tag(sync=True)
    orbits = traitlets.Int().tag(sync=True)

    def __init__(self,
                 simulation,
                 size=(200, 200),
                 orientation=(0., 0., 0., 1.),
                 scale=None,
                 autorefresh=True,
                 orbits=True):
        """ 
        Initializes a Widget.

        Widgets provide real-time 3D interactive visualizations for REBOUND simulations 
        within Jupyter Notebooks. To use widgets, the ipywidgets package needs to be installed
        and enabled in your Jupyter notebook server. 

        Parameters
        ----------
        size : (int, int), optional
            Specify the size of the widget in pixels. The default is 200 times 200 pixels.
        orientation : (float, float, float, float), optional
            Specify the initial orientation of the view. The four floats correspond to the
            x, y, z, and w components of a quaternion. The quaternion will be normalized.
        scale : float, optional
            Set the initial scale of the view. If not set, the widget will determine the 
            scale automatically based on current particle positions.
        autorefresh : bool, optional
            The default value if True. The view is updated whenever a particle is added,
            removed and every 100th of a second while a simulation is running. If set 
            to False, then the user needs to manually call the refresh() function on the 
            widget. This might be useful if performance is an issue.
        orbits : bool, optional
            The default value for this is True and the widget will draw the instantaneous 
            orbits of the particles. For simulations in which particles are not on
            Keplerian orbits, the orbits shown will not be accurate. 
        """
        self.width, self.height = size
        self.t, self.N = simulation.t, simulation.N
        self.orientation = orientation
        self.autorefresh = autorefresh
        self.orbits = orbits
        self.simp = pointer(simulation)
        clibrebound.reb_display_copy_data.restype = c_int
        if scale is None:
            self.scale = simulation.display_data.contents.scale
        else:
            self.scale = scale
        self.count += 1

        super(Widget, self).__init__()

    def refresh(self, simp=None, isauto=0):
        """ 
        Manually refreshes a widget.
        
        Note that this function can also be called using the wrapper function of
        the Simulation object: sim.refreshWidgets(). 
        """

        if simp == None:
            simp = self.simp
        if self.autorefresh == 0 and isauto == 1:
            return
        sim = simp.contents
        size_changed = clibrebound.reb_display_copy_data(simp)
        clibrebound.reb_display_prepare_data(simp, c_int(self.orbits))
        if sim.N > 0:
            self.particle_data = (c_char * (4 * 7 * sim.N)).from_address(
                sim.display_data.contents.particle_data).raw
            if self.orbits:
                self.orbit_data = (
                    c_char * (4 * 9 * (sim.N - 1))).from_address(
                        sim.display_data.contents.orbit_data).raw
        if size_changed:
            #TODO: Implement better GPU size change
            pass
        self.N = sim.N
        self.t = sim.t
        self.count += 1

    @staticmethod
    def getClientCode():
        return shader_code + js_code
Example #7
0
class RobustScaler(Transformer):
    ''' The RobustScaler removes the median and scales the data according to a
    given percentile range. By default, the scaling is done between the 25th and
    the 75th percentile. Centering and scaling happens independently for each
    feature (column).

    Example:

    >>> import vaex
    >>> df = vaex.from_arrays(x=[2,5,7,2,15], y=[-2,3,0,0,10])
    >>> df
     #    x    y
     0    2   -2
     1    5    3
     2    7    0
     3    2    0
     4   15   10
    >>> scaler = vaex.ml.MaxAbsScaler(features=['x', 'y'])
    >>> scaler.fit_transform(df)
     #    x    y    robust_scaled_x    robust_scaled_y
     0    2   -2       -0.333686             -0.266302
     1    5    3       -0.000596934           0.399453
     2    7    0        0.221462              0
     3    2    0       -0.333686              0
     4   15   10        1.1097                1.33151
    '''
    with_centering = traitlets.CBool(
        default_value=True,
        help='If True, remove the median.').tag(ui='Checkbox')
    with_scaling = traitlets.CBool(
        default_value=True,
        help=
        'If True, scale each feature between the specified percentile range.'
    ).tag(ui='Checkbox')
    percentile_range = traitlets.Tuple(
        default_value=(25, 75),
        help='The percentile range to which to scale each feature to.').tag(
        ).tag(ui='FloatRangeSlider')
    prefix = traitlets.Unicode(default_value="robust_scaled_",
                               help=help_prefix).tag(ui='Text')
    center_ = traitlets.List(
        traitlets.CFloat(),
        default_value=None,
        help='The median of each feature.').tag(output=True)
    scale_ = traitlets.List(
        traitlets.CFloat(),
        default_value=None,
        help='The percentile range for each feature.').tag(output=True)

    def fit(self, df):
        '''
        Fit RobustScaler to the DataFrame.

        :param df: A vaex DataFrame.
        '''

        # check the quantile range
        q_min, q_max = self.percentile_range
        if not 0 <= q_min <= q_max <= 100:
            raise ValueError('Invalid percentile range: %s' %
                             (str(self.percentile_range)))

        if self.with_centering:
            self.center_ = df.percentile_approx(expression=self.features,
                                                percentage=50).tolist()

        if self.with_scaling:
            self.scale_ = (df.percentile_approx(expression=self.features,
                                                percentage=q_max) -
                           df.percentile_approx(expression=self.features,
                                                percentage=q_min)).tolist()

    def transform(self, df):
        '''
        Transform a DataFrame with a fitted RobustScaler.

        :param df: A vaex DataFrame.

        :returns copy: a shallow copy of the DataFrame that includes the scaled features.
        :rtype: DataFrame
        '''

        copy = df.copy()
        for i in range(len(self.features)):
            name = self.prefix + self.features[i]
            expr = copy[self.features[i]]
            if self.with_centering:
                expr = expr - self.center_[i]
            if self.with_scaling:
                expr = expr / self.scale_[i]
            copy[name] = expr
        return copy
Example #8
0
class MolViz2DBaseWidget(MessageWidget):
    """
    This is actually the D3.js graphics driver for the 2D base widget.
    It should be refactored with an abstract base class if
    there's a chance of adding another graphics driver.
    """
    _view_name = Unicode('MolWidget2DView').tag(sync=True)
    _model_name = Unicode('MolWidget2DModel').tag(sync=True)
    _view_module = Unicode('nbmolviz-js').tag(sync=True)
    _model_module = Unicode('nbmolviz-js').tag(sync=True)

    charge = traitlets.Float().tag(sync=True)
    uuid = traitlets.Unicode().tag(sync=True)
    graph = traitlets.Dict().tag(sync=True)
    clicked_atom_index = traitlets.Int(-1).tag(sync=True)
    clicked_bond_indices = traitlets.Tuple((-1, -1)).tag(sync=True)
    _atom_colors = traitlets.Dict({}).tag(sync=True)
    width = traitlets.Float().tag(sync=True)
    height = traitlets.Float().tag(sync=True)

    def __init__(self, atoms,
                 charge=-150,
                 width=400, height=350,
                 **kwargs):
        super(MolViz2DBaseWidget, self).__init__(width=width,
                                                 height=height,
                                                 **kwargs)
        try:
            self.atoms = atoms.atoms
        except AttributeError:
            self.atoms = atoms
        else:
            self.entity = atoms
        self.width = width
        self.height = height
        self.uuid = 'mol2d'+str(uuid.uuid4())
        self.charge = charge
        self._clicks_enabled = False
        self.graph = self.to_graph(self.atoms)

    def to_graph(self, atoms):
        """Turn a set of atoms into a graph
        Should return a dict of the form
        {nodes:[a1,a2,a3...],
        links:[b1,b2,b3...]}
        where ai = {atom:[atom name],color='black',size=1,index:i}
        and bi = {bond:[order],source:[i1],dest:[i2],
                color/category='black',distance=22.0,strength=1.0}
        You can assign an explicit color with "color" OR
        get automatically assigned unique colors using "category"
        """
        raise NotImplementedError("This method must be implemented by the interface class")

    def set_atom_style(self, atoms=None, fill_color=None, outline_color=None):
        if atoms is None:
            indices = range(len(self.atoms))
        else:
            indices = map(self.get_atom_index, atoms)
        spec = {}
        if fill_color is not None: spec['fill'] = translate_color(fill_color, prefix='#')
        if outline_color is not None: spec['stroke'] = translate_color(outline_color, prefix='#')
        self.viewer('setAtomStyle', [indices, spec])

    def set_bond_style(self, bonds, color=None, width=None, dash_length=None, opacity=None):
        """
        :param bonds: List of atoms
        :param color:
        :param width:
        :param dash_length:
        :return:
        """
        atom_pairs = [map(self.get_atom_index, pair) for pair in bonds]
        spec = {}
        if width is not None: spec['stroke-width'] = str(width)+'px'
        if color is not None: spec['stroke'] = color
        if dash_length is not None: spec['stroke-dasharray'] = str(dash_length)+'px'
        if opacity is not None: spec['opacity'] = opacity
        if not spec: raise ValueError('No bond style specified!')
        self.viewer('setBondStyle', [atom_pairs, spec])

    def set_atom_label(self, atom, text=None, text_color=None, size=None, font=None):
        atomidx = self.get_atom_index(atom)
        self._change_label('setAtomLabel', atomidx, text, text_color, size, font)

    def set_bond_label(self, bond, text=None, text_color=None, size=None, font=None):
        bondids = map(self.get_atom_index, bond)
        self._change_label('setBondLabel', bondids, text, text_color, size, font)

    def _change_label(self, driver_function, obj_index, text,
                      text_color, size, font):
        spec = {}
        if size is not None:
            if type(size) is not str:
                size = str(size)+'pt'
                spec['font-size'] = size
        if text_color is not None:
            spec['fill'] = text_color  # this strangely doesn't always work if you send it a name
        if font is not None:
            spec['font'] = font
        self.viewer(driver_function, [obj_index, text, spec])

    def highlight_atoms(self, atoms):
        indices = map(self.get_atom_index, atoms)
        self.viewer('updateHighlightAtoms', [indices])

    def get_atom_index(self, atom):
        raise NotImplemented("This method must be implemented by the interface class")

    def set_click_callback(self, callback=None, enabled=True):
        """
        :param callback: Callback can have signature (), (trait_name), (trait_name,old), or (trait_name,old,new)
        :type callback: callable
        :param enabled:
        :return:
        """
        if not enabled: return  # TODO: FIX THIS
        assert callable(callback)
        self._clicks_enabled = True
        self.on_trait_change(callback, 'clicked_atom_index')
        self.click_callback = callback

    def set_color(self, color, atoms=None, render=None):
        self.set_atom_style(fill_color=color, atoms=atoms)

    def set_colors(self, colormap, render=True):
        """
        Args:
         colormap(Mapping[str,List[Atoms]]): mapping of colors to atoms
        """
        for color, atoms in colormap.iteritems():
            self.set_color(atoms=atoms, color=color)
Example #9
0
class MinMaxScaler(Transformer):
    '''Will scale a set of features to a given range.

    Example:

    >>> import vaex
    >>> df = vaex.from_arrays(x=[2,5,7,2,15], y=[-2,3,0,0,10])
    >>> df
     #    x    y
     0    2   -2
     1    5    3
     2    7    0
     3    2    0
     4   15   10
    >>> scaler = vaex.ml.MinMaxScaler(features=['x', 'y'])
    >>> scaler.fit_transform(df)
     #    x    y    minmax_scaled_x    minmax_scaled_y
     0    2   -2           0                  0
     1    5    3           0.230769           0.416667
     2    7    0           0.384615           0.166667
     3    2    0           0                  0.166667
     4   15   10           1                  1
    '''
    # title = Unicode(default_value='MinMax Scaler', read_only=True).tag(ui='HTML')
    feature_range = traitlets.Tuple(
        default_value=(0, 1),
        help='The range the features are scaled to.').tag().tag(
            ui='FloatRangeSlider')
    prefix = traitlets.Unicode(default_value="minmax_scaled_",
                               help=help_prefix).tag(ui='Text')
    fmax_ = traitlets.List(
        traitlets.CFloat(),
        help='The minimum value of a feature.').tag(output=True)
    fmin_ = traitlets.List(
        traitlets.CFloat(),
        help='The maximum value of a feature.').tag(output=True)

    def fit(self, df):
        '''
        Fit MinMaxScaler to the DataFrame.

        :param df: A vaex DataFrame.
        '''

        assert len(
            self.feature_range) == 2, 'feature_range must have 2 elements only'
        minmax = df.minmax(self.features)
        self.fmin_ = minmax[:, 0].tolist()
        self.fmax_ = minmax[:, 1].tolist()

    def transform(self, df):
        '''
        Transform a DataFrame with a fitted MinMaxScaler.

        :param df: A vaex DataFrame.

        :return copy: a shallow copy of the DataFrame that includes the scaled features.
        :rtype: DataFrame
        '''

        copy = df.copy()

        for i in range(len(self.features)):
            name = self.prefix + self.features[i]
            a = self.feature_range[0]
            b = self.feature_range[1]
            expr = copy[self.features[i]]
            expr = (b - a) * (expr - self.fmin_[i]) / (self.fmax_[i] -
                                                       self.fmin_[i]) + a
            copy[name] = expr
        return copy
Example #10
0
class MinMaxScaler(Transformer):
    '''Will scale a set of features to a given range.

    Example:

    >>> import vaex
    >>> df = vaex.from_arrays(x=[2,5,7,2,15], y=[-2,3,0,0,10])
    >>> df
     #    x    y
     0    2   -2
     1    5    3
     2    7    0
     3    2    0
     4   15   10
    >>> scaler = vaex.ml.MinMaxScaler(features=['x', 'y'])
    >>> scaler.fit_transform(df)
     #    x    y    minmax_scaled_x    minmax_scaled_y
     0    2   -2           0                  0
     1    5    3           0.230769           0.416667
     2    7    0           0.384615           0.166667
     3    2    0           0                  0.166667
     4   15   10           1                  1
    '''
    snake_name = 'minmax_scaler'
    # title = Unicode(default_value='MinMax Scaler', read_only=True).tag(ui='HTML')
    feature_range = traitlets.Tuple(
        default_value=(0, 1),
        help='The range the features are scaled to.').tag().tag(
            ui='FloatRangeSlider')
    prefix = traitlets.Unicode(default_value="minmax_scaled_",
                               help=help_prefix).tag(ui='Text')
    fmax_ = traitlets.List(
        traitlets.CFloat(),
        help='The minimum value of a feature.').tag(output=True)
    fmin_ = traitlets.List(
        traitlets.CFloat(),
        help='The maximum value of a feature.').tag(output=True)

    def fit(self, df):
        '''
        Fit MinMaxScaler to the DataFrame.

        :param df: A vaex DataFrame.
        '''

        minmax = []
        for feat in self.features:
            minmax.append(df.minmax(feat, delay=True))

        @vaex.delayed
        def assign(minmax):
            self.fmin_ = [elem[0] for elem in minmax]
            self.fmax_ = [elem[1] for elem in minmax]

        assign(minmax)
        df.execute()

    def transform(self, df):
        '''
        Transform a DataFrame with a fitted MinMaxScaler.

        :param df: A vaex DataFrame.

        :return copy: a shallow copy of the DataFrame that includes the scaled features.
        :rtype: DataFrame
        '''

        copy = df.copy()

        for i, feature in enumerate(self.features):
            name = self.prefix + feature
            a = self.feature_range[0]
            b = self.feature_range[1]
            expr = copy[feature]
            expr = (b - a) * (expr - self.fmin_[i]) / (self.fmax_[i] -
                                                       self.fmin_[i]) + a
            copy[name] = expr
        return copy
Example #11
0
class Widget(DOMWidget):
    _view_name = traitlets.Unicode('ReboundView').tag(sync=True)
    _view_module = traitlets.Unicode('rebound').tag(sync=True)
    count = traitlets.Int(0).tag(sync=True)
    screenshotcount = traitlets.Int(0).tag(sync=True)
    t = traitlets.Float().tag(sync=True)
    N = traitlets.Int().tag(sync=True)
    width = traitlets.Float().tag(sync=True)
    height = traitlets.Float().tag(sync=True)
    scale = traitlets.Float().tag(sync=True)
    particle_data = traitlets.CBytes(allow_none=True).tag(sync=True)
    orbit_data = traitlets.CBytes(allow_none=True).tag(sync=True)
    orientation = traitlets.Tuple().tag(sync=True)
    orbits = traitlets.Int().tag(sync=True)
    screenshot = traitlets.Unicode().tag(sync=True)

    def __init__(self,
                 simulation,
                 size=(200, 200),
                 orientation=(0., 0., 0., 1.),
                 scale=None,
                 autorefresh=True,
                 orbits=True):
        """ 
        Initializes a Widget.

        Widgets provide real-time 3D interactive visualizations for REBOUND simulations 
        within Jupyter Notebooks. To use widgets, the ipywidgets package needs to be installed
        and enabled in your Jupyter notebook server. 

        Parameters
        ----------
        size : (int, int), optional
            Specify the size of the widget in pixels. The default is 200 times 200 pixels.
        orientation : (float, float, float, float), optional
            Specify the initial orientation of the view. The four floats correspond to the
            x, y, z, and w components of a quaternion. The quaternion will be normalized.
        scale : float, optional
            Set the initial scale of the view. If not set, the widget will determine the 
            scale automatically based on current particle positions.
        autorefresh : bool, optional
            The default value if True. The view is updated whenever a particle is added,
            removed and every 100th of a second while a simulation is running. If set 
            to False, then the user needs to manually call the refresh() function on the 
            widget. This might be useful if performance is an issue.
        orbits : bool, optional
            The default value for this is True and the widget will draw the instantaneous 
            orbits of the particles. For simulations in which particles are not on
            Keplerian orbits, the orbits shown will not be accurate. 
        """
        self.screenshotcountall = 0
        self.width, self.height = size
        self.t, self.N = simulation.t, simulation.N
        self.orientation = orientation
        self.autorefresh = autorefresh
        self.orbits = orbits
        self.simp = pointer(simulation)
        clibrebound.reb_display_copy_data.restype = c_int
        if scale is None:
            self.scale = simulation.display_data.contents.scale
        else:
            self.scale = scale
        self.count += 1

        super(Widget, self).__init__()

    def refresh(self, simp=None, isauto=0):
        """ 
        Manually refreshes a widget.
        
        Note that this function can also be called using the wrapper function of
        the Simulation object: sim.refreshWidgets(). 
        """

        if simp == None:
            simp = self.simp
        if self.autorefresh == 0 and isauto == 1:
            return
        sim = simp.contents
        size_changed = clibrebound.reb_display_copy_data(simp)
        clibrebound.reb_display_prepare_data(simp, c_int(self.orbits))
        if sim.N > 0:
            self.particle_data = (c_char * (4 * 7 * sim.N)).from_address(
                sim.display_data.contents.particle_data).raw
            if self.orbits:
                self.orbit_data = (
                    c_char * (4 * 9 * (sim.N - 1))).from_address(
                        sim.display_data.contents.orbit_data).raw
        if size_changed:
            #TODO: Implement better GPU size change
            pass
        self.N = sim.N
        self.t = sim.t
        self.count += 1

    def takeScreenshot(self,
                       times=None,
                       prefix="./screenshot",
                       resetCounter=False,
                       archive=None,
                       mode="snapshot"):
        """
        Take one or more screenshots of the widget and save the images to a file. 
        The images can be used to create a video.

        This function cannot be called multiple times within one cell.

        Note: this is a new feature and might not work on all systems.
        It was tested on python 2.7.10 and 3.5.2 on MacOSX.
        
        Parameters
        ----------
        times : (float, list), optional
            If this argument is not given a screenshot of the widget will be made 
            as it is (without integrating the simulation). If a float is given, then the
            simulation will be integrated to that time and then a screenshot will 
            be taken. If a list of floats is given, the simulation will be integrated
            to each time specified in the array. A separate screenshot for 
            each time will be saved.
        prefix : (str), optional
            This string will be part of the output filename for each image.
            Follow by a five digit integer and the suffix .png. By default the
            prefix is './screenshot' which outputs images in the current
            directory with the filnames screenshot00000.png, screenshot00001.png...
            Note that the prefix can include a directory.
        resetCounter : (bool), optional
            Resets the output counter to 0. 
        archive : (rebound.SimulationArchive), optional
            Use a REBOUND SimulationArchive. Thus, instead of integratating the 
            Simulation from the current time, it will use the SimulationArchive
            to load a snapshot. See examples for usage.
        mode : (string), optional
            Mode to use when querying the SimulationArchive. See SimulationArchive
            documentation for details. By default the value is "snapshot".

        Examples
        --------

        First, create a simulation and widget. All of the following can go in 
        one cell.

        >>> sim = rebound.Simulation()
        >>> sim.add(m=1.)
        >>> sim.add(m=1.e-3,x=1.,vy=1.)
        >>> w = sim.getWidget()
        >>> w

        The widget should show up. To take a screenshot, simply call 

        >>> w.takeScreenshot()

        A new file with the name screenshot00000.png will appear in the 
        current directory. 

        Note that the takeScreenshot command needs to be in a separate cell, 
        i.e. after you see the widget. 

        You can pass an array of times to the function. This allows you to 
        take multiple screenshots, for example to create a movie, 

        >>> times = [0,10,100]
        >>> w.takeScreenshot(times)

        """
        self.archive = archive
        if resetCounter:
            self.screenshotcountall = 0
        self.screenshotprefix = prefix
        self.screenshotcount = 0
        self.screenshot = ""
        if archive is None:
            if times is None:
                times = self.simp.contents.t
            try:
                # List
                len(times)
            except:
                # Float:
                times = [times]
            self.times = times
            self.observe(savescreenshot, names="screenshot")
            self.simp.contents.integrate(times[0])
            self.screenshotcount += 1  # triggers first screenshot
        else:
            if times is None:
                raise ValueError("Need times argument for archive mode.")
            try:
                len(times)
            except:
                raise ValueError("Need a list of times for archive mode.")
            self.times = times
            self.mode = mode
            self.observe(savescreenshot, names="screenshot")
            sim = archive.getSimulation(times[0], mode=mode)
            self.refresh(pointer(sim))
            self.screenshotcount += 1  # triggers first screenshot

    @staticmethod
    def getClientCode():
        return shader_code + js_code
Example #12
0
class Unit(traitlets.HasTraits):

    attack_power = traitlets.Integer(default_value=3)
    hit_points = traitlets.Integer(default_value=200)

    location = traitlets.Tuple(traitlets.Integer(), traitlets.Integer())  # y, x

    dead = traitlets.Bool(default_value=False)

    members = []
    opponents = traitlets.Type('__main__.Unit')

    @classmethod
    def append(cls, other):
        cls.members.append(other)

    def attack(self, other):
        other.hit_points -= self.attack_power
        if other.hit_points <= 0:
            other.dead = True
            self.opponents.members.remove(other)
            print(self, 'killed', other)

    def distance(self, other):
        return cityblock(self.location, other.location)

    @property
    def target(self):
        opponent_distances = [
            self.distance(foe)
            for foe in self.opponents.members
        ]
        potential_targets = [
            foe
            for foe, distance
            in zip(self.opponents.members, opponent_distances)
            if distance == 1
        ]
        if not potential_targets:
            return None
        elif len(potential_targets) == 1:
            return potential_targets[0]
        else:
            return sorted(
                potential_targets,
                key = lambda u: (u.hit_points, *u.location)
            )[0]

    def move(self):
        # first, block out your buddies
        current_dungeon = DUNGEON.copy()

        allies = np.array([
            friend.location for friend in self.members
            if friend is not self
        ])

        if allies.size:
            # locations are stored as y, x, so:
            current_dungeon[allies[:, 0], allies[:, 1]] = -1

        foe_locations = np.array([
            foe.location
            for foe in self.opponents.members
        ])

        # and now find the costs
        mcp = MCP(current_dungeon, fully_connected=False)
        cum_costs, traceback = mcp.find_costs(
            starts=[self.location],
            #ends=foe_locations,
            find_all_ends=True
        )

        foe_distances = cum_costs[
            foe_locations[:, 0], foe_locations[:, 1]
        ]
        if np.isinf(foe_distances.min()):
            return  # no route available to any foe
        closest_foes = np.arange(len(foe_distances))[foe_distances == foe_distances.min()]
        closest_foe = sorted(
            self.opponents.members[i] for i in
            closest_foes
        )[0]

        # now you have one closest foe, reverse the distance calc
        # and move one step close

        mcp = MCP(current_dungeon, fully_connected=False)
        cum_costs, traceback = mcp.find_costs(
            ends=[self.location],
            starts=[closest_foe.location],
            find_all_ends=False
        )

        target_locations = np.argwhere(cum_costs == foe_distances.min() - 1)
        valid_locations = target_locations[(
                (target_locations >= np.array(self.location) - 1) &
                (target_locations <= np.array(self.location) + 1)
        ).all(axis=1)]

        y, x = (sorted(tuple(coords) for coords in valid_locations))[0]
        print(self, 'moving to', y, x)
        self.location = (int(y), int(x))


    def __eq__(self, other):
        return (*self.location, self.hit_points) == (*other.location, other.hit_points)

    def __lt__(self, other):
        return (*self.location, self.hit_points) < (*other.location, other.hit_points)

    def __gt__(self, other):
        return (*self.location, self.hit_points) == (*other.location, other.hit_points)

    def __repr__(self):
        return f'<{self.__class__.__name__} ap{self.attack_power} hp{self.hit_points} loc{self.location}>'

    def __add__(self, other):
        return self.hit_points + other.hit_points

    def __radd__(self, other):
        return self.hit_points + other
Example #13
0
class OptimadeQueryFilterWidget(  # pylint: disable=too-many-instance-attributes
        ipw.VBox):
    """Structure search and import widget for OPTIMADE

    NOTE: Only supports offset- and number-pagination at the moment.
    """

    structure = traitlets.Instance(Structure, allow_none=True)
    database = traitlets.Tuple(
        traitlets.Unicode(),
        traitlets.Instance(LinksResourceAttributes, allow_none=True),
    )

    def __init__(
        self,
        result_limit: int = None,
        button_style: Union[ButtonStyle, str] = None,
        embedded: bool = False,
        subparts_order: List[str] = None,
        **kwargs,
    ):
        self.page_limit = result_limit if result_limit else 25
        if button_style:
            if isinstance(button_style, str):
                button_style = ButtonStyle[button_style.upper()]
            elif isinstance(button_style, ButtonStyle):
                pass
            else:
                raise TypeError(
                    "button_style should be either a string or a ButtonStyle Enum. "
                    f"You passed type {type(button_style)!r}.")
        else:
            button_style = ButtonStyle.PRIMARY

        subparts_order = subparts_order or QueryFilterWidgetOrder.default_order(
            as_str=True)

        self.offset = 0
        self.number = 1
        self._data_available = None
        self.__perform_query = True
        self.__cached_ranges = {}
        self.__cached_versions = {}
        self.database_version = ""

        self.filter_header = ipw.HTML(
            '<h4 style="margin:0px;padding:0px;">Apply filters</h4>')
        self.filters = FilterTabs(show_large_filters=not embedded)
        self.filters.freeze()
        self.filters.on_submit(self.retrieve_data)

        self.query_button = ipw.Button(
            description="Search",
            button_style=button_style.value,
            icon="search",
            disabled=True,
            tooltip="Search - No database chosen",
        )
        self.query_button.on_click(self.retrieve_data)

        self.structures_header = ipw.HTML(
            '<h4 style="margin-bottom:0px;padding:0px;">Results</h4>')

        self.sort_selector = SortSelector(disabled=True)
        self.sorting = self.sort_selector.value
        self.sort_selector.observe(self._sort, names="value")

        self.structure_drop = StructureDropdown(disabled=True)
        self.structure_drop.observe(self._on_structure_select, names="value")
        self.error_or_status_messages = ipw.HTML("")

        self.structure_page_chooser = ResultsPageChooser(self.page_limit)
        self.structure_page_chooser.observe(
            self._get_more_results,
            names=["page_link", "page_offset", "page_number"])

        for subpart in subparts_order:
            if not hasattr(self, subpart):
                raise ValueError(
                    f"Wrongly specified subpart_order: {subpart!r}. Available subparts "
                    f"(and default order): {QueryFilterWidgetOrder.default_order(as_str=True)}"
                )

        super().__init__(
            children=[getattr(self, _) for _ in subparts_order],
            layout=ipw.Layout(width="auto", height="auto"),
            **kwargs,
        )

    @traitlets.observe("database")
    def _on_database_select(self, _):
        """Load chosen database"""
        self.structure_drop.reset()

        if (self.database[1] is None
                or getattr(self.database[1], "base_url", None) is None):
            self.query_button.tooltip = "Search - No database chosen"
            self.freeze()
        else:
            self.offset = 0
            self.number = 1
            self.structure_page_chooser.silent_reset()
            try:
                self.freeze()

                self.query_button.description = "Updating ..."
                self.query_button.icon = "cog"
                self.query_button.tooltip = "Updating filters ..."

                self._set_intslider_ranges()
                self._set_version()
            except Exception as exc:  # pylint: disable=broad-except
                LOGGER.error(
                    "Exception raised during setting IntSliderRanges: %s",
                    exc.with_traceback(),
                )
            finally:
                self.query_button.description = "Search"
                self.query_button.icon = "search"
                self.query_button.tooltip = "Search"
                self.sort_selector.valid_fields = sorted(
                    get_sortable_fields(self.database[1].base_url))
                self.unfreeze()

    def _on_structure_select(self, change):
        """Update structure trait with chosen structure dropdown value"""
        chosen_structure = change["new"]
        if chosen_structure is None:
            self.structure = None
            with self.hold_trait_notifications():
                self.structure_drop.index = 0
        else:
            self.structure = chosen_structure["structure"]

    def _get_more_results(self, change):
        """Query for more results according to pageing"""
        if not self.__perform_query:
            self.__perform_query = True
            LOGGER.debug(
                "NOT going to perform query with change: name=%s value=%s",
                change["name"],
                change["new"],
            )
            return

        pageing: Union[int, str] = change["new"]
        LOGGER.debug(
            "Updating results with pageing change: name=%s value=%s",
            change["name"],
            pageing,
        )
        if change["name"] == "page_offset":
            self.offset = pageing
            pageing = None
        elif change["name"] == "page_number":
            self.number = pageing
            pageing = None
        else:
            # It is needed to update page_offset, but we do not wish to query again
            with self.hold_trait_notifications():
                self.__perform_query = False
                self.structure_page_chooser.update_offset()

        try:
            # Freeze and disable list of structures in dropdown widget
            # We don't want changes leading to weird things happening prior to the query ending
            self.freeze()

            # Update button text and icon
            self.query_button.description = "Updating ... "
            self.query_button.icon = "cog"
            self.query_button.tooltip = "Please wait ..."

            # Query database
            response = self._query(pageing)
            msg, _ = handle_errors(response)
            if msg:
                self.error_or_status_messages.value = msg
                return

            # Update list of structures in dropdown widget
            self._update_structures(response["data"])

            # Update pageing
            self.structure_page_chooser.set_pagination_data(
                links_to_page=response.get("links", {}), )

        finally:
            self.query_button.description = "Search"
            self.query_button.icon = "search"
            self.query_button.tooltip = "Search"
            self.unfreeze()

    def _sort(self, change: dict) -> None:
        """Perform new query with new sorting"""
        sort = change["new"]
        self.sorting = sort
        self.retrieve_data({})

    def freeze(self):
        """Disable widget"""
        self.query_button.disabled = True
        self.filters.freeze()
        self.structure_drop.freeze()
        self.structure_page_chooser.freeze()
        self.sort_selector.freeze()

    def unfreeze(self):
        """Activate widget (in its current state)"""
        self.query_button.disabled = False
        self.filters.unfreeze()
        self.structure_drop.unfreeze()
        self.structure_page_chooser.unfreeze()
        self.sort_selector.unfreeze()

    def reset(self):
        """Reset widget"""
        self.offset = 0
        self.number = 1
        with self.hold_trait_notifications():
            self.query_button.disabled = False
            self.query_button.tooltip = "Search - No database chosen"
            self.filters.reset()
            self.structure_drop.reset()
            self.structure_page_chooser.reset()
            self.sort_selector.reset()

    def _uses_new_structure_features(self) -> bool:
        """Check whether self.database_version is >= v1.0.0-rc.2"""
        critical_version = SemanticVersion("1.0.0-rc.2")
        version = SemanticVersion(self.database_version)

        LOGGER.debug("Semantic version: %r", version)

        if version.base_version > critical_version.base_version:
            return True

        if version.base_version == critical_version.base_version:
            if version.prerelease:
                return version.prerelease >= critical_version.prerelease

            # Version is bigger than critical version and is not a pre-release
            return True

        # Major.Minor.Patch is lower than critical version
        return False

    def _set_version(self):
        """Set self.database_version from an /info query"""
        base_url = self.database[1].base_url
        if base_url not in self.__cached_versions:
            # Retrieve and cache version
            response = perform_optimade_query(
                base_url=self.database[1].base_url, endpoint="/info")
            msg, _ = handle_errors(response)
            if msg:
                raise QueryError(msg)

            if "meta" not in response:
                raise QueryError(
                    f"'meta' field not found in /info endpoint for base URL: {base_url}"
                )
            if "api_version" not in response["meta"]:
                raise QueryError(
                    f"'api_version' field not found in 'meta' for base URL: {base_url}"
                )

            version = response["meta"]["api_version"]
            if version.startswith("v"):
                version = version[1:]
            self.__cached_versions[base_url] = version
            LOGGER.debug(
                "Cached version %r for base URL: %r",
                self.__cached_versions[base_url],
                base_url,
            )

        self.database_version = self.__cached_versions[base_url]

    def _set_intslider_ranges(self):
        """Update IntRangeSlider ranges according to chosen database

        Query database to retrieve ranges.
        Cache ranges in self.__cached_ranges.
        """
        defaults = {
            "nsites": {
                "min": 0,
                "max": 10000
            },
            "nelements": {
                "min": 0,
                "max": len(CHEMICAL_SYMBOLS)
            },
        }

        db_base_url = self.database[1].base_url
        if db_base_url not in self.__cached_ranges:
            self.__cached_ranges[db_base_url] = {}

        sortable_fields = check_entry_properties(
            base_url=db_base_url,
            entry_endpoint="structures",
            properties=["nsites", "nelements"],
            checks=["sort"],
        )

        for response_field in sortable_fields:
            if response_field in self.__cached_ranges[db_base_url]:
                # Use cached value(s)
                continue

            page_limit = 1

            new_range = {}
            for extremum, sort in [
                ("min", response_field),
                ("max", f"-{response_field}"),
            ]:
                query_params = {
                    "base_url": db_base_url,
                    "page_limit": page_limit,
                    "response_fields": response_field,
                    "sort": sort,
                }
                LOGGER.debug(
                    "Querying %s to get %s of %s.\nParameters: %r",
                    self.database[0],
                    extremum,
                    response_field,
                    query_params,
                )

                response = perform_optimade_query(**query_params)
                msg, _ = handle_errors(response)
                if msg:
                    raise QueryError(msg)

                if not response.get("meta", {}).get("data_available", 0):
                    new_range[extremum] = defaults[response_field][extremum]
                else:
                    new_range[extremum] = (response.get("data", [{}])[0].get(
                        "attributes", {}).get(response_field, None))

            # Cache new values
            LOGGER.debug(
                "Caching newly found range values for %s\nValue: %r",
                db_base_url,
                {response_field: new_range},
            )
            self.__cached_ranges[db_base_url].update(
                {response_field: new_range})

        if not self.__cached_ranges[db_base_url]:
            LOGGER.debug("No values found for %s, storing default values.",
                         db_base_url)
            self.__cached_ranges[db_base_url].update({
                "nsites": {
                    "min": 0,
                    "max": 10000
                },
                "nelements": {
                    "min": 0,
                    "max": len(CHEMICAL_SYMBOLS)
                },
            })

        # Set widget's new extrema
        LOGGER.debug(
            "Updating range extrema for %s\nValues: %r",
            db_base_url,
            self.__cached_ranges[db_base_url],
        )
        self.filters.update_range_filters(self.__cached_ranges[db_base_url])

    def _query(self, link: str = None) -> dict:
        """Query helper function"""
        # If a complete link is provided, use it straight up
        if link is not None:
            try:
                link = ordered_query_url(link)
                response = SESSION.get(link, timeout=TIMEOUT_SECONDS)
                if response.from_cache:
                    LOGGER.debug("Request to %s was taken from cache !", link)
                response = response.json()
            except (
                    requests.exceptions.ConnectTimeout,
                    requests.exceptions.ConnectionError,
                    requests.exceptions.ReadTimeout,
            ) as exc:
                response = {
                    "errors": {
                        "msg": "CLIENT: Connection error or timeout.",
                        "url": link,
                        "Exception": repr(exc),
                    }
                }
            except JSONDecodeError as exc:
                response = {
                    "errors": {
                        "msg": "CLIENT: Could not decode response to JSON.",
                        "url": link,
                        "Exception": repr(exc),
                    }
                }
            return response

        # Avoid structures with null positions and with assemblies.
        add_to_filter = 'NOT structure_features HAS ANY "assemblies"'
        if not self._uses_new_structure_features():
            add_to_filter += ',"unknown_positions"'

        optimade_filter = self.filters.collect_value()
        optimade_filter = ("( {} ) AND ( {} )".format(optimade_filter,
                                                      add_to_filter)
                           if optimade_filter and add_to_filter else
                           optimade_filter or add_to_filter or None)
        LOGGER.debug("Querying with filter: %s", optimade_filter)

        # OPTIMADE queries
        queries = {
            "base_url": self.database[1].base_url,
            "filter": optimade_filter,
            "page_limit": self.page_limit,
            "page_offset": self.offset,
            "page_number": self.number,
            "sort": self.sorting,
        }
        LOGGER.debug(
            "Parameters (excluding filter) sent to query util func: %s",
            {key: value
             for key, value in queries.items() if key != "filter"},
        )

        return perform_optimade_query(**queries)

    @staticmethod
    def _check_species_mass(structure: dict) -> dict:
        """Ensure species.mass is using OPTIMADE API v1.0.1 type"""
        if structure.get("attributes", {}).get("species", False):
            for species in structure["attributes"]["species"] or []:
                if not isinstance(species.get("mass", None),
                                  (list, type(None))):
                    species.pop("mass", None)
        return structure

    def _update_structures(self, data: list):
        """Update structures dropdown from response data"""
        structures = []

        for entry in data:
            # XXX: THIS IS TEMPORARY AND SHOULD BE REMOVED ASAP
            entry["attributes"]["chemical_formula_anonymous"] = None

            structure = Structure(self._check_species_mass(entry))

            formula = structure.attributes.chemical_formula_descriptive
            if formula is None:
                formula = structure.attributes.chemical_formula_reduced
            if formula is None:
                formula = structure.attributes.chemical_formula_anonymous
            if formula is None:
                formula = structure.attributes.chemical_formula_hill
            if formula is None:
                raise BadResource(
                    resource=structure,
                    fields=[
                        "chemical_formula_descriptive",
                        "chemical_formula_reduced",
                        "chemical_formula_anonymous",
                        "chemical_formula_hill",
                    ],
                    msg="At least one of the following chemical formula fields "
                    "should have a valid value",
                )

            entry_name = f"{formula} (id={structure.id})"
            structures.append((entry_name, {"structure": structure}))

        # Update list of structures in dropdown widget
        self.structure_drop.set_options(structures)

    def retrieve_data(self, _):
        """Perform query and retrieve data"""
        self.offset = 0
        self.number = 1
        try:
            # Freeze and disable list of structures in dropdown widget
            # We don't want changes leading to weird things happening prior to the query ending
            self.freeze()

            # Reset the error or status message
            if self.error_or_status_messages.value:
                self.error_or_status_messages.value = ""

            # Update button text and icon
            self.query_button.description = "Querying ... "
            self.query_button.icon = "cog"
            self.query_button.tooltip = "Please wait ..."

            # Query database
            response = self._query()
            msg, _ = handle_errors(response)
            if msg:
                self.error_or_status_messages.value = msg
                raise QueryError(msg)

            # Update list of structures in dropdown widget
            self._update_structures(response["data"])

            # Update pageing
            if self._data_available is None:
                self._data_available = response.get("meta", {}).get(
                    "data_available", None)
            data_returned = response.get("meta",
                                         {}).get("data_returned",
                                                 len(response.get("data", [])))
            self.structure_page_chooser.set_pagination_data(
                data_returned=data_returned,
                data_available=self._data_available,
                links_to_page=response.get("links", {}),
                reset_cache=True,
            )

        except QueryError:
            self.structure_drop.reset()
            self.structure_page_chooser.reset()
            raise

        except Exception as exc:
            self.structure_drop.reset()
            self.structure_page_chooser.reset()
            raise QueryError(
                f"Bad stuff happened: {traceback.format_exc()}") from exc

        finally:
            self.query_button.description = "Search"
            self.query_button.icon = "search"
            self.query_button.tooltip = "Search"
            self.unfreeze()