예제 #1
0
def convert_firings(*, firings, firings_out):
    """
    Export firings (either .mda or .npy) .res & .clu format.

    Parameters
    ----------
    firings : INPUT
        Path of firings mda (or npy) file (RxL) where R>=3 and L is the number of events. Second row are timestamps, third row are integer labels.    
        
    firings_out : OUTPUT
        Path of output filebase. Actual output files will be $(firings_out).res & $(firings_out).clu.
    """

    F = mdaio.readmda(firings)
    L = F.shape[1]  # Number of Spike Times
    times = F[1, :].ravel()  # Spike Times
    labels = F[2, :].ravel().astype(int)  # Cluster IDs
    K = np.max(labels)  # Number of clusters

    res_fname = firings_out + '.res'
    clu_fname = firings_out + '.clu'
    np.savetxt(res_fname, times, fmt='%d')
    np.savetxt(clu_fname, np.concatenate((np.array([K]), times)), fmt='%d')

    return True
def compute_cross_correlograms_helper(*,
                                      firings,
                                      mode='autocorrelograms',
                                      samplerate=30000,
                                      max_dt_msec=50,
                                      bin_size_msec=2):
    if type(firings) == str:
        F = mdaio.readmda(firings)
    else:
        F = firings
    R, L = np.shape(F)
    assert (R >= 3)
    assert (mode == 'autocorrelograms')
    max_dt_tp = max_dt_msec / 1000 * samplerate
    bin_size_tp = bin_size_msec / 1000 * samplerate
    times = F[1, :]
    labels = F[2, :].astype(int)
    K = np.max(labels)
    correlograms = []
    for k in range(1, K + 1):
        inds_k = np.where(labels == k)[0]
        times_k = times[inds_k]
        bin_counts, bin_edges = compute_autocorrelogram(
            times_k, max_dt_tp=max_dt_tp, bin_size_tp=bin_size_tp)
        correlograms.append({
            "k": k,
            "bin_edges": bin_edges / samplerate * 1000,
            "bin_counts": bin_counts
        })
    return {"correlograms": correlograms}
예제 #3
0
def compute_templates_helper(*,timeseries,firings,clip_size=100):
    X=mdaio.DiskReadMda(timeseries)
    M,N = X.N1(),X.N2()
    F=mdaio.readmda(firings)
    L=F.shape[1]
    L=L
    T=clip_size
    Tmid = int(np.floor((T + 1) / 2) - 1);
    times=F[1,:].ravel()
    labels=F[2,:].ravel().astype(int)
    K=np.max(labels)

    sums=np.zeros((M,T,K),dtype='float64')
    counts=np.zeros(K)

    for k in range(1,K+1):
        inds_k=np.where(labels==k)[0]
        #TODO: subsample
        for ind_k in inds_k:
            t0=int(times[ind_k])
            if (clip_size<=t0) and (t0<N-clip_size):
                clip0=X.readChunk(i1=0,N1=M,i2=t0-Tmid,N2=T)
                sums[:,:,k-1]+=clip0
                counts[k-1]+=1
    templates=np.zeros((M,T,K))
    for k in range(K):
        templates[:,:,k]=sums[:,:,k]/counts[k]
    return templates
예제 #4
0
def test_convert_array(dtype='int32', shape=[12, 3, 7]):
    X = np.array(np.random.normal(0, 1, shape), dtype=dtype)
    np.save('test_convert1.npy', X)
    convert_array(input='test_convert1.npy',
                  output='test_convert2.mda')  # npy -> mda
    convert_array(input='test_convert2.mda',
                  output='test_convert3.npy')  # mda -> npy
    convert_array(input='test_convert3.npy',
                  output='test_convert4.dat')  # npy -> dat
    convert_array(input='test_convert4.dat',
                  output='test_convert5.npy',
                  dtype=dtype,
                  dimensions=','.join(str(entry)
                                      for entry in X.shape))  # dat -> npy
    convert_array(input='test_convert5.npy',
                  output='test_convert6.mda')  # npy -> mda
    convert_array(input='test_convert6.mda',
                  output='test_convert7.dat')  # mda -> dat
    convert_array(input='test_convert7.dat',
                  output='test_convert8.mda',
                  dtype=dtype,
                  dimensions=','.join(str(entry)
                                      for entry in X.shape))  # dat -> mda
    Y = mdaio.readmda('test_convert8.mda')
    print(np.max(np.abs(X - Y)), Y.dtype)
예제 #5
0
def test_compute_templates():
    M,N,K,T,L = 5,1000,6,50,100
    X=np.random.rand(M,N)
    mdaio.writemda32(X,'tmp.mda')
    F=np.zeros((3,L))
    F[1,:]=1+np.random.randint(N,size=(1,L))
    F[2,:]=1+np.random.randint(K,size=(1,L))
    mdaio.writemda64(F,'tmp2.mda')
    ret=compute_templates(timeseries='tmp.mda',firings='tmp2.mda',templates_out='tmp3.mda',clip_size=T)
    assert(ret)
    templates0=mdaio.readmda('tmp3.mda')
    assert(templates0.shape==(M,T,K))
    return True
예제 #6
0
def copy_mda(input, output):
    """
    Copy .mda file from input to output.

    Parameters
    ----------
    input : INPUT
        Path of mda file to read.
    output : OUTPUT
        Path of mda file to write. 
    """
    X = mdaio.readmda(input)
    mdaio.writemda(X, output, dtype='int16')
    return True
예제 #7
0
def synthesize_timeseries(*,
                          firings='',
                          waveforms='',
                          timeseries_out,
                          noise_level=1,
                          samplerate=30000,
                          duration=60,
                          waveform_upsamplefac,
                          amplitudes_row=0):
    """
    Synthesize an electrophysiology timeseries from a set of ground-truth firing events and waveforms

    Parameters
    ----------
    firings : INPUT
        (Optional) The path of firing events file in .mda format. RxL where R>=3 and L is the number of events. Second row is the timestamps, third row is the integer labels/
    waveforms : INPUT
        (Optional) The path of (possibly upsampled) waveforms file in .mda format. Mx(T*waveform_upsample_factor)*K, where M is the number of channels, T is the clip size, and K is the number of units.
    
    timeseries_out : OUTPUT
        The output path for the new timeseries. MxN

    noise_level : double
        (Optional) Standard deviation of the simulated background noise added to the timeseries
    samplerate : double
        (Optional) Sample rate for the synthetic dataset in Hz
    duration : double
        (Optional) Duration of the synthetic dataset in seconds. The number of timepoints will be duration*samplerate
    waveform_upsamplefac : int
        (Optional) The upsampling factor corresponding to the input waveforms. (avoids digitization artifacts)
    amplitudes_row : int
        (Optional) If positive, this is the row in the firings arrays where the amplitude scale factors are found. Otherwise, use all 1's
    """
    num_timepoints = np.int64(samplerate * duration)
    waveform_upsamplefac = int(waveform_upsamplefac)

    if type(waveforms) == str:
        if waveforms:
            W = mdaio.readmda(waveforms)
        else:
            W = np.zeros((4, 100 * waveform_upsamplefac, 0))
    else:
        W = waveforms

    if type(firings) == str:
        if firings:
            F = mdaio.readmda(firings)
        else:
            F = np.zeros((3, 0))
    else:
        F = firings

    times = F[1, :]
    labels = F[2, :].astype('int')

    M, TT, K = W.shape[0], W.shape[1], W.shape[2]
    T = int(TT / waveform_upsamplefac)
    Tmid = int(np.ceil((T + 1) / 2 - 1))

    N = num_timepoints
    if (N == 0):
        if times.size == 0:
            N = T
        else:
            N = max(times) + T

    X = np.random.randn(M, N) * noise_level

    waveform_list = []
    for k in range(K):
        waveform0 = W[:, :, k - 1]
        waveform_list.append(waveform0)

    for j in range(times.size):
        t0 = times[j]
        k0 = labels[j]
        amp0 = 1
        if amplitudes_row > 0:
            amp0 = F[amplitudes_row - 1, j]
        waveform0 = waveform_list[k0 - 1]
        frac_offset = int(np.floor((t0 - np.floor(t0)) * waveform_upsamplefac))
        tstart = np.int64(np.floor(t0)) - Tmid
        if (0 <= tstart) and (tstart + T <= N):
            X[:, tstart:tstart +
              T] = X[:, tstart:tstart +
                     T] + waveform0[:,
                                    frac_offset::waveform_upsamplefac] * amp0

    if timeseries_out:
        return mdaio.writemda32(X, timeseries_out)
    else:
        return (X)
예제 #8
0
def convert_array(*,
                  input,
                  output,
                  format='',
                  format_out='',
                  dimensions='',
                  dtype='',
                  dtype_out='',
                  channels=''):
    """
    Convert a multi-dimensional array between various formats ('.mda', '.npy', '.dat') based on the file extensions of the input/output files

    Parameters
    ----------
    input : INPUT
        Path of input array file (can be repeated for concatenation).
    output : OUTPUT
        Path of the output array file.
        
    format : string
        The format for the input array (mda, npy, dat), or determined from the file extension if empty
    format_out : string
        The format for the output input array (mda, npy, dat), or determined from the file extension if empty
    dimensions : string
        Comma-separated list of dimensions (shape). If empty, it is auto-determined, if possible, by the input array. If second dim is -1 then it will be extrapolated from file size / first dim.
    dtype : string
        The data format for the input array. Choices: int8, int16, int32, uint16, uint32, float32, float64 (possibly float16 in the future).
    dtype_out : string
        The data format for the output array. If empty, the dtype for the input array is used.
    channels : string
        Comma-seperated list of channels to keep in output. Zero-based indexing. Only works for .dat to .mda conversions.
    """
    if isinstance(input, (list, )):
        multifile = True
        inputs = input
        input = inputs[0]
    else:
        multifile = False

    format_in = format
    if not format_in:
        format_in = determine_file_format(file_extension(input), dimensions)
    if not format_out:
        format_out = determine_file_format(file_extension(output), dimensions)
    print('Input/output formats: {}/{}'.format(format_in, format_out))
    ext_in = file_extension(input)

    dims = None

    if (format_in == 'mda') and (dtype == ''):
        header = mdaio.readmda_header(input)
        dtype = header.dt
        dims = header.dims

    if (format_in == 'npy') and (dtype == ''):
        A = np.load(input, mmap_mode='r')
        dtype = npy_dtype_to_string(A.dtype)
        dims = A.shape
        A = 0

    if dimensions:
        dims2 = [int(entry) for entry in dimensions.split(',')]
        if dims:
            if len(dims) != len(dims2):
                raise Exception(
                    'Inconsistent number of dimensions for input array')
            if not np.all(np.array(dims) == np.array(dims2)):
                raise Exception('Inconsistent dimensions for input array')
        dims = dims2

    if not dtype_out:
        dtype_out = dtype

    if not dtype:
        raise Exception('Unable to determine datatype for input array')

    if not dtype_out:
        raise Exception('Unable to determine datatype for output array')

    if (dims[1] == -1) and (dims[0] > 0):
        if ((dtype) and (format_in == 'dat') and not multifile):
            bits = int(
                dtype[-2:]
            )  # number of bits per entry of dtype, TODO: make this smarter
            filebytes = os.stat(input).st_size  # bytes in input file
            entries = int(filebytes / (int(bits / 8)))  # entries in input file
            dims[1] = int(entries / dims[0])  # caclulated second dimension
            if DEBUG:
                print(bits)
                print(filebytes)
                print(int(filebytes / (int(bits / 8))))
                print(dims)
        else:
            raise Exception('Could not infer dimensions')

    if not dims:
        raise Exception('Unable to determine dimensions for input array')

    if not channels:
        channels = range(0, dims[0])
    else:
        channels = np.array([int(entry) for entry in channels.split(',')])

    if DEBUG:
        print(channels)

    print('Using dtype={}, dtype_out={}, dimensions={}'.format(
        dtype, dtype_out, ','.join(str(item) for item in dims)))
    if (format_in == format_out) and ((dtype == dtype_out) or
                                      (dtype_out == '')):
        if multifile and (format_in == 'dat'):
            print('Concatenating Files')
            with open(output, "wb") as outfile:
                for input_file in inputs:
                    with open(input_file, "rb") as inpt:
                        outfile.write(inpt.read())
        elif not multifile:
            print('Simply copying file...')
            shutil.copyfile(input, output)
            print('Done.')
        return True

    if format_out == 'dat' and not multifile:
        if format_in == 'mda':
            H = mdaio.readmda_header(input)
            copy_raw_file_data(input,
                               output,
                               start_byte=H.header_size,
                               num_entries=np.product(dims),
                               dtype=dtype,
                               dtype_out=dtype_out)
            return True
        elif format_in == 'npy':
            print(
                'Warning: loading entire array into memory. This should be avoided in the future.'
            )
            A = np.load(input, mmap_mode='r').astype(dtype=dtype_out,
                                                     order='F',
                                                     copy=False)
            A = A.ravel(order='F')
            A.tofile(output)
            # The following was problematic because of row-major ordering, i think
            #header_size=determine_npy_header_size(input)
            #copy_raw_file_data(input,output,start_byte=header_size,num_entries=np.product(dims),dtype=dtype,dtype_out=dtype_out)
            return True
        elif format_in == 'dat':
            raise Exception('This case not yet implemented.')
        else:
            raise Exception('Unexpected case.')

    elif (format_out == 'mda') or (format_out == 'npy'):
        if format_in == 'dat' and multifile:
            print(
                'Warning: loading entire array into memory. This should be avoided in the future.'
            )
            A = np.fromfile(inputs[0], dtype=dtype, count=np.product(dims))
            A = A.reshape(tuple(dims), order='F')
            A = A[channels, :]
            for inputn in inputs[1:]:
                An = np.fromfile(inputn, dtype=dtype, count=np.product(dims))
                An = An.reshape(tuple(dims), order='F')
                An = An[channels, :]
                A = np.concatenate((A, An), axis=0)
            if format_out == 'mda':
                mdaio.writemda(A, output, dtype=dtype_out)
            else:
                mdaio.writenpy(A, output, dtype=dtype_out)
            return True
        elif format_in == 'dat' and not multifile:
            print(
                'Warning: loading entire array into memory. This should be avoided in the future.'
            )
            A = np.fromfile(input, dtype=dtype, count=np.product(dims))
            A = A.reshape(tuple(dims), order='F')
            A = A[channels, :]
            if format_out == 'mda':
                mdaio.writemda(A, output, dtype=dtype_out)
            else:
                mdaio.writenpy(A, output, dtype=dtype_out)
            return True
        elif format_in == 'mda' and not multifile:
            print(
                'Warning: loading entire array into memory. This should be avoided in the future.'
            )
            A = mdaio.readmda(input)
            if format_out == 'mda':
                mdaio.writemda(A, output, dtype=dtype_out)
            else:
                mdaio.writenpy(A, output, dtype=dtype_out)
            return True
        elif format_in == 'mda' and multifile:
            print(
                'Warning: loading entire array into memory. This should be avoided in the future.'
            )
            A = mdaio.readmda(inputs[0])
            A = A[channels, :]
            for inputn in inputs[1:]:
                An = mdaio.readmda(inputn)
                An = An[channels, :]
                A = np.concatenate((A, An), axis=0)
            if format_out == 'mda':
                mdaio.writemda(A, output, dtype=dtype_out)
            else:
                mdaio.writenpy(A, output, dtype=dtype_out)
            return True
        elif format_in == 'npy' and not multifile:
            print(
                'Warning: loading entire array into memory. This should be avoided in the future.'
            )
            A = np.load(input, mmap_mode='r').astype(dtype=dtype,
                                                     order='F',
                                                     copy=False)
            if format_out == 'mda':
                mdaio.writemda(A, output, dtype=dtype_out)
            else:
                mdaio.writenpy(A, output, dtype=dtype_out)
            return True
        elif format_in == 'npy' and multifile:
            print(
                'Warning: loading entire array into memory. This should be avoided in the future.'
            )
            A = np.load(inputs[0], mmap_mode='r').astype(dtype=dtype,
                                                         order='F',
                                                         copy=False)
            A = A[channels, :]
            for inputn in inputs[1:]:
                An = np.load(inputn, mmap_mode='r').astype(dtype=dtype,
                                                           order='F',
                                                           copy=False)
                An = An[channels, :]
                A = np.concatenate((A, An), axis=0)
            if format_out == 'mda':
                mdaio.writemda(A, output, dtype=dtype_out)
            else:
                mdaio.writenpy(A, output, dtype=dtype_out)
            return True
        else:
            raise Exception('Unexpected case.')
    else:
        raise Exception('Unexpected output format: {}'.format(format_out))

    raise Exception('Unexpected error.')
예제 #9
0
def convert_clips_fet_spk(*,
                          timeseries,
                          firings,
                          waveforms_out,
                          ntro_nchannels,
                          clip_size=32,
                          nPC=4,
                          DEBUG=True):
    """
    Compute templates (average waveforms) for clusters defined by the labeled events in firings. One .spk.n file per n-trode.

    Parameters
    ----------
    timeseries : INPUT
        Path of timeseries mda file (MxN) from which to draw the event clips (snippets) for computing the templates. M is number of channels, N is number of timepoints.
    firings : INPUT
        Path of firings mda file (RxL) where R>=3 and L is the number of events. Second row are timestamps, third row are integer labels.
    params : INPUT
        params.json file. Needed to see number of channels per tetrode.
        
    ntro_nchannels : INPUT
        Comma-seperated determining the number of channels should be taken for each ntrode.
        
    waveforms_out : OUTPUT
        Base Path (MxTxK). T=clip_size, K=maximum cluster label. Note that empty clusters will correspond to a template of all zeros. 
        
    clip_size : int
        (Optional) clip size, aka snippet size, number of timepoints in a single template
    
    nPC : int
        (Optional) Number of principal components *per channel* for .fet files.
        
    DEBUG : bool
        (Optional) Verbose output for debugging purposes.
    """
    X = mdaio.DiskReadMda(timeseries)
    M, N = X.N1(), X.N2()
    F = mdaio.readmda(firings)
    L = F.shape[1]
    L = L
    T = clip_size
    Tmid = int(np.floor((T + 1) / 2) - 1)
    whch = F[0, :].ravel()[:]
    times = F[1, :].ravel()[:]
    labels = F[2, :].ravel().astype(int)[:]
    K = np.max(labels)

    tetmap = list()
    i = 0
    for nch in ntro_nchannels.split(','):
        tp = i + 1 + np.arange(0, int(nch))
        tetmap.append(tp)
        i = tp[-1]
    which_tet = [np.where(w == tetmap)[0][0] + 1 for w in whch]

    print("Starting:")
    for tro in np.arange(1, 12):
        chans = tetmap[tro - 1]
        inds_k = np.where(which_tet == tro)[0]

        if DEBUG:
            print("Tetrode: " + str(tro))
            print("Chans: " + str(chans))
            print("Create Waveforms Array: " + str(len(chans)) + "," + str(T) +
                  "," + str(len(inds_k)))

        waveforms = np.zeros((len(chans), T, len(inds_k)), dtype='int16')
        for i, ind_k in enumerate(inds_k):  # for each spike
            t0 = int(times[ind_k])
            if (clip_size <= t0) and (t0 < N - clip_size):
                clip0 = X.readChunk(i1=0, N1=M, i2=t0 - Tmid, N2=T)
                clip0 = clip0[chans, :] * 100
                clip_int = clip0.astype(dtype='int16')
                waveforms[:, :, i] = clip_int

        fname = waveforms_out + '.spk.' + str(tro)
        if DEBUG:
            print("Writing Waveforms to File: " + fname)
        waveforms.tofile(fname, format='')

        if DEBUG:
            print("Calculating Feature Array")
        fet = np.zeros((np.shape(waveforms)[2], (len(chans) * nPC) + 1))
        for c in np.arange(len(chans)):
            pca = decomposition.PCA(n_components=nPC)
            x_std = StandardScaler().fit_transform(
                np.transpose(waveforms[c, :, :]).astype(dtype='float64'))
            fpos = (c * nPC)
            fet[:, fpos:fpos + 4] = pca.fit_transform(x_std)

        fet[:, (len(chans) * nPC)] = times[inds_k]
        fet *= 1000
        fet.astype(dtype='int64')
        fname = waveforms_out + '.fet.' + str(tro)
        if DEBUG:
            print("Writing Features to File: " + fname)
        np.savetxt(fname, fet, fmt='%d')

    return True
예제 #10
0
def compare_ground_truth(*, firings_true, firings, json_out, max_dt=20):
    """
    compare a sorting (firings) with ground truth (firings_true)

    Parameters
    ----------
    firings_true : INPUT
        Path of true firings file (RxL) R>=3, L=#evts
    firings : INPUT
        Path of sorted firings file (RxL) R>=3, L=#evts
    json_out : OUTPUT
        Path of the output file containing the results of the comparison
        
    max_dt : int
        Tolerance for matching events (in timepoints)
        
    """

    print('Reading arrays...')
    F = mdaio.readmda(firings)
    Ft = mdaio.readmda(firings_true)
    print('Initializing data...')
    L = F.shape[1]
    Lt = Ft.shape[1]
    times = F[1, :]
    times_true = Ft[1, :]
    labels = F[2, :].astype(np.int32)
    labels_true = Ft[2, :].astype(np.int32)
    F = 0  # free memory? Alex: depends if python decides to call garbage collection
    Ft = 0  # free memory?
    N = np.maximum(np.max(times), np.max(times_true))
    K = np.max(labels)
    Kt = np.max(labels_true)

    # todo: subsample in first pass to get the best

    # First we split into segments
    print('Splitting into segments')
    segment_size = max_dt
    num_segments = int(np.ceil(N / segment_size))
    N = num_segments * segment_size
    segments = []
    # occupancy: K x num_segments
    occupancy, counts = create_occupancy_array(times,
                                               labels,
                                               segment_size,
                                               num_segments,
                                               K,
                                               spread=False,
                                               multiplicity=True)
    # occupancy_true: Kt x num_segments
    occupancy_true, counts_true = create_occupancy_array(times_true,
                                                         labels_true,
                                                         segment_size,
                                                         num_segments,
                                                         Kt,
                                                         spread=True,
                                                         multiplicity=False)

    # Note: we spread the occupancy_true but not the occupancy
    # Note: we count the occupancy with multiplicity but not occupancy_true

    print('Computing pairwise counts and accuracies...')
    pairwise_counts = occupancy_true @ occupancy.transpose()  # Kt x K
    pairwise_accuracies = np.zeros((Kt, K))
    for k1 in range(1, Kt + 1):
        for k2 in range(1, K):
            numer = pairwise_counts[k1 - 1, k2 - 1]
            denom = counts_true[k1 - 1] + counts[k2 - 1] - numer
            if denom > 0:
                pairwise_accuracies[k1 - 1, k2 - 1] = numer / denom

    print('Preparing output...')
    ret = {"true_units": {}}
    for k1 in range(1, Kt + 1):
        k2_match = int(1 + np.argmax(pairwise_accuracies[k1 - 1, :].ravel()))
        # todo: compute accuracy more precisely here
        num_matches = int(pairwise_counts[k1 - 1, k2_match - 1])
        num_false_positives = int(counts[k2_match - 1] - num_matches)
        num_false_negatives = int(counts_true[k1 - 1] - num_matches)
        unit = {
            "best_match": k2_match,
            "accuracy": pairwise_accuracies[k1 - 1, k2_match - 1],
            "num_matches": num_matches,
            "num_false_positives": num_false_positives,
            "num_false_negatives": num_false_negatives
        }
        ret['true_units'][k1] = unit

    print('Writing output...')
    str = json.dumps(ret, indent=4)
    with open(json_out, 'w') as out:
        out.write(str)
    print('Done.')

    return True
예제 #11
0
 def __init__(self,path):
     self._timeseries_path=path
     self._timeseries=mdaio.readmda(path)
예제 #12
0
def compute_cluster_metrics(*,
                            timeseries='',
                            firings,
                            metrics_out,
                            clip_size=100,
                            samplerate=0,
                            refrac_msec=1):
    """
    Compute cluster metrics for a spike sorting output

    Parameters
    ----------
    firings : INPUT
        Path of firings mda file (RxL) where R>=3 and L is the number of events. Second row are timestamps, third row are integer cluster labels.
    timeseries : INPUT
        Optional path of timeseries mda file (MxN) which could be raw or preprocessed   
        
    metrics_out : OUTPUT
        Path of output json file containing the metrics.
        
    clip_size : int
        (Optional) clip size, aka snippet size (used when computing the templates, or average waveforms)
    samplerate : float
        Optional sample rate in Hz

    refrac_msec : float
        (Optional) Define interval (in ms) as refractory period. If 0 then don't compute the refractory period metric.
    """
    print('Reading firings...')
    F = mdaio.readmda(firings)

    print('Initializing...')
    R = F.shape[0]
    L = F.shape[1]
    assert (R >= 3)
    times = F[1, :]
    labels = F[2, :].astype(np.int)
    K = np.max(labels)
    N = 0
    if timeseries:
        X = mdaio.DiskReadMda(timeseries)
        N = X.N2()

    if (samplerate > 0) and (N > 0):
        duration_sec = N / samplerate
    else:
        duration_sec = 0

    clusters = []
    for k in range(1, K + 1):
        inds_k = np.where(labels == k)[0]
        metrics_k = {"num_events": len(inds_k)}
        if duration_sec:
            metrics_k['firing_rate'] = len(inds_k) / duration_sec
        cluster = {"label": k, "metrics": metrics_k}
        clusters.append(cluster)

    if timeseries:
        print('Computing templates...')
        templates = compute_templates_helper(timeseries=timeseries,
                                             firings=firings,
                                             clip_size=clip_size)
        for k in range(1, K + 1):
            template_k = templates[:, :, k - 1]
            # subtract mean on each channel (todo: vectorize this operation)
            # Alex: like this?
            # template_k[m,:] -= np.mean(template_k[m,:])
            for m in range(templates.shape[0]):
                template_k[m, :] = template_k[m, :] - np.mean(template_k[m, :])
            peak_amplitude = np.max(np.abs(template_k))
            clusters[k - 1]['metrics']['peak_amplitude'] = peak_amplitude
        ## todo: subtract template means, compute peak amplitudes

    if refrac_msec > 0:
        print('Computing Refractory Period Violations')
        msec_string = '{0:.5f}'.format(refrac_msec).rstrip('0').rstrip('.')
        rr_string = 'refractory_violation_{}msec'.format(msec_string)
        max_dt_msec = 50
        bin_size_msec = refrac_msec
        auto_cors = compute_cross_correlograms_helper(
            firings=firings,
            mode='autocorrelograms',
            samplerate=samplerate,
            max_dt_msec=max_dt_msec,
            bin_size_msec=bin_size_msec)
        mid_ind = np.floor(max_dt_msec / bin_size_msec).astype(int)
        mid_inds = np.arange(mid_ind - 2, mid_ind + 2)
        for k0, auto_cor_obj in enumerate(auto_cors['correlograms']):
            auto = auto_cor_obj['bin_counts']
            k = auto_cor_obj['k']
            peak = np.percentile(auto, 75)
            mid = np.mean(auto[mid_inds])
            rr = safe_div(mid, peak)
            clusters[k - 1]['metrics'][rr_string] = rr

    ret = {"clusters": clusters}

    print('Writing output...')
    str = json.dumps(ret, indent=4)
    with open(metrics_out, 'w') as out:
        out.write(str)
    print('Done.')
    return True
예제 #13
0
def synthesize_drifting_timeseries(*,
                                   firings,
                                   waveforms,
                                   timeseries_out=None,
                                   noise_level=1,
                                   samplerate=30000,
                                   duration=60,
                                   waveform_upsamplefac=1,
                                   amplitudes_row=0,
                                   num_interp_nodes=2):
    """
    Synthesize a electrophysiology timeseries from a set of ground-truth firing events and waveforms, and simulating drift (linear for now)

    Parameters
    ----------
    firings : INPUT
        (Optional) The path of firing events file in .mda format. RxL where 
        R>=3 and L is the number of events. Second row is the timestamps, 
        third row is the integer labels
    waveforms : INPUT
        (Optional) The path of (possibly upsampled) waveforms file in .mda
        format. Mx(T*waveform_upsample_factor)*(2K), where M is the number of
        channels, T is the clip size, and K is the number of units. Each unit
        has a contiguous pair of waveforms (interpolates from first to second
        across the timeseries)
    
    timeseries_out : OUTPUT
        The output path for the new timeseries. MxN

    noise_level : double
        (Optional) Standard deviation of the simulated background noise added to the timeseries
    samplerate : double
        (Optional) Sample rate for the synthetic dataset in Hz
    duration : double
        (Optional) Duration of the synthetic dataset in seconds. The number of timepoints will be duration*samplerate
    waveform_upsamplefac : int
        (Optional) The upsampling factor corresponding to the input waveforms. (avoids digitization artifacts)
    amplitudes_row : int
        (Optional) If positive, this is the row in the firings arrays where the amplitude scale factors are found. Otherwise, use all 1's
    num_interp_nodes : int
        (Optional) For drift, the number of timepoints where we specify the waveform (Default 2)
    """

    if type(firings) == str:
        F = mdaio.readmda(firings)
    else:
        F = firings

    if amplitudes_row == 0:
        F = np.concatenate((F, np.ones((1, F.shape[1]))))
        amplitudes_row = F.shape[0]

    times = F[1, :]
    times_normalized = times / (duration * samplerate
                                )  #normalized between 0 and 1
    labels = F[2, :]
    amps = F[amplitudes_row - 1, :]

    F = np.kron(F, [1] * num_interp_nodes)  #duplicate every event!

    for j in range(num_interp_nodes):
        F[amplitudes_row - 1, j::num_interp_nodes] = amps * time_basis_func(
            j, num_interp_nodes, times_normalized)
        # adjust the labels
        F[2, j::num_interp_nodes] = (
            labels -
            1) * num_interp_nodes + j + 1  #remember that labels are 1-indexed
    return synthesize_timeseries(firings=F,
                                 waveforms=waveforms,
                                 timeseries_out=timeseries_out,
                                 noise_level=noise_level,
                                 samplerate=samplerate,
                                 duration=duration,
                                 waveform_upsamplefac=waveform_upsamplefac,
                                 amplitudes_row=amplitudes_row)
예제 #14
0
def loadMdaFile(obj):
    return mdaio.readmda(getFilePath(obj))