Пример #1
0
    def _save_graph_properties(self):
        if self.line2D:
            for key in self.__requiredgraphproperties:
                self.graph_properties[
                    'line2D_properties'][
                    key] = artist.getp(
                    self.line2D,
                    key)

        if self.line2Dfit:
            for key in self.__requiredgraphproperties:
                self.graph_properties[
                    'line2Dfit_properties'][
                    key] = artist.getp(
                    self.line2Dfit,
                    key)

        if self.line2Dresiduals:
            for key in self.__requiredgraphproperties:
                self.graph_properties[
                    'line2Dresiduals_properties'][
                    key] = artist.getp(
                    self.line2Dresiduals,
                    key)

        if self.line2Dsld_profile:
            for key in self.__requiredgraphproperties:
                self.graph_properties[
                    'line2Dsld_profile_properties'][
                    key] = artist.getp(
                    self.line2Dsld_profile,
                    key)
Пример #2
0
 def onExport(self, event):
     from matplotlib.artist import getp
     
     shell = self.get_root_parent().app.shell
     canvas = event.GetEventObject()
     sel = [a() for a in canvas.selection]
     for a in self._artists:
         if a in sel:
            fig_val={"xdata": getp(a,"xdata"), 
                     "ydata": getp(a,"ydata")}
            text= '#Exporting data as fig_val[\'xdata\'], fig_val[\'ydata\']\"'
            self._export_shell(fig_val, 'fig_val', text)
            break
Пример #3
0
 def onExport(self, event):
     from matplotlib.artist import getp
     fig_val = None
     canvas = event.GetEventObject()
     sel = [a() for a in canvas.selection]
     for a in self._artists:
        if a in sel:
            print("Exporting Data to Shell") 
            fig_val={"xdata": getp(a,"xdata"), 
                     "ydata": getp(a,"ydata")}
            break
     if fig_val is not None:
         text= '#Exporting data as fig_val[\'xdata\'], fig_val[\'ydata\']\"'            
         self._export_shell(fig_val, 'fig_val', text)
Пример #4
0
    def onExport(self, event):
        from matplotlib.artist import getp

        shell = self.get_root_parent().app.shell
        canvas = event.GetEventObject()
        sel = [a() for a in canvas.selection]
        for a in self._artists:
            if a in sel:
                fig_val = {
                    "xdata": getp(a, "xdata"),
                    "ydata": getp(a, "ydata")
                }
                text = '#Exporting data as fig_val[\'xdata\'], fig_val[\'ydata\']\"'
                self._export_shell(fig_val, 'fig_val', text)
                break
Пример #5
0
 def onExport(self, event):
     from matplotlib.artist import getp
     fig_val = None
     canvas = event.GetEventObject()
     sel = [a() for a in canvas.selection]
     for a in self._artists:
         if a in sel:
             print("Exporting Data to Shell")
             fig_val = {
                 "xdata": getp(a, "xdata"),
                 "ydata": getp(a, "ydata")
             }
             break
     if fig_val is not None:
         text = '#Exporting data as fig_val[\'xdata\'], fig_val[\'ydata\']\"'
         self._export_shell(fig_val, 'fig_val', text)
Пример #6
0
    def onExport(self, event):
        from matplotlib.artist import getp

        shell = self.get_root_parent().app.shell        
        canvas = event.GetEventObject()
        sel = [a() for a in canvas.selection]
        for a in self._artists:
           if a in sel:
               print("Exporting Data to Shell") 
               fig_val={"xdata": getp(a,"xdata"), 
                        "ydata": getp(a,"ydata")}
               self.write2shell(fig_val, "fig_val")
               break
        shell.redirectStdout(True)
        text= '#Exporting data as fig_val[\'xdata\'], fig_val[\'ydata\']\"'
        shell.writeOut(text)
        shell.redirectStdout(False)
Пример #7
0
    def save_graph_properties(self):
        # pass
        if self.ax_data is not None:
            for key in _requiredgraphproperties:
                self["data_properties"][key] = artist.getp(self.ax_data, key)

        if self.ax_fit is not None:
            for key in _requiredgraphproperties:
                self["fit_properties"][key] = artist.getp(self.ax_fit, key)

        if self.ax_residuals is not None:
            for key in _requiredgraphproperties:
                self["residuals_properties"][key] = artist.getp(
                    self.ax_residuals, key)

        if self.ax_sld_profile is not None:
            for key in _requiredgraphproperties:
                self["sld_profile_properties"][key] = artist.getp(
                    self.ax_sld_profile, key)
Пример #8
0
    def onExport(self, event):
        from matplotlib.artist import getp

        shell = self.get_root_parent().app.shell
        canvas = event.GetEventObject()
        sel = [a() for a in canvas.selection]
        for a in self._artists:
            if a in sel:
                print("Exporting Data to Shell")
                fig_val = {
                    "xdata": getp(a, "xdata"),
                    "ydata": getp(a, "ydata")
                }
                self.write2shell(fig_val, "fig_val")
                break
        shell.redirectStdout(True)
        text = '#Exporting data as fig_val[\'xdata\'], fig_val[\'ydata\']\"'
        shell.writeOut(text)
        shell.redirectStdout(False)
Пример #9
0
    def save_graph_properties(self):
        if self.line2D:
            for key in _requiredgraphproperties:
                self['line2D_properties'][key] = artist.getp(self.line2D, key)

        if self.line2Dfit:
            for key in _requiredgraphproperties:
                self['line2Dfit_properties'][key] = artist.getp(
                    self.line2Dfit, key)

        if self.line2Dresiduals:
            for key in _requiredgraphproperties:
                self['line2Dresiduals_properties'][key] = artist.getp(
                    self.line2Dresiduals, key)

        if self.line2Dsld_profile:
            for key in _requiredgraphproperties:
                self['line2Dsld_profile_properties'][key] = artist.getp(
                    self.line2Dsld_profile, key)
Пример #10
0
class SigMng(object):
    '''
    Wraps an ordered map of 'signals' for easy data/vector mgmt, playback and visualization
    The mpl figure is managed as a singleton and subplots are used by default whenever
    viewing multiple signals. Numpy arrays are loaded as 'visarrays' which maintain
    references to mpl axes and artists internally and use the array data directly.
    This allows for efficient animation support with a compact interface.

    The main intent of this class is to avoid needless figure clutter, memory consumption,
    and boilerplate lines found in sig proc scripts.

    You can use this class interactively (IPython) as well as programatically
    Examples:
        >>> s = SigMng()
        >>> s.find_data('path/to/audio/files')
        ...
        found 4 'wav' files
        >>> s.show_corpus
        ...
    '''
    def __init__(self, *args):
        self._signals = Lict()  # what's underneath...
        # unpack any file names that might have been passed initially
        if args:
            for i in args:
                self._sig[i] = None

        # mpl stateful objs
        # self._lines    = []
        self._fig = None
        self._mng = None
        self._axes_cache = Lict()
        self._arts = []

        # to be updated by external thread
        self._cur_sample = 0
        self._cur_sig = None

        # animation state
        # self._provision_anim = lambda : None
        # self._anim_func      = None
        self._frames_gen = None
        # self._anim_fargs     = None
        # self._anim_sig_set   = set()
        # self._time_elapsed   = 0
        self._realtime_artists = []
        self._cursor = None

        # animation settings
        # determines the resolution for animated features (i.e. window size)
        self._fps = 15

        # FIXME: make scr_dim impl more pythonic!
        self.w, self.h = scr_dim()
        # get the garbarge truck rolling...
        gc.enable()

    # FIXME: is this superfluous?
    # delegate as much as possible to the oid
    def __getattr__(self, attr):
        try:
            return self.__dict__[attr]
        except KeyError:
            return getattr(self._signals, attr)

    def __setitem__(self, *args):
        return self._signals.__setitem__(*args)

    def __getitem__(self, key):
        '''lazy loading of signals and delegation to the oid'''
        # path exists in our set
        if issubclass(self._signals[key], np.ndarray):
            pass
            # path exists but signal not yet loaded
        else:
            self._load_sig(key)
        return self._signals[key]

    # TODO: move to Signal class
    def _load_sig(self, path):
        if isinstance(path, int):
            path = self._signals._get_key(path)
        if path in self._signals:
            # TODO: loading should be completed by 'Signal' class (i.e. format specific)
            try:
                print("loading wave file : ", os.path.basename(path))
                # read audio data and params
                sig, self.Fs, self.bd = wav2np(path)
                # (self.Fs, sig) = wavfile.read(self.flist[index])

                amax = 2**(self.bd - 1) - 1
                sig = sig / amax
                self._signals[path] = sig
                print("INFO |->", len(sig), "samples =",
                      len(sig) / self.Fs, "seconds @ ", self.Fs, " Hz")
                return path
            except:
                raise Exception(
                    "Failed to load wave file!\nEnsure that the wave file exists and is in LPCM format"
                )
        else:
            raise KeyError("no entry in the signal set for key '" + str(path) +
                           "'")

    def _prettify(self):
        '''pretty the figure in sensible ways'''
        # tighten up the margins
        self._fig.tight_layout(pad=1.03)

    def _sig2axis(self, sig_key=None):
        '''return rendered axis corresponding to a signal'''
        #FIXME: in Signal class we should assign the axes
        # on which a signal is drawn to avoid this hack?
        if not self._fig: return None
        # if not sig_key: sig_key = self._cur_sig
        sig = self[sig_key]
        try:
            for ax in self._axes_cache.values():
                for line in ax.get_lines():
                    if sig in line.get_ydata():  # continue
                        return ax
                    else:
                        return None
        except ValueError:  # no easy way to compare vectors?
            return None
            # return ax
        # else:
        #     return self._sig2axis(self._cur_sig)

    getp = lambda key: getp(self._axes_cache[key])

    def _init_anim(self):
        '''
        in general we provision as follows:
        1) check if the 'current signal(s)' (_cur_signal) is shown on a figure axis
        2) if not (re-)plot them
        3) return the baseline artists which won't change after plotting (usually the time series)
        '''
        # axes = [self._sig2axis(key) for key in self._anim_sig_set]
        # ax = self._sig2axis(self._cur_sig)
        # 4) do addional steps using self._provision_anim
        # self._provision_anim()
        # return the artists which won't change during animation (blitting)
        y = tuple(axes for axes in self._fig.get_axes())
        print(y[0])
        return y
        # return ax.get_lines()
        # line = vline(axis, time, colour=green)

    def _do_fanim(self, fig=None):
        '''run the function based animation once'''
        if not fig:
            fig = self._fig
        anim = animation.FuncAnimation(
            fig,
            _set_cursor,
            frames=self.fr,  #self._audio_time_gen,
            init_func=self._init_anim,
            interval=1000 / self._fps,
            fargs=self._arts,
            blit=True,
            repeat=False)

        return anim
        # self._animations.appendleft(anim)
        # else: raise RuntimeError("no animation function has been set!")

    def sound(self, key, **kwargs):
        '''JUST play sound'''
        sig = self[key]
        sound4python(sig, 8e8, **kwargs)

    def play(self, key):
        '''play sound + do mpl animation with a playback cursor'''
        # sig = self[key]
        self._cur_sig = key
        ax = self._sig2axis(key)
        if not ax:
            ax = self.plot(key)

        # self._arts.append(anim_action(cursor(ax, 0), action=Line2D.set_xdata))
        # self._arts.append(
        self._cursor = cursor(ax, 10)
        # set animator routine
        # self._anim_func = self._set_cursor
        # set the frame iterator
        self._frames_gen = ift()
        self._do_fanim()

    def _audio_time_gen(self):
        '''generate the audio sample-time for cursor placement
        on each animation frame'''
        # frame_step = self.Fs / self._fps    # samples/frame
        time_step = 1 / self._fps
        self._audio_time = 0  # this can be locked out
        # FIXME: get rid of this hack job!
        total_time = len(self[self._cur_sig] / self.Fs)
        while self._audio_time <= total_time:
            yield self._audio_time
            self._audio_time += time_step

    def _show_corpus(self):
        '''pretty print the internal path list'''
        # TODO: show the vectors in the last column
        try:
            print_table(map(os.path.basename, self._signals.keys()))
        except:
            raise ValueError("no signal entries exist yet!?...add some first")

    # convenience attrs
    figure = property(lambda self: self._fig)
    mng = property(lambda self: self._mng)
    flist = property(lambda self: [f for f in self.keys()])
    show_corpus = property(lambda self: self._show_corpus())

    def get(self, key):
        self.__getitem__(key)

    def kill_mpl(self):
        # plt.close('all')
        self.mng.destroy()
        self._fig = None

    def close(self):
        if self._fig:
            # plt.close(self._fig)
            # self._fig.close()
            self._fig = None  # FIXME: is this necessary?

    def clear(self):
        # FIXME: make this actually release memory instead of
        # just being a bitch!
        self._signals.clear()
        # gc.collect()

    def fullscreen(self):
        '''convenience func to fullscreen if using a mpl gui fe'''
        if self._mng:
            self._mng.full_screen_toggle()
        else:
            print("no figure handle exists?")

    def add_path(self, p):
        '''
        Add a data file path to the SigMng
        Can take a single path string or a sequence as input
        '''
        if os.path.exists(p):
            # filename, extension = os.path.splitext(p)
            if p not in self:  # ._signals.keys():
                self[p] = None
            else:
                print(os.path.basename(p),
                      "is already in our path db, see grapher.SigPack.show()")
        else:
            raise ValueError("path string not valid?!")

    def plot(self, *args, **kwargs):
        '''
        can take inputs of ints, ranges or paths
        meant to be used as an interactive interface...
        returns a either a list of axes or a single axis
        '''
        axes = [axis for axis, lines in self.itr_plot(args, **kwargs)]
        self._prettify()
        if len(axes) < 2:
            axes = axes[0]
        # self.figure.show() #-> only works when using pyplot
        return axes

    def itr_plot(self, items, **kwargs):
        '''A lazy plotter to save aux space?...doubtful
        should be used as the programatic interface to _plot
        '''
        paths = []
        for i in items:
            # path string, add it if we don't have it
            if isinstance(i, str) and i not in self:
                self.add_path(i)
                paths.append(i)
            elif isinstance(i, int):
                paths.append(self._get_key(i))  # delegate to oid

        # plot the paths (composed generator)
        # return (axis,lines for axis,lines in self._plot(paths, **kwargs))
        for axis, lines in self._plot(paths, **kwargs):
            yield axis, lines

    def _plot(self,
              keys_itr,
              start_time=0,
              time_on_x=True,
              singlefig=True,
              title=None):
        '''Plot generator - uses 'makes sense' figure / axes settings
        inputs: keys_itr -> must be an iterator over names in self.keys()
        '''
        # FIXME: there is still a massive memory issue when making multiple
        # plot calls and I can't seem to manage it using the oo-interface or
        # pyplot (at least not without closing the figure all the time...lame)

        if isinstance(keys_itr, list):
            keys = keys_itr
        else:
            keys = [i for i in keys_itr]

        # create a new figure and format
        if not singlefig or not (self._fig and self._mng.window):

            # using mpl/backends.py pylab setup (NOT pylab)
            # self._mng = new_figure_manager(1)
            # self._mng.set_window_title('visig')
            # self._fig = self._mng.canvas.figure

            # using pylab
            # pylab and pyplot seem to be causing mem headaches?
            # self._fig = pylab.figure()
            # self._mng = pylab.get_current_fig_manager()

            # using pyplot
            self._fig = plt.figure()
            self._mng = plt.get_current_fig_manager()

            # using oo-api directly
            # self._fig = Figure()
            # self._canvas = FigureCanvasQT(self._fig)
            # self._mng = new_figure_manager_given_figure(self._fig, 1)
            # self._mng = FigureManager(self._canvas, 1)
            # self._mng = new_figure_manager_given_figure(1, self._fig)

            self._mng.set_window_title('visig')
        else:
            # for axis in self.figure.get_axes():
            #     axis.clear()
            #     gc.collect()
            # for line in axis.get_lines():
            # line.clear()
            # gc.collect()
            self.figure.clear()
            gc.collect()

        # draw fig
        # TODO: eventually detect if a figure is currently shown?
        # draw_if_interactive()

        # set window to half screen size if only one signal
        if len(keys) < 2:
            h = self.h / 2
        else:
            h = self.h
        # try:
        self._mng.resize(self.w, h)
        # except: raise Exception("unable to resize window!?")
        # self._fig.set_size_inches(10,2)
        # self._fig.tight_layout()

        # title settings
        font_style = {'size': 'small'}

        self._axes_cache.clear()
        # main plot loop
        for icount, key in enumerate(keys):
            # always set 'curr_sig' to last plotted
            self._cur_sig = key
            sig = self[key]
            slen = len(sig)

            # set up a time vector and plot
            t = np.linspace(start_time, slen / self.Fs, num=slen)
            ax = self._fig.add_subplot(len(keys), 1, icount + 1)

            # maintain the key map to our figure's axes
            self._axes_cache[key] = ax

            lines = ax.plot(t, sig, figure=self._fig)
            ax.set_xlabel('Time (s)', fontdict=font_style)

            if title is None:
                title = os.path.basename(key)
            ax.set_title(title, fontdict=font_style)

            # ax.figure.canvas.draw()
            ax.figure.canvas.draw()
            yield (ax, lines)

    def find_wavs(self, sdir):
        '''find all wav files in a dir'''
        for i, path in enumerate(glob.iglob("{}/*.wav".format(sdir))):
            self[path] = None
            print("found file : ", path)
        print("found", len(self.flist), "files")
Пример #11
0
 def GetSplineData(self, x=None):
     from matplotlib.artist import getp
     return [ (getp(a,"xdata"), 
               getp(a,"ydata"))  for a in self._artists]
Пример #12
0
def OnMouseDrag(event):
    global MousePress
    if MousePress is None or event.inaxes is None:
        return

    x0, y0, xpress, ypress = MousePress
    dx = event.xdata - xpress
    dy = event.ydata - ypress

    # Check whether the zoom controls are selected
    if SelectedTrack == -100:
        if z_container is not None:
            global img_limits, axis, img_offset
            w, h = getp(z_container, "width"), getp(z_container, "height")
            xy = getp(z_container, "xy")

            xrange = 0.5 * (img_limits[0][1] - img_limits[0][0])
            yrange = 0.5 * (img_limits[1][0] - img_limits[1][1])

            scale = 1. / zoom_factor
            xcenter = 2 * (event.xdata - xy[0]) * xrange / w
            ycenter = 2 * (event.ydata - xy[1]) * yrange / h

            xcenter = numpy.clip(xcenter, xrange * scale,
                                 img_limits[0][1] - xrange * scale)
            ycenter = numpy.clip(ycenter, yrange * scale,
                                 img_limits[1][0] - yrange * scale)
            img_offset = xcenter, ycenter
            axis.set_xlim([xcenter - xrange * scale, xcenter + xrange * scale])
            axis.set_ylim([ycenter + yrange * scale, ycenter - yrange * scale])
            UpdateZoomGizmo(scale, xrange, yrange)
            canvas.draw_idle()  # fo

        return  # Stop the function here

    # Fail conditions
    if SelectedTrack == -1 or len(TrackedStars) == 0: return

    sel = list(
        filter(lambda obj: obj.label == "Rect" + str(SelectedTrack),
               axis.artists))
    text = list(
        filter(lambda obj: obj.label == "Text" + str(SelectedTrack),
               axis.texts))
    if len(sel) > 0 and len(text) > 0:
        sel[0].set_x(x0 + dx)
        sel[0].set_y(y0 + dy)
        text[0].set_x(x0 + dx + TrackedStars[SelectedTrack].star.radius)
        text[0].set_y(y0 - TrackedStars[SelectedTrack].star.radius + 6 + dy)
        TrackedStars[SelectedTrack].trackedPos[CurrentFile][1] = int(
            y0 + dy + TrackedStars[SelectedTrack].star.radius)
        TrackedStars[SelectedTrack].trackedPos[CurrentFile][0] = int(
            x0 + dx + TrackedStars[SelectedTrack].star.radius)
        TrackedStars[SelectedTrack].currPos = list(
            reversed(TrackedStars[SelectedTrack].trackedPos[CurrentFile]))
    poly = next(
        filter(lambda obj: obj.label == "Poly" + str(SelectedTrack),
               axis.artists))
    poly.set_xy(TrackedStars[SelectedTrack].trackedPos[max(CurrentFile -
                                                           4, 0):CurrentFile +
                                                       1])
    canvas.draw_idle()
Пример #13
0
def OnMouseDrag(event):
    global MousePress, Stars, drag_displacement
    if MousePress is None or event.inaxes is None:
        return
    x0, y0, xpress, ypress = MousePress
    dx = event.xdata - xpress
    dy = event.ydata - ypress

    # Check whether the zoom controls are selected
    if SelectedStar == -100:
        if z_container is not None:
            global img_limits, axis, img_offset
            w, h = getp(z_container, "width"), getp(z_container, "height")
            xy = getp(z_container, "xy")

            xrange = 0.5 * (img_limits[0][1] - img_limits[0][0])
            yrange = 0.5 * (img_limits[1][0] - img_limits[1][1])

            scale = 1. / zoom_factor
            xcenter = 2 * (event.xdata - xy[0]) * xrange / w
            ycenter = 2 * (event.ydata - xy[1]) * yrange / h

            xcenter = numpy.clip(xcenter, xrange * scale,
                                 img_limits[0][1] - xrange * scale)
            ycenter = numpy.clip(ycenter, yrange * scale,
                                 img_limits[1][0] - yrange * scale)
            img_offset = xcenter, ycenter
            axis.set_xlim([xcenter - xrange * scale, xcenter + xrange * scale])
            axis.set_ylim([ycenter + yrange * scale, ycenter - yrange * scale])
            UpdateZoomGizmo(scale, xrange, yrange)
            canvas.draw_idle()  # fo

        return  # Stop the function here

    # Fail conditions
    if SelectedStar == -1 or len(Stars) == 0: return

    sel = list(
        filter(lambda obj: obj.label == "Rect" + str(SelectedStar),
               axis.artists))
    bod = list(
        filter(lambda obj: obj.label == "Bound" + str(SelectedStar),
               axis.artists))
    text = list(
        filter(lambda obj: obj.label == "Text" + str(SelectedStar),
               axis.texts))
    if len(sel) > 0 and len(text) > 0:
        sel[0].set_x(x0 + dx + Stars[SelectedStar].bounds -
                     Stars[SelectedStar].radius)
        sel[0].set_y(y0 + dy + Stars[SelectedStar].bounds -
                     Stars[SelectedStar].radius)
        bod[0].set_x(x0 + dx)
        bod[0].set_y(y0 + dy)
        text[0].set_x(x0 + dx + Stars[SelectedStar].bounds)
        text[0].set_y(y0 - 6 + dy)
        Stars[SelectedStar].location = (int(y0 + dy +
                                            Stars[SelectedStar].bounds),
                                        int(x0 + dx +
                                            Stars[SelectedStar].bounds))
    canvas.draw_idle()

    sx = drag_displacement[2] + abs(event.xdata - drag_displacement[0])
    sy = drag_displacement[3] + abs(event.ydata - drag_displacement[1])
    drag_displacement = event.xdata, event.ydata, sx, sy
Пример #14
0
 def setAxes(self, pAxes):
     logging.debug("%s.setAxes invoked" % self.__class__)
     self.mAxes = pAxes
     for prop in self.props.keys():
         logging.debug("\tsetting prop %s to value %s" % (prop, repr(getp(self.mAxes, prop))))
         self.props[prop].set(getp(self.mAxes, prop))
Пример #15
0
 def GetSplineData(self, x=None):
     from matplotlib.artist import getp
     return [(getp(a, "xdata"), getp(a, "ydata")) for a in self._artists]
Пример #16
0
    def draw_fig(self, **kwargs):
        slice_opts = {"ls": "-", "color": "firebrick", "lw": 0.5}  # defaults
        hslice_opts = slice_opts.copy()
        vslice_opts = slice_opts.copy()
        #
        hslice_opts.update(kwargs.get("hslice_opts", {}))
        vslice_opts.update(kwargs.get("vslice_opts", {}))

        # #
        if (self.hslice_idx is None) and (self.vslice_idx is None):
            gs = GridSpec(1, 1, height_ratios=[1], width_ratios=[1])
            self.ax0 = self.fig.add_subplot(gs[0])
            self.main_panel(**kwargs)

        # ---- #
        elif (self.hslice_idx is not None) and (self.vslice_idx is None):
            gs = GridSpec(2, 1, height_ratios=[1, 3], width_ratios=[1])
            self.ax0 = self.fig.add_subplot(gs[1, 0])
            self.axh = self.fig.add_subplot(gs[0, 0], sharex=self.ax0)
            #
            self.main_panel(**kwargs)
            #
            self.ax0.axhline(y=self.v_axis[self.hslice_idx], **hslice_opts)
            #
            self.ax0.annotate(
                "{:.1f}".format(self.v_axis[self.hslice_idx]),
                xy=(self.h_axis[3], self.v_axis[self.hslice_idx + 3]),
                xycoords="data",
                color=hslice_opts["color"],
            )
            #
            self.axh.set_xmargin(0)
            self.axh.set_ylabel(self.label["z"])
            self.axh.plot(self.h_axis, self.data[self.hslice_idx, :],
                          **hslice_opts)
            self.axh.set_ylim(self.vmin, self.vmax)
            #
            self.axh.xaxis.set_visible(False)
            #
            for sp in ("top", "bottom", "right"):
                self.axh.spines[sp].set_visible(False)
            #
            self.fig.subplots_adjust(hspace=0.03)

        # | #
        elif (self.vslice_idx is not None) and (self.hslice_idx is None):
            gs = GridSpec(1, 2, height_ratios=[1], width_ratios=[3, 1])
            self.ax0 = self.fig.add_subplot(gs[0, 0])
            self.axv = self.fig.add_subplot(gs[0, 1], sharey=self.ax0)
            #
            self.main_panel(**kwargs)
            #
            self.ax0.axvline(x=self.h_axis[self.vslice_idx], **vslice_opts)
            #
            self.ax0.annotate(
                "{:.1f}".format(self.h_axis[self.vslice_idx]),
                xy=(self.h_axis[self.vslice_idx - 40], self.v_axis[-40]),
                xycoords="data",
                color=vslice_opts["color"],
                rotation="vertical",
            )
            #
            self.axv.set_ymargin(0)
            self.axv.set_xlabel(self.label["z"])
            self.axv.plot(self.data[:, self.vslice_idx], self.v_axis,
                          **vslice_opts)
            self.axv.set_xlim(self.vmin, self.vmax)
            #
            self.axv.yaxis.set_visible(False)
            #
            for sp in ("top", "left", "right"):
                self.axv.spines[sp].set_visible(False)
            #
            self.fig.subplots_adjust(wspace=0.03)

        # --|-- #
        else:
            gs = GridSpec(2, 2, height_ratios=[1, 3], width_ratios=[3, 1])
            self.ax0 = self.fig.add_subplot(gs[1, 0])
            self.axh = self.fig.add_subplot(gs[0, 0], sharex=self.ax0)
            self.axv = self.fig.add_subplot(gs[1, 1], sharey=self.ax0)
            #
            self.main_panel(**kwargs)
            #
            self.ax0.axhline(y=self.v_axis[self.hslice_idx],
                             **hslice_opts)  ##----##
            self.ax0.axvline(x=self.h_axis[self.vslice_idx],
                             **vslice_opts)  ## | ##
            # --- #
            self.ax0.annotate(
                "{:.1f}".format(self.v_axis[self.hslice_idx]),
                xy=(self.h_axis[3], self.v_axis[self.hslice_idx + 3]),
                xycoords="data",
                color=hslice_opts["color"],
            )
            # | #
            self.ax0.annotate(
                "{:.1f}".format(self.h_axis[self.vslice_idx]),
                xy=(self.h_axis[self.vslice_idx - 40], self.v_axis[-40]),
                xycoords="data",
                color=vslice_opts["color"],
                rotation="vertical",
            )
            # --- #
            self.axh.set_xmargin(0)  # otherwise ax0 may have white margins
            self.axh.set_ylabel(self.label["z"])
            self.axh.plot(self.h_axis, self.data[self.hslice_idx, :],
                          **hslice_opts)
            self.axh.set_ylim(self.vmin, self.vmax)
            # self.axh.set_yticks([-1, 0, 1])
            # | #
            self.axv.set_ymargin(0)
            self.axv.set_xlabel(self.label["z"])
            self.axv.plot(self.data[:, self.vslice_idx], self.v_axis,
                          **vslice_opts)
            self.axv.set_xlim(self.vmin, self.vmax)
            # hide the relevant axis
            self.axh.xaxis.set_visible(False)  # -
            self.axv.yaxis.set_visible(False)  # |
            # "Despine" the slice profiles
            for ax, spines in (
                (self.axh, ("top", "bottom", "right")),
                (self.axv, ("top", "left", "right")),
            ):
                #
                for sp in spines:
                    ax.spines[sp].set_visible(False)
            #
            self.fig.subplots_adjust(wspace=0.03, hspace=0.03)

        # self.fig.tight_layout()
        #
        self.ax0.text(0.02,
                      0.02,
                      self.text,
                      transform=self.ax0.transAxes,
                      color="firebrick")
        #
        if self.cbar:
            cax = inset_axes(self.ax0, width="70%", height="3%", loc=2)
            cbar = self.fig.colorbar(
                self.im, cax=cax,
                orientation="horizontal")  # ticks=[self.vmin, self.vmax]
            # cbar.set_label(self.label['z'], color='firebrick')
            self.ax0.text(
                0.74,
                0.97,
                self.label["z"],
                transform=self.ax0.transAxes,
                color="firebrick",
            )
            # cbar.ax.xaxis.set_ticks_position('top')
            # cbar.ax.xaxis.set_label_position('top')
            cbar.ax.tick_params(color="firebrick", width=1.5, labelsize=8)
            cbxtick_obj = getp(cbar.ax.axes, "xticklabels")
            setp(cbxtick_obj, color="firebrick")
Пример #17
0
    def _draw_fig(self, **kwargs):
        slice_opts = {"ls": "-", "color": "#ff7f0e", "lw": 1.5}  # defaults
        hslice_opts = slice_opts.copy()
        vslice_opts = slice_opts.copy()
        #
        hslice_opts.update(kwargs.get("hslice_opts", {}))
        vslice_opts.update(kwargs.get("vslice_opts", {}))

        # #
        if (self.hslice_idx is None) and (self.vslice_idx is None):
            gs = GridSpec(1, 1, height_ratios=[1], width_ratios=[1])
            self.ax0 = self.fig.add_subplot(gs[0])
            self._main_panel(**kwargs)

        # ---- #
        elif (self.hslice_idx is not None) and (self.vslice_idx is None):
            gs = GridSpec(2, 1, height_ratios=[1, 3], width_ratios=[1])
            self.ax0 = self.fig.add_subplot(gs[1, 0])
            self.axh = self.fig.add_subplot(gs[0, 0], sharex=self.ax0)
            #
            self._main_panel(**kwargs)
            #
            self.ax0.axhline(y=self.v_axis[self.hslice_idx], **hslice_opts)
            #
            trans = transforms.blended_transform_factory(
                self.ax0.get_yticklabels()[0].get_transform(),
                self.ax0.transData)
            self.ax0.text(
                0,
                self.v_axis[self.hslice_idx],
                "{:.1f}".format(self.v_axis[self.hslice_idx]),
                color=hslice_opts["color"],
                transform=trans,
                ha="right",
                va="center",
            )
            #
            self.axh.set_xmargin(0)
            self.axh.set_ylabel(self.label["z"])
            self.axh.plot(self.h_axis, self.data[self.hslice_idx, :],
                          **hslice_opts)
            self.axh.set_ylim(self.vmin, self.vmax)
            #
            self.axh.xaxis.set_visible(False)
            #
            for sp in ("top", "bottom", "right"):
                self.axh.spines[sp].set_visible(False)
            #
            self.fig.subplots_adjust(hspace=0.03)

        # | #
        elif (self.vslice_idx is not None) and (self.hslice_idx is None):
            gs = GridSpec(1, 2, height_ratios=[1], width_ratios=[3, 1])
            self.ax0 = self.fig.add_subplot(gs[0, 0])
            self.axv = self.fig.add_subplot(gs[0, 1], sharey=self.ax0)
            #
            self._main_panel(**kwargs)
            #
            self.ax0.axvline(x=self.h_axis[self.vslice_idx], **vslice_opts)
            #
            trans = transforms.blended_transform_factory(
                self.ax0.transData,
                self.ax0.get_xticklabels()[0].get_transform())
            self.ax0.text(
                self.h_axis[self.vslice_idx],
                0,
                "{:.1f}".format(self.h_axis[self.vslice_idx]),
                color=vslice_opts["color"],
                transform=trans,
                ha="center",
                va="top",
            )
            #
            self.axv.set_ymargin(0)
            self.axv.set_xlabel(self.label["z"])
            self.axv.plot(self.data[:, self.vslice_idx], self.v_axis,
                          **vslice_opts)
            self.axv.set_xlim(self.vmin, self.vmax)
            #
            self.axv.yaxis.set_visible(False)
            #
            for sp in ("top", "left", "right"):
                self.axv.spines[sp].set_visible(False)
            #
            self.fig.subplots_adjust(wspace=0.03)

        # --|-- #
        else:
            gs = GridSpec(2, 2, height_ratios=[1, 3], width_ratios=[3, 1])
            self.ax0 = self.fig.add_subplot(gs[1, 0])
            self.axh = self.fig.add_subplot(gs[0, 0], sharex=self.ax0)
            self.axv = self.fig.add_subplot(gs[1, 1], sharey=self.ax0)
            #
            self._main_panel(**kwargs)
            #
            self.ax0.axhline(y=self.v_axis[self.hslice_idx],
                             **hslice_opts)  # ##----##
            self.ax0.axvline(x=self.h_axis[self.vslice_idx],
                             **vslice_opts)  # ## | ##
            # --- #
            trans = transforms.blended_transform_factory(
                self.ax0.get_yticklabels()[0].get_transform(),
                self.ax0.transData)
            self.ax0.text(
                0,
                self.v_axis[self.hslice_idx],
                "{:.1f}".format(self.v_axis[self.hslice_idx]),
                color=hslice_opts["color"],
                transform=trans,
                ha="right",
                va="center",
            )
            # | #
            trans = transforms.blended_transform_factory(
                self.ax0.transData,
                self.ax0.get_xticklabels()[0].get_transform())
            self.ax0.text(
                self.h_axis[self.vslice_idx],
                0,
                "{:.1f}".format(self.h_axis[self.vslice_idx]),
                color=vslice_opts["color"],
                transform=trans,
                ha="center",
                va="top",
            )
            # --- #
            self.axh.set_xmargin(0)  # otherwise ax0 may have white margins
            self.axh.set_ylabel(self.label["z"])
            self.axh.plot(self.h_axis, self.data[self.hslice_idx, :],
                          **hslice_opts)
            self.axh.set_ylim(self.vmin, self.vmax)
            # | #
            self.axv.set_ymargin(0)
            self.axv.set_xlabel(self.label["z"])
            self.axv.plot(self.data[:, self.vslice_idx], self.v_axis,
                          **vslice_opts)
            self.axv.set_xlim(self.vmin, self.vmax)
            # hide the relevant axis
            self.axh.xaxis.set_visible(False)  # -
            self.axv.yaxis.set_visible(False)  # |
            # "Despine" the slice profiles
            for ax, spines in (
                (self.axh, ("top", "bottom", "right")),
                (self.axv, ("top", "left", "right")),
            ):
                #
                for sp in spines:
                    ax.spines[sp].set_visible(False)
            #
            self.fig.subplots_adjust(wspace=0.03, hspace=0.03)
        #
        self.ax0.text(0.02,
                      0.02,
                      self.text,
                      transform=self.ax0.transAxes,
                      color="#ff7f0e")
        #
        if self.cbar:
            cax = inset_axes(self.ax0, width="70%", height="3%", loc=9)
            cbar = self.fig.colorbar(self.im,
                                     cax=cax,
                                     orientation="horizontal")
            cbar.set_label(self.label["z"], color="#ff7f0e")
            cbar.ax.xaxis.set_ticks_position("top")
            cbar.ax.xaxis.set_label_position("top")
            cbar.ax.tick_params(color="#ff7f0e", width=1.5, labelsize=8)
            cbxtick_obj = getp(cbar.ax.axes, "xticklabels")
            setp(cbxtick_obj, color="#ff7f0e")
Пример #18
0
def DaVincify(ax, mag=1.0,
            f1=50, f2=0.01, f3=15,
            bgcolor='#F2EDDC',
            xaxis_loc=None,
            yaxis_loc=None,
            xaxis_arrow='+',
            yaxis_arrow='+',
            ax_extend=0.1,
            expand_axes=False):
    """Make axis look hand-drawn

    This adjusts all lines, text, legends, and axes in the figure to look
    like xkcd plots.  Other plot elements are not modified.
    
    Parameters
    ----------
    ax : Axes instance
        the axes to be modified.
    mag : float
        the magnitude of the distortion
    f1, f2, f3 : int, float, int
        filtering parameters.  f1 gives the size of the window, f2 gives
        the high-frequency cutoff, f3 gives the size of the filter
    xaxis_loc, yaxis_log : float
        The locations to draw the x and y axes.  If not specified, they
        will be drawn from the bottom left of the plot
    xaxis_arrow, yaxis_arrow : str
        where to draw arrows on the x/y axes.  Options are '+', '-', '+-', or ''
    ax_extend : float
        How far (fractionally) to extend the drawn axes beyond the original
        axes limits
    expand_axes : bool
        if True, then expand axes to fill the figure (useful if there is only
        a single axes in the figure)
    """
    # Get axes aspect
    ext = ax.get_window_extent().extents
    aspect = (ext[3] - ext[1]) / (ext[2] - ext[0])

    xlim = ax.get_xlim()
    ylim = ax.get_ylim()

    xspan = xlim[1] - xlim[0]
    yspan = ylim[1] - xlim[0]

    xax_lim = (xlim[0] - ax_extend * xspan,
               xlim[1] + ax_extend * xspan)
    yax_lim = (ylim[0] - ax_extend * yspan,
               ylim[1] + ax_extend * yspan)

    if xaxis_loc is None:
        xaxis_loc = ylim[0]

    if yaxis_loc is None:
        yaxis_loc = xlim[0]

    # Draw axes
    xaxis = pl.Line2D([xax_lim[0], xax_lim[1]], [xaxis_loc, xaxis_loc],
                      linestyle='-', color="#362A1C", linewidth = 0.7)
    yaxis = pl.Line2D([yaxis_loc, yaxis_loc], [yax_lim[0], yax_lim[1]],
                      linestyle='-', color="#362A1C", linewidth = 0.7)

    # Label axes3, 0.5, 'hello', fontsize=14)
    ax.text(xax_lim[1], xaxis_loc - 0.05 * yspan, ax.get_xlabel(),
            fontsize=14, ha='right', va='top')
    ax.text(yaxis_loc - 0.05 * xspan, yax_lim[1], ax.get_ylabel(),
            fontsize=14, ha='right', va='top')
    ax.set_xlabel('')
    ax.set_ylabel('')

    # Add title
    ax.text(0.5 * (xax_lim[1] + xax_lim[0]), yax_lim[1]*1.08,
            ax.get_title(),
            ha='center', va='bottom', fontsize=26)
    ax.set_title('')

    grids = []
    for yt in ax.get_yticks():
        if yt>= ylim[0] and yt<=ylim[1]:
            grids.append(pl.Line2D([xlim[0]-0.05 * random() * xspan, xlim[1]]+0.05 * random() * xspan, [yt, yt],
                          linestyle='-', color="#111111", linewidth = 0.1))
    
    for xt in ax.get_xticks():
        if xt>= xlim[0] and xt<=xlim[1]:
            grids.append(pl.Line2D([xt, xt], [ylim[0]-0.05 * random() *yspan, ylim[1]+0.05 * random() * yspan],
                          linestyle='-', color="#111111", linewidth = 0.1))

    # Draw arrow-heads at the end of axes lines
    arr1 = 0.03 * np.array([-1, 0, -1])
    arr2 = 0.02 * np.array([-1, 0, 1])

    #arr1[::2] += np.random.normal(0, 0.001, 2)
    #arr2[::2] += np.random.normal(0, 0.001, 2)

    x, y = xaxis.get_data()
    if '+' in str(xaxis_arrow):
        ax.plot(x[-1] + arr1 * xspan * aspect,
                y[-1] + arr2 * yspan,
                color="#362A1C", lw=1)
    if '-' in str(xaxis_arrow):
        ax.plot(x[0] - arr1 * xspan * aspect,
                y[0] - arr2 * yspan,
                color="#362A1C", lw=1)

    x, y = yaxis.get_data()
    if '+' in str(yaxis_arrow):
        ax.plot(x[-1] + arr2 * xspan * aspect,
                y[-1] + arr1 * yspan,
                color="#362A1C", lw=1)
    if '-' in str(yaxis_arrow):
        ax.plot(x[0] - arr2 * xspan * aspect,
                y[0] - arr1 * yspan,
                color="#362A1C", lw=1)

    Nlines = len(ax.lines)
    lines = grids+[xaxis, yaxis] + [ax.lines.pop(0) for i in range(Nlines)]
            
    fg = []
    shade = []
    
    for line in lines:
        l = DaVinci_line(line, "b",  xlim = xlim, ylim = ylim)
        fg.extend(l[0])
        shade.extend(l[1])
    
    for l in shade:
        ax.add_line(l)
        
    for l in fg:
        ax.add_line(l)
                
    # Change all the fonts to humor-sans.
    

    arr = np.arange(256).reshape(1,256)/256
    
    prop = fm.FontProperties(fname='davinci.ttf', size=34)
    for text in ax.texts:
        text.set_fontproperties(prop)
        text.set_path_effects([ patheffects.withStroke(linewidth=5,
                                foreground=bgcolor)])
    
    # modify legend
    leg = ax.get_legend()
    if leg is not None:
        #leg.set_frame_on(False)
        #leg.set_zorder(1)
        leg.draggable(True)
        frame = leg.get_frame()
        frame.set_boxstyle('round')
        frame.set_lw(0)
        frame.set_facecolor(bgcolor)
        
        #corners = leg.get_bbox_to_anchor().corners()
        #corners = [corners[0], corners[1],corners[3],corners[2]]
        #transform = transforms.BboxTransformFrom(leg.get_bbox_to_anchor())
        # leg_ax = frame.get_axes()
        #leg.set_frame_on(False)
        
        # com = [Path.MOVETO,
           # Path.LINETO,
           # Path.LINETO,
           # Path.CLOSEPOLY]
        
        #g = patches.Polygon(corners)
        # print dir(frame)
        
        
        #g.set_clip_box(None)
        # print dir(leg.get_bbox_to_anchor())
        #g.set_transform(transforms.IdentityTransform())
        
        pl.draw()
        
        bbox = leg.legendPatch.get_bbox().inverse_transformed(ax.transAxes)
        g = patches.Rectangle((bbox.x0, bbox.y0), 
                      bbox.width*1.2, bbox.height*1.2, 
                      fc='red', transform=ax.transAxes, zorder=50)
        
        ax.add_patch(g)
        
        print getp(g)
        print getp(frame)
        print getp(leg)
        
        for child in leg.get_children():
            if isinstance(child, pl.Line2D):
                x, y = child.get_data()
                child.set_data(xkcd_line(x, y, mag=10, f1=100, f2=0.001))
                child.set_linewidth(2 * child.get_linewidth())
            if isinstance(child, pl.Text):
                child.set_fontproperties(prop)
    
    # Set the axis limits
    ax.set_xlim(xax_lim[0] - 0.1 * xspan,
                xax_lim[1] + 0.1 * xspan)
    ax.set_ylim(yax_lim[0] - 0.1 * yspan,
                yax_lim[1] + 0.1 * yspan)

    # adjust the axes
    #ax.set_xticks([])
    #ax.set_yticks([])      

    if expand_axes:
        ax.figure.set_facecolor(bgcolor)
        ax.set_axis_off()
        ax.set_position([0.03, 0.03, 0.92, 0.92])
    
    return ax
    
    
# np.random.seed(0)

# ax = pl.axes()

# x = np.linspace(-0.01, 1.1, 50)
# def cannon(x, angle):
    # return x / np.tan(angle) - 0.5 * x**2/(np.sin(angle)**2)
    
# ax.plot(x, cannon(x,0.3)*2, lw=1)
# ax.plot(x, cannon(x,0.45)*2, lw=1)
# ax.plot(x, cannon(x,0.6)*2, lw=1)
# ax.plot(x, cannon(x,0.75)*2,'k', lw=1)
# ax.plot(x, cannon(x,0.9)*2,'b', lw=1)

# #ax.text(0.9, 0.4, "embarrassment")

# ax.set_title('Study of the trajectory of a cannon ball')
# ax.set_xlabel('Distance')
# ax.set_ylabel('Height')

# ax.legend(loc='top right', fancybox = True)

# ax.set_xlim(-0.05, 1.2)
# ax.set_ylim(-.1, 1.0)
# pl.draw()

# #XKCDify the axes -- this operates in-place
# DaVincify(ax, xaxis_loc=0.0, yaxis_loc=0.0,
        # xaxis_arrow='+', yaxis_arrow='+',
        # expand_axes=True)
# pl.show()