Ejemplo n.º 1
0
class MatplotlibDataViewer(DataViewer):

    _state_cls = MatplotlibDataViewerState

    tools = ['mpl:home', 'mpl:pan', 'mpl:zoom']
    subtools = {'save': ['mpl:save']}

    def __init__(self, session, parent=None, wcs=None, state=None):

        super(MatplotlibDataViewer, self).__init__(session,
                                                   parent=parent,
                                                   state=state)

        # Use MplWidget to set up a Matplotlib canvas inside the Qt window
        self.mpl_widget = MplWidget()
        self.setCentralWidget(self.mpl_widget)

        # TODO: shouldn't have to do this
        self.central_widget = self.mpl_widget

        self.figure, self._axes = init_mpl(self.mpl_widget.canvas.fig, wcs=wcs)

        for spine in self._axes.spines.values():
            spine.set_zorder(ZORDER_MAX)

        self.loading_rectangle = Rectangle((0, 0),
                                           1,
                                           1,
                                           color='0.9',
                                           alpha=0.9,
                                           zorder=ZORDER_MAX - 1,
                                           transform=self.axes.transAxes)
        self.loading_rectangle.set_visible(False)
        self.axes.add_patch(self.loading_rectangle)

        self.loading_text = self.axes.text(
            0.4,
            0.5,
            'Computing',
            color='k',
            zorder=self.loading_rectangle.get_zorder() + 1,
            ha='left',
            va='center',
            transform=self.axes.transAxes)
        self.loading_text.set_visible(False)

        self.state.add_callback('aspect', self.update_aspect)

        self.update_aspect()

        self.state.add_callback('x_min', self.limits_to_mpl)
        self.state.add_callback('x_max', self.limits_to_mpl)
        self.state.add_callback('y_min', self.limits_to_mpl)
        self.state.add_callback('y_max', self.limits_to_mpl)

        self.limits_to_mpl()

        self.state.add_callback('x_log', self.update_x_log, priority=1000)
        self.state.add_callback('y_log', self.update_y_log, priority=1000)

        self.update_x_log()

        self.axes.callbacks.connect('xlim_changed', self.limits_from_mpl)
        self.axes.callbacks.connect('ylim_changed', self.limits_from_mpl)

        self.axes.set_autoscale_on(False)

        self.state.add_callback('x_axislabel', self.update_x_axislabel)
        self.state.add_callback('x_axislabel_weight', self.update_x_axislabel)
        self.state.add_callback('x_axislabel_size', self.update_x_axislabel)

        self.state.add_callback('y_axislabel', self.update_y_axislabel)
        self.state.add_callback('y_axislabel_weight', self.update_y_axislabel)
        self.state.add_callback('y_axislabel_size', self.update_y_axislabel)

        self.state.add_callback('x_ticklabel_size', self.update_x_ticklabel)
        self.state.add_callback('y_ticklabel_size', self.update_y_ticklabel)

        self.update_x_axislabel()
        self.update_y_axislabel()
        self.update_x_ticklabel()
        self.update_y_ticklabel()

        self.central_widget.resize(600, 400)
        self.resize(self.central_widget.size())

        self._monitor_computation = QTimer()
        self._monitor_computation.setInterval(500)
        self._monitor_computation.timeout.connect(self._update_computation)

    def _update_computation(self, message=None):

        # If we get a ComputationStartedMessage and the timer isn't currently
        # active, then we start the timer but we then return straight away.
        # This is to avoid showing the 'Computing' message straight away in the
        # case of reasonably fast operations.
        if isinstance(message, ComputationStartedMessage):
            if not self._monitor_computation.isActive():
                self._monitor_computation.start()
            return

        for layer_artist in self.layers:
            if layer_artist.is_computing:
                self.loading_rectangle.set_visible(True)
                text = self.loading_text.get_text()
                if text.count('.') > 2:
                    text = 'Computing'
                else:
                    text += '.'
                self.loading_text.set_text(text)
                self.loading_text.set_visible(True)
                self.redraw()
                return

        self.loading_rectangle.set_visible(False)
        self.loading_text.set_visible(False)
        self.redraw()

        # If we get here, the computation has stopped so we can stop the timer
        self._monitor_computation.stop()

    def add_data(self, *args, **kwargs):
        return super(MatplotlibDataViewer, self).add_data(*args, **kwargs)

    def add_subset(self, *args, **kwargs):
        return super(MatplotlibDataViewer, self).add_subset(*args, **kwargs)

    def update_x_axislabel(self, *event):
        self.axes.set_xlabel(self.state.x_axislabel,
                             weight=self.state.x_axislabel_weight,
                             size=self.state.x_axislabel_size)
        self.redraw()

    def update_y_axislabel(self, *event):
        self.axes.set_ylabel(self.state.y_axislabel,
                             weight=self.state.y_axislabel_weight,
                             size=self.state.y_axislabel_size)
        self.redraw()

    def update_x_ticklabel(self, *event):
        self.axes.tick_params(axis='x', labelsize=self.state.x_ticklabel_size)
        self.axes.xaxis.get_offset_text().set_fontsize(
            self.state.x_ticklabel_size)
        self.redraw()

    def update_y_ticklabel(self, *event):
        self.axes.tick_params(axis='y', labelsize=self.state.y_ticklabel_size)
        self.axes.yaxis.get_offset_text().set_fontsize(
            self.state.y_ticklabel_size)
        self.redraw()

    def redraw(self):
        self.figure.canvas.draw()

    def update_x_log(self, *args):
        self.axes.set_xscale('log' if self.state.x_log else 'linear')
        self.redraw()

    def update_y_log(self, *args):
        self.axes.set_yscale('log' if self.state.y_log else 'linear')
        self.redraw()

    def update_aspect(self, aspect=None):
        self.axes.set_aspect(self.state.aspect, adjustable='datalim')

    @avoid_circular
    def limits_from_mpl(self, *args):

        with delay_callback(self.state, 'x_min', 'x_max', 'y_min', 'y_max'):

            if isinstance(self.state.x_min, np.datetime64):
                x_min, x_max = [
                    mpl_to_datetime64(x) for x in self.axes.get_xlim()
                ]
            else:
                x_min, x_max = self.axes.get_xlim()

            self.state.x_min, self.state.x_max = x_min, x_max

            if isinstance(self.state.y_min, np.datetime64):
                y_min, y_max = [
                    mpl_to_datetime64(y) for y in self.axes.get_ylim()
                ]
            else:
                y_min, y_max = self.axes.get_ylim()

            self.state.y_min, self.state.y_max = y_min, y_max

    @avoid_circular
    def limits_to_mpl(self, *args):

        if self.state.x_min is not None and self.state.x_max is not None:
            x_min, x_max = self.state.x_min, self.state.x_max
            if self.state.x_log:
                if self.state.x_max <= 0:
                    x_min, x_max = 0.1, 1
                elif self.state.x_min <= 0:
                    x_min = x_max / 10
            self.axes.set_xlim(x_min, x_max)

        if self.state.y_min is not None and self.state.y_max is not None:
            y_min, y_max = self.state.y_min, self.state.y_max
            if self.state.y_log:
                if self.state.y_max <= 0:
                    y_min, y_max = 0.1, 1
                elif self.state.y_min <= 0:
                    y_min = y_max / 10
            self.axes.set_ylim(y_min, y_max)

        if self.state.aspect == 'equal':

            # FIXME: for a reason I don't quite understand, dataLim doesn't
            # get updated immediately here, which means that there are then
            # issues in the first draw of the image (the limits are such that
            # only part of the image is shown). We just set dataLim manually
            # to avoid this issue.
            self.axes.dataLim.intervalx = self.axes.get_xlim()
            self.axes.dataLim.intervaly = self.axes.get_ylim()

            # We then force the aspect to be computed straight away
            self.axes.apply_aspect()

            # And propagate any changes back to the state since we have the
            # @avoid_circular decorator
            with delay_callback(self.state, 'x_min', 'x_max', 'y_min',
                                'y_max'):
                # TODO: fix case with datetime64 here
                self.state.x_min, self.state.x_max = self.axes.get_xlim()
                self.state.y_min, self.state.y_max = self.axes.get_ylim()

        self.axes.figure.canvas.draw()

    # TODO: shouldn't need this!
    @property
    def axes(self):
        return self._axes

    def _update_appearance_from_settings(self, message=None):
        update_appearance_from_settings(self.axes)
        self.redraw()

    def get_layer_artist(self, cls, layer=None, layer_state=None):
        return cls(self.axes, self.state, layer=layer, layer_state=layer_state)

    def apply_roi(self, roi, use_current=False):
        """ This method must be implemented by subclasses """
        raise NotImplementedError

    def _script_header(self):
        state_dict = self.state.as_dict()
        return ['import matplotlib.pyplot as plt'
                ], SCRIPT_HEADER.format(**state_dict)

    def _script_footer(self):
        state_dict = self.state.as_dict()
        state_dict['x_log_str'] = 'log' if self.state.x_log else 'linear'
        state_dict['y_log_str'] = 'log' if self.state.y_log else 'linear'
        return [], SCRIPT_FOOTER.format(**state_dict)
Ejemplo n.º 2
0
class MatplotlibViewerMixin(object):

    def setup_callbacks(self):

        for spine in self.axes.spines.values():
            spine.set_zorder(ZORDER_MAX)

        self.loading_rectangle = Rectangle((0, 0), 1, 1, color='0.9', alpha=0.9,
                                           zorder=ZORDER_MAX - 1, transform=self.axes.transAxes)
        self.loading_rectangle.set_visible(False)
        self.axes.add_patch(self.loading_rectangle)

        self.loading_text = self.axes.text(0.4, 0.5, 'Computing', color='k',
                                           zorder=self.loading_rectangle.get_zorder() + 1,
                                           ha='left', va='center',
                                           transform=self.axes.transAxes)
        self.loading_text.set_visible(False)

        self.state.add_callback('aspect', self.update_aspect)

        self.update_aspect()

        self.state.add_callback('x_min', self.limits_to_mpl)
        self.state.add_callback('x_max', self.limits_to_mpl)
        self.state.add_callback('y_min', self.limits_to_mpl)
        self.state.add_callback('y_max', self.limits_to_mpl)

        self.limits_to_mpl()

        self.state.add_callback('x_log', self.update_x_log, priority=1000)
        self.state.add_callback('y_log', self.update_y_log, priority=1000)

        self.update_x_log()

        self.axes.callbacks.connect('xlim_changed', self.limits_from_mpl)
        self.axes.callbacks.connect('ylim_changed', self.limits_from_mpl)

        self.axes.set_autoscale_on(False)

        self.state.add_callback('x_axislabel', self.update_x_axislabel)
        self.state.add_callback('x_axislabel_weight', self.update_x_axislabel)
        self.state.add_callback('x_axislabel_size', self.update_x_axislabel)

        self.state.add_callback('y_axislabel', self.update_y_axislabel)
        self.state.add_callback('y_axislabel_weight', self.update_y_axislabel)
        self.state.add_callback('y_axislabel_size', self.update_y_axislabel)

        self.state.add_callback('x_ticklabel_size', self.update_x_ticklabel)
        self.state.add_callback('y_ticklabel_size', self.update_y_ticklabel)

        self.update_x_axislabel()
        self.update_y_axislabel()
        self.update_x_ticklabel()
        self.update_y_ticklabel()

    def _update_computation(self, message=None):
        pass

    def update_x_axislabel(self, *event):
        self.axes.set_xlabel(self.state.x_axislabel,
                             weight=self.state.x_axislabel_weight,
                             size=self.state.x_axislabel_size)
        self.redraw()

    def update_y_axislabel(self, *event):
        self.axes.set_ylabel(self.state.y_axislabel,
                             weight=self.state.y_axislabel_weight,
                             size=self.state.y_axislabel_size)
        self.redraw()

    def update_x_ticklabel(self, *event):
        self.axes.tick_params(axis='x', labelsize=self.state.x_ticklabel_size)
        self.axes.xaxis.get_offset_text().set_fontsize(self.state.x_ticklabel_size)
        self.redraw()

    def update_y_ticklabel(self, *event):
        self.axes.tick_params(axis='y', labelsize=self.state.y_ticklabel_size)
        self.axes.yaxis.get_offset_text().set_fontsize(self.state.y_ticklabel_size)
        self.redraw()

    def redraw(self):
        self.figure.canvas.draw()

    def update_x_log(self, *args):
        self.axes.set_xscale('log' if self.state.x_log else 'linear')
        self.redraw()

    def update_y_log(self, *args):
        self.axes.set_yscale('log' if self.state.y_log else 'linear')
        self.redraw()

    def update_aspect(self, aspect=None):
        self.axes.set_aspect(self.state.aspect, adjustable='datalim')

    @avoid_circular
    def limits_from_mpl(self, *args):

        with delay_callback(self.state, 'x_min', 'x_max', 'y_min', 'y_max'):

            if isinstance(self.state.x_min, np.datetime64):
                x_min, x_max = [mpl_to_datetime64(x) for x in self.axes.get_xlim()]
            else:
                x_min, x_max = self.axes.get_xlim()

            self.state.x_min, self.state.x_max = x_min, x_max

            if isinstance(self.state.y_min, np.datetime64):
                y_min, y_max = [mpl_to_datetime64(y) for y in self.axes.get_ylim()]
            else:
                y_min, y_max = self.axes.get_ylim()

            self.state.y_min, self.state.y_max = y_min, y_max

    @avoid_circular
    def limits_to_mpl(self, *args):

        if self.state.x_min is not None and self.state.x_max is not None:
            x_min, x_max = self.state.x_min, self.state.x_max
            if self.state.x_log:
                if self.state.x_max <= 0:
                    x_min, x_max = 0.1, 1
                elif self.state.x_min <= 0:
                    x_min = x_max / 10
            self.axes.set_xlim(x_min, x_max)

        if self.state.y_min is not None and self.state.y_max is not None:
            y_min, y_max = self.state.y_min, self.state.y_max
            if self.state.y_log:
                if self.state.y_max <= 0:
                    y_min, y_max = 0.1, 1
                elif self.state.y_min <= 0:
                    y_min = y_max / 10
            self.axes.set_ylim(y_min, y_max)

        if self.state.aspect == 'equal':

            # FIXME: for a reason I don't quite understand, dataLim doesn't
            # get updated immediately here, which means that there are then
            # issues in the first draw of the image (the limits are such that
            # only part of the image is shown). We just set dataLim manually
            # to avoid this issue.
            self.axes.dataLim.intervalx = self.axes.get_xlim()
            self.axes.dataLim.intervaly = self.axes.get_ylim()

            # We then force the aspect to be computed straight away
            self.axes.apply_aspect()

            # And propagate any changes back to the state since we have the
            # @avoid_circular decorator
            with delay_callback(self.state, 'x_min', 'x_max', 'y_min', 'y_max'):
                # TODO: fix case with datetime64 here
                self.state.x_min, self.state.x_max = self.axes.get_xlim()
                self.state.y_min, self.state.y_max = self.axes.get_ylim()

        self.axes.figure.canvas.draw()

    def _update_appearance_from_settings(self, message=None):
        update_appearance_from_settings(self.axes)
        self.redraw()

    def get_layer_artist(self, cls, layer=None, layer_state=None):
        return cls(self.axes, self.state, layer=layer, layer_state=layer_state)

    def apply_roi(self, roi, override_mode=None):
        """ This method must be implemented by subclasses """
        raise NotImplementedError

    def _script_header(self):
        state_dict = self.state.as_dict()
        return ['import matplotlib.pyplot as plt'], SCRIPT_HEADER.format(**state_dict)

    def _script_footer(self):
        state_dict = self.state.as_dict()
        state_dict['x_log_str'] = 'log' if self.state.x_log else 'linear'
        state_dict['y_log_str'] = 'log' if self.state.y_log else 'linear'
        return [], SCRIPT_FOOTER.format(**state_dict)
Ejemplo n.º 3
0
class MatplotlibViewerMixin(object):
    def setup_callbacks(self):

        for spine in self.axes.spines.values():
            spine.set_zorder(ZORDER_MAX)

        self.loading_rectangle = Rectangle((0, 0),
                                           1,
                                           1,
                                           color='0.9',
                                           alpha=0.9,
                                           zorder=ZORDER_MAX - 1,
                                           transform=self.axes.transAxes)
        self.loading_rectangle.set_visible(False)
        self.axes.add_patch(self.loading_rectangle)

        self.loading_text = self.axes.text(
            0.4,
            0.5,
            'Computing',
            color='k',
            zorder=self.loading_rectangle.get_zorder() + 1,
            ha='left',
            va='center',
            transform=self.axes.transAxes)
        self.loading_text.set_visible(False)

        self.state.add_callback('x_min', self.limits_to_mpl)
        self.state.add_callback('x_max', self.limits_to_mpl)
        self.state.add_callback('y_min', self.limits_to_mpl)
        self.state.add_callback('y_max', self.limits_to_mpl)

        if (self.state.x_min or self.state.x_max or self.state.y_min
                or self.state.y_max) is None:
            self.limits_from_mpl()
        else:
            self.limits_to_mpl()

        self.state.add_callback('x_log', self.update_x_log, priority=1000)
        self.state.add_callback('y_log', self.update_y_log, priority=1000)

        self.update_x_log()

        self.axes.callbacks.connect('xlim_changed', self.limits_from_mpl)
        self.axes.callbacks.connect('ylim_changed', self.limits_from_mpl)
        self.figure.canvas.mpl_connect('resize_event', self._on_resize)

        self._on_resize()

        self.axes.set_autoscale_on(False)

        self.state.add_callback('x_axislabel', self.update_x_axislabel)
        self.state.add_callback('x_axislabel_weight', self.update_x_axislabel)
        self.state.add_callback('x_axislabel_size', self.update_x_axislabel)

        self.state.add_callback('y_axislabel', self.update_y_axislabel)
        self.state.add_callback('y_axislabel_weight', self.update_y_axislabel)
        self.state.add_callback('y_axislabel_size', self.update_y_axislabel)

        self.state.add_callback('x_ticklabel_size', self.update_x_ticklabel)
        self.state.add_callback('y_ticklabel_size', self.update_y_ticklabel)

        self.update_x_axislabel()
        self.update_y_axislabel()
        self.update_x_ticklabel()
        self.update_y_ticklabel()

    def _update_computation(self, message=None):
        pass

    def update_x_axislabel(self, *event):
        self.axes.set_xlabel(self.state.x_axislabel,
                             weight=self.state.x_axislabel_weight,
                             size=self.state.x_axislabel_size)
        self.redraw()

    def update_y_axislabel(self, *event):
        self.axes.set_ylabel(self.state.y_axislabel,
                             weight=self.state.y_axislabel_weight,
                             size=self.state.y_axislabel_size)
        self.redraw()

    def update_x_ticklabel(self, *event):
        self.axes.tick_params(axis='x', labelsize=self.state.x_ticklabel_size)
        self.axes.xaxis.get_offset_text().set_fontsize(
            self.state.x_ticklabel_size)
        self.redraw()

    def update_y_ticklabel(self, *event):
        self.axes.tick_params(axis='y', labelsize=self.state.y_ticklabel_size)
        self.axes.yaxis.get_offset_text().set_fontsize(
            self.state.y_ticklabel_size)
        self.redraw()

    def redraw(self):
        self.figure.canvas.draw_idle()

    def update_x_log(self, *args):
        self.axes.set_xscale('log' if self.state.x_log else 'linear')
        self.redraw()

    def update_y_log(self, *args):
        self.axes.set_yscale('log' if self.state.y_log else 'linear')
        self.redraw()

    def limits_from_mpl(self, *args, **kwargs):

        if getattr(self, '_skip_limits_from_mpl', False):
            return

        if isinstance(self.state.x_min, np.datetime64):
            x_min, x_max = [mpl_to_datetime64(x) for x in self.axes.get_xlim()]
        else:
            x_min, x_max = self.axes.get_xlim()

        if isinstance(self.state.y_min, np.datetime64):
            y_min, y_max = [mpl_to_datetime64(y) for y in self.axes.get_ylim()]
        else:
            y_min, y_max = self.axes.get_ylim()

        with delay_callback(self.state, 'x_min', 'x_max', 'y_min', 'y_max'):

            self.state.x_min = x_min
            self.state.x_max = x_max
            self.state.y_min = y_min
            self.state.y_max = y_max

    def limits_to_mpl(self, *args):

        if (self.state.x_min is None or self.state.x_max is None
                or self.state.y_min is None or self.state.y_max is None):
            return

        x_min, x_max = self.state.x_min, self.state.x_max

        if self.state.x_log:
            if self.state.x_max <= 0:
                x_min, x_max = 0.1, 1
            elif self.state.x_min <= 0:
                x_min = x_max / 10

        y_min, y_max = self.state.y_min, self.state.y_max
        if self.state.y_log:
            if self.state.y_max <= 0:
                y_min, y_max = 0.1, 1
            elif self.state.y_min <= 0:
                y_min = y_max / 10

        # Since we deal with aspect ratio internally, there are no conditions
        # under which we would want to immediately change the state once the
        # limits are set in Matplotlib, so we avoid this by setting the
        # _skip_limits_from_mpl attribute which is then recognized inside
        # limits_from_mpl
        self._skip_limits_from_mpl = True
        try:
            self.axes.set_xlim(x_min, x_max)
            self.axes.set_ylim(y_min, y_max)
            self.axes.figure.canvas.draw_idle()
        finally:
            self._skip_limits_from_mpl = False

    @property
    def axes_ratio(self):

        # Get figure aspect ratio
        width, height = self.figure.get_size_inches()
        fig_aspect = height / width

        # Get axes aspect ratio
        l, b, w, h = self.axes.get_position().bounds
        axes_ratio = fig_aspect * (h / w)

        return axes_ratio

    def _on_resize(self, *args):
        self.state._set_axes_aspect_ratio(self.axes_ratio)

    def _update_appearance_from_settings(self, message=None):
        update_appearance_from_settings(self.axes)
        self.redraw()

    def get_layer_artist(self, cls, layer=None, layer_state=None):
        return cls(self.axes, self.state, layer=layer, layer_state=layer_state)

    def apply_roi(self, roi, override_mode=None):
        """ This method must be implemented by subclasses """
        raise NotImplementedError

    def _script_header(self):
        state_dict = self.state.as_dict()
        return [
            'import matplotlib', "matplotlib.use('Agg')",
            'import matplotlib.pyplot as plt'
        ], SCRIPT_HEADER.format(**state_dict)

    def _script_footer(self):
        state_dict = self.state.as_dict()
        state_dict['x_log_str'] = 'log' if self.state.x_log else 'linear'
        state_dict['y_log_str'] = 'log' if self.state.y_log else 'linear'
        return [], SCRIPT_FOOTER.format(**state_dict)
Ejemplo n.º 4
0
class MatplotlibViewerMixin(object):
    def setup_callbacks(self):

        for spine in self.axes.spines.values():
            spine.set_zorder(ZORDER_MAX)

        self.loading_rectangle = Rectangle((0, 0),
                                           1,
                                           1,
                                           color='0.9',
                                           alpha=0.9,
                                           zorder=ZORDER_MAX - 1,
                                           transform=self.axes.transAxes)
        self.loading_rectangle.set_visible(False)
        self.axes.add_patch(self.loading_rectangle)

        self.loading_text = self.axes.text(
            0.4,
            0.5,
            'Computing',
            color='k',
            zorder=self.loading_rectangle.get_zorder() + 1,
            ha='left',
            va='center',
            transform=self.axes.transAxes)
        self.loading_text.set_visible(False)

        self.state.add_callback('x_min', self.limits_to_mpl)
        self.state.add_callback('x_max', self.limits_to_mpl)
        self.state.add_callback('y_min', self.limits_to_mpl)
        self.state.add_callback('y_max', self.limits_to_mpl)

        if (self.state.x_min or self.state.x_max or self.state.y_min
                or self.state.y_max) is None:
            self.limits_from_mpl()
        else:
            self.limits_to_mpl()

        self.state.add_callback('x_log', self.update_x_log, priority=1000)
        self.state.add_callback('y_log', self.update_y_log, priority=1000)

        self.update_x_log()

        self.axes.callbacks.connect('xlim_changed', self.limits_from_mpl)
        self.axes.callbacks.connect('ylim_changed', self.limits_from_mpl)
        self.figure.canvas.mpl_connect('resize_event', self._on_resize)

        self._on_resize()

        self.axes.set_autoscale_on(False)

        self.state.add_callback('x_axislabel', self.update_x_axislabel)
        self.state.add_callback('x_axislabel_weight', self.update_x_axislabel)
        self.state.add_callback('x_axislabel_size', self.update_x_axislabel)

        self.state.add_callback('y_axislabel', self.update_y_axislabel)
        self.state.add_callback('y_axislabel_weight', self.update_y_axislabel)
        self.state.add_callback('y_axislabel_size', self.update_y_axislabel)

        self.state.add_callback('x_ticklabel_size', self.update_x_ticklabel)
        self.state.add_callback('y_ticklabel_size', self.update_y_ticklabel)

        self.state.legend.add_callback('visible', self.draw_legend)
        self.state.legend.add_callback('location', self.draw_legend)
        self.state.legend.add_callback('alpha', self.update_legend)
        self.state.legend.add_callback('title', self.draw_legend)
        self.state.legend.add_callback('fontsize', self.draw_legend)
        self.state.legend.add_callback('frame_color', self.update_legend)
        self.state.legend.add_callback('show_edge', self.update_legend)
        self.state.legend.add_callback('text_color', self.update_legend)

        self.update_x_axislabel()
        self.update_y_axislabel()
        self.update_x_ticklabel()
        self.update_y_ticklabel()
        self.draw_legend()

    def _update_computation(self, message=None):
        pass

    def update_x_axislabel(self, *event):
        self.axes.set_xlabel(self.state.x_axislabel,
                             weight=self.state.x_axislabel_weight,
                             size=self.state.x_axislabel_size)
        self.redraw()

    def update_y_axislabel(self, *event):
        self.axes.set_ylabel(self.state.y_axislabel,
                             weight=self.state.y_axislabel_weight,
                             size=self.state.y_axislabel_size)
        self.redraw()

    def update_x_ticklabel(self, *event):
        self.axes.tick_params(axis='x', labelsize=self.state.x_ticklabel_size)
        self.axes.xaxis.get_offset_text().set_fontsize(
            self.state.x_ticklabel_size)
        self.redraw()

    def update_y_ticklabel(self, *event):
        self.axes.tick_params(axis='y', labelsize=self.state.y_ticklabel_size)
        self.axes.yaxis.get_offset_text().set_fontsize(
            self.state.y_ticklabel_size)
        self.redraw()

    def get_handles_legend(self):
        """Collect the handles and labels from each layer artist."""
        handles = []
        labels = []
        handler_dict = {}
        for layer_artist in self._layer_artist_container:
            handle, label, handler = layer_artist.get_handle_legend()
            if handle is not None:
                handles.append(handle)
                labels.append(label)
                if handler is not None:
                    handler_dict[handle] = handler
        return handles, labels, handler_dict

    def _update_legend_visual(self, legend):
        """Update the legend colors and opacity. No redraw."""
        msetp(legend.get_title(), color=self.state.legend.text_color)
        msetp(legend.get_texts(), color=self.state.legend.text_color)
        msetp(legend.get_frame(),
              alpha=self.state.legend.alpha,
              facecolor=self.state.legend.frame_color,
              edgecolor=self.state.legend.edge_color)

    def update_legend(self, *args):
        """Update the legend colors and opacity."""
        legend = self.axes.get_legend()
        if legend is not None:
            self._update_legend_visual(legend)
        self.redraw()

    def draw_legend(self, *args):
        if self.state.legend.visible:
            handles, labels, handler_map = self.get_handles_legend()
            kwargs = dict(loc=self.state.legend.mpl_location,
                          title=self.state.legend.title,
                          title_fontsize=self.state.legend.fontsize,
                          fontsize=self.state.legend.fontsize)
            if handler_map is not None:
                kwargs["handler_map"] = handler_map
            legend = self.axes.legend(handles, labels, **kwargs)
            self._update_legend_visual(legend)
            legend.set_draggable(self.state.legend.draggable)
        else:
            legend = self.axes.get_legend()
            if legend is not None:
                legend.remove()
        self.redraw()

    def redraw(self):
        self.figure.canvas.draw_idle()

    def update_x_log(self, *args):
        self.axes.set_xscale('log' if self.state.x_log else 'linear')
        self.redraw()

    def update_y_log(self, *args):
        self.axes.set_yscale('log' if self.state.y_log else 'linear')
        self.redraw()

    def limits_from_mpl(self, *args, **kwargs):

        if getattr(self, '_skip_limits_from_mpl', False):
            return

        if isinstance(self.state.x_min, np.datetime64):
            x_min, x_max = [mpl_to_datetime64(x) for x in self.axes.get_xlim()]
        else:
            x_min, x_max = self.axes.get_xlim()

        if isinstance(self.state.y_min, np.datetime64):
            y_min, y_max = [mpl_to_datetime64(y) for y in self.axes.get_ylim()]
        else:
            y_min, y_max = self.axes.get_ylim()

        with delay_callback(self.state, 'x_min', 'x_max', 'y_min', 'y_max'):

            self.state.x_min = x_min
            self.state.x_max = x_max
            self.state.y_min = y_min
            self.state.y_max = y_max

    def limits_to_mpl(self, *args):

        if (self.state.x_min is None or self.state.x_max is None
                or self.state.y_min is None or self.state.y_max is None):
            return

        x_min, x_max = self.state.x_min, self.state.x_max

        if self.state.x_log:
            if self.state.x_max <= 0:
                x_min, x_max = 0.1, 1
            elif self.state.x_min <= 0:
                x_min = x_max / 10

        y_min, y_max = self.state.y_min, self.state.y_max
        if self.state.y_log:
            if self.state.y_max <= 0:
                y_min, y_max = 0.1, 1
            elif self.state.y_min <= 0:
                y_min = y_max / 10

        # Since we deal with aspect ratio internally, there are no conditions
        # under which we would want to immediately change the state once the
        # limits are set in Matplotlib, so we avoid this by setting the
        # _skip_limits_from_mpl attribute which is then recognized inside
        # limits_from_mpl
        self._skip_limits_from_mpl = True
        try:
            self.axes.set_xlim(x_min, x_max)
            self.axes.set_ylim(y_min, y_max)
            self.axes.figure.canvas.draw_idle()
        finally:
            self._skip_limits_from_mpl = False

    @property
    def axes_ratio(self):

        # Get figure aspect ratio
        width, height = self.figure.get_size_inches()
        fig_aspect = height / width

        # Get axes aspect ratio
        l, b, w, h = self.axes.get_position().bounds
        axes_ratio = fig_aspect * (h / w)

        return axes_ratio

    def _on_resize(self, *args):
        self.state._set_axes_aspect_ratio(self.axes_ratio)

    def _update_appearance_from_settings(self, message=None):
        update_appearance_from_settings(self.axes)
        self.redraw()

    def get_layer_artist(self, cls, layer=None, layer_state=None):
        return cls(self.axes, self.state, layer=layer, layer_state=layer_state)

    def apply_roi(self, roi, override_mode=None):
        """ This method must be implemented by subclasses """
        raise NotImplementedError

    def _script_header(self):
        state_dict = self.state.as_dict()
        return [
            'import matplotlib', "matplotlib.use('Agg')",
            "# matplotlib.use('qt5Agg')", 'import matplotlib.pyplot as plt'
        ], SCRIPT_HEADER.format(**state_dict)

    def _script_footer(self):
        state_dict = self.state.as_dict()
        state_dict['x_log_str'] = 'log' if self.state.x_log else 'linear'
        state_dict['y_log_str'] = 'log' if self.state.y_log else 'linear'
        return [], SCRIPT_FOOTER.format(**state_dict)

    def _script_legend(self):
        state_dict = self.state.legend.as_dict()
        state_dict['location'] = self.state.legend.mpl_location
        state_dict['edge_color'] = self.state.legend.edge_color
        legend_str = SCRIPT_LEGEND.format(**state_dict)
        if not self.state.legend.visible:
            legend_str = indent(legend_str, "# ")
        return [], legend_str
Ejemplo n.º 5
0
class MatplotlibDataViewer(DataViewer):

    _state_cls = MatplotlibDataViewerState

    tools = ['mpl:home', 'mpl:pan', 'mpl:zoom']
    subtools = {'save': ['mpl:save']}

    def __init__(self, session, parent=None, wcs=None, state=None):

        super(MatplotlibDataViewer, self).__init__(session, parent=parent, state=state)

        # Use MplWidget to set up a Matplotlib canvas inside the Qt window
        self.mpl_widget = MplWidget()
        self.setCentralWidget(self.mpl_widget)

        # TODO: shouldn't have to do this
        self.central_widget = self.mpl_widget

        self.figure, self._axes = init_mpl(self.mpl_widget.canvas.fig, wcs=wcs)

        for spine in self._axes.spines.values():
            spine.set_zorder(ZORDER_MAX)

        self.loading_rectangle = Rectangle((0, 0), 1, 1, color='0.9', alpha=0.9,
                                           zorder=ZORDER_MAX - 1, transform=self.axes.transAxes)
        self.loading_rectangle.set_visible(False)
        self.axes.add_patch(self.loading_rectangle)

        self.loading_text = self.axes.text(0.4, 0.5, 'Computing', color='k',
                                           zorder=self.loading_rectangle.get_zorder() + 1,
                                           ha='left', va='center',
                                           transform=self.axes.transAxes)
        self.loading_text.set_visible(False)

        self.state.add_callback('aspect', self.update_aspect)

        self.update_aspect()

        self.state.add_callback('x_min', self.limits_to_mpl)
        self.state.add_callback('x_max', self.limits_to_mpl)
        self.state.add_callback('y_min', self.limits_to_mpl)
        self.state.add_callback('y_max', self.limits_to_mpl)

        self.limits_to_mpl()

        self.state.add_callback('x_log', self.update_x_log, priority=1000)
        self.state.add_callback('y_log', self.update_y_log, priority=1000)

        self.update_x_log()

        self.axes.callbacks.connect('xlim_changed', self.limits_from_mpl)
        self.axes.callbacks.connect('ylim_changed', self.limits_from_mpl)

        self.axes.set_autoscale_on(False)

        self.state.add_callback('x_axislabel', self.update_x_axislabel)
        self.state.add_callback('x_axislabel_weight', self.update_x_axislabel)
        self.state.add_callback('x_axislabel_size', self.update_x_axislabel)

        self.state.add_callback('y_axislabel', self.update_y_axislabel)
        self.state.add_callback('y_axislabel_weight', self.update_y_axislabel)
        self.state.add_callback('y_axislabel_size', self.update_y_axislabel)

        self.state.add_callback('x_ticklabel_size', self.update_x_ticklabel)
        self.state.add_callback('y_ticklabel_size', self.update_y_ticklabel)

        self.update_x_axislabel()
        self.update_y_axislabel()
        self.update_x_ticklabel()
        self.update_y_ticklabel()

        self.central_widget.resize(600, 400)
        self.resize(self.central_widget.size())

        self._monitor_computation = QTimer()
        self._monitor_computation.setInterval(500)
        self._monitor_computation.timeout.connect(self._update_computation)

    def _update_computation(self, message=None):

        # If we get a ComputationStartedMessage and the timer isn't currently
        # active, then we start the timer but we then return straight away.
        # This is to avoid showing the 'Computing' message straight away in the
        # case of reasonably fast operations.
        if isinstance(message, ComputationStartedMessage):
            if not self._monitor_computation.isActive():
                self._monitor_computation.start()
            return

        for layer_artist in self.layers:
            if layer_artist.is_computing:
                self.loading_rectangle.set_visible(True)
                text = self.loading_text.get_text()
                if text.count('.') > 2:
                    text = 'Computing'
                else:
                    text += '.'
                self.loading_text.set_text(text)
                self.loading_text.set_visible(True)
                self.redraw()
                return

        self.loading_rectangle.set_visible(False)
        self.loading_text.set_visible(False)
        self.redraw()

        # If we get here, the computation has stopped so we can stop the timer
        self._monitor_computation.stop()

    def add_data(self, *args, **kwargs):
        return super(MatplotlibDataViewer, self).add_data(*args, **kwargs)

    def add_subset(self, *args, **kwargs):
        return super(MatplotlibDataViewer, self).add_subset(*args, **kwargs)

    def update_x_axislabel(self, *event):
        self.axes.set_xlabel(self.state.x_axislabel,
                             weight=self.state.x_axislabel_weight,
                             size=self.state.x_axislabel_size)
        self.redraw()

    def update_y_axislabel(self, *event):
        self.axes.set_ylabel(self.state.y_axislabel,
                             weight=self.state.y_axislabel_weight,
                             size=self.state.y_axislabel_size)
        self.redraw()

    def update_x_ticklabel(self, *event):
        self.axes.tick_params(axis='x', labelsize=self.state.x_ticklabel_size)
        self.axes.xaxis.get_offset_text().set_fontsize(self.state.x_ticklabel_size)
        self.redraw()

    def update_y_ticklabel(self, *event):
        self.axes.tick_params(axis='y', labelsize=self.state.y_ticklabel_size)
        self.axes.yaxis.get_offset_text().set_fontsize(self.state.y_ticklabel_size)
        self.redraw()

    def redraw(self):
        self.figure.canvas.draw()

    def update_x_log(self, *args):
        self.axes.set_xscale('log' if self.state.x_log else 'linear')
        self.redraw()

    def update_y_log(self, *args):
        self.axes.set_yscale('log' if self.state.y_log else 'linear')
        self.redraw()

    def update_aspect(self, aspect=None):
        self.axes.set_aspect(self.state.aspect, adjustable='datalim')

    @avoid_circular
    def limits_from_mpl(self, *args):

        with delay_callback(self.state, 'x_min', 'x_max', 'y_min', 'y_max'):

            if isinstance(self.state.x_min, np.datetime64):
                x_min, x_max = [mpl_to_datetime64(x) for x in self.axes.get_xlim()]
            else:
                x_min, x_max = self.axes.get_xlim()

            self.state.x_min, self.state.x_max = x_min, x_max

            if isinstance(self.state.y_min, np.datetime64):
                y_min, y_max = [mpl_to_datetime64(y) for y in self.axes.get_ylim()]
            else:
                y_min, y_max = self.axes.get_ylim()

            self.state.y_min, self.state.y_max = y_min, y_max

    @avoid_circular
    def limits_to_mpl(self, *args):

        if self.state.x_min is not None and self.state.x_max is not None:
            x_min, x_max = self.state.x_min, self.state.x_max
            if self.state.x_log:
                if self.state.x_max <= 0:
                    x_min, x_max = 0.1, 1
                elif self.state.x_min <= 0:
                    x_min = x_max / 10
            self.axes.set_xlim(x_min, x_max)

        if self.state.y_min is not None and self.state.y_max is not None:
            y_min, y_max = self.state.y_min, self.state.y_max
            if self.state.y_log:
                if self.state.y_max <= 0:
                    y_min, y_max = 0.1, 1
                elif self.state.y_min <= 0:
                    y_min = y_max / 10
            self.axes.set_ylim(y_min, y_max)

        if self.state.aspect == 'equal':

            # FIXME: for a reason I don't quite understand, dataLim doesn't
            # get updated immediately here, which means that there are then
            # issues in the first draw of the image (the limits are such that
            # only part of the image is shown). We just set dataLim manually
            # to avoid this issue.
            self.axes.dataLim.intervalx = self.axes.get_xlim()
            self.axes.dataLim.intervaly = self.axes.get_ylim()

            # We then force the aspect to be computed straight away
            self.axes.apply_aspect()

            # And propagate any changes back to the state since we have the
            # @avoid_circular decorator
            with delay_callback(self.state, 'x_min', 'x_max', 'y_min', 'y_max'):
                # TODO: fix case with datetime64 here
                self.state.x_min, self.state.x_max = self.axes.get_xlim()
                self.state.y_min, self.state.y_max = self.axes.get_ylim()

        self.axes.figure.canvas.draw()

    # TODO: shouldn't need this!
    @property
    def axes(self):
        return self._axes

    def _update_appearance_from_settings(self, message=None):
        update_appearance_from_settings(self.axes)
        self.redraw()

    def get_layer_artist(self, cls, layer=None, layer_state=None):
        return cls(self.axes, self.state, layer=layer, layer_state=layer_state)

    def apply_roi(self, roi, use_current=False):
        """ This method must be implemented by subclasses """
        raise NotImplementedError

    def _script_header(self):
        state_dict = self.state.as_dict()
        return ['import matplotlib.pyplot as plt'], SCRIPT_HEADER.format(**state_dict)

    def _script_footer(self):
        state_dict = self.state.as_dict()
        state_dict['x_log_str'] = 'log' if self.state.x_log else 'linear'
        state_dict['y_log_str'] = 'log' if self.state.y_log else 'linear'
        return [], SCRIPT_FOOTER.format(**state_dict)
Ejemplo n.º 6
0
class MatplotlibViewerMixin(object):

    def setup_callbacks(self):

        for spine in self.axes.spines.values():
            spine.set_zorder(ZORDER_MAX)

        self.loading_rectangle = Rectangle((0, 0), 1, 1, color='0.9', alpha=0.9,
                                           zorder=ZORDER_MAX - 1, transform=self.axes.transAxes)
        self.loading_rectangle.set_visible(False)
        self.axes.add_patch(self.loading_rectangle)

        self.loading_text = self.axes.text(0.4, 0.5, 'Computing', color='k',
                                           zorder=self.loading_rectangle.get_zorder() + 1,
                                           ha='left', va='center',
                                           transform=self.axes.transAxes)
        self.loading_text.set_visible(False)

        self.state.add_callback('x_min', self.limits_to_mpl)
        self.state.add_callback('x_max', self.limits_to_mpl)
        self.state.add_callback('y_min', self.limits_to_mpl)
        self.state.add_callback('y_max', self.limits_to_mpl)

        if (self.state.x_min or self.state.x_max or self.state.y_min or self.state.y_max) is None:
            self.limits_from_mpl()
        else:
            self.limits_to_mpl()

        self.state.add_callback('x_log', self.update_x_log, priority=1000)
        self.state.add_callback('y_log', self.update_y_log, priority=1000)

        self.update_x_log()

        self.axes.callbacks.connect('xlim_changed', self.limits_from_mpl)
        self.axes.callbacks.connect('ylim_changed', self.limits_from_mpl)
        self.figure.canvas.mpl_connect('resize_event', self._on_resize)

        self._on_resize()

        self.axes.set_autoscale_on(False)

        self.state.add_callback('x_axislabel', self.update_x_axislabel)
        self.state.add_callback('x_axislabel_weight', self.update_x_axislabel)
        self.state.add_callback('x_axislabel_size', self.update_x_axislabel)

        self.state.add_callback('y_axislabel', self.update_y_axislabel)
        self.state.add_callback('y_axislabel_weight', self.update_y_axislabel)
        self.state.add_callback('y_axislabel_size', self.update_y_axislabel)

        self.state.add_callback('x_ticklabel_size', self.update_x_ticklabel)
        self.state.add_callback('y_ticklabel_size', self.update_y_ticklabel)

        self.update_x_axislabel()
        self.update_y_axislabel()
        self.update_x_ticklabel()
        self.update_y_ticklabel()

    def _update_computation(self, message=None):
        pass

    def update_x_axislabel(self, *event):
        self.axes.set_xlabel(self.state.x_axislabel,
                             weight=self.state.x_axislabel_weight,
                             size=self.state.x_axislabel_size)
        self.redraw()

    def update_y_axislabel(self, *event):
        self.axes.set_ylabel(self.state.y_axislabel,
                             weight=self.state.y_axislabel_weight,
                             size=self.state.y_axislabel_size)
        self.redraw()

    def update_x_ticklabel(self, *event):
        self.axes.tick_params(axis='x', labelsize=self.state.x_ticklabel_size)
        self.axes.xaxis.get_offset_text().set_fontsize(self.state.x_ticklabel_size)
        self.redraw()

    def update_y_ticklabel(self, *event):
        self.axes.tick_params(axis='y', labelsize=self.state.y_ticklabel_size)
        self.axes.yaxis.get_offset_text().set_fontsize(self.state.y_ticklabel_size)
        self.redraw()

    def redraw(self):
        self.figure.canvas.draw_idle()

    def update_x_log(self, *args):
        self.axes.set_xscale('log' if self.state.x_log else 'linear')
        self.redraw()

    def update_y_log(self, *args):
        self.axes.set_yscale('log' if self.state.y_log else 'linear')
        self.redraw()

    def limits_from_mpl(self, *args, **kwargs):

        if getattr(self, '_skip_limits_from_mpl', False):
            return

        if isinstance(self.state.x_min, np.datetime64):
            x_min, x_max = [mpl_to_datetime64(x) for x in self.axes.get_xlim()]
        else:
            x_min, x_max = self.axes.get_xlim()

        if isinstance(self.state.y_min, np.datetime64):
            y_min, y_max = [mpl_to_datetime64(y) for y in self.axes.get_ylim()]
        else:
            y_min, y_max = self.axes.get_ylim()

        with delay_callback(self.state, 'x_min', 'x_max', 'y_min', 'y_max'):

            self.state.x_min = x_min
            self.state.x_max = x_max
            self.state.y_min = y_min
            self.state.y_max = y_max

    def limits_to_mpl(self, *args):

        if (self.state.x_min is None or self.state.x_max is None or
                self.state.y_min is None or self.state.y_max is None):
            return

        x_min, x_max = self.state.x_min, self.state.x_max

        if self.state.x_log:
            if self.state.x_max <= 0:
                x_min, x_max = 0.1, 1
            elif self.state.x_min <= 0:
                x_min = x_max / 10

        y_min, y_max = self.state.y_min, self.state.y_max
        if self.state.y_log:
            if self.state.y_max <= 0:
                y_min, y_max = 0.1, 1
            elif self.state.y_min <= 0:
                y_min = y_max / 10

        # Since we deal with aspect ratio internally, there are no conditions
        # under which we would want to immediately change the state once the
        # limits are set in Matplotlib, so we avoid this by setting the
        # _skip_limits_from_mpl attribute which is then recognized inside
        # limits_from_mpl
        self._skip_limits_from_mpl = True
        try:
            self.axes.set_xlim(x_min, x_max)
            self.axes.set_ylim(y_min, y_max)
            self.axes.figure.canvas.draw_idle()
        finally:
            self._skip_limits_from_mpl = False

    @property
    def axes_ratio(self):

        # Get figure aspect ratio
        width, height = self.figure.get_size_inches()
        fig_aspect = height / width

        # Get axes aspect ratio
        l, b, w, h = self.axes.get_position().bounds
        axes_ratio = fig_aspect * (h / w)

        return axes_ratio

    def _on_resize(self, *args):
        self.state._set_axes_aspect_ratio(self.axes_ratio)

    def _update_appearance_from_settings(self, message=None):
        update_appearance_from_settings(self.axes)
        self.redraw()

    def get_layer_artist(self, cls, layer=None, layer_state=None):
        return cls(self.axes, self.state, layer=layer, layer_state=layer_state)

    def apply_roi(self, roi, override_mode=None):
        """ This method must be implemented by subclasses """
        raise NotImplementedError

    def _script_header(self):
        state_dict = self.state.as_dict()
        return ['import matplotlib', "matplotlib.use('Agg')", 'import matplotlib.pyplot as plt'], SCRIPT_HEADER.format(**state_dict)

    def _script_footer(self):
        state_dict = self.state.as_dict()
        state_dict['x_log_str'] = 'log' if self.state.x_log else 'linear'
        state_dict['y_log_str'] = 'log' if self.state.y_log else 'linear'
        return [], SCRIPT_FOOTER.format(**state_dict)
Ejemplo n.º 7
0
class MatplotlibViewerMixin(object):
    def setup_callbacks(self):

        for spine in self.axes.spines.values():
            spine.set_zorder(ZORDER_MAX)

        self.loading_rectangle = Rectangle((0, 0),
                                           1,
                                           1,
                                           color='0.9',
                                           alpha=0.9,
                                           zorder=ZORDER_MAX - 1,
                                           transform=self.axes.transAxes)
        self.loading_rectangle.set_visible(False)
        self.axes.add_patch(self.loading_rectangle)

        self.loading_text = self.axes.text(
            0.4,
            0.5,
            'Computing',
            color='k',
            zorder=self.loading_rectangle.get_zorder() + 1,
            ha='left',
            va='center',
            transform=self.axes.transAxes)
        self.loading_text.set_visible(False)

        self.state.add_callback('aspect', self.update_aspect)

        self.update_aspect()

        self.state.add_callback('x_min', self.limits_to_mpl)
        self.state.add_callback('x_max', self.limits_to_mpl)
        self.state.add_callback('y_min', self.limits_to_mpl)
        self.state.add_callback('y_max', self.limits_to_mpl)

        self.limits_to_mpl()

        self.state.add_callback('x_log', self.update_x_log, priority=1000)
        self.state.add_callback('y_log', self.update_y_log, priority=1000)

        self.update_x_log()

        self.axes.callbacks.connect('xlim_changed', self.limits_from_mpl)
        self.axes.callbacks.connect('ylim_changed', self.limits_from_mpl)

        self.axes.set_autoscale_on(False)

        self.state.add_callback('x_axislabel', self.update_x_axislabel)
        self.state.add_callback('x_axislabel_weight', self.update_x_axislabel)
        self.state.add_callback('x_axislabel_size', self.update_x_axislabel)

        self.state.add_callback('y_axislabel', self.update_y_axislabel)
        self.state.add_callback('y_axislabel_weight', self.update_y_axislabel)
        self.state.add_callback('y_axislabel_size', self.update_y_axislabel)

        self.state.add_callback('x_ticklabel_size', self.update_x_ticklabel)
        self.state.add_callback('y_ticklabel_size', self.update_y_ticklabel)

        self.update_x_axislabel()
        self.update_y_axislabel()
        self.update_x_ticklabel()
        self.update_y_ticklabel()

    def _update_computation(self, message=None):
        pass

    def update_x_axislabel(self, *event):
        self.axes.set_xlabel(self.state.x_axislabel,
                             weight=self.state.x_axislabel_weight,
                             size=self.state.x_axislabel_size)
        self.redraw()

    def update_y_axislabel(self, *event):
        self.axes.set_ylabel(self.state.y_axislabel,
                             weight=self.state.y_axislabel_weight,
                             size=self.state.y_axislabel_size)
        self.redraw()

    def update_x_ticklabel(self, *event):
        self.axes.tick_params(axis='x', labelsize=self.state.x_ticklabel_size)
        self.axes.xaxis.get_offset_text().set_fontsize(
            self.state.x_ticklabel_size)
        self.redraw()

    def update_y_ticklabel(self, *event):
        self.axes.tick_params(axis='y', labelsize=self.state.y_ticklabel_size)
        self.axes.yaxis.get_offset_text().set_fontsize(
            self.state.y_ticklabel_size)
        self.redraw()

    def redraw(self):
        self.figure.canvas.draw()

    def update_x_log(self, *args):
        self.axes.set_xscale('log' if self.state.x_log else 'linear')
        self.redraw()

    def update_y_log(self, *args):
        self.axes.set_yscale('log' if self.state.y_log else 'linear')
        self.redraw()

    def update_aspect(self, aspect=None):
        self.axes.set_aspect(self.state.aspect, adjustable='datalim')

    @avoid_circular
    def limits_from_mpl(self, *args):

        with delay_callback(self.state, 'x_min', 'x_max', 'y_min', 'y_max'):

            if isinstance(self.state.x_min, np.datetime64):
                x_min, x_max = [
                    mpl_to_datetime64(x) for x in self.axes.get_xlim()
                ]
            else:
                x_min, x_max = self.axes.get_xlim()

            self.state.x_min, self.state.x_max = x_min, x_max

            if isinstance(self.state.y_min, np.datetime64):
                y_min, y_max = [
                    mpl_to_datetime64(y) for y in self.axes.get_ylim()
                ]
            else:
                y_min, y_max = self.axes.get_ylim()

            self.state.y_min, self.state.y_max = y_min, y_max

    @avoid_circular
    def limits_to_mpl(self, *args):

        if self.state.x_min is not None and self.state.x_max is not None:
            x_min, x_max = self.state.x_min, self.state.x_max
            if self.state.x_log:
                if self.state.x_max <= 0:
                    x_min, x_max = 0.1, 1
                elif self.state.x_min <= 0:
                    x_min = x_max / 10
            self.axes.set_xlim(x_min, x_max)

        if self.state.y_min is not None and self.state.y_max is not None:
            y_min, y_max = self.state.y_min, self.state.y_max
            if self.state.y_log:
                if self.state.y_max <= 0:
                    y_min, y_max = 0.1, 1
                elif self.state.y_min <= 0:
                    y_min = y_max / 10
            self.axes.set_ylim(y_min, y_max)

        if self.state.aspect == 'equal':

            # FIXME: for a reason I don't quite understand, dataLim doesn't
            # get updated immediately here, which means that there are then
            # issues in the first draw of the image (the limits are such that
            # only part of the image is shown). We just set dataLim manually
            # to avoid this issue.
            self.axes.dataLim.intervalx = self.axes.get_xlim()
            self.axes.dataLim.intervaly = self.axes.get_ylim()

            # We then force the aspect to be computed straight away
            self.axes.apply_aspect()

            # And propagate any changes back to the state since we have the
            # @avoid_circular decorator
            with delay_callback(self.state, 'x_min', 'x_max', 'y_min',
                                'y_max'):
                # TODO: fix case with datetime64 here
                self.state.x_min, self.state.x_max = self.axes.get_xlim()
                self.state.y_min, self.state.y_max = self.axes.get_ylim()

        self.axes.figure.canvas.draw()

    def _update_appearance_from_settings(self, message=None):
        update_appearance_from_settings(self.axes)
        self.redraw()

    def get_layer_artist(self, cls, layer=None, layer_state=None):
        return cls(self.axes, self.state, layer=layer, layer_state=layer_state)

    def apply_roi(self, roi, override_mode=None):
        """ This method must be implemented by subclasses """
        raise NotImplementedError

    def _script_header(self):
        state_dict = self.state.as_dict()
        return ['import matplotlib.pyplot as plt'
                ], SCRIPT_HEADER.format(**state_dict)

    def _script_footer(self):
        state_dict = self.state.as_dict()
        state_dict['x_log_str'] = 'log' if self.state.x_log else 'linear'
        state_dict['y_log_str'] = 'log' if self.state.y_log else 'linear'
        return [], SCRIPT_FOOTER.format(**state_dict)