コード例 #1
0
 def update_active_regions_target_data(self,
                                       target_data,
                                       probabilistic_model,
                                       sensors,
                                       reset=False):
     if reset:
         probabilistic_model.update_active_regions([])
     if target_data:
         active_regions = probabilistic_model.active_regions.tolist()
         gain_matrix = np.array(sensors.gain_matrix)
         signals_inds = sensors.get_sensors_inds_by_sensors_labels(
             target_data.space_labels)
         if len(signals_inds) != 0:
             gain_matrix = gain_matrix[signals_inds]
             for proj in gain_matrix:
                 active_regions += select_greater_values_array_inds(
                     proj, self.gain_matrix_th,
                     self.n_signals_per_roi).tolist()
             active_regions = self.exclude_regions(active_regions)
             probabilistic_model.update_active_regions(active_regions)
         else:
             warning(
                 "Skipping active regions setting by seeg power because no data were assigned to sensors!"
             )
     else:
         warning(
             "Skipping active regions setting by seeg power because no target data were provided!"
         )
     probabilistic_model.gain_matrix = sensors.gain_matrix[
         signals_inds][:, probabilistic_model.active_regions]
     return probabilistic_model
コード例 #2
0
 def print_not_equal_message(attr, field1, field2, logger):
     # logger.error("\n\nValueError: Original and read object field "+ attr + " not equal!")
     # raise_value_error("\n\nOriginal and read object field " + attr + " not equal!")
     warning(
         "Original and read object field " + attr + " not equal!" +
         "\nOriginal field:\n" + str(field1) + "\nRead object field:\n" +
         str(field2), logger)
コード例 #3
0
 def get_truth(self, parameter_name):
     truth = self.ground_truth.get(parameter_name, np.nan)
     if truth is np.nan and self.target_data_type == Target_Data_Type.SYNTHETIC.value:
         truth = getattr(self.model_config, parameter_name, np.nan)
     if truth is np.nan:
         warning("Ground truth value for parameter " + parameter_name +
                 " was not found!")
     return truth
コード例 #4
0
 def _region_parameters_violin_plots(self, samples_all, values=None, lines=None, stats=None,
                                     params=[], skip_samples=0, per_chain_or_run=False,
                                     labels=[], special_idx=None, figure_name="Regions parameters samples",
                                     figsize=FiguresConfig.VERY_LARGE_SIZE):
     if isinstance(values, dict):
         vals_fun = lambda param: values.get(param, numpy.array([]))
     else:
         vals_fun = lambda param: []
     if isinstance(lines, dict):
         lines_fun = lambda param: lines.get(param, numpy.array([]))
     else:
         lines_fun = lambda param: []
     samples_all = ensure_list(samples_all)
     params = [param for param in params if param in samples_all[0].keys()]
     samples = []
     for sample in samples_all:
         samples.append(extract_dict_stringkeys(sample, params, modefun="equal"))
     if len(labels) == 0:
         labels = generate_region_labels(samples[0].values()[0].shape[-1], labels, numbering=False)
     n_chains = len(samples_all)
     n_samples = samples[0].values()[0].shape[0]
     if n_samples > 1:
         violin_flag = True
     else:
         violin_flag = False
     if not per_chain_or_run and n_chains > 1:
         samples = [merge_samples(samples)]
         plot_samples = lambda s: numpy.concatenate(numpy.split(s[:, skip_samples:].T, n_chains, axis=2),
                                                    axis=1).squeeze().T
         plot_figure_name = lambda ichain: figure_name
     else:
         plot_samples = lambda s: s[skip_samples:]
         plot_figure_name = lambda ichain: figure_name + ": chain " + str(ichain + 1)
     params_labels = {}
     for ip, p in enumerate(params):
         if ip == 0:
             params_labels[p] = self._params_stats_labels(p, stats, labels)
         else:
             params_labels[p] = self._params_stats_labels(p, stats, "")
     n_params = len(params)
     if n_params > 9:
         warning("Number of subplots in column wise vector-violin-plots cannot be > 9 and it is "
                           + str(n_params) + "!")
     subplot_ind = 100 + n_params * 10
     figs = []
     for ichain, chain_sample in enumerate(samples):
         pyplot.figure(plot_figure_name(ichain), figsize=figsize)
         for ip, param in enumerate(params):
             fig = self.plot_vector_violin(plot_samples(chain_sample[param]), vals_fun(param),
                                           lines_fun(param), params_labels[param],
                                           subplot_ind + ip + 1, param, violin_flag=violin_flag,
                                           colormap="YlOrRd", show_y_labels=True,
                                           indices_red=special_idx, sharey=None)
         self._save_figure(pyplot.gcf(), None)
         self._check_show()
         figs.append(fig)
     return tuple(figs)
コード例 #5
0
 def get_prior_pdf(self, parameter_name):
     parameter_mean, parameter = self.get_prior(parameter_name)
     if isinstance(parameter, (ProbabilisticParameterBase,
                               TransformedProbabilisticParameterBase)):
         return parameter.scipy_method("pdf")
     else:
         warning("No probabilistic parameter " + parameter_name +
                 " was found!"
                 "\nReturning prior value, if available, instead of pdf!")
         return parameter_mean, np.nan
コード例 #6
0
 def read_output(self):
     samples = self.read_output_samples(self.output_filepath)
     est = self.compute_estimates_from_samples(samples)
     if os.path.isfile(self.summary_filepath):
         try:
             summary = parse_csv_in_cols(self.summary_filepath)
         except:
             summary = None
             warning("Reading stan summary failed!")
     else:
         summary = None
     return est, samples, summary
コード例 #7
0
def convert_params_names_to_ins(dicts_list,
                                parameter_names=INS_PARAMS_NAMES_DICT):
    output = []
    for lst in ensure_list(dicts_list):
        for dct in ensure_list(lst):
            for p, p_ins in parameter_names.items():
                try:
                    dct[p_ins] = dct[p]
                except:
                    warning("Parameter " + p + " not found in \n" +
                            str(dicts_list))
        output.append(lst)
    return tuple(output)
コード例 #8
0
 def build_model_config_from_model_config(self, model_config):
     if not isinstance(model_config, dict):
         model_config_dict = model_config.__dict__
     else:
         model_config_dict = model_config
     model_configuration = ModelConfiguration()
     for attr, value in model_configuration.__dict__.items():
         value = model_config_dict.get(attr, None)
         if value is None:
             warning(attr + " not found in the input model configuration dictionary!" +
                     "\nLeaving default " + attr + ": " + str(getattr(model_configuration, attr)))
         if value is not None:
             setattr(model_configuration, attr, value)
     return model_configuration
コード例 #9
0
 def get_prior(self, parameter_name):
     parameter = self.get_parameter(parameter_name)
     if parameter is None:
         warning("No probabilistic prior for parameter " + parameter_name +
                 " was found!")
         # TODO: decide if it is a good idea to return this kind of modeler's fixed "prior"...:
         pmean = getattr(self, parameter_name, np.nan)
         if pmean is np.nan:
             pmean = getattr(self.model_config, parameter_name, np.nan)
         if pmean is np.nan:
             warning("No prior value for parameter " + parameter_name +
                     " was found!")
         return pmean, parameter
     else:
         return parameter.mean, parameter
コード例 #10
0
 def __init__(self, input="EpileptorDP", connectivity=None, K_unscaled=np.array([K_UNSCALED_DEF]),
              x0_values=X0_DEF, e_values=E_DEF, x1eq_mode="optimize", **kwargs):
     if isinstance(input, Simulator):
         # TODO: make this more specific once we clarify the model configuration representation compared to simTVB
         self.model_name = input.model._ui_name
         self.set_params_from_tvb_model(input.model)
         self.connectivity = normalize_weights(input.connectivity.weights)
         # self.coupling = input.coupling
         self.initial_conditions = np.squeeze(input.initial_conditions)  # initial conditions in a reduced form
         # self.noise = input.integrator.noise
         # self.monitor = ensure_list(input.monitors)[0]
     else:
         if isinstance(input, Model):
             self.model_name = input._ui_name
             self.set_params_from_tvb_model(input)
         elif isinstance(input, basestring):
             self.model_name = input
         else:
             raise_value_error("Input (%s) is not a TVB simulator, an epileptor model, "
                               "\nor a string of an epileptor model!")
     if isinstance(connectivity, Connectivity):
         self.connectivity = connectivity.normalized_weights
     elif isinstance(connectivity, TVBConnectivity):
         self.connectivity = normalize_weights(connectivity.weights)
     elif isinstance(connectivity, np.ndarray):
         self.connectivity = normalize_weights(connectivity)
     else:
         if not(isinstance(input, Simulator)):
             warning("Input connectivity (%s) is not a virtual patient connectivity, a TVB connectivity, "
                     "\nor a numpy.array!" % str(connectivity))
     self.x0_values = x0_values * np.ones((self.number_of_regions,), dtype=np.float32)
     self.x1eq_mode = x1eq_mode
     if len(ensure_list(K_unscaled)) == 1:
         K_unscaled = np.array(K_unscaled) * np.ones((self.number_of_regions,), dtype=np.float32)
     elif len(ensure_list(K_unscaled)) == self.number_of_regions:
         K_unscaled = np.array(K_unscaled)
     else:
         self.logger.warning(
             "The length of input global coupling K_unscaled is neither 1 nor equal to the number of regions!" +
             "\nSetting model_configuration_builder.K_unscaled = K_UNSCALED_DEF for all regions")
     self.set_K_unscaled(K_unscaled)
     for pname in EPILEPTOR_PARAMS:
         self.set_parameter(pname, kwargs.get(pname, getattr(self, pname)))
     # Update K_unscaled
     self.e_values = e_values * np.ones((self.number_of_regions,), dtype=np.float32)
     self.x0cr = 0.0
     self.rx0 = 0.0
     self._compute_critical_x0_scaling()
コード例 #11
0
 def _set_attributes_from_dict(self, attributes_dict):
     if not isinstance(attributes_dict, dict):
         attributes_dict = attributes_dict.__dict__
     for attr, value in attributes_dict.items():
         if not attr in [
                 "model_config", "parameters", "number_of_regions",
                 "number_of_parameters"
         ]:
             value = attributes_dict.get(attr, None)
             if value is None:
                 warning(attr + " not found in input dictionary!" +
                         "\nLeaving as it is: " + attr + " = " +
                         str(getattr(self, attr)))
             if value is not None:
                 setattr(self, attr, value)
     return attributes_dict
コード例 #12
0
 def update_active_regions_x0_values(self,
                                     probabilistic_model,
                                     x0_values,
                                     reset=False):
     active_regions = probabilistic_model.active_regions.tolist()
     if reset:
         active_regions = []
     if len(x0_values) > 0:
         active_regions += select_greater_values_array_inds(
             x0_values, self.active_x0_th).tolist()
         active_regions = self.exclude_regions(active_regions)
         probabilistic_model.update_active_regions(active_regions)
     else:
         warning(
             "Skipping active regions setting by x0 values because no such values were provided!"
         )
     return probabilistic_model
コード例 #13
0
    def _write_dicts_at_location(self, datasets_dict, metadata_dict, location):
        for key, value in datasets_dict.items():
            try:
                location.create_dataset(key, data=value)
            except:
                warning(
                    "Failed to write to %s dataset %s %s:\n%s !" %
                    (str(location), value.__class__, key, str(value)),
                    self.logger)

        for key, value in metadata_dict.items():
            try:
                location.attrs.create(key, value)
            except:
                warning(
                    "Failed to write to %s attribute %s %s:\n%s !" %
                    (str(location), value.__class__, key, str(value)),
                    self.logger)
        return location
コード例 #14
0
    def _determine_datasets_and_attributes(self, object, datasets_size=None):
        datasets_dict = {}
        metadata_dict = {}
        groups_keys = []

        try:
            if isinstance(object, dict):
                dict_object = object
            else:
                dict_object = vars(object)
            for key, value in dict_object.items():
                if isinstance(value, numpy.ndarray):
                    # if value.size == 1:
                    #     metadata_dict.update({key: value})
                    # else:
                    datasets_dict.update({key: value})
                    # if datasets_size is not None and value.size == datasets_size:
                    #     datasets_dict.update({key: value})
                    # else:
                    #     if datasets_size is None and value.size > 0:
                    #         datasets_dict.update({key: value})
                    #     else:
                    #         metadata_dict.update({key: value})
                # TODO: check how this works! Be carefull not to include lists and tuples if possible in tvb_fit classes!
                elif isinstance(value, (list, tuple)):
                    warning(
                        "Writing %s %s to h5 file as a numpy array dataset !" %
                        (value.__class__, key), self.logger)
                    datasets_dict.update({key: numpy.array(value)})
                else:
                    if is_numeric(value) or isinstance(value, str):
                        metadata_dict.update({key: value})
                    elif not (callable(value)):
                        groups_keys.append(key)
        except:
            msg = "Failed to decompose group object: " + str(object) + "!"
            try:
                self.logger.info(str(object.__dict__))
            except:
                msg += "\n It has no __dict__ attribute!"
            warning(msg, self.logger)

        return datasets_dict, metadata_dict, groups_keys
コード例 #15
0
    def plot_bars(self,
                  data,
                  ax=None,
                  fig=None,
                  title="",
                  group_names=[],
                  legend_prefix="",
                  figsize=FiguresConfig.VERY_LARGE_SIZE):
        def barlabel(ax, rects, positions):
            """
            Attach a text label on each bar displaying its height
            """
            for rect, pos in zip(rects, positions):
                height = rect.get_height()
                if pos < 0:
                    y = -height
                    pos = 0.75 * pos
                else:
                    y = height
                    pos = 0.25 * pos
                ax.text(rect.get_x() + rect.get_width() / 2.,
                        pos,
                        '%0.2f' % y,
                        color="k",
                        ha='center',
                        va='bottom',
                        rotation=90)

        if fig is None:
            fig, ax = pyplot.subplots(1, 1, figsize=figsize)
            show_and_save = True
        else:
            show_and_save = False
            if ax is None:
                ax = pyplot.gca()
        if isinstance(
                data,
            (list, tuple)):  # If, there are many groups, data is a list:
            # Fill in with nan in case that not all groups have the same number of elements
            from itertools import izip_longest
            data = numpy.array(
                list(izip_longest(*ensure_list(data), fillvalue=numpy.nan))).T
        elif data.ndim == 1:  # This is the case where there is only one group...
            data = numpy.expand_dims(data, axis=1).T
        n_groups, n_elements = data.shape
        posmax = numpy.nanmax(data)
        negmax = numpy.nanmax(-(-data))
        n_groups_names = len(group_names)
        if n_groups_names != n_groups:
            if n_groups_names != 0:
                warning("Ignoring group_names because their number (" +
                        str(n_groups_names) +
                        ") is not equal to the number of groups (" +
                        str(n_groups) + ")!")
            group_names = n_groups * [""]
        colorcycle = pyplot.rcParams['axes.prop_cycle'].by_key()['color']
        n_colors = len(colorcycle)
        x_inds = numpy.arange(n_groups)
        width = 0.9 / n_elements
        elements = []
        for iE in range(n_elements):
            elements.append(
                ax.bar(x_inds + iE * width,
                       data[:, iE],
                       width,
                       color=colorcycle[iE % n_colors]))
            positions = numpy.array(
                [negmax if d < 0 else posmax for d in data[:, iE]])
            positions[numpy.logical_or(numpy.isnan(positions),
                                       numpy.isinf(
                                           numpy.abs(positions)))] = 0.0
            barlabel(ax, elements[-1], positions)
        if n_elements > 1:
            legend = [
                legend_prefix + str(ii) for ii in range(1, n_elements + 1)
            ]
            ax.legend(tuple([element[0] for element in elements]),
                      tuple(legend))
        ax.set_xticks(x_inds + n_elements * width / 2)
        ax.set_xticklabels(tuple(group_names))
        ax.set_title(title)
        ax.autoscale()  # tight=True
        ax.set_xlim([-1.05 * width, n_groups * 1.05])
        if show_and_save:
            fig.tight_layout()
            self._save_figure(fig)
            self._check_show()
        return fig, ax
コード例 #16
0
    def run_lsa(self, disease_hypothesis, model_configuration):

        if self.lsa_method == "auto":
            if numpy.any(model_configuration.x1eq > X1EQ_CR_DEF):
                self.lsa_method = "2D"
            else:
                self.lsa_method = "1D"

        if self.lsa_method == "2D" and numpy.all(
                model_configuration.x1eq <= X1EQ_CR_DEF):
            warning(
                "LSA with the '2D' method (on the 2D Epileptor model) will not produce interpretable results when"
                " the equilibrium point of the system is not supercritical (unstable)!"
            )

        jacobian = self._compute_jacobian(model_configuration)

        # Perform eigenvalue decomposition
        eigen_values, eigen_vectors = numpy.linalg.eig(jacobian)
        eigen_values = numpy.real(eigen_values)
        eigen_vectors = numpy.real(eigen_vectors)
        sorted_indices = numpy.argsort(eigen_values, kind='mergesort')
        if self.lsa_method == "2D":
            sorted_indices = sorted_indices[::-1]
        self.eigen_vectors = eigen_vectors[:, sorted_indices]
        self.eigen_values = eigen_values[sorted_indices]

        self._ensure_eigen_vectors_number(
            self.eigen_values[:disease_hypothesis.number_of_regions],
            model_configuration.e_values, model_configuration.x0_values,
            disease_hypothesis.regions_disease_indices)

        if self.eigen_vectors_number == disease_hypothesis.number_of_regions:
            # Calculate the propagation strength index by summing all eigenvectors
            lsa_propagation_strength = numpy.abs(
                numpy.sum(self.eigen_vectors, axis=1))

        else:
            sorted_indices = max(self.eigen_vectors_number, 1)
            # Calculate the propagation strength index by summing the first n eigenvectors (minimum 1)
            if self.weighted_eigenvector_sum:
                lsa_propagation_strength = \
                    numpy.abs(weighted_vector_sum(numpy.array(self.eigen_values[:self.eigen_vectors_number]),
                                                  numpy.array(self.eigen_vectors[:, :self.eigen_vectors_number]),
                                                              normalize=True))
            else:
                lsa_propagation_strength = \
                    numpy.abs(numpy.sum(self.eigen_vectors[:, :self.eigen_vectors_number], axis=1))

        if self.lsa_method == "2D":
            # lsa_propagation_strength = lsa_propagation_strength[:disease_hypothesis.number_of_regions]
            # or
            # lsa_propagation_strength = numpy.where(lsa_propagation_strength[:disease_hypothesis.number_of_regions] >=
            #                                        lsa_propagation_strength[disease_hypothesis.number_of_regions:],
            #                                        lsa_propagation_strength[:disease_hypothesis.number_of_regions],
            #                                        lsa_propagation_strength[disease_hypothesis.number_of_regions:])
            # or
            lsa_propagation_strength = numpy.sqrt(
                lsa_propagation_strength[:disease_hypothesis.number_of_regions]
                **2 +
                lsa_propagation_strength[disease_hypothesis.number_of_regions:]
                **2)
            lsa_propagation_strength = numpy.log10(lsa_propagation_strength)
            lsa_propagation_strength -= lsa_propagation_strength.min()

        if self.normalize_propagation_strength:
            # Normalize by the maximum
            lsa_propagation_strength /= numpy.max(lsa_propagation_strength)

        # # TODO: this has to be corrected
        # if self.eigen_vectors_number < 0.2 * disease_hypothesis.number_of_regions:
        #     propagation_strength_elbow = numpy.max([self.get_curve_elbow_point(lsa_propagation_strength),
        #                                     self.eigen_vectors_number])
        # else:
        propagation_strength_elbow = self.get_curve_elbow_point(
            lsa_propagation_strength)
        propagation_indices = lsa_propagation_strength.argsort(
        )[-propagation_strength_elbow:]

        hypothesis_builder = HypothesisBuilder(disease_hypothesis.number_of_regions). \
                                set_attributes_based_on_hypothesis(disease_hypothesis). \
                                    set_name(disease_hypothesis.name + "_LSA"). \
                                        set_lsa_propagation(propagation_indices, lsa_propagation_strength)

        return hypothesis_builder.build_lsa_hypothesis()
コード例 #17
0
def assert_equal_objects(obj1, obj2, attributes_dict=None, logger=None):
    def print_not_equal_message(attr, field1, field2, logger):
        # logger.error("\n\nValueError: Original and read object field "+ attr + " not equal!")
        # raise_value_error("\n\nOriginal and read object field " + attr + " not equal!")
        warning(
            "Original and read object field " + attr + " not equal!" +
            "\nOriginal field:\n" + str(field1) + "\nRead object field:\n" +
            str(field2), logger)

    if isinstance(obj1, dict):
        get_field1 = lambda obj, key: obj[key]
        if not (isinstance(attributes_dict, dict)):
            attributes_dict = dict()
            for key in obj1.keys():
                attributes_dict.update({key: key})
    elif isinstance(obj1, (list, tuple)):
        get_field1 = lambda obj, key: get_list_or_tuple_item_safely(obj, key)
        indices = range(len(obj1))
        attributes_dict = dict(zip([str(ind) for ind in indices], indices))
    else:
        get_field1 = lambda obj, attribute: getattr(obj, attribute)
        if not (isinstance(attributes_dict, dict)):
            attributes_dict = dict()
            for key in obj1.__dict__.keys():
                attributes_dict.update({key: key})
    if isinstance(obj2, dict):
        get_field2 = lambda obj, key: obj.get(key, None)
    elif isinstance(obj2, (list, tuple)):
        get_field2 = lambda obj, key: get_list_or_tuple_item_safely(obj, key)
    else:
        get_field2 = lambda obj, attribute: getattr(obj, attribute, None)

    equal = True
    for attribute in attributes_dict:
        # print attributes_dict[attribute]
        field1 = get_field1(obj1, attributes_dict[attribute])
        field2 = get_field2(obj2, attributes_dict[attribute])
        try:
            # TODO: a better hack for the stupid case of an ndarray of a string, such as model.zmode or pmode
            # For non numeric types
            if isinstance(field1, basestring) or isinstance(field1, list) or isinstance(field1, dict) \
                    or (isinstance(field1, np.ndarray) and field1.dtype.kind in 'OSU'):
                if np.any(field1 != field2):
                    print_not_equal_message(attributes_dict[attribute], field1,
                                            field2, logger)
                    equal = False
            # For numeric numpy arrays:
            elif isinstance(field1,
                            np.ndarray) and not field1.dtype.kind in 'OSU':
                # TODO: handle better accuracy differences, empty matrices and complex numbers...
                if field1.shape != field2.shape:
                    print_not_equal_message(attributes_dict[attribute], field1,
                                            field2, logger)
                    equal = False
                elif np.any(np.float32(field1) - np.float32(field2) > 0):
                    print_not_equal_message(attributes_dict[attribute], field1,
                                            field2, logger)
                    equal = False
            # For numeric scalar types
            elif is_numeric(field1):
                if np.float32(field1) - np.float32(field2) > 0:
                    print_not_equal_message(attributes_dict[attribute], field1,
                                            field2, logger)
                    equal = False
            else:
                equal = assert_equal_objects(field1, field2, logger=logger)
        except:
            try:
                warning(
                    "Comparing str(objects) for field " +
                    str(attributes_dict[attribute]) +
                    " because there was an error!", logger)
                if np.any(str(field1) != str(field2)):
                    print_not_equal_message(attributes_dict[attribute], field1,
                                            field2, logger)
                    equal = False
            except:
                raise_value_error(
                    "ValueError: Something went wrong when trying to compare "
                    + str(attributes_dict[attribute]) + " !", logger)

    if equal:
        return True
    else:
        return False
コード例 #18
0
    def compute_information_criteria(self,
                                     samples,
                                     nparams=None,
                                     nsamples=None,
                                     ndata=None,
                                     parameters=[],
                                     skip_samples=0,
                                     merge_chains_or_runs_flag=False,
                                     log_like_str='log_likelihood'):
        """

        :param samples: a dictionary of stan outputs or a list of dictionaries for multiple runs/chains
        :param nparams: number of model parameters, it can be inferred from parameters if None
        :param nsamples: number of samples, it can be inferred from loglikelihood if None
        :param ndata: number of data points, it can be inferred from loglikelihood if None
        :param parameters: a list of parameter names, necessary for dic metric computations and in case nparams is None,
                           as well as for aicc, aic and bic computation
        :param merge_chains_or_runs_flag: logical flag for merging seperate chains/runs, default is True
        :param log_like_str: the name of the log likelihood output of stan, default ''log_likelihood
        :return:
        """

        import sys
        sys.path.insert(0, self.config.generic.MODEL_COMPARISON_PATH)
        from information_criteria.ComputeIC import maxlike, aicc, aic, bic, dic, waic
        from information_criteria.ComputePSIS import psisloo

        # if self.fitmethod.find("opt") >= 0:
        #     warning("No model comparison can be computed for optimization method!")
        #     return None

        samples = ensure_list(samples)
        if merge_chains_or_runs_flag and len(samples) > 1:
            samples = ensure_list(
                merge_samples(samples, skip_samples, flatten=True))
            skip_samples = 0

        results = []
        for sample in samples:
            log_likelihood = -1 * sample[log_like_str][skip_samples:]
            log_lik_shape = log_likelihood.shape
            if len(log_lik_shape) > 1:
                target_shape = log_lik_shape[1:]
            else:
                target_shape = (1, )
            if nsamples is None:
                nsamples = log_lik_shape[0]
            elif nsamples != log_likelihood.shape[0]:
                warning("nsamples (" + str(nsamples) +
                        ") is not equal to likelihood.shape[0] (" +
                        str(log_lik_shape[0]) + ")!")

            log_likelihood = np.reshape(log_likelihood, (log_lik_shape[0], -1))
            if log_likelihood.shape > 1:
                ndata_real = np.maximum(log_likelihood.shape[1], 1)
            else:
                ndata_real = 1
            if ndata is None:
                ndata = ndata_real
            elif ndata != ndata_real:
                warning("ndata (" + str(ndata) +
                        ") is not equal to likelihood.shape[1] (" +
                        str(ndata_real) + ")!")

            result = maxlike(log_likelihood)

            if len(parameters) == 0:
                parameters = [
                    param for param in sample.keys()
                    if param.find("_star") >= 0
                ]
            if len(parameters) > 0:
                nparams_real = 0
                zscore_params = []
                for p in parameters:
                    pval = sample[p][skip_samples:]
                    pzscore = np.array(
                        (pval - np.mean(pval, axis=0)) / np.std(pval, axis=0))
                    if len(pzscore.shape) > 2:
                        pzscore = np.reshape(pzscore, (pzscore.shape[0], -1))
                    zscore_params.append(pzscore)
                    if len(pzscore.shape) > 1:
                        nparams_real += np.maximum(pzscore.shape[1], 1)
                    else:
                        nparams_real += 1
                if nparams is None:
                    nparams = nparams_real
                elif nparams != nparams_real:
                    warning(
                        "nparams (" + str(nparams) +
                        ") is not equal to number of parameters included in the dic computation ("
                        + str(nparams_real) + ")!")
                # TODO: find out how to reduce dic to 1 value, from 1 value per parameter. mean(.) for the moment:
                result['dic'] = np.mean(dic(log_likelihood, zscore_params))
            else:
                warning(
                    "Parameters' names' list is empty and we found no _star parameters! No computation of dic!"
                )

            if nparams is not None:
                result['aicc'] = aicc(log_likelihood, nparams, ndata)
                result['aic'] = aic(log_likelihood, nparams)
                result['bic'] = bic(log_likelihood, nparams, ndata)
            else:
                warning(
                    "Unknown number of parameters! No computation of aic, aaic, bic!"
                )

            result.update(waic(log_likelihood))

            if nsamples > 1:
                result.update(psisloo(log_likelihood))
                result["loos"] = np.reshape(result["loos"], target_shape)
                result["ks"] = np.reshape(result["ks"], target_shape)
            else:
                result.pop('p_waic', None)

            for metric, value in result.items():
                result[metric] = value * np.ones(1, )

            results.append(result)

        if len(results) == 1:
            return results[0]
        else:
            return list_of_dicts_to_dicts_of_ndarrays(results)
コード例 #19
0
 def get_parameter(self, parameter_name):
     parameter = self.parameters.get(parameter_name, None)
     if parameter is None:
         warning("Ground truth value for parameter " + parameter_name +
                 " was not found!")
     return parameter
コード例 #20
0
    def plot_array_model_comparison(self, model_comps, title_prefix="", metrics=["loos", "ks"], labels=[], xdata=None,
                                    xlabel="", figsize=FiguresConfig.VERY_LARGE_SIZE[::-1], figure_name=None):

        def arrange_chains_or_runs(metric_data):
            n_chains_or_runs = 1
            for imodel, model in enumerate(metric_data):
                if model.ndim > 2:
                    if model.shape[0] > n_chains_or_runs:
                        n_chains_or_runs = model.shape[0]
                else:
                    metric_data[imodel] = numpy.expand_dims(model, axis=0)
            return metric_data

        colorcycle = pyplot.rcParams['axes.prop_cycle'].by_key()['color']
        n_colors = len(colorcycle)
        metrics = [metric for metric in metrics if metric in model_comps.keys()]
        figs=[]
        axs = []
        for metric in metrics:
            if isinstance(model_comps[metric], dict):
                # Multiple models as a list of np.arrays of chains x data
                metric_data = model_comps[metric].values()
                model_names = model_comps[metric].keys()
            else:
                # Single models as a one element list of one np.array of chains x data
                metric_data = [model_comps[metric]]
                model_names = [""]
            metric_data = arrange_chains_or_runs(metric_data)
            n_models = len(metric_data)
            for jj in range(n_models):
                # Necessary because ks gets infinite sometimes...
                temp = metric_data[jj] == numpy.inf
                if numpy.all(temp):
                    warning("All values are inf for metric " + metric + " of model " + model_names[ii] + "!\n")
                    return
                elif numpy.any(temp):
                    warning("Inf values found for metric " + metric + " of model " + model_names[ii] + "!\n" +
                            "Substituting them with the maximum non-infite value!")
                    metric_data[jj][temp] = metric_data[jj][~temp].max()
            n_subplots = metric_data[0].shape[1]
            n_labels = len(labels)
            if n_labels != n_subplots:
                if n_labels != 0:
                    warning("Ignoring labels because their number (" + str(n_labels) +
                            ") is not equal to the number of row subplots (" + str(n_subplots) + ")!")
                labels = [str(ii + 1) for ii in range(n_subplots)]
            if xdata is None:
                xdata = numpy.arange(metric_data[jj].shape[-1])
            else:
                xdata = xdata.flatten()
            xdata0 = numpy.concatenate([numpy.reshape(xdata[0] - 0.1*(xdata[-1]-xdata[0]), (1,)), xdata])
            xdata1 = xdata[-1] + 0.1 * (xdata[-1] - xdata[0])
            if len(title_prefix) > 0:
                title = title_prefix + ": " + metric
            else:
                title = metric
            figs.append(pyplot.figure(title, figsize=figsize))
            figs[-1].suptitle(title)
            figs[-1].set_label(title)
            gs = gridspec.GridSpec(n_subplots, n_models)
            axes = numpy.empty((n_subplots, n_models), dtype="O")
            for ii in range(n_subplots-1,-1, -1):
                for jj in range(n_models):
                    if ii > n_subplots-1:
                        if jj > 0:
                            axes[ii, jj] = pyplot.subplot(gs[ii, jj], sharex=axes[n_subplots-1, jj], sharey=axes[ii, 0])
                        else:
                            axes[ii, jj] = pyplot.subplot(gs[ii, jj], sharex=axes[n_subplots-1, jj])
                    else:
                        if jj > 0:
                            axes[ii, jj] = pyplot.subplot(gs[ii, jj], sharey=axes[ii, 0])
                        else:
                            axes[ii, jj] = pyplot.subplot(gs[ii, jj])
                    n_chains_or_runs = metric_data[jj].shape[0]
                    for kk in range(n_chains_or_runs):
                        c = colorcycle[kk % n_colors]
                        axes[ii, jj].plot(xdata, metric_data[jj][kk][ii, :],  label="chain/run " + str(kk + 1),
                                          marker="o", markersize=5, markeredgecolor=c, markerfacecolor=c,
                                          linestyle="-", linewidth=1)
                        m = numpy.nanmean(metric_data[jj][kk][ii, :])
                        axes[ii, jj].plot(xdata0, m * numpy.ones(xdata0.shape), color=c, linewidth=1)
                        axes[ii, jj].text(xdata0[0], 1.1 * m, 'mean=%0.2f' % m, ha='center', va='bottom', color=c)
                    axes[ii, jj].set_xlabel(xlabel)
                    if ii == 0:
                        axes[ii, jj].set_title(model_names[ii])
                        if n_chains_or_runs > 1:
                            axes[ii, jj].legend()
                if ii == n_subplots-1:
                    axes[ii, 0].autoscale()  # tight=True
                    axes[ii, 0].set_xlim([xdata0[0], xdata1])  # tight=True
            # figs[-1].tight_layout()
            self._save_figure(figs[-1], figure_name)
            self._check_show()
            axs.append(axes)
        return tuple(figs), tuple(axs)
コード例 #21
0
def run_fitting(probabilistic_model,
                stan_model_name,
                model_data,
                target_data,
                config,
                head=None,
                seizure_indices=[],
                pair_plot_params=[
                    "tau1", "tau0", "K", "sigma", "epsilon", "scale", "offset"
                ],
                region_violin_params=["x0", "PZ", "x1eq", "zeq"],
                fit_flag=True,
                test_flag=False,
                base_path="",
                fitmethod="sample",
                n_chains_or_runs=2,
                output_samples=200,
                num_warmup=100,
                min_samples_per_chain=200,
                max_depth=15,
                delta=0.95,
                iter=500000,
                tol_rel_obj=1e-6,
                debug=1,
                simulate=0,
                step_prefix='',
                writer=None,
                plotter=None,
                **kwargs):

    # ------------------------------Stan model and service--------------------------------------
    model_code_path = os.path.join(config.generic.PROBLSTC_MODELS_PATH,
                                   stan_model_name + ".stan")
    stan_interface = CmdStanInterface(model_name=stan_model_name,
                                      model_dir=base_path,
                                      model_code_path=model_code_path,
                                      fitmethod=fitmethod,
                                      config=config)
    stan_interface.set_or_compile_model()
    stan_interface.model_data_path = os.path.join(base_path, "ModelData.h5")

    # -------------------------- Fit and get estimates: ------------------------------------------------------------
    n_chains_or_runs = np.where(test_flag, 2, n_chains_or_runs)
    output_samples = np.where(
        test_flag, 20,
        max(int(np.round(output_samples * 1.0 / n_chains_or_runs)),
            min_samples_per_chain))

    # Sampling (HMC)
    num_samples = output_samples
    num_warmup = np.where(test_flag, 30, num_warmup)
    max_depth = np.where(test_flag, 7, max_depth)
    delta = np.where(test_flag, 0.8, delta)
    # ADVI or optimization:
    iter = np.where(test_flag, 1000, iter)
    if fitmethod.find("sampl") >= 0:
        skip_samples = num_warmup
    else:
        skip_samples = 0
    prob_model_name = probabilistic_model.name.split(".")[0]
    if fit_flag:
        estimates, samples, summary = stan_interface.fit(
            debug=debug,
            simulate=simulate,
            model_data=model_data,
            n_chains_or_runs=n_chains_or_runs,
            refresh=1,
            iter=iter,
            tol_rel_obj=tol_rel_obj,
            output_samples=output_samples,
            num_warmup=num_warmup,
            num_samples=num_samples,
            max_depth=max_depth,
            delta=delta,
            save_warmup=1,
            plot_warmup=1,
            output_path=base_path,
            **kwargs)
        # TODO: check if write_dictionary is enough for estimates, samples, summary and info_crit
        if writer:
            writer.write_list_of_dictionaries(
                estimates, path(prob_model_name + "_FitEst", base_path))
            writer.write_list_of_dictionaries(
                samples, path(prob_model_name + "_FitSamples", base_path))
            if summary is not None:
                writer.write_dictionary(
                    summary, path(prob_model_name + "_FitSummary", base_path))
    else:
        stan_interface.set_output_files(base_path=base_path, update=True)
        estimates, samples, summary = stan_interface.read_output()

    # Model comparison:
    # scale_signal, offset_signal, time_scale, epsilon, sigma -> 5 (+ K = 6)
    # x0[active] -> probabilistic_model.model.number_of_active_regions
    # x1init[active], zinit[active] -> 2 * probabilistic_model.number_of_active_regions
    # dZt[active, t] -> probabilistic_model.number_of_active_regions * (probabilistic_model.time_length-1)
    number_of_total_params = \
        5 + probabilistic_model.number_of_active_regions * (3 + (probabilistic_model.time_length - 1))
    info_crit = \
        stan_interface.compute_information_criteria(samples, number_of_total_params, skip_samples=skip_samples,
                                                    # parameters=["amplitude_star", "offset_star", "epsilon_star",
                                                    #                  "sigma_star", "time_scale_star", "x0_star",
                                                    #                  "x_init_star", "z_init_star", "z_eta_star"],
                                                    merge_chains_or_runs_flag=False)

    if writer:
        writer.write_dictionary(info_crit,
                                path(prob_model_name + "_InfoCrit", base_path))

    Rhat = stan_interface.get_Rhat(summary)
    # Interface backwards with INS stan models
    # from tvb_fit.service.model_inversion.vep_stan_dict_builder import convert_params_names_from_ins
    # estimates, samples, Rhat, model_data = \
    #     convert_params_names_from_ins([estimates, samples, Rhat, model_data])
    if fitmethod.find("opt") < 0 and Rhat is not None:
        stats = {"Rhat": Rhat}
    else:
        stats = None

    if plotter:
        # -------------------------- Plot fitting results: ------------------------------------------------------------
        try:
            if fitmethod.find("sampl") >= 0:
                plotter.plot_HMC(samples,
                                 figure_name=step_prefix + prob_model_name +
                                 " HMC NUTS trace")

            plotter.plot_fit_results(
                estimates,
                samples,
                model_data,
                target_data,
                probabilistic_model,
                info_crit,
                stats=stats,
                seizure_indices=seizure_indices,
                pair_plot_params=pair_plot_params,  #
                region_violin_params=region_violin_params,
                region_labels=head.connectivity.region_labels,
                skip_samples=skip_samples,
                title_prefix=step_prefix + prob_model_name)
        except:
            warning("Fitting plotting failed for step %s" % step_prefix)

    return estimates, samples, summary, info_crit
コード例 #22
0
    def plot_spectral_analysis_raster(self,
                                      time,
                                      data,
                                      time_units="ms",
                                      freq=None,
                                      spectral_options={},
                                      special_idx=[],
                                      labels=[],
                                      title='Spectral Analysis',
                                      figure_name=None,
                                      figsize=FiguresConfig.VERY_LARGE_SIZE):
        nS = data.shape[1]
        n_special_idx = len(special_idx)
        if n_special_idx > 0:
            data = data[:, special_idx]
            nS = data.shape[1]
            if len(labels) > n_special_idx:
                labels = numpy.array([
                    str(ilbl) + ". " + str(labels[ilbl])
                    for ilbl in special_idx
                ])
            elif len(labels) == n_special_idx:
                labels = numpy.array([
                    str(ilbl) + ". " + str(label)
                    for ilbl, label in zip(special_idx, labels)
                ])
            else:
                labels = numpy.array([str(ilbl) for ilbl in special_idx])
        else:
            if len(labels) != nS:
                labels = numpy.array([str(ilbl) for ilbl in range(nS)])
        if nS > 20:
            warning(
                "It is not possible to plot spectral analysis plots for more than 20 signals!"
            )
            return
        if not isinstance(time_units, basestring):
            time_units = list(time_units)[0]
        time_units = ensure_string(time_units)
        if time_units in ("ms", "msec"):
            fs = 1000.0
        else:
            fs = 1.0
        fs = fs / numpy.mean(numpy.diff(time))
        log_norm = spectral_options.get("log_norm", False)
        mode = spectral_options.get("mode", "psd")
        psd_label = mode
        if log_norm:
            psd_label = "log" + psd_label
        stf, time, freq, psd = time_spectral_analysis(
            data,
            fs,
            freq=freq,
            mode=mode,
            nfft=spectral_options.get("nfft"),
            window=spectral_options.get("window", 'hanning'),
            nperseg=spectral_options.get("nperseg", int(numpy.round(fs / 4))),
            detrend=spectral_options.get("detrend", 'constant'),
            noverlap=spectral_options.get("noverlap"),
            f_low=spectral_options.get("f_low", 10.0),
            log_scale=spectral_options.get("log_scale", False))
        min_val = numpy.min(stf.flatten())
        max_val = numpy.max(stf.flatten())
        if nS > 2:
            figsize = FiguresConfig.VERY_LARGE_SIZE
        fig = pyplot.figure(title, figsize=figsize)
        fig.suptitle(title)
        gs = gridspec.GridSpec(nS, 23)
        ax = numpy.empty((nS, 2), dtype="O")
        img = numpy.empty((nS, ), dtype="O")
        line = numpy.empty((nS, ), dtype="O")
        for iS in range(nS, -1, -1):
            if iS < nS - 1:
                ax[iS, 0] = pyplot.subplot(gs[iS, :20], sharex=ax[iS, 0])
                ax[iS, 1] = pyplot.subplot(gs[iS, 20:22],
                                           sharex=ax[iS, 1],
                                           sharey=ax[iS, 0])
            else:
                # TODO: find and correct bug here
                ax[iS, 0] = pyplot.subplot(gs[iS, :20])
                ax[iS, 1] = pyplot.subplot(gs[iS, 20:22], sharey=ax[iS, 0])
            img[iS] = ax[iS, 0].imshow(numpy.squeeze(stf[:, :, iS]).T,
                                       cmap=pyplot.set_cmap('jet'),
                                       interpolation='none',
                                       norm=Normalize(vmin=min_val,
                                                      vmax=max_val),
                                       aspect='auto',
                                       origin='lower',
                                       extent=(time.min(), time.max(),
                                               freq.min(), freq.max()))
            # img[iS].clim(min_val, max_val)
            ax[iS, 0].set_title(labels[iS])
            ax[iS, 0].set_ylabel("Frequency (Hz)")
            line[iS] = ax[iS, 1].plot(psd[:, iS], freq, 'k', label=labels[iS])
            pyplot.setp(ax[iS, 1].get_yticklabels(), visible=False)
            # ax[iS, 1].yaxis.tick_right()
            # ax[iS, 1].yaxis.set_ticks_position('both')
            if iS == (nS - 1):
                ax[iS, 0].set_xlabel("Time (" + time_units + ")")

                ax[iS, 1].set_xlabel(psd_label)
            else:
                pyplot.setp(ax[iS, 0].get_xticklabels(), visible=False)
            pyplot.setp(ax[iS, 1].get_xticklabels(), visible=False)
            ax[iS, 0].autoscale(tight=True)
            ax[iS, 1].autoscale(tight=True)
        # make a color bar
        cax = pyplot.subplot(gs[:, 22])
        pyplot.colorbar(img[0], cax=pyplot.subplot(
            gs[:, 22]))  # fraction=0.046, pad=0.04) #fraction=0.15, shrink=1.0
        cax.set_title(psd_label)
        self._save_figure(pyplot.gcf(), figure_name)
        self._check_show()
        return fig, ax, img, line, time, freq, stf, psd