コード例 #1
0
ファイル: animations.py プロジェクト: MalloryDazza/NNGT
    def __init__(self,
                 source,
                 network,
                 resolution=1,
                 start=0.,
                 timewindow=None,
                 trace=5.,
                 show_spikes=False,
                 sort_neurons=None,
                 decimate_connections=False,
                 interval=50,
                 repeat=True,
                 active_size=None,
                 **kwargs):
        '''
        Generate a SubplotAnimation instance to plot a network activity.

        Parameters
        ----------
        source : tuple
            NEST gid of the ``spike_detector``(s) which recorded the network.
        network : :class:`~nngt.SpatialNetwork`
            Network embedded in space to plot the actvity of the neurons in
            space.
        resolution : double, optional (default: None)
            Time resolution of the animation.
        timewindow : double, optional (default: None)
            Time window which will be shown for the spikes and self.second.
        trace : double, optional (default: 5.)
            Interval of time (ms) over which the data is overlayed in red.
        show_spikes : bool, optional (default: True)
            Whether a spike trajectory should be displayed on the network.
        sort_neurons : str or list, optional (default: None)
            Sort neurons using a topological property ("in-degree",
            "out-degree", "total-degree" or "betweenness"), an activity-related
            property ("firing_rate", 'B2') or a user-defined list of sorted
            neuron ids. Sorting is performed by increasing value of the
            `sort_neurons` property from bottom to top inside each group.
        **kwargs : dict, optional (default: {})
            Optional arguments such as 'make_rate', or all arguments for the
            :func:`nngt.plot.draw_network`.
        '''
        import matplotlib.pyplot as plt
        import nest
        from nngt.simulation.nest_activity import _get_data

        self.network = weakref.ref(network)

        self.simtime = _get_data(source)[-1, 1]
        self.times = np.arange(start, self.simtime + resolution, resolution)

        self.num_frames = len(self.times)
        self.start = start
        self.duration = self.simtime - start
        self.trace = trace
        self.show_spikes = show_spikes
        if timewindow is None:
            self.timewindow = self.duration
        else:
            self.timewindow = min(timewindow, self.duration)

        # init _SpikeAnimator parent class (create figure and right axes)
        #~ self.decim_conn = 1 if decimate is not None else decimate
        self.kwargs = kwargs
        cs = kwargs.get('chunksize', 10000)
        mpl.rcParams['agg.path.chunksize'] = cs
        if 'make_rate' not in kwargs:
            kwargs['make_rate'] = True
        super(AnimationNetwork, self).__init__(source,
                                               sort_neurons=sort_neurons,
                                               network=network,
                                               **kwargs)

        self.env = plt.subplot2grid((2, 4), (0, 0), rowspan=2, colspan=2)

        # Data and axis for network representation
        bbox = self.env.get_window_extent().transformed(
            self.fig.dpi_scale_trans.inverted())
        area_px = bbox.width * bbox.height * self.fig.dpi**2
        n_size = max(2,
                     0.5 * np.sqrt(area_px / self.num_neurons))  # neuron size
        if active_size is None:
            active_size = n_size + 2
        pos = network.get_positions()  # positions of the neurons
        self.x = pos[:, 0]
        self.y = pos[:, 1]

        # neurons
        self.line_neurons = Line2D([], [],
                                   ls='None',
                                   marker='o',
                                   color='black',
                                   ms=n_size,
                                   mew=0)
        self.line_neurons_a = Line2D([], [],
                                     ls='None',
                                     marker='o',
                                     color='red',
                                     ms=active_size,
                                     mew=0)
        self.lines_env = [self.line_neurons, self.line_neurons_a]
        xlim = (_min_axis(self.x.min()), _max_axis(self.x.max()))
        self.set_axis(self.env,
                      xlabel='Network',
                      ylabel='',
                      lines=self.lines_env,
                      xdata=self.x,
                      ydata=self.y,
                      xlim=xlim)
        # spike trajectory
        if show_spikes:
            self.line_st_a = Line2D([], [], color='red', linewidth=1)
            self.line_st_e = Line2D([], [],
                                    color='red',
                                    marker='d',
                                    ms=2,
                                    markeredgecolor='r')
            self.lines_env.extend((self.line_st_a, self.line_st_e))
        # remove the axes and grid from env
        self.env.set_xticks([])
        self.env.set_yticks([])
        self.env.set_xticklabels([])
        self.env.set_yticklabels([])
        self.env.grid(None)

        plt.tight_layout()

        anim.FuncAnimation.__init__(self,
                                    self.fig,
                                    self._draw,
                                    self._gen_data,
                                    repeat=repeat,
                                    interval=interval,
                                    blit=True)
コード例 #2
0
ファイル: animations.py プロジェクト: MalloryDazza/NNGT
    def __init__(self,
                 source,
                 sort_neurons=None,
                 network=None,
                 grid=(2, 4),
                 pos_raster=(0, 2),
                 span_raster=(1, 2),
                 pos_rate=(1, 2),
                 span_rate=(1, 2),
                 make_rate=True,
                 **kwargs):
        '''
        Generate a SubplotAnimation instance to plot a network activity.
        
        Parameters
        ----------
        source : NEST gid tuple or str
            NEST gid of the `spike_detector`(s) which recorded the network or
            path to a file containing the recorded spikes.

        Note
        ----
        Calling class is supposed to have defined `self.times`, `self.start`,
        `self.duration`, `self.trace`, and `self.timewindow`.
        '''
        import matplotlib.pyplot as plt
        import nest
        from nngt.simulation.nest_activity import _get_data

        # organization
        self.grid = grid
        self.has_rate = make_rate

        # get data
        data_s = _get_data(source)
        spikes = np.where(data_s[:, 1] >= self.times[0])[0]

        if np.any(spikes):
            idx_start = spikes[0]
            self.spikes = data_s[:, 1][idx_start:]
            self.senders = data_s[:, 0][idx_start:].astype(int)
            self._ymax = np.max(self.senders)
            self._ymin = np.min(self.senders)

            if network is None:
                self.num_neurons = int(self._ymax - self._ymin)
            else:
                self.num_neurons = network.node_nb()
            # sorting
            if sort_neurons is not None:
                if network is not None:
                    sorted_neurons = _sort_neurons(sort_neurons,
                                                   self.senders,
                                                   network,
                                                   data=data_s)
                    self.senders = sorted_neurons[self.senders]
                else:
                    warnings.warn("Could not sort neurons because no " \
                                  + "`network` was provided.")

            dt = self.times[1] - self.times[0]
            self.simtime = self.times[-1] - self.times[0]

            # generate the spike-rate
            if make_rate:
                self.firing_rate, _ = total_firing_rate(network,
                                                        data=data_s,
                                                        resolution=self.times)
        else:
            raise RuntimeError("No spikes between {} and {}.".format(
                self.start, self.times[-1]))

        # figure/canvas: pause/resume and step by step interactions
        self.fig = plt.figure(figsize=kwargs.get("figsize", (8, 6)),
                              dpi=kwargs.get("dpi", 75))
        self.pause = False
        self.pause_after = False
        self.event = None
        self.increment = 1
        self.fig.canvas.mpl_connect('button_press_event', self.on_click)
        self.fig.canvas.mpl_connect('key_press_event', self.on_keyboard_press)
        self.fig.canvas.mpl_connect('key_release_event',
                                    self.on_keyboard_release)

        # Axes for spikes and spike-rate/other representations
        self.spks = plt.subplot2grid(grid,
                                     pos_raster,
                                     rowspan=span_raster[0],
                                     colspan=span_raster[1])
        self.second = plt.subplot2grid(grid,
                                       pos_rate,
                                       rowspan=span_rate[0],
                                       colspan=span_rate[1],
                                       sharex=self.spks)

        # lines
        self.line_spks_ = Line2D([], [],
                                 ls='None',
                                 marker='o',
                                 color='black',
                                 ms=2,
                                 mew=0)
        self.line_spks_a = Line2D([], [],
                                  ls='None',
                                  marker='o',
                                  color='red',
                                  ms=2,
                                  mew=0)
        self.line_second_ = Line2D([], [], color='black')
        self.line_second_a = Line2D([], [], color='red', linewidth=2)
        self.line_second_e = Line2D([], [],
                                    color='red',
                                    marker='o',
                                    markeredgecolor='r')

        # Spikes raster plot
        kw_args = {}
        if self.timewindow != self.duration:
            kw_args['xlim'] = (self.start,
                               min(self.simtime, self.timewindow + self.start))
        ylim = (self._ymin, self._ymax)
        self.lines_raster = [self.line_spks_, self.line_spks_a]
        self.set_axis(self.spks,
                      xlabel='Time (ms)',
                      ylabel='Neuron',
                      lines=self.lines_raster,
                      ylim=ylim,
                      set_xticks=True,
                      **kw_args)
        self.lines_second = [
            self.line_second_, self.line_second_a, self.line_second_e
        ]

        # Rate plot
        if make_rate:
            self.set_axis(self.second,
                          xlabel='Time (ms)',
                          ylabel='Rate (Hz)',
                          lines=self.lines_second,
                          ydata=self.firing_rate,
                          **kw_args)