def sensitivity_analysis_pse_from_hypothesis(n_samples, hypothesis, connectivity_matrix, region_labels, method="sobol", half_range=0.1, global_coupling=[], healthy_regions_parameters=[], save_services=False, config=Config(), model_config_kwargs={}, **kwargs): logger = initialize_logger(__name__, config.out.FOLDER_LOGS) # Compute lsa for this hypothesis before sensitivity analysis: logger.info("Running hypothesis: " + hypothesis.name) model_configuration_builder, model_configuration, lsa_service, lsa_hypothesis = \ start_lsa_run(hypothesis, connectivity_matrix, config, **model_config_kwargs) results, pse_results = sensitivity_analysis_pse_from_lsa_hypothesis( n_samples, lsa_hypothesis, connectivity_matrix, model_configuration_builder, lsa_service, region_labels, method, half_range, global_coupling, healthy_regions_parameters, save_services, config, **kwargs) return model_configuration_builder, model_configuration, lsa_service, lsa_hypothesis, results, pse_results
def main_fit_sim_hyplsa(ep_name="ep_l_frontal_complex", stats_model_name="vep_sde", EMPIRICAL="", times_on_off=[], sensors_lbls=[], sensors_inds=[], fitmethod="optimizing", stan_service="CmdStan", config=Config(), **kwargs): # ------------------------------Stan model and service-------------------------------------- # Compile or load model: # model_code_path = os.path.join(STATS_MODELS_PATH, stats_model_name + ".stan") model_code_path = os.path.join(config.generic.STATS_MODELS_PATH, stats_model_name + ".stan") if isequal_string(stan_service, "CmdStan"): stan_service = CmdStanService(model_name=stats_model_name, model=None, model_code=None, model_code_path=model_code_path, fitmethod=fitmethod, random_seed=12345, init="random", config=config) else: stan_service = PyStanService(model_name=stats_model_name, model=None, model_code=None, model_code_path=model_code_path, fitmethod=fitmethod, random_seed=12345, init="random", config=config) stan_service.set_or_compile_model() # -------------------------------Reading model_data and hypotheses-------------------------------------------------- head, hypos = from_head_to_hypotheses(ep_name, config, plot_head=False) for hyp in hypos[:1]: # --------------------------Model configuration and LSA----------------------------------- model_config_file = os.path.join(config.out.FOLDER_RES, hyp.name + "_ModelConfig.h5") hyp_file = os.path.join(config.out.FOLDER_RES, hyp.name + "_LSA.h5") if os.path.isfile(hyp_file) and os.path.isfile(model_config_file): model_configuration = reader.read_model_configuration( model_config_file) lsa_hypothesis = reader.read_hypothesis(hyp_file) else: model_configuration, lsa_hypothesis, model_configuration_builder, lsa_service = \ from_hypothesis_to_model_config_lsa(hyp, head, eigen_vectors_number=None, weighted_eigenvector_sum=True, config=config, K=K_DEF) dynamical_model = "EpileptorDP2D" # -------------------------- Get model_data and observation signals: ------------------------------------------- model_inversion_file = os.path.join( config.out.FOLDER_RES, hyp.name + "_ModelInversionService.h5") stats_model_file = os.path.join(config.out.FOLDER_RES, hyp.name + "_StatsModel.h5") model_data_file = os.path.join(config.out.FOLDER_RES, hyp.name + "_ModelData.h5") if os.path.isfile(model_inversion_file) and os.path.isfile(stats_model_file) \ and os.path.isfile(model_data_file): model_inversion = reader.read_model_inversions_service( model_inversion_file) statistical_model = reader.read_generic(stats_model_file) model_data = stan_service.load_model_data_from_file( model_data_path=model_data_file) else: model_inversion = SDEModelInversionService( model_configuration, lsa_hypothesis, head, dynamical_model, x1eq_max=-1.0, priors_mode="uninformative") # observation_expression="lfp" statistical_model = model_inversion.generate_statistical_model( x1eq_max=-1.0, observation_model="seeg_logpower") statistical_model = model_inversion.update_active_regions( statistical_model, methods=["e_values", "LSA"], active_regions_th=0.1, reset=True) plotter.plot_statistical_model(statistical_model, "Statistical Model") cut_signals_tails = (6, 6) n_electrodes = 8 sensors_per_electrode = 2 if os.path.isfile(EMPIRICAL): # ---------------------------------------Get empirical data------------------------------------------- target_data_type = "empirical" statistical_model.observation_model = "seeg_logpower" decimate = 2 ts_file = os.path.join(config.out.FOLDER_RES, hyp.name + "_ts_empirical.mat") try: vois_ts_dict = loadmat(ts_file) time = vois_ts_dict["time"].flatten() sensors_inds = np.array( vois_ts_dict["sensors_inds"]).flatten().tolist() sensors_lbls = np.array( vois_ts_dict["sensors_lbls"]).flatten().tolist() vois_ts_dict.update({ "time": time, "sensors_inds": sensors_inds, "sensors_lbls": sensors_lbls }) savemat(ts_file, vois_ts_dict) except: if len(sensors_lbls) == 0: sensor_lbls = head.get_sensors_id( ).labels[sensors_inds] signals, time, fs = prepare_seeg_observable(EMPIRICAL, times_on_off, sensors_lbls, plot_flag=True, log_flag=True) if len( sensors_inds ) > 1: # get_bipolar_channels(sensors_inds, sensors_lbls) sensors_inds, sensors_lbls = head.get_sensors_id( ).get_bipolar_sensors(sensors_inds=sensors_inds) inds = np.argsort(sensors_inds) sensors_inds = np.array( sensors_inds)[inds].flatten().tolist() sensors_lbls = np.array( sensors_lbls)[inds].flatten().tolist() all_signals = np.zeros( (signals.shape[0], len(model_inversion.sensors_labels))) all_signals[:, sensors_inds] = signals[:, inds] signals = all_signals del all_signals vois_ts_dict = { "time": time.flatten(), "signals": signals, "sensors_inds": sensors_inds, "sensors_lbls": sensors_lbls } savemat(ts_file, vois_ts_dict) model_inversion.sensors_labels[ vois_ts_dict["sensors_inds"]] = sensors_lbls manual_selection = sensors_inds else: # -------------------------- Get simulated data (simulate if necessary) ------------------------------- target_data_type = "simulated" statistical_model.observation_model = "seeg_logpower" # "seeg_logpower" # "lfp_power" decimate = 1 ts_file = os.path.join(config.out.FOLDER_RES, hyp.name + "_ts.h5") vois_ts_dict = \ from_model_configuration_to_simulation(model_configuration, head, lsa_hypothesis, sim_type="fitting", dynamical_model=dynamical_model, ts_file=ts_file, plot_flag=True, config=config) # if len(sensors_inds) > 1: # get_bipolar_channels(sensors_inds, sensors_lbls) # sensors_inds, sensors_lbls = head.get_sensors_id().get_bipolar_sensors(sensors_inds=sensors_inds) if statistical_model.observation_model.find("seeg") >= 0: manual_selection = sensors_inds else: if stats_model_name.find("vep-fe-rev") >= 0: manual_selection = statistical_model.active_regions else: manual_selection = [] # -------------------------- Select and set observation signals ----------------------------------- signals, time, statistical_model, vois_ts_dict = \ model_inversion.set_target_data_and_time(target_data_type, vois_ts_dict, statistical_model, decimate=decimate, cut_signals_tails=cut_signals_tails, manual_selection=manual_selection, auto_selection="correlation-power", # auto_selection=False, n_electrodes=n_electrodes, sensors_per_electrode=sensors_per_electrode, group_electrodes=True, normalization="baseline-amplitude", ) # if len(model_inversion.signals_inds) < head.get_sensors_id().number_of_sensors: # statistical_model = \ # model_inversion.update_active_regions_seeg(statistical_model) if model_inversion.data_type == "lfp": labels = model_inversion.region_labels else: labels = model_inversion.sensors_labels if vois_ts_dict.get("signals", None) is not None: vois_ts_dict["signals"] -= vois_ts_dict["signals"].min() vois_ts_dict["signals"] /= vois_ts_dict["signals"].max() plotter.plot_raster( {'Target Signals': vois_ts_dict["signals"]}, vois_ts_dict["time"].flatten(), time_units="ms", title=hyp.name + ' Target Signals raster', special_idx=model_inversion.signals_inds, offset=0.1, labels=labels) plotter.plot_timeseries( {'Target Signals': signals}, time, time_units="ms", title=hyp.name + ' Target Signals', labels=labels[model_inversion.signals_inds]) writer.write_model_inversion_service( model_inversion, os.path.join(config.out.FOLDER_RES, hyp.name + "_ModelInversionService.h5")) writer.write_generic(statistical_model, config.out.FOLDER_RES, hyp.name + "_StatsModel.h5") # try: # model_data = stan_service.load_model_data_from_file() # except: model_data = build_stan_model_dict(statistical_model, signals, model_inversion) writer.write_dictionary( model_data, os.path.join(config.out.FOLDER_RES, hyp.name + "_ModelData.h5")) if os.path.isfile(EMPIRICAL): simulation_values = None else: simulation_values = { "x0": model_configuration.x0, "x1eq": model_configuration.x1EQ, "x1init": model_configuration.x1EQ, "zinit": model_configuration.zEQ } # Stupid code to interface with INS stan model if stats_model_name.find("vep-fe-rev") >= 0: model_data = build_stan_model_dict_to_interface_ins( model_data, statistical_model, model_inversion) x1_str = "x" input_signals_str = "seeg_log_power" signals_str = "mu_seeg_log_power" dX1t_str = "x_eta" dZt_str = "z_eta" sig_str = "sigma" k_str = "k" pair_plot_params = [ "time_scale", "k", "sigma", "epsilon", "amplitude", "offset" ] region_violin_params = ["x0", "x_init", "z_init"] if simulation_values is not None: simulation_values.update({ "x0": simulation_values["x0"][statistical_model.active_regions] }) simulation_values.update({ "x_init": simulation_values["x1init"][ statistical_model.active_regions] }) simulation_values.update({ "z_init": simulation_values["zinit"][ statistical_model.active_regions] }) connectivity_plot = False estMC = lambda: model_configuration.model_connectivity region_mode = "active" else: x1_str = "x1" input_signals_str = "signals" signals_str = "fit_signals" dX1t_str = "dX1t" # "x1_dWt" dZt_str = "dZt" # "z_dWt" sig_str = "sig" k_str = "K" pair_plot_params = [ "tau1", "tau0", "K", "sig_init", "sig", "eps", "scale_signal", "offset_signal" ] region_violin_params = ["x0", "x1eq", "x1init", "zinit"] connectivity_plot = False estMC = lambda est: est["MC"] region_mode = "all" # -------------------------- Fit and get estimates: ------------------------------------------------------------ ests, samples, summary = stan_service.fit(debug=0, simulate=0, model_data=model_data, merge_outputs=False, chains=1, refresh=1, num_warmup=5, num_samples=5, max_depth=7, delta=0.8, **kwargs) writer.write_generic(ests, config.out.FOLDER_RES, hyp.name + "_fit_est.h5") writer.write_generic(samples, config.out.FOLDER_RES, hyp.name + "_fit_samples.h5") if summary is not None: writer.write_generic(summary, config.out.FOLDER_RES, hyp.name + "_fit_summary.h5") if isinstance(summary, dict): R_hat = summary.get("R_hat", None) if R_hat is not None: R_hat = {"R_hat": R_hat} ests = ensure_list(ests) plotter.plot_fit_results( model_inversion, ests, samples, statistical_model, model_data[input_signals_str], R_hat, model_data["time"], simulation_values, region_mode, seizure_indices=lsa_hypothesis.get_regions_disease_indices(), x1_str=x1_str, signals_str=signals_str, sig_str=sig_str, dX1t_str=dX1t_str, dZt_str=dZt_str, trajectories_plot=True, connectivity_plot=connectivity_plot, pair_plot_params=pair_plot_params, region_violin_params=region_violin_params) # -------------------------- Reconfigure model after fitting:--------------------------------------------------- # for id_est, est in enumerate(ensure_list(ests)): # fit_model_configuration_builder = \ # ModelConfigurationBuilder(hyp.number_of_regions, K=est[k_str] * hyp.number_of_regions) # x0_values_fit = \ # fit_model_configuration_builder._compute_x0_values_from_x0_model(est['x0']) # hyp_fit = HypothesisBuilder().set_nr_of_regions(head.connectivity.number_of_regions).set_name( # 'fit' + str(id_est) + "_" + hyp.name)._build_excitability_hypothesis(x0_values_fit, range( # model_configuration.number_of_regions)) # model_configuration_fit = fit_model_configuration_builder.build_model_from_hypothesis(hyp_fit, # est["MC"] # estMC(est)) # writer.write_model_configuration(model_configuration_fit, # os.path.join(config.out.FOLDER_RES, hyp_fit.name + "_ModelConfig.h5")) # # # Plot nullclines and equilibria of model configuration # plotter.plot_state_space(model_configuration_fit, # region_labels=model_inversion.region_labels, # special_idx=statistical_model.active_regions, # model="6d", zmode="lin", # figure_name=hyp_fit.name + "_Nullclines and equilibria") logger.info("Done!")
from tvb_epilepsy.io.h5_reader import H5Reader from tvb_epilepsy.plot.plotter import Plotter from tvb_epilepsy.service.hypothesis_builder import HypothesisBuilder from tvb_epilepsy.service.model_configuration_builder import ModelConfigurationBuilder from tvb_epilepsy.service.model_inversion.sde_model_inversion_service import SDEModelInversionService from tvb_epilepsy.service.model_inversion.stan.cmdstan_service import CmdStanService from tvb_epilepsy.service.model_inversion.stan.pystan_service import PyStanService from tvb_epilepsy.service.model_inversion.vep_stan_dict_builder import build_stan_model_dict, \ build_stan_model_dict_to_interface_ins from tvb_epilepsy.top.scripts.hypothesis_scripts import from_head_to_hypotheses, from_hypothesis_to_model_config_lsa from tvb_epilepsy.top.scripts.simulation_scripts import from_model_configuration_to_simulation from tvb_epilepsy.top.scripts.seeg_data_scripts import prepare_seeg_observable output = os.path.join(os.path.expanduser("~"), 'Dropbox', 'Work', 'VBtech', 'VEP', "results", "fit") config = Config(output_base=output) logger = initialize_logger(__name__, config.out.FOLDER_LOGS) reader = H5Reader() writer = H5Writer() plotter = Plotter(config) def main_fit_sim_hyplsa(ep_name="ep_l_frontal_complex", stats_model_name="vep_sde", EMPIRICAL="", times_on_off=[], sensors_lbls=[], sensors_inds=[], fitmethod="optimizing",
def sensitivity_analysis_pse_from_lsa_hypothesis(n_samples, lsa_hypothesis, connectivity_matrix, model_configuration_builder, lsa_service, region_labels, method="sobol", half_range=0.1, global_coupling=[], healthy_regions_parameters=[], save_services=False, config=Config(), **kwargs): logger = initialize_logger(__name__, config.out.FOLDER_LOGS) method = method.lower() if np.in1d(method, METHODS): if np.in1d(method, ["delta", "dgsm"]): sampler = "latin" elif method == "sobol": sampler = "saltelli" elif method == "fast": sampler = "fast_sampler" else: sampler = method else: raise_value_error("Method " + str(method) + " is not one of the available methods " + str(METHODS) + " !") all_regions_indices = range(lsa_hypothesis.number_of_regions) disease_indices = lsa_hypothesis.regions_disease_indices healthy_indices = np.delete(all_regions_indices, disease_indices).tolist() pse_params = {"path": [], "indices": [], "name": [], "low": [], "high": []} n_inputs = 0 # First build from the hypothesis the input parameters of the sensitivity analysis. # These can be either originating from excitability, epileptogenicity or connectivity hypotheses, # or they can relate to the global coupling scaling (parameter K of the model configuration) for ii in range(len(lsa_hypothesis.x0_values)): n_inputs += 1 pse_params["indices"].append([ii]) pse_params["path"].append("hypothesis.x0_values") pse_params["name"].append( str(region_labels[lsa_hypothesis.x0_indices[ii]]) + " Excitability") pse_params["low"].append(lsa_hypothesis.x0_values[ii] - half_range) pse_params["high"].append( np.min( [MAX_DISEASE_VALUE, lsa_hypothesis.x0_values[ii] + half_range])) for ii in range(len(lsa_hypothesis.e_values)): n_inputs += 1 pse_params["indices"].append([ii]) pse_params["path"].append("hypothesis.e_values") pse_params["name"].append( str(region_labels[lsa_hypothesis.e_indices[ii]]) + " Epileptogenicity") pse_params["low"].append(lsa_hypothesis.e_values[ii] - half_range) pse_params["high"].append( np.min( [MAX_DISEASE_VALUE, lsa_hypothesis.e_values[ii] + half_range])) for ii in range(len(lsa_hypothesis.w_values)): n_inputs += 1 pse_params["indices"].append([ii]) pse_params["path"].append("hypothesis.w_values") inds = linear_index_to_coordinate_tuples(lsa_hypothesis.w_indices[ii], connectivity_matrix.shape) if len(inds) == 1: pse_params["name"].append( str(region_labels[inds[0][0]]) + "-" + str(region_labels[inds[0][0]]) + " Connectivity") else: pse_params["name"].append("Connectivity[" + str(inds), + "]") pse_params["low"].append( np.max([lsa_hypothesis.w_values[ii] - half_range, 0.0])) pse_params["high"].append(lsa_hypothesis.w_values[ii] + half_range) for val in global_coupling: n_inputs += 1 pse_params["path"].append("model.configuration.service.K_unscaled") inds = val.get("indices", all_regions_indices) if np.all(inds == all_regions_indices): pse_params["name"].append("Global coupling") else: pse_params["name"].append("Afferent coupling[" + str(inds) + "]") pse_params["indices"].append(inds) pse_params["low"].append(val.get("low", 0.0)) pse_params["high"].append(val.get("high", 2.0)) # Now generate samples suitable for sensitivity analysis sampler = SalibSamplingService(n_samples=n_samples, sampler=sampler, random_seed=kwargs.get("random_seed", None)) input_samples = sampler.generate_samples(low=pse_params["low"], high=pse_params["high"], **kwargs) n_samples = input_samples.shape[1] pse_params.update( {"samples": [np.array(value) for value in input_samples.tolist()]}) pse_params_list = dicts_of_lists_to_lists_of_dicts(pse_params) # Add a random jitter to the healthy regions if required...: sampler = ProbabilisticSamplingService(n_samples=n_samples, random_seed=kwargs.get( "random_seed", None)) for val in healthy_regions_parameters: inds = val.get("indices", healthy_indices) name = val.get("name", "x0_values") n_params = len(inds) samples = sampler.generate_samples( parameter=( kwargs.get("loc", 0.0), # loc kwargs.get("scale", 2 * half_range)), # scale probability_distribution="uniform", low=0.0, shape=(n_params, )) for ii in range(n_params): pse_params_list.append({ "path": "model_configuration_builder." + name, "samples": samples[ii], "indices": [inds[ii]], "name": name }) # Now run pse service to generate output samples: pse = LSAPSEService(hypothesis=lsa_hypothesis, params_pse=pse_params_list) pse_results, execution_status = pse.run_pse(connectivity_matrix, False, model_configuration_builder, lsa_service) pse_results = list_of_dicts_to_dicts_of_ndarrays(pse_results) # Now prepare inputs and outputs and run the sensitivity analysis: # NOTE!: Without the jittered healthy regions which we don' want to include into the sensitivity analysis! inputs = dicts_of_lists_to_lists_of_dicts(pse_params) outputs = [{ "names": ["LSA Propagation Strength"], "values": pse_results["lsa_propagation_strengths"] }] sensitivity_analysis_service = SensitivityAnalysisService( inputs, outputs, method=method, calc_second_order=kwargs.get("calc_second_order", True), conf_level=kwargs.get("conf_level", 0.95)) results = sensitivity_analysis_service.run(**kwargs) if save_services: logger.info(pse.__repr__()) writer = H5Writer() writer.write_pse_service( pse, os.path.join(config.out.FOLDER_RES, method + "_test_pse_service.h5")) logger.info(sensitivity_analysis_service.__repr__()) writer.write_sensitivity_analysis_service( sensitivity_analysis_service, os.path.join(config.out.FOLDER_RES, method + "_test_sa_service.h5")) return results, pse_results