Esempio n. 1
0
def sensitivity_analysis_pse_from_hypothesis(n_samples,
                                             hypothesis,
                                             connectivity_matrix,
                                             region_labels,
                                             method="sobol",
                                             half_range=0.1,
                                             global_coupling=[],
                                             healthy_regions_parameters=[],
                                             save_services=False,
                                             config=Config(),
                                             model_config_kwargs={},
                                             **kwargs):
    logger = initialize_logger(__name__, config.out.FOLDER_LOGS)
    # Compute lsa for this hypothesis before sensitivity analysis:
    logger.info("Running hypothesis: " + hypothesis.name)
    model_configuration_builder, model_configuration, lsa_service, lsa_hypothesis = \
        start_lsa_run(hypothesis, connectivity_matrix, config, **model_config_kwargs)
    results, pse_results = sensitivity_analysis_pse_from_lsa_hypothesis(
        n_samples, lsa_hypothesis, connectivity_matrix,
        model_configuration_builder, lsa_service, region_labels, method,
        half_range, global_coupling, healthy_regions_parameters, save_services,
        config, **kwargs)
    return model_configuration_builder, model_configuration, lsa_service, lsa_hypothesis, results, pse_results
Esempio n. 2
0
def main_fit_sim_hyplsa(ep_name="ep_l_frontal_complex",
                        stats_model_name="vep_sde",
                        EMPIRICAL="",
                        times_on_off=[],
                        sensors_lbls=[],
                        sensors_inds=[],
                        fitmethod="optimizing",
                        stan_service="CmdStan",
                        config=Config(),
                        **kwargs):
    # ------------------------------Stan model and service--------------------------------------
    # Compile or load model:
    # model_code_path = os.path.join(STATS_MODELS_PATH, stats_model_name + ".stan")
    model_code_path = os.path.join(config.generic.STATS_MODELS_PATH,
                                   stats_model_name + ".stan")
    if isequal_string(stan_service, "CmdStan"):
        stan_service = CmdStanService(model_name=stats_model_name,
                                      model=None,
                                      model_code=None,
                                      model_code_path=model_code_path,
                                      fitmethod=fitmethod,
                                      random_seed=12345,
                                      init="random",
                                      config=config)
    else:
        stan_service = PyStanService(model_name=stats_model_name,
                                     model=None,
                                     model_code=None,
                                     model_code_path=model_code_path,
                                     fitmethod=fitmethod,
                                     random_seed=12345,
                                     init="random",
                                     config=config)
    stan_service.set_or_compile_model()

    # -------------------------------Reading model_data and hypotheses--------------------------------------------------
    head, hypos = from_head_to_hypotheses(ep_name, config, plot_head=False)

    for hyp in hypos[:1]:

        # --------------------------Model configuration and LSA-----------------------------------
        model_config_file = os.path.join(config.out.FOLDER_RES,
                                         hyp.name + "_ModelConfig.h5")
        hyp_file = os.path.join(config.out.FOLDER_RES, hyp.name + "_LSA.h5")
        if os.path.isfile(hyp_file) and os.path.isfile(model_config_file):
            model_configuration = reader.read_model_configuration(
                model_config_file)
            lsa_hypothesis = reader.read_hypothesis(hyp_file)
        else:
            model_configuration, lsa_hypothesis, model_configuration_builder, lsa_service = \
                from_hypothesis_to_model_config_lsa(hyp, head, eigen_vectors_number=None, weighted_eigenvector_sum=True,
                                                    config=config, K=K_DEF)

        dynamical_model = "EpileptorDP2D"

        # -------------------------- Get model_data and observation signals: -------------------------------------------
        model_inversion_file = os.path.join(
            config.out.FOLDER_RES, hyp.name + "_ModelInversionService.h5")
        stats_model_file = os.path.join(config.out.FOLDER_RES,
                                        hyp.name + "_StatsModel.h5")
        model_data_file = os.path.join(config.out.FOLDER_RES,
                                       hyp.name + "_ModelData.h5")
        if os.path.isfile(model_inversion_file) and os.path.isfile(stats_model_file) \
                and os.path.isfile(model_data_file):
            model_inversion = reader.read_model_inversions_service(
                model_inversion_file)
            statistical_model = reader.read_generic(stats_model_file)
            model_data = stan_service.load_model_data_from_file(
                model_data_path=model_data_file)
        else:
            model_inversion = SDEModelInversionService(
                model_configuration,
                lsa_hypothesis,
                head,
                dynamical_model,
                x1eq_max=-1.0,
                priors_mode="uninformative")
            # observation_expression="lfp"
            statistical_model = model_inversion.generate_statistical_model(
                x1eq_max=-1.0, observation_model="seeg_logpower")
            statistical_model = model_inversion.update_active_regions(
                statistical_model,
                methods=["e_values", "LSA"],
                active_regions_th=0.1,
                reset=True)
            plotter.plot_statistical_model(statistical_model,
                                           "Statistical Model")
            cut_signals_tails = (6, 6)
            n_electrodes = 8
            sensors_per_electrode = 2
            if os.path.isfile(EMPIRICAL):
                # ---------------------------------------Get empirical data-------------------------------------------
                target_data_type = "empirical"
                statistical_model.observation_model = "seeg_logpower"
                decimate = 2
                ts_file = os.path.join(config.out.FOLDER_RES,
                                       hyp.name + "_ts_empirical.mat")
                try:
                    vois_ts_dict = loadmat(ts_file)
                    time = vois_ts_dict["time"].flatten()
                    sensors_inds = np.array(
                        vois_ts_dict["sensors_inds"]).flatten().tolist()
                    sensors_lbls = np.array(
                        vois_ts_dict["sensors_lbls"]).flatten().tolist()
                    vois_ts_dict.update({
                        "time": time,
                        "sensors_inds": sensors_inds,
                        "sensors_lbls": sensors_lbls
                    })
                    savemat(ts_file, vois_ts_dict)
                except:
                    if len(sensors_lbls) == 0:
                        sensor_lbls = head.get_sensors_id(
                        ).labels[sensors_inds]
                    signals, time, fs = prepare_seeg_observable(EMPIRICAL,
                                                                times_on_off,
                                                                sensors_lbls,
                                                                plot_flag=True,
                                                                log_flag=True)
                    if len(
                            sensors_inds
                    ) > 1:  # get_bipolar_channels(sensors_inds, sensors_lbls)
                        sensors_inds, sensors_lbls = head.get_sensors_id(
                        ).get_bipolar_sensors(sensors_inds=sensors_inds)
                    inds = np.argsort(sensors_inds)
                    sensors_inds = np.array(
                        sensors_inds)[inds].flatten().tolist()
                    sensors_lbls = np.array(
                        sensors_lbls)[inds].flatten().tolist()
                    all_signals = np.zeros(
                        (signals.shape[0],
                         len(model_inversion.sensors_labels)))
                    all_signals[:, sensors_inds] = signals[:, inds]
                    signals = all_signals
                    del all_signals
                    vois_ts_dict = {
                        "time": time.flatten(),
                        "signals": signals,
                        "sensors_inds": sensors_inds,
                        "sensors_lbls": sensors_lbls
                    }
                    savemat(ts_file, vois_ts_dict)
                model_inversion.sensors_labels[
                    vois_ts_dict["sensors_inds"]] = sensors_lbls
                manual_selection = sensors_inds
            else:
                # -------------------------- Get simulated data (simulate if necessary) -------------------------------
                target_data_type = "simulated"
                statistical_model.observation_model = "seeg_logpower"  # "seeg_logpower" # "lfp_power"
                decimate = 1
                ts_file = os.path.join(config.out.FOLDER_RES,
                                       hyp.name + "_ts.h5")
                vois_ts_dict = \
                    from_model_configuration_to_simulation(model_configuration, head, lsa_hypothesis,
                                                           sim_type="fitting", dynamical_model=dynamical_model,
                                                           ts_file=ts_file, plot_flag=True, config=config)
                # if len(sensors_inds) > 1:  # get_bipolar_channels(sensors_inds, sensors_lbls)
                #     sensors_inds, sensors_lbls = head.get_sensors_id().get_bipolar_sensors(sensors_inds=sensors_inds)
                if statistical_model.observation_model.find("seeg") >= 0:
                    manual_selection = sensors_inds
                else:
                    if stats_model_name.find("vep-fe-rev") >= 0:
                        manual_selection = statistical_model.active_regions
                    else:
                        manual_selection = []
                # -------------------------- Select and set observation signals -----------------------------------
            signals, time, statistical_model, vois_ts_dict = \
                model_inversion.set_target_data_and_time(target_data_type, vois_ts_dict, statistical_model,
                                                         decimate=decimate, cut_signals_tails=cut_signals_tails,
                                                         manual_selection=manual_selection,
                                                         auto_selection="correlation-power", # auto_selection=False,
                                                         n_electrodes=n_electrodes,
                                                         sensors_per_electrode=sensors_per_electrode,
                                                         group_electrodes=True, normalization="baseline-amplitude",
                                                         )
            # if len(model_inversion.signals_inds) < head.get_sensors_id().number_of_sensors:
            #     statistical_model = \
            #             model_inversion.update_active_regions_seeg(statistical_model)
            if model_inversion.data_type == "lfp":
                labels = model_inversion.region_labels
            else:
                labels = model_inversion.sensors_labels
            if vois_ts_dict.get("signals", None) is not None:
                vois_ts_dict["signals"] -= vois_ts_dict["signals"].min()
                vois_ts_dict["signals"] /= vois_ts_dict["signals"].max()
                plotter.plot_raster(
                    {'Target Signals': vois_ts_dict["signals"]},
                    vois_ts_dict["time"].flatten(),
                    time_units="ms",
                    title=hyp.name + ' Target Signals raster',
                    special_idx=model_inversion.signals_inds,
                    offset=0.1,
                    labels=labels)
            plotter.plot_timeseries(
                {'Target Signals': signals},
                time,
                time_units="ms",
                title=hyp.name + ' Target Signals',
                labels=labels[model_inversion.signals_inds])
            writer.write_model_inversion_service(
                model_inversion,
                os.path.join(config.out.FOLDER_RES,
                             hyp.name + "_ModelInversionService.h5"))
            writer.write_generic(statistical_model, config.out.FOLDER_RES,
                                 hyp.name + "_StatsModel.h5")
            # try:
            #     model_data = stan_service.load_model_data_from_file()
            # except:
            model_data = build_stan_model_dict(statistical_model, signals,
                                               model_inversion)
            writer.write_dictionary(
                model_data,
                os.path.join(config.out.FOLDER_RES,
                             hyp.name + "_ModelData.h5"))
        if os.path.isfile(EMPIRICAL):
            simulation_values = None
        else:
            simulation_values = {
                "x0": model_configuration.x0,
                "x1eq": model_configuration.x1EQ,
                "x1init": model_configuration.x1EQ,
                "zinit": model_configuration.zEQ
            }
        # Stupid code to interface with INS stan model
        if stats_model_name.find("vep-fe-rev") >= 0:
            model_data = build_stan_model_dict_to_interface_ins(
                model_data, statistical_model, model_inversion)
            x1_str = "x"
            input_signals_str = "seeg_log_power"
            signals_str = "mu_seeg_log_power"
            dX1t_str = "x_eta"
            dZt_str = "z_eta"
            sig_str = "sigma"
            k_str = "k"
            pair_plot_params = [
                "time_scale", "k", "sigma", "epsilon", "amplitude", "offset"
            ]
            region_violin_params = ["x0", "x_init", "z_init"]
            if simulation_values is not None:
                simulation_values.update({
                    "x0":
                    simulation_values["x0"][statistical_model.active_regions]
                })
                simulation_values.update({
                    "x_init":
                    simulation_values["x1init"][
                        statistical_model.active_regions]
                })
                simulation_values.update({
                    "z_init":
                    simulation_values["zinit"][
                        statistical_model.active_regions]
                })
            connectivity_plot = False
            estMC = lambda: model_configuration.model_connectivity
            region_mode = "active"
        else:
            x1_str = "x1"
            input_signals_str = "signals"
            signals_str = "fit_signals"
            dX1t_str = "dX1t"  # "x1_dWt"
            dZt_str = "dZt"  # "z_dWt"
            sig_str = "sig"
            k_str = "K"
            pair_plot_params = [
                "tau1", "tau0", "K", "sig_init", "sig", "eps", "scale_signal",
                "offset_signal"
            ]
            region_violin_params = ["x0", "x1eq", "x1init", "zinit"]
            connectivity_plot = False
            estMC = lambda est: est["MC"]
            region_mode = "all"
        # -------------------------- Fit and get estimates: ------------------------------------------------------------
        ests, samples, summary = stan_service.fit(debug=0,
                                                  simulate=0,
                                                  model_data=model_data,
                                                  merge_outputs=False,
                                                  chains=1,
                                                  refresh=1,
                                                  num_warmup=5,
                                                  num_samples=5,
                                                  max_depth=7,
                                                  delta=0.8,
                                                  **kwargs)
        writer.write_generic(ests, config.out.FOLDER_RES,
                             hyp.name + "_fit_est.h5")
        writer.write_generic(samples, config.out.FOLDER_RES,
                             hyp.name + "_fit_samples.h5")
        if summary is not None:
            writer.write_generic(summary, config.out.FOLDER_RES,
                                 hyp.name + "_fit_summary.h5")
            if isinstance(summary, dict):
                R_hat = summary.get("R_hat", None)
                if R_hat is not None:
                    R_hat = {"R_hat": R_hat}
        ests = ensure_list(ests)
        plotter.plot_fit_results(
            model_inversion,
            ests,
            samples,
            statistical_model,
            model_data[input_signals_str],
            R_hat,
            model_data["time"],
            simulation_values,
            region_mode,
            seizure_indices=lsa_hypothesis.get_regions_disease_indices(),
            x1_str=x1_str,
            signals_str=signals_str,
            sig_str=sig_str,
            dX1t_str=dX1t_str,
            dZt_str=dZt_str,
            trajectories_plot=True,
            connectivity_plot=connectivity_plot,
            pair_plot_params=pair_plot_params,
            region_violin_params=region_violin_params)
        # -------------------------- Reconfigure model after fitting:---------------------------------------------------
        # for id_est, est in enumerate(ensure_list(ests)):
        #     fit_model_configuration_builder = \
        #         ModelConfigurationBuilder(hyp.number_of_regions, K=est[k_str] * hyp.number_of_regions)
        #     x0_values_fit = \
        #         fit_model_configuration_builder._compute_x0_values_from_x0_model(est['x0'])
        #     hyp_fit = HypothesisBuilder().set_nr_of_regions(head.connectivity.number_of_regions).set_name(
        #         'fit' + str(id_est) + "_" + hyp.name)._build_excitability_hypothesis(x0_values_fit, range(
        #         model_configuration.number_of_regions))
        #     model_configuration_fit = fit_model_configuration_builder.build_model_from_hypothesis(hyp_fit,  # est["MC"]
        #                                                                                           estMC(est))
        #     writer.write_model_configuration(model_configuration_fit,
        #                                      os.path.join(config.out.FOLDER_RES, hyp_fit.name + "_ModelConfig.h5"))
        #
        #     # Plot nullclines and equilibria of model configuration
        #     plotter.plot_state_space(model_configuration_fit,
        #                              region_labels=model_inversion.region_labels,
        #                              special_idx=statistical_model.active_regions,
        #                              model="6d", zmode="lin",
        #                              figure_name=hyp_fit.name + "_Nullclines and equilibria")
        logger.info("Done!")
Esempio n. 3
0
from tvb_epilepsy.io.h5_reader import H5Reader
from tvb_epilepsy.plot.plotter import Plotter
from tvb_epilepsy.service.hypothesis_builder import HypothesisBuilder
from tvb_epilepsy.service.model_configuration_builder import ModelConfigurationBuilder
from tvb_epilepsy.service.model_inversion.sde_model_inversion_service import SDEModelInversionService
from tvb_epilepsy.service.model_inversion.stan.cmdstan_service import CmdStanService
from tvb_epilepsy.service.model_inversion.stan.pystan_service import PyStanService
from tvb_epilepsy.service.model_inversion.vep_stan_dict_builder import build_stan_model_dict, \
    build_stan_model_dict_to_interface_ins
from tvb_epilepsy.top.scripts.hypothesis_scripts import from_head_to_hypotheses, from_hypothesis_to_model_config_lsa
from tvb_epilepsy.top.scripts.simulation_scripts import from_model_configuration_to_simulation
from tvb_epilepsy.top.scripts.seeg_data_scripts import prepare_seeg_observable

output = os.path.join(os.path.expanduser("~"), 'Dropbox', 'Work', 'VBtech',
                      'VEP', "results", "fit")
config = Config(output_base=output)
logger = initialize_logger(__name__, config.out.FOLDER_LOGS)

reader = H5Reader()
writer = H5Writer()

plotter = Plotter(config)


def main_fit_sim_hyplsa(ep_name="ep_l_frontal_complex",
                        stats_model_name="vep_sde",
                        EMPIRICAL="",
                        times_on_off=[],
                        sensors_lbls=[],
                        sensors_inds=[],
                        fitmethod="optimizing",
Esempio n. 4
0
def sensitivity_analysis_pse_from_lsa_hypothesis(n_samples,
                                                 lsa_hypothesis,
                                                 connectivity_matrix,
                                                 model_configuration_builder,
                                                 lsa_service,
                                                 region_labels,
                                                 method="sobol",
                                                 half_range=0.1,
                                                 global_coupling=[],
                                                 healthy_regions_parameters=[],
                                                 save_services=False,
                                                 config=Config(),
                                                 **kwargs):
    logger = initialize_logger(__name__, config.out.FOLDER_LOGS)
    method = method.lower()
    if np.in1d(method, METHODS):
        if np.in1d(method, ["delta", "dgsm"]):
            sampler = "latin"
        elif method == "sobol":
            sampler = "saltelli"
        elif method == "fast":
            sampler = "fast_sampler"
        else:
            sampler = method
    else:
        raise_value_error("Method " + str(method) +
                          " is not one of the available methods " +
                          str(METHODS) + " !")
    all_regions_indices = range(lsa_hypothesis.number_of_regions)
    disease_indices = lsa_hypothesis.regions_disease_indices
    healthy_indices = np.delete(all_regions_indices, disease_indices).tolist()
    pse_params = {"path": [], "indices": [], "name": [], "low": [], "high": []}
    n_inputs = 0
    # First build from the hypothesis the input parameters of the sensitivity analysis.
    # These can be either originating from excitability, epileptogenicity or connectivity hypotheses,
    # or they can relate to the global coupling scaling (parameter K of the model configuration)
    for ii in range(len(lsa_hypothesis.x0_values)):
        n_inputs += 1
        pse_params["indices"].append([ii])
        pse_params["path"].append("hypothesis.x0_values")
        pse_params["name"].append(
            str(region_labels[lsa_hypothesis.x0_indices[ii]]) +
            " Excitability")
        pse_params["low"].append(lsa_hypothesis.x0_values[ii] - half_range)
        pse_params["high"].append(
            np.min(
                [MAX_DISEASE_VALUE,
                 lsa_hypothesis.x0_values[ii] + half_range]))
    for ii in range(len(lsa_hypothesis.e_values)):
        n_inputs += 1
        pse_params["indices"].append([ii])
        pse_params["path"].append("hypothesis.e_values")
        pse_params["name"].append(
            str(region_labels[lsa_hypothesis.e_indices[ii]]) +
            " Epileptogenicity")
        pse_params["low"].append(lsa_hypothesis.e_values[ii] - half_range)
        pse_params["high"].append(
            np.min(
                [MAX_DISEASE_VALUE, lsa_hypothesis.e_values[ii] + half_range]))
    for ii in range(len(lsa_hypothesis.w_values)):
        n_inputs += 1
        pse_params["indices"].append([ii])
        pse_params["path"].append("hypothesis.w_values")
        inds = linear_index_to_coordinate_tuples(lsa_hypothesis.w_indices[ii],
                                                 connectivity_matrix.shape)
        if len(inds) == 1:
            pse_params["name"].append(
                str(region_labels[inds[0][0]]) + "-" +
                str(region_labels[inds[0][0]]) + " Connectivity")
        else:
            pse_params["name"].append("Connectivity[" + str(inds), + "]")
            pse_params["low"].append(
                np.max([lsa_hypothesis.w_values[ii] - half_range, 0.0]))
            pse_params["high"].append(lsa_hypothesis.w_values[ii] + half_range)
    for val in global_coupling:
        n_inputs += 1
        pse_params["path"].append("model.configuration.service.K_unscaled")
        inds = val.get("indices", all_regions_indices)
        if np.all(inds == all_regions_indices):
            pse_params["name"].append("Global coupling")
        else:
            pse_params["name"].append("Afferent coupling[" + str(inds) + "]")
        pse_params["indices"].append(inds)
        pse_params["low"].append(val.get("low", 0.0))
        pse_params["high"].append(val.get("high", 2.0))
    # Now generate samples suitable for sensitivity analysis
    sampler = SalibSamplingService(n_samples=n_samples,
                                   sampler=sampler,
                                   random_seed=kwargs.get("random_seed", None))
    input_samples = sampler.generate_samples(low=pse_params["low"],
                                             high=pse_params["high"],
                                             **kwargs)
    n_samples = input_samples.shape[1]
    pse_params.update(
        {"samples": [np.array(value) for value in input_samples.tolist()]})
    pse_params_list = dicts_of_lists_to_lists_of_dicts(pse_params)
    # Add a random jitter to the healthy regions if required...:
    sampler = ProbabilisticSamplingService(n_samples=n_samples,
                                           random_seed=kwargs.get(
                                               "random_seed", None))
    for val in healthy_regions_parameters:
        inds = val.get("indices", healthy_indices)
        name = val.get("name", "x0_values")
        n_params = len(inds)
        samples = sampler.generate_samples(
            parameter=(
                kwargs.get("loc", 0.0),  # loc
                kwargs.get("scale", 2 * half_range)),  # scale
            probability_distribution="uniform",
            low=0.0,
            shape=(n_params, ))
        for ii in range(n_params):
            pse_params_list.append({
                "path": "model_configuration_builder." + name,
                "samples": samples[ii],
                "indices": [inds[ii]],
                "name": name
            })
    # Now run pse service to generate output samples:
    pse = LSAPSEService(hypothesis=lsa_hypothesis, params_pse=pse_params_list)
    pse_results, execution_status = pse.run_pse(connectivity_matrix, False,
                                                model_configuration_builder,
                                                lsa_service)
    pse_results = list_of_dicts_to_dicts_of_ndarrays(pse_results)
    # Now prepare inputs and outputs and run the sensitivity analysis:
    # NOTE!: Without the jittered healthy regions which we don' want to include into the sensitivity analysis!
    inputs = dicts_of_lists_to_lists_of_dicts(pse_params)
    outputs = [{
        "names": ["LSA Propagation Strength"],
        "values": pse_results["lsa_propagation_strengths"]
    }]
    sensitivity_analysis_service = SensitivityAnalysisService(
        inputs,
        outputs,
        method=method,
        calc_second_order=kwargs.get("calc_second_order", True),
        conf_level=kwargs.get("conf_level", 0.95))
    results = sensitivity_analysis_service.run(**kwargs)
    if save_services:
        logger.info(pse.__repr__())
        writer = H5Writer()
        writer.write_pse_service(
            pse,
            os.path.join(config.out.FOLDER_RES,
                         method + "_test_pse_service.h5"))
        logger.info(sensitivity_analysis_service.__repr__())
        writer.write_sensitivity_analysis_service(
            sensitivity_analysis_service,
            os.path.join(config.out.FOLDER_RES,
                         method + "_test_sa_service.h5"))
    return results, pse_results