Beispiel #1
0
 def _set_axis_labels(self,
                      fig,
                      sub,
                      n_regions,
                      region_labels,
                      indices2emphasize,
                      color='k',
                      position='left'):
     y_ticks = range(n_regions)
     # region_labels = numpy.array(["%d. %s" % l for l in zip(y_ticks, region_labels)])
     region_labels = generate_region_labels(len(y_ticks), region_labels,
                                            ". ",
                                            self.print_regions_indices,
                                            y_ticks)
     big_ax = fig.add_subplot(sub, frameon=False)
     if position == 'right':
         big_ax.yaxis.tick_right()
         big_ax.yaxis.set_label_position("right")
     big_ax.set_yticks(y_ticks)
     big_ax.set_yticklabels(region_labels, color='k')
     if not (color == 'k'):
         labels = big_ax.yaxis.get_ticklabels()
         for idx in indices2emphasize:
             labels[idx].set_color(color)
         big_ax.yaxis.set_ticklabels(labels)
     big_ax.invert_yaxis()
     big_ax.axes.get_xaxis().set_visible(False)
    def plot_fit_results(self, ests, samples, model_data, target_data, probabilistic_model=None, info_crit=None,
                         stats=None, pair_plot_params=["tau1", "sigma", "epsilon", "scale", "offset"],
                         region_violin_params=["x0", "PZ", "x1eq", "zeq"],
                         region_labels=[], regions_mode="active", seizure_indices=[],
                         trajectories_plot=True, connectivity_plot=False, skip_samples=0, title_prefix=""):
        sigma = []
        if probabilistic_model is not None:
            n_regions = probabilistic_model.number_of_regions
            region_labels = generate_region_labels(n_regions, region_labels, ". ", True)
            if probabilistic_model.parameters.get("sigma", None) is not None:
                sigma = ["sigma"]
            active_regions = ensure_list(probabilistic_model.active_regions)
        else:
            active_regions = ensure_list(model_data.get("active_regions", []))

        if isequal_string(regions_mode, "all"):
            if len(seizure_indices) == 0:
                seizure_indices = active_regions
        else:
            if len(active_regions) > 0:
                seizure_indices = [active_regions.index(ind) for ind in seizure_indices]
                if len(region_labels) > 0:
                    region_labels = region_labels[active_regions]

        if len(region_labels) == 0:
            self.print_regions_indices = True
            self.print_ts_indices = True

        figs = []

        # Pack fit samples time series into timeseries objects:
        from tvb_fit.tvb_epilepsy.top.scripts.fitting_scripts import samples_to_timeseries
        samples, target_data, x1prior, x1eps = samples_to_timeseries(samples, model_data, target_data, region_labels)
        figs.append(self.plot_fit_timeseries(target_data, samples, ests, stats, probabilistic_model, "fit_target_data",
                                             ["x1", "z"], ["dWt", "dX1t", "dZt"], sigma, seizure_indices,
                                             skip_samples, trajectories_plot, region_labels, title_prefix))

        figs.append(
            self.plot_fit_region_params(samples, stats, probabilistic_model, region_violin_params, seizure_indices,
                                        region_labels, regions_mode, False, skip_samples, title_prefix))

        figs.append(
            self.plot_fit_region_params(samples, stats, probabilistic_model, region_violin_params, seizure_indices,
                                        region_labels, regions_mode, True, skip_samples, title_prefix))

        figs.append(self.plot_fit_scalar_params(samples, stats, probabilistic_model, pair_plot_params,
                                                skip_samples, title_prefix))

        figs.append(self.plot_fit_scalar_params_iters(samples, pair_plot_params, 0, title_prefix, subplot_shape=None))


        if info_crit is not None:
            figs.append(self.plot_scalar_model_comparison(info_crit, title_prefix))
            figs.append(self.plot_array_model_comparison(info_crit, title_prefix, labels=target_data.space_labels,
                                                         xdata=target_data.time, xlabel="Time"))

        if connectivity_plot:
            figs.append(self.plot_fit_connectivity(ests, stats, probabilistic_model, "MC", region_labels, title_prefix))

        return tuple(figs)
Beispiel #3
0
 def axYticks(labels, nTS, offsets=offset):
     pyplot.gca().set_yticks(
         (offset * numpy.array([range(nTS)]).flatten()).tolist())
     try:
         pyplot.gca().set_yticklabels(labels.flatten().tolist())
     except:
         labels = generate_region_labels(nTS, [], "", True)
         self.logger.warning(
             "Cannot convert region labels' strings for y axis ticks!")
Beispiel #4
0
 def string_connectivity_disease(self, region_labels=[]):
     region_labels = generate_region_labels(self.number_of_regions,
                                            region_labels,
                                            str=". ")
     disease_string = ""
     for w_ind, w_val in zip(self.w_indices, self.w_values):
         disease_string += region_labels[w_ind[0]] + " -> " + region_labels[
             w_ind[1]] + ": " + str(w_val) + "\n"
     return disease_string[:-1]
 def _region_parameters_violin_plots(self, samples_all, values=None, lines=None, stats=None,
                                     params=[], skip_samples=0, per_chain_or_run=False,
                                     labels=[], special_idx=None, figure_name="Regions parameters samples",
                                     figsize=FiguresConfig.VERY_LARGE_SIZE):
     if isinstance(values, dict):
         vals_fun = lambda param: values.get(param, numpy.array([]))
     else:
         vals_fun = lambda param: []
     if isinstance(lines, dict):
         lines_fun = lambda param: lines.get(param, numpy.array([]))
     else:
         lines_fun = lambda param: []
     samples_all = ensure_list(samples_all)
     params = [param for param in params if param in samples_all[0].keys()]
     samples = []
     for sample in samples_all:
         samples.append(extract_dict_stringkeys(sample, params, modefun="equal"))
     if len(labels) == 0:
         labels = generate_region_labels(samples[0].values()[0].shape[-1], labels, numbering=False)
     n_chains = len(samples_all)
     n_samples = samples[0].values()[0].shape[0]
     if n_samples > 1:
         violin_flag = True
     else:
         violin_flag = False
     if not per_chain_or_run and n_chains > 1:
         samples = [merge_samples(samples)]
         plot_samples = lambda s: numpy.concatenate(numpy.split(s[:, skip_samples:].T, n_chains, axis=2),
                                                    axis=1).squeeze().T
         plot_figure_name = lambda ichain: figure_name
     else:
         plot_samples = lambda s: s[skip_samples:]
         plot_figure_name = lambda ichain: figure_name + ": chain " + str(ichain + 1)
     params_labels = {}
     for ip, p in enumerate(params):
         if ip == 0:
             params_labels[p] = self._params_stats_labels(p, stats, labels)
         else:
             params_labels[p] = self._params_stats_labels(p, stats, "")
     n_params = len(params)
     if n_params > 9:
         warning("Number of subplots in column wise vector-violin-plots cannot be > 9 and it is "
                           + str(n_params) + "!")
     subplot_ind = 100 + n_params * 10
     figs = []
     for ichain, chain_sample in enumerate(samples):
         pyplot.figure(plot_figure_name(ichain), figsize=figsize)
         for ip, param in enumerate(params):
             fig = self.plot_vector_violin(plot_samples(chain_sample[param]), vals_fun(param),
                                           lines_fun(param), params_labels[param],
                                           subplot_ind + ip + 1, param, violin_flag=violin_flag,
                                           colormap="YlOrRd", show_y_labels=True,
                                           indices_red=special_idx, sharey=None)
         self._save_figure(pyplot.gcf(), None)
         self._check_show()
         figs.append(fig)
     return tuple(figs)
Beispiel #6
0
 def string_regions_disease(self, region_labels=[]):
     region_labels = generate_region_labels(self.number_of_regions,
                                            region_labels,
                                            str=". ")
     disease_values = self.regions_disease
     disease_string = ""
     for iRegion in self.regions_disease_indices:
         if iRegion in self.e_indices:
             hyp_type = "E"
         else:
             hyp_type = "x0"
         disease_string += region_labels[
             iRegion] + ": " + hyp_type + "=" + str(
                 disease_values[iRegion]) + "\n"
     return disease_string[:-1]
def get_x1_estimates_from_samples(samples,
                                  model_data,
                                  region_labels=[],
                                  time_unit="ms"):
    time = model_data.get("time", False)
    if time is not False:
        time_start = time[0]
        time_step = np.diff(time).mean()
    else:
        time_start = 0
        time_step = 1
    if isinstance(samples[0]["x1"], np.ndarray):
        get_x1 = lambda x1: x1.T
    else:
        get_x1 = lambda x1: x1.squeezed
    (n_times, n_regions, n_samples) = get_x1(samples[0]["x1"]).shape
    active_regions = model_data.get("active_regions",
                                    np.array(range(n_regions)))
    region_labels = generate_region_labels(
        np.maximum(n_regions, len(region_labels)), region_labels, ". ", False)
    if len(region_labels) > len(active_regions):
        region_labels = region_labels[active_regions]
    x1 = np.empty((n_times, n_regions, 0))
    for sample in ensure_list(samples):
        x1 = np.concatenate([x1, get_x1(sample["x1"])], axis=2)
    x1_mean = Timeseries(np.nanmean(x1, axis=2).squeeze(), {
        TimeseriesDimensions.SPACE.value: region_labels,
        TimeseriesDimensions.VARIABLES.value: ["x1"]
    },
                         time_start=time_start,
                         time_step=time_step,
                         time_unit=time_unit)
    x1_std = Timeseries(np.nanstd(x1, axis=2).squeeze(), {
        TimeseriesDimensions.SPACE.value: region_labels,
        TimeseriesDimensions.VARIABLES.value: ["x1std"]
    },
                        time_start=time_start,
                        time_step=time_step,
                        time_unit=time_unit)
    return x1_mean, x1_std
Beispiel #8
0
 def plot_vector(self,
                 vector,
                 labels,
                 subplot,
                 title,
                 show_y_labels=True,
                 indices_red=None,
                 sharey=None):
     ax = pyplot.subplot(subplot, sharey=sharey)
     pyplot.title(title)
     n_vector = labels.shape[0]
     y_ticks = numpy.array(range(n_vector), dtype=numpy.int32)
     color = 'k'
     colors = numpy.repeat([color], n_vector)
     coldif = False
     if indices_red is not None:
         colors[indices_red] = 'r'
         coldif = True
     if len(vector.shape) == 1:
         ax.barh(y_ticks, vector, color=colors, align='center')
     else:
         ax.barh(y_ticks, vector[0, :], color=colors, align='center')
     # ax.invert_yaxis()
     ax.grid(True, color='grey')
     ax.set_yticks(y_ticks)
     if show_y_labels:
         region_labels = generate_region_labels(n_vector, labels, ". ",
                                                self.print_regions_indices)
         ax.set_yticklabels(region_labels)
         if coldif:
             labels = ax.yaxis.get_ticklabels()
             for ids in indices_red:
                 labels[ids].set_color('r')
             ax.yaxis.set_ticklabels(labels)
     else:
         ax.set_yticklabels([])
     ax.autoscale(tight=True)
     if sharey is None:
         ax.invert_yaxis()
     return ax
def samples_to_timeseries(samples,
                          model_data,
                          target_data=None,
                          region_labels=[]):
    samples = ensure_list(samples)

    if isinstance(target_data, Timeseries):
        time = target_data.time
        n_target_data = target_data.number_of_labels
        target_data_labels = target_data.space_labels
    else:
        time = model_data.get("time", False)
        n_target_data = samples[0]["fit_target_data"]
        target_data_labels = generate_region_labels(n_target_data, [], ". ",
                                                    False)

    if time is not False:
        time_start = time[0]
        time_step = np.diff(time).mean()
    else:
        time_start = 0
        time_step = 1

    if not isinstance(target_data, Timeseries):
        target_data = Timeseries(
            target_data, {
                TimeseriesDimensions.SPACE.value: target_data_labels,
                TimeseriesDimensions.VARIABLES.value: ["target_data"]
            },
            time_start=time_start,
            time_step=time_step)

    (n_times, n_regions, n_samples) = samples[0]["x1"].T.shape
    active_regions = model_data.get("active_regions",
                                    np.array(range(n_regions)))
    region_labels = generate_region_labels(
        np.maximum(n_regions, len(region_labels)), region_labels, ". ", False)
    if len(region_labels) > len(active_regions):
        region_labels = region_labels[active_regions]

    x1 = np.empty((n_times, n_regions, 0))
    for sample in ensure_list(samples):
        for x in [
                "x1", "z", "x1_star", "z_star", "dX1t", "dZt", "dWt",
                "dX1t_star", "dZt_star", "dWt_star"
        ]:
            try:
                if x == "x1":
                    x1 = np.concatenate([x1, sample[x].T], axis=2)
                sample[x] = Timeseries(
                    np.expand_dims(sample[x].T, 2), {
                        TimeseriesDimensions.SPACE.value: region_labels,
                        TimeseriesDimensions.VARIABLES.value: [x]
                    },
                    time_start=time_start,
                    time_step=time_step,
                    time_unit=target_data.time_unit)

            except:
                pass

        sample["fit_target_data"] = Timeseries(
            np.expand_dims(sample["fit_target_data"].T, 2), {
                TimeseriesDimensions.SPACE.value: target_data_labels,
                TimeseriesDimensions.VARIABLES.value: ["fit_target_data"]
            },
            time_start=time_start,
            time_step=time_step)

    return samples, target_data, np.nanmean(x1, axis=2).squeeze(), np.nanstd(
        x1, axis=2).squeeze()
Beispiel #10
0
 def _plot_matrix(self,
                  matrix,
                  xlabels,
                  ylabels,
                  subplot=111,
                  title="",
                  show_x_labels=True,
                  show_y_labels=True,
                  x_ticks=numpy.array([]),
                  y_ticks=numpy.array([]),
                  indices_red_x=None,
                  indices_red_y=None,
                  sharex=None,
                  sharey=None,
                  cmap='autumn_r',
                  vmin=None,
                  vmax=None):
     ax = pyplot.subplot(subplot, sharex=sharex, sharey=sharey)
     pyplot.title(title)
     nx, ny = matrix.shape
     indices_red = [indices_red_x, indices_red_y]
     ticks = [x_ticks, y_ticks]
     labels = [xlabels, ylabels]
     nticks = []
     for ii, (n, tick) in enumerate(zip([nx, ny], ticks)):
         if len(tick) == 0:
             ticks[ii] = numpy.array(range(n), dtype=numpy.int32)
         nticks.append(len(ticks[ii]))
     cmap = pyplot.set_cmap(cmap)
     img = pyplot.imshow(matrix[ticks[0]][:, ticks[1]].T,
                         cmap=cmap,
                         vmin=vmin,
                         vmax=vmax,
                         interpolation='none')
     pyplot.grid(True, color='black')
     for ii, (xy, tick, ntick, ind_red, show, lbls, rot) in enumerate(
             zip(["x", "y"], ticks, nticks, indices_red,
                 [show_x_labels, show_y_labels], labels, [90, 0])):
         if show:
             labels[ii] = generate_region_labels(len(tick),
                                                 numpy.array(lbls)[tick],
                                                 ". ",
                                                 self.print_regions_indices,
                                                 tick)
             # labels[ii] = numpy.array(["%d. %s" % l for l in zip(tick, lbls[tick])])
             getattr(pyplot, xy + "ticks")(numpy.array(range(ntick)),
                                           labels[ii],
                                           rotation=rot)
         else:
             labels[ii] = numpy.array(["%d." % l for l in tick])
             getattr(pyplot, xy + "ticks")(numpy.array(range(ntick)),
                                           labels[ii])
         if ind_red is not None:
             tck = tick.tolist()
             ticklabels = getattr(ax, xy + "axis").get_ticklabels()
             for iidx, indr in enumerate(ind_red):
                 try:
                     ticklabels[tck.index(indr)].set_color('r')
                 except:
                     pass
             getattr(ax, xy + "axis").set_ticklabels(ticklabels)
     ax.autoscale(tight=True)
     divider = make_axes_locatable(ax)
     cax1 = divider.append_axes("right", size="5%", pad=0.05)
     pyplot.colorbar(
         img,
         cax=cax1)  # fraction=0.046, pad=0.04) #fraction=0.15, shrink=1.0
     return ax, cax1
Beispiel #11
0
    def plot_vector_violin(self,
                           dataset,
                           vector=[],
                           lines=[],
                           labels=[],
                           subplot=111,
                           title="",
                           violin_flag=True,
                           colormap="YlOrRd",
                           show_y_labels=True,
                           indices_red=None,
                           sharey=None):
        ax = pyplot.subplot(subplot, sharey=sharey)
        # ax.hold(True)
        pyplot.title(title)
        n_violins = dataset.shape[1]
        y_ticks = numpy.array(range(n_violins), dtype=numpy.int32)
        # the vector plot
        coldif = False
        if indices_red is None:
            indices_red = []
        if violin_flag:
            # the violin plot
            colormap = matplotlib.cm.ScalarMappable(
                cmap=pyplot.set_cmap(colormap))
            colormap = colormap.to_rgba(numpy.mean(dataset, axis=0),
                                        alpha=0.75)
            violin_parts = ax.violinplot(dataset,
                                         y_ticks,
                                         vert=False,
                                         widths=0.9,
                                         showmeans=True,
                                         showmedians=True,
                                         showextrema=True)
            violin_parts['cmeans'].set_color("k")
            violin_parts['cmins'].set_color("b")
            violin_parts['cmaxes'].set_color("b")
            violin_parts['cbars'].set_color("b")
            violin_parts['cmedians'].set_color("b")
            for ii in range(len(violin_parts['bodies'])):
                violin_parts['bodies'][ii].set_color(
                    numpy.reshape(colormap[ii], (1, 4)))
                violin_parts['bodies'][ii]._alpha = 0.75
                violin_parts['bodies'][ii]._edgecolors = numpy.reshape(
                    colormap[ii], (1, 4))
                violin_parts['bodies'][ii]._facecolors = numpy.reshape(
                    colormap[ii], (1, 4))
        else:
            colorcycle = pyplot.rcParams['axes.prop_cycle'].by_key()['color']
            n_samples = dataset.shape[0]
            for ii in range(n_violins):
                for jj in range(n_samples):
                    ax.plot(dataset[jj, ii],
                            y_ticks[ii],
                            "D",
                            mfc=colorcycle[jj % n_samples],
                            mec=colorcycle[jj % n_samples],
                            ms=20)
        color = 'k'
        colors = numpy.repeat([color], n_violins)
        if indices_red is not None:
            colors[indices_red] = 'r'
            coldif = True
        if len(vector) == n_violins:
            for ii in range(n_violins):
                ax.plot(vector[ii],
                        y_ticks[ii],
                        '*',
                        mfc=colors[ii],
                        mec=colors[ii],
                        ms=10)
        if len(lines) == 2 and lines[0].shape[0] == n_violins and lines[
                1].shape[0] == n_violins:
            for ii in range(n_violins):
                yy = (y_ticks[ii] - 0.45*lines[1][ii]/numpy.max(lines[1][ii]))\
                     * numpy.ones(numpy.array(lines[0][ii]).shape)
                ax.plot(lines[0][ii], yy, '--', color=colors[ii])

        ax.grid(True, color='grey')
        ax.set_yticks(y_ticks)
        if show_y_labels:
            region_labels = generate_region_labels(n_violins, labels, ". ",
                                                   self.print_regions_indices)
            ax.set_yticklabels(region_labels)
            if coldif:
                labels = ax.yaxis.get_ticklabels()
                for ids in indices_red:
                    labels[ids].set_color('r')
                ax.yaxis.set_ticklabels(labels)
        else:
            ax.set_yticklabels([])
        if sharey is None:
            ax.invert_yaxis()
        ax.autoscale()
        return ax
Beispiel #12
0
    def plot_timeseries(self,
                        data_dict,
                        time=None,
                        mode="ts",
                        subplots=None,
                        special_idx=[],
                        subtitles=[],
                        labels=[],
                        offset=0.5,
                        time_units="ms",
                        title='Time series',
                        figure_name=None,
                        figsize=FiguresConfig.LARGE_SIZE):
        n_vars = len(data_dict)
        vars = data_dict.keys()
        data = data_dict.values()
        data_lims = []
        for id, d in enumerate(data):
            if isequal_string(mode, "raster"):
                data[id] = (d - d.mean(axis=0))
                drange = numpy.max(data[id].max(axis=0) - data[id].min(axis=0))
                data[id] = data[id] / drange  # zscore(d, axis=None)
            data_lims.append([d.min(), d.max()])
        data_shape = data[0].shape
        n_times, nTS = data_shape[:2]
        if len(data_shape) > 2:
            nSamples = data_shape[2]
        else:
            nSamples = 1
        if special_idx is None:
            special_idx = []
        n_special_idx = len(special_idx)
        if len(subtitles) == 0:
            subtitles = vars
        if isinstance(labels, list) and len(labels) == n_vars:
            labels = [
                generate_region_labels(nTS, label, ". ", self.print_ts_indices)
                for label in labels
            ]
        else:
            labels = [
                generate_region_labels(nTS, labels, ". ",
                                       self.print_ts_indices)
                for _ in range(n_vars)
            ]
        if isequal_string(mode, "traj"):
            data_fun, plot_lines, projection, n_rows, n_cols, def_alpha, loopfun, \
            subtitle, subtitle_col, axlabels, axlimits = \
                self._trajectories_plot(n_vars, nTS, nSamples, subplots)
        else:
            if isequal_string(mode, "raster"):
                data_fun, time, plot_lines, projection, n_rows, n_cols, def_alpha, loopfun, \
                subtitle, subtitle_col, axlabels, axlimits, axYticks = \
                    self._timeseries_plot(time, n_vars, nTS, n_times, time_units, 0, offset, data_lims)

            else:
                data_fun, time, plot_lines, projection, n_rows, n_cols, def_alpha, loopfun, \
                subtitle, subtitle_col, axlabels, axlimits, axYticks = \
                    self._timeseries_plot(time, n_vars, nTS, n_times, time_units, ensure_list(subplots)[0])
        alpha_ratio = 1.0 / nSamples
        alphas = numpy.maximum(
            numpy.array([def_alpha] * nTS) * alpha_ratio, 0.1)
        alphas[special_idx] = numpy.maximum(alpha_ratio, 0.1)
        if isequal_string(mode, "traj") and (n_cols * n_rows > 1):
            colors = numpy.zeros((nTS, 4))
            colors[special_idx] = \
                numpy.array([numpy.array([1.0, 0, 0, 1.0]) for _ in range(n_special_idx)]).reshape((n_special_idx, 4))
        else:
            cmap = matplotlib.cm.get_cmap('jet')
            colors = numpy.array([cmap(0.5 * iTS / nTS) for iTS in range(nTS)])
            colors[special_idx] = \
                numpy.array([cmap(1.0 - 0.25 * iTS / nTS) for iTS in range(n_special_idx)]).reshape((n_special_idx, 4))
        colors[:, 3] = alphas
        lines = []
        pyplot.figure(title, figsize=figsize)
        pyplot.hold(True)
        axes = []
        for icol in range(n_cols):
            if n_rows == 1:
                # If there are no more rows, create axis, and set its limits, labels and possible subtitle
                axes += ensure_list(
                    pyplot.subplot(n_rows,
                                   n_cols,
                                   icol + 1,
                                   projection=projection))
                axlimits(data_lims, time, n_vars, icol)
                axlabels(labels[icol % n_vars], vars, n_vars, n_rows, 1, 0)
                pyplot.gca().set_title(subtitles[icol])
            for iTS in loopfun(nTS, n_rows, icol):
                if n_rows > 1:
                    # If there are more rows, create axes, and set their limits, labels and possible subtitles
                    axes += ensure_list(
                        pyplot.subplot(n_rows,
                                       n_cols,
                                       iTS + 1,
                                       projection=projection))
                    axlimits(data_lims, time, n_vars, icol)
                    subtitle(labels[icol % n_vars], iTS)
                    axlabels(labels[icol % n_vars], vars, n_vars, n_rows,
                             (iTS % n_rows) + 1, iTS)
                lines += ensure_list(
                    plot_lines(data_fun(data, time, icol), iTS, colors,
                               labels[icol % n_vars]))
            if isequal_string(
                    mode,
                    "raster"):  # set yticks as labels if this is a raster plot
                axYticks(labels[icol % n_vars], nTS)
                yticklabels = pyplot.gca().yaxis.get_ticklabels()
                self.tick_font_size = numpy.minimum(
                    self.tick_font_size,
                    int(numpy.round(self.tick_font_size * 100.0 / nTS)))
                for iTS in range(nTS):
                    yticklabels[iTS].set_fontsize(self.tick_font_size)
                    if iTS in special_idx:
                        yticklabels[iTS].set_color(colors[iTS, :3].tolist() +
                                                   [1])
                pyplot.gca().yaxis.set_ticklabels(yticklabels)
                pyplot.gca().invert_yaxis()

        if self.config.figures.MOUSE_HOOVER:
            for line in lines:
                self.HighlightingDataCursor(line,
                                            formatter='{label}'.format,
                                            bbox=dict(fc='white'),
                                            arrowprops=dict(
                                                arrowstyle='simple',
                                                fc='white',
                                                alpha=0.5))

        self._save_figure(pyplot.gcf(), figure_name)
        self._check_show()
        return pyplot.gcf(), axes, lines
    def plot_state_space(self, model_config, region_labels=[], special_idx=[],
                         figure_name="", approximations=False):
        if model_config.model_name == "Epileptor2D":
            model = "2d"
        else:
            model = "6d"
        add_name = " " + "Epileptor " + model + " z-" + str(numpy.where(model_config.zmode[0], "exp", "lin"))
        figure_name = figure_name + add_name

        region_labels = generate_region_labels(model_config.number_of_regions, region_labels, ". ")
        # n_region_labels = len(region_labels)
        # if n_region_labels == model_config.number_of_regions:
        #     region_labels = numpy.array(["%d. %s" % l for l in zip(range(model_config.number_of_regions), region_labels)])
        # else:
        #     region_labels = numpy.array(["%d" % l for l in range(model_config.number_of_regions)])

        # Fixed parameters for all regions:
        x1eq = model_config.x1eq
        zeq = model_config.zeq
        x0 = a = b = d = yc = slope = Iext1 = Iext2 = s = tau1 = tau0 = zmode = 0.0
        for p in ["x0", "a", "b", "d", "yc", "slope", "Iext1", "Iext2", "s", "tau1", "tau0", "zmode"]:
            exec (p + " = numpy.mean(model_config." + p + ")")

        fig = pyplot.figure(figure_name, figsize=FiguresConfig.SMALL_SIZE)

        # Lines:
        x1 = numpy.linspace(-2.0, 1.0, 100)
        if isequal_string(model, "2d"):
            y1 = yc
        else:
            y1 = calc_eq_y1(x1, yc, d=d)
        # x1 nullcline:
        zX1 = calc_fx1(x1, z=0, y1=y1, Iext1=Iext1, slope=slope, a=a, b=b, d=d, tau1=1.0, x1_neg=True, model=model,
                       x2=0.0)  # yc + Iext1 - x1 ** 3 - 2.0 * x1 ** 2
        x1null, = pyplot.plot(x1, zX1, 'b-', label='x1 nullcline', linewidth=1)
        ax = pyplot.gca()
        ax.axes.hold(True)
        # z nullcines
        # center point (critical equilibrium point) without approximation:
        # zsq0 = yc + Iext1 - x1sq0 ** 3 - 2.0 * x1sq0 ** 2
        x0e = calc_x0_val_to_model_x0(X0_CR_DEF, yc, Iext1, a=a, b=b, d=d, zmode=model_config.zmode)
        x0ne = calc_x0_val_to_model_x0(X0_DEF, yc, Iext1, a=a, b=b, d=d, zmode=model_config.zmode)
        zZe = calc_fz(x1, z=0.0, x0=x0e, tau1=1.0, tau0=1.0, zmode=model_config.zmode)  # for epileptogenic regions
        zZne = calc_fz(x1, z=0.0, x0=x0ne, tau1=1.0, tau0=1.0, zmode=model_config.zmode)  # for non-epileptogenic regions
        zE1null, = pyplot.plot(x1, zZe, 'g-', label='z nullcline at critical point (e_values=1)', linewidth=1)
        zE2null, = pyplot.plot(x1, zZne, 'g--', label='z nullcline for e_values=0', linewidth=1)
        if approximations:
            # The point of the linear approximation (1st order Taylor expansion)
            x1LIN = def_x1lin(X1_DEF, X1EQ_CR_DEF, len(region_labels))
            x1SQ = X1EQ_CR_DEF
            x1lin0 = numpy.mean(x1LIN)
            # The point of the square (parabolic) approximation (2nd order Taylor expansion)
            x1sq0 = numpy.mean(x1SQ)
            # approximations:
            # linear:
            x1lin = numpy.linspace(-5.5 / 3.0, -3.5 / 3, 30)
            # x1 nullcline after linear approximation:
            # yc + Iext1 + 2.0 * x1lin0 ** 3 + 2.0 * x1lin0 ** 2 - \
            # (3.0 * x1lin0 ** 2 + 4.0 * x1lin0) * x1lin  # x1
            zX1lin = calc_fx1_2d_taylor(x1lin, x1lin0, z=0, y1=yc, Iext1=Iext1, slope=slope, a=a, b=b, d=d, tau1=1.0,
                                        x1_neg=None, order=2)  #
            # center point without approximation:
            # zlin0 = yc + Iext1 - x1lin0 ** 3 - 2.0 * x1lin0 ** 2
            # square:
            x1sq = numpy.linspace(-5.0 / 3, -1.0, 30)
            # x1 nullcline after parabolic approximation:
            # + 2.0 * x1sq ** 2 + 16.0 * x1sq / 3.0 + yc + Iext1 + 64.0 / 27.0
            zX1sq = calc_fx1_2d_taylor(x1sq, x1sq0, z=0, y1=yc, Iext1=Iext1, slope=slope, a=a, b=b, d=d, tau1=1.0,
                                       x1_neg=None, order=3, shape=x1sq.shape)
            sq, = pyplot.plot(x1sq, zX1sq, 'm--', label='Parabolic local approximation', linewidth=2)
            lin, = pyplot.plot(x1lin, zX1lin, 'c--', label='Linear local approximation', linewidth=2)
            pyplot.legend(handles=[x1null, zE1null, zE2null, lin, sq])
        else:
            pyplot.legend(handles=[x1null, zE1null, zE2null])

        # Points:
        ii = range(len(region_labels))
        n_special_idx = len(special_idx)
        if n_special_idx > 0:
            ii = numpy.delete(ii, special_idx)
        points = []
        for i in ii:
            point = pyplot.text(x1eq[i], zeq[i], str(i), fontsize=10, color='k', alpha=0.3,
                                 label=str(i) + '.' + region_labels[i])
            # point, = pyplot.plot(x1eq[i], zeq[i], '*', mfc='k', mec='k',
            #                      ms=10, alpha=0.3, label=str(i) + '.' + region_labels[i])
            points.append(point)
        if n_special_idx > 0:
            for i in special_idx:
                point = pyplot.text(x1eq[i], zeq[i], str(i), fontsize=10, color='r', alpha=0.8,
                                     label=str(i) + '.' + region_labels[i])
                # point, = pyplot.plot(x1eq[i], zeq[i], '*', mfc='r', mec='r', ms=10, alpha=0.8,
                #                      label=str(i) + '.' + region_labels[i])
                points.append(point)
        # ax.plot(x1lin0, zlin0, '*', mfc='r', mec='r', ms=10)
        # ax.axes.text(x1lin0 - 0.1, zlin0 + 0.2, 'e_values=0.0', fontsize=10, color='r')
        # ax.plot(x1sq0, zsq0, '*', mfc='m', mec='m', ms=10)
        # ax.axes.text(x1sq0, zsq0 - 0.2, 'e_values=1.0', fontsize=10, color='m')

        # Vector field
        X1, Z = numpy.meshgrid(numpy.linspace(-2.0, 1.0, 41), numpy.linspace(0.0, 6.0, 31), indexing='ij')
        if isequal_string(model, "2d"):
            y1 = yc
            x2 = 0.0
        else:
            y1 = calc_eq_y1(X1, yc, d=d)
            x2 = 0.0  # as a simplification for faster computation without important consequences
            # x2 = calc_eq_x2(Iext2, y2eq=None, zeq=X1, x1eq=Z, s=s)[0]
        fx1 = calc_fx1(X1, Z, y1=y1, Iext1=Iext1, slope=slope, a=a, b=b, d=d, tau1=tau1, x1_neg=None,
                       model=model, x2=x2)
        fz = calc_fz(X1, Z, x0=x0, tau1=tau1, tau0=tau0, zmode=zmode)
        C = numpy.abs(fx1) + numpy.abs(fz)
        pyplot.quiver(X1, Z, fx1, fz, C, edgecolor='k', alpha=.5, linewidth=.5)
        pyplot.contour(X1, Z, fx1, 0, colors='b', linestyles="dashed")

        ax.set_title("Epileptor states pace at the x1-z phase plane of the" + add_name)
        ax.axes.autoscale(tight=True)
        ax.axes.set_ylim([0.0, 6.0])
        ax.axes.set_xlabel('x1')
        ax.axes.set_ylabel('z')

        if self.config.figures.MOUSE_HOOVER:
            self.HighlightingDataCursor(points[0], formatter='{label}'.format, bbox=dict(fc='white'),
                                        arrowprops=dict(arrowstyle='simple', fc='white', alpha=0.5))

        if len(fig.get_label()) == 0:
            fig.set_label(figure_name)
        else:
            figure_name = fig.get_label().replace(": ", "_").replace(" ", "_").replace("\t", "_")

        self._save_figure(None, figure_name)
        self._check_show()
        return fig