def test_mask_out_artifacts():
    
    # Create noisy array
    samplerate = int(48e3)
    duration = 30 # seconds
    n_samples = samplerate*duration
    noise_amplitude = 5
    noise = noise_amplitude*np.random.normal(0,1,n_samples)
    standard_dev = np.std(noise)
    
     # add three artefacts
    n_artifacts = 3
    artifacts = np.zeros_like(noise)
    artifact_duration = int(0.2*samplerate) # samples
    artifact_signal = np.zeros((n_artifacts, artifact_duration))

    for i in np.arange(n_artifacts):                   
        artifact_signal[i, :] = noise_amplitude*np.random.normal(0,6,artifact_duration)

    artifact_indices = np.tile(np.arange(artifact_duration), (3,1))

    artifact_shift = np.array([int(n_samples*0.10), int(n_samples*0.20), int(n_samples*0.70)])

    artifact_indices += artifact_shift.reshape((-1,1))

    for i, indices in enumerate(artifact_indices):
        artifacts[indices] = artifact_signal[i,:]

    signal = noise + artifacts

    timeseries = 'test_mask.mda'
    timeseries_out = 'masked.mda' 
    
    # write as mda
    mdaio.writemda32(signal.reshape((1,-1)), timeseries)
    
    # run the mask artefacts
    mask_out_artifacts(timeseries=timeseries, timeseries_out=timeseries_out, threshold=6, chunk_size=2000, 
                       num_write_chunks=150)
    
    # check that they are gone 
    read_data = mdaio.readmda(timeseries).reshape((-1,1))
    masked_data = mdaio.readmda(timeseries_out).reshape((-1,1))

    indices_masked = sum(masked_data[artifact_indices,0].flatten() == 0)
    total_indices_to_mask = len(artifact_indices.flatten())
    masked = indices_masked == total_indices_to_mask
    
    os.remove(timeseries)
    os.remove(timeseries_out)
    
    if masked:
        print('Artifacts 100% masked')
        return True
    else:
        print('Artifacts %.2f%% masked' % (100*(indices_masked/total_indices_to_mask)))
        return False
def compare_ground_truth(*, firings, firings_true, json_out, opts={}):
    Ft = mdaio.readmda(mlp.realizeFile(firings_true))
    F = mdaio.readmda(mlp.realizeFile(firings))
    times1 = Ft[1, :]
    labels1 = Ft[2, :]
    times2 = F[1, :]
    labels2 = F[2, :]
    out = compare_ground_truth_helper(times1, labels1, times2, labels2)
    with open(json_out, 'w') as outfile:
        json.dump(out, outfile, indent=4)
示例#3
0
def test_convert_array(dtype='int32', shape=[12, 3, 7]):
    X = np.array(np.random.normal(0, 1, shape), dtype=dtype)
    np.save('test_convert1.npy', X)
    convert_array(input='test_convert1.npy',
                  output='test_convert2.mda')  # npy -> mda
    convert_array(input='test_convert2.mda',
                  output='test_convert3.npy')  # mda -> npy
    convert_array(input='test_convert3.npy',
                  output='test_convert4.dat')  # npy -> dat
    convert_array(input='test_convert4.dat',
                  output='test_convert5.npy',
                  dtype=dtype,
                  dimensions=','.join(str(entry)
                                      for entry in X.shape))  # dat -> npy
    convert_array(input='test_convert5.npy',
                  output='test_convert6.mda')  # npy -> mda
    convert_array(input='test_convert6.mda',
                  output='test_convert7.dat')  # mda -> dat
    convert_array(input='test_convert7.dat',
                  output='test_convert8.mda',
                  dtype=dtype,
                  dimensions=','.join(str(entry)
                                      for entry in X.shape))  # dat -> mda
    Y = mdaio.readmda('test_convert8.mda')
    print(np.max(np.abs(X - Y)), Y.dtype)
示例#4
0
def convert_firings(*, firings, params, res_fname, clu_fname, tetrode):
    """
    Export firings (either .mda or .npy) to .res & .clu format.

    Parameters
    ----------
    firings : INPUT
        Path of firings mda (or npy) file (RxL) where R>=3 and L is the number of events. Second row are timestamps, third row are integer labels.    
     
    params : INPUT
        params_json including channel map for tetrode.
   
    res_fname : OUTPUT
        Path of output spike times .res file.

    clu_fname : OUTPUT
        Path of output spike ids .clu file.

    tetrode : int
        Which tetrode to write out firings for.

    """

    F = mdaio.readmda(firings)
    L = F.shape[1]
    L = L
    whch = F[0, :].ravel()[:]
    times = F[1, :].ravel()[:]
    labels = F[2, :].ravel().astype(int)[:]
    K = np.max(labels)

    with open(params, 'r') as fp:
        params_obj = json.load(fp)

    tetmap = []
    pre = 1
    for i, x in enumerate(params_obj['tetrodes']):
        tetmap.append(list(range(pre, pre + len(x))))
        pre += len(x)

    which_tet = np.array([np.where(w == tetmap)[0][0] + 1 for w in whch])
    inds_k = np.where(which_tet == int(tetrode))[0]

    ###### Write .res file ######
    res = times[inds_k]
    np.savetxt(res_fname, res, fmt='%d')

    labels_tet = labels[inds_k]
    u_labels_tet = np.unique(labels_tet)
    K = len(u_labels_tet)
    u_new_labels = np.argsort(u_labels_tet).argsort() + 2
    new_labels = [
        u_new_labels[(l == np.array(u_labels_tet)).argmax()]
        for l in labels_tet
    ]
    np.savetxt(clu_fname,
               np.concatenate((np.array([K + 1]), new_labels)),
               fmt='%d')

    return True
def validate_sorting_results(*,dataset_dir,sorting_output_dir,output_dir):
    if not os.path.exists(output_dir):
        os.mkdir(output_dir)
        
    compare_ground_truth(
        firings=sorting_output_dir+'/firings.mda',
        firings_true=dataset_dir+'/firings_true.mda',
        json_out=output_dir+'/compare_ground_truth.json',
    )
    
    compute_templates(
        timeseries=dataset_dir+'/raw.mda',
        firings=dataset_dir+'/firings_true.mda',
        templates_out=output_dir+'/templates_true.mda.prv'
    )
    
    mlp.runPipeline()
    
    templates_true=mdaio.readmda(mlp.realizeFile(output_dir+'/templates_true.mda'))
    amplitudes_true=np.max(np.max(np.abs(templates_true),axis=1),axis=0)
    accuracies=get_accuracies(output_dir+'/compare_ground_truth.json')
    return dict(
        accuracies=accuracies,
        amplitudes_true=amplitudes_true
    )
示例#6
0
def compute_templates_helper(*, timeseries, firings, clip_size=100):
    X = mdaio.DiskReadMda(timeseries)
    M, N = X.N1(), X.N2()
    F = mdaio.readmda(firings)
    L = F.shape[1]
    L = L
    T = clip_size
    Tmid = int(np.floor((T + 1) / 2) - 1)
    times = F[1, :].ravel()
    labels = F[2, :].ravel().astype(int)
    K = np.max(labels)

    sums = np.zeros((M, T, K), dtype='float64')
    counts = np.zeros(K)

    for k in range(1, K + 1):
        inds_k = np.where(labels == k)[0]
        #TODO: subsample
        for ind_k in inds_k:
            t0 = int(times[ind_k])
            if (clip_size <= t0) and (t0 < N - clip_size):
                clip0 = X.readChunk(i1=0, N1=M, i2=t0 - Tmid, N2=T)
                sums[:, :, k - 1] += clip0
                counts[k - 1] += 1
    templates = np.zeros((M, T, K))
    for k in range(K):
        templates[:, :, k] = sums[:, :, k] / counts[k]
    return templates
示例#7
0
def load_marks(marks_path, spikes_df):
    marks = mdaio.readmda(marks_path)
    if marks.shape[-1] == spikes_df.shape[0]:
        ch_cols = [f'ch{c:>02d}' for c in np.arange(marks.shape[0])]
        marks_df = pd.DataFrame(marks[:,0,:].squeeze().T, columns=ch_cols, \
                                index=spikes_df.index)
        return marks_df
    return
示例#8
0
def test_synthesize_random_firings():
    K = 10
    synthesize_random_firings(K=K, firings_out='tmp.firings.mda')
    firings = mdaio.readmda('tmp.firings.mda')
    labels = firings[2, :]
    assert (max(labels) == K)
    assert (firings.shape[0] == 3)
    return True
 def __init__(self, firings_fname):
     OutputExtractor.__init__(self)
     print('Downloading file if needed: ' + firings_fname)
     self._firings_path = mlp.realizeFile(firings_fname)
     print('Done.')
     self._firings = mdaio.readmda(self._firings_path)
     self._times = self._firings[1, :]
     self._labels = self._firings[2, :]
     self._num_units = np.max(self._labels)
示例#10
0
def apply_label_map(*, firings, label_map, firings_out):
    """
    Apply a label map to a given firings, including masking and merging

    Parameters
    ----------
    firings : INPUT
        Path of input firings mda file
    label_map : INPUT
        Path of input label map mda file [base 1, mapping to zero removes from firings]
    firings_out : OUTPUT
        ...
    """
    firings = mdaio.readmda(firings)
    label_map = mdaio.readmda(label_map)
    label_map = np.reshape(label_map, (-1, 2))
    label_map = label_map[np.argsort(label_map[:,
                                               0])]  # Assure input is sorted

    #Propagate merge pairs to lowest label number
    for idx, label in enumerate(label_map[:, 1]):
        # jfm changed on 12/8/17 because isin() is not isin() older versions of numpy. :)
        #label_map[np.isin(label_map[:,0],label),0] = label_map[idx,0] # Input should be sorted
        label_map[np.where(label_map[:, 0] == label)[0],
                  0] = label_map[idx, 0]  # Input should be sorted

    #Apply label map
    for label_pair in range(label_map.shape[0]):
        # jfm changed on 12/8/17 because isin() is not isin() older versions of numpy. :)
        #firings[2, np.isin(firings[2, :], label_map[label_pair, 1])] = label_map[label_pair,0]
        firings[2, np.where(
            firings[2, :] == label_map[label_pair,
                                       1])[0]] = label_map[label_pair, 0]

    #Mask out all labels mapped to zero
    firings = firings[:, firings[2, :] != 0]

    #Write remapped firings
    return mdaio.writemda64(firings, firings_out)
示例#11
0
def load_clips(clips_path, spikes_df):
    # load clips
    clips = mdaio.readmda(clips_path)
    if clips.shape[-1] == spikes_df.shape[0]:
        clips_2d_list = [
            clips[int(primary_chan - 1), :, i]
            for i, primary_chan in enumerate(spikes_df.primary_chan.values)
        ]
        clips_2d = np.array(clips_2d_list)
        cl_cols = [f'{c:>03d}' for c in np.arange(clips_2d.shape[1])]
        clips_df = pd.DataFrame(clips_2d, columns=cl_cols, \
                                index=spikes_df.index)
        return clips_df
    return
 def initialize(self):
     print('Downloading timeseries (if needed): {}'.format(
         self._timeseries))
     if self._timeseries is not None:
         timeseries_path = mlp.realizeFile(self._timeseries)
     print('Downloading firings (if needed): {}'.format(self._firings))
     firings_path = mlp.realizeFile(self._firings)
     if self._geom is not None:
         print('Downloading geom (if needed): {}'.format(self._geom))
         geom_path = mlp.realizeFile(self._geom)
         self._G = np.genfromtxt(geom_path, delimiter=',').T
         self._G = np.flip(self._G, axis=0)
     else:
         self._G = None
     print('Reading arrays into memory...')
     if self._timeseries is not None:
         self._X = mdaio.readmda(timeseries_path)
     else:
         self._X = None
     self._F = mdaio.readmda(firings_path)
     self._times = self._F[1, :]
     self._labels = self._F[2, :]
     self._K = int(self._labels.max())
示例#13
0
    def get_result_from_folder(output_folder):

        # overwrite the SorterBase.get_result
        from mountainlab_pytools import mdaio

        result_fname = Path(output_folder) / 'firings.mda'

        assert result_fname.exists(), 'Result file does not exist: {}'.format(
            str(result_fname))

        firings = mdaio.readmda(str(result_fname))
        sorting = se.NumpySortingExtractor()
        sorting.set_times_labels(firings[1, :], firings[2, :])
        return sorting
示例#14
0
def test_compute_templates():
    M, N, K, T, L = 5, 1000, 6, 50, 100
    X = np.random.rand(M, N)
    mdaio.writemda32(X, 'tmp.mda')
    F = np.zeros((3, L))
    F[1, :] = 1 + np.random.randint(N, size=(1, L))
    F[2, :] = 1 + np.random.randint(K, size=(1, L))
    mdaio.writemda64(F, 'tmp2.mda')
    ret = compute_templates(timeseries='tmp.mda',
                            firings='tmp2.mda',
                            templates_out='tmp3.mda',
                            clip_size=T)
    assert (ret)
    templates0 = mdaio.readmda('tmp3.mda')
    assert (templates0.shape == (M, T, K))
    return True
def make_toes(fp):
    d = mdaio.readmda(os.path.join(fp, 'mountainout/firings.curated.mda'))
    alignfile = glob.glob(os.path.join(fp, '*.align'))
    align = pd.read_csv(alignfile[0])
    clusters = np.unique(d[2])
    Fs = 30000
    for cluster in clusters:
        spikes = d[1, d[2] == cluster]
        channel = channel_map(d[0, d[2] == cluster][0])
        starts = np.asarray(align.total_start, dtype=int)
        #starts = np.asarray(align.total_pulse, dtype = int)
        stops = np.asarray(align.total_stop, dtype=int)
        #stops = np.asarray(align.total_pulse+(1.5*Fs), dtype = int)
        offsets = np.asarray(align.total_pulse, dtype=int)
        indices = [(spikes >= start) & (spikes <= stop)
                   for start, stop in zip(starts, stops)]
        spiketrains = [
            spikes[ind] - offset for ind, offset in zip(indices, offsets)
        ]
        dirname = os.path.join(fp, "ch%d_c%d" % (channel, cluster))
        if not os.path.exists(dirname):
            os.makedirs(dirname)
        for stim in align['stim'].unique():
            f = [spiketrains[rec] for rec in align[align['stim'] == stim].rec]
            #put together an informative  name for the file
            song = list(align.song[align.stim == stim].unique())[0]
            cond = list(align.condition[align.stim == stim].unique())[0]
            name = "%s_%s" % (song, cond)
            toefile = os.path.join(dirname, "%s.toe_lis" % name)
            with open(toefile, "wt") as ftl:
                tl.write(ftl, np.asarray(f) / Fs * 1000)


#        for song in align['song'].unique():
#            f = [spiketrains[rec] for rec in align[align['song']==song][align['condition']=='no-scene'].rec]
#            name = song+'_no-scene'
#            toefile = os.path.join(dirname, "%s.toe_lis" % name)
#            with open(toefile, "wt") as ftl:
#                tl.write(ftl, np.asarray(f)/Fs*1000)
#            f = [spiketrains[rec] for rec in align[align['song']==song][align['condition'].str.contains('scene63')].rec]
#            name = song+'_scene63'
#            toefile = os.path.join(dirname, "%s.toe_lis" % name)
#            with open(toefile, "wt") as ftl:
#                tl.write(ftl, np.asarray(f)/Fs*1000)

        auditory_plot(dirname)
        cluster_info(dirname, fp, cluster)
示例#16
0
def compute_templates_helper(*, timeseries, firings, clip_size=100):
    X = DiskReadMda(timeseries)
    M, N = X.N1(), X.N2()
    N = N
    F = mdaio.readmda(firings)
    L = F.shape[1]
    L = L
    T = clip_size
    times = F[1, :]
    labels = F[2, :].astype(int)
    K = np.max(labels)
    compute_templates._sums = np.zeros((M, T, K))
    compute_templates._counts = np.zeros(K)

    def _kernel(chunk, info):
        inds = np.where((info.t1 <= times) & (times <= info.t2))[0]
        times0 = (times[inds] - info.t1 + info.t1a).astype(np.int32)
        labels0 = labels[inds]

        clips0 = np.zeros((M, clip_size, len(inds)),
                          dtype=np.float32,
                          order='F')
        cpp.extract_clips(clips0, chunk, times0, clip_size)

        for k in range(1, K + 1):
            inds_kk = np.where(labels0 == k)[0]
            compute_templates._sums[:, :, k -
                                    1] = compute_templates._sums[:, :, k -
                                                                 1] + np.sum(
                                                                     clips0[:, :,
                                                                            inds_kk],
                                                                     axis=2)
            compute_templates._counts[
                k - 1] = compute_templates._counts[k - 1] + len(inds_kk)
        return True

    TCR = TimeseriesChunkReader(chunk_size_mb=40, overlap_size=clip_size * 2)
    if not TCR.run(timeseries, _kernel):
        return None
    templates = np.zeros((M, T, K))
    for k in range(1, K + 1):
        if compute_templates._counts[k - 1]:
            templates[:, :, k -
                      1] = compute_templates._sums[:, :, k -
                                                   1] / compute_templates._counts[
                                                       k - 1]
    return templates
def loadSpikeTimestamps(data_file=None):
    """
    Load spike timestamps, extracted from the rec file with exportmda
    returns: Timestamps for spike data
    """

    if data_file is None:
        ts_file = QtHelperUtils.get_open_file_name(\
                message="Select Timestamp MDA Files", file_format="LFP Data (.dat)")

    # Now that we have the timestamp file, we have to read it
    try:
        timestamps = mdaio.readmda(ts_file)
    except (FileNotFoundError, IOError) as err:
        print(err)
        print('Unable to read TIMESTAMPS file.')
        timestamps = None
    finally:
        return timestamps
示例#18
0
def load_firings(firings_path,
                 samples_offset,
                 ep_samp_start_offset,
                 animal,
                 date,
                 ntrode,
                 config_path,
                 fs=30000):
    # load firings
    firings = mdaio.readmda(firings_path)
    #create spikes_df with firing times with inter-epoch gap
    spike_cols = ['animal', 'day', 'epoch', 'ntrode', 'cluster', \
                  'timedelta', 'sampleindex']
    spikes_df = pd.DataFrame(columns=spike_cols)
    spikes_sampleindex = samples_offset[firings[1, :].astype(int)]
    spikes_timedelta = pd.TimedeltaIndex(spikes_sampleindex / fs,
                                         unit='s',
                                         name='time')
    spikes_df['sampleindex'] = spikes_sampleindex
    spikes_df['timedelta'] = spikes_timedelta
    spikes_df['cluster'] = firings[2, :].astype(int)
    spikes_df['primary_chan'] = firings[0, :].astype(int)
    #add animal, day, epoch cols:
    spikes_df['animal'] = animal
    spikes_df['day'] = core.convert_dates_to_days(animal, date, config_path)
    spikes_df['ntrode'] = ntrode
    spikes_df['epoch'] = 0
    ep_samp_start_timedelta = pd.TimedeltaIndex([ep_s/fs  for ep_s in \
                                                 ep_samp_start_offset], \
                                                unit='s', name='time')
    for epn, ep_samp_td in enumerate(ep_samp_start_timedelta):
        spikes_df.loc[spikes_df['timedelta'] >= \
                                    ep_samp_td, 'epoch'] = epn+1
    spikes_df.set_index(['animal', 'day', 'epoch', 'timedelta', \
                                       'ntrode', 'cluster'], inplace=True)
    return spikes_df
示例#19
0
def ironclust(*,
    recording, # Recording object
    tmpdir, # Temporary working directory
    detect_sign=-1, # Polarity of the spikes, -1, 0, or 1
    adjacency_radius=-1, # Channel neighborhood adjacency radius corresponding to geom file
    detect_threshold=5, # Threshold for detection
    merge_thresh=.98, # Cluster merging threhold 0..1
    freq_min=300, # Lower frequency limit for band-pass filter
    freq_max=6000, # Upper frequency limit for band-pass filter
    pc_per_chan=3, # Number of pc per channel
    prm_template_name, # Name of the template file
    ironclust_src=None
):      
    if ironclust_src is None:
        ironclust_src=os.getenv('IRONCLUST_SRC',None)
    if not ironclust_src:
        raise Exception('You must either set the IRONCLUST_SRC environment variable, or pass the ironclust_src parameter')
    source_dir=os.path.dirname(os.path.realpath(__file__))

    dataset_dir=tmpdir+'/ironclust_dataset'
    # Generate three files in the dataset directory: raw.mda, geom.csv, params.json
    si.MdaRecordingExtractor.writeRecording(recording=recording,save_path=dataset_dir)
        
    samplerate=recording.getSamplingFrequency()

    print('Reading timeseries header...')
    HH=mdaio.readmda_header(dataset_dir+'/raw.mda')
    num_channels=HH.dims[0]
    num_timepoints=HH.dims[1]
    duration_minutes=num_timepoints/samplerate/60
    print('Num. channels = {}, Num. timepoints = {}, duration = {} minutes'.format(num_channels,num_timepoints,duration_minutes))

    print('Creating .params file...')
    txt=''
    txt+='samplerate={}\n'.format(samplerate)
    txt+='detect_sign={}\n'.format(detect_sign)
    txt+='adjacency_radius={}\n'.format(adjacency_radius)
    txt+='detect_threshold={}\n'.format(detect_threshold)
    txt+='merge_thresh={}\n'.format(merge_thresh)
    txt+='freq_min={}\n'.format(freq_min)
    txt+='freq_max={}\n'.format(freq_max)    
    txt+='pc_per_chan={}\n'.format(pc_per_chan)
    txt+='prm_template_name={}\n'.format(prm_template_name)
    _write_text_file(dataset_dir+'/argfile.txt',txt)
        
    print('Running IronClust...')
    cmd_path = "addpath('{}', '{}/matlab', '{}/mdaio');".format(ironclust_src, ironclust_src, ironclust_src)
    #"p_ironclust('$(tempdir)','$timeseries$','$geom$','$prm$','$firings_true$','$firings_out$','$(argfile)');"
    cmd_call = "p_ironclust('{}', '{}', '{}', '', '', '{}', '{}');"\
        .format(tmpdir, dataset_dir+'/raw.mda', dataset_dir+'/geom.csv', tmpdir+'/firings.mda', dataset_dir+'/argfile.txt')
    cmd='matlab -nosplash -nodisplay -r "{} {} quit;"'.format(cmd_path, cmd_call)
    print(cmd)
    retcode=_run_command_and_print_output(cmd)

    if retcode != 0:
        raise Exception('IronClust returned a non-zero exit code')

    # parse output
    result_fname=tmpdir+'/firings.mda'
    if not os.path.exists(result_fname):
        raise Exception('Result file does not exist: '+ result_fname)
    
    firings=mdaio.readmda(result_fname)
    sorting=si.NumpySortingExtractor()
    sorting.setTimesLabels(firings[1,:],firings[2,:])
    return sorting
示例#20
0
 def read_data(self, dataset_id, file_id):
     return readmda(self.directories[dataset_id][file_id])
示例#21
0
def export_mountain_cells():
    
    print('expecting call from channel level')
    init_path = os.getcwd()
    current_path = os.getcwd()
    path_split = str.split(current_path,'/')
    if path_split[-2] == 'mountains':
        channel_path = path_split[-1]
        os.chdir('output')
    else:
        os.chdir('../../../mountains/')
        os.chdir(path_split[-1])
        channel_path = path_split[-1]
        os.chdir('output')

    original = mdaio.readmda('firings.mda')
    os.chdir('..')
    start_time = {}
    with open('start_indices.csv') as csvfile:
        csvreader = csv.reader(csvfile,delimiter = ',')
        linecount = 0
        for line in csvreader:
            if linecount == 0:
                temp = []
                for term in line[1:]:
                    temp.append(int(term.strip(" '][")))
            else:
                temp = []
                for term in line[1:]:
                    temp.append(term.strip(" ']["))

            start_time[line[0]] = temp        
            linecount += 1
    
    current_path = os.getcwd()
    ch = current_path.partition('channel')[1] + current_path.partition('channel')[2]
    total_ind = len(original[1])
    marker = 0
    count = 0
    
    for i in range(total_ind):
        if count == len(start_time.get('2'))-1:
            exporting_arr = [original[1][marker:total_ind],original[2][marker:total_ind]]
            location = os.getcwd()
            temp = location.split('/')
            upper_folder1 = location[0:len(location)-len(temp[len(temp)-1])-1]
            os.chdir(upper_folder1)
            temp1 = upper_folder1.split('/')
            upper_folder2 = upper_folder1[0:len(upper_folder1)-len(temp1[len(temp1)-1])-1]
            os.chdir(upper_folder2)
            session_path = upper_folder2 + '/' + start_time.get('2')[count]
            os.chdir(session_path)
            full_list = []
            for name in os.listdir("."):
                if os.path.isdir(name):
                    full_list.append(name)
            for folder in full_list:
                array_path = session_path + '/' + folder
                os.chdir(array_path)
                if os.path.isdir(channel_path):
                    os.chdir(array_path)
                    split_into_cells_intra_session(ch, exporting_arr, start_time.get('1')[count]);
                    os.chdir(location)
            break
        if original[1][i] >= start_time.get('1')[count+1]:          
            exporting_arr = [original[1][marker:i-1],original[2][marker:i-1]]
            location = os.getcwd()
            temp = location.split('/')
            upper_folder1 = location[0:len(location)-len(temp[len(temp)-1])-1]
            os.chdir(upper_folder1)
            temp1 = upper_folder1.split('/')
            upper_folder2 = upper_folder1[0:len(upper_folder1)-len(temp1[len(temp1)-1])-1]
            os.chdir(upper_folder2)
            session_path = upper_folder2 + '/' + start_time.get('2')[count]
            os.chdir(session_path)
            full_list = []
            for name in os.listdir("."):
                if os.path.isdir(name):
                    full_list.append(name)
            for folder in full_list:
                array_path = session_path + '/' + folder
                os.chdir(array_path)
                if os.path.isdir(channel_path):
                    os.chdir(array_path)
                    split_into_cells_intra_session(ch, exporting_arr, start_time.get('1')[count])
            marker = i
            count += 1
            os.chdir(location)
            
    os.chdir(init_path)
示例#22
0
def convert_array(*,
                  input,
                  output,
                  format='',
                  format_out='',
                  dimensions='',
                  dtype='',
                  dtype_out='',
                  channels=''):
    """
    Convert a multi-dimensional array between various formats ('.mda', '.npy', '.dat') based on the file extensions of the input/output files

    Parameters
    ----------
    input : INPUT
        Path of input array file (can be repeated for concatenation).
    output : OUTPUT
        Path of the output array file.
        
    format : string
        The format for the input array (mda, npy, dat), or determined from the file extension if empty
    format_out : string
        The format for the output input array (mda, npy, dat), or determined from the file extension if empty
    dimensions : string
        Comma-separated list of dimensions (shape). If empty, it is auto-determined, if possible, by the input array. If second dim is -1 then it will be extrapolated from file size / first dim.
    dtype : string
        The data format for the input array. Choices: int8, int16, int32, uint16, uint32, float32, float64 (possibly float16 in the future).
    dtype_out : string
        The data format for the output array. If empty, the dtype for the input array is used.
    channels : string
        Comma-seperated list of channels to keep in output. Zero-based indexing. Only works for .dat to .mda conversions.
    """
    if isinstance(input, (list, )):
        multifile = True
        inputs = input
        input = inputs[0]
    else:
        multifile = False

    format_in = format
    if not format_in:
        format_in = determine_file_format(file_extension(input), dimensions)
    if not format_out:
        format_out = determine_file_format(file_extension(output), dimensions)
    print('Input/output formats: {}/{}'.format(format_in, format_out))
    ext_in = file_extension(input)

    dims = None

    if (format_in == 'mda') and (dtype == ''):
        header = mdaio.readmda_header(input)
        dtype = header.dt
        dims = header.dims

    if (format_in == 'npy') and (dtype == ''):
        A = np.load(input, mmap_mode='r')
        dtype = npy_dtype_to_string(A.dtype)
        dims = A.shape
        A = 0

    if dimensions:
        dims2 = [int(entry) for entry in dimensions.split(',')]
        if dims:
            if len(dims) != len(dims2):
                raise Exception(
                    'Inconsistent number of dimensions for input array')
            if not np.all(np.array(dims) == np.array(dims2)):
                raise Exception('Inconsistent dimensions for input array')
        dims = dims2

    if not dtype_out:
        dtype_out = dtype

    if not dtype:
        raise Exception('Unable to determine datatype for input array')

    if not dtype_out:
        raise Exception('Unable to determine datatype for output array')

    if (dims[1] == -1) and (dims[0] > 0):
        if ((dtype) and (format_in == 'dat')):
            bits = int(
                dtype[-2:]
            )  # number of bits per entry of dtype, TODO: make this smarter
            if not multifile:
                filebytes = os.stat(input).st_size  # bytes in input file
            else:
                dims1 = np.copy(dims)
                filebytes1 = os.stat(input).st_size  # bytes in input file
                entries1 = int(filebytes1 / (int(bits / 8)))
                dims1[1] = int(entries1 / dims1[0])
                filebytes = sum([os.stat(inp).st_size for inp in inputs])
            entries = int(filebytes / (int(bits / 8)))  # entries in input file
            dims[1] = int(entries / dims[0])  # caclulated second dimension
            if DEBUG:
                print(bits)
                print(filebytes)
                print(int(filebytes / (int(bits / 8))))
                print(dims)
        else:
            raise Exception('Could not infer dimensions')

    if not dims:
        raise Exception('Unable to determine dimensions for input array')

    if not channels:
        channels = range(0, dims[0])
    else:
        channels = np.array([int(entry) for entry in channels.split(',')])

    if DEBUG:
        print(channels)

    print('Using dtype={}, dtype_out={}, dimensions={}'.format(
        dtype, dtype_out, ','.join(str(item) for item in dims)))
    if (format_in == format_out) and ((dtype == dtype_out) or
                                      (dtype_out == '')):
        if multifile and (format_in == 'dat'):
            print('Concatenating Files')
            with open(output, "wb") as outfile:
                for input_file in inputs:
                    with open(input_file, "rb") as inpt:
                        outfile.write(inpt.read())
        elif not multifile:
            print('Simply copying file...')
            shutil.copyfile(input, output)
            print('Done.')
        return True

    if format_out == 'dat' and not multifile:
        if format_in == 'mda':
            H = mdaio.readmda_header(input)
            copy_raw_file_data(input,
                               output,
                               start_byte=H.header_size,
                               num_entries=np.product(dims),
                               dtype=dtype,
                               dtype_out=dtype_out)
            return True
        elif format_in == 'npy':
            print(
                'Warning: loading entire array into memory. This should be avoided in the future.'
            )
            A = np.load(input, mmap_mode='r').astype(dtype=dtype_out,
                                                     order='F',
                                                     copy=False)
            A = A.ravel(order='F')
            A.tofile(output)
            # The following was problematic because of row-major ordering, i think
            #header_size=determine_npy_header_size(input)
            #copy_raw_file_data(input,output,start_byte=header_size,num_entries=np.product(dims),dtype=dtype,dtype_out=dtype_out)
            return True
        elif format_in == 'dat':
            raise Exception('This case not yet implemented.')
        else:
            raise Exception('Unexpected case.')

    elif (format_out == 'mda') or (format_out == 'npy'):
        if format_in == 'dat' and multifile:
            print(
                'Warning: loading entire array into memory. This should be avoided in the future.'
            )
            print(channels)  #DEBUG
            A = np.fromfile(inputs[0], dtype=dtype, count=np.product(dims))
            A = A.reshape(tuple(dims1), order='F')
            A = A[channels, :]
            for inputn in inputs[1:]:
                An = np.fromfile(inputn, dtype=dtype, count=np.product(dims))
                dimsN = np.copy(dims1)
                dimsN[1] = An.size / dims1[0]
                An = An.reshape(tuple(dimsN), order='F')
                An = An[channels, :]
                print(A.shape)
                print(An.shape)
                A = np.concatenate((A, An), axis=1)
            print(A.shape)  #DEBUG
            if format_out == 'mda':
                mdaio.writemda(A, output, dtype=dtype_out)
            else:
                mdaio.writenpy(A, output, dtype=dtype_out)
            return True
        elif format_in == 'dat' and not multifile:
            print(
                'Warning: loading entire array into memory. This should be avoided in the future.'
            )
            A = np.fromfile(input, dtype=dtype, count=np.product(dims))
            A = A.reshape(tuple(dims), order='F')
            A = A[channels, :]
            if format_out == 'mda':
                mdaio.writemda(A, output, dtype=dtype_out)
            else:
                mdaio.writenpy(A, output, dtype=dtype_out)
            return True
        elif format_in == 'mda' and not multifile:
            print(
                'Warning: loading entire array into memory. This should be avoided in the future.'
            )
            A = mdaio.readmda(input)
            if format_out == 'mda':
                mdaio.writemda(A, output, dtype=dtype_out)
            else:
                mdaio.writenpy(A, output, dtype=dtype_out)
            return True
        elif format_in == 'mda' and multifile:
            print(
                'Warning: loading entire array into memory. This should be avoided in the future.'
            )
            A = mdaio.readmda(inputs[0])
            A = A[channels, :]
            for inputn in inputs[1:]:
                An = mdaio.readmda(inputn)
                An = An[channels, :]
                A = np.concatenate((A, An), axis=0)
            if format_out == 'mda':
                mdaio.writemda(A, output, dtype=dtype_out)
            else:
                mdaio.writenpy(A, output, dtype=dtype_out)
            return True
        elif format_in == 'npy' and not multifile:
            print(
                'Warning: loading entire array into memory. This should be avoided in the future.'
            )
            A = np.load(input, mmap_mode='r').astype(dtype=dtype,
                                                     order='F',
                                                     copy=False)
            if format_out == 'mda':
                mdaio.writemda(A, output, dtype=dtype_out)
            else:
                mdaio.writenpy(A, output, dtype=dtype_out)
            return True
        elif format_in == 'npy' and multifile:
            print(
                'Warning: loading entire array into memory. This should be avoided in the future.'
            )
            A = np.load(inputs[0], mmap_mode='r').astype(dtype=dtype,
                                                         order='F',
                                                         copy=False)
            A = A[channels, :]
            for inputn in inputs[1:]:
                An = np.load(inputn, mmap_mode='r').astype(dtype=dtype,
                                                           order='F',
                                                           copy=False)
                An = An[channels, :]
                A = np.concatenate((A, An), axis=0)
            if format_out == 'mda':
                mdaio.writemda(A, output, dtype=dtype_out)
            else:
                mdaio.writenpy(A, output, dtype=dtype_out)
            return True
        else:
            raise Exception('Unexpected case.')
    else:
        raise Exception('Unexpected output format: {}'.format(format_out))

    raise Exception('Unexpected error.')
def separateSpikesInEpochs(data_dir=None,
                           firings_file='firings.curated.mda',
                           timestamp_files=None,
                           write_separated_spikes=True):
    """
    Takes curated spikes from MountainSort and combines this information with spike timestamps to create separate curated spikes for each epoch

    :firings_file: Curated firings file
    :timestamp_files: Spike timestamps file list
    :write_separated_spikes: If the separated spikes should be written back to the data directory.
    :returns: List of spikes for each epoch
    """

    if data_dir is None:
        # Get the firings file
        data_dir = QtHelperUtils.get_directory(
            message="Select Curated firings location")

    separated_tetrodes = []
    curated_firings = []
    merged_curated_firings = []
    tetrode_list = os.listdir(data_dir)
    for tt_dir in tetrode_list:
        try:
            if firings_file in os.listdir(data_dir + '/' + tt_dir):
                curated_firings.append([])
                separated_tetrodes.append(tt_dir)
                firings_file_location = '/'.join(
                    [data_dir, tt_dir, firings_file])
                merged_curated_firings.append(
                    mdaio.readmda(firings_file_location))
                print(MODULE_IDENTIFIER +
                      'Read merged firings file for tetrode %s!' % tt_dir)
            else:
                print(MODULE_IDENTIFIER +
                      'Merged firings %s not  found for tetrode %s!' %
                      (firings_file, tt_dir))
        except (FileNotFoundError, IOError) as err:
            print(MODULE_IDENTIFIER +
                  'Unable to read merged firings file for tetrode %s!' %
                  tt_dir)
            print(err)

    gui_root = Tk()
    gui_root.withdraw()
    timestamp_headers = []
    if timestamp_files is None:
        # Read all the timestamp files
        timestamp_files = filedialog.askopenfilenames(initialdir=DEFAULT_SEARCH_PATH, \
                title="Select all timestamp files", \
                filetypes=(("Timestamps", ".mda"), ("All Files", "*.*")))
    gui_root.destroy()

    for ts_file in timestamp_files:
        timestamp_headers.append(mdaio.readmda_header(ts_file))

    # Now that we have both the timestamp headers and the timestamp files, we
    # can separate spikes out.  It is important here for the timestamp files to
    # be in the same order as the curated firings as that is the only way for
    # us to tell that the firings are being split up correctly.

    print(MODULE_IDENTIFIER + 'Looking at spike timestamps in order')
    print(timestamp_files)

    # First splice up curated spikes into indiviual epochs
    for tt_idx, tt_firings in enumerate(merged_curated_firings):
        for ep_idx, ts_header in enumerate(timestamp_headers):
            if tt_firings is None:
                curated_firings[tt_idx].append(None)
                continue

            n_data_points = ts_header.dims[0]
            print(MODULE_IDENTIFIER + 'Epoch ' + str(ep_idx) + ': ' +
                  str(n_data_points) + ' samples.')
            last_spike_from_epoch = np.searchsorted(
                tt_firings[1], n_data_points, side='left') - 1

            # If there are no spikes in this epoch, there might still be some in future epochs!
            if last_spike_from_epoch < 0:
                tt_firings[1] = tt_firings[1] - float(n_data_points)
                curated_firings[tt_idx].append(None)
                continue

            last_spike_sample_number = tt_firings[1][last_spike_from_epoch]
            print(MODULE_IDENTIFIER + separated_tetrodes[tt_idx] + ': First spike ' + str(tt_firings[1][0])\
                    + ', Last spike ' + str(last_spike_sample_number))
            epoch_spikes = tt_firings[:, :last_spike_from_epoch]
            curated_firings[tt_idx].append(epoch_spikes)

            if last_spike_from_epoch < (len(tt_firings[1]) - 1):
                # Slice the merged curated firing to only have the remaining spikes
                tt_firings = tt_firings[:, last_spike_from_epoch + 1:]
                print(MODULE_IDENTIFIER +
                      'Trimming curated spikes. Aggregate sample start ' +
                      str(tt_firings[1][0]))
                tt_firings[1] = tt_firings[1] - float(n_data_points)
                print(MODULE_IDENTIFIER + 'Sample number trimmed to ' +
                      str(tt_firings[1][0]))
            else:
                print(MODULE_IDENTIFIER + separated_tetrodes[tt_idx] +
                      ', Reached end of curated firings at Epoch ' +
                      str(ep_idx))
                tt_firings = None

    print(MODULE_IDENTIFIER +
          'Spikes separated in epochs. Substituting timestamps!')
    # For each epoch replace the sample numbers with the corresponding
    # timestamps. We are going through multiple revisions for this so that we
    # only have to load one timestamp file at a time
    for ep_idx, ts_file in enumerate(timestamp_files):
        epoch_timestamps = mdaio.readmda(ts_file)
        print(MODULE_IDENTIFIER + 'Epoch ' + str(ep_idx))
        for tt_idx, tt_curated_firings in enumerate(curated_firings):
            if tt_curated_firings[ep_idx] is None:
                continue
            # Need to use the original array because changes to tt_curated_firings do not get copied back
            curated_firings[tt_idx][ep_idx][1] = epoch_timestamps[np.array(
                tt_curated_firings[ep_idx][1], dtype=int)]
            print(MODULE_IDENTIFIER + separated_tetrodes[tt_idx] + ': Samples (' + \
                    str(tt_curated_firings[ep_idx][1][0]) + ', ' + str(tt_curated_firings[ep_idx][1][-1]), ')')

    if write_separated_spikes:
        try:
            for tt_idx, tet in enumerate(separated_tetrodes):
                for ep_idx in range(len(timestamp_files)):
                    if curated_firings[tt_idx][ep_idx] is not None:
                        ep_firings_file_name = data_dir + '/' + tet + '/firings-' + \
                                str(ep_idx+1) + '.curated.mda'
                        mdaio.writemda64(curated_firings[tt_idx][ep_idx],
                                         ep_firings_file_name)
        except OSError as exception:
            if exception.errno != errno.EEXIST:
                print(MODULE_IDENTIFIER +
                      'Unable to write timestamped firings!')
                print(exception)

    return curated_firings
示例#24
0
def compare_ground_truth(*, firings_true, firings, json_out, max_dt=20):
    """
    compare a sorting (firings) with ground truth (firings_true)

    Parameters
    ----------
    firings_true : INPUT
        Path of true firings file (RxL) R>=3, L=#evts
    firings : INPUT
        Path of sorted firings file (RxL) R>=3, L=#evts
    json_out : OUTPUT
        Path of the output file containing the results of the comparison
        
    max_dt : int
        Tolerance for matching events (in timepoints)
        
    """

    print('Reading arrays...')
    F = mdaio.readmda(firings)
    Ft = mdaio.readmda(firings_true)
    print('Initializing data...')
    L = F.shape[1]
    Lt = Ft.shape[1]
    times = F[1, :]
    times_true = Ft[1, :]
    labels = F[2, :].astype(np.int32)
    labels_true = Ft[2, :].astype(np.int32)
    F = 0  # free memory? Alex: depends if python decides to call garbage collection
    Ft = 0  # free memory?
    N = np.maximum(np.max(times), np.max(times_true))
    K = np.max(labels)
    Kt = np.max(labels_true)

    # todo: subsample in first pass to get the best

    # First we split into segments
    print('Splitting into segments')
    segment_size = max_dt
    num_segments = int(np.ceil(N / segment_size))
    N = num_segments * segment_size
    segments = []
    # occupancy: K x num_segments
    occupancy, counts = create_occupancy_array(times,
                                               labels,
                                               segment_size,
                                               num_segments,
                                               K,
                                               spread=False,
                                               multiplicity=True)
    # occupancy_true: Kt x num_segments
    occupancy_true, counts_true = create_occupancy_array(times_true,
                                                         labels_true,
                                                         segment_size,
                                                         num_segments,
                                                         Kt,
                                                         spread=True,
                                                         multiplicity=False)

    # Note: we spread the occupancy_true but not the occupancy
    # Note: we count the occupancy with multiplicity but not occupancy_true

    print('Computing pairwise counts and accuracies...')
    pairwise_counts = occupancy_true @ occupancy.transpose()  # Kt x K
    pairwise_accuracies = np.zeros((Kt, K))
    for k1 in range(1, Kt + 1):
        for k2 in range(1, K + 1):  # jfm fixed +1 bug on 8/29/18
            numer = pairwise_counts[k1 - 1, k2 - 1]
            denom = counts_true[k1 - 1] + counts[k2 - 1] - numer
            if denom > 0:
                pairwise_accuracies[k1 - 1, k2 - 1] = numer / denom

    print('Preparing output...')
    ret = {"true_units": {}}
    for k1 in range(1, Kt + 1):
        k2_match = int(1 + np.argmax(pairwise_accuracies[k1 - 1, :].ravel()))
        # todo: compute accuracy more precisely here
        num_matches = int(pairwise_counts[k1 - 1, k2_match - 1])
        num_false_positives = int(counts[k2_match - 1] - num_matches)
        num_false_negatives = int(counts_true[k1 - 1] - num_matches)
        unit = {
            "best_match": k2_match,
            "accuracy": pairwise_accuracies[k1 - 1, k2_match - 1],
            "num_matches": num_matches,
            "num_false_positives": num_false_positives,
            "num_false_negatives": num_false_negatives
        }
        ret['true_units'][k1] = unit

    print('Writing output...')
    str = json.dumps(ret, indent=4)
    with open(json_out, 'w') as out:
        out.write(str)
    print('Done.')

    return True
from mountainlab_pytools import mdaio
import sys
from scipy.interpolate import interp1d

if len(sys.argv)<2:
    print('Not enough input. I need the tetrode number #t (folder tet+#t where to find the files firings.mda and raw_filt.mda')
    exit()

tet= sys.argv[1]
fold='tet'+tet

fpar = open(fold + '/' + 'tet' + tet + '.par.' + tet)
nchan = int(fpar.readline().split()[1])
print('Number of channels:',nchan)

a = mdaio.readmda(fold+'/raw_filt.mda')
b = mdaio.readmda(fold+'/firings.mda')

nC, L = a.shape # number of channels, length of recording

res = b[1,:].astype(int)
clu = b[2,:].astype(int)

print(res[-1])

c = np.zeros([len(res),nchan,32])  # was 32
ce = np.zeros([len(res),nchan,33])

for ir, r in enumerate(res):
    if r > 15 and r < L-18:
        c[ir,:,:] = a[:,r-15:r+17] # was 17
示例#26
0
 def __init__(self, path):
     self._timeseries_path = path
     self._timeseries = mdaio.readmda(path)
def compute_cluster_metrics(*,timeseries='',firings,metrics_out,clip_size=100,samplerate=0,
        refrac_msec=1):
    """
    Compute cluster metrics for a spike sorting output

    Parameters
    ----------
    firings : INPUT
        Path of firings mda file (RxL) where R>=3 and L is the number of events. Second row are timestamps, third row are integer cluster labels.
    timeseries : INPUT
        Optional path of timeseries mda file (MxN) which could be raw or preprocessed   
        
    metrics_out : OUTPUT
        Path of output json file containing the metrics.
        
    clip_size : int
        (Optional) clip size, aka snippet size (used when computing the templates, or average waveforms)
    samplerate : float
        Optional sample rate in Hz

    refrac_msec : float
        (Optional) Define interval (in ms) as refractory period. If 0 then don't compute the refractory period metric.
    """    
    print('Reading firings...')
    F=mdaio.readmda(firings)

    print('Initializing...')
    R=F.shape[0]
    L=F.shape[1]
    assert(R>=3)
    times=F[1,:]
    labels=F[2,:].astype(np.int)
    K=np.max(labels)
    N=0
    if timeseries:
        X=mdaio.DiskReadMda(timeseries)
        N=X.N2()

    if (samplerate>0) and (N>0):
        duration_sec=N/samplerate
    else:
        duration_sec=0

    clusters=[]
    for k in range(1,K+1):
        inds_k=np.where(labels==k)[0]
        metrics_k={
            "num_events":len(inds_k)
        }
        if duration_sec:
            metrics_k['firing_rate']=len(inds_k)/duration_sec
        cluster={
            "label":k,
            "metrics":metrics_k
        }
        clusters.append(cluster)

    if timeseries:
        print('Computing templates...')
        templates=compute_templates_helper(timeseries=timeseries,firings=firings,clip_size=clip_size)
        for k in range(1,K+1):
            template_k=templates[:,:,k-1]
            # subtract mean on each channel (todo: vectorize this operation)
            # Alex: like this?
            # template_k[m,:] -= np.mean(template_k[m,:])
            for m in range(templates.shape[0]):
                template_k[m,:]=template_k[m,:]-np.mean(template_k[m,:])
            peak_amplitude=np.max(np.abs(template_k))
            clusters[k-1]['metrics']['peak_amplitude']=peak_amplitude
        ## todo: subtract template means, compute peak amplitudes

    if refrac_msec > 0:
        print('Computing Refractory Period Violations')
        msec_string = '{0:.5f}'.format(refrac_msec).rstrip('0').rstrip('.')
        rr_string = 'refractory_violation_{}msec'.format(msec_string)
        max_dt_msec=50
        bin_size_msec=refrac_msec
        auto_cors = compute_cross_correlograms_helper(firings=firings,mode='autocorrelograms',samplerate=samplerate,max_dt_msec=max_dt_msec,bin_size_msec=bin_size_msec)
        mid_ind  = np.floor(max_dt_msec/bin_size_msec).astype(int)
        mid_inds = np.arange(mid_ind-2,mid_ind+2)
        for k0,auto_cor_obj in enumerate(auto_cors['correlograms']):
            auto = auto_cor_obj['bin_counts']
            k    = auto_cor_obj['k']
            peak = np.percentile(auto, 75)
            mid  = np.mean(auto[mid_inds])
            rr   = safe_div(mid,peak)
            clusters[k-1]['metrics'][rr_string] = rr

    ret={
        "clusters":clusters
    }

    print('Writing output...')
    str=json.dumps(ret,indent=4)
    with open(metrics_out, 'w') as out:
        out.write(str)
    print('Done.')
    return True
示例#28
0
bandpass_filter.version = '0.1'

if __name__ == "__main__":

    samplerate = int(3e4)
    freq_min = 250
    freq_max = 6000
    data_dir = '../ephys_preprocessing/'

    raw_data_ch1 = np.asarray(
        sio.loadmat(os.path.join(data_dir, 'raw_data_ch1.mat'))['data'])
    mdaio.writemda32(raw_data_ch1, os.path.join(data_dir, 'raw_data_ch1.mda'))
    timeseries = os.path.join(data_dir, 'raw_data_ch1.mda')
    timeseries_out = os.path.join(data_dir, 'filtered_raw_data_ch1.mda')
    bandpass_filter(timeseries, timeseries_out, samplerate, freq_min, freq_max)
    filtered_data = mdaio.readmda(
        os.path.join(data_dir, 'filtered_raw_data_ch1.mda'))

    detrended_data_ch1 = np.asarray(
        sio.loadmat(os.path.join(data_dir, 'detrended_data_ch1.mat'))['copy'])
    mdaio.writemda32(detrended_data_ch1,
                     os.path.join(data_dir, 'detrended_data_ch1.mda'))
    timeseries = os.path.join(data_dir, 'detrended_data_ch1.mda')
    timeseries_out = os.path.join(data_dir, 'filtered_detrended_data_ch1.mda')
    bandpass_filter(timeseries, timeseries_out, samplerate, freq_min, freq_max)
    filtered_data_detrended = mdaio.readmda(
        os.path.join(data_dir, 'filtered_detrended_data_ch1.mda'))

    plt.plot(raw_data_ch1[0, :], label='raw')
    plt.plot(filtered_data_detrended[0, :], label='filtered_detrended')

    # plt.plot(detrended_data_ch1[0, :], label='raw_detrended')
示例#29
0
def read_mda_timestamps(file):
    return readmda(file)
def loadClusteredData(data_location=None,
                      firings_file='firings.curated.mda',
                      helper_file='hand_curated.mv2',
                      time_limits=None):
    """
    Load up clustered data and pool it.

    :data_location: Directory which has clustered data from all the tetrodes.
    :time_limits: (Floating Point) Real time limits within which the data should be extracted.
    :returns: Spike data, separated into containers for individual units.
    """

    if data_location is None:
        # Get the location using a file dialog
        data_location = QtHelperUtils.get_directory("Select data location.")

    clustered_spikes = []
    tt_cl_to_unique_cluster_id = {}
    unique_cluster_id = 0
    tetrode_list = os.listdir(data_location)
    for tt_dir in tetrode_list:
        tt_dir_path = os.path.join(data_location, tt_dir)
        if not os.path.isdir(tt_dir_path):
            continue

        if firings_file in os.listdir(tt_dir_path):
            tt_idx = tt_dir.split('nt')[1]
            firings_file_path = os.path.join(tt_dir_path, firings_file)
            curation_file_path = os.path.join(tt_dir_path, helper_file)
            try:
                # Read the firings file
                firing_data = mdaio.readmda(firings_file_path)
            except Exception as err:
                print('Tetrode ' + tt_dir + 'Unable to read firings file!')
                print(err)
                continue

            try:
                # Read the curation file for info on spike clusters
                with open(curation_file_path, 'r') as f:
                    curation_file = json.load(f)
                # Get all cluster IDs. This includes noise, mua, everything!
                cluster_ids = [
                    int(cl)
                    for cl in curation_file['cluster_attributes'].keys()
                ]
            except (FileNotFoundError, IOError) as err:
                print('Tetrode ' + tt_dir + 'Unable to read curation file.')
                continue

            # Read off spikes for individual clusters and assign unique cluster IDs to them
            if time_limits is not None:
                firing_times = (firing_data[1] -
                                firing_data[1][0]) / SPIKE_SAMPLING_RATE
                time_limit_start_idx = np.searchsorted(firing_times,
                                                       time_limits[0],
                                                       side='left')
                time_limit_finish_idx = np.searchsorted(firing_times,
                                                        time_limits[1],
                                                        side='right')
                firing_data = firing_data[:, time_limit_start_idx:
                                          time_limit_finish_idx]
            firing_clusters = firing_data[2]

            n_clusters = 0
            for unit_id in cluster_ids:
                unit_spikes = firing_data[1][firing_clusters == unit_id]
                if len(unit_spikes > 0):
                    tt_cl_to_unique_cluster_id[(tt_idx,
                                                unit_id)] = unique_cluster_id
                    n_clusters += 1
                    unique_cluster_id += 1
                    clustered_spikes.append(unit_spikes)
                else:
                    clustered_spikes.append(None)

            n_spikes = len(firing_data[1])
            print('Tetrode %s loaded %d spikes from %d clusters.' %
                  (tt_dir, n_spikes, n_clusters))
        else:
            print('Tetrode ' + tt_dir + ': Firings file not found!')

    return clustered_spikes