Exemplo n.º 1
0
class PCA(Transformer):
    '''Transform a set of features using a Principal Component Analysis.

    Example:

    >>> import vaex
    >>> df = vaex.from_arrays(x=[2,5,7,2,15], y=[-2,3,0,0,10])
    >>> df
     #   x   y
     0   2   -2
     1   5   3
     2   7   0
     3   2   0
     4   15  10
    >>> pca = vaex.ml.PCA(n_components=2, features=['x', 'y'])
    >>> pca.fit_transform(df)
     #    x    y       PCA_0      PCA_1
     0    2   -2    5.92532    0.413011
     1    5    3    0.380494  -1.39112
     2    7    0    0.840049   2.18502
     3    2    0    4.61287   -1.09612
     4   15   10  -11.7587    -0.110794

    '''
    # title = traitlets.Unicode(default_value='PCA', read_only=True).tag(ui='HTML')
    n_components = traitlets.Int(
        help=
        'Number of components to retain. If None, all the components will be retained.'
    ).tag(ui='IntText')
    prefix = traitlets.Unicode(default_value="PCA_", help=help_prefix)
    progress = traitlets.CBool(
        default_value=False,
        help='If True, display a progressbar of the PCA fitting process.').tag(
            ui='Checkbox')
    eigen_vectors_ = traitlets.List(
        traitlets.List(traitlets.CFloat()),
        help='The eigen vectors corresponding to each feature').tag(
            output=True)
    eigen_values_ = traitlets.List(
        traitlets.CFloat(),
        help='The eigen values that correspond to each feature.').tag(
            output=True)
    means_ = traitlets.List(traitlets.CFloat(),
                            help='The mean of each feature').tag(output=True)

    @traitlets.default('n_components')
    def get_n_components_default(self):
        return len(self.features)

    def fit(self, df):
        '''Fit the PCA model to the DataFrame.

        :param df: A vaex DataFrame.
        '''
        self.n_components = self.n_components or len(self.features)
        assert self.n_components >= 2, 'At least two features are required.'
        assert self.n_components <= len(
            self.features), 'Can not have more components than features.'
        C = df.cov(self.features, progress=self.progress)
        eigen_values, eigen_vectors = np.linalg.eigh(C)
        indices = np.argsort(eigen_values)[::-1]
        self.means_ = df.mean(self.features, progress=self.progress).tolist()
        self.eigen_vectors_ = eigen_vectors[:, indices].tolist()
        self.eigen_values_ = eigen_values[indices].tolist()

    def transform(self, df, n_components=None):
        '''Apply the PCA transformation to the DataFrame.

        :param df: A vaex DataFrame.
        :param n_components: The number of PCA components to retain.

        :return copy: A shallow copy of the DataFrame that includes the PCA components.
        :rtype: DataFrame
        '''
        n_components = n_components or self.n_components
        copy = df.copy()
        name_prefix_offset = 0
        eigen_vectors = np.array(self.eigen_vectors_)
        while self.prefix + str(name_prefix_offset) in copy.get_column_names(
                virtual=True, strings=True):
            name_prefix_offset += 1

        expressions = [
            copy[feature] - mean
            for feature, mean in zip(self.features, self.means_)
        ]
        for i in range(n_components):
            v = eigen_vectors[:, i]
            expr = dot_product(expressions, v)
            name = self.prefix + str(i + name_prefix_offset)
            copy[name] = expr
        return copy
Exemplo n.º 2
0
class ImageLayerStateWidget(v.VuetifyTemplate):
    template = load_template('layer_image.vue', __file__)

    glue_state = GlueState().tag(sync=True)

    # TODO: expose toggle to turn on image and/or contour

    attribute_items = traitlets.List().tag(sync=True)
    attribute_selected = traitlets.Int(allow_none=True).tag(sync=True)

    stretch_items = traitlets.List().tag(sync=True)
    stretch_selected = traitlets.Int(allow_none=True).tag(sync=True)

    percentile_items = traitlets.List().tag(sync=True)
    percentile_selected = traitlets.Int(allow_none=True).tag(sync=True)

    colormap_items = traitlets.List().tag(sync=True)
    color_mode = traitlets.Unicode().tag(sync=True)

    c_levels_txt = traitlets.Unicode().tag(sync=True)
    c_levels_txt_editing = False
    c_levels_error = traitlets.Unicode().tag(sync=True)

    has_contour = traitlets.Bool().tag(sync=True)

    def __init__(self, layer_state):
        super().__init__()

        self.layer_state = layer_state
        self.glue_state = layer_state

        self.has_contour = hasattr(layer_state, "contour_visible")

        link_glue_choices(self, layer_state, 'attribute')
        link_glue_choices(self, layer_state, 'stretch')
        link_glue_choices(self, layer_state, 'percentile')

        self.colormap_items = [dict(
            text=cmap[0],
            value=cmap[1].name
        ) for cmap in colormaps.members]

        link_glue(self, 'color_mode', layer_state.viewer_state)

        # we only go from glue state to the text version of the level list
        # the other way around is handled in _on_change_c_levels_txt
        if self.has_contour:
            def levels_to_text(*_ignore):
                if not self.c_levels_txt_editing:
                    text = ", ".join('%g' % v for v in self.glue_state.levels)
                    self.c_levels_txt = text

            self.glue_state.add_callback('levels', levels_to_text)

    @traitlets.observe('c_levels_txt')
    def _on_change_c_levels_txt(self, change):
        try:
            self.c_levels_txt_editing = True
            try:
                parts = change['new'].split(',')
                float_list_str = [float(v.strip()) for v in parts]
            except Exception as e:
                self.c_levels_error = str(e)
                return

            if self.glue_state.level_mode == "Custom":
                self.glue_state.levels = float_list_str
            self.c_levels_error = ''
        finally:
            self.c_levels_txt_editing = False

    def vue_set_colormap(self, data):
        cmap = None
        for member in colormaps.members:
            if member[1].name == data:
                cmap = member[1]
                break

        self.layer_state.cmap = cmap
Exemplo n.º 3
0
class WorkflowsLayer(ipyleaflet.TileLayer):
    """
    Subclass of ``ipyleaflet.TileLayer`` for displaying a Workflows `~.geospatial.Image`.

    Attributes
    ----------
    image: ~.geospatial.Image
        The `~.geospatial.Image` to use
    parameters: ParameterSet
        Parameters to use while computing; modify attributes under ``.parameters``
        (like ``layer.parameters.foo = "bar"``) to cause the layer to recompute
        and update under those new parameters.
    xyz_obj: ~.models.XYZ
        Read-only: The `XYZ` object this layer is displaying.
    session_id: str
        Read-only: Unique ID that error logs will be stored under, generated automatically.
    checkerboard: bool, default True
        Whether to display a checkerboarded background for missing or masked data.
    colormap: str, optional, default None
        Name of the colormap to use.
        If set, `image` must have 1 band.
    r_min: float, optional, default None
        Min value for scaling the red band. Along with r_max,
        controls scaling when a colormap is enabled.
    r_max: float, optional, default None
        Max value for scaling the red band. Along with r_min, controls scaling
        when a colormap is enabled.
    g_min: float, optional, default None
        Min value for scaling the green band.
    g_max: float, optional, default None
        Max value for scaling the green band.
    b_min: float, optional, default None
        Min value for scaling the blue band.
    b_max: float, optional, default None
        Max value for scaling the blue band.
    error_output: ipywidgets.Output, optional, default None
        If set, write unique errors from tiles computation to this output area
        from a background thread. Setting to None stops the listener thread.

    Example
    -------
    >>> import descarteslabs.workflows as wf
    >>> wf.map # doctest: +SKIP
    >>> # ^ display interactive map
    >>> img = wf.Image.from_id("landsat:LC08:PRE:TOAR:meta_LC80330352016022_v1").pick_bands("red")
    >>> masked_img = img.mask(img > wf.parameter("threshold", wf.Float))
    >>> layer = masked_img.visualize("sample", colormap="viridis", threshold=0.07) # doctest: +SKIP
    >>> layer.colormap = "plasma" # doctest: +SKIP
    >>> # ^ change colormap (this will update the layer on the map)
    >>> layer.parameters.threshold = 0.13 # doctest: +SKIP
    >>> # ^ adjust parameters (this also updates the layer)
    >>> layer.set_scales((0.01, 0.3)) # doctest: +SKIP
    >>> # ^ adjust scaling (this also updates the layer)
    """

    attribution = traitlets.Unicode("Descartes Labs").tag(sync=True, o=True)
    min_zoom = traitlets.Int(5).tag(sync=True, o=True)
    url = traitlets.Unicode(read_only=True).tag(sync=True)

    image = traitlets.Instance(Image)
    parameters = traitlets.Instance(parameters.ParameterSet, allow_none=True)
    xyz_obj = traitlets.Instance(XYZ, read_only=True)
    session_id = traitlets.Unicode(read_only=True)

    checkerboard = traitlets.Bool(True)
    colormap = traitlets.Unicode(None, allow_none=True)

    r_min = ScaleFloat(None, allow_none=True)
    r_max = ScaleFloat(None, allow_none=True)
    g_min = ScaleFloat(None, allow_none=True)
    g_max = ScaleFloat(None, allow_none=True)
    b_min = ScaleFloat(None, allow_none=True)
    b_max = ScaleFloat(None, allow_none=True)

    error_output = traitlets.Instance(widgets.Output, allow_none=True)
    autoscale_progress = traitlets.Instance(ClearableOutput)

    def __init__(self, image, *args, **kwargs):
        params = kwargs.pop("parameters", {})
        super(WorkflowsLayer, self).__init__(*args, **kwargs)

        with self.hold_trait_notifications():
            self.image = image
            self.set_trait("session_id", uuid.uuid4().hex)
            self.set_trait(
                "autoscale_progress",
                ClearableOutput(
                    widgets.Output(),
                    layout=widgets.Layout(max_height="10rem", flex="1 0 auto"),
                ),
            )
            self.set_parameters(**params)

        self._error_listener = None
        self._known_errors = set()
        self._known_errors_lock = threading.Lock()

    def make_url(self):
        """
        Generate the URL for this layer.

        This is called automatically as the attributes (`image`, `colormap`, scales, etc.) are changed.

        Example
        -------
        >>> import descarteslabs.workflows as wf
        >>> img = wf.Image.from_id("landsat:LC08:PRE:TOAR:meta_LC80330352016022_v1") # doctest: +SKIP
        >>> img = img.pick_bands("red blue green") # doctest: +SKIP
        >>> layer = img.visualize("sample") # doctest: +SKIP
        >>> layer.make_url() # doctest: +SKIP
        'https://workflows.descarteslabs.com/master/xyz/9ec70d0e99db7f50c856c774809ae454ffd8475816e05c5c/{z}/{x}/{y}.png?session_id=xxx&checkerboard=true'
        """
        if not self.visible:
            # workaround for the fact that Leaflet still loads tiles from inactive layers,
            # which is expensive computation users don't want
            return ""

        if self.colormap is not None:
            scales = [[self.r_min, self.r_max]]
        else:
            scales = [
                [self.r_min, self.r_max],
                [self.g_min, self.g_max],
                [self.b_min, self.b_max],
            ]

        scales = [scale for scale in scales if scale != [None, None]]

        parameters = self.parameters.to_dict()

        return self.xyz_obj.url(session_id=self.session_id,
                                colormap=self.colormap,
                                scales=scales,
                                checkerboard=self.checkerboard,
                                **parameters)

    @traitlets.observe("image")
    def _update_xyz(self, change):
        old, new = change["old"], change["new"]
        if old is new:
            # traitlets does an == check between the old and new value to decide if it's changed,
            # which for an Image, returns another Image, which it considers changed.
            return

        xyz = XYZ.build(new, name=self.name)
        xyz.save()
        self.set_trait("xyz_obj", xyz)

    @traitlets.observe(
        "visible",
        "checkerboard",
        "colormap",
        "r_min",
        "r_max",
        "g_min",
        "g_max",
        "b_min",
        "b_max",
        "xyz_obj",
        "session_id",
        "parameters",
    )
    @traitlets.observe("parameters", type="delete")
    def _update_url(self, change):
        try:
            self.set_trait("url", self.make_url())
        except ValueError as e:
            if "Invalid scales passed" not in str(e):
                raise e

    @traitlets.observe("parameters", type="delete")
    def _update_url_on_param_delete(self, change):
        # traitlets is dumb and decorator stacking doesn't work so we have to repeat this
        try:
            self.set_trait("url", self.make_url())
        except ValueError as e:
            if "Invalid scales passed" not in str(e):
                raise e

    @traitlets.observe("xyz_obj", "session_id")
    def _update_error_logger(self, change):
        if self.error_output is None:
            return

        # Remove old errors for the layer
        self.forget_errors()
        new_errors = []
        for error in self.error_output.outputs:
            if not error["text"].startswith(self.name + ": "):
                new_errors.append(error)
        self.error_output.outputs = tuple(new_errors)

        if self._error_listener is not None:
            self._error_listener.stop(timeout=1)

        listener = self.xyz_obj.error_listener()
        listener.add_callback(self._log_errors_callback)
        listener.listen(self.session_id,
                        datetime.datetime.now(datetime.timezone.utc))

        self._error_listener = listener

    def _stop_error_logger(self):
        if self._error_listener is not None:
            self._error_listener.stop(timeout=1)
            self._error_listener = None

    @traitlets.observe("error_output")
    def _toggle_error_listener_if_output(self, change):
        if change["new"] is None:
            self._stop_error_logger()
        else:
            if self._error_listener is None:
                self._update_error_logger({})

    def _log_errors_callback(self, msg):
        message = msg.message

        with self._known_errors_lock:
            if message in self._known_errors:
                return
            else:
                self._known_errors.add(message)

        error = "{}: {}\n".format(self.name, message)
        self.error_output.append_stdout(error)

    def __del__(self):
        self._stop_error_logger()
        super(WorkflowsLayer, self).__del__()

    def forget_errors(self):
        """
        Clear the set of known errors, so they are re-displayed if they occur again

        Example
        -------
        >>> import descarteslabs.workflows as wf
        >>> img = wf.Image.from_id("landsat:LC08:PRE:TOAR:meta_LC80330352016022_v1") # doctest: +SKIP
        >>> wf.map # doctest: +SKIP
        >>> layer = img.visualize("sample visualization") # doctest: +SKIP
        >>> # ^ will show an error for attempting to visualize more than 3 bands
        >>> layer.forget_errors() # doctest: +SKIP
        >>> wf.map.zoom = 10 # doctest: +SKIP
        >>> # ^ attempting to load more tiles from img will cause the same error to appear
        """
        with self._known_errors_lock:
            self._known_errors.clear()

    def set_scales(self, scales, new_colormap=False):
        """
        Update the scales for this layer by giving a list of scales

        Parameters
        ----------
        scales: list of lists, default None
            The scaling to apply to each band in the `Image`.

            If `Image` contains 3 bands, ``scales`` must be a list like ``[(0, 1), (0, 1), (-1, 1)]``.

            If `Image` contains 1 band, ``scales`` must be a list like ``[(0, 1)]``,
            or just ``(0, 1)`` for convenience

            If None, each 256x256 tile will be scaled independently
            based on the min and max values of its data.
        new_colormap: str, None, or False, optional, default False
            A new colormap to set at the same time, or False to use the current colormap.

        Example
        -------
        >>> import descarteslabs.workflows as wf
        >>> img = wf.Image.from_id("landsat:LC08:PRE:TOAR:meta_LC80330352016022_v1") # doctest: +SKIP
        >>> img = img.pick_bands("red") # doctest: +SKIP
        >>> layer = img.visualize("sample visualization", colormap="viridis") # doctest: +SKIP
        >>> layer.set_scales((0.08, 0.3), new_colormap="plasma") # doctest: +SKIP
        >>> # ^ optionally set new colormap
        """
        colormap = self.colormap if new_colormap is False else new_colormap

        if scales is not None:
            scales = XYZ._validate_scales(scales)

            scales_len = 1 if colormap is not None else 3
            if len(scales) != scales_len:
                msg = "Expected {} scales, but got {}.".format(
                    scales_len, len(scales))
                if len(scales) in (1, 2):
                    msg += " If displaying a 1-band Image, use a colormap."
                elif colormap:
                    msg += " Colormaps cannot be used with multi-band images."

                raise ValueError(msg)

            with self.hold_trait_notifications():
                if colormap is None:
                    self.r_min = scales[0][0]
                    self.r_max = scales[0][1]
                    self.g_min = scales[1][0]
                    self.g_max = scales[1][1]
                    self.b_min = scales[2][0]
                    self.b_max = scales[2][1]
                else:
                    self.r_min = scales[0][0]
                    self.r_max = scales[0][1]
                if new_colormap is not False:
                    self.colormap = new_colormap
        else:
            # scales is None
            with self.hold_trait_notifications():
                if colormap is None:
                    self.r_min = None
                    self.r_max = None
                    self.g_min = None
                    self.g_max = None
                    self.b_min = None
                    self.b_max = None
                else:
                    self.r_min = None
                    self.r_max = None
                if new_colormap is not False:
                    self.colormap = new_colormap

    def set_parameters(self, **params):
        """
        Set new parameters for this `WorkflowsLayer`.

        In typical cases, you update parameters by assigning to `parameters`
        (like ``layer.parameters.threshold = 6.6``).

        Instead, use this function when you need to change the *names or types*
        of parameters available on the `WorkflowsLayer`. (Users shouldn't need to
        do this, as `~.Image.visualize` handles it for you, but custom widget developers
        may need to use this method when they change the `image` field on a `WorkflowsLayer`.)

        If a value is an ipywidgets Widget, it will be linked to that parameter
        (via its ``"value"`` attribute). If a parameter was previously set with
        a widget, and a different widget instance (or non-widget) is passed
        for its new value, the old widget is automatically unlinked.
        If the same widget instance is passed as is already linked, no change occurs.

        Parameters
        ----------
        params: JSON-serializable value, Proxytype, or ipywidgets.Widget
            Paramter names to new values. Values can be Python types,
            `Proxytype` instances, or ``ipywidgets.Widget`` instances.

        Example
        -------

        >>> import descarteslabs.workflows as wf
        >>> from ipywidgets import FloatSlider
        >>> img = wf.Image.from_id("landsat:LC08:PRE:TOAR:meta_LC80330352016022_v1") # doctest: +SKIP
        >>> img = img.pick_bands("red") # doctest: +SKIP
        >>> masked_img = img.mask(img > wf.parameter("threshold", wf.Float)) # doctest: +SKIP
        >>> layer = masked_img.tile_layer("sample", colormap="plasma", threshold=0.07) # doctest: +SKIP
        >>> scaled_img = img * wf.parameter("scale", wf.Float) + wf.parameter("offset", wf.Float) # doctest: +SKIP
        >>> with layer.hold_trait_notifications(): # doctest: +SKIP
        ...     layer.image = scaled_img # doctest: +SKIP
        ...     layer.set_parameters(scale=FloatSlider(min=0, max=10, value=2), offset=2.5) # doctest: +SKIP
        >>> # ^ re-use the same layer instance for a new Image with different parameters
        """
        param_set = self.parameters
        if param_set is None:
            param_set = self.parameters = parameters.ParameterSet(
                self, "parameters")

        with self.hold_trait_notifications():
            param_set.update(**params)

    def _ipython_display_(self):
        param_set = self.parameters
        if param_set:
            widget = param_set.widget
            if widget and len(widget.children) > 0:
                widget._ipython_display_()
Exemplo n.º 4
0
class KBinsDiscretizer(Transformer):
    '''Bin continous features into discrete bins.

    A stretegy to encode continuous features into discrete bins. The transformed
    columns contain the bin label each sample falls into. In a way this
    transformer Label/Ordinal encodes continous features.

    Example:

    >>> import vaex
    >>> import vaex.ml
    >>> df = vaex.from_arrays(x=[0, 2.5, 5, 7.5, 10, 12.5, 15])
    >>> bin_trans = vaex.ml.KBinsDiscretizer(features=['x'], n_bins=3, strategy='uniform')
    >>> bin_trans.fit_transform(df)
      #     x    binned_x
      0   0             0
      1   2.5           0
      2   5             1
      3   7.5           1
      4  10             2
      5  12.5           2
      6  15             2
    '''
    snake_name = 'kbins_discretizer'
    n_bins = traitlets.Int(allow_none=False,
                           default_value=5,
                           help='Number of bins. Must be greater than 1.')
    strategy = traitlets.Enum(
        values=['uniform', 'quantile', 'kmeans'],
        default_value='uniform',
        help='Strategy used to define the widths of the bins.')
    prefix = traitlets.Unicode(default_value='binned_', help=help_prefix)
    epsilon = traitlets.Float(
        default_value=1e-8,
        allow_none=False,
        help=
        'Tiny value added to the bin edges ensuring samples close to the bin edges are binned correcly.'
    )
    n_bins_ = traitlets.Dict(help='Number of bins per feature.').tag(
        output=True)
    bin_edges_ = traitlets.Dict(
        help='The bin edges for each binned feature').tag(output=True)

    def fit(self, df):
        '''
        Fit KBinsDiscretizer to the DataFrame.

        :param df: A vaex DataFrame.
        '''

        # We need at least two bins to do the transformations
        assert self.n_bins > 1, ' Kwarg `n_bins` must be greated than 1.'

        # Find the extent of the features
        minmax = []
        minmax_promise = []
        for feat in self.features:
            minmax_promise.append(df.minmax(feat, delay=True))

        @vaex.delayed
        def assign(minmax_promise):
            for elem in minmax_promise:
                minmax.append(elem)

        assign(minmax_promise)
        df.execute()

        # warning: everyting is cast to float, which is unavoidable due to the addition of self.epsilon
        minmax = np.array(minmax)
        minmax[:, 1] = minmax[:, 1] + self.epsilon

        # # Determine the bin edges and number of bins depending on the strategy per feature
        if self.strategy == 'uniform':
            bin_edges = {
                feat: np.linspace(minmax[i, 0], minmax[i, 1], self.n_bins + 1)
                for i, feat in enumerate(self.features)
            }

        elif self.strategy == 'quantile':
            percentiles = np.linspace(0, 100, self.n_bins + 1)
            bin_edges = df.percentile_approx(self.features,
                                             percentage=percentiles)
            bin_edges = {
                feat: edges
                for feat, edges in zip(self.features, bin_edges)
            }

        else:
            from .cluster import KMeans

            bin_edges = {}
            for i, feat in enumerate(self.features):

                # Deterministic initialization with uniform spacing
                uniform_edges = np.linspace(minmax[i, 0], minmax[i, 1],
                                            self.n_bins + 1)
                centers_init = ((uniform_edges[1:] + uniform_edges[:-1]) *
                                0.5).tolist()
                centers_init = [[elem] for elem in centers_init]

                # KMeans strategy
                km = KMeans(n_clusters=self.n_bins,
                            init=centers_init,
                            n_init=1,
                            features=[feat])
                km.fit(df)
                # Get and sort the centres of the kmeans clusters
                centers = np.sort(np.array(km.cluster_centers).flatten())
                # Put the bin edges half way between each center (ignoring the outermost edges)
                be = (centers[1:] + centers[:-1]) * 0.5
                # The outermost edges are defined by the min/max of each feature
                # Quickly build a numpy array by concat individual values (min/max) and arrays (be)
                bin_edges[feat] = np.r_[minmax[i, 0], be, minmax[i, 1]]

        # Remove bins whose width are too small (i.e., <= 1e-8)
        n_bins = {}  # number of bins per features that are actually used
        for feat in self.features:
            mask = np.diff(bin_edges[feat], append=np.inf) > 1e-8
            be = bin_edges[feat][mask]
            if len(be) - 1 != self.n_bins:
                warnings.warn(
                    f'Bins whose width are too small (i.e., <= 1e-8) in   {feat} are removed.'
                    f'Consider decreasing the number of bins.')
                bin_edges[feat] = be
            n_bins[feat] = len(be) - 1

        self.bin_edges_ = bin_edges
        self.n_bins_ = n_bins

    def transform(self, df):
        '''
        Transform a DataFrame with a fitted KBinsDiscretizer.

        :param df: A vaex DataFrame.

        :returns copy: a shallow copy of the DataFrame that includes the binned features.
        :rtype: DataFrame
        '''

        df = df.copy()

        for feat in self.features:
            name = self.prefix + feat
            # Samples outside the bin range are added to the closest bin
            df[name] = (df[feat].digitize(self.bin_edges_[feat]) - 1).clip(
                0, self.n_bins_[feat] - 1)

        return df
Exemplo n.º 5
0
class Widget(DOMWidget):
    _view_name = traitlets.Unicode('ReboundView').tag(sync=True)
    _view_module = traitlets.Unicode('rebound').tag(sync=True)
    count = traitlets.Int(0).tag(sync=True)
    screenshotcount = traitlets.Int(0).tag(sync=True)
    t = traitlets.Float().tag(sync=True)
    N = traitlets.Int().tag(sync=True)
    width = traitlets.Float().tag(sync=True)
    height = traitlets.Float().tag(sync=True)
    scale = traitlets.Float().tag(sync=True)
    particle_data = traitlets.CBytes(allow_none=True).tag(sync=True)
    orbit_data = traitlets.CBytes(allow_none=True).tag(sync=True)
    orientation = traitlets.Tuple().tag(sync=True)
    orbits = traitlets.Int().tag(sync=True)
    screenshot = traitlets.Unicode().tag(sync=True)

    def __init__(self,
                 simulation,
                 size=(200, 200),
                 orientation=(0., 0., 0., 1.),
                 scale=None,
                 autorefresh=True,
                 orbits=True):
        """ 
        Initializes a Widget.

        Widgets provide real-time 3D interactive visualizations for REBOUND simulations 
        within Jupyter Notebooks. To use widgets, the ipywidgets package needs to be installed
        and enabled in your Jupyter notebook server. 

        Parameters
        ----------
        size : (int, int), optional
            Specify the size of the widget in pixels. The default is 200 times 200 pixels.
        orientation : (float, float, float, float), optional
            Specify the initial orientation of the view. The four floats correspond to the
            x, y, z, and w components of a quaternion. The quaternion will be normalized.
        scale : float, optional
            Set the initial scale of the view. If not set, the widget will determine the 
            scale automatically based on current particle positions.
        autorefresh : bool, optional
            The default value if True. The view is updated whenever a particle is added,
            removed and every 100th of a second while a simulation is running. If set 
            to False, then the user needs to manually call the refresh() function on the 
            widget. This might be useful if performance is an issue.
        orbits : bool, optional
            The default value for this is True and the widget will draw the instantaneous 
            orbits of the particles. For simulations in which particles are not on
            Keplerian orbits, the orbits shown will not be accurate. 
        """
        self.screenshotcountall = 0
        self.width, self.height = size
        self.t, self.N = simulation.t, simulation.N
        self.orientation = orientation
        self.autorefresh = autorefresh
        self.orbits = orbits
        self.simp = pointer(simulation)
        clibrebound.reb_display_copy_data.restype = c_int
        if scale is None:
            self.scale = simulation.display_data.contents.scale
        else:
            self.scale = scale
        self.count += 1

        super(Widget, self).__init__()

    def refresh(self, simp=None, isauto=0):
        """ 
        Manually refreshes a widget.
        
        Note that this function can also be called using the wrapper function of
        the Simulation object: sim.refreshWidgets(). 
        """

        if simp == None:
            simp = self.simp
        if self.autorefresh == 0 and isauto == 1:
            return
        sim = simp.contents
        size_changed = clibrebound.reb_display_copy_data(simp)
        clibrebound.reb_display_prepare_data(simp, c_int(self.orbits))
        if sim.N > 0:
            self.particle_data = (c_char * (4 * 7 * sim.N)).from_address(
                sim.display_data.contents.particle_data).raw
            if self.orbits:
                self.orbit_data = (
                    c_char * (4 * 9 * (sim.N - 1))).from_address(
                        sim.display_data.contents.orbit_data).raw
        if size_changed:
            #TODO: Implement better GPU size change
            pass
        self.N = sim.N
        self.t = sim.t
        self.count += 1

    def takeScreenshot(self,
                       times=None,
                       prefix="./screenshot",
                       resetCounter=False,
                       archive=None,
                       mode="snapshot"):
        """
        Take one or more screenshots of the widget and save the images to a file. 
        The images can be used to create a video.

        This function cannot be called multiple times within one cell.

        Note: this is a new feature and might not work on all systems.
        It was tested on python 2.7.10 and 3.5.2 on MacOSX.
        
        Parameters
        ----------
        times : (float, list), optional
            If this argument is not given a screenshot of the widget will be made 
            as it is (without integrating the simulation). If a float is given, then the
            simulation will be integrated to that time and then a screenshot will 
            be taken. If a list of floats is given, the simulation will be integrated
            to each time specified in the array. A separate screenshot for 
            each time will be saved.
        prefix : (str), optional
            This string will be part of the output filename for each image.
            Follow by a five digit integer and the suffix .png. By default the
            prefix is './screenshot' which outputs images in the current
            directory with the filnames screenshot00000.png, screenshot00001.png...
            Note that the prefix can include a directory.
        resetCounter : (bool), optional
            Resets the output counter to 0. 
        archive : (rebound.SimulationArchive), optional
            Use a REBOUND SimulationArchive. Thus, instead of integratating the 
            Simulation from the current time, it will use the SimulationArchive
            to load a snapshot. See examples for usage.
        mode : (string), optional
            Mode to use when querying the SimulationArchive. See SimulationArchive
            documentation for details. By default the value is "snapshot".

        Examples
        --------

        First, create a simulation and widget. All of the following can go in 
        one cell.

        >>> sim = rebound.Simulation()
        >>> sim.add(m=1.)
        >>> sim.add(m=1.e-3,x=1.,vy=1.)
        >>> w = sim.getWidget()
        >>> w

        The widget should show up. To take a screenshot, simply call 

        >>> w.takeScreenshot()

        A new file with the name screenshot00000.png will appear in the 
        current directory. 

        Note that the takeScreenshot command needs to be in a separate cell, 
        i.e. after you see the widget. 

        You can pass an array of times to the function. This allows you to 
        take multiple screenshots, for example to create a movie, 

        >>> times = [0,10,100]
        >>> w.takeScreenshot(times)

        """
        self.archive = archive
        if resetCounter:
            self.screenshotcountall = 0
        self.screenshotprefix = prefix
        self.screenshotcount = 0
        self.screenshot = ""
        if archive is None:
            if times is None:
                times = self.simp.contents.t
            try:
                # List
                len(times)
            except:
                # Float:
                times = [times]
            self.times = times
            self.observe(savescreenshot, names="screenshot")
            self.simp.contents.integrate(times[0])
            self.screenshotcount += 1  # triggers first screenshot
        else:
            if times is None:
                raise ValueError("Need times argument for archive mode.")
            try:
                len(times)
            except:
                raise ValueError("Need a list of times for archive mode.")
            self.times = times
            self.mode = mode
            self.observe(savescreenshot, names="screenshot")
            sim = archive.getSimulation(times[0], mode=mode)
            self.refresh(pointer(sim))
            self.screenshotcount += 1  # triggers first screenshot

    @staticmethod
    def getClientCode():
        return shader_code + js_code
Exemplo n.º 6
0
class ReflectModelView(HasTraits):
    """
    An ipywidgets viewport of a `refnx.reflect.ReflectModel`.

    Parameters
    ----------
    reflect_model: refnx.reflect.ReflectModel

    Notes
    -----
    Use the `model_box` property to view/modify the ReflectModel parameters.
    Use the `limits_box` property to view the limits for the varying
    parameters.
    Observe the `view_changed` traitlet to determine when widget values are
    changed.
    Observe the `view_redraw` traitlet to determine when a complete redraw
    of the view is required (because the number of widgets has changed for
    example).

    """

    # traitlet to say when params were last altered
    view_changed = traitlets.Float(time.time())

    # traitlet to ask when a redraw of the GUI is requested.
    # e.g. the number of layers has changed, or there are a
    # different number of fitted parameters requiring
    # different limit widgets.
    view_redraw = traitlets.Float(time.time())

    # number of varying parameters in the model
    num_varying = traitlets.Int(0)

    def __init__(self, reflect_model):
        super(ReflectModelView, self).__init__()

        self.model = reflect_model
        self.structure_view = StructureView(self.model.structure)
        self.last_selected_param = None
        self.param_widgets_link = {}

        slab_views = self.structure_view.slab_views
        slab_views[0].w_thick.disabled = True
        slab_views[0].c_thick.disabled = True
        slab_views[0].w_rough.disabled = True
        slab_views[0].c_rough.disabled = True
        slab_views[-1].w_thick.disabled = True
        slab_views[-1].c_thick.disabled = True

        # got to listen to all the slab views
        for slab_view in slab_views:
            slab_view.observe(self._on_slab_params_modified,
                              names=['view_changed'])

        # if you'd like to change the number of layers
        self.w_layers = widgets.BoundedIntText(
            description='Number of layers',
            value=len(slab_views) - 2, min=0, max=1000,
            style={'description_width': '120px'},
            continuous_update=False)

        self.w_layers.observe(self._on_change_layers, names=['value'])

        # where you're going to add/remove layers
        # varying layers is a flag to say if you're currently in the process
        # of adding/removing layers
        self._varying_layers = False
        self._location = None
        self.ok_button = None
        self.cancel_button = None

        # associated with ReflectModel
        p = reflect_model.scale
        self.w_scale = widgets.FloatText(value=p.value,
                                         description='scale', step=0.01,
                                         style={'description_width': '120px'})
        self.c_scale = widgets.Checkbox(value=p.vary)
        self.scale_low_limit = widgets.FloatText(value=p.bounds.lb, step=0.01)
        self.scale_hi_limit = widgets.FloatText(value=p.bounds.ub, step=0.01)

        p = reflect_model.bkg
        self.w_bkg = widgets.FloatText(value=p.value,
                                       description='background', step=1e-7,
                                       style={'description_width': '120px'})
        self.c_bkg = widgets.Checkbox(value=reflect_model.bkg.vary)
        self.bkg_low_limit = widgets.FloatText(p.bounds.lb, step=1e-8)
        self.bkg_hi_limit = widgets.FloatText(value=p.bounds.ub, step=1e-7)

        p = reflect_model.dq
        self.w_dq = widgets.BoundedFloatText(value=p.value,
                                             description='dq/q', step=0.1,
                                             min=0, max=20.)
        self.c_dq = widgets.Checkbox(value=reflect_model.dq.vary)
        self.dq_low_limit = widgets.BoundedFloatText(value=p.bounds.lb,
                                                     min=0, max=20,
                                                     step=0.1)
        self.dq_hi_limit = widgets.BoundedFloatText(value=p.bounds.ub,
                                                    min=0, max=20,
                                                    step=0.1)

        self.c_scale.style.description_width = '0px'
        self.c_bkg.style.description_width = '0px'
        self.c_dq.style.description_width = '0px'
        self.do_fit_button = widgets.Button(description='Do Fit')
        self.to_code_button = widgets.Button(description='To code')

        widget_list = [self.w_scale, self.c_scale, self.w_bkg,
                       self.c_bkg, self.w_dq, self.c_dq]

        limits_widgets_list = [self.scale_low_limit, self.scale_hi_limit,
                               self.bkg_low_limit, self.bkg_hi_limit,
                               self.dq_low_limit, self.dq_hi_limit]

        for widget in widget_list:
            widget.observe(self._on_model_params_modified, names=['value'])

        for widget in limits_widgets_list:
            widget.observe(self._on_model_limits_modified, names=['value'])

        # button to create default limits
        self.default_limits_button = widgets.Button(
            description='Set default limits')
        self.default_limits_button.on_click(self.default_limits)

        # widgets for easy model change
        self.model_slider = widgets.FloatSlider()
        self.model_slider.layout = widgets.Layout(width='100%')
        self.model_slider_link = None
        self.model_slider_min = widgets.FloatText()
        self.model_slider_min.layout = widgets.Layout(width='10%')
        self.model_slider_max = widgets.FloatText()
        self.model_slider_max.layout = widgets.Layout(width='10%')
        self.model_slider_max.observe(self._on_slider_limits_modified,
                                      names=['value'])
        self.model_slider_min.observe(self._on_slider_limits_modified,
                                      names=['value'])
        self.last_selected_param = None

        self.num_varying = len(self.model.parameters.varying_parameters())

        self._link_param_widgets()

    def _on_model_params_modified(self, change):
        """
        Called when ReflectModel parameters are varied.
        """
        d = self.param_widgets_link

        for par in [self.model.scale, self.model.bkg, self.model.dq]:
            idx = id(par)
            wids = d[idx]

            if change['owner'] in wids:
                loc = wids.index(change['owner'])
                if loc == 0:
                    par.value = wids[0].value

                    # this captures when the user starts modifying a different
                    # parameter
                    self._possibly_link_slider(change['owner'])

                    self.view_changed = time.time()
                    break
                elif loc == 1:
                    # you are changing the number of varying parameters
                    par.vary = wids[1].value

                    # need to rebuild the limit widgets, achieved by redrawing
                    # box
                    # set the number of varying parameters
                    self.num_varying = (
                        len(self.model.parameters.varying_parameters()))

                    self.view_redraw = time.time()
                    break
                else:
                    return

        # this captures when the user starts modifying a different parameter
        self._possibly_link_slider(change['owner'])

    def _on_model_limits_modified(self, change):
        """
        When a limit widget is changed, update corresponding limits in the
        underlying ReflectModel.
        """
        d = self.param_widgets_link
        for par in [self.model.scale, self.model.bkg, self.model.dq]:
            idx = id(par)
            wids = d[idx]

            if change['owner'] in wids:
                loc = wids.index(change['owner'])
                if loc == 2:
                    par.bounds.lb = wids[2].value
                    break
                elif loc == 3:
                    par.bounds.ub = wids[3].value
                    break

    def default_limits(self, change):
        """
        Makes default limits for the parameters being varied
        """
        varying_parameters = self.model.parameters.varying_parameters()

        for par in varying_parameters:
            par.bounds.lb = min(0, 2 * par.value)
            par.bounds.ub = max(0, 2 * par.value)

        self.refresh()

    def _on_slab_params_modified(self, change):
        """
        Called when slab parameters are varied.
        """
        # this captures when the user starts modifying a different parameter
        if isinstance(change['owner'].param_being_varied, widgets.Checkbox):
            # you are changing the number of fitted parameters

            # set the number of varying parameters
            self.num_varying = len(self.model.parameters.varying_parameters())

            # need to rebuild the limit widgets, achieved by redrawing box
            self.view_redraw = time.time()
        else:
            self._possibly_link_slider(change['owner'].param_being_varied)
            self.view_changed = time.time()

    def _possibly_link_slider(self, change_owner):
        """
        When a ReflectModel value is changed link a slider widget to the
        parameter that is being varied.
        """
        if (change_owner is not self.last_selected_param):
            self.last_selected_param = change_owner
            if self.model_slider_link is not None:
                self.model_slider_link.unlink()
            self.model_slider_link = widgets.link(
                (self.last_selected_param, 'value'),
                (self.model_slider, 'value'))
            self.model_slider_max.value = max(
                0, 2. * self.last_selected_param.value)
            self.model_slider_min.value = min(
                0, 2. * self.last_selected_param.value)

    def _on_slider_limits_modified(self, change):
        """
        Callback when adjusting the min, max widgets for the main slider
        widget.
        """
        self.model_slider.max = self.model_slider_max.value
        self.model_slider.min = self.model_slider_min.value
        self.model_slider.step = (self.model_slider.max -
                                  self.model_slider.min) / 200.

    def _on_change_layers(self, change):
        self.ok_button = widgets.Button(description="OK")
        if change['new'] > change['old']:
            description = 'Insert before which layer?'
            min_loc = 1
            max_loc = len(self.model.structure) - 2 + 1
            self.ok_button.on_click(self._increase_layers)
        elif change['new'] < change['old']:
            min_loc = 1
            max_loc = (len(self.model.structure) - 2 -
                       (change['old'] - change['new']) + 1)
            description = 'Remove from which layer?'
            self.ok_button.on_click(self._decrease_layers)
        else:
            return
        self._varying_layers = True
        self.w_layers.disabled = True
        self.do_fit_button.disabled = True
        self.to_code_button.disabled = True

        self._location = widgets.BoundedIntText(
            value=min_loc,
            description=description,
            min=min_loc, max=max_loc,
            style={'description_width': 'initial'})
        self.cancel_button = widgets.Button(description="Cancel")
        self.cancel_button.on_click(self._cancel_layers)
        self.view_redraw = time.time()

    def _increase_layers(self, b):
        self.w_layers.disabled = False
        self.do_fit_button.disabled = False
        self.to_code_button.disabled = False

        how_many = self.w_layers.value - (len(self.model.structure) - 2)
        loc = self._location.value

        for i in range(how_many):
            slab = Slab(0, 0, 3)
            slab.thick.bounds = (0, 2 * slab.thick.value)
            slab.sld.real.bounds = (0, 2 * slab.sld.real.value)
            slab.sld.imag.bounds = (0, 2 * slab.sld.imag.value)
            slab.rough.bounds = (0, 2 * slab.rough.value)

            slab_view = SlabView(slab)
            self.model.structure.insert(loc, slab)
            self.structure_view.slab_views.insert(loc, slab_view)
            slab_view.observe(self._on_slab_params_modified)

        rename_params(self.model.structure)
        self._varying_layers = False

        # set the number of varying parameters
        self.num_varying = len(self.model.parameters.varying_parameters())

        self.view_redraw = time.time()

    def _decrease_layers(self, b):
        self.w_layers.disabled = False
        self.do_fit_button.disabled = False
        self.to_code_button.disabled = False

        loc = self._location.value
        how_many = len(self.model.structure) - 2 - self.w_layers.value
        for i in range(how_many):
            self.model.structure.pop(loc)
            slab_view = self.structure_view.slab_views.pop(loc)
            slab_view.unobserve_all()

        rename_params(self.model.structure)
        self._varying_layers = False

        # set the number of varying parameters
        self.num_varying = len(self.model.parameters.varying_parameters())

        self.view_redraw = time.time()

    def _link_param_widgets(self):
        """
        Creates a dictionary of {parameter: (associated_widgets_tuple)}.
        """

        # link parameters to widgets (value, checkbox,
        #                             upperlim, lowerlim)
        self.param_widgets_link = {}
        d = self.param_widgets_link
        model = self.model

        d[id(model.scale)] = (self.w_scale, self.c_scale,
                              self.scale_low_limit, self.scale_hi_limit)
        d[id(model.bkg)] = (self.w_bkg, self.c_bkg,
                            self.bkg_low_limit, self.bkg_hi_limit)
        d[id(model.dq)] = (self.w_dq, self.c_dq,
                           self.dq_low_limit, self.dq_hi_limit)

    def _cancel_layers(self, b):
        # disable the change layers widget to prevent recursion
        self.w_layers.unobserve(self._on_change_layers, names='value')
        self.w_layers.value = len(self.model.structure) - 2
        self.w_layers.observe(self._on_change_layers, names='value')
        self.w_layers.disabled = False
        self.do_fit_button.disabled = False
        self.to_code_button.disabled = False

        self._varying_layers = False
        self.view_redraw = time.time()

    def refresh(self):
        """
        Updates the widget values from the underlying `ReflectModel`.
        """
        for par in [self.model.scale, self.model.bkg, self.model.dq]:
            wid = self.param_widgets_link[id(par)]
            wid[0].value = par.value
            wid[1].value = par.vary
            wid[2].value = par.bounds.lb
            wid[3].value = par.bounds.ub

        slab_views = self.structure_view.slab_views

        for slab_view in slab_views:
            slab_view.refresh()

    @property
    def model_box(self):
        """
        `ipywidgets.Vbox` displaying model relevant widgets.
        """
        output = [self.w_layers,
                  widgets.HBox([self.w_scale, self.c_scale,
                                self.w_dq, self.c_dq]),
                  widgets.HBox([self.w_bkg, self.c_bkg]),
                  self.structure_view.box,
                  widgets.HBox([self.model_slider_min,
                                self.model_slider,
                                self.model_slider_max])]

        if self._varying_layers:
            output.append(widgets.HBox([self._location,
                                        self.ok_button,
                                        self.cancel_button]))

        output.append(widgets.HBox([self.do_fit_button, self.to_code_button]))

        return widgets.VBox(output)

    @property
    def limits_box(self):
        varying_pars = self.model.parameters.varying_parameters()
        hboxes = [self.default_limits_button]

        d = {}
        d.update(self.param_widgets_link)

        slab_views = self.structure_view.slab_views
        for slab_view in slab_views:
            d.update(slab_view.param_widgets_link)

        for par in varying_pars:
            name = widgets.Text(par.name)
            name.disabled = True

            val, check, ll, ul = d[id(par)]

            hbox = widgets.HBox([name, ll, val, ul])
            hboxes.append(hbox)

        return widgets.VBox(hboxes)
Exemplo n.º 7
0
 class MyClass(tl.HasTraits):
     t = TupleTrait(trait=tl.Int())
Exemplo n.º 8
0
class Style(tl.HasTraits):
    """Summary

    Attributes
    ----------
    name : str
        data name
    units : TYPE
        data units
    clim : list
        [low, high], color map limits
    colormap : str
        matplotlib colormap name
    cmap : matplotlib.cm.ColorMap
        matplotlib colormap property
    enumeration_colors : dict
        data colors (replaces colormap/cmap)
    enumeration_legend : dict
        data legend, should correspond with enumeration_colors
    """
    def __init__(self, node=None, *args, **kwargs):
        if node:
            self.name = node.__class__.__name__
            self.units = node.units
        super(Style, self).__init__(*args, **kwargs)

    name = tl.Unicode()
    units = tl.Unicode(allow_none=True, default_value="")
    clim = tl.List(default_value=[None, None])
    colormap = tl.Unicode(allow_none=True, default_value=None)
    enumeration_legend = tl.Dict(key_trait=tl.Int(),
                                 value_trait=tl.Unicode(),
                                 default_value=None,
                                 allow_none=True)
    enumeration_colors = tl.Dict(key_trait=tl.Int(),
                                 default_value=None,
                                 allow_none=True)
    default_enumeration_legend = tl.Unicode(
        default_value=DEFAULT_ENUMERATION_LEGEND)
    default_enumeration_color = tl.Any(default_value=DEFAULT_ENUMERATION_COLOR)

    @tl.validate("colormap")
    def _validate_colormap(self, d):
        if isinstance(d["value"], six.string_types):
            matplotlib.cm.get_cmap(d["value"])
        if d["value"] and self.enumeration_colors:
            raise TypeError(
                "Style can have a colormap or enumeration_colors, but not both"
            )
        return d["value"]

    @tl.validate("enumeration_colors")
    def _validate_enumeration_colors(self, d):
        enum_colors = d["value"]
        if enum_colors and self.colormap:
            raise TypeError(
                "Style can have a colormap or enumeration_colors, but not both"
            )
        return enum_colors

    @tl.validate("enumeration_legend")
    def _validate_enumeration_legend(self, d):
        # validate against enumeration_colors
        enum_legend = d["value"]
        if not self.enumeration_colors:
            raise TypeError(
                "Style enumeration_legend requires enumeration_colors")
        if set(enum_legend) != set(self.enumeration_colors):
            raise ValueError(
                "Style enumeration_legend keys must match enumeration_colors keys"
            )
        return enum_legend

    @property
    def full_enumeration_colors(self):
        """ Convert enumeration_colors into a tuple suitable for matplotlib ListedColormap. """
        return tuple([
            self.enumeration_colors.get(value, self.default_enumeration_color)
            for value in range(min(self.enumeration_colors),
                               max(self.enumeration_colors) + 1)
        ])

    @property
    def full_enumeration_legend(self):
        """ Convert enumeration_legend into a tuple suitable for matplotlib. """
        return tuple([
            self.enumeration_legend.get(value, self.default_enumeration_legend)
            for value in range(min(self.enumeration_legend),
                               max(self.enumeration_legend) + 1)
        ])

    @property
    def cmap(self):
        if self.colormap:
            return matplotlib.cm.get_cmap(self.colormap)
        elif self.enumeration_colors:
            return ListedColormap(self.full_enumeration_colors)
        else:
            return matplotlib.cm.get_cmap("viridis")

    @property
    def json(self):
        """JSON-serialized style definition

        The `json` can be used to create new styles.

        See Also
        ----------
        from_json
        """

        return json.dumps(self.definition,
                          separators=(",", ":"),
                          cls=JSONEncoder)

    @classmethod
    def get_style_ui(self):
        """
        Attempting to expose style units to get_ui_spec(). This will grab defaults in general.
        BUT this will not set defaults for each particular node.
        """
        d = OrderedDict()
        if self.name:
            d["name"] = self.name
        if self.units:
            d["units"] = self.units
        if self.colormap:
            d["colormap"] = self.colormap
        if self.enumeration_legend:
            d["enumeration_legend"] = self.enumeration_legend
        if self.enumeration_colors:
            d["enumeration_colors"] = self.enumeration_colors
        if self.default_enumeration_legend != DEFAULT_ENUMERATION_LEGEND:
            d["default_enumeration_legend"] = self.default_enumeration_legend
        if self.default_enumeration_color != DEFAULT_ENUMERATION_COLOR:
            d["default_enumeration_color"] = self.default_enumeration_color
        if self.clim != [None, None]:
            d["clim"] = self.clim
        return d

    @property
    def definition(self):
        d = OrderedDict()
        if self.name:
            d["name"] = self.name
        if self.units:
            d["units"] = self.units
        if self.colormap:
            d["colormap"] = self.colormap
        if self.enumeration_legend:
            d["enumeration_legend"] = self.enumeration_legend
        if self.enumeration_colors:
            d["enumeration_colors"] = self.enumeration_colors
        if self.default_enumeration_legend != DEFAULT_ENUMERATION_LEGEND:
            d["default_enumeration_legend"] = self.default_enumeration_legend
        if self.default_enumeration_color != DEFAULT_ENUMERATION_COLOR:
            d["default_enumeration_color"] = self.default_enumeration_color
        if self.clim != [None, None]:
            d["clim"] = self.clim
        return d

    @classmethod
    def from_definition(cls, d):
        # parse enumeration keys to int
        if "enumeration_colors" in d:
            d["enumeration_colors"] = {
                int(key): value
                for key, value in d["enumeration_colors"].items()
            }
        if "enumeration_legend" in d:
            d["enumeration_legend"] = {
                int(key): value
                for key, value in d["enumeration_legend"].items()
            }
        return cls(**d)

    @classmethod
    def from_json(cls, s):
        """Create podpac Style from a style JSON definition.

        Parameters
        -----------
        s : str
            JSON definition

        Returns
        --------
        Style
            podpac Style object
        """

        d = json.loads(s)
        return cls.from_definition(d)

    def __eq__(self, other):
        if not isinstance(other, Style):
            return False

        return self.json == other.json