class ScipyPoint(Interpolator): """Scipy Point Interpolation Attributes ---------- {interpolator_attributes} """ methods_supported = ["nearest"] method = tl.Unicode(default_value="nearest") dims_supported = ["lat", "lon"] # TODO: implement these parameters for the method 'nearest' spatial_tolerance = tl.Float(default_value=np.inf) time_tolerance = tl.Union([tl.Unicode(), tl.Instance(np.timedelta64, allow_none=True)]) @common_doc(COMMON_INTERPOLATOR_DOCS) def can_interpolate(self, udims, source_coordinates, eval_coordinates): """ {interpolator_can_interpolate} """ # TODO: make this so we don't need to specify lat and lon together # or at least throw a warning if ( "lat" in udims and "lon" in udims and not self._dim_in(["lat", "lon"], source_coordinates) and self._dim_in(["lat", "lon"], source_coordinates, unstacked=True) and self._dim_in(["lat", "lon"], eval_coordinates, unstacked=True) ): return tuple(["lat", "lon"]) # otherwise return no supported dims return tuple() @common_doc(COMMON_INTERPOLATOR_DOCS) def interpolate(self, udims, source_coordinates, source_data, eval_coordinates, output_data): """ {interpolator_interpolate} """ order = "lat_lon" if "lat_lon" in source_coordinates.dims else "lon_lat" # calculate tolerance if isinstance(eval_coordinates["lat"], UniformCoordinates1d): dlat = eval_coordinates["lat"].step else: dlat = (eval_coordinates["lat"].bounds[1] - eval_coordinates["lat"].bounds[0]) / ( eval_coordinates["lat"].size - 1 ) if isinstance(eval_coordinates["lon"], UniformCoordinates1d): dlon = eval_coordinates["lon"].step else: dlon = (eval_coordinates["lon"].bounds[1] - eval_coordinates["lon"].bounds[0]) / ( eval_coordinates["lon"].size - 1 ) tol = np.linalg.norm([dlat, dlon]) * 8 if self._dim_in(["lat", "lon"], eval_coordinates): pts = np.stack([source_coordinates[dim].coordinates for dim in source_coordinates[order].dims], axis=1) if order == "lat_lon": pts = pts[:, ::-1] pts = KDTree(pts) lon, lat = np.meshgrid(eval_coordinates["lon"].coordinates, eval_coordinates["lat"].coordinates) dist, ind = pts.query(np.stack((lon.ravel(), lat.ravel()), axis=1), distance_upper_bound=tol) mask = ind == source_data[order].size ind[mask] = 0 # This is a hack to make the select on the next line work # (the masked values are set to NaN on the following line) vals = source_data[{order: ind}] vals[mask] = np.nan # make sure 'lat_lon' or 'lon_lat' is the first dimension dims = [dim for dim in source_data.dims if dim != order] vals = vals.transpose(order, *dims).data shape = vals.shape coords = [eval_coordinates["lat"].coordinates, eval_coordinates["lon"].coordinates] coords += [source_coordinates[d].coordinates for d in dims] vals = vals.reshape(eval_coordinates["lat"].size, eval_coordinates["lon"].size, *shape[1:]) vals = UnitsDataArray(vals, coords=coords, dims=["lat", "lon"] + dims) # and transpose back to the destination order output_data.data[:] = vals.transpose(*output_data.dims).data[:] return output_data elif self._dim_in(["lat", "lon"], eval_coordinates, unstacked=True): dst_order = "lat_lon" if "lat_lon" in eval_coordinates.dims else "lon_lat" src_stacked = np.stack( [source_coordinates[dim].coordinates for dim in source_coordinates[order].dims], axis=1 ) new_stacked = np.stack( [eval_coordinates[dim].coordinates for dim in source_coordinates[order].dims], axis=1 ) pts = KDTree(src_stacked) dist, ind = pts.query(new_stacked, distance_upper_bound=tol) mask = ind == source_data[order].size ind[mask] = 0 vals = source_data[{order: ind}] vals[{order: mask}] = np.nan dims = list(output_data.dims) dims[dims.index(dst_order)] = order output_data.data[:] = vals.transpose(*dims).data[:] return output_data
class GroupByTransformer(Transformer): '''The GroupByTransformer creates aggregations via the groupby operation, which are joined to a DataFrame. This is useful for creating aggregate features. Example: >>> import vaex >>> import vaex.ml >>> df_train = vaex.from_arrays(x=['dog', 'dog', 'dog', 'cat', 'cat'], y=[2, 3, 4, 10, 20]) >>> df_test = vaex.from_arrays(x=['dog', 'cat', 'dog', 'mouse'], y=[5, 5, 5, 5]) >>> group_trans = vaex.ml.GroupByTransformer(by='x', agg={'mean_y': vaex.agg.mean('y')}, rsuffix='_agg') >>> group_trans.fit_transform(df_train) # x y x_agg mean_y 0 dog 2 dog 3 1 dog 3 dog 3 2 dog 4 dog 3 3 cat 10 cat 15 4 cat 20 cat 15 >>> group_trans.transform(df_test) # x y x_agg mean_y 0 dog 5 dog 3.0 1 cat 5 cat 15.0 2 dog 5 dog 3.0 3 mouse 5 -- -- ''' snake_name = 'groupby_transformer' by = traitlets.Unicode(allow_none=False, help='The feature on which to do the grouping.') agg = traitlets.Dict( help= 'Dict where the keys are feature names and the values are vaex.agg objects.' ) rprefix = traitlets.Unicode( default_value='', help= 'Prefix for the names of the aggregate features in case of a collision.' ) rsuffix = traitlets.Unicode( default_value='', help= 'Suffix for the names of the aggregate features in case of a collision.' ) df_group_ = traitlets.Instance(klass=vaex.dataframe.DataFrame, allow_none=True) def fit(self, df): ''' Fit GroupByTransformer to the DataFrame. :param df: A vaex DataFrame. ''' if not self.agg: raise ValueError( 'You have to specify a dict for the `agg` keyword.') if len(self.by) == 0: raise ValueError('Please specify a value for the `by` keyword.') self.df_group_ = df.groupby(by=self.by, agg=self.agg) def transform(self, df): ''' Transform a DataFrame with a fitted GroupByTransformer. :param df: A vaex DataFrame. :returns copy: a shallow copy of the DataFrame that includes the aggregated features. :rtype: DataFrame ''' df = df.copy() # We effectively want to do a join, but since that is not part of the state, it will not be state # transferrable, instead we implement this with map # df = df.join(other=self.df_group_, on=self.by, how='left', rprefix=self.rprefix, rsuffix=self.rsuffix) key_values = self.df_group_[self.by].tolist() for name in self.df_group_.get_column_names(): if name == self.by: continue # we don't need to include the column we group/join on mapper = dict(zip(key_values, self.df_group_[name].values)) join_name = name if join_name in df: join_name = self.rprefix + join_name + self.rsuffix df[join_name] = df[self.by].map(mapper, allow_missing=True) return df
class LayerPicker(ipywidgets.HBox): """ Widget to pick a WorkflowsLayer from a map In subclasses, set `_attr` to the trait name on WorkflowsLayer that you want mirrored into the `value` trait of this class. Attributes ---------- value: ImageCollection, None The parametrized ImageCollection of the currently-selected layer. """ value = traitlets.Instance(klass=ImageCollection, allow_none=True, read_only=True) _attr = "value" def __init__( self, map=None, default_layer: Optional[WorkflowsLayer] = None, hide_deps_of: Optional[Proxytype] = None, **kwargs, ): """ Construct a LayerPicker widget for a map. Parameters ---------- map: ipyleaflet.Map The map instance to pick from. Defaults to `wf.map`. default_layer: WorkflowsLayer The layer instance to have selected by default hide_deps_of: Proxytype Hide any layers from the dropdown that have this object in their ``.params``. Mainly used by the Picker parameter widget to hide its own layer from the dropdown, avoiding graft cycles. """ super().__init__(**kwargs) if map is None: # use wf.map as default from . import map # awkwardly handle MapApp without circularly importing it for an isinstance check try: map = map.map except AttributeError: pass self._map = map self._hide_deps_of = hide_deps_of self._dropdown = ipywidgets.Dropdown(equals=operator.is_) type_ = type(self).value.klass if default_layer is not None: if not isinstance(default_layer, WorkflowsLayer): raise TypeError( f"Default values for an {type(self).__name__} can only be WorkflowsLayer instances " f"(the layer object returned by `.visualize`), not {default_layer!r}." "Also note that this default value won't be synced when publishing." ) value = getattr(default_layer, self._attr) if not isinstance(value, type_): raise TypeError( f"Expected a default layer visualizing an {type_.name}, not an {type(value).__name__}. " "Pick a different layer, or pick a different type for this widget, or remove a " "reduction operation (like `.mosaic()`, `.mean('images')`) from the code that " f"produces the layer {default_layer.name!r}") self.set_trait("value", value) default_layer.observe(self._picked_layer_value_changes, self._attr) self._picked_layer = default_layer map.observe(self._update_options, "layers") self._dropdown.observe(self._layer_picked, "value") self.children = [self._dropdown] self._setting_options = False self._update_options({}) def _update_options(self, change): type_ = type(self).value.klass options = [(lyr.name, lyr) for lyr in reversed(self._map.layers) if isinstance(lyr, WorkflowsLayer) and isinstance(getattr(lyr, self._attr), type_) and all( p is not self._hide_deps_of for p in lyr.xyz_obj.params)] # when changing options, ipywidgets always just picks the first option. # this is infuriatingly difficult to work around, so we set our own flag to ignore # changes while this is happening. self._setting_options = True self._dropdown.options = options self._setting_options = False try: self._dropdown.value = self._picked_layer except traitlets.TraitError: # the previously-picked layer doesn't exist anymore; # we'd rather just have no value in that case self._dropdown.value = None self._picked_layer = None self.set_trait("value", None) def _layer_picked(self, change): new_layer = change["new"] if self._setting_options or new_layer is self._picked_layer: return if self._picked_layer is not None: self._picked_layer.unobserve(self._picked_layer_value_changes, self._attr) if new_layer is None: self._picked_layer = None self.set_trait("value", None) else: new_layer.observe(self._picked_layer_value_changes, self._attr) self._picked_layer = new_layer self.set_trait("value", getattr(new_layer, self._attr)) def unlink(self): self._map.unobserve("layers", self._update_options) self._dropdown.unobserve("value", self._layer_picked) if self._picked_layer is not None: self._picked_layer.unobserve(self._picked_layer_value_changes, self._attr) self._picked_layer = None def _picked_layer_value_changes(self, change): self.set_trait("value", change["new"]) def _ipython_display_(self): super()._ipython_display_()
class InspectorRowGenerator(traitlets.HasTraits): """ Controller class that manages the name and pixel values widgets for one layer. Not a widget itself, but just exposes `name_label` and `value_labels` for the `PixelInspector` to add into its table. Listens for changes to the layer (XYZ object or parameters) or the marker (location), and updates the widgets in `value_labels` appropriately, by calling `inspect` in a separate thread to pull pixel values. """ _value_layout = {"width": "initial", "margin": "0 2px", "height": "1.6em"} name_label = traitlets.Instance( ipywidgets.Label, kw={"layout": dict(_value_layout, grid_column="1")}, read_only=True, allow_none=False, ) values_labels = traitlets.List(read_only=True, allow_none=False) def __init__(self, layer, marker, n_bands): self.marker = marker self.layer = layer self._updating = False self._cache = cachetools.LRUCache(64) self.name_label.value = layer.name # TODO make names bold. it's frustratingly difficult (unsupported) with ipywidgets: # https://github.com/jupyter-widgets/ipywidgets/issues/577 self._name_link = ipywidgets.jslink((layer, "name"), (self.name_label, "value")) self.set_trait( "values_labels", [ ipywidgets.Label(value="", layout=dict(self._value_layout, grid_column=str(2 + i))) for i in range(n_bands) ], ) marker.observe(self.recalculate, "geoctx", type="change") layer.observe( self.recalculate, ["image_value", "visible"], type="change", ) self._viz_links = [ traitlets.dlink( (layer, "visible"), (label.layout, "display"), lambda v: "" if v else "none", ) for label in [self.name_label] + self.values_labels ] if marker.opacity == 1: # there's already a point to sample; eagerly recalculate now self.recalculate() def unlink(self): # NOTE(gabe): the traitlets docs say name=All (the default) should work, # but careful reading of the source shows it's a no-op. # we must explicitly unobserve for exactly the names and types we observed for. self.marker.unobserve(self.recalculate, "geoctx", type="change") self.layer.unobserve( self.recalculate, ["image_value", "visible"], type="change", ) self._name_link.unlink() for viz_link in self._viz_links: viz_link.unlink() def recalculate(self, *args, **kwargs): if self._updating or not self.layer.visible or self.marker.opacity == 0: return xy_3857 = self.marker.xy_3857 # try to make a cache key from the marker location, XYZ ID, reduction, and parameters. # if the parameters are unhashable (probably because they contain grafts), # we'll consider it a cache miss and go fetch. try: params_key = frozenset(self.layer.parameters.to_dict().items()) except TypeError: cache_key = None else: cache_key = ( xy_3857, self.layer.xyz_obj.id, self.layer.reduction, params_key, ) if cache_key: try: value_list = self._cache[cache_key] except KeyError: value_list = None else: value_list = None image = self.layer.image_value if image is None: value_list = ["❓"] if value_list: self.set_values(value_list) else: self.set_updating() # NOTE(gabe): I don't trust traitlets or ipywidgets to be thread-safe, # so we pull all values out of traits here and pass them in to the thread directly ctx = self.marker.geoctx thread = threading.Thread( target=self._fetch_and_set_thread, args=(image, xy_3857, ctx, cache_key), daemon=True, ) thread.start() def _fetch_and_set_thread(self, image, xy_3857, ctx, cache_key): proxy_value_list = image.value_at(*xy_3857).values() try: value_list = proxy_value_list.inspect(ctx) except JobTimeoutError: value_list = ["⏱"] except Exception: value_list = ["💥"] else: if len(value_list) == 0: # empty Image value_list = [np.ma.masked] self._cache[cache_key] = value_list self.set_values(value_list) def set_values(self, new_values_list): for i, value in enumerate(new_values_list): if isinstance(value, str): pass elif value is np.ma.masked: new_values_list[i] = "∅" else: new_values_list[i] = "{:.6g}".format(value) for i, label in enumerate(self.values_labels): try: label.value = new_values_list[i] except IndexError: label.value = "" self._updating = False def set_updating(self): self._updating = True for i, label in enumerate(self.values_labels): if label.value != "" or i == 0: label.value = "..."
class Robot(SingletonConfigurable): front_left_motor = traitlets.Instance(Motor) front_right_motor = traitlets.Instance(Motor) back_left_motor = traitlets.Instance(Motor) back_right_motor = traitlets.Instance(Motor) # config front_left_motor_channel = traitlets.Integer(default_value=1).tag(config=True) front_left_motor_alpha = traitlets.Float(default_value=1.0).tag(config=True) front_right_motor_channel = traitlets.Integer(default_value=2).tag(config=True) front_right_motor_alpha = traitlets.Float(default_value=1.0).tag(config=True) back_left_motor_channel = traitlets.Integer(default_value=3).tag(config=True) back_left_motor_alpha = traitlets.Float(default_value=1.0).tag(config=True) back_right_motor_channel = traitlets.Integer(default_value=4).tag(config=True) back_right_motor_alpha = traitlets.Float(default_value=1.0).tag(config=True) def __init__(self, *args, **kwargs): super(Robot, self).__init__(*args, **kwargs) self.left_motor_driver = PCA9685(0x41, debug=False) self.right_motor_driver = PCA9685(0x40, debug=False) self.front_left_motor = Motor(self.left_motor_driver, channel=self.front_left_motor_channel, alpha=self.front_left_motor_alpha) self.front_right_motor = Motor(self.right_motor_driver, channel=self.front_right_motor_channel, alpha=self.front_right_motor_alpha) self.back_left_motor = Motor(self.left_motor_driver, channel=self.back_left_motor_channel, alpha=self.back_left_motor_alpha) self.back_right_motor = Motor(self.right_motor_driver, channel=self.back_right_motor_channel, alpha=self.back_right_motor_alpha) def set_motors(self, front_left_speed, front_right_speed, back_left_speed, back_right_speed): self.front_left_motor.value = front_left_speed self.front_right_motor.value = front_right_speed self.back_left_motor.value = back_left_speed self.back_right_motor.value = back_right_speed def forward(self, speed=1.0): self.front_left_motor.value = speed self.front_right_motor.value = speed self.back_left_motor.value = speed self.back_right_motor.value = speed def backward(self, speed=1.0): self.front_left_motor.value = -speed self.front_right_motor.value = -speed self.back_left_motor.value = -speed self.back_right_motor.value = -speed def left(self, speed=1.0): self.front_left_motor.value = -speed self.front_right_motor.value = speed self.back_left_motor.value = -speed self.back_right_motor.value = speed def right(self, speed=1.0): self.front_left_motor.value = speed self.front_right_motor.value = -speed self.back_left_motor.value = speed self.back_right_motor.value = -speed def stop(self): self.front_left_motor.value = 0 self.front_right_motor.value = 0 self.back_left_motor.value = 0 self.back_right_motor.value = 0 def forward_left(self, speed=1.0): self.front_left_motor.value = speed / 3 self.front_right_motor.value = speed self.back_left_motor.value = speed / 3 self.back_right_motor.value = speed def forward_right(self, speed=1.0): self.front_left_motor.value = speed self.front_right_motor.value = speed / 3 self.back_left_motor.value = speed self.back_right_motor.value = speed / 3 def backward_left(self, speed=1.0): self.front_left_motor.value = -speed / 3 self.front_right_motor.value = -speed self.back_left_motor.value = -speed / 3 self.back_right_motor.value = -speed def backward_right(self, speed=1.0): self.front_left_motor.value = -speed self.front_right_motor.value = -speed / 3 self.back_left_motor.value = -speed self.back_right_motor.value = -speed / 3
class DownloadChooser(ipw.HBox): """Download chooser for structure download To be able to have the download button work no matter the widget's final environment, (as long as it supports JavaScript), the very helpful insight from the following page is used: https://stackoverflow.com/questions/2906582/how-to-create-an-html-button-that-acts-like-a-link """ chosen_format = traitlets.Tuple(traitlets.Unicode(), traitlets.Dict()) structure = traitlets.Instance(Structure, allow_none=True) _formats = [ ( "Crystallographic Information File v1.0 (.cif)", { "ext": ".cif", "adapter_format": "cif" }, ), ("Protein Data Bank (.pdb)", { "ext": ".pdb", "adapter_format": "pdb" }), ( "Crystallographic Information File v1.0 [via ASE] (.cif)", { "ext": ".cif", "adapter_format": "ase", "final_format": "cif" }, ), ( "Protein Data Bank [via ASE] (.pdb)", { "ext": ".pdb", "adapter_format": "ase", "final_format": "proteindatabank" }, ), ( "XMol XYZ File [via ASE] (.xyz)", { "ext": ".xyz", "adapter_format": "ase", "final_format": "xyz" }, ), ( "XCrySDen Structure File [via ASE] (.xsf)", { "ext": ".xsf", "adapter_format": "ase", "final_format": "xsf" }, ), ( "WIEN2k Structure File [via ASE] (.struct)", { "ext": ".struct", "adapter_format": "ase", "final_format": "struct" }, ), ( "VASP POSCAR File [via ASE]", { "ext": "", "adapter_format": "ase", "final_format": "vasp" }, ), ( "Quantum ESPRESSO File [via ASE] (.in)", { "ext": ".in", "adapter_format": "ase", "final_format": "espresso-in" }, ), # Not yet implemented: # ( # "Protein Data Bank, macromolecular CIF v1.1 (PDBx/mmCIF) (.cif)", # {"ext": "cif", "adapter_format": "pdbx_mmcif"}, # ), ] _download_button_format = """ <input type="button" class="jupyter-widgets jupyter-button widget-button" value="Download" title="Download structure" style="width:auto;" {disabled} onclick="var link = document.createElement('a'); link.href = 'data:charset={encoding};base64,{data}'; link.download = '{filename}'; document.body.appendChild(link); link.click(); document.body.removeChild(link);" /> """ def __init__(self, **kwargs): self.dropdown = ipw.Dropdown(options=("Select a format", {}), width="auto") self.download_button = ipw.HTML( self._download_button_format.format(disabled="disabled", encoding="", data="", filename="")) self.children = (self.dropdown, self.download_button) super().__init__(children=self.children, layout={"width": "auto"}) self.reset() self.dropdown.observe(self._update_download_button, names="value") @traitlets.observe("structure") def _on_change_structure(self, change: dict): """Update widget when a new structure is chosen""" if change["new"] is None: self.reset() else: self._update_options() self.unfreeze() def _update_options(self): """Update options according to chosen structure""" # Disordered structures not usable with ASE if "disorder" in self.structure.structure_features: options = sorted([ option for option in self._formats if option[1].get("adapter_format", "") != "ase" ]) options.insert(0, ("Select a format", {})) else: options = sorted(self._formats) options.insert(0, ("Select a format", {})) self.dropdown.options = options def _update_download_button(self, change: dict): """Update Download button with correct onclick value The whole parsing process from `Structure` to desired format, is wrapped in a try/except, which is further wrapped in a `warnings.catch_warnings()`. This is in order to be able to log any warnings that might be thrown by the adapter in `optimade-python-tools` and/or any related exceptions. """ desired_format = change["new"] if not desired_format or desired_format is None: self.download_button.value = self._download_button_format.format( disabled="disabled", encoding="", data="", filename="") return with warnings.catch_warnings(): warnings.filterwarnings("error") try: output = getattr(self.structure, f"as_{desired_format['adapter_format']}") if desired_format["adapter_format"] in ( "ase", "pymatgen", "aiida_structuredata", ): # output is not a file, but a proxy Python class func = getattr( self, f"_get_via_{desired_format['adapter_format']}") output = func( output, desired_format=desired_format["final_format"]) encoding = "utf-8" # Specifically for CIF: v1.x CIF needs to be in "latin-1" formatting if desired_format["ext"] == ".cif": encoding = "latin-1" filename = ( f"optimade_structure_{self.structure.id}{desired_format['ext']}" ) if isinstance(output, str): output = output.encode(encoding) data = base64.b64encode(output).decode() except Warning as warn: self.download_button.value = self._download_button_format.format( disabled="disabled", encoding="", data="", filename="") warnings.warn(OptimadeClientWarning(warn)) except Exception as exc: self.download_button.value = self._download_button_format.format( disabled="disabled", encoding="", data="", filename="") if isinstance(exc, exceptions.OptimadeClientError): raise exc # Else wrap the exception to make sure to log it. raise exceptions.OptimadeClientError(exc) else: self.download_button.value = self._download_button_format.format( disabled="", encoding=encoding, data=data, filename=filename) @staticmethod def _get_via_pymatgen( structure_molecule: Union[pymatgenStructure, pymatgenMolecule], desired_format: str, ) -> str: """Use pymatgen.[Structure,Molecule].to() method""" molecule_only_formats = ["xyz", "pdb"] structure_only_formats = ["xsf", "cif"] if desired_format in molecule_only_formats and not isinstance( structure_molecule, pymatgenMolecule): raise exceptions.WrongPymatgenType( f"Converting to '{desired_format}' format is only possible with a pymatgen." f"Molecule, instead got {type(structure_molecule)}") if desired_format in structure_only_formats and not isinstance( structure_molecule, pymatgenStructure): raise exceptions.WrongPymatgenType( f"Converting to '{desired_format}' format is only possible with a pymatgen." f"Structure, instead got {type(structure_molecule)}.") return structure_molecule.to(fmt=desired_format) @staticmethod def _get_via_ase(atoms: aseAtoms, desired_format: str) -> Union[str, bytes]: """Use ase.Atoms.write() method""" with tempfile.NamedTemporaryFile(mode="w+b") as temp_file: atoms.write(temp_file.name, format=desired_format) res = temp_file.read() return res def freeze(self): """Disable widget""" for widget in self.children: widget.disabled = True def unfreeze(self): """Activate widget (in its current state)""" for widget in self.children: widget.disabled = False def reset(self): """Reset widget""" self.dropdown.index = 0 self.freeze()
class ProviderImplementationSummary(ipw.GridspecLayout): """Summary/description of chosen provider and their database""" provider = traitlets.Instance(LinksResourceAttributes, allow_none=True) database = traitlets.Instance(LinksResourceAttributes, allow_none=True) text_style = "margin:0px;padding-top:6px;padding-bottom:4px;padding-left:4px;padding-right:4px;" def __init__(self, **kwargs): self.provider_summary = ipw.HTML() provider_section = ipw.VBox( children=[self.provider_summary], layout=ipw.Layout(width="auto", height="auto"), ) self.database_summary = ipw.HTML() database_section = ipw.VBox( children=[self.database_summary], layout=ipw.Layout(width="auto", height="auto"), ) super().__init__( n_rows=1, n_columns=31, layout={ "border": "solid 0.5px darkgrey", "margin": "0px 0px 0px 0px", "padding": "0px 0px 10px 0px", }, **kwargs, ) self[:, :15] = provider_section self[:, 16:] = database_section self.observe(self._on_provider_change, names="provider") self.observe(self._on_database_change, names="database") def _on_provider_change(self, change: dict): """Update provider summary, since self.provider has been changed""" LOGGER.debug("Provider changed in summary. New value: %r", change["new"]) self.database_summary.value = "" if not change["new"] or change["new"] is None: self.provider_summary.value = "" else: self._update_provider() def _on_database_change(self, change): """Update database summary, since self.database has been changed""" LOGGER.debug("Database changed in summary. New value: %r", change["new"]) if not change["new"] or change["new"] is None: self.database_summary.value = "" else: self._update_database() def _update_provider(self): """Update provider summary""" html_text = f"""<strong style="line-height:1;{self.text_style}">{getattr(self.provider, 'name', 'Provider')}</strong> <p style="line-height:1.2;{self.text_style}">{getattr(self.provider, 'description', '')}</p>""" self.provider_summary.value = html_text def _update_database(self): """Update database summary""" html_text = f"""<strong style="line-height:1;{self.text_style}">{getattr(self.database, 'name', 'Database')}</strong> <p style="line-height:1.2;{self.text_style}">{getattr(self.database, 'description', '')}</p>""" self.database_summary.value = html_text def freeze(self): """Disable widget""" def unfreeze(self): """Activate widget (in its current state)""" def reset(self): """Reset widget""" self.provider = None
class QueryConstructor(W.HBox): """TODO - way better templating and more efficient formatting - replace individual observers with larger observer - move build_query to standalone function """ convert_arrow = T.Instance(W.Image) query_input = T.Instance(W.VBox) formatted_query = T.Instance(QueryColorizer, kw=dict(layout=W.Layout(max_height="260px"))) # traits from children namespaces = T.Unicode() query_type = T.Unicode(default_value="SELECT") query_line = T.Unicode(allow_none=True) query_body = T.Unicode() query = T.Unicode() log = W.Output() def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.query_input = QueryInput() # Inherit traits with links TODO easier way? T.link((self.query_input.namespaces, "namespaces"), (self, "namespaces")) T.link((self.query_input.header, "dropdown_value"), (self, "query_type")) T.link((self.query_input.header, "header_value"), (self, "query_line")) T.link((self.query_input.body.body, "value"), (self, "query_body")) T.dlink((self, "query"), (self.formatted_query, "query")) self.children = tuple([self.query_input, self.formatted_query]) @log.capture() def build_query(self): # get values TODO improve namespaces = self.namespaces query_type = self.query_type query_line = self.query_line query_body = self.query_body or self.query_input.body.body.placeholder # update query_body query_body = "\t\n".join(query_body.split( "\n")) # TODO this isn't actually formatting properly header_str = "" # TODO move these to module vars if query_type in {"SELECT", "SELECT DISTINCT"}: if query_line == "": query_line = "*" header_str = f"{query_type} {query_line}" elif query_type == "ASK": header_str = query_type elif query_type == "CONSTRUCT": if query_line == "": query_line = "{?s ?p ?o}" header_str = f"{query_type} {query_line}" else: with self.log: raise ValueError(f"Unexpected query type: {query_type}") return query_template.format( namespaces, header_str, query_body, ) @T.observe( "namespaces", "query_type", "query_line", "query_body", ) def update_query(self, change): self.query = self.build_query()
class DataArray(_HasState): class Status(enum.Enum): MISSING_LIMITS = 1 STAGED_CALCULATING_LIMITS = 3 CALCULATING_LIMITS = 4 CALCULATED_LIMITS = 5 NEEDS_CALCULATING_GRID = 6 STAGED_CALCULATING_GRID = 7 CALCULATING_GRID = 8 CALCULATED_GRID = 9 READY = 10 EXCEPTION = 11 status = traitlets.UseEnum(Status, Status.MISSING_LIMITS) status_text = traitlets.Unicode('Initializing') exception = traitlets.Any(None) df = traitlets.Instance(vaex.dataframe.DataFrame) axes = traitlets.List(traitlets.Instance(Axis), []) grid = traitlets.Instance(xarray.DataArray, allow_none=True) grid_sliced = traitlets.Instance(xarray.DataArray, allow_none=True) shape = traitlets.CInt(64) selections = traitlets.List( traitlets.Union([traitlets.Bool(), traitlets.Unicode(allow_none=True)]), [None]) def __init__(self, **kwargs): super(DataArray, self).__init__(**kwargs) self.signal_slice = vaex.events.Signal() self.signal_regrid = vaex.events.Signal() self.signal_grid_progress = vaex.events.Signal() self.observe(lambda change: self.signal_regrid.emit(), 'selections') self._on_axis_status_change() # keep a set of axis that need new limits self._dirty_axes = set() for axis in self.axes: assert axis.df is self.df, "axes should have the same dataframe" traitlets.link((self, 'shape'), (axis, 'shape_default')) axis.observe(self._on_axis_status_change, 'status') axis.observe(lambda _: self.signal_slice.emit(self), ['slice']) def on_change_min_max(change): if change.owner.status == Axis.Status.READY: # this indicates a user changed the min/max self.status = DataArray.Status.NEEDS_CALCULATING_GRID axis.observe(on_change_min_max, ['min', 'max']) self._on_axis_status_change() self.df.signal_selection_changed.connect(self._on_change_selection) def _on_change_selection(self, df, name): # TODO: check if the selection applies to us self.status = DataArray.Status.NEEDS_CALCULATING_GRID async def _allow_state_change_cancel(self): self._allow_state_change.release() def _on_axis_status_change(self, change=None): missing_limits = [ axis for axis in self.axes if axis.status == Axis.Status.NO_LIMITS ] staged_calculating_limits = [ axis for axis in self.axes if axis.status == Axis.Status.STAGED_CALCULATING_LIMITS ] calculating_limits = [ axis for axis in self.axes if axis.status == Axis.Status.CALCULATING_LIMITS ] calculated_limits = [ axis for axis in self.axes if axis.status == Axis.Status.CALCULATED_LIMITS ] def names(axes): return ", ".join([str(axis.expression) for axis in axes]) if staged_calculating_limits: self.status = DataArray.Status.STAGED_CALCULATING_LIMITS self.status_text = 'Staged limit computation for {}'.format( names(staged_calculating_limits)) elif missing_limits: self.status = DataArray.Status.MISSING_LIMITS self.status_text = 'Missing limits for {}'.format( names(missing_limits)) elif calculating_limits: self.status = DataArray.Status.CALCULATING_LIMITS self.status_text = 'Computing limits for {}'.format( names(calculating_limits)) elif calculated_limits: self.status = DataArray.Status.CALCULATED_LIMITS self.status_text = 'Computed limits for {}'.format( names(calculating_limits)) else: assert all( [axis.status == Axis.Status.READY for axis in self.axes]) self.status = DataArray.Status.NEEDS_CALCULATING_GRID @traitlets.observe('status') def _on_change_status(self, change): if self.status == DataArray.Status.EXCEPTION: self.status_text = f'Exception: {self.exception}' elif self.status == DataArray.Status.NEEDS_CALCULATING_GRID: self.status_text = 'Grid needs to be calculated' elif self.status == DataArray.Status.STAGED_CALCULATING_GRID: self.status_text = 'Staged grid computation' elif self.status == DataArray.Status.CALCULATING_GRID: self.status_text = 'Calculating grid' elif self.status == DataArray.Status.CALCULATED_GRID: self.status_text = 'Calculated grid' elif self.status == DataArray.Status.READY: self.status_text = 'Ready' # GridCalculator can change the status # self._update_grid() # self.status_text = 'Computing limits for {}'.format(names(missing_limits)) @property def has_missing_limits(self): return any([axis.has_missing_limit for axis in self.axes]) def on_progress_grid(self, f): return all(self.signal_grid_progress.emit(f))
class NXBase(VisualizerBase): """ The visualization class for the NXLayouts. Used by the datashader visualizations. :param _nx_layout: the desired networkx layout function to be used. :param _layouts: a dictionary mapping labels of known layouts to networkx functions. Notes: - not all networkx layouts work without custom node/edge data or graph_layout_params and are NOT_HANDLED by default, but can be set explicitly """ NOT_HANDLED = [ "bipartite_layout", "multipartite_layout", "rescale_layout", ] _layouts = T.Dict() _nx_layout = T.Instance(types.FunctionType) @T.default("graph_layout_options") def _make_default_options(self): return tuple(self._layouts) @T.default("_layouts") def _make_default_layouts(self): """these are leniently loaded, as the exact set of algorithms depends heavily on the version of networkx installed """ layouts = {} for layout_key in nx_layout.__all__: if layout_key in self.NOT_HANDLED: continue try: layout = getattr(nx_layout, layout_key) label = _make_nx_layout_label(layout_key) layouts[label] = layout except Exception as err: self.log.warning( "Expected to be able to load from networkx: %s\n%s", layout_key, err) return layouts @T.default("graph_layout") def _make_default_layout(self): return self.graph_layout_options[0] @T.default("_nx_layout") def set_default_nx_layout(self): return self._layouts[self.graph_layout] @T.observe("graph_layout") def _update_graph_layout(self, change): if change.new is None: self._nx_layout = sorted(self._layouts.items())[0][1] return if change.new in self._layouts: self._nx_layout = self._layouts[self.graph_layout] return try: self._nx_layout = getattr(nx_layout, change.new) except Exception as err: self.log.warning("Could not load from networkx: %s\n%s", change.new, err)
class XELK(ElkTransformer): """NetworkX DiGraphs to ELK dictionary structure""" HIDDEN_ATTR = "hidden" hoist_hidden_edges: bool = True source = T.Tuple(T.Instance(nx.Graph), T.Instance(nx.DiGraph, allow_none=True)) layouts = T.Dict() # keys: networkx nodes {ElkElements: {layout options}} css_classes = T.Dict() port_scale = T.Int(default_value=5) label_key = T.Unicode(default_value="labels") port_key = T.Unicode(default_value="ports") @T.default("source") def _default_source(self): return (nx.Graph(), None) @T.default("text_sizer") def _default_text_sizer(self): return ElkTextSizer() @T.default("layouts") def _default_layouts(self): parent_opts = opt.OptionsWidget( identifier="parents", options=[ opt.HierarchyHandling(), ], ) label_opts = opt.OptionsWidget( identifier=ElkLabel, options=[opt.NodeLabelPlacement(horizontal="center")]) node_opts = opt.OptionsWidget( identifier=ElkNode, options=[ opt.NodeSizeConstraints(), ], ) default = opt.OptionsWidget( options=[parent_opts, node_opts, label_opts]) return {ElkRoot: default.value} def node_id(self, node: Hashable) -> str: """Get the element id for a node in the main graph for use in elk :param node: Node in main graph :type node: Hashable :return: Element ID :rtype: str """ g, tree = self.source if node is ElkRoot: return self.ELK_ROOT_ID elif node in g: return g.nodes.get(node, {}).get("id", f"{node}") return f"{node}" def port_id(self, node: Hashable, port: Optional[Hashable] = None) -> str: """Get a suitable Elk identifier from the node and port :param node: Node from the incoming networkx graph :type node: Hashable :param port: Port identifier, defaults to None :type port: Optional[Hashable], optional :return: If no port is provided will refer to the node :rtype: str """ if port is None: return self.node_id(node) else: return f"{self.node_id(node)}.{port}" def edge_id(self, edge: Edge): # TODO probably will need more sophisticated id generation in future return "{}__{}___{}__{}".format(edge.source, edge.source_port, edge.target, edge.target_port) async def transform(self) -> ElkNode: """Generate ELK dictionary structure :return: Root Elk node :rtype: ElkNode """ # TODO decide behavior for nodes that exist in the tree but not g g, tree = self.source self.clear_registry() visible_edges, hidden_edges = self.collect_edges() # Process visible networkx nodes into elknodes visible_nodes = [ n for n in g.nodes() if not is_hidden(tree, n, self.HIDDEN_ATTR) ] # make elknodes then connect their hierarchy elknodes: NodeMap = {} ports: PortMap = {} for node in visible_nodes: elknode, node_ports = await self.make_elknode(node) for key, value in node_ports.items(): ports[key] = value elknodes[node] = elknode # make top level ElkNode and attach all others as children elknodes[ElkRoot] = top = ElkNode( id=self.ELK_ROOT_ID, children=build_hierarchy(g, tree, elknodes, self.HIDDEN_ATTR), layoutOptions=self.get_layout(ElkRoot, "parents"), ) # map of original nodes to the generated elknodes for node, elk_node in elknodes.items(): self.register(elk_node, node) elknodes, ports = await self.process_edges(elknodes, ports, visible_edges) # process edges with one or both original endpoints are hidden if self.hoist_hidden_edges: elknodes, ports = await self.process_edges( elknodes, ports, hidden_edges, edge_style={"slack-edge"}, port_style={"slack-port"}, ) # iterate through the port map and add ports to ElkNodes for port_id, port in ports.items(): owner = port.node elkport = port.elkport if owner not in elknodes: # TODO skip generating port to begin with break elknode = elknodes[owner] if elknode.ports is None: elknode.ports = [] layout = self.get_layout(owner, ElkPort) elkport.layoutOptions = merge(elkport.layoutOptions, layout) elknode.ports += [elkport] # map of ports to the generated elkports self.register(elkport, port) # bulk calculate label sizes await size_labels(self.text_sizer, collect_labels([top])) return top async def make_elknode(self, node) -> Tuple[ElkNode, PortMap]: # merge layout options defined on the node data with default layout # options node_data = self.get_node_data(node) layout = merge( node_data.get("layoutOptions", {}), self.get_layout(node, ElkNode), ) labels = await self.make_labels(node) # update port map with declared ports in the networkx node data node_ports = await self.collect_ports(node) properties = self.get_properties(node, self.get_css(node, ElkNode)) or None elk_node = ElkNode( id=self.node_id(node), labels=labels, layoutOptions=layout, properties=properties, width=node_data.get("width", None), height=node_data.get("height", None), ) return elk_node, node_ports def get_layout(self, node: Hashable, elk_type: Type[ElkGraphElement]) -> Optional[Dict]: """Get the Elk Layout Options appropriate for given networkx node and filter by given elk_type :param node: [description] :type node: Hashable :param elk_type: [description] :type elk_type: Type[ElkGraphElement] :return: [description] :rtype: [type] """ # TODO look at self.source hierarchy and resolve layout with added # infomation. until then use root node `None` for layout options if node not in self.layouts: node = ElkRoot type_opts = self.layouts.get(node, {}) options = {**type_opts.get(elk_type, {})} if options: return options def get_properties( self, element: Optional[Union[Hashable, str]], dom_classes: Optional[Set[str]] = None, ) -> Dict: """Get the properties for a graph element :param element: Networkx node or edge :type node: Hashable :param dom_classes: Set of base CSS DOM classes to merge, defaults to Set[str]=None :type dom_classes: [type], optional :return: Set of CSS Classes to apply :rtype: Set[str] """ g, tree = self.source properties = [] if g and element in g: g_props = g.nodes[element].get("properties", {}) if g_props: properties += [g_props] if hasattr(element, "data"): properties += [element.data.get("properties", {})] elif isinstance(element, dict): properties += [element.get("properties", {})] if dom_classes: properties += [{"cssClasses": " ".join(dom_classes)}] if not properties: return {} elif len(properties) == 1: return properties[0] merged_properties = {} for props in properties[::-1]: merged_properties = merge(props, merged_properties) return merged_properties def get_css( self, node: Hashable, elk_type: Type[ElkGraphElement], dom_classes: Set[str] = None, ) -> Set[str]: """Get the CSS Classes appropriate for given networkx node given elk_type :param node: Networkx node :type node: Hashable :param elk_type: ElkGraphElement to get appropriate css classes :type elk_type: Type[ElkGraphElement] :param dom_classes: Set of base CSS DOM classes to merge, defaults to Set[str]=None :type dom_classes: [type], optional :return: Set of CSS Classes to apply :rtype: Set[str] """ typed_css = self.css_classes.get(node, {}) css_classes = set(typed_css.get(elk_type, [])) if dom_classes is None: return css_classes return css_classes | dom_classes async def process_edges( self, nodes: NodeMap, ports: PortMap, edges: EdgeMap, edge_style: Set[str] = None, port_style: Set[str] = None, ) -> Tuple[NodeMap, PortMap]: for owner, edge_list in edges.items(): edge_css = self.get_css(owner, ElkEdge, edge_style) port_css = self.get_css(owner, ElkPort, port_style) for edge in edge_list: elknode = nodes[owner] if elknode.edges is None: elknode.edges = [] if edge.source_port is not None: port_id = self.port_id(edge.source, edge.source_port) if port_id not in ports: ports[port_id] = await self.make_port( edge.source, edge.source_port, port_css) if edge.target_port is not None: port_id = self.port_id(edge.target, edge.target_port) if port_id not in ports: ports[port_id] = await self.make_port( edge.target, edge.target_port, port_css) elknode.edges += [await self.make_edge(edge, edge_css)] return nodes, ports async def make_edge(self, edge: Edge, styles: Optional[Set[str]] = None) -> ElkExtendedEdge: """Make the associated Elk edge for the given Edge :param edge: Edge object to wrap :type edge: Edge :param styles: List of css classes to add to given Elk edge, defaults to None :type styles: Optional[List[str]], optional :return: Elk edge :rtype: ElkExtendedEdge """ labels = [] properties = self.get_properties(edge, styles) or None label_layout_options = self.get_layout( edge.owner, ElkLabel) # TODO add edgelabel type? edge_layout_options = self.get_layout(edge.owner, ElkEdge) for i, label in enumerate(edge.data.get(self.label_key, [])): if isinstance(label, ElkLabel): label = label.to_dict() # used to create copy of label if isinstance(label, dict): label = ElkLabel.from_dict(label) if isinstance(label, str): label = ElkLabel(id=f"{edge.owner}_label_{i}_{label}", text=label) label.layoutOptions = merge(label.layoutOptions, label_layout_options) labels.append(label) for label in labels: self.register(label, edge) elk_edge = ElkExtendedEdge( id=edge.data.get("id", self.edge_id(edge)), sources=[self.port_id(edge.source, edge.source_port)], targets=[self.port_id(edge.target, edge.target_port)], properties=properties, layoutOptions=merge(edge.data.get("layoutOptions"), edge_layout_options), labels=compact(labels), ) self.register(elk_edge, edge) return elk_edge async def make_port(self, owner: Hashable, port: Hashable, styles: Optional[Set[str]] = None) -> Port: """Make the associated elk port for the given owner node and port :param owner: [description] :type owner: Hashable :param port: [description] :type port: Hashable :param styles: list of css classes to apply to given ElkPort :type styles: List[str] :return: [description] :rtype: ElkPort """ port_id = self.port_id(owner, port) properties = self.get_properties(port, styles) or None elk_port = ElkPort( id=port_id, height=self.port_scale, width=self.port_scale, properties=properties, # TODO labels ) return Port(node=owner, elkport=elk_port) def get_node_data(self, node: Hashable) -> Dict: g, tree = self.source return g.nodes.get(node, {}) async def collect_ports(self, *nodes) -> PortMap: ports: PortMap = {} for node in nodes: values = self.get_node_data(node).get(self.port_key, []) for i, port in enumerate(values): if isinstance(port, ElkPort): # generate a fresh copy of the port to prevent mutating original port = port.to_dict() elif isinstance(port, str): port_id = self.port_id(node, port) port = { "id": port_id, "labels": [{ "text": port, "id": f"{port}_label_{i}", }], } if isinstance(port, dict): elkport = ElkPort.from_dict(port) if elkport.width is None: elkport.width = self.port_scale if elkport.height is None: elkport.height = self.port_scale ports[elkport.id] = Port(node=node, elkport=elkport) return ports async def make_labels(self, node: Hashable) -> Optional[List[ElkLabel]]: labels = [] g = self.source[0] data = g.nodes.get(node, {}) values = data.get(self.label_key, [data.get("_id", f"{node}")]) properties = {} css_classes = self.get_css(node, ElkLabel) if css_classes: properties["cssClasses"] = " ".join(css_classes) if isinstance(values, (str, ElkLabel)): values = [values] # get node labels for i, label in enumerate(values): if isinstance(label, str): label = ElkLabel( id=f"{label}_label_{i}_{node}", text=label, ) elif isinstance(label, ElkLabel): # prevent mutating original label in the node data label = label.to_dict() if isinstance(label, dict): label = ElkLabel.from_dict(label) # add css classes and layout options label.layoutOptions = merge(label.layoutOptions, self.get_layout(node, ElkLabel)) merged_props = merge(label.properties, properties) if merged_props is not None: merged_props = ElkProperties.from_dict(merged_props) label.properties = merged_props labels.append(label) self.register(label, node) return labels def collect_edges(self) -> Tuple[EdgeMap, EdgeMap]: """[summary] :return: [description] :rtype: Tuple[ Dict[Hashable, List[ElkExtendedEdge]], Dict[Hashable, List[ElkExtendedEdge]] ] """ visible: EdgeMap = defaultdict( list) # will index edges by nx.lowest_common_ancestor hidden: EdgeMap = defaultdict( list) # will index edges by nx.lowest_common_ancestor g, tree = self.source attr = self.HIDDEN_ATTR hidden: List[Tuple[List, List]] = [] visible: EdgeMap = defaultdict( list) # will index edges by nx.lowest_common_ancestor hidden: EdgeMap = defaultdict( list) # will index edges by nx.lowest_common_ancestor closest_visible = map_visible(g, tree, attr) @lru_cache() def closest_common_visible(nodes: Tuple[Hashable]) -> Hashable: if tree is None: return ElkRoot result = lowest_common_ancestor(tree, nodes) return result for source, target, edge_data in g.edges(data=True): source_port, target_port = get_ports(edge_data) vis_source = closest_visible[source] vis_target = closest_visible[target] shidden = vis_source != source thidden = vis_target != target owner = closest_common_visible((vis_source, vis_target)) if source == target and source == owner: if owner in tree: for p in tree.predecessors(owner): # need to make this edge's owner to it's parent owner = p if shidden or thidden: # create new slack ports if source or target is remapped if vis_source != source: source_port = (source, source_port) if vis_target != target: target_port = (target, target_port) if vis_source != vis_target: hidden[owner].append( Edge( source=vis_source, source_port=source_port, target=vis_target, target_port=target_port, data=edge_data, owner=owner, )) else: visible[owner].append( Edge( source=source, source_port=source_port, target=target, target_port=target_port, data=edge_data, owner=owner, )) return visible, hidden
class Compositor(Node): """Compositor Attributes ---------- cache_native_coordinates : Bool Default is True. If native_coordinates are requested by the user, it may take a long time to calculate if the Compositor points to many sources. The result is relatively small and is cached by default. Caching may not be desired if the datasource change or is updated. interpolation : str, dict, optional {interpolation} is_source_coordinates_complete : Bool Default is False. The source_coordinates do not have to completely describe the source. For example, the source coordinates could include the year-month-day of the source, but the actual source also has hour-minute-second information. In that case, source_coordinates is incomplete. This flag is used to automatically construct native_coordinates. n_threads : int Default is 10 -- used when threaded is True. NASA data servers seem to have a hard limit of 10 simultaneous requests, which determined the default value. shared_coordinates : :class:`podpac.Coordinates`, optional Coordinates that are shared amongst all of the composited sources source : str The source is used for a unique name to cache composited products. source_coordinates : :class:`podpac.Coordinates` Description sources : :class:`np.ndarray` An array of sources. This is a numpy array as opposed to a list so that boolean indexing may be used to subselect the nodes that will be evaluated. threaded : bool, optional Default if False. When threaded is False, the compositor stops evaluated sources once the output is completely filled. When threaded is True, the compositor must evaluate every source. The result is the same, but note that because of this, threaded=False could be faster than threaded=True, especially if n_threads is low. For example, threaded with n_threads=1 could be much slower than non-threaded if the output is completely filled after the first few sources. source_coordinates : :class:`podpac.Coordinates`, optional Coordinates that make each source unique. This is used for subsetting which sources to evaluate based on the user-requested coordinates. It is an optimization. Notes ----- Developers of new Compositor nodes need to implement the `composite` method. """ shared_coordinates = tl.Instance(Coordinates, allow_none=True) source_coordinates = tl.Instance(Coordinates, allow_none=True) is_source_coordinates_complete = tl.Bool( False, help=("This allows some optimizations but assumes that a node's " "native_coordinates=source_coordinate + shared_coordinate " "IN THAT ORDER")) source = tl.Unicode().tag(attr=True) sources = ArrayTrait(ndim=1) cache_native_coordinates = tl.Bool(True) interpolation = interpolation_trait(default_value=None) threaded = tl.Bool(False) n_threads = tl.Int(10) @tl.default('source') def _source_default(self): source = [] for s in self.sources[:3]: source.append(str(s)) return '_'.join(source) @tl.default('source_coordinates') def _source_coordinates_default(self): return self.get_source_coordinates() def get_source_coordinates(self): """ Returns the coordinates describing each source. This may be implemented by derived classes, and is an optimization that allows evaluation subsets of source. Returns ------- :class:`podpac.Coordinates` Coordinates describing each source. """ return None @tl.default('shared_coordinates') def _shared_coordinates_default(self): return self.get_shared_coordinates() def get_shared_coordinates(self): """Coordinates shared by each source. Raises ------ NotImplementedError Description """ raise NotImplementedError() def select_sources(self, coordinates): """Downselect compositor sources based on requested coordinates. This is used during the :meth:`eval` process as an optimization when :attr:`source_coordinates` are not pre-defined. Parameters ---------- coordinates : :class:`podpac.Coordinates` Coordinates to evaluate at compositor sources Returns ------- :class:`np.ndarray` Array of downselected sources """ # if source coordinates are defined, use intersect if self.source_coordinates is not None: # intersecting sources only try: _, I = self.source_coordinates.intersect(coordinates, outer=True, return_indices=True) except: # Likely non-monotonic coordinates _, I = self.source_coordinates.intersect(coordinates, outer=False, return_indices=True) src_subset = self.sources[I] # no downselection possible - get all sources compositor else: src_subset = self.sources return src_subset def composite(self, outputs, result=None): """Implements the rules for compositing multiple sources together. Parameters ---------- outputs : list A list of outputs that need to be composited together result : UnitDataArray, optional An optional pre-filled array may be supplied, otherwise the output will be allocated. Raises ------ NotImplementedError """ raise NotImplementedError() def iteroutputs(self, coordinates): """Summary Parameters ---------- coordinates : :class:`podpac.Coordinates` Coordinates to evaluate at compositor sources Yields ------ :class:`podpac.core.units.UnitsDataArray` Output from source node eval method """ # downselect sources based on coordinates src_subset = self.select_sources(coordinates) if len(src_subset) == 0: yield self.create_output_array(coordinates) return # Set the interpolation properties for sources if self.interpolation is not None: for s in src_subset.ravel(): if trait_is_defined(self, 'interpolation'): s.interpolation = self.interpolation # Optimization: if coordinates complete and source coords is 1D, # set native_coordinates unless they are set already # WARNING: this assumes # native_coords = source_coords + shared_coordinates # NOT native_coords = shared_coords + source_coords if self.is_source_coordinates_complete and self.source_coordinates.ndim == 1: coords_subset = list( self.source_coordinates.intersect( coordinates, outer=True).coords.values())[0] coords_dim = list(self.source_coordinates.dims)[0] for s, c in zip(src_subset, coords_subset): nc = merge_dims([ Coordinates(np.atleast_1d(c), dims=[coords_dim]), self.shared_coordinates ]) if trait_is_defined(s, 'native_coordinates') is False: s.native_coordinates = nc if self.threaded: # TODO pool of pre-allocated scratch space # TODO: docstring? def f(src): return src.eval(coordinates) pool = ThreadPool(processes=self.n_threads) results = [pool.apply_async(f, [src]) for src in src_subset] for src, res in zip(src_subset, results): yield res.get() #src._output = None # free up memory else: output = None # scratch space for src in src_subset: output = src.eval(coordinates, output) yield output #output[:] = np.nan @node_eval @common_doc(COMMON_COMPOSITOR_DOC) def eval(self, coordinates, output=None): """Evaluates this nodes using the supplied coordinates. Parameters ---------- coordinates : :class:`podpac.Coordinates` {requested_coordinates} output : podpac.UnitsDataArray, optional {eval_output} Returns ------- {eval_return} """ self._requested_coordinates = coordinates outputs = self.iteroutputs(coordinates) output = self.composite(outputs, output) return output def find_coordinates(self): """ Get the available native coordinates for the Node. Returns ------- coords_list : list list of available coordinates (Coordinate objects) """ raise NotImplementedError("TODO") @property @common_doc(COMMON_COMPOSITOR_DOC) def base_definition(self): """Base node defintion for Compositor nodes. Returns ------- {definition_return} """ d = super(Compositor, self).base_definition d['sources'] = self.sources d['interpolation'] = self.interpolation return d
class Git(Repo): """A git repository widget""" # pylint: disable=protected-access,too-many-instance-attributes _git = T.Instance(G.Repo, allow_none=True) _ref_watcher = T.Instance(Watcher, allow_none=True) def _initialize_watcher(self): """watch key folders in git""" self._watcher = Watcher(Path(self._git.git_dir) / "refs", _watcher_cls=_GitRefWatcher) def _schedule(change=None): IOLoop.current().add_callback(self._on_ref_change, change) self._watcher.observe(_schedule, "changes") _schedule() async def _on_ref_change(self, _change=None): """recalculate key values when files in .git/refs folder change""" self._update_heads() self._update_head_history() for remote in self.remotes.values(): await remote._update_heads() @property def _remote_cls(self): return GitRemote @T.observe("working_dir") def _on_path(self, change): """handle when the working directory changes""" if change.new: self._git = G.Repo.init(change.new) ignore = Path(change.new) / ".gitignore" if not ignore.exists(): ignore.write_text(".ipynb_checkpoints/") self.commit("initial commit") self._initialize_watcher() self._update_head_history() @T.default("head") def _default_head(self): """get current head""" return self._git.active_branch.name @T.observe("head") def _on_head_changed(self, change): """react to the symbolic head name changing""" if change.new: self._update_head_history() def _update_head_history(self): """build a structure of history""" # pylint: disable=broad-except try: head = [h for h in self._git.heads if h.name == self.head][0] self.head_hash = head.commit.hexsha self.head_history = [{ "commit": str(c.newhexsha), "timestamp": c.time[0], "message": c.message, "author": { "name": c.actor.name, "email": c.actor.email }, } for c in head.log()[::-1]] except Exception as err: self.log.warn("Git head update error, ignoring: %s", err, exc_info=True) self.head_history = [] def _update_heads(self): """refresh the heads""" self.heads = { head.name: head.commit.hexsha for head in self._git.heads } self._update_head_history() def _on_watch_changes(self, *changes): """overload of the base method to handle changes""" self.dirty = self._git.is_dirty() if self._watcher: for change in self._watcher.changes: for tracker in self._trackers: tracked_path = Path(self._git.working_dir) / change["path"] if tracker.path.resolve() == tracked_path.resolve(): tracker._on_file_change(None) return [ dict(a_path=diff.a_path, b_path=diff.b_path, change_type=diff.change_type) for diff in self._git.index.diff(None) ] + [ dict(a_path=None, b_path=ut, change_type="U") for ut in self._git.untracked_files ] def stage(self, path): """stage a single path to the index""" self._git.index.add(path) def unstage(self, path): """remove a path from the index""" self._git.index.remove(path) def commit(self, message): """create a commit""" self._git.index.commit(message) self._on_watching(None) def revert(self, ref): """restore to a committish""" self._git.head.commit = ref self._git.head.reset(index=True, working_tree=True) def branch(self, name, ref="HEAD"): """create and checkout a new branch""" self._git.create_head(name, ref) self.checkout(name) def checkout(self, name): """checkout a named reference""" head = [h for h in self._git.heads if h.name == name][0] head.checkout() self.head = head.name self._git.head.reset(index=True, working_tree=True) def merge(self, ref): """create a merge commit on the active branch with the given ref""" active = self._git.active_branch active_commit = self._git.active_branch.commit active_name = active.name merge_base = self._git.merge_base(active, ref) ref_commit = self._git.commit(ref) self._git.index.merge_tree(ref_commit, base=merge_base) merge_commit = self._git.index.commit( f"Merged {ref} into {active_name}", parent_commits=(active_commit, ref_commit), ) self.log.error("MERGE %s", merge_commit) self._git.active_branch.reference = merge_commit active.checkout() self._git.head.reset(index=True, working_tree=True) def _update_remotes(self): """fetch some remotes""" remotes = {} for remote in self._git.remotes: remotes[remote.name] = self._remote_cls(name=remote.name, url=remote.url, _remote=remote) self.remotes = remotes
class WorkflowsLayer(ipyleaflet.TileLayer): """ Subclass of ``ipyleaflet.TileLayer`` for displaying a Workflows `~.geospatial.Image`. Attributes ---------- image: ~.geospatial.Image The `~.geospatial.Image` to use parameters: ParameterSet Parameters to use while computing; modify attributes under ``.parameters`` (like ``layer.parameters.foo = "bar"``) to cause the layer to recompute and update under those new parameters. xyz_obj: ~.models.XYZ Read-only: The `XYZ` object this layer is displaying. session_id: str Read-only: Unique ID that error logs will be stored under, generated automatically. checkerboard: bool, default True Whether to display a checkerboarded background for missing or masked data. colormap: str, optional, default None Name of the colormap to use. If set, `image` must have 1 band. r_min: float, optional, default None Min value for scaling the red band. Along with r_max, controls scaling when a colormap is enabled. r_max: float, optional, default None Max value for scaling the red band. Along with r_min, controls scaling when a colormap is enabled. g_min: float, optional, default None Min value for scaling the green band. g_max: float, optional, default None Max value for scaling the green band. b_min: float, optional, default None Min value for scaling the blue band. b_max: float, optional, default None Max value for scaling the blue band. error_output: ipywidgets.Output, optional, default None If set, write unique errors from tiles computation to this output area from a background thread. Setting to None stops the listener thread. Example ------- >>> import descarteslabs.workflows as wf >>> wf.map # doctest: +SKIP >>> # ^ display interactive map >>> img = wf.Image.from_id("landsat:LC08:PRE:TOAR:meta_LC80330352016022_v1").pick_bands("red") >>> masked_img = img.mask(img > wf.parameter("threshold", wf.Float)) >>> layer = masked_img.visualize("sample", colormap="viridis", threshold=0.07) # doctest: +SKIP >>> layer.colormap = "plasma" # doctest: +SKIP >>> # ^ change colormap (this will update the layer on the map) >>> layer.parameters.threshold = 0.13 # doctest: +SKIP >>> # ^ adjust parameters (this also updates the layer) >>> layer.set_scales((0.01, 0.3)) # doctest: +SKIP >>> # ^ adjust scaling (this also updates the layer) """ attribution = traitlets.Unicode("Descartes Labs").tag(sync=True, o=True) min_zoom = traitlets.Int(5).tag(sync=True, o=True) url = traitlets.Unicode(read_only=True).tag(sync=True) image = traitlets.Instance(Image) parameters = traitlets.Instance(parameters.ParameterSet, allow_none=True) xyz_obj = traitlets.Instance(XYZ, read_only=True) session_id = traitlets.Unicode(read_only=True) checkerboard = traitlets.Bool(True) colormap = traitlets.Unicode(None, allow_none=True) r_min = ScaleFloat(None, allow_none=True) r_max = ScaleFloat(None, allow_none=True) g_min = ScaleFloat(None, allow_none=True) g_max = ScaleFloat(None, allow_none=True) b_min = ScaleFloat(None, allow_none=True) b_max = ScaleFloat(None, allow_none=True) error_output = traitlets.Instance(widgets.Output, allow_none=True) autoscale_progress = traitlets.Instance(ClearableOutput) def __init__(self, image, *args, **kwargs): params = kwargs.pop("parameters", {}) super(WorkflowsLayer, self).__init__(*args, **kwargs) with self.hold_trait_notifications(): self.image = image self.set_trait("session_id", uuid.uuid4().hex) self.set_trait( "autoscale_progress", ClearableOutput( widgets.Output(), layout=widgets.Layout(max_height="10rem", flex="1 0 auto"), ), ) self.set_parameters(**params) self._error_listener = None self._known_errors = set() self._known_errors_lock = threading.Lock() def make_url(self): """ Generate the URL for this layer. This is called automatically as the attributes (`image`, `colormap`, scales, etc.) are changed. Example ------- >>> import descarteslabs.workflows as wf >>> img = wf.Image.from_id("landsat:LC08:PRE:TOAR:meta_LC80330352016022_v1") # doctest: +SKIP >>> img = img.pick_bands("red blue green") # doctest: +SKIP >>> layer = img.visualize("sample") # doctest: +SKIP >>> layer.make_url() # doctest: +SKIP 'https://workflows.descarteslabs.com/master/xyz/9ec70d0e99db7f50c856c774809ae454ffd8475816e05c5c/{z}/{x}/{y}.png?session_id=xxx&checkerboard=true' """ if not self.visible: # workaround for the fact that Leaflet still loads tiles from inactive layers, # which is expensive computation users don't want return "" if self.colormap is not None: scales = [[self.r_min, self.r_max]] else: scales = [ [self.r_min, self.r_max], [self.g_min, self.g_max], [self.b_min, self.b_max], ] scales = [scale for scale in scales if scale != [None, None]] parameters = self.parameters.to_dict() return self.xyz_obj.url(session_id=self.session_id, colormap=self.colormap, scales=scales, checkerboard=self.checkerboard, **parameters) @traitlets.observe("image") def _update_xyz(self, change): old, new = change["old"], change["new"] if old is new: # traitlets does an == check between the old and new value to decide if it's changed, # which for an Image, returns another Image, which it considers changed. return xyz = XYZ.build(new, name=self.name) xyz.save() self.set_trait("xyz_obj", xyz) @traitlets.observe( "visible", "checkerboard", "colormap", "r_min", "r_max", "g_min", "g_max", "b_min", "b_max", "xyz_obj", "session_id", "parameters", ) @traitlets.observe("parameters", type="delete") def _update_url(self, change): try: self.set_trait("url", self.make_url()) except ValueError as e: if "Invalid scales passed" not in str(e): raise e @traitlets.observe("parameters", type="delete") def _update_url_on_param_delete(self, change): # traitlets is dumb and decorator stacking doesn't work so we have to repeat this try: self.set_trait("url", self.make_url()) except ValueError as e: if "Invalid scales passed" not in str(e): raise e @traitlets.observe("xyz_obj", "session_id") def _update_error_logger(self, change): if self.error_output is None: return # Remove old errors for the layer self.forget_errors() new_errors = [] for error in self.error_output.outputs: if not error["text"].startswith(self.name + ": "): new_errors.append(error) self.error_output.outputs = tuple(new_errors) if self._error_listener is not None: self._error_listener.stop(timeout=1) listener = self.xyz_obj.error_listener() listener.add_callback(self._log_errors_callback) listener.listen(self.session_id, datetime.datetime.now(datetime.timezone.utc)) self._error_listener = listener def _stop_error_logger(self): if self._error_listener is not None: self._error_listener.stop(timeout=1) self._error_listener = None @traitlets.observe("error_output") def _toggle_error_listener_if_output(self, change): if change["new"] is None: self._stop_error_logger() else: if self._error_listener is None: self._update_error_logger({}) def _log_errors_callback(self, msg): message = msg.message with self._known_errors_lock: if message in self._known_errors: return else: self._known_errors.add(message) error = "{}: {}\n".format(self.name, message) self.error_output.append_stdout(error) def __del__(self): self._stop_error_logger() super(WorkflowsLayer, self).__del__() def forget_errors(self): """ Clear the set of known errors, so they are re-displayed if they occur again Example ------- >>> import descarteslabs.workflows as wf >>> img = wf.Image.from_id("landsat:LC08:PRE:TOAR:meta_LC80330352016022_v1") # doctest: +SKIP >>> wf.map # doctest: +SKIP >>> layer = img.visualize("sample visualization") # doctest: +SKIP >>> # ^ will show an error for attempting to visualize more than 3 bands >>> layer.forget_errors() # doctest: +SKIP >>> wf.map.zoom = 10 # doctest: +SKIP >>> # ^ attempting to load more tiles from img will cause the same error to appear """ with self._known_errors_lock: self._known_errors.clear() def set_scales(self, scales, new_colormap=False): """ Update the scales for this layer by giving a list of scales Parameters ---------- scales: list of lists, default None The scaling to apply to each band in the `Image`. If `Image` contains 3 bands, ``scales`` must be a list like ``[(0, 1), (0, 1), (-1, 1)]``. If `Image` contains 1 band, ``scales`` must be a list like ``[(0, 1)]``, or just ``(0, 1)`` for convenience If None, each 256x256 tile will be scaled independently based on the min and max values of its data. new_colormap: str, None, or False, optional, default False A new colormap to set at the same time, or False to use the current colormap. Example ------- >>> import descarteslabs.workflows as wf >>> img = wf.Image.from_id("landsat:LC08:PRE:TOAR:meta_LC80330352016022_v1") # doctest: +SKIP >>> img = img.pick_bands("red") # doctest: +SKIP >>> layer = img.visualize("sample visualization", colormap="viridis") # doctest: +SKIP >>> layer.set_scales((0.08, 0.3), new_colormap="plasma") # doctest: +SKIP >>> # ^ optionally set new colormap """ colormap = self.colormap if new_colormap is False else new_colormap if scales is not None: scales = XYZ._validate_scales(scales) scales_len = 1 if colormap is not None else 3 if len(scales) != scales_len: msg = "Expected {} scales, but got {}.".format( scales_len, len(scales)) if len(scales) in (1, 2): msg += " If displaying a 1-band Image, use a colormap." elif colormap: msg += " Colormaps cannot be used with multi-band images." raise ValueError(msg) with self.hold_trait_notifications(): if colormap is None: self.r_min = scales[0][0] self.r_max = scales[0][1] self.g_min = scales[1][0] self.g_max = scales[1][1] self.b_min = scales[2][0] self.b_max = scales[2][1] else: self.r_min = scales[0][0] self.r_max = scales[0][1] if new_colormap is not False: self.colormap = new_colormap else: # scales is None with self.hold_trait_notifications(): if colormap is None: self.r_min = None self.r_max = None self.g_min = None self.g_max = None self.b_min = None self.b_max = None else: self.r_min = None self.r_max = None if new_colormap is not False: self.colormap = new_colormap def set_parameters(self, **params): """ Set new parameters for this `WorkflowsLayer`. In typical cases, you update parameters by assigning to `parameters` (like ``layer.parameters.threshold = 6.6``). Instead, use this function when you need to change the *names or types* of parameters available on the `WorkflowsLayer`. (Users shouldn't need to do this, as `~.Image.visualize` handles it for you, but custom widget developers may need to use this method when they change the `image` field on a `WorkflowsLayer`.) If a value is an ipywidgets Widget, it will be linked to that parameter (via its ``"value"`` attribute). If a parameter was previously set with a widget, and a different widget instance (or non-widget) is passed for its new value, the old widget is automatically unlinked. If the same widget instance is passed as is already linked, no change occurs. Parameters ---------- params: JSON-serializable value, Proxytype, or ipywidgets.Widget Paramter names to new values. Values can be Python types, `Proxytype` instances, or ``ipywidgets.Widget`` instances. Example ------- >>> import descarteslabs.workflows as wf >>> from ipywidgets import FloatSlider >>> img = wf.Image.from_id("landsat:LC08:PRE:TOAR:meta_LC80330352016022_v1") # doctest: +SKIP >>> img = img.pick_bands("red") # doctest: +SKIP >>> masked_img = img.mask(img > wf.parameter("threshold", wf.Float)) # doctest: +SKIP >>> layer = masked_img.tile_layer("sample", colormap="plasma", threshold=0.07) # doctest: +SKIP >>> scaled_img = img * wf.parameter("scale", wf.Float) + wf.parameter("offset", wf.Float) # doctest: +SKIP >>> with layer.hold_trait_notifications(): # doctest: +SKIP ... layer.image = scaled_img # doctest: +SKIP ... layer.set_parameters(scale=FloatSlider(min=0, max=10, value=2), offset=2.5) # doctest: +SKIP >>> # ^ re-use the same layer instance for a new Image with different parameters """ param_set = self.parameters if param_set is None: param_set = self.parameters = parameters.ParameterSet( self, "parameters") with self.hold_trait_notifications(): param_set.update(**params) def _ipython_display_(self): param_set = self.parameters if param_set: widget = param_set.widget if widget and len(widget.children) > 0: widget._ipython_display_()
class Figure(widgets.DOMWidget): """Widget class representing a volume (rendering) using three.js""" _view_name = Unicode('FigureView').tag(sync=True) _view_module = Unicode('ipyvolume').tag(sync=True) _model_name = Unicode('FigureModel').tag(sync=True) _model_module = Unicode('ipyvolume').tag(sync=True) _view_module_version = Unicode(semver_range_frontend).tag(sync=True) _model_module_version = Unicode(semver_range_frontend).tag(sync=True) volume_data = Array(default_value=None, allow_none=True).tag(sync=True, **array_cube_png_serialization) eye_separation = traitlets.CFloat(6.4).tag(sync=True) data_min = traitlets.CFloat().tag(sync=True) data_max = traitlets.CFloat().tag(sync=True) tf = traitlets.Instance(TransferFunction, allow_none=True).tag( sync=True, **ipywidgets.widget_serialization) anglex = traitlets.Float(0.0).tag(sync=True) angley = traitlets.Float(0.0).tag(sync=True) anglez = traitlets.Float(0.0).tag(sync=True) angle_order = Unicode(default_value="XYZ").tag(sync=True) scatters = traitlets.List(traitlets.Instance(Scatter), [], allow_none=False).tag( sync=True, **ipywidgets.widget_serialization) meshes = traitlets.List(traitlets.Instance(Mesh), [], allow_none=False).tag( sync=True, **ipywidgets.widget_serialization) animation = traitlets.Float(1000.0).tag(sync=True) animation_exponent = traitlets.Float(.5).tag(sync=True) ambient_coefficient = traitlets.Float(0.5).tag(sync=True) diffuse_coefficient = traitlets.Float(0.8).tag(sync=True) specular_coefficient = traitlets.Float(0.5).tag(sync=True) specular_exponent = traitlets.Float(5).tag(sync=True) stereo = traitlets.Bool(False).tag(sync=True) screen_capture_enabled = traitlets.Bool(False).tag(sync=True) screen_capture_mime_type = traitlets.Unicode( default_value='image/png').tag(sync=True) screen_capture_data = traitlets.Unicode(default_value=None, allow_none=True).tag(sync=True) fullscreen = traitlets.Bool(False).tag(sync=True) camera_control = traitlets.Unicode(default_value='trackball').tag( sync=True) width = traitlets.CInt(500).tag(sync=True) height = traitlets.CInt(400).tag(sync=True) downscale = traitlets.CInt(1).tag(sync=True) show = traitlets.Unicode("Volume").tag(sync=True) # for debugging xlim = traitlets.List(traitlets.CFloat, default_value=[0, 1], minlen=2, maxlen=2).tag(sync=True) ylim = traitlets.List(traitlets.CFloat, default_value=[0, 1], minlen=2, maxlen=2).tag(sync=True) zlim = traitlets.List(traitlets.CFloat, default_value=[0, 1], minlen=2, maxlen=2).tag(sync=True) xlabel = traitlets.Unicode("x").tag(sync=True) ylabel = traitlets.Unicode("y").tag(sync=True) zlabel = traitlets.Unicode("z").tag(sync=True) style = traitlets.Dict(default_value=ipyvolume.style.default).tag( sync=True) #xlim = traitlets.Tuple(traitlets.CFloat(0), traitlets.CFloat(1)).tag(sync=True) #y#lim = traitlets.Tuple(traitlets.CFloat(0), traitlets.CFloat(1)).tag(sync=True) #zlim = traitlets.Tuple(traitlets.CFloat(0), traitlets.CFloat(1)).tag(sync=True) def __init__(self, **kwargs): super(Figure, self).__init__(**kwargs) self._screenshot_handlers = widgets.CallbackDispatcher() self.on_msg(self._handle_custom_msg) def screenshot(self): self.send({'msg': 'screenshot'}) def on_screenshot(self, callback, remove=False): self._screenshot_handlers.register_callback(callback, remove=remove) def _handle_custom_msg(self, content, buffers): print("msg", content) if content.get('event', '') == 'screenshot': self._screenshot_handlers(content['data'])
class GridCalculator(_HasState): '''A grid is responsible for scheduling the grid calculations and possible slicing''' class Status(enum.Enum): VOID = 1 STAGED_CALCULATION = 3 CALCULATING = 4 READY = 9 status = traitlets.UseEnum(Status, Status.VOID) df = traitlets.Instance(vaex.dataframe.DataFrame) models = traitlets.List(traitlets.Instance(DataArray)) calculation = traitlets.Any(None, allow_none=True) _debug = traitlets.Bool(False) def __init__(self, df, models): super().__init__(df=df, models=[]) self._callbacks_regrid = [] self._callbacks_slice = [] for model in models: self.model_add(model) self._testing_exeception_regrid = False # used for testing, to throw an exception self._testing_exeception_reslice = False # used for testing, to throw an exception # def model_remove(self, model, regrid=True): # index = self.models.index(model) # del self.models[index] # del self._callbacks_regrid[index] # del self._callbacks_slice[index] def model_add(self, model): self.models = self.models + [model] if model.status == DataArray.Status.NEEDS_CALCULATING_GRID: if self.calculation is not None: self._cancel_computation() self.computation() def on_status_changed(change): if change.owner.status == DataArray.Status.NEEDS_CALCULATING_GRID: if self.calculation is not None: self._cancel_computation() self.computation() model.observe(on_status_changed, 'status') # TODO: if we listen to the same axis twice it will trigger twice for axis in model.axes: axis.observe(lambda change: self.reslice(), 'slice') # self._callbacks_regrid.append(model.signal_regrid.connect(self.on_regrid)) # self._callbacks_slice.append(model.signal_slice.connect(self.reslice)) assert model.df == self.df # @vaex.jupyter.debounced(delay_seconds=0.05, reentrant=False) # def reslice_debounced(self): # self.reslice() def reslice(self, source_model=None): if self._testing_exeception_reslice: raise RuntimeError("test:reslice") coords = [] selections = self.models[0].selections selections = [ k for k in selections if k is None or self.df.has_selection(k) ] for model in self.models: subgrid = self.grid subgrid_sliced = self.grid axis_index = 1 has_slice = False dims = ["selection"] coords = [selections.copy()] mins = [] maxs = [] for other_model in self.models: if other_model == model: # simply skip these axes # for expression, shape, limit, slice_index in other_model.bin_parameters(): for axis in other_model.axes: axis_index += 1 dims.append(str(axis.expression)) coords.append(axis.centers) mins.append(axis.min) maxs.append(axis.max) else: # for expression, shape, limit, slice_index in other_model.bin_parameters(): for axis in other_model.axes: if axis.slice is not None: subgrid_sliced = subgrid_sliced.__getitem__( tuple([slice(None)] * axis_index + [axis.slice])).copy() subgrid = np.sum(subgrid, axis=axis_index) has_slice = True else: subgrid_sliced = np.sum(subgrid_sliced, axis=axis_index) subgrid = np.sum(subgrid, axis=axis_index) grid = xarray.DataArray(subgrid, dims=dims, coords=coords) for i, (vmin, vmax) in enumerate(zip(mins, maxs)): # +1 to skip the selection axis grid.coords[dims[i + 1]].attrs['min'] = vmin grid.coords[dims[i + 1]].attrs['max'] = vmax model.grid = grid if has_slice: model.grid_sliced = xarray.DataArray(subgrid_sliced) else: model.grid_sliced = None def _regrid_error(self, e): try: self._error(e) for model in self.models: model._error(e) for model in self.models: model.exception = e model.status = vaex.jupyter.model.DataArray.Status.EXCEPTION except Exception as e2: print(e2) def on_regrid(self, ignore=None): self.regrid() @vaex.jupyter.debounced(delay_seconds=0.5, reentrant=False, on_error=_regrid_error) async def computation(self): try: logger.debug('Starting grid computation') # vaex.utils.print_stack_trace() if self._testing_exeception_regrid: raise RuntimeError("test:regrid") if not self.models: return binby = [] shapes = [] limits = [] selections = self.models[0].selections for model in self.models: if model.selections != selections: raise ValueError( 'Selections for all models should be the same') for axis in model.axes: binby.append(axis.expression) limits.append([axis.min, axis.max]) shapes.append(axis.shape or axis.shape_default) selections = [ k for k in selections if k is None or self.df.has_selection(k) ] self._continue_calculation = True logger.debug('Setting up grid computation...') self.calculation = self.df.count(binby=binby, shape=shapes, limits=limits, selection=selections, progress=self.progress, delay=True) logger.debug('Setting up grid computation done tasks=%r', self.df.executor.tasks) logger.debug('Schedule debounced execute') self.df.widget.execute_debounced() # keep a nearly reference to this, since awaits (which trigger the execution, AND reset of this future) may change it this execute_prehook_future = self.df.widget.execute_debounced.pre_hook_future async with contextlib.AsyncExitStack() as stack: for model in self.models: await stack.enter_async_context( model._state_change_to( DataArray.Status.STAGED_CALCULATING_GRID)) async with contextlib.AsyncExitStack() as stack: for model in self.models: await stack.enter_async_context( model._state_change_to( DataArray.Status.CALCULATING_GRID)) await execute_prehook_future async with contextlib.AsyncExitStack() as stack: for model in self.models: await stack.enter_async_context( model._state_change_to( DataArray.Status.CALCULATED_GRID)) # first assign to local grid = await self.calculation # indicate we are done with the calculation self.calculation = None # raise asyncio.CancelledError("User abort") async with contextlib.AsyncExitStack() as stack: for model in self.models: await stack.enter_async_context( model._state_change_to(DataArray.Status.READY)) self.grid = grid self.reslice() except vaex.execution.UserAbort: pass # a user changed the limits or expressions except asyncio.CancelledError: pass # cancelled... def _cancel_computation(self): logger.debug('Cancelling grid computation') self._continue_calculation = False def progress(self, f): return self._continue_calculation and all( [model.on_progress_grid(f) for model in self.models])
class StructureViewer(ipw.VBox): """NGL structure viewer including download button""" structure = traitlets.Instance(Structure, allow_none=True) def __init__(self, **kwargs): self._current_view = None self.viewer = nglview.NGLWidget() self.viewer.camera = "orthographic" self.viewer.stage.set_parameters(mouse_preset="pymol") self.viewer_box = ipw.Box( children=(self.viewer, ), layout={ "width": "auto", "height": "auto", "border": "solid 0.5px darkgrey", "margin": "0px", "padding": "0.5px", }, ) self.download = DownloadChooser(**kwargs) super().__init__( children=(self.viewer_box, self.download), layout={ "width": "auto", "height": "auto", "margin": "0px 0px 0px 0px", "padding": "0px 0px 10px 0px", }, ) self.observe(self._on_change_structure, names="structure") traitlets.dlink((self, "structure"), (self.download, "structure")) def _on_change_structure(self, change): """Update viewer for new structure""" self.reset() self._current_view = self.viewer.add_structure( nglview.TextStructure(change["new"].as_pdb)) self.viewer.add_representation("ball+stick", aspectRatio=4) self.viewer.add_representation("unitcell") def freeze(self): """Disable widget""" self.download.freeze() def unfreeze(self): """Activate widget (in its current state)""" self.download.unfreeze() def reset(self): """Reset widget""" self.download.reset() if self._current_view is not None: self.viewer.clear() self.viewer.remove_component(self._current_view) self._current_view = None
class Axis(_HasState): class Status(enum.Enum): """ State transitions NO_LIMITS -> STAGED_CALCULATING_LIMITS -> CALCULATING_LIMITS -> CALCULATED_LIMITS -> READY when expression changes: STAGED_CALCULATING_LIMITS: calculation.cancel() ->NO_LIMITS CALCULATING_LIMITS: calculation.cancel() ->NO_LIMITS when min/max changes: STAGED_CALCULATING_LIMITS: calculation.cancel() ->NO_LIMITS CALCULATING_LIMITS: calculation.cancel() ->NO_LIMITS """ NO_LIMITS = 1 STAGED_CALCULATING_LIMITS = 2 CALCULATING_LIMITS = 3 CALCULATED_LIMITS = 4 READY = 5 EXCEPTION = 6 ABORTED = 7 status = traitlets.UseEnum(Status, Status.NO_LIMITS) df = traitlets.Instance(vaex.dataframe.DataFrame) expression = Expression() slice = traitlets.CInt(None, allow_none=True) min = traitlets.CFloat(None, allow_none=True) max = traitlets.CFloat(None, allow_none=True) centers = traitlets.Any() shape = traitlets.CInt(None, allow_none=True) shape_default = traitlets.CInt(64) calculation = traitlets.Any(None, allow_none=True) exception = traitlets.Any(None, allow_none=True) _status_change_delay = traitlets.Float(0) def __init__(self, **kwargs): super().__init__(**kwargs) if self.min is not None and self.max is not None: self.status = Axis.Status.READY self._calculate_centers() else: self.computation() self.observe(self.on_change_expression, 'expression') self.observe(self.on_change_shape, 'shape') self.observe(self.on_change_shape_default, 'shape_default') def __repr__(self): def myrepr(value, key): if isinstance(value, vaex.expression.Expression): return str(value) return value args = ', '.join('{}={}'.format(key, myrepr(getattr(self, key), key)) for key in self.traits().keys() if key != 'df') return '{}({})'.format(self.__class__.__name__, args) @property def has_missing_limit(self): # return not self.df.is_category(self.expression) and (self.min is None or self.max is None) return (self.min is None or self.max is None) def on_change_expression(self, change): self.min = None self.max = None self.status = Axis.Status.NO_LIMITS if self.calculation is not None: self._cancel_computation() self.computation() def on_change_shape(self, change): if self.min is not None and self.max is not None: self._calculate_centers() def on_change_shape_default(self, change): if self.min is not None and self.max is not None: self._calculate_centers() def _cancel_computation(self): self._continue_calculation = False @traitlets.observe('min', 'max') def on_change_limits(self, change): if self.min is not None and self.max is not None: self._calculate_centers() if self.status == Axis.Status.NO_LIMITS: if self.min is not None and self.max is not None: self.status = Axis.Status.READY elif self.status == Axis.Status.READY: if self.min is None or self.max is None: self.status = Axis.Status.NO_LIMITS else: # in this case, grids may want to be computed # this happens when a user change min/max pass else: if self.calculation is not None: self._cancel_computation() if self.min is not None and self.max is not None: self.status = Axis.Status.READY else: self.status = Axis.Status.NO_LIMITS else: # in this case we've set min/max after the calculation assert self.min is not None or self.max is not None @vaex.jupyter.debounced(delay_seconds=0.1, reentrant=False, on_error=_HasState._error) async def computation(self): categorical = self.df.is_category(self.expression) if categorical: N = self.df.category_count(self.expression) self.min, self.max = -0.5, N - 0.5 # centers = np.arange(N) # self.shape = N self._calculate_centers() self.status = Axis.Status.READY else: try: self._continue_calculation = True self.calculation = self.df.minmax(self.expression, delay=True, progress=self._progress) self.df.widget.execute_debounced() # keep a nearly reference to this, since awaits (which trigger the execution, AND reset of this future) may change it this execute_prehook_future = self.df.widget.execute_debounced.pre_hook_future async with self._state_change_to( Axis.Status.STAGED_CALCULATING_LIMITS): pass async with self._state_change_to( Axis.Status.CALCULATING_LIMITS): await execute_prehook_future async with self._state_change_to( Axis.Status.CALCULATED_LIMITS): vmin, vmax = await self.calculation # indicate we are done with the calculation self.calculation = None if not self._continue_calculation: assert self.status == Axis.Status.READY async with self._state_change_to(Axis.Status.READY): self.min, self.max = vmin, vmax self._calculate_centers() except vaex.execution.UserAbort: # expression or min/max change, we don't have to take action assert self.status in [ Axis.Status.NO_LIMITS, Axis.Status.READY ] except asyncio.CancelledError: pass def _progress(self, f): # we use the progres callback to cancel as calculation return self._continue_calculation def _calculate_centers(self): categorical = self.df.is_category(self.expression) if categorical: N = self.df.category_count(self.expression) centers = np.arange(N) self.shape = N else: centers = self.df.bin_centers(self.expression, [self.min, self.max], shape=self.shape or self.shape_default) self.centers = centers
class ProviderImplementationChooser( # pylint: disable=too-many-instance-attributes ipw.VBox): """List all OPTIMADE providers and their implementations""" provider = traitlets.Instance(LinksResourceAttributes, allow_none=True) database = traitlets.Tuple( traitlets.Unicode(), traitlets.Instance(LinksResourceAttributes, allow_none=True), default_value=("", None), ) HINT = {"provider": "Select a provider", "child_dbs": "Select a database"} INITIAL_CHILD_DBS = [("", (("No provider chosen", None), ))] def __init__( self, child_db_limit: int = None, disable_providers: List[str] = None, skip_providers: List[str] = None, skip_databases: Dict[str, List[str]] = None, provider_database_groupings: Dict[str, Dict[str, List[str]]] = None, **kwargs, ): self.child_db_limit = (child_db_limit if child_db_limit and child_db_limit > 0 else 10) self.skip_child_dbs = skip_databases or {} self.child_db_groupings = (provider_database_groupings if provider_database_groupings is not None else PROVIDER_DATABASE_GROUPINGS) self.offset = 0 self.number = 1 self.__perform_query = True self.__cached_child_dbs = {} self.debug = bool(os.environ.get("OPTIMADE_CLIENT_DEBUG", None)) providers = [] providers, invalid_providers = get_list_of_valid_providers( disable_providers=disable_providers, skip_providers=skip_providers) providers.insert(0, (self.HINT["provider"], {})) if self.debug: from optimade_client.utils import VERSION_PARTS local_provider = LinksResourceAttributes( **{ "name": "Local server", "description": "Local server, running aiida-optimade", "base_url": f"http://localhost:5000{VERSION_PARTS[0][0]}", "homepage": "https://example.org", "link_type": "external", }) providers.insert(1, ("Local server", local_provider)) self.providers = DropdownExtended( options=providers, disabled_options=invalid_providers, layout=ipw.Layout(width="auto"), ) self.show_child_dbs = ipw.Layout(width="auto", display="none") self.child_dbs = DropdownExtended(grouping=self.INITIAL_CHILD_DBS, layout=self.show_child_dbs) self.page_chooser = ResultsPageChooser(page_limit=self.child_db_limit, layout=self.show_child_dbs) self.providers.observe(self._observe_providers, names="value") self.child_dbs.observe(self._observe_child_dbs, names="value") self.page_chooser.observe( self._get_more_child_dbs, names=["page_link", "page_offset", "page_number"]) self.error_or_status_messages = ipw.HTML("") super().__init__( children=( self.providers, self.child_dbs, self.page_chooser, self.error_or_status_messages, ), layout=ipw.Layout(width="auto"), **kwargs, ) def freeze(self): """Disable widget""" self.providers.disabled = True self.show_child_dbs.display = "none" self.page_chooser.freeze() def unfreeze(self): """Activate widget (in its current state)""" self.providers.disabled = False self.show_child_dbs.display = None self.page_chooser.unfreeze() def reset(self): """Reset widget""" self.page_chooser.reset() self.offset = 0 self.number = 1 self.providers.disabled = False self.providers.index = 0 self.show_child_dbs.display = "none" self.child_dbs.grouping = self.INITIAL_CHILD_DBS def _observe_providers(self, change: dict): """Update child database dropdown upon changing provider""" value = change["new"] self.show_child_dbs.display = "none" self.provider = value if value is None or not value: self.show_child_dbs.display = "none" self.child_dbs.grouping = self.INITIAL_CHILD_DBS self.providers.index = 0 self.child_dbs.index = 0 else: self._initialize_child_dbs() if sum([len(_[1]) for _ in self.child_dbs.grouping]) <= 2: # The provider either has 0 or 1 implementations # or we have failed to retrieve any implementations. # Automatically choose the 1 implementation (if there), # while otherwise keeping the dropdown disabled. self.show_child_dbs.display = "none" try: self.child_dbs.index = 1 LOGGER.debug("Changed child_dbs index. New child_dbs: %s", self.child_dbs) except IndexError: pass else: self.show_child_dbs.display = None def _observe_child_dbs(self, change: dict): """Update database traitlet with base URL for chosen child database""" value = change["new"] if value is None or not value: self.database = "", None else: self.database = self.child_dbs.label.strip(), self.child_dbs.value @staticmethod def _remove_current_dropdown_option(dropdown: ipw.Dropdown) -> tuple: """Remove the current option from a Dropdown widget and return updated options Since Dropdown.options is a tuple there is a need to go through a list. """ list_of_options = list(dropdown.options) list_of_options.pop(dropdown.index) return tuple(list_of_options) def _initialize_child_dbs(self): """New provider chosen; initialize child DB dropdown""" self.offset = 0 self.number = 1 try: # Freeze and disable list of structures in dropdown widget # We don't want changes leading to weird things happening prior to the query ending self.freeze() # Reset the error or status message if self.error_or_status_messages.value: self.error_or_status_messages.value = "" if self.provider.base_url in self.__cached_child_dbs: cache = self.__cached_child_dbs[self.provider.base_url] LOGGER.debug( "Initializing child DBs for %s. Using cached info:\n%r", self.provider.name, cache, ) self._set_child_dbs(cache["child_dbs"]) data_returned = cache["data_returned"] data_available = cache["data_available"] links = cache["links"] else: LOGGER.debug("Initializing child DBs for %s.", self.provider.name) # Query database and get child_dbs child_dbs, links, data_returned, data_available = self._query() while True: # Update list of structures in dropdown widget exclude_child_dbs, final_child_dbs = self._update_child_dbs( data=child_dbs, skip_dbs=self.skip_child_dbs.get( self.provider.name, []), ) LOGGER.debug("Exclude child DBs: %r", exclude_child_dbs) data_returned -= len(exclude_child_dbs) data_available -= len(exclude_child_dbs) if exclude_child_dbs and data_returned: child_dbs, links, data_returned, _ = self._query( exclude_ids=exclude_child_dbs) else: break self._set_child_dbs(final_child_dbs) # Cache initial child_dbs and related information self.__cached_child_dbs[self.provider.base_url] = { "child_dbs": final_child_dbs, "data_returned": data_returned, "data_available": data_available, "links": links, } LOGGER.debug( "Found the following, which has now been cached:\n%r", self.__cached_child_dbs[self.provider.base_url], ) # Update pageing self.page_chooser.set_pagination_data( data_returned=data_returned, data_available=data_available, links_to_page=links, reset_cache=True, ) except QueryError as exc: LOGGER.debug( "Trying to initalize child DBs. QueryError caught: %r", exc) if exc.remove_target: LOGGER.debug( "Remove target: %r. Will remove target at %r: %r", exc.remove_target, self.providers.index, self.providers.value, ) self.providers.options = self._remove_current_dropdown_option( self.providers) self.reset() else: LOGGER.debug( "Remove target: %r. Will NOT remove target at %r: %r", exc.remove_target, self.providers.index, self.providers.value, ) self.show_child_dbs.display = "none" self.child_dbs.grouping = self.INITIAL_CHILD_DBS else: self.unfreeze() def _set_child_dbs( self, data: List[Tuple[str, List[Tuple[str, LinksResourceAttributes]]]], ) -> None: """Update the child_dbs options with `data`""" first_choice = (self.HINT["child_dbs"] if data else "No valid implementations found") new_data = list(data) new_data.insert(0, ("", [(first_choice, None)])) self.child_dbs.grouping = new_data def _update_child_dbs( self, data: List[dict], skip_dbs: List[str] = None ) -> Tuple[List[str], List[List[Union[str, List[Tuple[ str, LinksResourceAttributes]]]]], ]: """Update child DB dropdown from response data""" child_dbs = ({ "": [] } if self.providers.label not in self.child_db_groupings else deepcopy( self.child_db_groupings[self.providers.label])) exclude_dbs = [] skip_dbs = skip_dbs or [] for entry in data: child_db = update_old_links_resources(entry) if child_db is None: continue # Skip if desired by user if child_db.id in skip_dbs: exclude_dbs.append(child_db.id) continue attributes = child_db.attributes # Skip if not a 'child' link_type database if attributes.link_type != LinkType.CHILD: LOGGER.debug( "Skip %s: Links resource not a %r link_type, instead: %r", attributes.name, LinkType.CHILD, attributes.link_type, ) continue # Skip if there is no base_url if attributes.base_url is None: LOGGER.debug( "Skip %s: Base URL found to be None for child DB: %r", attributes.name, child_db, ) exclude_dbs.append(child_db.id) continue versioned_base_url = get_versioned_base_url(attributes.base_url) if versioned_base_url: attributes.base_url = versioned_base_url else: # Not a valid/supported child DB: skip LOGGER.debug( "Skip %s: Could not determine versioned base URL for child DB: %r", attributes.name, child_db, ) exclude_dbs.append(child_db.id) continue if self.providers.label in self.child_db_groupings: for group, ids in self.child_db_groupings[ self.providers.label].items(): if child_db.id in ids: index = child_dbs[group].index(child_db.id) child_dbs[group][index] = (attributes.name, attributes) break else: if "" in child_dbs: child_dbs[""].append((attributes.name, attributes)) else: child_dbs[""] = [(attributes.name, attributes)] else: child_dbs[""].append((attributes.name, attributes)) if self.providers.label in self.child_db_groupings: for group, ids in tuple(child_dbs.items()): child_dbs[group] = [_ for _ in ids if isinstance(_, tuple)] child_dbs = list(child_dbs.items()) LOGGER.debug("Final updated child_dbs: %s", child_dbs) return exclude_dbs, child_dbs def _get_more_child_dbs(self, change): """Query for more child DBs according to page_offset""" if self.providers.value is None: # This may be called if a provider is suddenly removed (bad provider) return if not self.__perform_query: self.__perform_query = True LOGGER.debug( "Will not perform query with pageing: name=%s value=%s", change["name"], change["new"], ) return pageing: Union[int, str] = change["new"] LOGGER.debug( "Detected change in page_chooser: name=%s value=%s", change["name"], pageing, ) if change["name"] == "page_offset": LOGGER.debug( "Got offset %d to retrieve more child DBs from %r", pageing, self.providers.value, ) self.offset = pageing pageing = None elif change["name"] == "page_number": LOGGER.debug( "Got number %d to retrieve more child DBs from %r", pageing, self.providers.value, ) self.number = pageing pageing = None else: LOGGER.debug( "Got link %r to retrieve more child DBs from %r", pageing, self.providers.value, ) # It is needed to update page_offset, but we do not wish to query again with self.hold_trait_notifications(): self.__perform_query = False self.page_chooser.update_offset() try: # Freeze and disable both dropdown widgets # We don't want changes leading to weird things happening prior to the query ending self.freeze() # Query index meta-database LOGGER.debug("Querying for more child DBs using pageing: %r", pageing) child_dbs, links, _, _ = self._query(pageing) data_returned = self.page_chooser.data_returned while True: # Update list of child DBs in dropdown widget exclude_child_dbs, final_child_dbs = self._update_child_dbs( child_dbs) data_returned -= len(exclude_child_dbs) if exclude_child_dbs and data_returned: child_dbs, links, data_returned, _ = self._query( link=pageing, exclude_ids=exclude_child_dbs) else: break self._set_child_dbs(final_child_dbs) # Update pageing self.page_chooser.set_pagination_data(data_returned=data_returned, links_to_page=links) except QueryError as exc: LOGGER.debug( "Trying to retrieve more child DBs (new page). QueryError caught: %r", exc, ) if exc.remove_target: LOGGER.debug( "Remove target: %r. Will remove target at %r: %r", exc.remove_target, self.providers.index, self.providers.value, ) self.providers.options = self._remove_current_dropdown_option( self.providers) self.reset() else: LOGGER.debug( "Remove target: %r. Will NOT remove target at %r: %r", exc.remove_target, self.providers.index, self.providers.value, ) self.show_child_dbs.display = "none" self.child_dbs.grouping = self.INITIAL_CHILD_DBS else: self.unfreeze() def _query( # pylint: disable=too-many-locals,too-many-branches,too-many-statements self, link: str = None, exclude_ids: List[str] = None) -> Tuple[List[dict], dict, int, int]: """Query helper function""" # If a complete link is provided, use it straight up if link is not None: try: if exclude_ids: filter_value = " AND ".join( [f'NOT id="{id_}"' for id_ in exclude_ids]) parsed_url = urllib.parse.urlparse(link) queries = urllib.parse.parse_qs(parsed_url.query) # Since parse_qs wraps all values in a list, # this extracts the values from the list(s). queries = {key: value[0] for key, value in queries.items()} if "filter" in queries: queries[ "filter"] = f"( {queries['filter']} ) AND ( {filter_value} )" else: queries["filter"] = filter_value parsed_query = urllib.parse.urlencode(queries) link = ( f"{parsed_url.scheme}://{parsed_url.netloc}{parsed_url.path}" f"?{parsed_query}") link = ordered_query_url(link) response = SESSION.get(link, timeout=TIMEOUT_SECONDS) if response.from_cache: LOGGER.debug("Request to %s was taken from cache !", link) response = response.json() except ( requests.exceptions.ConnectTimeout, requests.exceptions.ConnectionError, requests.exceptions.ReadTimeout, ) as exc: response = { "errors": { "msg": "CLIENT: Connection error or timeout.", "url": link, "Exception": repr(exc), } } except json.JSONDecodeError as exc: response = { "errors": { "msg": "CLIENT: Could not decode response to JSON.", "url": link, "Exception": repr(exc), } } else: filter_ = '( link_type="child" OR type="child" )' if exclude_ids: filter_ += ( " AND ( " + " AND ".join([f'NOT id="{id_}"' for id_ in exclude_ids]) + " )") response = perform_optimade_query( filter=filter_, base_url=self.provider.base_url, endpoint="/links", page_limit=self.child_db_limit, page_offset=self.offset, page_number=self.number, ) msg, http_errors = handle_errors(response) if msg: if 404 in http_errors: # If /links not found move on pass else: self.error_or_status_messages.value = msg raise QueryError(msg=msg, remove_target=True) # Check implementation API version msg = validate_api_version(response.get("meta", {}).get("api_version", ""), raise_on_fail=False) if msg: self.error_or_status_messages.value = ( f"{msg}<br>The provider has been removed.") raise QueryError(msg=msg, remove_target=True) LOGGER.debug( "Manually remove `exclude_ids` if filters are not supported") child_db_data = { impl.get("id", "N/A"): impl for impl in response.get("data", []) } if exclude_ids: for links_id in exclude_ids: if links_id in list(child_db_data.keys()): child_db_data.pop(links_id) LOGGER.debug("child_db_data after popping: %r", child_db_data) response["data"] = list(child_db_data.values()) if "meta" in response: if "data_available" in response["meta"]: old_data_available = response["meta"].get( "data_available", 0) if len(response["data"]) > old_data_available: LOGGER.debug("raising OptimadeClientError") raise OptimadeClientError( f"Reported data_available ({old_data_available}) is smaller than " f"curated list of responses ({len(response['data'])}).", ) response["meta"]["data_available"] = len(response["data"]) else: raise OptimadeClientError( "'meta' not found in response. Bad response") LOGGER.debug( "Attempt for %r (in /links): Found implementations (names+base_url only):\n%s", self.provider.name, [ f"(id: {name}; base_url: {base_url}) " for name, base_url in [( impl.get("id", "N/A"), impl.get("attributes", {}).get("base_url", "N/A"), ) for impl in response.get("data", [])] ], ) # Return all implementations of link_type "child" implementations = [ implementation for implementation in response.get("data", []) if (implementation.get("attributes", {}).get("link_type", "") == "child" or implementation.get("type", "") == "child") ] LOGGER.debug( "After curating for implementations which are of 'link_type' = 'child' or 'type' == " "'child' (old style):\n%s", [ f"(id: {name}; base_url: {base_url}) " for name, base_url in [( impl.get("id", "N/A"), impl.get("attributes", {}).get("base_url", "N/A"), ) for impl in implementations] ], ) # Get links, data_returned, and data_available links = response.get("links", {}) data_returned = response.get("meta", {}).get("data_returned", len(implementations)) if data_returned > 0 and not implementations: # Most probably dealing with pre-v1.0.0-rc.2 implementations data_returned = 0 data_available = response.get("meta", {}).get("data_available", 0) return implementations, links, data_returned, data_available
class ModifyCoordinates(Algorithm): """ Base class for nodes that modify the requested coordinates before evaluation. Attributes ---------- source : podpac.Node Source node that will be evaluated with the modified coordinates. coordinates_source : podpac.Node Node that supplies the available coordinates when necessary, optional. The source node is used by default. lat, lon, time, alt : List Modification parameters for given dimension. Varies by node. """ source = tl.Instance(Node) coordinates_source = tl.Instance(Node) lat = tl.List().tag(attr=True) lon = tl.List().tag(attr=True) time = tl.List().tag(attr=True) alt = tl.List().tag(attr=True) _modified_coordinates = tl.Instance(Coordinates, allow_none=True) @tl.default('coordinates_source') def _default_coordinates_source(self): return self.source @common_doc(COMMON_DOC) def eval(self, coordinates, output=None): """Evaluates this nodes using the supplied coordinates. Parameters ---------- coordinates : podpac.Coordinates {requested_coordinates} output : podpac.UnitsDataArray, optional {eval_output} Returns ------- {eval_return} Notes ------- The input coordinates are modified and the passed to the base class implementation of eval. """ self._requested_coordinates = coordinates self.outputs = {} self._modified_coordinates = Coordinates([ self.get_modified_coordinates1d(coordinates, dim) for dim in coordinates.dims ]) for dim in self._modified_coordinates.udims: if self._modified_coordinates[dim].size == 0: raise ValueError( "Modified coordinates do not intersect with source data (dim '%s')" % dim) self.outputs['source'] = self.source.eval(self._modified_coordinates, output=output) if output is None: output = self.outputs['source'] else: output[:] = self.outputs['source'] if settings['DEBUG']: self._output = output return output
class PolarCoordinates(StackedCoordinates): """ Parameterized spatial coordinates defined by a center, radius coordinates, and theta coordinates. Attributes ---------- center radius theta """ center = ArrayTrait(shape=(2, ), dtype=float, read_only=True) radius = tl.Instance(Coordinates1d, read_only=True) theta = tl.Instance(Coordinates1d, read_only=True) dims = tl.Tuple(tl.Unicode(), tl.Unicode(), read_only=True) def __init__(self, center, radius, theta=None, theta_size=None, dims=None): # radius if not isinstance(radius, Coordinates1d): radius = ArrayCoordinates1d(radius) # theta if theta is not None and theta_size is not None: raise TypeError( "PolarCoordinates expected theta or theta_size, not both.") if theta is None and theta_size is None: raise TypeError("PolarCoordinates requires theta or theta_size.") if theta_size is not None: theta = UniformCoordinates1d(start=0, stop=2 * np.pi, size=theta_size + 1)[:-1] elif not isinstance(theta, Coordinates1d): theta = ArrayCoordinates1d(theta) self.set_trait("center", center) self.set_trait("radius", radius) self.set_trait("theta", theta) if dims is not None: self.set_trait("dims", dims) @tl.validate("dims") def _validate_dims(self, d): val = d["value"] for dim in val: if dim not in ["lat", "lon"]: raise ValueError( "PolarCoordinates dims must be 'lat' or 'lon', not '%s'" % dim) if val[0] == val[1]: raise ValueError("Duplicate dimension '%s'" % val[0]) return val @tl.validate("radius") def _validate_radius(self, d): val = d["value"] if np.any(val.coordinates <= 0): raise ValueError("PolarCoordinates radius must all be positive") return val def _set_name(self, value): self._set_dims(value.split("_")) def _set_dims(self, dims): self.set_trait("dims", dims) # ------------------------------------------------------------------------------------------------------------------ # Alternate Constructors # ------------------------------------------------------------------------------------------------------------------ @classmethod def from_definition(cls, d): if "center" not in d: raise ValueError( 'PolarCoordinates definition requires "center" property') if "radius" not in d: raise ValueError( 'PolarCoordinates definition requires "radius" property') if "theta" not in d and "theta_size" not in d: raise ValueError( 'PolarCoordinates definition requires "theta" or "theta_size" property' ) if "dims" not in d: raise ValueError( 'PolarCoordinates definition requires "dims" property') # center center = d["center"] # radius if isinstance(d["radius"], list): radius = ArrayCoordinates1d(d["radius"]) elif "values" in d["radius"]: radius = ArrayCoordinates1d.from_definition(d["radius"]) elif "start" in d["radius"] and "stop" in d["radius"] and ( "step" in d["radius"] or "size" in d["radius"]): radius = UniformCoordinates1d.from_definition(d["radius"]) else: raise ValueError( "Could not parse radius coordinates definition with keys %s" % d.keys()) # theta if "theta" not in d: theta = None elif isinstance(d["theta"], list): theta = ArrayCoordinates1d(d["theta"]) elif "values" in d["theta"]: theta = ArrayCoordinates1d.from_definition(d["theta"]) elif "start" in d["theta"] and "stop" in d["theta"] and ( "step" in d["theta"] or "size" in d["theta"]): theta = UniformCoordinates1d.from_definition(d["theta"]) else: raise ValueError( "Could not parse theta coordinates definition with keys %s" % d.keys()) kwargs = { k: v for k, v in d.items() if k not in ["center", "radius", "theta"] } return PolarCoordinates(center, radius, theta, **kwargs) # ------------------------------------------------------------------------------------------------------------------ # standard methods # ------------------------------------------------------------------------------------------------------------------ def __repr__(self): return "%s(%s): center%s, shape%s" % ( self.__class__.__name__, self.dims, self.center, self.shape) def __eq__(self, other): if not isinstance(other, PolarCoordinates): return False if not np.allclose(self.center, other.center): return False if self.radius != other.radius: return False if self.theta != other.theta: return False return True def __getitem__(self, index): if isinstance(index, slice): index = index, slice(None) if isinstance(index, tuple) and isinstance( index[0], slice) and isinstance(index[1], slice): return PolarCoordinates(self.center, self.radius[index[0]], self.theta[index[1]], dims=self.dims) else: # convert to raw StackedCoordinates (which creates the _coords attribute that the indexing requires) return StackedCoordinates(self.coordinates, dims=self.dims).__getitem__(index) # ------------------------------------------------------------------------------------------------------------------ # Properties # ------------------------------------------------------------------------------------------------------------------ @property def _coords(self): raise RuntimeError("PolarCoordinates do not have a _coords attribute.") @property def ndim(self): return 2 @property def shape(self): return self.radius.size, self.theta.size @property def xdims(self): return ("r", "t") @property def coordinates(self): r, theta = np.meshgrid(self.radius.coordinates, self.theta.coordinates) lat = r * np.sin(theta) + self.center[0] lon = r * np.cos(theta) + self.center[1] return lat.T, lon.T @property def definition(self): d = OrderedDict() d["dims"] = self.dims d["center"] = self.center d["radius"] = self.radius.definition d["theta"] = self.theta.definition return d @property def full_definition(self): return self.definition # ------------------------------------------------------------------------------------------------------------------ # Methods # ------------------------------------------------------------------------------------------------------------------ def copy(self): return PolarCoordinates(self.center, self.radius, self.theta, dims=self.dims)
class Lambda(Node): """A `Node` wrapper to evaluate source on AWS Lambda function Attributes ---------- AWS_ACCESS_KEY_ID : string access key id from AWS credentials AWS_SECRET_ACCESS_KEY : string` access key value from AWS credentials AWS_REGION_NAME : string name of the AWS region source: Node node to be evaluated source_output: Output how to output the evaluated results of `source` attrs: dict additional attributes passed on to the Lambda definition of the base node """ AWS_ACCESS_KEY_ID = tl.Unicode( allow_none=False, help="Access key ID from AWS for S3 bucket.") @tl.default('AWS_ACCESS_KEY_ID') def _AWS_ACCESS_KEY_ID_default(self): return settings['AWS_ACCESS_KEY_ID'] AWS_SECRET_ACCESS_KEY = tl.Unicode( allow_none=False, help="Access key value from AWS for S3 bucket.") @tl.default('AWS_SECRET_ACCESS_KEY') def _AWS_SECRET_ACCESS_KEY_default(self): return settings['AWS_SECRET_ACCESS_KEY'] AWS_REGION_NAME = tl.Unicode(allow_none=False, help="Region name of AWS S3 bucket.") @tl.default('AWS_REGION_NAME') def _AWS_REGION_NAME_default(self): return settings['AWS_REGION_NAME'] source = tl.Instance(Node, allow_none=False, help="Node to evaluate in a Lambda function.") source_output = tl.Instance(Output, allow_none=False, help="Image output information.") attrs = tl.Dict() @tl.default('source_output') def _source_output_default(self): return FileOutput(node=self.source, name=self.source.__class__.__name__) s3_bucket_name = tl.Unicode(allow_none=False, help="Name of AWS s3 bucket.") @tl.default('s3_bucket_name') def _s3_bucket_name_default(self): return settings['S3_BUCKET_NAME'] s3_json_folder = tl.Unicode(allow_none=False, help="S3 folder to put JSON in.") @tl.default('s3_json_folder') def _s3_json_folder_default(self): return settings['S3_JSON_FOLDER'] s3_output_folder = tl.Unicode(allow_none=False, help="S3 folder to put output in.") @tl.default('s3_output_folder') def _s3_output_folder_default(self): return settings['S3_OUTPUT_FOLDER'] @property def definition(self): """ The definition of this manager is the aggregation of the source node and source output. """ d = OrderedDict() d['pipeline'] = self.source.definition if self.attrs: out_node = next(reversed(d['pipeline']['nodes'].keys())) d['pipeline']['nodes'][out_node]['attrs'].update(self.attrs) d['pipeline']['output'] = self.source_output.definition return d @common_doc(COMMON_DOC) def eval(self, coordinates, output=None): """ Evaluate the source node on the AWS Lambda Function at the given coordinates """ d = self.definition d['coordinates'] = json.loads(coordinates.json) filename = '%s%s_%s_%s.%s' % ( self.s3_json_folder, self.source_output.name, self.source.hash, coordinates.hash, 'json') s3 = boto3.client('s3') s3.put_object(Body=(bytes( json.dumps(d, indent=4, cls=JSONEncoder).encode('UTF-8'))), Bucket=self.s3_bucket_name, Key=filename) waiter = s3.get_waiter('object_exists') filename = '%s%s_%s_%s.%s' % ( self.s3_output_folder, self.source_output.name, self.source.hash, coordinates.hash, self.source_output.format) waiter.wait(Bucket=self.s3_bucket_name, Key=filename) # After waiting, load the pickle file like this: resource = boto3.resource('s3') with BytesIO() as data: # Get the bucket and file name programmatically - see above... resource.Bucket(self.s3_bucket_name).download_fileobj( filename, data) data.seek(0) # move back to the beginning after writing self._output = cPickle.load(data) return self._output
class PixelInspector(ipywidgets.GridBox): """ Display pixel values when clicking on the map. Whenever you click on the map, it fetches the pixel values at that location for all active Workflows layers and displays them in a table overlaid on the map. It also shows a marker on the map indicating the last position clicked. As layers change, or are added or removed, the table keeps fetching pixel values for the new layers at the last-clicked point (the marker's current position). For performance, the inspector does not use full-resolution data, but rather whatever resolution (zoom level) the map is currently displaying. Therefore, it's possible that values for the same point would come back slightly different at different zoom levels. (Note that the resampling method used is whatever the input `~.geospatial.Image` or `~.geospatial.ImageCollection` was constructed with.) To unlink from the map, call `~.unlink`. Example ------- >>> import descarteslabs.workflows as wf >>> my_map = wf.interactive.Map() >>> img = wf.Image.from_id("landsat:LC08:PRE:TOAR:meta_LC80330352016022_v1").pick_bands("red") >>> img.pick_bands("red").visualize("Red", colormap="Reds", map=my_map) # doctest: +SKIP >>> img.pick_bands("green").visualize("Green", colormap="Greens", map=my_map) # doctest: +SKIP >>> inspector = wf.interactive.PixelInspector(my_map) >>> my_map # doctest: +SKIP >>> # ^ display the map >>> # click on the map; a table will pop up showing pixel values for the Red and Green layers >>> inspector.unlink() >>> # table and marker disappear; click again and nothing happens """ # NOTE(gabe): we use the marker's opacity as a crude global on-off switch; # all event listeners check that opacity == 1 before doing actual work. marker = traitlets.Instance( CircleMarkerWithXYZGeoContext, kw=dict(opacity=0, radius=5, weight=1, name="Inspected pixel marker"), read_only=True, ) n_bands = traitlets.Int(3, read_only=True, allow_none=False) def __init__(self, map, position="topright", layout=None): """ Construct a PixelInspector and attach it to a map. Parameters ---------- map: ipyleaflet.Map, workflows.interactive.MapApp The map to attach to position: str, optional, default "topright" Where on the map to display the values table layout: ipywidgets.Layout, optional Layout for the values table. Defaults to ``Layout(max_height="350px", overflow="scroll", padding="4px")`` """ if layout is None: layout = ipywidgets.Layout(max_height="350px", overflow="scroll", padding="4px") super().__init__([], layout=layout) self.layout.grid_template_columns = "min-content " * (1 + self.n_bands) # awkwardly handle MapApp without circularly importing it for an isinstance check try: sub_map = map.map except AttributeError: pass else: if isinstance(sub_map, ipyleaflet.Map): map = sub_map self._map = map self.marker.map = map self._inspector_rows_by_layer_id = weakref.WeakValueDictionary() self._layers_changed({"old": [], "new": map.layers}) # initialize with the current layers on the map, if any self._control = ipyleaflet.WidgetControl(widget=self, position=position) map.add_control(self._control) map.observe(self._layers_changed, names=["layers"], type="change") map.on_interaction(self._handle_click) self._orig_cursor = map.default_style.cursor map.default_style.cursor = "crosshair" def unlink(self): "Stop listening for click events or layer updates and remove the table from the map" self._map.on_interaction(self._handle_click, remove=True) self._map.unobserve(self._layers_changed, "layers", type="change") for inspector_row in tuple(self._inspector_rows_by_layer_id.values()): # ^ take a tuple first, since unlinking should remove all references to `inspector_row`, # which would then pop it from the dict, causing mutation of the dict while we iterate over it inspector_row.unlink() self._map.default_style.cursor = self._orig_cursor try: self._map.remove_control(self._control) except ipyleaflet.ControlException: pass try: self._map.remove_layer(self.marker) except ipyleaflet.LayerException: pass self.marker.opacity = 0 # be extra sure no more inspects will run self.children = [] def _layers_changed(self, change): new_layers = change["new"] inspector_rows = [] for layer in reversed(new_layers): if isinstance(layer, WorkflowsLayer): try: inspector_row = self._inspector_rows_by_layer_id[ layer.model_id] except KeyError: inspector_row = InspectorRowGenerator( layer, self.marker, self.n_bands) self._inspector_rows_by_layer_id[ layer.model_id] = inspector_row inspector_rows.append(inspector_row) new_children = [] for inspector_row in inspector_rows: new_children.append(inspector_row.name_label) new_children.extend(inspector_row.values_labels) self.children = new_children def _handle_click(self, **kwargs): with self._map.output_log: if kwargs.get("type") != "click": return try: lat, lon = kwargs["coordinates"] except KeyError: return self.marker.opacity = 1 self.marker.location = (lat, lon) # in case it accidentally got deleted with `clear_layers` if self.marker not in self._map.layers: self._map.add_layer(self.marker)
class ImageSelector(ipywidgets.HBox): """UI element to select an image sequence and frame number Given a list of file names, images, or image sequences, this allows the user to chose one of entry and also to select a frame number. The selected image is available via the :py:attr:`output` traitlet. """ images = traitlets.Union([traitlets.Dict(), traitlets.List()]) """Images or sequences to select from. Image sequences can be passed as 3D :py:class:`numpy.ndarray`, as lists of 2D arrays, or as paths to image files, which will be opened using :py:mod:`pims`. Single images are represented as 2D arrays. This attribute can be a list of ``(key, img)`` tuples where ``key`` is the name to display and value an image (sequence), a dict mapping ``key`` to ``img`` (which will be converted to a list of tuples) or a plain list of image (sequence). """ output = traitlets.Instance(np.ndarray, allow_none=True) """2D array representing the currently selected frame.""" index = traitlets.Int(allow_none=True) """Currently selected index (w.r.t. :py:attr:`images`)""" def __init__(self, images: Union[Sequence, Dict] = [], **kwargs): """Parameters --------- images List of image (sequences) to populate :py:attr:`images`. **kwargs Passed to parent ``__init__``. """ self._file_sel = ipywidgets.Dropdown(description="image") self._file_sel.observe(self._file_changed, "value") self._frame_sel = ipywidgets.BoundedIntText(description="frame", min=0, max=0) self._frame_sel.observe(self._frame_changed, "value") self._prev_button = ipywidgets.Button(icon="arrow-left") self._prev_button.on_click(self._prev_button_clicked) self._next_button = ipywidgets.Button(icon="arrow-right") self._next_button.on_click(self._next_button_clicked) for b in self._prev_button, self._next_button: b.layout.width = "auto" b.disabled = True self.show_file_buttons = False traitlets.link((self._file_sel, "index"), (self, "index")) super().__init__([ self._prev_button, self._file_sel, self._next_button, self._frame_sel ], **kwargs) self._cur_image = None self._cur_image_opened = False self._frame_changed_lock = threading.Lock() self.images = images @traitlets.validate("images") def _make_images_list(self, proposal): """Validator for the :py:attr:`images` traitlet Turns dictionaries into lists of tuples. """ images = proposal["value"] if len(images) == 0: return [] if isinstance(images, dict): return list(images.items()) return images @traitlets.observe("images") def _set_file_options(self, change=None): """Set the options for the sequence selection dropdown element""" if len(self.images) == 0: self._file_sel.options = [] return n_figures = int(math.log10(len(self.images))) generic_key_pattern = "<{{:0{}}}>".format(n_figures) opts = [] for n, img in enumerate(self.images): if isinstance(img, tuple): opts.append(img[0]) continue if isinstance(img, str): img = Path(img) if isinstance(img, Path): opts.append("{} ({})".format(img.name, str(img.parent))) continue opts.append(generic_key_pattern.format(n)) self._file_sel.options = opts def _file_changed(self, change=None): """Call-back upon change of the currently selected sequence""" if self._cur_image_opened: self._cur_image.close() self._cur_image_opened = False if self._file_sel.value is None: # No file selected with self._frame_changed_lock: self._frame_sel.max = 0 self._cur_image = None self.output = None self._prev_button.disabled = True self._next_button.disabled = True return self._prev_button.disabled = self._file_sel.index <= 0 self._next_button.disabled = (self._file_sel.index >= len(self.images) - 1) img = self.images[self._file_sel.index] if isinstance(img, tuple): # TODO: What if there is a tuple of images instead of (key, value)? img = img[1] if isinstance(img, np.ndarray) and img.ndim == 2: # Single image img = img[None, ...] elif isinstance(img, (str, Path)): # Open… img = pims.open(str(img)) self._cur_image_opened = True self._cur_image = img with self._frame_changed_lock: # Disable potential update at this point. Will be explicitly # updated below. self._frame_sel.max = len(img) - 1 self._frame_changed() def _frame_changed(self, change=None): """Call-back upon change of the currently selected frame number""" if self._frame_changed_lock.locked(): return self.output = self._cur_image[self._frame_sel.value] def _prev_button_clicked(self, button=None): self._file_sel.index -= 1 def _next_button_clicked(self, button=None): self._file_sel.index += 1 @property def show_file_buttons(self) -> bool: """Whether to show "back" and "forward" buttons for file selection""" return self._prev_button.layout.display is None @show_file_buttons.setter def show_file_buttons(self, s): for b in self._prev_button, self._next_button: b.layout.display = None if s else "none"
class WCSBase(DataSource): """ Access data from a WCS source. Attributes ---------- source : str WCS server url layer : str layer name (required) version : str WCS version, passed through to all requests (default '1.0.0') format : str Data format, passed through to the GetCoverage requests (default 'geotiff') crs : str coordinate reference system, passed through to the GetCoverage requests (default 'EPSG:4326') interpolation : str Interpolation, passed through to the GetCoverage requests. max_size : int maximum request size, optional. If provided, the coordinates will be tiled into multiple requests. """ source = tl.Unicode().tag(attr=True) layer = tl.Unicode().tag(attr=True) version = tl.Unicode(default_value="1.0.0").tag(attr=True) interpolation = tl.Unicode(default_value=None, allow_none=True).tag(attr=True) format = tl.CaselessStrEnum(["geotiff", "geotiff_byte"], default_value="geotiff") crs = tl.Unicode(default_value="EPSG:4326") max_size = tl.Long(default_value=None, allow_none=True) _repr_keys = ["source", "layer"] _requested_coordinates = tl.Instance(Coordinates, allow_none=True) _evaluated_coordinates = tl.Instance(Coordinates) @cached_property def client(self): return owslib_wcs.WebCoverageService(self.source, version=self.version) def get_coordinates(self): """ Get the full WCS grid. """ metadata = self.client.contents[self.layer] # TODO select correct boundingbox by crs # coordinates w, s, e, n = metadata.boundingBoxWGS84 low = metadata.grid.lowlimits high = metadata.grid.highlimits xsize = int(high[0]) - int(low[0]) ysize = int(high[1]) - int(low[1]) coords = [] coords.append(UniformCoordinates1d(s, n, size=ysize, name="lat")) coords.append(UniformCoordinates1d(w, e, size=xsize, name="lon")) if metadata.timepositions: coords.append( ArrayCoordinates1d(metadata.timepositions, name="time")) if metadata.timelimits: raise NotImplementedError("TODO") return Coordinates(coords, crs=self.crs) def _eval(self, coordinates, output=None, _selector=None): """Evaluates this node using the supplied coordinates. This method intercepts the DataSource._eval method in order to use the requested coordinates directly when they are a uniform grid. Parameters ---------- coordinates : :class:`podpac.Coordinates` {requested_coordinates} An exception is raised if the requested coordinates are missing dimensions in the DataSource. Extra dimensions in the requested coordinates are dropped. output : :class:`podpac.UnitsDataArray`, optional {eval_output} _selector: callable(coordinates, request_coordinates) {eval_selector} Returns ------- {eval_return} Raises ------ ValueError Cannot evaluate these coordinates """ # the datasource does do this, but we need to do it here to correctly select the correct case if self.coordinates.crs.lower() != coordinates.crs.lower(): coordinates = coordinates.transform(self.coordinates.crs) # for a uniform grid, use the requested coordinates (the WCS server will interpolate) if (("lat" in coordinates.dims and "lon" in coordinates.dims) and (coordinates["lat"].is_uniform or coordinates["lat"].size == 1) and (coordinates["lon"].is_uniform or coordinates["lon"].size == 1)): def selector(rsc, coordinates, index_type=None): return coordinates, None return super()._eval(coordinates, output=output, _selector=selector) # for uniform stacked, unstack to use the requested coordinates (the WCS server will interpolate) if (("lat" in coordinates.udims and coordinates.is_stacked("lat")) and ("lon" in coordinates.udims and coordinates.is_stacked("lon")) and (coordinates["lat"].is_uniform or coordinates["lat"].size == 1) and (coordinates["lon"].is_uniform or coordinates["lon"].size == 1)): def selector(rsc, coordinates, index_type=None): unstacked = coordinates.unstack() unstacked = unstacked.drop( "alt", ignore_missing=True) # if lat_lon_alt return unstacked, None udata = super()._eval(coordinates, output=None, _selector=selector) data = udata.data.diagonal() # get just the stacked data if output is None: output = self.create_output_array(coordinates, data=data) else: output.data[:] = data return output # otherwise, pass-through (podpac will select and interpolate) return super()._eval(coordinates, output=output, _selector=_selector) def _get_data(self, coordinates, coordinates_index): """{get_data}""" # transpose the coordinates to match the response data if "time" in coordinates: coordinates = coordinates.transpose("time", "lat", "lon") else: coordinates = coordinates.transpose("lat", "lon") # determine the chunk size (if applicable) if self.max_size is not None: shape = [] s = 1 for n in coordinates.shape: r = self.max_size // s if r == 0: shape.append(1) elif r < n: shape.append(r) else: shape.append(n) s *= n shape = tuple(shape) else: shape = coordinates.shape # request each chunk and composite the data output = self.create_output_array(coordinates) for i, (chunk, slc) in enumerate( coordinates.iterchunks(shape, return_slices=True)): output[slc] = self._get_chunk(chunk) return output def _get_chunk(self, coordinates): if coordinates["lon"].size == 1: w = coordinates["lon"].coordinates[0] e = coordinates["lon"].coordinates[0] else: w = coordinates["lon"].start - coordinates["lon"].step / 2.0 e = coordinates["lon"].stop + coordinates["lon"].step / 2.0 if coordinates["lat"].size == 1: s = coordinates["lat"].coordinates[0] n = coordinates["lat"].coordinates[0] else: s = coordinates["lat"].start - coordinates["lat"].step / 2.0 n = coordinates["lat"].stop + coordinates["lat"].step / 2.0 width = coordinates["lon"].size height = coordinates["lat"].size kwargs = {} if "time" in coordinates: kwargs["time"] = coordinates["time"].coordinates.astype( str).tolist() if isinstance(self.interpolation, str): kwargs["interpolation"] = self.interpolation logger.info( "WCS GetCoverage (source=%s, layer=%s, bbox=%s, shape=%s)" % (self.source, self.layer, (w, n, e, s), (width, height))) response = self.client.getCoverage(identifier=self.layer, bbox=(w, n, e, s), width=width, height=height, crs=self.crs, format=self.format, version=self.version, **kwargs) content = response.read() # check for errors xml = bs4.BeautifulSoup(content, "lxml") error = xml.find("serviceexception") if error: raise WCSError(error.text) # get data using rasterio with rasterio.MemoryFile() as mf: mf.write(content) dataset = mf.open(driver="GTiff") if "time" in coordinates and coordinates["time"].size > 1: # this should be easy to do, I'm just not sure how the data comes back. # is each time in a different band? raise NotImplementedError("TODO") data = dataset.read(1).astype(float) return data @classmethod def get_layers(cls, source=None): if source is None: source = cls.source client = owslib_wcs.WebCoverageService(source) return list(client.contents)
class DatashaderVisualizer(NXBase): """ A class for visualization an RDF graph with datashader :param graph: an rdflib.graph.Graph object to visualize. :param tooltip: takes either 'nodes' or 'edges', and sets the hover tool. :param sparql: a query you'd like to perform on the rdflib.graph.Grab object. """ output = T.Instance(W.Output) tooltip = T.Unicode(default_value="nodes") tooltip_dict = T.Dict() node_tooltips = T.List() edge_tooltips = T.List() sparql = T.Unicode() @T.default("output") def _make_default_output(self): return W.Output() @T.default("edge_tooltips") def _make_edge_tooltip(self): return [ ("Source", "@start"), ("Target", "@end"), ] @T.default("node_tooltips") def _make_node_tooltip(self): return [ ("ID", "@index"), ] @T.default("tooltip") def _make_tooltip(self): return "nodes" @T.default("tooltip_dict") def _make_tooltip_dict(self): return { "nodes": HoverTool(tooltips=self.node_tooltips), "edges": HoverTool(tooltips=self.edge_tooltips), } @T.default("sparql") def _make_sparql(self): return """ CONSTRUCT { ?s ?p ?o . } WHERE { ?s ?p ?o . FILTER (!isLiteral(?o)) FILTER (!isLiteral(?s)) } LIMIT 300 """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.children = [self.output] def display_datashader_vis(self, p): self.output.clear_output() with self.output: IPython.display.display(p) def strip_and_produce_rdf_graph(self, rdf_graph: Graph): """ A function that takes in an rdflib.graph.Graph object and transforms it into a datashader holoviews graph. Also performs the sparql query on the graph that can be set via the 'sparql' parameter """ sparql = self.sparql qres = rdf_graph.query(sparql) uri_graph = Graph() for row in qres: uri_graph.add(row) new_netx = rdflib_to_networkx_graph(uri_graph) original = hv.Graph.from_networkx(new_netx, self._nx_layout, **self.graph_layout_params) output_graph = bundle_graph(original) return output_graph def set_options(self, output_graph): return output_graph.options( frame_width=1000, frame_height=1000, xaxis=None, yaxis=None, tools=[self.tooltip_dict[self.tooltip], "tap", "box_select"], inspection_policy=self.tooltip, node_color=self.node_color, edge_color=self.edge_color, ) def tap_stream_subscriber(self, x, y): nodes_data = self.output_graph.nodes.data t = 0.01 values = nodes_data[nodes_data.x.between(x - t, x + t, True)][nodes_data.y.between( y - t, y + t, True)] self.selected_nodes = tuple([URIRef(_) for _ in list(values["index"])]) def box_stream_subscriber(self, **kwargs): bounds = kwargs["bounds"] nodes_data = self.output_graph.nodes.data values = nodes_data[nodes_data.x.between(bounds[0], bounds[2], True)][nodes_data.y.between( bounds[1], bounds[3], True)] self.selected_nodes = tuple([URIRef(_) for _ in list(values["index"])]) @T.observe("_nx_layout", "sparql", "graph", "graph_layout_params") def changed_layout(self, change): if self.graph is None: self.output_graph = None self.display_datashader_vis(self.output_graph) elif len(self.graph) == 0: self.output_graph = None self.display_datashader_vis("Cannot display blank graph.") else: self.output_graph = self.strip_and_produce_rdf_graph(self.graph) self.tap_selection_stream = streams.Tap(source=self.output_graph) self.tap_selection_stream.add_subscriber( self.tap_stream_subscriber) self.box_selection_stream = streams.BoundsXY( source=self.output_graph) self.box_selection_stream.add_subscriber( self.box_stream_subscriber) self.final_graph = self.set_options(self.output_graph) self.display_datashader_vis(self.final_graph)
class Foo(jst.JSONHasTraits): _additional_traits = False x = T.Integer() y = T.Instance(Bar)
class ZarrOutputMixin(tl.HasTraits): """ This class assumes that the node has a 'output_format' attribute (currently the "Lambda" Node, and the "Process" Node) Attributes ----------- zarr_file: str Path to the output zarr file that collects all of the computed results. This can reside on S3. dataset: ZarrGroup A handle to the zarr group pointing to the output file fill_output: bool, optional Default is False (unlike parent class). If True, will collect the output data and return it as an xarray. init_file_mode: str, optional Default is 'w'. Mode used for initializing the zarr file. zarr_chunks: dict Size of the chunks in the zarr file for each dimension zarr_shape: dict, optional Default is the {coordinated.dims: coordinates.shape}, where coordinates used as part of the eval call. This does not need to be specified unless the Node modifies the input coordinates (as part of a Reduce operation, for example). The result can be incorrect and requires care/checking by the user. zarr_coordinates: podpac.Coordinates, optional Default is None. If the node modifies the shape of the input coordinates, this allows users to set the coordinates in the output zarr file. This can be incorrect and requires care by the user. skip_existing: bool Default is False. If true, this will check to see if the results already exist. And if so, it will not submit a job for that particular coordinate evaluation. This assumes self.chunks == self.zar_chunks list_dir: bool, optional Default is False. If skip_existing is True, by default existing files are checked by asking for an 'exists' call. If list_dir is True, then at the first opportunity a "list_dir" is performed on the directory and the results are cached. """ zarr_file = tl.Unicode().tag(attr=True) dataset = tl.Any() zarr_node = NodeTrait() zarr_data_key = tl.Union([tl.Unicode(), tl.List()]) fill_output = tl.Bool(False) init_file_mode = tl.Unicode("a").tag(attr=True) zarr_chunks = tl.Dict(default_value=None, allow_none=True).tag(attr=True) zarr_shape = tl.Dict(allow_none=True, default_value=None).tag(attr=True) zarr_coordinates = tl.Instance(Coordinates, allow_none=True, default_value=None).tag(attr=True) zarr_dtype = tl.Unicode("f4") skip_existing = tl.Bool(True).tag(attr=True) list_dir = tl.Bool(False) _list_dir = tl.List(allow_none=True, default_value=[]) _shape = tl.Tuple() _chunks = tl.List() aws_client_kwargs = tl.Dict() aws_config_kwargs = tl.Dict() def eval(self, coordinates, **kwargs): output = kwargs.get("output") if self.zarr_shape is None: self._shape = coordinates.shape else: self._shape = tuple(self.zarr_shape.values()) # initialize zarr file if self.zarr_chunks is None: chunks = [self.chunks[d] for d in coordinates] else: chunks = [self.zarr_chunks[d] for d in coordinates] self._chunks = chunks zf, data_key, zn = self.initialize_zarr_array(self._shape, chunks) self.dataset = zf self.zarr_data_key = data_key self.zarr_node = zn zn.keys # eval _log.debug("Starting parallel eval.") missing_dims = [ d for d in coordinates.dims if d not in self.chunks.keys() ] if self.zarr_coordinates is not None: missing_dims = missing_dims + [ d for d in self.zarr_coordinates.dims if d not in missing_dims ] set_coords = merge_dims( [coordinates.drop(missing_dims), self.zarr_coordinates]) else: set_coords = coordinates.drop(missing_dims) set_coords.transpose(*coordinates.dims) self.set_zarr_coordinates(set_coords, data_key) if self.list_dir: dk = data_key if isinstance(dk, list): dk = dk[0] self._list_dir = self.zarr_node.list_dir(dk) output = super(ZarrOutputMixin, self).eval(coordinates, output=output) # fill in the coordinates, this is guaranteed to be correct even if the user messed up. if output is not None: self.set_zarr_coordinates(Coordinates.from_xarray(output.coords), data_key) else: return zf return output def set_zarr_coordinates(self, coordinates, data_key): # Fill in metadata for dk in data_key: self.dataset[dk].attrs["_ARRAY_DIMENSIONS"] = coordinates.dims for d in coordinates.dims: # TODO ADD UNITS AND TIME DECODING INFORMATION self.dataset.create_dataset(d, shape=coordinates[d].size, overwrite=True) self.dataset[d][:] = coordinates[d].coordinates def initialize_zarr_array(self, shape, chunks): _log.debug("Creating Zarr file.") zn = Zarr(source=self.zarr_file, file_mode=self.init_file_mode, aws_client_kwargs=self.aws_client_kwargs) if self.source.output or getattr(self.source, "data_key", None): data_key = self.source.output if data_key is None: data_key = self.source.data_key if not isinstance(data_key, list): data_key = [data_key] elif self.source.outputs: # If someone restricted the outputs for this node, we need to know data_key = [dk for dk in data_key if dk in self.source.outputs] elif self.source.outputs: data_key = self.source.outputs else: data_key = ["data"] zf = zarr.open(zn._get_store(), mode=self.init_file_mode) # Intialize the output zarr arrays for dk in data_key: try: arr = zf.create_dataset( dk, shape=shape, chunks=chunks, fill_value=np.nan, dtype=self.zarr_dtype, overwrite=not self.skip_existing, ) except ValueError: pass # Dataset already exists # Recompute any cached properties zn = Zarr(source=self.zarr_file, file_mode=self.init_file_mode, aws_client_kwargs=self.aws_client_kwargs) return zf, data_key, zn def eval_source(self, coordinates, coordinates_index, out, i, source=None): if source is None: source = self.source if self.skip_existing: # This section allows previously computed chunks to be skipped dk = self.zarr_data_key if isinstance(dk, list): dk = dk[0] try: exists = self.zarr_node.chunk_exists(coordinates_index, data_key=dk, list_dir=self._list_dir, chunks=self._chunks) except ValueError as e: # This was needed in cases where a poor internet connection caused read errors exists = False if exists: _log.info("Skipping {} (already exists)".format(i)) return out, coordinates_index # Make a copy to prevent any possibility of memory corruption source = Node.from_definition(source.definition) _log.debug("Creating output format.") output = dict( format="zarr_part", format_kwargs=dict( part=[[s.start, min(s.stop, self._shape[i]), s.step] for i, s in enumerate(coordinates_index)], source=self.zarr_file, mode="a", ), ) _log.debug("Finished creating output format.") if source.has_trait("output_format"): source.set_trait("output_format", output) _log.debug("output: {}, coordinates.shape: {}".format( output, coordinates.shape)) _log.debug("Evaluating node.") o, slc = super(ZarrOutputMixin, self).eval_source(coordinates, coordinates_index, out, i, source) if not source.has_trait("output_format"): o.to_format(output["format"], **output["format_kwargs"]) return o, slc
class StackedCoordinates(BaseCoordinates): """ Stacked coordinates. StackedCoordinates contain coordinates from two or more different dimensions that are stacked together to form a list of points (rather than a grid). The underlying coordinates values are :class:`Coordinates1d` objects of equal size. The name for the stacked coordinates combines the underlying dimensions with underscores, e.g. ``'lat_lon'`` or ``'lat_lon_time'``. When creating :class:`Coordinates`, podpac automatically detects StackedCoordinates. The following Coordinates contain 3 stacked lat-lon coordinates and 2 time coordinates in a 3 x 2 grid:: >>> lat = [0, 1, 2] >>> lon = [10, 20, 30] >>> time = ['2018-01-01', '2018-01-02'] >>> podpac.Coordinates([[lat, lon], time], dims=['lat_lon', 'time']) Coordinates lat_lon[lat]: ArrayCoordinates1d(lat): Bounds[0.0, 2.0], N[3] lat_lon[lon]: ArrayCoordinates1d(lon): Bounds[10.0, 30.0], N[3] time: ArrayCoordinates1d(time): Bounds[2018-01-01, 2018-01-02], N[2] For convenience, you can also create uniformly-spaced stacked coordinates using :class:`clinspace`:: >>> lat_lon = podpac.clinspace((0, 10), (2, 30), 3) >>> time = ['2018-01-01', '2018-01-02'] >>> podpac.Coordinates([lat_lon, time], dims=['lat_lon', 'time']) Coordinates lat_lon[lat]: ArrayCoordinates1d(lat): Bounds[0.0, 2.0], N[3] lat_lon[lon]: ArrayCoordinates1d(lon): Bounds[10.0, 30.0], N[3] time: ArrayCoordinates1d(time): Bounds[2018-01-01, 2018-01-02], N[2] Parameters ---------- dims : tuple Tuple of dimension names. name : str Stacked dimension name. coords : dict-like xarray coordinates (container of coordinate arrays) coordinates : pandas.MultiIndex MultiIndex of stacked coordinates values. """ _coords = tl.List(trait=tl.Instance(Coordinates1d), read_only=True) def __init__(self, coords, name=None, dims=None): """ Initialize a multidimensional coords bject. Parameters ---------- coords : list, :class:`StackedCoordinates` Coordinate values in a list, or a StackedCoordinates object to copy. See Also -------- clinspace, crange """ if not isinstance(coords, (list, tuple)): raise TypeError("Unrecognized coords type '%s'" % type(coords)) if len(coords) < 2: raise ValueError( "Stacked coords must have at least 2 coords, got %d" % len(coords)) # coerce coords = tuple( c if isinstance(c, Coordinates1d) else ArrayCoordinates1d(c) for c in coords) # set coords self.set_trait("_coords", coords) # propagate properties if dims is not None and name is not None: raise TypeError( "StackedCoordinates expected 'dims' or 'name', not both") if dims is not None: self._set_dims(dims) if name is not None: self._set_name(name) # finalize super(StackedCoordinates, self).__init__() @tl.validate("_coords") def _validate_coords(self, d): val = d["value"] # check sizes shape = val[0].shape for c in val[1:]: if c.shape != shape: raise ValueError("Shape mismatch in stacked coords %s != %s" % (c.shape, shape)) # check dims dims = [c.name for c in val] for i, dim in enumerate(dims): if dim is not None and dim in dims[:i]: raise ValueError("Duplicate dimension '%s' in stacked coords" % dim) return val def _set_name(self, value): dims = value.split("_") # check length if len(dims) != len(self._coords): raise ValueError( "Invalid name '%s' for StackedCoordinates with length %d" % (value, len(self._coords))) self._set_dims(dims) def _set_dims(self, dims): # check size if len(dims) != len(self._coords): raise ValueError( "Invalid dims '%s' for StackedCoordinates with length %d" % (dims, len(self._coords))) for i, dim in enumerate(dims): if dim is not None and dim in dims[:i]: raise ValueError("Duplicate dimension '%s' in dims" % dim) # set names, checking for duplicates for i, (c, dim) in enumerate(zip(self._coords, dims)): if dim is None: continue c._set_name(dim) # ------------------------------------------------------------------------------------------------------------------ # Alternate constructors # ------------------------------------------------------------------------------------------------------------------ @classmethod def from_xarray(cls, x, **kwargs): """ Create 1d Coordinates from named xarray coordinates. Arguments --------- x : xarray.DataArray Nade DataArray of the coordinate values Returns ------- :class:`ArrayCoordinates1d` 1d coordinates """ dims = x.dims[0].split("_") cs = [x[dim].data for dim in dims] return cls(cs, dims=dims, **kwargs) @classmethod def from_definition(cls, d): """ Create StackedCoordinates from a stacked coordinates definition. Arguments --------- d : list stacked coordinates definition Returns ------- :class:`StackedCoordinates` stacked coordinates object See Also -------- definition """ coords = [] for elem in d: if "start" in elem and "stop" in elem and ("step" in elem or "size" in elem): c = UniformCoordinates1d.from_definition(elem) elif "values" in elem: c = ArrayCoordinates1d.from_definition(elem) else: raise ValueError( "Could not parse coordinates definition with keys %s" % elem.keys()) coords.append(c) return cls(coords) # ------------------------------------------------------------------------------------------------------------------ # standard methods, list-like # ------------------------------------------------------------------------------------------------------------------ def __repr__(self): rep = str(self.__class__.__name__) for c in self._coords: rep += "\n\t%s[%s]: %s" % (self.name, c.name or "?", c) return rep def __iter__(self): return iter(self._coords) def __len__(self): return len(self._coords) def __getitem__(self, index): if isinstance(index, string_types): if index not in self.dims: raise KeyError("Dimension '%s' not found in dims %s" % (index, self.dims)) return self._coords[self.dims.index(index)] else: return StackedCoordinates([c[index] for c in self._coords]) def __setitem__(self, dim, c): if not dim in self.dims: raise KeyError( "Cannot set dimension '%s' in StackedCoordinates %s" % (dim, self.dims)) # try to cast to ArrayCoordinates1d if not isinstance(c, Coordinates1d): c = ArrayCoordinates1d(c) if c.name is None: c.name = dim # replace the element of the coords list idx = self.dims.index(dim) coords = list(self._coords) coords[idx] = c # set (and check) new coords list self.set_trait("_coords", coords) def __contains__(self, item): try: item = np.array([make_coord_value(value) for value in item]) except: return False if len(item) != len(self._coords): return False if any(val not in c for val, c in zip(item, self._coords)): return False return (self.flatten().coordinates == item).all(axis=1).any() def __eq__(self, other): if not isinstance(other, StackedCoordinates): return False # shortcuts if self.dims != other.dims: return False if self.shape != other.shape: return False # full check of underlying coordinates if self._coords != other._coords: return False return True # ------------------------------------------------------------------------------------------------------------------ # Properties # ------------------------------------------------------------------------------------------------------------------ @property def dims(self): """:tuple: Tuple of dimension names.""" return tuple(c.name for c in self._coords) @property def ndim(self): """:int: coordinates array ndim.""" return self._coords[0].ndim @property def name(self): """:str: Stacked dimension name. Stacked dimension names are the individual `dims` joined by an underscore.""" if any(self.dims): return "_".join(dim or "?" for dim in self.dims) @property def size(self): """:int: Number of stacked coordinates. """ return self._coords[0].size @property def shape(self): """:tuple: Shape of the stacked coordinates.""" return self._coords[0].shape @property def bounds(self): """:dict: Dictionary of (low, high) coordinates bounds in each dimension""" if None in self.dims: raise ValueError( "Cannot get bounds for StackedCoordinates with un-named dimensions" ) return {dim: self[dim].bounds for dim in self.udims} @property def coordinates(self): dtypes = [c.dtype for c in self._coords] if len(set(dtypes)) == 1: dtype = dtypes[0] else: dtype = object return np.dstack([c.coordinates.astype(dtype) for c in self._coords]).squeeze() @property def xcoords(self): """:dict-like: xarray coordinates (container of coordinate arrays)""" if None in self.dims: raise ValueError( "Cannot get xcoords for StackedCoordinates with un-named dimensions" ) if self.ndim == 1: # use a multi-index so that we can use DataArray.sel easily coords = pd.MultiIndex.from_arrays( [np.array(c.coordinates) for c in self._coords], names=self.dims) xcoords = {self.name: coords} else: # fall-back for shaped coordinates xcoords = { c.name: (self.xdims, c.coordinates) for c in self._coords } return xcoords @property def definition(self): """:list: Serializable stacked coordinates definition. """ return [c.definition for c in self._coords] @property def full_definition(self): """:list: Serializable stacked coordinates definition, containing all properties. For internal use.""" return [c.full_definition for c in self._coords] # ----------------------------------------------------------------------------------------------------------------- # Methods # ----------------------------------------------------------------------------------------------------------------- def copy(self): """ Make a copy of the stacked coordinates. Returns ------- :class:`StackedCoordinates` Copy of the stacked coordinates. """ return StackedCoordinates(self._coords) def unique(self, return_index=False): """ Remove duplicate stacked coordinate values. Arguments --------- return_index : bool, optional If True, return index for the unique coordinates in addition to the coordinates. Default False. Returns ------- unique : :class:`StackedCoordinates` New StackedCoordinates object with unique, sorted, flattened coordinate values. unique_index : list of indices index """ flat = self.flatten() a, I = np.unique(flat.coordinates, axis=0, return_index=True) if return_index: return flat[I], I else: return flat[I] def get_area_bounds(self, boundary): """Get coordinate area bounds, including boundary information, for each unstacked dimension. Arguments --------- boundary : dict dictionary of boundary offsets for each unstacked dimension. Point dimensions can be omitted. Returns ------- area_bounds : dict Dictionary of (low, high) coordinates area_bounds in each unstacked dimension """ if None in self.dims: raise ValueError( "Cannot get area_bounds for StackedCoordinates with un-named dimensions" ) return { dim: self[dim].get_area_bounds(boundary.get(dim)) for dim in self.dims } def select(self, bounds, outer=False, return_index=False): """ Get the coordinate values that are within the given bounds in all dimensions. *Note: you should not generally need to call this method directly.* Parameters ---------- bounds : dict dictionary of dim -> (low, high) selection bounds outer : bool, optional If True, do *outer* selections. Default False. return_index : bool, optional If True, return index for the selections in addition to coordinates. Default False. Returns ------- selection : :class:`StackedCoordinates` StackedCoordinates object consisting of the selection in all dimensions. selection_index : slice, boolean array Slice or index for the selected coordinates, only if ``return_index`` is True. """ # logical AND of the selection in each dimension indices = [ c.select(bounds, outer=outer, return_index=True)[1] for c in self._coords ] index = self._and_indices(indices) if return_index: return self[index], index else: return self[index] def _and_indices(self, indices): if all(isinstance(index, slice) for index in indices): index = slice(max(index.start or 0 for index in indices), min(index.stop or self.size for index in indices)) # for consistency if index.start == 0 and index.stop == self.size: index = slice(None, None) else: # convert any slices to boolean array for i, index in enumerate(indices): if isinstance(index, slice): indices[i] = np.zeros(self.shape, dtype=bool) indices[i][index] = True # logical and index = np.logical_and.reduce(indices) # for consistency if np.all(index): index = slice(None, None) return index def _transform(self, transformer): coords = [c.copy() for c in self._coords] if "lat" in self.dims and "lon" in self.dims and "alt" in self.dims: ilat = self.dims.index("lat") ilon = self.dims.index("lon") ialt = self.dims.index("alt") lat = coords[ilat] lon = coords[ilon] alt = coords[ialt] tlon, tlat, talt = transformer.transform(lon.coordinates, lat.coordinates, alt.coordinates) coords[ilat] = ArrayCoordinates1d(tlat, "lat").simplify() coords[ilon] = ArrayCoordinates1d(tlon, "lon").simplify() coords[ialt] = ArrayCoordinates1d(talt, "alt").simplify() elif "lat" in self.dims and "lon" in self.dims: ilat = self.dims.index("lat") ilon = self.dims.index("lon") lat = coords[ilat] lon = coords[ilon] tlon, tlat = transformer.transform(lon.coordinates, lat.coordinates) if (self.ndim == 2 and all(np.allclose(a, tlat[:, 0]) for a in tlat.T) and all(np.allclose(a, tlon[0]) for a in tlon)): coords[ilat] = ArrayCoordinates1d(tlat[:, 0], name="lat").simplify() coords[ilon] = ArrayCoordinates1d(tlon[0], name="lon").simplify() return coords coords[ilat] = ArrayCoordinates1d(tlat, "lat").simplify() coords[ilon] = ArrayCoordinates1d(tlon, "lon").simplify() elif "alt" in self.dims: ialt = self.dims.index("alt") alt = coords[ialt] _, _, talt = transformer.transform(np.zeros(self.size), np.zeros(self.size), alt.coordinates) coords[ialt] = ArrayCoordinates1d(talt, "alt").simplify() return StackedCoordinates(coords) def transpose(self, *dims, **kwargs): """ Transpose (re-order) the dimensions of the StackedCoordinates. Parameters ---------- dim_1, dim_2, ... : str, optional Reorder dims to this order. By default, reverse the dims. in_place : boolean, optional If True, transpose the dimensions in-place. Otherwise (default), return a new, transposed Coordinates object. Returns ------- transposed : :class:`StackedCoordinates` The transposed StackedCoordinates object. """ in_place = kwargs.get("in_place", False) if len(dims) == 0: dims = list(self.dims[::-1]) if set(dims) != set(self.dims): raise ValueError( "Invalid transpose dimensions, input %s does match any dims in %s" % (dims, self.dims)) coordinates = [self._coords[self.dims.index(dim)] for dim in dims] if in_place: self.set_trait("_coords", coordinates) return self else: return StackedCoordinates(coordinates) def flatten(self): return StackedCoordinates([c.flatten() for c in self._coords]) def reshape(self, newshape): return StackedCoordinates([c.reshape(newshape) for c in self._coords]) def issubset(self, other): """Report whether other coordinates contains these coordinates. Arguments --------- other : Coordinates, StackedCoordinates Other coordinates to check Returns ------- issubset : bool True if these coordinates are a subset of the other coordinates. """ from podpac.core.coordinates import Coordinates if not isinstance(other, (Coordinates, StackedCoordinates)): raise TypeError( "StackedCoordinates issubset expected Coordinates or StackedCoordinates, not '%s'" % type(other)) if isinstance(other, StackedCoordinates): if set(self.dims) != set(other.dims): return False mine = self.flatten().coordinates other = other.flatten().transpose(*self.dims).coordinates return set(map(tuple, mine)).issubset(map(tuple, other)) elif isinstance(other, Coordinates): if not all(dim in other.udims for dim in self.dims): return False acs = [] ocs = [] for coords in other.values(): dims = [dim for dim in coords.dims if dim in self.dims] if len(dims) == 0: continue elif len(dims) == 1: acs.append(self[dims[0]]) if isinstance(coords, Coordinates1d): ocs.append(coords) elif isinstance(coords, StackedCoordinates): ocs.append(coords[dims[0]]) elif len(dims) > 1: acs.append(StackedCoordinates([self[dim] for dim in dims])) if isinstance(coords, StackedCoordinates): ocs.append( StackedCoordinates([coords[dim] for dim in dims])) return all(a.issubset(o) for a, o in zip(acs, ocs))
class ScipyGrid(ScipyPoint): """Scipy Interpolation Attributes ---------- {interpolator_attributes} """ methods_supported = ["nearest", "bilinear", "cubic_spline", "spline_2", "spline_3", "spline_4"] method = tl.Unicode(default_value="nearest") # TODO: implement these parameters for the method 'nearest' spatial_tolerance = tl.Float(default_value=np.inf) time_tolerance = tl.Union([tl.Unicode(), tl.Instance(np.timedelta64, allow_none=True)], default_value=None) @common_doc(COMMON_INTERPOLATOR_DOCS) def can_interpolate(self, udims, source_coordinates, eval_coordinates): """ {interpolator_can_interpolate} """ # TODO: make this so we don't need to specify lat and lon together # or at least throw a warning if ( "lat" in udims and "lon" in udims and self._dim_in(["lat", "lon"], source_coordinates) and self._dim_in(["lat", "lon"], eval_coordinates, unstacked=True) ): return ["lat", "lon"] # otherwise return no supported dims return tuple() @common_doc(COMMON_INTERPOLATOR_DOCS) def interpolate(self, udims, source_coordinates, source_data, eval_coordinates, output_data): """ {interpolator_interpolate} """ if self._dim_in(["lat", "lon"], eval_coordinates): return self._interpolate_irregular_grid( udims, source_coordinates, source_data, eval_coordinates, output_data, grid=True ) elif self._dim_in(["lat", "lon"], eval_coordinates, unstacked=True): eval_coordinates_us = eval_coordinates.unstack() return self._interpolate_irregular_grid( udims, source_coordinates, source_data, eval_coordinates_us, output_data, grid=False ) def _interpolate_irregular_grid( self, udims, source_coordinates, source_data, eval_coordinates, output_data, grid=True ): if len(source_data.dims) > 2: keep_dims = ["lat", "lon"] return self._loop_helper( self._interpolate_irregular_grid, keep_dims, udims, source_coordinates, source_data, eval_coordinates, output_data, grid=grid, ) s = [] if source_coordinates["lat"].is_descending: lat = source_coordinates["lat"].coordinates[::-1] s.append(slice(None, None, -1)) else: lat = source_coordinates["lat"].coordinates s.append(slice(None, None)) if source_coordinates["lon"].is_descending: lon = source_coordinates["lon"].coordinates[::-1] s.append(slice(None, None, -1)) else: lon = source_coordinates["lon"].coordinates s.append(slice(None, None)) data = source_data.data[tuple(s)] # remove nan's I, J = np.isfinite(lat), np.isfinite(lon) coords_i = lat[I], lon[J] coords_i_dst = [eval_coordinates["lon"].coordinates, eval_coordinates["lat"].coordinates] # Swap order in case datasource uses lon,lat ordering instead of lat,lon if source_coordinates.dims.index("lat") > source_coordinates.dims.index("lon"): I, J = J, I coords_i = coords_i[::-1] coords_i_dst = coords_i_dst[::-1] data = data[I, :][:, J] if self.method in ["bilinear", "nearest"]: f = RegularGridInterpolator( coords_i, data, method=self.method.replace("bi", ""), bounds_error=False, fill_value=np.nan ) if grid: x, y = np.meshgrid(*coords_i_dst) else: x, y = coords_i_dst output_data.data[:] = f((y.ravel(), x.ravel())).reshape(output_data.shape) # TODO: what methods is 'spline' associated with? elif "spline" in self.method: if self.method == "cubic_spline": order = 3 else: # TODO: make this a parameter order = int(self.method.split("_")[-1]) f = RectBivariateSpline(coords_i[0], coords_i[1], data, kx=max(1, order), ky=max(1, order)) output_data.data[:] = f(coords_i_dst[1], coords_i_dst[0], grid=grid).reshape(output_data.shape) return output_data