def update_active_regions_target_data(self, target_data, probabilistic_model, sensors, reset=False): if reset: probabilistic_model.update_active_regions([]) if target_data: active_regions = probabilistic_model.active_regions.tolist() gain_matrix = np.array(sensors.gain_matrix) signals_inds = sensors.get_sensors_inds_by_sensors_labels( target_data.space_labels) if len(signals_inds) != 0: gain_matrix = gain_matrix[signals_inds] for proj in gain_matrix: active_regions += select_greater_values_array_inds( proj, self.gain_matrix_th, self.n_signals_per_roi).tolist() active_regions = self.exclude_regions(active_regions) probabilistic_model.update_active_regions(active_regions) else: warning( "Skipping active regions setting by seeg power because no data were assigned to sensors!" ) else: warning( "Skipping active regions setting by seeg power because no target data were provided!" ) probabilistic_model.gain_matrix = sensors.gain_matrix[ signals_inds][:, probabilistic_model.active_regions] return probabilistic_model
def print_not_equal_message(attr, field1, field2, logger): # logger.error("\n\nValueError: Original and read object field "+ attr + " not equal!") # raise_value_error("\n\nOriginal and read object field " + attr + " not equal!") warning( "Original and read object field " + attr + " not equal!" + "\nOriginal field:\n" + str(field1) + "\nRead object field:\n" + str(field2), logger)
def get_truth(self, parameter_name): truth = self.ground_truth.get(parameter_name, np.nan) if truth is np.nan and self.target_data_type == Target_Data_Type.SYNTHETIC.value: truth = getattr(self.model_config, parameter_name, np.nan) if truth is np.nan: warning("Ground truth value for parameter " + parameter_name + " was not found!") return truth
def _region_parameters_violin_plots(self, samples_all, values=None, lines=None, stats=None, params=[], skip_samples=0, per_chain_or_run=False, labels=[], special_idx=None, figure_name="Regions parameters samples", figsize=FiguresConfig.VERY_LARGE_SIZE): if isinstance(values, dict): vals_fun = lambda param: values.get(param, numpy.array([])) else: vals_fun = lambda param: [] if isinstance(lines, dict): lines_fun = lambda param: lines.get(param, numpy.array([])) else: lines_fun = lambda param: [] samples_all = ensure_list(samples_all) params = [param for param in params if param in samples_all[0].keys()] samples = [] for sample in samples_all: samples.append(extract_dict_stringkeys(sample, params, modefun="equal")) if len(labels) == 0: labels = generate_region_labels(samples[0].values()[0].shape[-1], labels, numbering=False) n_chains = len(samples_all) n_samples = samples[0].values()[0].shape[0] if n_samples > 1: violin_flag = True else: violin_flag = False if not per_chain_or_run and n_chains > 1: samples = [merge_samples(samples)] plot_samples = lambda s: numpy.concatenate(numpy.split(s[:, skip_samples:].T, n_chains, axis=2), axis=1).squeeze().T plot_figure_name = lambda ichain: figure_name else: plot_samples = lambda s: s[skip_samples:] plot_figure_name = lambda ichain: figure_name + ": chain " + str(ichain + 1) params_labels = {} for ip, p in enumerate(params): if ip == 0: params_labels[p] = self._params_stats_labels(p, stats, labels) else: params_labels[p] = self._params_stats_labels(p, stats, "") n_params = len(params) if n_params > 9: warning("Number of subplots in column wise vector-violin-plots cannot be > 9 and it is " + str(n_params) + "!") subplot_ind = 100 + n_params * 10 figs = [] for ichain, chain_sample in enumerate(samples): pyplot.figure(plot_figure_name(ichain), figsize=figsize) for ip, param in enumerate(params): fig = self.plot_vector_violin(plot_samples(chain_sample[param]), vals_fun(param), lines_fun(param), params_labels[param], subplot_ind + ip + 1, param, violin_flag=violin_flag, colormap="YlOrRd", show_y_labels=True, indices_red=special_idx, sharey=None) self._save_figure(pyplot.gcf(), None) self._check_show() figs.append(fig) return tuple(figs)
def get_prior_pdf(self, parameter_name): parameter_mean, parameter = self.get_prior(parameter_name) if isinstance(parameter, (ProbabilisticParameterBase, TransformedProbabilisticParameterBase)): return parameter.scipy_method("pdf") else: warning("No probabilistic parameter " + parameter_name + " was found!" "\nReturning prior value, if available, instead of pdf!") return parameter_mean, np.nan
def read_output(self): samples = self.read_output_samples(self.output_filepath) est = self.compute_estimates_from_samples(samples) if os.path.isfile(self.summary_filepath): try: summary = parse_csv_in_cols(self.summary_filepath) except: summary = None warning("Reading stan summary failed!") else: summary = None return est, samples, summary
def convert_params_names_to_ins(dicts_list, parameter_names=INS_PARAMS_NAMES_DICT): output = [] for lst in ensure_list(dicts_list): for dct in ensure_list(lst): for p, p_ins in parameter_names.items(): try: dct[p_ins] = dct[p] except: warning("Parameter " + p + " not found in \n" + str(dicts_list)) output.append(lst) return tuple(output)
def build_model_config_from_model_config(self, model_config): if not isinstance(model_config, dict): model_config_dict = model_config.__dict__ else: model_config_dict = model_config model_configuration = ModelConfiguration() for attr, value in model_configuration.__dict__.items(): value = model_config_dict.get(attr, None) if value is None: warning(attr + " not found in the input model configuration dictionary!" + "\nLeaving default " + attr + ": " + str(getattr(model_configuration, attr))) if value is not None: setattr(model_configuration, attr, value) return model_configuration
def get_prior(self, parameter_name): parameter = self.get_parameter(parameter_name) if parameter is None: warning("No probabilistic prior for parameter " + parameter_name + " was found!") # TODO: decide if it is a good idea to return this kind of modeler's fixed "prior"...: pmean = getattr(self, parameter_name, np.nan) if pmean is np.nan: pmean = getattr(self.model_config, parameter_name, np.nan) if pmean is np.nan: warning("No prior value for parameter " + parameter_name + " was found!") return pmean, parameter else: return parameter.mean, parameter
def __init__(self, input="EpileptorDP", connectivity=None, K_unscaled=np.array([K_UNSCALED_DEF]), x0_values=X0_DEF, e_values=E_DEF, x1eq_mode="optimize", **kwargs): if isinstance(input, Simulator): # TODO: make this more specific once we clarify the model configuration representation compared to simTVB self.model_name = input.model._ui_name self.set_params_from_tvb_model(input.model) self.connectivity = normalize_weights(input.connectivity.weights) # self.coupling = input.coupling self.initial_conditions = np.squeeze(input.initial_conditions) # initial conditions in a reduced form # self.noise = input.integrator.noise # self.monitor = ensure_list(input.monitors)[0] else: if isinstance(input, Model): self.model_name = input._ui_name self.set_params_from_tvb_model(input) elif isinstance(input, basestring): self.model_name = input else: raise_value_error("Input (%s) is not a TVB simulator, an epileptor model, " "\nor a string of an epileptor model!") if isinstance(connectivity, Connectivity): self.connectivity = connectivity.normalized_weights elif isinstance(connectivity, TVBConnectivity): self.connectivity = normalize_weights(connectivity.weights) elif isinstance(connectivity, np.ndarray): self.connectivity = normalize_weights(connectivity) else: if not(isinstance(input, Simulator)): warning("Input connectivity (%s) is not a virtual patient connectivity, a TVB connectivity, " "\nor a numpy.array!" % str(connectivity)) self.x0_values = x0_values * np.ones((self.number_of_regions,), dtype=np.float32) self.x1eq_mode = x1eq_mode if len(ensure_list(K_unscaled)) == 1: K_unscaled = np.array(K_unscaled) * np.ones((self.number_of_regions,), dtype=np.float32) elif len(ensure_list(K_unscaled)) == self.number_of_regions: K_unscaled = np.array(K_unscaled) else: self.logger.warning( "The length of input global coupling K_unscaled is neither 1 nor equal to the number of regions!" + "\nSetting model_configuration_builder.K_unscaled = K_UNSCALED_DEF for all regions") self.set_K_unscaled(K_unscaled) for pname in EPILEPTOR_PARAMS: self.set_parameter(pname, kwargs.get(pname, getattr(self, pname))) # Update K_unscaled self.e_values = e_values * np.ones((self.number_of_regions,), dtype=np.float32) self.x0cr = 0.0 self.rx0 = 0.0 self._compute_critical_x0_scaling()
def _set_attributes_from_dict(self, attributes_dict): if not isinstance(attributes_dict, dict): attributes_dict = attributes_dict.__dict__ for attr, value in attributes_dict.items(): if not attr in [ "model_config", "parameters", "number_of_regions", "number_of_parameters" ]: value = attributes_dict.get(attr, None) if value is None: warning(attr + " not found in input dictionary!" + "\nLeaving as it is: " + attr + " = " + str(getattr(self, attr))) if value is not None: setattr(self, attr, value) return attributes_dict
def update_active_regions_x0_values(self, probabilistic_model, x0_values, reset=False): active_regions = probabilistic_model.active_regions.tolist() if reset: active_regions = [] if len(x0_values) > 0: active_regions += select_greater_values_array_inds( x0_values, self.active_x0_th).tolist() active_regions = self.exclude_regions(active_regions) probabilistic_model.update_active_regions(active_regions) else: warning( "Skipping active regions setting by x0 values because no such values were provided!" ) return probabilistic_model
def _write_dicts_at_location(self, datasets_dict, metadata_dict, location): for key, value in datasets_dict.items(): try: location.create_dataset(key, data=value) except: warning( "Failed to write to %s dataset %s %s:\n%s !" % (str(location), value.__class__, key, str(value)), self.logger) for key, value in metadata_dict.items(): try: location.attrs.create(key, value) except: warning( "Failed to write to %s attribute %s %s:\n%s !" % (str(location), value.__class__, key, str(value)), self.logger) return location
def _determine_datasets_and_attributes(self, object, datasets_size=None): datasets_dict = {} metadata_dict = {} groups_keys = [] try: if isinstance(object, dict): dict_object = object else: dict_object = vars(object) for key, value in dict_object.items(): if isinstance(value, numpy.ndarray): # if value.size == 1: # metadata_dict.update({key: value}) # else: datasets_dict.update({key: value}) # if datasets_size is not None and value.size == datasets_size: # datasets_dict.update({key: value}) # else: # if datasets_size is None and value.size > 0: # datasets_dict.update({key: value}) # else: # metadata_dict.update({key: value}) # TODO: check how this works! Be carefull not to include lists and tuples if possible in tvb_fit classes! elif isinstance(value, (list, tuple)): warning( "Writing %s %s to h5 file as a numpy array dataset !" % (value.__class__, key), self.logger) datasets_dict.update({key: numpy.array(value)}) else: if is_numeric(value) or isinstance(value, str): metadata_dict.update({key: value}) elif not (callable(value)): groups_keys.append(key) except: msg = "Failed to decompose group object: " + str(object) + "!" try: self.logger.info(str(object.__dict__)) except: msg += "\n It has no __dict__ attribute!" warning(msg, self.logger) return datasets_dict, metadata_dict, groups_keys
def plot_bars(self, data, ax=None, fig=None, title="", group_names=[], legend_prefix="", figsize=FiguresConfig.VERY_LARGE_SIZE): def barlabel(ax, rects, positions): """ Attach a text label on each bar displaying its height """ for rect, pos in zip(rects, positions): height = rect.get_height() if pos < 0: y = -height pos = 0.75 * pos else: y = height pos = 0.25 * pos ax.text(rect.get_x() + rect.get_width() / 2., pos, '%0.2f' % y, color="k", ha='center', va='bottom', rotation=90) if fig is None: fig, ax = pyplot.subplots(1, 1, figsize=figsize) show_and_save = True else: show_and_save = False if ax is None: ax = pyplot.gca() if isinstance( data, (list, tuple)): # If, there are many groups, data is a list: # Fill in with nan in case that not all groups have the same number of elements from itertools import izip_longest data = numpy.array( list(izip_longest(*ensure_list(data), fillvalue=numpy.nan))).T elif data.ndim == 1: # This is the case where there is only one group... data = numpy.expand_dims(data, axis=1).T n_groups, n_elements = data.shape posmax = numpy.nanmax(data) negmax = numpy.nanmax(-(-data)) n_groups_names = len(group_names) if n_groups_names != n_groups: if n_groups_names != 0: warning("Ignoring group_names because their number (" + str(n_groups_names) + ") is not equal to the number of groups (" + str(n_groups) + ")!") group_names = n_groups * [""] colorcycle = pyplot.rcParams['axes.prop_cycle'].by_key()['color'] n_colors = len(colorcycle) x_inds = numpy.arange(n_groups) width = 0.9 / n_elements elements = [] for iE in range(n_elements): elements.append( ax.bar(x_inds + iE * width, data[:, iE], width, color=colorcycle[iE % n_colors])) positions = numpy.array( [negmax if d < 0 else posmax for d in data[:, iE]]) positions[numpy.logical_or(numpy.isnan(positions), numpy.isinf( numpy.abs(positions)))] = 0.0 barlabel(ax, elements[-1], positions) if n_elements > 1: legend = [ legend_prefix + str(ii) for ii in range(1, n_elements + 1) ] ax.legend(tuple([element[0] for element in elements]), tuple(legend)) ax.set_xticks(x_inds + n_elements * width / 2) ax.set_xticklabels(tuple(group_names)) ax.set_title(title) ax.autoscale() # tight=True ax.set_xlim([-1.05 * width, n_groups * 1.05]) if show_and_save: fig.tight_layout() self._save_figure(fig) self._check_show() return fig, ax
def run_lsa(self, disease_hypothesis, model_configuration): if self.lsa_method == "auto": if numpy.any(model_configuration.x1eq > X1EQ_CR_DEF): self.lsa_method = "2D" else: self.lsa_method = "1D" if self.lsa_method == "2D" and numpy.all( model_configuration.x1eq <= X1EQ_CR_DEF): warning( "LSA with the '2D' method (on the 2D Epileptor model) will not produce interpretable results when" " the equilibrium point of the system is not supercritical (unstable)!" ) jacobian = self._compute_jacobian(model_configuration) # Perform eigenvalue decomposition eigen_values, eigen_vectors = numpy.linalg.eig(jacobian) eigen_values = numpy.real(eigen_values) eigen_vectors = numpy.real(eigen_vectors) sorted_indices = numpy.argsort(eigen_values, kind='mergesort') if self.lsa_method == "2D": sorted_indices = sorted_indices[::-1] self.eigen_vectors = eigen_vectors[:, sorted_indices] self.eigen_values = eigen_values[sorted_indices] self._ensure_eigen_vectors_number( self.eigen_values[:disease_hypothesis.number_of_regions], model_configuration.e_values, model_configuration.x0_values, disease_hypothesis.regions_disease_indices) if self.eigen_vectors_number == disease_hypothesis.number_of_regions: # Calculate the propagation strength index by summing all eigenvectors lsa_propagation_strength = numpy.abs( numpy.sum(self.eigen_vectors, axis=1)) else: sorted_indices = max(self.eigen_vectors_number, 1) # Calculate the propagation strength index by summing the first n eigenvectors (minimum 1) if self.weighted_eigenvector_sum: lsa_propagation_strength = \ numpy.abs(weighted_vector_sum(numpy.array(self.eigen_values[:self.eigen_vectors_number]), numpy.array(self.eigen_vectors[:, :self.eigen_vectors_number]), normalize=True)) else: lsa_propagation_strength = \ numpy.abs(numpy.sum(self.eigen_vectors[:, :self.eigen_vectors_number], axis=1)) if self.lsa_method == "2D": # lsa_propagation_strength = lsa_propagation_strength[:disease_hypothesis.number_of_regions] # or # lsa_propagation_strength = numpy.where(lsa_propagation_strength[:disease_hypothesis.number_of_regions] >= # lsa_propagation_strength[disease_hypothesis.number_of_regions:], # lsa_propagation_strength[:disease_hypothesis.number_of_regions], # lsa_propagation_strength[disease_hypothesis.number_of_regions:]) # or lsa_propagation_strength = numpy.sqrt( lsa_propagation_strength[:disease_hypothesis.number_of_regions] **2 + lsa_propagation_strength[disease_hypothesis.number_of_regions:] **2) lsa_propagation_strength = numpy.log10(lsa_propagation_strength) lsa_propagation_strength -= lsa_propagation_strength.min() if self.normalize_propagation_strength: # Normalize by the maximum lsa_propagation_strength /= numpy.max(lsa_propagation_strength) # # TODO: this has to be corrected # if self.eigen_vectors_number < 0.2 * disease_hypothesis.number_of_regions: # propagation_strength_elbow = numpy.max([self.get_curve_elbow_point(lsa_propagation_strength), # self.eigen_vectors_number]) # else: propagation_strength_elbow = self.get_curve_elbow_point( lsa_propagation_strength) propagation_indices = lsa_propagation_strength.argsort( )[-propagation_strength_elbow:] hypothesis_builder = HypothesisBuilder(disease_hypothesis.number_of_regions). \ set_attributes_based_on_hypothesis(disease_hypothesis). \ set_name(disease_hypothesis.name + "_LSA"). \ set_lsa_propagation(propagation_indices, lsa_propagation_strength) return hypothesis_builder.build_lsa_hypothesis()
def assert_equal_objects(obj1, obj2, attributes_dict=None, logger=None): def print_not_equal_message(attr, field1, field2, logger): # logger.error("\n\nValueError: Original and read object field "+ attr + " not equal!") # raise_value_error("\n\nOriginal and read object field " + attr + " not equal!") warning( "Original and read object field " + attr + " not equal!" + "\nOriginal field:\n" + str(field1) + "\nRead object field:\n" + str(field2), logger) if isinstance(obj1, dict): get_field1 = lambda obj, key: obj[key] if not (isinstance(attributes_dict, dict)): attributes_dict = dict() for key in obj1.keys(): attributes_dict.update({key: key}) elif isinstance(obj1, (list, tuple)): get_field1 = lambda obj, key: get_list_or_tuple_item_safely(obj, key) indices = range(len(obj1)) attributes_dict = dict(zip([str(ind) for ind in indices], indices)) else: get_field1 = lambda obj, attribute: getattr(obj, attribute) if not (isinstance(attributes_dict, dict)): attributes_dict = dict() for key in obj1.__dict__.keys(): attributes_dict.update({key: key}) if isinstance(obj2, dict): get_field2 = lambda obj, key: obj.get(key, None) elif isinstance(obj2, (list, tuple)): get_field2 = lambda obj, key: get_list_or_tuple_item_safely(obj, key) else: get_field2 = lambda obj, attribute: getattr(obj, attribute, None) equal = True for attribute in attributes_dict: # print attributes_dict[attribute] field1 = get_field1(obj1, attributes_dict[attribute]) field2 = get_field2(obj2, attributes_dict[attribute]) try: # TODO: a better hack for the stupid case of an ndarray of a string, such as model.zmode or pmode # For non numeric types if isinstance(field1, basestring) or isinstance(field1, list) or isinstance(field1, dict) \ or (isinstance(field1, np.ndarray) and field1.dtype.kind in 'OSU'): if np.any(field1 != field2): print_not_equal_message(attributes_dict[attribute], field1, field2, logger) equal = False # For numeric numpy arrays: elif isinstance(field1, np.ndarray) and not field1.dtype.kind in 'OSU': # TODO: handle better accuracy differences, empty matrices and complex numbers... if field1.shape != field2.shape: print_not_equal_message(attributes_dict[attribute], field1, field2, logger) equal = False elif np.any(np.float32(field1) - np.float32(field2) > 0): print_not_equal_message(attributes_dict[attribute], field1, field2, logger) equal = False # For numeric scalar types elif is_numeric(field1): if np.float32(field1) - np.float32(field2) > 0: print_not_equal_message(attributes_dict[attribute], field1, field2, logger) equal = False else: equal = assert_equal_objects(field1, field2, logger=logger) except: try: warning( "Comparing str(objects) for field " + str(attributes_dict[attribute]) + " because there was an error!", logger) if np.any(str(field1) != str(field2)): print_not_equal_message(attributes_dict[attribute], field1, field2, logger) equal = False except: raise_value_error( "ValueError: Something went wrong when trying to compare " + str(attributes_dict[attribute]) + " !", logger) if equal: return True else: return False
def compute_information_criteria(self, samples, nparams=None, nsamples=None, ndata=None, parameters=[], skip_samples=0, merge_chains_or_runs_flag=False, log_like_str='log_likelihood'): """ :param samples: a dictionary of stan outputs or a list of dictionaries for multiple runs/chains :param nparams: number of model parameters, it can be inferred from parameters if None :param nsamples: number of samples, it can be inferred from loglikelihood if None :param ndata: number of data points, it can be inferred from loglikelihood if None :param parameters: a list of parameter names, necessary for dic metric computations and in case nparams is None, as well as for aicc, aic and bic computation :param merge_chains_or_runs_flag: logical flag for merging seperate chains/runs, default is True :param log_like_str: the name of the log likelihood output of stan, default ''log_likelihood :return: """ import sys sys.path.insert(0, self.config.generic.MODEL_COMPARISON_PATH) from information_criteria.ComputeIC import maxlike, aicc, aic, bic, dic, waic from information_criteria.ComputePSIS import psisloo # if self.fitmethod.find("opt") >= 0: # warning("No model comparison can be computed for optimization method!") # return None samples = ensure_list(samples) if merge_chains_or_runs_flag and len(samples) > 1: samples = ensure_list( merge_samples(samples, skip_samples, flatten=True)) skip_samples = 0 results = [] for sample in samples: log_likelihood = -1 * sample[log_like_str][skip_samples:] log_lik_shape = log_likelihood.shape if len(log_lik_shape) > 1: target_shape = log_lik_shape[1:] else: target_shape = (1, ) if nsamples is None: nsamples = log_lik_shape[0] elif nsamples != log_likelihood.shape[0]: warning("nsamples (" + str(nsamples) + ") is not equal to likelihood.shape[0] (" + str(log_lik_shape[0]) + ")!") log_likelihood = np.reshape(log_likelihood, (log_lik_shape[0], -1)) if log_likelihood.shape > 1: ndata_real = np.maximum(log_likelihood.shape[1], 1) else: ndata_real = 1 if ndata is None: ndata = ndata_real elif ndata != ndata_real: warning("ndata (" + str(ndata) + ") is not equal to likelihood.shape[1] (" + str(ndata_real) + ")!") result = maxlike(log_likelihood) if len(parameters) == 0: parameters = [ param for param in sample.keys() if param.find("_star") >= 0 ] if len(parameters) > 0: nparams_real = 0 zscore_params = [] for p in parameters: pval = sample[p][skip_samples:] pzscore = np.array( (pval - np.mean(pval, axis=0)) / np.std(pval, axis=0)) if len(pzscore.shape) > 2: pzscore = np.reshape(pzscore, (pzscore.shape[0], -1)) zscore_params.append(pzscore) if len(pzscore.shape) > 1: nparams_real += np.maximum(pzscore.shape[1], 1) else: nparams_real += 1 if nparams is None: nparams = nparams_real elif nparams != nparams_real: warning( "nparams (" + str(nparams) + ") is not equal to number of parameters included in the dic computation (" + str(nparams_real) + ")!") # TODO: find out how to reduce dic to 1 value, from 1 value per parameter. mean(.) for the moment: result['dic'] = np.mean(dic(log_likelihood, zscore_params)) else: warning( "Parameters' names' list is empty and we found no _star parameters! No computation of dic!" ) if nparams is not None: result['aicc'] = aicc(log_likelihood, nparams, ndata) result['aic'] = aic(log_likelihood, nparams) result['bic'] = bic(log_likelihood, nparams, ndata) else: warning( "Unknown number of parameters! No computation of aic, aaic, bic!" ) result.update(waic(log_likelihood)) if nsamples > 1: result.update(psisloo(log_likelihood)) result["loos"] = np.reshape(result["loos"], target_shape) result["ks"] = np.reshape(result["ks"], target_shape) else: result.pop('p_waic', None) for metric, value in result.items(): result[metric] = value * np.ones(1, ) results.append(result) if len(results) == 1: return results[0] else: return list_of_dicts_to_dicts_of_ndarrays(results)
def get_parameter(self, parameter_name): parameter = self.parameters.get(parameter_name, None) if parameter is None: warning("Ground truth value for parameter " + parameter_name + " was not found!") return parameter
def plot_array_model_comparison(self, model_comps, title_prefix="", metrics=["loos", "ks"], labels=[], xdata=None, xlabel="", figsize=FiguresConfig.VERY_LARGE_SIZE[::-1], figure_name=None): def arrange_chains_or_runs(metric_data): n_chains_or_runs = 1 for imodel, model in enumerate(metric_data): if model.ndim > 2: if model.shape[0] > n_chains_or_runs: n_chains_or_runs = model.shape[0] else: metric_data[imodel] = numpy.expand_dims(model, axis=0) return metric_data colorcycle = pyplot.rcParams['axes.prop_cycle'].by_key()['color'] n_colors = len(colorcycle) metrics = [metric for metric in metrics if metric in model_comps.keys()] figs=[] axs = [] for metric in metrics: if isinstance(model_comps[metric], dict): # Multiple models as a list of np.arrays of chains x data metric_data = model_comps[metric].values() model_names = model_comps[metric].keys() else: # Single models as a one element list of one np.array of chains x data metric_data = [model_comps[metric]] model_names = [""] metric_data = arrange_chains_or_runs(metric_data) n_models = len(metric_data) for jj in range(n_models): # Necessary because ks gets infinite sometimes... temp = metric_data[jj] == numpy.inf if numpy.all(temp): warning("All values are inf for metric " + metric + " of model " + model_names[ii] + "!\n") return elif numpy.any(temp): warning("Inf values found for metric " + metric + " of model " + model_names[ii] + "!\n" + "Substituting them with the maximum non-infite value!") metric_data[jj][temp] = metric_data[jj][~temp].max() n_subplots = metric_data[0].shape[1] n_labels = len(labels) if n_labels != n_subplots: if n_labels != 0: warning("Ignoring labels because their number (" + str(n_labels) + ") is not equal to the number of row subplots (" + str(n_subplots) + ")!") labels = [str(ii + 1) for ii in range(n_subplots)] if xdata is None: xdata = numpy.arange(metric_data[jj].shape[-1]) else: xdata = xdata.flatten() xdata0 = numpy.concatenate([numpy.reshape(xdata[0] - 0.1*(xdata[-1]-xdata[0]), (1,)), xdata]) xdata1 = xdata[-1] + 0.1 * (xdata[-1] - xdata[0]) if len(title_prefix) > 0: title = title_prefix + ": " + metric else: title = metric figs.append(pyplot.figure(title, figsize=figsize)) figs[-1].suptitle(title) figs[-1].set_label(title) gs = gridspec.GridSpec(n_subplots, n_models) axes = numpy.empty((n_subplots, n_models), dtype="O") for ii in range(n_subplots-1,-1, -1): for jj in range(n_models): if ii > n_subplots-1: if jj > 0: axes[ii, jj] = pyplot.subplot(gs[ii, jj], sharex=axes[n_subplots-1, jj], sharey=axes[ii, 0]) else: axes[ii, jj] = pyplot.subplot(gs[ii, jj], sharex=axes[n_subplots-1, jj]) else: if jj > 0: axes[ii, jj] = pyplot.subplot(gs[ii, jj], sharey=axes[ii, 0]) else: axes[ii, jj] = pyplot.subplot(gs[ii, jj]) n_chains_or_runs = metric_data[jj].shape[0] for kk in range(n_chains_or_runs): c = colorcycle[kk % n_colors] axes[ii, jj].plot(xdata, metric_data[jj][kk][ii, :], label="chain/run " + str(kk + 1), marker="o", markersize=5, markeredgecolor=c, markerfacecolor=c, linestyle="-", linewidth=1) m = numpy.nanmean(metric_data[jj][kk][ii, :]) axes[ii, jj].plot(xdata0, m * numpy.ones(xdata0.shape), color=c, linewidth=1) axes[ii, jj].text(xdata0[0], 1.1 * m, 'mean=%0.2f' % m, ha='center', va='bottom', color=c) axes[ii, jj].set_xlabel(xlabel) if ii == 0: axes[ii, jj].set_title(model_names[ii]) if n_chains_or_runs > 1: axes[ii, jj].legend() if ii == n_subplots-1: axes[ii, 0].autoscale() # tight=True axes[ii, 0].set_xlim([xdata0[0], xdata1]) # tight=True # figs[-1].tight_layout() self._save_figure(figs[-1], figure_name) self._check_show() axs.append(axes) return tuple(figs), tuple(axs)
def run_fitting(probabilistic_model, stan_model_name, model_data, target_data, config, head=None, seizure_indices=[], pair_plot_params=[ "tau1", "tau0", "K", "sigma", "epsilon", "scale", "offset" ], region_violin_params=["x0", "PZ", "x1eq", "zeq"], fit_flag=True, test_flag=False, base_path="", fitmethod="sample", n_chains_or_runs=2, output_samples=200, num_warmup=100, min_samples_per_chain=200, max_depth=15, delta=0.95, iter=500000, tol_rel_obj=1e-6, debug=1, simulate=0, step_prefix='', writer=None, plotter=None, **kwargs): # ------------------------------Stan model and service-------------------------------------- model_code_path = os.path.join(config.generic.PROBLSTC_MODELS_PATH, stan_model_name + ".stan") stan_interface = CmdStanInterface(model_name=stan_model_name, model_dir=base_path, model_code_path=model_code_path, fitmethod=fitmethod, config=config) stan_interface.set_or_compile_model() stan_interface.model_data_path = os.path.join(base_path, "ModelData.h5") # -------------------------- Fit and get estimates: ------------------------------------------------------------ n_chains_or_runs = np.where(test_flag, 2, n_chains_or_runs) output_samples = np.where( test_flag, 20, max(int(np.round(output_samples * 1.0 / n_chains_or_runs)), min_samples_per_chain)) # Sampling (HMC) num_samples = output_samples num_warmup = np.where(test_flag, 30, num_warmup) max_depth = np.where(test_flag, 7, max_depth) delta = np.where(test_flag, 0.8, delta) # ADVI or optimization: iter = np.where(test_flag, 1000, iter) if fitmethod.find("sampl") >= 0: skip_samples = num_warmup else: skip_samples = 0 prob_model_name = probabilistic_model.name.split(".")[0] if fit_flag: estimates, samples, summary = stan_interface.fit( debug=debug, simulate=simulate, model_data=model_data, n_chains_or_runs=n_chains_or_runs, refresh=1, iter=iter, tol_rel_obj=tol_rel_obj, output_samples=output_samples, num_warmup=num_warmup, num_samples=num_samples, max_depth=max_depth, delta=delta, save_warmup=1, plot_warmup=1, output_path=base_path, **kwargs) # TODO: check if write_dictionary is enough for estimates, samples, summary and info_crit if writer: writer.write_list_of_dictionaries( estimates, path(prob_model_name + "_FitEst", base_path)) writer.write_list_of_dictionaries( samples, path(prob_model_name + "_FitSamples", base_path)) if summary is not None: writer.write_dictionary( summary, path(prob_model_name + "_FitSummary", base_path)) else: stan_interface.set_output_files(base_path=base_path, update=True) estimates, samples, summary = stan_interface.read_output() # Model comparison: # scale_signal, offset_signal, time_scale, epsilon, sigma -> 5 (+ K = 6) # x0[active] -> probabilistic_model.model.number_of_active_regions # x1init[active], zinit[active] -> 2 * probabilistic_model.number_of_active_regions # dZt[active, t] -> probabilistic_model.number_of_active_regions * (probabilistic_model.time_length-1) number_of_total_params = \ 5 + probabilistic_model.number_of_active_regions * (3 + (probabilistic_model.time_length - 1)) info_crit = \ stan_interface.compute_information_criteria(samples, number_of_total_params, skip_samples=skip_samples, # parameters=["amplitude_star", "offset_star", "epsilon_star", # "sigma_star", "time_scale_star", "x0_star", # "x_init_star", "z_init_star", "z_eta_star"], merge_chains_or_runs_flag=False) if writer: writer.write_dictionary(info_crit, path(prob_model_name + "_InfoCrit", base_path)) Rhat = stan_interface.get_Rhat(summary) # Interface backwards with INS stan models # from tvb_fit.service.model_inversion.vep_stan_dict_builder import convert_params_names_from_ins # estimates, samples, Rhat, model_data = \ # convert_params_names_from_ins([estimates, samples, Rhat, model_data]) if fitmethod.find("opt") < 0 and Rhat is not None: stats = {"Rhat": Rhat} else: stats = None if plotter: # -------------------------- Plot fitting results: ------------------------------------------------------------ try: if fitmethod.find("sampl") >= 0: plotter.plot_HMC(samples, figure_name=step_prefix + prob_model_name + " HMC NUTS trace") plotter.plot_fit_results( estimates, samples, model_data, target_data, probabilistic_model, info_crit, stats=stats, seizure_indices=seizure_indices, pair_plot_params=pair_plot_params, # region_violin_params=region_violin_params, region_labels=head.connectivity.region_labels, skip_samples=skip_samples, title_prefix=step_prefix + prob_model_name) except: warning("Fitting plotting failed for step %s" % step_prefix) return estimates, samples, summary, info_crit
def plot_spectral_analysis_raster(self, time, data, time_units="ms", freq=None, spectral_options={}, special_idx=[], labels=[], title='Spectral Analysis', figure_name=None, figsize=FiguresConfig.VERY_LARGE_SIZE): nS = data.shape[1] n_special_idx = len(special_idx) if n_special_idx > 0: data = data[:, special_idx] nS = data.shape[1] if len(labels) > n_special_idx: labels = numpy.array([ str(ilbl) + ". " + str(labels[ilbl]) for ilbl in special_idx ]) elif len(labels) == n_special_idx: labels = numpy.array([ str(ilbl) + ". " + str(label) for ilbl, label in zip(special_idx, labels) ]) else: labels = numpy.array([str(ilbl) for ilbl in special_idx]) else: if len(labels) != nS: labels = numpy.array([str(ilbl) for ilbl in range(nS)]) if nS > 20: warning( "It is not possible to plot spectral analysis plots for more than 20 signals!" ) return if not isinstance(time_units, basestring): time_units = list(time_units)[0] time_units = ensure_string(time_units) if time_units in ("ms", "msec"): fs = 1000.0 else: fs = 1.0 fs = fs / numpy.mean(numpy.diff(time)) log_norm = spectral_options.get("log_norm", False) mode = spectral_options.get("mode", "psd") psd_label = mode if log_norm: psd_label = "log" + psd_label stf, time, freq, psd = time_spectral_analysis( data, fs, freq=freq, mode=mode, nfft=spectral_options.get("nfft"), window=spectral_options.get("window", 'hanning'), nperseg=spectral_options.get("nperseg", int(numpy.round(fs / 4))), detrend=spectral_options.get("detrend", 'constant'), noverlap=spectral_options.get("noverlap"), f_low=spectral_options.get("f_low", 10.0), log_scale=spectral_options.get("log_scale", False)) min_val = numpy.min(stf.flatten()) max_val = numpy.max(stf.flatten()) if nS > 2: figsize = FiguresConfig.VERY_LARGE_SIZE fig = pyplot.figure(title, figsize=figsize) fig.suptitle(title) gs = gridspec.GridSpec(nS, 23) ax = numpy.empty((nS, 2), dtype="O") img = numpy.empty((nS, ), dtype="O") line = numpy.empty((nS, ), dtype="O") for iS in range(nS, -1, -1): if iS < nS - 1: ax[iS, 0] = pyplot.subplot(gs[iS, :20], sharex=ax[iS, 0]) ax[iS, 1] = pyplot.subplot(gs[iS, 20:22], sharex=ax[iS, 1], sharey=ax[iS, 0]) else: # TODO: find and correct bug here ax[iS, 0] = pyplot.subplot(gs[iS, :20]) ax[iS, 1] = pyplot.subplot(gs[iS, 20:22], sharey=ax[iS, 0]) img[iS] = ax[iS, 0].imshow(numpy.squeeze(stf[:, :, iS]).T, cmap=pyplot.set_cmap('jet'), interpolation='none', norm=Normalize(vmin=min_val, vmax=max_val), aspect='auto', origin='lower', extent=(time.min(), time.max(), freq.min(), freq.max())) # img[iS].clim(min_val, max_val) ax[iS, 0].set_title(labels[iS]) ax[iS, 0].set_ylabel("Frequency (Hz)") line[iS] = ax[iS, 1].plot(psd[:, iS], freq, 'k', label=labels[iS]) pyplot.setp(ax[iS, 1].get_yticklabels(), visible=False) # ax[iS, 1].yaxis.tick_right() # ax[iS, 1].yaxis.set_ticks_position('both') if iS == (nS - 1): ax[iS, 0].set_xlabel("Time (" + time_units + ")") ax[iS, 1].set_xlabel(psd_label) else: pyplot.setp(ax[iS, 0].get_xticklabels(), visible=False) pyplot.setp(ax[iS, 1].get_xticklabels(), visible=False) ax[iS, 0].autoscale(tight=True) ax[iS, 1].autoscale(tight=True) # make a color bar cax = pyplot.subplot(gs[:, 22]) pyplot.colorbar(img[0], cax=pyplot.subplot( gs[:, 22])) # fraction=0.046, pad=0.04) #fraction=0.15, shrink=1.0 cax.set_title(psd_label) self._save_figure(pyplot.gcf(), figure_name) self._check_show() return fig, ax, img, line, time, freq, stf, psd