예제 #1
0
    def __init__(self, data, layer_artist_container=None):

        VizClient.__init__(self, data)

        self.artists = layer_artist_container
        if self.artists is None:
            self.artists = LayerArtistContainer()

        # slice through ND cube
        # ('y', 'x', 2)
        # means current data slice is [:, :, 2], and axis=0 is vertical on plot
        self._slice = None

        # how to extract a downsampled/cropped 2D image to plot
        # (ComponentID, slice, slice, ...)
        self._view = None

        # cropped/downsampled image
        # self._image == self.display_data[self._view]
        self._image = None

        # if this is set, render this instead of self._image
        self._override_image = None

        # maps attributes -> normalization settings
        self._norm_cache = {}
예제 #2
0
파일: client.py 프로젝트: rguter/glue
    def __init__(self,
                 data=None,
                 figure=None,
                 axes=None,
                 layer_artist_container=None):
        """
        Create a new ScatterClient object

        :param data: :class:`~glue.core.data.DataCollection` to use

        :param figure:
           Which matplotlib figure instance to draw to. One will be created if
           not provided

        :param axes:
           Which matplotlib axes instance to use. Will be created if necessary
        """
        Client.__init__(self, data=data)
        figure, axes = init_mpl(figure, axes)
        self.artists = layer_artist_container
        if self.artists is None:
            self.artists = LayerArtistContainer()

        self._layer_updated = False  # debugging
        self._xset = False
        self._yset = False
        self.axes = axes

        self._connect()
        self._set_limits()
예제 #3
0
    def __init__(self, data=None, figure=None, axes=None,
                 layer_artist_container=None, axes_factory=None):

        super(GenericMplClient, self).__init__(data=data)
        if axes_factory is None:
            axes_factory = self.create_axes
        figure, self.axes = init_mpl(figure, axes, axes_factory=axes_factory)
        self.artists = layer_artist_container
        if self.artists is None:
            self.artists = LayerArtistContainer()

        self._connect()
예제 #4
0
파일: client.py 프로젝트: bmorris3/glue
    def __init__(self, data=None, figure=None, axes=None,
                 layer_artist_container=None):
        """
        Create a new ScatterClient object

        :param data: :class:`~glue.core.data.DataCollection` to use

        :param figure:
           Which matplotlib figure instance to draw to. One will be created if
           not provided

        :param axes:
           Which matplotlib axes instance to use. Will be created if necessary
        """
        Client.__init__(self, data=data)
        figure, axes = init_mpl(figure, axes)
        self.artists = layer_artist_container
        if self.artists is None:
            self.artists = LayerArtistContainer()

        self._layer_updated = False  # debugging
        self._xset = False
        self._yset = False
        self.axes = axes

        self._connect()
        self._set_limits()
예제 #5
0
파일: client.py 프로젝트: astrofrog/glue
    def __init__(self, data, layer_artist_container=None):

        VizClient.__init__(self, data)

        self.artists = layer_artist_container
        if self.artists is None:
            self.artists = LayerArtistContainer()

        # slice through ND cube
        # ('y', 'x', 2)
        # means current data slice is [:, :, 2], and axis=0 is vertical on plot
        self._slice = None

        # how to extract a downsampled/cropped 2D image to plot
        # (ComponentID, slice, slice, ...)
        self._view = None

        # cropped/downsampled image
        # self._image == self.display_data[self._view]
        self._image = None

        # if this is set, render this instead of self._image
        self._override_image = None

        # maps attributes -> normalization settings
        self._norm_cache = {}
예제 #6
0
    def __init__(self, data, figure, layer_artist_container=None):
        super(HistogramClient, self).__init__(data)

        self._artists = layer_artist_container or LayerArtistContainer()
        self._figure, self._axes = init_mpl(figure=figure, axes=None)
        self._component = None
        self._saved_nbins = None
        self._xlim_cache = {}
        self._xlog_cache = {}
        self._sync_enabled = True
        self._xlog_curr = False
예제 #7
0
파일: viz_client.py 프로젝트: bmorris3/glue
    def __init__(self, data=None, figure=None, axes=None,
                 layer_artist_container=None, axes_factory=None):

        super(GenericMplClient, self).__init__(data=data)
        if axes_factory is None:
            axes_factory = self.create_axes
        figure, self.axes = init_mpl(figure, axes, axes_factory=axes_factory)
        self.artists = layer_artist_container
        if self.artists is None:
            self.artists = LayerArtistContainer()

        self._connect()
예제 #8
0
class GenericMplClient(Client):

    """
    This client base class handles the logic of adding, removing,
    and updating layers.

    Subsets are auto-added and removed with datasets.
    New subsets are auto-added iff the data has already been added
    """

    def __init__(self, data=None, figure=None, axes=None,
                 layer_artist_container=None, axes_factory=None):

        super(GenericMplClient, self).__init__(data=data)
        if axes_factory is None:
            axes_factory = self.create_axes
        figure, self.axes = init_mpl(figure, axes, axes_factory=axes_factory)
        self.artists = layer_artist_container
        if self.artists is None:
            self.artists = LayerArtistContainer()

        self._connect()

    def create_axes(self, figure):
        return figure.add_subplot(1, 1, 1)

    def _connect(self):
        pass

    @property
    def collect(self):
        # a better name
        return self.data

    def _redraw(self):
        self.axes.figure.canvas.draw()

    def new_layer_artist(self, layer):
        raise NotImplementedError

    def apply_roi(self, roi):
        raise NotImplementedError

    def _update_layer(self, layer):
        raise NotImplementedError

    def add_layer(self, layer):
        """
        Add a new Data or Subset layer to the plot.

        Returns the created layer artist

        :param layer: The layer to add
        :type layer: :class:`~glue.core.data.Data` or :class:`~glue.core.subset.Subset`
        """
        if layer.data not in self.collect:
            return

        if layer in self.artists:
            return self.artists[layer][0]

        result = self.new_layer_artist(layer)
        self.artists.append(result)
        self._update_layer(layer)

        self.add_layer(layer.data)
        for s in layer.data.subsets:
            self.add_layer(s)

        if layer.data is layer:  # Added Data object. Relimit view
            self.axes.autoscale_view(True, True, True)

        return result

    def remove_layer(self, layer):
        if layer not in self.artists:
            return

        self.artists.pop(layer)
        if isinstance(layer, Data):
            list(map(self.remove_layer, layer.subsets))

        self._redraw()

    def set_visible(self, layer, state):
        """
        Toggle a layer's visibility

        :param layer: which layer to modify
        :param state: True or False
        """

    def _update_all(self):
        for layer in self.artists.layers:
            self._update_layer(layer)

    def __contains__(self, layer):
        return layer in self.artists

    # Hub message handling
    def _add_subset(self, message):
        self.add_layer(message.sender)

    def _remove_subset(self, message):
        self.remove_layer(message.sender)

    def _update_subset(self, message):
        self._update_layer(message.sender)

    def _update_data(self, message):
        self._update_layer(message.sender)

    def _remove_data(self, message):
        self.remove_layer(message.data)

    def register_to_hub(self, hub):

        super(GenericMplClient, self).register_to_hub(hub)

        def is_appearance_settings(msg):
            return ('BACKGROUND_COLOR' in msg.settings or
                    'FOREGROUND_COLOR' in msg.settings)

        hub.subscribe(self, SettingsChangeMessage,
                      self._update_appearance_from_settings,
                      filter=is_appearance_settings)

    def _update_appearance_from_settings(self, message):
        update_appearance_from_settings(self.axes)
        self._redraw()

    def restore_layers(self, layers, context):
        """ Re-generate plot layers from a glue-serialized list"""
        for l in layers:
            l.pop('_type')
            props = dict((k, context.object(v)) for k, v in l.items())
            layer = self.add_layer(props['layer'])
            layer.properties = props
예제 #9
0
파일: client.py 프로젝트: rguter/glue
class ScatterClient(Client):
    """
    A client class that uses matplotlib to visualize tables as scatter plots.
    """
    xmin = CallbackProperty(0)
    xmax = CallbackProperty(1)
    ymin = CallbackProperty(0)
    ymax = CallbackProperty(1)
    ylog = CallbackProperty(False)
    xlog = CallbackProperty(False)
    yflip = CallbackProperty(False)
    xflip = CallbackProperty(False)
    xatt = CallbackProperty()
    yatt = CallbackProperty()
    jitter = CallbackProperty()

    def __init__(self,
                 data=None,
                 figure=None,
                 axes=None,
                 layer_artist_container=None):
        """
        Create a new ScatterClient object

        :param data: :class:`~glue.core.data.DataCollection` to use

        :param figure:
           Which matplotlib figure instance to draw to. One will be created if
           not provided

        :param axes:
           Which matplotlib axes instance to use. Will be created if necessary
        """
        Client.__init__(self, data=data)
        figure, axes = init_mpl(figure, axes)
        self.artists = layer_artist_container
        if self.artists is None:
            self.artists = LayerArtistContainer()

        self._layer_updated = False  # debugging
        self._xset = False
        self._yset = False
        self.axes = axes

        self._connect()
        self._set_limits()

    def is_layer_present(self, layer):
        """ True if layer is plotted """
        return layer in self.artists

    def get_layer_order(self, layer):
        """If layer exists as a single artist, return its zorder.
        Otherwise, return None"""
        artists = self.artists[layer]
        if len(artists) == 1:
            return artists[0].zorder
        else:
            return None

    @property
    def layer_count(self):
        return len(self.artists)

    def _connect(self):
        add_callback(self, 'xlog', self._set_xlog)
        add_callback(self, 'ylog', self._set_ylog)

        add_callback(self, 'xflip', self._set_limits)
        add_callback(self, 'yflip', self._set_limits)
        add_callback(self, 'xmin', self._set_limits)
        add_callback(self, 'xmax', self._set_limits)
        add_callback(self, 'ymin', self._set_limits)
        add_callback(self, 'ymax', self._set_limits)
        add_callback(self, 'xatt', partial(self._set_xydata, 'x'))
        add_callback(self, 'yatt', partial(self._set_xydata, 'y'))
        add_callback(self, 'jitter', self._jitter)
        self.axes.figure.canvas.mpl_connect('draw_event',
                                            lambda x: self._pull_properties())

    def _set_limits(self, *args):

        xlim = min(self.xmin, self.xmax), max(self.xmin, self.xmax)
        if self.xflip:
            xlim = xlim[::-1]
        ylim = min(self.ymin, self.ymax), max(self.ymin, self.ymax)
        if self.yflip:
            ylim = ylim[::-1]

        xold = self.axes.get_xlim()
        yold = self.axes.get_ylim()
        self.axes.set_xlim(xlim)
        self.axes.set_ylim(ylim)
        if xlim != xold or ylim != yold:
            self._redraw()

    def plottable_attributes(self, layer, show_hidden=False):
        data = layer.data
        comp = data.components if show_hidden else data.visible_components
        return [
            c for c in comp if data.get_component(c).numeric
            or data.get_component(c).categorical
        ]

    def add_layer(self, layer):
        """ Adds a new visual layer to a client, to display either a dataset
        or a subset. Updates both the client data structure and the
        plot.

        Returns the created layer artist

        :param layer: the layer to add
        :type layer: :class:`~glue.core.data.Data` or :class:`~glue.core.subset.Subset`
        """
        if layer.data not in self.data:
            raise TypeError("Layer not in data collection")
        if layer in self.artists:
            return self.artists[layer][0]

        result = ScatterLayerArtist(layer, self.axes)
        self.artists.append(result)
        self._update_layer(layer)
        self._ensure_subsets_added(layer)
        return result

    def _ensure_subsets_added(self, layer):
        if not isinstance(layer, Data):
            return
        for subset in layer.subsets:
            self.add_layer(subset)

    def _visible_limits(self, axis):
        """Return the min-max visible data boundaries for given axis"""
        return visible_limits(self.artists, axis)

    def _snap_xlim(self):
        """
        Reset the plotted x rng to show all the data
        """
        is_log = self.xlog
        rng = self._visible_limits(0)
        if rng is None:
            return
        rng = relim(rng[0], rng[1], is_log)
        if self.xflip:
            rng = rng[::-1]
        self.axes.set_xlim(rng)
        self._pull_properties()

    def _snap_ylim(self):
        """
        Reset the plotted y rng to show all the data
        """
        rng = [np.infty, -np.infty]
        is_log = self.ylog

        rng = self._visible_limits(1)
        if rng is None:
            return
        rng = relim(rng[0], rng[1], is_log)

        if self.yflip:
            rng = rng[::-1]
        self.axes.set_ylim(rng)
        self._pull_properties()

    def snap(self):
        """Rescale axes to fit the data"""
        self._snap_xlim()
        self._snap_ylim()
        self._redraw()

    def set_visible(self, layer, state):
        """ Toggle a layer's visibility

        :param layer: which layer to modify
        :type layer: class:`~glue.core.data.Data` or :class:`~glue.coret.Subset`

        :param state: True to show. false to hide
        :type state: boolean
        """
        if layer not in self.artists:
            return
        for a in self.artists[layer]:
            a.visible = state
        self._redraw()

    def is_visible(self, layer):
        if layer not in self.artists:
            return False
        return any(a.visible for a in self.artists[layer])

    def _set_xydata(self, coord, attribute, snap=True):
        """ Redefine which components get assigned to the x/y axes

        :param coord: 'x' or 'y'
           Which axis to reassign
        :param attribute:
           Which attribute of the data to use.
        :type attribute: core.data.ComponentID
        :param snap:
           If True, will rescale x/y axes to fit the data
        :type snap: bool
        """

        if coord not in ('x', 'y'):
            raise TypeError("coord must be one of x,y")
        if not isinstance(attribute, ComponentID):
            raise TypeError("attribute must be a ComponentID")

        # update coordinates of data and subsets
        if coord == 'x':
            new_add = not self._xset
            self.xatt = attribute
            self._xset = self.xatt is not None
        elif coord == 'y':
            new_add = not self._yset
            self.yatt = attribute
            self._yset = self.yatt is not None

        # update plots
        list(map(self._update_layer, self.artists.layers))

        if coord == 'x' and snap:
            self._snap_xlim()
            if new_add:
                self._snap_ylim()
        elif coord == 'y' and snap:
            self._snap_ylim()
            if new_add:
                self._snap_xlim()

        self._update_axis_labels()
        self._pull_properties()
        self._redraw()

    def _process_categorical_roi(self, roi):
        """ Returns a RoiSubsetState object.
        """

        if isinstance(roi, RectangularROI):
            subsets = []
            axes = [('x', roi.xmin, roi.xmax), ('y', roi.ymin, roi.ymax)]

            for coord, lo, hi in axes:
                comp = list(self._get_data_components(coord))
                if comp:
                    if comp[0].categorical:
                        subset = CategoricalROISubsetState.from_range(
                            comp[0], self._get_attribute(coord), lo, hi)
                    else:
                        subset = RangeSubsetState(lo, hi,
                                                  self._get_attribute(coord))
                else:
                    subset = None
                subsets.append(subset)
        else:
            raise AssertionError
        return AndState(*subsets)

    def apply_roi(self, roi):
        # every editable subset is updated
        # using specified ROI

        for x_comp, y_comp in zip(self._get_data_components('x'),
                                  self._get_data_components('y')):
            subset_state = x_comp.subset_from_roi(self.xatt,
                                                  roi,
                                                  other_comp=y_comp,
                                                  other_att=self.yatt,
                                                  coord='x')
            mode = EditSubsetMode()
            visible = [d for d in self._data if self.is_visible(d)]
            focus = visible[0] if len(visible) > 0 else None
            mode.update(self._data, subset_state, focus_data=focus)

    def _set_xlog(self, state):
        """ Set the x axis scaling

        :param state:
            The new scaling for the x axis
        :type state: string ('log' or 'linear')
        """
        mode = 'log' if state else 'linear'
        lim = self.axes.get_xlim()
        self.axes.set_xscale(mode)

        # Rescale if switching to log with negative bounds
        if state and min(lim) <= 0:
            self._snap_xlim()

        self._redraw()

    def _set_ylog(self, state):
        """ Set the y axis scaling

        :param state: The new scaling for the y axis
        :type state: string ('log' or 'linear')
        """
        mode = 'log' if state else 'linear'
        lim = self.axes.get_ylim()
        self.axes.set_yscale(mode)
        # Rescale if switching to log with negative bounds
        if state and min(lim) <= 0:
            self._snap_ylim()

        self._redraw()

    def _remove_data(self, message):
        """Process DataCollectionDeleteMessage"""
        for s in message.data.subsets:
            self.delete_layer(s)
        self.delete_layer(message.data)

    def _remove_subset(self, message):
        self.delete_layer(message.subset)

    def delete_layer(self, layer):
        if layer not in self.artists:
            return
        self.artists.pop(layer)
        self._redraw()
        assert not self.is_layer_present(layer)

    def _update_data(self, message):
        data = message.sender
        self._update_layer(data)

    def _numerical_data_changed(self, message):
        data = message.sender
        self._update_layer(data, force=True)
        for s in data.subsets:
            self._update_layer(s, force=True)

    def _redraw(self):
        self.axes.figure.canvas.draw()

    def _jitter(self, *args):

        for attribute in [self.xatt, self.yatt]:
            if attribute is not None:
                for data in self.data:
                    try:
                        comp = data.get_component(attribute)
                        comp.jitter(method=self.jitter)
                    except (IncompatibleAttribute, NotImplementedError):
                        continue

    def _update_axis_labels(self, *args):
        self.axes.set_xlabel(self.xatt)
        self.axes.set_ylabel(self.yatt)
        if self.xatt is not None:
            update_ticks(self.axes, 'x', list(self._get_data_components('x')),
                         self.xlog)

        if self.yatt is not None:
            update_ticks(self.axes, 'y', list(self._get_data_components('y')),
                         self.ylog)

    def _add_subset(self, message):
        subset = message.sender
        # only add subset if data layer present
        if subset.data not in self.artists:
            return
        subset.do_broadcast(False)
        self.add_layer(subset)
        subset.do_broadcast(True)

    def add_data(self, data):
        result = self.add_layer(data)
        for subset in data.subsets:
            self.add_layer(subset)
        return result

    @property
    def data(self):
        """The data objects in the scatter plot"""
        return list(self._data)

    def _get_attribute(self, coord):
        if coord == 'x':
            return self.xatt
        elif coord == 'y':
            return self.yatt
        else:
            raise TypeError('coord must be x or y')

    def _get_data_components(self, coord):
        """ Returns the components for each dataset for x and y axes.
        """

        attribute = self._get_attribute(coord)

        for data in self._data:
            try:
                yield data.get_component(attribute)
            except IncompatibleAttribute:
                pass

    def _check_categorical(self, attribute):
        """ A simple function to figure out if an attribute is categorical.
        :param attribute: a core.Data.ComponentID
        :return: True iff the attribute represents a CategoricalComponent
        """

        for data in self._data:
            try:
                comp = data.get_component(attribute)
                if comp.categorical:
                    return True
            except IncompatibleAttribute:
                pass
        return False

    def _update_subset(self, message):
        self._update_layer(message.sender)

    def restore_layers(self, layers, context):
        """ Re-generate a list of plot layers from a glue-serialized list"""
        for l in layers:
            cls = lookup_class_with_patches(l.pop('_type'))
            if cls != ScatterLayerArtist:
                raise ValueError("Scatter client cannot restore layer of type "
                                 "%s" % cls)
            props = dict((k, context.object(v)) for k, v in l.items())
            layer = self.add_layer(props['layer'])
            layer.properties = props

    def _update_layer(self, layer, force=False):
        """ Update both the style and data for the requested layer"""
        if self.xatt is None or self.yatt is None:
            return

        if layer not in self.artists:
            return

        self._layer_updated = True
        for art in self.artists[layer]:
            art.xatt = self.xatt
            art.yatt = self.yatt
            art.force_update() if force else art.update()
        self._redraw()

    def _pull_properties(self):
        xlim = self.axes.get_xlim()
        ylim = self.axes.get_ylim()
        xsc = self.axes.get_xscale()
        ysc = self.axes.get_yscale()

        xflip = (xlim[1] < xlim[0])
        yflip = (ylim[1] < ylim[0])

        with delay_callback(self, 'xmin', 'xmax', 'xflip', 'xlog'):
            self.xmin = min(xlim)
            self.xmax = max(xlim)
            self.xflip = xflip
            self.xlog = (xsc == 'log')

        with delay_callback(self, 'ymin', 'ymax', 'yflip', 'ylog'):
            self.ymin = min(ylim)
            self.ymax = max(ylim)
            self.yflip = yflip
            self.ylog = (ysc == 'log')

    def _on_component_replace(self, msg):
        old = msg.old
        new = msg.new

        if self.xatt is old:
            self.xatt = new
        if self.yatt is old:
            self.yatt = new

    def register_to_hub(self, hub):
        super(ScatterClient, self).register_to_hub(hub)
        hub.subscribe(self, ComponentReplacedMessage,
                      self._on_component_replace)
예제 #10
0
파일: viz_client.py 프로젝트: bmorris3/glue
class GenericMplClient(Client):

    """
    This client base class handles the logic of adding, removing,
    and updating layers.

    Subsets are auto-added and removed with datasets.
    New subsets are auto-added iff the data has already been added
    """

    def __init__(self, data=None, figure=None, axes=None,
                 layer_artist_container=None, axes_factory=None):

        super(GenericMplClient, self).__init__(data=data)
        if axes_factory is None:
            axes_factory = self.create_axes
        figure, self.axes = init_mpl(figure, axes, axes_factory=axes_factory)
        self.artists = layer_artist_container
        if self.artists is None:
            self.artists = LayerArtistContainer()

        self._connect()

    def create_axes(self, figure):
        return figure.add_subplot(1, 1, 1)

    def _connect(self):
        pass

    @property
    def collect(self):
        # a better name
        return self.data

    def _redraw(self):
        self.axes.figure.canvas.draw()

    def new_layer_artist(self, layer):
        raise NotImplementedError

    def apply_roi(self, roi):
        raise NotImplementedError

    def _update_layer(self, layer):
        raise NotImplementedError

    def add_layer(self, layer):
        """
        Add a new Data or Subset layer to the plot.

        Returns the created layer artist

        :param layer: The layer to add
        :type layer: :class:`~glue.core.data.Data` or :class:`~glue.core.subset.Subset`
        """
        if layer.data not in self.collect:
            return

        if layer in self.artists:
            return self.artists[layer][0]

        result = self.new_layer_artist(layer)
        self.artists.append(result)
        self._update_layer(layer)

        self.add_layer(layer.data)
        for s in layer.data.subsets:
            self.add_layer(s)

        if layer.data is layer:  # Added Data object. Relimit view
            self.axes.autoscale_view(True, True, True)

        return result

    def remove_layer(self, layer):
        if layer not in self.artists:
            return

        self.artists.pop(layer)
        if isinstance(layer, Data):
            list(map(self.remove_layer, layer.subsets))

        self._redraw()

    def set_visible(self, layer, state):
        """
        Toggle a layer's visibility

        :param layer: which layer to modify
        :param state: True or False
        """

    def _update_all(self):
        for layer in self.artists.layers:
            self._update_layer(layer)

    def __contains__(self, layer):
        return layer in self.artists

    # Hub message handling
    def _add_subset(self, message):
        self.add_layer(message.sender)

    def _remove_subset(self, message):
        self.remove_layer(message.sender)

    def _update_subset(self, message):
        self._update_layer(message.sender)

    def _update_data(self, message):
        self._update_layer(message.sender)

    def _remove_data(self, message):
        self.remove_layer(message.data)

    def restore_layers(self, layers, context):
        """ Re-generate plot layers from a glue-serialized list"""
        for l in layers:
            l.pop('_type')
            props = dict((k, context.object(v)) for k, v in l.items())
            layer = self.add_layer(props['layer'])
            layer.properties = props
예제 #11
0
파일: client.py 프로젝트: bmorris3/glue
class ScatterClient(Client):

    """
    A client class that uses matplotlib to visualize tables as scatter plots.
    """
    xmin = CallbackProperty(0)
    xmax = CallbackProperty(1)
    ymin = CallbackProperty(0)
    ymax = CallbackProperty(1)
    ylog = CallbackProperty(False)
    xlog = CallbackProperty(False)
    yflip = CallbackProperty(False)
    xflip = CallbackProperty(False)
    xatt = CallbackProperty()
    yatt = CallbackProperty()
    jitter = CallbackProperty()

    def __init__(self, data=None, figure=None, axes=None,
                 layer_artist_container=None):
        """
        Create a new ScatterClient object

        :param data: :class:`~glue.core.data.DataCollection` to use

        :param figure:
           Which matplotlib figure instance to draw to. One will be created if
           not provided

        :param axes:
           Which matplotlib axes instance to use. Will be created if necessary
        """
        Client.__init__(self, data=data)
        figure, axes = init_mpl(figure, axes)
        self.artists = layer_artist_container
        if self.artists is None:
            self.artists = LayerArtistContainer()

        self._layer_updated = False  # debugging
        self._xset = False
        self._yset = False
        self.axes = axes

        self._connect()
        self._set_limits()

    def is_layer_present(self, layer):
        """ True if layer is plotted """
        return layer in self.artists

    def get_layer_order(self, layer):
        """If layer exists as a single artist, return its zorder.
        Otherwise, return None"""
        artists = self.artists[layer]
        if len(artists) == 1:
            return artists[0].zorder
        else:
            return None

    @property
    def layer_count(self):
        return len(self.artists)

    def _connect(self):
        add_callback(self, 'xlog', self._set_xlog)
        add_callback(self, 'ylog', self._set_ylog)

        add_callback(self, 'xflip', self._set_limits)
        add_callback(self, 'yflip', self._set_limits)
        add_callback(self, 'xmin', self._set_limits)
        add_callback(self, 'xmax', self._set_limits)
        add_callback(self, 'ymin', self._set_limits)
        add_callback(self, 'ymax', self._set_limits)
        add_callback(self, 'xatt', partial(self._set_xydata, 'x'))
        add_callback(self, 'yatt', partial(self._set_xydata, 'y'))
        add_callback(self, 'jitter', self._jitter)
        self.axes.figure.canvas.mpl_connect('draw_event',
                                            lambda x: self._pull_properties())

    def _set_limits(self, *args):

        xlim = min(self.xmin, self.xmax), max(self.xmin, self.xmax)
        if self.xflip:
            xlim = xlim[::-1]
        ylim = min(self.ymin, self.ymax), max(self.ymin, self.ymax)
        if self.yflip:
            ylim = ylim[::-1]

        xold = self.axes.get_xlim()
        yold = self.axes.get_ylim()
        self.axes.set_xlim(xlim)
        self.axes.set_ylim(ylim)
        if xlim != xold or ylim != yold:
            self._redraw()

    def plottable_attributes(self, layer, show_hidden=False):
        data = layer.data
        comp = data.components if show_hidden else data.visible_components
        return [c for c in comp if
                data.get_component(c).numeric
                or data.get_component(c).categorical]

    def add_layer(self, layer):
        """ Adds a new visual layer to a client, to display either a dataset
        or a subset. Updates both the client data structure and the
        plot.

        Returns the created layer artist

        :param layer: the layer to add
        :type layer: :class:`~glue.core.data.Data` or :class:`~glue.core.subset.Subset`
        """
        if layer.data not in self.data:
            raise TypeError("Layer not in data collection")
        if layer in self.artists:
            return self.artists[layer][0]

        result = ScatterLayerArtist(layer, self.axes)
        self.artists.append(result)
        self._update_layer(layer)
        self._ensure_subsets_added(layer)
        return result

    def _ensure_subsets_added(self, layer):
        if not isinstance(layer, Data):
            return
        for subset in layer.subsets:
            self.add_layer(subset)

    def _visible_limits(self, axis):
        """Return the min-max visible data boundaries for given axis"""
        return visible_limits(self.artists, axis)

    def _snap_xlim(self):
        """
        Reset the plotted x rng to show all the data
        """
        is_log = self.xlog
        rng = self._visible_limits(0)
        if rng is None:
            return
        rng = relim(rng[0], rng[1], is_log)
        if self.xflip:
            rng = rng[::-1]
        self.axes.set_xlim(rng)
        self._pull_properties()

    def _snap_ylim(self):
        """
        Reset the plotted y rng to show all the data
        """
        rng = [np.infty, -np.infty]
        is_log = self.ylog

        rng = self._visible_limits(1)
        if rng is None:
            return
        rng = relim(rng[0], rng[1], is_log)

        if self.yflip:
            rng = rng[::-1]
        self.axes.set_ylim(rng)
        self._pull_properties()

    def snap(self):
        """Rescale axes to fit the data"""
        self._snap_xlim()
        self._snap_ylim()
        self._redraw()

    def set_visible(self, layer, state):
        """ Toggle a layer's visibility

        :param layer: which layer to modify
        :type layer: class:`~glue.core.data.Data` or :class:`~glue.coret.Subset`

        :param state: True to show. false to hide
        :type state: boolean
        """
        if layer not in self.artists:
            return
        for a in self.artists[layer]:
            a.visible = state
        self._redraw()

    def is_visible(self, layer):
        if layer not in self.artists:
            return False
        return any(a.visible for a in self.artists[layer])

    def _set_xydata(self, coord, attribute, snap=True):
        """ Redefine which components get assigned to the x/y axes

        :param coord: 'x' or 'y'
           Which axis to reassign
        :param attribute:
           Which attribute of the data to use.
        :type attribute: core.data.ComponentID
        :param snap:
           If True, will rescale x/y axes to fit the data
        :type snap: bool
        """

        if coord not in ('x', 'y'):
            raise TypeError("coord must be one of x,y")
        if not isinstance(attribute, ComponentID):
            raise TypeError("attribute must be a ComponentID")

        # update coordinates of data and subsets
        if coord == 'x':
            new_add = not self._xset
            self.xatt = attribute
            self._xset = self.xatt is not None
        elif coord == 'y':
            new_add = not self._yset
            self.yatt = attribute
            self._yset = self.yatt is not None

        # update plots
        list(map(self._update_layer, self.artists.layers))

        if coord == 'x' and snap:
            self._snap_xlim()
            if new_add:
                self._snap_ylim()
        elif coord == 'y' and snap:
            self._snap_ylim()
            if new_add:
                self._snap_xlim()

        self._update_axis_labels()
        self._pull_properties()
        self._redraw()

    def apply_roi(self, roi):
        # every editable subset is updated
        # using specified ROI

        for x_comp, y_comp in zip(self._get_data_components('x'),
                                  self._get_data_components('y')):
            subset_state = x_comp.subset_from_roi(self.xatt, roi,
                                                  other_comp=y_comp,
                                                  other_att=self.yatt,
                                                  coord='x')
            mode = EditSubsetMode()
            visible = [d for d in self._data if self.is_visible(d)]
            focus = visible[0] if len(visible) > 0 else None
            mode.update(self._data, subset_state, focus_data=focus)

    def _set_xlog(self, state):
        """ Set the x axis scaling

        :param state:
            The new scaling for the x axis
        :type state: string ('log' or 'linear')
        """
        mode = 'log' if state else 'linear'
        lim = self.axes.get_xlim()
        self.axes.set_xscale(mode)

        # Rescale if switching to log with negative bounds
        if state and min(lim) <= 0:
            self._snap_xlim()

        self._redraw()

    def _set_ylog(self, state):
        """ Set the y axis scaling

        :param state: The new scaling for the y axis
        :type state: string ('log' or 'linear')
        """
        mode = 'log' if state else 'linear'
        lim = self.axes.get_ylim()
        self.axes.set_yscale(mode)
        # Rescale if switching to log with negative bounds
        if state and min(lim) <= 0:
            self._snap_ylim()

        self._redraw()

    def _remove_data(self, message):
        """Process DataCollectionDeleteMessage"""
        for s in message.data.subsets:
            self.delete_layer(s)
        self.delete_layer(message.data)

    def _remove_subset(self, message):
        self.delete_layer(message.subset)

    def delete_layer(self, layer):
        if layer not in self.artists:
            return
        self.artists.pop(layer)
        self._redraw()
        assert not self.is_layer_present(layer)

    def _update_data(self, message):
        data = message.sender
        self._update_layer(data)

    def _numerical_data_changed(self, message):
        data = message.sender
        self._update_layer(data, force=True)
        for s in data.subsets:
            self._update_layer(s, force=True)

    def _redraw(self):
        self.axes.figure.canvas.draw()

    def _jitter(self, *args):

        for attribute in [self.xatt, self.yatt]:
            if attribute is not None:
                for data in self.data:
                    try:
                        comp = data.get_component(attribute)
                        comp.jitter(method=self.jitter)
                    except (IncompatibleAttribute, NotImplementedError):
                        continue

    def _update_axis_labels(self, *args):
        self.axes.set_xlabel(self.xatt)
        self.axes.set_ylabel(self.yatt)
        if self.xatt is not None:
            update_ticks(self.axes, 'x',
                         list(self._get_data_components('x')),
                         self.xlog)

        if self.yatt is not None:
            update_ticks(self.axes, 'y',
                         list(self._get_data_components('y')),
                         self.ylog)

    def _add_subset(self, message):
        subset = message.sender
        # only add subset if data layer present
        if subset.data not in self.artists:
            return
        subset.do_broadcast(False)
        self.add_layer(subset)
        subset.do_broadcast(True)

    def add_data(self, data):
        result = self.add_layer(data)
        for subset in data.subsets:
            self.add_layer(subset)
        return result

    @property
    def data(self):
        """The data objects in the scatter plot"""
        return list(self._data)

    def _get_attribute(self, coord):
        if coord == 'x':
            return self.xatt
        elif coord == 'y':
            return self.yatt
        else:
            raise TypeError('coord must be x or y')

    def _get_data_components(self, coord):
        """ Returns the components for each dataset for x and y axes.
        """

        attribute = self._get_attribute(coord)

        for data in self._data:
            try:
                yield data.get_component(attribute)
            except IncompatibleAttribute:
                pass

    def _check_categorical(self, attribute):
        """ A simple function to figure out if an attribute is categorical.
        :param attribute: a core.Data.ComponentID
        :return: True iff the attribute represents a CategoricalComponent
        """

        for data in self._data:
            try:
                comp = data.get_component(attribute)
                if comp.categorical:
                    return True
            except IncompatibleAttribute:
                pass
        return False

    def _update_subset(self, message):
        self._update_layer(message.sender)

    def restore_layers(self, layers, context):
        """ Re-generate a list of plot layers from a glue-serialized list"""
        for l in layers:
            cls = lookup_class_with_patches(l.pop('_type'))
            if cls != ScatterLayerArtist:
                raise ValueError("Scatter client cannot restore layer of type "
                                 "%s" % cls)
            props = dict((k, context.object(v)) for k, v in l.items())
            layer = self.add_layer(props['layer'])
            layer.properties = props

    def _update_layer(self, layer, force=False):
        """ Update both the style and data for the requested layer"""
        if self.xatt is None or self.yatt is None:
            return

        if layer not in self.artists:
            return

        self._layer_updated = True
        for art in self.artists[layer]:
            art.xatt = self.xatt
            art.yatt = self.yatt
            art.force_update() if force else art.update()
        self._redraw()

    def _pull_properties(self):
        xlim = self.axes.get_xlim()
        ylim = self.axes.get_ylim()
        xsc = self.axes.get_xscale()
        ysc = self.axes.get_yscale()

        xflip = (xlim[1] < xlim[0])
        yflip = (ylim[1] < ylim[0])

        with delay_callback(self, 'xmin', 'xmax', 'xflip', 'xlog'):
            self.xmin = min(xlim)
            self.xmax = max(xlim)
            self.xflip = xflip
            self.xlog = (xsc == 'log')

        with delay_callback(self, 'ymin', 'ymax', 'yflip', 'ylog'):
            self.ymin = min(ylim)
            self.ymax = max(ylim)
            self.yflip = yflip
            self.ylog = (ysc == 'log')

    def _on_component_replace(self, msg):
        old = msg.old
        new = msg.new

        if self.xatt is old:
            self.xatt = new
        if self.yatt is old:
            self.yatt = new

    def register_to_hub(self, hub):
        super(ScatterClient, self).register_to_hub(hub)
        hub.subscribe(self, ComponentReplacedMessage, self._on_component_replace)
예제 #12
0
파일: client.py 프로젝트: astrofrog/glue
class ImageClient(VizClient):

    display_data = CallbackProperty(None)
    display_attribute = CallbackProperty(None)
    display_aspect = CallbackProperty('equal')

    def __init__(self, data, layer_artist_container=None):

        VizClient.__init__(self, data)

        self.artists = layer_artist_container
        if self.artists is None:
            self.artists = LayerArtistContainer()

        # slice through ND cube
        # ('y', 'x', 2)
        # means current data slice is [:, :, 2], and axis=0 is vertical on plot
        self._slice = None

        # how to extract a downsampled/cropped 2D image to plot
        # (ComponentID, slice, slice, ...)
        self._view = None

        # cropped/downsampled image
        # self._image == self.display_data[self._view]
        self._image = None

        # if this is set, render this instead of self._image
        self._override_image = None

        # maps attributes -> normalization settings
        self._norm_cache = {}

    def point_details(self, x, y):
        if self.display_data is None:
            return dict(labels=['x=%s' % x, 'y=%s' % y],
                        pix=(x, y), world=(x, y), value=np.nan)

        data = self.display_data
        pix = self._pixel_coords(x, y)
        labels = self.coordinate_labels(pix)
        world = data.coords.pixel2world(*pix[::-1])
        world = world[::-1]  # reverse for numpy convention

        view = []
        for p, s in zip(pix, data.shape):
            p = int(p)
            if not (0 <= p < s):
                value = None
                break
            view.append(slice(p, p + 1))
        else:
            if self._override_image is None:
                value = self.display_data[self.display_attribute, view]
            else:
                value = self._override_image[int(y), int(x)]

            value = value.ravel()[0]

        return dict(pix=pix, world=world, labels=labels, value=value)

    def coordinate_labels(self, pix):
        """
        Return human-readable labels for a position in pixel coords

        Parameters
        ----------
        pix : tuple of int
            Pixel coordinates of point in the data. Note that pix describes a
            position in the *data*, not necessarily the image display.

        Returns
        -------
        list
            A list of strings for each coordinate axis, of the form
            ``axis_label_name=world_coordinate_value``
        """
        data = self.display_data
        if data is None:
            return []

        world = data.coords.pixel2world(*pix[::-1])
        world = world[::-1]   # reverse for numpy convention
        labels = ['%s=%s' % (data.get_world_component_id(i).label, w)
                  for i, w in enumerate(world)]
        return labels

    @callback_property
    def slice(self):
        """
        Returns a tuple describing the current slice through the data

        The tuple has length equal to the dimensionality of the display
        data. Each entry is either:

        * 'x' if the dimension is mapped to the X image axis
        * 'y' if the dimension is mapped to the Y image axis
        * a number, indicating which fixed slice the dimension is restricted to
        """
        if self._slice is not None:
            return self._slice

        if self.display_data is None:
            return tuple()
        ndim = self.display_data.ndim
        if ndim == 1:
            self._slice = ('x',)
        elif ndim == 2:
            self._slice = ('y', 'x')
        else:
            self._slice = (0,) * (ndim - 2) + ('y', 'x')

        return self._slice

    @slice.setter
    @defer_draw
    def slice(self, value):
        if self.slice == tuple(value):
            return

        if value == tuple():
            return

        relim = value.index('x') != self._slice.index('x') or \
            value.index('y') != self._slice.index('y')

        self._slice = tuple(value)
        self._clear_override()
        self._update_axis_labels()
        self._update_data_plot(relim=relim)
        self._update_subset_plots()
        self._update_scatter_plots()
        self._redraw()

    @property
    def is_3D(self):
        """
        Returns True if the display data has 3 dimensions
        """
        if not self.display_data:
            return False
        return len(self.display_data.shape) == 3

    @property
    def slice_ind(self):
        """
        For 3D data, returns the pixel index of the current slice.
        Otherwise, returns `None`.
        """
        if self.is_3D:
            for s in self.slice:
                if s not in ['x', 'y']:
                    return s
        return None

    @property
    def image(self):
        return self._image

    @requires_data
    def override_image(self, image):
        """
        Temporarily override the current slice view with another image (i.e.,
        an aggregate).
        """
        self._override_image = image
        for a in self.artists[self.display_data]:
            if isinstance(a, ImageLayerBase):
                a.override_image(image)
        self._update_data_plot()
        self._redraw()

    def _clear_override(self):
        self._override_image = None
        for a in self.artists[self.display_data]:
            if isinstance(a, ImageLayerBase):
                a.clear_override()

    @slice_ind.setter
    @defer_draw
    def slice_ind(self, value):
        if self.is_3D:
            slc = [s if s in ['x', 'y'] else value for s in self.slice]
            self.slice = slc
            self._update_data_plot()
            self._update_subset_plots()
            self._update_scatter_plots()
            self._redraw()
        else:
            raise IndexError("Can only set slice_ind for 3D images")

    def can_image_data(self, data):
        return data.ndim > 1

    def _ensure_data_present(self, data):
        if data not in self.artists:
            self.add_layer(data)

    @defer_draw
    def set_data(self, data, attribute=None):
        if not self.can_image_data(data):
            return

        self._ensure_data_present(data)
        self._slice = None

        attribute = attribute or _default_component(data)

        self.display_data = data
        self.display_attribute = attribute
        self._update_axis_labels()
        self._update_data_plot(relim=True)
        self._update_subset_plots()
        self._update_scatter_plots()
        self._redraw()

    def set_attribute(self, attribute):
        if not self.display_data or \
                attribute not in self.display_data.component_ids():
            raise IncompatibleAttribute(
                "Attribute not in data's attributes: %s" % attribute)
        if self.display_attribute is not None:
            self._norm_cache[self.display_attribute] = self.get_norm()

        self.display_attribute = attribute

        if attribute in self._norm_cache:
            self.set_norm(norm=self._norm_cache[attribute])
        else:
            self.clear_norm()

        self._update_data_plot()
        self._redraw()

    def _redraw(self):
        """
        Re-render the screen.
        """
        pass

    @requires_data
    @defer_draw
    def set_norm(self, **kwargs):
        for a in self.artists[self.display_data]:
            a.set_norm(**kwargs)
        self._update_data_plot()
        self._redraw()

    @requires_data
    def clear_norm(self):
        for a in self.artists[self.display_data]:
            a.clear_norm()

    @requires_data
    def get_norm(self):
        a = self.artists[self.display_data][0]
        return a.norm

    @requires_data
    @defer_draw
    def set_cmap(self, cmap):
        for a in self.artists[self.display_data]:
            a.cmap = cmap
            a.redraw()

    def _build_view(self):
        att = self.display_attribute
        shp = self.display_data.shape
        x, y = np.s_[:], np.s_[:]
        slc = list(self.slice)
        slc[slc.index('x')] = x
        slc[slc.index('y')] = y
        return (att,) + tuple(slc)

    @requires_data
    def _numerical_data_changed(self, message):
        data = message.sender
        self._update_data_plot(force=True)
        self._update_scatter_layer(data)

        for s in data.subsets:
            self._update_subset_single(s, force=True)

        self._redraw()

    @requires_data
    def _update_data_plot(self, relim=False, force=False):
        """
        Re-sync the main image and its subsets.
        """

        if relim:
            self.relim()

        view = self._build_view()
        self._image = self.display_data[view]
        transpose = self.slice.index('x') < self.slice.index('y')

        self._view = view
        for a in list(self.artists):
            if (not isinstance(a, ScatterLayerBase) and
                    a.layer.data is not self.display_data):
                self.artists.remove(a)
            else:
                if isinstance(a, ImageLayerArtist):
                    a.update(view, transpose, aspect=self.display_aspect)
                else:
                    a.update(view, transpose)
        for a in self.artists[self.display_data]:
            meth = a.update if not force else a.force_update
            if isinstance(a, ImageLayerArtist):
                meth(view, transpose=transpose, aspect=self.display_aspect)
            else:
                meth(view, transpose=transpose)

    def _update_subset_single(self, s, redraw=False, force=False):
        """
        Update the location and visual properties of each point in a single
        subset.

        Parameters
        ----------
        s: `~glue.core.subset.Subset`
            The subset to refresh.
        """
        logging.getLogger(__name__).debug("update subset single: %s", s)

        if s not in self.artists:
            return

        self._update_scatter_layer(s)

        if s.data is not self.display_data:
            return

        view = self._build_view()
        transpose = self.slice.index('x') < self.slice.index('y')
        for a in self.artists[s]:
            meth = a.update if not force else a.force_update
            if isinstance(a, SubsetImageLayerArtist):
                meth(view, transpose=transpose, aspect=self.display_aspect)
            else:
                meth(view, transpose=transpose)

        if redraw:
            self._redraw()

    @property
    def _slice_ori(self):
        if not self.is_3D:
            return None
        for i, s in enumerate(self.slice):
            if s not in ['x', 'y']:
                return i

    @requires_data
    @defer_draw
    def apply_roi(self, roi):

        subset_state = RoiSubsetState()
        xroi, yroi = roi.to_polygon()
        x, y = self._get_plot_attributes()
        subset_state.xatt = x
        subset_state.yatt = y
        subset_state.roi = PolygonalROI(xroi, yroi)
        mode = EditSubsetMode()
        mode.update(self.data, subset_state, focus_data=self.display_data)

    def _remove_subset(self, message):
        self.delete_layer(message.sender)

    def delete_layer(self, layer):
        if layer not in self.artists:
            return
        for a in self.artists.pop(layer):
            a.clear()

        if isinstance(layer, Data):
            for subset in layer.subsets:
                self.delete_layer(subset)

        if layer is self.display_data:
            for layer in self.artists:
                if isinstance(layer, ImageLayerArtist):
                    self.display_data = layer.data
                    break
            else:
                for artist in self.artists:
                    self.delete_layer(artist.layer)
                self.display_data = None
                self.display_attribute = None

        self._redraw()

    def _remove_data(self, message):
        self.delete_layer(message.data)
        for s in message.data.subsets:
            self.delete_layer(s)

    def init_layer(self, layer):
        # only auto-add subsets if they are of the main image
        if isinstance(layer, Subset) and layer.data is not self.display_data:
            return
        self.add_layer(layer)

    def rgb_mode(self, enable=None):
        """
        Query whether RGB mode is enabled, or toggle RGB mode.

        Parameters
        ----------
        enable : bool or None
            If `True` or `False`, explicitly enable/disable RGB mode.
            If `None`, check if RGB mode is enabled

        Returns
        -------
        LayerArtist or None
            If RGB mode is enabled, returns an ``RGBImageLayerBase``.
            If ``enable`` is `False`, return the new ``ImageLayerArtist``
        """
        # XXX need to better handle case where two RGBImageLayerArtists
        #    are created

        if enable is None:
            for a in self.artists:
                if isinstance(a, RGBImageLayerBase):
                    return a
            return None

        result = None
        layer = self.display_data
        if enable:
            layer = self.display_data
            a = self._new_rgb_layer(layer)
            if a is None:
                return

            a.r = a.g = a.b = self.display_attribute

            with self.artists.ignore_empty():
                self.artists.pop(layer)
                self.artists.append(a)
            result = a
        else:
            with self.artists.ignore_empty():
                for artist in list(self.artists):
                    if isinstance(artist, RGBImageLayerBase):
                        self.artists.remove(artist)
                result = self.add_layer(layer)

        self._update_data_plot()
        self._redraw()
        return result

    def _update_aspect(self):
        self._update_data_plot(relim=True)
        self._redraw()

    def add_layer(self, layer):
        if layer in self.artists:
            return self.artists[layer][0]

        if layer.data not in self.data:
            raise TypeError("Data not managed by client's data collection")

        if not self.can_image_data(layer.data):
            # if data is 1D, try to scatter plot
            if len(layer.data.shape) == 1:
                return self.add_scatter_layer(layer)
            logging.getLogger(__name__).warning(
                "Cannot visualize %s. Aborting", layer.label)
            return

        if isinstance(layer, Data):
            result = self._new_image_layer(layer)
            self.artists.append(result)
            for s in layer.subsets:
                self.add_layer(s)
            self.set_data(layer)
        elif isinstance(layer, Subset):
            result = self._new_subset_image_layer(layer)
            self.artists.append(result)
            self._update_subset_single(layer)
        else:
            raise TypeError("Unrecognized layer type: %s" % type(layer))

        return result

    def add_scatter_layer(self, layer):
        logging.getLogger(
            __name__).debug('Adding scatter layer for %s' % layer)
        if layer in self.artists:
            logging.getLogger(__name__).debug('Layer already present')
            return

        result = self._new_scatter_layer(layer)
        self.artists.append(result)
        self._update_scatter_layer(layer)
        return result

    def _update_scatter_plots(self):
        for layer in self.artists.layers:
            self._update_scatter_layer(layer)

    @requires_data
    def _update_scatter_layer(self, layer, force=False):

        if layer not in self.artists:
            return

        xatt, yatt = self._get_plot_attributes()
        need_redraw = False

        for a in self.artists[layer]:
            if not isinstance(a, ScatterLayerBase):
                continue
            need_redraw = True
            a.xatt = xatt
            a.yatt = yatt
            if self.is_3D:
                zatt = self.display_data.get_pixel_component_id(
                    self._slice_ori)
                subset = (
                    zatt > self.slice_ind) & (zatt <= self.slice_ind + 1)
                a.emphasis = subset
            else:
                a.emphasis = None
            a.update() if not force else a.force_update()
            a.redraw()

        if need_redraw:
            self._redraw()

    @requires_data
    def _get_plot_attributes(self):
        x, y = _slice_axis(self.display_data.shape, self.slice)
        ids = self.display_data.pixel_component_ids
        return ids[x], ids[y]

    def _pixel_coords(self, x, y):
        """
        From a slice coordinate (x,y), return the full (possibly >2D) numpy
        index into the full data.

        .. note:: The inputs to this function are the reverse of numpy
                  convention (horizontal axis first, then vertical)

        Returns
        -------
        coords : tuple
            Either a tuple of (x,y) or (x,y,z)
        """
        result = list(self.slice)
        result[result.index('x')] = x
        result[result.index('y')] = y
        return result

    def is_visible(self, layer):
        return all(a.visible for a in self.artists[layer])

    def set_visible(self, layer, state):
        for a in self.artists[layer]:
            a.visible = state

    def set_slice_ori(self, ori):
        if not self.is_3D:
            raise IndexError("Can only set slice_ori for 3D images")
        if ori == 0:
            self.slice = (0, 'y', 'x')
        elif ori == 1:
            self.slice = ('y', 0, 'x')
        elif ori == 2:
            self.slice = ('y', 'x', 0)
        else:
            raise ValueError("Orientation must be 0, 1, or 2")

    def restore_layers(self, layers, context):
        """
        Restore a list of glue-serialized layer dicts.
        """
        for layer in layers:
            c = lookup_class_with_patches(layer.pop('_type'))
            props = dict((k, v if k == 'stretch' else context.object(v))
                         for k, v in layer.items())
            l = props['layer']
            if issubclass(c, ScatterLayerBase):
                l = self.add_scatter_layer(l)
            elif issubclass(c, RGBImageLayerBase):
                r = props.pop('r')
                g = props.pop('g')
                b = props.pop('b')
                self.display_data = l
                self.display_attribute = r
                l = self.rgb_mode(True)
                l.r = r
                l.g = g
                l.b = b
            elif issubclass(c, (ImageLayerBase, SubsetImageLayerBase)):
                if isinstance(l, Data):
                    self.set_data(l)
                l = self.add_layer(l)
            else:
                raise ValueError("Cannot restore layer of type %s" % l)
            l.properties = props

    def _on_component_replace(self, msg):
        if self.display_attribute is msg.old:
            self.display_attribute = msg.new

    def register_to_hub(self, hub):
        super(ImageClient, self).register_to_hub(hub)
        hub.subscribe(self,
                      ComponentReplacedMessage,
                      self._on_component_replace)

    # subclasses should override the following methods as appropriate
    def _new_rgb_layer(self, layer):
        """
        Construct and return an RGBImageLayerBase for the given layer

        Parameters
        ----------
        layer : :class:`~glue.core.data.Data` or :class:`~glue.core.subset.Subset`
            Which object to visualize
        """
        raise NotImplementedError()

    def _new_subset_image_layer(self, layer):
        """
        Construct and return a SubsetImageLayerArtist for the given layer

        Parameters
        ----------
        layer : :class:`~glue.core.data.Data` or :class:`~glue.core.subset.Subset`
            Which object to visualize
        """
        raise NotImplementedError()

    def _new_image_layer(self, layer):
        """
        Construct and return an ImageLayerArtist for the given layer

        Parameters
        ----------
        layer : :class:`~glue.core.data.Data` or :class:`~glue.core.subset.Subset`
            Which object to visualize
        """
        raise NotImplementedError()

    def _new_scatter_layer(self, layer):
        """
        Construct and return a ScatterLayerArtist for the given layer

        Parameters
        ----------
        layer : :class:`~glue.core.data.Data` or :class:`~glue.core.subset.Subset`
            Which object to visualize
        """
        raise NotImplementedError()

    def _update_axis_labels(self):
        """
        Sync the displays for labels on X/Y axes, because the data or slice has
        changed
        """
        raise NotImplementedError()

    def relim(self):
        """
        Reset view window to the default pan/zoom setting.
        """
        pass

    def show_crosshairs(self, x, y):
        pass

    def clear_crosshairs(self):
        pass
예제 #13
0
class ImageClient(VizClient):

    display_data = CallbackProperty(None)
    display_attribute = CallbackProperty(None)
    display_aspect = CallbackProperty('equal')

    def __init__(self, data, layer_artist_container=None):

        VizClient.__init__(self, data)

        self.artists = layer_artist_container
        if self.artists is None:
            self.artists = LayerArtistContainer()

        # slice through ND cube
        # ('y', 'x', 2)
        # means current data slice is [:, :, 2], and axis=0 is vertical on plot
        self._slice = None

        # how to extract a downsampled/cropped 2D image to plot
        # (ComponentID, slice, slice, ...)
        self._view = None

        # cropped/downsampled image
        # self._image == self.display_data[self._view]
        self._image = None

        # if this is set, render this instead of self._image
        self._override_image = None

        # maps attributes -> normalization settings
        self._norm_cache = {}

    def point_details(self, x, y):
        if self.display_data is None:
            return dict(labels=['x=%s' % x, 'y=%s' % y],
                        pix=(x, y),
                        world=(x, y),
                        value=np.nan)

        data = self.display_data
        pix = self._pixel_coords(x, y)
        labels = self.coordinate_labels(pix)
        world = data.coords.pixel2world(*pix[::-1])
        world = world[::-1]  # reverse for numpy convention

        view = []
        for p, s in zip(pix, data.shape):
            p = int(p)
            if not (0 <= p < s):
                value = None
                break
            view.append(slice(p, p + 1))
        else:
            if self._override_image is None:
                value = self.display_data[self.display_attribute, view]
            else:
                value = self._override_image[int(y), int(x)]

            value = value.ravel()[0]

        return dict(pix=pix, world=world, labels=labels, value=value)

    def coordinate_labels(self, pix):
        """
        Return human-readable labels for a position in pixel coords

        Parameters
        ----------
        pix : tuple of int
            Pixel coordinates of point in the data. Note that pix describes a
            position in the *data*, not necessarily the image display.

        Returns
        -------
        list
            A list of strings for each coordinate axis, of the form
            ``axis_label_name=world_coordinate_value``
        """
        data = self.display_data
        if data is None:
            return []

        world = data.coords.pixel2world(*pix[::-1])
        world = world[::-1]  # reverse for numpy convention
        labels = [
            '%s=%s' % (data.get_world_component_id(i).label, w)
            for i, w in enumerate(world)
        ]
        return labels

    @callback_property
    def slice(self):
        """
        Returns a tuple describing the current slice through the data

        The tuple has length equal to the dimensionality of the display
        data. Each entry is either:

        * 'x' if the dimension is mapped to the X image axis
        * 'y' if the dimension is mapped to the Y image axis
        * a number, indicating which fixed slice the dimension is restricted to
        """
        if self._slice is not None:
            return self._slice

        if self.display_data is None:
            return tuple()
        ndim = self.display_data.ndim
        if ndim == 1:
            self._slice = ('x', )
        elif ndim == 2:
            self._slice = ('y', 'x')
        else:
            self._slice = (0, ) * (ndim - 2) + ('y', 'x')

        return self._slice

    @slice.setter
    @defer_draw
    def slice(self, value):
        if self.slice == tuple(value):
            return

        if value == tuple():
            return

        relim = value.index('x') != self._slice.index('x') or \
            value.index('y') != self._slice.index('y')

        self._slice = tuple(value)
        self._clear_override()
        self._update_axis_labels()
        self._update_data_plot(relim=relim)
        self._update_subset_plots()
        self._update_scatter_plots()
        self._redraw()

    @property
    def is_3D(self):
        """
        Returns True if the display data has 3 dimensions
        """
        if not self.display_data:
            return False
        return len(self.display_data.shape) == 3

    @property
    def slice_ind(self):
        """
        For 3D data, returns the pixel index of the current slice.
        Otherwise, returns `None`.
        """
        if self.is_3D:
            for s in self.slice:
                if s not in ['x', 'y']:
                    return s
        return None

    @property
    def image(self):
        return self._image

    @requires_data
    def override_image(self, image):
        """
        Temporarily override the current slice view with another image (i.e.,
        an aggregate).
        """
        self._override_image = image
        for a in self.artists[self.display_data]:
            if isinstance(a, ImageLayerBase):
                a.override_image(image)
        self._update_data_plot()
        self._redraw()

    def _clear_override(self):
        self._override_image = None
        for a in self.artists[self.display_data]:
            if isinstance(a, ImageLayerBase):
                a.clear_override()

    @slice_ind.setter
    @defer_draw
    def slice_ind(self, value):
        if self.is_3D:
            slc = [s if s in ['x', 'y'] else value for s in self.slice]
            self.slice = slc
            self._update_data_plot()
            self._update_subset_plots()
            self._update_scatter_plots()
            self._redraw()
        else:
            raise IndexError("Can only set slice_ind for 3D images")

    def can_image_data(self, data):
        return data.ndim > 1

    def _ensure_data_present(self, data):
        if data not in self.artists:
            self.add_layer(data)

    @defer_draw
    def set_data(self, data, attribute=None):
        if not self.can_image_data(data):
            return

        self._ensure_data_present(data)
        self._slice = None

        attribute = attribute or _default_component(data)

        self.display_data = data
        self.display_attribute = attribute
        self._update_axis_labels()
        self._update_data_plot(relim=True)
        self._update_subset_plots()
        self._update_scatter_plots()
        self._redraw()

    def set_attribute(self, attribute):
        if not self.display_data or \
                attribute not in self.display_data.component_ids():
            raise IncompatibleAttribute(
                "Attribute not in data's attributes: %s" % attribute)
        if self.display_attribute is not None:
            self._norm_cache[self.display_attribute] = self.get_norm()

        self.display_attribute = attribute

        if attribute in self._norm_cache:
            self.set_norm(norm=self._norm_cache[attribute])
        else:
            self.clear_norm()

        self._update_data_plot()
        self._redraw()

    def _redraw(self):
        """
        Re-render the screen.
        """
        pass

    @requires_data
    @defer_draw
    def set_norm(self, **kwargs):
        for a in self.artists[self.display_data]:
            a.set_norm(**kwargs)
        self._update_data_plot()
        self._redraw()

    @requires_data
    def clear_norm(self):
        for a in self.artists[self.display_data]:
            a.clear_norm()

    @requires_data
    def get_norm(self):
        a = self.artists[self.display_data][0]
        return a.norm

    @requires_data
    @defer_draw
    def set_cmap(self, cmap):
        for a in self.artists[self.display_data]:
            a.cmap = cmap
            a.redraw()

    def _build_view(self):
        att = self.display_attribute
        shp = self.display_data.shape
        x, y = np.s_[:], np.s_[:]
        slc = list(self.slice)
        slc[slc.index('x')] = x
        slc[slc.index('y')] = y
        return (att, ) + tuple(slc)

    @requires_data
    def _numerical_data_changed(self, message):
        data = message.sender
        self._update_data_plot(force=True)
        self._update_scatter_layer(data)

        for s in data.subsets:
            self._update_subset_single(s, force=True)

        self._redraw()

    @requires_data
    def _update_data_plot(self, relim=False, force=False):
        """
        Re-sync the main image and its subsets.
        """

        if relim:
            self.relim()

        view = self._build_view()
        self._image = self.display_data[view]
        transpose = self.slice.index('x') < self.slice.index('y')

        self._view = view
        for a in list(self.artists):
            if (not isinstance(a, ScatterLayerBase)
                    and a.layer.data is not self.display_data):
                self.artists.remove(a)
            else:
                if isinstance(a, ImageLayerArtist):
                    a.update(view, transpose, aspect=self.display_aspect)
                else:
                    a.update(view, transpose)
        for a in self.artists[self.display_data]:
            meth = a.update if not force else a.force_update
            if isinstance(a, ImageLayerArtist):
                meth(view, transpose=transpose, aspect=self.display_aspect)
            else:
                meth(view, transpose=transpose)

    def _update_subset_single(self, s, redraw=False, force=False):
        """
        Update the location and visual properties of each point in a single
        subset.

        Parameters
        ----------
        s: `~glue.core.subset.Subset`
            The subset to refresh.
        """
        logging.getLogger(__name__).debug("update subset single: %s", s)

        if s not in self.artists:
            return

        self._update_scatter_layer(s)

        if s.data is not self.display_data:
            return

        view = self._build_view()
        transpose = self.slice.index('x') < self.slice.index('y')
        for a in self.artists[s]:
            meth = a.update if not force else a.force_update
            if isinstance(a, SubsetImageLayerArtist):
                meth(view, transpose=transpose, aspect=self.display_aspect)
            else:
                meth(view, transpose=transpose)

        if redraw:
            self._redraw()

    @property
    def _slice_ori(self):
        if not self.is_3D:
            return None
        for i, s in enumerate(self.slice):
            if s not in ['x', 'y']:
                return i

    @requires_data
    @defer_draw
    def apply_roi(self, roi):

        subset_state = RoiSubsetState()
        xroi, yroi = roi.to_polygon()
        x, y = self._get_plot_attributes()
        subset_state.xatt = x
        subset_state.yatt = y
        subset_state.roi = PolygonalROI(xroi, yroi)
        mode = EditSubsetMode()
        mode.update(self.data, subset_state, focus_data=self.display_data)

    def _remove_subset(self, message):
        self.delete_layer(message.sender)

    def delete_layer(self, layer):
        if layer not in self.artists:
            return
        for a in self.artists.pop(layer):
            a.clear()

        if isinstance(layer, Data):
            for subset in layer.subsets:
                self.delete_layer(subset)

        if layer is self.display_data:
            for layer in self.artists:
                if isinstance(layer, ImageLayerArtist):
                    self.display_data = layer.data
                    break
            else:
                for artist in self.artists:
                    self.delete_layer(artist.layer)
                self.display_data = None
                self.display_attribute = None

        self._redraw()

    def _remove_data(self, message):
        self.delete_layer(message.data)
        for s in message.data.subsets:
            self.delete_layer(s)

    def init_layer(self, layer):
        # only auto-add subsets if they are of the main image
        if isinstance(layer, Subset) and layer.data is not self.display_data:
            return
        self.add_layer(layer)

    def rgb_mode(self, enable=None):
        """
        Query whether RGB mode is enabled, or toggle RGB mode.

        Parameters
        ----------
        enable : bool or None
            If `True` or `False`, explicitly enable/disable RGB mode.
            If `None`, check if RGB mode is enabled

        Returns
        -------
        LayerArtist or None
            If RGB mode is enabled, returns an ``RGBImageLayerBase``.
            If ``enable`` is `False`, return the new ``ImageLayerArtist``
        """
        # XXX need to better handle case where two RGBImageLayerArtists
        #    are created

        if enable is None:
            for a in self.artists:
                if isinstance(a, RGBImageLayerBase):
                    return a
            return None

        result = None
        layer = self.display_data
        if enable:
            layer = self.display_data
            a = self._new_rgb_layer(layer)
            if a is None:
                return

            a.r = a.g = a.b = self.display_attribute

            with self.artists.ignore_empty():
                self.artists.pop(layer)
                self.artists.append(a)
            result = a
        else:
            with self.artists.ignore_empty():
                for artist in list(self.artists):
                    if isinstance(artist, RGBImageLayerBase):
                        self.artists.remove(artist)
                result = self.add_layer(layer)

        self._update_data_plot()
        self._redraw()
        return result

    def _update_aspect(self):
        self._update_data_plot(relim=True)
        self._redraw()

    def add_layer(self, layer):
        if layer in self.artists:
            return self.artists[layer][0]

        if layer.data not in self.data:
            raise TypeError("Data not managed by client's data collection")

        if not self.can_image_data(layer.data):
            # if data is 1D, try to scatter plot
            if len(layer.data.shape) == 1:
                return self.add_scatter_layer(layer)
            logging.getLogger(__name__).warning(
                "Cannot visualize %s. Aborting", layer.label)
            return

        if isinstance(layer, Data):
            result = self._new_image_layer(layer)
            self.artists.append(result)
            for s in layer.subsets:
                self.add_layer(s)
            self.set_data(layer)
        elif isinstance(layer, Subset):
            result = self._new_subset_image_layer(layer)
            self.artists.append(result)
            self._update_subset_single(layer)
        else:
            raise TypeError("Unrecognized layer type: %s" % type(layer))

        return result

    def add_scatter_layer(self, layer):
        logging.getLogger(__name__).debug('Adding scatter layer for %s' %
                                          layer)
        if layer in self.artists:
            logging.getLogger(__name__).debug('Layer already present')
            return

        result = self._new_scatter_layer(layer)
        self.artists.append(result)
        self._update_scatter_layer(layer)
        return result

    def _update_scatter_plots(self):
        for layer in self.artists.layers:
            self._update_scatter_layer(layer)

    @requires_data
    def _update_scatter_layer(self, layer, force=False):

        if layer not in self.artists:
            return

        xatt, yatt = self._get_plot_attributes()
        need_redraw = False

        for a in self.artists[layer]:
            if not isinstance(a, ScatterLayerBase):
                continue
            need_redraw = True
            a.xatt = xatt
            a.yatt = yatt
            if self.is_3D:
                zatt = self.display_data.get_pixel_component_id(
                    self._slice_ori)
                subset = (zatt > self.slice_ind) & (zatt <= self.slice_ind + 1)
                a.emphasis = subset
            else:
                a.emphasis = None
            a.update() if not force else a.force_update()
            a.redraw()

        if need_redraw:
            self._redraw()

    @requires_data
    def _get_plot_attributes(self):
        x, y = _slice_axis(self.display_data.shape, self.slice)
        ids = self.display_data.pixel_component_ids
        return ids[x], ids[y]

    def _pixel_coords(self, x, y):
        """
        From a slice coordinate (x,y), return the full (possibly >2D) numpy
        index into the full data.

        .. note:: The inputs to this function are the reverse of numpy
                  convention (horizontal axis first, then vertical)

        Returns
        -------
        coords : tuple
            Either a tuple of (x,y) or (x,y,z)
        """
        result = list(self.slice)
        result[result.index('x')] = x
        result[result.index('y')] = y
        return result

    def is_visible(self, layer):
        return all(a.visible for a in self.artists[layer])

    def set_visible(self, layer, state):
        for a in self.artists[layer]:
            a.visible = state

    def set_slice_ori(self, ori):
        if not self.is_3D:
            raise IndexError("Can only set slice_ori for 3D images")
        if ori == 0:
            self.slice = (0, 'y', 'x')
        elif ori == 1:
            self.slice = ('y', 0, 'x')
        elif ori == 2:
            self.slice = ('y', 'x', 0)
        else:
            raise ValueError("Orientation must be 0, 1, or 2")

    def restore_layers(self, layers, context):
        """
        Restore a list of glue-serialized layer dicts.
        """
        for layer in layers:
            c = lookup_class_with_patches(layer.pop('_type'))
            props = dict((k, v if k == 'stretch' else context.object(v))
                         for k, v in layer.items())
            l = props['layer']
            if issubclass(c, ScatterLayerBase):
                l = self.add_scatter_layer(l)
            elif issubclass(c, RGBImageLayerBase):
                r = props.pop('r')
                g = props.pop('g')
                b = props.pop('b')
                self.display_data = l
                self.display_attribute = r
                l = self.rgb_mode(True)
                l.r = r
                l.g = g
                l.b = b
            elif issubclass(c, (ImageLayerBase, SubsetImageLayerBase)):
                if isinstance(l, Data):
                    self.set_data(l)
                l = self.add_layer(l)
            else:
                raise ValueError("Cannot restore layer of type %s" % l)
            l.properties = props

    def _on_component_replace(self, msg):
        if self.display_attribute is msg.old:
            self.display_attribute = msg.new

    def register_to_hub(self, hub):
        super(ImageClient, self).register_to_hub(hub)
        hub.subscribe(self, ComponentReplacedMessage,
                      self._on_component_replace)

    # subclasses should override the following methods as appropriate
    def _new_rgb_layer(self, layer):
        """
        Construct and return an RGBImageLayerBase for the given layer

        Parameters
        ----------
        layer : :class:`~glue.core.data.Data` or :class:`~glue.core.subset.Subset`
            Which object to visualize
        """
        raise NotImplementedError()

    def _new_subset_image_layer(self, layer):
        """
        Construct and return a SubsetImageLayerArtist for the given layer

        Parameters
        ----------
        layer : :class:`~glue.core.data.Data` or :class:`~glue.core.subset.Subset`
            Which object to visualize
        """
        raise NotImplementedError()

    def _new_image_layer(self, layer):
        """
        Construct and return an ImageLayerArtist for the given layer

        Parameters
        ----------
        layer : :class:`~glue.core.data.Data` or :class:`~glue.core.subset.Subset`
            Which object to visualize
        """
        raise NotImplementedError()

    def _new_scatter_layer(self, layer):
        """
        Construct and return a ScatterLayerArtist for the given layer

        Parameters
        ----------
        layer : :class:`~glue.core.data.Data` or :class:`~glue.core.subset.Subset`
            Which object to visualize
        """
        raise NotImplementedError()

    def _update_axis_labels(self):
        """
        Sync the displays for labels on X/Y axes, because the data or slice has
        changed
        """
        raise NotImplementedError()

    def relim(self):
        """
        Reset view window to the default pan/zoom setting.
        """
        pass

    def show_crosshairs(self, x, y):
        pass

    def clear_crosshairs(self):
        pass