Example #1
0
 def set_shank(self, shank):
     """Change the current shank and read the corresponding tables."""
     if not shank in self.shanks:
         warn("Shank {0:d} is not in the list of shanks: {1:s}".format(
             shank, str(self.shanks)))
     self.shank = shank
     self.shank_path = '/shanks/shank{0:d}'.format(self.shank)
Example #2
0
 def __getattr__(self, key):
     # Do not override if key is an attribute of this class.
     if key.startswith('_'):
         try:
             return self.__dict__[key]
         # Accept nodewrapper._method if _method is a method of the PyTables
         # Node object.
         except KeyError:
             return getattr(self._node, key)
     try:
         # Return the wrapped node if the child is a group.
         attr = getattr(self._node, key)
         if isinstance(attr, tb.Group):
             return NodeWrapper(attr)
         else:
             return attr
     # Return the attribute.
     except:
         try:
             return self._node._f_getAttr(key)
         except AttributeError:
             # NOTE: old format
             if key == 'n_features_per_channel':
                 return self._node._f_getAttr('nfeatures_per_channel')
             warn(("{key} needs to be an attribute of "
                  "{node}").format(key=key, node=self._node._v_name))
             return None
Example #3
0
 def __getattr__(self, key):
     # Do not override if key is an attribute of this class.
     if key.startswith('_'):
         try:
             return self.__dict__[key]
         # Accept nodewrapper._method if _method is a method of the PyTables
         # Node object.
         except KeyError:
             return getattr(self._node, key)
     try:
         # Return the wrapped node if the child is a group.
         attr = getattr(self._node, key)
         if isinstance(attr, tb.Group):
             return NodeWrapper(attr)
         else:
             return attr
     # Return the attribute.
     except:
         try:
             return self._node._f_getAttr(key)
         except AttributeError:
             # NOTE: old format
             if key == 'n_features_per_channel':
                 return self._node._f_getAttr('nfeatures_per_channel')
             warn(("{key} needs to be an attribute of "
                   "{node}").format(key=key, node=self._node._v_name))
             return None
Example #4
0
 def read_res(self):
     try:
         self.spiketimes_res = read_res(self.filename_res, self.freq)
         self.spiketimes_res = pd.Series(self.spiketimes_res,
                                         dtype=np.float32)
     except IOError:
         warn("The RES file is missing.")
Example #5
0
    def compute_cluster_statistics(self, spikes_in_clusters):
        """Compute the statistics of all clusters."""

        nspikes, ndims = self.features.shape
        nclusters = len(spikes_in_clusters)
        LogP = np.zeros((nspikes, nclusters))
        stats = {}

        for c in spikes_in_clusters:
            # "my" refers to "my cluster"
            myspikes = spikes_in_clusters[c]
            myfeatures = np.take(self.y, myspikes, axis=0).astype(np.float64)
            nmyspikes = len(myfeatures)
            mymasks = np.take(self.masks, myspikes, axis=0)
            mymean = np.mean(myfeatures, axis=0).reshape((1, -1))
            # Boolean vector of size (nchannels,): which channels are unmasked?
            unmask = ((mymasks>0).sum(axis=0) > self.unmask_threshold)
            mask = ~unmask
            nunmask = np.sum(unmask)
            if nmyspikes <= 1 or nunmask == 0:
                mymean = np.zeros((1, myfeatures.shape[1]))
                covmat = 1e-3 * np.eye(nunmask)  # optim: nactivefeatures
                stats[c] = (mymean, covmat,
                            (1e-3)**ndims, nmyspikes,
                            np.zeros(ndims, dtype=np.bool)  # unmask
                            )
                continue

            # optimization: covmat only for submatrix of active features
            covmat = np.cov(myfeatures[:, unmask], rowvar=0) # stats for cluster c

            # Variation Bayesian approximation
            priorpoint = 1
            covmat *= (nmyspikes - 1)  # get rid of the normalization factor
            covmat += self.D[unmask, unmask] * priorpoint  # D = np.diag(sigma2.ravel())
            covmat /= (nmyspikes + priorpoint - 1)

            # the eta just for the current cluster
            etac = np.take(self.eta, myspikes, axis=0)
            # optimization: etac just for active features
            etac = etac[:, unmask]
            d = np.mean(etac, axis=0)

            # Handle nmasked == 0
            d[np.isnan(d)] = 0

            # add diagonal
            covmat += np.diag(d)

            # Compute the det of the covmat
            _sign, logdet = np.linalg.slogdet(covmat)
            if _sign < 0:
                warn("The correlation matrix of cluster %d has a negative determinant (whaaat??)" % c)

            stats[int(c)] = (mymean, covmat, logdet, nmyspikes, unmask)

        self.stats.update(stats)
Example #6
0
 def __getattr__(self, key):
     try:
         return self.__dict__[key]
     except:
         try:
             return self._node._f_getAttr(key)
         except AttributeError:
             warn(("{key} needs to be an attribute of "
                  "{node}").format(key=key, node=self._node._v_name))
             return None
Example #7
0
 def __getattr__(self, key):
     try:
         return self.__dict__[key]
     except:
         try:
             return self._node._f_getAttr(key)
         except AttributeError:
             warn(("{key} needs to be an attribute of "
                   "{node}").format(key=key, node=self._node._v_name))
             return None
Example #8
0
 def read_waveforms(self):
     try:
         self.waveforms = read_waveforms(self.filename_spk, self.nsamples,
                                         self.nchannels)
         info("Successfully loaded {0:s}".format(self.filename_spk))
     except IOError:
         warn("The SPK file is missing.")
         self.waveforms = np.zeros((self.nspikes, self.nsamples, 
             self.nchannels))
     # Convert to Pandas.
     self.waveforms = pd.Panel(self.waveforms, dtype=np.float32)
Example #9
0
def _read_traces(files, dtype=None, n_channels=None):
    kwd_path = None
    dat_path = None
    kwik = files['kwik']

    recordings = kwik.root.recordings
    traces = []
    # opened_files = []
    for recording in recordings:
        # Is there a path specified to a .raw.kwd file which exists in
        # [KWIK]/recordings/[X]/raw? If so, open it.
        raw = recording.raw
        if 'hdf5_path' in raw._v_attrs:
            kwd_path = raw._v_attrs.hdf5_path[:-8]
            kwd = files['raw.kwd']
            if kwd is None:
                debug("%s not found, trying same basename in KWIK dir" %
                      kwd_path)
            else:
                debug("Loading traces: %s" % kwd_path)
                traces.append(
                    kwd.root.recordings._f_getChild(str(
                        recording._v_name)).data)
                # opened_files.append(kwd)
                continue
        # Is there a path specified to a .dat file which exists?
        if 'dat_path' in raw._v_attrs:
            dtype = kwik.root.application_data.spikedetekt._v_attrs.dtype[0]
            if dtype:
                dtype = np.dtype(dtype)

            n_channels = kwik.root.application_data.spikedetekt._v_attrs. \
                n_channels
            if n_channels:
                n_channels = int(n_channels)

            assert dtype is not None
            assert n_channels
            dat_path = raw._v_attrs.dat_path
            if not op.exists(dat_path):
                debug("%s not found, trying same basename in KWIK dir" %
                      dat_path)
            else:
                debug("Loading traces: %s" % dat_path)
                dat = _dat_to_traces(dat_path,
                                     dtype=dtype,
                                     n_channels=n_channels)
                traces.append(dat)
                # opened_files.append(dat)
                continue

    if not traces:
        warn("No traces found: the waveforms won't be available.")
    return _concatenate_virtual_arrays(traces)
 def read_waveforms(self):
     try:
         self.waveforms = read_waveforms(self.filename_spk, self.nsamples,
                                         self.nchannels)
         info("Successfully loaded {0:s}".format(self.filename_spk))
     except IOError:
         warn("The SPK file is missing.")
         self.waveforms = np.zeros((self.nspikes, self.nsamples, 
             self.nchannels))
     # Convert to Pandas.
     self.waveforms = pd.Panel(self.waveforms, dtype=np.float32)
 def read_masks(self):
     try:
         self.masks, self.masks_full = read_masks(self.filename_mask,
                                                  self.fetdim)
         info("Successfully loaded {0:s}".format(self.filename_mask))
     except IOError:
         warn("The MASKS/FMASKS file is missing.")
         # Default masks if the MASK/FMASK file is not available.
         self.masks = np.ones((self.nspikes, self.nchannels))
         self.masks_full = np.ones(self.features.shape)
     self.masks = pd.DataFrame(self.masks)
     self.masks_full = pd.DataFrame(self.masks_full)
Example #12
0
 def read_masks(self):
     try:
         self.masks, self.masks_full = read_masks(self.filename_mask,
                                                  self.fetdim)
         info("Successfully loaded {0:s}".format(self.filename_mask))
     except IOError:
         warn("The MASKS/FMASKS file is missing.")
         # Default masks if the MASK/FMASK file is not available.
         self.masks = np.ones((self.nspikes, self.nchannels))
         self.masks_full = np.ones(self.features.shape)
     self.masks = pd.DataFrame(self.masks)
     self.masks_full = pd.DataFrame(self.masks_full)
Example #13
0
 def remove_group(self, group):
     """Remove an empty group. Raise an error if the group is not empty."""
     groupidx = group.groupidx()
     # check that the group is empty
     if self.get_channels_in_group(groupidx):
         raise ValueError("group %d is not empty, unable to delete it" % \
                 groupidx)
     groups = [g for g in self.get_groups() if g.groupidx() == groupidx]
     if groups:
         group = groups[0]
         self.remove_node(group)
     else:
         log.warn("Group %d does not exist0" % groupidx)
Example #14
0
 def remove_group(self, group):
     """Remove an empty group. Raise an error if the group is not empty."""
     groupidx = group.groupidx()
     # check that the group is empty
     if self.get_channels_in_group(groupidx):
         raise ValueError("group %d is not empty, unable to delete it" % \
                 groupidx)
     groups = [g for g in self.get_groups() if g.groupidx() == groupidx]
     if groups:
         group = groups[0]
         self.remove_node(group)
     else:
         log.warn("Group %d does not exist0" % groupidx)
Example #15
0
def _read_traces(files, dtype=None, n_channels=None):
    kwd_path = None
    dat_path = None
    kwik = files['kwik']

    recordings = kwik.root.recordings
    traces = []
    # opened_files = []
    for recording in recordings:
        # Is there a path specified to a .raw.kwd file which exists in
        # [KWIK]/recordings/[X]/raw? If so, open it.
        raw = recording.raw
        if 'hdf5_path' in raw._v_attrs:
            kwd_path = raw._v_attrs.hdf5_path[:-8]
            kwd = files['raw.kwd']
            if kwd is None:
                debug("%s not found, trying same basename in KWIK dir" %
                      kwd_path)
            else:
                debug("Loading traces: %s" % kwd_path)
                traces.append(kwd.root.recordings._f_getChild(str(recording._v_name)).data)
                # opened_files.append(kwd)
                continue
        # Is there a path specified to a .dat file which exists?
        if 'dat_path' in raw._v_attrs:
            dtype = kwik.root.application_data.spikedetekt._v_attrs.dtype[0]
            if dtype:
                dtype = np.dtype(dtype)

            n_channels = kwik.root.application_data.spikedetekt._v_attrs. \
                n_channels
            if n_channels:
                n_channels = int(n_channels)

            assert dtype is not None
            assert n_channels
            dat_path = raw._v_attrs.dat_path
            if not op.exists(dat_path):
                debug("%s not found, trying same basename in KWIK dir" %
                      dat_path)
            else:
                debug("Loading traces: %s" % dat_path)
                dat = _dat_to_traces(dat_path, dtype=dtype,
                                     n_channels=n_channels)
                traces.append(dat)
                # opened_files.append(dat)
                continue

    if not traces:
        warn("No traces found: the waveforms won't be available.")
    return _concatenate_virtual_arrays(traces)
Example #16
0
    def set_data(self, cluster_groups=None, similarity_matrix=None):
        """Update the data."""
        
        if cluster_groups is not None:
            self.clusters_unique = get_array(get_indices(cluster_groups))
            self.cluster_groups = get_array(cluster_groups)
            
        if (similarity_matrix is not None and similarity_matrix.size > 0):

            if len(get_array(cluster_groups)) != similarity_matrix.shape[0]:
                log.warn(("Cannot update the wizard: cluster_groups "
                    "has {0:d} elements whereas the similarity matrix has {1:d}.").format(
                        len(get_array(cluster_groups)), similarity_matrix.shape[0]))
                return

            self.matrix = similarity_matrix
            self.quality = np.diag(self.matrix)
Example #17
0
 def read_clusters(self):
     try:
         # Try reading the ACLU file, or fallback on the CLU file.
         if os.path.exists(self.filename_aclu):
             self.clusters = read_clusters(self.filename_aclu)
             info("Successfully loaded {0:s}".format(self.filename_aclu))
         else:
             self.clusters = read_clusters(self.filename_clu)
             info("Successfully loaded {0:s}".format(self.filename_clu))
     except IOError:
         warn("The CLU file is missing.")
         # Default clusters if the CLU file is not available.
         self.clusters = np.zeros(self.nspikes, dtype=np.int32)
     # Convert to Pandas.
     self.clusters = pd.Series(self.clusters, dtype=np.int32)
     
     # Count clusters.
     self._update_data()
Example #18
0
    def _consistency_check(self):
        exp = self.experiment
        chgrp = self.shank

        cg = exp.channel_groups[chgrp]
        clusters = cg.clusters.main.keys()
        clusters_unique = np.unique(cg.spikes.clusters.main[:])

        # Find missing clusters in the kwik file.
        missing = sorted(set(clusters_unique)-set(clusters))

        # Add all missing clusters with a default color and "Unsorted" cluster group (group #3).
        for idx in missing:
            warn("Consistency check: adding cluster %d in the kwik file" % idx)
            add_cluster(exp._files, channel_group_id='%d' % chgrp,
                        id=idx,
                        clustering='main',
                        cluster_group=3)
 def read_clusters(self):
     try:
         # Try reading the ACLU file, or fallback on the CLU file.
         if os.path.exists(self.filename_aclu):
             self.clusters = read_clusters(self.filename_aclu)
             info("Successfully loaded {0:s}".format(self.filename_aclu))
         else:
             self.clusters = read_clusters(self.filename_clu)
             info("Successfully loaded {0:s}".format(self.filename_clu))
     except IOError:
         warn("The CLU file is missing.")
         # Default clusters if the CLU file is not available.
         self.clusters = np.zeros(self.nspikes, dtype=np.int32)
     # Convert to Pandas.
     self.clusters = pd.Series(self.clusters, dtype=np.int32)
     
     # Count clusters.
     self._update_data()
Example #20
0
    def set_data(self, cluster_groups=None, similarity_matrix=None):
        """Update the data."""

        if cluster_groups is not None:
            self.clusters_unique = get_array(get_indices(cluster_groups))
            self.cluster_groups = get_array(cluster_groups)

        if (similarity_matrix is not None and similarity_matrix.size > 0):

            if len(get_array(cluster_groups)) != similarity_matrix.shape[0]:
                log.warn((
                    "Cannot update the wizard: cluster_groups "
                    "has {0:d} elements whereas the similarity matrix has {1:d}."
                ).format(len(get_array(cluster_groups)),
                         similarity_matrix.shape[0]))
                return

            self.matrix = similarity_matrix
            self.quality = np.diag(self.matrix)
Example #21
0
    def _consistency_check(self):
        exp = self.experiment
        chgrp = self.shank

        cg = exp.channel_groups[chgrp]
        clusters = cg.clusters.main.keys()
        clusters_unique = np.unique(cg.spikes.clusters.main[:])

        # Find missing clusters in the kwik file.
        missing = sorted(set(clusters_unique) - set(clusters))

        # Add all missing clusters with a default color and "Unsorted" cluster group (group #3).
        for idx in missing:
            warn("Consistency check: adding cluster %d in the kwik file" % idx)
            add_cluster(exp._files,
                        channel_group_id='%d' % chgrp,
                        id=idx,
                        clustering='main',
                        cluster_group=3)
Example #22
0
    def set_shank(self, shank):
        """Change the current shank and read the corresponding tables."""
        if not shank in self.shanks:
            warn("Shank {0:d} is not in the list of shanks: {1:s}".format(
                shank, str(self.shanks)))
            return
        self.shank = shank

        # CONSISTENCY CHECK
        # self._consistency_check()

        self.nchannels = len(
            self.experiment.channel_groups[self.shank].channels)

        clusters = self.experiment.channel_groups[
            self.shank].spikes.clusters.main[:]
        self.clusters = pd.Series(clusters, dtype=np.int32)
        self.nspikes = len(self.clusters)

        self.features = self.experiment.channel_groups[
            self.shank].spikes.features
        self.masks = self.experiment.channel_groups[self.shank].spikes.masks
        self.waveforms = self.experiment.channel_groups[
            self.shank].spikes.waveforms_filtered

        if self.features is not None:
            nfet = self.features.shape[1]
            self.nextrafet = (nfet - self.nchannels * self.fetdim)
        else:
            self.nextrafet = 0

        # Load concatenated time samples: those are the time samples +
        # the start time of the corresponding recordings.
        spiketimes = self.experiment.channel_groups[
            self.shank].spikes.concatenated_time_samples[:] * (1. / self.freq)
        self.spiketimes = pd.Series(spiketimes, dtype=np.float64)
        self.duration = spiketimes[-1]

        self._update_data()

        self.read_clusters()
Example #23
0
    def set_shank(self, shank):
        """Change the current shank and read the corresponding tables."""
        if not shank in self.shanks:
            warn("Shank {0:d} is not in the list of shanks: {1:s}".format(
                shank, str(self.shanks)))
            return
        self.shank = shank

        # CONSISTENCY CHECK
        # self._consistency_check()

        self.nchannels = len(self.experiment.channel_groups[self.shank].channels)

        clusters = self.experiment.channel_groups[self.shank].spikes.clusters.main[:]
        self.clusters = pd.Series(clusters, dtype=np.int32)
        self.nspikes = len(self.clusters)

        self.features = self.experiment.channel_groups[self.shank].spikes.features
        self.masks = self.experiment.channel_groups[self.shank].spikes.masks
        self.waveforms = self.experiment.channel_groups[self.shank].spikes.waveforms_filtered

        if self.features is not None:
            nfet = self.features.shape[1]
            self.nextrafet = (nfet - self.nchannels * self.fetdim)
        else:
            self.nextrafet = 0

        # Load concatenated time samples: those are the time samples +
        # the start time of the corresponding recordings.
        spiketimes = self.experiment.channel_groups[self.shank].spikes.concatenated_time_samples[:] * (1. / self.freq)
        self.spiketimes = pd.Series(spiketimes, dtype=np.float64)
        self.duration = spiketimes[-1]

        self._update_data()

        self.read_clusters()
Example #24
0
def compute_statistics(Fet1, Fet2, spikes_in_clusters, masks):
    """Return Gaussian statistics about each cluster."""

    nPoints = Fet1.shape[0] #size(Fet1, 1)
    nDims = Fet1.shape[1] #size(Fet1, 2)
    # nclusters = Clu2.max() #max(Clu2)
    nclusters = len(spikes_in_clusters)

    # Default masks.
    if masks is None:
        masks = np.ones((nPoints, nDims), dtype=np.float32)

    # precompute the mean and variances of the masked points for each feature
    # contains 1 when the corresponding point is masked
    masked = np.zeros_like(masks)
    masked[masks == 0] = 1
    nmasked = np.sum(masked, axis=0)
    nu = np.sum(Fet2 * masked, axis=0) / nmasked
    # Handle nmasked == 0.
    nu[np.isnan(nu)] = 0
    nu = nu.reshape((1, -1))
    sigma2 = np.sum(((Fet2 - nu) * masked) ** 2, axis=0) / nmasked
    sigma2[np.isnan(sigma2)] = 0
    sigma2 = sigma2.reshape((1, -1))
    D = np.diag(sigma2.ravel())
    # expected features
    y = Fet1 * masks + (1 - masks) * nu
    z = masks * Fet1**2 + (1 - masks) * (nu ** 2 + sigma2)
    eta = z - y ** 2

    LogP = np.zeros((nPoints, nclusters))

    stats = {}

    for c in spikes_in_clusters:
        # MyPoints = np.nonzero(Clu2==c)[0]
        MyPoints = spikes_in_clusters[c]
        # MyFet2 = Fet2[MyPoints, :]
        # now, take the modified features here
        # MyFet2 = y[MyPoints, :]
        MyFet2 = np.take(y, MyPoints, axis=0).astype(np.float64)
        MyMasks = np.take(masks, MyPoints, axis=0)
        # if len(MyPoints) > nDims:
        # LogProp = np.log(len(MyPoints) / float(nPoints)) # log of the proportion in cluster c
        Mean = np.mean(MyFet2, axis=0).reshape((1, -1))

        if len(MyPoints) <= 1:
            CovMat = 1e-3*np.eye(nDims)
            stats[c] = (Mean, CovMat, 1e3*np.eye(nDims),
                (1e-3)**nDims, len(MyPoints), np.zeros(nDims, dtype=np.bool))
            continue


        CovMat = np.cov(MyFet2, rowvar=0) # stats for cluster c

        # Variation Bayesian approximation
        priorPoint = 1
        CovMat *= (len(MyFet2) - 1)  # get rid of the normalization factor
        CovMat += D * priorPoint  # D = np.diag(sigma2.ravel())
        CovMat /= (len(MyFet2) + priorPoint - 1)


        # HACK: avoid instability issues, kind of works
        # CovMat += np.diag(1e-0 * np.ones(nDims))

        # now, add the diagonal modification to the covariance matrix
        # the eta just for the current cluster
        etac = np.take(eta, MyPoints, axis=0)
        d = np.mean(etac, axis=0)

        # Handle nmasked == 0
        d[np.isnan(d)] = 0

        # add diagonal
        CovMat += np.diag(d)
        # We don't compute that explicitely anymore: we solve Ax=b instead
        # CovMatinv = np.linalg.inv(CovMat)
        CovMatinv = None

        # WARNING: this is numerically instable
        # LogDet = np.log(np.linalg.det(CovMat))

        _sign, LogDet = np.linalg.slogdet(CovMat)
        if _sign < 0:
            warn("The correlation matrix of cluster %d has a negative determinant (whaaat??)" % c)

        # Boolean vector of size (nchannels,): which channels are unmasked?
        unmask = (MyMasks>0).mean(axis=0)

        stats[c] = (Mean, CovMat, CovMatinv, LogDet, len(MyPoints), unmask)

    return stats
    def compute_cluster_statistics(self, spikes_in_clusters):
        """Compute the statistics of all clusters."""

        nspikes, ndims = self.features.shape
        nclusters = len(spikes_in_clusters)
        LogP = np.zeros((nspikes, nclusters))
        stats = {}

        for c in spikes_in_clusters:
            # "my" refers to "my cluster"
            myspikes = spikes_in_clusters[c]
            myfeatures = np.take(self.y, myspikes, axis=0).astype(np.float64)
            nmyspikes = len(myfeatures)
            mymasks = np.take(self.masks, myspikes, axis=0)
            mymean = np.mean(myfeatures, axis=0).reshape((1, -1))
            # Boolean vector of size (nchannels,): which channels are unmasked?
            unmask = ((mymasks > 0).sum(axis=0) > self.unmask_threshold)
            mask = ~unmask
            nunmask = np.sum(unmask)
            if nmyspikes <= 1 or nunmask == 0:
                mymean = np.zeros((1, myfeatures.shape[1]))
                covmat = 1e-3 * np.eye(nunmask)  # optim: nactivefeatures
                stats[c] = (
                    mymean,
                    covmat,
                    (1e-3)**ndims,
                    nmyspikes,
                    np.zeros(ndims, dtype=np.bool)  # unmask
                )
                continue

            # optimization: covmat only for submatrix of active features
            covmat = np.cov(myfeatures[:, unmask],
                            rowvar=0)  # stats for cluster c

            # Variation Bayesian approximation
            priorpoint = 1
            covmat *= (nmyspikes - 1)  # get rid of the normalization factor
            covmat += self.D[
                unmask, unmask] * priorpoint  # D = np.diag(sigma2.ravel())
            covmat /= (nmyspikes + priorpoint - 1)

            # the eta just for the current cluster
            etac = np.take(self.eta, myspikes, axis=0)
            # optimization: etac just for active features
            etac = etac[:, unmask]
            d = np.mean(etac, axis=0)

            # Handle nmasked == 0
            d[np.isnan(d)] = 0

            # add diagonal
            covmat += np.diag(d)

            # Compute the det of the covmat
            _sign, logdet = np.linalg.slogdet(covmat)
            if _sign < 0:
                warn(
                    "The correlation matrix of cluster %d has a negative determinant (whaaat??)"
                    % c)

            stats[int(c)] = (mymean, covmat, logdet, nmyspikes, unmask)

        self.stats.update(stats)
Example #26
0
def compute_statistics(Fet1, Fet2, spikes_in_clusters, masks):
    """Return Gaussian statistics about each cluster."""

    nPoints = Fet1.shape[0]  #size(Fet1, 1)
    nDims = Fet1.shape[1]  #size(Fet1, 2)
    # nclusters = Clu2.max() #max(Clu2)
    nclusters = len(spikes_in_clusters)

    # Default masks.
    if masks is None:
        masks = np.ones((nPoints, nDims), dtype=np.float32)

    # precompute the mean and variances of the masked points for each feature
    # contains 1 when the corresponding point is masked
    masked = np.zeros_like(masks)
    masked[masks == 0] = 1
    nmasked = np.sum(masked, axis=0)
    nu = np.sum(Fet2 * masked, axis=0) / nmasked
    # Handle nmasked == 0.
    nu[np.isnan(nu)] = 0
    nu = nu.reshape((1, -1))
    sigma2 = np.sum(((Fet2 - nu) * masked)**2, axis=0) / nmasked
    sigma2[np.isnan(sigma2)] = 0
    sigma2 = sigma2.reshape((1, -1))
    D = np.diag(sigma2.ravel())
    # expected features
    y = Fet1 * masks + (1 - masks) * nu
    z = masks * Fet1**2 + (1 - masks) * (nu**2 + sigma2)
    eta = z - y**2

    LogP = np.zeros((nPoints, nclusters))

    stats = {}

    for c in spikes_in_clusters:
        # MyPoints = np.nonzero(Clu2==c)[0]
        MyPoints = spikes_in_clusters[c]
        # MyFet2 = Fet2[MyPoints, :]
        # now, take the modified features here
        # MyFet2 = y[MyPoints, :]
        MyFet2 = np.take(y, MyPoints, axis=0).astype(np.float64)
        MyMasks = np.take(masks, MyPoints, axis=0)
        # if len(MyPoints) > nDims:
        # LogProp = np.log(len(MyPoints) / float(nPoints)) # log of the proportion in cluster c
        Mean = np.mean(MyFet2, axis=0).reshape((1, -1))

        if len(MyPoints) <= 1:
            CovMat = 1e-3 * np.eye(nDims)
            stats[c] = (Mean, CovMat, 1e3 * np.eye(nDims), (1e-3)**nDims,
                        len(MyPoints), np.zeros(nDims, dtype=np.bool))
            continue

        CovMat = np.cov(MyFet2, rowvar=0)  # stats for cluster c

        # Variation Bayesian approximation
        priorPoint = 1
        CovMat *= (len(MyFet2) - 1)  # get rid of the normalization factor
        CovMat += D * priorPoint  # D = np.diag(sigma2.ravel())
        CovMat /= (len(MyFet2) + priorPoint - 1)

        # HACK: avoid instability issues, kind of works
        # CovMat += np.diag(1e-0 * np.ones(nDims))

        # now, add the diagonal modification to the covariance matrix
        # the eta just for the current cluster
        etac = np.take(eta, MyPoints, axis=0)
        d = np.mean(etac, axis=0)

        # Handle nmasked == 0
        d[np.isnan(d)] = 0

        # add diagonal
        CovMat += np.diag(d)
        # We don't compute that explicitely anymore: we solve Ax=b instead
        # CovMatinv = np.linalg.inv(CovMat)
        CovMatinv = None

        # WARNING: this is numerically instable
        # LogDet = np.log(np.linalg.det(CovMat))

        _sign, LogDet = np.linalg.slogdet(CovMat)
        if _sign < 0:
            warn(
                "The correlation matrix of cluster %d has a negative determinant (whaaat??)"
                % c)

        # Boolean vector of size (nchannels,): which channels are unmasked?
        unmask = (MyMasks > 0).mean(axis=0)

        stats[c] = (Mean, CovMat, CovMatinv, LogDet, len(MyPoints), unmask)

    return stats
Example #27
0
 def read_fil(self):
     try:
         self.fil = read_dat(self.filename_fil, self.nchannels)
     except IOError:
         warn("The FIL file is missing.")
Example #28
0
 def read_dat(self):
     try:
         self.dat = read_dat(self.filename_dat, self.nchannels)
     except IOError:
         warn("The DAT file is missing.")
Example #29
0
 def read_res(self):
     try:
         self.spiketimes_res = read_res(self.filename_res, self.freq)
         self.spiketimes_res = pd.Series(self.spiketimes_res, dtype=np.float32)
     except IOError:
         warn("The RES file is missing.")
 def read_fil(self):
     try:
         self.fil = read_dat(self.filename_fil, self.nchannels)
     except IOError:
         warn("The FIL file is missing.")
 def read_dat(self):
     try:
         self.dat = read_dat(self.filename_dat, self.nchannels)
     except IOError:
         warn("The DAT file is missing.")