Esempio n. 1
0
 def get_max_loc_channel(params, extension):
     if test_if_support(params, extension):
         supports = io.load_data(params, 'supports', extension)
         max_loc_channel = numpy.sum(supports, 1).max()
     else:
         nodes, edges = get_nodes_and_edges(params)
         max_loc_channel = 0
         for key in list(edges.keys()):
             if len(edges[key]) > max_loc_channel:
                 max_loc_channel = len(edges[key])
     return max_loc_channel
Esempio n. 2
0
def delete_mixtures(params, nb_cpu, nb_gpu, use_gpu):

    data_file      = params.data_file
    N_e            = params.getint('data', 'N_e')
    N_total        = params.nb_channels
    N_t            = params.getint('detection', 'N_t')
    template_shift = params.getint('detection', 'template_shift')
    cc_merge       = params.getfloat('clustering', 'cc_merge')
    mixtures       = []
    to_remove      = []

    filename         = params.get('data', 'file_out_suff') + '.overlap-mixtures.hdf5'
    norm_templates   = load_data(params, 'norm-templates')
    best_elec        = load_data(params, 'electrodes')
    limits           = load_data(params, 'limits')
    nodes, edges     = get_nodes_and_edges(params)
    inv_nodes        = numpy.zeros(N_total, dtype=numpy.int32)
    inv_nodes[nodes] = numpy.arange(len(nodes))
    decimation       = params.getboolean('clustering', 'decimation')
    has_support      = test_if_support(params, '')

    overlap = get_overlaps(params, extension='-mixtures', erase=True, normalize=True, maxoverlap=False, verbose=False, half=True, use_gpu=use_gpu, nb_cpu=nb_cpu, nb_gpu=nb_gpu, decimation=decimation)
    overlap.close()

    SHARED_MEMORY = get_shared_memory_flag(params)

    if SHARED_MEMORY:
        c_overs    = load_data_memshared(params, 'overlaps', extension='-mixtures', use_gpu=use_gpu, nb_cpu=nb_cpu, nb_gpu=nb_gpu)
    else:
        c_overs    = load_data(params, 'overlaps', extension='-mixtures')

    if SHARED_MEMORY:
        templates  = load_data_memshared(params, 'templates', normalize=True)
    else:
        templates  = load_data(params, 'templates')

    x,        N_tm = templates.shape
    nb_temp        = int(N_tm//2)
    merged         = [nb_temp, 0]

    if has_support:
        supports = load_data(params, 'supports')
    else:
        supports = {}
        for t in range(N_e):
            elecs = numpy.take(inv_nodes, edges[nodes[t]])
            supports[t] = elecs

    overlap_0 = numpy.zeros(nb_temp, dtype=numpy.float32)
    distances = numpy.zeros((nb_temp, nb_temp), dtype=numpy.int32)

    for i in xrange(nb_temp-1):
        data = c_overs[i].toarray()
        distances[i, i+1:] = numpy.argmax(data[i+1:, :], 1)
        distances[i+1:, i] = distances[i, i+1:]
        overlap_0[i] = data[i, N_t - 1]

    all_temp    = numpy.arange(comm.rank, nb_temp, comm.size)
    sorted_temp = numpy.argsort(norm_templates[:nb_temp])[::-1]
    M           = numpy.zeros((2, 2), dtype=numpy.float32)
    V           = numpy.zeros((2, 1), dtype=numpy.float32)

    to_explore = xrange(comm.rank, nb_temp, comm.size)
    if comm.rank == 0:
        to_explore = get_tqdm_progressbar(to_explore)

    for count, k in enumerate(to_explore):

        k             = sorted_temp[k]
        overlap_k     = c_overs[k]
        if has_support:
            electrodes= numpy.where(supports[k])[0]
            all_idx   = [numpy.any(numpy.in1d(numpy.where(supports[t])[0], electrodes)) for t in range(nb_temp)]
        else:
            electrodes= numpy.take(inv_nodes, edges[nodes[best_elec[k]]])
            all_idx   = [numpy.any(numpy.in1d(supports[best_elec[t]], electrodes)) for t in range(nb_temp)]
        all_idx       = numpy.arange(nb_temp)[all_idx]
        been_found    = False
        t_k           = None

        for l, i in enumerate(all_idx):
            t_i = None
            if not been_found:
                overlap_i = c_overs[i]
                M[0, 0]   = overlap_0[i]
                V[0, 0]   = overlap_k[i, distances[k, i]]
                for j in all_idx[l+1:]:
                    t_j = None
                    M[1, 1]  = overlap_0[j]
                    M[1, 0]  = overlap_i[j, distances[k, i] - distances[k, j]]
                    M[0, 1]  = M[1, 0]
                    V[1, 0]  = overlap_k[j, distances[k, j]]
                    try:
                        [a1, a2] = numpy.dot(scipy.linalg.inv(M), V)
                    except Exception:
                        [a1, a2] = [0, 0]
                    a1_lim   = limits[i]
                    a2_lim   = limits[j]
                    is_a1    = (a1_lim[0] <= a1) and (a1 <= a1_lim[1])
                    is_a2    = (a2_lim[0] <= a2) and (a2 <= a2_lim[1])
                    if is_a1 and is_a2:
                        if t_k is None:
                            t_k = templates[:, k].toarray().ravel()
                        if t_i is None:
                            t_i = templates[:, i].toarray().ravel()
                        if t_j is None:
                            t_j = templates[:, j].toarray().ravel()
                        new_template = (a1*t_i + a2*t_j)
                        similarity   = numpy.corrcoef(t_k, new_template)[0, 1]
                        local_overlap = numpy.corrcoef(t_i, t_j)[0, 1]
                        if similarity > cc_merge and local_overlap < cc_merge:
                            if k not in mixtures:
                                mixtures  += [k]
                                been_found = True
                                break

    sys.stderr.flush()
    to_remove = numpy.unique(numpy.array(mixtures, dtype=numpy.int32))
    to_remove = all_gather_array(to_remove, comm, 0, dtype='int32')

    if len(to_remove) > 0 and comm.rank == 0:
        result = load_data(params, 'clusters')
        slice_templates(params, to_remove)
        slice_clusters(params, result, to_remove=to_remove)

    comm.Barrier()

    del c_overs

    if comm.rank == 0:
        os.remove(filename)

    return [nb_temp, len(to_remove)]
Esempio n. 3
0
    def write_pcs(path, params, extension, N_tm, mode=0):

        spikes = numpy.load(os.path.join(output_path, 'spike_times.npy'))
        labels = numpy.load(os.path.join(output_path, 'spike_templates.npy'))
        max_loc_channel = get_max_loc_channel(params, extension)
        nb_features = params.getint('whitening', 'output_dim')
        sign_peaks = params.get('detection', 'peaks')
        nodes, edges = get_nodes_and_edges(params)
        N_total = params.getint('data', 'N_total')
        has_support = test_if_support(params, extension)
        if has_support:
            supports = io.load_data(params, 'supports', extension)
        else:
            inv_nodes = numpy.zeros(N_total, dtype=numpy.int32)
            inv_nodes[nodes] = numpy.arange(len(nodes))

        if export_all:
            nb_templates = N_tm + N_e
        else:
            nb_templates = N_tm

        pc_features_ind = numpy.zeros((nb_templates, max_loc_channel),
                                      dtype=numpy.int32)
        best_elec = io.load_data(params, 'electrodes', extension)
        if export_all:
            best_elec = numpy.concatenate((best_elec, numpy.arange(N_e)))

        if has_support:
            for count, support in enumerate(supports):
                nb_loc = numpy.sum(support)
                pc_features_ind[count, numpy.arange(nb_loc)] = numpy.where(
                    support == True)[0]
        else:
            for count, elec in enumerate(best_elec):
                nb_loc = len(edges[nodes[elec]])
                pc_features_ind[count, numpy.arange(nb_loc)] = inv_nodes[edges[
                    nodes[elec]]]

        if sign_peaks in ['negative', 'both']:
            basis_proj, basis_rec = io.load_data(params, 'basis')
        elif sign_peaks in ['positive']:
            basis_proj, basis_rec = io.load_data(params, 'basis-pos')

        to_process = numpy.arange(comm.rank, nb_templates, comm.size)

        all_offsets = numpy.zeros(nb_templates, dtype=numpy.int32)
        for target in range(nb_templates):
            if mode == 0:
                all_offsets[target] = len(numpy.where(labels == target)[0])
            elif mode == 1:
                all_offsets[target] = min(
                    500, len(numpy.where(labels == target)[0]))

        all_paddings = numpy.concatenate(([0], numpy.cumsum(all_offsets)))
        total_pcs = numpy.sum(all_offsets)

        pc_file = os.path.join(output_path, 'pc_features.npy')
        pc_file_ids = os.path.join(output_path, 'pc_feature_spike_ids.npy')

        from numpy.lib.format import open_memmap

        if comm.rank == 0:
            pc_features = open_memmap(pc_file,
                                      shape=(total_pcs, nb_features,
                                             max_loc_channel),
                                      dtype=numpy.float32,
                                      mode='w+')
            if mode == 1:
                pc_ids = open_memmap(pc_file_ids,
                                     shape=(total_pcs, ),
                                     dtype=numpy.int32,
                                     mode='w+')

        comm.Barrier()
        pc_features = open_memmap(pc_file, mode='r+')
        if mode == 1:
            pc_ids = open_memmap(pc_file_ids, mode='r+')

        to_explore = list(range(comm.rank, nb_templates, comm.size))

        if comm.rank == 0:
            to_explore = get_tqdm_progressbar(params, to_explore)

        all_idx = numpy.zeros(0, dtype=numpy.int32)
        for gcount, target in enumerate(to_explore):

            count = all_paddings[target]

            if mode == 1:
                idx = numpy.random.permutation(
                    numpy.where(labels == target)[0])[:500]
                pc_ids[count:count + len(idx)] = idx
            elif mode == 0:
                idx = numpy.where(labels == target)[0]

            elec = best_elec[target]

            if has_support:
                if target >= len(supports):
                    indices = [target - N_tm]
                else:
                    indices = numpy.where(supports[target])[0]
            else:
                indices = inv_nodes[edges[nodes[elec]]]
            labels_i = target * numpy.ones(len(idx))
            times_i = numpy.take(spikes, idx).astype(numpy.int64)
            sub_data = io.get_stas(params,
                                   times_i,
                                   labels_i,
                                   elec,
                                   neighs=indices,
                                   nodes=nodes,
                                   auto_align=False)

            pcs = numpy.dot(sub_data, basis_proj)
            pcs = numpy.swapaxes(pcs, 1, 2)
            if mode == 0:
                pc_features[idx, :, :len(indices)] = pcs
            elif mode == 1:
                pc_features[count:count + len(idx), :, :len(indices)] = pcs

        comm.Barrier()

        if comm.rank == 0:
            numpy.save(os.path.join(output_path, 'pc_feature_ind'),
                       pc_features_ind.astype(
                           numpy.uint32))  # n_templates, n_loc_chan
Esempio n. 4
0
def slice_templates(params, to_remove=[], to_merge=[], extension='', input_extension=''):
    """Slice templates in HDF5 file.

    Arguments:
        params
        to_remove: list (optional)
            An array of template indices to remove.
            The default value is [].
        to_merge: list | numpy.ndarray (optional)
            An array of pair of template indices to merge
            (i.e. shape = (nb_merges, 2)).
            The default value is [].
        extension: string (optional)
            The extension to use as output.
            The default value is ''.
        input_extension: string (optional)
            The extension to use as input.
            The default value is ''.
    """

    file_out_suff  = params.get('data', 'file_out_suff')

    data_file      = params.data_file
    N_e            = params.getint('data', 'N_e')
    N_total        = params.nb_channels
    hdf5_compress  = params.getboolean('data', 'hdf5_compress')
    N_t            = params.getint('detection', 'N_t')
    template_shift = params.getint('detection', 'template_shift')
    has_support    = test_if_support(params, input_extension)

    if comm.rank == 0:
        print_and_log(['Node 0 is slicing templates'], 'debug', logger)
        old_templates  = load_data(params, 'templates', extension=input_extension)
        old_limits     = load_data(params, 'limits', extension=input_extension)
        if has_support:
            old_supports   = load_data(params, 'supports', extension=input_extension)
        _, N_tm        = old_templates.shape
        norm_templates = load_data(params, 'norm-templates', extension=input_extension)

        # Determine the template indices to delete.
        to_delete = list(to_remove)  # i.e. copy
        if to_merge != []:
            for count in xrange(len(to_merge)):
                remove = to_merge[count][1]
                to_delete += [remove]

        # Determine the indices to keep.
        all_templates = set(numpy.arange(N_tm // 2))
        to_keep = numpy.array(list(all_templates.difference(to_delete)))

        positions = numpy.arange(len(to_keep))


        # Initialize new HDF5 file for templates.
        local_keep = to_keep[positions]
        templates  = scipy.sparse.lil_matrix((N_e*N_t, 2*len(to_keep)), dtype=numpy.float32)
        hfilename  = file_out_suff + '.templates{}.hdf5'.format('-new')
        hfile      = h5py.File(hfilename, 'w', libver='earliest')
        norms      = hfile.create_dataset('norms', shape=(2*len(to_keep), ), dtype=numpy.float32, chunks=True)
        limits     = hfile.create_dataset('limits', shape=(len(to_keep), 2), dtype=numpy.float32, chunks=True)
        if has_support:
            supports   = hfile.create_dataset('supports', shape=(len(to_keep), N_e), dtype=numpy.bool, chunks=True)
        # For each index to keep.
        for count, keep in zip(positions, local_keep):
            # Copy template.
            templates[:, count]                = old_templates[:, keep]
            templates[:, count + len(to_keep)] = old_templates[:, keep + N_tm//2]
            # Copy norm.
            norms[count]                       = norm_templates[keep]
            norms[count + len(to_keep)]        = norm_templates[keep + N_tm//2]
            if has_support:
                supports[count]                = old_supports[keep]
            # Copy limits.
            if to_merge == []:
                new_limits = old_limits[keep]
            else:
                subset     = numpy.where(to_merge[:, 0] == keep)[0]
                if len(subset) > 0:
                    # pylab.subplot(211)
                    # pylab.plot(templates[:, count].toarray().flatten())
                    # ymin, ymax = pylab.ylim()
                    # pylab.subplot(212)
                    # for i in to_merge[subset]:
                    #     pylab.plot(old_templates[:, i[1]].toarray().flatten())
                    # pylab.ylim(ymin, ymax)
                    # pylab.savefig('merge_%d.png' %count)
                    # pylab.close()
                    # Index to keep is involved in merge(s) and limits need to
                    # be updated.
                    idx        = numpy.unique(to_merge[subset].flatten())
                    ratios     = norm_templates[idx] / norm_templates[keep]
                    new_limits = [
                        numpy.min(ratios * old_limits[idx][:, 0]),
                        numpy.max(ratios * old_limits[idx][:, 1])
                    ]
                else:
                    new_limits = old_limits[keep]
            limits[count]  = new_limits

        # Copy templates to file.
        templates = templates.tocoo()
        if hdf5_compress:
            hfile.create_dataset('temp_x', data=templates.row, compression='gzip')
            hfile.create_dataset('temp_y', data=templates.col, compression='gzip')
            hfile.create_dataset('temp_data', data=templates.data, compression='gzip')
        else:
            hfile.create_dataset('temp_x', data=templates.row)
            hfile.create_dataset('temp_y', data=templates.col)
            hfile.create_dataset('temp_data', data=templates.data)
        hfile.create_dataset('temp_shape', data=numpy.array([N_e, N_t, 2*len(to_keep)], dtype=numpy.int32))
        hfile.close()

        # Rename output filename.
        temporary_path = hfilename
        output_path = file_out_suff + '.templates{}.hdf5'.format(extension)
        if os.path.exists(output_path):
            os.remove(output_path)
        shutil.move(temporary_path, output_path)
    else:
        to_keep = numpy.array([])

    return to_keep