Пример #1
0
 def export(self, path):
     """Export mp4-video of animation to path. This requires FFMPEG."""
     currentTimeIndex = self.currentTimeIndex  # Index to jump back to
     ffmpegWriter = FFMpegFileWriter()
     with ffmpegWriter.saving(self.figure, path, dpi=100):
         for t_i in range(self.maxTimeIndex):
             self.time_changed(t_i)
             ffmpegWriter.grab_frame()
     self.time_changed(currentTimeIndex)
Пример #2
0
    def save_video(self,
                   save_path,
                   start_time=0,
                   end_time=50,
                   fps=5,
                   dpi=60,
                   figsize=(5, 8)):
        '''
        Saves video of frames from start_time to end_time
        they stack.
        '''
        def update_frame(t):
            ax.imshow(self.update_board(time=t), interpolation='nearest')
            ax.grid(color='white')
            ax.set_axis_off()

        fig, ax = plt.subplots(figsize=figsize)
        animation = FuncAnimation(fig,
                                  update_frame,
                                  repeat=False,
                                  frames=np.arange(start_time, end_time))

        # Gif
        if save_path[-4:].lower() == '.gif':
            animation.save(save_path, dpi=60, fps=fps, writer='imagemagick')
        elif save_path[-4:].lower() == '.mp4':
            writer = FFMpegFileWriter(fps=fps)
            animation.save(save_path, dpi=60, writer=writer)
        else:
            error_msg = 'ERROR: unknown file type "'\
                        + save_path.split('.')[-1] + '".'
            raise Exception(error_msg)
Пример #3
0
    def _plot_output(self, type, latpositions, longpositions, data, media_file):
        fig = plt.figure(figsize=plt.figaspect(0.3))
        image = self.input_config.image
        impositions = self.input_config.positions
        imfactors = self.input_config.factors
        intensities = self.input_config.intensities

        writer = FFMpegFileWriter(fps=5, codec='mpeg4')
        writer.setup(
            fig, media_file, dpi=80,
            frame_prefix=os.path.splitext(media_file)[0]+'_')
        writer.frame_format = 'png'

        step = 100
        plt.hold(False)

        #TODO keep DRY
        U, V = np.mgrid[0:np.pi/2:complex(0, self.PARALLELS),
                        0:2*np.pi:complex(0, self.MERIDIANS)]
        X = np.cos(V)*np.sin(U)
        Y = np.sin(V)*np.sin(U)
        Z = np.cos(U)
        for i, d in enumerate(data):
            if i % step == 0:
                ax1 = fig.add_subplot(1, 3, 1)
                ax1.set_title('Image')
                if type == 'image':
                    ax1.imshow(image, cmap=cm.Greys_r)
                elif type == 'video':
                    # TODO show red rectangle
                    ax1 = fig.add_subplot(1, 3, 1)
                    xind = impositions[i, 0:2]
                    yind = impositions[i, 2:4]
                    ax1.imshow(image[int(xind[0]):int(xind[1]),
                                     int(yind[0]):int(yind[1])],
                               cmap=cm.Greys_r)
                else:
                    raise ValueError('Invalid value for media type: {}'
                                     ', valid values image, video'.format(type))
                
                colors = self.compute_colors(U, V, latpositions, 
                                             longpositions, intensities[i])
                ax2 = fig.add_subplot(1, 3, 2, projection='3d')
                ax2.plot_surface(X, Y, Z, rstride=1, cstride=1, facecolors=cm.gray(colors),
                                 antialiased=False, shade=False)
                ax2.set_title('Intensities')
                colors = self.compute_colors(U, V, latpositions, 
                                             longpositions, data[i])
                ax3 = fig.add_subplot(1, 3, 3, projection='3d')
                ax3.plot_surface(X, Y, Z, rstride=1, cstride=1, facecolors=cm.gray(colors),
                                 antialiased=False, shade=False)
                ax3.set_title('Photoreceptor outputs')
                fig.canvas.draw()
                writer.grab_frame()
        writer.finish()
Пример #4
0
    def test_animate_events(self):
        events = ["hello", "world", "", "!", ""]
        times = [0., 1.0, 1.5, 2.0, 2.5]
        fps = 10
        filename = os.path.join(write_dir, "animate_test.mp4")
        ani = animate_events(events, times, fps)

        from matplotlib.animation import FFMpegFileWriter
        kwargs = {'transparent': True,}
        writer = FFMpegFileWriter(fps=fps)
        ani.save(filename, dpi=300, writer=writer, savefig_kwargs=kwargs)
 def output(self, anim, plt, dataset, output):
     """
     Generic output functions
     """
     logger.info("Writing output...")
     filename = "covid_{}.{}".format(dataset.lower(), output)
     if output == "mp4":
         writer = FFMpegFileWriter(fps=10, bitrate=1800)
         anim.save(filename, writer=writer)
     elif output == "gif":
         writer = ImageMagickFileWriter()
         anim.save(filename, writer=writer)
     else:
         plt.show()
Пример #6
0
def savemp4(images, videofile, step=10):
    '''
        Generates a frame every 10(default) images and saves all
        of them to a video file

        parameters:
            images: a numpy array where each row corresponds to an image
            videofile: file to store video
            step: every that number of images will be stored in the file
                  the rest will be ignored (e.g if step is 10
                  and images are 50, only images 1,11,21,31,41 will be stored)
    '''
    from matplotlib import cm
    from matplotlib.animation import FFMpegFileWriter, AVConvFileWriter
    import matplotlib.pyplot as plt

    fig = plt.figure(figsize=plt.figaspect(1.0))

    writer = FFMpegFileWriter(fps=5, codec='mpeg4')
    writer.setup(fig,
                 videofile,
                 dpi=80,
                 frame_prefix=os.path.splitext(videofile)[0] + '_')
    writer.frame_format = 'png'

    plt.hold(False)

    ax = fig.add_subplot(111)

    plt.subplots_adjust(left=0, right=1.0)
    for i in range(0, len(images), step):
        ax.imshow(images[i],
                  cmap=cm.Greys_r,
                  vmin=images.min(),
                  vmax=images.max())
        fig.canvas.draw()
        writer.grab_frame()

    writer.finish()
Пример #7
0
def savemp4(images, videofile, step=10):
    '''
        Generates a frame every 10(default) images and saves all
        of them to a video file

        parameters:
            images: a numpy array where each row corresponds to an image
            videofile: file to store video
            step: every that number of images will be stored in the file
                  the rest will be ignored (e.g if step is 10
                  and images are 50, only images 1,11,21,31,41 will be stored)
    '''
    from matplotlib import cm
    from matplotlib.animation import FFMpegFileWriter, AVConvFileWriter
    import matplotlib.pyplot as plt

    fig = plt.figure(figsize=plt.figaspect(1.0))

    writer = FFMpegFileWriter(fps=5, codec='mpeg4')
    writer.setup(
        fig, videofile, dpi=80,
        frame_prefix=os.path.splitext(videofile)[0]+'_')
    writer.frame_format = 'png'

    plt.hold(False)

    ax = fig.add_subplot(111)

    plt.subplots_adjust(left=0, right=1.0)
    for i in range(0, len(images), step):
        ax.imshow(images[i], cmap=cm.Greys_r,
                  vmin=images.min(), vmax=images.max())
        fig.canvas.draw()
        writer.grab_frame()

    writer.finish()
Пример #8
0
class visualizer(object):
    """
    Visualize the output produced by LPU models.

    Example
    -------
    import neurokernel.LPU.utils.visualizer as vis
    V = vis.visualizer()
    config1 = {}
    config1['type'] = 'image'
    config1['shape'] = [32,24]
    config1['clim'] = [-0.6,0.5]
    config2 = config1.copy()
    config2['clim'] = [-0.55,-0.45]
    V.add_LPU('lamina_output.h5', 'lamina.gexf.gz','lamina')
    V.add_plot(config1, 'lamina', 'R1')
    V.add_plot(config1, 'lamina', 'L1')
    V._update_interval = 50
    V.out_filename = 'test.avi'
    V.run()
    """
    
    def __init__(self):
        self._xlim = [0,1]
        self._ylim = [-1,1]
        self._imlim = [-1, 1]
        self._update_interval = 50
        self._out_file = None
        self._fps = 5
        self._codec = 'libtheora'
        self._config = OrderedDict()
        self._rows = 0
        self._cols = 0
        self._figsize = (16,9)
        self._fontsize = 18
        self._t = 1
        self._dt = 1
        self._data = {}
        self._graph = {}
        self._maxt = None
        self._title = None

    def add_LPU(self, data_file, gexf_file=None, LPU=None, win=None):
        '''
        Add data associated with a specific LPU to a visualization.

        To add a plot containing neurons from a particular LPU,
        the LPU needs to be added to the visualization using this
        function. Not that outputs from multiple neurons can
        be visualized using the same visualizer object.

        Parameters
        ----------
        data_file: str
             Location of the h5 file generated by neurokernel
             containing the output of the LPU

        gexf_file: str
            Location of the gexf file describing the LPU.
            If not specified, it will be assumed that the h5 file
            contains input.

        LPU: str
            Name of the LPU. Will be used as identifier to add plots.
       
        '''
        if gexf_file:
            self._graph[LPU] = nx.read_gexf(gexf_file)
        else:
            if LPU:
                LPU = 'input_' + str(LPU)
            else:
                LPU = 'input_' + str(len(self._data))
        if not LPU:
            LPU = len(self._data)
        self._data[LPU] = np.transpose(sio.read_array(data_file))
        if win is not None:
            self._data[LPU] = self._data[LPU][:,win]
        if self._maxt:
            self._maxt = min(self._maxt, self._data[LPU].shape[1])
        else:
            self._maxt = self._data[LPU].shape[1]

    def run(self, final_frame_name=None, dpi=300):
        '''
        Starts the visualization process The final frame is saved to the specified
        file name; otherwise, the visualization is displayed in a window without being saved.
        '''

        self._initialize()
        self._t = self._update_interval+1
        for i in range(self._update_interval,self._maxt, self._update_interval):
            self.update()
        if final_frame_name is not None:
            self.f.savefig(final_frame_name, dpi=dpi)
        if self.out_filename:
            self.close()

    def _set_wrapper(self, obj, name, value):
        name = name.lower()
        func = getattr(obj, 'set_'+name, None)
        if func:
            try:
                func(value, fontsize=self._fontsize, weight='bold')
            except:
                try:
                    func(value)
                except:
                    pass
        
    def _initialize(self):

        # Count number of plots to create:
        num_plots = 0
        for config in self._config.itervalues():
            num_plots += len(config)

        # Set default grid of plot positions:
        if not self._rows*self._cols == num_plots:
            self._cols = int(np.ceil(np.sqrt(num_plots)))
            self._rows = int(np.ceil(num_plots/float(self._cols)))
        self.f, self.axarr = plt.subplots(self._rows, self._cols,
                                          figsize=self._figsize)
        
        # Remove unused subplots:
        for i in xrange(num_plots, self._rows*self._cols):
            plt.delaxes(self.axarr[np.unravel_index(i, (self._rows, self._cols))])
        cnt = 0
        self.handles = []
        self.types = []
        keywds = ['handle', 'ydata', 'fmt', 'type', 'ids', 'shape'] 
        if not isinstance(self.axarr, np.ndarray):
            self.axarr = np.asarray([self.axarr])
        for LPU, configs in self._config.iteritems():
            for plt_id, config in enumerate(configs):
                ind = np.unravel_index(cnt, self.axarr.shape)
                cnt+=1

                # Some plot types require specific numbers of
                # neuron ID arrays:
                if 'type' in config:
                    if config['type'] == 'quiver':
                        assert len(config['ids'])==2
                        config['type'] = 0
                    elif config['type'] == 'hsv':
                        assert len(config['ids'])==2
                        config['type'] = 1
                    elif config['type'] == 'image':
                        assert len(config['ids'])==1
                        config['type'] = 2
                    elif config['type'] == 'waveform':
                        config['type'] = 3
                    elif config['type'] == 'raster':
                        config['type'] = 4
                    elif config['type'] == 'rate':
                        config['type'] = 5
                    else:
                        raise ValueError('Plot type not supported')
                else:
                    if str(LPU).startswith('input') or not self._graph[LPU][str(config['ids'][0])]['spiking']:
                        config['type'] = 2
                    else:
                        config['type'] = 4
                        
                if config['type'] < 3:
                    if not 'shape' in config:

                        # XXX This can cause problems when the number
                        # of neurons is not equal to
                        # np.prod(config['shape'])
                        num_neurons = len(config['ids'][0])
                        config['shape'] = [int(np.ceil(np.sqrt(num_neurons)))]
                        config['shape'].append(int(np.ceil(num_neurons/float(config['shape'][0]))))
                        
                if config['type'] == 0:
                    config['handle'] = self.axarr[ind].quiver(\
                               np.reshape(self._data[LPU][config['ids'][0],0],config['shape']),\
                               np.reshape(self._data[LPU][config['ids'][1],0],config['shape']))
                elif config['type'] == 1:
                    X = np.reshape(self._data[LPU][config['ids'][0],0],config['shape'])
                    Y = np.reshape(self._data[LPU][config['ids'][1],0],config['shape'])
                    V = (X**2 + Y**2)**0.5
                    H = (np.arctan2(X,Y)+np.pi)/(2*np.pi)
                    S = np.ones_like(V)
                    HSV = np.dstack((H,S,V))
                    RGB = hsv_to_rgb(HSV)
                    config['handle'] = self.axarr[ind].imshow(RGB)
                elif config['type'] == 2:
                    if 'trans' in config:
                        if config['trans'] is True:
                            to_transpose = True
                        else:
                            to_transpose = False
                    else:
                        to_transpose = False
                        config['trans'] = False
                    
                    if to_transpose:
                        temp = self.axarr[ind].imshow(np.transpose(np.reshape(\
                                self._data[LPU][config['ids'][0],0], config['shape'])))
                    else:
                        temp = self.axarr[ind].imshow(np.reshape(\
                                self._data[LPU][config['ids'][0],0], config['shape']))
            
            
            
                    temp.set_clim(self._imlim)
                    temp.set_cmap(plt.cm.gist_gray)
                    config['handle'] = temp
                elif config['type'] == 3:
                    fmt = config['fmt'] if 'fmt' in config else '' 
                    self.axarr[ind].set_xlim(self._xlim)
                    self.axarr[ind].set_ylim(self._ylim)
                    if len(config['ids'][0])==1:
                        config['handle'] = self.axarr[ind].plot([0], \
                                            [self._data[LPU][config['ids'][0][0],0]], fmt)[0]
                        config['ydata'] = [self._data[LPU][config['ids'][0][0],0]]
                    else:
                        config['handle'] = self.axarr[ind].plot(self._data[LPU][config['ids'][0],0])[0]
                        
                elif config['type'] == 4:
                    config['handle'] = self.axarr[ind]
                    config['handle'].vlines(0, 0, 0.01)
                    config['handle'].set_ylim([.5, len(config['ids'][0]) + .5])
                    config['handle'].set_ylabel('Neurons',
                                                fontsize=self._fontsize-1, weight='bold')
                    config['handle'].set_xlabel('Time (s)',fontsize=self._fontsize-1, weight='bold')
                    config['handle'].set_xlim([0,len(self._data[LPU][config['ids'][0][0],:])*self._dt])
                    config['handle'].axes.set_yticks([])
                    config['handle'].axes.set_xticks([])
                for key in config.iterkeys():
                    if key not in keywds:
                        try:
                            self._set_wrapper(self.axarr[ind],key, config[key])
                        except:
                            pass
                        try:
                            self._set_wrapper(config['handle'],key, config[key])
                        except:
                            pass
                if config['type']<3:
                    config['handle'].axes.set_xticks([])
                    config['handle'].axes.set_yticks([])

            if self.suptitle is not None:
                self.f.suptitle(self._title, fontsize=self._fontsize+1, x=0.5,y=0.03, weight='bold')

        plt.tight_layout()

        if self.out_filename:
            self.writer = FFMpegFileWriter(fps=self.fps, codec=self.codec)

            # Use the output file to determine the name of the temporary frame
            # files so that two concurrently run visualizations don't clobber
            # each other's frames:
            self.writer.setup(self.f, self.out_filename, dpi=80,
                              frame_prefix=os.path.splitext(self.out_filename)[0]+'_')
            self.writer.frame_format = 'png'
            self.writer.grab_frame()
        else:
            self.f.show()

    def update(self):
        dt = self._dt
        t = self._t
        for key, configs in self._config.iteritems():
            data = self._data[key]
            for config in configs:
                if config['type'] == 3:
                    if len(config['ids'][0])==1:
                        config['ydata'].extend(np.reshape(np.double(\
                                        data[config['ids'][0], \
                                                  max(0,t-self._update_interval):t]),(-1,)))
                        config['handle'].set_xdata(dt*np.arange(0, t))
                        config['handle'].set_ydata(np.asarray(config['ydata']))
                    else:
                        config['handle'].set_ydata(\
                                        data[config['ids'][0], t])

                elif config['type']==4:
                    for j,id in enumerate(config['ids'][0]):
                        for time in np.where(data[id,max(0,t-self._update_interval):t])[0]:
                            config['handle'].vlines(float(t-time)*self._dt,j+0.75, j+1.25)
                else:
                    if config['type'] == 0:
                        shape = config['shape']
                        ids = config['ids']
                        config['handle'].U = np.reshape(data[ids[0], t],shape)
                        config['handle'].V = np.reshape(data[ids[1], t],shape)
                    elif config['type']==1:
                        shape = config['shape']
                        ids = config['ids']
                        X = np.reshape(data[ids[0], t],shape)
                        Y = np.reshape(data[ids[1], t],shape)
                        V = (X**2 + Y**2)**0.5
                        H = (np.arctan2(X,Y)+np.pi)/(2*np.pi)
                        S = np.ones_like(V)
                        HSV = np.dstack((H,S,V))
                        RGB = hsv_to_rgb(HSV)
                        config['handle'].set_data(RGB)
                    elif config['type'] == 2:
                        ids = config['ids']
                        if config['trans']:
                            config['handle'].set_data(
                                np.transpose(np.reshape(data[ids[0], t], config['shape'
                            ])))
                        else:
                            config['handle'].set_data(
                                np.reshape(data[ids[0], t], config['shape']))
                    
        self.f.canvas.draw()
        if self.out_filename:
            self.writer.grab_frame()

        self._t+=self._update_interval
            
    def add_plot(self, config_dict, LPU=0, names=[''], shift=0):
        config = config_dict.copy()
        if not isinstance(names, list):
            names = [names]
        if not LPU in self._config:
            self._config[LPU] = []
        if 'ids' in config:
            # XXX should check whether the specified ids are within range
            self._config[LPU].append(config)
        elif str(LPU).startswith('input'):
            config['ids'] = [range(0, self._data[LPU].shape[0])]
            self._config[LPU].append(config)
        else:
            config['ids'] = {}
            for i,name in enumerate(names):
                config['ids'][i]=[]
                for id in range(len(self._graph[LPU].node)):
                    if self._graph[LPU].node[str(id)]['name'] == name:
                        config['ids'][i].append(id-shift)
            self._config[LPU].append(config)
        if not 'title' in config:
            if names[0]:
                config['title'] = "{0} - {1}".format(str(LPU),str(names[0]))
            else:
                if str(LPU).startswith('input_'):
                    config['title'] = LPU.split('_',1)[1] + ' - ' + 'Input'
                else:
                    config['title'] = str(LPU)

    def close(self):
        self.writer.finish()

    @property
    def xlim(self): return self._xlim

    @xlim.setter
    def xlim(self, value):
        self._xlim = value

    @property
    def ylim(self): return self._ylim

    @ylim.setter
    def ylim(self, value):
        self._ylim = value

    @property
    def imlim(self): return self._imlim

    @imlim.setter
    def imlim(self, value):
        self._imlim = value

    @property
    def out_filename(self): return self._out_file

    @out_filename.setter
    def out_filename(self, value):
        assert(isinstance(value, str))
        self._out_file = value

    @property
    def fps(self): return self._fps

    @fps.setter
    def fps(self, value):
        assert(isinstance(value, int))
        self._fps = value

    @property
    def codec(self): return self._codec

    @codec.setter
    def codec(self, value):
        assert(isinstance(value, str))
        self._codec = value

    @property
    def rows(self): return self._rows

    @rows.setter
    def rows(self, value):
        self._rows = value

    @property
    def cols(self): return self._cols

    @cols.setter
    def cols(self, value):
        self._cols = value

    @property
    def dt(self): return self._dt

    @dt.setter
    def dt(self, value):
        self._dt = value

    @property
    def figsize(self): return self._figsize

    @figsize.setter
    def figsize(self, value):
        assert(isinstance(value, tuple) and len(value)==2)
        self._figsize = value

    @property
    def fontsize(self): return self._fontsize

    @fontsize.setter
    def fontsize(self, value):
        self._fontsize = value

    @property
    def suptitle(self): return self._title

    @suptitle.setter
    def suptitle(self, value):
        self._title = value

    @property
    def update_interval(self): return self._update_interval

    @update_interval.setter
    def update_interval(self, value):
        self._update_interval = value
Пример #9
0
    def _initialize(self):

        # Count number of plots to create:
        num_plots = 0
        for config in self._config.itervalues():
            num_plots += len(config)

        # Set default grid of plot positions:
        if not self._rows*self._cols == num_plots:
            self._cols = int(np.ceil(np.sqrt(num_plots)))
            self._rows = int(np.ceil(num_plots/float(self._cols)))
        self.f, self.axarr = plt.subplots(self._rows, self._cols,
                                          figsize=self._figsize)
        
        # Remove unused subplots:
        for i in xrange(num_plots, self._rows*self._cols):
            plt.delaxes(self.axarr[np.unravel_index(i, (self._rows, self._cols))])
        cnt = 0
        self.handles = []
        self.types = []
        keywds = ['handle', 'ydata', 'fmt', 'type', 'ids', 'shape'] 
        if not isinstance(self.axarr, np.ndarray):
            self.axarr = np.asarray([self.axarr])
        for LPU, configs in self._config.iteritems():
            for plt_id, config in enumerate(configs):
                ind = np.unravel_index(cnt, self.axarr.shape)
                cnt+=1

                # Some plot types require specific numbers of
                # neuron ID arrays:
                if 'type' in config:
                    if config['type'] == 'quiver':
                        assert len(config['ids'])==2
                        config['type'] = 0
                    elif config['type'] == 'hsv':
                        assert len(config['ids'])==2
                        config['type'] = 1
                    elif config['type'] == 'image':
                        assert len(config['ids'])==1
                        config['type'] = 2
                    elif config['type'] == 'waveform':
                        config['type'] = 3
                    elif config['type'] == 'raster':
                        config['type'] = 4
                    elif config['type'] == 'rate':
                        config['type'] = 5
                    else:
                        raise ValueError('Plot type not supported')
                else:
                    if str(LPU).startswith('input') or not self._graph[LPU][str(config['ids'][0])]['spiking']:
                        config['type'] = 2
                    else:
                        config['type'] = 4
                        
                if config['type'] < 3:
                    if not 'shape' in config:

                        # XXX This can cause problems when the number
                        # of neurons is not equal to
                        # np.prod(config['shape'])
                        num_neurons = len(config['ids'][0])
                        config['shape'] = [int(np.ceil(np.sqrt(num_neurons)))]
                        config['shape'].append(int(np.ceil(num_neurons/float(config['shape'][0]))))
                        
                if config['type'] == 0:
                    config['handle'] = self.axarr[ind].quiver(\
                               np.reshape(self._data[LPU][config['ids'][0],0],config['shape']),\
                               np.reshape(self._data[LPU][config['ids'][1],0],config['shape']))
                elif config['type'] == 1:
                    X = np.reshape(self._data[LPU][config['ids'][0],0],config['shape'])
                    Y = np.reshape(self._data[LPU][config['ids'][1],0],config['shape'])
                    V = (X**2 + Y**2)**0.5
                    H = (np.arctan2(X,Y)+np.pi)/(2*np.pi)
                    S = np.ones_like(V)
                    HSV = np.dstack((H,S,V))
                    RGB = hsv_to_rgb(HSV)
                    config['handle'] = self.axarr[ind].imshow(RGB)
                elif config['type'] == 2:
                    if 'trans' in config:
                        if config['trans'] is True:
                            to_transpose = True
                        else:
                            to_transpose = False
                    else:
                        to_transpose = False
                        config['trans'] = False
                    
                    if to_transpose:
                        temp = self.axarr[ind].imshow(np.transpose(np.reshape(\
                                self._data[LPU][config['ids'][0],0], config['shape'])))
                    else:
                        temp = self.axarr[ind].imshow(np.reshape(\
                                self._data[LPU][config['ids'][0],0], config['shape']))
            
            
            
                    temp.set_clim(self._imlim)
                    temp.set_cmap(plt.cm.gist_gray)
                    config['handle'] = temp
                elif config['type'] == 3:
                    fmt = config['fmt'] if 'fmt' in config else '' 
                    self.axarr[ind].set_xlim(self._xlim)
                    self.axarr[ind].set_ylim(self._ylim)
                    if len(config['ids'][0])==1:
                        config['handle'] = self.axarr[ind].plot([0], \
                                            [self._data[LPU][config['ids'][0][0],0]], fmt)[0]
                        config['ydata'] = [self._data[LPU][config['ids'][0][0],0]]
                    else:
                        config['handle'] = self.axarr[ind].plot(self._data[LPU][config['ids'][0],0])[0]
                        
                elif config['type'] == 4:
                    config['handle'] = self.axarr[ind]
                    config['handle'].vlines(0, 0, 0.01)
                    config['handle'].set_ylim([.5, len(config['ids'][0]) + .5])
                    config['handle'].set_ylabel('Neurons',
                                                fontsize=self._fontsize-1, weight='bold')
                    config['handle'].set_xlabel('Time (s)',fontsize=self._fontsize-1, weight='bold')
                    config['handle'].set_xlim([0,len(self._data[LPU][config['ids'][0][0],:])*self._dt])
                    config['handle'].axes.set_yticks([])
                    config['handle'].axes.set_xticks([])
                for key in config.iterkeys():
                    if key not in keywds:
                        try:
                            self._set_wrapper(self.axarr[ind],key, config[key])
                        except:
                            pass
                        try:
                            self._set_wrapper(config['handle'],key, config[key])
                        except:
                            pass
                if config['type']<3:
                    config['handle'].axes.set_xticks([])
                    config['handle'].axes.set_yticks([])

            if self.suptitle is not None:
                self.f.suptitle(self._title, fontsize=self._fontsize+1, x=0.5,y=0.03, weight='bold')

        plt.tight_layout()

        if self.out_filename:
            self.writer = FFMpegFileWriter(fps=self.fps, codec=self.codec)

            # Use the output file to determine the name of the temporary frame
            # files so that two concurrently run visualizations don't clobber
            # each other's frames:
            self.writer.setup(self.f, self.out_filename, dpi=80,
                              frame_prefix=os.path.splitext(self.out_filename)[0]+'_')
            self.writer.frame_format = 'png'
            self.writer.grab_frame()
        else:
            self.f.show()
Пример #10
0
xdata, ydata = [], []
ln, = plt.plot([], [], 'r', animated=True)
f = 50


def init():
    ax.set_xlim(-3, 3)
    ax.set_ylim(-0.25, 2)
    ln.set_data(xdata, ydata)
    return ln,


def update(t):
    xdata.append(frame)
    ydata.append(np.exp(-frame**2))
    ln.set_data(xdata, ydata)
    return ln,


ani = FuncAnimation(fig,
                    update,
                    frames=2,
                    init_func=init,
                    blit=True,
                    interval=2.5,
                    repeat=False)
plt.show()

mywriter = FFMpegFileWriter(fps=25, codec="libx264")
ani.save("test.mp4", writer=mywriter)
Пример #11
0
 def animate(self,
             outputType="screen",
             color="random",
             speed=10,
             outDir="out",
             type="line"):
     if not os.path.exists(outDir): os.mkdir(outDir)
     fname = os.path.join(
         outDir, "knightPath-{}-{}".format(self.kn.shape[0],
                                           self.kn.shape[1]))
     knightAnimation = FuncAnimation(self.fig,
                                     self._genLine,
                                     fargs=(speed, color, type),
                                     repeat=False,
                                     frames=range(
                                         -speed,
                                         len(self.kn.history),
                                         speed,
                                     ),
                                     blit=False,
                                     interval=10,
                                     cache_frame_data=False)
     if outputType == "screen":
         plt.show()
     elif outputType == "avi":
         knightAnimation.save(
             fname + ".avi",
             writer=FFMpegFileWriter(fps=1,
                                     bitrate=100000,
                                     extra_args=['-vcodec', 'libx264']),
         )
         print("video file written to", fname + ".avi")
     elif outputType == "mp4":
         knightAnimation.save(
             fname + ".mp4",
             writer=FFMpegFileWriter(fps=1,
                                     bitrate=100000,
                                     extra_args=['-vcodec', 'libx264']),
         )
         print("MP4 video file written to", fname + ".mp4")
     elif outputType == "gif":
         knightPathAnim = FuncAnimation(self.fig,
                                        self._genAllLines,
                                        repeat=False,
                                        frames=range(1),
                                        blit=False,
                                        interval=10,
                                        cache_frame_data=False)
         knightPathAnim.save(fname + '.gif',
                             writer=ImageMagickFileWriter(fps=1))
         print("Image written to", fname + ".gif")
     elif outputType == "animgif":
         knightAnimation.save(fname + '.anim.gif',
                              writer=ImageMagickFileWriter(fps=1))
         print("Animated gif written to", fname + ".anim.gif")
     elif outputType == "html":
         open(fname + ".html", "w").write(
             self._genHtmlFrame(knightAnimation.to_html5_video(50.0)))
         print("HTML file written to", fname + ".html")
     else:
         print("Unknown outputType '" + outputType + "'")
         return
Пример #12
0
    def _initialize(self):

        # Count number of plots to create:
        num_plots = 0
        for config in self._config.itervalues():
            num_plots += len(config)

        # Set default grid of plot positions:
        if not self._rows * self._cols == num_plots:
            self._cols = int(np.ceil(np.sqrt(num_plots)))
            self._rows = int(np.ceil(num_plots / float(self._cols)))
        self.f, self.axarr = plt.subplots(self._rows,
                                          self._cols,
                                          figsize=self._figsize)

        # Remove unused subplots:
        for i in xrange(num_plots, self._rows * self._cols):
            plt.delaxes(self.axarr[np.unravel_index(i,
                                                    (self._rows, self._cols))])
        cnt = 0
        self.handles = []
        self.types = []
        keywds = ['handle', 'ydata', 'fmt', 'type', 'ids', 'shape']
        if not isinstance(self.axarr, np.ndarray):
            self.axarr = np.asarray([self.axarr])
        for LPU, configs in self._config.iteritems():
            for plt_id, config in enumerate(configs):
                ind = np.unravel_index(cnt, self.axarr.shape)
                cnt += 1

                # Some plot types require specific numbers of
                # neuron ID arrays:
                if 'type' in config:
                    if config['type'] == 'quiver':
                        assert len(config['ids']) == 2
                        config['type'] = 0
                    elif config['type'] == 'hsv':
                        assert len(config['ids']) == 2
                        config['type'] = 1
                    elif config['type'] == 'image':
                        assert len(config['ids']) == 1
                        config['type'] = 2
                    elif config['type'] == 'waveform':
                        config['type'] = 3
                    elif config['type'] == 'raster':
                        config['type'] = 4
                    elif config['type'] == 'rate':
                        config['type'] = 5
                    else:
                        raise ValueError('Plot type not supported')
                else:
                    if str(LPU).startswith('input') or not self._graph[LPU][
                            str(config['ids'][0])]['spiking']:
                        config['type'] = 2
                    else:
                        config['type'] = 4

                if config['type'] < 3:
                    if not 'shape' in config:

                        # XXX This can cause problems when the number
                        # of neurons is not equal to
                        # np.prod(config['shape'])
                        num_neurons = len(config['ids'][0])
                        config['shape'] = [int(np.ceil(np.sqrt(num_neurons)))]
                        config['shape'].append(
                            int(
                                np.ceil(num_neurons /
                                        float(config['shape'][0]))))

                if config['type'] == 0:
                    config['handle'] = self.axarr[ind].quiver(\
                               np.reshape(self._data[LPU][config['ids'][0],0],config['shape']),\
                               np.reshape(self._data[LPU][config['ids'][1],0],config['shape']))
                elif config['type'] == 1:
                    X = np.reshape(self._data[LPU][config['ids'][0], 0],
                                   config['shape'])
                    Y = np.reshape(self._data[LPU][config['ids'][1], 0],
                                   config['shape'])
                    V = (X**2 + Y**2)**0.5
                    H = (np.arctan2(X, Y) + np.pi) / (2 * np.pi)
                    S = np.ones_like(V)
                    HSV = np.dstack((H, S, V))
                    RGB = hsv_to_rgb(HSV)
                    config['handle'] = self.axarr[ind].imshow(RGB)
                elif config['type'] == 2:
                    if 'trans' in config:
                        if config['trans'] is True:
                            to_transpose = True
                        else:
                            to_transpose = False
                    else:
                        to_transpose = False
                        config['trans'] = False

                    if to_transpose:
                        temp = self.axarr[ind].imshow(np.transpose(np.reshape(\
                                self._data[LPU][config['ids'][0],0], config['shape'])))
                    else:
                        temp = self.axarr[ind].imshow(np.reshape(\
                                self._data[LPU][config['ids'][0],0], config['shape']))

                    temp.set_clim(self._imlim)
                    temp.set_cmap(plt.cm.gist_gray)
                    config['handle'] = temp
                elif config['type'] == 3:
                    fmt = config['fmt'] if 'fmt' in config else ''
                    self.axarr[ind].set_xlim(self._xlim)
                    self.axarr[ind].set_ylim(self._ylim)
                    if len(config['ids'][0]) == 1:
                        config['handle'] = self.axarr[ind].plot([0], \
                                            [self._data[LPU][config['ids'][0][0],0]], fmt)[0]
                        config['ydata'] = [
                            self._data[LPU][config['ids'][0][0], 0]
                        ]
                    else:
                        config['handle'] = self.axarr[ind].plot(
                            self._data[LPU][config['ids'][0], 0])[0]

                elif config['type'] == 4:
                    config['handle'] = self.axarr[ind]
                    config['handle'].vlines(0, 0, 0.01)
                    config['handle'].set_ylim([.5, len(config['ids'][0]) + .5])
                    config['handle'].set_ylabel('Neurons',
                                                fontsize=self._fontsize - 1,
                                                weight='bold')
                    config['handle'].set_xlabel('Time (s)',
                                                fontsize=self._fontsize - 1,
                                                weight='bold')
                    min_id = min(self._id_to_data_idx[LPU].keys())
                    min_idx = self._id_to_data_idx[LPU][min_id]
                    config['handle'].set_xlim(
                        [0, len(self._data[LPU][min_idx, :]) * self._dt])
                    config['handle'].axes.set_yticks([])
                    config['handle'].axes.set_xticks([])
                for key in config.iterkeys():
                    if key not in keywds:
                        try:
                            self._set_wrapper(self.axarr[ind], key,
                                              config[key])
                        except:
                            pass
                        try:
                            self._set_wrapper(config['handle'], key,
                                              config[key])
                        except:
                            pass
                if config['type'] < 3:
                    config['handle'].axes.set_xticks([])
                    config['handle'].axes.set_yticks([])

            if self.suptitle is not None:
                self.f.suptitle(self._title,
                                fontsize=self._fontsize + 1,
                                x=0.5,
                                y=0.03,
                                weight='bold')

        plt.tight_layout()

        if self.out_filename:
            self.writer = FFMpegFileWriter(fps=self.fps, codec=self.codec)

            # Use the output file to determine the name of the temporary frame
            # files so that two concurrently run visualizations don't clobber
            # each other's frames:
            self.writer.setup(
                self.f,
                self.out_filename,
                dpi=80,
                frame_prefix=os.path.splitext(self.out_filename)[0] + '_')
            self.writer.frame_format = 'png'
            self.writer.grab_frame()
        else:
            self.f.show()
Пример #13
0
def main():
    """
    Entry point for rendering the plot
    """
    anim_file_path = Path("./test.mp4")
    figure = plt.figure(figsize=(19.2, 10.8))

    file_writer = FFMpegFileWriter(fps=FRAME_RATE)
    with file_writer.saving(figure, anim_file_path, dpi=100):
        intro_text = Scene(
            0,
            169,
            1,
            draw_text(
                sentence="I've seen things you people wouldn't believe",
                text_pos_list=[16, 44],
                alpha_transitions=60,
                persist_frames=0,
                fade_out_frames=24,
                font_size=48,
                left_offset=0.12,
                bottom_offset=0.0,
            ),
        )
        eye = Scene(
            intro_text.start_frame,
            intro_text.end_frame,
            0,
            draw_eye(axes_dims=[0, 0.22, 1.0, 0.8],
                     persist_frames=24,
                     fade_out_frames=24),
        )
        heatmap = Scene(
            intro_text.end_frame - 47,
            intro_text.end_frame + 145,
            2,
            draw_fire_automata(
                axes_dims=[0.2, 0.35, 0.6, 0.6],
                fade_in_frames=24,
                update_frames=144,
                fade_out_frames=24,
            ),
        )
        gaussian = Scene(
            intro_text.end_frame + 1,
            intro_text.end_frame + 145,
            1,
            draw_gaussian(
                axes_dims=[0.05, 0.1, 0.9, 0.25],
                fade_in_frames=24,
                update_frames=72,
                persist_frames=24,
                fade_out_frames=24,
            ),
        )
        heatmaps_text = Scene(
            145,
            313,
            1,
            draw_text(
                sentence="Heat maps on fire off the shoulder of a Gaussian",
                text_pos_list=[17, 48],
                alpha_transitions=60,
                persist_frames=24,
                fade_out_frames=24,
                font_size=48,
                left_offset=0.08,
                bottom_offset=0.0,
            ),
        )
        learning_curve = Scene(
            heatmaps_text.end_frame + 1,
            heatmaps_text.end_frame + 277,
            1,
            draw_learning_curve(
                topo_axes_dims=[0.01, 0.15, 0.5, 0.8],
                learning_curve_axes_dims=[0.54, 0.15, 0.44, 0.8],
                fade_in_frames=24,
                update_frames=156,
                persist_frames=72,
                fade_out_frames=24,
            ),
        )
        residuals_text = Scene(
            heatmaps_text.end_frame + 1,
            heatmaps_text.end_frame + 277,
            2,
            draw_text(
                sentence=
                "I watched residuals diminish down the arc of ten thousand weights",
                text_pos_list=[10, 41, 65],
                alpha_transitions=60,
                persist_frames=72,
                fade_out_frames=24,
                font_size=40,
                left_offset=0.015,
                bottom_offset=0.0,
            ),
        )
        fade_text_1 = Scene(
            residuals_text.end_frame + 1,
            residuals_text.end_frame + 193,
            2,
            draw_text(
                sentence="All these visuals",
                text_pos_list=[17],
                alpha_transitions=60,
                persist_frames=84,
                fade_out_frames=48,
                font_size=100,
                left_offset=0.2,
                bottom_offset=0.53,
            ),
        )
        fade_text_2 = Scene(
            residuals_text.end_frame + 60,
            residuals_text.end_frame + 193,
            2,
            draw_text(
                sentence="will fade in time",
                text_pos_list=[17],
                alpha_transitions=60,
                persist_frames=24,
                fade_out_frames=48,
                font_size=100,
                left_offset=0.2,
                bottom_offset=0.37,
            ),
        )
        terrain = Scene(
            fade_text_2.end_frame + 49,
            fade_text_2.end_frame + 169,
            1,
            draw_terrain(
                axes_dims=[0.05, 0.2, 0.9, 0.8],
                fade_in_frames=24,
                update_frames=72,
                fade_out_frames=24,
                frame_jiggle=0.01,
            ),
        )
        tears_text = Scene(
            fade_text_2.end_frame + 1,
            fade_text_2.end_frame + 169,
            2,
            draw_text(
                sentence="Like tears in terrain",
                text_pos_list=[11, 21],
                alpha_transitions=48,
                persist_frames=48,
                fade_out_frames=24,
                font_size=60,
                left_offset=0.3,
                bottom_offset=0,
            ),
        )
        pi_text_1 = Scene(
            tears_text.end_frame + 1,
            tears_text.end_frame + 217,
            2,
            draw_text(
                sentence="Time to pi",
                text_pos_list=[4, 10],
                alpha_transitions=72,
                persist_frames=72,
                fade_out_frames=0,
                font_size=80,
                left_offset=0.08,
                bottom_offset=0.8,
            ),
        )
        pi_text_2 = Scene(
            tears_text.end_frame + 169,
            tears_text.end_frame + 217,
            1,
            draw_text(
                sentence="Time to pip install matplotlib",
                text_pos_list=[30],
                alpha_transitions=24,
                persist_frames=100,
                fade_out_frames=0,
                font_size=80,
                left_offset=0.08,
                bottom_offset=0.8,
            ),
        )
        smiley = Scene(
            tears_text.end_frame + 169,
            tears_text.end_frame + 217,
            2,
            draw_smiley(fade_in_frames=24, pos_x=0, pos_y=0),
        )
        active_scenes_list: List[Scene] = [
            intro_text,
            eye,
            heatmap,
            gaussian,
            heatmaps_text,
            learning_curve,
            residuals_text,
            fade_text_1,
            fade_text_2,
            terrain,
            tears_text,
            pi_text_1,
            pi_text_2,
            smiley,
        ]
        active_scenes_list.sort(key=lambda scene: scene.zorder, reverse=True)

        for frame_number in itertools.count():
            figure.clear()
            render_axes = figure.add_axes([0.0, 0.0, 1.0, 1.0])
            render_axes.axis("off")
            active_scene_count = len(active_scenes_list)
            if active_scene_count <= 0:
                break
            rendered_scene: bool = False
            for scene_index in range(active_scene_count - 1, -1, -1):
                scene = active_scenes_list[scene_index]
                if frame_number >= scene.start_frame:
                    if frame_number > scene.end_frame:
                        del active_scenes_list[scene_index]
                    else:
                        render_axes.imshow(next(scene.render_frame))
                        rendered_scene = True
            if rendered_scene is True:
                file_writer.grab_frame(facecolor=BACKGROUND_COLOUR)
Пример #14
0
  def __init__(self, temp_prefix='_tmp', clear_temp=True, *args, **kwargs):

    FFMpegFileWriter.__init__(self, *args, **kwargs)
    self.temp_prefix=temp_prefix
    self.clear_temp=clear_temp
Пример #15
0
def run_simulation(adj_list, node_mappings, verbose=0, visualize=False):
    def draw_legend(results):
        key_patches = []
        key_patches.append(
            mpatches.Patch(color='lightgray',
                           label='Unclaimed: %d' % results[None]))
        for k in node_mappings.keys():
            key_patches.append(
                mpatches.Patch(color=m.to_rgba(key_colors[k]),
                               label="%s: %d" % (str(k), results[k])))
        ax.legend(loc='upper left',
                  bbox_to_anchor=(-0.1, 1.1),
                  fancybox=True,
                  handles=key_patches)

    """
    Function: run_simulation
    ------------------------
    Runs the simulation. Returns a dictionary with the key as the "color"/name,
    and the value as the number of nodes that "color"/name got.

    adj_list: A dictionary representation of the graph adjacencies.
    node_mappings: A dictionary where the key is a name and the value is a list
                   of seed nodes associated with that name.
    """
    # Stores a mapping of nodes to their color.
    node_color = dict([(node, None) for node in adj_list.keys()])
    # print('Initializing test graph...')
    init(node_mappings, node_color, verbose)
    # print('Done')

    if visualize:
        print('Preparing graph for visualization...', end='', flush=True)
        # Load and build graph
        G = load_graph(adj_list)
        pos = nx.drawing.layout.spring_layout(G,
                                              k=0.1,
                                              random_state=0,
                                              scale=10)
        # pos = nx.drawing.layout.kamada_kawai_layout(G, scale=10)
        # pos = nx.nx_pydot.pydot_layout(G)

        # Set up animation writers
        from matplotlib.animation import FFMpegFileWriter
        writer = FFMpegFileWriter(fps=1)
        import time
        filename = str(len(node_mappings.keys(
        ))) + "_Players " + time.strftime("%Y%m%d %H%M%S") + ".mp4"

        # Set up pyplot
        fig, ax = plt.subplots(figsize=(16, 9))
        fig.subplots_adjust(bottom=0.2)
        key_colors = dict(
            map(lambda x: (x[1], x[0]), enumerate(node_mappings.keys())))
        key_colors[None] = -1
        import matplotlib.patches as mpatches
        key_patches = []
        colormap = cm.tab10
        colormap.set_bad('lightgray')
        m = cm.ScalarMappable(cmap=colormap,
                              norm=colors.Normalize(0,
                                                    len(node_mappings.keys())))
        key_patches.append(mpatches.Patch(color='lightgray',
                                          label='Unclaimed'))
        for k in node_mappings.keys():
            key_patches.append(
                mpatches.Patch(color=m.to_rgba(key_colors[k]), label=str(k)))
        # fig.legend(loc='upper left', bbox_to_anchor=(0, 1), fancybox=True, handles=key_patches)
        # plt.tight_layout()
        writer.setup(fig, filename, dpi=200)
        print('DONE')

    if verbose:
        print('Initial nodes counts minus overlaps:')
        print(get_result(node_mappings.keys(), node_color))
    generation = 1

    # Keep calculating the epidemic until it stops changing. Randomly choose
    # number between 100 and 200 as the stopping point if the epidemic does not
    # converge.
    prev = None
    nodes = adj_list.keys()
    last_iter = randint(100, 200)

    while not is_stable(generation, last_iter, prev, node_color):
        legends = list(node_mappings.keys())
        legends.append(None)
        results = get_result(legends, node_color)
        if verbose:
            print(results)
        if visualize:
            ax.clear()
            values = np.array(
                [key_colors[node_color.get(node, None)] for node in G.nodes()])
            values = np.ma.masked_where(values < 0, values)
            draw_frame(G, pos, ax, m.to_rgba(values), generation)
            draw_legend(results)
            plt.axis('off')
            writer.grab_frame()
        prev = deepcopy(node_color)
        for node in nodes:
            (changed, color) = update(adj_list, prev, node)
            # Store the node's new color only if it changed.
            if changed: node_color[node] = color
        # NOTE: prev contains the state of the graph of the previous generation,
        # node_colros contains the state of the graph at the current generation.
        # You could check these two dicts if you want to see the intermediate steps
        # of the epidemic.
        generation += 1
    if visualize:
        writer.finish()
        # writer.cleanup()
    return get_result(node_mappings.keys(), node_color)
Пример #16
0
class visualizer(object):
    """
    Visualize the output produced by LPU models.

    Examples
    --------
        import neurokernel.LPU.utils.visualizer as vis
        V = vis.visualizer()
        config1 = {}
        config1['type'] = 'image'
        config1['shape'] = [32,24]
        config1['clim'] = [-0.6,0.5]
        config2 = config1.copy()
        config2['clim'] = [-0.55,-0.45]
        V.add_LPU('lamina_output.h5', 'lamina.gexf.gz','lamina')
        V.add_plot(config1, 'lamina', 'R1')
        V.add_plot(config2, 'lamina', 'L1')
        V.update_interval = 50
        V.out_filename = 'test.avi'
        V.run()

        
    """

    def __init__(self):
        self._xlim = [0,1]
        self._ylim = [-1,1]
        self._imlim = [-1, 1]
        self._update_interval = 50
        self._out_file = None
        self._fps = 5
        self._codec = 'libtheora'
        self._config = OrderedDict()
        self._rows = 0
        self._cols = 0
        self._figsize = (16,9)
        self._fontsize = 18
        self._t = 1
        self._dt = 1
        self._data = {}
        self._graph = {}
        self._id_to_data_idx = {}
        self._maxt = None
        self._title = None

    def add_LPU(self, data_file, gexf_file=None, LPU=None, win=None):
        '''
        Add data associated with a specific LPU to a visualization.
        To add a plot containing neurons from a particular LPU,
        the LPU needs to be added to the visualization using this
        function. Note that outputs from multiple neurons can
        be visualized using the same visualizer object.

        Parameters
        ----------
        data_file: str
             Location of the h5 file generated by neurokernel
             containing the output of the LPU

        gexf_file: str
            Location of the gexf file describing the LPU.
            If not specified, it will be assumed that the h5 file
            contains input.

        LPU: str
            Name of the LPU. Will be used as identifier to add plots.
            For input signals, the name of the LPU will be prepended
            with 'input_'. For example::

                V.add_LPU('vision_in.h5', LPU='vision')

            will create the LPU identifier 'input_vision'.
            Therefore, adding a plot depicting this input can be done by::

                V.add_plot({''type':'image',imlim':[-0.5,0.5]},LPU='input_vision)

        win: slice/list
            Can be used to limit the visualization to a specific time window.
        
        '''
        if gexf_file:
            self._graph[LPU] = nx.read_gexf(gexf_file)

            # Map neuron ids to index into output data array:
            self._id_to_data_idx[LPU] = {m:i for i, m in \
                enumerate(sorted([int(n) for n, k in \
                                  self._graph[LPU].nodes_iter(True) if k['spiking']]))}
        else:
            if LPU:
                LPU = 'input_' + str(LPU)
            else:
                LPU = 'input_' + str(len(self._data))
        if not LPU:
            LPU = len(self._data)
        self._data[LPU] = np.transpose(sio.read_array(data_file))
        if win is not None:
            self._data[LPU] = self._data[LPU][:,win]
        if self._maxt:
            self._maxt = min(self._maxt, self._data[LPU].shape[1])
        else:
            self._maxt = self._data[LPU].shape[1]

    def run(self, final_frame_name=None, dpi=300):
        '''
        Starts the visualization process. If the property out_filename is set,
        the visualization is saved as a video to the disk. If it is not
        specified, the animation will be displayed on screen.
        Please refer to documentation of add_LPU, add_plot and
        the properties of this class on how to configure the visualizer before call this
        method. An example can be found in the class doc string.

        Paramters
        ----------

        final_frame_name: str
            Optional. If specified, the final frame of the animation will be saved
            to disk.

        dpi: int
            Default(300). If final_frame_name is specified, this parameter will control
            the resolution at which the final frame is saved to disk.

        Note:
        -----
        If update_interval is set to 0 or None, it will be replaced by the
        index of the final time step. As a result, the visualizer will only
        generate the final frame.

        '''

        self._initialize()
        if not self._update_interval:
            self._update_interval = self._maxt - 1
        self._t = self._update_interval + 1
        for _ in range(self._update_interval, 
                       self._maxt, self._update_interval):
            self._update()
        if final_frame_name is not None:
            self.f.savefig(final_frame_name, dpi=dpi)
        if self.out_filename:
            self._close()

    def _set_wrapper(self, obj, name, value):
        name = name.lower()
        func = getattr(obj, 'set_'+name, None)
        if func:
            try:
                func(value, fontsize=self._fontsize, weight='bold')
            except:
                try:
                    func(value)
                except:
                    pass

    def _initialize(self):

        # Count number of plots to create:
        num_plots = 0
        for config in self._config.itervalues():
            num_plots += len(config)

        # Set default grid of plot positions:
        if not self._rows*self._cols == num_plots:
            self._cols = int(np.ceil(np.sqrt(num_plots)))
            self._rows = int(np.ceil(num_plots/float(self._cols)))
        self.f, self.axarr = plt.subplots(self._rows, self._cols,
                                          figsize=self._figsize)

        # Remove unused subplots:
        for i in xrange(num_plots, self._rows*self._cols):
            plt.delaxes(self.axarr[np.unravel_index(i, (self._rows, self._cols))])
        cnt = 0
        self.handles = []
        self.types = []
        keywds = ['handle', 'ydata', 'fmt', 'type', 'ids', 'shape'] 
        if not isinstance(self.axarr, np.ndarray):
            self.axarr = np.asarray([self.axarr])
        for LPU, configs in self._config.iteritems():
            for plt_id, config in enumerate(configs):
                ind = np.unravel_index(cnt, self.axarr.shape)
                cnt+=1

                # Some plot types require specific numbers of
                # neuron ID arrays:
                if 'type' in config:
                    if config['type'] == 'quiver':
                        assert len(config['ids'])==2
                        config['type'] = 0
                    elif config['type'] == 'hsv':
                        assert len(config['ids'])==2
                        config['type'] = 1
                    elif config['type'] == 'image':
                        assert len(config['ids'])==1
                        config['type'] = 2
                    elif config['type'] == 'waveform':
                        config['type'] = 3
                    elif config['type'] == 'raster':
                        config['type'] = 4
                    elif config['type'] == 'rate':
                        config['type'] = 5
                    else:
                        raise ValueError('Plot type not supported')
                else:
                    if str(LPU).startswith('input') or not self._graph[LPU][str(config['ids'][0])]['spiking']:
                        config['type'] = 2
                    else:
                        config['type'] = 4

                if config['type'] < 3:
                    if not 'shape' in config:

                        # XXX This can cause problems when the number
                        # of neurons is not equal to
                        # np.prod(config['shape'])
                        num_neurons = len(config['ids'][0])
                        config['shape'] = [int(np.ceil(np.sqrt(num_neurons)))]
                        config['shape'].append(int(np.ceil(num_neurons/float(config['shape'][0]))))

                if config['type'] == 0:
                    config['handle'] = self.axarr[ind].quiver(\
                               np.reshape(self._data[LPU][config['ids'][0],0],config['shape']),\
                               np.reshape(self._data[LPU][config['ids'][1],0],config['shape']))
                elif config['type'] == 1:
                    X = np.reshape(self._data[LPU][config['ids'][0],0],config['shape'])
                    Y = np.reshape(self._data[LPU][config['ids'][1],0],config['shape'])
                    V = (X**2 + Y**2)**0.5
                    H = (np.arctan2(X,Y)+np.pi)/(2*np.pi)
                    S = np.ones_like(V)
                    HSV = np.dstack((H,S,V))
                    RGB = hsv_to_rgb(HSV)
                    config['handle'] = self.axarr[ind].imshow(RGB)
                elif config['type'] == 2:
                    if 'trans' in config:
                        if config['trans'] is True:
                            to_transpose = True
                        else:
                            to_transpose = False
                    else:
                        to_transpose = False
                        config['trans'] = False

                    if to_transpose:
                        temp = self.axarr[ind].imshow(np.transpose(np.reshape(\
                                self._data[LPU][config['ids'][0],0], config['shape'])))
                    else:
                        temp = self.axarr[ind].imshow(np.reshape(\
                                self._data[LPU][config['ids'][0],0], config['shape']))



                    temp.set_clim(self._imlim)
                    temp.set_cmap(plt.cm.gist_gray)
                    config['handle'] = temp
                elif config['type'] == 3:
                    fmt = config['fmt'] if 'fmt' in config else '' 
                    self.axarr[ind].set_xlim(self._xlim)
                    self.axarr[ind].set_ylim(self._ylim)
                    if len(config['ids'][0])==1:
                        config['handle'] = self.axarr[ind].plot([0], \
                                            [self._data[LPU][config['ids'][0][0],0]], fmt)[0]
                        config['ydata'] = [self._data[LPU][config['ids'][0][0],0]]
                    else:
                        config['handle'] = self.axarr[ind].plot(self._data[LPU][config['ids'][0],0])[0]

                elif config['type'] == 4:
                    config['handle'] = self.axarr[ind]
                    config['handle'].vlines(0, 0, 0.01)
                    config['handle'].set_ylim([.5, len(config['ids'][0]) + .5])
                    config['handle'].set_ylabel('Neurons',
                                                fontsize=self._fontsize-1, weight='bold')
                    config['handle'].set_xlabel('Time (s)',fontsize=self._fontsize-1, weight='bold')
                    min_id = min(self._id_to_data_idx[LPU].keys())
                    min_idx = self._id_to_data_idx[LPU][min_id]
                    config['handle'].set_xlim([0,len(self._data[LPU][min_idx,:])*self._dt])
                    config['handle'].axes.set_yticks([])
                    config['handle'].axes.set_xticks([])
                for key in config.iterkeys():
                    if key not in keywds:
                        try:
                            self._set_wrapper(self.axarr[ind],key, config[key])
                        except:
                            pass
                        try:
                            self._set_wrapper(config['handle'],key, config[key])
                        except:
                            pass
                if config['type']<3:
                    config['handle'].axes.set_xticks([])
                    config['handle'].axes.set_yticks([])

            if self.suptitle is not None:
                self.f.suptitle(self._title, fontsize=self._fontsize+1, x=0.5,y=0.03, weight='bold')

        plt.tight_layout()

        if self.out_filename:
            self.writer = FFMpegFileWriter(fps=self.fps, codec=self.codec)

            # Use the output file to determine the name of the temporary frame
            # files so that two concurrently run visualizations don't clobber
            # each other's frames:
            self.writer.setup(self.f, self.out_filename, dpi=80,
                              frame_prefix=os.path.splitext(self.out_filename)[0]+'_')
            self.writer.frame_format = 'png'
            self.writer.grab_frame()
        else:
            self.f.show()

    def _update(self):
        dt = self._dt
        t = self._t
        for key, configs in self._config.iteritems():
            data = self._data[key]
            for config in configs:
                if config['type'] == 3:
                    if len(config['ids'][0])==1:
                        config['ydata'].extend(np.reshape(np.double(\
                                        data[config['ids'][0], \
                                                  max(0,t-self._update_interval):t]),(-1,)))
                        config['handle'].set_xdata(dt*np.arange(0, t))
                        config['handle'].set_ydata(np.asarray(config['ydata']))
                    else:
                        config['handle'].set_ydata(\
                                        data[config['ids'][0], t])

                elif config['type']==4:

                    for j, id in enumerate(config['ids'][0]):

                        # Convert neuron id to index into array of generated outputs:
                        try:
                            idx = self._id_to_data_idx[key][id]
                        except:
                            continue
                        else:
                            for time in np.where(data[idx, max(0,t-self._update_interval):t])[0]:
                                config['handle'].vlines(float(t-time)*self._dt,j+0.75, j+1.25)
                else:
                    if config['type'] == 0:
                        shape = config['shape']
                        ids = config['ids']
                        config['handle'].U = np.reshape(data[ids[0], t],shape)
                        config['handle'].V = np.reshape(data[ids[1], t],shape)
                    elif config['type']==1:
                        shape = config['shape']
                        ids = config['ids']
                        X = np.reshape(data[ids[0], t],shape)
                        Y = np.reshape(data[ids[1], t],shape)
                        V = (X**2 + Y**2)**0.5
                        H = (np.arctan2(X,Y)+np.pi)/(2*np.pi)
                        S = np.ones_like(V)
                        HSV = np.dstack((H,S,V))
                        RGB = hsv_to_rgb(HSV)
                        config['handle'].set_data(RGB)
                    elif config['type'] == 2:
                        ids = config['ids']
                        if config['trans']:
                            config['handle'].set_data(
                                np.transpose(np.reshape(data[ids[0], t], config['shape'
                            ])))
                        else:
                            config['handle'].set_data(
                                np.reshape(data[ids[0], t], config['shape']))
                    
        self.f.canvas.draw()
        if self.out_filename:
            self.writer.grab_frame()

        self._t+=self._update_interval
            
    def add_plot(self, config_dict, LPU, names=[''], shift=0):
        '''
        Add a plot to the visualizer

        Parameters
        ----------
        config_dict: dict
            A dictionary specifying the plot attributes. The attribute
            names should be the keys.
            
            The following are the plot attributes that can be specfied using
            this dict.

            type - str
                This specifies the type of the plot. Has to be one of
                ['waveform', 'raster', 'image','hsv','quiver']
            
            ids - dict with either 1 or 2 entries
                Specifies the neuron ids from the associated LPU.
                The keys should be in [0,1] and the values
                should be a list of ids.
                For example::

                    {'ids':{0:[1,2]}}

                will plot neurons with ids 1 and 2.
                Two entries in the dictionary  are needed if the plot is
                of type 'hsv' or 'quiver'
                For example::

                     {'ids':{0:[:768],1:[768:1536]},'type':'HSV'}

                can be used to generate a HSV plot where the hue channel is
                controlled by the angle of the vector defined by the membrane
                potentials of the neurons with ids [:768] and [768:1536] and
                the value will be the magnitude of the same vector. 
            
                This parameter is optional for the following cases::

                    1) The plot is associated with input signals.
                    2) The names parameter is specified.

                If the above doesn't hold, this attribute needs to be specified.

            shape - list or tuple with two entries
                This attribute specifies the dimensions for plots of type image,
                hsv or quiver.
  
            title - str
                Optional. Can be used to control the title of the plot.

            
            In addition to the above, any parameter supported by matlpotlib
            for the particular type of plot can be specified.
            For example - 'imlim','clim','xlim','ylim' etc.
              
        LPU: str
            The name of the LPU associated to this plot.

        names: list
            Optional. A list of str specifying the neurons
            to plot. Can be used instead of specifying ids in the
            config_dict. The gexf file of the LPU needs to have
            the name attribute in order for this to be used.

        
        '''
        config = config_dict.copy()
        if not isinstance(names, list):
            names = [names]
        if not LPU in self._config:
            self._config[LPU] = []
        if 'ids' in config:
            # XXX should check whether the specified ids are within range
            self._config[LPU].append(config)
        elif str(LPU).startswith('input'):
            config['ids'] = [range(0, self._data[LPU].shape[0])]
            self._config[LPU].append(config)
        else:
            config['ids'] = {}
            for i,name in enumerate(names):
                config['ids'][i]=[]
                for id in range(len(self._graph[LPU].node)):
                    if self._graph[LPU].node[str(id)]['name'] == name:
                        config['ids'][i].append(id-shift)
            self._config[LPU].append(config)
        if not 'title' in config:
            if names[0]:
                config['title'] = "{0} - {1}".format(str(LPU),str(names[0]))
            else:
                if str(LPU).startswith('input_'):
                    config['title'] = LPU.split('_',1)[1] + ' - ' + 'Input'
                else:
                    config['title'] = str(LPU)

    def _close(self):
        self.writer.finish()

    @property
    def xlim(self):
        '''
        Get or set the limits of the x-axis for all the raster and waveform plots.
        Can be superseded for individual plots by specifying xlim in the confid_dict
        for that plot.

        See also
        --------
            add_plot
        '''
        return self._xlim

    @xlim.setter
    def xlim(self, value):
        self._xlim = value

    @property
    def ylim(self):
        '''
        Get or set the limits of the y-axis for all the raster and waveform plots.
        Can be superseded for individual plots by specifying xlim in the confid_dict
        for that plot.

        See also
        --------
            add_plot
        '''
        return self._ylim

    @ylim.setter
    def ylim(self, value):
        self._ylim = value

    @property
    def imlim(self): return self._imlim
    
    @imlim.setter
    def imlim(self, value):
        self._imlim = value

    @property
    def out_filename(self): return self._out_file

    @out_filename.setter
    def out_filename(self, value):
        assert(isinstance(value, str))
        self._out_file = value

    @property
    def fps(self): return self._fps

    @fps.setter
    def fps(self, value):
        assert(isinstance(value, int))
        self._fps = value

    @property
    def codec(self): return self._codec

    @codec.setter
    def codec(self, value):
        assert(isinstance(value, str))
        self._codec = value

    @property
    def rows(self): return self._rows

    @rows.setter
    def rows(self, value):
        self._rows = value

    @property
    def cols(self): return self._cols

    @cols.setter
    def cols(self, value):
        self._cols = value

    @property
    def dt(self): return self._dt

    @dt.setter
    def dt(self, value):
        self._dt = value

    @property
    def figsize(self): return self._figsize

    @figsize.setter
    def figsize(self, value):
        assert(isinstance(value, tuple) and len(value)==2)
        self._figsize = value

    @property
    def fontsize(self): return self._fontsize

    @fontsize.setter
    def fontsize(self, value):
        self._fontsize = value

    @property
    def suptitle(self): return self._title

    @suptitle.setter
    def suptitle(self, value):
        self._title = value

    @property
    def update_interval(self):
        """
        Gets or sets the update interval(in terms of time steps) for the animation.
        If value is 0 or None, update_interval will be set to the index of the
        final step. As a consequence, only the final frame will be generated.
        """
        return self._update_interval

    @update_interval.setter
    def update_interval(self, value):
        self._update_interval = value
Пример #17
0
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation, FFMpegFileWriter

rawdata = pd.read_csv(
    'D:/code/python/workspace/LTEDataVis/src/data/bbkyields1968.csv',
    dtype={
        'year': int,
        'plot': str,
        'grain': float,
        'colour': str
    })
years = rawdata['year'].unique()
plt.rcdefaults()
fig, ax = plt.subplots()

writer = FFMpegFileWriter()
writer.setup(fig, "test6.mp4", 100)


def init():
    pass


def update(i):
    year = years[i]
    yeardata = rawdata[rawdata['year'] == year].sort_values(by='grain',
                                                            ascending=False)

    print(yeardata)

    ypos = np.arange(len(yeardata['treatment']))
Пример #18
0
    def _initialize(self):

        # Count number of plots to create:
        num_plots = 0
        for config in self._config.itervalues():
            num_plots += len(config)

        # Set default grid of plot positions:
        if not self._rows * self._cols == num_plots:
            self._cols = int(np.ceil(np.sqrt(num_plots)))
            self._rows = int(np.ceil(num_plots / float(self._cols)))
        self.f, self.axarr = plt.subplots(self._rows,
                                          self._cols,
                                          figsize=self._figsize)

        # Remove unused subplots:
        for i in xrange(num_plots, self._rows * self._cols):
            plt.delaxes(self.axarr[np.unravel_index(i,
                                                    (self._rows, self._cols))])
        cnt = 0
        self.handles = []
        self.types = []
        keywds = ['handle', 'ydata', 'fmt', 'type', 'ids', 'shape', 'norm']
        # TODO: Irregular grid in U will make the plot better
        U, V = np.mgrid[0:np.pi / 2:complex(0, 60), 0:2 * np.pi:complex(0, 60)]
        X = np.cos(V) * np.sin(U)
        Y = np.sin(V) * np.sin(U)
        Z = np.cos(U)
        self._dome_pos_flat = (X.flatten(), Y.flatten(), Z.flatten())
        self._dome_pos = (X, Y, Z)
        self._dome_arr_shape = X.shape
        if not isinstance(self.axarr, np.ndarray):
            self.axarr = np.asarray([self.axarr])
        for LPU, configs in self._config.iteritems():
            for plt_id, config in enumerate(configs):
                ind = np.unravel_index(cnt, self.axarr.shape)
                cnt += 1

                # Some plot types require specific numbers of
                # neuron ID arrays:
                if 'type' in config:
                    if config['type'] == 'quiver':
                        assert len(config['ids']) == 2
                        config['type'] = 0
                    elif config['type'] == 'hsv':
                        assert len(config['ids']) == 2
                        config['type'] = 1
                    elif config['type'] == 'image':
                        assert len(config['ids']) == 1
                        config['type'] = 2
                    elif config['type'] == 'waveform':
                        config['type'] = 3
                    elif config['type'] == 'raster':
                        config['type'] = 4
                    elif config['type'] == 'rate':
                        config['type'] = 5
                    elif config['type'] == 'dome':
                        config['type'] = 6
                    else:
                        raise ValueError('Plot type not supported')
                else:
                    if str(LPU).startswith(
                            'input') and not self._graph[LPU].node[str(
                                config['ids'][0][0])]['spiking']:
                        config['type'] = 2
                    else:
                        config['type'] = 4

                if config['type'] < 3:
                    if not 'shape' in config:

                        # XXX This can cause problems when the number
                        # of neurons is not equal to
                        # np.prod(config['shape'])
                        num_neurons = len(config['ids'][0])
                        config['shape'] = [int(np.ceil(np.sqrt(num_neurons)))]
                        config['shape'].append(
                            int(
                                np.ceil(num_neurons /
                                        float(config['shape'][0]))))

                if config['type'] == 0:
                    config['handle'] = self.axarr[ind].quiver(\
                               np.reshape(self._data[LPU][config['ids'][0],0],config['shape']),\
                               np.reshape(self._data[LPU][config['ids'][1],0],config['shape']))
                elif config['type'] == 1:
                    X = np.reshape(self._data[LPU][config['ids'][0], 0],
                                   config['shape'])
                    Y = np.reshape(self._data[LPU][config['ids'][1], 0],
                                   config['shape'])
                    V = (X**2 + Y**2)**0.5
                    H = (np.arctan2(X, Y) + np.pi) / (2 * np.pi)
                    S = np.ones_like(V)
                    HSV = np.dstack((H, S, V))
                    RGB = hsv_to_rgb(HSV)
                    config['handle'] = self.axarr[ind].imshow(RGB)
                elif config['type'] == 2:
                    if 'trans' in config:
                        if config['trans'] is True:
                            to_transpose = True
                        else:
                            to_transpose = False
                    else:
                        to_transpose = False
                        config['trans'] = False

                    if to_transpose:
                        temp = self.axarr[ind].imshow(np.transpose(np.reshape(\
                                self._data[LPU][config['ids'][0],0], config['shape'])))
                    else:
                        temp = self.axarr[ind].imshow(np.reshape(\
                                self._data[LPU][config['ids'][0],0], config['shape']))

                    temp.set_clim(self._imlim)
                    temp.set_cmap(plt.cm.gist_gray)
                    config['handle'] = temp
                elif config['type'] == 3:
                    fmt = config['fmt'] if 'fmt' in config else ''
                    self.axarr[ind].set_xlim(self._xlim)
                    self.axarr[ind].set_ylim(self._ylim)
                    if len(config['ids'][0]) == 1:
                        config['handle'] = self.axarr[ind].plot([0], \
                                            [self._data[LPU][config['ids'][0][0],0]], fmt)[0]
                        config['ydata'] = [
                            self._data[LPU][config['ids'][0][0], 0]
                        ]
                    else:
                        config['handle'] = self.axarr[ind].plot(
                            self._data[LPU][config['ids'][0], 0])[0]

                elif config['type'] == 4:
                    config['handle'] = self.axarr[ind]
                    config['handle'].vlines(0, 0, 0.01)
                    config['handle'].set_ylim([.5, len(config['ids'][0]) + .5])
                    config['handle'].set_ylabel('Neurons',
                                                fontsize=self._fontsize - 1,
                                                weight='bold')
                    config['handle'].set_xlabel('Time (s)',
                                                fontsize=self._fontsize - 1,
                                                weight='bold')
                    min_id = min(self._id_to_data_idx[LPU].keys())
                    min_idx = self._id_to_data_idx[LPU][min_id]
                    config['handle'].set_xlim(
                        [0, len(self._data[LPU][min_idx, :]) * self._dt])
                    config['handle'].axes.set_yticks([])
                    config['handle'].axes.set_xticks([])
                elif config['type'] == 6:
                    self.axarr[ind].axes.set_yticks([])
                    self.axarr[ind].axes.set_xticks([])
                    self.axarr[ind] = self.f.add_subplot(self._rows,
                                                         self._cols,
                                                         cnt,
                                                         projection='3d')
                    config['handle'] = self.axarr[ind]
                    config['handle'].axes.set_yticks([])
                    config['handle'].axes.set_xticks([])
                    config['handle'].xaxis.set_ticks([])
                    config['handle'].yaxis.set_ticks([])
                    config['handle'].zaxis.set_ticks([])
                    if 'norm' not in config.keys():
                        config['norm'] = Normalize(vmin=-70, vmax=0, clip=True)
                    elif config['norm'] == 'auto':
                        if self._data[LPU].shape[1] > 100:
                            config['norm'] = Normalize(
                                vmin=np.min(self._data[LPU][config['ids'][0],
                                                            100:]),
                                vmax=np.max(self._data[LPU][config['ids'][0],
                                                            100:]),
                                clip=True)
                        else:
                            config['norm'] = Normalize(
                                vmin=np.min(
                                    self._data[LPU][config['ids'][0], :]),
                                vmax=np.max(
                                    self._data[LPU][config['ids'][0], :]),
                                clip=True)

                    node_dict = self._graph[LPU].node
                    if str(LPU).startswith('input'):
                        latpositions = np.asarray([ node_dict[str(nid)]['lat'] \
                                                    for nid in range(len(node_dict)) \
                                                    if node_dict[str(nid)]['extern'] ])
                        longpositions = np.asarray([ node_dict[str(nid)]['long'] \
                                                     for nid in range(len(node_dict)) \
                                                     if node_dict[str(nid)]['extern'] ])
                    else:
                        latpositions = np.asarray([
                            node_dict[str(nid)]['lat']
                            for nid in config['ids'][0]
                        ])
                        longpositions = np.asarray([
                            node_dict[str(nid)]['long']
                            for nid in config['ids'][0]
                        ])
                    xx = np.cos(longpositions) * np.sin(latpositions)
                    yy = np.sin(longpositions) * np.sin(latpositions)
                    zz = np.cos(latpositions)
                    config['positions'] = (xx, yy, zz)
                    colors = griddata(config['positions'],
                                      self._data[LPU][config['ids'][0],
                                                      0], self._dome_pos_flat,
                                      'nearest').reshape(self._dome_arr_shape)
                    colors = config['norm'](colors).data
                    colors = np.tile(
                        np.reshape(colors, [
                            self._dome_arr_shape[0], self._dome_arr_shape[1], 1
                        ]), [1, 1, 4])
                    colors[:, :, 3] = 1.0
                    config['handle'].plot_surface(self._dome_pos[0],
                                                  self._dome_pos[1],
                                                  self._dome_pos[2],
                                                  rstride=1,
                                                  cstride=1,
                                                  facecolors=colors,
                                                  antialiased=False,
                                                  shade=False)

                for key in config.iterkeys():
                    if key not in keywds:
                        try:
                            self._set_wrapper(self.axarr[ind], key,
                                              config[key])
                        except:
                            pass
                        try:
                            self._set_wrapper(config['handle'], key,
                                              config[key])
                        except:
                            pass

                if config['type'] < 3:
                    config['handle'].axes.set_xticks([])
                    config['handle'].axes.set_yticks([])

            if self.suptitle is not None:
                self.f.suptitle(self._title,
                                fontsize=self._fontsize + 1,
                                x=0.5,
                                y=0.03,
                                weight='bold')

        plt.tight_layout()

        if self.out_filename:
            if self.FFMpeg is None:
                if which(matplotlib.rcParams['animation.ffmpeg_path']):
                    self.writer = FFMpegFileWriter(fps=self.fps,
                                                   codec=self.codec)
                elif which(matplotlib.rcParams['animation.avconv_path']):
                    self.writer = AVConvFileWriter(fps=self.fps,
                                                   codec=self.codec)
                else:
                    raise RuntimeError('cannot find ffmpeg or avconv')
            elif self.FFMpeg:
                if which(matplotlib.rcParams['animation.ffmpeg_path']):
                    self.writer = FFMpegFileWriter(fps=self.fps,
                                                   codec=self.codec)
                else:
                    raise RuntimeError('cannot find ffmpeg')
            else:
                if which(matplotlib.rcParams['animation.avconv_path']):
                    self.writer = AVConvFileWriter(fps=self.fps,
                                                   codec=self.codec)
                else:
                    raise RuntimeError('cannot find avconv')

            # Use the output file to determine the name of the temporary frame
            # files so that two concurrently run visualizations don't clobber
            # each other's frames:
            self.writer.setup(
                self.f,
                self.out_filename,
                dpi=80,
                frame_prefix=os.path.splitext(self.out_filename)[0] + '_')
            self.writer.frame_format = 'png'
            self.writer.grab_frame()
        else:
            self.f.show()
Пример #19
0
class visualizer(object):
    """
    Visualize the output produced by LPU models.

    Examples
    --------
        import neurokernel.LPU.utils.visualizer as vis
        V = vis.visualizer()
        config1 = {}
        config1['type'] = 'image'
        config1['shape'] = [32,24]
        config1['clim'] = [-0.6,0.5]
        config2 = config1.copy()
        config2['clim'] = [-0.55,-0.45]
        V.add_LPU('lamina_output.h5', 'lamina.gexf.gz','lamina')
        V.add_plot(config1, 'lamina', 'R1')
        V.add_plot(config2, 'lamina', 'L1')
        V.update_interval = 50
        V.out_filename = 'test.avi'
        V.run()

        
    """
    def __init__(self):
        self._xlim = [0, 1]
        self._ylim = [-1, 1]
        self._imlim = [-1, 1]
        self._update_interval = 50
        self._out_file = None
        self._fps = 5
        self._codec = 'libtheora'
        self._config = OrderedDict()
        self._rows = 0
        self._cols = 0
        self._figsize = (16, 9)
        self._fontsize = 18
        self._t = 1
        self._dt = 1
        self._data = {}
        self._graph = {}
        self._id_to_data_idx = {}
        self._maxt = None
        self._title = None

    def add_LPU(self, data_file, gexf_file=None, LPU=None, win=None):
        '''
        Add data associated with a specific LPU to a visualization.
        To add a plot containing neurons from a particular LPU,
        the LPU needs to be added to the visualization using this
        function. Note that outputs from multiple neurons can
        be visualized using the same visualizer object.

        Parameters
        ----------
        data_file: str
             Location of the h5 file generated by neurokernel
             containing the output of the LPU

        gexf_file: str
            Location of the gexf file describing the LPU.
            If not specified, it will be assumed that the h5 file
            contains input.

        LPU: str
            Name of the LPU. Will be used as identifier to add plots.
            For input signals, the name of the LPU will be prepended
            with 'input_'. For example::

                V.add_LPU('vision_in.h5', LPU='vision')

            will create the LPU identifier 'input_vision'.
            Therefore, adding a plot depicting this input can be done by::

                V.add_plot({''type':'image',imlim':[-0.5,0.5]},LPU='input_vision)

        win: slice/list
            Can be used to limit the visualization to a specific time window.
        
        '''
        if gexf_file:
            self._graph[LPU] = nx.read_gexf(gexf_file)

            # Map neuron ids to index into output data array:
            self._id_to_data_idx[LPU] = {m:i for i, m in \
                enumerate(sorted([int(n) for n, k in \
                                  self._graph[LPU].nodes_iter(True) if k['spiking']]))}
        else:
            if LPU:
                LPU = 'input_' + str(LPU)
            else:
                LPU = 'input_' + str(len(self._data))
        if not LPU:
            LPU = len(self._data)
        self._data[LPU] = np.transpose(sio.read_array(data_file))
        if win is not None:
            self._data[LPU] = self._data[LPU][:, win]
        if self._maxt:
            self._maxt = min(self._maxt, self._data[LPU].shape[1])
        else:
            self._maxt = self._data[LPU].shape[1]

    def run(self, final_frame_name=None, dpi=300):
        '''
        Starts the visualization process. If the property out_filename is set,
        the visualization is saved as a video to the disk. If it is not
        specified, the animation will be displayed on screen.
        Please refer to documentation of add_LPU, add_plot and
        the properties of this class on how to configure the visualizer before call this
        method. An example can be found in the class doc string.

        Paramters
        ----------

        final_frame_name: str
            Optional. If specified, the final frame of the animation will be saved
            to disk.

        dpi: int
            Default(300). If final_frame_name is specified, this parameter will control
            the resolution at which the final frame is saved to disk.

        Note:
        -----
        If update_interval is set to 0 or None, it will be replaced by the
        index of the final time step. As a result, the visualizer will only
        generate the final frame.

        '''

        self._initialize()
        if not self._update_interval:
            self._update_interval = self._maxt - 1
        self._t = self._update_interval + 1
        for _ in range(self._update_interval, self._maxt,
                       self._update_interval):
            self._update()
        if final_frame_name is not None:
            self.f.savefig(final_frame_name, dpi=dpi)
        if self.out_filename:
            self._close()

    def _set_wrapper(self, obj, name, value):
        name = name.lower()
        func = getattr(obj, 'set_' + name, None)
        if func:
            try:
                func(value, fontsize=self._fontsize, weight='bold')
            except:
                try:
                    func(value)
                except:
                    pass

    def _initialize(self):

        # Count number of plots to create:
        num_plots = 0
        for config in self._config.itervalues():
            num_plots += len(config)

        # Set default grid of plot positions:
        if not self._rows * self._cols == num_plots:
            self._cols = int(np.ceil(np.sqrt(num_plots)))
            self._rows = int(np.ceil(num_plots / float(self._cols)))
        self.f, self.axarr = plt.subplots(self._rows,
                                          self._cols,
                                          figsize=self._figsize)

        # Remove unused subplots:
        for i in xrange(num_plots, self._rows * self._cols):
            plt.delaxes(self.axarr[np.unravel_index(i,
                                                    (self._rows, self._cols))])
        cnt = 0
        self.handles = []
        self.types = []
        keywds = ['handle', 'ydata', 'fmt', 'type', 'ids', 'shape']
        if not isinstance(self.axarr, np.ndarray):
            self.axarr = np.asarray([self.axarr])
        for LPU, configs in self._config.iteritems():
            for plt_id, config in enumerate(configs):
                ind = np.unravel_index(cnt, self.axarr.shape)
                cnt += 1

                # Some plot types require specific numbers of
                # neuron ID arrays:
                if 'type' in config:
                    if config['type'] == 'quiver':
                        assert len(config['ids']) == 2
                        config['type'] = 0
                    elif config['type'] == 'hsv':
                        assert len(config['ids']) == 2
                        config['type'] = 1
                    elif config['type'] == 'image':
                        assert len(config['ids']) == 1
                        config['type'] = 2
                    elif config['type'] == 'waveform':
                        config['type'] = 3
                    elif config['type'] == 'raster':
                        config['type'] = 4
                    elif config['type'] == 'rate':
                        config['type'] = 5
                    else:
                        raise ValueError('Plot type not supported')
                else:
                    if str(LPU).startswith('input') or not self._graph[LPU][
                            str(config['ids'][0])]['spiking']:
                        config['type'] = 2
                    else:
                        config['type'] = 4

                if config['type'] < 3:
                    if not 'shape' in config:

                        # XXX This can cause problems when the number
                        # of neurons is not equal to
                        # np.prod(config['shape'])
                        num_neurons = len(config['ids'][0])
                        config['shape'] = [int(np.ceil(np.sqrt(num_neurons)))]
                        config['shape'].append(
                            int(
                                np.ceil(num_neurons /
                                        float(config['shape'][0]))))

                if config['type'] == 0:
                    config['handle'] = self.axarr[ind].quiver(\
                               np.reshape(self._data[LPU][config['ids'][0],0],config['shape']),\
                               np.reshape(self._data[LPU][config['ids'][1],0],config['shape']))
                elif config['type'] == 1:
                    X = np.reshape(self._data[LPU][config['ids'][0], 0],
                                   config['shape'])
                    Y = np.reshape(self._data[LPU][config['ids'][1], 0],
                                   config['shape'])
                    V = (X**2 + Y**2)**0.5
                    H = (np.arctan2(X, Y) + np.pi) / (2 * np.pi)
                    S = np.ones_like(V)
                    HSV = np.dstack((H, S, V))
                    RGB = hsv_to_rgb(HSV)
                    config['handle'] = self.axarr[ind].imshow(RGB)
                elif config['type'] == 2:
                    if 'trans' in config:
                        if config['trans'] is True:
                            to_transpose = True
                        else:
                            to_transpose = False
                    else:
                        to_transpose = False
                        config['trans'] = False

                    if to_transpose:
                        temp = self.axarr[ind].imshow(np.transpose(np.reshape(\
                                self._data[LPU][config['ids'][0],0], config['shape'])))
                    else:
                        temp = self.axarr[ind].imshow(np.reshape(\
                                self._data[LPU][config['ids'][0],0], config['shape']))

                    temp.set_clim(self._imlim)
                    temp.set_cmap(plt.cm.gist_gray)
                    config['handle'] = temp
                elif config['type'] == 3:
                    fmt = config['fmt'] if 'fmt' in config else ''
                    self.axarr[ind].set_xlim(self._xlim)
                    self.axarr[ind].set_ylim(self._ylim)
                    if len(config['ids'][0]) == 1:
                        config['handle'] = self.axarr[ind].plot([0], \
                                            [self._data[LPU][config['ids'][0][0],0]], fmt)[0]
                        config['ydata'] = [
                            self._data[LPU][config['ids'][0][0], 0]
                        ]
                    else:
                        config['handle'] = self.axarr[ind].plot(
                            self._data[LPU][config['ids'][0], 0])[0]

                elif config['type'] == 4:
                    config['handle'] = self.axarr[ind]
                    config['handle'].vlines(0, 0, 0.01)
                    config['handle'].set_ylim([.5, len(config['ids'][0]) + .5])
                    config['handle'].set_ylabel('Neurons',
                                                fontsize=self._fontsize - 1,
                                                weight='bold')
                    config['handle'].set_xlabel('Time (s)',
                                                fontsize=self._fontsize - 1,
                                                weight='bold')
                    min_id = min(self._id_to_data_idx[LPU].keys())
                    min_idx = self._id_to_data_idx[LPU][min_id]
                    config['handle'].set_xlim(
                        [0, len(self._data[LPU][min_idx, :]) * self._dt])
                    config['handle'].axes.set_yticks([])
                    config['handle'].axes.set_xticks([])
                for key in config.iterkeys():
                    if key not in keywds:
                        try:
                            self._set_wrapper(self.axarr[ind], key,
                                              config[key])
                        except:
                            pass
                        try:
                            self._set_wrapper(config['handle'], key,
                                              config[key])
                        except:
                            pass
                if config['type'] < 3:
                    config['handle'].axes.set_xticks([])
                    config['handle'].axes.set_yticks([])

            if self.suptitle is not None:
                self.f.suptitle(self._title,
                                fontsize=self._fontsize + 1,
                                x=0.5,
                                y=0.03,
                                weight='bold')

        plt.tight_layout()

        if self.out_filename:
            self.writer = FFMpegFileWriter(fps=self.fps, codec=self.codec)

            # Use the output file to determine the name of the temporary frame
            # files so that two concurrently run visualizations don't clobber
            # each other's frames:
            self.writer.setup(
                self.f,
                self.out_filename,
                dpi=80,
                frame_prefix=os.path.splitext(self.out_filename)[0] + '_')
            self.writer.frame_format = 'png'
            self.writer.grab_frame()
        else:
            self.f.show()

    def _update(self):
        dt = self._dt
        t = self._t
        for key, configs in self._config.iteritems():
            data = self._data[key]
            for config in configs:
                if config['type'] == 3:
                    if len(config['ids'][0]) == 1:
                        config['ydata'].extend(np.reshape(np.double(\
                                        data[config['ids'][0], \
                                                  max(0,t-self._update_interval):t]),(-1,)))
                        config['handle'].set_xdata(dt * np.arange(0, t))
                        config['handle'].set_ydata(np.asarray(config['ydata']))
                    else:
                        config['handle'].set_ydata(\
                                        data[config['ids'][0], t])

                elif config['type'] == 4:

                    for j, id in enumerate(config['ids'][0]):

                        # Convert neuron id to index into array of generated outputs:
                        try:
                            idx = self._id_to_data_idx[key][id]
                        except:
                            continue
                        else:
                            for time in np.where(
                                    data[idx,
                                         max(0, t -
                                             self._update_interval):t])[0]:
                                config['handle'].vlines(
                                    float(t - time) * self._dt, j + 0.75,
                                    j + 1.25)
                else:
                    if config['type'] == 0:
                        shape = config['shape']
                        ids = config['ids']
                        config['handle'].U = np.reshape(data[ids[0], t], shape)
                        config['handle'].V = np.reshape(data[ids[1], t], shape)
                    elif config['type'] == 1:
                        shape = config['shape']
                        ids = config['ids']
                        X = np.reshape(data[ids[0], t], shape)
                        Y = np.reshape(data[ids[1], t], shape)
                        V = (X**2 + Y**2)**0.5
                        H = (np.arctan2(X, Y) + np.pi) / (2 * np.pi)
                        S = np.ones_like(V)
                        HSV = np.dstack((H, S, V))
                        RGB = hsv_to_rgb(HSV)
                        config['handle'].set_data(RGB)
                    elif config['type'] == 2:
                        ids = config['ids']
                        if config['trans']:
                            config['handle'].set_data(
                                np.transpose(
                                    np.reshape(data[ids[0], t],
                                               config['shape'])))
                        else:
                            config['handle'].set_data(
                                np.reshape(data[ids[0], t], config['shape']))

        self.f.canvas.draw()
        if self.out_filename:
            self.writer.grab_frame()

        self._t += self._update_interval

    def add_plot(self, config_dict, LPU, names=[''], shift=0):
        '''
        Add a plot to the visualizer

        Parameters
        ----------
        config_dict: dict
            A dictionary specifying the plot attributes. The attribute
            names should be the keys.
            
            The following are the plot attributes that can be specfied using
            this dict.

            type - str
                This specifies the type of the plot. Has to be one of
                ['waveform', 'raster', 'image','hsv','quiver']
            
            ids - dict with either 1 or 2 entries
                Specifies the neuron ids from the associated LPU.
                The keys should be in [0,1] and the values
                should be a list of ids.
                For example::

                    {'ids':{0:[1,2]}}

                will plot neurons with ids 1 and 2.
                Two entries in the dictionary  are needed if the plot is
                of type 'hsv' or 'quiver'
                For example::

                     {'ids':{0:[:768],1:[768:1536]},'type':'HSV'}

                can be used to generate a HSV plot where the hue channel is
                controlled by the angle of the vector defined by the membrane
                potentials of the neurons with ids [:768] and [768:1536] and
                the value will be the magnitude of the same vector. 
            
                This parameter is optional for the following cases::

                    1) The plot is associated with input signals.
                    2) The names parameter is specified.

                If the above doesn't hold, this attribute needs to be specified.

            shape - list or tuple with two entries
                This attribute specifies the dimensions for plots of type image,
                hsv or quiver.
  
            title - str
                Optional. Can be used to control the title of the plot.

            
            In addition to the above, any parameter supported by matlpotlib
            for the particular type of plot can be specified.
            For example - 'imlim','clim','xlim','ylim' etc.
              
        LPU: str
            The name of the LPU associated to this plot.

        names: list
            Optional. A list of str specifying the neurons
            to plot. Can be used instead of specifying ids in the
            config_dict. The gexf file of the LPU needs to have
            the name attribute in order for this to be used.

        
        '''
        config = config_dict.copy()
        if not isinstance(names, list):
            names = [names]
        if not LPU in self._config:
            self._config[LPU] = []
        if 'ids' in config:
            # XXX should check whether the specified ids are within range
            self._config[LPU].append(config)
        elif str(LPU).startswith('input'):
            config['ids'] = [range(0, self._data[LPU].shape[0])]
            self._config[LPU].append(config)
        else:
            config['ids'] = {}
            for i, name in enumerate(names):
                config['ids'][i] = []
                for id in range(len(self._graph[LPU].node)):
                    if self._graph[LPU].node[str(id)]['name'] == name:
                        config['ids'][i].append(id - shift)
            self._config[LPU].append(config)
        if not 'title' in config:
            if names[0]:
                config['title'] = "{0} - {1}".format(str(LPU), str(names[0]))
            else:
                if str(LPU).startswith('input_'):
                    config['title'] = LPU.split('_', 1)[1] + ' - ' + 'Input'
                else:
                    config['title'] = str(LPU)

    def _close(self):
        self.writer.finish()

    @property
    def xlim(self):
        '''
        Get or set the limits of the x-axis for all the raster and waveform plots.
        Can be superseded for individual plots by specifying xlim in the confid_dict
        for that plot.

        See also
        --------
            add_plot
        '''
        return self._xlim

    @xlim.setter
    def xlim(self, value):
        self._xlim = value

    @property
    def ylim(self):
        '''
        Get or set the limits of the y-axis for all the raster and waveform plots.
        Can be superseded for individual plots by specifying xlim in the confid_dict
        for that plot.

        See also
        --------
            add_plot
        '''
        return self._ylim

    @ylim.setter
    def ylim(self, value):
        self._ylim = value

    @property
    def imlim(self):
        return self._imlim

    @imlim.setter
    def imlim(self, value):
        self._imlim = value

    @property
    def out_filename(self):
        return self._out_file

    @out_filename.setter
    def out_filename(self, value):
        assert (isinstance(value, str))
        self._out_file = value

    @property
    def fps(self):
        return self._fps

    @fps.setter
    def fps(self, value):
        assert (isinstance(value, int))
        self._fps = value

    @property
    def codec(self):
        return self._codec

    @codec.setter
    def codec(self, value):
        assert (isinstance(value, str))
        self._codec = value

    @property
    def rows(self):
        return self._rows

    @rows.setter
    def rows(self, value):
        self._rows = value

    @property
    def cols(self):
        return self._cols

    @cols.setter
    def cols(self, value):
        self._cols = value

    @property
    def dt(self):
        return self._dt

    @dt.setter
    def dt(self, value):
        self._dt = value

    @property
    def figsize(self):
        return self._figsize

    @figsize.setter
    def figsize(self, value):
        assert (isinstance(value, tuple) and len(value) == 2)
        self._figsize = value

    @property
    def fontsize(self):
        return self._fontsize

    @fontsize.setter
    def fontsize(self, value):
        self._fontsize = value

    @property
    def suptitle(self):
        return self._title

    @suptitle.setter
    def suptitle(self, value):
        self._title = value

    @property
    def update_interval(self):
        """
        Gets or sets the update interval(in terms of time steps) for the animation.
        If value is 0 or None, update_interval will be set to the index of the
        final step. As a consequence, only the final frame will be generated.
        """
        return self._update_interval

    @update_interval.setter
    def update_interval(self, value):
        self._update_interval = value
Пример #20
0
  def setup(self, fig, outfile, dpi):

    return FFMpegFileWriter.setup(self, fig, outfile, dpi,
      frame_prefix=self.temp_prefix,
      clear_temp=self.clear_temp)