Beispiel #1
0
    def _initialize(self):

        # Count number of plots to create:
        num_plots = 0
        for config in self._config.itervalues():
            num_plots += len(config)

        # Set default grid of plot positions:
        if not self._rows * self._cols == num_plots:
            self._cols = int(np.ceil(np.sqrt(num_plots)))
            self._rows = int(np.ceil(num_plots / float(self._cols)))
        self.f, self.axarr = plt.subplots(self._rows,
                                          self._cols,
                                          figsize=self._figsize)

        # Remove unused subplots:
        for i in xrange(num_plots, self._rows * self._cols):
            plt.delaxes(self.axarr[np.unravel_index(i,
                                                    (self._rows, self._cols))])
        cnt = 0
        self.handles = []
        self.types = []
        keywds = ['handle', 'ydata', 'fmt', 'type', 'ids', 'shape', 'norm']
        # TODO: Irregular grid in U will make the plot better
        U, V = np.mgrid[0:np.pi / 2:complex(0, 60), 0:2 * np.pi:complex(0, 60)]
        X = np.cos(V) * np.sin(U)
        Y = np.sin(V) * np.sin(U)
        Z = np.cos(U)
        self._dome_pos_flat = (X.flatten(), Y.flatten(), Z.flatten())
        self._dome_pos = (X, Y, Z)
        self._dome_arr_shape = X.shape
        if not isinstance(self.axarr, np.ndarray):
            self.axarr = np.asarray([self.axarr])
        for LPU, configs in self._config.iteritems():
            for plt_id, config in enumerate(configs):
                ind = np.unravel_index(cnt, self.axarr.shape)
                cnt += 1

                # Some plot types require specific numbers of
                # neuron ID arrays:
                if 'type' in config:
                    if config['type'] == 'quiver':
                        assert len(config['ids']) == 2
                        config['type'] = 0
                    elif config['type'] == 'hsv':
                        assert len(config['ids']) == 2
                        config['type'] = 1
                    elif config['type'] == 'image':
                        assert len(config['ids']) == 1
                        config['type'] = 2
                    elif config['type'] == 'waveform':
                        config['type'] = 3
                    elif config['type'] == 'raster':
                        config['type'] = 4
                    elif config['type'] == 'rate':
                        config['type'] = 5
                    elif config['type'] == 'dome':
                        config['type'] = 6
                    else:
                        raise ValueError('Plot type not supported')
                else:
                    if str(LPU).startswith(
                            'input') and not self._graph[LPU].node[str(
                                config['ids'][0][0])]['spiking']:
                        config['type'] = 2
                    else:
                        config['type'] = 4

                if config['type'] < 3:
                    if not 'shape' in config:

                        # XXX This can cause problems when the number
                        # of neurons is not equal to
                        # np.prod(config['shape'])
                        num_neurons = len(config['ids'][0])
                        config['shape'] = [int(np.ceil(np.sqrt(num_neurons)))]
                        config['shape'].append(
                            int(
                                np.ceil(num_neurons /
                                        float(config['shape'][0]))))

                if config['type'] == 0:
                    config['handle'] = self.axarr[ind].quiver(\
                               np.reshape(self._data[LPU][config['ids'][0],0],config['shape']),\
                               np.reshape(self._data[LPU][config['ids'][1],0],config['shape']))
                elif config['type'] == 1:
                    X = np.reshape(self._data[LPU][config['ids'][0], 0],
                                   config['shape'])
                    Y = np.reshape(self._data[LPU][config['ids'][1], 0],
                                   config['shape'])
                    V = (X**2 + Y**2)**0.5
                    H = (np.arctan2(X, Y) + np.pi) / (2 * np.pi)
                    S = np.ones_like(V)
                    HSV = np.dstack((H, S, V))
                    RGB = hsv_to_rgb(HSV)
                    config['handle'] = self.axarr[ind].imshow(RGB)
                elif config['type'] == 2:
                    if 'trans' in config:
                        if config['trans'] is True:
                            to_transpose = True
                        else:
                            to_transpose = False
                    else:
                        to_transpose = False
                        config['trans'] = False

                    if to_transpose:
                        temp = self.axarr[ind].imshow(np.transpose(np.reshape(\
                                self._data[LPU][config['ids'][0],0], config['shape'])))
                    else:
                        temp = self.axarr[ind].imshow(np.reshape(\
                                self._data[LPU][config['ids'][0],0], config['shape']))

                    temp.set_clim(self._imlim)
                    temp.set_cmap(plt.cm.gist_gray)
                    config['handle'] = temp
                elif config['type'] == 3:
                    fmt = config['fmt'] if 'fmt' in config else ''
                    self.axarr[ind].set_xlim(self._xlim)
                    self.axarr[ind].set_ylim(self._ylim)
                    if len(config['ids'][0]) == 1:
                        config['handle'] = self.axarr[ind].plot([0], \
                                            [self._data[LPU][config['ids'][0][0],0]], fmt)[0]
                        config['ydata'] = [
                            self._data[LPU][config['ids'][0][0], 0]
                        ]
                    else:
                        config['handle'] = self.axarr[ind].plot(
                            self._data[LPU][config['ids'][0], 0])[0]

                elif config['type'] == 4:
                    config['handle'] = self.axarr[ind]
                    config['handle'].vlines(0, 0, 0.01)
                    config['handle'].set_ylim([.5, len(config['ids'][0]) + .5])
                    config['handle'].set_ylabel('Neurons',
                                                fontsize=self._fontsize - 1,
                                                weight='bold')
                    config['handle'].set_xlabel('Time (s)',
                                                fontsize=self._fontsize - 1,
                                                weight='bold')
                    min_id = min(self._id_to_data_idx[LPU].keys())
                    min_idx = self._id_to_data_idx[LPU][min_id]
                    config['handle'].set_xlim(
                        [0, len(self._data[LPU][min_idx, :]) * self._dt])
                    config['handle'].axes.set_yticks([])
                    config['handle'].axes.set_xticks([])
                elif config['type'] == 6:
                    self.axarr[ind].axes.set_yticks([])
                    self.axarr[ind].axes.set_xticks([])
                    self.axarr[ind] = self.f.add_subplot(self._rows,
                                                         self._cols,
                                                         cnt,
                                                         projection='3d')
                    config['handle'] = self.axarr[ind]
                    config['handle'].axes.set_yticks([])
                    config['handle'].axes.set_xticks([])
                    config['handle'].xaxis.set_ticks([])
                    config['handle'].yaxis.set_ticks([])
                    config['handle'].zaxis.set_ticks([])
                    if 'norm' not in config.keys():
                        config['norm'] = Normalize(vmin=-70, vmax=0, clip=True)
                    elif config['norm'] == 'auto':
                        if self._data[LPU].shape[1] > 100:
                            config['norm'] = Normalize(
                                vmin=np.min(self._data[LPU][config['ids'][0],
                                                            100:]),
                                vmax=np.max(self._data[LPU][config['ids'][0],
                                                            100:]),
                                clip=True)
                        else:
                            config['norm'] = Normalize(
                                vmin=np.min(
                                    self._data[LPU][config['ids'][0], :]),
                                vmax=np.max(
                                    self._data[LPU][config['ids'][0], :]),
                                clip=True)

                    node_dict = self._graph[LPU].node
                    if str(LPU).startswith('input'):
                        latpositions = np.asarray([ node_dict[str(nid)]['lat'] \
                                                    for nid in range(len(node_dict)) \
                                                    if node_dict[str(nid)]['extern'] ])
                        longpositions = np.asarray([ node_dict[str(nid)]['long'] \
                                                     for nid in range(len(node_dict)) \
                                                     if node_dict[str(nid)]['extern'] ])
                    else:
                        latpositions = np.asarray([
                            node_dict[str(nid)]['lat']
                            for nid in config['ids'][0]
                        ])
                        longpositions = np.asarray([
                            node_dict[str(nid)]['long']
                            for nid in config['ids'][0]
                        ])
                    xx = np.cos(longpositions) * np.sin(latpositions)
                    yy = np.sin(longpositions) * np.sin(latpositions)
                    zz = np.cos(latpositions)
                    config['positions'] = (xx, yy, zz)
                    colors = griddata(config['positions'],
                                      self._data[LPU][config['ids'][0],
                                                      0], self._dome_pos_flat,
                                      'nearest').reshape(self._dome_arr_shape)
                    colors = config['norm'](colors).data
                    colors = np.tile(
                        np.reshape(colors, [
                            self._dome_arr_shape[0], self._dome_arr_shape[1], 1
                        ]), [1, 1, 4])
                    colors[:, :, 3] = 1.0
                    config['handle'].plot_surface(self._dome_pos[0],
                                                  self._dome_pos[1],
                                                  self._dome_pos[2],
                                                  rstride=1,
                                                  cstride=1,
                                                  facecolors=colors,
                                                  antialiased=False,
                                                  shade=False)

                for key in config.iterkeys():
                    if key not in keywds:
                        try:
                            self._set_wrapper(self.axarr[ind], key,
                                              config[key])
                        except:
                            pass
                        try:
                            self._set_wrapper(config['handle'], key,
                                              config[key])
                        except:
                            pass

                if config['type'] < 3:
                    config['handle'].axes.set_xticks([])
                    config['handle'].axes.set_yticks([])

            if self.suptitle is not None:
                self.f.suptitle(self._title,
                                fontsize=self._fontsize + 1,
                                x=0.5,
                                y=0.03,
                                weight='bold')

        plt.tight_layout()

        if self.out_filename:
            if self.FFMpeg is None:
                if which(matplotlib.rcParams['animation.ffmpeg_path']):
                    self.writer = FFMpegFileWriter(fps=self.fps,
                                                   codec=self.codec)
                elif which(matplotlib.rcParams['animation.avconv_path']):
                    self.writer = AVConvFileWriter(fps=self.fps,
                                                   codec=self.codec)
                else:
                    raise RuntimeError('cannot find ffmpeg or avconv')
            elif self.FFMpeg:
                if which(matplotlib.rcParams['animation.ffmpeg_path']):
                    self.writer = FFMpegFileWriter(fps=self.fps,
                                                   codec=self.codec)
                else:
                    raise RuntimeError('cannot find ffmpeg')
            else:
                if which(matplotlib.rcParams['animation.avconv_path']):
                    self.writer = AVConvFileWriter(fps=self.fps,
                                                   codec=self.codec)
                else:
                    raise RuntimeError('cannot find avconv')

            # Use the output file to determine the name of the temporary frame
            # files so that two concurrently run visualizations don't clobber
            # each other's frames:
            self.writer.setup(
                self.f,
                self.out_filename,
                dpi=80,
                frame_prefix=os.path.splitext(self.out_filename)[0] + '_')
            self.writer.frame_format = 'png'
            self.writer.grab_frame()
        else:
            self.f.show()
Beispiel #2
0
    def generate_video(self, data, coordinates, rng, videofile):
        from scipy.interpolate import griddata
        from mpl_toolkits.mplot3d import Axes3D
        import matplotlib
        matplotlib.use('Agg')
        from matplotlib import cm
        import matplotlib.pyplot as plt
        from matplotlib.animation import FFMpegFileWriter, AVConvFileWriter
        from matplotlib.colors import Normalize

        radius = self._radius

        # unpacking coordinates of data
        zpositions, thetapositions = coordinates
        # conversion to cartesian
        x = radius * np.cos(thetapositions).flatten()
        y = radius * np.sin(thetapositions).flatten()
        z = zpositions.flatten()

        # constructing screen grid
        Z, Theta = np.mgrid[z.min():z.max():complex(0, 60),
                            0:2 * np.pi:complex(0, 60)]

        # conversion to cartesian
        X = radius * np.cos(Theta)
        Y = radius * np.sin(Theta)
        X_flat = X.flatten()
        Y_flat = Y.flatten()
        Z_flat = Z.flatten()

        # initialization
        fig = plt.figure(figsize=plt.figaspect(0.8), dpi=80)
        writer = AVConvFileWriter(fps=5, codec='mpeg4')
        writer.setup(fig,
                     videofile,
                     dpi=80,
                     frame_prefix=os.path.splitext(videofile)[0] + '_')
        writer.frame_format = 'png'

        step = 100
        plt.hold(False)

        ax = fig.add_subplot('111', projection='3d')
        ax.set_title('Input')

        ax.xaxis.set_ticks([])
        ax.yaxis.set_ticks([])
        ax.zaxis.set_ticks([])

        norm = Normalize(vmin=rng[0], vmax=rng[1], clip=True)

        for i in range(0, len(data), step):
            data_flat = data[i].flatten()
            colors = griddata((x, y, z), data_flat, (X_flat, Y_flat, Z_flat),
                              'nearest').reshape(X.shape)
            # normalize values
            colors = norm(colors).data
            # convert to RGB (equal values of R,G,B = greyscale)
            colors = np.tile(np.reshape(colors, [X.shape[0], X.shape[1], 1]),
                             [1, 1, 4])
            colors[:, :, 3] = 1.0

            ax.clear()
            ax.xaxis.set_ticks([])
            ax.yaxis.set_ticks([])
            ax.zaxis.set_ticks([])
            ax.plot_surface(X,
                            Y,
                            Z,
                            rstride=1,
                            cstride=1,
                            facecolors=colors,
                            antialiased=False,
                            shade=False)
            fig.canvas.draw()
            writer.grab_frame()
        writer.finish()