예제 #1
0
def merging_cc(params, nb_cpu, nb_gpu, use_gpu):
    def remove(result, distances, cc_merge):
        do_merge = True
        to_merge = numpy.zeros((0, 2), dtype=numpy.int32)
        g_idx = range(len(distances))
        while do_merge:
            dmax = distances.max()
            idx = numpy.where(distances == dmax)
            one_merge = [idx[0][0], idx[1][0]]
            do_merge = dmax >= cc_merge

            if do_merge:

                elec_ic1 = result['electrodes'][one_merge[0]]
                elec_ic2 = result['electrodes'][one_merge[1]]
                nic1 = one_merge[0] - numpy.where(
                    result['electrodes'] == elec_ic1)[0][0]
                nic2 = one_merge[1] - numpy.where(
                    result['electrodes'] == elec_ic2)[0][0]
                mask1 = result['clusters_' + str(elec_ic1)] > -1
                mask2 = result['clusters_' + str(elec_ic2)] > -1
                tmp1 = numpy.unique(result['clusters_' + str(elec_ic1)][mask1])
                tmp2 = numpy.unique(result['clusters_' + str(elec_ic2)][mask2])
                elements1 = numpy.where(result['clusters_' +
                                               str(elec_ic1)] == tmp1[nic1])[0]
                elements2 = numpy.where(result['clusters_' +
                                               str(elec_ic2)] == tmp2[nic2])[0]

                if len(elements1) > len(elements2):
                    to_remove = one_merge[1]
                    to_keep = one_merge[0]
                    elec = elec_ic2
                    elements = elements2
                else:
                    to_remove = one_merge[0]
                    to_keep = one_merge[1]
                    elec = elec_ic1
                    elements = elements1

                result['data_' + str(elec)] = numpy.delete(result['data_' +
                                                                  str(elec)],
                                                           elements,
                                                           axis=0)
                result['clusters_' + str(elec)] = numpy.delete(
                    result['clusters_' + str(elec)], elements)
                result['times_' + str(elec)] = numpy.delete(
                    result['times_' + str(elec)], elements)
                result['peaks_' + str(elec)] = numpy.delete(
                    result['peaks_' + str(elec)], elements)
                result['electrodes'] = numpy.delete(result['electrodes'],
                                                    to_remove)
                distances = numpy.delete(distances, to_remove, axis=0)
                distances = numpy.delete(distances, to_remove, axis=1)
                to_merge = numpy.vstack(
                    (to_merge, numpy.array([g_idx[to_keep],
                                            g_idx[to_remove]])))
                g_idx.pop(to_remove)

        return to_merge, result

    data_file = params.data_file
    N_e = params.getint('data', 'N_e')
    N_total = params.nb_channels
    N_t = params.getint('detection', 'N_t')
    template_shift = params.getint('detection', 'template_shift')
    blosc_compress = params.getboolean('data', 'blosc_compress')

    N_tm = load_data(params, 'nb_templates')
    nb_temp = int(N_tm // 2)
    to_merge = []
    cc_merge = params.getfloat('clustering', 'cc_merge')
    norm = N_e * N_t

    result = []
    overlap = get_overlaps(params,
                           extension='-merging',
                           erase=True,
                           normalize=True,
                           maxoverlap=False,
                           verbose=False,
                           half=True,
                           use_gpu=use_gpu,
                           nb_cpu=nb_cpu,
                           nb_gpu=nb_gpu)
    overlap.close()
    filename = params.get('data', 'file_out_suff') + '.overlap-merging.hdf5'

    SHARED_MEMORY = get_shared_memory_flag(params)

    if not SHARED_MEMORY:
        over_x, over_y, over_data, over_shape = load_data(params,
                                                          'overlaps-raw',
                                                          extension='-merging')
    else:
        over_x, over_y, over_data, over_shape = load_data_memshared(
            params,
            'overlaps-raw',
            extension='-merging',
            use_gpu=use_gpu,
            nb_cpu=nb_cpu,
            nb_gpu=nb_gpu)

    #sub_comm, is_local = get_local_ring(True)

    #if is_local:

    distances = numpy.zeros((nb_temp, nb_temp), dtype=numpy.float32)

    to_explore = numpy.arange(nb_temp - 1)[comm.rank::comm.size]

    for i in to_explore:

        idx = numpy.where((over_x >= i * nb_temp + i + 1)
                          & (over_x < ((i + 1) * nb_temp)))[0]
        local_x = over_x[idx] - (i * nb_temp + i + 1)
        data = numpy.zeros((nb_temp - (i + 1), over_shape[1]),
                           dtype=numpy.float32)
        data[local_x, over_y[idx]] = over_data[idx]
        distances[i, i + 1:] = numpy.max(data, 1) / norm
        distances[i + 1:, i] = distances[i, i + 1:]

    #Now we need to sync everything across nodes
    distances = gather_array(distances,
                             comm,
                             0,
                             1,
                             'float32',
                             compress=blosc_compress)
    if comm.rank == 0:
        distances = distances.reshape(comm.size, nb_temp, nb_temp)
        distances = numpy.sum(distances, 0)

    #sub_comm.Barrier()
    #sub_comm.Free()

    if comm.rank == 0:
        result = load_data(params, 'clusters')
        to_merge, result = remove(result, distances, cc_merge)

    to_merge = numpy.array(to_merge)
    to_merge = comm.bcast(to_merge, root=0)

    if len(to_merge) > 0:
        slice_templates(params, to_merge=to_merge)
        slice_clusters(params, result)

    comm.Barrier()

    del result, over_x, over_y, over_data

    if comm.rank == 0:
        os.remove(filename)

    return [nb_temp, len(to_merge)]
예제 #2
0
def merging_cc(params, nb_cpu, nb_gpu, use_gpu):


    def remove(result, distances, cc_merge):
        do_merge  = True
        to_merge  = numpy.zeros((0, 2), dtype=numpy.int32)
        g_idx     = range(len(distances))
        while do_merge:
            dmax      = distances.max()
            idx       = numpy.where(distances == dmax)
            one_merge = [idx[0][0], idx[1][0]]
            do_merge  = dmax >= cc_merge

            if do_merge:

                elec_ic1  = result['electrodes'][one_merge[0]]
                elec_ic2  = result['electrodes'][one_merge[1]]
                nic1      = one_merge[0] - numpy.where(result['electrodes'] == elec_ic1)[0][0]
                nic2      = one_merge[1] - numpy.where(result['electrodes'] == elec_ic2)[0][0]
                mask1     = result['clusters_' + str(elec_ic1)] > -1
                mask2     = result['clusters_' + str(elec_ic2)] > -1
                tmp1      = numpy.unique(result['clusters_' + str(elec_ic1)][mask1])
                tmp2      = numpy.unique(result['clusters_' + str(elec_ic2)][mask2])
                elements1 = numpy.where(result['clusters_' + str(elec_ic1)] == tmp1[nic1])[0]
                elements2 = numpy.where(result['clusters_' + str(elec_ic2)] == tmp2[nic2])[0]

                if len(elements1) > len(elements2):
                    to_remove = one_merge[1]
                    to_keep   = one_merge[0]
                    elec      = elec_ic2
                    elements  = elements2
                else:
                    to_remove = one_merge[0]
                    to_keep   = one_merge[1]
                    elec      = elec_ic1
                    elements  = elements1

                result['data_' + str(elec)]     = numpy.delete(result['data_' + str(elec)], elements, axis=0)
                result['clusters_' + str(elec)] = numpy.delete(result['clusters_' + str(elec)], elements)
                result['times_' + str(elec)]    = numpy.delete(result['times_' + str(elec)], elements)
                result['peaks_' + str(elec)]    = numpy.delete(result['peaks_' + str(elec)], elements)
                result['electrodes']            = numpy.delete(result['electrodes'], to_remove)
                distances                       = numpy.delete(distances, to_remove, axis=0)
                distances                       = numpy.delete(distances, to_remove, axis=1)
                to_merge                        = numpy.vstack((to_merge, numpy.array([g_idx[to_keep], g_idx[to_remove]])))
                g_idx.pop(to_remove)

        return to_merge, result

    data_file      = params.data_file
    N_e            = params.getint('data', 'N_e')
    N_total        = params.nb_channels
    N_t            = params.getint('detection', 'N_t')
    template_shift = params.getint('detection', 'template_shift')

    templates      = load_data(params, 'templates')
    x,        N_tm = templates.shape
    nb_temp        = N_tm//2
    to_merge       = []
    cc_merge       = params.getfloat('clustering', 'cc_merge')

    result   = []
    overlap  = get_overlaps(params, extension='-merging', erase=True, normalize=True, maxoverlap=False, verbose=False, half=True, use_gpu=use_gpu, nb_cpu=nb_cpu, nb_gpu=nb_gpu)
    filename = params.get('data', 'file_out_suff') + '.overlap-merging.hdf5'

    if comm.rank > 0:
        overlap.close()
    else:
        over_x     = overlap.get('over_x')[:]
        over_y     = overlap.get('over_y')[:]
        over_data  = overlap.get('over_data')[:]
        over_shape = overlap.get('over_shape')[:]
        overlap.close()

        overlap   = scipy.sparse.csr_matrix((over_data, (over_x, over_y)), shape=over_shape)
        result    = load_data(params, 'clusters')
        distances = numpy.zeros((nb_temp, nb_temp), dtype=numpy.float32)
        for i in xrange(nb_temp-1):
            distances[i, i+1:] = numpy.max(overlap[i*nb_temp+i+1:(i+1)*nb_temp].toarray(), 1)
            distances[i+1:, i] = distances[i, i+1:]

        distances /= (N_e*N_t)
        to_merge, result = remove(result, distances, cc_merge)

    to_merge = numpy.array(to_merge)
    to_merge = comm.bcast(to_merge, root=0)

    if len(to_merge) > 0:
        slice_templates(params, to_merge=to_merge)
        slice_clusters(params, result)

    comm.Barrier()

    if comm.rank == 0:
        os.remove(filename)

    return [nb_temp, len(to_merge)]