예제 #1
0
def jrclust_helper(
        *,
        recording,  # Recording object
        tmpdir,  # Temporary working directory
        params=dict(),
        **kwargs):

    jrclust_path = os.environ.get('JRCLUST_PATH_DEV', None)
    if jrclust_path:
        print('Using jrclust from JRCLUST_PATH_DEV directory: {}'.format(
            jrclust_path))
    else:
        try:
            print('Auto-installing jrclust.')
            jrclust_path = install_jrclust(
                repo='https://github.com/JaneliaSciComp/JRCLUST.git',
                commit='68ffb3ef064f97aca7043b7faac49c34a58997d9')
        except:
            traceback.print_exc()
            raise Exception(
                'Problem installing jrclust. You can set the JRCLUST_PATH_DEV to force to use a particular path.'
            )
    print('Using jrclust from: {}'.format(jrclust_path))

    dataset_dir = os.path.join(tmpdir, 'jrclust_dataset')
    # Generate three files in the dataset directory: raw.mda, geom.csv, params.json
    SFMdaRecordingExtractor.write_recording(recording=recording,
                                            save_path=dataset_dir,
                                            params=params,
                                            _preserve_dtype=True)

    samplerate = recording.get_sampling_frequency()

    print('Reading timeseries header...')
    raw_mda = os.path.join(dataset_dir, 'raw.mda')
    HH = mdaio.readmda_header(raw_mda)
    num_channels = HH.dims[0]
    num_timepoints = HH.dims[1]
    duration_minutes = num_timepoints / samplerate / 60
    print('Num. channels = {}, Num. timepoints = {}, duration = {} minutes'.
          format(num_channels, num_timepoints, duration_minutes))

    print('Creating argfile.txt...')
    txt = ''
    for key0, val0 in kwargs.items():
        txt += '{}={}\n'.format(key0, val0)
    if 'scale_factor' in params:
        txt += 'bitScaling={}\n'.format(params["scale_factor"])
    txt += 'sampleRate={}\n'.format(samplerate)
    _write_text_file(dataset_dir + '/argfile.txt', txt)

    # new method
    source_path = os.path.dirname(os.path.realpath(__file__))
    print('Running jrclust in {tmpdir}...'.format(tmpdir=tmpdir))
    cmd = '''
        addpath('{jrclust_path}', '{source_path}', '{source_path}/mdaio');
        try
            p_jrclust('{tmpdir}', '{dataset_dir}/raw.mda', '{dataset_dir}/geom.csv', '{tmpdir}/firings.mda', '{dataset_dir}/argfile.txt');
        catch
            fprintf('----------------------------------------');
            fprintf(lasterr());
            quit(1);
        end
        quit(0);
    '''
    cmd = cmd.format(jrclust_path=jrclust_path,
                     tmpdir=tmpdir,
                     dataset_dir=dataset_dir,
                     source_path=source_path)

    matlab_cmd = mlpr.ShellScript(cmd,
                                  script_path=tmpdir + '/run_jrclust.m',
                                  keep_temp_files=True)
    matlab_cmd.write()

    shell_cmd = '''
        #!/bin/bash
        cd {tmpdir}
        matlab -nosplash -nodisplay -r run_jrclust
    '''.format(tmpdir=tmpdir)
    shell_cmd = mlpr.ShellScript(shell_cmd,
                                 script_path=tmpdir + '/run_jrclust.sh',
                                 keep_temp_files=True)
    shell_cmd.write(tmpdir + '/run_jrclust.sh')
    shell_cmd.start()

    retcode = shell_cmd.wait()

    if retcode != 0:
        raise Exception('jrclust returned a non-zero exit code')

    # parse output
    result_fname = tmpdir + '/firings.mda'
    if not os.path.exists(result_fname):
        raise Exception('Result file does not exist: ' + result_fname)

    firings = mdaio.readmda(result_fname)
    sorting = se.NumpySortingExtractor()
    sorting.set_times_labels(firings[1, :], firings[2, :])
    return sorting
예제 #2
0
def kilosort2_helper(
        *,
        recording,  # Recording object
        tmpdir,  # Temporary working directory
        detect_sign=-1,  # Polarity of the spikes, -1, 0, or 1
        adjacency_radius=-1,  # Channel neighborhood adjacency radius corresponding to geom file
        detect_threshold=6,  # Threshold for detection
        merge_thresh=.98,  # Cluster merging threhold 0..1
        freq_min=150,  # Lower frequency limit for band-pass filter
        freq_max=6000,  # Upper frequency limit for band-pass filter
        pc_per_chan=3,  # number of PC per chan
        minFR=1 / 50):

    # # TODO: do not require ks2 to depend on irc -- rather, put all necessary .m code in the spikeforest repo
    # ironclust_path = os.environ.get('IRONCLUST_PATH_DEV', None)
    # if ironclust_path:
    #     print('Using ironclust from IRONCLUST_PATH_DEV directory: {}'.format(ironclust_path))
    # else:
    #     try:
    #         print('Auto-installing ironclust.')
    #         ironclust_path = install_ironclust(commit='042b600b014de13f6d11d3b4e50e849caafb4709')
    #     except:
    #         traceback.print_exc()
    #         raise Exception('Problem installing ironclust. You can set the IRONCLUST_PATH_DEV to force to use a particular path.')
    # print('For kilosort2, using ironclust utility functions from: {}'.format(ironclust_path))

    kilosort2_path = os.environ.get('KILOSORT2_PATH_DEV', None)
    if kilosort2_path:
        print('Using kilosort2 from KILOSORT2_PATH_DEV directory: {}'.format(
            kilosort2_path))
    else:
        try:
            print('Auto-installing kilosort2.')
            kilosort2_path = KiloSort2.install()
        except:
            traceback.print_exc()
            raise Exception(
                'Problem installing kilosort2. You can set the KILOSORT2_PATH_DEV to force to use a particular path.'
            )
    print('Using kilosort2 from: {}'.format(kilosort2_path))

    source_dir = os.path.dirname(os.path.realpath(__file__))

    dataset_dir = tmpdir + '/kilosort2_dataset'
    # Generate three files in the dataset directory: raw.mda, geom.csv, params.json
    SFMdaRecordingExtractor.write_recording(recording=recording,
                                            save_path=dataset_dir,
                                            _preserve_dtype=True)

    samplerate = recording.get_sampling_frequency()

    print('Reading timeseries header...')
    HH = mdaio.readmda_header(dataset_dir + '/raw.mda')
    num_channels = HH.dims[0]
    num_timepoints = HH.dims[1]
    duration_minutes = num_timepoints / samplerate / 60
    print('Num. channels = {}, Num. timepoints = {}, duration = {} minutes'.
          format(num_channels, num_timepoints, duration_minutes))

    print('Creating argfile.txt file...')
    txt = ''
    txt += 'samplerate={}\n'.format(samplerate)
    txt += 'detect_sign={}\n'.format(detect_sign)
    txt += 'adjacency_radius={}\n'.format(adjacency_radius)
    txt += 'detect_threshold={}\n'.format(detect_threshold)
    txt += 'merge_thresh={}\n'.format(merge_thresh)
    txt += 'freq_min={}\n'.format(freq_min)
    txt += 'freq_max={}\n'.format(freq_max)
    txt += 'pc_per_chan={}\n'.format(pc_per_chan)
    txt += 'minFR={}\n'.format(minFR)
    _write_text_file(dataset_dir + '/argfile.txt', txt)

    print('Running Kilosort2 in {tmpdir}...'.format(tmpdir=tmpdir))
    cmd = '''
        addpath('{source_dir}');
        addpath('{source_dir}/mdaio')
        try
            p_kilosort2('{ksort}', '{tmpdir}', '{raw}', '{geom}', '{firings}', '{arg}');
        catch
            quit(1);
        end
        quit(0);
        '''
    cmd = cmd.format(source_dir=source_dir,
                     ksort=kilosort2_path,
                     tmpdir=tmpdir,
                     raw=dataset_dir + '/raw.mda',
                     geom=dataset_dir + '/geom.csv',
                     firings=tmpdir + '/firings.mda',
                     arg=dataset_dir + '/argfile.txt')
    matlab_cmd = mlpr.ShellScript(cmd,
                                  script_path=tmpdir + '/run_kilosort2.m',
                                  keep_temp_files=True)
    matlab_cmd.write()
    shell_cmd = '''
        #!/bin/bash
        cd {tmpdir}
        echo '=====================' `date` '====================='
        matlab -nosplash -nodisplay -r run_kilosort2
    '''.format(tmpdir=tmpdir)
    shell_cmd = mlpr.ShellScript(shell_cmd,
                                 script_path=tmpdir + '/run_kilosort2.sh',
                                 keep_temp_files=True)
    shell_cmd.write(tmpdir + '/run_kilosort2.sh')
    shell_cmd.start()
    retcode = shell_cmd.wait()

    if retcode != 0:
        raise Exception('kilosort2 returned a non-zero exit code')

    # parse output
    result_fname = tmpdir + '/firings.mda'
    if not os.path.exists(result_fname):
        raise Exception('Result file does not exist: ' + result_fname)

    firings = mdaio.readmda(result_fname)
    sorting = se.NumpySortingExtractor()
    sorting.set_times_labels(firings[1, :], firings[2, :])
    return sorting
예제 #3
0
def ironclust_helper(
        *,
        recording,  # Recording object
        tmpdir,  # Temporary working directory
        params=dict(),
        ironclust_path,
        **kwargs):
    source_dir = os.path.dirname(os.path.realpath(__file__))

    dataset_dir = tmpdir + '/ironclust_dataset'
    # Generate three files in the dataset directory: raw.mda, geom.csv, params.json
    SFMdaRecordingExtractor.write_recording(
        recording=recording, save_path=dataset_dir, params=params, _preserve_dtype=True)

    samplerate = recording.get_sampling_frequency()

    print('Reading timeseries header...')
    HH = mdaio.readmda_header(dataset_dir + '/raw.mda')
    num_channels = HH.dims[0]
    num_timepoints = HH.dims[1]
    duration_minutes = num_timepoints / samplerate / 60
    print('Num. channels = {}, Num. timepoints = {}, duration = {} minutes'.format(
        num_channels, num_timepoints, duration_minutes))

    print('Creating argfile.txt...')
    txt = ''
    for key0, val0 in kwargs.items():
        txt += '{}={}\n'.format(key0, val0)
    txt += 'samplerate={}\n'.format(samplerate)
    if 'scale_factor' in params:
        txt += 'scale_factor={}\n'.format(params["scale_factor"])
    _write_text_file(dataset_dir + '/argfile.txt', txt)

    # new method
    print('Running ironclust in {tmpdir}...'.format(tmpdir=tmpdir))
    cmd = '''
        addpath('{source_dir}');
        addpath('{ironclust_path}', '{ironclust_path}/matlab', '{ironclust_path}/matlab/mdaio');
        try
            p_ironclust('{tmpdir}', '{dataset_dir}/raw.mda', '{dataset_dir}/geom.csv', '', '', '{tmpdir}/firings.mda', '{dataset_dir}/argfile.txt');
        catch
            fprintf('----------------------------------------');
            fprintf(lasterr());
            quit(1);
        end
        quit(0);
    '''
    cmd = cmd.format(ironclust_path=ironclust_path, tmpdir=tmpdir, dataset_dir=dataset_dir, source_dir=source_dir)

    matlab_cmd = mlpr.ShellScript(cmd, script_path=tmpdir + '/run_ironclust.m', keep_temp_files=True)
    matlab_cmd.write()

    shell_cmd = '''
        #!/bin/bash
        cd {tmpdir}
        matlab -nosplash -nodisplay -r run_ironclust
    '''.format(tmpdir=tmpdir)
    shell_cmd = mlpr.ShellScript(shell_cmd, script_path=tmpdir + '/run_ironclust.sh', keep_temp_files=True)
    shell_cmd.write(tmpdir + '/run_ironclust.sh')
    shell_cmd.start()

    retcode = shell_cmd.wait()

    if retcode != 0:
        raise Exception('ironclust returned a non-zero exit code')

    # parse output
    result_fname = tmpdir + '/firings.mda'
    if not os.path.exists(result_fname):
        raise Exception('Result file does not exist: ' + result_fname)

    firings = mdaio.readmda(result_fname)
    sorting = se.NumpySortingExtractor()
    sorting.set_times_labels(firings[1, :], firings[2, :])
    return sorting
예제 #4
0
from spikeforest import SFMdaRecordingExtractor, SFMdaSortingExtractor
from mountaintools import client as mt

# Configure to download from the public spikeforest kachery node
mt.configDownloadFrom('spikeforest.public')

# Load an example tetrode recording with its ground truth
# You can also substitute any of the other available recordings
recdir = 'sha1dir://fb52d510d2543634e247e0d2d1d4390be9ed9e20.synth_magland/datasets_noise10_K10_C4/001_synth'

print('loading recording...')
recording = SFMdaRecordingExtractor(dataset_directory=recdir, download=True)
sorting_true = SFMdaSortingExtractor(firings_file=recdir + '/firings_true.mda')
예제 #5
0
import os
import shutil
from spikeforest import example_datasets
from spikeforest import SFMdaRecordingExtractor, SFMdaSortingExtractor

recording, sorting_true = example_datasets.toy_example1()

recdir = 'toy_example1'

# remove the toy recording directory if it exists
if os.path.exists(recdir):
    shutil.rmtree(recdir)

print('Preparing toy recording...')
SFMdaRecordingExtractor.write_recording(recording=recording, save_path=recdir)
SFMdaSortingExtractor.write_sorting(sorting=sorting_true,
                                    save_path=recdir + '/firings_true.mda')
예제 #6
0
#from mountaintools import client as ca
from spikeforest import SFMdaRecordingExtractor, SFMdaSortingExtractor, example_datasets
import os
from spikeforestsorters import YASS
from spikeforest import spikewidgets as sw

# %%SortingComparisonyass_test1/recording/raw.mda
tmpdir = 'yass_test1'
if not os.path.isdir(tmpdir):
    os.mkdir(tmpdir)
rx, sx = example_datasets.yass_example(set_id=1)

# %%
firings_true = tmpdir + '/recording/firings_true.mda'
recording_path = tmpdir + '/recording'
SFMdaRecordingExtractor.write_recording(recording=rx, save_path=recording_path)
SFMdaSortingExtractor.write_sorting(sorting=sx, save_path=firings_true)

YASS.execute(recording_dir=tmpdir + '/recording',
             firings_out=tmpdir + '/firings_out.mda',
             detect_sign=-1,
             adjacency_radius=50,
             _container=None,
             _force_run=True,
             _keep_temp_files=True)
firings_out = tmpdir + '/firings_out.mda'
assert os.path.exists(firings_out)

# %%
print('recording: {}'.format(recording_path))
print('firings_out: {}'.format(firings_out))