示例#1
0
 def _set_axis_labels(self,
                      fig,
                      sub,
                      n_regions,
                      region_labels,
                      indices2emphasize,
                      color='k',
                      position='left'):
     y_ticks = range(n_regions)
     # region_labels = numpy.array(["%d. %s" % l for l in zip(y_ticks, region_labels)])
     region_labels = generate_region_labels(len(y_ticks), region_labels,
                                            ". ",
                                            self.print_regions_indices,
                                            y_ticks)
     big_ax = fig.add_subplot(sub, frameon=False)
     if position == 'right':
         big_ax.yaxis.tick_right()
         big_ax.yaxis.set_label_position("right")
     big_ax.set_yticks(y_ticks)
     big_ax.set_yticklabels(region_labels, color='k')
     if not (color == 'k'):
         labels = big_ax.yaxis.get_ticklabels()
         for idx in indices2emphasize:
             labels[idx].set_color(color)
         big_ax.yaxis.set_ticklabels(labels)
     big_ax.invert_yaxis()
     big_ax.axes.get_xaxis().set_visible(False)
示例#2
0
 def axYticks(labels, nTS, offsets=offset):
     pyplot.gca().set_yticks((offset * numpy.array([list(range(nTS))]).flatten()).tolist())
     try:
         pyplot.gca().set_yticklabels(labels.flatten().tolist())
     except:
         labels = generate_region_labels(nTS, [], "", True)
         self.logger.warning("Cannot convert region labels' strings for y axis ticks!")
示例#3
0
 def plot_vector(self,
                 vector,
                 labels,
                 subplot,
                 title,
                 show_y_labels=True,
                 indices_red=None,
                 sharey=None):
     ax = pyplot.subplot(subplot, sharey=sharey)
     pyplot.title(title)
     n_vector = labels.shape[0]
     y_ticks = numpy.array(range(n_vector), dtype=numpy.int32)
     color = 'k'
     colors = numpy.repeat([color], n_vector)
     coldif = False
     if indices_red is not None:
         colors[indices_red] = 'r'
         coldif = True
     if len(vector.shape) == 1:
         ax.barh(y_ticks, vector, color=colors, align='center')
     else:
         ax.barh(y_ticks, vector[0, :], color=colors, align='center')
     # ax.invert_yaxis()
     ax.grid(True, color='grey')
     ax.set_yticks(y_ticks)
     if show_y_labels:
         region_labels = generate_region_labels(n_vector, labels, ". ",
                                                self.print_regions_indices)
         ax.set_yticklabels(region_labels)
         if coldif:
             labels = ax.yaxis.get_ticklabels()
             for ids in indices_red:
                 labels[ids].set_color('r')
             ax.yaxis.set_ticklabels(labels)
     else:
         ax.set_yticklabels([])
     ax.autoscale(tight=True)
     if sharey is None:
         ax.invert_yaxis()
     return ax
示例#4
0
 def _plot_matrix(self,
                  matrix,
                  xlabels,
                  ylabels,
                  subplot=111,
                  title="",
                  show_x_labels=True,
                  show_y_labels=True,
                  x_ticks=numpy.array([]),
                  y_ticks=numpy.array([]),
                  indices_red_x=None,
                  indices_red_y=None,
                  sharex=None,
                  sharey=None,
                  cmap='autumn_r',
                  vmin=None,
                  vmax=None):
     ax = pyplot.subplot(subplot, sharex=sharex, sharey=sharey)
     pyplot.title(title)
     nx, ny = matrix.shape
     indices_red = [indices_red_x, indices_red_y]
     ticks = [x_ticks, y_ticks]
     labels = [xlabels, ylabels]
     nticks = []
     for ii, (n, tick) in enumerate(zip([nx, ny], ticks)):
         if len(tick) == 0:
             ticks[ii] = numpy.array(range(n), dtype=numpy.int32)
         nticks.append(len(ticks[ii]))
     cmap = pyplot.set_cmap(cmap)
     img = pyplot.imshow(matrix[ticks[0]][:, ticks[1]].T,
                         cmap=cmap,
                         vmin=vmin,
                         vmax=vmax,
                         interpolation='none')
     pyplot.grid(True, color='black')
     for ii, (xy, tick, ntick, ind_red, show, lbls, rot) in enumerate(
             zip(["x", "y"], ticks, nticks, indices_red,
                 [show_x_labels, show_y_labels], labels, [90, 0])):
         if show:
             labels[ii] = \
                 generate_region_labels(len(tick), numpy.array(lbls)[tick], ". ", self.print_regions_indices, tick)
             # labels[ii] = numpy.array(["%d. %s" % l for l in zip(tick, lbls[tick])])
             getattr(pyplot, xy + "ticks")(numpy.array(range(ntick)),
                                           labels[ii],
                                           rotation=rot)
         else:
             labels[ii] = numpy.array(["%d." % l for l in tick])
             getattr(pyplot, xy + "ticks")(numpy.array(range(ntick)),
                                           labels[ii])
         if ind_red is not None:
             tck = tick.tolist()
             ticklabels = getattr(ax, xy + "axis").get_ticklabels()
             for iidx, indr in enumerate(ind_red):
                 try:
                     ticklabels[tck.index(indr)].set_color('r')
                 except:
                     pass
             getattr(ax, xy + "axis").set_ticklabels(ticklabels)
     ax.autoscale(tight=True)
     divider = make_axes_locatable(ax)
     cax1 = divider.append_axes("right", size="5%", pad=0.05)
     pyplot.colorbar(
         img,
         cax=cax1)  # fraction=0.046, pad=0.04) #fraction=0.15, shrink=1.0
     return ax, cax1
示例#5
0
    def plot_vector_violin(self,
                           dataset,
                           vector=[],
                           lines=[],
                           labels=[],
                           subplot=111,
                           title="",
                           violin_flag=True,
                           colormap="YlOrRd",
                           show_y_labels=True,
                           indices_red=None,
                           sharey=None):
        ax = pyplot.subplot(subplot, sharey=sharey)
        pyplot.title(title)
        n_violins = dataset.shape[1]
        y_ticks = numpy.array(range(n_violins), dtype=numpy.int32)
        # the vector plot
        coldif = False
        if indices_red is None:
            indices_red = []
        if violin_flag:
            # the violin plot
            colormap = matplotlib.cm.ScalarMappable(
                cmap=pyplot.set_cmap(colormap))
            colormap = colormap.to_rgba(numpy.mean(dataset, axis=0),
                                        alpha=0.75)
            violin_parts = ax.violinplot(dataset,
                                         y_ticks,
                                         vert=False,
                                         widths=0.9,
                                         showmeans=True,
                                         showmedians=True,
                                         showextrema=True)
            violin_parts['cmeans'].set_color("k")
            violin_parts['cmins'].set_color("b")
            violin_parts['cmaxes'].set_color("b")
            violin_parts['cbars'].set_color("b")
            violin_parts['cmedians'].set_color("b")
            for ii in range(len(violin_parts['bodies'])):
                violin_parts['bodies'][ii].set_color(
                    numpy.reshape(colormap[ii], (1, 4)))
                violin_parts['bodies'][ii]._alpha = 0.75
                violin_parts['bodies'][ii]._edgecolors = numpy.reshape(
                    colormap[ii], (1, 4))
                violin_parts['bodies'][ii]._facecolors = numpy.reshape(
                    colormap[ii], (1, 4))
        else:
            colorcycle = pyplot.rcParams['axes.prop_cycle'].by_key()['color']
            n_samples = dataset.shape[0]
            for ii in range(n_violins):
                for jj in range(n_samples):
                    ax.plot(dataset[jj, ii],
                            y_ticks[ii],
                            "D",
                            mfc=colorcycle[jj % n_samples],
                            mec=colorcycle[jj % n_samples],
                            ms=20)
        color = 'k'
        colors = numpy.repeat([color], n_violins)
        if indices_red is not None:
            colors[indices_red] = 'r'
            coldif = True
        if len(vector) == n_violins:
            for ii in range(n_violins):
                ax.plot(vector[ii],
                        y_ticks[ii],
                        '*',
                        mfc=colors[ii],
                        mec=colors[ii],
                        ms=10)
        if len(lines) == 2 and lines[0].shape[0] == n_violins and lines[
                1].shape[0] == n_violins:
            for ii in range(n_violins):
                yy = (y_ticks[ii] - 0.45*lines[1][ii]/numpy.max(lines[1][ii]))\
                     * numpy.ones(numpy.array(lines[0][ii]).shape)
                ax.plot(lines[0][ii], yy, '--', color=colors[ii])

        ax.grid(True, color='grey')
        ax.set_yticks(y_ticks)
        if show_y_labels:
            region_labels = generate_region_labels(n_violins, labels, ". ",
                                                   self.print_regions_indices)
            ax.set_yticklabels(region_labels)
            if coldif:
                labels = ax.yaxis.get_ticklabels()
                for ids in indices_red:
                    labels[ids].set_color('r')
                ax.yaxis.set_ticklabels(labels)
        else:
            ax.set_yticklabels([])
        if sharey is None:
            ax.invert_yaxis()
        ax.autoscale()
        return ax
    def plot_ts(self,
                data,
                time=None,
                var_labels=[],
                mode="ts",
                subplots=None,
                special_idx=[],
                subtitles=[],
                labels=[],
                offset=0.5,
                time_unit="ms",
                title='Time series',
                figure_name=None,
                figsize=None):
        if not isinstance(figsize, (list, tuple)):
            figsize = self.config.figures.LARGE_SIZE
        if isinstance(data, dict):
            var_labels = data.keys()
            data = data.values()
        elif isinstance(data, numpy.ndarray):
            if len(data.shape) < 3:
                if len(data.shape) < 2:
                    data = numpy.expand_dims(data, 1)
                data = numpy.expand_dims(data, 2)
                data = [data]
            else:
                # Assuming a structure of Time X Space X Variables X Samples
                data = [
                    data[:, :, iv].squeeze() for iv in range(data.shape[2])
                ]
        elif isinstance(data, (list, tuple)):
            data = ensure_list(data)
        else:
            raise_value_error("Input timeseries: %s \n"
                              "is not on of one of the following types: "
                              "[numpy.ndarray, dict, list, tuple]" % str(data))
        n_vars = len(data)
        data_lims = []
        for id, d in enumerate(data):
            if isequal_string(mode, "raster"):
                data[id] = (d - d.mean(axis=0))
                drange = numpy.max(data[id].max(axis=0) - data[id].min(axis=0))
                data[id] = data[id] / drange  # zscore(d, axis=None)
            data_lims.append([d.min(), d.max()])
        data_shape = data[0].shape
        if len(data_shape) == 1:
            n_times = data_shape[0]
            nTS = 1
            for iV in range(n_vars):
                data[iV] = data[iV][:, numpy.newaxis]
        else:
            n_times, nTS = data_shape[:2]
        if len(data_shape) > 2:
            nSamples = data_shape[2]
        else:
            nSamples = 1
        if special_idx is None:
            special_idx = []
        n_special_idx = len(special_idx)
        if len(subtitles) == 0:
            subtitles = var_labels
        if isinstance(labels, list) and len(labels) == n_vars:
            labels = [
                generate_region_labels(nTS, label, ". ", self.print_ts_indices)
                for label in labels
            ]
        else:
            labels = [
                generate_region_labels(nTS, labels, ". ",
                                       self.print_ts_indices)
                for _ in range(n_vars)
            ]
        if isequal_string(mode, "traj"):
            data_fun, plot_lines, projection, n_rows, n_cols, def_alpha, loopfun, \
            subtitle, subtitle_col, axlabels, axlimits = \
                self._trajectories_plot(n_vars, nTS, nSamples, subplots)
        else:
            if isequal_string(mode, "raster"):
                data_fun, time, plot_lines, projection, n_rows, n_cols, def_alpha, loopfun, \
                subtitle, subtitle_col, axlabels, axlimits, axYticks = \
                    self._ts_plot(time, n_vars, nTS, n_times, time_unit, 0, offset, data_lims)

            else:
                data_fun, time, plot_lines, projection, n_rows, n_cols, def_alpha, loopfun, \
                subtitle, subtitle_col, axlabels, axlimits, axYticks = \
                    self._ts_plot(time, n_vars, nTS, n_times, time_unit, ensure_list(subplots)[0])
        alpha_ratio = 1.0 / nSamples
        alphas = numpy.maximum(
            numpy.array([def_alpha] * nTS) * alpha_ratio, 0.1)
        alphas[special_idx] = numpy.maximum(alpha_ratio, 0.1)
        if isequal_string(mode, "traj") and (n_cols * n_rows > 1):
            colors = numpy.zeros((nTS, 4))
            colors[special_idx] = \
                numpy.array([numpy.array([1.0, 0, 0, 1.0]) for _ in range(n_special_idx)]).reshape((n_special_idx, 4))
        else:
            cmap = matplotlib.cm.get_cmap('jet')
            colors = numpy.array([cmap(0.5 * iTS / nTS) for iTS in range(nTS)])
            colors[special_idx] = \
                numpy.array([cmap(1.0 - 0.25 * iTS / nTS) for iTS in range(n_special_idx)]).reshape((n_special_idx, 4))
        colors[:, 3] = alphas
        lines = []
        pyplot.figure(title, figsize=figsize)
        axes = []
        for icol in range(n_cols):
            if n_rows == 1:
                # If there are no more rows, create axis, and set its limits, labels and possible subtitle
                axes += ensure_list(
                    pyplot.subplot(n_rows,
                                   n_cols,
                                   icol + 1,
                                   projection=projection))
                axlimits(data_lims, time, n_vars, icol)
                axlabels(labels[icol % n_vars], var_labels, n_vars, n_rows, 1,
                         0)
                pyplot.gca().set_title(subtitles[icol])
            for iTS in loopfun(nTS, n_rows, icol):
                if n_rows > 1:
                    # If there are more rows, create axes, and set their limits, labels and possible subtitles
                    axes += ensure_list(
                        pyplot.subplot(n_rows,
                                       n_cols,
                                       iTS + 1,
                                       projection=projection))
                    axlimits(data_lims, time, n_vars, icol)
                    subtitle(labels[icol % n_vars], iTS)
                    axlabels(labels[icol % n_vars], var_labels, n_vars, n_rows,
                             (iTS % n_rows) + 1, iTS)
                lines += ensure_list(
                    plot_lines(data_fun(data, time, icol), iTS, colors,
                               labels[icol % n_vars]))
            if isequal_string(
                    mode,
                    "raster"):  # set yticks as labels if this is a raster plot
                axYticks(labels[icol % n_vars], nTS)
                yticklabels = pyplot.gca().yaxis.get_ticklabels()
                self.tick_font_size = numpy.minimum(
                    self.tick_font_size,
                    int(numpy.round(self.tick_font_size * 100.0 / nTS)))
                for iTS in range(nTS):
                    yticklabels[iTS].set_fontsize(self.tick_font_size)
                    if iTS in special_idx:
                        yticklabels[iTS].set_color(colors[iTS, :3].tolist() +
                                                   [1])
                pyplot.gca().yaxis.set_ticklabels(yticklabels)
                pyplot.gca().invert_yaxis()

        if self.config.figures.MOUSE_HOOVER:
            for line in lines:
                self.HighlightingDataCursor(line,
                                            formatter='{label}'.format,
                                            bbox=dict(fc='white'),
                                            arrowprops=dict(
                                                arrowstyle='simple',
                                                fc='white',
                                                alpha=0.5))

        self._save_figure(pyplot.gcf(), figure_name)
        self._check_show()
        return pyplot.gcf(), axes, lines