コード例 #1
0
class ScipyPoint(Interpolator):
    """Scipy Point Interpolation

    Attributes
    ----------
    {interpolator_attributes}
    """

    methods_supported = ["nearest"]
    method = tl.Unicode(default_value="nearest")
    dims_supported = ["lat", "lon"]

    # TODO: implement these parameters for the method 'nearest'
    spatial_tolerance = tl.Float(default_value=np.inf)
    time_tolerance = tl.Union([tl.Unicode(), tl.Instance(np.timedelta64, allow_none=True)])

    @common_doc(COMMON_INTERPOLATOR_DOCS)
    def can_interpolate(self, udims, source_coordinates, eval_coordinates):
        """
        {interpolator_can_interpolate}
        """

        # TODO: make this so we don't need to specify lat and lon together
        # or at least throw a warning
        if (
            "lat" in udims
            and "lon" in udims
            and not self._dim_in(["lat", "lon"], source_coordinates)
            and self._dim_in(["lat", "lon"], source_coordinates, unstacked=True)
            and self._dim_in(["lat", "lon"], eval_coordinates, unstacked=True)
        ):

            return tuple(["lat", "lon"])

        # otherwise return no supported dims
        return tuple()

    @common_doc(COMMON_INTERPOLATOR_DOCS)
    def interpolate(self, udims, source_coordinates, source_data, eval_coordinates, output_data):
        """
        {interpolator_interpolate}
        """

        order = "lat_lon" if "lat_lon" in source_coordinates.dims else "lon_lat"

        # calculate tolerance
        if isinstance(eval_coordinates["lat"], UniformCoordinates1d):
            dlat = eval_coordinates["lat"].step
        else:
            dlat = (eval_coordinates["lat"].bounds[1] - eval_coordinates["lat"].bounds[0]) / (
                eval_coordinates["lat"].size - 1
            )

        if isinstance(eval_coordinates["lon"], UniformCoordinates1d):
            dlon = eval_coordinates["lon"].step
        else:
            dlon = (eval_coordinates["lon"].bounds[1] - eval_coordinates["lon"].bounds[0]) / (
                eval_coordinates["lon"].size - 1
            )

        tol = np.linalg.norm([dlat, dlon]) * 8

        if self._dim_in(["lat", "lon"], eval_coordinates):
            pts = np.stack([source_coordinates[dim].coordinates for dim in source_coordinates[order].dims], axis=1)
            if order == "lat_lon":
                pts = pts[:, ::-1]
            pts = KDTree(pts)
            lon, lat = np.meshgrid(eval_coordinates["lon"].coordinates, eval_coordinates["lat"].coordinates)
            dist, ind = pts.query(np.stack((lon.ravel(), lat.ravel()), axis=1), distance_upper_bound=tol)
            mask = ind == source_data[order].size
            ind[mask] = 0  # This is a hack to make the select on the next line work
            # (the masked values are set to NaN on the following line)
            vals = source_data[{order: ind}]
            vals[mask] = np.nan
            # make sure 'lat_lon' or 'lon_lat' is the first dimension
            dims = [dim for dim in source_data.dims if dim != order]
            vals = vals.transpose(order, *dims).data
            shape = vals.shape
            coords = [eval_coordinates["lat"].coordinates, eval_coordinates["lon"].coordinates]
            coords += [source_coordinates[d].coordinates for d in dims]
            vals = vals.reshape(eval_coordinates["lat"].size, eval_coordinates["lon"].size, *shape[1:])
            vals = UnitsDataArray(vals, coords=coords, dims=["lat", "lon"] + dims)
            # and transpose back to the destination order
            output_data.data[:] = vals.transpose(*output_data.dims).data[:]

            return output_data

        elif self._dim_in(["lat", "lon"], eval_coordinates, unstacked=True):
            dst_order = "lat_lon" if "lat_lon" in eval_coordinates.dims else "lon_lat"
            src_stacked = np.stack(
                [source_coordinates[dim].coordinates for dim in source_coordinates[order].dims], axis=1
            )
            new_stacked = np.stack(
                [eval_coordinates[dim].coordinates for dim in source_coordinates[order].dims], axis=1
            )
            pts = KDTree(src_stacked)
            dist, ind = pts.query(new_stacked, distance_upper_bound=tol)
            mask = ind == source_data[order].size
            ind[mask] = 0
            vals = source_data[{order: ind}]
            vals[{order: mask}] = np.nan
            dims = list(output_data.dims)
            dims[dims.index(dst_order)] = order
            output_data.data[:] = vals.transpose(*dims).data[:]

            return output_data
コード例 #2
0
class GroupByTransformer(Transformer):
    '''The GroupByTransformer creates aggregations via the groupby operation, which are
    joined to a DataFrame. This is useful for creating aggregate features.

    Example:

    >>> import vaex
    >>> import vaex.ml
    >>> df_train = vaex.from_arrays(x=['dog', 'dog', 'dog', 'cat', 'cat'], y=[2, 3, 4, 10, 20])
    >>> df_test = vaex.from_arrays(x=['dog', 'cat', 'dog', 'mouse'], y=[5, 5, 5, 5])
    >>> group_trans = vaex.ml.GroupByTransformer(by='x', agg={'mean_y': vaex.agg.mean('y')}, rsuffix='_agg')
    >>> group_trans.fit_transform(df_train)
      #  x      y  x_agg      mean_y
      0  dog    2  dog             3
      1  dog    3  dog             3
      2  dog    4  dog             3
      3  cat   10  cat            15
      4  cat   20  cat            15
    >>> group_trans.transform(df_test)
      #  x        y  x_agg    mean_y
      0  dog      5  dog      3.0
      1  cat      5  cat      15.0
      2  dog      5  dog      3.0
      3  mouse    5  --       --
    '''

    snake_name = 'groupby_transformer'
    by = traitlets.Unicode(allow_none=False,
                           help='The feature on which to do the grouping.')
    agg = traitlets.Dict(
        help=
        'Dict where the keys are feature names and the values are vaex.agg objects.'
    )
    rprefix = traitlets.Unicode(
        default_value='',
        help=
        'Prefix for the names of the aggregate features in case of a collision.'
    )
    rsuffix = traitlets.Unicode(
        default_value='',
        help=
        'Suffix for the names of the aggregate features in case of a collision.'
    )
    df_group_ = traitlets.Instance(klass=vaex.dataframe.DataFrame,
                                   allow_none=True)

    def fit(self, df):
        '''
        Fit GroupByTransformer to the DataFrame.

        :param df: A vaex DataFrame.
        '''

        if not self.agg:
            raise ValueError(
                'You have to specify a dict for the `agg` keyword.')
        if len(self.by) == 0:
            raise ValueError('Please specify a value for the `by` keyword.')
        self.df_group_ = df.groupby(by=self.by, agg=self.agg)

    def transform(self, df):
        '''
        Transform a DataFrame with a fitted GroupByTransformer.

        :param df: A vaex DataFrame.

        :returns copy: a shallow copy of the DataFrame that includes the aggregated features.
        :rtype: DataFrame
        '''

        df = df.copy()
        # We effectively want to do a join, but since that is not part of the state, it will not be state
        # transferrable, instead we implement this with map
        # df = df.join(other=self.df_group_, on=self.by, how='left', rprefix=self.rprefix, rsuffix=self.rsuffix)
        key_values = self.df_group_[self.by].tolist()
        for name in self.df_group_.get_column_names():
            if name == self.by:
                continue  # we don't need to include the column we group/join on
            mapper = dict(zip(key_values, self.df_group_[name].values))
            join_name = name
            if join_name in df:
                join_name = self.rprefix + join_name + self.rsuffix
            df[join_name] = df[self.by].map(mapper, allow_missing=True)
        return df
コード例 #3
0
class LayerPicker(ipywidgets.HBox):
    """
    Widget to pick a WorkflowsLayer from a map

    In subclasses, set `_attr` to the trait name on WorkflowsLayer that you want mirrored into
    the `value` trait of this class.

    Attributes
    ----------
    value: ImageCollection, None
        The parametrized ImageCollection of the currently-selected layer.
    """

    value = traitlets.Instance(klass=ImageCollection,
                               allow_none=True,
                               read_only=True)
    _attr = "value"

    def __init__(
        self,
        map=None,
        default_layer: Optional[WorkflowsLayer] = None,
        hide_deps_of: Optional[Proxytype] = None,
        **kwargs,
    ):
        """
        Construct a LayerPicker widget for a map.

        Parameters
        ----------
        map: ipyleaflet.Map
            The map instance to pick from. Defaults to `wf.map`.
        default_layer: WorkflowsLayer
            The layer instance to have selected by default
        hide_deps_of: Proxytype
            Hide any layers from the dropdown that have this object in their ``.params``.
            Mainly used by the Picker parameter widget to hide its own layer from the dropdown,
            avoiding graft cycles.
        """
        super().__init__(**kwargs)
        if map is None:
            # use wf.map as default
            from . import map

        # awkwardly handle MapApp without circularly importing it for an isinstance check
        try:
            map = map.map
        except AttributeError:
            pass

        self._map = map
        self._hide_deps_of = hide_deps_of
        self._dropdown = ipywidgets.Dropdown(equals=operator.is_)

        type_ = type(self).value.klass
        if default_layer is not None:
            if not isinstance(default_layer, WorkflowsLayer):
                raise TypeError(
                    f"Default values for an {type(self).__name__} can only be WorkflowsLayer instances "
                    f"(the layer object returned by `.visualize`), not {default_layer!r}."
                    "Also note that this default value won't be synced when publishing."
                )
            value = getattr(default_layer, self._attr)
            if not isinstance(value, type_):
                raise TypeError(
                    f"Expected a default layer visualizing an {type_.name}, not an {type(value).__name__}. "
                    "Pick a different layer, or pick a different type for this widget, or remove a "
                    "reduction operation (like `.mosaic()`, `.mean('images')`) from the code that "
                    f"produces the layer {default_layer.name!r}")
            self.set_trait("value", value)
            default_layer.observe(self._picked_layer_value_changes, self._attr)
        self._picked_layer = default_layer

        map.observe(self._update_options, "layers")
        self._dropdown.observe(self._layer_picked, "value")

        self.children = [self._dropdown]
        self._setting_options = False

        self._update_options({})

    def _update_options(self, change):
        type_ = type(self).value.klass
        options = [(lyr.name, lyr) for lyr in reversed(self._map.layers)
                   if isinstance(lyr, WorkflowsLayer)
                   and isinstance(getattr(lyr, self._attr), type_) and all(
                       p is not self._hide_deps_of
                       for p in lyr.xyz_obj.params)]

        # when changing options, ipywidgets always just picks the first option.
        # this is infuriatingly difficult to work around, so we set our own flag to ignore
        # changes while this is happening.
        self._setting_options = True
        self._dropdown.options = options
        self._setting_options = False

        try:
            self._dropdown.value = self._picked_layer
        except traitlets.TraitError:
            # the previously-picked layer doesn't exist anymore;
            # we'd rather just have no value in that case
            self._dropdown.value = None
            self._picked_layer = None
            self.set_trait("value", None)

    def _layer_picked(self, change):
        new_layer = change["new"]
        if self._setting_options or new_layer is self._picked_layer:
            return

        if self._picked_layer is not None:
            self._picked_layer.unobserve(self._picked_layer_value_changes,
                                         self._attr)

        if new_layer is None:
            self._picked_layer = None
            self.set_trait("value", None)
        else:
            new_layer.observe(self._picked_layer_value_changes, self._attr)
            self._picked_layer = new_layer
            self.set_trait("value", getattr(new_layer, self._attr))

    def unlink(self):
        self._map.unobserve("layers", self._update_options)
        self._dropdown.unobserve("value", self._layer_picked)
        if self._picked_layer is not None:
            self._picked_layer.unobserve(self._picked_layer_value_changes,
                                         self._attr)
        self._picked_layer = None

    def _picked_layer_value_changes(self, change):
        self.set_trait("value", change["new"])

    def _ipython_display_(self):
        super()._ipython_display_()
コード例 #4
0
class InspectorRowGenerator(traitlets.HasTraits):
    """
    Controller class that manages the name and pixel values widgets for one layer.

    Not a widget itself, but just exposes `name_label` and `value_labels`
    for the `PixelInspector` to add into its table.

    Listens for changes to the layer (XYZ object or parameters) or the marker (location),
    and updates the widgets in `value_labels` appropriately, by calling `inspect` in a
    separate thread to pull pixel values.
    """

    _value_layout = {"width": "initial", "margin": "0 2px", "height": "1.6em"}

    name_label = traitlets.Instance(
        ipywidgets.Label,
        kw={"layout": dict(_value_layout, grid_column="1")},
        read_only=True,
        allow_none=False,
    )
    values_labels = traitlets.List(read_only=True, allow_none=False)

    def __init__(self, layer, marker, n_bands):
        self.marker = marker
        self.layer = layer
        self._updating = False
        self._cache = cachetools.LRUCache(64)

        self.name_label.value = layer.name
        # TODO make names bold. it's frustratingly difficult (unsupported) with ipywidgets:
        # https://github.com/jupyter-widgets/ipywidgets/issues/577
        self._name_link = ipywidgets.jslink((layer, "name"),
                                            (self.name_label, "value"))
        self.set_trait(
            "values_labels",
            [
                ipywidgets.Label(value="",
                                 layout=dict(self._value_layout,
                                             grid_column=str(2 + i)))
                for i in range(n_bands)
            ],
        )

        marker.observe(self.recalculate, "geoctx", type="change")
        layer.observe(
            self.recalculate,
            ["image_value", "visible"],
            type="change",
        )

        self._viz_links = [
            traitlets.dlink(
                (layer, "visible"),
                (label.layout, "display"),
                lambda v: "" if v else "none",
            ) for label in [self.name_label] + self.values_labels
        ]

        if marker.opacity == 1:
            # there's already a point to sample; eagerly recalculate now
            self.recalculate()

    def unlink(self):
        # NOTE(gabe): the traitlets docs say name=All (the default) should work,
        # but careful reading of the source shows it's a no-op.
        # we must explicitly unobserve for exactly the names and types we observed for.
        self.marker.unobserve(self.recalculate, "geoctx", type="change")
        self.layer.unobserve(
            self.recalculate,
            ["image_value", "visible"],
            type="change",
        )
        self._name_link.unlink()
        for viz_link in self._viz_links:
            viz_link.unlink()

    def recalculate(self, *args, **kwargs):
        if self._updating or not self.layer.visible or self.marker.opacity == 0:
            return

        xy_3857 = self.marker.xy_3857

        # try to make a cache key from the marker location, XYZ ID, reduction, and parameters.
        # if the parameters are unhashable (probably because they contain grafts),
        # we'll consider it a cache miss and go fetch.
        try:
            params_key = frozenset(self.layer.parameters.to_dict().items())
        except TypeError:
            cache_key = None
        else:
            cache_key = (
                xy_3857,
                self.layer.xyz_obj.id,
                self.layer.reduction,
                params_key,
            )

        if cache_key:
            try:
                value_list = self._cache[cache_key]
            except KeyError:
                value_list = None
        else:
            value_list = None

        image = self.layer.image_value
        if image is None:
            value_list = ["❓"]

        if value_list:
            self.set_values(value_list)
        else:
            self.set_updating()
            # NOTE(gabe): I don't trust traitlets or ipywidgets to be thread-safe,
            # so we pull all values out of traits here and pass them in to the thread directly
            ctx = self.marker.geoctx
            thread = threading.Thread(
                target=self._fetch_and_set_thread,
                args=(image, xy_3857, ctx, cache_key),
                daemon=True,
            )
            thread.start()

    def _fetch_and_set_thread(self, image, xy_3857, ctx, cache_key):
        proxy_value_list = image.value_at(*xy_3857).values()

        try:
            value_list = proxy_value_list.inspect(ctx)
        except JobTimeoutError:
            value_list = ["⏱"]
        except Exception:
            value_list = ["💥"]
        else:
            if len(value_list) == 0:
                # empty Image
                value_list = [np.ma.masked]
            self._cache[cache_key] = value_list

        self.set_values(value_list)

    def set_values(self, new_values_list):
        for i, value in enumerate(new_values_list):
            if isinstance(value, str):
                pass
            elif value is np.ma.masked:
                new_values_list[i] = "∅"
            else:
                new_values_list[i] = "{:.6g}".format(value)

        for i, label in enumerate(self.values_labels):
            try:
                label.value = new_values_list[i]
            except IndexError:
                label.value = ""

        self._updating = False

    def set_updating(self):
        self._updating = True

        for i, label in enumerate(self.values_labels):
            if label.value != "" or i == 0:
                label.value = "..."
コード例 #5
0
ファイル: RobotClass.py プロジェクト: BrownKnight/robotics
class Robot(SingletonConfigurable):
    
    front_left_motor = traitlets.Instance(Motor)
    front_right_motor = traitlets.Instance(Motor)
    back_left_motor = traitlets.Instance(Motor)
    back_right_motor = traitlets.Instance(Motor)

    # config
    front_left_motor_channel = traitlets.Integer(default_value=1).tag(config=True)
    front_left_motor_alpha = traitlets.Float(default_value=1.0).tag(config=True)
    front_right_motor_channel = traitlets.Integer(default_value=2).tag(config=True)
    front_right_motor_alpha = traitlets.Float(default_value=1.0).tag(config=True)
    back_left_motor_channel = traitlets.Integer(default_value=3).tag(config=True)
    back_left_motor_alpha = traitlets.Float(default_value=1.0).tag(config=True)
    back_right_motor_channel = traitlets.Integer(default_value=4).tag(config=True)
    back_right_motor_alpha = traitlets.Float(default_value=1.0).tag(config=True)
    
    def __init__(self, *args, **kwargs):
        super(Robot, self).__init__(*args, **kwargs)
        self.left_motor_driver = PCA9685(0x41, debug=False)
        self.right_motor_driver = PCA9685(0x40, debug=False)
        self.front_left_motor = Motor(self.left_motor_driver, channel=self.front_left_motor_channel, alpha=self.front_left_motor_alpha)
        self.front_right_motor = Motor(self.right_motor_driver, channel=self.front_right_motor_channel, alpha=self.front_right_motor_alpha)
        self.back_left_motor = Motor(self.left_motor_driver, channel=self.back_left_motor_channel, alpha=self.back_left_motor_alpha)
        self.back_right_motor = Motor(self.right_motor_driver, channel=self.back_right_motor_channel, alpha=self.back_right_motor_alpha)
        
    def set_motors(self, front_left_speed, front_right_speed, back_left_speed, back_right_speed):
        self.front_left_motor.value = front_left_speed
        self.front_right_motor.value = front_right_speed
        self.back_left_motor.value = back_left_speed
        self.back_right_motor.value = back_right_speed
        
    def forward(self, speed=1.0):
        self.front_left_motor.value = speed
        self.front_right_motor.value = speed
        self.back_left_motor.value = speed
        self.back_right_motor.value = speed

    def backward(self, speed=1.0):
        self.front_left_motor.value = -speed
        self.front_right_motor.value = -speed
        self.back_left_motor.value = -speed
        self.back_right_motor.value = -speed

    def left(self, speed=1.0):
        self.front_left_motor.value = -speed
        self.front_right_motor.value = speed
        self.back_left_motor.value = -speed
        self.back_right_motor.value = speed

    def right(self, speed=1.0):
        self.front_left_motor.value = speed
        self.front_right_motor.value = -speed
        self.back_left_motor.value = speed
        self.back_right_motor.value = -speed

    def stop(self):
        self.front_left_motor.value = 0
        self.front_right_motor.value = 0
        self.back_left_motor.value = 0
        self.back_right_motor.value = 0
    
    def forward_left(self, speed=1.0):
        self.front_left_motor.value = speed / 3
        self.front_right_motor.value = speed
        self.back_left_motor.value = speed / 3
        self.back_right_motor.value = speed
            
    def forward_right(self, speed=1.0):
        self.front_left_motor.value = speed
        self.front_right_motor.value = speed / 3
        self.back_left_motor.value = speed
        self.back_right_motor.value = speed / 3

    def backward_left(self, speed=1.0):
        self.front_left_motor.value = -speed / 3
        self.front_right_motor.value = -speed
        self.back_left_motor.value = -speed / 3
        self.back_right_motor.value = -speed
            
    def backward_right(self, speed=1.0):
        self.front_left_motor.value = -speed
        self.front_right_motor.value = -speed / 3
        self.back_left_motor.value = -speed
        self.back_right_motor.value = -speed / 3
コード例 #6
0
class DownloadChooser(ipw.HBox):
    """Download chooser for structure download

    To be able to have the download button work no matter the widget's final environment,
    (as long as it supports JavaScript), the very helpful insight from the following page is used:
    https://stackoverflow.com/questions/2906582/how-to-create-an-html-button-that-acts-like-a-link
    """

    chosen_format = traitlets.Tuple(traitlets.Unicode(), traitlets.Dict())
    structure = traitlets.Instance(Structure, allow_none=True)

    _formats = [
        (
            "Crystallographic Information File v1.0 (.cif)",
            {
                "ext": ".cif",
                "adapter_format": "cif"
            },
        ),
        ("Protein Data Bank (.pdb)", {
            "ext": ".pdb",
            "adapter_format": "pdb"
        }),
        (
            "Crystallographic Information File v1.0 [via ASE] (.cif)",
            {
                "ext": ".cif",
                "adapter_format": "ase",
                "final_format": "cif"
            },
        ),
        (
            "Protein Data Bank [via ASE] (.pdb)",
            {
                "ext": ".pdb",
                "adapter_format": "ase",
                "final_format": "proteindatabank"
            },
        ),
        (
            "XMol XYZ File [via ASE] (.xyz)",
            {
                "ext": ".xyz",
                "adapter_format": "ase",
                "final_format": "xyz"
            },
        ),
        (
            "XCrySDen Structure File [via ASE] (.xsf)",
            {
                "ext": ".xsf",
                "adapter_format": "ase",
                "final_format": "xsf"
            },
        ),
        (
            "WIEN2k Structure File [via ASE] (.struct)",
            {
                "ext": ".struct",
                "adapter_format": "ase",
                "final_format": "struct"
            },
        ),
        (
            "VASP POSCAR File [via ASE]",
            {
                "ext": "",
                "adapter_format": "ase",
                "final_format": "vasp"
            },
        ),
        (
            "Quantum ESPRESSO File [via ASE] (.in)",
            {
                "ext": ".in",
                "adapter_format": "ase",
                "final_format": "espresso-in"
            },
        ),
        # Not yet implemented:
        # (
        #     "Protein Data Bank, macromolecular CIF v1.1 (PDBx/mmCIF) (.cif)",
        #     {"ext": "cif", "adapter_format": "pdbx_mmcif"},
        # ),
    ]
    _download_button_format = """
<input type="button" class="jupyter-widgets jupyter-button widget-button" value="Download" title="Download structure" style="width:auto;" {disabled}
onclick="var link = document.createElement('a');
link.href = 'data:charset={encoding};base64,{data}';
link.download = '{filename}';
document.body.appendChild(link);
link.click();
document.body.removeChild(link);" />
"""

    def __init__(self, **kwargs):
        self.dropdown = ipw.Dropdown(options=("Select a format", {}),
                                     width="auto")
        self.download_button = ipw.HTML(
            self._download_button_format.format(disabled="disabled",
                                                encoding="",
                                                data="",
                                                filename=""))

        self.children = (self.dropdown, self.download_button)
        super().__init__(children=self.children, layout={"width": "auto"})
        self.reset()

        self.dropdown.observe(self._update_download_button, names="value")

    @traitlets.observe("structure")
    def _on_change_structure(self, change: dict):
        """Update widget when a new structure is chosen"""
        if change["new"] is None:
            self.reset()
        else:
            self._update_options()
            self.unfreeze()

    def _update_options(self):
        """Update options according to chosen structure"""
        # Disordered structures not usable with ASE
        if "disorder" in self.structure.structure_features:
            options = sorted([
                option for option in self._formats
                if option[1].get("adapter_format", "") != "ase"
            ])
            options.insert(0, ("Select a format", {}))
        else:
            options = sorted(self._formats)
            options.insert(0, ("Select a format", {}))
        self.dropdown.options = options

    def _update_download_button(self, change: dict):
        """Update Download button with correct onclick value

        The whole parsing process from `Structure` to desired format, is wrapped in a try/except,
        which is further wrapped in a `warnings.catch_warnings()`.
        This is in order to be able to log any warnings that might be thrown by the adapter in
        `optimade-python-tools` and/or any related exceptions.
        """
        desired_format = change["new"]
        if not desired_format or desired_format is None:
            self.download_button.value = self._download_button_format.format(
                disabled="disabled", encoding="", data="", filename="")
            return

        with warnings.catch_warnings():
            warnings.filterwarnings("error")

            try:
                output = getattr(self.structure,
                                 f"as_{desired_format['adapter_format']}")

                if desired_format["adapter_format"] in (
                        "ase",
                        "pymatgen",
                        "aiida_structuredata",
                ):
                    # output is not a file, but a proxy Python class
                    func = getattr(
                        self, f"_get_via_{desired_format['adapter_format']}")
                    output = func(
                        output, desired_format=desired_format["final_format"])
                encoding = "utf-8"

                # Specifically for CIF: v1.x CIF needs to be in "latin-1" formatting
                if desired_format["ext"] == ".cif":
                    encoding = "latin-1"

                filename = (
                    f"optimade_structure_{self.structure.id}{desired_format['ext']}"
                )

                if isinstance(output, str):
                    output = output.encode(encoding)
                data = base64.b64encode(output).decode()
            except Warning as warn:
                self.download_button.value = self._download_button_format.format(
                    disabled="disabled", encoding="", data="", filename="")
                warnings.warn(OptimadeClientWarning(warn))
            except Exception as exc:
                self.download_button.value = self._download_button_format.format(
                    disabled="disabled", encoding="", data="", filename="")
                if isinstance(exc, exceptions.OptimadeClientError):
                    raise exc
                # Else wrap the exception to make sure to log it.
                raise exceptions.OptimadeClientError(exc)
            else:
                self.download_button.value = self._download_button_format.format(
                    disabled="",
                    encoding=encoding,
                    data=data,
                    filename=filename)

    @staticmethod
    def _get_via_pymatgen(
        structure_molecule: Union[pymatgenStructure, pymatgenMolecule],
        desired_format: str,
    ) -> str:
        """Use pymatgen.[Structure,Molecule].to() method"""
        molecule_only_formats = ["xyz", "pdb"]
        structure_only_formats = ["xsf", "cif"]
        if desired_format in molecule_only_formats and not isinstance(
                structure_molecule, pymatgenMolecule):
            raise exceptions.WrongPymatgenType(
                f"Converting to '{desired_format}' format is only possible with a pymatgen."
                f"Molecule, instead got {type(structure_molecule)}")
        if desired_format in structure_only_formats and not isinstance(
                structure_molecule, pymatgenStructure):
            raise exceptions.WrongPymatgenType(
                f"Converting to '{desired_format}' format is only possible with a pymatgen."
                f"Structure, instead got {type(structure_molecule)}.")

        return structure_molecule.to(fmt=desired_format)

    @staticmethod
    def _get_via_ase(atoms: aseAtoms,
                     desired_format: str) -> Union[str, bytes]:
        """Use ase.Atoms.write() method"""
        with tempfile.NamedTemporaryFile(mode="w+b") as temp_file:
            atoms.write(temp_file.name, format=desired_format)
            res = temp_file.read()
        return res

    def freeze(self):
        """Disable widget"""
        for widget in self.children:
            widget.disabled = True

    def unfreeze(self):
        """Activate widget (in its current state)"""
        for widget in self.children:
            widget.disabled = False

    def reset(self):
        """Reset widget"""
        self.dropdown.index = 0
        self.freeze()
コード例 #7
0
class ProviderImplementationSummary(ipw.GridspecLayout):
    """Summary/description of chosen provider and their database"""

    provider = traitlets.Instance(LinksResourceAttributes, allow_none=True)
    database = traitlets.Instance(LinksResourceAttributes, allow_none=True)

    text_style = "margin:0px;padding-top:6px;padding-bottom:4px;padding-left:4px;padding-right:4px;"

    def __init__(self, **kwargs):
        self.provider_summary = ipw.HTML()
        provider_section = ipw.VBox(
            children=[self.provider_summary],
            layout=ipw.Layout(width="auto", height="auto"),
        )

        self.database_summary = ipw.HTML()
        database_section = ipw.VBox(
            children=[self.database_summary],
            layout=ipw.Layout(width="auto", height="auto"),
        )

        super().__init__(
            n_rows=1,
            n_columns=31,
            layout={
                "border": "solid 0.5px darkgrey",
                "margin": "0px 0px 0px 0px",
                "padding": "0px 0px 10px 0px",
            },
            **kwargs,
        )
        self[:, :15] = provider_section
        self[:, 16:] = database_section

        self.observe(self._on_provider_change, names="provider")
        self.observe(self._on_database_change, names="database")

    def _on_provider_change(self, change: dict):
        """Update provider summary, since self.provider has been changed"""
        LOGGER.debug("Provider changed in summary. New value: %r",
                     change["new"])
        self.database_summary.value = ""
        if not change["new"] or change["new"] is None:
            self.provider_summary.value = ""
        else:
            self._update_provider()

    def _on_database_change(self, change):
        """Update database summary, since self.database has been changed"""
        LOGGER.debug("Database changed in summary. New value: %r",
                     change["new"])
        if not change["new"] or change["new"] is None:
            self.database_summary.value = ""
        else:
            self._update_database()

    def _update_provider(self):
        """Update provider summary"""
        html_text = f"""<strong style="line-height:1;{self.text_style}">{getattr(self.provider, 'name', 'Provider')}</strong>
<p style="line-height:1.2;{self.text_style}">{getattr(self.provider, 'description', '')}</p>"""
        self.provider_summary.value = html_text

    def _update_database(self):
        """Update database summary"""
        html_text = f"""<strong style="line-height:1;{self.text_style}">{getattr(self.database, 'name', 'Database')}</strong>
<p style="line-height:1.2;{self.text_style}">{getattr(self.database, 'description', '')}</p>"""
        self.database_summary.value = html_text

    def freeze(self):
        """Disable widget"""

    def unfreeze(self):
        """Activate widget (in its current state)"""

    def reset(self):
        """Reset widget"""
        self.provider = None
コード例 #8
0
class QueryConstructor(W.HBox):
    """TODO
    - way better templating and more efficient formatting
    - replace individual observers with larger observer
    - move build_query to standalone function
    """

    convert_arrow = T.Instance(W.Image)
    query_input = T.Instance(W.VBox)
    formatted_query = T.Instance(QueryColorizer,
                                 kw=dict(layout=W.Layout(max_height="260px")))

    # traits from children
    namespaces = T.Unicode()
    query_type = T.Unicode(default_value="SELECT")
    query_line = T.Unicode(allow_none=True)
    query_body = T.Unicode()
    query = T.Unicode()

    log = W.Output()

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.query_input = QueryInput()

        # Inherit traits with links TODO easier way?
        T.link((self.query_input.namespaces, "namespaces"),
               (self, "namespaces"))
        T.link((self.query_input.header, "dropdown_value"),
               (self, "query_type"))
        T.link((self.query_input.header, "header_value"), (self, "query_line"))
        T.link((self.query_input.body.body, "value"), (self, "query_body"))
        T.dlink((self, "query"), (self.formatted_query, "query"))

        self.children = tuple([self.query_input, self.formatted_query])

    @log.capture()
    def build_query(self):
        # get values TODO improve
        namespaces = self.namespaces
        query_type = self.query_type
        query_line = self.query_line
        query_body = self.query_body or self.query_input.body.body.placeholder

        # update query_body
        query_body = "\t\n".join(query_body.split(
            "\n"))  # TODO this isn't actually formatting properly

        header_str = ""
        # TODO move these to module vars
        if query_type in {"SELECT", "SELECT DISTINCT"}:
            if query_line == "":
                query_line = "*"
            header_str = f"{query_type} {query_line}"
        elif query_type == "ASK":
            header_str = query_type
        elif query_type == "CONSTRUCT":
            if query_line == "":
                query_line = "{?s ?p ?o}"
            header_str = f"{query_type} {query_line}"
        else:
            with self.log:
                raise ValueError(f"Unexpected query type: {query_type}")

        return query_template.format(
            namespaces,
            header_str,
            query_body,
        )

    @T.observe(
        "namespaces",
        "query_type",
        "query_line",
        "query_body",
    )
    def update_query(self, change):
        self.query = self.build_query()
コード例 #9
0
ファイル: model.py プロジェクト: sharadMalmanchi/vaex
class DataArray(_HasState):
    class Status(enum.Enum):
        MISSING_LIMITS = 1
        STAGED_CALCULATING_LIMITS = 3
        CALCULATING_LIMITS = 4
        CALCULATED_LIMITS = 5
        NEEDS_CALCULATING_GRID = 6
        STAGED_CALCULATING_GRID = 7
        CALCULATING_GRID = 8
        CALCULATED_GRID = 9
        READY = 10
        EXCEPTION = 11

    status = traitlets.UseEnum(Status, Status.MISSING_LIMITS)
    status_text = traitlets.Unicode('Initializing')
    exception = traitlets.Any(None)
    df = traitlets.Instance(vaex.dataframe.DataFrame)
    axes = traitlets.List(traitlets.Instance(Axis), [])
    grid = traitlets.Instance(xarray.DataArray, allow_none=True)
    grid_sliced = traitlets.Instance(xarray.DataArray, allow_none=True)
    shape = traitlets.CInt(64)
    selections = traitlets.List(
        traitlets.Union([traitlets.Bool(),
                         traitlets.Unicode(allow_none=True)]), [None])

    def __init__(self, **kwargs):
        super(DataArray, self).__init__(**kwargs)
        self.signal_slice = vaex.events.Signal()
        self.signal_regrid = vaex.events.Signal()
        self.signal_grid_progress = vaex.events.Signal()
        self.observe(lambda change: self.signal_regrid.emit(), 'selections')
        self._on_axis_status_change()

        # keep a set of axis that need new limits
        self._dirty_axes = set()
        for axis in self.axes:
            assert axis.df is self.df, "axes should have the same dataframe"
            traitlets.link((self, 'shape'), (axis, 'shape_default'))
            axis.observe(self._on_axis_status_change, 'status')
            axis.observe(lambda _: self.signal_slice.emit(self), ['slice'])

            def on_change_min_max(change):
                if change.owner.status == Axis.Status.READY:
                    # this indicates a user changed the min/max
                    self.status = DataArray.Status.NEEDS_CALCULATING_GRID

            axis.observe(on_change_min_max, ['min', 'max'])

        self._on_axis_status_change()
        self.df.signal_selection_changed.connect(self._on_change_selection)

    def _on_change_selection(self, df, name):
        # TODO: check if the selection applies to us
        self.status = DataArray.Status.NEEDS_CALCULATING_GRID

    async def _allow_state_change_cancel(self):
        self._allow_state_change.release()

    def _on_axis_status_change(self, change=None):
        missing_limits = [
            axis for axis in self.axes if axis.status == Axis.Status.NO_LIMITS
        ]
        staged_calculating_limits = [
            axis for axis in self.axes
            if axis.status == Axis.Status.STAGED_CALCULATING_LIMITS
        ]
        calculating_limits = [
            axis for axis in self.axes
            if axis.status == Axis.Status.CALCULATING_LIMITS
        ]
        calculated_limits = [
            axis for axis in self.axes
            if axis.status == Axis.Status.CALCULATED_LIMITS
        ]

        def names(axes):
            return ", ".join([str(axis.expression) for axis in axes])

        if staged_calculating_limits:
            self.status = DataArray.Status.STAGED_CALCULATING_LIMITS
            self.status_text = 'Staged limit computation for {}'.format(
                names(staged_calculating_limits))
        elif missing_limits:
            self.status = DataArray.Status.MISSING_LIMITS
            self.status_text = 'Missing limits for {}'.format(
                names(missing_limits))
        elif calculating_limits:
            self.status = DataArray.Status.CALCULATING_LIMITS
            self.status_text = 'Computing limits for {}'.format(
                names(calculating_limits))
        elif calculated_limits:
            self.status = DataArray.Status.CALCULATED_LIMITS
            self.status_text = 'Computed limits for {}'.format(
                names(calculating_limits))
        else:
            assert all(
                [axis.status == Axis.Status.READY for axis in self.axes])
            self.status = DataArray.Status.NEEDS_CALCULATING_GRID

    @traitlets.observe('status')
    def _on_change_status(self, change):
        if self.status == DataArray.Status.EXCEPTION:
            self.status_text = f'Exception: {self.exception}'
        elif self.status == DataArray.Status.NEEDS_CALCULATING_GRID:
            self.status_text = 'Grid needs to be calculated'
        elif self.status == DataArray.Status.STAGED_CALCULATING_GRID:
            self.status_text = 'Staged grid computation'
        elif self.status == DataArray.Status.CALCULATING_GRID:
            self.status_text = 'Calculating grid'
        elif self.status == DataArray.Status.CALCULATED_GRID:
            self.status_text = 'Calculated grid'
        elif self.status == DataArray.Status.READY:
            self.status_text = 'Ready'
        # GridCalculator can change the status
        # self._update_grid()
        # self.status_text = 'Computing limits for {}'.format(names(missing_limits))

    @property
    def has_missing_limits(self):
        return any([axis.has_missing_limit for axis in self.axes])

    def on_progress_grid(self, f):
        return all(self.signal_grid_progress.emit(f))
コード例 #10
0
class NXBase(VisualizerBase):
    """
    The visualization class for the NXLayouts. Used by the datashader visualizations.

    :param _nx_layout: the desired networkx layout function to be used.
    :param _layouts: a dictionary mapping labels of known layouts to networkx functions.

    Notes:
      - not all networkx layouts work without custom node/edge data or graph_layout_params
        and are NOT_HANDLED by default, but can be set explicitly
    """

    NOT_HANDLED = [
        "bipartite_layout",
        "multipartite_layout",
        "rescale_layout",
    ]

    _layouts = T.Dict()

    _nx_layout = T.Instance(types.FunctionType)

    @T.default("graph_layout_options")
    def _make_default_options(self):
        return tuple(self._layouts)

    @T.default("_layouts")
    def _make_default_layouts(self):
        """these are leniently loaded, as the exact set of algorithms depends
        heavily on the version of networkx installed
        """
        layouts = {}
        for layout_key in nx_layout.__all__:
            if layout_key in self.NOT_HANDLED:
                continue
            try:
                layout = getattr(nx_layout, layout_key)
                label = _make_nx_layout_label(layout_key)
                layouts[label] = layout
            except Exception as err:
                self.log.warning(
                    "Expected to be able to load from networkx: %s\n%s",
                    layout_key, err)
        return layouts

    @T.default("graph_layout")
    def _make_default_layout(self):
        return self.graph_layout_options[0]

    @T.default("_nx_layout")
    def set_default_nx_layout(self):
        return self._layouts[self.graph_layout]

    @T.observe("graph_layout")
    def _update_graph_layout(self, change):
        if change.new is None:
            self._nx_layout = sorted(self._layouts.items())[0][1]
            return

        if change.new in self._layouts:
            self._nx_layout = self._layouts[self.graph_layout]
            return

        try:
            self._nx_layout = getattr(nx_layout, change.new)
        except Exception as err:
            self.log.warning("Could not load from networkx: %s\n%s",
                             change.new, err)
コード例 #11
0
class XELK(ElkTransformer):
    """NetworkX DiGraphs to ELK dictionary structure"""

    HIDDEN_ATTR = "hidden"
    hoist_hidden_edges: bool = True

    source = T.Tuple(T.Instance(nx.Graph),
                     T.Instance(nx.DiGraph, allow_none=True))
    layouts = T.Dict()  # keys: networkx nodes {ElkElements: {layout options}}
    css_classes = T.Dict()

    port_scale = T.Int(default_value=5)
    label_key = T.Unicode(default_value="labels")
    port_key = T.Unicode(default_value="ports")

    @T.default("source")
    def _default_source(self):
        return (nx.Graph(), None)

    @T.default("text_sizer")
    def _default_text_sizer(self):
        return ElkTextSizer()

    @T.default("layouts")
    def _default_layouts(self):
        parent_opts = opt.OptionsWidget(
            identifier="parents",
            options=[
                opt.HierarchyHandling(),
            ],
        )
        label_opts = opt.OptionsWidget(
            identifier=ElkLabel,
            options=[opt.NodeLabelPlacement(horizontal="center")])
        node_opts = opt.OptionsWidget(
            identifier=ElkNode,
            options=[
                opt.NodeSizeConstraints(),
            ],
        )

        default = opt.OptionsWidget(
            options=[parent_opts, node_opts, label_opts])
        return {ElkRoot: default.value}

    def node_id(self, node: Hashable) -> str:
        """Get the element id for a node in the main graph for use in elk

        :param node: Node in main  graph
        :type node: Hashable
        :return: Element ID
        :rtype: str
        """
        g, tree = self.source
        if node is ElkRoot:
            return self.ELK_ROOT_ID
        elif node in g:
            return g.nodes.get(node, {}).get("id", f"{node}")
        return f"{node}"

    def port_id(self, node: Hashable, port: Optional[Hashable] = None) -> str:
        """Get a suitable Elk identifier from the node and port

        :param node: Node from the incoming networkx graph
        :type node: Hashable
        :param port: Port identifier, defaults to None
        :type port: Optional[Hashable], optional
        :return: If no port is provided will refer to the node
        :rtype: str
        """
        if port is None:
            return self.node_id(node)
        else:
            return f"{self.node_id(node)}.{port}"

    def edge_id(self, edge: Edge):
        # TODO probably will need more sophisticated id generation in future
        return "{}__{}___{}__{}".format(edge.source, edge.source_port,
                                        edge.target, edge.target_port)

    async def transform(self) -> ElkNode:
        """Generate ELK dictionary structure
        :return: Root Elk node
        :rtype: ElkNode
        """
        # TODO decide behavior for nodes that exist in the tree but not g
        g, tree = self.source
        self.clear_registry()
        visible_edges, hidden_edges = self.collect_edges()

        # Process visible networkx nodes into elknodes
        visible_nodes = [
            n for n in g.nodes() if not is_hidden(tree, n, self.HIDDEN_ATTR)
        ]

        # make elknodes then connect their hierarchy
        elknodes: NodeMap = {}
        ports: PortMap = {}
        for node in visible_nodes:
            elknode, node_ports = await self.make_elknode(node)
            for key, value in node_ports.items():
                ports[key] = value
            elknodes[node] = elknode

        # make top level ElkNode and attach all others as children
        elknodes[ElkRoot] = top = ElkNode(
            id=self.ELK_ROOT_ID,
            children=build_hierarchy(g, tree, elknodes, self.HIDDEN_ATTR),
            layoutOptions=self.get_layout(ElkRoot, "parents"),
        )

        # map of original nodes to the generated elknodes
        for node, elk_node in elknodes.items():
            self.register(elk_node, node)

        elknodes, ports = await self.process_edges(elknodes, ports,
                                                   visible_edges)
        # process edges with one or both original endpoints are hidden
        if self.hoist_hidden_edges:
            elknodes, ports = await self.process_edges(
                elknodes,
                ports,
                hidden_edges,
                edge_style={"slack-edge"},
                port_style={"slack-port"},
            )

        # iterate through the port map and add ports to ElkNodes
        for port_id, port in ports.items():
            owner = port.node
            elkport = port.elkport
            if owner not in elknodes:
                # TODO skip generating port to begin with
                break
            elknode = elknodes[owner]
            if elknode.ports is None:
                elknode.ports = []
            layout = self.get_layout(owner, ElkPort)
            elkport.layoutOptions = merge(elkport.layoutOptions, layout)
            elknode.ports += [elkport]

            # map of ports to the generated elkports
            self.register(elkport, port)

        # bulk calculate label sizes
        await size_labels(self.text_sizer, collect_labels([top]))

        return top

    async def make_elknode(self, node) -> Tuple[ElkNode, PortMap]:
        # merge layout options defined on the node data with default layout
        # options
        node_data = self.get_node_data(node)
        layout = merge(
            node_data.get("layoutOptions", {}),
            self.get_layout(node, ElkNode),
        )
        labels = await self.make_labels(node)

        # update port map with declared ports in the networkx node data
        node_ports = await self.collect_ports(node)

        properties = self.get_properties(node, self.get_css(node,
                                                            ElkNode)) or None

        elk_node = ElkNode(
            id=self.node_id(node),
            labels=labels,
            layoutOptions=layout,
            properties=properties,
            width=node_data.get("width", None),
            height=node_data.get("height", None),
        )
        return elk_node, node_ports

    def get_layout(self, node: Hashable,
                   elk_type: Type[ElkGraphElement]) -> Optional[Dict]:
        """Get the Elk Layout Options appropriate for given networkx node and
        filter by given elk_type

        :param node: [description]
        :type node: Hashable
        :param elk_type: [description]
        :type elk_type: Type[ElkGraphElement]
        :return: [description]
        :rtype: [type]
        """
        # TODO look at self.source hierarchy and resolve layout with added
        # infomation. until then use root node `None` for layout options
        if node not in self.layouts:
            node = ElkRoot

        type_opts = self.layouts.get(node, {})
        options = {**type_opts.get(elk_type, {})}
        if options:
            return options

    def get_properties(
        self,
        element: Optional[Union[Hashable, str]],
        dom_classes: Optional[Set[str]] = None,
    ) -> Dict:
        """Get the properties for a graph element

        :param element: Networkx node or edge
        :type node: Hashable
        :param dom_classes: Set of base CSS DOM classes to merge, defaults to
        Set[str]=None
        :type dom_classes: [type], optional
        :return: Set of CSS Classes to apply
        :rtype: Set[str]
        """

        g, tree = self.source

        properties = []

        if g and element in g:
            g_props = g.nodes[element].get("properties", {})
            if g_props:
                properties += [g_props]
        if hasattr(element, "data"):
            properties += [element.data.get("properties", {})]
        elif isinstance(element, dict):
            properties += [element.get("properties", {})]

        if dom_classes:
            properties += [{"cssClasses": " ".join(dom_classes)}]

        if not properties:
            return {}
        elif len(properties) == 1:
            return properties[0]

        merged_properties = {}

        for props in properties[::-1]:
            merged_properties = merge(props, merged_properties)

        return merged_properties

    def get_css(
        self,
        node: Hashable,
        elk_type: Type[ElkGraphElement],
        dom_classes: Set[str] = None,
    ) -> Set[str]:
        """Get the CSS Classes appropriate for given networkx node given
        elk_type

        :param node: Networkx node
        :type node: Hashable
        :param elk_type: ElkGraphElement to get appropriate css classes
        :type elk_type: Type[ElkGraphElement]
        :param dom_classes: Set of base CSS DOM classes to merge, defaults to
        Set[str]=None
        :type dom_classes: [type], optional
        :return: Set of CSS Classes to apply
        :rtype: Set[str]
        """
        typed_css = self.css_classes.get(node, {})
        css_classes = set(typed_css.get(elk_type, []))
        if dom_classes is None:
            return css_classes
        return css_classes | dom_classes

    async def process_edges(
        self,
        nodes: NodeMap,
        ports: PortMap,
        edges: EdgeMap,
        edge_style: Set[str] = None,
        port_style: Set[str] = None,
    ) -> Tuple[NodeMap, PortMap]:
        for owner, edge_list in edges.items():
            edge_css = self.get_css(owner, ElkEdge, edge_style)
            port_css = self.get_css(owner, ElkPort, port_style)
            for edge in edge_list:
                elknode = nodes[owner]
                if elknode.edges is None:
                    elknode.edges = []
                if edge.source_port is not None:
                    port_id = self.port_id(edge.source, edge.source_port)
                    if port_id not in ports:
                        ports[port_id] = await self.make_port(
                            edge.source, edge.source_port, port_css)
                if edge.target_port is not None:
                    port_id = self.port_id(edge.target, edge.target_port)
                    if port_id not in ports:
                        ports[port_id] = await self.make_port(
                            edge.target, edge.target_port, port_css)

                elknode.edges += [await self.make_edge(edge, edge_css)]
        return nodes, ports

    async def make_edge(self,
                        edge: Edge,
                        styles: Optional[Set[str]] = None) -> ElkExtendedEdge:
        """Make the associated Elk edge for the given Edge

        :param edge: Edge object to wrap
        :type edge: Edge
        :param styles: List of css classes to add to given Elk edge, defaults to None
        :type styles: Optional[List[str]], optional
        :return: Elk edge
        :rtype: ElkExtendedEdge
        """

        labels = []
        properties = self.get_properties(edge, styles) or None
        label_layout_options = self.get_layout(
            edge.owner, ElkLabel)  # TODO add edgelabel type?
        edge_layout_options = self.get_layout(edge.owner, ElkEdge)

        for i, label in enumerate(edge.data.get(self.label_key, [])):

            if isinstance(label, ElkLabel):
                label = label.to_dict()  # used to create copy of label
            if isinstance(label, dict):
                label = ElkLabel.from_dict(label)
            if isinstance(label, str):
                label = ElkLabel(id=f"{edge.owner}_label_{i}_{label}",
                                 text=label)
            label.layoutOptions = merge(label.layoutOptions,
                                        label_layout_options)
            labels.append(label)
        for label in labels:
            self.register(label, edge)
        elk_edge = ElkExtendedEdge(
            id=edge.data.get("id", self.edge_id(edge)),
            sources=[self.port_id(edge.source, edge.source_port)],
            targets=[self.port_id(edge.target, edge.target_port)],
            properties=properties,
            layoutOptions=merge(edge.data.get("layoutOptions"),
                                edge_layout_options),
            labels=compact(labels),
        )
        self.register(elk_edge, edge)
        return elk_edge

    async def make_port(self,
                        owner: Hashable,
                        port: Hashable,
                        styles: Optional[Set[str]] = None) -> Port:
        """Make the associated elk port for the given owner node and port

        :param owner: [description]
        :type owner: Hashable
        :param port: [description]
        :type port: Hashable
        :param styles: list of css classes to apply to given ElkPort
        :type styles: List[str]
        :return: [description]
        :rtype: ElkPort
        """
        port_id = self.port_id(owner, port)
        properties = self.get_properties(port, styles) or None

        elk_port = ElkPort(
            id=port_id,
            height=self.port_scale,
            width=self.port_scale,
            properties=properties,
            # TODO labels
        )
        return Port(node=owner, elkport=elk_port)

    def get_node_data(self, node: Hashable) -> Dict:
        g, tree = self.source
        return g.nodes.get(node, {})

    async def collect_ports(self, *nodes) -> PortMap:
        ports: PortMap = {}
        for node in nodes:
            values = self.get_node_data(node).get(self.port_key, [])
            for i, port in enumerate(values):
                if isinstance(port, ElkPort):
                    # generate a fresh copy of the port to prevent mutating original
                    port = port.to_dict()
                elif isinstance(port, str):
                    port_id = self.port_id(node, port)
                    port = {
                        "id": port_id,
                        "labels": [{
                            "text": port,
                            "id": f"{port}_label_{i}",
                        }],
                    }

                if isinstance(port, dict):
                    elkport = ElkPort.from_dict(port)

                if elkport.width is None:
                    elkport.width = self.port_scale
                if elkport.height is None:
                    elkport.height = self.port_scale
                ports[elkport.id] = Port(node=node, elkport=elkport)
        return ports

    async def make_labels(self, node: Hashable) -> Optional[List[ElkLabel]]:
        labels = []
        g = self.source[0]
        data = g.nodes.get(node, {})
        values = data.get(self.label_key, [data.get("_id", f"{node}")])

        properties = {}
        css_classes = self.get_css(node, ElkLabel)
        if css_classes:
            properties["cssClasses"] = " ".join(css_classes)

        if isinstance(values, (str, ElkLabel)):
            values = [values]
        # get node labels
        for i, label in enumerate(values):
            if isinstance(label, str):
                label = ElkLabel(
                    id=f"{label}_label_{i}_{node}",
                    text=label,
                )
            elif isinstance(label, ElkLabel):
                # prevent mutating original label in the node data
                label = label.to_dict()
            if isinstance(label, dict):
                label = ElkLabel.from_dict(label)

            # add css classes and layout options
            label.layoutOptions = merge(label.layoutOptions,
                                        self.get_layout(node, ElkLabel))
            merged_props = merge(label.properties, properties)
            if merged_props is not None:
                merged_props = ElkProperties.from_dict(merged_props)
            label.properties = merged_props

            labels.append(label)
            self.register(label, node)
        return labels

    def collect_edges(self) -> Tuple[EdgeMap, EdgeMap]:
        """[summary]

        :return: [description]
        :rtype: Tuple[
            Dict[Hashable, List[ElkExtendedEdge]],
            Dict[Hashable, List[ElkExtendedEdge]]
        ]
        """
        visible: EdgeMap = defaultdict(
            list)  # will index edges by nx.lowest_common_ancestor
        hidden: EdgeMap = defaultdict(
            list)  # will index edges by nx.lowest_common_ancestor
        g, tree = self.source
        attr = self.HIDDEN_ATTR
        hidden: List[Tuple[List, List]] = []

        visible: EdgeMap = defaultdict(
            list)  # will index edges by nx.lowest_common_ancestor
        hidden: EdgeMap = defaultdict(
            list)  # will index edges by nx.lowest_common_ancestor

        closest_visible = map_visible(g, tree, attr)

        @lru_cache()
        def closest_common_visible(nodes: Tuple[Hashable]) -> Hashable:
            if tree is None:
                return ElkRoot
            result = lowest_common_ancestor(tree, nodes)
            return result

        for source, target, edge_data in g.edges(data=True):
            source_port, target_port = get_ports(edge_data)
            vis_source = closest_visible[source]
            vis_target = closest_visible[target]
            shidden = vis_source != source
            thidden = vis_target != target
            owner = closest_common_visible((vis_source, vis_target))
            if source == target and source == owner:
                if owner in tree:
                    for p in tree.predecessors(owner):
                        # need to make this edge's owner to it's parent
                        owner = p

            if shidden or thidden:
                # create new slack ports if source or target is remapped
                if vis_source != source:
                    source_port = (source, source_port)
                if vis_target != target:
                    target_port = (target, target_port)

                if vis_source != vis_target:
                    hidden[owner].append(
                        Edge(
                            source=vis_source,
                            source_port=source_port,
                            target=vis_target,
                            target_port=target_port,
                            data=edge_data,
                            owner=owner,
                        ))
            else:
                visible[owner].append(
                    Edge(
                        source=source,
                        source_port=source_port,
                        target=target,
                        target_port=target_port,
                        data=edge_data,
                        owner=owner,
                    ))
        return visible, hidden
コード例 #12
0
class Compositor(Node):
    """Compositor
    
    Attributes
    ----------
    cache_native_coordinates : Bool
        Default is True. If native_coordinates are requested by the user, it may take a long time to calculate if the
        Compositor points to many sources. The result is relatively small and is cached by default. Caching may not be
        desired if the datasource change or is updated.
    interpolation : str, dict, optional
        {interpolation}
    is_source_coordinates_complete : Bool
        Default is False. The source_coordinates do not have to completely describe the source. For example, the source
        coordinates could include the year-month-day of the source, but the actual source also has hour-minute-second
        information. In that case, source_coordinates is incomplete. This flag is used to automatically construct
        native_coordinates.
    n_threads : int
        Default is 10 -- used when threaded is True.
        NASA data servers seem to have a hard limit of 10 simultaneous requests, which determined the default value.
    shared_coordinates : :class:`podpac.Coordinates`, optional
        Coordinates that are shared amongst all of the composited sources
    source : str
        The source is used for a unique name to cache composited products.
    source_coordinates : :class:`podpac.Coordinates`
        Description
    sources : :class:`np.ndarray`
        An array of sources. This is a numpy array as opposed to a list so that boolean indexing may be used to
        subselect the nodes that will be evaluated.
    threaded : bool, optional
        Default if False.
        When threaded is False, the compositor stops evaluated sources once the output is completely filled.
        When threaded is True, the compositor must evaluate every source.
        The result is the same, but note that because of this, threaded=False could be faster than threaded=True,
        especially if n_threads is low. For example, threaded with n_threads=1 could be much slower than non-threaded
        if the output is completely filled after the first few sources.
    source_coordinates : :class:`podpac.Coordinates`, optional
        Coordinates that make each source unique. This is used for subsetting which sources to evaluate based on the
        user-requested coordinates. It is an optimization.
    
    Notes
    -----
    Developers of new Compositor nodes need to implement the `composite` method.
    """
    shared_coordinates = tl.Instance(Coordinates, allow_none=True)
    source_coordinates = tl.Instance(Coordinates, allow_none=True)
    is_source_coordinates_complete = tl.Bool(
        False,
        help=("This allows some optimizations but assumes that a node's "
              "native_coordinates=source_coordinate + shared_coordinate "
              "IN THAT ORDER"))

    source = tl.Unicode().tag(attr=True)
    sources = ArrayTrait(ndim=1)
    cache_native_coordinates = tl.Bool(True)

    interpolation = interpolation_trait(default_value=None)

    threaded = tl.Bool(False)
    n_threads = tl.Int(10)

    @tl.default('source')
    def _source_default(self):
        source = []
        for s in self.sources[:3]:
            source.append(str(s))
        return '_'.join(source)

    @tl.default('source_coordinates')
    def _source_coordinates_default(self):
        return self.get_source_coordinates()

    def get_source_coordinates(self):
        """
        Returns the coordinates describing each source.
        This may be implemented by derived classes, and is an optimization that allows evaluation subsets of source.
        
        Returns
        -------
        :class:`podpac.Coordinates`
            Coordinates describing each source.
        """
        return None

    @tl.default('shared_coordinates')
    def _shared_coordinates_default(self):
        return self.get_shared_coordinates()

    def get_shared_coordinates(self):
        """Coordinates shared by each source.
        
        Raises
        ------
        NotImplementedError
            Description
        """
        raise NotImplementedError()

    def select_sources(self, coordinates):
        """Downselect compositor sources based on requested coordinates.
        
        This is used during the :meth:`eval` process as an optimization
        when :attr:`source_coordinates` are not pre-defined.
        
        Parameters
        ----------
        coordinates : :class:`podpac.Coordinates`
            Coordinates to evaluate at compositor sources
        
        Returns
        -------
        :class:`np.ndarray`
            Array of downselected sources
        """

        # if source coordinates are defined, use intersect
        if self.source_coordinates is not None:
            # intersecting sources only
            try:
                _, I = self.source_coordinates.intersect(coordinates,
                                                         outer=True,
                                                         return_indices=True)

            except:  # Likely non-monotonic coordinates
                _, I = self.source_coordinates.intersect(coordinates,
                                                         outer=False,
                                                         return_indices=True)

            src_subset = self.sources[I]

        # no downselection possible - get all sources compositor
        else:
            src_subset = self.sources

        return src_subset

    def composite(self, outputs, result=None):
        """Implements the rules for compositing multiple sources together.
        
        Parameters
        ----------
        outputs : list
            A list of outputs that need to be composited together
        result : UnitDataArray, optional
            An optional pre-filled array may be supplied, otherwise the output will be allocated.
        
        Raises
        ------
        NotImplementedError
        """
        raise NotImplementedError()

    def iteroutputs(self, coordinates):
        """Summary
        
        Parameters
        ----------
        coordinates : :class:`podpac.Coordinates`
            Coordinates to evaluate at compositor sources
        
        Yields
        ------
        :class:`podpac.core.units.UnitsDataArray`
            Output from source node eval method
        """
        # downselect sources based on coordinates
        src_subset = self.select_sources(coordinates)

        if len(src_subset) == 0:
            yield self.create_output_array(coordinates)
            return

        # Set the interpolation properties for sources
        if self.interpolation is not None:
            for s in src_subset.ravel():
                if trait_is_defined(self, 'interpolation'):
                    s.interpolation = self.interpolation

        # Optimization: if coordinates complete and source coords is 1D,
        # set native_coordinates unless they are set already
        # WARNING: this assumes
        #              native_coords = source_coords + shared_coordinates
        #         NOT  native_coords = shared_coords + source_coords
        if self.is_source_coordinates_complete and self.source_coordinates.ndim == 1:
            coords_subset = list(
                self.source_coordinates.intersect(
                    coordinates, outer=True).coords.values())[0]
            coords_dim = list(self.source_coordinates.dims)[0]
            for s, c in zip(src_subset, coords_subset):
                nc = merge_dims([
                    Coordinates(np.atleast_1d(c), dims=[coords_dim]),
                    self.shared_coordinates
                ])

                if trait_is_defined(s, 'native_coordinates') is False:
                    s.native_coordinates = nc

        if self.threaded:
            # TODO pool of pre-allocated scratch space
            # TODO: docstring?
            def f(src):
                return src.eval(coordinates)

            pool = ThreadPool(processes=self.n_threads)
            results = [pool.apply_async(f, [src]) for src in src_subset]

            for src, res in zip(src_subset, results):
                yield res.get()
                #src._output = None # free up memory

        else:
            output = None  # scratch space
            for src in src_subset:
                output = src.eval(coordinates, output)
                yield output
                #output[:] = np.nan

    @node_eval
    @common_doc(COMMON_COMPOSITOR_DOC)
    def eval(self, coordinates, output=None):
        """Evaluates this nodes using the supplied coordinates. 

        Parameters
        ----------
        coordinates : :class:`podpac.Coordinates`
            {requested_coordinates}
        output : podpac.UnitsDataArray, optional
            {eval_output}
            
        Returns
        -------
        {eval_return}
        """

        self._requested_coordinates = coordinates

        outputs = self.iteroutputs(coordinates)
        output = self.composite(outputs, output)
        return output

    def find_coordinates(self):
        """
        Get the available native coordinates for the Node.

        Returns
        -------
        coords_list : list
            list of available coordinates (Coordinate objects)
        """

        raise NotImplementedError("TODO")

    @property
    @common_doc(COMMON_COMPOSITOR_DOC)
    def base_definition(self):
        """Base node defintion for Compositor nodes. 
        
        Returns
        -------
        {definition_return}
        """
        d = super(Compositor, self).base_definition
        d['sources'] = self.sources
        d['interpolation'] = self.interpolation
        return d
コード例 #13
0
ファイル: repo_git.py プロジェクト: deathbeds/wxyz
class Git(Repo):
    """A git repository widget"""

    # pylint: disable=protected-access,too-many-instance-attributes
    _git = T.Instance(G.Repo, allow_none=True)
    _ref_watcher = T.Instance(Watcher, allow_none=True)

    def _initialize_watcher(self):
        """watch key folders in git"""
        self._watcher = Watcher(Path(self._git.git_dir) / "refs",
                                _watcher_cls=_GitRefWatcher)

        def _schedule(change=None):
            IOLoop.current().add_callback(self._on_ref_change, change)

        self._watcher.observe(_schedule, "changes")

        _schedule()

    async def _on_ref_change(self, _change=None):
        """recalculate key values when files in .git/refs folder change"""
        self._update_heads()
        self._update_head_history()
        for remote in self.remotes.values():
            await remote._update_heads()

    @property
    def _remote_cls(self):
        return GitRemote

    @T.observe("working_dir")
    def _on_path(self, change):
        """handle when the working directory changes"""
        if change.new:
            self._git = G.Repo.init(change.new)
            ignore = Path(change.new) / ".gitignore"
            if not ignore.exists():
                ignore.write_text(".ipynb_checkpoints/")
                self.commit("initial commit")

            self._initialize_watcher()
            self._update_head_history()

    @T.default("head")
    def _default_head(self):
        """get current head"""
        return self._git.active_branch.name

    @T.observe("head")
    def _on_head_changed(self, change):
        """react to the symbolic head name changing"""
        if change.new:
            self._update_head_history()

    def _update_head_history(self):
        """build a structure of history"""
        # pylint: disable=broad-except
        try:
            head = [h for h in self._git.heads if h.name == self.head][0]
            self.head_hash = head.commit.hexsha
            self.head_history = [{
                "commit": str(c.newhexsha),
                "timestamp": c.time[0],
                "message": c.message,
                "author": {
                    "name": c.actor.name,
                    "email": c.actor.email
                },
            } for c in head.log()[::-1]]
        except Exception as err:
            self.log.warn("Git head update error, ignoring: %s",
                          err,
                          exc_info=True)
            self.head_history = []

    def _update_heads(self):
        """refresh the heads"""
        self.heads = {
            head.name: head.commit.hexsha
            for head in self._git.heads
        }
        self._update_head_history()

    def _on_watch_changes(self, *changes):
        """overload of the base method to handle changes"""
        self.dirty = self._git.is_dirty()
        if self._watcher:
            for change in self._watcher.changes:
                for tracker in self._trackers:
                    tracked_path = Path(self._git.working_dir) / change["path"]
                    if tracker.path.resolve() == tracked_path.resolve():
                        tracker._on_file_change(None)
        return [
            dict(a_path=diff.a_path,
                 b_path=diff.b_path,
                 change_type=diff.change_type)
            for diff in self._git.index.diff(None)
        ] + [
            dict(a_path=None, b_path=ut, change_type="U")
            for ut in self._git.untracked_files
        ]

    def stage(self, path):
        """stage a single path to the index"""
        self._git.index.add(path)

    def unstage(self, path):
        """remove a path from the index"""
        self._git.index.remove(path)

    def commit(self, message):
        """create a commit"""
        self._git.index.commit(message)
        self._on_watching(None)

    def revert(self, ref):
        """restore to a committish"""
        self._git.head.commit = ref
        self._git.head.reset(index=True, working_tree=True)

    def branch(self, name, ref="HEAD"):
        """create and checkout a new branch"""
        self._git.create_head(name, ref)
        self.checkout(name)

    def checkout(self, name):
        """checkout a named reference"""
        head = [h for h in self._git.heads if h.name == name][0]
        head.checkout()
        self.head = head.name
        self._git.head.reset(index=True, working_tree=True)

    def merge(self, ref):
        """create a merge commit on the active branch with the given ref"""
        active = self._git.active_branch
        active_commit = self._git.active_branch.commit
        active_name = active.name
        merge_base = self._git.merge_base(active, ref)
        ref_commit = self._git.commit(ref)
        self._git.index.merge_tree(ref_commit, base=merge_base)
        merge_commit = self._git.index.commit(
            f"Merged {ref} into {active_name}",
            parent_commits=(active_commit, ref_commit),
        )
        self.log.error("MERGE %s", merge_commit)
        self._git.active_branch.reference = merge_commit
        active.checkout()
        self._git.head.reset(index=True, working_tree=True)

    def _update_remotes(self):
        """fetch some remotes"""
        remotes = {}
        for remote in self._git.remotes:
            remotes[remote.name] = self._remote_cls(name=remote.name,
                                                    url=remote.url,
                                                    _remote=remote)
        self.remotes = remotes
コード例 #14
0
class WorkflowsLayer(ipyleaflet.TileLayer):
    """
    Subclass of ``ipyleaflet.TileLayer`` for displaying a Workflows `~.geospatial.Image`.

    Attributes
    ----------
    image: ~.geospatial.Image
        The `~.geospatial.Image` to use
    parameters: ParameterSet
        Parameters to use while computing; modify attributes under ``.parameters``
        (like ``layer.parameters.foo = "bar"``) to cause the layer to recompute
        and update under those new parameters.
    xyz_obj: ~.models.XYZ
        Read-only: The `XYZ` object this layer is displaying.
    session_id: str
        Read-only: Unique ID that error logs will be stored under, generated automatically.
    checkerboard: bool, default True
        Whether to display a checkerboarded background for missing or masked data.
    colormap: str, optional, default None
        Name of the colormap to use.
        If set, `image` must have 1 band.
    r_min: float, optional, default None
        Min value for scaling the red band. Along with r_max,
        controls scaling when a colormap is enabled.
    r_max: float, optional, default None
        Max value for scaling the red band. Along with r_min, controls scaling
        when a colormap is enabled.
    g_min: float, optional, default None
        Min value for scaling the green band.
    g_max: float, optional, default None
        Max value for scaling the green band.
    b_min: float, optional, default None
        Min value for scaling the blue band.
    b_max: float, optional, default None
        Max value for scaling the blue band.
    error_output: ipywidgets.Output, optional, default None
        If set, write unique errors from tiles computation to this output area
        from a background thread. Setting to None stops the listener thread.

    Example
    -------
    >>> import descarteslabs.workflows as wf
    >>> wf.map # doctest: +SKIP
    >>> # ^ display interactive map
    >>> img = wf.Image.from_id("landsat:LC08:PRE:TOAR:meta_LC80330352016022_v1").pick_bands("red")
    >>> masked_img = img.mask(img > wf.parameter("threshold", wf.Float))
    >>> layer = masked_img.visualize("sample", colormap="viridis", threshold=0.07) # doctest: +SKIP
    >>> layer.colormap = "plasma" # doctest: +SKIP
    >>> # ^ change colormap (this will update the layer on the map)
    >>> layer.parameters.threshold = 0.13 # doctest: +SKIP
    >>> # ^ adjust parameters (this also updates the layer)
    >>> layer.set_scales((0.01, 0.3)) # doctest: +SKIP
    >>> # ^ adjust scaling (this also updates the layer)
    """

    attribution = traitlets.Unicode("Descartes Labs").tag(sync=True, o=True)
    min_zoom = traitlets.Int(5).tag(sync=True, o=True)
    url = traitlets.Unicode(read_only=True).tag(sync=True)

    image = traitlets.Instance(Image)
    parameters = traitlets.Instance(parameters.ParameterSet, allow_none=True)
    xyz_obj = traitlets.Instance(XYZ, read_only=True)
    session_id = traitlets.Unicode(read_only=True)

    checkerboard = traitlets.Bool(True)
    colormap = traitlets.Unicode(None, allow_none=True)

    r_min = ScaleFloat(None, allow_none=True)
    r_max = ScaleFloat(None, allow_none=True)
    g_min = ScaleFloat(None, allow_none=True)
    g_max = ScaleFloat(None, allow_none=True)
    b_min = ScaleFloat(None, allow_none=True)
    b_max = ScaleFloat(None, allow_none=True)

    error_output = traitlets.Instance(widgets.Output, allow_none=True)
    autoscale_progress = traitlets.Instance(ClearableOutput)

    def __init__(self, image, *args, **kwargs):
        params = kwargs.pop("parameters", {})
        super(WorkflowsLayer, self).__init__(*args, **kwargs)

        with self.hold_trait_notifications():
            self.image = image
            self.set_trait("session_id", uuid.uuid4().hex)
            self.set_trait(
                "autoscale_progress",
                ClearableOutput(
                    widgets.Output(),
                    layout=widgets.Layout(max_height="10rem", flex="1 0 auto"),
                ),
            )
            self.set_parameters(**params)

        self._error_listener = None
        self._known_errors = set()
        self._known_errors_lock = threading.Lock()

    def make_url(self):
        """
        Generate the URL for this layer.

        This is called automatically as the attributes (`image`, `colormap`, scales, etc.) are changed.

        Example
        -------
        >>> import descarteslabs.workflows as wf
        >>> img = wf.Image.from_id("landsat:LC08:PRE:TOAR:meta_LC80330352016022_v1") # doctest: +SKIP
        >>> img = img.pick_bands("red blue green") # doctest: +SKIP
        >>> layer = img.visualize("sample") # doctest: +SKIP
        >>> layer.make_url() # doctest: +SKIP
        'https://workflows.descarteslabs.com/master/xyz/9ec70d0e99db7f50c856c774809ae454ffd8475816e05c5c/{z}/{x}/{y}.png?session_id=xxx&checkerboard=true'
        """
        if not self.visible:
            # workaround for the fact that Leaflet still loads tiles from inactive layers,
            # which is expensive computation users don't want
            return ""

        if self.colormap is not None:
            scales = [[self.r_min, self.r_max]]
        else:
            scales = [
                [self.r_min, self.r_max],
                [self.g_min, self.g_max],
                [self.b_min, self.b_max],
            ]

        scales = [scale for scale in scales if scale != [None, None]]

        parameters = self.parameters.to_dict()

        return self.xyz_obj.url(session_id=self.session_id,
                                colormap=self.colormap,
                                scales=scales,
                                checkerboard=self.checkerboard,
                                **parameters)

    @traitlets.observe("image")
    def _update_xyz(self, change):
        old, new = change["old"], change["new"]
        if old is new:
            # traitlets does an == check between the old and new value to decide if it's changed,
            # which for an Image, returns another Image, which it considers changed.
            return

        xyz = XYZ.build(new, name=self.name)
        xyz.save()
        self.set_trait("xyz_obj", xyz)

    @traitlets.observe(
        "visible",
        "checkerboard",
        "colormap",
        "r_min",
        "r_max",
        "g_min",
        "g_max",
        "b_min",
        "b_max",
        "xyz_obj",
        "session_id",
        "parameters",
    )
    @traitlets.observe("parameters", type="delete")
    def _update_url(self, change):
        try:
            self.set_trait("url", self.make_url())
        except ValueError as e:
            if "Invalid scales passed" not in str(e):
                raise e

    @traitlets.observe("parameters", type="delete")
    def _update_url_on_param_delete(self, change):
        # traitlets is dumb and decorator stacking doesn't work so we have to repeat this
        try:
            self.set_trait("url", self.make_url())
        except ValueError as e:
            if "Invalid scales passed" not in str(e):
                raise e

    @traitlets.observe("xyz_obj", "session_id")
    def _update_error_logger(self, change):
        if self.error_output is None:
            return

        # Remove old errors for the layer
        self.forget_errors()
        new_errors = []
        for error in self.error_output.outputs:
            if not error["text"].startswith(self.name + ": "):
                new_errors.append(error)
        self.error_output.outputs = tuple(new_errors)

        if self._error_listener is not None:
            self._error_listener.stop(timeout=1)

        listener = self.xyz_obj.error_listener()
        listener.add_callback(self._log_errors_callback)
        listener.listen(self.session_id,
                        datetime.datetime.now(datetime.timezone.utc))

        self._error_listener = listener

    def _stop_error_logger(self):
        if self._error_listener is not None:
            self._error_listener.stop(timeout=1)
            self._error_listener = None

    @traitlets.observe("error_output")
    def _toggle_error_listener_if_output(self, change):
        if change["new"] is None:
            self._stop_error_logger()
        else:
            if self._error_listener is None:
                self._update_error_logger({})

    def _log_errors_callback(self, msg):
        message = msg.message

        with self._known_errors_lock:
            if message in self._known_errors:
                return
            else:
                self._known_errors.add(message)

        error = "{}: {}\n".format(self.name, message)
        self.error_output.append_stdout(error)

    def __del__(self):
        self._stop_error_logger()
        super(WorkflowsLayer, self).__del__()

    def forget_errors(self):
        """
        Clear the set of known errors, so they are re-displayed if they occur again

        Example
        -------
        >>> import descarteslabs.workflows as wf
        >>> img = wf.Image.from_id("landsat:LC08:PRE:TOAR:meta_LC80330352016022_v1") # doctest: +SKIP
        >>> wf.map # doctest: +SKIP
        >>> layer = img.visualize("sample visualization") # doctest: +SKIP
        >>> # ^ will show an error for attempting to visualize more than 3 bands
        >>> layer.forget_errors() # doctest: +SKIP
        >>> wf.map.zoom = 10 # doctest: +SKIP
        >>> # ^ attempting to load more tiles from img will cause the same error to appear
        """
        with self._known_errors_lock:
            self._known_errors.clear()

    def set_scales(self, scales, new_colormap=False):
        """
        Update the scales for this layer by giving a list of scales

        Parameters
        ----------
        scales: list of lists, default None
            The scaling to apply to each band in the `Image`.

            If `Image` contains 3 bands, ``scales`` must be a list like ``[(0, 1), (0, 1), (-1, 1)]``.

            If `Image` contains 1 band, ``scales`` must be a list like ``[(0, 1)]``,
            or just ``(0, 1)`` for convenience

            If None, each 256x256 tile will be scaled independently
            based on the min and max values of its data.
        new_colormap: str, None, or False, optional, default False
            A new colormap to set at the same time, or False to use the current colormap.

        Example
        -------
        >>> import descarteslabs.workflows as wf
        >>> img = wf.Image.from_id("landsat:LC08:PRE:TOAR:meta_LC80330352016022_v1") # doctest: +SKIP
        >>> img = img.pick_bands("red") # doctest: +SKIP
        >>> layer = img.visualize("sample visualization", colormap="viridis") # doctest: +SKIP
        >>> layer.set_scales((0.08, 0.3), new_colormap="plasma") # doctest: +SKIP
        >>> # ^ optionally set new colormap
        """
        colormap = self.colormap if new_colormap is False else new_colormap

        if scales is not None:
            scales = XYZ._validate_scales(scales)

            scales_len = 1 if colormap is not None else 3
            if len(scales) != scales_len:
                msg = "Expected {} scales, but got {}.".format(
                    scales_len, len(scales))
                if len(scales) in (1, 2):
                    msg += " If displaying a 1-band Image, use a colormap."
                elif colormap:
                    msg += " Colormaps cannot be used with multi-band images."

                raise ValueError(msg)

            with self.hold_trait_notifications():
                if colormap is None:
                    self.r_min = scales[0][0]
                    self.r_max = scales[0][1]
                    self.g_min = scales[1][0]
                    self.g_max = scales[1][1]
                    self.b_min = scales[2][0]
                    self.b_max = scales[2][1]
                else:
                    self.r_min = scales[0][0]
                    self.r_max = scales[0][1]
                if new_colormap is not False:
                    self.colormap = new_colormap
        else:
            # scales is None
            with self.hold_trait_notifications():
                if colormap is None:
                    self.r_min = None
                    self.r_max = None
                    self.g_min = None
                    self.g_max = None
                    self.b_min = None
                    self.b_max = None
                else:
                    self.r_min = None
                    self.r_max = None
                if new_colormap is not False:
                    self.colormap = new_colormap

    def set_parameters(self, **params):
        """
        Set new parameters for this `WorkflowsLayer`.

        In typical cases, you update parameters by assigning to `parameters`
        (like ``layer.parameters.threshold = 6.6``).

        Instead, use this function when you need to change the *names or types*
        of parameters available on the `WorkflowsLayer`. (Users shouldn't need to
        do this, as `~.Image.visualize` handles it for you, but custom widget developers
        may need to use this method when they change the `image` field on a `WorkflowsLayer`.)

        If a value is an ipywidgets Widget, it will be linked to that parameter
        (via its ``"value"`` attribute). If a parameter was previously set with
        a widget, and a different widget instance (or non-widget) is passed
        for its new value, the old widget is automatically unlinked.
        If the same widget instance is passed as is already linked, no change occurs.

        Parameters
        ----------
        params: JSON-serializable value, Proxytype, or ipywidgets.Widget
            Paramter names to new values. Values can be Python types,
            `Proxytype` instances, or ``ipywidgets.Widget`` instances.

        Example
        -------

        >>> import descarteslabs.workflows as wf
        >>> from ipywidgets import FloatSlider
        >>> img = wf.Image.from_id("landsat:LC08:PRE:TOAR:meta_LC80330352016022_v1") # doctest: +SKIP
        >>> img = img.pick_bands("red") # doctest: +SKIP
        >>> masked_img = img.mask(img > wf.parameter("threshold", wf.Float)) # doctest: +SKIP
        >>> layer = masked_img.tile_layer("sample", colormap="plasma", threshold=0.07) # doctest: +SKIP
        >>> scaled_img = img * wf.parameter("scale", wf.Float) + wf.parameter("offset", wf.Float) # doctest: +SKIP
        >>> with layer.hold_trait_notifications(): # doctest: +SKIP
        ...     layer.image = scaled_img # doctest: +SKIP
        ...     layer.set_parameters(scale=FloatSlider(min=0, max=10, value=2), offset=2.5) # doctest: +SKIP
        >>> # ^ re-use the same layer instance for a new Image with different parameters
        """
        param_set = self.parameters
        if param_set is None:
            param_set = self.parameters = parameters.ParameterSet(
                self, "parameters")

        with self.hold_trait_notifications():
            param_set.update(**params)

    def _ipython_display_(self):
        param_set = self.parameters
        if param_set:
            widget = param_set.widget
            if widget and len(widget.children) > 0:
                widget._ipython_display_()
コード例 #15
0
ファイル: volume.py プロジェクト: esizikova/ipyvolume
class Figure(widgets.DOMWidget):
    """Widget class representing a volume (rendering) using three.js"""
    _view_name = Unicode('FigureView').tag(sync=True)
    _view_module = Unicode('ipyvolume').tag(sync=True)
    _model_name = Unicode('FigureModel').tag(sync=True)
    _model_module = Unicode('ipyvolume').tag(sync=True)
    _view_module_version = Unicode(semver_range_frontend).tag(sync=True)
    _model_module_version = Unicode(semver_range_frontend).tag(sync=True)

    volume_data = Array(default_value=None,
                        allow_none=True).tag(sync=True,
                                             **array_cube_png_serialization)
    eye_separation = traitlets.CFloat(6.4).tag(sync=True)
    data_min = traitlets.CFloat().tag(sync=True)
    data_max = traitlets.CFloat().tag(sync=True)
    tf = traitlets.Instance(TransferFunction, allow_none=True).tag(
        sync=True, **ipywidgets.widget_serialization)
    anglex = traitlets.Float(0.0).tag(sync=True)
    angley = traitlets.Float(0.0).tag(sync=True)
    anglez = traitlets.Float(0.0).tag(sync=True)
    angle_order = Unicode(default_value="XYZ").tag(sync=True)

    scatters = traitlets.List(traitlets.Instance(Scatter), [],
                              allow_none=False).tag(
                                  sync=True, **ipywidgets.widget_serialization)
    meshes = traitlets.List(traitlets.Instance(Mesh), [],
                            allow_none=False).tag(
                                sync=True, **ipywidgets.widget_serialization)

    animation = traitlets.Float(1000.0).tag(sync=True)
    animation_exponent = traitlets.Float(.5).tag(sync=True)

    ambient_coefficient = traitlets.Float(0.5).tag(sync=True)
    diffuse_coefficient = traitlets.Float(0.8).tag(sync=True)
    specular_coefficient = traitlets.Float(0.5).tag(sync=True)
    specular_exponent = traitlets.Float(5).tag(sync=True)
    stereo = traitlets.Bool(False).tag(sync=True)
    screen_capture_enabled = traitlets.Bool(False).tag(sync=True)
    screen_capture_mime_type = traitlets.Unicode(
        default_value='image/png').tag(sync=True)
    screen_capture_data = traitlets.Unicode(default_value=None,
                                            allow_none=True).tag(sync=True)
    fullscreen = traitlets.Bool(False).tag(sync=True)

    camera_control = traitlets.Unicode(default_value='trackball').tag(
        sync=True)

    width = traitlets.CInt(500).tag(sync=True)
    height = traitlets.CInt(400).tag(sync=True)
    downscale = traitlets.CInt(1).tag(sync=True)
    show = traitlets.Unicode("Volume").tag(sync=True)  # for debugging

    xlim = traitlets.List(traitlets.CFloat,
                          default_value=[0, 1],
                          minlen=2,
                          maxlen=2).tag(sync=True)
    ylim = traitlets.List(traitlets.CFloat,
                          default_value=[0, 1],
                          minlen=2,
                          maxlen=2).tag(sync=True)
    zlim = traitlets.List(traitlets.CFloat,
                          default_value=[0, 1],
                          minlen=2,
                          maxlen=2).tag(sync=True)

    xlabel = traitlets.Unicode("x").tag(sync=True)
    ylabel = traitlets.Unicode("y").tag(sync=True)
    zlabel = traitlets.Unicode("z").tag(sync=True)

    style = traitlets.Dict(default_value=ipyvolume.style.default).tag(
        sync=True)

    #xlim = traitlets.Tuple(traitlets.CFloat(0), traitlets.CFloat(1)).tag(sync=True)
    #y#lim = traitlets.Tuple(traitlets.CFloat(0), traitlets.CFloat(1)).tag(sync=True)
    #zlim = traitlets.Tuple(traitlets.CFloat(0), traitlets.CFloat(1)).tag(sync=True)

    def __init__(self, **kwargs):
        super(Figure, self).__init__(**kwargs)
        self._screenshot_handlers = widgets.CallbackDispatcher()
        self.on_msg(self._handle_custom_msg)

    def screenshot(self):
        self.send({'msg': 'screenshot'})

    def on_screenshot(self, callback, remove=False):
        self._screenshot_handlers.register_callback(callback, remove=remove)

    def _handle_custom_msg(self, content, buffers):
        print("msg", content)
        if content.get('event', '') == 'screenshot':
            self._screenshot_handlers(content['data'])
コード例 #16
0
ファイル: model.py プロジェクト: sharadMalmanchi/vaex
class GridCalculator(_HasState):
    '''A grid is responsible for scheduling the grid calculations and possible slicing'''
    class Status(enum.Enum):
        VOID = 1
        STAGED_CALCULATION = 3
        CALCULATING = 4
        READY = 9

    status = traitlets.UseEnum(Status, Status.VOID)
    df = traitlets.Instance(vaex.dataframe.DataFrame)
    models = traitlets.List(traitlets.Instance(DataArray))
    calculation = traitlets.Any(None, allow_none=True)
    _debug = traitlets.Bool(False)

    def __init__(self, df, models):
        super().__init__(df=df, models=[])
        self._callbacks_regrid = []
        self._callbacks_slice = []
        for model in models:
            self.model_add(model)
        self._testing_exeception_regrid = False  # used for testing, to throw an exception
        self._testing_exeception_reslice = False  # used for testing, to throw an exception

    # def model_remove(self, model, regrid=True):
    #     index = self.models.index(model)
    #     del self.models[index]
    #     del self._callbacks_regrid[index]
    #     del self._callbacks_slice[index]

    def model_add(self, model):
        self.models = self.models + [model]
        if model.status == DataArray.Status.NEEDS_CALCULATING_GRID:
            if self.calculation is not None:
                self._cancel_computation()
            self.computation()

        def on_status_changed(change):
            if change.owner.status == DataArray.Status.NEEDS_CALCULATING_GRID:
                if self.calculation is not None:
                    self._cancel_computation()
                self.computation()

        model.observe(on_status_changed, 'status')
        # TODO: if we listen to the same axis twice it will trigger twice
        for axis in model.axes:
            axis.observe(lambda change: self.reslice(), 'slice')
        # self._callbacks_regrid.append(model.signal_regrid.connect(self.on_regrid))
        # self._callbacks_slice.append(model.signal_slice.connect(self.reslice))
        assert model.df == self.df

    # @vaex.jupyter.debounced(delay_seconds=0.05, reentrant=False)
    # def reslice_debounced(self):
    #     self.reslice()

    def reslice(self, source_model=None):
        if self._testing_exeception_reslice:
            raise RuntimeError("test:reslice")
        coords = []
        selections = self.models[0].selections
        selections = [
            k for k in selections if k is None or self.df.has_selection(k)
        ]
        for model in self.models:
            subgrid = self.grid
            subgrid_sliced = self.grid
            axis_index = 1
            has_slice = False
            dims = ["selection"]
            coords = [selections.copy()]
            mins = []
            maxs = []
            for other_model in self.models:
                if other_model == model:  # simply skip these axes
                    # for expression, shape, limit, slice_index in other_model.bin_parameters():
                    for axis in other_model.axes:
                        axis_index += 1
                        dims.append(str(axis.expression))
                        coords.append(axis.centers)
                        mins.append(axis.min)
                        maxs.append(axis.max)
                else:
                    # for expression, shape, limit, slice_index in other_model.bin_parameters():
                    for axis in other_model.axes:
                        if axis.slice is not None:
                            subgrid_sliced = subgrid_sliced.__getitem__(
                                tuple([slice(None)] * axis_index +
                                      [axis.slice])).copy()
                            subgrid = np.sum(subgrid, axis=axis_index)
                            has_slice = True
                        else:
                            subgrid_sliced = np.sum(subgrid_sliced,
                                                    axis=axis_index)
                            subgrid = np.sum(subgrid, axis=axis_index)
            grid = xarray.DataArray(subgrid, dims=dims, coords=coords)
            for i, (vmin, vmax) in enumerate(zip(mins, maxs)):
                # +1 to skip the selection axis
                grid.coords[dims[i + 1]].attrs['min'] = vmin
                grid.coords[dims[i + 1]].attrs['max'] = vmax
            model.grid = grid
            if has_slice:
                model.grid_sliced = xarray.DataArray(subgrid_sliced)
            else:
                model.grid_sliced = None

    def _regrid_error(self, e):
        try:
            self._error(e)
            for model in self.models:
                model._error(e)
            for model in self.models:
                model.exception = e
                model.status = vaex.jupyter.model.DataArray.Status.EXCEPTION
        except Exception as e2:
            print(e2)

    def on_regrid(self, ignore=None):
        self.regrid()

    @vaex.jupyter.debounced(delay_seconds=0.5,
                            reentrant=False,
                            on_error=_regrid_error)
    async def computation(self):
        try:
            logger.debug('Starting grid computation')
            # vaex.utils.print_stack_trace()
            if self._testing_exeception_regrid:
                raise RuntimeError("test:regrid")
            if not self.models:
                return
            binby = []
            shapes = []
            limits = []
            selections = self.models[0].selections
            for model in self.models:
                if model.selections != selections:
                    raise ValueError(
                        'Selections for all models should be the same')
                for axis in model.axes:
                    binby.append(axis.expression)
                    limits.append([axis.min, axis.max])
                    shapes.append(axis.shape or axis.shape_default)
            selections = [
                k for k in selections if k is None or self.df.has_selection(k)
            ]

            self._continue_calculation = True
            logger.debug('Setting up grid computation...')
            self.calculation = self.df.count(binby=binby,
                                             shape=shapes,
                                             limits=limits,
                                             selection=selections,
                                             progress=self.progress,
                                             delay=True)

            logger.debug('Setting up grid computation done tasks=%r',
                         self.df.executor.tasks)

            logger.debug('Schedule debounced execute')
            self.df.widget.execute_debounced()
            # keep a nearly reference to this, since awaits (which trigger the execution, AND reset of this future) may change it this
            execute_prehook_future = self.df.widget.execute_debounced.pre_hook_future

            async with contextlib.AsyncExitStack() as stack:
                for model in self.models:
                    await stack.enter_async_context(
                        model._state_change_to(
                            DataArray.Status.STAGED_CALCULATING_GRID))
            async with contextlib.AsyncExitStack() as stack:
                for model in self.models:
                    await stack.enter_async_context(
                        model._state_change_to(
                            DataArray.Status.CALCULATING_GRID))
                await execute_prehook_future
            async with contextlib.AsyncExitStack() as stack:
                for model in self.models:
                    await stack.enter_async_context(
                        model._state_change_to(
                            DataArray.Status.CALCULATED_GRID))
                # first assign to local
                grid = await self.calculation
                # indicate we are done with the calculation
                self.calculation = None
                # raise asyncio.CancelledError("User abort")
            async with contextlib.AsyncExitStack() as stack:
                for model in self.models:
                    await stack.enter_async_context(
                        model._state_change_to(DataArray.Status.READY))
                self.grid = grid
                self.reslice()
        except vaex.execution.UserAbort:
            pass  # a user changed the limits or expressions
        except asyncio.CancelledError:
            pass  # cancelled...

    def _cancel_computation(self):
        logger.debug('Cancelling grid computation')
        self._continue_calculation = False

    def progress(self, f):
        return self._continue_calculation and all(
            [model.on_progress_grid(f) for model in self.models])
コード例 #17
0
class StructureViewer(ipw.VBox):
    """NGL structure viewer including download button"""

    structure = traitlets.Instance(Structure, allow_none=True)

    def __init__(self, **kwargs):
        self._current_view = None

        self.viewer = nglview.NGLWidget()
        self.viewer.camera = "orthographic"
        self.viewer.stage.set_parameters(mouse_preset="pymol")
        self.viewer_box = ipw.Box(
            children=(self.viewer, ),
            layout={
                "width": "auto",
                "height": "auto",
                "border": "solid 0.5px darkgrey",
                "margin": "0px",
                "padding": "0.5px",
            },
        )

        self.download = DownloadChooser(**kwargs)

        super().__init__(
            children=(self.viewer_box, self.download),
            layout={
                "width": "auto",
                "height": "auto",
                "margin": "0px 0px 0px 0px",
                "padding": "0px 0px 10px 0px",
            },
        )

        self.observe(self._on_change_structure, names="structure")

        traitlets.dlink((self, "structure"), (self.download, "structure"))

    def _on_change_structure(self, change):
        """Update viewer for new structure"""
        self.reset()
        self._current_view = self.viewer.add_structure(
            nglview.TextStructure(change["new"].as_pdb))
        self.viewer.add_representation("ball+stick", aspectRatio=4)
        self.viewer.add_representation("unitcell")

    def freeze(self):
        """Disable widget"""
        self.download.freeze()

    def unfreeze(self):
        """Activate widget (in its current state)"""
        self.download.unfreeze()

    def reset(self):
        """Reset widget"""
        self.download.reset()
        if self._current_view is not None:
            self.viewer.clear()
            self.viewer.remove_component(self._current_view)
            self._current_view = None
コード例 #18
0
ファイル: model.py プロジェクト: sharadMalmanchi/vaex
class Axis(_HasState):
    class Status(enum.Enum):
        """
        State transitions
        NO_LIMITS -> STAGED_CALCULATING_LIMITS -> CALCULATING_LIMITS -> CALCULATED_LIMITS -> READY

        when expression changes:
            STAGED_CALCULATING_LIMITS: 
                calculation.cancel()
                ->NO_LIMITS
            CALCULATING_LIMITS: 
                calculation.cancel()
                ->NO_LIMITS

        when min/max changes:
            STAGED_CALCULATING_LIMITS: 
                calculation.cancel()
                ->NO_LIMITS
            CALCULATING_LIMITS: 
                calculation.cancel()
                ->NO_LIMITS
        """
        NO_LIMITS = 1
        STAGED_CALCULATING_LIMITS = 2
        CALCULATING_LIMITS = 3
        CALCULATED_LIMITS = 4
        READY = 5
        EXCEPTION = 6
        ABORTED = 7

    status = traitlets.UseEnum(Status, Status.NO_LIMITS)
    df = traitlets.Instance(vaex.dataframe.DataFrame)
    expression = Expression()
    slice = traitlets.CInt(None, allow_none=True)
    min = traitlets.CFloat(None, allow_none=True)
    max = traitlets.CFloat(None, allow_none=True)
    centers = traitlets.Any()
    shape = traitlets.CInt(None, allow_none=True)
    shape_default = traitlets.CInt(64)
    calculation = traitlets.Any(None, allow_none=True)
    exception = traitlets.Any(None, allow_none=True)
    _status_change_delay = traitlets.Float(0)

    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        if self.min is not None and self.max is not None:
            self.status = Axis.Status.READY
            self._calculate_centers()
        else:
            self.computation()
        self.observe(self.on_change_expression, 'expression')
        self.observe(self.on_change_shape, 'shape')
        self.observe(self.on_change_shape_default, 'shape_default')

    def __repr__(self):
        def myrepr(value, key):
            if isinstance(value, vaex.expression.Expression):
                return str(value)
            return value

        args = ', '.join('{}={}'.format(key, myrepr(getattr(self, key), key))
                         for key in self.traits().keys() if key != 'df')
        return '{}({})'.format(self.__class__.__name__, args)

    @property
    def has_missing_limit(self):
        # return not self.df.is_category(self.expression) and (self.min is None or self.max is None)
        return (self.min is None or self.max is None)

    def on_change_expression(self, change):
        self.min = None
        self.max = None
        self.status = Axis.Status.NO_LIMITS
        if self.calculation is not None:
            self._cancel_computation()
        self.computation()

    def on_change_shape(self, change):
        if self.min is not None and self.max is not None:
            self._calculate_centers()

    def on_change_shape_default(self, change):
        if self.min is not None and self.max is not None:
            self._calculate_centers()

    def _cancel_computation(self):
        self._continue_calculation = False

    @traitlets.observe('min', 'max')
    def on_change_limits(self, change):
        if self.min is not None and self.max is not None:
            self._calculate_centers()
        if self.status == Axis.Status.NO_LIMITS:
            if self.min is not None and self.max is not None:
                self.status = Axis.Status.READY
        elif self.status == Axis.Status.READY:
            if self.min is None or self.max is None:
                self.status = Axis.Status.NO_LIMITS
            else:
                # in this case, grids may want to be computed
                # this happens when a user change min/max
                pass
        else:
            if self.calculation is not None:
                self._cancel_computation()
                if self.min is not None and self.max is not None:
                    self.status = Axis.Status.READY
                else:
                    self.status = Axis.Status.NO_LIMITS
            else:
                # in this case we've set min/max after the calculation
                assert self.min is not None or self.max is not None

    @vaex.jupyter.debounced(delay_seconds=0.1,
                            reentrant=False,
                            on_error=_HasState._error)
    async def computation(self):
        categorical = self.df.is_category(self.expression)
        if categorical:
            N = self.df.category_count(self.expression)
            self.min, self.max = -0.5, N - 0.5
            # centers = np.arange(N)
            # self.shape = N
            self._calculate_centers()
            self.status = Axis.Status.READY
        else:
            try:

                self._continue_calculation = True
                self.calculation = self.df.minmax(self.expression,
                                                  delay=True,
                                                  progress=self._progress)
                self.df.widget.execute_debounced()
                # keep a nearly reference to this, since awaits (which trigger the execution, AND reset of this future) may change it this
                execute_prehook_future = self.df.widget.execute_debounced.pre_hook_future
                async with self._state_change_to(
                        Axis.Status.STAGED_CALCULATING_LIMITS):
                    pass
                async with self._state_change_to(
                        Axis.Status.CALCULATING_LIMITS):
                    await execute_prehook_future
                async with self._state_change_to(
                        Axis.Status.CALCULATED_LIMITS):
                    vmin, vmax = await self.calculation
                # indicate we are done with the calculation
                self.calculation = None
                if not self._continue_calculation:
                    assert self.status == Axis.Status.READY
                async with self._state_change_to(Axis.Status.READY):
                    self.min, self.max = vmin, vmax
                    self._calculate_centers()
            except vaex.execution.UserAbort:
                # expression or min/max change, we don't have to take action
                assert self.status in [
                    Axis.Status.NO_LIMITS, Axis.Status.READY
                ]
            except asyncio.CancelledError:
                pass

    def _progress(self, f):
        # we use the progres callback to cancel as calculation
        return self._continue_calculation

    def _calculate_centers(self):
        categorical = self.df.is_category(self.expression)
        if categorical:
            N = self.df.category_count(self.expression)
            centers = np.arange(N)
            self.shape = N
        else:
            centers = self.df.bin_centers(self.expression,
                                          [self.min, self.max],
                                          shape=self.shape
                                          or self.shape_default)
        self.centers = centers
コード例 #19
0
class ProviderImplementationChooser(  # pylint: disable=too-many-instance-attributes
        ipw.VBox):
    """List all OPTIMADE providers and their implementations"""

    provider = traitlets.Instance(LinksResourceAttributes, allow_none=True)
    database = traitlets.Tuple(
        traitlets.Unicode(),
        traitlets.Instance(LinksResourceAttributes, allow_none=True),
        default_value=("", None),
    )

    HINT = {"provider": "Select a provider", "child_dbs": "Select a database"}
    INITIAL_CHILD_DBS = [("", (("No provider chosen", None), ))]

    def __init__(
        self,
        child_db_limit: int = None,
        disable_providers: List[str] = None,
        skip_providers: List[str] = None,
        skip_databases: Dict[str, List[str]] = None,
        provider_database_groupings: Dict[str, Dict[str, List[str]]] = None,
        **kwargs,
    ):
        self.child_db_limit = (child_db_limit if child_db_limit
                               and child_db_limit > 0 else 10)
        self.skip_child_dbs = skip_databases or {}
        self.child_db_groupings = (provider_database_groupings
                                   if provider_database_groupings is not None
                                   else PROVIDER_DATABASE_GROUPINGS)
        self.offset = 0
        self.number = 1
        self.__perform_query = True
        self.__cached_child_dbs = {}

        self.debug = bool(os.environ.get("OPTIMADE_CLIENT_DEBUG", None))

        providers = []
        providers, invalid_providers = get_list_of_valid_providers(
            disable_providers=disable_providers, skip_providers=skip_providers)
        providers.insert(0, (self.HINT["provider"], {}))
        if self.debug:
            from optimade_client.utils import VERSION_PARTS

            local_provider = LinksResourceAttributes(
                **{
                    "name": "Local server",
                    "description": "Local server, running aiida-optimade",
                    "base_url": f"http://localhost:5000{VERSION_PARTS[0][0]}",
                    "homepage": "https://example.org",
                    "link_type": "external",
                })
            providers.insert(1, ("Local server", local_provider))

        self.providers = DropdownExtended(
            options=providers,
            disabled_options=invalid_providers,
            layout=ipw.Layout(width="auto"),
        )

        self.show_child_dbs = ipw.Layout(width="auto", display="none")
        self.child_dbs = DropdownExtended(grouping=self.INITIAL_CHILD_DBS,
                                          layout=self.show_child_dbs)
        self.page_chooser = ResultsPageChooser(page_limit=self.child_db_limit,
                                               layout=self.show_child_dbs)

        self.providers.observe(self._observe_providers, names="value")
        self.child_dbs.observe(self._observe_child_dbs, names="value")
        self.page_chooser.observe(
            self._get_more_child_dbs,
            names=["page_link", "page_offset", "page_number"])
        self.error_or_status_messages = ipw.HTML("")

        super().__init__(
            children=(
                self.providers,
                self.child_dbs,
                self.page_chooser,
                self.error_or_status_messages,
            ),
            layout=ipw.Layout(width="auto"),
            **kwargs,
        )

    def freeze(self):
        """Disable widget"""
        self.providers.disabled = True
        self.show_child_dbs.display = "none"
        self.page_chooser.freeze()

    def unfreeze(self):
        """Activate widget (in its current state)"""
        self.providers.disabled = False
        self.show_child_dbs.display = None
        self.page_chooser.unfreeze()

    def reset(self):
        """Reset widget"""
        self.page_chooser.reset()
        self.offset = 0
        self.number = 1

        self.providers.disabled = False
        self.providers.index = 0

        self.show_child_dbs.display = "none"
        self.child_dbs.grouping = self.INITIAL_CHILD_DBS

    def _observe_providers(self, change: dict):
        """Update child database dropdown upon changing provider"""
        value = change["new"]
        self.show_child_dbs.display = "none"
        self.provider = value
        if value is None or not value:
            self.show_child_dbs.display = "none"
            self.child_dbs.grouping = self.INITIAL_CHILD_DBS
            self.providers.index = 0
            self.child_dbs.index = 0
        else:
            self._initialize_child_dbs()
            if sum([len(_[1]) for _ in self.child_dbs.grouping]) <= 2:
                # The provider either has 0 or 1 implementations
                # or we have failed to retrieve any implementations.
                # Automatically choose the 1 implementation (if there),
                # while otherwise keeping the dropdown disabled.
                self.show_child_dbs.display = "none"
                try:
                    self.child_dbs.index = 1
                    LOGGER.debug("Changed child_dbs index. New child_dbs: %s",
                                 self.child_dbs)
                except IndexError:
                    pass
            else:
                self.show_child_dbs.display = None

    def _observe_child_dbs(self, change: dict):
        """Update database traitlet with base URL for chosen child database"""
        value = change["new"]
        if value is None or not value:
            self.database = "", None
        else:
            self.database = self.child_dbs.label.strip(), self.child_dbs.value

    @staticmethod
    def _remove_current_dropdown_option(dropdown: ipw.Dropdown) -> tuple:
        """Remove the current option from a Dropdown widget and return updated options

        Since Dropdown.options is a tuple there is a need to go through a list.
        """
        list_of_options = list(dropdown.options)
        list_of_options.pop(dropdown.index)
        return tuple(list_of_options)

    def _initialize_child_dbs(self):
        """New provider chosen; initialize child DB dropdown"""
        self.offset = 0
        self.number = 1
        try:
            # Freeze and disable list of structures in dropdown widget
            # We don't want changes leading to weird things happening prior to the query ending
            self.freeze()

            # Reset the error or status message
            if self.error_or_status_messages.value:
                self.error_or_status_messages.value = ""

            if self.provider.base_url in self.__cached_child_dbs:
                cache = self.__cached_child_dbs[self.provider.base_url]

                LOGGER.debug(
                    "Initializing child DBs for %s. Using cached info:\n%r",
                    self.provider.name,
                    cache,
                )

                self._set_child_dbs(cache["child_dbs"])
                data_returned = cache["data_returned"]
                data_available = cache["data_available"]
                links = cache["links"]
            else:
                LOGGER.debug("Initializing child DBs for %s.",
                             self.provider.name)

                # Query database and get child_dbs
                child_dbs, links, data_returned, data_available = self._query()

                while True:
                    # Update list of structures in dropdown widget
                    exclude_child_dbs, final_child_dbs = self._update_child_dbs(
                        data=child_dbs,
                        skip_dbs=self.skip_child_dbs.get(
                            self.provider.name, []),
                    )

                    LOGGER.debug("Exclude child DBs: %r", exclude_child_dbs)
                    data_returned -= len(exclude_child_dbs)
                    data_available -= len(exclude_child_dbs)
                    if exclude_child_dbs and data_returned:
                        child_dbs, links, data_returned, _ = self._query(
                            exclude_ids=exclude_child_dbs)
                    else:
                        break
                self._set_child_dbs(final_child_dbs)

                # Cache initial child_dbs and related information
                self.__cached_child_dbs[self.provider.base_url] = {
                    "child_dbs": final_child_dbs,
                    "data_returned": data_returned,
                    "data_available": data_available,
                    "links": links,
                }

                LOGGER.debug(
                    "Found the following, which has now been cached:\n%r",
                    self.__cached_child_dbs[self.provider.base_url],
                )

            # Update pageing
            self.page_chooser.set_pagination_data(
                data_returned=data_returned,
                data_available=data_available,
                links_to_page=links,
                reset_cache=True,
            )

        except QueryError as exc:
            LOGGER.debug(
                "Trying to initalize child DBs. QueryError caught: %r", exc)
            if exc.remove_target:
                LOGGER.debug(
                    "Remove target: %r. Will remove target at %r: %r",
                    exc.remove_target,
                    self.providers.index,
                    self.providers.value,
                )
                self.providers.options = self._remove_current_dropdown_option(
                    self.providers)
                self.reset()
            else:
                LOGGER.debug(
                    "Remove target: %r. Will NOT remove target at %r: %r",
                    exc.remove_target,
                    self.providers.index,
                    self.providers.value,
                )
                self.show_child_dbs.display = "none"
                self.child_dbs.grouping = self.INITIAL_CHILD_DBS

        else:
            self.unfreeze()

    def _set_child_dbs(
        self,
        data: List[Tuple[str, List[Tuple[str, LinksResourceAttributes]]]],
    ) -> None:
        """Update the child_dbs options with `data`"""
        first_choice = (self.HINT["child_dbs"]
                        if data else "No valid implementations found")
        new_data = list(data)
        new_data.insert(0, ("", [(first_choice, None)]))
        self.child_dbs.grouping = new_data

    def _update_child_dbs(
        self,
        data: List[dict],
        skip_dbs: List[str] = None
    ) -> Tuple[List[str], List[List[Union[str, List[Tuple[
            str, LinksResourceAttributes]]]]], ]:
        """Update child DB dropdown from response data"""
        child_dbs = ({
            "": []
        } if self.providers.label not in self.child_db_groupings else deepcopy(
            self.child_db_groupings[self.providers.label]))
        exclude_dbs = []
        skip_dbs = skip_dbs or []

        for entry in data:
            child_db = update_old_links_resources(entry)
            if child_db is None:
                continue

            # Skip if desired by user
            if child_db.id in skip_dbs:
                exclude_dbs.append(child_db.id)
                continue

            attributes = child_db.attributes

            # Skip if not a 'child' link_type database
            if attributes.link_type != LinkType.CHILD:
                LOGGER.debug(
                    "Skip %s: Links resource not a %r link_type, instead: %r",
                    attributes.name,
                    LinkType.CHILD,
                    attributes.link_type,
                )
                continue

            # Skip if there is no base_url
            if attributes.base_url is None:
                LOGGER.debug(
                    "Skip %s: Base URL found to be None for child DB: %r",
                    attributes.name,
                    child_db,
                )
                exclude_dbs.append(child_db.id)
                continue

            versioned_base_url = get_versioned_base_url(attributes.base_url)
            if versioned_base_url:
                attributes.base_url = versioned_base_url
            else:
                # Not a valid/supported child DB: skip
                LOGGER.debug(
                    "Skip %s: Could not determine versioned base URL for child DB: %r",
                    attributes.name,
                    child_db,
                )
                exclude_dbs.append(child_db.id)
                continue

            if self.providers.label in self.child_db_groupings:
                for group, ids in self.child_db_groupings[
                        self.providers.label].items():
                    if child_db.id in ids:
                        index = child_dbs[group].index(child_db.id)
                        child_dbs[group][index] = (attributes.name, attributes)
                        break
                else:
                    if "" in child_dbs:
                        child_dbs[""].append((attributes.name, attributes))
                    else:
                        child_dbs[""] = [(attributes.name, attributes)]
            else:
                child_dbs[""].append((attributes.name, attributes))

        if self.providers.label in self.child_db_groupings:
            for group, ids in tuple(child_dbs.items()):
                child_dbs[group] = [_ for _ in ids if isinstance(_, tuple)]
        child_dbs = list(child_dbs.items())

        LOGGER.debug("Final updated child_dbs: %s", child_dbs)

        return exclude_dbs, child_dbs

    def _get_more_child_dbs(self, change):
        """Query for more child DBs according to page_offset"""
        if self.providers.value is None:
            # This may be called if a provider is suddenly removed (bad provider)
            return

        if not self.__perform_query:
            self.__perform_query = True
            LOGGER.debug(
                "Will not perform query with pageing: name=%s value=%s",
                change["name"],
                change["new"],
            )
            return

        pageing: Union[int, str] = change["new"]
        LOGGER.debug(
            "Detected change in page_chooser: name=%s value=%s",
            change["name"],
            pageing,
        )
        if change["name"] == "page_offset":
            LOGGER.debug(
                "Got offset %d to retrieve more child DBs from %r",
                pageing,
                self.providers.value,
            )
            self.offset = pageing
            pageing = None
        elif change["name"] == "page_number":
            LOGGER.debug(
                "Got number %d to retrieve more child DBs from %r",
                pageing,
                self.providers.value,
            )
            self.number = pageing
            pageing = None
        else:
            LOGGER.debug(
                "Got link %r to retrieve more child DBs from %r",
                pageing,
                self.providers.value,
            )
            # It is needed to update page_offset, but we do not wish to query again
            with self.hold_trait_notifications():
                self.__perform_query = False
                self.page_chooser.update_offset()

        try:
            # Freeze and disable both dropdown widgets
            # We don't want changes leading to weird things happening prior to the query ending
            self.freeze()

            # Query index meta-database
            LOGGER.debug("Querying for more child DBs using pageing: %r",
                         pageing)
            child_dbs, links, _, _ = self._query(pageing)

            data_returned = self.page_chooser.data_returned
            while True:
                # Update list of child DBs in dropdown widget
                exclude_child_dbs, final_child_dbs = self._update_child_dbs(
                    child_dbs)

                data_returned -= len(exclude_child_dbs)
                if exclude_child_dbs and data_returned:
                    child_dbs, links, data_returned, _ = self._query(
                        link=pageing, exclude_ids=exclude_child_dbs)
                else:
                    break
            self._set_child_dbs(final_child_dbs)

            # Update pageing
            self.page_chooser.set_pagination_data(data_returned=data_returned,
                                                  links_to_page=links)

        except QueryError as exc:
            LOGGER.debug(
                "Trying to retrieve more child DBs (new page). QueryError caught: %r",
                exc,
            )
            if exc.remove_target:
                LOGGER.debug(
                    "Remove target: %r. Will remove target at %r: %r",
                    exc.remove_target,
                    self.providers.index,
                    self.providers.value,
                )
                self.providers.options = self._remove_current_dropdown_option(
                    self.providers)
                self.reset()
            else:
                LOGGER.debug(
                    "Remove target: %r. Will NOT remove target at %r: %r",
                    exc.remove_target,
                    self.providers.index,
                    self.providers.value,
                )
                self.show_child_dbs.display = "none"
                self.child_dbs.grouping = self.INITIAL_CHILD_DBS

        else:
            self.unfreeze()

    def _query(  # pylint: disable=too-many-locals,too-many-branches,too-many-statements
            self,
            link: str = None,
            exclude_ids: List[str] = None) -> Tuple[List[dict], dict, int,
                                                    int]:
        """Query helper function"""
        # If a complete link is provided, use it straight up
        if link is not None:
            try:
                if exclude_ids:
                    filter_value = " AND ".join(
                        [f'NOT id="{id_}"' for id_ in exclude_ids])

                    parsed_url = urllib.parse.urlparse(link)
                    queries = urllib.parse.parse_qs(parsed_url.query)
                    # Since parse_qs wraps all values in a list,
                    # this extracts the values from the list(s).
                    queries = {key: value[0] for key, value in queries.items()}

                    if "filter" in queries:
                        queries[
                            "filter"] = f"( {queries['filter']} ) AND ( {filter_value} )"
                    else:
                        queries["filter"] = filter_value

                    parsed_query = urllib.parse.urlencode(queries)

                    link = (
                        f"{parsed_url.scheme}://{parsed_url.netloc}{parsed_url.path}"
                        f"?{parsed_query}")

                link = ordered_query_url(link)
                response = SESSION.get(link, timeout=TIMEOUT_SECONDS)
                if response.from_cache:
                    LOGGER.debug("Request to %s was taken from cache !", link)
                response = response.json()
            except (
                    requests.exceptions.ConnectTimeout,
                    requests.exceptions.ConnectionError,
                    requests.exceptions.ReadTimeout,
            ) as exc:
                response = {
                    "errors": {
                        "msg": "CLIENT: Connection error or timeout.",
                        "url": link,
                        "Exception": repr(exc),
                    }
                }
            except json.JSONDecodeError as exc:
                response = {
                    "errors": {
                        "msg": "CLIENT: Could not decode response to JSON.",
                        "url": link,
                        "Exception": repr(exc),
                    }
                }
        else:
            filter_ = '( link_type="child" OR type="child" )'
            if exclude_ids:
                filter_ += (
                    " AND ( " +
                    " AND ".join([f'NOT id="{id_}"'
                                  for id_ in exclude_ids]) + " )")

            response = perform_optimade_query(
                filter=filter_,
                base_url=self.provider.base_url,
                endpoint="/links",
                page_limit=self.child_db_limit,
                page_offset=self.offset,
                page_number=self.number,
            )
        msg, http_errors = handle_errors(response)
        if msg:
            if 404 in http_errors:
                # If /links not found move on
                pass
            else:
                self.error_or_status_messages.value = msg
                raise QueryError(msg=msg, remove_target=True)

        # Check implementation API version
        msg = validate_api_version(response.get("meta",
                                                {}).get("api_version", ""),
                                   raise_on_fail=False)
        if msg:
            self.error_or_status_messages.value = (
                f"{msg}<br>The provider has been removed.")
            raise QueryError(msg=msg, remove_target=True)

        LOGGER.debug(
            "Manually remove `exclude_ids` if filters are not supported")
        child_db_data = {
            impl.get("id", "N/A"): impl
            for impl in response.get("data", [])
        }
        if exclude_ids:
            for links_id in exclude_ids:
                if links_id in list(child_db_data.keys()):
                    child_db_data.pop(links_id)
            LOGGER.debug("child_db_data after popping: %r", child_db_data)
            response["data"] = list(child_db_data.values())
            if "meta" in response:
                if "data_available" in response["meta"]:
                    old_data_available = response["meta"].get(
                        "data_available", 0)
                    if len(response["data"]) > old_data_available:
                        LOGGER.debug("raising OptimadeClientError")
                        raise OptimadeClientError(
                            f"Reported data_available ({old_data_available}) is smaller than "
                            f"curated list of responses ({len(response['data'])}).",
                        )
                response["meta"]["data_available"] = len(response["data"])
            else:
                raise OptimadeClientError(
                    "'meta' not found in response. Bad response")

        LOGGER.debug(
            "Attempt for %r (in /links): Found implementations (names+base_url only):\n%s",
            self.provider.name,
            [
                f"(id: {name}; base_url: {base_url}) " for name, base_url in [(
                    impl.get("id", "N/A"),
                    impl.get("attributes", {}).get("base_url", "N/A"),
                ) for impl in response.get("data", [])]
            ],
        )
        # Return all implementations of link_type "child"
        implementations = [
            implementation for implementation in response.get("data", [])
            if (implementation.get("attributes", {}).get("link_type", "") ==
                "child" or implementation.get("type", "") == "child")
        ]
        LOGGER.debug(
            "After curating for implementations which are of 'link_type' = 'child' or 'type' == "
            "'child' (old style):\n%s",
            [
                f"(id: {name}; base_url: {base_url}) " for name, base_url in [(
                    impl.get("id", "N/A"),
                    impl.get("attributes", {}).get("base_url", "N/A"),
                ) for impl in implementations]
            ],
        )

        # Get links, data_returned, and data_available
        links = response.get("links", {})
        data_returned = response.get("meta", {}).get("data_returned",
                                                     len(implementations))
        if data_returned > 0 and not implementations:
            # Most probably dealing with pre-v1.0.0-rc.2 implementations
            data_returned = 0
        data_available = response.get("meta", {}).get("data_available", 0)

        return implementations, links, data_returned, data_available
コード例 #20
0
class ModifyCoordinates(Algorithm):
    """
    Base class for nodes that modify the requested coordinates before evaluation.
    
    Attributes
    ----------
    source : podpac.Node
        Source node that will be evaluated with the modified coordinates.
    coordinates_source : podpac.Node
        Node that supplies the available coordinates when necessary, optional. The source node is used by default.
    lat, lon, time, alt : List
        Modification parameters for given dimension. Varies by node.
    """

    source = tl.Instance(Node)
    coordinates_source = tl.Instance(Node)
    lat = tl.List().tag(attr=True)
    lon = tl.List().tag(attr=True)
    time = tl.List().tag(attr=True)
    alt = tl.List().tag(attr=True)

    _modified_coordinates = tl.Instance(Coordinates, allow_none=True)

    @tl.default('coordinates_source')
    def _default_coordinates_source(self):
        return self.source

    @common_doc(COMMON_DOC)
    def eval(self, coordinates, output=None):
        """Evaluates this nodes using the supplied coordinates.

        Parameters
        ----------
        coordinates : podpac.Coordinates
            {requested_coordinates}
        output : podpac.UnitsDataArray, optional
            {eval_output}
            
        Returns
        -------
        {eval_return}
        
        Notes
        -------
        The input coordinates are modified and the passed to the base class implementation of eval.
        """

        self._requested_coordinates = coordinates
        self.outputs = {}
        self._modified_coordinates = Coordinates([
            self.get_modified_coordinates1d(coordinates, dim)
            for dim in coordinates.dims
        ])

        for dim in self._modified_coordinates.udims:
            if self._modified_coordinates[dim].size == 0:
                raise ValueError(
                    "Modified coordinates do not intersect with source data (dim '%s')"
                    % dim)

        self.outputs['source'] = self.source.eval(self._modified_coordinates,
                                                  output=output)

        if output is None:
            output = self.outputs['source']
        else:
            output[:] = self.outputs['source']

        if settings['DEBUG']:
            self._output = output
        return output
コード例 #21
0
class PolarCoordinates(StackedCoordinates):
    """
    Parameterized spatial coordinates defined by a center, radius coordinates, and theta coordinates.

    Attributes
    ----------
    center
    radius
    theta
    """

    center = ArrayTrait(shape=(2, ), dtype=float, read_only=True)
    radius = tl.Instance(Coordinates1d, read_only=True)
    theta = tl.Instance(Coordinates1d, read_only=True)
    dims = tl.Tuple(tl.Unicode(), tl.Unicode(), read_only=True)

    def __init__(self, center, radius, theta=None, theta_size=None, dims=None):

        # radius
        if not isinstance(radius, Coordinates1d):
            radius = ArrayCoordinates1d(radius)

        # theta
        if theta is not None and theta_size is not None:
            raise TypeError(
                "PolarCoordinates expected theta or theta_size, not both.")
        if theta is None and theta_size is None:
            raise TypeError("PolarCoordinates requires theta or theta_size.")

        if theta_size is not None:
            theta = UniformCoordinates1d(start=0,
                                         stop=2 * np.pi,
                                         size=theta_size + 1)[:-1]
        elif not isinstance(theta, Coordinates1d):
            theta = ArrayCoordinates1d(theta)

        self.set_trait("center", center)
        self.set_trait("radius", radius)
        self.set_trait("theta", theta)
        if dims is not None:
            self.set_trait("dims", dims)

    @tl.validate("dims")
    def _validate_dims(self, d):
        val = d["value"]
        for dim in val:
            if dim not in ["lat", "lon"]:
                raise ValueError(
                    "PolarCoordinates dims must be 'lat' or 'lon', not '%s'" %
                    dim)
        if val[0] == val[1]:
            raise ValueError("Duplicate dimension '%s'" % val[0])
        return val

    @tl.validate("radius")
    def _validate_radius(self, d):
        val = d["value"]
        if np.any(val.coordinates <= 0):
            raise ValueError("PolarCoordinates radius must all be positive")
        return val

    def _set_name(self, value):
        self._set_dims(value.split("_"))

    def _set_dims(self, dims):
        self.set_trait("dims", dims)

    # ------------------------------------------------------------------------------------------------------------------
    # Alternate Constructors
    # ------------------------------------------------------------------------------------------------------------------

    @classmethod
    def from_definition(cls, d):
        if "center" not in d:
            raise ValueError(
                'PolarCoordinates definition requires "center" property')
        if "radius" not in d:
            raise ValueError(
                'PolarCoordinates definition requires "radius" property')
        if "theta" not in d and "theta_size" not in d:
            raise ValueError(
                'PolarCoordinates definition requires "theta" or "theta_size" property'
            )
        if "dims" not in d:
            raise ValueError(
                'PolarCoordinates definition requires "dims" property')

        # center
        center = d["center"]

        # radius
        if isinstance(d["radius"], list):
            radius = ArrayCoordinates1d(d["radius"])
        elif "values" in d["radius"]:
            radius = ArrayCoordinates1d.from_definition(d["radius"])
        elif "start" in d["radius"] and "stop" in d["radius"] and (
                "step" in d["radius"] or "size" in d["radius"]):
            radius = UniformCoordinates1d.from_definition(d["radius"])
        else:
            raise ValueError(
                "Could not parse radius coordinates definition with keys %s" %
                d.keys())

        # theta
        if "theta" not in d:
            theta = None
        elif isinstance(d["theta"], list):
            theta = ArrayCoordinates1d(d["theta"])
        elif "values" in d["theta"]:
            theta = ArrayCoordinates1d.from_definition(d["theta"])
        elif "start" in d["theta"] and "stop" in d["theta"] and (
                "step" in d["theta"] or "size" in d["theta"]):
            theta = UniformCoordinates1d.from_definition(d["theta"])
        else:
            raise ValueError(
                "Could not parse theta coordinates definition with keys %s" %
                d.keys())

        kwargs = {
            k: v
            for k, v in d.items() if k not in ["center", "radius", "theta"]
        }
        return PolarCoordinates(center, radius, theta, **kwargs)

    # ------------------------------------------------------------------------------------------------------------------
    # standard methods
    # ------------------------------------------------------------------------------------------------------------------

    def __repr__(self):
        return "%s(%s): center%s, shape%s" % (
            self.__class__.__name__, self.dims, self.center, self.shape)

    def __eq__(self, other):
        if not isinstance(other, PolarCoordinates):
            return False

        if not np.allclose(self.center, other.center):
            return False

        if self.radius != other.radius:
            return False

        if self.theta != other.theta:
            return False

        return True

    def __getitem__(self, index):
        if isinstance(index, slice):
            index = index, slice(None)

        if isinstance(index, tuple) and isinstance(
                index[0], slice) and isinstance(index[1], slice):
            return PolarCoordinates(self.center,
                                    self.radius[index[0]],
                                    self.theta[index[1]],
                                    dims=self.dims)
        else:
            # convert to raw StackedCoordinates (which creates the _coords attribute that the indexing requires)
            return StackedCoordinates(self.coordinates,
                                      dims=self.dims).__getitem__(index)

    # ------------------------------------------------------------------------------------------------------------------
    # Properties
    # ------------------------------------------------------------------------------------------------------------------

    @property
    def _coords(self):
        raise RuntimeError("PolarCoordinates do not have a _coords attribute.")

    @property
    def ndim(self):
        return 2

    @property
    def shape(self):
        return self.radius.size, self.theta.size

    @property
    def xdims(self):
        return ("r", "t")

    @property
    def coordinates(self):
        r, theta = np.meshgrid(self.radius.coordinates, self.theta.coordinates)
        lat = r * np.sin(theta) + self.center[0]
        lon = r * np.cos(theta) + self.center[1]
        return lat.T, lon.T

    @property
    def definition(self):
        d = OrderedDict()
        d["dims"] = self.dims
        d["center"] = self.center
        d["radius"] = self.radius.definition
        d["theta"] = self.theta.definition
        return d

    @property
    def full_definition(self):
        return self.definition

    # ------------------------------------------------------------------------------------------------------------------
    # Methods
    # ------------------------------------------------------------------------------------------------------------------

    def copy(self):
        return PolarCoordinates(self.center,
                                self.radius,
                                self.theta,
                                dims=self.dims)
コード例 #22
0
class Lambda(Node):
    """A `Node` wrapper to evaluate source on AWS Lambda function

    Attributes
    ----------
    AWS_ACCESS_KEY_ID : string
        access key id from AWS credentials
    AWS_SECRET_ACCESS_KEY : string`
        access key value from AWS credentials
    AWS_REGION_NAME : string
        name of the AWS region
    source: Node
        node to be evaluated
    source_output: Output
        how to output the evaluated results of `source`
    attrs: dict
        additional attributes passed on to the Lambda definition of the base node
    """

    AWS_ACCESS_KEY_ID = tl.Unicode(
        allow_none=False, help="Access key ID from AWS for S3 bucket.")

    @tl.default('AWS_ACCESS_KEY_ID')
    def _AWS_ACCESS_KEY_ID_default(self):
        return settings['AWS_ACCESS_KEY_ID']

    AWS_SECRET_ACCESS_KEY = tl.Unicode(
        allow_none=False, help="Access key value from AWS for S3 bucket.")

    @tl.default('AWS_SECRET_ACCESS_KEY')
    def _AWS_SECRET_ACCESS_KEY_default(self):
        return settings['AWS_SECRET_ACCESS_KEY']

    AWS_REGION_NAME = tl.Unicode(allow_none=False,
                                 help="Region name of AWS S3 bucket.")

    @tl.default('AWS_REGION_NAME')
    def _AWS_REGION_NAME_default(self):
        return settings['AWS_REGION_NAME']

    source = tl.Instance(Node,
                         allow_none=False,
                         help="Node to evaluate in a Lambda function.")

    source_output = tl.Instance(Output,
                                allow_none=False,
                                help="Image output information.")

    attrs = tl.Dict()

    @tl.default('source_output')
    def _source_output_default(self):
        return FileOutput(node=self.source,
                          name=self.source.__class__.__name__)

    s3_bucket_name = tl.Unicode(allow_none=False,
                                help="Name of AWS s3 bucket.")

    @tl.default('s3_bucket_name')
    def _s3_bucket_name_default(self):
        return settings['S3_BUCKET_NAME']

    s3_json_folder = tl.Unicode(allow_none=False,
                                help="S3 folder to put JSON in.")

    @tl.default('s3_json_folder')
    def _s3_json_folder_default(self):
        return settings['S3_JSON_FOLDER']

    s3_output_folder = tl.Unicode(allow_none=False,
                                  help="S3 folder to put output in.")

    @tl.default('s3_output_folder')
    def _s3_output_folder_default(self):
        return settings['S3_OUTPUT_FOLDER']

    @property
    def definition(self):
        """
        The definition of this manager is the aggregation of the source node
        and source output.
        """
        d = OrderedDict()
        d['pipeline'] = self.source.definition
        if self.attrs:
            out_node = next(reversed(d['pipeline']['nodes'].keys()))
            d['pipeline']['nodes'][out_node]['attrs'].update(self.attrs)
        d['pipeline']['output'] = self.source_output.definition
        return d

    @common_doc(COMMON_DOC)
    def eval(self, coordinates, output=None):
        """
        Evaluate the source node on the AWS Lambda Function at the given coordinates
        """
        d = self.definition
        d['coordinates'] = json.loads(coordinates.json)
        filename = '%s%s_%s_%s.%s' % (
            self.s3_json_folder, self.source_output.name, self.source.hash,
            coordinates.hash, 'json')
        s3 = boto3.client('s3')
        s3.put_object(Body=(bytes(
            json.dumps(d, indent=4, cls=JSONEncoder).encode('UTF-8'))),
                      Bucket=self.s3_bucket_name,
                      Key=filename)

        waiter = s3.get_waiter('object_exists')
        filename = '%s%s_%s_%s.%s' % (
            self.s3_output_folder, self.source_output.name, self.source.hash,
            coordinates.hash, self.source_output.format)
        waiter.wait(Bucket=self.s3_bucket_name, Key=filename)
        # After waiting, load the pickle file like this:
        resource = boto3.resource('s3')
        with BytesIO() as data:
            # Get the bucket and file name programmatically - see above...
            resource.Bucket(self.s3_bucket_name).download_fileobj(
                filename, data)
            data.seek(0)  # move back to the beginning after writing
            self._output = cPickle.load(data)
        return self._output
コード例 #23
0
class PixelInspector(ipywidgets.GridBox):
    """
    Display pixel values when clicking on the map.

    Whenever you click on the map, it fetches the pixel values at that
    location for all active Workflows layers and displays them in a table
    overlaid on the map. It also shows a marker on the map indicating
    the last position clicked. As layers change, or are added or removed,
    the table keeps fetching pixel values for the new layers at the last-clicked
    point (the marker's current position).

    For performance, the inspector does not use full-resolution data, but rather
    whatever resolution (zoom level) the map is currently displaying.
    Therefore, it's possible that values for the same point would come back slightly
    different at different zoom levels. (Note that the resampling method used is
    whatever the input `~.geospatial.Image` or `~.geospatial.ImageCollection`
    was constructed with.)

    To unlink from the map, call `~.unlink`.

    Example
    -------
    >>> import descarteslabs.workflows as wf
    >>> my_map = wf.interactive.Map()

    >>> img = wf.Image.from_id("landsat:LC08:PRE:TOAR:meta_LC80330352016022_v1").pick_bands("red")
    >>> img.pick_bands("red").visualize("Red", colormap="Reds", map=my_map)  # doctest: +SKIP
    >>> img.pick_bands("green").visualize("Green", colormap="Greens", map=my_map)  # doctest: +SKIP

    >>> inspector = wf.interactive.PixelInspector(my_map)
    >>> my_map  # doctest: +SKIP
    >>> # ^ display the map
    >>> # click on the map; a table will pop up showing pixel values for the Red and Green layers
    >>> inspector.unlink()
    >>> # table and marker disappear; click again and nothing happens
    """

    # NOTE(gabe): we use the marker's opacity as a crude global on-off switch;
    # all event listeners check that opacity == 1 before doing actual work.
    marker = traitlets.Instance(
        CircleMarkerWithXYZGeoContext,
        kw=dict(opacity=0, radius=5, weight=1, name="Inspected pixel marker"),
        read_only=True,
    )
    n_bands = traitlets.Int(3, read_only=True, allow_none=False)

    def __init__(self, map, position="topright", layout=None):
        """
        Construct a PixelInspector and attach it to a map.

        Parameters
        ----------
        map: ipyleaflet.Map, workflows.interactive.MapApp
            The map to attach to
        position: str, optional, default "topright"
            Where on the map to display the values table
        layout: ipywidgets.Layout, optional
            Layout for the values table. Defaults to
            ``Layout(max_height="350px", overflow="scroll", padding="4px")``
        """
        if layout is None:
            layout = ipywidgets.Layout(max_height="350px",
                                       overflow="scroll",
                                       padding="4px")
        super().__init__([], layout=layout)
        self.layout.grid_template_columns = "min-content " * (1 + self.n_bands)

        # awkwardly handle MapApp without circularly importing it for an isinstance check
        try:
            sub_map = map.map
        except AttributeError:
            pass
        else:
            if isinstance(sub_map, ipyleaflet.Map):
                map = sub_map

        self._map = map
        self.marker.map = map
        self._inspector_rows_by_layer_id = weakref.WeakValueDictionary()

        self._layers_changed({"old": [], "new": map.layers})
        # initialize with the current layers on the map, if any

        self._control = ipyleaflet.WidgetControl(widget=self,
                                                 position=position)
        map.add_control(self._control)
        map.observe(self._layers_changed, names=["layers"], type="change")
        map.on_interaction(self._handle_click)

        self._orig_cursor = map.default_style.cursor
        map.default_style.cursor = "crosshair"

    def unlink(self):
        "Stop listening for click events or layer updates and remove the table from the map"
        self._map.on_interaction(self._handle_click, remove=True)
        self._map.unobserve(self._layers_changed, "layers", type="change")
        for inspector_row in tuple(self._inspector_rows_by_layer_id.values()):
            # ^ take a tuple first, since unlinking should remove all references to `inspector_row`,
            # which would then pop it from the dict, causing mutation of the dict while we iterate over it
            inspector_row.unlink()
        self._map.default_style.cursor = self._orig_cursor
        try:
            self._map.remove_control(self._control)
        except ipyleaflet.ControlException:
            pass
        try:
            self._map.remove_layer(self.marker)
        except ipyleaflet.LayerException:
            pass
        self.marker.opacity = 0  # be extra sure no more inspects will run
        self.children = []

    def _layers_changed(self, change):
        new_layers = change["new"]

        inspector_rows = []
        for layer in reversed(new_layers):
            if isinstance(layer, WorkflowsLayer):
                try:
                    inspector_row = self._inspector_rows_by_layer_id[
                        layer.model_id]
                except KeyError:
                    inspector_row = InspectorRowGenerator(
                        layer, self.marker, self.n_bands)
                    self._inspector_rows_by_layer_id[
                        layer.model_id] = inspector_row

                inspector_rows.append(inspector_row)

        new_children = []
        for inspector_row in inspector_rows:
            new_children.append(inspector_row.name_label)
            new_children.extend(inspector_row.values_labels)

        self.children = new_children

    def _handle_click(self, **kwargs):
        with self._map.output_log:
            if kwargs.get("type") != "click":
                return
            try:
                lat, lon = kwargs["coordinates"]
            except KeyError:
                return

            self.marker.opacity = 1
            self.marker.location = (lat, lon)

            # in case it accidentally got deleted with `clear_layers`
            if self.marker not in self._map.layers:
                self._map.add_layer(self.marker)
コード例 #24
0
class ImageSelector(ipywidgets.HBox):
    """UI element to select an image sequence and frame number

    Given a list of file names, images, or image sequences, this allows the
    user to chose one of entry and also to select a frame number. The selected
    image is available via the :py:attr:`output` traitlet.
    """
    images = traitlets.Union([traitlets.Dict(), traitlets.List()])
    """Images or sequences to select from. Image sequences can be passed as
    3D :py:class:`numpy.ndarray`, as lists of 2D arrays, or as paths to image
    files, which will be opened using :py:mod:`pims`. Single images are
    represented as 2D arrays.

    This attribute can be a list of ``(key, img)`` tuples where ``key`` is the
    name to display and value an image (sequence), a dict mapping ``key`` to
    ``img`` (which will be converted to a list of tuples) or a plain list of
    image (sequence).
    """
    output = traitlets.Instance(np.ndarray, allow_none=True)
    """2D array representing the currently selected frame."""
    index = traitlets.Int(allow_none=True)
    """Currently selected index (w.r.t. :py:attr:`images`)"""
    def __init__(self, images: Union[Sequence, Dict] = [], **kwargs):
        """Parameters
        ---------
        images
            List of image (sequences) to populate :py:attr:`images`.
        **kwargs
            Passed to parent ``__init__``.
        """
        self._file_sel = ipywidgets.Dropdown(description="image")
        self._file_sel.observe(self._file_changed, "value")
        self._frame_sel = ipywidgets.BoundedIntText(description="frame",
                                                    min=0,
                                                    max=0)
        self._frame_sel.observe(self._frame_changed, "value")

        self._prev_button = ipywidgets.Button(icon="arrow-left")
        self._prev_button.on_click(self._prev_button_clicked)
        self._next_button = ipywidgets.Button(icon="arrow-right")
        self._next_button.on_click(self._next_button_clicked)
        for b in self._prev_button, self._next_button:
            b.layout.width = "auto"
            b.disabled = True
        self.show_file_buttons = False

        traitlets.link((self._file_sel, "index"), (self, "index"))

        super().__init__([
            self._prev_button, self._file_sel, self._next_button,
            self._frame_sel
        ], **kwargs)

        self._cur_image = None
        self._cur_image_opened = False
        self._frame_changed_lock = threading.Lock()

        self.images = images

    @traitlets.validate("images")
    def _make_images_list(self, proposal):
        """Validator for the :py:attr:`images` traitlet

        Turns dictionaries into lists of tuples.
        """
        images = proposal["value"]
        if len(images) == 0:
            return []
        if isinstance(images, dict):
            return list(images.items())
        return images

    @traitlets.observe("images")
    def _set_file_options(self, change=None):
        """Set the options for the sequence selection dropdown element"""
        if len(self.images) == 0:
            self._file_sel.options = []
            return

        n_figures = int(math.log10(len(self.images)))
        generic_key_pattern = "<{{:0{}}}>".format(n_figures)

        opts = []
        for n, img in enumerate(self.images):
            if isinstance(img, tuple):
                opts.append(img[0])
                continue
            if isinstance(img, str):
                img = Path(img)
            if isinstance(img, Path):
                opts.append("{} ({})".format(img.name, str(img.parent)))
                continue
            opts.append(generic_key_pattern.format(n))
        self._file_sel.options = opts

    def _file_changed(self, change=None):
        """Call-back upon change of the currently selected sequence"""
        if self._cur_image_opened:
            self._cur_image.close()
            self._cur_image_opened = False

        if self._file_sel.value is None:
            # No file selected
            with self._frame_changed_lock:
                self._frame_sel.max = 0
            self._cur_image = None
            self.output = None
            self._prev_button.disabled = True
            self._next_button.disabled = True
            return

        self._prev_button.disabled = self._file_sel.index <= 0
        self._next_button.disabled = (self._file_sel.index >=
                                      len(self.images) - 1)

        img = self.images[self._file_sel.index]
        if isinstance(img, tuple):
            # TODO: What if there is a tuple of images instead of (key, value)?
            img = img[1]

        if isinstance(img, np.ndarray) and img.ndim == 2:
            # Single image
            img = img[None, ...]
        elif isinstance(img, (str, Path)):
            # Open…
            img = pims.open(str(img))
            self._cur_image_opened = True

        self._cur_image = img

        with self._frame_changed_lock:
            # Disable potential update at this point. Will be explicitly
            # updated below.
            self._frame_sel.max = len(img) - 1

        self._frame_changed()

    def _frame_changed(self, change=None):
        """Call-back upon change of the currently selected frame number"""
        if self._frame_changed_lock.locked():
            return
        self.output = self._cur_image[self._frame_sel.value]

    def _prev_button_clicked(self, button=None):
        self._file_sel.index -= 1

    def _next_button_clicked(self, button=None):
        self._file_sel.index += 1

    @property
    def show_file_buttons(self) -> bool:
        """Whether to show "back" and "forward" buttons for file selection"""
        return self._prev_button.layout.display is None

    @show_file_buttons.setter
    def show_file_buttons(self, s):
        for b in self._prev_button, self._next_button:
            b.layout.display = None if s else "none"
コード例 #25
0
class WCSBase(DataSource):
    """
    Access data from a WCS source.

    Attributes
    ----------
    source : str
        WCS server url
    layer : str
        layer name (required)
    version : str
        WCS version, passed through to all requests (default '1.0.0')
    format : str
        Data format, passed through to the GetCoverage requests (default 'geotiff')
    crs : str
        coordinate reference system, passed through to the GetCoverage requests (default 'EPSG:4326')
    interpolation : str
        Interpolation, passed through to the GetCoverage requests.
    max_size : int
        maximum request size, optional.
        If provided, the coordinates will be tiled into multiple requests.
    """

    source = tl.Unicode().tag(attr=True)
    layer = tl.Unicode().tag(attr=True)
    version = tl.Unicode(default_value="1.0.0").tag(attr=True)
    interpolation = tl.Unicode(default_value=None,
                               allow_none=True).tag(attr=True)

    format = tl.CaselessStrEnum(["geotiff", "geotiff_byte"],
                                default_value="geotiff")
    crs = tl.Unicode(default_value="EPSG:4326")
    max_size = tl.Long(default_value=None, allow_none=True)

    _repr_keys = ["source", "layer"]

    _requested_coordinates = tl.Instance(Coordinates, allow_none=True)
    _evaluated_coordinates = tl.Instance(Coordinates)

    @cached_property
    def client(self):
        return owslib_wcs.WebCoverageService(self.source, version=self.version)

    def get_coordinates(self):
        """
        Get the full WCS grid.
        """

        metadata = self.client.contents[self.layer]

        # TODO select correct boundingbox by crs

        # coordinates
        w, s, e, n = metadata.boundingBoxWGS84
        low = metadata.grid.lowlimits
        high = metadata.grid.highlimits
        xsize = int(high[0]) - int(low[0])
        ysize = int(high[1]) - int(low[1])

        coords = []
        coords.append(UniformCoordinates1d(s, n, size=ysize, name="lat"))
        coords.append(UniformCoordinates1d(w, e, size=xsize, name="lon"))

        if metadata.timepositions:
            coords.append(
                ArrayCoordinates1d(metadata.timepositions, name="time"))

        if metadata.timelimits:
            raise NotImplementedError("TODO")

        return Coordinates(coords, crs=self.crs)

    def _eval(self, coordinates, output=None, _selector=None):
        """Evaluates this node using the supplied coordinates.

        This method intercepts the DataSource._eval method in order to use the requested coordinates directly when
        they are a uniform grid.

        Parameters
        ----------
        coordinates : :class:`podpac.Coordinates`
            {requested_coordinates}

            An exception is raised if the requested coordinates are missing dimensions in the DataSource.
            Extra dimensions in the requested coordinates are dropped.
        output : :class:`podpac.UnitsDataArray`, optional
            {eval_output}
        _selector: callable(coordinates, request_coordinates)
            {eval_selector}

        Returns
        -------
        {eval_return}

        Raises
        ------
        ValueError
            Cannot evaluate these coordinates
        """

        # the datasource does do this, but we need to do it here to correctly select the correct case
        if self.coordinates.crs.lower() != coordinates.crs.lower():
            coordinates = coordinates.transform(self.coordinates.crs)

        # for a uniform grid, use the requested coordinates (the WCS server will interpolate)
        if (("lat" in coordinates.dims and "lon" in coordinates.dims) and
            (coordinates["lat"].is_uniform or coordinates["lat"].size == 1) and
            (coordinates["lon"].is_uniform or coordinates["lon"].size == 1)):

            def selector(rsc, coordinates, index_type=None):
                return coordinates, None

            return super()._eval(coordinates,
                                 output=output,
                                 _selector=selector)

        # for uniform stacked, unstack to use the requested coordinates (the WCS server will interpolate)
        if (("lat" in coordinates.udims and coordinates.is_stacked("lat")) and
            ("lon" in coordinates.udims and coordinates.is_stacked("lon")) and
            (coordinates["lat"].is_uniform or coordinates["lat"].size == 1) and
            (coordinates["lon"].is_uniform or coordinates["lon"].size == 1)):

            def selector(rsc, coordinates, index_type=None):
                unstacked = coordinates.unstack()
                unstacked = unstacked.drop(
                    "alt", ignore_missing=True)  # if lat_lon_alt
                return unstacked, None

            udata = super()._eval(coordinates, output=None, _selector=selector)
            data = udata.data.diagonal()  # get just the stacked data
            if output is None:
                output = self.create_output_array(coordinates, data=data)
            else:
                output.data[:] = data
            return output

        # otherwise, pass-through (podpac will select and interpolate)
        return super()._eval(coordinates, output=output, _selector=_selector)

    def _get_data(self, coordinates, coordinates_index):
        """{get_data}"""

        # transpose the coordinates to match the response data
        if "time" in coordinates:
            coordinates = coordinates.transpose("time", "lat", "lon")
        else:
            coordinates = coordinates.transpose("lat", "lon")

        # determine the chunk size (if applicable)
        if self.max_size is not None:
            shape = []
            s = 1
            for n in coordinates.shape:
                r = self.max_size // s
                if r == 0:
                    shape.append(1)
                elif r < n:
                    shape.append(r)
                else:
                    shape.append(n)
                s *= n
            shape = tuple(shape)
        else:
            shape = coordinates.shape

        # request each chunk and composite the data
        output = self.create_output_array(coordinates)
        for i, (chunk, slc) in enumerate(
                coordinates.iterchunks(shape, return_slices=True)):
            output[slc] = self._get_chunk(chunk)

        return output

    def _get_chunk(self, coordinates):
        if coordinates["lon"].size == 1:
            w = coordinates["lon"].coordinates[0]
            e = coordinates["lon"].coordinates[0]
        else:
            w = coordinates["lon"].start - coordinates["lon"].step / 2.0
            e = coordinates["lon"].stop + coordinates["lon"].step / 2.0

        if coordinates["lat"].size == 1:
            s = coordinates["lat"].coordinates[0]
            n = coordinates["lat"].coordinates[0]
        else:
            s = coordinates["lat"].start - coordinates["lat"].step / 2.0
            n = coordinates["lat"].stop + coordinates["lat"].step / 2.0

        width = coordinates["lon"].size
        height = coordinates["lat"].size

        kwargs = {}

        if "time" in coordinates:
            kwargs["time"] = coordinates["time"].coordinates.astype(
                str).tolist()

        if isinstance(self.interpolation, str):
            kwargs["interpolation"] = self.interpolation

        logger.info(
            "WCS GetCoverage (source=%s, layer=%s, bbox=%s, shape=%s)" %
            (self.source, self.layer, (w, n, e, s), (width, height)))

        response = self.client.getCoverage(identifier=self.layer,
                                           bbox=(w, n, e, s),
                                           width=width,
                                           height=height,
                                           crs=self.crs,
                                           format=self.format,
                                           version=self.version,
                                           **kwargs)
        content = response.read()

        # check for errors
        xml = bs4.BeautifulSoup(content, "lxml")
        error = xml.find("serviceexception")
        if error:
            raise WCSError(error.text)

        # get data using rasterio
        with rasterio.MemoryFile() as mf:
            mf.write(content)
            dataset = mf.open(driver="GTiff")

        if "time" in coordinates and coordinates["time"].size > 1:
            # this should be easy to do, I'm just not sure how the data comes back.
            # is each time in a different band?
            raise NotImplementedError("TODO")

        data = dataset.read(1).astype(float)
        return data

    @classmethod
    def get_layers(cls, source=None):
        if source is None:
            source = cls.source
        client = owslib_wcs.WebCoverageService(source)
        return list(client.contents)
コード例 #26
0
ファイル: datashader_vis.py プロジェクト: zwelz3/ipyradiant
class DatashaderVisualizer(NXBase):
    """
    A class for visualization an RDF graph with datashader

    :param graph: an rdflib.graph.Graph object to visualize.
    :param tooltip: takes either 'nodes' or 'edges', and sets the hover tool.
    :param sparql: a query you'd like to perform on the rdflib.graph.Grab object.
    """

    output = T.Instance(W.Output)
    tooltip = T.Unicode(default_value="nodes")
    tooltip_dict = T.Dict()
    node_tooltips = T.List()
    edge_tooltips = T.List()
    sparql = T.Unicode()

    @T.default("output")
    def _make_default_output(self):
        return W.Output()

    @T.default("edge_tooltips")
    def _make_edge_tooltip(self):
        return [
            ("Source", "@start"),
            ("Target", "@end"),
        ]

    @T.default("node_tooltips")
    def _make_node_tooltip(self):
        return [
            ("ID", "@index"),
        ]

    @T.default("tooltip")
    def _make_tooltip(self):
        return "nodes"

    @T.default("tooltip_dict")
    def _make_tooltip_dict(self):
        return {
            "nodes": HoverTool(tooltips=self.node_tooltips),
            "edges": HoverTool(tooltips=self.edge_tooltips),
        }

    @T.default("sparql")
    def _make_sparql(self):
        return """
            CONSTRUCT {
                ?s ?p ?o .
            }
            WHERE {
                ?s ?p ?o .
                FILTER (!isLiteral(?o))
                FILTER (!isLiteral(?s))
            }
            LIMIT 300
        """

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.children = [self.output]

    def display_datashader_vis(self, p):
        self.output.clear_output()
        with self.output:
            IPython.display.display(p)

    def strip_and_produce_rdf_graph(self, rdf_graph: Graph):
        """
        A function that takes in an rdflib.graph.Graph object
        and transforms it into a datashader holoviews graph.
        Also performs the sparql query on the graph that can be set
        via the 'sparql' parameter
        """

        sparql = self.sparql
        qres = rdf_graph.query(sparql)
        uri_graph = Graph()
        for row in qres:
            uri_graph.add(row)

        new_netx = rdflib_to_networkx_graph(uri_graph)
        original = hv.Graph.from_networkx(new_netx, self._nx_layout,
                                          **self.graph_layout_params)
        output_graph = bundle_graph(original)
        return output_graph

    def set_options(self, output_graph):
        return output_graph.options(
            frame_width=1000,
            frame_height=1000,
            xaxis=None,
            yaxis=None,
            tools=[self.tooltip_dict[self.tooltip], "tap", "box_select"],
            inspection_policy=self.tooltip,
            node_color=self.node_color,
            edge_color=self.edge_color,
        )

    def tap_stream_subscriber(self, x, y):
        nodes_data = self.output_graph.nodes.data
        t = 0.01
        values = nodes_data[nodes_data.x.between(x - t, x + t,
                                                 True)][nodes_data.y.between(
                                                     y - t, y + t, True)]
        self.selected_nodes = tuple([URIRef(_) for _ in list(values["index"])])

    def box_stream_subscriber(self, **kwargs):
        bounds = kwargs["bounds"]
        nodes_data = self.output_graph.nodes.data
        values = nodes_data[nodes_data.x.between(bounds[0], bounds[2],
                                                 True)][nodes_data.y.between(
                                                     bounds[1], bounds[3],
                                                     True)]
        self.selected_nodes = tuple([URIRef(_) for _ in list(values["index"])])

    @T.observe("_nx_layout", "sparql", "graph", "graph_layout_params")
    def changed_layout(self, change):
        if self.graph is None:
            self.output_graph = None
            self.display_datashader_vis(self.output_graph)
        elif len(self.graph) == 0:
            self.output_graph = None
            self.display_datashader_vis("Cannot display blank graph.")
        else:
            self.output_graph = self.strip_and_produce_rdf_graph(self.graph)
            self.tap_selection_stream = streams.Tap(source=self.output_graph)
            self.tap_selection_stream.add_subscriber(
                self.tap_stream_subscriber)
            self.box_selection_stream = streams.BoundsXY(
                source=self.output_graph)
            self.box_selection_stream.add_subscriber(
                self.box_stream_subscriber)
            self.final_graph = self.set_options(self.output_graph)
            self.display_datashader_vis(self.final_graph)
コード例 #27
0
class Foo(jst.JSONHasTraits):
    _additional_traits = False
    x = T.Integer()
    y = T.Instance(Bar)
コード例 #28
0
class ZarrOutputMixin(tl.HasTraits):
    """
    This class assumes that the node has a 'output_format' attribute
    (currently the "Lambda" Node, and the "Process" Node)

    Attributes
    -----------
    zarr_file: str
        Path to the output zarr file that collects all of the computed results. This can reside on S3.
    dataset: ZarrGroup
        A handle to the zarr group pointing to the output file
    fill_output: bool, optional
        Default is False (unlike parent class). If True, will collect the output data and return it as an xarray.
    init_file_mode: str, optional
        Default is 'w'. Mode used for initializing the zarr file.
    zarr_chunks: dict
        Size of the chunks in the zarr file for each dimension
    zarr_shape: dict, optional
        Default is the {coordinated.dims: coordinates.shape}, where coordinates used as part of the eval call. This
        does not need to be specified unless the Node modifies the input coordinates (as part of a Reduce operation,
        for example). The result can be incorrect and requires care/checking by the user.
    zarr_coordinates: podpac.Coordinates, optional
        Default is None. If the node modifies the shape of the input coordinates, this allows users to set the
        coordinates in the output zarr file. This can be incorrect and requires care by the user.
    skip_existing: bool
        Default is False. If true, this will check to see if the results already exist. And if so, it will not
        submit a job for that particular coordinate evaluation. This assumes self.chunks == self.zar_chunks
    list_dir: bool, optional
        Default is False. If skip_existing is True, by default existing files are checked by asking for an 'exists' call.
        If list_dir is True, then at the first opportunity a "list_dir" is performed on the directory and the results
        are cached.
    """

    zarr_file = tl.Unicode().tag(attr=True)
    dataset = tl.Any()
    zarr_node = NodeTrait()
    zarr_data_key = tl.Union([tl.Unicode(), tl.List()])
    fill_output = tl.Bool(False)
    init_file_mode = tl.Unicode("a").tag(attr=True)
    zarr_chunks = tl.Dict(default_value=None, allow_none=True).tag(attr=True)
    zarr_shape = tl.Dict(allow_none=True, default_value=None).tag(attr=True)
    zarr_coordinates = tl.Instance(Coordinates,
                                   allow_none=True,
                                   default_value=None).tag(attr=True)
    zarr_dtype = tl.Unicode("f4")
    skip_existing = tl.Bool(True).tag(attr=True)
    list_dir = tl.Bool(False)
    _list_dir = tl.List(allow_none=True, default_value=[])
    _shape = tl.Tuple()
    _chunks = tl.List()
    aws_client_kwargs = tl.Dict()
    aws_config_kwargs = tl.Dict()

    def eval(self, coordinates, **kwargs):
        output = kwargs.get("output")
        if self.zarr_shape is None:
            self._shape = coordinates.shape
        else:
            self._shape = tuple(self.zarr_shape.values())

        # initialize zarr file
        if self.zarr_chunks is None:
            chunks = [self.chunks[d] for d in coordinates]
        else:
            chunks = [self.zarr_chunks[d] for d in coordinates]
        self._chunks = chunks
        zf, data_key, zn = self.initialize_zarr_array(self._shape, chunks)
        self.dataset = zf
        self.zarr_data_key = data_key
        self.zarr_node = zn
        zn.keys

        # eval
        _log.debug("Starting parallel eval.")
        missing_dims = [
            d for d in coordinates.dims if d not in self.chunks.keys()
        ]
        if self.zarr_coordinates is not None:
            missing_dims = missing_dims + [
                d for d in self.zarr_coordinates.dims if d not in missing_dims
            ]
            set_coords = merge_dims(
                [coordinates.drop(missing_dims), self.zarr_coordinates])
        else:
            set_coords = coordinates.drop(missing_dims)
        set_coords.transpose(*coordinates.dims)

        self.set_zarr_coordinates(set_coords, data_key)
        if self.list_dir:
            dk = data_key
            if isinstance(dk, list):
                dk = dk[0]
            self._list_dir = self.zarr_node.list_dir(dk)

        output = super(ZarrOutputMixin, self).eval(coordinates, output=output)

        # fill in the coordinates, this is guaranteed to be correct even if the user messed up.
        if output is not None:
            self.set_zarr_coordinates(Coordinates.from_xarray(output.coords),
                                      data_key)
        else:
            return zf

        return output

    def set_zarr_coordinates(self, coordinates, data_key):
        # Fill in metadata
        for dk in data_key:
            self.dataset[dk].attrs["_ARRAY_DIMENSIONS"] = coordinates.dims
        for d in coordinates.dims:
            # TODO ADD UNITS AND TIME DECODING INFORMATION
            self.dataset.create_dataset(d,
                                        shape=coordinates[d].size,
                                        overwrite=True)
            self.dataset[d][:] = coordinates[d].coordinates

    def initialize_zarr_array(self, shape, chunks):
        _log.debug("Creating Zarr file.")
        zn = Zarr(source=self.zarr_file,
                  file_mode=self.init_file_mode,
                  aws_client_kwargs=self.aws_client_kwargs)
        if self.source.output or getattr(self.source, "data_key", None):
            data_key = self.source.output
            if data_key is None:
                data_key = self.source.data_key
            if not isinstance(data_key, list):
                data_key = [data_key]
            elif self.source.outputs:  # If someone restricted the outputs for this node, we need to know
                data_key = [dk for dk in data_key if dk in self.source.outputs]
        elif self.source.outputs:
            data_key = self.source.outputs
        else:
            data_key = ["data"]

        zf = zarr.open(zn._get_store(), mode=self.init_file_mode)

        # Intialize the output zarr arrays
        for dk in data_key:
            try:
                arr = zf.create_dataset(
                    dk,
                    shape=shape,
                    chunks=chunks,
                    fill_value=np.nan,
                    dtype=self.zarr_dtype,
                    overwrite=not self.skip_existing,
                )
            except ValueError:
                pass  # Dataset already exists

        # Recompute any cached properties
        zn = Zarr(source=self.zarr_file,
                  file_mode=self.init_file_mode,
                  aws_client_kwargs=self.aws_client_kwargs)
        return zf, data_key, zn

    def eval_source(self, coordinates, coordinates_index, out, i, source=None):
        if source is None:
            source = self.source

        if self.skip_existing:  # This section allows previously computed chunks to be skipped
            dk = self.zarr_data_key
            if isinstance(dk, list):
                dk = dk[0]
            try:
                exists = self.zarr_node.chunk_exists(coordinates_index,
                                                     data_key=dk,
                                                     list_dir=self._list_dir,
                                                     chunks=self._chunks)
            except ValueError as e:  # This was needed in cases where a poor internet connection caused read errors
                exists = False
            if exists:
                _log.info("Skipping {} (already exists)".format(i))
                return out, coordinates_index

        # Make a copy to prevent any possibility of memory corruption
        source = Node.from_definition(source.definition)
        _log.debug("Creating output format.")
        output = dict(
            format="zarr_part",
            format_kwargs=dict(
                part=[[s.start, min(s.stop, self._shape[i]), s.step]
                      for i, s in enumerate(coordinates_index)],
                source=self.zarr_file,
                mode="a",
            ),
        )
        _log.debug("Finished creating output format.")

        if source.has_trait("output_format"):
            source.set_trait("output_format", output)
        _log.debug("output: {}, coordinates.shape: {}".format(
            output, coordinates.shape))
        _log.debug("Evaluating node.")

        o, slc = super(ZarrOutputMixin,
                       self).eval_source(coordinates, coordinates_index, out,
                                         i, source)

        if not source.has_trait("output_format"):
            o.to_format(output["format"], **output["format_kwargs"])
        return o, slc
コード例 #29
0
class StackedCoordinates(BaseCoordinates):
    """
    Stacked coordinates.

    StackedCoordinates contain coordinates from two or more different dimensions that are stacked together to form a
    list of points (rather than a grid). The underlying coordinates values are :class:`Coordinates1d` objects of equal
    size. The name for the stacked coordinates combines the underlying dimensions with underscores, e.g. ``'lat_lon'``
    or ``'lat_lon_time'``.

    When creating :class:`Coordinates`, podpac automatically detects StackedCoordinates. The following Coordinates
    contain 3 stacked lat-lon coordinates and 2 time coordinates in a 3 x 2 grid::

        >>> lat = [0, 1, 2]
        >>> lon = [10, 20, 30]
        >>> time = ['2018-01-01', '2018-01-02']
        >>> podpac.Coordinates([[lat, lon], time], dims=['lat_lon', 'time'])
        Coordinates
            lat_lon[lat]: ArrayCoordinates1d(lat): Bounds[0.0, 2.0], N[3]
            lat_lon[lon]: ArrayCoordinates1d(lon): Bounds[10.0, 30.0], N[3]
            time: ArrayCoordinates1d(time): Bounds[2018-01-01, 2018-01-02], N[2]

    For convenience, you can also create uniformly-spaced stacked coordinates using :class:`clinspace`::

        >>> lat_lon = podpac.clinspace((0, 10), (2, 30), 3)
        >>> time = ['2018-01-01', '2018-01-02']
        >>> podpac.Coordinates([lat_lon, time], dims=['lat_lon', 'time'])
        Coordinates
            lat_lon[lat]: ArrayCoordinates1d(lat): Bounds[0.0, 2.0], N[3]
            lat_lon[lon]: ArrayCoordinates1d(lon): Bounds[10.0, 30.0], N[3]
            time: ArrayCoordinates1d(time): Bounds[2018-01-01, 2018-01-02], N[2]

    Parameters
    ----------
    dims : tuple
        Tuple of dimension names.
    name : str
        Stacked dimension name.
    coords : dict-like
        xarray coordinates (container of coordinate arrays)
    coordinates : pandas.MultiIndex
        MultiIndex of stacked coordinates values.

    """

    _coords = tl.List(trait=tl.Instance(Coordinates1d), read_only=True)

    def __init__(self, coords, name=None, dims=None):
        """
        Initialize a multidimensional coords bject.

        Parameters
        ----------
        coords : list, :class:`StackedCoordinates`
            Coordinate values in a list, or a StackedCoordinates object to copy.

        See Also
        --------
        clinspace, crange
        """

        if not isinstance(coords, (list, tuple)):
            raise TypeError("Unrecognized coords type '%s'" % type(coords))

        if len(coords) < 2:
            raise ValueError(
                "Stacked coords must have at least 2 coords, got %d" %
                len(coords))

        # coerce
        coords = tuple(
            c if isinstance(c, Coordinates1d) else ArrayCoordinates1d(c)
            for c in coords)

        # set coords
        self.set_trait("_coords", coords)

        # propagate properties
        if dims is not None and name is not None:
            raise TypeError(
                "StackedCoordinates expected 'dims' or 'name', not both")
        if dims is not None:
            self._set_dims(dims)
        if name is not None:
            self._set_name(name)

        # finalize
        super(StackedCoordinates, self).__init__()

    @tl.validate("_coords")
    def _validate_coords(self, d):
        val = d["value"]

        # check sizes
        shape = val[0].shape
        for c in val[1:]:
            if c.shape != shape:
                raise ValueError("Shape mismatch in stacked coords %s != %s" %
                                 (c.shape, shape))

        # check dims
        dims = [c.name for c in val]
        for i, dim in enumerate(dims):
            if dim is not None and dim in dims[:i]:
                raise ValueError("Duplicate dimension '%s' in stacked coords" %
                                 dim)

        return val

    def _set_name(self, value):
        dims = value.split("_")

        # check length
        if len(dims) != len(self._coords):
            raise ValueError(
                "Invalid name '%s' for StackedCoordinates with length %d" %
                (value, len(self._coords)))

        self._set_dims(dims)

    def _set_dims(self, dims):
        # check size
        if len(dims) != len(self._coords):
            raise ValueError(
                "Invalid dims '%s' for StackedCoordinates with length %d" %
                (dims, len(self._coords)))

        for i, dim in enumerate(dims):
            if dim is not None and dim in dims[:i]:
                raise ValueError("Duplicate dimension '%s' in dims" % dim)

        # set names, checking for duplicates
        for i, (c, dim) in enumerate(zip(self._coords, dims)):
            if dim is None:
                continue
            c._set_name(dim)

    # ------------------------------------------------------------------------------------------------------------------
    # Alternate constructors
    # ------------------------------------------------------------------------------------------------------------------

    @classmethod
    def from_xarray(cls, x, **kwargs):
        """
        Create 1d Coordinates from named xarray coordinates.

        Arguments
        ---------
        x : xarray.DataArray
            Nade DataArray of the coordinate values

        Returns
        -------
        :class:`ArrayCoordinates1d`
            1d coordinates
        """

        dims = x.dims[0].split("_")
        cs = [x[dim].data for dim in dims]
        return cls(cs, dims=dims, **kwargs)

    @classmethod
    def from_definition(cls, d):
        """
        Create StackedCoordinates from a stacked coordinates definition.

        Arguments
        ---------
        d : list
            stacked coordinates definition

        Returns
        -------
        :class:`StackedCoordinates`
            stacked coordinates object

        See Also
        --------
        definition
        """

        coords = []
        for elem in d:
            if "start" in elem and "stop" in elem and ("step" in elem
                                                       or "size" in elem):
                c = UniformCoordinates1d.from_definition(elem)
            elif "values" in elem:
                c = ArrayCoordinates1d.from_definition(elem)
            else:
                raise ValueError(
                    "Could not parse coordinates definition with keys %s" %
                    elem.keys())

            coords.append(c)

        return cls(coords)

    # ------------------------------------------------------------------------------------------------------------------
    # standard methods, list-like
    # ------------------------------------------------------------------------------------------------------------------

    def __repr__(self):
        rep = str(self.__class__.__name__)
        for c in self._coords:
            rep += "\n\t%s[%s]: %s" % (self.name, c.name or "?", c)
        return rep

    def __iter__(self):
        return iter(self._coords)

    def __len__(self):
        return len(self._coords)

    def __getitem__(self, index):
        if isinstance(index, string_types):
            if index not in self.dims:
                raise KeyError("Dimension '%s' not found in dims %s" %
                               (index, self.dims))

            return self._coords[self.dims.index(index)]

        else:
            return StackedCoordinates([c[index] for c in self._coords])

    def __setitem__(self, dim, c):
        if not dim in self.dims:
            raise KeyError(
                "Cannot set dimension '%s' in StackedCoordinates %s" %
                (dim, self.dims))

        # try to cast to ArrayCoordinates1d
        if not isinstance(c, Coordinates1d):
            c = ArrayCoordinates1d(c)

        if c.name is None:
            c.name = dim

        # replace the element of the coords list
        idx = self.dims.index(dim)
        coords = list(self._coords)
        coords[idx] = c

        # set (and check) new coords list
        self.set_trait("_coords", coords)

    def __contains__(self, item):
        try:
            item = np.array([make_coord_value(value) for value in item])
        except:
            return False

        if len(item) != len(self._coords):
            return False

        if any(val not in c for val, c in zip(item, self._coords)):
            return False

        return (self.flatten().coordinates == item).all(axis=1).any()

    def __eq__(self, other):
        if not isinstance(other, StackedCoordinates):
            return False

        # shortcuts
        if self.dims != other.dims:
            return False

        if self.shape != other.shape:
            return False

        # full check of underlying coordinates
        if self._coords != other._coords:
            return False

        return True

    # ------------------------------------------------------------------------------------------------------------------
    # Properties
    # ------------------------------------------------------------------------------------------------------------------

    @property
    def dims(self):
        """:tuple: Tuple of dimension names."""
        return tuple(c.name for c in self._coords)

    @property
    def ndim(self):
        """:int: coordinates array ndim."""
        return self._coords[0].ndim

    @property
    def name(self):
        """:str: Stacked dimension name. Stacked dimension names are the individual `dims` joined by an underscore."""

        if any(self.dims):
            return "_".join(dim or "?" for dim in self.dims)

    @property
    def size(self):
        """:int: Number of stacked coordinates. """
        return self._coords[0].size

    @property
    def shape(self):
        """:tuple: Shape of the stacked coordinates."""
        return self._coords[0].shape

    @property
    def bounds(self):
        """:dict: Dictionary of (low, high) coordinates bounds in each dimension"""
        if None in self.dims:
            raise ValueError(
                "Cannot get bounds for StackedCoordinates with un-named dimensions"
            )
        return {dim: self[dim].bounds for dim in self.udims}

    @property
    def coordinates(self):
        dtypes = [c.dtype for c in self._coords]
        if len(set(dtypes)) == 1:
            dtype = dtypes[0]
        else:
            dtype = object
        return np.dstack([c.coordinates.astype(dtype)
                          for c in self._coords]).squeeze()

    @property
    def xcoords(self):
        """:dict-like: xarray coordinates (container of coordinate arrays)"""
        if None in self.dims:
            raise ValueError(
                "Cannot get xcoords for StackedCoordinates with un-named dimensions"
            )

        if self.ndim == 1:
            # use a multi-index so that we can use DataArray.sel easily
            coords = pd.MultiIndex.from_arrays(
                [np.array(c.coordinates) for c in self._coords],
                names=self.dims)
            xcoords = {self.name: coords}
        else:
            # fall-back for shaped coordinates
            xcoords = {
                c.name: (self.xdims, c.coordinates)
                for c in self._coords
            }
        return xcoords

    @property
    def definition(self):
        """:list: Serializable stacked coordinates definition. """

        return [c.definition for c in self._coords]

    @property
    def full_definition(self):
        """:list: Serializable stacked coordinates definition, containing all properties. For internal use."""

        return [c.full_definition for c in self._coords]

    # -----------------------------------------------------------------------------------------------------------------
    # Methods
    # -----------------------------------------------------------------------------------------------------------------

    def copy(self):
        """
        Make a copy of the stacked coordinates.

        Returns
        -------
        :class:`StackedCoordinates`
            Copy of the stacked coordinates.
        """

        return StackedCoordinates(self._coords)

    def unique(self, return_index=False):
        """
        Remove duplicate stacked coordinate values.

        Arguments
        ---------
        return_index : bool, optional
            If True, return index for the unique coordinates in addition to the coordinates. Default False.

        Returns
        -------
        unique : :class:`StackedCoordinates`
            New StackedCoordinates object with unique, sorted, flattened coordinate values.
        unique_index : list of indices
            index
        """

        flat = self.flatten()
        a, I = np.unique(flat.coordinates, axis=0, return_index=True)
        if return_index:
            return flat[I], I
        else:
            return flat[I]

    def get_area_bounds(self, boundary):
        """Get coordinate area bounds, including boundary information, for each unstacked dimension.

        Arguments
        ---------
        boundary : dict
            dictionary of boundary offsets for each unstacked dimension. Point dimensions can be omitted.

        Returns
        -------
        area_bounds : dict
            Dictionary of (low, high) coordinates area_bounds in each unstacked dimension
        """

        if None in self.dims:
            raise ValueError(
                "Cannot get area_bounds for StackedCoordinates with un-named dimensions"
            )
        return {
            dim: self[dim].get_area_bounds(boundary.get(dim))
            for dim in self.dims
        }

    def select(self, bounds, outer=False, return_index=False):
        """
        Get the coordinate values that are within the given bounds in all dimensions.

        *Note: you should not generally need to call this method directly.*

        Parameters
        ----------
        bounds : dict
            dictionary of dim -> (low, high) selection bounds
        outer : bool, optional
            If True, do *outer* selections. Default False.
        return_index : bool, optional
            If True, return index for the selections in addition to coordinates. Default False.

        Returns
        -------
        selection : :class:`StackedCoordinates`
            StackedCoordinates object consisting of the selection in all dimensions.
        selection_index : slice, boolean array
            Slice or index for the selected coordinates, only if ``return_index`` is True.
        """

        # logical AND of the selection in each dimension
        indices = [
            c.select(bounds, outer=outer, return_index=True)[1]
            for c in self._coords
        ]
        index = self._and_indices(indices)

        if return_index:
            return self[index], index
        else:
            return self[index]

    def _and_indices(self, indices):
        if all(isinstance(index, slice) for index in indices):
            index = slice(max(index.start or 0 for index in indices),
                          min(index.stop or self.size for index in indices))

            # for consistency
            if index.start == 0 and index.stop == self.size:
                index = slice(None, None)

        else:
            # convert any slices to boolean array
            for i, index in enumerate(indices):
                if isinstance(index, slice):
                    indices[i] = np.zeros(self.shape, dtype=bool)
                    indices[i][index] = True

            # logical and
            index = np.logical_and.reduce(indices)

            # for consistency
            if np.all(index):
                index = slice(None, None)

        return index

    def _transform(self, transformer):
        coords = [c.copy() for c in self._coords]

        if "lat" in self.dims and "lon" in self.dims and "alt" in self.dims:
            ilat = self.dims.index("lat")
            ilon = self.dims.index("lon")
            ialt = self.dims.index("alt")

            lat = coords[ilat]
            lon = coords[ilon]
            alt = coords[ialt]
            tlon, tlat, talt = transformer.transform(lon.coordinates,
                                                     lat.coordinates,
                                                     alt.coordinates)

            coords[ilat] = ArrayCoordinates1d(tlat, "lat").simplify()
            coords[ilon] = ArrayCoordinates1d(tlon, "lon").simplify()
            coords[ialt] = ArrayCoordinates1d(talt, "alt").simplify()

        elif "lat" in self.dims and "lon" in self.dims:
            ilat = self.dims.index("lat")
            ilon = self.dims.index("lon")

            lat = coords[ilat]
            lon = coords[ilon]
            tlon, tlat = transformer.transform(lon.coordinates,
                                               lat.coordinates)

            if (self.ndim == 2
                    and all(np.allclose(a, tlat[:, 0]) for a in tlat.T)
                    and all(np.allclose(a, tlon[0]) for a in tlon)):
                coords[ilat] = ArrayCoordinates1d(tlat[:, 0],
                                                  name="lat").simplify()
                coords[ilon] = ArrayCoordinates1d(tlon[0],
                                                  name="lon").simplify()
                return coords

            coords[ilat] = ArrayCoordinates1d(tlat, "lat").simplify()
            coords[ilon] = ArrayCoordinates1d(tlon, "lon").simplify()

        elif "alt" in self.dims:
            ialt = self.dims.index("alt")

            alt = coords[ialt]
            _, _, talt = transformer.transform(np.zeros(self.size),
                                               np.zeros(self.size),
                                               alt.coordinates)

            coords[ialt] = ArrayCoordinates1d(talt, "alt").simplify()

        return StackedCoordinates(coords)

    def transpose(self, *dims, **kwargs):
        """
        Transpose (re-order) the dimensions of the StackedCoordinates.

        Parameters
        ----------
        dim_1, dim_2, ... : str, optional
            Reorder dims to this order. By default, reverse the dims.
        in_place : boolean, optional
            If True, transpose the dimensions in-place.
            Otherwise (default), return a new, transposed Coordinates object.

        Returns
        -------
        transposed : :class:`StackedCoordinates`
            The transposed StackedCoordinates object.
        """

        in_place = kwargs.get("in_place", False)

        if len(dims) == 0:
            dims = list(self.dims[::-1])

        if set(dims) != set(self.dims):
            raise ValueError(
                "Invalid transpose dimensions, input %s does match any dims in %s"
                % (dims, self.dims))

        coordinates = [self._coords[self.dims.index(dim)] for dim in dims]

        if in_place:
            self.set_trait("_coords", coordinates)
            return self
        else:
            return StackedCoordinates(coordinates)

    def flatten(self):
        return StackedCoordinates([c.flatten() for c in self._coords])

    def reshape(self, newshape):
        return StackedCoordinates([c.reshape(newshape) for c in self._coords])

    def issubset(self, other):
        """Report whether other coordinates contains these coordinates.

        Arguments
        ---------
        other : Coordinates, StackedCoordinates
            Other coordinates to check

        Returns
        -------
        issubset : bool
            True if these coordinates are a subset of the other coordinates.
        """

        from podpac.core.coordinates import Coordinates

        if not isinstance(other, (Coordinates, StackedCoordinates)):
            raise TypeError(
                "StackedCoordinates issubset expected Coordinates or StackedCoordinates, not '%s'"
                % type(other))

        if isinstance(other, StackedCoordinates):
            if set(self.dims) != set(other.dims):
                return False

            mine = self.flatten().coordinates
            other = other.flatten().transpose(*self.dims).coordinates
            return set(map(tuple, mine)).issubset(map(tuple, other))

        elif isinstance(other, Coordinates):
            if not all(dim in other.udims for dim in self.dims):
                return False

            acs = []
            ocs = []
            for coords in other.values():
                dims = [dim for dim in coords.dims if dim in self.dims]

                if len(dims) == 0:
                    continue

                elif len(dims) == 1:
                    acs.append(self[dims[0]])
                    if isinstance(coords, Coordinates1d):
                        ocs.append(coords)
                    elif isinstance(coords, StackedCoordinates):
                        ocs.append(coords[dims[0]])

                elif len(dims) > 1:
                    acs.append(StackedCoordinates([self[dim] for dim in dims]))
                    if isinstance(coords, StackedCoordinates):
                        ocs.append(
                            StackedCoordinates([coords[dim] for dim in dims]))

            return all(a.issubset(o) for a, o in zip(acs, ocs))
コード例 #30
0
class ScipyGrid(ScipyPoint):
    """Scipy Interpolation

    Attributes
    ----------
    {interpolator_attributes}
    """

    methods_supported = ["nearest", "bilinear", "cubic_spline", "spline_2", "spline_3", "spline_4"]
    method = tl.Unicode(default_value="nearest")

    # TODO: implement these parameters for the method 'nearest'
    spatial_tolerance = tl.Float(default_value=np.inf)
    time_tolerance = tl.Union([tl.Unicode(), tl.Instance(np.timedelta64, allow_none=True)], default_value=None)

    @common_doc(COMMON_INTERPOLATOR_DOCS)
    def can_interpolate(self, udims, source_coordinates, eval_coordinates):
        """
        {interpolator_can_interpolate}
        """

        # TODO: make this so we don't need to specify lat and lon together
        # or at least throw a warning
        if (
            "lat" in udims
            and "lon" in udims
            and self._dim_in(["lat", "lon"], source_coordinates)
            and self._dim_in(["lat", "lon"], eval_coordinates, unstacked=True)
        ):

            return ["lat", "lon"]

        # otherwise return no supported dims
        return tuple()

    @common_doc(COMMON_INTERPOLATOR_DOCS)
    def interpolate(self, udims, source_coordinates, source_data, eval_coordinates, output_data):
        """
        {interpolator_interpolate}
        """

        if self._dim_in(["lat", "lon"], eval_coordinates):
            return self._interpolate_irregular_grid(
                udims, source_coordinates, source_data, eval_coordinates, output_data, grid=True
            )

        elif self._dim_in(["lat", "lon"], eval_coordinates, unstacked=True):
            eval_coordinates_us = eval_coordinates.unstack()
            return self._interpolate_irregular_grid(
                udims, source_coordinates, source_data, eval_coordinates_us, output_data, grid=False
            )

    def _interpolate_irregular_grid(
        self, udims, source_coordinates, source_data, eval_coordinates, output_data, grid=True
    ):

        if len(source_data.dims) > 2:
            keep_dims = ["lat", "lon"]
            return self._loop_helper(
                self._interpolate_irregular_grid,
                keep_dims,
                udims,
                source_coordinates,
                source_data,
                eval_coordinates,
                output_data,
                grid=grid,
            )

        s = []
        if source_coordinates["lat"].is_descending:
            lat = source_coordinates["lat"].coordinates[::-1]
            s.append(slice(None, None, -1))
        else:
            lat = source_coordinates["lat"].coordinates
            s.append(slice(None, None))
        if source_coordinates["lon"].is_descending:
            lon = source_coordinates["lon"].coordinates[::-1]
            s.append(slice(None, None, -1))
        else:
            lon = source_coordinates["lon"].coordinates
            s.append(slice(None, None))

        data = source_data.data[tuple(s)]

        # remove nan's
        I, J = np.isfinite(lat), np.isfinite(lon)
        coords_i = lat[I], lon[J]
        coords_i_dst = [eval_coordinates["lon"].coordinates, eval_coordinates["lat"].coordinates]

        # Swap order in case datasource uses lon,lat ordering instead of lat,lon
        if source_coordinates.dims.index("lat") > source_coordinates.dims.index("lon"):
            I, J = J, I
            coords_i = coords_i[::-1]
            coords_i_dst = coords_i_dst[::-1]
        data = data[I, :][:, J]

        if self.method in ["bilinear", "nearest"]:
            f = RegularGridInterpolator(
                coords_i, data, method=self.method.replace("bi", ""), bounds_error=False, fill_value=np.nan
            )
            if grid:
                x, y = np.meshgrid(*coords_i_dst)
            else:
                x, y = coords_i_dst
            output_data.data[:] = f((y.ravel(), x.ravel())).reshape(output_data.shape)

        # TODO: what methods is 'spline' associated with?
        elif "spline" in self.method:
            if self.method == "cubic_spline":
                order = 3
            else:
                # TODO: make this a parameter
                order = int(self.method.split("_")[-1])

            f = RectBivariateSpline(coords_i[0], coords_i[1], data, kx=max(1, order), ky=max(1, order))
            output_data.data[:] = f(coords_i_dst[1], coords_i_dst[0], grid=grid).reshape(output_data.shape)

        return output_data