def plot_fit_results(self, ests, samples, model_data, target_data, probabilistic_model=None, info_crit=None,
                         stats=None, pair_plot_params=["tau1", "sigma", "epsilon", "scale", "offset"],
                         region_violin_params=["x0", "PZ", "x1eq", "zeq"],
                         region_labels=[], regions_mode="active", seizure_indices=[],
                         trajectories_plot=True, connectivity_plot=False, skip_samples=0, title_prefix=""):
        sigma = []
        if probabilistic_model is not None:
            n_regions = probabilistic_model.number_of_regions
            region_labels = generate_region_labels(n_regions, region_labels, ". ", True)
            if probabilistic_model.parameters.get("sigma", None) is not None:
                sigma = ["sigma"]
            active_regions = ensure_list(probabilistic_model.active_regions)
        else:
            active_regions = ensure_list(model_data.get("active_regions", []))

        if isequal_string(regions_mode, "all"):
            if len(seizure_indices) == 0:
                seizure_indices = active_regions
        else:
            if len(active_regions) > 0:
                seizure_indices = [active_regions.index(ind) for ind in seizure_indices]
                if len(region_labels) > 0:
                    region_labels = region_labels[active_regions]

        if len(region_labels) == 0:
            self.print_regions_indices = True
            self.print_ts_indices = True

        figs = []

        # Pack fit samples time series into timeseries objects:
        from tvb_fit.tvb_epilepsy.top.scripts.fitting_scripts import samples_to_timeseries
        samples, target_data, x1prior, x1eps = samples_to_timeseries(samples, model_data, target_data, region_labels)
        figs.append(self.plot_fit_timeseries(target_data, samples, ests, stats, probabilistic_model, "fit_target_data",
                                             ["x1", "z"], ["dWt", "dX1t", "dZt"], sigma, seizure_indices,
                                             skip_samples, trajectories_plot, region_labels, title_prefix))

        figs.append(
            self.plot_fit_region_params(samples, stats, probabilistic_model, region_violin_params, seizure_indices,
                                        region_labels, regions_mode, False, skip_samples, title_prefix))

        figs.append(
            self.plot_fit_region_params(samples, stats, probabilistic_model, region_violin_params, seizure_indices,
                                        region_labels, regions_mode, True, skip_samples, title_prefix))

        figs.append(self.plot_fit_scalar_params(samples, stats, probabilistic_model, pair_plot_params,
                                                skip_samples, title_prefix))

        figs.append(self.plot_fit_scalar_params_iters(samples, pair_plot_params, 0, title_prefix, subplot_shape=None))


        if info_crit is not None:
            figs.append(self.plot_scalar_model_comparison(info_crit, title_prefix))
            figs.append(self.plot_array_model_comparison(info_crit, title_prefix, labels=target_data.space_labels,
                                                         xdata=target_data.time, xlabel="Time"))

        if connectivity_plot:
            figs.append(self.plot_fit_connectivity(ests, stats, probabilistic_model, "MC", region_labels, title_prefix))

        return tuple(figs)
Example #2
0
 def _sort_disease_indices_values(self, disease_dict):
     indices = []
     values = []
     for key, value in disease_dict.items():
         value = ensure_list(value)
         key = ensure_list(key)
         n = len(key)
         if n > 0:
             indices += key
             if len(value) == n:
                 values += value
             elif len(value) == 1 and n > 1:
                 values += value * n
             else:
                 raise_value_error("Length of disease indices " + str(n) +
                                   " and values " + str(len(value)) +
                                   " do not match!")
     if len(indices) > 0:
         if isinstance(indices[0], tuple):
             arg_sort = np.ravel_multi_index(
                 indices, (self.number_of_regions,
                           self.number_of_regions)).argsort()
         else:
             arg_sort = np.argsort(indices)
         return np.array(indices)[arg_sort].tolist(), np.array(
             values)[arg_sort]
     else:
         return [], []
 def plot_fit_connectivity(self, ests, samples, stats=None, probabilistic_model=None, model_conn_str="MC",
                           region_labels=[], title_prefix=""):
     # plot connectivity
     if len(title_prefix) > 0:
         title_prefix = title_prefix + "_"
     if probabilistic_model is not None:
         MC_prior = probabilistic_model.get_prior(model_conn_str)
         MC_subplot = 122
     else:
         MC_prior = False
         MC_subplot = 111
     for id_est, (est, sample) in enumerate(zip(ensure_list(ests), ensure_list(samples))):
         conn_figure_name = title_prefix + "chain" + str(id_est + 1) + ": Model Connectivity"
         pyplot.figure(conn_figure_name, FiguresConfig.VERY_LARGE_SIZE)
         # plot_regions2regions(conn.weights, conn.region_labels, 121, "weights")
         if MC_prior:
             self.plot_regions2regions(MC_prior, region_labels, 121,
                                       "Prior Model Connectivity")
         MC_title = "Posterior Model  Connectivity"
         if isinstance(stats, dict):
             MC_title = MC_title + ": "
             for skey, sval in stats.items():
                 MC_title = MC_title + skey + "_mean=" + str(sval[model_conn_str].mean()) + ", "
             MC_title = MC_title[:-2]
         fig=self.plot_regions2regions(est[model_conn_str], region_labels, MC_subplot, MC_title)
         self._save_figure(pyplot.gcf(), conn_figure_name)
         self._check_show()
         return fig
 def plot_fit_region_params(self, samples, stats=None, probabilistic_model=None,
                            params=["x0", "PZ", "x1eq", "zeq"], special_idx=[], region_labels=[],
                            regions_mode="all", per_chain_or_run=False, skip_samples=0, title_prefix=""):
     if len(title_prefix) > 0:
         title_prefix = title_prefix + " "
     title_pair_plot = title_prefix + "Global coupling vs x0 pair plot"
     title_violin_plot = title_prefix + "Regions parameters samples"
     # We assume in this function that regions_inds run for all regions for the statistical model,
     # and either among all or only among active regions for samples, ests and stats, depending on regions_mode
     samples = ensure_list(samples)
     priors = {}
     truth = {}
     if probabilistic_model is not None:
         if regions_mode == "active":
             regions_inds = ensure_list(probabilistic_model.active_regions)
         else:
             regions_inds = range(probabilistic_model.number_of_regions)
         I = numpy.ones((probabilistic_model.number_of_regions, )) #, 1
         for param in params:
             pdf = ensure_list(probabilistic_model.get_prior_pdf(param))
             for ip, p in enumerate(pdf):
                 pdf[ip] = ((numpy.array(p) * I).T[regions_inds])
             priors.update({param: (pdf[0].squeeze(), pdf[1].squeeze())})
             truth.update({param: ((probabilistic_model.get_truth(param) * I)[regions_inds]).squeeze()}) #[:, 0]
     # plot region-wise parameters
     f1 = self._region_parameters_violin_plots(samples, truth, priors, stats, params, skip_samples,
                                               per_chain_or_run=per_chain_or_run, labels=region_labels,
                                               special_idx=special_idx, figure_name=title_violin_plot)
     if not(per_chain_or_run) and "x0" in params and samples[0]["x0"].shape[1] < 10:
         x0_K_pair_plot_params = []
         x0_K_pair_plot_samples = [{} for _ in range(len(samples))]
         if samples[0].get("K", None) is not None:
             # plot K-x0 parameters in pair plots
             x0_K_pair_plot_params = ["K"]
             x0_K_pair_plot_samples = [{"K": s["K"]} for s in samples]
             if probabilistic_model is not None:
                 pdf = probabilistic_model.get_prior_pdf("K")
                 # TODO: a better hack than the following for when pdf returns p_mean, nan and p_mean is not a scalar
                 if pdf[1] is numpy.nan:
                     pdf = list(pdf)
                     pdf[0] = numpy.mean(pdf[0])
                     pdf = tuple(pdf)
                 priors.update({"K": pdf})
                 truth.update({"K": probabilistic_model.get_truth("K")})
         for inode, label in enumerate(region_labels):
             temp_name = "x0[" + label + "]"
             x0_K_pair_plot_params.append(temp_name)
             for ichain, s in enumerate(samples):
                 x0_K_pair_plot_samples[ichain].update({temp_name: s["x0"][:, inode]})
                 if probabilistic_model is not None:
                     priors.update({temp_name: (priors["x0"][0][inode], priors["x0"][1][inode])})
                     truth.update({temp_name: truth["x0"][inode]})
         f2 = self._parameters_pair_plots(x0_K_pair_plot_samples, x0_K_pair_plot_params, None, priors, truth,
                                          skip_samples, title=title_pair_plot)
         return f1, f2
     else:
         return f1
Example #5
0
def convert_params_names_to_ins(dicts_list,
                                parameter_names=INS_PARAMS_NAMES_DICT):
    output = []
    for lst in ensure_list(dicts_list):
        for dct in ensure_list(lst):
            for p, p_ins in parameter_names.items():
                try:
                    dct[p_ins] = dct[p]
                except:
                    warning("Parameter " + p + " not found in \n" +
                            str(dicts_list))
        output.append(lst)
    return tuple(output)
 def __init__(self, input="EpileptorDP", connectivity=None, K_unscaled=np.array([K_UNSCALED_DEF]),
              x0_values=X0_DEF, e_values=E_DEF, x1eq_mode="optimize", **kwargs):
     if isinstance(input, Simulator):
         # TODO: make this more specific once we clarify the model configuration representation compared to simTVB
         self.model_name = input.model._ui_name
         self.set_params_from_tvb_model(input.model)
         self.connectivity = normalize_weights(input.connectivity.weights)
         # self.coupling = input.coupling
         self.initial_conditions = np.squeeze(input.initial_conditions)  # initial conditions in a reduced form
         # self.noise = input.integrator.noise
         # self.monitor = ensure_list(input.monitors)[0]
     else:
         if isinstance(input, Model):
             self.model_name = input._ui_name
             self.set_params_from_tvb_model(input)
         elif isinstance(input, basestring):
             self.model_name = input
         else:
             raise_value_error("Input (%s) is not a TVB simulator, an epileptor model, "
                               "\nor a string of an epileptor model!")
     if isinstance(connectivity, Connectivity):
         self.connectivity = connectivity.normalized_weights
     elif isinstance(connectivity, TVBConnectivity):
         self.connectivity = normalize_weights(connectivity.weights)
     elif isinstance(connectivity, np.ndarray):
         self.connectivity = normalize_weights(connectivity)
     else:
         if not(isinstance(input, Simulator)):
             warning("Input connectivity (%s) is not a virtual patient connectivity, a TVB connectivity, "
                     "\nor a numpy.array!" % str(connectivity))
     self.x0_values = x0_values * np.ones((self.number_of_regions,), dtype=np.float32)
     self.x1eq_mode = x1eq_mode
     if len(ensure_list(K_unscaled)) == 1:
         K_unscaled = np.array(K_unscaled) * np.ones((self.number_of_regions,), dtype=np.float32)
     elif len(ensure_list(K_unscaled)) == self.number_of_regions:
         K_unscaled = np.array(K_unscaled)
     else:
         self.logger.warning(
             "The length of input global coupling K_unscaled is neither 1 nor equal to the number of regions!" +
             "\nSetting model_configuration_builder.K_unscaled = K_UNSCALED_DEF for all regions")
     self.set_K_unscaled(K_unscaled)
     for pname in EPILEPTOR_PARAMS:
         self.set_parameter(pname, kwargs.get(pname, getattr(self, pname)))
     # Update K_unscaled
     self.e_values = e_values * np.ones((self.number_of_regions,), dtype=np.float32)
     self.x0cr = 0.0
     self.rx0 = 0.0
     self._compute_critical_x0_scaling()
Example #7
0
 def read_output_samples(self, output_filepath, **kwargs):
     samples = ensure_list(
         parse_csv(output_filepath.replace(".csv", "*"),
                   merge=kwargs.pop("merge_outputs", False)))
     if len(samples) == 1:
         return samples[0]
     return samples
 def plot_fit_scalar_params_iters(self, samples_all, params=[], skip_samples=0, title_prefix="",
                                  subplot_shape=None, figure_name=None, figsize=FiguresConfig.LARGE_SIZE):
     if len(title_prefix) > 0:
         title_prefix = title_prefix + ": "
     title = title_prefix + "Parameters samples per iteration"
     samples_all = ensure_list(samples_all)
     params = [param for param in params if param in samples_all[0].keys()]
     samples = []
     for sample in samples_all:
         samples.append(extract_dict_stringkeys(sample, params, modefun="equal"))
     if len(samples) > 1:
         samples = merge_samples(samples)
     else:
         samples = samples[0]
     if subplot_shape is None:
         n_params = len(samples)
         # subplot_shape = self.rect_subplot_shape(n_params, mode="col")
         if n_params > 1:
             subplot_shape = (int(numpy.ceil(1.0*n_params/2)), 2)
         else:
             subplot_shape = (1, 1)
     n_chains_or_runs = samples.values()[0].shape[0]
     legend = {samples.keys()[0]: ["chain/run " + str(ii+1) for ii in range(n_chains_or_runs)]}
     return self.plots(samples, shape=subplot_shape, transpose=True, skip=skip_samples, xlabels={}, xscales={},
                       yscales={}, title=title, lgnd=legend, figure_name=figure_name, figsize=figsize)
Example #9
0
 def plot_HMC(self,
              samples,
              skip_samples=0,
              title='HMC NUTS trace',
              figure_name=None,
              figsize=FiguresConfig.LARGE_SIZE):
     nuts = []
     for sample in ensure_list(samples):
         nuts.append(extract_dict_stringkeys(sample, "__", modefun="find"))
     if len(nuts) > 1:
         nuts = merge_samples(nuts)
     else:
         nuts = nuts[0]
     n_chains_or_runs = nuts.values()[0].shape[0]
     legend = {
         nuts.keys()[0]:
         ["chain/run " + str(ii + 1) for ii in range(n_chains_or_runs)]
     }
     return self.plots(nuts,
                       shape=(4, 2),
                       transpose=True,
                       skip=skip_samples,
                       xlabels={},
                       xscales={},
                       yscales={"stepsize__": "log"},
                       lgnd=legend,
                       title=title,
                       figure_name=figure_name,
                       figsize=figsize)
Example #10
0
    def read_sensors(self, filename, root_folder, s_type, atlas=""):
        def get_sensors_name(sensors_file, s_type):
            locations_file = sensors_file[0]
            if len(sensors_file) > 1:
                gain_file = sensors_file[1]
            else:
                gain_file = ""
            return s_type.value + (locations_file + gain_file).replace(
                ".txt", "").replace(s_type.value, "")

        filename = ensure_list(filename)
        name = get_sensors_name(filename, s_type)
        path = os.path.join(root_folder, filename[0])
        if os.path.isfile(path):
            if s_type == Sensors.TYPE_EEG:
                tvb_sensors = sensors.SensorsEEG.from_file(path)
            elif s_type == Sensors.TYPE_MEG:
                tvb_sensors = sensors.SensorsMEG.from_file(path)
            else:
                tvb_sensors = sensors.SensorsInternal.from_file(path)
            if len(filename) > 1:
                gain_matrix = self.read_gain_matrix(
                    os.path.join(root_folder, atlas, filename[1]), s_type,
                    atlas)
            else:
                gain_matrix = np.array([])
            return Sensors(tvb_sensors.labels,
                           tvb_sensors.locations,
                           orientations=tvb_sensors.orientations,
                           gain_matrix=gain_matrix,
                           s_type=s_type,
                           name=name)
        else:
            self.logger.warning("\nNo Sensor file found at path " + path + "!")
            return None
 def plot_fit_region_params(self, samples, stats=None, probabilistic_model=None,
                            params=[], special_idx=[], region_labels=[],
                            regions_mode="all", per_chain_or_run=False, skip_samples=0, title_prefix=""):
     if len(title_prefix) > 0:
         title_prefix = title_prefix + " "
     title_pair_plot = title_prefix + "Regions parameters pair plot"
     title_violin_plot = title_prefix + "Regions parameters samples"
     # We assume in this function that regions_inds run for all regions for the statistical model,
     # and either among all or only among active regions for samples, ests and stats, depending on regions_mode
     samples = ensure_list(samples)
     priors = {}
     truth = {}
     if probabilistic_model is not None:
         if regions_mode == "active":
             regions_inds = ensure_list(probabilistic_model.active_regions)
         else:
             regions_inds = range(probabilistic_model.number_of_regions)
         I = numpy.ones((probabilistic_model.number_of_regions, )) #, 1
         for param in params:
             pdf = ensure_list(probabilistic_model.get_prior_pdf(param))
             for ip, p in enumerate(pdf):
                 pdf[ip] = ((numpy.array(p) * I).T[regions_inds])
             priors.update({param: (pdf[0].squeeze(), pdf[1].squeeze())})
             truth.update({param: ((probabilistic_model.get_truth(param) * I)[regions_inds]).squeeze()}) #[:, 0]
     # plot region-wise parameters
     f1 = self._region_parameters_violin_plots(samples, truth, priors, stats, params, skip_samples,
                                               per_chain_or_run=per_chain_or_run, labels=region_labels,
                                               special_idx=special_idx, figure_name=title_violin_plot)
     f2 = []
     for p in params:
         if not(per_chain_or_run) and samples[0][p].shape[1] < 10:
             p_title_pair_plot = p + " " + title_pair_plot
             p_pair_plot_params = []
             p_pair_plot_samples = [{} for _ in range(len(samples))]
             for inode, label in enumerate(region_labels):
                 temp_name = "x0[" + label + "]"
                 p_pair_plot_params.append(temp_name)
                 for ichain, s in enumerate(samples):
                     inode[ichain].update({temp_name: s[p][:, inode]})
                     if probabilistic_model is not None:
                         priors.update({temp_name: (priors[p][0][inode], priors[p][1][inode])})
                         truth.update({temp_name: truth[p][inode]})
             f2.append(self._parameters_pair_plots(p_pair_plot_samples, p_pair_plot_params, None, priors, truth,
                                              skip_samples, title=p_title_pair_plot))
     return f1, f2
 def _region_parameters_violin_plots(self, samples_all, values=None, lines=None, stats=None,
                                     params=[], skip_samples=0, per_chain_or_run=False,
                                     labels=[], special_idx=None, figure_name="Regions parameters samples",
                                     figsize=FiguresConfig.VERY_LARGE_SIZE):
     if isinstance(values, dict):
         vals_fun = lambda param: values.get(param, numpy.array([]))
     else:
         vals_fun = lambda param: []
     if isinstance(lines, dict):
         lines_fun = lambda param: lines.get(param, numpy.array([]))
     else:
         lines_fun = lambda param: []
     samples_all = ensure_list(samples_all)
     params = [param for param in params if param in samples_all[0].keys()]
     samples = []
     for sample in samples_all:
         samples.append(extract_dict_stringkeys(sample, params, modefun="equal"))
     if len(labels) == 0:
         labels = generate_region_labels(samples[0].values()[0].shape[-1], labels, numbering=False)
     n_chains = len(samples_all)
     n_samples = samples[0].values()[0].shape[0]
     if n_samples > 1:
         violin_flag = True
     else:
         violin_flag = False
     if not per_chain_or_run and n_chains > 1:
         samples = [merge_samples(samples)]
         plot_samples = lambda s: numpy.concatenate(numpy.split(s[:, skip_samples:].T, n_chains, axis=2),
                                                    axis=1).squeeze().T
         plot_figure_name = lambda ichain: figure_name
     else:
         plot_samples = lambda s: s[skip_samples:]
         plot_figure_name = lambda ichain: figure_name + ": chain " + str(ichain + 1)
     params_labels = {}
     for ip, p in enumerate(params):
         if ip == 0:
             params_labels[p] = self._params_stats_labels(p, stats, labels)
         else:
             params_labels[p] = self._params_stats_labels(p, stats, "")
     n_params = len(params)
     if n_params > 9:
         warning("Number of subplots in column wise vector-violin-plots cannot be > 9 and it is "
                           + str(n_params) + "!")
     subplot_ind = 100 + n_params * 10
     figs = []
     for ichain, chain_sample in enumerate(samples):
         pyplot.figure(plot_figure_name(ichain), figsize=figsize)
         for ip, param in enumerate(params):
             fig = self.plot_vector_violin(plot_samples(chain_sample[param]), vals_fun(param),
                                           lines_fun(param), params_labels[param],
                                           subplot_ind + ip + 1, param, violin_flag=violin_flag,
                                           colormap="YlOrRd", show_y_labels=True,
                                           indices_red=special_idx, sharey=None)
         self._save_figure(pyplot.gcf(), None)
         self._check_show()
         figs.append(fig)
     return tuple(figs)
Example #13
0
    def read_simulator_model_group(self, h5_file, model, group):
        for dataset in h5_file[group].keys():
            if dataset in ["variables_of_interest", "state_variables"]:
                setattr(model, dataset,
                        ensure_list(h5_file[group][dataset][()]))
            else:
                setattr(model, dataset, h5_file[group][dataset][()])

        for attr in h5_file[group].attrs.keys():
            setattr(model, attr, h5_file[group].attrs[attr])

        return model
Example #14
0
 def check_number_of_inputs(nmodels, input, input_str):
     input = ensure_list(input)
     ninput = len(input)
     if ninput != nmodels:
         if ninput == 1:
             input *= nmodels
         else:
             raise_value_error(
                 "The size of input " + input_str + " (" + str(ninput) +
                 ") is neither equal to the number of models (" +
                 str(nmodels) + ") nor equal to 1!")
     return input
 def concatenate_in_time(self, timeseries_list):
     timeseries_list = ensure_list(timeseries_list)
     out_timeseries = timeseries_list[0]
     for id, timeseries in enumerate(timeseries_list[1:]):
         if out_timeseries.time_step == timeseries.time_step:
             out_timeseries.data = np.concatenate(
                 [out_timeseries.data, timeseries.data], axis=0)
         else:
             raise_value_error("Timeseries concatenation in time failed!\n"
                               "Timeseries %d have a different time step (%s) than the ones before(%s)!" \
                               % (id, str(timeseries_list.time_step), str(out_timeseries.time_step)))
     return out_timeseries
 def set_normalize(self, values):
     values = ensure_list(values)
     n_vals = len(values)
     if n_vals > 0:
         if n_vals > 2:
             raise_value_error(
                 "Invalid disease hypothesis normalization values!: " +
                 str(values) + "\nThey cannot be more than 2!")
         else:
             if n_vals < 2:
                 # Assuming normalization only to a maximum value, keeping the existing minimum one
                 values = [numpy.min(self.diseased_regions_values)] + values
             self.normalize_values = values
     return self
 def plot_head(self, head):
     output = []
     output.append(self._plot_connectivity(head.connectivity))
     output.append(self._plot_connectivity_stats(head.connectivity))
     count = 1
     for s_type in SensorTypes:
         sensors = getattr(head, "sensors" + s_type.value)
         if isinstance(sensors, (list, Sensors)):
             sensors_list = ensure_list(sensors)
             if len(sensors_list) > 0:
                 for s in sensors_list:
                     count, figure, ax, cax = self._plot_sensors(
                         s, head.connectivity.region_labels, count)
                     output.append((figure, ax, cax))
     return tuple(output)
Example #18
0
 def get_sensors_by_index(self, s_type=SensorTypes.TYPE_SEEG, sensor_ids=0):
     sensors = self.get_sensors(s_type)
     if sensors is None:
         return sensors
     else:
         sensors = sensors.values()
         out_sensors = []
         sensors = ensure_list(sensors)
         for iS, s in enumerate(sensors):
             if np.in1d(iS, sensor_ids):
                 out_sensors.append(sensors[iS])
         if len(out_sensors) == 0:
             return None
         elif len(out_sensors) == 1:
             return out_sensors[0]
         else:
             return out_sensors
Example #19
0
def merge_samples(samples, skip_samples=0, flatten=False):
    samples = ensure_list(samples)
    if len(samples) > 1:
        samples = list_of_dicts_to_dicts_of_ndarrays(samples)
        for skey in samples.keys():
            if len(samples[skey].shape) == 1:
                samples[skey] = (samples[skey] * np.ones((1, 1))).T
            if flatten:
                sshape = samples[skey].shape
                if len(sshape) > 2:
                    samples[skey] = np.reshape(samples[skey][:, skip_samples:],
                                               tuple((-1, ) + sshape[2:]))
                else:
                    samples[skey] = samples[skey][:, skip_samples:].flatten()

        return samples
    else:
        return samples[0]
def reconfigure_model_with_fit_estimates(head,
                                         model_configuration,
                                         probabilistic_model,
                                         estimates,
                                         base_path,
                                         writer=None,
                                         plotter=None):
    # -------------------------- Reconfigure model after fitting:---------------------------------------------------
    for id_est, est in enumerate(ensure_list(estimates)):
        K = est.get("K", np.mean(model_configuration.K))
        tau1 = est.get("tau1", np.mean(model_configuration.tau1))
        tau0 = est.get("tau0", np.mean(model_configuration.tau0))
        fit_conn = est.get("MC", model_configuration.connectivity)
        fit_model_configuration_builder = \
            ModelConfigurationBuilder(model_configuration.model_name, fit_conn,
                                      K_unscaled=K * model_configuration.number_of_regions). \
                set_parameter("tau1", tau1).set_parameter("tau0", tau0)
        x0 = model_configuration.x0
        x0[probabilistic_model.active_regions] = est["x0"]
        x0_values_fit = fit_model_configuration_builder._compute_x0_values_from_x0_model(
            x0)
        hyp_fit = HypothesisBuilder().set_nr_of_regions(head.connectivity.number_of_regions). \
            set_name('fit' + str(id_est + 1)). \
            set_x0_hypothesis(list(probabilistic_model.active_regions),
                              x0_values_fit[probabilistic_model.active_regions]). \
            build_hypothesis()

        model_configuration_fit = \
            fit_model_configuration_builder.build_model_from_hypothesis(hyp_fit)

        if writer:
            writer.write_hypothesis(hyp_fit, path("fit_Hypothesis", base_path))
            writer.write_model_configuration(
                model_configuration_fit, path("fit_ModelConfig", base_path))

        # Plot nullclines and equilibria of model configuration
        if plotter:
            plotter.plot_state_space(
                model_configuration_fit,
                region_labels=head.connectivity.region_labels,
                special_idx=probabilistic_model.active_regions,
                figure_name="fit_Nullclines and equilibria")
        return model_configuration_fit
Example #21
0
 def update_active_regions(self,
                           probabilistic_model,
                           e_values=[],
                           x0_values=[],
                           lsa_propagation_strengths=[],
                           reset=False):
     if reset:
         probabilistic_model.update_active_regions([])
     for m in ensure_list(self.active_regions_selection_methods):
         if isequal_string(m, "E"):
             probabilistic_model = self.update_active_regions_e_values(
                 probabilistic_model, e_values, reset=False)
         elif isequal_string(m, "x0"):
             probabilistic_model = self.update_active_regions_x0_values(
                 probabilistic_model, x0_values, reset=False)
         elif isequal_string(m, "LSA"):
             probabilistic_model = self.update_active_regions_lsa(
                 probabilistic_model,
                 lsa_propagation_strengths,
                 reset=False)
     return probabilistic_model
def get_x1_estimates_from_samples(samples,
                                  model_data,
                                  region_labels=[],
                                  time_unit="ms"):
    time = model_data.get("time", False)
    if time is not False:
        time_start = time[0]
        time_step = np.diff(time).mean()
    else:
        time_start = 0
        time_step = 1
    if isinstance(samples[0]["x1"], np.ndarray):
        get_x1 = lambda x1: x1.T
    else:
        get_x1 = lambda x1: x1.squeezed
    (n_times, n_regions, n_samples) = get_x1(samples[0]["x1"]).shape
    active_regions = model_data.get("active_regions",
                                    np.array(range(n_regions)))
    region_labels = generate_region_labels(
        np.maximum(n_regions, len(region_labels)), region_labels, ". ", False)
    if len(region_labels) > len(active_regions):
        region_labels = region_labels[active_regions]
    x1 = np.empty((n_times, n_regions, 0))
    for sample in ensure_list(samples):
        x1 = np.concatenate([x1, get_x1(sample["x1"])], axis=2)
    x1_mean = Timeseries(np.nanmean(x1, axis=2).squeeze(), {
        TimeseriesDimensions.SPACE.value: region_labels,
        TimeseriesDimensions.VARIABLES.value: ["x1"]
    },
                         time_start=time_start,
                         time_step=time_step,
                         time_unit=time_unit)
    x1_std = Timeseries(np.nanstd(x1, axis=2).squeeze(), {
        TimeseriesDimensions.SPACE.value: region_labels,
        TimeseriesDimensions.VARIABLES.value: ["x1std"]
    },
                        time_start=time_start,
                        time_step=time_step,
                        time_unit=time_unit)
    return x1_mean, x1_std
    def _parameters_pair_plots(self, samples_all, params=[],
                               stats=None, priors={}, truth={}, skip_samples=0, title='Parameters samples',
                               figure_name=None, figsize=FiguresConfig.VERY_LARGE_SIZE):
        subtitles = list(self._params_stats_subtitles(params, stats))
        samples = []
        samples_all = ensure_list(samples_all)
        params = [param for param in params if param in samples_all[0].keys()]
        for sample in samples_all:
            samples.append(extract_dict_stringkeys(sample, params, modefun="equal"))
        if len(samples) > 1:
            samples = merge_samples(samples)
        else:
            samples = samples[0]
            n_samples = (samples.values()[0]).shape[0]
            for p_key, p_val in samples.items():
                samples[p_key] = numpy.reshape(p_val, (1, n_samples))
        diagonal_plots = {}
        # for param_key in samples.keys():
        for p_key in params:
            diagonal_plots.update({p_key: [priors.get(p_key, ()), truth.get(p_key, ())]})

        return self.pair_plots(samples, params, diagonal_plots, True, skip_samples,
                               title, "chain/run ", subtitles, figure_name, figsize)
Example #24
0
 def compute_estimates_from_samples(self, samples):
     ests = []
     for chain_or_run_samples in ensure_list(samples):
         est = {}
         for pkey, pval in chain_or_run_samples.items():
             try:
                 est[pkey +
                     "_low"], est[pkey], est[pkey + "_std"] = describe(
                         chain_or_run_samples[pkey])[1:4]
                 est[pkey + "_high"] = est[pkey + "_low"][1]
                 est[pkey + "_low"] = est[pkey + "_low"][0]
                 est[pkey + "_std"] = np.sqrt(est[pkey + "_std"])
                 for skey in [
                         pkey, pkey + "_low", pkey + "_high", pkey + "_std"
                 ]:
                     est[skey] = np.squeeze(est[skey])
             except:
                 est[pkey] = chain_or_run_samples[pkey]
         ests.append(sort_dict(est))
     if len(ests) == 1:
         return ests[0]
     else:
         return ests
Example #25
0
    def read_head(
        self,
        root_folder,
        name='',
        atlas="default",
        connectivity_file="connectivity.zip",
        cortical_surface_file="surface_cort.zip",
        subcortical_surface_file="surface_subcort.zip",
        cortical_region_mapping_file="region_mapping_cort.txt",
        subcortical_region_mapping_file="region_mapping_subcort.txt",
        eeg_sensors_files=[("eeg_brainstorm_65.txt",
                            "gain_matrix_eeg_65_surface_16k.npy")],
        meg_sensors_files=[("meg_brainstorm_276.txt",
                            "gain_matrix_meg_276_surface_16k.npy")],
        seeg_sensors_files=[("seeg_xyz.txt", "seeg_dipole_gain.txt"),
                            ("seeg_xyz.txt", "seeg_distance_gain.txt"),
                            ("seeg_xyz.txt", "seeg_regions_distance_gain.txt"),
                            ("seeg_588.txt",
                             "gain_matrix_seeg_588_surface_16k.npy")],
        vm_file="aparc+aseg.nii.gz",
        t1_file="T1.nii.gz",
    ):

        conn = self.read_connectivity(
            os.path.join(root_folder, atlas, connectivity_file))
        cort_srf = self.read_cortical_surface(
            os.path.join(root_folder, cortical_surface_file))
        subcort_srf = self.read_cortical_surface(
            os.path.join(root_folder, subcortical_surface_file))
        cort_rm = self.read_region_mapping(
            os.path.join(root_folder, atlas, cortical_region_mapping_file))
        subcort_rm = self.read_region_mapping(
            os.path.join(root_folder, atlas, subcortical_region_mapping_file))
        vm = self.read_volume_mapping(os.path.join(root_folder, atlas,
                                                   vm_file))
        t1 = self.read_t1(os.path.join(root_folder, t1_file))
        sensorsSEEG = OrderedDict()
        for s_files in ensure_list(seeg_sensors_files):
            sensors = self.read_sensors(s_files, root_folder,
                                        Sensors.TYPE_SEEG, atlas)
            sensorsSEEG[sensors.name] = sensors
        sensorsEEG = OrderedDict()
        for s_files in ensure_list(eeg_sensors_files):
            sensors = self.read_sensors(s_files, root_folder, Sensors.TYPE_EEG,
                                        atlas)
            sensorsSEEG[sensors.name] = sensors
        sensorsMEG = OrderedDict()
        for s_files in ensure_list(meg_sensors_files):
            sensors = self.read_sensors(s_files, root_folder, Sensors.TYPE_MEG,
                                        atlas)
            sensorsSEEG[sensors.name] = sensors
        if len(name) == 0:
            name = atlas
        return Head(conn,
                    cort_srf,
                    subcort_srf,
                    cort_rm,
                    subcort_rm,
                    vm,
                    t1,
                    name,
                    sensorsSEEG=sensorsSEEG,
                    sensorsEEG=sensorsEEG,
                    sensorsMEG=sensorsMEG)
def samples_to_timeseries(samples,
                          model_data,
                          target_data=None,
                          region_labels=[]):
    samples = ensure_list(samples)

    if isinstance(target_data, Timeseries):
        time = target_data.time
        n_target_data = target_data.number_of_labels
        target_data_labels = target_data.space_labels
    else:
        time = model_data.get("time", False)
        n_target_data = samples[0]["fit_target_data"]
        target_data_labels = generate_region_labels(n_target_data, [], ". ",
                                                    False)

    if time is not False:
        time_start = time[0]
        time_step = np.diff(time).mean()
    else:
        time_start = 0
        time_step = 1

    if not isinstance(target_data, Timeseries):
        target_data = Timeseries(
            target_data, {
                TimeseriesDimensions.SPACE.value: target_data_labels,
                TimeseriesDimensions.VARIABLES.value: ["target_data"]
            },
            time_start=time_start,
            time_step=time_step)

    (n_times, n_regions, n_samples) = samples[0]["x1"].T.shape
    active_regions = model_data.get("active_regions",
                                    np.array(range(n_regions)))
    region_labels = generate_region_labels(
        np.maximum(n_regions, len(region_labels)), region_labels, ". ", False)
    if len(region_labels) > len(active_regions):
        region_labels = region_labels[active_regions]

    x1 = np.empty((n_times, n_regions, 0))
    for sample in ensure_list(samples):
        for x in [
                "x1", "z", "x1_star", "z_star", "dX1t", "dZt", "dWt",
                "dX1t_star", "dZt_star", "dWt_star"
        ]:
            try:
                if x == "x1":
                    x1 = np.concatenate([x1, sample[x].T], axis=2)
                sample[x] = Timeseries(
                    np.expand_dims(sample[x].T, 2), {
                        TimeseriesDimensions.SPACE.value: region_labels,
                        TimeseriesDimensions.VARIABLES.value: [x]
                    },
                    time_start=time_start,
                    time_step=time_step,
                    time_unit=target_data.time_unit)

            except:
                pass

        sample["fit_target_data"] = Timeseries(
            np.expand_dims(sample["fit_target_data"].T, 2), {
                TimeseriesDimensions.SPACE.value: target_data_labels,
                TimeseriesDimensions.VARIABLES.value: ["fit_target_data"]
            },
            time_start=time_start,
            time_step=time_step)

    return samples, target_data, np.nanmean(x1, axis=2).squeeze(), np.nanstd(
        x1, axis=2).squeeze()
def get_target_timeseries(probabilistic_model,
                          head,
                          hypothesis,
                          times_on,
                          time_length,
                          sensors_lbls,
                          sensor_id,
                          observation_model,
                          sim_target_file,
                          empirical_target_file,
                          sim_source_type="paper",
                          downsampling=1,
                          empirical_files=[],
                          config=Config(),
                          plotter=None):

    # Some scripts for settting and preprocessing target signals:
    simulator = None
    log_flag = observation_model == OBSERVATION_MODELS.SEEG_LOGPOWER.value
    empirical_files = ensure_list(empirical_files)
    times_on = ensure_list(times_on)
    seizure_length = int(
        np.ceil(
            compute_seizure_length(probabilistic_model.tau0) / downsampling))
    if len(empirical_files) > 0:
        if log_flag:
            preprocessing = ["spectrogram", "log"]  #
        else:
            preprocessing = [
                "hpf", "mean_center", "abs-envelope", "convolve", "decimate",
                "baseline"
            ]
        # -------------------------- Get empirical data (preprocess edf if necessary) --------------------------
        signals, probabilistic_model.number_of_seizures = \
            set_multiple_empirical_data(empirical_files, empirical_target_file, head, sensors_lbls, sensor_id,
                                        seizure_length, times_on, time_length,
                                        label_strip_fun=lambda s: s.split("POL ")[-1], preprocessing=preprocessing,
                                        plotter=plotter, title_prefix="")
    else:
        probabilistic_model.number_of_seizures = 1
        # --------------------- Get fitting target simulated data (simulate if necessary) ----------------------
        probabilistic_model.target_data_type = Target_Data_Type.SYNTHETIC.value
        if sim_source_type == "paper":
            if log_flag:
                preprocessing = [
                    "spectrogram", "log"
                ]  #, "convolve" # ["hpf", "mean_center", "abs_envelope", "log"]
            else:
                preprocessing = ["convolve", "decimate", "baseline"]
        else:
            if log_flag:
                preprocessing = ["log"]
            else:
                preprocessing = ["decimate", "baseline"]
        rescale_x1eq = -1.225
        if np.max(probabilistic_model.model_config.x1eq) > rescale_x1eq:
            rescale_x1eq = False
        signals, simulator = \
            set_simulated_target_data(sim_target_file, head, hypothesis, probabilistic_model, sensor_id,
                                      rescale_x1eq=rescale_x1eq, sim_type=sim_source_type,
                                      times_on_off=[times_on[0], times_on[0] + time_length],
                                      seizure_length=seizure_length,
                                      # Maybe change some of those for Epileptor 6D simulations:
                                      bipolar=False, preprocessing=preprocessing,
                                      plotter=plotter, config=config, title_prefix="")
    return signals, probabilistic_model, simulator
def set_multiple_empirical_data(empirical_files,
                                ts_file,
                                head,
                                sensors_lbls,
                                sensor_id=0,
                                seizure_length=SEIZURE_LENGTH,
                                times_on=[],
                                time_length=25600,
                                time_units="ms",
                                label_strip_fun=None,
                                preprocessing=TARGET_DATA_PREPROCESSING,
                                low_hpf=LOW_HPF,
                                high_hpf=HIGH_HPF,
                                low_lpf=LOW_LPF,
                                high_lpf=HIGH_LPF,
                                bipolar=BIPOLAR,
                                win_len_ratio=WIN_LEN_RATIO,
                                plotter=None,
                                title_prefix=""):
    empirical_files = ensure_list(ensure_list(empirical_files))
    n_seizures = len(empirical_files)
    times_on = ensure_list(times_on)
    signals = []
    ts_filename = ts_file.split(".h5")[0]
    for empirical_file, time_on in zip(empirical_files, times_on):
        seizure_name = os.path.basename(empirical_file).split(".")[0]
        signals.append(
            set_empirical_data(empirical_file,
                               "_".join([ts_filename, seizure_name]) + ".h5",
                               head, sensors_lbls, sensor_id, seizure_length,
                               [time_on, time_on + time_length], time_units,
                               label_strip_fun, preprocessing, low_hpf,
                               high_hpf, low_lpf, high_lpf, bipolar,
                               win_len_ratio, plotter, title_prefix))
    if n_seizures > 1:
        signals = TimeseriesService().concatenate_in_time(signals)
    else:
        signals = signals[0]
    if plotter:
        title_prefix = title_prefix + "MultiseizureEmpiricalSEEG"
        plotter.plot_raster({"ObservationRaster": signals.squeezed},
                            signals.time,
                            time_units=signals.time_unit,
                            special_idx=[],
                            offset=0.1,
                            title='Multiseizure Observation Raster Plot',
                            figure_name=title_prefix + 'ObservationRasterPlot',
                            labels=signals.space_labels)
        plotter.plot_timeseries({"Observation": signals.squeezed},
                                signals.time,
                                time_units=signals.time_unit,
                                special_idx=[],
                                title='Observation Time Series',
                                figure_name=title_prefix +
                                'ObservationTimeSeries',
                                labels=signals.space_labels)
    move_overwrite_files_to_folder_with_wildcard(
        os.path.join(plotter.config.out.FOLDER_FIGURES,
                     "fitData_EmpiricalSEEG"),
        os.path.join(plotter.config.out.FOLDER_FIGURES,
                     title_prefix.replace(" ", "_")) + "*")
    return signals, n_seizures
Example #29
0
 def set_attributes(self, attributes_names, attribute_values):
     for attribute_name, attribute_value in zip(
             ensure_list(attributes_names), ensure_list(attribute_values)):
         setattr(self, attribute_name, attribute_value)
     return self
Example #30
0
    def plot_bars(self,
                  data,
                  ax=None,
                  fig=None,
                  title="",
                  group_names=[],
                  legend_prefix="",
                  figsize=FiguresConfig.VERY_LARGE_SIZE):
        def barlabel(ax, rects, positions):
            """
            Attach a text label on each bar displaying its height
            """
            for rect, pos in zip(rects, positions):
                height = rect.get_height()
                if pos < 0:
                    y = -height
                    pos = 0.75 * pos
                else:
                    y = height
                    pos = 0.25 * pos
                ax.text(rect.get_x() + rect.get_width() / 2.,
                        pos,
                        '%0.2f' % y,
                        color="k",
                        ha='center',
                        va='bottom',
                        rotation=90)

        if fig is None:
            fig, ax = pyplot.subplots(1, 1, figsize=figsize)
            show_and_save = True
        else:
            show_and_save = False
            if ax is None:
                ax = pyplot.gca()
        if isinstance(
                data,
            (list, tuple)):  # If, there are many groups, data is a list:
            # Fill in with nan in case that not all groups have the same number of elements
            from itertools import izip_longest
            data = numpy.array(
                list(izip_longest(*ensure_list(data), fillvalue=numpy.nan))).T
        elif data.ndim == 1:  # This is the case where there is only one group...
            data = numpy.expand_dims(data, axis=1).T
        n_groups, n_elements = data.shape
        posmax = numpy.nanmax(data)
        negmax = numpy.nanmax(-(-data))
        n_groups_names = len(group_names)
        if n_groups_names != n_groups:
            if n_groups_names != 0:
                warning("Ignoring group_names because their number (" +
                        str(n_groups_names) +
                        ") is not equal to the number of groups (" +
                        str(n_groups) + ")!")
            group_names = n_groups * [""]
        colorcycle = pyplot.rcParams['axes.prop_cycle'].by_key()['color']
        n_colors = len(colorcycle)
        x_inds = numpy.arange(n_groups)
        width = 0.9 / n_elements
        elements = []
        for iE in range(n_elements):
            elements.append(
                ax.bar(x_inds + iE * width,
                       data[:, iE],
                       width,
                       color=colorcycle[iE % n_colors]))
            positions = numpy.array(
                [negmax if d < 0 else posmax for d in data[:, iE]])
            positions[numpy.logical_or(numpy.isnan(positions),
                                       numpy.isinf(
                                           numpy.abs(positions)))] = 0.0
            barlabel(ax, elements[-1], positions)
        if n_elements > 1:
            legend = [
                legend_prefix + str(ii) for ii in range(1, n_elements + 1)
            ]
            ax.legend(tuple([element[0] for element in elements]),
                      tuple(legend))
        ax.set_xticks(x_inds + n_elements * width / 2)
        ax.set_xticklabels(tuple(group_names))
        ax.set_title(title)
        ax.autoscale()  # tight=True
        ax.set_xlim([-1.05 * width, n_groups * 1.05])
        if show_and_save:
            fig.tight_layout()
            self._save_figure(fig)
            self._check_show()
        return fig, ax