示例#1
0
    def load_single_unit(self, *args, **kwargs):
        """
        Call the NeuroChaT NSpike.load method.

        Returns
        -------
        dict
            The keys of this dictionary are saved as attributes
            in simuran.spatial.Spatial.load()

        """
        fname, clust_name = args
        if clust_name is not None:
            self.single_unit = NSpike()
            self.single_unit.load(fname, self.load_params["system"])
            waveforms = deepcopy(self.single_unit.get_waveform())
            for chan, val in waveforms.items():
                waveforms[chan] = val * u.uV
            return {
                "underlying": self.single_unit,
                "timestamps": self.single_unit.get_timestamp() * u.s,
                "unit_tags": self.single_unit.get_unit_tags(),
                "waveforms": waveforms,
                "date": self.single_unit.get_date(),
                "time": self.single_unit.get_time(),
                "available_units": self.single_unit.get_unit_list(),
                # "units_to_use": self.single_unit.get_unit_list(),
            }
        else:
            return None
示例#2
0
    def __init__(self, **kwargs):
        """
        Create an NClust object.

        Parameters
        ----------
        **kwargs : Keyword arguments
            spike: NSpike object,
                If directly passed an NSpike object, this is stored.
            Otherwise if spike is not NSpike or spike is not a kwarg,
            self.spike = NSpike(**kwargs)

        Returns
        -------
        None

        """
        spike = kwargs.get('spike', None)
        self.wavetime = []
        self.UPSAMPLED = False
        self.ALLIGNED = False
        self.NULL_CHAN_REMOVED = False

        if isinstance(spike, NSpike):
            self.spike = spike
        else:
            self.spike = NSpike(**kwargs)
        super().__init__(**kwargs)
示例#3
0
    def __init__(self):
        """
        Attributes
        ----------
        spatial : NSpatial
            Spatial data object
        spike : NSpike
            Spike data object
        lfp : Nlfp
            LFP data object
        hdf : NHdf
            Object for manipulating HDF5 file
        data_format : str
            Recording system or format of the data file

        """

        super().__init__()
        self.spike = NSpike(name='C0')
        self.spatial = NSpatial(name='S0')
        self.lfp = NLfp(name='L0')
        self.data_format = 'Axona'
        self._results = oDict()
        self.hdf = Nhdf()

        self.__type = 'data'
示例#4
0
    def __init__(self):
        """See NData class description."""
        super().__init__()
        self.spike = NSpike(name='C0')
        self.spatial = NSpatial(name='S0')
        self.lfp = NLfp(name='L0')
        self.data_format = 'Axona'
        self._results = oDict()
        self.hdf = Nhdf()

        self.__type = 'data'
示例#5
0
    def __init__(self, **kwargs):
        """
        Attributes
        ----------
        spike : NSpike
            An object of NSpike() class or its subclass
            
        """
        
        spike = kwargs.get('spike', None)
        self.wavetime = []
        self.UPSAMPLED = False
        self.ALLIGNED = False
        self.NULL_CHAN_REMOVED = False

        if isinstance(spike, NSpike):
            self.spike = spike
        else:
            self.spike = NSpike(**kwargs)
        super().__init__(**kwargs)
示例#6
0
    output_dict = {}
    for key in keys:
        output_dict[key] = np.nanmean(results[key])
    return output_dict, p_down_data


if __name__ == "__main__":
    """Some examples for testing the code on for correctness."""

    # Set up the recordings
    spatial = NSpatial()
    fname = r"D:\SubRet_recordings_imaging\muscimol_data\CanCSR8_muscimol\05102018\s3_after_smallsq\05102018_CanCSR8_smallsq_10_3_3.txt"
    spatial.set_filename(fname)
    spatial.load()

    spike = NSpike()
    fname = r"D:\SubRet_recordings_imaging\muscimol_data\CanCSR8_muscimol\05102018\s3_after_smallsq\05102018_CanCSR8_smallsq_10_3.3"
    spike.set_filename(fname)
    spike.load()
    spike.set_unit_no(1)

    spatial2 = NSpatial()
    fname = r"D:\SubRet_recordings_imaging\muscimol_data\CanCSR8_muscimol\05102018\s4_big sq\05102018_CanCSR8_bigsq_10_4_3.txt"
    spatial2.set_filename(fname)
    spatial2.load()

    spike2 = NSpike()
    fname = r"D:\SubRet_recordings_imaging\muscimol_data\CanCSR8_muscimol\05102018\s4_big sq\05102018_CanCSR8_bigsq_10_4.3"
    spike2.set_filename(fname)
    spike2.load()
    spike2.set_unit_no(6)
示例#7
0
class NClust(NBase):
    """
    This class facilitates clustering-related operations.

    Although no clustering algorithm is implemented in this class,
    it can be subclassed to create such algorithms.

    Many of the functions in this class are delegated to the spike attr.

    Attributes
    ----------
    spike : NSpike
        An object of NSpike() class.

    """
    def __init__(self, **kwargs):
        """
        Create an NClust object.

        Parameters
        ----------
        **kwargs : Keyword arguments
            spike: NSpike object,
                If directly passed an NSpike object, this is stored.
            Otherwise if spike is not NSpike or spike is not a kwarg,
            self.spike = NSpike(**kwargs)

        Returns
        -------
        None

        """
        spike = kwargs.get('spike', None)
        self.wavetime = []
        self.UPSAMPLED = False
        self.ALLIGNED = False
        self.NULL_CHAN_REMOVED = False

        if isinstance(spike, NSpike):
            self.spike = spike
        else:
            self.spike = NSpike(**kwargs)
        super().__init__(**kwargs)

    def get_unit_tags(self):
        """
        Return tags of the spiking waveforms from clustering.

        Parameters
        ----------
        None

        Returns
        -------
        None

        """
        return self.spike.get_unit_tags()

    def set_unit_tags(self, new_tags=None):
        """
        Return tags of the spiking waveforms from clustering.

        Parameters
        ----------
        new_tags : ndarray
            Array that contains the tags for spike-waveforms
            which is based on the cluster number.

        Returns
        -------
        None

        """
        self.spike.set_unit_tags(new_tags)

    def get_unit_list(self):
        """
        Return the list of units in a spike dataset.

        Parameters
        ----------
        None

        Returns
        -------
        list
            List of units

        """
        return self.spike.get_unit_list()

    def _set_unit_list(self):
        """
        Set the unit list.

        Delegates to NSpike._set_unit_list()

        Parameters
        ----------
        None

        Returns
        -------
        None

        See also
        --------
        nc_spike.NSPike()._set_unit_list

        """
        self.spike._set_unit_list()

    def get_timestamp(self, unit_no=None):
        """
        Return the timestamps of the spike-waveforms of specified unit.

        Parameters
        ----------
        unit_no : int
            Unit whose timestamps are to be returned

        Returns
        -------
        ndarray
            Timestamps of the spiking waveforms

        """
        self.spike.get_timestamp(unit_no=unit_no)

    def get_unit_spikes_count(self, unit_no=None):
        """
        Return the total number of spikes in a specified unit.

        Parameters
        ----------
        unit_no : int
            Unit whose count is returned

        Returns
        -------
        int
            Total number of spikes in the unit

        """
        return self.spike.get_unit_spikes_count(unit_no=unit_no)

    def get_waveform(self):
        """
        Return the waveforms in the spike dataset.

        Parameters
        ----------
        None

        Returns
        -------
        dict
            Each key represents one channel of the electrode group.
            Each value represents the waveforms of the spikes
            in a matrix form (no_samples x no_spikes)

        """
        return self.spike.get_waveform()

    def _set_waveform(self, spike_waves=[]):
        """
        Set the waveforms of the spike dataset.

        Parameters
        ----------
        spike_waves : dict
            Each key represents one channel of the electrode group.
            Each value represents the waveforms of the spikes
            in a matrix form (no_samples x no_spikes)

        Returns
        -------
        None

        """
        self.spike._set_waveform(spike_waves=spike_waves)

    def get_unit_waves(self, unit_no=None):
        """
        Return spike waveforms of a specific unit.

        Parameters
        ----------
        unit_no : int
            Unit whose waveforms are returned

        Returns
        -------
        dict
            Spike waveforms in each channel of the electrode group

        """
        return self.spike.get_unit_waves(unit_no=unit_no)

    # For multi-unit analysis,
    # {'SpikeName': cell_no} pairs should be used as function input

    def load(self, filename=None, system=None):
        """
        Load spike dataset from the file.

        Parameters
        ----------
        filename: str
            Name of the spike file
        system : str
            Name of the recording format or system.

        Returns
        -------
        None

        See Also
        --------
        nc_spike.NSpike().load()

        """
        self.spike.load(filename=filename, system=system)

    def add_spike(self, spike=None, **kwargs):
        """
        Add new spike node to current NSpike() object.

        Parameters
        ----------
        spike : NSpike
            NSPike object. If None, new object is created

        Returns
        -------
        `:obj:NSpike`
            A new NSpike() object

        """
        return self.spike.add_spike(spike=spike, **kwargs)

    def load_spike(self, names=None):
        """
        Load datasets of the spike nodes.

        The name of each node is used for obtaining the filenames.

        Parameters
        ----------
        names : list of str
            Names of the nodes to load.
            If None, current NSpike() object is loaded

        Returns
        -------
        None

        """
        self.spike.load_spike(names=names)

    def wave_property(self):
        """
        Calculate different waveform properties for currently set unit.

        Delegates to NSpike().wave_property()

        Parameters
        ----------
        None

        Returns
        -------
        dict
            Graphical data of the analysis

        See also
        --------
        NSpike().wave_property()

        """
        return self.spike.wave_property()

    def isi(self, bins='auto', bound=None, density=False):
        """
        Calculate the ISI histogram of the spike train.

        Delegates to NSpike().isi()

        Parameters
        ----------
        bins : str or int
            Number of ISI histogram bins. If 'auto', NumPy default is used
        bound : int
            Length of the ISI histogram in msec
        density : bool
            If true, normalized histogram is calculated

        Returns
        -------
        dict
            Graphical data of the analysis

        See also
        --------
        NSpike().isi()

        """
        return self.spike.isi(bins=bins, bound=bound, density=density)

    def isi_corr(self, **kwargs):
        """
        Analysis of ISI autocorrelation histogram.

        Delegates to NSpike().isi_auto_corr()

        Parameters
        ----------
        **kwargs
            Keyword arguments

        Returns
        -------
        dict
            Graphical data of the analysis

        See also
        --------
        nc_spike.NSpike().isi_corr

        """
        return self.spike.isi_corr(**kwargs)

    def psth(self, event_stamp, **kwargs):
        """
        Calculate peri-stimulus time histogram (PSTH).

        Delegates to NSpike().psth()

        Parameters
        ----------
        event_stamp : ndarray
            Event timestamps

        **kwargs
            Keyword arguments

        Returns
        -------
        dict
            Graphical data of the analysis

        See also
        --------
        nc_spike.NSpike().psth()

        """
        return self.spike.psth(event_stamp, **kwargs)

    def burst(self, burst_thresh=5, ibi_thresh=50):
        """
        Burst analysis of spike-train.

        Delegates to NSpike().burst()

        Parameters
        ----------
        burst_thresh : int
            Minimum ISI between consecutive spikes in a burst

        ibi_thresh : int
            Minimum inter-burst interval between two bursting groups of spikes

        Returns
        -------
        None

        See also
        --------
        nc_spike.NSpike().burst

        """
        self.spike.burst(burst_thresh=burst_thresh, ibi_thresh=ibi_thresh)

    def get_total_spikes(self):
        """
        Return total number of spikes in the recording.

        Parameters
        ----------
        None

        Returns
        -------
        int
            Total number of spikes

        """
        return self.spike.get_total_spikes()

    def get_total_channels(self):
        """
        Return total number of electrode channels in the spike data file.

        Parameters
        ----------
        None

        Returns
        -------
        int
            Total number of electrode channels

        """
        return self.spike.get_total_channels()

    def get_channel_ids(self):
        """
        Return the identities of individual channels.

        Parameters
        ----------
        None

        Returns
        -------
        list
            Identities of individual channels

        """
        return self.spike.get_channel_ids()

    def get_timebase(self):
        """
        Return the timebase for spike event timestamps.

        Parameters
        ----------
        None

        Returns
        -------
        int
            Timebase for spike event timestamps

        """
        return self.spike.get_timebase()

    def get_sampling_rate(self):
        """
        Return the sampling rate of spike waveforms.

        Parameters
        ----------
        None

        Returns
        -------
        int
            Sampling rate for spike waveforms

        """
        return self.spike.get_sampling_rate()

    def get_samples_per_spike(self):
        """
        Return the number of bytes to represent each timestamp.

        Parameters
        ----------
        None

        Returns
        -------
        int
            Number of bytes to represent timestamps

        """
        return self.spike.get_samples_per_spike()

    def get_wave_timestamp(self):
        """
        Return the temporal resolution to represent samples of spike-waves.

        Parameters
        ----------
        None

        Returns
        -------
        int
            Number of bytes to represent timestamps

        """
        # return as microsecond
        # fs downsampled so that the time is given in microsecond
        fs = self.spike.get_sampling_rate() / 10**6
        return 1 / fs

    def save_to_hdf5(self):
        """
        Store NSpike() object to HDF5 file.

        Delegates to NSPike().save_to_hdf5()

        Parameters
        ----------
        None

        Returns
        -------
        None

        Also see
        --------
        nc_hdf.Nhdf().save_spike()

        """
        self.spike.save_to_hdf5()

    def get_feat(self, npc=2):
        """
        Return the spike-waveform features.

        Parameters
        ----------
        npc : int
            Number of principle components in each channel.

        Returns
        -------
        feat : ndarray
            Matrix of size (number_spike X number_features)

        """
        if not self.NULL_CHAN_REMOVED:
            self.remove_null_chan()
        if not self.ALLIGNED:
            self.align_wave_peak()

        trough, trough_loc = self.get_min_wave_chan()
        peak, peak_chan, peak_loc = self.get_max_wave_chan()
        pc = self.get_wave_pc(npc=npc)
        shape = (self.get_total_spikes(), 1)
        feat = np.concatenate((peak.reshape(shape), trough.reshape(shape), pc),
                              axis=1)

        return feat

    def get_feat_by_unit(self, unit_no=None):
        """
        Return the spike-waveform features for a particular unit.

        Parameters
        ----------
        unit_no : int
            Unit of interest

        Returns
        -------
        feat : ndarray
            Matrix of size (number_spike X number_features)

        """
        if unit_no in self.get_unit_list():
            feat = self.get_feat()
            return feat[self.get_unit_tags() == unit_no, :]
        else:
            logging.error('Specified unit does not exist in the spike dataset')

    def get_wave_peaks(self):
        """
        Return the peaks of the spike-waveforms.

        Parameters
        ----------
        None

        Returns
        -------
        (peak, peak_loc) : (ndarray, ndarray)
            peak:
                Spike waveform peaks in all the electrode channels
                Shape is (num_waves X num_channels)
            peak_loc :
                Index of peak locations

        """
        wave = self.get_waveform()
        peak = np.zeros((self.get_total_spikes(), len(wave.keys())))
        peak_loc = np.zeros((self.get_total_spikes(), len(wave.keys())),
                            dtype=int)
        for i, key in enumerate(wave.keys()):
            peak[:, i] = np.amax(wave[key], axis=1)
            peak_loc[:, i] = np.argmax(wave[key], axis=1)

        return peak, peak_loc

    def get_max_wave_chan(self):
        """
        Return the maximum of waveform peaks among the electrode groups.

        Parameters
        ----------
        None

        Returns
        -------
        (max_wave_val, max_wave_chan, peak_loc) : (ndarray, ndarray, ndarray)
        max_wave_val : ndarray
            Maximum value of the peaks of the waveforms
        max_wave_chan : ndarray
            Channel of the electrode group where a spike waveform is strongest
        peak_loc : ndarray
            Peak location in the channel with strongest waveform

        """
        peak, peak_loc = self.get_wave_peaks()
        max_wave_chan = np.argmax(peak, axis=1)
        max_wave_val = np.amax(peak, axis=1)
        return (max_wave_val, max_wave_chan, peak_loc[np.arange(len(peak_loc)),
                                                      max_wave_chan])

    def align_wave_peak(self, reach=300, factor=2):
        """
        Align the waves by their peaks.

        Parameters
        ----------
        reach : int
            Maximum allowed time-shift in microsecond
        factors : int
            Resampling factor

        Returns
        -------
        None

        """
        if not self.UPSAMPLED:
            self.resample_wave(factor=factor)
        if not self.ALLIGNED:
            # maximum 300microsecond allowed for shift
            shift = round(reach / self.get_wave_timestamp())
            # NC waves are stored in waves['ch1'], waves['ch2'] etc. ways
            wave = self.get_waveform()
            maxInd = shift + self.get_max_wave_chan()[2]
            shift_ind = int(np.median(maxInd)) - maxInd
            shift_ind[np.abs(shift_ind) > shift] = 0
            stacked_chan = np.empty(
                (self.get_total_spikes(), self.get_samples_per_spike(),
                 self.get_total_channels()))
            keys = []
            i = 0
            for key, val in wave.items():
                stacked_chan[:, :, i] = val
                keys.append(key)
                i += 1

            stacked_chan = np.lib.pad(stacked_chan, [(0, 0), (shift, shift),
                                                     (0, 0)], 'edge')

            stacked_chan = np.array([
                np.roll(stacked_chan[i, :, :], shift_ind[i],
                        axis=0)[shift:shift + self.get_samples_per_spike()]
                for i in np.arange(shift_ind.size)
            ])

            for i, key in enumerate(keys):
                wave[key] = stacked_chan[:, :, i]
            self._set_waveform(wave)
            self.ALLIGNED = True

    def get_wave_min(self):
        """
        Return the minimum values of the spike-waveforms.

        Parameters
        ----------
        None

        Returns
        -------
        (min_w, min_loc) : (ndarray, ndarray)
            min_w : ndarray
                Minimum value of the waveforms
            min_loc : ndarray
                Index of minimum value

        """
        wave = self.get_waveform()
        min_w = np.zeros((self.get_total_spikes(), len(wave.keys())))
        min_loc = np.zeros((self.get_total_spikes(), len(wave.keys())))
        for i, key in enumerate(wave.keys()):
            min_w[:, i] = np.amin(wave[key], axis=1)
            min_loc[:, i] = np.argmin(wave[key], axis=1)

        return min_w, min_loc

    def get_min_wave_chan(self):
        """
        Return the maximum of waveform peaks among the electrode groups.

        Parameters
        ----------
        None

        Returns
        -------
        (min_val, min_index) : (ndarray, ndarray)
            min_val : ndarray
                Minimum value of the waveform at channels with peak value
            min_index : ndarray
                Index of minimum values

        """
        max_wave_chan = self.get_max_wave_chan()[1]
        trough, trough_loc = self.get_wave_min()
        return (trough[np.arange(len(max_wave_chan)), max_wave_chan],
                trough_loc[np.arange(len(max_wave_chan)), max_wave_chan])

    def get_wave_pc(self, npc=2):
        """
        Return the Principle Components of the waveforms.

        Parameters
        ----------
        npc : int
            Number of principle components from waveforms of each channel

        Returns
        -------
        pc : ndarray
            Principle components (num_waves X npc*num_channels)

        """
        wave = self.get_waveform()
        pc = np.array([])
        for key, w in wave.items():
            pca = PCA(n_components=5)
            w_new = pca.fit_transform(w)
            pc_var = pca.explained_variance_ratio_

            if npc and npc < w_new.shape[1]:
                w_new = w_new[:, :npc]
            else:
                w_new = w_new[:, 0:(
                    find(np.cumsum(pc_var) >= 0.95, 1, 'first')[0] + 1)]
            if not len(pc):
                pc = w_new
            else:
                pc = np.append(pc, w_new, axis=1)
        return pc

    def get_wavetime(self):
        """
        Return the timestamps of the waveforms, not the spiking-event.

        Parameters
        ----------
        None

        Returns
        -------
            Timestamps of the spike-waveforms

        """
        # calculate the wavetime from the sampling rate and number of sample
        # returns in microsecond
        nsamp = self.spike.get_samples_per_spike()
        timestamp = self.get_wave_timestamp()
        return np.arange(0, (nsamp) * timestamp, timestamp)

    def resample_wavetime(self, factor=2):
        """
        Resample the timestamps of spike-waveforms.

        Parameters
        ----------
        factor : int
            Resampling factor

        Returns
        -------
            Resampled timestamps

        """
        wavetime = self.get_wavetime()
        timestamp = self.get_wave_timestamp()

        return np.arange(0, wavetime[-1], timestamp / factor)

    def resample_wave(self, factor=2):
        """
        Resample spike waveforms using spline interpolation.

        Parameters
        ----------
        factor : int
            Resampling factor

        Returns
        -------
        wave : dict
            Upsampled waveforms
        uptime  ndarray
            Upsampled wave timestamps

        """
        # resample wave using spline interpolation using the resampled_time
        if not self.UPSAMPLED:
            wavetime = self.get_wavetime()
            uptime = self.resample_wavetime(factor=factor)
            wave = self.get_waveform()
            for key, w in wave.items():
                f = sc.interpolate.interp1d(wavetime,
                                            w,
                                            axis=1,
                                            kind='quadratic')
                wave[key] = f(uptime)

            self.spike._set_sampling_rate(self.get_sampling_rate() * factor)
            self.spike._set_samples_per_spike(uptime.size)
            self.UPSAMPLED = True

            return wave, uptime

        else:
            logging.warning('You can upsample only once. ' +
                            'Please reload data from source file ' +
                            'for changing sampling factor!')

    def get_wave_energy(self):
        """
        Energy of the spike waveforms.

        This is measured as the summation of the square of samples.

        Parameters
        ----------
        None

        Returns
        -------
        energy : ndarray
            Energy of spikes (num_spike X num_channels)

        """
        wave = self.get_waveform()
        energy = np.zeros((self.get_total_spikes(), len(wave.keys())))
        for i, key in enumerate(wave.keys()):
            # taken the energy in mV2
            energy[:, i] = (np.sum(np.square(wave[key]), 1) / 10**6)
        return energy

    def get_max_energy_chan(self):
        """
        Return the maximum energy of the spike waveforms.

        Parameters
        ----------
        None

        Returns
        -------
        ndarray
            Maximum energy of the spikes

        """
        energy = self.get_wave_energy()
        return np.argmax(energy, axis=1)

    def remove_null_chan(self):
        """
        Remove the channel from the electrode group that has no spike in it.

        Parameters
        ----------
        None

        Returns
        -------
        off_chan : int
            Channel number that has been removed

        """
        # simply detect in which channel everything is zero,
        # which means it's a reference channel or nothing is recorded here
        wave = self.get_waveform()
        off_chan = []
        for key, w in wave.items():
            if np.abs(w).sum() == 0:
                off_chan.append(key)
        if off_chan:
            for key in off_chan:
                del wave[key]
            self._set_waveform(wave)
            self.NULL_CHAN_REMOVED = True

        return off_chan

    def cluster_separation(self, unit_no=0):
        """
        Measure the separation of a specific unit from other clusters.

        This is performed quantitatively using the following:
        1. Bhattacharyya coefficient
        2. Hellinger distance

        Parameters
        ----------
        unit_no : int
            Unit of interest.
            If '0', pairwise comparison of all units are returned.

        Returns
        -------
        (bc, dh) : (ndarray, ndarray)
        bc : ndarray
            Bhattacharyya coefficient
        dh : ndarray
            Hellinger distance

        """
        # if unit_no==0 all units, matrix output for pairwise comparison,
        # else maximum BC for the specified unit
        feat = self.get_feat()
        unit_list = self.get_unit_list()
        n_units = len(unit_list)

        if unit_no == 0:
            bc = np.zeros([n_units, n_units])
            dh = np.zeros([n_units, n_units])
            for c1 in np.arange(n_units):
                for c2 in np.arange(n_units):
                    X1 = feat[self.get_unit_tags() == unit_list[c1], :]
                    X2 = feat[self.get_unit_tags() == unit_list[c2], :]
                    bc[c1, c2] = bhatt(X1, X2)[0]
                    dh[c1, c2] = hellinger(X1, X2)
                    unit_list = self.get_unit_list()
            return bc, dh

        else:
            bc = np.zeros(n_units)
            dh = np.zeros(n_units)
            X1 = feat[self.get_unit_tags() == unit_no, :]
            for c2 in np.arange(n_units):
                if c2 == unit_no:
                    bc[c2] = 0
                    dh[c2] = 1
                else:
                    X2 = feat[self.get_unit_tags() == unit_list[c2], :]
                    bc[c2] = bhatt(X1, X2)[0]
                    dh[c2] = hellinger(X1, X2)
                idx = find(np.array(unit_list) != unit_no)

            return bc[idx], dh[idx]

    def cluster_similarity(self, nclust=None, unit_1=None, unit_2=None):
        """
        Measure the similarity or distance of units in a cluster.

        This is performed on one unit
        in a spike dataset to cluster of another unit in another dataset.

        This is performed quantitatively using the following:
        1. Bhattacharyya coefficient
        2. Hellinger distance

        Parameters
        ----------
        nclust : Nclust
            NClust object whose unit is under comparison
        unit_1 : int
            Unit of current Nclust object
        unit_2 : int
            Unit of another NClust object under comparison

        Returns
        -------
        (bc, dh) : (ndarray, ndarray)
        bc : ndarray
            Bhattacharyya coefficient
        dh : ndarray
            Hellinger distance

        """
        if isinstance(nclust, NClust):
            if ((unit_1 in self.get_unit_list())
                    and (unit_2 in nclust.get_unit_list())):
                X1 = self.get_feat_by_unit(unit_no=unit_1)
                X2 = nclust.get_feat_by_unit(unit_no=unit_2)
                bc = bhatt(X1, X2)[0]
                dh = hellinger(X1, X2)
        return bc, dh
if __name__ == "__main__":
    location = r"C:\Users\smartin5\Recordings\Raw\2min\CS1_18_02_open_2_bin_shuff.bin"
    location = r"G:\Ham\A10_CAR-SA2\CAR-SA2_20200109_PreBox\CAR-SA2_2020-01-09_PreBox_shuff.bin"
    # location = r"F:\CAR-SA4_20200301_PreBox\CAR-SA4_2020-03-01_PreBox.bin"
    tetrode = 12
    channels = [4 * (tetrode - 1) + i for i in range(4)]
    times = []
    n_times = 10
    # spike_names = "CAR-SA4_2020-03-01_PreBox_12_c1_times.txt"
    # with open(os.path.join(os.path.dirname(location), spike_names), "r") as f:
    #     for i in range(n_times):
    #         line = f.readline()
    #         time = float(line[:-1].strip())
    #         times.append(time)
    # This is a temp
    from neurochat.nc_spike import NSpike
    import neurochat.nc_plot as nc_plot
    ns = NSpike()
    ns.load(
        os.path.join(os.path.dirname(location),
                     "CAR-SA2_2020-01-09_PreBox.12"), "Axona")
    ns.set_unit_no(unit_no=1)
    times = []
    for i in range(n_times):
        t = ns.get_unit_stamp()[i]
        times.append(t)
    nc_plot.wave_property(ns.wave_property())
    plt.savefig("wave.png")
    read_shuff_bin(location, channels, times, fname="test12_1_ac_new2.png")
示例#9
0
def test_nc_recording_loading(delete=False):
    from neurochat.nc_lfp import NLfp
    from neurochat.nc_spike import NSpike
    from neurochat.nc_spatial import NSpatial
    from simuran.loaders.nc_loader import NCLoader

    main_test_dir = os.path.join(main_dir, "tests", "resources", "temp",
                                 "axona")
    os.makedirs(main_test_dir, exist_ok=True)

    axona_files = fetch_axona_data()

    # Load using SIMURAN auto detection.
    ex = Recording(
        param_file=os.path.join(main_dir, "tests", "resources", "params",
                                "axona_test.py"),
        base_file=main_test_dir,
        load=False,
    )
    ex.signals[0].load()
    ex.units[0].load()
    ex.units[0].underlying.set_unit_no(1)
    ex.spatial.load()

    # Load using NeuroChaT
    lfp = NLfp()
    lfp.set_filename(
        os.path.join(main_test_dir, "010416b-LS3-50Hz10.V5.ms.eeg"))
    lfp.load(system="Axona")

    unit = NSpike()
    unit.set_filename(os.path.join(main_test_dir,
                                   "010416b-LS3-50Hz10.V5.ms.2"))
    unit.load(system="Axona")
    unit.set_unit_no(1)

    spatial = NSpatial()
    spatial.set_filename(
        os.path.join(main_test_dir, "010416b-LS3-50Hz10.V5.ms_2.txt"))
    spatial.load(system="Axona")

    assert np.all(ex.signals[0].underlying.get_samples() == lfp.get_samples())
    assert np.all(
        ex.units[0].underlying.get_unit_stamp() == unit.get_unit_stamp())
    assert np.all(
        ex.units[0].underlying.get_unit_tags() == unit.get_unit_tags())
    assert np.all(ex.spatial.underlying.get_pos_x() == spatial.get_pos_x())

    ncl = NCLoader()
    ncl.load_params["system"] = "Axona"
    loc = os.path.join(main_dir, "tests", "resources", "temp", "axona")
    file_locs, _ = ncl.auto_fname_extraction(
        loc,
        verbose=False,
        unit_groups=[
            2,
        ],
        sig_channels=[
            1,
        ],
    )
    clust_locs = [
        os.path.basename(f) for f in file_locs["Clusters"] if f is not None
    ]
    assert "010416b-LS3-50Hz10.V5.ms_2.cut" in clust_locs

    if delete:
        for f in axona_files:
            os.remove(f)
示例#10
0
class NData():
    """
    The NData object is composed of data objects (NSpike(), NSpatial(), NLfp(),
    and Nhdf()) and is built upon the composite structural object pattern. 

    This data class is the main data element in NeuroChaT which delegates the 
    analysis and other operations to respective objects.

    """
    def __init__(self):
        """
        Attributes
        ----------
        spatial : NSpatial
            Spatial data object
        spike : NSpike
            Spike data object
        lfp : Nlfp
            LFP data object
        hdf : NHdf
            Object for manipulating HDF5 file
        data_format : str
            Recording system or format of the data file

        """

        super().__init__()
        self.spike = NSpike(name='C0')
        self.spatial = NSpatial(name='S0')
        self.lfp = NLfp(name='L0')
        self.data_format = 'Axona'
        self._results = oDict()
        self.hdf = Nhdf()

        self.__type = 'data'

    def subsample(self, sample_range):
        """
        Split up a data object in the collection into parts.

        Parameters
        ----------
        sample_range: tuple
            times in seconds to extract

        Returns
        -------
        NData
            subsampled version of initial ndata object
        """
        ndata = NData()
        if self.lfp.get_duration() != 0:
            ndata.lfp = self.lfp.subsample(sample_range)
        if self.spike.get_duration() != 0:
            ndata.spike = self.spike.subsample(sample_range)
        if self.spatial.get_duration() != 0:
            ndata.spatial = self.spatial.subsample(sample_range)

        return ndata

    def get_type(self):
        """
        Returns the type of object. For NData, this is always `data` type

        Parameters
        ----------
        None

        Returns
        -------
        str

        """
        return self.__type

    def get_results(self, spaces_to_underscores=False):
        """
        Returns the parametric results of the analyses

        Parameters
        ----------
        None

        Returns
        -------
        OrderedDict

        """
        if spaces_to_underscores:
            results = {
                x.replace(' ', '_'): v
                for x, v in self._results.items()
            }
            return results
        return self._results

    def update_results(self, results):
        """
        Adds new parametric results of the analyses

        Parameters
        ----------
        results : OrderedDict
            New analyses results (parametric)

        Returns
        -------
        None

        """

        self._results.update(results)

    def reset_results(self):
        """
        Reset the NData results to an empty OrderedDict

        Parameters
        ----------
        None

        Returns
        -------
        None

        """

        self._results = oDict()
        # self.spike.reset_results()
        # self.spatial.reset_results()
        # self.lfp.reset_results()

    def get_data_format(self):
        """
        Returns the recording system or data format

        Parameters
        ----------
        None

        Returns
        -------
        str

        """
        return self.data_format

    def set_data_format(self, data_format=None):
        """
        Returns the parametric results of the analyses

        Parameters
        ----------
        data_format : str
            Recording system or format of the data

        Returns
        -------
        None

        """

        if data_format is None:
            data_format = self.get_data_format()
        self.data_format = data_format
        self.spike.set_system(data_format)
        self.spatial.set_system(data_format)
        self.lfp.set_system(data_format)

    def load(self):
        """
        Loads the data from the filenames in each constituing objects, i.e, 
        spatial,  spike and LFP 

        Parameters
        ----------
        None

        Returns
        -------
        None

        """
        self.load_spike()
        self.load_spatial()
        self.load_lfp()

    def save_to_hdf5(self):
        """
        Stores the spatial, spike and LFP datasets to HDF5 file 

        Parameters
        ----------
        None

        Returns
        -------
        None

        """

        try:
            self.hdf.save_object(obj=self.spike)
        except:
            logging.warning(
                'Error in exporting NSpike data from NData object to the hdf5 file!'
            )

        try:
            self.hdf.save_object(obj=self.spatial)
        except:
            logging.warning(
                'Error in exporting NSpatial data from NData object to the hdf5 file!'
            )

        try:
            self.hdf.save_object(obj=self.lfp)
        except:
            logging.warning(
                'Error in exporting NLfp data from NData object to the hdf5 file!'
            )

    def set_unit_no(self, unit_no):
        """
        Sets the unit number of the spike dataset to analyse

        Parameters
        ----------
        unit_no : int
            Unit or cell number to analyse

        Returns
        -------
        None

        """

        self.spike.set_unit_no(unit_no)

    def set_spike_name(self, name='C0'):
        """
        Sets the name of the spike dataset

        Parameters
        ----------
        name : str
            Name of the spike dataset

        Returns
        -------
        None

        """

        self.spike.set_name(name)

    def set_spike_file(self, filename):
        """
        Sets the filename of the spike dataset

        Parameters
        ----------
        filename : str
            Full file directory of the spike dataset

        Returns
        -------
        None

        """

        self.spike.set_filename(filename)

    def get_spike_file(self):
        """
        Gets the filename of the spike dataset

        Parameters
        ----------
        None

        Returns
        -------
        str
            Filename of the spike dataset
        """

        return self.spike.get_filename()

    def load_spike(self):
        """
        Loads spike dataset from the file to NSpike() object

        Parameters
        ----------
        None        
        Returns
        -------
        None

        """

        self.spike.load()

    def set_spatial_file(self, filename):
        """
        Sets the filename of the spatial dataset

        Parameters
        ----------
        filename : str
            Full file directory of the spike dataset

        Returns
        -------
        None

        """
        self.spatial.set_filename(filename)

    def get_spatial_file(self):
        """
        Gets the filename of the spatial dataset

        Parameters
        ----------
        None

        Returns
        -------
        str
            Filename of the spatial dataset

        """
        return self.spatial.get_filename()

    def set_spatial_name(self, name):
        """
        Sets the name of the spatial dataset

        Parameters
        ----------
        name : str
            Name of the spatial dataset

        Returns
        -------
        None

        """

        self.spatial.set_name(name)

    def load_spatial(self):
        """
        Loads spatial dataset from the file to NSpatial() object

        Parameters
        ----------
        filename : str
            Full file directory of the spike dataset

        Returns
        -------
        None

        """
        self.spatial.load()

    def set_lfp_file(self, filename):
        """
        Sets the filename of the LFP dataset

        Parameters
        ----------
        filename : str
            Full file directory of the spike dataset

        Returns
        -------
        None

        """
        self.lfp.set_filename(filename)

    def get_lfp_file(self):
        """
        Gets the filename of the LFP dataset

        Parameters
        ----------
        None

        Returns
        -------
        str
            Filename of the LFP dataset
        """

        return self.lfp.get_filename()

    def set_lfp_name(self, name):
        """
        Sets the name of the NLfp() object

        Parameters
        ----------
        name : str
            Name of the LFP dataset

        Returns
        -------
        None

        """

        self.lfp.set_name(name)

    def load_lfp(self):
        """
        Loads LFP dataset to NLfp() object

        Parameters
        ----------
        None

        Returns
        -------
        None

        """

        self.lfp.load()

    # Forwarding to analysis
    def wave_property(self):
        """
        Analysis of wavefor characteristics of the spikes of a unit

        Delegates to NSpike().wave_property()

        Parameters
        ----------
        None        

        Returns
        -------
        dict
            Graphical data of the analysis

        See also
        --------
        nc_spike.NSpike().wave_property

        """

        gdata = self.spike.wave_property()
        self.update_results(self.spike.get_results())

        return gdata

    def isi(self,
            bins='auto',
            bound=None,
            density=False,
            refractory_threshold=2):
        """
        Analysis of ISI histogram

        Delegates to NSpike().isi()

        Parameters
        ----------
        bins : str or int
            Number of ISI histogram bins. If 'auto', NumPy default is used

        bound : int
            Length of the ISI histogram in msec
        density : bool
            If true, normalized historagm is calcultaed
        refractory_threshold : int
            Length of the refractory period in msec

        Returns
        -------
        dict
            Graphical data of the analysis

        See also
        --------
        nc_spike.NSpike().isi

        """
        gdata = self.spike.isi(bins, bound, density, refractory_threshold)
        self.update_results(self.spike.get_results())

        return gdata

    def isi_auto_corr(self, spike=None, **kwargs):
        """
        Analysis of ISI autocrrelation histogram

        Delegates to NSpike().isi_corr()

        Parameters
        ----------
        spike : NSpike()
            If specified, it calulates cross-correlation.

        **kwargs
            Keyword arguments

        Returns
        -------
        dict
            Graphical data of the analysis

        See also
        --------
        nc_spike.NSpike().isi_corr, nc_spike.NSpike().psth

        """

        gdata = self.spike.isi_corr(spike, **kwargs)

        return gdata

    def burst(self, burst_thresh=5, ibi_thresh=50):
        """
        Burst analysis of spik-train

        Delegates to NSpike().burst()

        Parameters
        ----------
        burst_thresh : int
            Minimum ISI between consecutive spikes in a burst

        ibi_thresh : int
            Minimum inter-burst interval between two bursting groups of spikes

        Returns
        -------
        None

        See also
        --------
        nc_spike.NSpike().burst

        """

        self.spike.burst(burst_thresh, ibi_thresh=ibi_thresh)
        self.update_results(self.spike.get_results())

    def theta_index(self, **kwargs):
        """
        Calculates theta-modulation of spike-train ISI autocorrelation histogram.

        Delegates to NSpike().theta_index()

        Parameters
        ----------
        **kwargs
            Keyword arguments

        Returns
        -------
        dict
            Graphical data of the analysis

        See also
        --------
        nc_spike.NSpike().theta_index()

        """

        gdata = self.spike.theta_index(**kwargs)
        self.update_results(self.spike.get_results())

        return gdata

    def theta_skip_index(self, **kwargs):
        """
        Calculates theta-skipping of spike-train ISI autocorrelation histogram.

        Delegates to NSpike().theta_skip_index()

        Parameters
        ----------
        **kwargs
            Keyword arguments

        Returns
        -------
        dict
            Graphical data of the analysis

        See also
        --------
        nc_spike.NSpike().theta_skip_index()

        """

        gdata = self.spike.theta_skip_index(**kwargs)
        self.update_results(self.spike.get_results())

        return gdata

    def bandpower_ratio(self, first_band, second_band, win_sec, **kwargs):
        """
        Calculate the ratio in power between two bandpass filtered signals.

        Delegates to NLfp.bandpower_ratio()
        Suggested [5, 11] and [1.5, 4 bands]


        Parameters
        ----------
        first_band, second_band, win_sec, **kwargs

        See also
        --------
        nc_lfp.NLfp.bandpower_ratio()
        """

        self.lfp.bandpower_ratio(first_band, second_band, win_sec, **kwargs)
        self.update_results(self.lfp.get_results())

    def spectrum(self, **kwargs):
        """
        Analyses frequency spectrum of the LFP signal

        Delegates to NLfp().spectrum()

        Parameters
        ----------
        **kwargs
            Keyword arguments

        Returns
        -------
        dict
            Graphical data of the analysis

        See also
        --------
        nc_lfp.NLfp().spectrum()

        """

        gdata = self.lfp.spectrum(**kwargs)

        return gdata

    def phase_dist(self, **kwargs):
        """
        Analysis of spike to LFP phase distribution

        Delegates to NLfp().phase_dist()

        Parameters
        ----------
        **kwargs
            Keyword arguments

        Returns
        -------
        dict
            Graphical data of the analysis

        See also
        --------
        nc_lfp.NLfp().phase_dist()

        """

        gdata = self.lfp.phase_dist(self.spike.get_unit_stamp(), **kwargs)
        self.update_results(self.lfp.get_results())

        return gdata

    def phase_at_spikes(self, **kwargs):
        """
        Analysis of spike to LFP phase distribution

        Delegates to NLfp().phase_dist()

        Parameters
        ----------
        **kwargs
            Keyword arguments

        Returns
        -------
        phases, times, positions

        See also
        --------
        nc_lfp.NLfp().phase_at_events()

        """
        key = "keep_zero_idx"
        out_data = {}
        if not key in kwargs.keys():
            kwargs[key] = True
        should_filter = kwargs.get("should_filter", True)

        ftimes = self.spike.get_unit_stamp()
        phases = self.lfp.phase_at_events(ftimes, **kwargs)
        _, positions, directions = self.get_event_loc(ftimes, **kwargs)

        if should_filter:
            place_data = self.place(**kwargs)
            boundary = place_data["placeBoundary"]
            co_ords = place_data["indicesInPlaceField"]
            largest_group = place_data["largestPlaceGroup"]

            out_data["good_place"] = (largest_group != 0)
            out_data["phases"] = phases[co_ords]
            out_data["times"] = ftimes[co_ords]
            out_data["directions"] = directions[co_ords]
            out_data["positions"] = [
                positions[0][co_ords], positions[1][co_ords]
            ]
            out_data["boundary"] = boundary

        else:
            out_data["phases"] = phases
            out_data["times"] = ftimes
            out_data["positions"] = positions
            out_data["directions"] = directions

        self.update_results(self.get_results())
        return out_data

    def plv(self, **kwargs):
        """
        Calculates phase-locking value of the spike train to underlying LFP signal.

        Delegates to NLfp().plv()

        Parameters
        ----------
        **kwargs
            Keyword arguments

        Returns
        -------
        dict
            Graphical data of the analysis

        See also
        --------
        nc_lfp.NLfp().plv()

        """

        gdata = self.lfp.plv(self.spike.get_unit_stamp(), **kwargs)

        return gdata

    # def sfc(self, **kwargs):
    # """
    # Calculates spike-field coherence of spike train with underlying LFP signal.

    # Delegates to NLfp().sfc()

    # Parameters
    # ----------
    # **kwargs
    #     Keyword arguments

    # Returns
    # -------
    # dict
    #     Graphical data of the analysis

    # See also
    # --------
    # nc_lfp.NLfp().sfc()

    # """

    # gdata = self.lfp.plv(self.spike.get_unit_stamp(), **kwargs)

    # return gdata

    def event_trig_average(self, **kwargs):
        """
        Averaging event-triggered LFP signals

        Delegates to NLfp().event_trig_average()

        Parameters
        ----------
        **kwargs
            Keyword arguments

        Returns
        -------
        dict
            Graphical data of the analysis

        See also
        --------
        nc_lfp.NLfp().event_trig_average()

        """

        gdata = self.lfp.event_trig_average(self.spike.get_unit_stamp(),
                                            **kwargs)

        return gdata

    def spike_lfp_causality(self, **kwargs):
        """
        Analyses spike to underlying LFP causality

        Delegates to NLfp().spike_lfp_causality()

        Parameters
        ----------
        **kwargs
            Keyword arguments

        Returns
        -------
        dict
            Graphical data of the analysis

        See also
        --------
        nc_lfp.NLfp().spike_lfp_causality()

        """

        gdata = self.lfp.spike_lfp_causality(self.spike.get_unit_stamp(),
                                             **kwargs)
        self.update_results(self.lfp.get_results())

        return gdata

    def speed(self, **kwargs):
        """
        Analysis of unit correlation with running speed

        Delegates to NSpatial().speed()

        Parameters
        ----------
        **kwargs
            Keyword arguments

        Returns
        -------
        dict
            Graphical data of the analysis

        See also
        --------
        nc_spatial.NSpatial().speed()

        """

        gdata = self.spatial.speed(self.spike.get_unit_stamp(), **kwargs)
        self.update_results(self.spatial.get_results())

        return gdata

    def angular_velocity(self, **kwargs):
        """
        Analysis of unit correlation to angular head velocity (AHV) of the animal

        Delegates to NSpatial().angular_velocity()

        Parameters
        ----------
        **kwargs
            Keyword arguments

        Returns
        -------
        dict
            Graphical data of the analysis

        See also
        --------
        nc_spatial.NSpatial().angular_velocity()

        """

        gdata = self.spatial.angular_velocity(self.spike.get_unit_stamp(),
                                              **kwargs)
        self.update_results(self.spatial.get_results())

        return gdata

    def place(self, **kwargs):
        """
        Analysis of place cell firing characteristics

        Delegates to NSpatial().place()

        Parameters
        ----------
        **kwargs
            Keyword arguments

        Returns
        -------
        dict
            Graphical data of the analysis

        See also
        --------
        nc_spatial.NSpatial().place()

        """

        gdata = self.spatial.place(self.spike.get_unit_stamp(), **kwargs)
        self.update_results(self.spatial.get_results())

        return gdata

    # Created by Sean Martin: 13/02/2019
    def place_field_centroid_zscore(self, **kwargs):
        """
        Calculates a very simple centroid of place field

        Delegates to NSpatial().place_field()

        Parameters
        ----------
        **kwargs
            Keyword arguments

        Returns
        -------
        ndarray
            Centroid of the place field

        See also
        --------
        nc_spatial.NSpatial().place_field()

        """

        gdata = self.spatial.place_field_centroid_zscore(
            self.spike.get_unit_stamp(), **kwargs)
        self.update_results(self.spatial.get_results())

        return gdata

    def loc_time_lapse(self, **kwargs):
        """
        Time-lapse firing proeprties of the unit with respect to location

        Delegates to NSpatial().loc_time_lapse()

        Parameters
        ----------
        **kwargs
            Keyword arguments

        Returns
        -------
        dict
            Graphical data of the analysis

        See also
        --------
        nc_spatial.NSpatial().loc_time_lapse()

        """
        gdata = self.spatial.loc_time_lapse(self.spike.get_unit_stamp(),
                                            **kwargs)

        return gdata

    def loc_shuffle(self, **kwargs):
        """
        Shuffling analysis of the unit to  see if the locational firing specifity
        is by chance or actually correlated to the location of the animal

        Delegates to NSpatial().loc_shuffle()

        Parameters
        ----------
        **kwargs
            Keyword arguments

        Returns
        -------
        dict
            Graphical data of the analysis

        See also
        --------
        nc_spatial.NSpatial().loc_shuffle()

        """

        gdata = self.spatial.loc_shuffle(self.spike.get_unit_stamp(), **kwargs)
        self.update_results(self.spatial.get_results())

        return gdata

    def loc_shift(self, shift_ind=np.arange(-10, 11), **kwargs):
        """
        Analysis of firing specificity of the unit with respect to animal's location
        to oberve whether it represents past location of the animal or anicipates a
        future location.

        Delegates to NSpatial().loc_shift()

        Parameters
        ----------
        shift_ind : ndarray
            Index of spatial resolution shift for the spike event time. Shift -1
            implies shift to the past by 1 spatial time resolution, and +2 implies
            shift to the future by 2 spatial time resoultion.
        **kwargs
            Keyword arguments

        Returns
        -------
        dict
            Graphical data of the analysis

        See also
        --------
        nc_spatial.NSpatial().loc_shift()

        """

        gdata = self.spatial.loc_shift(self.spike.get_unit_stamp(),
                                       shift_ind=shift_ind,
                                       **kwargs)
        self.update_results(self.spatial.get_results())

        return gdata

    def loc_auto_corr(self, **kwargs):
        """
        Calculates the two-dimensional correlation of firing map which is the
        map of the firing rate of the animal with respect to its location

        Delegates to NSpatial().loc_auto_corr()

        Parameters
        ----------
        **kwargs
            Keyword arguments

        Returns
        -------
        dict
            Graphical data of the analysis

        See also
        --------
        nc_spatial.NSpatial().loc_auto_corr()

        """
        gdata = self.spatial.loc_auto_corr(self.spike.get_unit_stamp(),
                                           **kwargs)

        return gdata

    def loc_rot_corr(self, **kwargs):
        """
        Calculates the rotational correlation of the locational firing rate of the animal with
        respect to location, also called firing map

        Delegates to NSpatial().loc_rot_corr()

        Parameters
        ----------
        **kwargs
            Keyword arguments

        Returns
        -------
        dict
            Graphical data of the analysis

        See also
        --------
        nc_spatial.NSpatial().loc_rot_corr()

        """

        gdata = self.spatial.loc_rot_corr(self.spike.get_unit_stamp(),
                                          **kwargs)

        return gdata

    def hd_rate(self, **kwargs):
        """
        Analysis of the firing characteristics of a unit with respect to animal's
        head-direction

        Delegates to NSpatial().hd_rate()

        Parameters
        ----------
        **kwargs
            Keyword arguments

        Returns
        -------
        dict
            Graphical data of the analysis

        See also
        --------
        nc_spatial.NSpatial().hd_rate()

        """

        gdata = self.spatial.hd_rate(self.spike.get_unit_stamp(), **kwargs)
        self.update_results(self.spatial.get_results())

        return gdata

    def hd_rate_ccw(self, **kwargs):
        """
        Analysis of the firing characteristics of a unit with respect to animal's
        head-direction split into clockwise and counterclockwised directions

        Delegates to NSpatial().hd_rate_ccw()

        Parameters
        ----------
        **kwargs
            Keyword arguments

        Returns
        -------
        dict
            Graphical data of the analysis

        See also
        --------
        nc_spatial.NSpatial().hd_rate_ccw()

        """

        gdata = self.spatial.hd_rate_ccw(self.spike.get_unit_stamp(), **kwargs)
        self.update_results(self.spatial.get_results())

        return gdata

    def hd_time_lapse(self):
        """
        Time-lapse firing proeprties of the unit with respect to the head-direction
        of the animal

        Delegates to NSpatial().hd_time_lapse()

        Parameters
        ----------
        **kwargs
            Keyword arguments

        Returns
        -------
        dict
            Graphical data of the analysis

        See also
        --------
        nc_spatial.NSpatial().hd_time_lapse()

        """

        gdata = self.spatial.hd_time_lapse(self.spike.get_unit_stamp())

        return gdata

    def hd_shuffle(self, **kwargs):
        """
        Shuffling analysis of the unit to see if the head-directional firing specifity
        is by chance or actually correlated to the head-direction of the animal

        Delegates to NSpatial().hd_shuffle()

        Parameters
        ----------
        **kwargs
            Keyword arguments

        Returns
        -------
        dict
            Graphical data of the analysis

        See also
        --------
        nc_spatial.NSpatial().hd_shuffle()

        """

        gdata = self.spatial.hd_shuffle(self.spike.get_unit_stamp(), **kwargs)
        self.update_results(self.spatial.get_results())

        return gdata

    def hd_shift(self, shift_ind=np.arange(-10, 11), **kwargs):
        """
        Analysis of firing specificity of the unit with respect to animal's head
        direction to oberve whether it represents past direction or anicipates a
        future direction.

        Delegates to NSpatial().hd_shift()

        Parameters
        ----------
        shift_ind : ndarray
            Index of spatial resolution shift for the spike event time. Shift -1
            implies shift to the past by 1 spatial time resolution, and +2 implies
            shift to the future by 2 spatial time resoultion.
        **kwargs
            Keyword arguments

        Returns
        -------
        dict
            Graphical data of the analysis

        See also
        --------
        nc_spatial.NSpatial().speed()

        """

        gdata = self.spatial.hd_shift(self.spike.get_unit_stamp(),
                                      shift_ind=shift_ind)
        self.update_results(self.spatial.get_results())

        return gdata

    def border(self, **kwargs):
        """
        Analysis of the firing characteristic of a unit with respect to the
        environmental border

        Delegates to NSpatial().border()

        Parameters
        ----------
        **kwargs
            Keyword arguments

        Returns
        -------
        dict
            Graphical data of the analysis

        See also
        --------
        nc_spatial.NSpatial().border()

        """

        gdata = self.spatial.border(self.spike.get_unit_stamp(), **kwargs)
        self.update_results(self.spatial.get_results())

        return gdata

    def gradient(self, **kwargs):
        """
        Analysis of gradient cell, a unit whose firing rate gradually increases 
        as the animal traverses from the border to the cneter of the environment

        Delegates to NSpatial().gradient()

        Parameters
        ----------
        **kwargs
            Keyword arguments

        Returns
        -------
        dict
            Graphical data of the analysis

        See also
        --------
        nc_spatial.NSpatial().gradient()

        """

        gdata = self.spatial.gradient(self.spike.get_unit_stamp(), **kwargs)
        self.update_results(self.spatial.get_results())

        return gdata

    def grid(self, **kwargs):
        """
        Analysis of Grid cells characterised by formation of grid-like pattern
        of high activity in the firing-rate map

        Delegates to NSpatial().grid()

        Parameters
        ----------
        **kwargs
            Keyword arguments

        Returns
        -------
        dict
            Graphical data of the analysis

        See also
        --------
        nc_spatial.NSpatial().grid()

        """

        gdata = self.spatial.grid(self.spike.get_unit_stamp(), **kwargs)
        self.update_results(self.spatial.get_results())

        return gdata

    def multiple_regression(self, **kwargs):
        """
        Multiple-rgression analysis where firing rate for each variable, namely
        location, head-direction, speed, AHV, and distance from border, are used
        to regress the instantaneous firing rate of the unit.

        Delegates to NSpatial().multiple_regression()

        Parameters
        ----------
        **kwargs
            Keyword arguments

        Returns
        -------
        dict
            Graphical data of the analysis

        See also
        --------
        nc_spatial.NSpatial().multiple-regression()

        """
        gdata = self.spatial.multiple_regression(self.spike.get_unit_stamp(),
                                                 **kwargs)
        self.update_results(self.spatial.get_results())

        return gdata

    def interdependence(self, **kwargs):
        """
        Interdependence analysis where firing rate of each variable is predicted
        from another variable and the distributive ratio is measured between the
        predicted firing rate and the caclulated firing rate.

        Delegates to NSpatial().interdependence()

        Parameters
        ----------
        **kwargs
            Keyword arguments

        Returns
        -------
        None

        See also
        --------
        nc_spatial.NSpatial().interdependence()

        """

        self.spatial.interdependence(self.spike.get_unit_stamp(), **kwargs)
        self.update_results(self.spatial.get_results())

    def __getattr__(self, arg):
        """
        Sets precedence for delegation with NSpike() > NLfp() > NSpatial()
        Parameters
        ----------
        arg : str
            Name of the function ot attributes to look for

        """

        if hasattr(self.spike, arg):
            return getattr(self.spike, arg)
        elif hasattr(self.lfp, arg):
            return getattr(self.lfp, arg)
        elif hasattr(self.spatial, arg):
            return getattr(self.spatial, arg)
        else:
            logging.warning(
                'No ' + arg +
                ' method or attribute in NeuroData or in composing data class')
示例#11
0
class NCLoader(BaseLoader):
    """Load data compatible with the NeuroChaT package."""
    def __init__(self, load_params={}):
        """Call super class initialize."""
        super().__init__(load_params=load_params)

    def load_signal(self, *args, **kwargs):
        """
        Call the NeuroChaT NLfp.load method.

        Returns
        -------
        dict
            The keys of this dictionary are saved as attributes
            in simuran.signal.BaseSignal.load()
        """
        self.signal = NLfp()
        self.signal.load(*args, self.load_params["system"])
        return {
            "underlying": self.signal,
            "timestamps": self.signal.get_timestamp() * u.s,
            "samples": self.signal.get_samples() * u.mV,
            "date": self.signal.get_date(),
            "time": self.signal.get_time(),
            "channel": self.signal.get_channel_id(),
        }

    def load_spatial(self, *args, **kwargs):
        """
        Call the NeuroChaT NSpatial.load method.

        Returns
        -------
        dict
            The keys of this dictionary are saved as attributes
            in simuran.single_unit.SingleUnit.load()
        """
        self.spatial = NSpatial()
        self.spatial.load(*args, self.load_params["system"])
        return {
            "underlying":
            self.spatial,
            "date":
            self.spatial.get_date(),
            "time":
            self.spatial.get_time(),
            "speed":
            self.spatial.get_speed() * (u.cm / u.s),
            "position": (
                self.spatial.get_pos_x() * u.cm,
                self.spatial.get_pos_y() * u.cm,
            ),
            "direction":
            self.spatial.get_direction() * u.deg,
        }

    def load_single_unit(self, *args, **kwargs):
        """
        Call the NeuroChaT NSpike.load method.

        Returns
        -------
        dict
            The keys of this dictionary are saved as attributes
            in simuran.spatial.Spatial.load()

        """
        fname, clust_name = args
        if clust_name is not None:
            self.single_unit = NSpike()
            self.single_unit.load(fname, self.load_params["system"])
            waveforms = deepcopy(self.single_unit.get_waveform())
            for chan, val in waveforms.items():
                waveforms[chan] = val * u.uV
            return {
                "underlying": self.single_unit,
                "timestamps": self.single_unit.get_timestamp() * u.s,
                "unit_tags": self.single_unit.get_unit_tags(),
                "waveforms": waveforms,
                "date": self.single_unit.get_date(),
                "time": self.single_unit.get_time(),
                "available_units": self.single_unit.get_unit_list(),
                # "units_to_use": self.single_unit.get_unit_list(),
            }
        else:
            return None

    def auto_fname_extraction(self, base, **kwargs):
        """
        Extract all filenames relevant to the recording from base.

        Parameters
        ----------
        base : str
            Where to start looking from.
            For Axona, this should be a .set file,
            or a directory containing exactly one .set file

        Returns
        -------
        fnames : dict
            A dictionary listing the filenames involved in loading.
        base : str
            The base file name, in Axona this is a .set file.

        TODO
        ----
        Expand to support nwb and neuralynx as well as Axona.

        """
        # Currently only implemented for Axona systems
        error_on_missing = self.load_params.get("enforce_data", True)

        if self.load_params["system"] == "Axona":

            # Find the set file if a directory is passed
            if os.path.isdir(base):
                set_files = get_all_files_in_dir(base, ext="set")
                if len(set_files) == 0:
                    print("WARNING: No set files found in {}, skipping".format(
                        base))
                    return None, None
                elif len(set_files) > 1:
                    raise ValueError(
                        "Found more than one set file, found {}".format(
                            len(set_files)))
                base = set_files[0]
            elif not os.path.isfile(base):
                raise ValueError("{} is not a file or directory".format(base))

            joined_params = {**self.load_params, **kwargs}
            cluster_extension = joined_params.get("cluster_extension", ".cut")
            clu_extension = joined_params.get("clu_extension", ".clu.X")
            pos_extension = joined_params.get("pos_extension", ".pos")
            lfp_extension = joined_params.get("lfp_extension",
                                              ".eeg")  # eeg or egf
            stm_extension = joined_params.get("stm_extension", ".stm")
            tet_groups = joined_params.get("unit_groups", None)
            channels = joined_params.get("sig_channels", None)

            filename = os.path.splitext(base)[0]
            base_filename = os.path.splitext(os.path.basename(base))[0]

            # Extract the tetrode and cluster data
            spike_names_all = []
            cluster_names_all = []
            if tet_groups is None:
                tet_groups = [
                    x for x in range(0, 64)
                    if os.path.exists(filename + "." + str(x))
                ]
            if channels is None:
                channels = [
                    x for x in range(2, 256)
                    if os.path.exists(filename + lfp_extension + str(x))
                ]
                if os.path.exists(filename + lfp_extension):
                    channels = [1] + channels
            for tetrode in tet_groups:
                spike_name = filename + "." + str(tetrode)
                if not os.path.isfile(spike_name):
                    e_msg = "Axona data is not available for {}".format(
                        spike_name)
                    if error_on_missing:
                        raise ValueError(e_msg)
                    else:
                        logging.warning(e_msg)
                        return None, base

                spike_names_all.append(spike_name)

                cut_name = filename + "_" + str(tetrode) + cluster_extension
                clu_name = filename + clu_extension[:-1] + str(tetrode)
                if os.path.isfile(cut_name):
                    cluster_name = cut_name
                elif os.path.isfile(clu_name):
                    cluster_name = clu_name
                else:
                    cluster_name = None
                cluster_names_all.append(cluster_name)

            # Extract the positional data
            output_list = [None, None]
            for i, ext in enumerate([pos_extension, stm_extension]):
                for fname in get_all_files_in_dir(
                        os.path.dirname(base),
                        ext=ext,
                        return_absolute=False,
                        case_sensitive_ext=True,
                ):
                    if ext == ".txt":
                        if fname[:len(base_filename) +
                                 1] == base_filename + "_":
                            name = os.path.join(os.path.dirname(base), fname)
                            output_list[i] = name
                            break
                    else:
                        if fname[:len(base_filename)] == base_filename:
                            name = os.path.join(os.path.dirname(base), fname)
                            output_list[i] = name
                            break
            spatial_name, stim_name = output_list

            base_sig_name = filename + lfp_extension
            signal_names = []
            for c in channels:
                if c != 1:
                    if os.path.exists(base_sig_name + str(c)):
                        signal_names.append(base_sig_name + str(c))
                    else:
                        e_msg = "{} does not exist".format(base_sig_name +
                                                           str(c))
                        if error_on_missing:
                            raise ValueError(e_msg)
                        else:
                            logging.warning(e_msg)
                            return None, base
                else:
                    if os.path.exists(base_sig_name):
                        signal_names.append(base_sig_name)
                    else:
                        e_msg = "{} does not exist".format(base_sig_name)
                        if error_on_missing:
                            raise ValueError(e_msg)
                        else:
                            logging.warning(e_msg)
                            return None, base

            file_locs = {
                "Spike": spike_names_all,
                "Clusters": cluster_names_all,
                "Spatial": spatial_name,
                "Signal": signal_names,
                "Stimulation": stim_name,
            }
            return file_locs, base
        else:
            raise ValueError(
                "auto_fname_extraction only implemented for Axona")

    def index_files(self, folder, **kwargs):
        """Find all available neurochat files in the given folder"""
        if self.load_params["system"] == "Axona":
            set_files = []
            root_folders = []
            times = []
            durations = []
            print("Finding all .set files...")
            files = get_all_files_in_dir(
                folder,
                ext=".set",
                recursive=True,
                return_absolute=True,
                case_sensitive_ext=True,
            )
            print(f"Found {len(set_files)} set files")

            for fname in tqdm(files, desc="Processing files"):
                set_files.append(os.path.basename(fname))
                root_folders.append(os.path.normpath(os.path.dirname(fname)))
                with open(fname) as f:
                    f.readline()
                    t = f.readline()[-9:-2]
                    try:
                        int(t[:2])
                        times.append(t)
                        f.readline()
                        f.readline()
                        durations.append(f.readline()[-11:-8])
                    except:
                        if len(times) != len(set_files):
                            times.append(np.nan)
                        if len(durations) != len(set_files):
                            durations.append(np.nan)

            headers = ["filename", "folder", "time", "duration"]
            in_list = [set_files, root_folders, times, durations]
            results_df = list_to_df(in_list, transpose=True, headers=headers)
            return results_df
        else:
            raise ValueError(
                "auto_fname_extraction only implemented for Axona")
示例#12
0
    unit_ids = sorting.get_unit_ids()
    for u in unit_ids:
        waveforms[str(u)] = sorting.get_unit_spike_features(u, "waveforms")

    return timestamps, unit_tags, waveforms


def load_spike_phy(self, folder_name):
    """Appended to NSpike class, loads spikes from phy."""
    print("loading Phy sorting information from {}".format(folder_name))
    sorting = load_phy(folder_name)
    timestamps, unit_tags, waveforms = extract_sorting_info(sorting)

    self._set_duration(timestamps.max())
    self._set_timestamp(timestamps)
    self.set_unit_tags(unit_tags)

    # TODO note that waveforms do not follow NC convention
    # It is just a way to store them for the moment.
    self._set_waveform(waveforms)


NSpike.load_spike_phy = load_spike_phy

if __name__ == "__main__":
    folder = r"D:\Ham_Data\Batch_3\A13_CAR-SA5\CAR-SA5_20200212\phy_klusta"
    nspike = NSpike()
    nspike.load_spike_phy(folder)
    print(nspike.get_unit_list())
    print(nspike.get_timestamp(13))
示例#13
0
def get_spike_times(spike_file, unit_num):
    spike = NSpike()
    spike.set_filename(spike_file)
    spike.set_system("Axona")
    spike.load()
    spike.set_unit_no(unit_num)
    spike_times = spike.get_unit_stamp()
    return spike_times
示例#14
0
def read_hdf(hdf_path, verbose=False, group=3):
    """
    Read the NWB at hdf_path into NeuroChaT.

    Parameters
    ----------
    hdf_path : str
        Path to the hdf5 file
    verbose : bool, optional.
        Defaults to False, indicates whether to print information.
    group : int, optional.
        Defaults to 3, indicates the group in the hdf5 file to use.

    Returns
    -------
    NSpike
        The loaded NSpike object.

    """
    if verbose:
        from skm_pyutils.py_print import print_h5
        print_h5(hdf_path)

    spike_file = hdf_path + "+/processing/Shank/" + str(group)
    spike = NSpike()
    spike.set_system("NWB")
    spike.set_filename(spike_file)
    spike.load()
    unit_no = spike.get_unit_list()[0]
    spike.set_unit_no(unit_no)

    if verbose:
        print(spike)

    return spike