Ejemplo n.º 1
0
 def test_load_save_probes(self):
     SX = se.loadProbeFile(self.RX, 'tests/probe_test.prb')
     # print(SX.getChannelPropertyNames())
     assert 'location' in SX.getChannelPropertyNames()
     assert 'group' in SX.getChannelPropertyNames()
     positions = [SX.getChannelProperty(chan, 'location') for chan in range(self.RX.getNumChannels())]
     # save in csv
     se.saveProbeFile(SX, Path(self.test_dir) / 'geom.csv')
     # load csv locations
     SX_load = se.loadProbeFile(SX, Path(self.test_dir) / 'geom.csv')
     position_loaded = [SX_load.getChannelProperty(chan, 'location') for chan in range(SX_load.getNumChannels())]
     self.assertTrue(np.allclose(positions[10], position_loaded[10]))
Ejemplo n.º 2
0
def yass_helper(
        recording,
        output_folder=None,  # Temporary working directory
        probe_file=None,
        file_name=None,
        detect_sign=-1,  # -1 - 1 - 0
        template_width_ms=1,  # yass parameter
        filter=True,
        adjacency_radius=100):

    source_dir = os.path.dirname(os.path.realpath(__file__))

    # make output dir
    if output_folder is None:
        output_folder = 'yass'
    else:
        output_folder = join(output_folder, 'yass')
    output_folder = os.path.abspath(output_folder)
    if not os.path.isdir(output_folder):
        os.makedirs(output_folder)

    # save prb file:
    if probe_file is None:
        probe_file = join_abspath_(output_folder, 'probe.npy')
    se.saveProbeFile(recording, probe_file, format='yass')

    # save binary file
    if file_name is None:
        file_name = 'raw.bin'
    bin_file = join_abspath_(output_folder, file_name)
    # print('bin_file:{}'.format(bin_file))
    writeRecording_(recording=recording, save_path=bin_file,
                    fReversePolarity=(detect_sign > 0), dtype=np.float32, scale_factor=1)
    #print('bin_file exists? {}'.format(os.path.exists(bin_file)))

    # set up yass config file
    print(source_dir)
    with open(join(source_dir, 'config_default.yaml'), 'r') as f:
        yass_config = f.read()

    # get the order
    # root_folder, recordings, geometry, dtype, sampling_rate, n_channels, spatial_radius, spike_size_ms, filter
    n_channels = recording.getNumChannels()
    sampling_rate = recording.getSamplingFrequency()

    # print('sampling_rate={}'.format(sampling_rate))

    yaml_file = join(output_folder, file_name + '.yaml')
    yass_config = yass_config.format(
        output_folder, bin_file, probe_file, 'single', int(sampling_rate), n_channels, adjacency_radius, template_width_ms, filter)
    with open(yaml_file, 'w') as f:
        f.write(yass_config)

    with open(yaml_file) as ff:
        print('YASS CONFIG:')
        print(ff.read())

    print('Running yass...')
    t_start_proc = time.time()

    yass_path = '/usr/local/bin'
    num_cores_str = ''
    # cmd = 'python2 {}/yass {} {} '.format(
    #    yass_path, join(output_folder, file_name+'.yaml'), num_cores_str)
    cmd = 'yass {}'.format(join(output_folder, file_name+'.yaml'))

    retcode = run_command_and_print_output(cmd)
    if retcode != 0:
        raise Exception('yass returned a non-zero exit code')

    # retcode = run_command_and_print_output(cmd_merge)
    # if retcode != 0:
    #    raise Exception('yass merging returned a non-zero exit code')
    processing_time = time.time() - t_start_proc
    print('Elapsed time: ', processing_time)
    sorting = yassSortingExtractor(join_abspath_(output_folder, 'tmp'))

    return sorting, yaml_file
Ejemplo n.º 3
0
def spyking_circus(
        recording,
        output_folder=None,  # Temporary working directory
        probe_file=None,
        file_name=None,
        detect_sign=-1,  # -1 - 1 - 0
        adjacency_radius=100,  # Channel neighborhood adjacency radius corresponding to geom file
        spike_thresh=6,  # Threshold for detection
        template_width_ms=3,  # Spyking circus parameter
        filter=True,
        merge_spikes=True,
        n_cores=None,
        electrode_dimensions=None,
        whitening_max_elts=1000,  # I believe it relates to subsampling and affects compute time
        clustering_max_elts=10000,  # I believe it relates to subsampling and affects compute time
):
    try:
        import circus
    except ModuleNotFoundError:
        raise ModuleNotFoundError(
            "\nTo use Spyking-Circus, install spyking-circus: \n\n"
            "\npip install spyking-circus"
            "\nfor ubuntu install openmpi: "
            "\nsudo apt install libopenmpi-dev"
            "\nMore information on Spyking-Circus at: "
            "\nhttps://spyking-circus.readthedocs.io/en/latest/")
    source_dir = os.path.dirname(os.path.realpath(__file__))

    if output_folder is None:
        output_folder = 'spyking_circus'
    else:
        output_folder = join(output_folder, 'spyking_circus')
    output_folder = os.path.abspath(output_folder)

    if not os.path.isdir(output_folder):
        os.makedirs(output_folder)

    # save prb file:
    if probe_file is None:
        se.saveProbeFile(recording,
                         join(output_folder, 'probe.prb'),
                         format='spyking_circus',
                         radius=adjacency_radius,
                         dimensions=electrode_dimensions)
        probe_file = join(output_folder, 'probe.prb')
    # save binary file
    if file_name is None:
        file_name = 'recording'
    elif file_name.endswith('.npy'):
        file_name = file_name[file_name.find('.npy')]
    np.save(join(output_folder, file_name),
            recording.getTraces().astype('float32'))

    if detect_sign < 0:
        detect_sign = 'negative'
    elif detect_sign > 0:
        detect_sign = 'positive'
    else:
        detect_sign = 'both'

    # set up spykingcircus config file
    with open(join(source_dir, 'config_default.params'), 'r') as f:
        circus_config = f.readlines()
    if merge_spikes:
        auto = 1e-5
    else:
        auto = 0
    circus_config = ''.join(circus_config).format(
        float(recording.getSamplingFrequency()), probe_file, template_width_ms,
        spike_thresh, detect_sign, filter, whitening_max_elts,
        clustering_max_elts, auto)
    with open(join(output_folder, file_name + '.params'), 'w') as f:
        f.writelines(circus_config)

    print('Running spyking circus...')
    t_start_proc = time.time()
    if n_cores is None:
        n_cores = np.maximum(1, int(os.cpu_count() / 2))

    cmd = 'spyking-circus {} -c {} '.format(
        join(output_folder, file_name + '.npy'), n_cores)
    cmd_merge = 'spyking-circus {} -m merging -c {} '.format(
        join(output_folder, file_name + '.npy'), n_cores)
    # cmd_convert = 'spyking-circus {} -m converting'.format(join(output_folder, file_name+'.npy'))
    print(cmd)
    retcode = run_command_and_print_output(cmd)
    if retcode != 0:
        raise Exception('Spyking circus returned a non-zero exit code')
    print(cmd_merge)
    retcode = run_command_and_print_output(cmd_merge)
    if retcode != 0:
        raise Exception('Spyking circus merging returned a non-zero exit code')
    processing_time = time.time() - t_start_proc
    print('Elapsed time: ', processing_time)
    sorting = se.SpykingCircusSortingExtractor(join(output_folder, file_name))

    return sorting
Ejemplo n.º 4
0
def _klusta(
        recording,  # The recording extractor
        output_folder=None,
        probe_file=None,
        file_name=None,
        adjacency_radius=None,
        threshold_strong_std_factor=5,
        threshold_weak_std_factor=2,
        detect_sign=-1,
        extract_s_before=16,
        extract_s_after=32,
        n_features_per_channel=3,
        pca_n_waveforms_max=10000,
        num_starting_clusters=50):
    try:
        import klusta
        import klustakwik2
    except ModuleNotFoundError:
        raise ModuleNotFoundError(
            "\nTo use Klusta, install klusta and klustakwik2: \n\n"
            "\npip install klusta klustakwik\n"
            "\nMore information on klusta at: "
            "\nhttps://github.com/kwikteam/phy"
            "\nhttps://github.com/kwikteam/klusta")
    source_dir = Path(__file__).parent
    if output_folder is None:
        output_folder = Path('klusta')
    else:
        output_folder = Path(output_folder)
    if not output_folder.is_dir():
        output_folder.mkdir()

    # save prb file:
    if probe_file is None:
        probe_file = output_folder / 'probe.prb'
        se.saveProbeFile(recording,
                         probe_file,
                         format='klusta',
                         radius=adjacency_radius)

    # save binary file
    if file_name is None:
        file_name = Path('recording')
    elif file_name.suffix == '.dat':
        file_name = file_name.stem
    se.writeBinaryDatFormat(recording, output_folder / file_name)

    if detect_sign < 0:
        detect_sign = 'negative'
    elif detect_sign > 0:
        detect_sign = 'positive'
    else:
        detect_sign = 'both'

    # set up klusta config file
    with (source_dir / 'config_default.prm').open('r') as f:
        klusta_config = f.readlines()

    klusta_config = ''.join(klusta_config).format(
        output_folder / file_name, probe_file,
        float(recording.getSamplingFrequency()), recording.getNumChannels(),
        "'float32'", threshold_strong_std_factor, threshold_weak_std_factor,
        "'" + detect_sign + "'", extract_s_before, extract_s_after,
        n_features_per_channel, pca_n_waveforms_max, num_starting_clusters)

    with (output_folder / 'config.prm').open('w') as f:
        f.writelines(klusta_config)

    print('Running Klusta')
    cmd = 'klusta {} --overwrite'.format(output_folder / 'config.prm')
    print(cmd)
    _call_command(cmd)
    if not (output_folder / (file_name.name + '.kwik')).is_file():
        raise Exception('Klusta did not run successfully')

    sorting = se.KlustaSortingExtractor(output_folder /
                                        (file_name.name + '.kwik'))

    return sorting
Ejemplo n.º 5
0
def klusta(
        recording,  # The recording extractor
        output_folder=None,
        probe_file=None,
        file_name=None,
        threshold_strong_std_factor=5,
        threshold_weak_std_factor=2,
        detect_spikes='negative',
        extract_s_before=16,
        extract_s_after=32,
        n_features_per_channel=3,
        pca_n_waveforms_max=10000,
        num_starting_clusters=50):
    try:
        import klusta
        import klustakwik2
    except ModuleNotFoundError:
        raise ModuleNotFoundError(
            "\nTo use Klusta, install klusta and klustakwik2: \n\n"
            "\npip install klusta klustakwik\n"
            "\nMore information on klusta at: "
            "\nhttps://github.com/kwikteam/phy"
            "\nhttps://github.com/kwikteam/klusta")
    source_dir = os.path.dirname(os.path.realpath(__file__))

    if output_folder is None:
        output_folder = 'klusta'
    else:
        output_folder = join(output_folder, 'klusta')

    if not os.path.isdir(output_folder):
        os.makedirs(output_folder)

    # save prb file:
    if probe_file is None:
        si.saveProbeFile(recording,
                         join(output_folder, 'probe.prb'),
                         format='klusta')
        probe_file = join(output_folder, 'probe.prb')
    # save binary file
    if file_name is None:
        file_name = 'recording'
    elif file_name.endswith('.dat'):
        file_name = file_name[file_name.find('.dat')]
    si.writeBinaryDatFormat(recording, join(output_folder, file_name))

    # set up klusta config file
    with open(join(source_dir, 'config_default.prm'), 'r') as f:
        klusta_config = f.readlines()

    klusta_config = ''.join(klusta_config).format(
        join(output_folder, file_name), probe_file,
        float(recording.getSamplingFrequency()), recording.getNumChannels(),
        "'float32'", threshold_strong_std_factor, threshold_weak_std_factor,
        "'" + detect_spikes + "'", extract_s_before, extract_s_after,
        n_features_per_channel, pca_n_waveforms_max, num_starting_clusters)

    with open(join(output_folder, 'config.prm'), 'w') as f:
        f.writelines(klusta_config)

    print('Running Klusta')
    t_start_proc = time.time()
    cmd = 'klusta {} --overwrite'.format(join(output_folder, 'config.prm'))
    print(cmd)
    retcode = run_command_and_print_output(cmd)
    if retcode != 0:
        raise Exception('Klusta returned a non-zero exit code')
    print('Elapsed time: ', time.time() - t_start_proc)

    sorting = si.KlustaSortingExtractor(
        join(output_folder, file_name + '.kwik'))

    return sorting