示例#1
0
def ontype(event):

    if event.key=='enter':
        #-- retrieve the sample data from the plot: the sample is plotted in white
        for child in plt.gca().get_children():
            if hasattr(child,'get_color') and child.get_color()=='1':
                sample = np.array(child.get_data()).T # put it in the right format
                break # make sure `child` is now set to the right artist
        sample = evolve(sample,maxfunc,breed_crossover)
        child.set_data(sample[:,0],sample[:,1])

    if event.key=='i':
        sample = np.random.uniform(size=(slider.val,2),low=0,high=1) # initialize sample
        bkg = maxfunc(np.mgrid[0:1:500j,0:1:500j]) # make the background
        plt.cla()
        plt.plot(sample[:,0],sample[:,1],'o',color='1',mec='1')
        plt.imshow(bkg,extent=[0,1,0,1],aspect='auto',cmap=plt.cm.spectral)
        plt.gca().set_autoscale_on(False)

    plt.draw()
    
    if __name__ == "__main__":
    	axpop = plt.axes([0.20, 0.05, 0.65, 0.03], axisbg='lightgoldenrodyellow') # axis for the slide
    	slider = plt.Slider(axpop, 'Population', 10, 2000, valinit=500) # the slider
    	ax1 = plt.axes([0.1,0.15,0.85,0.77]) # main axis to plot the results

    	plt.gcf().canvas.mpl_connect('key_press_event',ontype) # connect the ontype function to the GUI
    	plt.show() # show the window

                            
示例#2
0
def makeSlider():
    global data, plot
    
    if plot.sliderAxes == []:
        plot.sliderAxes = P.axes([0.2, 0.05, 0.6, 0.03], axisbg='lightgoldenrodyellow') # make place for a slider
    if plot.slider == []:
        plot.slider = P.Slider(plot.sliderAxes, 'time', data.t[0], data.t[-1], valinit=data.t[0]) # make a slider for the time
        plot.slider.on_changed(updatePlots)   # define the function to be executed when the slider value is changed
        P.draw()
示例#3
0
ax2.set_title('Y(K)')
ax2.set_xlabel('capital')
ax2.set_ylabel('yield')
ax2.grid(True)
# ax2.set_yscale('log')
ax2.legend()
axes = []

leftSpace = 0.2
startY = 0.05
deltaY = 0.04
axcolor = '#E4AC9A'
for name in axNames:
    axes.append(plt.axes([leftSpace, startY, 0.65, 0.03], facecolor=axcolor))
    startY += deltaY

# creating sliders with their names
sliderColor = '#755A57'
sliders = {}
for i in range(0, len(axes)):
    leftB, rightB = axLimits[i]
    sliders[axNames[i]] = pylab.Slider(axes[i],
                                       axNames[i],
                                       leftB,
                                       rightB,
                                       valinit=axInitVal[i],
                                       color=sliderColor)
    sliders[axNames[i]].on_changed(update)

plt.show()
示例#4
0
def animate_interactive(data,
                        t=None,
                        dim_order=(0, 1, 2),
                        fps=10.0,
                        title=None,
                        xlabel='x',
                        ylabel='y',
                        font_size=24,
                        color_bar=0,
                        colorbar_label=None,
                        sloppy=True,
                        fancy=False,
                        range_min=None,
                        range_max=None,
                        extent=[-1, 1, -1, 1],
                        shade=False,
                        azdeg=0,
                        altdeg=65,
                        arrowsX=None,
                        arrowsY=None,
                        arrows_resX=10,
                        arrows_resY=10,
                        arrows_pivot='mid',
                        arrows_width=0.002,
                        arrows_scale=5,
                        arrows_color='black',
                        plot_arrows_grid=False,
                        movie_file=None,
                        bitrate=1800,
                        keep_images=False,
                        figsize=(8, 7),
                        dpi=300,
                        **kwimshow):
    """
    Assemble a 2D animation from a 3D array.

    call signature::

    animate_interactive(data, t=None, dim_order=(0, 1, 2),
                        fps=10.0, title=None, xlabel='x', ylabel='y',
                        font_size=24, color_bar=0, colorbar_label=None,
                        sloppy=True, fancy=False,
                        range_min=None, range_max=None, extent=[-1, 1, -1, 1],
                        shade=False, azdeg=0, altdeg=65,
                        arrowsX=None, arrowsY=None, arrows_resX=10, arrows_resY=10,
                        arrows_pivot='mid', arrows_width=0.002, arrows_scale=5,
                        arrows_color='black', plot_arrows_grid=False,
                        movie_file=None, bitrate=1800, keep_images=False,
                        figsize=(8, 7), dpi=300,
                        **kwimshow)

    Assemble a 2D animation from a 3D array. *data* has to be a 3D array of
    shape [nt, nx, ny] and who's time index has the same dimension as *t*.
    The time index of *data* as well as its x and y indices can be changed
    via *dim_order*.

    Keyword arguments:

    *dim_order*:
      Ordering of the dimensions in the data array (t, x, y).

    *fps*:
      Frames per second of the animation.

    *title*:
      Title of the plot.

    *xlabel*:
      Label of the x-axis.

    *ylabel*:
      Label of the y-axis.

    *font_size*:
      Font size of the title, x and y label.
      The size of the x- and y-ticks is 0.5*font_size and the colorbar ticks'
      font size is 0.5*font_size.

    *color_bar*: [ 0 | 1 ]
      Determines how the colorbar changes:
      (0 - no cahnge; 1 - adapt extreme values).

    *colorbar_label*:
      Label of the color bar.

    *sloppy*: [ True | False ]
      If True the update of the plot lags one frame behind. This speeds up the
      plotting.

    *fancy*: [ True | False ]
      Use fancy font style.

    *range_min*, *range_max*:
      Range of the colortable.

    *extent*: [ None | (left, right, bottom, top) ]
      Limits for the axes (domain).

    *shade*: [ False | True ]
      If True plot a shaded relief instead of the usual colormap.
      Note that with this option cmap has to be specified like
      cmap = plt.cm.hot instead of cmap = 'hot'. Shading cannot
      be used with the color_bar = 0 option.

    *azdeg*, *altdeg*:
      Azimuth and altitude of the light source for the shading.

    *arrowsX*:
      Data containing the x-component of the arrows.

    *arrowsY*:
      Data containing the y-component of the arrows.

    *arrows_resXY*:
      Plot every arrows_resXY arrow in x and y.

    *arrows_pivot*: [ 'tail' | 'middle' | 'tip' ]
      The part of the arrow that is used as pivot point.

    *arrows_width*:
      Width of the arrows.

    *arrows_scale*:
      Scaling of the arrows.

    *arrows_color*:
      Color of the arrows.

    *plot_arrows_grid*: [ False | True ]
      If 'True' the grid where the arrows are aligned to is shown.

    *movie_file*: [ None | string ]
      The movie file where the animation should be saved to.
      If 'None' no movie file is written. Requires 'ffmpeg' to be installed.

    *bitrate*:
      Bitrate of the movie file. Set to higher value for higher quality.

    *keep_images*: [ False | True ]
      If 'True' the images for the movie creation are not deleted.

    *figsize*:
      Size of the figure in inches.

    *dpi*:
      Dots per inch of the frame.

    **kwimshow:
      Remaining arguments are identical to those of pylab.imshow. Refer to that help.
    """

    try:
        import thread
    except:
        import _thread as thread

    # We need to define these variables as globals, as they are being used
    # by various functions.

    global time_step, time_slider, pause
    global fig, axes, image, colorbar, arrows, manager, n_times, movie_files
    global rgb, plot_arrows

    if title is None:
        title = ''

    def plot_frame():
        """
        Plot the current frame.
        """

        global time_step, axes, colorbar, arrows, manager, rgb

        # Define the plot title.
        if not movie_file is None:
            axes.set_title(title + r'$\quad$' +
                           r'$t={0:.4e}$'.format(t[time_step]),
                           fontsize=font_size)

        # Update the image data.
        if not shade:
            image.set_data(data[time_step, :, :])
        else:
            image.set_data(rgb[time_step, :, :, :])

        # Update the colorbar.
        if color_bar == 0:
            pass
        if color_bar == 1:
            colorbar.set_clim(vmin=data[time_step, :, :].min(),
                              vmax=data[time_step, :, :].max())
            colorbar.draw_all()

        # Update the arrows data.
        if plot_arrows:
            arrows.set_UVC(U=arrowsX[time_step, ::arrows_resX, ::arrows_resY],
                           V=arrowsY[time_step, ::arrows_resX, ::arrows_resY])

        if not sloppy or (not movie_file is None):
            manager.canvas.draw()

    def play(thread_name):
        """
        Play the movie.
        """

        import time
        global time_step, time_slider, pause, fig, axes, n_times, movie_files

        pause = False
        while (time_step < n_times) and (not pause):
            # Write the image files for the movie.
            if not movie_file is None:
                plot_frame()
                frame_name = '{0}{1:06}.png'.format(movie_file, time_step)
                fig.savefig(frame_name, dpi=dpi)
                movie_files.append(frame_name)
            else:
                time_start = time.clock()
                time_slider.set_val(t[time_step])
                # Wait for the next frame (fps).
                while (time.clock() - time_start < 1.0 / fps):
                    pass
            time_step += 1
        time_step -= 1

    def play_thread(event):
        """
        Call the play function as a separate thread (for GUI).
        """

        global pause

        if pause:
            try:
                thread.start_new_thread(play, ("play_thread", ))
            except:
                print("Error: unable to start play thread.")

    def pausing(event):
        global pause

        pause = True

    def reverse(event):
        global time_step, time_slider

        time_step -= 1
        if time_step < 0:
            time_step = 0
        # Plot the frame and update the time slider.
        time_slider.set_val(t[time_step])

    def forward(event):
        global time_step, time_slider

        time_step += 1
        if time_step > len(t) - 1:
            time_step = len(t) - 1
        # Plot the frame and update the time slider.
        time_slider.set_val(t[time_step])

    import numpy as np
    import pylab as plt

    pause = True
    plot_arrows = False

    # Check if the data has the right dimensions.
    if (data.ndim != 3 and data.ndim != 4):
        print("Error: data dimensions are invalid: {0} instead of 3.".format(
            data.ndim))
        return -1

    # Transpose the data according to dim_order.
    unordered_data = data
    data = np.transpose(unordered_data, dim_order)
    del (unordered_data)

    # Check if arrows should be plotted.
    if not (arrowsX is None) and not (arrowsY is None):
        if (isinstance(arrowsX, np.ndarray)
                and isinstance(arrowsY, np.ndarray)):
            if arrowsX.ndim == 3:
                # Transpose the data according to dim_order.
                unordered_data = arrowsX
                arrowsX = np.transpose(unordered_data, dim_order)
                del (unordered_data)
            if arrowsY.ndim == 3:
                # Transpose the data according to dim_order.
                unordered_data = arrowsY
                arrowsY = np.transpose(unordered_data, dim_order)
                unordered_data = []

                # Check if the dimensions of the arrow arrays match each other.
                if arrowsX.shape != arrowsY.shape:
                    print(
                        "Error: dimensions of arrowX do not match with dimensions of arrowY."
                    )
                    return -1
                else:
                    plot_arrows = True
        else:
            print("Warning: arrowsX and/or arrowsY are of invalid type.")

    # Check if time array has the right length.
    n_times = len(t)
    if n_times != data.shape[0]:
        print(
            "Error: length of time array does not match length of data array.")
        return -1
    if plot_arrows:
        if (n_times != arrowsX.shape[0]) or (n_times != arrowsY.shape[0]):
            print(
                "error: length of time array does not match length of arrows array."
            )
            return -1

    # Check if fps is positive.
    if fps < 0:
        print("Error: fps is not positive, fps = {0}.".format(fps))
        return -1

    # Determine the size of the data array.
    nX = data.shape[1]
    nY = data.shape[2]

    # Determine the minimum and maximum values of the data set.
    if not range_min:
        range_min = np.min(data)
    if not range_max:
        range_max = np.max(data)

    # Setup the plot.
    if fancy:
        plt.rc('text', usetex=True)
        plt.rc('font', family='arial')
    else:
        plt.rc('text', usetex=False)
        plt.rc('font', family='sans')
    if not movie_file is None:
        fig = plt.figure(figsize=figsize)
        axes = plt.axes([0.15, 0.1, .70, .85])
    else:
        fig = plt.figure(figsize=figsize)
        axes = plt.axes([0.1, 0.3, .80, .65])

    # Set up canvas of the plot.
    axes.set_title(title, fontsize=font_size)
    axes.set_xlabel(xlabel, fontsize=font_size)
    axes.set_ylabel(ylabel, fontsize=font_size)
    plt.xticks(fontsize=0.5 * font_size)
    plt.yticks(fontsize=0.5 * font_size)
    if shade:
        plane = np.zeros([nX, nY, 3])
    else:
        plane = np.zeros([nX, nY])

    # Apply shading.
    if shade:
        from matplotlib.colors import LightSource

        ls = LightSource(azdeg=azdeg, altdeg=altdeg)
        rgb = []
        # Shading can be only used with color_bar=1 or color_bar=2 at the moment.
        if color_bar == 0:
            color_bar = 1
        # Check if colormap is set, if not set it to 'copper'.
        if 'cmap' not in kwimshow.keys():
            kwimshow['cmap'] = plt.cm.copper
        for i in range(data.shape[0]):
            tmp = ls.shade(data[i, :, :], kwimshow['cmap'])
            rgb.append(tmp.tolist())
        rgb = np.array(rgb)
        del (tmp)

    # Calibrate the displayed colors for the data range.
    image = axes.imshow(plane,
                        vmin=range_min,
                        vmax=range_max,
                        origin='lower',
                        extent=extent,
                        **kwimshow)
    colorbar = fig.colorbar(image)
    colorbar.set_label(colorbar_label, fontsize=font_size, labelpad=10)

    # Change the font size of the colorbar's ytickslabels.
    cbytick_obj = plt.getp(colorbar.ax.axes, 'yticklabels')
    plt.setp(cbytick_obj, fontsize=0.5 * font_size)

    # Plot the arrows.
    if plot_arrows:
        # Prepare the mesh grid where the arrows will be drawn.
        arrow_grid = np.meshgrid(
            np.arange(
                extent[0], extent[1],
                float(extent[1] - extent[0]) * arrows_resX /
                (data.shape[2] - 1)),
            np.arange(
                extent[2], extent[3],
                float(extent[3] - extent[2]) * arrows_resY /
                (data.shape[1] - 1)))
        arrows = axes.quiver(arrow_grid[0],
                             arrow_grid[1],
                             arrowsX[0, ::arrows_resX, ::arrows_resY],
                             arrowsY[0, ::arrows_resX, ::arrows_resY],
                             units='width',
                             pivot=arrows_pivot,
                             width=arrows_width,
                             scale=arrows_scale,
                             color=arrows_color)
        # Plot the grid for the arrows.
        if plot_arrows_grid:
            axes.plot(arrow_grid[0], arrow_grid[1], 'k.')

    # For real-time image display.
    if (not sloppy) or (not movie_file is None):
        manager = plt.get_current_fig_manager()
        manager.show()

    time_step = 0
    if not movie_file is None:
        import os

        movie_files = []

        # Start the animation.
        play('no_thread')

        # Write the movie file.
        ffmpeg_command = "ffmpeg -r {0} -i {1}%6d.png -vcodec mpeg4 -b:v {2} -q:v 0 {3}.avi".format(
            fps, movie_file, bitrate, movie_file)
        os.system(ffmpeg_command)
        # Clean up the image files.
        if not keep_images:
            print("Cleaning up files.")
            for fname in movie_files:
                os.remove(fname)
    else:
        # Set up the gui.
        plt.ion()
        plt.subplots_adjust(bottom=0.2)

        #        axes_play = plt.axes([0.1, 0.05, 0.15, 0.05])
        #        button_play = plt.Button(axes_play, 'play', color='lightgoldenrodyellow',
        #                                 hovercolor='0.975')
        #        button_play.on_clicked(play_thread)

        #        axes_pause = plt.axes([0.3, 0.05, 0.15, 0.05])
        #        button_pause = plt.Button(axes_pause, 'pause', color='lightgoldenrodyellow',
        #                                  hovercolor='0.975')
        #        button_pause.on_clicked(pausing)

        #        axes_reverse = plt.axes([0.5, 0.05, 0.15, 0.05])
        axes_reverse = plt.axes([0.1, 0.05, 0.3, 0.05])
        button_reverse = plt.Button(axes_reverse,
                                    'reverse',
                                    color='lightgoldenrodyellow',
                                    hovercolor='0.975')
        button_reverse.on_clicked(reverse)

        #        axes_forward = plt.axes([0.7, 0.05, 0.15, 0.05])
        axes_forward = plt.axes([0.5, 0.05, 0.3, 0.05])
        button_forward = plt.Button(axes_forward,
                                    'forward',
                                    color='lightgoldenrodyellow',
                                    hovercolor='0.975')
        button_forward.on_clicked(forward)

        # Create the time slider.
        time_slider_axes = plt.axes([0.2, 0.12, 0.6, 0.03],
                                    facecolor='lightgoldenrodyellow')
        time_slider = plt.Slider(time_slider_axes,
                                 'time',
                                 t[0],
                                 t[-1],
                                 valinit=t[0])

        def update(val):
            global time_step
            # Find the closest time step to the slider time value.
            for i in range(len(t)):
                if t[i] < time_slider.val:
                    time_step = i
            if (time_step != len(t) - 1):
                if (t[time_step + 1] - time_slider.val) < (time_slider.val -
                                                           t[time_step]):
                    time_step += 1
            plot_frame()

        time_slider.on_changed(update)

        plt.show()

    print("done")

    #    return button_play, button_pause, button_reverse, button_forward
    return button_reverse, button_forward
def animate_interactive(data,
                        t=[],
                        dimOrder=(0, 1, 2),
                        fps=10.0,
                        title='',
                        xlabel='x',
                        ylabel='y',
                        fontsize=24,
                        cBar=0,
                        sloppy=True,
                        rangeMin=[],
                        rangeMax=[],
                        extent=[-1, 1, -1, 1],
                        shade=False,
                        azdeg=0,
                        altdeg=65,
                        arrowsX=np.array(0),
                        arrowsY=np.array(0),
                        arrowsRes=10,
                        arrowsPivot='mid',
                        arrowsWidth=0.002,
                        arrowsScale=5,
                        arrowsColor='black',
                        plotArrowsGrid=False,
                        movieFile='',
                        bitrate=1800,
                        keepImages=False,
                        figsize=(8, 7),
                        dpi=None,
                        **kwimshow):
    """
    Assemble a 2D animation from a 3D array.

    call signature::
    
      animate_interactive(data, t = [], dimOrder = (0,1,2),
                        fps = 10.0, title = '', xlabel = 'x', ylabel = 'y',
                        fontsize = 24, cBar = 0, sloppy = True,
                        rangeMin = [], rangeMax = [], extent = [-1,1,-1,1],
                        shade = False, azdeg = 0, altdeg = 65,
                        arrowsX = np.array(0), arrowsY = np.array(0), arrowsRes = 10,
                        arrowsPivot = 'mid', arrowsWidth = 0.002, arrowsScale = 5,
                        arrowsColor = 'black', plotArrowsGrid = False,
                        movieFile = '', bitrate = 1800, keepImages = False,
                        figsize = (8, 7), dpi = None,
                        **kwimshow)
    
    Assemble a 2D animation from a 3D array. *data* has to be a 3D array who's
    time index has the same dimension as *t*. The time index of *data* as well
    as its x and y indices can be changed via *dimOrder*.
    
    Keyword arguments:
    
      *dimOrder*: [ (i,j,k) ]
        Ordering of the dimensions in the data array (t,x,y).
        
     *fps*:
       Frames per second of the animation.
       
     *title*:
       Title of the plot.
       
     *xlabel*:
       Label of the x-axis.
       
     *ylabel*:
       Label of the y-axis.
       
     *fontsize*:
       Font size of the title, x and y label.
       The size of the x- and y-ticks is 0.7*fontsize and the colorbar ticks's
       font size is 0.5*fontsize.
       
     *cBar*: [ 0 | 1 | 2 ]
       Determines how the colorbar changes:
       (0 - no cahnge; 1 - keep extreme values constant; 2 - change extreme values).
     
     *sloppy*: [ True | False ]
       If True the update of the plot lags one frame behind. This speeds up the
       plotting.
     
     *rangeMin*, *rangeMax*:
       Range of the colortable.
       
     *extent*: [ None | scalars (left, right, bottom, top) ]
       Data limits for the axes. The default assigns zero-based row,
       column indices to the *x*, *y* centers of the pixels.
       
     *shade*: [ False | True ]
       If True plot a shaded relief plot instead of the usual colormap.
       Note that with this option cmap has to be specified like
       cmap = plt.cm.hot instead of cmap = 'hot'. Shading cannot
       be used with the cBar = 0 option.
     
     *azdeg*, *altdeg*:
       Azimuth and altitude of the light source for the shading.
       
     *arrowsX*:
       Data containing the x-component of the arrows.
       
     *arrowsy*:
       Data containing the y-component of the arrows.
       
     *arrowsRes*:
       Plot every arrowRes arrow.
       
     *arrowsPivot*: [ 'tail' | 'middle' | 'tip' ]
       The part of the arrow that is at the grid point; the arrow rotates
       about this point.
       
     *arrowsWidth*:
       Width of the arrows.
       
     *arrowsScale*:
       Scaling of the arrows.
       
     *arrowsColor*:
       Color of the arrows.
       
     *plotArrowsGrid*: [ False | True ]
       If 'True' the grid where the arrows are aligned to is shown.
     
     *movieFile*: [ None | string ]
       The movie file where the animation should be saved to.
       If 'None' no movie file is written. Requires 'mencoder' to be installed.
     
     *bitrate*:
       Bitrate of the movie file. Set to higher value for higher quality.
       
     *keepImages*: [ False | True ]
       If 'True' the images for the movie creation are not deleted.
     
     *figsize*:
       Size of the figure in inches.
      
     *dpi*:
       Dots per inch of the frame.
     
     **kwimshow:
       Remaining arguments are identical to those of pylab.imshow. Refer to that help.
    """

    global tStep, sliderTime, pause

    # plot the current frame
    def plotFrame():
        global tStep, sliderTime

        if movieFile:
            ax.set_title(title + r'$\quad$' + r'$t={0}$'.format(t[tStep]),
                         fontsize=fontsize)

        if shade == False:
            image.set_data(data[tStep, :, :])
        else:
            image.set_data(rgb[tStep, :, :, :])

        if (cBar == 0):
            pass
        if (cBar == 1):
            colorbar.set_clim(vmin=data[tStep, :, :].min(),
                              vmax=data[tStep, :, :].max())
        if (cBar == 2):
            colorbar.set_clim(vmin=data[tStep, :, :].min(),
                              vmax=data[tStep, :, :].max())
            colorbar.update_bruteforce(data[tStep, :, :])

        if plotArrows:
            arrows.set_UVC(U=arrowsX[tStep, ::arrowsRes, ::arrowsRes],
                           V=arrowsY[tStep, ::arrowsRes, ::arrowsRes])

        if (sloppy == False) or (movieFile):
            manager.canvas.draw()

    # play the movie
    def play(threadName):
        global tStep, sliderTime, pause

        pause = False
        while (tStep < nT) & (pause == False):
            # write the image files for the movie
            if movieFile:
                plotFrame()
                frameName = movieFile + '%06d.png' % tStep
                fig.savefig(frameName, dpi=dpi)
                movieFiles.append(frameName)
            else:
                start = time.clock()
                # time slider
                sliderTime.set_val(t[tStep])
                # wait for the next frame (fps)
                while (time.clock() - start < 1.0 / fps):
                    pass  # do nothing
            tStep += 1
        tStep -= 1

    # call the play function as a separate thread (for GUI)
    def play_thread(event):
        global pause

        if pause == True:
            try:
                thread.start_new_thread(play, ("playThread", ))
            except:
                print "Error: unable to start play thread"

    def pausing(event):
        global pause

        pause = True

    def reverse(event):
        global tStep, sliderTime

        tStep -= 1
        if tStep < 0:
            tStep = 0
        # plot the frame and update the time slider
        sliderTime.set_val(t[tStep])

    def forward(event):
        global tStep, sliderTime

        tStep += 1
        if tStep > len(t) - 1:
            tStep = len(t) - 1
        # plot the frame and update the time slider
        sliderTime.set_val(t[tStep])

    pause = True
    plotArrows = False

    # check if the data has the right dimensions
    if (len(data.shape) != 3 and len(data.shape) != 4):
        print 'error: data dimensions are invalid: {0} instead of 3'.format(
            len(data.shape))
        return -1

    # transpose the data according to dimOrder
    unOrdered = data
    data = np.transpose(unOrdered, dimOrder)
    unOrdered = []

    # check if arrows should be plotted
    if len(arrowsX.shape) == 3:
        # transpose the data according to dimOrder
        unOrdered = arrowsX
        arrowsX = np.transpose(unOrdered, dimOrder)
        unOrdered = []
        if len(arrowsY.shape) == 3:
            # transpose the data according to dimOrder
            unOrdered = arrowsY
            arrowsY = np.transpose(unOrdered, dimOrder)
            unOrdered = []

            # check if the dimensions of the arrow arrays match each other
            if ((len(arrowsX[:, 0, 0]) != len(arrowsY[:, 0, 0]))
                    or (len(arrowsX[0, :, 0]) != len(arrowsY[0, :, 0]))
                    or (len(arrowsX[0, 0, :]) != len(arrowsY[0, 0, :]))):
                print 'error: dimensions of arrowX do not match with dimensions of arrowY'
                return -1
            else:
                plotArrows = True

    # check if time array has the right length
    nT = len(t)
    if (nT != len(data[:, 0, 0])):
        print 'error: length of time array doesn\'t match length of data array'
        return -1
        if plotArrows:
            if (nT != len(arrowX[:, 0, 0]) or nT != len(arrowX[:, 0, 0])):
                print 'error: length of time array doesn\'t match length of arrows array'
                return -1

    # check if fps is positive
    if (fps < 0.0):
        print 'error: fps is not positive, fps = {0}'.format(fps)
        return -1

    # determine the size of the array
    nX = len(data[0, :, 0])
    nY = len(data[0, 0, :])

    # determine the minimum and maximum values of the data set
    if not (rangeMin):
        rangeMin = np.min(data)
    if not (rangeMax):
        rangeMax = np.max(data)

    # setup the plot
    if movieFile:
        width = figsize[0]
        height = figsize[1]
        plt.rc("figure.subplot", bottom=0.15)
        plt.rc("figure.subplot", top=0.95)
        plt.rc("figure.subplot", right=0.95)
        plt.rc("figure.subplot", left=0.15)
        fig = plt.figure(figsize=figsize)
        ax = plt.axes([0.1, 0.1, .90, .85])
    else:
        width = figsize[0]
        height = figsize[1]
        plt.rc("figure.subplot", bottom=0.05)
        plt.rc("figure.subplot", top=0.95)
        plt.rc("figure.subplot", right=0.95)
        plt.rc("figure.subplot", left=0.15)
        fig = plt.figure(figsize=figsize)
        ax = plt.axes([0.1, 0.25, .85, .70])

    ax.set_title(title, fontsize=fontsize)
    ax.set_xlabel(xlabel, fontsize=fontsize)
    ax.set_ylabel(ylabel, fontsize=fontsize)
    plt.xticks(fontsize=0.7 * fontsize)
    plt.yticks(fontsize=0.7 * fontsize)
    if shade:
        plane = np.zeros((nX, nY, 3))
    else:
        plane = np.zeros((nX, nY))

    # apply shading if True
    if shade:
        ls = LightSource(azdeg=azdeg, altdeg=altdeg)
        rgb = []
        # shading can be only used with cBar = 1 or cBar = 2 at the moment
        if cBar == 0:
            cBar = 1
        # check if colormap is set, if not set it to 'copper'
        if kwimshow.has_key('cmap') == False:
            kwimshow['cmap'] = plt.cm.copper
        for i in range(len(data[:, 0, 0])):
            tmp = ls.shade(data[i, :, :], kwimshow['cmap'])
            rgb.append(tmp.tolist())
        rgb = np.array(rgb)
        tmp = []

    # calibrate the displayed colors for the data range
    image = ax.imshow(plane,
                      vmin=rangeMin,
                      vmax=rangeMax,
                      origin='lower',
                      extent=extent,
                      **kwimshow)
    colorbar = fig.colorbar(image)
    # change the font size of the colorbar's ytickslabels
    cbytick_obj = plt.getp(colorbar.ax.axes, 'yticklabels')
    plt.setp(cbytick_obj, fontsize=0.5 * fontsize)

    # plot the arrows
    # TODO: add some more options
    if plotArrows:
        # prepare the mash grid where the arrows will be drawn
        arrowGridX, arrowGridY = np.meshgrid(
            np.arange(
                extent[0], extent[1],
                float(extent[1] - extent[0]) * arrowsRes / len(data[0, :, 0])),
            np.arange(
                extent[2], extent[3],
                float(extent[3] - extent[2]) * arrowsRes / len(data[0, 0, :])))
        arrows = ax.quiver(arrowGridX,
                           arrowGridY,
                           arrowsX[0, ::arrowsRes, ::arrowsRes],
                           arrowsY[0, ::arrowsRes, ::arrowsRes],
                           units='width',
                           pivot=arrowsPivot,
                           width=arrowsWidth,
                           scale=arrowsScale,
                           color=arrowsColor)
        # plot the grid for the arrows
        if plotArrowsGrid == True:
            ax.plot(arrowGridX, arrowGridY, 'k.')

    # for real-time image display
    if (sloppy == False) or (movieFile):
        manager = plt.get_current_fig_manager()
        manager.show()

    tStep = 0
    if movieFile:
        movieFiles = []
        # start the animation
        play('noThread')

        # write the movie file
        mencodeCommand = "mencoder 'mf://" + movieFile + "*.png' -mf type=png:fps=" + np.str(
            fps) + " -ovc lavc -lavcopts vcodec=mpeg4:vhq:vbitrate=" + np.str(
                bitrate) + " -ffourcc MP4S -oac copy -o " + movieFile + ".mpg"
        os.system(mencodeCommand)
        # clean up the image files
        if (keepImages == False):
            print 'cleaning up files'
            for fname in movieFiles:
                os.remove(fname)

    else:
        # set up the gui
        plt.ion()

        axPlay = plt.axes([0.1, 0.05, 0.15, 0.05],
                          axisbg='lightgoldenrodyellow')
        buttonPlay = plt.Button(axPlay,
                                'play',
                                color='lightgoldenrodyellow',
                                hovercolor='0.975')
        buttonPlay.on_clicked(play_thread)
        axPause = plt.axes([0.3, 0.05, 0.15, 0.05],
                           axisbg='lightgoldenrodyellow')
        buttonPause = plt.Button(axPause,
                                 'pause',
                                 color='lightgoldenrodyellow',
                                 hovercolor='0.975')
        buttonPause.on_clicked(pausing)

        axReverse = plt.axes([0.5, 0.05, 0.15, 0.05],
                             axisbg='lightgoldenrodyellow')
        buttonReverse = plt.Button(axReverse,
                                   'reverse',
                                   color='lightgoldenrodyellow',
                                   hovercolor='0.975')
        buttonReverse.on_clicked(reverse)
        axForward = plt.axes([0.7, 0.05, 0.15, 0.05],
                             axisbg='lightgoldenrodyellow')
        buttonForward = plt.Button(axForward,
                                   'forward',
                                   color='lightgoldenrodyellow',
                                   hovercolor='0.975')
        buttonForward.on_clicked(forward)

        # create the time slider
        fig.subplots_adjust(bottom=0.2)
        sliderTimeAxes = plt.axes([0.2, 0.12, 0.6, 0.03],
                                  axisbg='lightgoldenrodyellow')
        sliderTime = plt.Slider(sliderTimeAxes,
                                'time',
                                t[0],
                                t[-1],
                                valinit=0.0)

        def update(val):
            global tStep
            # find the closest time step to the slider time value
            for i in range(len(t)):
                if t[i] < sliderTime.val:
                    tStep = i
            if (tStep != len(t) - 1):
                if (t[tStep + 1] - sliderTime.val) < (sliderTime.val -
                                                      t[tStep]):
                    tStep += 1
            plotFrame()

        sliderTime.on_changed(update)

        plt.show()

    print 'done'
示例#6
0
    def displayResults2(self):
        """ show results with time slider

            geneNames have to be unique!

        """
        ##
        ##        import pylab
        ####        from matplotlib.widgets import Slider
        ##        mi=1
        ##        ma=self.results.shape[1]
        ##        print(ma)# should be 101
        ##
        ##        figure=pylab.figure()
        ##
        ##        ax=pylab.Axes(figure,[0.15, 0.1, 0.65, 0.03])
        ##        slider=pylab.Slider(ax=ax,label='time slider',valmin=mi,valmax=ma,valinit=1)
        ##
        ##        def update():
        ##            pylab.plot(self.results[:,slider.val,:])
        ##            pylab.draw()
        ##
        ##        slider.on_changed(update)
        ##        pylab.show()

        import pylab
        import scipy
        from matplotlib.widgets import Slider, Button, RadioButtons, CheckButtons

        ##        test = pylab.plot(self.results[:,-1,:])
        ##        pylab.show()

        # results is a data cube [cell, time, gene]
        miTime = 0
        maTime = self.results.shape[1] - 1
        ##        xdata=range(self.results.shape[0])

        ##ax = pylab.subplot(111)

        pylab.subplots_adjust(left=0.25, bottom=0.25)
        ##        t = scipy.arange(0.0, 1.0, 0.001)
        ##        a0 = 5
        ##        f0 = 3
        ##        s = a0*scipy.sin(2*scipy.pi*f0*t)
        ##        plot, = pylab.plot(t,s, lw=2, color='red')
        selection = scipy.ones(len(self.geneNames)).astype(
            bool)  # select all on start
        plots = pylab.plot(self.results[:, maTime, :]
                           )  # apperently returns a plot for each line...

        ##        pylab.axis([0, 1, -10, 10])

        axcolor = 'lightgoldenrodyellow'
        ##        axfreq = pylab.axes([0.25, 0.1, 0.65, 0.03], axisbg=axcolor)
        ##        axamp  = pylab.axes([0.25, 0.15, 0.65, 0.03], axisbg=axcolor)

        ##        sfreq = Slider(axfreq, 'Freq', 0.1, 30.0, valinit=f0)
        ##        samp = Slider(axamp, 'Amp', 0.1, 10.0, valinit=a0)
        ax = pylab.axes([0.15, 0.1, 0.65, 0.03])
        slider = pylab.Slider(ax=ax,
                              label='time slider',
                              valmin=miTime,
                              valmax=maTime,
                              valinit=maTime)

        def update(val):
            ##            amp = samp.val
            ##            freq = sfreq.val
            ##            l.set_ydata(amp*scipy.sin(2*scipy.pi*freq*t))
            for i in range(len(plots)):
                plots[i].set_ydata(self.results[:, slider.val, i])
                plots[i].set_visible(selection[i])

            pylab.draw()

##        sfreq.on_changed(update)
##        samp.on_changed(update)

        slider.on_changed(update)

        resetax = pylab.axes([0.8, 0.025, 0.1, 0.04])
        button = Button(resetax, 'Reset', color=axcolor, hovercolor='0.975')

        def reset(event):
            ##            sfreq.reset()
            ##            samp.reset()
            slider.reset()

        button.on_clicked(reset)

        rax = pylab.axes([0.025, 0.5, 0.15, 0.15], axisbg=axcolor)
        checker = CheckButtons(rax, self.geneNames, actives=selection)

        def selector(val):
            ##            print(val)
            ##            print(scipy.array(range(len(self.geneNames)))[self.geneNames==val][0])
            geneNr = scipy.array(
                range(len(self.geneNames))
            )[self.geneNames == val][
                0]  # its retarded to check label names... but that is the way they like it....
            selection[geneNr] = not (selection[geneNr])
            update(slider.val)

        checker.on_clicked(selector)
        ##        print(checker.eventson)
        ##        print(checker.drawon)

        ##        radio = RadioButtons(rax, ('red', 'blue', 'green'), active=0)

        ##
        ##        rax = pylab.axes([0.025, 0.5, 0.15, 0.15], axisbg=axcolor)
        ##        radio = RadioButtons(rax, ('red', 'blue', 'green'), active=0)
        ##        def colorfunc(label):
        ##            for i in range(len(plots)):
        ##                plots[i].set_color(label)
        ####            plots.set_color(label)
        ##            pylab.draw()
        ##        radio.on_clicked(colorfunc)

        pylab.show()
示例#7
0
def plot_3d_array(data,
                  axis=0,
                  title='3d',
                  cmap='gray',
                  interpolation='nearest',
                  vmin=None,
                  vmax=None,
                  **kwargs):
    '''
    plots 3d data with a slider to change the third dimension
    unfortunately the number that the slider shows is rounded weirdly.. be careful!
    TODO: fix that!

    input:
        - data: 3d numpy array containing the data
        - axis: axis that should be changeable by the slider

    author: Mathias Marschner
    added: 30.10.2013
    '''
    fig = plt.figure()
    ax = fig.add_subplot(111)
    plt.title(title)

    if vmin == None:
        vmin = data.min()
    if vmax == None:
        vmax = data.max()

    if axis == 0:
        cax = ax.imshow(data[data.shape[0] / 2, :, :],
                        cmap=cmap,
                        vmin=vmin,
                        vmax=vmax,
                        interpolation=interpolation,
                        **kwargs)
    elif axis == 1:
        cax = ax.imshow(data[:, data.shape[1] / 2, :],
                        cmap=cmap,
                        vmin=vmin,
                        vmax=vmax,
                        interpolation=interpolation,
                        **kwargs)
    elif axis == 2:
        cax = ax.imshow(data[:, :, data.shape[2] / 2],
                        cmap=cmap,
                        vmin=vmin,
                        vmax=vmax,
                        interpolation=interpolation,
                        **kwargs)

    cbar = fig.colorbar(cax)
    axcolor = 'lightgoldenrodyellow'
    ax4 = pylab.axes([0.1, 0.01, 0.8, 0.03], axisbg=axcolor)
    sframe = pylab.Slider(ax4,
                          '',
                          0,
                          data.shape[axis] - 1,
                          valinit=data.shape[axis] / 2,
                          closedmin=True,
                          closedmax=True,
                          valfmt='%d')

    def update(val):
        frame = np.around(np.clip(sframe.val, 0, data.shape[axis] - 1))
        if axis == 0:
            cax.set_data(data[frame, :, :])
        elif axis == 1:
            cax.set_data(data[:, frame, :])
        elif axis == 2:
            cax.set_data(data[:, :, frame])

    sframe.on_changed(update)
    return ax