def update_active_regions_seeg(self,
                                target_data,
                                probabilistic_model,
                                sensors,
                                reset=False):
     if reset:
         probabilistic_model.update_active_regions([])
     if target_data:
         active_regions = probabilistic_model.active_regions
         gain_matrix = np.array(sensors.gain_matrix)
         seeg_inds = sensors.get_sensors_inds_by_sensors_labels(
             target_data.space_labels)
         if len(seeg_inds) != 0:
             gain_matrix = gain_matrix[seeg_inds]
             for proj in gain_matrix:
                 active_regions += select_greater_values_array_inds(
                     proj).tolist()
                 probabilistic_model.update_active_regions(active_regions)
         else:
             warning(
                 "Skipping active regions setting by seeg power because no data were assigned to sensors!"
             )
     else:
         warning(
             "Skipping active regions setting by seeg power because no target data were provided!"
         )
     return probabilistic_model
Exemple #2
0
    def _determine_datasets_and_attributes(self, object, datasets_size=None):
        datasets_dict = {}
        metadata_dict = {}
        groups_keys = []

        try:
            if isinstance(object, dict):
                dict_object = object
            else:
                dict_object = vars(object)
            for key, value in dict_object.iteritems():
                if isinstance(value, numpy.ndarray):
                    if datasets_size is not None and value.size == datasets_size:
                        datasets_dict.update({key: value})
                    else:
                        if datasets_size is None and value.size > 0:
                            datasets_dict.update({key: value})
                        else:
                            metadata_dict.update({key: value})
                else:
                    if isinstance(value, (float, int, long, complex, str)):
                        metadata_dict.update({key: value})
                    else:
                        groups_keys.append(key)
        except:
            msg = "Failed to decompose group object: " + str(object) + "!"
            try:
                self.logger.info(str(object.__dict__))
            except:
                msg += "\n It has no __dict__ attribute!"
            warning(msg, self.logger)

        return datasets_dict, metadata_dict, groups_keys
Exemple #3
0
 def get_prior_pdf(self, parameter_name):
     mean_or_truth, parameter = self.get_prior(parameter_name)
     if isinstance(parameter, (ProbabilisticParameterBase,
                               TransformedProbabilisticParameterBase)):
         return parameter.scipy_method("pdf")
     else:
         warning("No parameter " + parameter_name +
                 " was found!\nReturning true value instead of pdf!")
         return mean_or_truth
Exemple #4
0
def convert_params_names_to_ins(dicts_list,
                                parameter_names=INS_PARAMS_NAMES_DICT):
    output = []
    for lst in ensure_list(dicts_list):
        for dct in ensure_list(lst):
            for p, p_ins in parameter_names.iteritems():
                try:
                    dct[p_ins] = dct[p]
                except:
                    warning("Parameter " + p + " not found in \n" +
                            str(dicts_list))
        output.append(lst)
    return tuple(output)
Exemple #5
0
 def build_model_from_model_config_dict(self, model_config_dict):
     if not isinstance(model_config_dict, dict):
         model_config_dict = model_config_dict.__dict__
     model_configuration = ModelConfiguration()
     for attr, value in model_configuration.__dict__.iteritems():
         value = model_config_dict.get(attr, None)
         if value is None:
             warning(
                 attr +
                 " not found in the input model configuraiton dictionary!" +
                 "\nLeaving default " + attr + ": " +
                 str(getattr(model_configuration, attr)))
         if value is not None:
             setattr(model_configuration, attr, value)
     return model_configuration
 def update_active_regions_x0_values(self,
                                     probabilistic_model,
                                     x0_values,
                                     reset=False):
     if reset:
         probabilistic_model.update_active_regions([])
     if len(x0_values) > 0:
         probabilistic_model.update_active_regions(
             probabilistic_model.active_regions +
             select_greater_values_array_inds(x0_values,
                                              self.active_x0_th).tolist())
     else:
         warning(
             "Skipping active regions setting by x0 values because no such values were provided!"
         )
     return probabilistic_model
Exemple #7
0
 def get_truth(self, parameter_name):
     if self.target_data_type == TARGET_DATA_TYPE.SYNTHETIC.value:
         truth = self.ground_truth.get(parameter_name, np.nan)
         if truth is np.nan:
             truth = getattr(self.model_config, parameter_name, np.nan)
             # TODO: find a more general solution here...
             if truth is np.nan and parameter_name == "MC" or parameter_name == "FC":
                 truth = self.model_config.model_connectivity
         if truth is np.nan:
             # TODO: decide if it is a good idea to return this kind of modeler's "truth"...
             truth = getattr(self, parameter_name, np.nan)
         if truth is np.nan:
             warning("Ground truth value for parameter " + parameter_name +
                     " was not found!")
         return truth
     return np.nan
 def _set_attributes_from_dict(self, attributes_dict):
     if not isinstance(attributes_dict, dict):
         attributes_dict = attributes_dict.__dict__
     for attr, value in attributes_dict.iteritems():
         if not attr in [
                 "model_config", "parameters", "number_of_regions",
                 "number_of_parameters"
         ]:
             value = attributes_dict.get(attr, None)
             if value is None:
                 warning(attr + " not found in input dictionary!" +
                         "\nLeaving as it is: " + attr + " = " +
                         str(getattr(self, attr)))
             if value is not None:
                 setattr(self, attr, value)
     return attributes_dict
Exemple #9
0
    def plot_spectral_analysis_raster(self, time, data, time_units="ms", freq=None, spectral_options={},
                                      special_idx=[], title='Spectral Analysis', figure_name=None, labels=[],
                                      figsize=FiguresConfig.VERY_LARGE_SIZE):
        nS = data.shape[1]
        n_special_idx = len(special_idx)
        if n_special_idx > 0:
            data = data[:, special_idx]
            nS = data.shape[1]
            if len(labels) > n_special_idx:
                labels = numpy.array([str(ilbl) + ". " + str(labels[ilbl]) for ilbl in special_idx])
            elif len(labels) == n_special_idx:
                labels = numpy.array([str(ilbl) + ". " + str(label) for ilbl, label in zip(special_idx, labels)])
            else:
                labels = numpy.array([str(ilbl) for ilbl in special_idx])
        else:
            if len(labels) != nS:
                labels = numpy.array([str(ilbl) for ilbl in range(nS)])
        if nS > 20:
            warning("It is not possible to plot spectral analysis plots for more than 20 signals!")
            return
        if not isinstance(time_units, basestring):
            time_units = list(time_units)[0]
        time_units = ensure_string(time_units)
        if time_units in ("ms", "msec"):
            fs = 1000.0
        else:
            fs = 1.0
        fs = fs / numpy.mean(numpy.diff(time))
        log_norm = spectral_options.get("log_norm", False)
        mode = spectral_options.get("mode", "psd")
        psd_label = mode
        if log_norm:
            psd_label = "log" + psd_label
        stf, time, freq, psd = time_spectral_analysis(data, fs,
                                                      freq=freq,
                                                      mode=mode,
                                                      nfft=spectral_options.get("nfft"),
                                                      window=spectral_options.get("window", 'hanning'),
                                                      nperseg=spectral_options.get("nperseg", int(numpy.round(fs / 4))),
                                                      detrend=spectral_options.get("detrend", 'constant'),
                                                      noverlap=spectral_options.get("noverlap"),
                                                      f_low=spectral_options.get("f_low", 10.0),
                                                      log_scale=spectral_options.get("log_scale", False))
        min_val = numpy.min(stf.flatten())
        max_val = numpy.max(stf.flatten())
        if nS > 2:
            figsize = FiguresConfig.VERY_LARGE_SIZE
        fig = pyplot.figure(title, figsize=figsize)
        fig.suptitle(title)
        gs = gridspec.GridSpec(nS, 23)
        ax = numpy.empty((nS, 2), dtype="O")
        img = numpy.empty((nS,), dtype="O")
        line = numpy.empty((nS,), dtype="O")
        for iS in range(nS, -1, -1):
            if iS < nS - 1:
                ax[iS, 0] = pyplot.subplot(gs[iS, :20], sharex=ax[iS, 0])
                ax[iS, 1] = pyplot.subplot(gs[iS, 20:22], sharex=ax[iS, 1], sharey=ax[iS, 0])
            else:
                # TODO: find and correct bug here
                ax[iS, 0] = pyplot.subplot(gs[iS, :20])
                ax[iS, 1] = pyplot.subplot(gs[iS, 20:22], sharey=ax[iS, 0])
            img[iS] = ax[iS, 0].imshow(numpy.squeeze(stf[:, :, iS]).T, cmap=pyplot.set_cmap('jet'),
                                       interpolation='none',
                                       norm=Normalize(vmin=min_val, vmax=max_val), aspect='auto', origin='lower',
                                       extent=(time.min(), time.max(), freq.min(), freq.max()))
            # img[iS].clim(min_val, max_val)
            ax[iS, 0].set_title(labels[iS])
            ax[iS, 0].set_ylabel("Frequency (Hz)")
            line[iS] = ax[iS, 1].plot(psd[:, iS], freq, 'k', label=labels[iS])
            pyplot.setp(ax[iS, 1].get_yticklabels(), visible=False)
            # ax[iS, 1].yaxis.tick_right()
            # ax[iS, 1].yaxis.set_ticks_position('both')
            if iS == (nS - 1):
                ax[iS, 0].set_xlabel("Time (" + time_units + ")")

                ax[iS, 1].set_xlabel(psd_label)
            else:
                pyplot.setp(ax[iS, 0].get_xticklabels(), visible=False)
            pyplot.setp(ax[iS, 1].get_xticklabels(), visible=False)
            ax[iS, 0].autoscale(tight=True)
            ax[iS, 1].autoscale(tight=True)
        # make a color bar
        cax = pyplot.subplot(gs[:, 22])
        pyplot.colorbar(img[0], cax=pyplot.subplot(gs[:, 22]))  # fraction=0.046, pad=0.04) #fraction=0.15, shrink=1.0
        cax.set_title(psd_label)
        self._save_figure(pyplot.gcf(), figure_name)
        self._check_show()
        return fig, ax, img, line, time, freq, stf, psd
 def _region_parameters_violin_plots(
         self,
         samples_all,
         values=None,
         lines=None,
         stats=None,
         params=["x0", "x1_init", "z_init"],
         skip_samples=0,
         per_chain_or_run=False,
         labels=[],
         seizure_indices=None,
         figure_name="Regions parameters samples",
         figsize=FiguresConfig.VERY_LARGE_SIZE):
     if isinstance(values, dict):
         vals_fun = lambda param: values.get(param, numpy.array([]))
     else:
         vals_fun = lambda param: []
     if isinstance(lines, dict):
         lines_fun = lambda param: lines.get(param, numpy.array([]))
     else:
         lines_fun = lambda param: []
     samples_all = ensure_list(samples_all)
     params = [param for param in params if param in samples_all[0].keys()]
     samples = []
     for sample in samples_all:
         samples.append(
             extract_dict_stringkeys(sample, params, modefun="equal"))
     labels = generate_region_labels(samples[0].values()[0].shape[-1],
                                     labels)
     n_chains = len(samples_all)
     n_samples = samples[0].values()[0].shape[0]
     if n_samples > 1:
         violin_flag = True
     else:
         violin_flag = False
     if not per_chain_or_run and n_chains > 1:
         samples = [merge_samples(samples)]
         plot_samples = lambda s: numpy.concatenate(numpy.split(
             s[:, skip_samples:].T, n_chains, axis=2),
                                                    axis=1).squeeze().T
         plot_figure_name = lambda ichain: figure_name
     else:
         plot_samples = lambda s: s[skip_samples:]
         plot_figure_name = lambda ichain: figure_name + ": chain " + str(
             ichain + 1)
     params_labels = {}
     for ip, p in enumerate(params):
         if ip == 0:
             params_labels[p] = self._params_stats_labels(p, stats, labels)
         else:
             params_labels[p] = self._params_stats_labels(p, stats, "")
     n_params = len(params)
     if n_params > 9:
         warning(
             "Number of subplots in column wise vector-violin-plots cannot be > 9 and it is "
             + str(n_params) + "!")
     subplot_ind = 100 + n_params * 10
     figs = []
     for ichain, chain_sample in enumerate(samples):
         pyplot.figure(plot_figure_name(ichain), figsize=figsize)
         for ip, param in enumerate(params):
             fig = self.plot_vector_violin(plot_samples(
                 chain_sample[param]),
                                           vals_fun(param),
                                           lines_fun(param),
                                           params_labels[param],
                                           subplot_ind + ip + 1,
                                           param,
                                           violin_flag=violin_flag,
                                           colormap="YlOrRd",
                                           show_y_labels=True,
                                           indices_red=seizure_indices,
                                           sharey=None)
         self._save_figure(pyplot.gcf(), None)
         self._check_show()
         figs.append(fig)
     return tuple(figs)
    def plot_array_model_comparison(self,
                                    model_comps,
                                    title_prefix="",
                                    metrics=["loos", "ks"],
                                    labels=[],
                                    xdata=None,
                                    xlabel="",
                                    figsize=FiguresConfig.VERY_LARGE_SIZE,
                                    figure_name=None):
        def arrange_chains_or_runs(metric_data):
            n_chains_or_runs = 1
            for imodel, model in enumerate(metric_data):
                if model.ndim > 2:
                    if model.shape[0] > n_chains_or_runs:
                        n_chains_or_runs = model.shape[0]
                else:
                    metric_data[imodel] = numpy.expand_dims(model, axis=0)
            return metric_data

        colorcycle = pyplot.rcParams['axes.prop_cycle'].by_key()['color']
        n_colors = len(colorcycle)
        metrics = [
            metric for metric in metrics if metric in model_comps.keys()
        ]
        figs = []
        axs = []
        for metric in metrics:
            if isinstance(model_comps[metric], dict):
                # Multiple models as a list of np.arrays of chains x data
                metric_data = model_comps[metric].values()
                model_names = model_comps[metric].keys()
            else:
                # Single models as a one element list of one np.array of chains x data
                metric_data = [model_comps[metric]]
                model_names = [""]
            metric_data = arrange_chains_or_runs(metric_data)
            n_models = len(metric_data)
            for jj in range(n_models):
                # Necessary because ks gets infinite sometimes...
                temp = metric_data[jj] == numpy.inf
                if numpy.all(temp):
                    warning("All values are inf for metric " + metric +
                            " of model " + model_names[ii] + "!\n")
                    return
                elif numpy.any(temp):
                    warning(
                        "Inf values found for metric " + metric +
                        " of model " + model_names[ii] + "!\n" +
                        "Substituting them with the maximum non-infite value!")
                    metric_data[jj][temp] = metric_data[jj][~temp].max()
            n_subplots = metric_data[0].shape[1]
            n_labels = len(labels)
            if n_labels != n_subplots:
                if n_labels != 0:
                    warning("Ignoring labels because their number (" +
                            str(n_labels) +
                            ") is not equal to the number of row subplots (" +
                            str(n_subplots) + ")!")
                labels = [str(ii + 1) for ii in range(n_subplots)]
            if xdata is None:
                xdata = numpy.arange(metric_data[jj].shape[-1])
            else:
                xdata = xdata.flatten()
            xdata0 = numpy.concatenate([
                numpy.reshape(xdata[0] - 0.1 * (xdata[-1] - xdata[0]), (1, )),
                xdata
            ])
            xdata1 = xdata[-1] + 0.1 * (xdata[-1] - xdata[0])
            if len(title_prefix) > 0:
                title = title_prefix + ": " + metric
            else:
                title = metric
            fig = pyplot.figure(title, figsize=figsize)
            fig.suptitle(title)
            fig.set_label(title)
            gs = gridspec.GridSpec(n_subplots, n_models)
            axes = numpy.empty((n_subplots, n_models), dtype="O")
            for ii in range(n_subplots - 1, -1, -1):
                for jj in range(n_models):
                    if ii > n_subplots - 1:
                        if jj > 0:
                            axes[ii, jj] = pyplot.subplot(
                                gs[ii, jj],
                                sharex=axes[n_subplots - 1, jj],
                                sharey=axes[ii, 0])
                        else:
                            axes[ii, jj] = pyplot.subplot(
                                gs[ii, jj], sharex=axes[n_subplots - 1, jj])
                    else:
                        if jj > 0:
                            axes[ii, jj] = pyplot.subplot(gs[ii, jj],
                                                          sharey=axes[ii, 0])
                        else:
                            axes[ii, jj] = pyplot.subplot(gs[ii, jj])
                    n_chains_or_runs = metric_data[jj].shape[0]
                    for kk in range(n_chains_or_runs):
                        c = colorcycle[kk % n_colors]
                        axes[ii, jj].plot(xdata,
                                          metric_data[jj][kk][ii, :],
                                          label="chain/run" + str(kk + 1),
                                          marker="o",
                                          markersize=1,
                                          markeredgecolor=c,
                                          markerfacecolor=None,
                                          linestyle="None")
                        if n_chains_or_runs > 1:
                            axes[ii, jj].legend()
                        m = numpy.nanmean(metric_data[jj][kk][ii, :])
                        axes[ii, jj].plot(xdata0,
                                          m * numpy.ones(xdata0.shape),
                                          color=c,
                                          linewidth=1)
                        axes[ii, jj].text(xdata0[0],
                                          1.1 * m,
                                          'mean=%0.2f' % m,
                                          ha='center',
                                          va='bottom',
                                          color=c)
                    axes[ii, jj].set_xlabel(xlabel)
                    if ii == 0:
                        axes[ii, jj].set_title(model_names[ii])
                if ii == n_subplots - 1:
                    axes[ii, 0].autoscale()  # tight=True
                    axes[ii, 0].set_xlim([xdata0[0], xdata1])  # tight=True
            # fig.tight_layout()
            self._save_figure(fig, figure_name)
            self._check_show()
            figs.append(fig)
            axs.append(axes)
        return tuple(figs), tuple(axs)
Exemple #12
0
 def get_parameter(self, parameter_name):
     parameter = self.parameters.get(parameter_name, None)
     if parameter is None:
         warning("Ground truth value for parameter " + parameter_name +
                 " was not found!")
     return parameter
Exemple #13
0
    def compute_information_criteria(self, samples, nparams=None, nsamples=None, ndata=None, parameters=[],
                                     skip_samples=0, merge_chains_or_runs_flag=False, log_like_str='log_likelihood'):

        """

        :param samples: a dictionary of stan outputs or a list of dictionaries for multiple runs/chains
        :param nparams: number of model parameters, it can be inferred from parameters if None
        :param nsamples: number of samples, it can be inferred from loglikelihood if None
        :param ndata: number of data points, it can be inferred from loglikelihood if None
        :param parameters: a list of parameter names, necessary for dic metric computations and in case nparams is None,
                           as well as for aicc, aic and bic computation
        :param merge_chains_or_runs_flag: logical flag for merging seperate chains/runs, default is True
        :param log_like_str: the name of the log likelihood output of stan, default ''log_likelihood
        :return:
        """

        import sys
        sys.path.insert(0, self.config.generic.MODEL_COMPARISON_PATH)
        from information_criteria.ComputeIC import maxlike, aicc, aic, bic, dic, waic
        from information_criteria.ComputePSIS import psisloo


        # if self.fitmethod.find("opt") >= 0:
        #     warning("No model comparison can be computed for optimization method!")
        #     return None

        samples = ensure_list(samples)
        if merge_chains_or_runs_flag and len(samples) > 1:
            samples = ensure_list(merge_samples(samples, skip_samples, flatten=True))
            skip_samples = 0

        results = []
        for sample in samples:
            log_likelihood = -1 * sample[log_like_str][skip_samples:]
            log_lik_shape = log_likelihood.shape
            if len(log_lik_shape) > 1:
                target_shape = log_lik_shape[1:]
            else:
                target_shape = (1,)
            if nsamples is None:
                nsamples = log_lik_shape[0]
            elif nsamples != log_likelihood.shape[0]:
                warning("nsamples (" + str(nsamples) +
                        ") is not equal to likelihood.shape[0] (" + str(log_lik_shape[0]) + ")!")

            log_likelihood = np.reshape(log_likelihood, (log_lik_shape[0], -1))
            if log_likelihood.shape > 1:
                ndata_real = np.maximum(log_likelihood.shape[1], 1)
            else:
                ndata_real = 1
            if ndata is None:
                ndata = ndata_real
            elif ndata != ndata_real:
                warning("ndata (" + str(ndata) + ") is not equal to likelihood.shape[1] (" + str(ndata_real) + ")!")

            result = maxlike(log_likelihood)

            if len(parameters) == 0:
                parameters = [param for param in sample.keys() if param.find("_star") >= 0]
            if len(parameters) > 0:
                nparams_real = 0
                zscore_params = []
                for p in parameters:
                    pval = sample[p][skip_samples:]
                    pzscore = np.array((pval - np.mean(pval, axis=0)) / np.std(pval, axis=0))
                    if len(pzscore.shape) > 2:
                        pzscore = np.reshape(pzscore, (pzscore.shape[0], -1))
                    zscore_params.append(pzscore)
                    if len(pzscore.shape) > 1:
                        nparams_real += np.maximum(pzscore.shape[1], 1)
                    else:
                        nparams_real += 1
                if nparams is None:
                    nparams = nparams_real
                elif nparams != nparams_real:
                    warning("nparams (" + str(nparams) +
                            ") is not equal to number of parameters included in the dic computation (" +
                            str(nparams_real) + ")!")
                # TODO: find out how to reduce dic to 1 value, from 1 value per parameter. mean(.) for the moment:
                result['dic'] = np.mean(dic(log_likelihood, zscore_params))
            else:
                warning("Parameters' names' list is empty and we found no _star parameters! No computation of dic!")

            if nparams is not None:
                result['aicc'] = aicc(log_likelihood, nparams, ndata)
                result['aic'] = aic(log_likelihood, nparams)
                result['bic'] = bic(log_likelihood, nparams, ndata)
            else:
                warning("Unknown number of parameters! No computation of aic, aaic, bic!")

            result.update(waic(log_likelihood))

            if nsamples > 1:
                result.update(psisloo(log_likelihood))
                result["loos"] = np.reshape(result["loos"], target_shape)
                result["ks"] = np.reshape(result["ks"], target_shape)
            else:
                result.pop('p_waic', None)

            for metric, value in result.items():
                result[metric] = value * np.ones(1,)

            results.append(result)

        if len(results) == 1:
            return results[0]
        else:
            return list_of_dicts_to_dicts_of_ndarrays(results)
Exemple #14
0
    def plot_bars(self, data, ax=None, fig=None, title="", group_names=[], legend_prefix="",
                  figsize=FiguresConfig.VERY_LARGE_SIZE):

        def barlabel(ax, rects, positions):
            """
            Attach a text label on each bar displaying its height
            """
            for rect, pos in zip(rects, positions):
                height = rect.get_height()
                if pos < 0:
                    y = -height
                    pos = 0.75 * pos
                else:
                    y = height
                    pos = 0.25 * pos
                ax.text(rect.get_x() + rect.get_width() / 2., pos, '%0.2f' % y,
                        color="k", ha='center', va='bottom', rotation=90)

        if fig is None:
            fig, ax = pyplot.subplots(1, 1, figsize=figsize)
            show_and_save = True
        else:
            show_and_save = False
            if ax is None:
                ax = pyplot.gca()
        if isinstance(data, (list, tuple)):  # If, there are many groups, data is a list:
            # Fill in with nan in case that not all groups have the same number of elements
            from itertools import izip_longest
            data = numpy.array(list(izip_longest(*ensure_list(data), fillvalue=numpy.nan))).T
        elif data.ndim == 1: # This is the case where there is only one group...
            data = numpy.expand_dims(data, axis=1).T
        n_groups, n_elements = data.shape
        posmax = data.max()
        negmax = -(-data).max()
        n_groups_names = len(group_names)
        if n_groups_names != n_groups:
            if n_groups_names != 0:
                warning("Ignoring group_names because their number (" + str(n_groups_names) +
                        ") is not equal to the number of groups (" + str(n_groups) + ")!")
            group_names = n_groups * [""]
        colorcycle = pyplot.rcParams['axes.prop_cycle'].by_key()['color']
        n_colors = len(colorcycle)
        x_inds = numpy.arange(n_groups)
        width = 0.9 / n_elements
        elements = []
        for iE in range(n_elements):
            elements.append(ax.bar(x_inds + iE*width, data[:, iE], width, color=colorcycle[iE % n_colors]))
            positions = [negmax if d < 0 else posmax for d in data[:, iE]]
            barlabel(ax, elements[-1], positions)
        if n_elements > 1:
            legend = [legend_prefix+str(ii) for ii in range(1, n_elements+1)]
            ax.legend(tuple([element[0] for element in elements]), tuple(legend))
        ax.set_xticks(x_inds + n_elements*width/2)
        ax.set_xticklabels(tuple(group_names))
        ax.set_title(title)
        ax.autoscale()  # tight=True
        ax.set_xlim([-1.05*width, n_groups*1.05])
        if show_and_save:
            fig.tight_layout()
            self._save_figure(fig)
            self._check_show()
        return fig, ax
Exemple #15
0
    def run_lsa(self, disease_hypothesis, model_configuration):

        if self.lsa_method == "auto":
            if numpy.any(model_configuration.x1eq > X1EQ_CR_DEF):
                self.lsa_method = "2D"
            else:
                self.lsa_method = "1D"

        if self.lsa_method == "2D" and numpy.all(model_configuration.x1eq <= X1EQ_CR_DEF):
            warning("LSA with the '2D' method (on the 2D Epileptor model) will not produce interpretable results when"
                    " the equilibrium point of the system is not supercritical (unstable)!")

        jacobian = self._compute_jacobian(model_configuration)

        # Perform eigenvalue decomposition
        eigen_values, eigen_vectors = numpy.linalg.eig(jacobian)
        eigen_values = numpy.real(eigen_values)
        eigen_vectors = numpy.real(eigen_vectors)
        sorted_indices = numpy.argsort(eigen_values, kind='mergesort')
        if self.lsa_method == "2D":
            sorted_indices = sorted_indices[::-1]
        self.eigen_vectors = eigen_vectors[:, sorted_indices]
        self.eigen_values = eigen_values[sorted_indices]

        self._ensure_eigen_vectors_number(self.eigen_values[:disease_hypothesis.number_of_regions],
                                          model_configuration.e_values, model_configuration.x0_values,
                                          disease_hypothesis.regions_disease_indices)

        if self.eigen_vectors_number == disease_hypothesis.number_of_regions:
            # Calculate the propagation strength index by summing all eigenvectors
            lsa_propagation_strength = numpy.abs(numpy.sum(self.eigen_vectors, axis=1))

        else:
            sorted_indices = max(self.eigen_vectors_number, 1)
            # Calculate the propagation strength index by summing the first n eigenvectors (minimum 1)
            if self.weighted_eigenvector_sum:
                lsa_propagation_strength = \
                    numpy.abs(weighted_vector_sum(numpy.array(self.eigen_values[:self.eigen_vectors_number]),
                                                  numpy.array(self.eigen_vectors[:, :self.eigen_vectors_number]),
                                                              normalize=True))
            else:
                lsa_propagation_strength = \
                    numpy.abs(numpy.sum(self.eigen_vectors[:, :self.eigen_vectors_number], axis=1))

        if self.lsa_method == "2D":
            # lsa_propagation_strength = lsa_propagation_strength[:disease_hypothesis.number_of_regions]
            # or
            # lsa_propagation_strength = numpy.where(lsa_propagation_strength[:disease_hypothesis.number_of_regions] >=
            #                                        lsa_propagation_strength[disease_hypothesis.number_of_regions:],
            #                                        lsa_propagation_strength[:disease_hypothesis.number_of_regions],
            #                                        lsa_propagation_strength[disease_hypothesis.number_of_regions:])
            # or
            lsa_propagation_strength = numpy.sqrt(lsa_propagation_strength[:disease_hypothesis.number_of_regions]**2 +
                                                  lsa_propagation_strength[disease_hypothesis.number_of_regions:]**2)
            lsa_propagation_strength = numpy.log10(lsa_propagation_strength)
            lsa_propagation_strength -= lsa_propagation_strength.min()


        if self.normalize_propagation_strength:
            # Normalize by the maximum
            lsa_propagation_strength /= numpy.max(lsa_propagation_strength)

        # # TODO: this has to be corrected
        # if self.eigen_vectors_number < 0.2 * disease_hypothesis.number_of_regions:
        #     propagation_strength_elbow = numpy.max([self.get_curve_elbow_point(lsa_propagation_strength),
        #                                     self.eigen_vectors_number])
        # else:
        propagation_strength_elbow = self.get_curve_elbow_point(lsa_propagation_strength)
        propagation_indices = lsa_propagation_strength.argsort()[-propagation_strength_elbow:]

        hypothesis_builder = HypothesisBuilder(disease_hypothesis.number_of_regions). \
                                set_attributes_based_on_hypothesis(disease_hypothesis). \
                                    set_name(disease_hypothesis.name + "_LSA"). \
                                        set_lsa_propagation(propagation_indices, lsa_propagation_strength)

        return hypothesis_builder.build_lsa_hypothesis()