Пример #1
0
    def test_write_dat_file(self):
        nb_sample = self.RX.getNumFrames()
        nb_chan = self.RX.getNumChannels()
        
        # time_axis=0 chunksize=None
        se.writeBinaryDatFormat(self.RX, self.test_dir + 'rec.dat', time_axis=0, dtype='float32', chunksize=None)
        data = np.memmap(open(self.test_dir + 'rec.dat'), dtype='float32', mode='r', shape=(nb_sample, nb_chan)).T
        assert np.allclose(data, self.RX.getTraces())
        del(data) # this close the file

        # time_axis=1 chunksize=None
        se.writeBinaryDatFormat(self.RX, self.test_dir + 'rec.dat', time_axis=1, dtype='float32', chunksize=None)
        data = np.memmap(open(self.test_dir + 'rec.dat'), dtype='float32', mode='r', shape=(nb_chan, nb_sample))
        assert np.allclose(data, self.RX.getTraces())
        del(data) # this close the file

        # time_axis=0 chunksize=99
        se.writeBinaryDatFormat(self.RX, self.test_dir + 'rec.dat', time_axis=0, dtype='float32', chunksize=99)
        data = np.memmap(open(self.test_dir + 'rec.dat'), dtype='float32', mode='r', shape=(nb_sample, nb_chan)).T
        assert np.allclose(data, self.RX.getTraces())
        del(data) # this close the file

        # time_axis=1 chunksize=99 do not work
        with self.assertRaises(Exception) as context:
            se.writeBinaryDatFormat(self.RX, self.test_dir + 'rec.dat', time_axis=1, dtype='float32', chunksize=99)
Пример #2
0
def kilosort(recording,
             kilosort_path=None,
             npy_matlab_path=None,
             output_folder=None,
             useGPU=False,
             probe_file=None,
             file_name=None,
             spike_thresh=4,
             electrode_dimensions=None):
    if kilosort_path is None:
        kilosort_path = os.getenv('KILOSORT_PATH', None)
    if npy_matlab_path is None:
        npy_matlab_path = os.getenv('NPY_MATLAB_PATH', None)
    if not os.path.isfile(join(kilosort_path, 'preprocessData.m')) \
            or not os.path.isfile(join(npy_matlab_path, 'readNPY.m')):
        raise ModuleNotFoundError(
            "\nTo use KiloSort, install KiloSort and npy-matlab from sources: \n\n"
            "\ngit clone https://github.com/cortex-lab/KiloSort\n"
            "\ngit clone https://github.com/kwikteam/npy-matlab\n"
            "and provide the installation path with the 'kilosort_path' and "
            "'npy_matlab_path' arguments or by setting the KILOSORT_PATH and NPY_MATLAB_PATH"
            "environment variables.\n+n"
            "\nMore information on KiloSort at: "
            "\nhttps://github.com/cortex-lab/KiloSort")

    source_dir = os.path.dirname(os.path.realpath(__file__))

    if output_folder is None:
        output_folder = os.path.abspath('kilosort')
    else:
        output_folder = os.path.abspath(join(output_folder, 'kilosort'))

    if not os.path.isdir(output_folder):
        os.makedirs(output_folder)

    kilosort_path = os.path.abspath(kilosort_path)
    npy_matlab_path = os.path.abspath(npy_matlab_path)

    if probe_file is not None:
        se.loadProbeFile(recording, probe_file)

    # save binary file
    if file_name is None:
        file_name = 'recording'
    elif file_name.endswith('.dat'):
        file_name = file_name[file_name.find('.dat')]
    se.writeBinaryDatFormat(recording,
                            join(output_folder, file_name),
                            dtype='int16')

    # set up kilosort config files and run kilosort on data
    with open(join(source_dir, 'kilosort_master.txt'), 'r') as f:
        kilosort_master = f.readlines()
    with open(join(source_dir, 'kilosort_config.txt'), 'r') as f:
        kilosort_config = f.readlines()
    with open(join(source_dir, 'kilosort_channelmap.txt'), 'r') as f:
        kilosort_channelmap = f.readlines()

    nchan = recording.getNumChannels()
    dat_file = file_name + '.dat'
    kilo_thresh = spike_thresh
    Nfilt = (nchan // 32) * 32 * 8
    if Nfilt == 0:
        Nfilt = nchan * 8
    nsamples = 128 * 1024 + 64

    if useGPU:
        ug = 1
    else:
        ug = 0

    abs_channel = os.path.abspath(join(output_folder, 'kilosort_channelmap.m'))
    abs_config = os.path.abspath(join(output_folder, 'kilosort_config.m'))

    kilosort_master = ''.join(kilosort_master).format(ug, kilosort_path,
                                                      npy_matlab_path,
                                                      output_folder,
                                                      abs_channel, abs_config)
    kilosort_config = ''.join(kilosort_config).format(
        nchan, nchan, recording.getSamplingFrequency(), dat_file, Nfilt,
        nsamples, kilo_thresh)
    if 'location' in recording.getChannelPropertyNames():
        positions = np.array([
            recording.getChannelProperty(chan, 'location')
            for chan in range(nchan)
        ])
        if electrode_dimensions is None:
            kilosort_channelmap = ''.join(kilosort_channelmap).format(
                nchan, list(positions[:, 0]), list(positions[:, 1]),
                'ones(1, Nchannels)', recording.getSamplingFrequency())
        elif len(electrode_dimensions) == 2:
            kilosort_channelmap = ''.join(kilosort_channelmap).format(
                nchan, list(positions[:, electrode_dimensions[0]]),
                list(positions[:, electrode_dimensions[1]]),
                'ones(1, Nchannels)', recording.getSamplingFrequency())
        else:
            raise Exception("Electrode dimension should bi a list of len 2")

    else:
        raise Exception(
            "'location' information is needed. Provide a probe information with a 'probe_file'"
        )

    for fname, value in zip(
        ['kilosort_master.m', 'kilosort_config.m', 'kilosort_channelmap.m'],
        [kilosort_master, kilosort_config, kilosort_channelmap]):
        with open(join(output_folder, fname), 'w') as f:
            f.writelines(value)

    # start sorting with kilosort
    print('Running KiloSort')
    t_start_proc = time.time()
    cmd = 'matlab -nosplash -nodisplay -r "run {}; quit;"'.format(
        join(output_folder, 'kilosort_master.m'))
    print(cmd)
    call_command(cmd)
    # retcode = run_command_and_print_output(cmd)
    # if retcode != 0:
    #     raise Exception('KiloSort returned a non-zero exit code')
    print('Elapsed time: ', time.time() - t_start_proc)

    sorting = se.KiloSortSortingExtractor(join(output_folder))
    return sorting
Пример #3
0
def _klusta(
        recording,  # The recording extractor
        output_folder=None,
        probe_file=None,
        file_name=None,
        adjacency_radius=None,
        threshold_strong_std_factor=5,
        threshold_weak_std_factor=2,
        detect_sign=-1,
        extract_s_before=16,
        extract_s_after=32,
        n_features_per_channel=3,
        pca_n_waveforms_max=10000,
        num_starting_clusters=50):
    try:
        import klusta
        import klustakwik2
    except ModuleNotFoundError:
        raise ModuleNotFoundError(
            "\nTo use Klusta, install klusta and klustakwik2: \n\n"
            "\npip install klusta klustakwik\n"
            "\nMore information on klusta at: "
            "\nhttps://github.com/kwikteam/phy"
            "\nhttps://github.com/kwikteam/klusta")
    source_dir = Path(__file__).parent
    if output_folder is None:
        output_folder = Path('klusta')
    else:
        output_folder = Path(output_folder)
    if not output_folder.is_dir():
        output_folder.mkdir()

    # save prb file:
    if probe_file is None:
        probe_file = output_folder / 'probe.prb'
        se.saveProbeFile(recording,
                         probe_file,
                         format='klusta',
                         radius=adjacency_radius)

    # save binary file
    if file_name is None:
        file_name = Path('recording')
    elif file_name.suffix == '.dat':
        file_name = file_name.stem
    se.writeBinaryDatFormat(recording, output_folder / file_name)

    if detect_sign < 0:
        detect_sign = 'negative'
    elif detect_sign > 0:
        detect_sign = 'positive'
    else:
        detect_sign = 'both'

    # set up klusta config file
    with (source_dir / 'config_default.prm').open('r') as f:
        klusta_config = f.readlines()

    klusta_config = ''.join(klusta_config).format(
        output_folder / file_name, probe_file,
        float(recording.getSamplingFrequency()), recording.getNumChannels(),
        "'float32'", threshold_strong_std_factor, threshold_weak_std_factor,
        "'" + detect_sign + "'", extract_s_before, extract_s_after,
        n_features_per_channel, pca_n_waveforms_max, num_starting_clusters)

    with (output_folder / 'config.prm').open('w') as f:
        f.writelines(klusta_config)

    print('Running Klusta')
    cmd = 'klusta {} --overwrite'.format(output_folder / 'config.prm')
    print(cmd)
    _call_command(cmd)
    if not (output_folder / (file_name.name + '.kwik')).is_file():
        raise Exception('Klusta did not run successfully')

    sorting = se.KlustaSortingExtractor(output_folder /
                                        (file_name.name + '.kwik'))

    return sorting
Пример #4
0
def klusta(
        recording,  # The recording extractor
        output_folder=None,
        probe_file=None,
        file_name=None,
        threshold_strong_std_factor=5,
        threshold_weak_std_factor=2,
        detect_spikes='negative',
        extract_s_before=16,
        extract_s_after=32,
        n_features_per_channel=3,
        pca_n_waveforms_max=10000,
        num_starting_clusters=50):
    try:
        import klusta
        import klustakwik2
    except ModuleNotFoundError:
        raise ModuleNotFoundError(
            "\nTo use Klusta, install klusta and klustakwik2: \n\n"
            "\npip install klusta klustakwik\n"
            "\nMore information on klusta at: "
            "\nhttps://github.com/kwikteam/phy"
            "\nhttps://github.com/kwikteam/klusta")
    source_dir = os.path.dirname(os.path.realpath(__file__))

    if output_folder is None:
        output_folder = 'klusta'
    else:
        output_folder = join(output_folder, 'klusta')

    if not os.path.isdir(output_folder):
        os.makedirs(output_folder)

    # save prb file:
    if probe_file is None:
        si.saveProbeFile(recording,
                         join(output_folder, 'probe.prb'),
                         format='klusta')
        probe_file = join(output_folder, 'probe.prb')
    # save binary file
    if file_name is None:
        file_name = 'recording'
    elif file_name.endswith('.dat'):
        file_name = file_name[file_name.find('.dat')]
    si.writeBinaryDatFormat(recording, join(output_folder, file_name))

    # set up klusta config file
    with open(join(source_dir, 'config_default.prm'), 'r') as f:
        klusta_config = f.readlines()

    klusta_config = ''.join(klusta_config).format(
        join(output_folder, file_name), probe_file,
        float(recording.getSamplingFrequency()), recording.getNumChannels(),
        "'float32'", threshold_strong_std_factor, threshold_weak_std_factor,
        "'" + detect_spikes + "'", extract_s_before, extract_s_after,
        n_features_per_channel, pca_n_waveforms_max, num_starting_clusters)

    with open(join(output_folder, 'config.prm'), 'w') as f:
        f.writelines(klusta_config)

    print('Running Klusta')
    t_start_proc = time.time()
    cmd = 'klusta {} --overwrite'.format(join(output_folder, 'config.prm'))
    print(cmd)
    retcode = run_command_and_print_output(cmd)
    if retcode != 0:
        raise Exception('Klusta returned a non-zero exit code')
    print('Elapsed time: ', time.time() - t_start_proc)

    sorting = si.KlustaSortingExtractor(
        join(output_folder, file_name + '.kwik'))

    return sorting
Пример #5
0
def _kilosort(recording,
              output_folder=None,
              kilosort_path=None,
              npy_matlab_path=None,
              useGPU=False,
              probe_file=None,
              file_name=None,
              detect_threshold=4,
              electrode_dimensions=None):
    if kilosort_path is None or kilosort_path == 'None':
        klp = os.getenv('KILOSORT_PATH')
        if klp.startswith('"'):
            klp = klp[1:-1]
        kilosort_path = Path(klp)
    if npy_matlab_path is None or npy_matlab_path == 'None':
        npp = os.getenv('NPY_MATLAB_PATH')
        if npp.startswith('"'):
            npp = npp[1:-1]
        npy_matlab_path = Path(npp)
    if not (Path(kilosort_path) / 'preprocessData.m').is_file() \
            or not (Path(npy_matlab_path) / 'npy-matlab' / 'readNPY.m').is_file():
        raise ModuleNotFoundError(
            "\nTo use KiloSort, install KiloSort and npy-matlab from sources: \n\n"
            "\ngit clone https://github.com/cortex-lab/KiloSort\n"
            "\ngit clone https://github.com/kwikteam/npy-matlab\n"
            "and provide the installation path with the 'kilosort_path' and "
            "'npy_matlab_path' arguments or by setting the KILOSORT_PATH and NPY_MATLAB_PATH"
            "environment variables.\n+n"
            "\nMore information on KiloSort at: "
            "\nhttps://github.com/cortex-lab/KiloSort")
    source_dir = Path(__file__).parent
    if output_folder is None:
        output_folder = Path('kilosort')
    else:
        output_folder = Path(output_folder)
    if not output_folder.is_dir():
        output_folder.mkdir()
    output_folder = output_folder.absolute()

    if probe_file is not None:
        recording = se.loadProbeFile(recording, probe_file)

    # save binary file
    if file_name is None:
        file_name = Path('recording')
    elif file_name.suffix == '.dat':
        file_name = file_name.stem
    se.writeBinaryDatFormat(recording,
                            output_folder / file_name,
                            dtype='int16')

    # set up kilosort config files and run kilosort on data
    with (source_dir / 'kilosort_master.txt').open('r') as f:
        kilosort_master = f.readlines()
    with (source_dir / 'kilosort_config.txt').open('r') as f:
        kilosort_config = f.readlines()
    with (source_dir / 'kilosort_channelmap.txt').open('r') as f:
        kilosort_channelmap = f.readlines()

    nchan = recording.getNumChannels()
    dat_file = (output_folder / (file_name.name + '.dat')).absolute()
    kilo_thresh = detect_threshold
    Nfilt = (nchan // 32) * 32 * 8
    if Nfilt == 0:
        Nfilt = nchan * 8
    nsamples = 128 * 1024 + 64

    if useGPU:
        ug = 1
    else:
        ug = 0

    abs_channel = (output_folder / 'kilosort_channelmap.m').absolute()
    abs_config = (output_folder / 'kilosort_config.m').absolute()
    kilosort_path = kilosort_path.absolute()
    npy_matlab_path = npy_matlab_path.absolute() / 'npy-matlab'

    kilosort_master = ''.join(kilosort_master).format(ug, kilosort_path,
                                                      npy_matlab_path,
                                                      output_folder,
                                                      abs_channel, abs_config)
    kilosort_config = ''.join(kilosort_config).format(
        nchan, nchan, recording.getSamplingFrequency(), dat_file, Nfilt,
        nsamples, kilo_thresh)
    if 'location' in recording.getChannelPropertyNames():
        positions = np.array([
            recording.getChannelProperty(chan, 'location')
            for chan in recording.getChannelIds()
        ])
        if electrode_dimensions is None:
            kilosort_channelmap = ''.join(kilosort_channelmap).format(
                nchan, list(positions[:, 0]), list(positions[:, 1]),
                'ones(1, Nchannels)', recording.getSamplingFrequency())
        elif len(electrode_dimensions) == 2:
            kilosort_channelmap = ''.join(kilosort_channelmap).format(
                nchan, list(positions[:, electrode_dimensions[0]]),
                list(positions[:, electrode_dimensions[1]]),
                'ones(1, Nchannels)', recording.getSamplingFrequency())
        else:
            raise Exception("Electrode dimension should bi a list of len 2")

    else:
        raise Exception(
            "'location' information is needed. Provide a probe information with a 'probe_file'"
        )

    for fname, value in zip(
        ['kilosort_master.m', 'kilosort_config.m', 'kilosort_channelmap.m'],
        [kilosort_master, kilosort_config, kilosort_channelmap]):
        with (output_folder / fname).open('w') as f:
            f.writelines(value)

    # start sorting with kilosort
    print('Running KiloSort')
    cmd = "matlab -nosplash -nodisplay -r 'run {}; quit;'".format(
        output_folder / 'kilosort_master.m')
    print(cmd)
    if sys.platform == "win":
        cmd_list = [
            'matlab', '-nosplash', '-nodisplay', '-wait', '-r',
            'run {}; quit;'.format(output_folder / 'kilosort_master.m')
        ]
    else:
        cmd_list = [
            'matlab', '-nosplash', '-nodisplay', '-r',
            'run {}; quit;'.format(output_folder / 'kilosort_master.m')
        ]
    retcode = _run_command_and_print_output_split(cmd_list)
    if not (output_folder / 'spike_times.npy').is_file():
        raise Exception('KiloSort did not run successfully')
    sorting = se.KiloSortSortingExtractor(output_folder)
    return sorting
Пример #6
0
def exportToPhy(recording,
                sorting,
                output_folder,
                nPCchan=3,
                nPC=5,
                filter=False,
                electrode_dimensions=None,
                max_num_waveforms=np.inf):
    analyzer = Analyzer(recording, sorting)

    if not isinstance(recording, se.RecordingExtractor) or not isinstance(
            sorting, se.SortingExtractor):
        raise AttributeError()
    output_folder = os.path.abspath(output_folder)
    if not os.path.isdir(output_folder):
        os.makedirs(output_folder)

    if filter:
        recording = bandpass_filter(recording, freq_min=300, freq_max=6000)

    # save dat file
    se.writeBinaryDatFormat(recording,
                            join(output_folder, 'recording.dat'),
                            dtype='int16')

    # write params.py
    with open(join(output_folder, 'params.py'), 'w') as f:
        f.write("dat_path =" + "'" + join(output_folder, 'recording.dat') +
                "'" + '\n')
        f.write('n_channels_dat = ' + str(recording.getNumChannels()) + '\n')
        f.write("dtype = 'int16'\n")
        f.write('offset = 0\n')
        f.write('sample_rate = ' + str(recording.getSamplingFrequency()) +
                '\n')
        f.write('hp_filtered = False')

    # pc_features.npy - [nSpikes, nFeaturesPerChannel, nPCFeatures] single
    if nPC > recording.getNumChannels():
        nPC = recording.getNumChannels()
        print("Changed number of PC to number of channels: ", nPC)
    pc_scores = analyzer.computePCAscores(n_comp=nPC,
                                          elec=True,
                                          max_num_waveforms=max_num_waveforms)

    # spike times.npy and spike clusters.npy
    spike_times = np.array([])
    spike_clusters = np.array([])
    pc_features = np.array([])
    for i_u, id in enumerate(sorting.getUnitIds()):
        st = sorting.getUnitSpikeTrain(id)
        cl = [i_u] * len(sorting.getUnitSpikeTrain(id))
        pc = pc_scores[i_u]
        spike_times = np.concatenate((spike_times, np.array(st)))
        spike_clusters = np.concatenate((spike_clusters, np.array(cl)))
        if i_u == 0:
            pc_features = np.array(pc)
        else:
            pc_features = np.vstack((pc_features, np.array(pc)))
    sorting_idxs = np.argsort(spike_times)
    spike_times = spike_times[sorting_idxs, np.newaxis]
    spike_clusters = spike_clusters[sorting_idxs, np.newaxis]
    pc_features = pc_features[sorting_idxs, :nPCchan, :]

    # amplitudes.npy
    amplitudes = np.ones((len(spike_times), 1))

    # channel_map.npy
    channel_map = np.array(recording.getChannelIds())

    # channel_positions.npy
    if 'location' in recording.getChannelPropertyNames():
        positions = np.array([
            recording.getChannelProperty(chan, 'location')
            for chan in range(recording.getNumChannels())
        ])
        if electrode_dimensions is not None:
            positions = positions[:, electrode_dimensions]
    else:
        print("'location' property is not available and it will be linear.")
        positions = np.zeros((recording.getNumChannels(), 2))
        positions[:, 1] = np.arange(recording.getNumChannels())

    # pc_feature_ind.npy - [nTemplates, nPCFeatures] uint32
    pc_feature_ind = np.tile(np.arange(nPC), (len(sorting.getUnitIds()), 1))

    # similar_templates.npy - [nTemplates, nTemplates] single
    templates = analyzer.getUnitTemplate()
    similar_templates = _computeTemplatesSimilarity(templates)

    # templates.npy
    templates = np.array(templates).swapaxes(1, 2)

    # template_ind.npy
    templates_ind = np.tile(np.arange(recording.getNumChannels()),
                            (len(sorting.getUnitIds()), 1))

    # spike_templates.npy - [nSpikes, ] uint32
    spike_templates = spike_clusters

    # whitening_mat.npy - [nChannels, nChannels] double
    # whitening_mat_inv.npy - [nChannels, nChannels] double
    whitening_mat, whitening_mat_inv = _computeWhiteningAndInverse(recording)

    np.save(join(output_folder, 'amplitudes.npy'), amplitudes)
    np.save(join(output_folder, 'spike_times.npy'), spike_times.astype(int))
    # np.save(join(output_folder, 'spike_clusters.npy'), spike_clusters.astype(int))
    np.save(join(output_folder, 'spike_templates.npy'),
            spike_templates.astype(int))
    np.save(join(output_folder, 'pc_features.npy'), pc_features)
    np.save(join(output_folder, 'pc_feature_ind.npy'),
            pc_feature_ind.astype(int))
    np.save(join(output_folder, 'templates.npy'), templates)
    np.save(join(output_folder, 'templates_ind.npy'),
            templates_ind.astype(int))
    np.save(join(output_folder, 'similar_templates.npy'), similar_templates)
    np.save(join(output_folder, 'channel_map.npy'), channel_map.astype(int))
    np.save(join(output_folder, 'channel_positions.npy'), positions)
    np.save(join(output_folder, 'whitening_mat.npy'), whitening_mat)
    np.save(join(output_folder, 'whitening_mat_inv.npy'), whitening_mat_inv)
    print('Saved phy format to: ', output_folder)
    print('Run:\n\nphy template-gui ', join(output_folder, 'params.py'))