def main():
    # removed reference to this file in config, it is not necessary
    # TODO: let the user specify the name through an option
    output_file = 'spike_train.csv'

    # configure logging module to get useful information
    logging.basicConfig(level=logging.INFO)
    logger = logging.getLogger(__name__)

    parser = argparse.ArgumentParser(description='Run YASS.')
    parser.add_argument('config', type=str,
                        help='Path to configuration file')

    args = parser.parse_args()

    cfg = yass.Config.from_yaml(args.config)

    pp = Preprocessor(cfg)
    score, spike_index_clear, spike_index_collision = pp.process()

    mp = Mainprocessor(cfg, score, spike_index_clear, spike_index_collision)
    spikeTrain_clear, spike_index_collision = mp.mainProcess()

    dc = Deconvolution(cfg, np.transpose(mp.templates,[1,0,2]), spike_index_collision)
    spikeTrain_col = dc.fullMPMU()

    spikeTrain = np.concatenate((spikeTrain_col, spikeTrain_clear))
    idx_sort = np.argsort(spikeTrain[:, 0])
    spikeTrain = spikeTrain[idx_sort]
    
    idx_keep = np.zeros(spikeTrain.shape[0],'bool')
    for k in range(mp.templates.shape[2]):
        idx_c = np.where(spikeTrain[:,1] == k)[0]
        idx_keep[idx_c[np.concatenate(([True], np.diff(spikeTrain[idx_c,0]) > 1))]] = 1
    spikeTrain = spikeTrain[idx_keep]

    path_to_file = os.path.join(cfg.data.root_folder, output_file)

    np.savetxt(path_to_file, spikeTrain, fmt='%i, %i')
    print('Done, spike train saved in: {}'.format(path_to_file))

    
    # set yass configuration parameters
    #set_config(args.config)
    #CONFIG = read_config()

    # run preprocessor
    #score, spike_index_clear, spike_index_collision = preprocess.run()

    # run processor
    #spike_train_clear, templates, spike_index_collision = process.run(score,
    #    spike_index_clear, spike_index_collision)

    # run deconvolution
    #spike_train = deconvolute.run(spike_train_clear, templates,
    #    spike_index_collision)

    # path_to_file = os.path.join(cfg.data.root_folder, output_file)
    #np.savetxt(path_to_file, spike_train, fmt='%i, %i')
    #logger.info('Done, spike train saved in: {}'.format(path_to_file))
def getCleanSpikeTrain(config):
    """
        Threshold detection for extracting clean templates from the raw recording if groundtruth is not available

        Parameters:
        -----------
        config: configuration object
            configuration object containing the parameters for making the training data.
            
        Returns:
        -----------
        spikeTrain: np.array
            [number of spikes, 2] first column corresponds to spike time; second column corresponds to cluster id.
 
    """

    config.detctionMethod = 'threshold'
    config.doWhitening = 0
    config.doDeconv = 0

    pp = Preprocessor(config)
    score, clr_idx, spt = pp.process()
    mp = Mainprocessor(config, score, clr_idx, spt)
    mp.mainProcess()

    return mp.spikeTrain
Example #3
0
def test_process_1k(path_to_config_1k):
    cfg = yass.Config.from_yaml(path_to_config_1k)

    pp = Preprocessor(cfg)
    score, clr_idx, spt = pp.process()

    mp = Mainprocessor(cfg, score, clr_idx, spt)
    spike_train, spt_left = mp.mainProcess()
Example #4
0
def test_new_process(path_to_config):
    cfg = yass.Config.from_yaml(path_to_config)

    pp = Preprocessor(cfg)
    score, clr_idx, spt = pp.process()

    set_config(path_to_config)

    (spike_train_clear, templates,
     spike_index_collision) = process.run(score, clr_idx, spt)
def test_deconvolute(path_to_config):
    cfg = yass.Config.from_yaml(path_to_config)

    pp = Preprocessor(cfg)
    score, clr_idx, spt = pp.process()

    mp = Mainprocessor(cfg, score, clr_idx, spt)
    spike_train, spt_left = mp.mainProcess()

    dc = Deconvolution(cfg, np.transpose(mp.templates, [1, 0, 2]), spt_left)
    dc.fullMPMU()
    def getBigTemplates(self, R):
        """
            Gets clean templates with large temporal radius

            Parameters:
            -----------
            R: int
                length of the templates to be returned.

            Returns:
            -----------
            templates: np.array
                [number of templates, temporal length, number of channels] returned templates.

        """

        pp = Preprocessor(self.config)
        return pp.getTemplates(self.spikeTrain, R)
def test_can_preprocess_data_with_nnet(path_to_nn_config):
    cfg = yass.Config.from_yaml(path_to_nn_config)
    pp = Preprocessor(cfg)
    score, clr_idx, spt = pp.process()
def test_can_preprocess_data_1k(path_to_config_1k):
    cfg = yass.Config.from_yaml(path_to_config_1k)
    pp = Preprocessor(cfg)
    score, clr_idx, spt = pp.process()
# Note: we are working on a improved version Deconvolution, this old
# pipeline will be removed in the near future. However, migrating to the new
# pipeline requires minimum code changes


import numpy as np

import yass
from yass.preprocessing import Preprocessor
from yass.mainprocess import Mainprocessor
from yass.deconvolute import Deconvolution

cfg = yass.Config.from_yaml('tests/config_nnet.yaml')

pp = Preprocessor(cfg)
score, spike_index_clear, spike_index_collision = pp.process()


mp = Mainprocessor(cfg, score, spike_index_clear, spike_index_collision)
spike_train_clear, spike_index_collision = mp.mainProcess()


dc = Deconvolution(cfg, np.transpose(mp.templates, [1, 0, 2]),
                   spike_index_collision)
spike_train = dc.fullMPMU()

spike_train
    def determineNoiseCov(self, temporal_size, D):
        """
            Determines the spatial and temporal covariance of the noise 

            Parameters:
            -----------
            temporal_size: int
                row size of the temporal covariance matrix.
            D: int
                number of channels away from the mainchannel. D=1 if only look at the main channel; D=2 if look at the main
                channel and its neighboring channels; D=3 if look at the main channel, the neighboring channels and the
                neighbors of the neighboring channels.
            Returns:
            -----------
            spatial_SIG: np.array
                [d, d] spatial covariance matrix of the noise where d depends on D.

            temporal_SIG: np.array
                [temporal_size, temporal_size] temporal covariance matrix of the noise.

        """

        pp = Preprocessor(self.config)

        batch_size = self.config.batch_size
        BUFF = self.config.BUFF
        nBatches = self.config.nBatches
        nPortion = self.config.nPortion
        residual = self.config.residual

        R = pp.config.spikeSize
        # get recording in the middle
        pp.openFile()
        i = np.ceil(nBatches / 2)

        self.logger.debug('Loading batch {}...'.format(i))

        # reading data
        if nBatches == 1:
            rec = pp.load(0, batch_size)

        elif i == 0:
            rec = pp.load(i * batch_size, batch_size + BUFF)

        elif i < nBatches - 1:
            rec = pp.load(i * batch_size - BUFF, batch_size + 2 * BUFF)

        elif residual == 0:
            rec = pp.load(i * batch_size - BUFF, batch_size + BUFF)

        else:
            rec = pp.load(i * batch_size - BUFF, residual + BUFF)

        neighChanBig = n_steps_neigh_channels(self.config.neighChannels, D)
        c_ref = np.argmax(np.sum(neighChanBig, 0))
        ch_idx = np.where(neighChanBig[c_ref])[0]
        ch_idx, temp = order_channels_by_distance(c_ref, ch_idx,
                                                  self.config.geom)
        rec = rec[:, ch_idx]

        # filter recording
        if pp.config.preprocess.filter == 1:
            rec = butterworth(rec, self.config.filter.low_pass_freq,
                              self.config.filter.high_factor,
                              self.config.filter.order,
                              self.config.recordings.sampling_rate)

        # standardize recording
        small_t = np.min(
            (int(pp.config.recordings.sampling_rate * 5), 6000000))
        mid_T = int(np.ceil(rec.shape[0] / 2))
        rec_temp = rec[np.arange(mid_T - small_t, mid_T + small_t)]
        sd = np.median(np.abs(rec), 0) / 0.6745
        rec = np.divide(rec, sd)

        pp.closeFile()

        T, C = rec.shape
        idxNoise = np.zeros((T, C))

        for c in range(C):
            idx_temp = np.where(rec[:, c] > 3)[0]
            for j in range(-R, R + 1):
                idx_temp2 = idx_temp + j
                idx_temp2 = idx_temp2[np.logical_and(idx_temp2 >= 0,
                                                     idx_temp2 < T)]
                rec[idx_temp2, c] = np.nan
            idxNoise_temp = (rec[:, c] == rec[:, c])
            rec[:, c] = rec[:, c] / np.nanstd(rec[:, c])

            rec[~idxNoise_temp, c] = 0
            idxNoise[idxNoise_temp, c] = 1

        spatial_cov = np.divide(np.matmul(rec.T, rec),
                                np.matmul(idxNoise.T, idxNoise))

        w, v = np.linalg.eig(spatial_cov)

        spatial_SIG = np.matmul(np.matmul(v, np.diag(np.sqrt(w))), v.T)
        spatial_whitener = np.matmul(np.matmul(v, np.diag(1 / np.sqrt(w))),
                                     v.T)
        rec = np.matmul(rec, spatial_whitener)

        noise_wf = np.zeros((1000, temporal_size))
        count = 0
        while count < 1000:
            tt = np.random.randint(T - temporal_size)
            cc = np.random.randint(C)
            temp = rec[tt:(tt + temporal_size), cc]
            temp_idxnoise = idxNoise[tt:(tt + temporal_size), cc]
            if np.sum(temp_idxnoise == 0) == 0:
                noise_wf[count] = temp
                count += 1

        w, v = np.linalg.eig(np.cov(noise_wf.T))

        temporal_SIG = np.matmul(np.matmul(v, np.diag(np.sqrt(w))), v.T)

        return spatial_SIG, temporal_SIG