def __init__(self,
              n_grc_dend,
              connectivity_rule,
              input_spatial_correlation_scale,
              active_mf_fraction,
              gaba_scale,
              dta,
              inh_cond_scaling,
              exc_cond_scaling,
              modulation_frequency,
              stim_rate_mu,
              stim_rate_sigma,
              noise_rate_mu,
              noise_rate_sigma,
              n_stim_patterns,
              n_trials,
              sim_duration,
              ana_duration,
              training_size,
              multineuron_metric_mixing,
              linkage_method,
              tau,
              dt):
     #--analysis-specific coordinates
     self.ana_duration = ana_duration
     self.training_size = int(round(training_size))
     self.multineuron_metric_mixing = multineuron_metric_mixing
     self.linkage_method = int(round(linkage_method))
     self.tau = tau
     self.dt = dt
     self.linkage_method_string = ['ward', 'kmeans'][self.linkage_method]
     super(ParameterSpacePoint, self).__init__(n_grc_dend,
                                               connectivity_rule,
                                               input_spatial_correlation_scale,
                                               active_mf_fraction,
                                               gaba_scale,
                                               dta,
                                               inh_cond_scaling,
                                               exc_cond_scaling,
                                               modulation_frequency,
                                               stim_rate_mu,
                                               stim_rate_sigma,
                                               noise_rate_mu,
                                               noise_rate_sigma,
                                               n_stim_patterns,
                                               n_trials,
                                               sim_duration)
     #--useful quantities
     self.sim_transient_time = self.sim_duration - self.ana_duration
     #--archive objects
     self.spikes_arch = SpikesArchive(self)
     self.results_arch = ResultsArchive(self)
class ParameterSpacePoint(SimpleParameterSpacePoint):
    def __init__(self,
                 n_grc_dend,
                 connectivity_rule,
                 input_spatial_correlation_scale,
                 active_mf_fraction,
                 gaba_scale,
                 dta,
                 inh_cond_scaling,
                 exc_cond_scaling,
                 modulation_frequency,
                 stim_rate_mu,
                 stim_rate_sigma,
                 noise_rate_mu,
                 noise_rate_sigma,
                 n_stim_patterns,
                 n_trials,
                 sim_duration,
                 ana_duration,
                 training_size,
                 multineuron_metric_mixing,
                 linkage_method,
                 tau,
                 dt):
        #--analysis-specific coordinates
        self.ana_duration = ana_duration
        self.training_size = int(round(training_size))
        self.multineuron_metric_mixing = multineuron_metric_mixing
        self.linkage_method = int(round(linkage_method))
        self.tau = tau
        self.dt = dt
        self.linkage_method_string = ['ward', 'kmeans'][self.linkage_method]
        super(ParameterSpacePoint, self).__init__(n_grc_dend,
                                                  connectivity_rule,
                                                  input_spatial_correlation_scale,
                                                  active_mf_fraction,
                                                  gaba_scale,
                                                  dta,
                                                  inh_cond_scaling,
                                                  exc_cond_scaling,
                                                  modulation_frequency,
                                                  stim_rate_mu,
                                                  stim_rate_sigma,
                                                  noise_rate_mu,
                                                  noise_rate_sigma,
                                                  n_stim_patterns,
                                                  n_trials,
                                                  sim_duration)
        #--useful quantities
        self.sim_transient_time = self.sim_duration - self.ana_duration
        #--archive objects
        self.spikes_arch = SpikesArchive(self)
        self.results_arch = ResultsArchive(self)
    def __repr__(self):
        simple = self.simple_representation()
        simple_args = simple.split('(')[1].split(')')[0]
        return('ParameterSpacePoint({0},{1},{2},{3},{4},{5},{6})'.format(simple_args, self.ana_duration, self.training_size, self.multineuron_metric_mixing, self.linkage_method, self.tau, self.dt))
    def __str__(self):
        analysis_specific_repr = " |@| adur: {0} | train: {1} | mix: {2} | link: {3} | tau: {4} | dt: {5}".format(self.ana_duration, self.training_size, self.multineuron_metric_mixing, self.linkage_method_string[self.linkage_method], self.tau, self.dt)
        return super(ParameterSpacePoint, self).__str__() + analysis_specific_repr
    def simple_representation(self):
        """Describe the point as if it were a SimpleParameterSpacePoint."""
        return super(ParameterSpacePoint, self).__repr__()
    def simple_representation_without_commas(self):
        return super(ParameterSpacePoint, self).representation_without_commas()
    def representation_without_commas(self):
        # sanitised version of the Point representation, with commas
        # replaced by | signs. This is needed because of a known bug
        # in Legion's version of JSV which freaks out when script
        # arguments contain commas.
        return self.__repr__().replace(',', '+')

    def is_spike_archive_compatible(self, path):
        """Check if the archive at the given path, if present, is suitable
        for providing the simulation data necessary to perform the
        analysis specified by this data point. Simulation duration and
        number of stimulus patterns need to be greater in the archive
        than in the analysis settings, while the number of trials that
        can be extracted from the archive can depend (via time
        slicing) on the length of the original simulations compared to
        the length of the analysis duration and the transient time (ie
        sim_duration-ana_duration) we are asking for.

        """
        path_sdur = float(path.rstrip('.hdf5').partition('sdur')[2])
        path_n_trials = float(path.rpartition('_t')[2].partition('_sdur')[0]) * (1 + max(0, (path_sdur - self.sim_duration)//(self.ana_duration + self.SIM_DECORRELATION_TIME)))
        path_spn = float(path.rpartition('_t')[0].rpartition('sp')[2])
        sdur_c = path_sdur >= self.sim_duration
        n_trials_c = path_n_trials >= self.n_trials
        spn_c = path_spn >= self.n_stim_patterns
        return all([sdur_c, n_trials_c, spn_c])

    def get_cell_positions(self):
        cell_positions = {'MFs':np.zeros(shape=(self.n_mf, 3)),
                          'GrCs':np.zeros(shape=(self.n_grc, 3))}
        for node in self.network_graph.nodes():
            cell, group_name = self.nC_cell_index_from_graph_node(node)
            cell_positions[group_name][cell,0] = self.network_graph.node[node]['x']
            cell_positions[group_name][cell,1] = self.network_graph.node[node]['y']
            cell_positions[group_name][cell,2] = self.network_graph.node[node]['z']
        return cell_positions

    #-------------------
    # Simulation methods
    #-------------------
    
    #-------------------
    # Compression methods
    #-------------------
    def run_compression(self):
        pass
    #-------------------
    # Analysis methods
    #-------------------
    def run_analysis(self):
        if self.results_arch.load():
            # we have the results already (loaded in memory or on the disk)
            pass
        else:
            # check if the spikes archive to analyse is actually present on disk
            if not os.path.isfile(self.spike_archive_path):
                raise Exception("Spike archive {} not found! aborting analysis.".format(self.spike_archive_path))
            # we actually need to calculate them
            print("Analysing for: {0} from spike archive: {1}".format(self, self.spike_archive_path))
            n_obs = self.n_stim_patterns * self.n_trials
            # load data
            min_clusts_analysed = int(round(self.n_stim_patterns * 1.0))
            max_clusts_analysed = int(round(self.n_stim_patterns * 1.0))
            clusts_step = max(int(round(self.n_stim_patterns * 0.05)), 1)
            # choose training and testing set: trials are picked at random, but every stim pattern is represented equally (i.e., get the same number of trials) in both sets. Trials are ordered with respect to their stim pattern.
            n_tr_obs_per_sp = self.training_size
            n_ts_obs_per_sp = self.n_trials - n_tr_obs_per_sp
            train_idxs = list(itertools.chain(*([x+self.n_trials*sp for x in random.sample(range(self.n_trials), n_tr_obs_per_sp)] for sp in range(self.n_stim_patterns))))
            test_idxs = [x for x in range(n_obs) if x not in train_idxs]
            n_tr_obs = len(train_idxs)
            n_ts_obs = len(test_idxs)
            Ym = self.n_stim_patterns
            Ny = np.array([n_ts_obs_per_sp for each in range(self.n_stim_patterns)])
            Xn = 1 # the output is effectively one-dimensional
            # initialize data structures for storage of results
            ts_decoded_mi_plugin = np.zeros(n_obs)
            ts_decoded_mi_qe = np.zeros(n_obs)
            ts_decoded_mi_pt = np.zeros(n_obs)
            ts_decoded_mi_nsb = np.zeros(n_obs)

            # compute mutual information by using direct clustering on training data (REMOVED)
            # --note: fcluster doesn't work in the border case with n_clusts=n_obs, as it never returns the trivial clustering. Cluster number 0 is never present in a clustering.
            print('counting spikes in output spike trains')
            i_level_array = self.spikes_arch.get_spike_counts(cell_type='mf')
            o_level_array = self.spikes_arch.get_spike_counts(cell_type='grc')
            print('computing mean input and output spike counts')
            i_mean_count = i_level_array.mean()
            o_mean_count = o_level_array.mean()
            print('computing input and output sparsity')
            i_sparseness_hoyer = hoyer_sparseness(i_level_array)
            i_sparseness_activity = activity_sparseness(i_level_array)
            i_sparseness_vinje = vinje_sparseness(i_level_array)
            o_sparseness_hoyer = hoyer_sparseness(o_level_array)
            o_sparseness_activity = activity_sparseness(o_level_array)
            o_sparseness_vinje = vinje_sparseness(o_level_array)
            print('input sparseness: hoyer {:.2f}, vinje {:.2f}, activity {:.2f}'.format(i_sparseness_hoyer, i_sparseness_vinje, i_sparseness_activity))
            print('output sparseness: hoyer {:.2f}, vinje {:.2f}, activity {:.2f}'.format(o_sparseness_hoyer, o_sparseness_vinje, o_sparseness_activity))
            if self.linkage_method_string == 'kmeans':
                spike_counts = o_level_array
                # divide spike count data in training and testing set
                tr_spike_counts = np.array([spike_counts[o] for o in train_idxs])
                ts_spike_counts = np.array([spike_counts[o] for o in test_idxs])
                for n_clusts in range(min_clusts_analysed, max_clusts_analysed+1, clusts_step):
                    clustering = KMeans(n_clusters=n_clusts)
                    print('performing k-means clustering on training set (training the decoder) for k='+str(n_clusts))
                    clustering.fit(tr_spike_counts)
                    print('using the decoder trained with k-means clustering to classify data points in testing set')
                    decoded_output = clustering.predict(ts_spike_counts)
                    # calculate MI
                    print('calculating MI')
                    Xm = n_clusts
                    X_dims = (Xn, Xm)
                    X = decoded_output
                    s = pe.SortedDiscreteSystem(X, X_dims, Ym, Ny)
                    s.calculate_entropies(method='plugin', calc=['HX', 'HXY'])
                    ts_decoded_mi_plugin[n_clusts-1] = s.I()
                    s.calculate_entropies(method='qe', sampling='naive', calc=['HX', 'HXY'], qe_method='plugin')
                    ts_decoded_mi_qe[n_clusts-1] = s.I()
                    s.calculate_entropies(method='pt', sampling='naive', calc=['HX', 'HXY'])
                    ts_decoded_mi_pt[n_clusts-1] = s.I()
                    s.calculate_entropies(method='nsb', sampling='naive', calc=['HX', 'HXY'])
                    ts_decoded_mi_nsb[n_clusts-1] = s.I()            
            else:
                tr_tree = np.zeros(shape=(n_tr_obs-1, 3))
                import pymuvr
                spikes = self.spikes_arch.get_spikes(cell_type='grc')
                self.spikes_arch.load_attrs()
                tr_spikes = [spikes[o] for o in train_idxs]
                ts_spikes = [spikes[o] for o in test_idxs]

                # compute multineuron distance between each pair of training observations
                print('calculating distances between training observations')
                tr_distances = pymuvr.square_distance_matrix(tr_spikes,
                                                             self.multineuron_metric_mixing,
                                                             self.tau)
                # cluster training data
                print('clustering training data')
                tr_tree = linkage(tr_distances, method=self.linkage_method_string)

                # train the decoder and use it to calculate mi on the testing dataset
                print("training the decoder and using it to calculate mi on test data")

                tr_distances_square = np.square(tr_distances)

                for n_clusts in range(min_clusts_analysed, max_clusts_analysed+1):
                    # iterate over the number of clusters and, step by
                    # step, train the decoder and use it to calculate mi
                    tr_clustering = fcluster(tr_tree, t=n_clusts, criterion='maxclust')
                    out_alphabet = []
                    for c in range(1,n_clusts+1):
                        # every cluster is represented in the output
                        # alphabet by the element which minimizes the sum
                        # of intra-cluster square distances
                        obs_in_c = [ob for ob in range(n_tr_obs) if tr_clustering[ob]==c]
                        sum_of_intracluster_square_distances = tr_distances_square[obs_in_c,:][:,obs_in_c].sum(axis=1)
                        out_alphabet.append(tr_spikes[np.argmin(sum_of_intracluster_square_distances)])
                    distances = pymuvr.distance_matrix(ts_spikes,
                                                       out_alphabet,
                                                       self.multineuron_metric_mixing,
                                                       self.tau)
                    # each observation in the testing set is decoded by
                    # assigning it to the cluster whose representative
                    # element it's closest to
                    decoded_output = distances.argmin(axis=1)
                    # calculate MI
                    Xm = n_clusts
                    X_dims = (Xn, Xm)
                    X = decoded_output
                    s = pe.SortedDiscreteSystem(X, X_dims, Ym, Ny)
                    s.calculate_entropies(method='qe', sampling='naive', calc=['HX', 'HXY'], qe_method='plugin')
                    ts_decoded_mi_qe[n_clusts-1] = s.I()
                    s.calculate_entropies(method='pt', sampling='naive', calc=['HX', 'HXY'])
                    ts_decoded_mi_pt[n_clusts-1] = s.I()
                    s.calculate_entropies(method='nsb', sampling='naive', calc=['HX', 'HXY'])
                    ts_decoded_mi_nsb[n_clusts-1] = s.I()            
                    if n_clusts == self.n_stim_patterns:
                        px_at_same_size_point = s.PX 
                # save linkage tree to results archive (only if
                # performing hierarchical clustering)
                self.results_arch.update_result('tr_linkage', data=tr_tree)

            # save analysis results in the archive
            print('updating results archive')
            self.results_arch.update_result('ts_decoded_mi_plugin', data=ts_decoded_mi_plugin)
            self.results_arch.update_result('ts_decoded_mi_qe', data=ts_decoded_mi_qe)
            self.results_arch.update_result('ts_decoded_mi_pt', data=ts_decoded_mi_pt)
            self.results_arch.update_result('ts_decoded_mi_nsb', data=ts_decoded_mi_nsb)

            self.results_arch.update_result('i_mean_count', data=i_mean_count)
            self.results_arch.update_result('o_mean_count', data=o_mean_count)
            self.results_arch.update_result('i_sparseness_hoyer', data=i_sparseness_hoyer)
            self.results_arch.update_result('i_sparseness_activity', data=i_sparseness_activity)
            self.results_arch.update_result('i_sparseness_vinje', data=i_sparseness_vinje)
            self.results_arch.update_result('o_sparseness_hoyer', data=o_sparseness_hoyer)
            self.results_arch.update_result('o_sparseness_activity', data=o_sparseness_activity)
            self.results_arch.update_result('o_sparseness_vinje', data=o_sparseness_vinje)
            # update attributes
            self.results_arch.load()