Ejemplo n.º 1
0
def merging_cc(params, nb_cpu, nb_gpu, use_gpu):
    def remove(result, distances, cc_merge):
        do_merge = True
        to_merge = numpy.zeros((0, 2), dtype=numpy.int32)
        g_idx = range(len(distances))
        while do_merge:
            dmax = distances.max()
            idx = numpy.where(distances == dmax)
            one_merge = [idx[0][0], idx[1][0]]
            do_merge = dmax >= cc_merge

            if do_merge:

                elec_ic1 = result['electrodes'][one_merge[0]]
                elec_ic2 = result['electrodes'][one_merge[1]]
                nic1 = one_merge[0] - numpy.where(
                    result['electrodes'] == elec_ic1)[0][0]
                nic2 = one_merge[1] - numpy.where(
                    result['electrodes'] == elec_ic2)[0][0]
                mask1 = result['clusters_' + str(elec_ic1)] > -1
                mask2 = result['clusters_' + str(elec_ic2)] > -1
                tmp1 = numpy.unique(result['clusters_' + str(elec_ic1)][mask1])
                tmp2 = numpy.unique(result['clusters_' + str(elec_ic2)][mask2])
                elements1 = numpy.where(result['clusters_' +
                                               str(elec_ic1)] == tmp1[nic1])[0]
                elements2 = numpy.where(result['clusters_' +
                                               str(elec_ic2)] == tmp2[nic2])[0]

                if len(elements1) > len(elements2):
                    to_remove = one_merge[1]
                    to_keep = one_merge[0]
                    elec = elec_ic2
                    elements = elements2
                else:
                    to_remove = one_merge[0]
                    to_keep = one_merge[1]
                    elec = elec_ic1
                    elements = elements1

                result['data_' + str(elec)] = numpy.delete(result['data_' +
                                                                  str(elec)],
                                                           elements,
                                                           axis=0)
                result['clusters_' + str(elec)] = numpy.delete(
                    result['clusters_' + str(elec)], elements)
                result['times_' + str(elec)] = numpy.delete(
                    result['times_' + str(elec)], elements)
                result['peaks_' + str(elec)] = numpy.delete(
                    result['peaks_' + str(elec)], elements)
                result['electrodes'] = numpy.delete(result['electrodes'],
                                                    to_remove)
                distances = numpy.delete(distances, to_remove, axis=0)
                distances = numpy.delete(distances, to_remove, axis=1)
                to_merge = numpy.vstack(
                    (to_merge, numpy.array([g_idx[to_keep],
                                            g_idx[to_remove]])))
                g_idx.pop(to_remove)

        return to_merge, result

    data_file = params.data_file
    N_e = params.getint('data', 'N_e')
    N_total = params.nb_channels
    N_t = params.getint('detection', 'N_t')
    template_shift = params.getint('detection', 'template_shift')
    blosc_compress = params.getboolean('data', 'blosc_compress')

    N_tm = load_data(params, 'nb_templates')
    nb_temp = int(N_tm // 2)
    to_merge = []
    cc_merge = params.getfloat('clustering', 'cc_merge')
    norm = N_e * N_t

    result = []
    overlap = get_overlaps(params,
                           extension='-merging',
                           erase=True,
                           normalize=True,
                           maxoverlap=False,
                           verbose=False,
                           half=True,
                           use_gpu=use_gpu,
                           nb_cpu=nb_cpu,
                           nb_gpu=nb_gpu)
    overlap.close()
    filename = params.get('data', 'file_out_suff') + '.overlap-merging.hdf5'

    SHARED_MEMORY = get_shared_memory_flag(params)

    if not SHARED_MEMORY:
        over_x, over_y, over_data, over_shape = load_data(params,
                                                          'overlaps-raw',
                                                          extension='-merging')
    else:
        over_x, over_y, over_data, over_shape = load_data_memshared(
            params,
            'overlaps-raw',
            extension='-merging',
            use_gpu=use_gpu,
            nb_cpu=nb_cpu,
            nb_gpu=nb_gpu)

    #sub_comm, is_local = get_local_ring(True)

    #if is_local:

    distances = numpy.zeros((nb_temp, nb_temp), dtype=numpy.float32)

    to_explore = numpy.arange(nb_temp - 1)[comm.rank::comm.size]

    for i in to_explore:

        idx = numpy.where((over_x >= i * nb_temp + i + 1)
                          & (over_x < ((i + 1) * nb_temp)))[0]
        local_x = over_x[idx] - (i * nb_temp + i + 1)
        data = numpy.zeros((nb_temp - (i + 1), over_shape[1]),
                           dtype=numpy.float32)
        data[local_x, over_y[idx]] = over_data[idx]
        distances[i, i + 1:] = numpy.max(data, 1) / norm
        distances[i + 1:, i] = distances[i, i + 1:]

    #Now we need to sync everything across nodes
    distances = gather_array(distances,
                             comm,
                             0,
                             1,
                             'float32',
                             compress=blosc_compress)
    if comm.rank == 0:
        distances = distances.reshape(comm.size, nb_temp, nb_temp)
        distances = numpy.sum(distances, 0)

    #sub_comm.Barrier()
    #sub_comm.Free()

    if comm.rank == 0:
        result = load_data(params, 'clusters')
        to_merge, result = remove(result, distances, cc_merge)

    to_merge = numpy.array(to_merge)
    to_merge = comm.bcast(to_merge, root=0)

    if len(to_merge) > 0:
        slice_templates(params, to_merge=to_merge)
        slice_clusters(params, result)

    comm.Barrier()

    del result, over_x, over_y, over_data

    if comm.rank == 0:
        os.remove(filename)

    return [nb_temp, len(to_merge)]
Ejemplo n.º 2
0
def delete_mixtures(params, nb_cpu, nb_gpu, use_gpu):

    data_file = params.data_file
    N_e = params.getint('data', 'N_e')
    N_total = params.nb_channels
    N_t = params.getint('detection', 'N_t')
    template_shift = params.getint('detection', 'template_shift')
    cc_merge = params.getfloat('clustering', 'cc_mixtures')
    mixtures = []
    to_remove = []

    filename = params.get('data', 'file_out_suff') + '.overlap-mixtures.hdf5'
    norm_templates = load_data(params, 'norm-templates')
    best_elec = load_data(params, 'electrodes')
    limits = load_data(params, 'limits')
    nodes, edges = get_nodes_and_edges(params)
    inv_nodes = numpy.zeros(N_total, dtype=numpy.int32)
    inv_nodes[nodes] = numpy.argsort(nodes)

    overlap = get_overlaps(params,
                           extension='-mixtures',
                           erase=True,
                           normalize=False,
                           maxoverlap=False,
                           verbose=False,
                           half=True,
                           use_gpu=use_gpu,
                           nb_cpu=nb_cpu,
                           nb_gpu=nb_gpu)
    overlap.close()

    SHARED_MEMORY = get_shared_memory_flag(params)

    if SHARED_MEMORY:
        c_overs = load_data_memshared(params,
                                      'overlaps',
                                      extension='-mixtures',
                                      use_gpu=use_gpu,
                                      nb_cpu=nb_cpu,
                                      nb_gpu=nb_gpu)
    else:
        c_overs = load_data(params, 'overlaps', extension='-mixtures')

    if SHARED_MEMORY:
        templates = load_data_memshared(params, 'templates', normalize=False)
    else:
        templates = load_data(params, 'templates')

    x, N_tm = templates.shape
    nb_temp = int(N_tm // 2)
    merged = [nb_temp, 0]

    overlap_0 = numpy.zeros(nb_temp, dtype=numpy.float32)
    distances = numpy.zeros((nb_temp, nb_temp), dtype=numpy.int32)

    for i in xrange(nb_temp - 1):
        data = c_overs[i].toarray()
        distances[i, i + 1:] = numpy.argmax(data[i + 1:, :], 1)
        distances[i + 1:, i] = distances[i, i + 1:]
        overlap_0[i] = data[i, N_t]

    all_temp = numpy.arange(comm.rank, nb_temp, comm.size)
    sorted_temp = numpy.argsort(
        norm_templates[:nb_temp])[::-1][comm.rank::comm.size]
    M = numpy.zeros((2, 2), dtype=numpy.float32)
    V = numpy.zeros((2, 1), dtype=numpy.float32)

    to_explore = xrange(comm.rank, len(sorted_temp), comm.size)
    if comm.rank == 0:
        to_explore = get_tqdm_progressbar(to_explore)

    for count, k in enumerate(to_explore):

        k = sorted_temp[k]
        electrodes = numpy.take(inv_nodes, edges[nodes[best_elec[k]]])
        overlap_k = c_overs[k]
        is_in_area = numpy.in1d(best_elec, electrodes)
        all_idx = numpy.arange(len(best_elec))[is_in_area]
        been_found = False
        t_k = None

        for i in all_idx:
            t_i = None
            if not been_found:
                overlap_i = c_overs[i]
                M[0, 0] = overlap_0[i]
                V[0, 0] = overlap_k[i, distances[k, i]]
                for j in all_idx[i + 1:]:
                    t_j = None
                    M[1, 1] = overlap_0[j]
                    M[1, 0] = overlap_i[j, distances[k, i] - distances[k, j]]
                    M[0, 1] = M[1, 0]
                    V[1, 0] = overlap_k[j, distances[k, j]]
                    try:
                        [a1, a2] = numpy.dot(scipy.linalg.inv(M), V)
                    except Exception:
                        [a1, a2] = [0, 0]
                    a1_lim = limits[i]
                    a2_lim = limits[j]
                    is_a1 = (a1_lim[0] <= a1) and (a1 <= a1_lim[1])
                    is_a2 = (a2_lim[0] <= a2) and (a2 <= a2_lim[1])
                    if is_a1 and is_a2:
                        if t_k is None:
                            t_k = templates[:, k].toarray().ravel()
                        if t_i is None:
                            t_i = templates[:, i].toarray().ravel()
                        if t_j is None:
                            t_j = templates[:, j].toarray().ravel()
                        new_template = (a1 * t_i + a2 * t_j)
                        similarity = numpy.corrcoef(t_k, new_template)[0, 1]
                        local_overlap = numpy.corrcoef(t_i, t_j)[0, 1]
                        if similarity > cc_merge and local_overlap < cc_merge:
                            if k not in mixtures:
                                mixtures += [k]
                                been_found = True
                                #print "Template", k, 'is sum of (%d, %g) and (%d,%g)' %(i, a1, j, a2)
                                break
    sys.stderr.flush()
    #print mixtures
    to_remove = numpy.unique(numpy.array(mixtures, dtype=numpy.int32))
    to_remove = all_gather_array(to_remove, comm, 0, dtype='int32')

    if len(to_remove) > 0 and comm.rank == 0:
        result = load_data(params, 'clusters')
        slice_templates(params, to_remove)
        slice_clusters(params, result, to_remove=to_remove)

    comm.Barrier()

    del c_overs

    if comm.rank == 0:
        os.remove(filename)

    return [nb_temp, len(to_remove)]
Ejemplo n.º 3
0
def delete_mixtures(comm, params, nb_cpu, nb_gpu, use_gpu):
        
    templates      = load_data(params, 'templates')
    templates      = load_data(params, 'templates')
    N_e            = params.getint('data', 'N_e')
    N_t            = params.getint('data', 'N_t')
    cc_merge       = params.getfloat('clustering', 'cc_merge')
    x,        N_tm = templates.shape
    nb_temp        = N_tm//2
    merged         = [nb_temp, 0]
    mixtures       = []
    to_remove      = []

    overlap  = get_overlaps(comm, params, extension='-mixtures', erase=True, normalize=False, maxoverlap=False, verbose=False, half=True, use_gpu=use_gpu, nb_cpu=nb_cpu, nb_gpu=nb_gpu)
    filename = params.get('data', 'file_out_suff') + '.overlap-mixtures.hdf5'
    result   = []
    
    norm_templates   = load_data(params, 'norm-templates')
    templates        = load_data(params, 'templates')
    result           = load_data(params, 'clusters')
    best_elec        = load_data(params, 'electrodes')
    limits           = load_data(params, 'limits')
    N_total          = params.getint('data', 'N_total')
    nodes, edges     = get_nodes_and_edges(params)
    inv_nodes        = numpy.zeros(N_total, dtype=numpy.int32)
    inv_nodes[nodes] = numpy.argsort(nodes)

    distances = numpy.zeros((nb_temp, nb_temp), dtype=numpy.float32)

    over_x     = overlap.get('over_x')[:]
    over_y     = overlap.get('over_y')[:]
    over_data  = overlap.get('over_data')[:]
    over_shape = overlap.get('over_shape')[:]
    overlap.close()

    overlap    = scipy.sparse.csr_matrix((over_data, (over_x, over_y)), shape=over_shape)

    for i in xrange(nb_temp-1):
        distances[i, i+1:] = numpy.argmax(overlap[i*nb_temp+i+1:(i+1)*nb_temp].toarray(), 1)
        distances[i+1:, i] = distances[i, i+1:]

    all_temp  = numpy.arange(comm.rank, nb_temp, comm.size)
    overlap_0 = overlap[:, N_t].toarray().reshape(nb_temp, nb_temp)
    if comm.rank == 0:
        pbar = get_progressbar(size=len(all_temp)).start()

    sorted_temp    = numpy.argsort(norm_templates[:nb_temp])[::-1][comm.rank::comm.size]
    M              = numpy.zeros((2, 2), dtype=numpy.float32)
    V              = numpy.zeros((2, 1), dtype=numpy.float32)

    for count, k in enumerate(sorted_temp):

        electrodes    = numpy.take(inv_nodes, edges[nodes[best_elec[k]]])
        overlap_k     = overlap[k*nb_temp:(k+1)*nb_temp].tolil()
        is_in_area    = numpy.in1d(best_elec, electrodes)
        all_idx       = numpy.arange(len(best_elec))[is_in_area]
        been_found    = False

        for i in all_idx:
            if not been_found:
                overlap_i = overlap[i*nb_temp:(i+1)*nb_temp].tolil()
                M[0, 0]   = overlap_0[i, i]
                V[0, 0]   = overlap_k[i, distances[k, i]]
                for j in all_idx[i+1:]:
                    M[1, 1]  = overlap_0[j, j]
                    M[1, 0]  = overlap_i[j, distances[k, i] - distances[k, j]]
                    M[0, 1]  = M[1, 0]
                    V[1, 0]  = overlap_k[j, distances[k, j]]
                    try:
                        [a1, a2] = numpy.dot(scipy.linalg.inv(M), V)
                    except Exception:
                        [a1, a2] = [0, 0]
                    a1_lim   = limits[i]
                    a2_lim   = limits[j]
                    is_a1    = (a1_lim[0] <= a1) and (a1 <= a1_lim[1])
                    is_a2    = (a2_lim[0] <= a2) and (a2 <= a2_lim[1])
                    if is_a1 and is_a2:
                        new_template = (a1*templates[:, i].toarray() + a2*templates[:, j].toarray()).ravel()
                        similarity   = numpy.corrcoef(templates[:, k].toarray().ravel(), new_template)[0, 1]
                        if similarity > cc_merge:
                            if k not in mixtures:
                                mixtures  += [k]
                                been_found = True 
                                break
                                #print "Template", k, 'is sum of (%d, %g) and (%d,%g)' %(i, a1, j, a2)

        if comm.rank == 0:
            pbar.update(count)

    if comm.rank == 0:
        pbar.finish()
    
    #print mixtures
    to_remove = numpy.unique(numpy.array(mixtures, dtype=numpy.int32))    
    to_remove = all_gather_array(to_remove, comm, 0, dtype='int32')
    
    if len(to_remove) > 0:
        slice_templates(comm, params, to_remove)
        slice_clusters(comm, params, result, to_remove=to_remove)

    if comm.rank == 0:
        os.remove(filename)

    return [nb_temp, len(to_remove)]
Ejemplo n.º 4
0
def merging_cc(comm, params, nb_cpu, nb_gpu, use_gpu):


    def remove(result, distances, cc_merge):
        do_merge  = True
        to_merge  = numpy.zeros((0, 2), dtype=numpy.int32)
        g_idx     = range(len(distances))
        while do_merge:
            dmax      = distances.max()
            idx       = numpy.where(distances == dmax)
            one_merge = [idx[0][0], idx[1][0]]
            do_merge  = dmax >= cc_merge

            if do_merge:

                elec_ic1  = result['electrodes'][one_merge[0]]
                elec_ic2  = result['electrodes'][one_merge[1]]
                nic1      = one_merge[0] - numpy.where(result['electrodes'] == elec_ic1)[0][0]
                nic2      = one_merge[1] - numpy.where(result['electrodes'] == elec_ic2)[0][0]
                mask1     = result['clusters_' + str(elec_ic1)] > -1
                mask2     = result['clusters_' + str(elec_ic2)] > -1
                tmp1      = numpy.unique(result['clusters_' + str(elec_ic1)][mask1])
                tmp2      = numpy.unique(result['clusters_' + str(elec_ic2)][mask2])
                elements1 = numpy.where(result['clusters_' + str(elec_ic1)] == tmp1[nic1])[0]
                elements2 = numpy.where(result['clusters_' + str(elec_ic2)] == tmp2[nic2])[0]

                if len(elements1) > len(elements2):
                    to_remove = one_merge[1]
                    to_keep   = one_merge[0]
                    elec      = elec_ic2
                    elements  = elements2
                else:
                    to_remove = one_merge[0]
                    to_keep   = one_merge[1]
                    elec      = elec_ic1
                    elements  = elements1

                result['data_' + str(elec)]     = numpy.delete(result['data_' + str(elec)], elements, axis=0)
                result['clusters_' + str(elec)] = numpy.delete(result['clusters_' + str(elec)], elements) 
                result['times_' + str(elec)]    = numpy.delete(result['times_' + str(elec)], elements)
                result['peaks_' + str(elec)]    = numpy.delete(result['peaks_' + str(elec)], elements)
                result['electrodes']            = numpy.delete(result['electrodes'], to_remove)
                distances                       = numpy.delete(distances, to_remove, axis=0)
                distances                       = numpy.delete(distances, to_remove, axis=1)
                to_merge                        = numpy.vstack((to_merge, numpy.array([g_idx[to_keep], g_idx[to_remove]])))
                g_idx.pop(to_remove)

        return to_merge, result
            
    templates      = load_data(params, 'templates')
    N_e            = params.getint('data', 'N_e')
    N_t            = params.getint('data', 'N_t')
    x,        N_tm = templates.shape
    nb_temp        = N_tm//2
    to_merge       = []
    cc_merge       = params.getfloat('clustering', 'cc_merge')
        
    result   = []
    overlap  = get_overlaps(comm, params, extension='-merging', erase=True, normalize=True, maxoverlap=False, verbose=False, half=True, use_gpu=use_gpu, nb_cpu=nb_cpu, nb_gpu=nb_gpu)
    filename = params.get('data', 'file_out_suff') + '.overlap-merging.hdf5'

    if comm.rank > 0:
        overlap.file.close()
    else:
        over_x     = overlap.get('over_x')[:]
        over_y     = overlap.get('over_y')[:]
        over_data  = overlap.get('over_data')[:]
        over_shape = overlap.get('over_shape')[:]
        overlap.close()

        overlap   = scipy.sparse.csr_matrix((over_data, (over_x, over_y)), shape=over_shape)
        result    = load_data(params, 'clusters')
        distances = numpy.zeros((nb_temp, nb_temp), dtype=numpy.float32)
        for i in xrange(nb_temp-1):
            distances[i, i+1:] = numpy.max(overlap[i*nb_temp+i+1:(i+1)*nb_temp].toarray(), 1)
            distances[i+1:, i] = distances[i, i+1:]

        distances /= (N_e*N_t)
        to_merge, result = remove(result, distances, cc_merge)       

    to_merge = numpy.array(to_merge)
    to_merge = comm.bcast(to_merge, root=0)
    
    if len(to_merge) > 0:
        slice_templates(comm, params, to_merge=to_merge)
        slice_clusters(comm, params, result)

    if comm.rank == 0:
        os.remove(filename)

    return [nb_temp, len(to_merge)]
Ejemplo n.º 5
0
def main(params, nb_cpu, nb_gpu, use_gpu):

    #################################################################
    # params = detect_memory(params)
    _ = init_logging(params.logfile)
    SHARED_MEMORY = get_shared_memory_flag(params)
    logger = logging.getLogger('circus.fitting')
    data_file = params.data_file
    n_e = params.getint('data', 'N_e')
    n_total = params.nb_channels
    n_t = params.getint('detection', 'N_t')
    template_shift = params.getint('detection', 'template_shift')
    # file_out = params.get('data', 'file_out')
    file_out_suff = params.get('data', 'file_out_suff')
    sign_peaks = params.get('detection', 'peaks')
    matched_filter = params.getboolean('detection', 'matched-filter')
    # spike_thresh = params.getfloat('detection', 'spike_thresh')
    ratio_thresh = params.getfloat('fitting', 'ratio_thresh')
    two_components = params.getboolean('fitting', 'two_components')
    # spike_width = params.getfloat('detection', 'spike_width')
    # dist_peaks = params.getint('detection', 'dist_peaks')
    do_temporal_whitening = params.getboolean('whitening', 'temporal')
    do_spatial_whitening = params.getboolean('whitening', 'spatial')
    templates_normalization = params.getboolean('clustering', 'templates_normalization')  # TODO test, switch, test!
    chunk_size = detect_memory(params, fitting=True)
    gpu_only = params.getboolean('fitting', 'gpu_only')
    nodes, edges = get_nodes_and_edges(params)
    tmp_limits = params.get('fitting', 'amp_limits').replace('(', '').replace(')', '').split(',')
    tmp_limits = [float(v) for v in tmp_limits]
    amp_auto = params.getboolean('fitting', 'amp_auto')
    auto_nb_chances = params.getboolean('fitting', 'auto_nb_chances')
    if auto_nb_chances:
        nb_chances = io.load_data(params, 'nb_chances')
        max_nb_chances = params.getint('fitting', 'max_nb_chances')
        percent_nb_chances = params.getfloat('fitting', 'percent_nb_chances')
        total_nb_chances = max(1, numpy.nanpercentile(nb_chances, percent_nb_chances))
        total_nb_chances = min(total_nb_chances, max_nb_chances)
        if comm.rank == 0:
            print_and_log(['nb_chances set automatically to %g' %total_nb_chances], 'debug', logger)
    else:
        total_nb_chances = params.getfloat('fitting', 'nb_chances')
    max_chunk = params.getfloat('fitting', 'max_chunk')
    # noise_thr = params.getfloat('clustering', 'noise_thr')
    collect_all = params.getboolean('fitting', 'collect_all')
    min_second_component = params.getfloat('fitting', 'min_second_component')
    debug = params.getboolean('fitting', 'debug')
    ignore_dead_times = params.getboolean('triggers', 'ignore_times')
    inv_nodes = numpy.zeros(n_total, dtype=numpy.int32)
    inv_nodes[nodes] = numpy.arange(len(nodes))
    data_file.open()
    #################################################################

    if use_gpu:
        import cudamat as cmt
        # # Need to properly handle multi GPU per MPI nodes?
        if nb_gpu > nb_cpu:
            gpu_id = int(comm.rank // nb_cpu)
        else:
            gpu_id = 0
        cmt.cuda_set_device(gpu_id)
        cmt.init()
        cmt.cuda_sync_threads()

    if SHARED_MEMORY:
        templates, _ = io.load_data_memshared(params, 'templates', normalize=templates_normalization, transpose=True)
        N_tm, x = templates.shape
    else:
        templates = io.load_data(params, 'templates')
        x, N_tm = templates.shape

    temp_2_shift = 2 * template_shift
    temp_3_shift = 3 * template_shift
    full_gpu = use_gpu and gpu_only
    n_tm = N_tm // 2
    n_scalar = n_e * n_t

    temp_window = numpy.arange(-template_shift, template_shift + 1)
    size_window = n_e * (2 * template_shift + 1)

    if not amp_auto:
        amp_limits = numpy.zeros((n_tm, 2))
        amp_limits[:, 0] = tmp_limits[0]
        amp_limits[:, 1] = tmp_limits[1]
    else:
        amp_limits = io.load_data(params, 'limits')

    norm_templates = io.load_data(params, 'norm-templates')
    if not templates_normalization:
        norm_templates_2 = (norm_templates ** 2.0) * n_scalar

    if not SHARED_MEMORY:
        # Normalize templates (if necessary).
        if templates_normalization:
            for idx in range(templates.shape[1]):
                myslice = numpy.arange(templates.indptr[idx], templates.indptr[idx+1])
                templates.data[myslice] /= norm_templates[idx]
        # Transpose templates.
        templates = templates.T

    waveform_neg = numpy.empty(0)  # default assignment (for PyCharm code inspection)
    matched_thresholds_neg = None  # default assignment (for PyCharm code inspection)
    waveform_pos = numpy.empty(0)  # default assignment (for PyCharm code inspection)
    matched_thresholds_pos = None  # default assignment (for PyCharm code inspection)
    if matched_filter:
        if sign_peaks in ['negative', 'both']:
            waveform_neg = io.load_data(params, 'waveform')[::-1]
            waveform_neg /= (numpy.abs(numpy.sum(waveform_neg)) * len(waveform_neg))
            matched_thresholds_neg = io.load_data(params, 'matched-thresholds')
        if sign_peaks in ['positive', 'both']:
            waveform_pos = io.load_data(params, 'waveform-pos')[::-1]
            waveform_pos /= (numpy.abs(numpy.sum(waveform_pos)) * len(waveform_pos))
            matched_thresholds_pos = io.load_data(params, 'matched-thresholds-pos')

    if ignore_dead_times:
        all_dead_times = get_dead_times(params)
    else:
        all_dead_times = None  # default assignment (for PyCharm code inspection)

    thresholds = io.get_accurate_thresholds(params, ratio_thresh)

    neighbors = {}
    if collect_all:
        for i in range(0, n_tm):
            tmp = templates[i, :].toarray().reshape(n_e, n_t)
            if templates_normalization:
                tmp = tmp * norm_templates[i]
            neighbors[i] = numpy.where(numpy.sum(tmp, axis=1) != 0.0)[0]

    if use_gpu:
        templates = cmt.SparseCUDAMatrix(templates, copy_on_host=False)

    info_string = ''

    if comm.rank == 0:
        if use_gpu:
            info_string = "using %d GPUs" % comm.size
        else:
            info_string = "using %d CPUs" % comm.size

    comm.Barrier()

    c_overlap = io.get_overlaps(params, nb_cpu=nb_cpu, nb_gpu=nb_gpu, use_gpu=use_gpu)
    over_shape = c_overlap.get('over_shape')[:]
    n_over = int(numpy.sqrt(over_shape[0]))
    s_over = over_shape[1]
    # # If the number of overlaps is different from templates, we need to recompute them.
    if n_over != N_tm:
        if comm.rank == 0:
            print_and_log(['Templates have been modified, recomputing the overlaps...'], 'default', logger)
        c_overlap = io.get_overlaps(params, erase=True, nb_cpu=nb_cpu, nb_gpu=nb_gpu, use_gpu=use_gpu)
        over_shape = c_overlap.get('over_shape')[:]
        n_over = int(numpy.sqrt(over_shape[0]))
        s_over = over_shape[1]

    if SHARED_MEMORY:
        c_overs, _ = io.load_data_memshared(params, 'overlaps')
    else:
        c_overs = io.load_data(params, 'overlaps')

    comm.Barrier()

    if n_tm == 0:
        if comm.rank == 0:
            print_and_log(["No templates present. Redo clustering?"], 'default', logger)

        sys.exit(0)

    if comm.rank == 0:
        print_and_log(["Here comes the SpyKING CIRCUS %s and %d templates..." % (info_string, n_tm)], 'default', logger)
        purge(file_out_suff, '.data')

    if do_spatial_whitening:
        spatial_whitening = io.load_data(params, 'spatial_whitening')
    else:
        spatial_whitening = None  # default assignment (for PyCharm code inspection)
    if do_temporal_whitening:
        temporal_whitening = io.load_data(params, 'temporal_whitening')
    else:
        temporal_whitening = None  # default assignment (for PyCharm code inspection)

    if full_gpu:
        try:
            # If memory on the GPU is large enough, we load the overlaps onto it
            for i in range(n_over):
                c_overs[i] = cmt.SparseCUDAMatrix(c_overs[i], copy_on_host=False)
        except Exception:
            if comm.rank == 0:
                print_and_log(["Not enough memory on GPUs: GPUs are used for projection only"], 'info', logger)
            for i in range(n_over):
                if i in c_overs:
                    del c_overs[i]
            full_gpu = False

    nb_chunks, last_chunk_len = data_file.analyze(chunk_size)
    processed_chunks = int(min(nb_chunks, max_chunk))

    comm.Barrier()
    spiketimes_file = open(file_out_suff + '.spiketimes-%d.data' % comm.rank, 'wb')
    comm.Barrier()
    amplitudes_file = open(file_out_suff + '.amplitudes-%d.data' % comm.rank, 'wb')
    comm.Barrier()
    templates_file = open(file_out_suff + '.templates-%d.data' % comm.rank, 'wb')
    comm.Barrier()

    if collect_all:
        garbage_times_file = open(file_out_suff + '.gspiketimes-%d.data' % comm.rank, 'wb')
        comm.Barrier()
        garbage_temp_file = open(file_out_suff + '.gtemplates-%d.data' % comm.rank, 'wb')
        comm.Barrier()
    else:
        garbage_times_file = None  # default assignment (for PyCharm code inspection)
        garbage_temp_file = None  # default assignment (for PyCharm code inspection)

    if debug:
        # Open debug files.
        chunk_nbs_debug_file = open(file_out_suff + '.chunk_nbs_debug_%d.data' % comm.rank, mode='wb')
        comm.Barrier()
        iteration_nbs_debug_file = open(file_out_suff + '.iteration_nbs_debug_%d.data' % comm.rank, mode='wb')
        comm.Barrier()
        peak_nbs_debug_file = open(file_out_suff + '.peak_nbs_debug_%d.data' % comm.rank, mode='wb')
        comm.Barrier()
        peak_local_time_steps_debug_file = open(
            file_out_suff + '.peak_local_time_steps_debug_%d.data' % comm.rank, mode='wb'
        )
        comm.Barrier()
        peak_time_steps_debug_file = open(file_out_suff + '.peak_time_steps_debug_%d.data' % comm.rank, mode='wb')
        comm.Barrier()
        peak_scalar_products_debug_file = open(
            file_out_suff + '.peak_scalar_products_debug_%d.data' % comm.rank, mode='wb'
        )
        comm.Barrier()
        peak_solved_flags_debug_file = open(file_out_suff + '.peak_solved_flags_debug_%d.data' % comm.rank, mode='wb')
        comm.Barrier()
        template_nbs_debug_file = open(file_out_suff + '.template_nbs_debug_%d.data' % comm.rank, mode='wb')
        comm.Barrier()
        success_flags_debug_file = open(file_out_suff + '.success_flags_debug_%d.data' % comm.rank, mode='wb')
        comm.Barrier()
    else:
        chunk_nbs_debug_file = None  # default assignment (for PyCharm code inspection)
        iteration_nbs_debug_file = None  # default assignment (for PyCharm code inspection)
        peak_nbs_debug_file = None  # default assignment (for PyCharm code inspection)
        peak_local_time_steps_debug_file = None  # default assignment (for PyCharm code inspection)
        peak_time_steps_debug_file = None  # default assignment (for PyCharm code inspection)
        peak_scalar_products_debug_file = None  # default assignment (for PyCharm code inspection)
        peak_solved_flags_debug_file = None  # default assignment (for PyCharm code inspection)
        template_nbs_debug_file = None  # default assignment (for PyCharm code inspection)
        success_flags_debug_file = None  # default assignment (for PyCharm code inspection)

    if use_gpu and do_spatial_whitening:
        spatial_whitening = cmt.CUDAMatrix(spatial_whitening, copy_on_host=False)

    last_chunk_size = 0
    slice_indices = numpy.zeros(0, dtype=numpy.int32)

    to_explore = range(comm.rank, processed_chunks, comm.size)

    if comm.rank == 0:
        to_explore = get_tqdm_progressbar(params, to_explore)

    for gcount, gidx in enumerate(to_explore):
        # print "Node", comm.rank, "is analyzing chunk", gidx, "/", nb_chunks, " ..."
        # # We need to deal with the borders by taking chunks of size [0, chunck_size + template_shift].

        is_first = data_file.is_first_chunk(gidx, nb_chunks)
        is_last = data_file.is_last_chunk(gidx, nb_chunks)

        if not (is_first and is_last):
            if is_last:
                padding = (-temp_3_shift, 0)
            elif is_first:
                padding = (0, temp_3_shift)
            else:
                padding = (-temp_3_shift, temp_3_shift)
        else:
            padding = (0, 0)

        result = {
            'spiketimes': [],
            'amplitudes': [],
            'templates': [],
        }
        result_debug = {
            'chunk_nbs': [],
            'iteration_nbs': [],
            'peak_nbs': [],
            'peak_local_time_steps': [],
            'peak_time_steps': [],
            'peak_scalar_products': [],
            'peak_solved_flags': [],
            'template_nbs': [],
            'success_flags': [],
        }

        local_chunk, t_offset = data_file.get_data(gidx, chunk_size, padding, nodes=nodes)           
        len_chunk = len(local_chunk)

        if do_spatial_whitening:
            if use_gpu:
                local_chunk = cmt.CUDAMatrix(local_chunk, copy_on_host=False)
                local_chunk = local_chunk.dot(spatial_whitening).asarray()
            else:
                local_chunk = numpy.dot(local_chunk, spatial_whitening)
        if do_temporal_whitening:
            local_chunk = scipy.ndimage.filters.convolve1d(local_chunk, temporal_whitening, axis=0, mode='constant')

        # Extracting peaks.

        all_found_spikes = {}
        if collect_all:
            for i in range(n_e):
                all_found_spikes[i] = []

        local_peaktimes = [numpy.empty(0, dtype=numpy.uint32)]

        if matched_filter:
            if sign_peaks in ['positive', 'both']:
                filter_chunk = scipy.ndimage.filters.convolve1d(local_chunk, waveform_pos, axis=0, mode='constant')
                for i in range(n_e):
                    peaktimes = scipy.signal.find_peaks(filter_chunk[:, i], height=matched_thresholds_pos[i])[0]
                    local_peaktimes.append(peaktimes)
                    if collect_all:
                        all_found_spikes[i] += peaktimes.tolist()
            if sign_peaks in ['negative', 'both']:
                filter_chunk = scipy.ndimage.filters.convolve1d(local_chunk, waveform_neg, axis=0, mode='constant')
                for i in range(n_e):
                    peaktimes = scipy.signal.find_peaks(filter_chunk[:, i], height=matched_thresholds_neg[i])[0]
                    local_peaktimes.append(peaktimes)
                    if collect_all:
                        all_found_spikes[i] += peaktimes.tolist()
            local_peaktimes = numpy.concatenate(local_peaktimes)
        else:
            for i in range(n_e):
                if sign_peaks == 'negative':
                    peaktimes = scipy.signal.find_peaks(-local_chunk[:, i], height=thresholds[i])[0]
                elif sign_peaks == 'positive':
                    peaktimes = scipy.signal.find_peaks(local_chunk[:, i], height=thresholds[i])[0]
                elif sign_peaks == 'both':
                    peaktimes = scipy.signal.find_peaks(numpy.abs(local_chunk[:, i]), height=thresholds[i])[0]
                else:
                    raise ValueError("Unexpected value %s" % sign_peaks)
                local_peaktimes.append(peaktimes)
                if collect_all:
                    all_found_spikes[i] += peaktimes.tolist()
            local_peaktimes = numpy.concatenate(local_peaktimes)

        local_peaktimes = numpy.unique(local_peaktimes)

        g_offset = t_offset + padding[0]

        if ignore_dead_times:
            dead_indices = numpy.searchsorted(all_dead_times, [t_offset, t_offset + chunk_size])
            if dead_indices[0] != dead_indices[1]:
                is_included = numpy.in1d(local_peaktimes + g_offset, all_dead_times[dead_indices[0]:dead_indices[1]])
                local_peaktimes = local_peaktimes[~is_included]
                local_peaktimes = numpy.sort(local_peaktimes)
        else:
            dead_indices = None  # default assignment (for PyCharm code inspection)

        # print "Removing the useless borders..."
        local_borders = (template_shift, len_chunk - template_shift)
        idx = (local_peaktimes >= local_borders[0]) & (local_peaktimes < local_borders[1])
        local_peaktimes = numpy.compress(idx, local_peaktimes)

        if collect_all:
            for i in range(n_e):
                all_found_spikes[i] = numpy.array(all_found_spikes[i], dtype=numpy.uint32)

                if ignore_dead_times:
                    if dead_indices[0] != dead_indices[1]:
                        is_included = numpy.in1d(
                            all_found_spikes[i] + g_offset, all_dead_times[dead_indices[0]:dead_indices[1]]
                        )
                        all_found_spikes[i] = all_found_spikes[i][~is_included]
                        all_found_spikes[i] = numpy.sort(all_found_spikes[i])

                idx = (all_found_spikes[i] >= local_borders[0]) & (all_found_spikes[i] < local_borders[1])
                all_found_spikes[i] = numpy.compress(idx, all_found_spikes[i])

        nb_local_peak_times = len(local_peaktimes)

        if full_gpu:
            # all_indices = cmt.CUDAMatrix(all_indices)
            # tmp_gpu = cmt.CUDAMatrix(local_peaktimes.reshape((1, nb_local_peak_times)), copy_on_host=False)
            _ = cmt.CUDAMatrix(local_peaktimes.reshape((1, nb_local_peak_times)), copy_on_host=False)

        if nb_local_peak_times > 0:
            # print "Computing the b (should full_gpu by putting all chunks on GPU if possible?)..."

            if collect_all:
                c_local_chunk = local_chunk.copy()
            else:
                c_local_chunk = None  # default assignment (for PyCharm code inspection)

            sub_mat = local_chunk[local_peaktimes[:, None] + temp_window]
            sub_mat = sub_mat.transpose(2, 1, 0).reshape(size_window, nb_local_peak_times)

            del local_chunk

            if use_gpu:
                sub_mat = cmt.CUDAMatrix(sub_mat, copy_on_host=False)
                b = cmt.sparse_dot(templates, sub_mat)
            else:
                b = templates.dot(sub_mat)

            del sub_mat

            local_restriction = (t_offset, t_offset + chunk_size)
            all_spikes = local_peaktimes + g_offset

            # Because for GPU, slicing by columns is more efficient, we need to transpose b
            # b = b.transpose()
            if use_gpu and not full_gpu:
                b = b.asarray()           

            failure = numpy.zeros(nb_local_peak_times, dtype=numpy.int32)

            if full_gpu:
                mask = numpy.zeros((2 * n_tm, nb_local_peak_times), dtype=numpy.float32)
                mask[:n_tm, :] = 1
                # data = cmt.empty(mask.shape)
                _ = cmt.empty(mask.shape)
                patch_gpu = b.shape[1] == 1
            else:
                patch_gpu = None

            if collect_all:
                c_all_times = numpy.zeros((len_chunk, n_e), dtype=numpy.bool)
                c_min_times = numpy.maximum(numpy.arange(len_chunk) - template_shift, 0)
                c_max_times = numpy.minimum(numpy.arange(len_chunk) + template_shift + 1, len_chunk)
                for i in range(n_e):
                    c_all_times[all_found_spikes[i], i] = True
            else:
                c_all_times = None  # default assignment (for PyCharm code inspection)
                c_min_times = None  # default assignment (for PyCharm code inspection)
                c_max_times = None  # default assignment (for PyCharm code inspection)

            iteration_nb = 0
            local_max = 0
            numerous_argmax = False
            nb_argmax = n_tm
            best_indices = numpy.zeros(0, dtype=numpy.int32)

            data = b[:n_tm, :]
            flatten_data = data.ravel()

            while numpy.mean(failure) < total_nb_chances:

                # Is there a way to update sub_b * mask at the same time?
                if full_gpu:
                    b_array = b.asarray()
                else:
                    b_array = None

                if numerous_argmax:
                    if len(best_indices) == 0:
                        best_indices = largest_indices(flatten_data, nb_argmax)
                    best_template_index, peak_index = numpy.unravel_index(best_indices[0], data.shape)
                else:
                    best_template_index, peak_index = numpy.unravel_index(data.argmax(), data.shape)

                peak_scalar_product = data[best_template_index, peak_index]
                best_template2_index = best_template_index + n_tm

                if templates_normalization:
                    if full_gpu:
                        best_amp = b_array[best_template_index, peak_index] / n_scalar
                        best_amp2 = b_array[best_template2_index, peak_index] / n_scalar
                    else:
                        best_amp = b[best_template_index, peak_index] / n_scalar
                        if two_components:
                            best_amp2 = b[best_template2_index, peak_index] / n_scalar
                        else:
                            best_amp2 = 0.0
                    best_amp_n = best_amp / norm_templates[best_template_index]
                    best_amp2_n = best_amp2 / norm_templates[best_template2_index]
                else:
                    if full_gpu:
                        best_amp = b_array[best_template_index, peak_index]
                        best_amp = best_amp / norm_templates_2[best_template_index]
                        # TODO is `best_amp` value correct?
                        best_amp2 = b_array[best_template2_index, peak_index]
                        best_amp2 = best_amp2 / norm_templates_2[best_template2_index]
                        # TODO is `best_amp2` value correct?
                    else:
                        best_amp = b[best_template_index, peak_index]
                        best_amp = best_amp / norm_templates_2[best_template_index]
                        # TODO is `best_amp` value correct?
                        if two_components:
                            best_amp2 = b[best_template2_index, peak_index]
                            best_amp2 = best_amp2 / norm_templates_2[best_template2_index]
                            # TODO is `best_amp2` value correct?
                        else:
                            best_amp2 = 0.0

                    best_amp_n = best_amp
                    best_amp2_n = best_amp2

                # Verify amplitude constraint.
                a_min, a_max = amp_limits[best_template_index, :]

                if (a_min <= best_amp_n) & (best_amp_n <= a_max):
                    # Keep the matching.
                    peak_time_step = local_peaktimes[peak_index]

                    peak_data = (local_peaktimes - peak_time_step).astype(np.int32)
                    is_neighbor = np.where(np.abs(peak_data) <= temp_2_shift)[0]
                    idx_neighbor = peak_data[is_neighbor] + temp_2_shift
                    nb_neighbors = len(is_neighbor)
                    indices = np.zeros((s_over, nb_neighbors), dtype=np.int32)
                    indices[idx_neighbor, np.arange(nb_neighbors)] = 1

                    if full_gpu:
                        indices = cmt.CUDAMatrix(indices, copy_on_host=False)
                        if patch_gpu:
                            b_lines = b.get_col_slice(0, b.shape[0])
                        else:
                            b_lines = b.get_col_slice(is_neighbor[0], is_neighbor[-1]+1)
                        tmp1 = cmt.sparse_dot(c_overs[best_template_index], indices, mult=-best_amp)
                        tmp2 = cmt.sparse_dot(c_overs[best_template2_index], indices, mult=-best_amp2)
                        b_lines.add(tmp1.add(tmp2))
                        del tmp1, tmp2
                    else:
                        tmp1 = c_overs[best_template_index].multiply(-best_amp)
                        if numpy.abs(best_amp2) > min_second_component:
                            tmp1 += c_overs[best_template2_index].multiply(-best_amp2)
                        b[:, is_neighbor] += tmp1.dot(indices)

                    numerous_argmax = False

                    # Add matching to the result.
                    t_spike = all_spikes[peak_index]

                    if (t_spike >= local_restriction[0]) and (t_spike < local_restriction[1]):
                        result['spiketimes'] += [t_spike]
                        result['amplitudes'] += [(best_amp_n, best_amp2_n)]
                        result['templates'] += [best_template_index]
                    # Mark current matching as tried.
                    b[best_template_index, peak_index] = -numpy.inf
                    # Save debug data.
                    if debug:
                        result_debug['chunk_nbs'] += [gidx]
                        result_debug['iteration_nbs'] += [iteration_nb]
                        result_debug['peak_nbs'] += [peak_index]
                        result_debug['peak_local_time_steps'] += [local_peaktimes[peak_index]]
                        result_debug['peak_time_steps'] += [all_spikes[peak_index]]
                        result_debug['peak_scalar_products'] += [peak_scalar_product]
                        result_debug['peak_solved_flags'] += [b[best_template_index, peak_index]]
                        result_debug['template_nbs'] += [best_template_index]
                        result_debug['success_flags'] += [True]
                else:
                    # Reject the matching.
                    numerous_argmax = True
                    # Update failure counter of the peak.
                    failure[peak_index] += 1
                    # If the maximal number of failures is reached then mark peak as solved (i.e. not fitted).
                    if failure[peak_index] >= total_nb_chances:
                        # Mark all the matching associated to the current peak as tried.
                        b[:, peak_index] = -numpy.inf
                        index = numpy.arange(n_tm) * nb_local_peak_times + peak_index
                    else:
                        # Mark current matching as tried.
                        b[best_template_index, peak_index] = -numpy.inf
                        index = best_template_index * nb_local_peak_times + peak_index

                    if numerous_argmax:
                        best_indices = best_indices[~numpy.in1d(best_indices, index)]

                    # Save debug data.
                    if debug:
                        result_debug['chunk_nbs'] += [gidx]
                        result_debug['iteration_nbs'] += [iteration_nb]
                        result_debug['peak_nbs'] += [peak_index]
                        result_debug['peak_local_time_steps'] += [local_peaktimes[peak_index]]
                        result_debug['peak_time_steps'] += [all_spikes[peak_index]]
                        result_debug['peak_scalar_products'] += [peak_scalar_product]
                        result_debug['peak_solved_flags'] += [b[best_template_index, peak_index]]
                        result_debug['template_nbs'] += [best_template_index]
                        result_debug['success_flags'] += [False]

                iteration_nb += 1

            spikes_to_write = numpy.array(result['spiketimes'], dtype=numpy.uint32)
            amplitudes_to_write = numpy.array(result['amplitudes'], dtype=numpy.float32)
            templates_to_write = numpy.array(result['templates'], dtype=numpy.uint32)

            spiketimes_file.write(spikes_to_write.tostring())
            amplitudes_file.write(amplitudes_to_write.tostring())
            templates_file.write(templates_to_write.tostring())

            if collect_all:

                for temp, spike in zip(templates_to_write, spikes_to_write - g_offset):
                    c_all_times[c_min_times[spike]:c_max_times[spike], neighbors[temp]] = False

                gspikes = numpy.where(numpy.sum(c_all_times, 1) > 0)[0]
                c_all_times = numpy.take(c_all_times, gspikes, axis=0)
                c_local_chunk = numpy.take(c_local_chunk, gspikes, axis=0) * c_all_times                

                if sign_peaks == 'negative':
                    bestlecs = numpy.argmin(c_local_chunk, 1)
                    if matched_filter:
                        threshs = -matched_thresholds_neg[bestlecs]
                    else:
                        threshs = -thresholds[bestlecs]
                    idx = numpy.where(numpy.min(c_local_chunk, 1) < threshs)[0]
                elif sign_peaks == 'positive':
                    bestlecs = numpy.argmax(c_local_chunk, 1)
                    if matched_filter:
                        threshs = matched_thresholds_pos[bestlecs]
                    else:
                        threshs = thresholds[bestlecs]
                    idx = numpy.where(numpy.max(c_local_chunk, 1) > threshs)[0]
                elif sign_peaks == 'both':
                    c_local_chunk = numpy.abs(c_local_chunk)
                    bestlecs = numpy.argmax(c_local_chunk, 1)
                    if matched_filter:
                        threshs = numpy.minimum(matched_thresholds_neg[bestlecs], matched_thresholds_pos[bestlecs])
                    else:
                        threshs = thresholds[bestlecs]
                    idx = numpy.where(numpy.max(c_local_chunk, 1) > threshs)[0]
                else:
                    raise ValueError("Unexpected value %s" % sign_peaks)

                gspikes = numpy.take(gspikes, idx)
                bestlecs = numpy.take(bestlecs, idx)
                gspikes_to_write = numpy.array(gspikes + g_offset, dtype=numpy.uint32)
                gtemplates_to_write = numpy.array(bestlecs, dtype=numpy.uint32)

                garbage_times_file.write(gspikes_to_write.tostring())
                garbage_temp_file.write(gtemplates_to_write.tostring())

            if debug:
                # Write debug data to debug files.
                for field_label, field_dtype, field_file in [
                    ('chunk_nbs', numpy.uint32, chunk_nbs_debug_file),
                    ('iteration_nbs', numpy.uint32, iteration_nbs_debug_file),
                    ('peak_nbs', numpy.uint32, peak_nbs_debug_file),
                    ('peak_local_time_steps', numpy.uint32, peak_local_time_steps_debug_file),
                    ('peak_time_steps', numpy.uint32, peak_time_steps_debug_file),
                    ('peak_scalar_products', numpy.float32, peak_scalar_products_debug_file),
                    ('peak_solved_flags', numpy.float32, peak_solved_flags_debug_file),
                    ('template_nbs', numpy.uint32, template_nbs_debug_file),
                    ('success_flags', numpy.bool, success_flags_debug_file),
                ]:
                    field_to_write = numpy.array(result_debug[field_label], dtype=field_dtype)
                    field_file.write(field_to_write.tostring())

            if full_gpu:
                del b, data

    sys.stderr.flush()

    spiketimes_file.flush()
    os.fsync(spiketimes_file.fileno())
    spiketimes_file.close()

    amplitudes_file.flush()
    os.fsync(amplitudes_file.fileno())
    amplitudes_file.close()

    templates_file.flush()
    os.fsync(templates_file.fileno())
    templates_file.close()

    if collect_all:

        garbage_temp_file.flush()
        os.fsync(garbage_temp_file.fileno())
        garbage_temp_file.close()
        
        garbage_times_file.flush()
        os.fsync(garbage_times_file.fileno())
        garbage_times_file.close()

    if debug:
        # Close debug files.
        for field_file in [
            chunk_nbs_debug_file,
            iteration_nbs_debug_file,
            peak_nbs_debug_file,
            peak_local_time_steps_debug_file,
            peak_time_steps_debug_file,
            peak_scalar_products_debug_file,
            peak_solved_flags_debug_file,
            template_nbs_debug_file,
            success_flags_debug_file,
        ]:
            field_file.flush()
            os.fsync(field_file.fileno())
            field_file.close()

    comm.Barrier()

    if comm.rank == 0:
        io.collect_data(comm.size, params, erase=True)

    data_file.close()
Ejemplo n.º 6
0
def main(params, nb_cpu, nb_gpu, use_gpu):
    numpy.random.seed(426236)
    # params = detect_memory(params)
    parallel_hdf5 = get_parallel_hdf5_flag(params)
    _ = init_logging(params.logfile)
    logger = logging.getLogger('circus.extracting')
    #################################################################
    data_file = params.data_file
    N_e = params.getint('data', 'N_e')
    N_t = params.getint('detection', 'N_t')
    N_total = params.nb_channels
    template_shift = params.getint('detection', 'template_shift')
    chunk_size = detect_memory(params)
    file_out = params.get('data', 'file_out')
    file_out_suff = params.get('data', 'file_out_suff')
    do_temporal_whitening = params.getboolean('whitening', 'temporal')
    do_spatial_whitening = params.getboolean('whitening', 'spatial')
    nodes, edges = get_nodes_and_edges(params)
    safety_time = params.getint('extracting', 'safety_time')
    max_elts_temp = params.getint('extracting', 'max_elts')
    output_dim = params.getfloat('extracting', 'output_dim')
    noise_thr = params.getfloat('extracting', 'noise_thr')
    hdf5_compress = params.getboolean('data', 'hdf5_compress')
    blosc_compress = params.getboolean('data', 'blosc_compress')
    tmp_limits = params.get('fitting',
                            'amp_limits').replace('(',
                                                  '').replace(')',
                                                              '').split(',')
    amp_limits = map(float, tmp_limits)
    elt_count = 0
    inv_nodes = numpy.zeros(N_total, dtype=numpy.int32)
    inv_nodes[nodes] = numpy.arange(len(nodes))
    data_file.open()
    #################################################################

    if comm.rank == 0:
        print_and_log(["Extracting templates from already found clusters..."],
                      'default', logger)

    thresholds = io.load_data(params, 'thresholds')
    basis_proj, basis_rec = io.load_data(params, 'basis')
    clusters, spiketimes, N_clusters = io.load_data(params, 'spike-cluster')
    inv_clusters = numpy.zeros(clusters.max() + 1, dtype=numpy.int32)
    inv_clusters[numpy.unique(clusters)] = numpy.argsort(
        numpy.unique(clusters))

    if use_gpu:
        import cudamat as cmt
        # # Need to properly handle multi GPU per MPI nodes?
        if nb_gpu > nb_cpu:
            gpu_id = int(comm.rank // nb_cpu)
        else:
            gpu_id = 0
        cmt.cuda_set_device(gpu_id)
        cmt.init()
        cmt.cuda_sync_threads()

    if do_spatial_whitening:
        spatial_whitening = io.load_data(params, 'spatial_whitening')
    else:
        spatial_whitening = None  # default assignment (PyCharm code inspection)
    if do_temporal_whitening:
        temporal_whitening = io.load_data(params, 'temporal_whitening')
    else:
        temporal_whitening = None  # default assignment (PyCharm code inspection)

    if use_gpu and do_spatial_whitening:
        spatial_whitening = cmt.CUDAMatrix(spatial_whitening,
                                           copy_on_host=False)

    result = {}
    for i in range(N_clusters):
        result['data_tmp_' + str(i)] = numpy.zeros(
            (0, N_e * basis_proj.shape[1]), dtype=numpy.float32)
        result['times_' + str(i)] = numpy.zeros(0, dtype=numpy.int32)

    nb_chunks, last_chunk_len = data_file.analyze(chunk_size)

    # I guess this is more relevant, to take signals from all over the recordings.
    all_chunks = numpy.random.permutation(numpy.arange(nb_chunks))

    nb_templates = numpy.sum(
        comm.rank == numpy.mod(numpy.arange(N_clusters), comm.size))
    nb_elts = max_elts_temp * nb_templates

    to_explore = all_chunks

    if comm.rank == 0:
        to_explore = get_tqdm_progressbar(params, to_explore)

    for gidx in all_chunks:

        if elt_count < nb_elts:
            # print "Node", comm.rank, "is analyzing chunk", gidx, "/", nb_chunks, " ..."
            local_chunk, t_offset = data_file.get_data(gidx,
                                                       chunk_size,
                                                       nodes=nodes)
            local_shape = len(local_chunk)

            if do_spatial_whitening:
                if use_gpu:
                    local_chunk = cmt.CUDAMatrix(local_chunk,
                                                 copy_on_host=False)
                    local_chunk = local_chunk.dot(spatial_whitening).asarray()
                else:
                    local_chunk = numpy.dot(local_chunk, spatial_whitening)
            if do_temporal_whitening:
                local_chunk = scipy.ndimage.filters.convolve1d(
                    local_chunk, temporal_whitening, axis=0, mode='constant')

            # print "Extracting the peaks..."
            idx = numpy.where((spiketimes >= gidx * chunk_size)
                              & (spiketimes < (gidx + 1) * chunk_size))[0]
            local_offset = t_offset
            local_peaktimes = spiketimes[idx] - local_offset

            # print "Removing the useless borders..."
            local_borders = (template_shift, chunk_size - template_shift)
            idx = (local_peaktimes >= local_borders[0]) & (local_peaktimes <
                                                           local_borders[1])
            local_peaktimes = local_peaktimes[idx]
            local_clusters = inv_clusters[clusters[idx]]

            if len(local_peaktimes) > 0:
                all_times = numpy.zeros(
                    (N_e, local_peaktimes[-1] - local_peaktimes[0] + 1),
                    dtype=numpy.bool)
                min_times = numpy.maximum(
                    local_peaktimes - local_peaktimes[0] - safety_time, 0)
                max_times = numpy.minimum(
                    local_peaktimes - local_peaktimes[0] + safety_time + 1,
                    local_peaktimes[-1] - local_peaktimes[0])

                n_times = len(local_peaktimes)
                argmax_peak = numpy.random.permutation(numpy.arange(n_times))
                clusters_id = local_clusters[argmax_peak]
                local_peaktimes = local_peaktimes[argmax_peak]

                # print "Selection of the peaks with spatio-temporal masks..."
                for idx in range(len(local_peaktimes)):

                    if elt_count == nb_elts:
                        break

                    temp = clusters_id[idx]

                    if numpy.mod(temp, comm.size) == comm.rank:

                        elec = numpy.argmin(local_chunk[local_peaktimes[idx]])
                        indices = inv_nodes[edges[nodes[elec]]]
                        myslice = all_times[indices,
                                            min_times[idx]:max_times[idx]]
                        peak = local_peaktimes[idx]
                        if not myslice.any():
                            if len(result['data_tmp_' +
                                          str(temp)]) < max_elts_temp:
                                elt_count += 1
                                sub_mat = local_chunk[peak -
                                                      template_shift:peak +
                                                      template_shift + 1, :]
                                sub_mat = numpy.dot(basis_rec, sub_mat)
                                nx, ny = sub_mat.shape
                                sub_mat = sub_mat.reshape((1, nx * ny))
                                result['data_tmp_' + str(temp)] = numpy.vstack(
                                    (result['data_tmp_' + str(temp)], sub_mat))
                                to_add = numpy.array([peak + local_offset],
                                                     dtype=numpy.int32)
                                result['times_' +
                                       str(temp)] = numpy.concatenate(
                                           (result['times_' + str(temp)],
                                            to_add))
                            all_times[indices,
                                      min_times[idx]:max_times[idx]] = True

    total_nb_elts = 0
    for temp in range(N_clusters):
        total_nb_elts += len(result['data_tmp_' + str(temp)])

    gdata = gather_array(numpy.array([total_nb_elts], dtype=numpy.float32),
                         comm, 0)
    if comm.rank == 0:
        print_and_log([
            "Found %d spikes over %d requested" %
            (int(numpy.sum(gdata)), int(nb_elts))
        ], 'default', logger)

    # print "Spikes extracted in", time.time() - t_start, "s"

    comm.Barrier()

    local_nb_clusters = 0
    for temp in range(comm.rank, N_clusters, comm.size):
        if len(result['data_tmp_' + str(temp)]) > 0:
            local_nb_clusters += 1

    # print total_nb_clusters, "found in", time.time() - t_start, "s"
    gdata3 = gather_array(
        numpy.array([local_nb_clusters], dtype=numpy.float32), comm, 0)

    comm.Barrier()
    if comm.rank == 0:
        print_and_log(["Extracting the templates..."], 'default', logger)

    total_nb_clusters = int(
        comm.bcast(numpy.array([int(numpy.sum(gdata3))], dtype=numpy.int32),
                   root=0)[0])
    offsets = numpy.zeros(comm.size, dtype=numpy.int32)
    for i in range(comm.size - 1):
        offsets[i + 1] = comm.bcast(numpy.array([local_nb_clusters],
                                                dtype=numpy.int32),
                                    root=i)

    if parallel_hdf5:
        node_pad = numpy.sum(offsets[:comm.rank + 1])
        hfile = h5py.File(file_out_suff + '.templates.hdf5',
                          'w',
                          driver='mpio',
                          comm=comm,
                          libver='earliest')
        norms = hfile.create_dataset('norms',
                                     shape=(2 * total_nb_clusters, ),
                                     dtype=numpy.float32,
                                     chunks=True)
        electrodes = hfile.create_dataset('electrodes',
                                          shape=(total_nb_clusters, ),
                                          dtype=numpy.int32,
                                          chunks=True)
        amps_lims = hfile.create_dataset('limits',
                                         shape=(total_nb_clusters, 2),
                                         dtype=numpy.float32,
                                         chunks=True)
        g_count = node_pad
        g_offset = total_nb_clusters
    else:
        node_pad = 0
        hfile = h5py.File(file_out_suff + '.templates-%d.hdf5' % comm.rank,
                          'w',
                          libver='earliest')
        electrodes = hfile.create_dataset('electrodes',
                                          shape=(local_nb_clusters, ),
                                          dtype=numpy.int32,
                                          chunks=True)
        norms = hfile.create_dataset('norms',
                                     shape=(2 * local_nb_clusters, ),
                                     dtype=numpy.float32,
                                     chunks=True)
        amps_lims = hfile.create_dataset('limits',
                                         shape=(local_nb_clusters, 2),
                                         dtype=numpy.float32,
                                         chunks=True)
        g_count = 0
        g_offset = local_nb_clusters

    cfile = h5py.File(file_out_suff + '.clusters-%d.hdf5' % comm.rank,
                      'w',
                      libver='earliest')
    count_templates = node_pad

    temp_x = numpy.zeros(0, dtype=numpy.int32)
    temp_y = numpy.zeros(0, dtype=numpy.int32)
    temp_data = numpy.zeros(0, dtype=numpy.float32)

    to_explore = range(comm.rank, N_clusters, comm.size)

    if comm.rank == 0:
        to_explore = get_tqdm_progressbar(params, to_explore)

    for temp in to_explore:
        n_data = len(result['data_tmp_' + str(temp)])
        if n_data > 0:
            data = result['data_tmp_' + str(temp)].reshape(
                n_data, basis_proj.shape[1], N_e)
            first_component = numpy.median(data, axis=0)
            tmp_templates = numpy.dot(first_component.T, basis_rec)
            electrodes[g_count] = indices[tmpidx[0][0]]
            indices = inv_nodes[edges[nodes[electrodes[-1]]]]
            templates = numpy.zeros((N_e, N_t), dtype=numpy.float32)
            if shift > 0:
                templates[indices, shift:] = tmp_templates[:, :-shift]
            elif shift < 0:
                templates[indices, :shift] = tmp_templates[:, -shift:]
            else:
                templates[indices, :] = tmp_templates

            templates = templates.flatten()
            dx = templates.nonzero()[0].astype(numpy.int32)

            temp_x = numpy.concatenate((temp_x, dx))
            temp_y = numpy.concatenate(
                (temp_y,
                 count_templates * numpy.ones(len(dx), dtype=numpy.int32)))
            temp_data = numpy.concatenate((temp_data, templates[dx]))

            norms[g_count] = numpy.sqrt(
                numpy.sum(templates.flatten()**2) / (N_e * N_t))

            x, y, z = data.shape
            data_flat = data.reshape(x, y * z)
            first_flat = first_component.reshape(y * z, 1)
            amplitudes = numpy.dot(data_flat, first_flat)
            amplitudes /= numpy.sum(first_flat**2)
            for i in range(x):
                data_flat[i, :] -= amplitudes[i] * first_flat[:, 0]

            variations = 10 * numpy.median(
                numpy.abs(amplitudes - numpy.median(amplitudes)))
            physical_limit = noise_thr * (
                -thresholds[indices[tmpidx[0][0]]]) / tmp_templates.min()
            amp_min = max(physical_limit,
                          numpy.median(amplitudes) - variations)
            amp_max = min(amp_limits[1], numpy.median(amplitudes) + variations)
            amps_lims[g_count] = [amp_min, amp_max]

            if len(data_flat) > 1:
                pca = PCA(1)
                res_pca = pca.fit_transform(data_flat.astype(numpy.double))
                second_component = pca.components_.T.astype(
                    numpy.float32).reshape(y, z)
            else:
                second_component = data_flat.reshape(y, z) / numpy.sum(
                    data_flat**2)

            tmp_templates = numpy.dot(second_component.T, basis_rec)
            offset = total_nb_clusters + count_templates
            sub_templates = numpy.zeros((N_e, N_t), dtype=numpy.float32)
            if shift > 0:
                sub_templates[indices, shift:] = tmp_templates[:, :-shift]
            elif shift < 0:
                sub_templates[indices, :shift] = tmp_templates[:, -shift:]
            else:
                sub_templates[indices, :] = tmp_templates

            sub_templates = sub_templates.flatten()
            dx = sub_templates.nonzero()[0].astype(numpy.int32)

            temp_x = numpy.concatenate((temp_x, dx))
            temp_y = numpy.concatenate(
                (temp_y, offset * numpy.ones(len(dx), dtype=numpy.int32)))
            temp_data = numpy.concatenate((temp_data, sub_templates[dx]))

            norms[g_count + g_offset] = numpy.sqrt(
                numpy.sum(sub_templates.flatten()**2) / (N_e * N_t))

            count_templates += 1
            g_count += 1

        io.write_datasets(cfile,
                          to_write,
                          result,
                          ielec,
                          compress=hdf5_compress)

    # At the end we should have a templates variable to store.
    cfile.close()
    del result, templates, amps_lims
    comm.Barrier()

    # We need to gather the sparse arrays.
    temp_x = gather_array(temp_x, comm, dtype='int32', compress=blosc_compress)
    temp_y = gather_array(temp_y, comm, dtype='int32', compress=blosc_compress)
    temp_data = gather_array(temp_data, comm, compress=blosc_compress)

    if parallel_hdf5:
        if comm.rank == 0:
            rs = [
                h5py.File(file_out_suff + '.clusters-%d.hdf5' % i,
                          'r',
                          libver='earliest') for i in range(comm.size)
            ]
            cfile = h5py.File(file_out_suff + '.clusters.hdf5',
                              'w',
                              libver='earliest')
            io.write_datasets(cfile, ['electrodes'],
                              {'electrodes': electrodes[:]},
                              compress=hdf5_compress)
            for i in range(comm.size):
                for j in range(i, N_e, comm.size):
                    io.write_datasets(cfile,
                                      to_write,
                                      rs[i],
                                      j,
                                      compress=hdf5_compress)
                rs[i].close()
                os.remove(file_out_suff + '.clusters-%d.hdf5' % i)
            cfile.close()
        hfile.close()
    else:
        hfile.close()
        if comm.rank == 0:
            ts = [
                h5py.File(file_out_suff + '.templates-%d.hdf5' % i,
                          'r',
                          libver='earliest') for i in range(comm.size)
            ]
            rs = [
                h5py.File(file_out_suff + '.clusters-%d.hdf5' % i,
                          'r',
                          libver='earliest') for i in range(comm.size)
            ]
            result = {}
            hfile = h5py.File(file_out_suff + '.templates.hdf5',
                              'w',
                              libver='earliest')
            cfile = h5py.File(file_out_suff + '.clusters.hdf5',
                              'w',
                              libver='earliest')
            electrodes = hfile.create_dataset('electrodes',
                                              shape=(total_nb_clusters, ),
                                              dtype=numpy.int32,
                                              chunks=True)
            norms = hfile.create_dataset('norms',
                                         shape=(2 * total_nb_clusters, ),
                                         dtype=numpy.float32,
                                         chunks=True)
            amplitudes = hfile.create_dataset('limits',
                                              shape=(total_nb_clusters, 2),
                                              dtype=numpy.float32,
                                              chunks=True)
            count = 0
            for i in range(comm.size):
                loc_temp = ts[i].get('templates')
                middle = loc_temp.shape[2] // 2
                norms[count:count + middle] = loc_norms[:middle]
                norms[n_clusters + count:n_clusters + count +
                      middle] = loc_norms[middle:]
                electrodes[count:count + middle] = ts[i].get('electrodes')
                amplitudes[count:count + middle] = ts[i].get('limits')
                count += middle
                for j in range(i, N_e, comm.size):
                    io.write_datasets(cfile,
                                      to_write,
                                      rs[i],
                                      j,
                                      compress=hdf5_compress)
                ts[i].close()
                rs[i].close()
                os.remove(file_out_suff + '.templates-%d.hdf5' % i)
                os.remove(file_out_suff + '.clusters-%d.hdf5' % i)
            io.write_datasets(cfile, ['electrodes'],
                              {'electrodes': electrodes[:]},
                              compress=hdf5_compress)
            hfile.close()
            cfile.close()

    if comm.rank == 0:
        hfile = h5py.File(file_out_suff + '.templates.hdf5',
                          'r+',
                          libver='earliest')
        hfile.create_dataset('temp_x', data=temp_x)
        hfile.create_dataset('temp_y', data=temp_y)
        hfile.create_dataset('temp_data', data=temp_data)
        hfile.create_dataset('temp_shape',
                             data=numpy.array(
                                 [N_e, N_t, 2 * total_nb_clusters],
                                 dtype=numpy.int32))
        hfile.close()

    comm.Barrier()

    if comm.rank == 0:
        print_and_log(["Merging similar templates..."], 'default', logger)

    merged1 = algo.merging_cc(params, parallel_hdf5)

    comm.Barrier()
    if remove_mixture:
        if comm.rank == 0:
            print_and_log(["Removing mixtures..."], 'default', logger)
        merged2 = algo.delete_mixtures(params, parallel_hdf5)
    else:
        merged2 = [0, 0]

    if comm.rank == 0:
        lines = [
            "Number of global merges    : %d" % merged1[1],
            "Number of mixtures removed : %d" % merged2[1],
        ]
        print_and_log(lines, 'info', logger)

    comm.Barrier()
    io.get_overlaps(params, erase=True, parallel_hdf5=parallel_hdf5)

    data_file.close()
Ejemplo n.º 7
0
def main(params, nb_cpu, nb_gpu, use_gpu):

    #################################################################
    # params = detect_memory(params)
    _ = init_logging(params.logfile)
    SHARED_MEMORY = get_shared_memory_flag(params)
    logger = logging.getLogger('circus.fitting')
    data_file = params.data_file
    n_e = params.getint('data', 'N_e')
    n_total = params.nb_channels
    n_t = params.getint('detection', 'N_t')
    template_shift = params.getint('detection', 'template_shift')
    # file_out = params.get('data', 'file_out')
    file_out_suff = params.get('data', 'file_out_suff')
    sign_peaks = params.get('detection', 'peaks')
    matched_filter = params.getboolean('detection', 'matched-filter')
    # spike_thresh = params.getfloat('detection', 'spike_thresh')
    ratio_thresh = params.getfloat('fitting', 'ratio_thresh')
    two_components = params.getboolean('fitting', 'two_components')
    sparse_threshold = params.getfloat('fitting', 'sparse_thresh')
    # spike_width = params.getfloat('detection', 'spike_width')
    # dist_peaks = params.getint('detection', 'dist_peaks')
    do_temporal_whitening = params.getboolean('whitening', 'temporal')
    do_spatial_whitening = params.getboolean('whitening', 'spatial')
    templates_normalization = params.getboolean('clustering', 'templates_normalization')  # TODO test, switch, test!
    chunk_size = detect_memory(params, fitting=True)
    gpu_only = params.getboolean('fitting', 'gpu_only')
    nodes, edges = get_nodes_and_edges(params)
    tmp_limits = params.get('fitting', 'amp_limits').replace('(', '').replace(')', '').split(',')
    tmp_limits = [float(v) for v in tmp_limits]
    amp_auto = params.getboolean('fitting', 'amp_auto')
    max_chunk = params.getfloat('fitting', 'max_chunk')
    # noise_thr = params.getfloat('clustering', 'noise_thr')
    collect_all = params.getboolean('fitting', 'collect_all')
    min_second_component = params.getfloat('fitting', 'min_second_component')
    debug = params.getboolean('fitting', 'debug')
    ignore_dead_times = params.getboolean('triggers', 'ignore_times')
    inv_nodes = numpy.zeros(n_total, dtype=numpy.int32)
    inv_nodes[nodes] = numpy.arange(len(nodes))
    data_file.open()
    supports = io.load_data(params, 'supports')
    low_channels_thr = params.getint('detection', 'low_channels_thr')
    median_channels = numpy.median(numpy.sum(supports, 1))
    fixed_amplitudes = params.getboolean('clustering', 'fixed_amplitudes')

    weird_thresh = params.get('detection', 'weird_thresh')
    if weird_thresh != '':
        ignore_artefacts = True
        weird_thresh = io.load_data(params, 'weird-thresholds')
    else:
        ignore_artefacts = False

    if not fixed_amplitudes:
        nb_amp_bins = params.getint('clustering', 'nb_amp_bins')
        splits = np.linspace(0, params.data_file.duration, nb_amp_bins)
        interpolated_times = np.zeros(len(splits) - 1, dtype=numpy.float32)
        for count in range(0, len(splits) - 1):
            interpolated_times[count] = (splits[count] + splits[count + 1])/2
        interpolated_times = numpy.concatenate(([0], interpolated_times, [params.data_file.duration]))
        nb_amp_times = len(splits) + 1

    mse_error = params.getboolean('fitting', 'mse_error')
    if mse_error:
        stds = io.load_data(params, 'stds')
        stds_norm = numpy.linalg.norm(stds)
    # if median_channels < low_channels_thr:
    #     normalization = False
    #     if comm.rank == 0:
    #         print_and_log(['Templates defined on few channels (%g), turning off normalization' %median_channels], 'debug', logger)

    #################################################################

    if SHARED_MEMORY:
        templates, mpi_memory_1 = io.load_data_memshared(params, 'templates', normalize=templates_normalization, transpose=True, sparse_threshold=sparse_threshold)
        N_tm, x = templates.shape
        is_sparse = not isinstance(templates, numpy.ndarray)
    else:
        templates = io.load_data(params, 'templates')
        x, N_tm = templates.shape
        if N_tm > 0:
            sparsity = templates.nnz / (x * N_tm)
            is_sparse = sparsity < sparse_threshold
        else:
            is_sparse = True
        if not is_sparse:
            if comm.rank == 0:
                print_and_log(['Templates sparsity is low (%g): densified to speedup the algorithm' %sparsity], 'debug', logger)
            templates = templates.toarray()

    temp_2_shift = 2 * template_shift
    temp_3_shift = 3 * template_shift
    n_tm = N_tm // 2
    n_scalar = n_e * n_t

    temp_window = numpy.arange(-template_shift, template_shift + 1)
    size_window = n_e * (2 * template_shift + 1)

    if not amp_auto:
        amp_limits = numpy.zeros((n_tm, 2))
        amp_limits[:, 0] = tmp_limits[0]
        amp_limits[:, 1] = tmp_limits[1]
    else:
        amp_limits = io.load_data(params, 'limits')

    norm_templates = io.load_data(params, 'norm-templates')
    sub_norm_templates = n_scalar * norm_templates[:n_tm]
    if not templates_normalization:
        norm_templates_2 = (norm_templates ** 2.0) * n_scalar
        sub_norm_templates_2 = norm_templates_2[:n_tm]

    if not SHARED_MEMORY:
        # Normalize templates (if necessary).
        if templates_normalization:
            if is_sparse:
                for idx in range(templates.shape[1]):
                    myslice = numpy.arange(templates.indptr[idx], templates.indptr[idx+1])
                    templates.data[myslice] /= norm_templates[idx]
            else:
                for idx in range(templates.shape[1]):
                    templates[:, idx] /= norm_templates[idx]
        # Transpose templates.
        templates = templates.T

    maxoverlap = io.load_data(params, 'maxoverlap')/n_scalar
    similar = np.where(maxoverlap > 0.5)

    idx = similar[0] < similar[1]
    similar = similar[0][idx], similar[1][idx]
    nb_mixtures = len(similar[0])

    waveform_neg = numpy.empty(0)  # default assignment (for PyCharm code inspection)
    matched_thresholds_neg = None  # default assignment (for PyCharm code inspection)
    waveform_pos = numpy.empty(0)  # default assignment (for PyCharm code inspection)
    matched_thresholds_pos = None  # default assignment (for PyCharm code inspection)
    if matched_filter:
        if sign_peaks in ['negative', 'both']:
            waveform_neg = io.load_data(params, 'waveform')[::-1]
            waveform_neg /= (numpy.abs(numpy.sum(waveform_neg)) * len(waveform_neg))
            matched_thresholds_neg = io.load_data(params, 'matched-thresholds')
        if sign_peaks in ['positive', 'both']:
            waveform_pos = io.load_data(params, 'waveform-pos')[::-1]
            waveform_pos /= (numpy.abs(numpy.sum(waveform_pos)) * len(waveform_pos))
            matched_thresholds_pos = io.load_data(params, 'matched-thresholds-pos')

    if ignore_dead_times:
        if SHARED_MEMORY:
            all_dead_times, mpi_memory_3 = get_dead_times(params)
        else:
            all_dead_times = get_dead_times(params)
    else:
        all_dead_times = None  # default assignment (for PyCharm code inspection)

    thresholds = io.get_accurate_thresholds(params, ratio_thresh)

    neighbors = {}
    if collect_all:
        is_sparse = not isinstance(templates, numpy.ndarray)
        for i in range(0, n_tm):
            if is_sparse:
                tmp = templates[i, :].toarray().reshape(n_e, n_t)
            else:
                tmp = templates[i].reshape(n_e, n_t)
            if templates_normalization:
                tmp = tmp * norm_templates[i]
            neighbors[i] = numpy.where(numpy.sum(tmp, axis=1) != 0.0)[0]

    #N_tm, x = templates.shape
    #sparsity_factor = templates.nnz / (N_tm * x)
    #if sparsity_factor > sparse_threshold:
    #    if comm.rank == 0:
    #        print_and_log(['Templates are not sparse enough, we densify them for'], 'default', logger)
    #    templates = templates.toarray()

    info_string = ''

    if comm.rank == 0:
        info_string = "using %d CPUs" % comm.size

    comm.Barrier()

    c_overlap = io.get_overlaps(params, nb_cpu=nb_cpu, nb_gpu=nb_gpu, use_gpu=use_gpu)
    over_shape = c_overlap.get('over_shape')[:]
    n_over = int(numpy.sqrt(over_shape[0]))
    s_over = over_shape[1]
    s_center = s_over // 2
    # # If the number of overlaps is different from templates, we need to recompute them.
    if n_over != N_tm:
        if comm.rank == 0:
            print_and_log(['Templates have been modified, recomputing the overlaps...'], 'default', logger)
        c_overlap = io.get_overlaps(params, erase=True, nb_cpu=nb_cpu, nb_gpu=nb_gpu, use_gpu=use_gpu)
        over_shape = c_overlap.get('over_shape')[:]
        n_over = int(numpy.sqrt(over_shape[0]))
        s_over = over_shape[1]

    if SHARED_MEMORY:
        c_overs, mpi_memory_2 = io.load_data_memshared(params, 'overlaps')
    else:
        c_overs = io.load_data(params, 'overlaps')

    comm.Barrier()

    if n_tm == 0:
        if comm.rank == 0:
            print_and_log(["No templates present. Redo clustering?"], 'default', logger)

        sys.exit(0)

    if comm.rank == 0:
        print_and_log(["Here comes the SpyKING CIRCUS %s and %d templates..." % (info_string, n_tm)], 'default', logger)
        purge(file_out_suff, '.data')

    if do_spatial_whitening:
        spatial_whitening = io.load_data(params, 'spatial_whitening')
    else:
        spatial_whitening = None  # default assignment (for PyCharm code inspection)
    if do_temporal_whitening:
        temporal_whitening = io.load_data(params, 'temporal_whitening')
    else:
        temporal_whitening = None  # default assignment (for PyCharm code inspection)

    nb_chunks, last_chunk_len = data_file.analyze(chunk_size)
    processed_chunks = int(min(nb_chunks, max_chunk))

    comm.Barrier()
    spiketimes_file = open(file_out_suff + '.spiketimes-%d.data' % comm.rank, 'wb')
    comm.Barrier()
    amplitudes_file = open(file_out_suff + '.amplitudes-%d.data' % comm.rank, 'wb')
    comm.Barrier()
    templates_file = open(file_out_suff + '.templates-%d.data' % comm.rank, 'wb')
    comm.Barrier()

    if ignore_artefacts:
        comm.Barrier()
        arte_spiketimes_file = open(file_out_suff + '.times-%d.sata' % comm.rank, 'wb')
        comm.Barrier()
        arte_electrodes_file = open(file_out_suff + '.elec-%d.sata' % comm.rank, 'wb')
        comm.Barrier()
        arte_amplitudes_file = open(file_out_suff + '.amp-%d.sata' % comm.rank, 'wb')
        comm.Barrier()

    if mse_error:
        mse_file = open(file_out_suff + '.mses-%d.data' % comm.rank, 'wb')
        comm.Barrier()

    if collect_all:
        garbage_times_file = open(file_out_suff + '.gspiketimes-%d.data' % comm.rank, 'wb')
        comm.Barrier()
        garbage_temp_file = open(file_out_suff + '.gtemplates-%d.data' % comm.rank, 'wb')
        comm.Barrier()
    else:
        garbage_times_file = None  # default assignment (for PyCharm code inspection)
        garbage_temp_file = None  # default assignment (for PyCharm code inspection)

    if debug:
        # Open debug files.
        chunk_nbs_debug_file = open(file_out_suff + '.chunk_nbs_debug_%d.data' % comm.rank, mode='wb')
        comm.Barrier()
        iteration_nbs_debug_file = open(file_out_suff + '.iteration_nbs_debug_%d.data' % comm.rank, mode='wb')
        comm.Barrier()
        peak_nbs_debug_file = open(file_out_suff + '.peak_nbs_debug_%d.data' % comm.rank, mode='wb')
        comm.Barrier()
        peak_local_time_steps_debug_file = open(
            file_out_suff + '.peak_local_time_steps_debug_%d.data' % comm.rank, mode='wb'
        )
        comm.Barrier()
        peak_time_steps_debug_file = open(file_out_suff + '.peak_time_steps_debug_%d.data' % comm.rank, mode='wb')
        comm.Barrier()
        peak_scalar_products_debug_file = open(
            file_out_suff + '.peak_scalar_products_debug_%d.data' % comm.rank, mode='wb'
        )
        comm.Barrier()
        peak_solved_flags_debug_file = open(file_out_suff + '.peak_solved_flags_debug_%d.data' % comm.rank, mode='wb')
        comm.Barrier()
        template_nbs_debug_file = open(file_out_suff + '.template_nbs_debug_%d.data' % comm.rank, mode='wb')
        comm.Barrier()
        success_flags_debug_file = open(file_out_suff + '.success_flags_debug_%d.data' % comm.rank, mode='wb')
        comm.Barrier()
    else:
        chunk_nbs_debug_file = None  # default assignment (for PyCharm code inspection)
        iteration_nbs_debug_file = None  # default assignment (for PyCharm code inspection)
        peak_nbs_debug_file = None  # default assignment (for PyCharm code inspection)
        peak_local_time_steps_debug_file = None  # default assignment (for PyCharm code inspection)
        peak_time_steps_debug_file = None  # default assignment (for PyCharm code inspection)
        peak_scalar_products_debug_file = None  # default assignment (for PyCharm code inspection)
        peak_solved_flags_debug_file = None  # default assignment (for PyCharm code inspection)
        template_nbs_debug_file = None  # default assignment (for PyCharm code inspection)
        success_flags_debug_file = None  # default assignment (for PyCharm code inspection)

    last_chunk_size = 0
    slice_indices = numpy.zeros(0, dtype=numpy.int32)

    to_explore = list(range(comm.rank, processed_chunks, comm.size))

    if comm.rank == 0:
        to_explore = get_tqdm_progressbar(params, to_explore)

    if fixed_amplitudes:
        min_scalar_products = amp_limits[:,0][:, numpy.newaxis]
        max_scalar_products = amp_limits[:,1][:, numpy.newaxis]

        if templates_normalization:
            min_sps = min_scalar_products * sub_norm_templates[:, numpy.newaxis]
            max_sps = max_scalar_products * sub_norm_templates[:, numpy.newaxis]
        else:
            min_sps = min_scalar_products * sub_norm_templates_2[:, numpy.newaxis]
            max_sps = max_scalar_products * sub_norm_templates_2[:, numpy.newaxis]

    for gcount, gidx in enumerate(to_explore):
        # print "Node", comm.rank, "is analyzing chunk", gidx, "/", nb_chunks, " ..."
        # # We need to deal with the borders by taking chunks of size [0, chunck_size + template_shift].

        is_first = data_file.is_first_chunk(gidx, nb_chunks)
        is_last = data_file.is_last_chunk(gidx, nb_chunks)

        if not (is_first and is_last):
            if is_last:
                padding = (-temp_3_shift, 0)
            elif is_first:
                padding = (0, temp_3_shift)
            else:
                padding = (-temp_3_shift, temp_3_shift)
        else:
            padding = (0, 0)

        result = {
            'spiketimes': [],
            'amplitudes': [],
            'templates': [],
        }

        if mse_error:
            mse_fit = {
            'spiketimes': [],
            'amplitudes': [],
            'templates': [],
        }

        result_debug = {
            'chunk_nbs': [],
            'iteration_nbs': [],
            'peak_nbs': [],
            'peak_local_time_steps': [],
            'peak_time_steps': [],
            'peak_scalar_products': [],
            'peak_solved_flags': [],
            'template_nbs': [],
            'success_flags': [],
        }

        local_chunk, t_offset = data_file.get_data(gidx, chunk_size, padding, nodes=nodes)           
        len_chunk = len(local_chunk)
        if is_last:
            my_chunk_size = last_chunk_size
        else:
            my_chunk_size = chunk_size

        if do_spatial_whitening:
            local_chunk = numpy.dot(local_chunk, spatial_whitening)
        if do_temporal_whitening:
            local_chunk = scipy.ndimage.filters.convolve1d(local_chunk, temporal_whitening, axis=0, mode='constant')

        # Extracting peaks.

        all_found_spikes = {}
        if collect_all:
            for i in range(n_e):
                all_found_spikes[i] = []

        local_peaktimes = [numpy.empty(0, dtype=numpy.uint32)]

        if ignore_artefacts:
            artefacts_peaktimes = [numpy.zeros(0, dtype=numpy.uint32)]
            artefacts_elecs = [numpy.zeros(0, dtype=numpy.uint32)]
            artefacts_amps = [numpy.zeros(0, dtype=numpy.float32)]    

        if matched_filter:
            if sign_peaks in ['positive', 'both']:
                filter_chunk = scipy.ndimage.filters.convolve1d(local_chunk, waveform_pos, axis=0, mode='constant')
                for i in range(n_e):
                    peaktimes = scipy.signal.find_peaks(filter_chunk[:, i], height=matched_thresholds_pos[i])[0]

                    if ignore_artefacts:
                        artetimes = scipy.signal.find_peaks(numpy.abs(filter_chunk[:, i]), height=weird_thresh[i])[0]
                        to_keep = numpy.logical_not(numpy.in1d(peaktimes, artetimes))
                        peaktimes = peaktimes[to_keep]
                        artefacts_peaktimes.append(artetimes)
                        artefacts_elecs.append(i*numpy.ones(len(artetimes), dtype='uint32'))
                        artefacts_amps.append(local_chunk[artetimes, i])

                    local_peaktimes.append(peaktimes)
                    if collect_all:
                        all_found_spikes[i] += peaktimes.tolist()
            if sign_peaks in ['negative', 'both']:
                filter_chunk = scipy.ndimage.filters.convolve1d(local_chunk, waveform_neg, axis=0, mode='constant')
                for i in range(n_e):
                    peaktimes = scipy.signal.find_peaks(filter_chunk[:, i], height=matched_thresholds_neg[i])[0]

                    if ignore_artefacts:
                        artetimes = scipy.signal.find_peaks(numpy.abs(filter_chunk[:, i]), height=weird_thresh[i])[0]
                        to_keep = numpy.logical_not(numpy.in1d(peaktimes, artetimes))
                        peaktimes = peaktimes[to_keep]
                        artefacts_peaktimes.append(artetimes)
                        artefacts_elecs.append(i*numpy.ones(len(artetimes), dtype='uint32'))
                        artefacts_amps.append(local_chunk[artetimes, i])

                    local_peaktimes.append(peaktimes)
                    if collect_all:
                        all_found_spikes[i] += peaktimes.tolist()
            local_peaktimes = numpy.concatenate(local_peaktimes)

            if ignore_artefacts:
                artefacts_peaktimes = numpy.concatenate(artefacts_peaktimes)
                artefacts_elecs = numpy.concatenate(artefacts_elecs)
                artefacts_amps = numpy.concatenate(artefacts_amps)
        else:
            for i in range(n_e):
                if sign_peaks == 'negative':
                    peaktimes = scipy.signal.find_peaks(-local_chunk[:, i], height=thresholds[i])[0]
                elif sign_peaks == 'positive':
                    peaktimes = scipy.signal.find_peaks(local_chunk[:, i], height=thresholds[i])[0]
                elif sign_peaks == 'both':
                    peaktimes = scipy.signal.find_peaks(numpy.abs(local_chunk[:, i]), height=thresholds[i])[0]
                else:
                    raise ValueError("Unexpected value %s" % sign_peaks)

                if ignore_artefacts:
                    artetimes = scipy.signal.find_peaks(numpy.abs(local_chunk[:, i]), height=weird_thresh[i])[0]
                    to_keep = numpy.logical_not(numpy.in1d(peaktimes, artetimes))
                    peaktimes = peaktimes[to_keep]
                    artefacts_peaktimes.append(artetimes)
                    artefacts_elecs.append(i*numpy.ones(len(artetimes), dtype='uint32'))
                    artefacts_amps.append(local_chunk[artetimes, i])

                local_peaktimes.append(peaktimes)
                if collect_all:
                    all_found_spikes[i] += peaktimes.tolist()
            local_peaktimes = numpy.concatenate(local_peaktimes)

            if ignore_artefacts:
                artefacts_peaktimes = numpy.concatenate(artefacts_peaktimes)
                artefacts_elecs = numpy.concatenate(artefacts_elecs)
                artefacts_amps = numpy.concatenate(artefacts_amps)

        local_peaktimes = numpy.unique(local_peaktimes)
        g_offset = t_offset + padding[0]

        if ignore_dead_times:
            dead_indices = numpy.searchsorted(all_dead_times, [t_offset, t_offset + my_chunk_size])
            if dead_indices[0] != dead_indices[1]:
                is_included = numpy.in1d(local_peaktimes + g_offset, all_dead_times[dead_indices[0]:dead_indices[1]])
                local_peaktimes = local_peaktimes[~is_included]

                if ignore_artefacts:
                    is_included = numpy.in1d(artefacts_peaktimes + g_offset, all_dead_times[dead_indices[0]:dead_indices[1]])
                    artefacts_peaktimes = artefacts_peaktimes[~is_included]
                    artefacts_elecs = artefacts_elecs[~is_included]
                    artefacts_amps = artefacts_amps[~is_included]  

                local_peaktimes = numpy.sort(local_peaktimes)
        else:
            dead_indices = None  # default assignment (for PyCharm code inspection)

        # print "Removing the useless borders..."
        local_borders = (template_shift, len_chunk - template_shift)
        idx = (local_peaktimes >= local_borders[0]) & (local_peaktimes < local_borders[1])
        local_peaktimes = numpy.compress(idx, local_peaktimes)

        if ignore_artefacts:
            artefacts_peaktimes = artefacts_peaktimes + g_offset
            idx = (artefacts_peaktimes >= t_offset) & (artefacts_peaktimes < t_offset + my_chunk_size)
            artefacts_peaktimes = numpy.compress(idx, artefacts_peaktimes)
            artefacts_elecs = numpy.compress(idx, artefacts_elecs)
            artefacts_amps = numpy.compress(idx, artefacts_amps)

        if collect_all:
            for i in range(n_e):
                all_found_spikes[i] = numpy.array(all_found_spikes[i], dtype=numpy.uint32)

                if ignore_dead_times:
                    if dead_indices[0] != dead_indices[1]:
                        is_included = numpy.in1d(
                            all_found_spikes[i] + g_offset, all_dead_times[dead_indices[0]:dead_indices[1]]
                        )
                        all_found_spikes[i] = all_found_spikes[i][~is_included]
                        all_found_spikes[i] = numpy.sort(all_found_spikes[i])

                idx = (all_found_spikes[i] >= local_borders[0]) & (all_found_spikes[i] < local_borders[1])
                all_found_spikes[i] = numpy.compress(idx, all_found_spikes[i])

        nb_local_peak_times = len(local_peaktimes)

        if nb_local_peak_times > 0:
            # print "Computing the b (should full_gpu by putting all chunks on GPU if possible?)..."

            if collect_all or mse_error:
                c_local_chunk = local_chunk.copy()
            else:
                c_local_chunk = None  # default assignment (for PyCharm code inspection)

            sub_mat = local_chunk[local_peaktimes[:, None] + temp_window]
            sub_mat = sub_mat.transpose(2, 1, 0).reshape(size_window, nb_local_peak_times)

            del local_chunk

            b = templates.dot(sub_mat)
    
            del sub_mat

            local_restriction = (t_offset, t_offset + my_chunk_size)
            all_spikes = local_peaktimes + g_offset

            if collect_all:
                c_all_times = numpy.zeros((len_chunk, n_e), dtype=numpy.bool)
                c_min_times = numpy.maximum(numpy.arange(len_chunk) - template_shift, 0)
                c_max_times = numpy.minimum(numpy.arange(len_chunk) + template_shift + 1, len_chunk)
                for i in range(n_e):
                    c_all_times[all_found_spikes[i], i] = True
            else:
                c_all_times = None  # default assignment (for PyCharm code inspection)
                c_min_times = None  # default assignment (for PyCharm code inspection)
                c_max_times = None  # default assignment (for PyCharm code inspection)

            iteration_nb = 0
            data = b[:n_tm, :]

            if not fixed_amplitudes:
                amp_index = numpy.searchsorted(splits, local_restriction[0], 'right')
                scaling = 1/(splits[amp_index] - splits[amp_index - 1])
                min_scalar_products = amp_limits[:, amp_index, 0] + (amp_limits[:, amp_index, 0] - amp_limits[:, amp_index+1, 0])*scaling
                max_scalar_products = amp_limits[:, amp_index, 1] + (amp_limits[:, amp_index, 1] - amp_limits[:, amp_index+1, 0])*scaling
                    
                min_scalar_products = min_scalar_products[:, numpy.newaxis]
                max_scalar_products = max_scalar_products[:, numpy.newaxis]

                if templates_normalization:
                    min_sps = min_scalar_products * sub_norm_templates[:, numpy.newaxis]
                    max_sps = max_scalar_products * sub_norm_templates[:, numpy.newaxis]
                else:
                    min_sps = min_scalar_products * sub_norm_templates_2[:, numpy.newaxis]
                    max_sps = max_scalar_products * sub_norm_templates_2[:, numpy.newaxis]

            while True:

                is_valid = (data > min_sps)*(data < max_sps)
                valid_indices = numpy.where(is_valid)

                if len(valid_indices[0]) == 0:
                    break

                best_amplitude_idx = data[is_valid].argmax()

                best_template_index, peak_index = valid_indices[0][best_amplitude_idx], valid_indices[1][best_amplitude_idx]
                peak_scalar_product = data[is_valid][best_amplitude_idx]

                best_template2_index = best_template_index + n_tm
                if templates_normalization:
                    best_amp = b[best_template_index, peak_index] / n_scalar
                    best_amp_n = best_amp / norm_templates[best_template_index]
                    if two_components:
                        best_amp2 = b[best_template2_index, peak_index] / n_scalar
                        best_amp2_n = best_amp2 /  norm_templates[best_template2_index]
                    else:
                        best_amp2 = 0
                        best_amp2_n = 0
                else:
                    best_amp = b[best_template_index, peak_index] / norm_templates_2[best_template_index]
                    best_amp_n = best_amp
                    if two_components:     
                        best_amp2 = b[best_template2_index, peak_index] / norm_templates_2[best_template2_index]
                        best_amp2_n = best_amp2
                    else:
                        best_amp2 = 0
                        best_amp2_n = 0

                peak_time_step = local_peaktimes[peak_index]

                peak_data = (local_peaktimes - peak_time_step).astype(np.int32)
                is_neighbor = np.abs(peak_data) <= temp_2_shift
                idx_neighbor = peak_data[is_neighbor] + temp_2_shift

                tmp1 = c_overs[best_template_index].multiply(-best_amp)
                if numpy.abs(best_amp2_n) > min_second_component:
                    tmp1 += c_overs[best_template2_index].multiply(-best_amp2)

                to_add = tmp1.toarray()[:, idx_neighbor]
                b[:, is_neighbor] += to_add

                b[best_template_index, peak_index] = -numpy.inf

                # Add matching to the result.
                t_spike = all_spikes[peak_index]

                if (t_spike >= local_restriction[0]) and (t_spike < local_restriction[1]):
                    result['spiketimes'] += [t_spike]
                    result['amplitudes'] += [(best_amp_n, best_amp2_n)]
                    result['templates'] += [best_template_index]
                elif mse_error:
                    mse_fit['spiketimes'] += [t_spike]
                    mse_fit['amplitudes'] += [(best_amp_n, best_amp2_n)]
                    mse_fit['templates'] += [best_template_index]

                # Save debug data.
                if debug:
                    result_debug['chunk_nbs'] += [gidx]
                    result_debug['iteration_nbs'] += [iteration_nb]
                    result_debug['peak_nbs'] += [peak_index]
                    result_debug['peak_local_time_steps'] += [local_peaktimes[peak_index]]
                    result_debug['peak_time_steps'] += [all_spikes[peak_index]]
                    result_debug['peak_scalar_products'] += [peak_scalar_product]
                    result_debug['peak_solved_flags'] += [b[best_template_index, peak_index]]
                    result_debug['template_nbs'] += [best_template_index]
                    result_debug['success_flags'] += [True]

                iteration_nb += 1

            spikes_to_write = numpy.array(result['spiketimes'], dtype=numpy.uint32)
            amplitudes_to_write = numpy.array(result['amplitudes'], dtype=numpy.float32)
            templates_to_write = numpy.array(result['templates'], dtype=numpy.uint32)

            spiketimes_file.write(spikes_to_write.tostring())
            amplitudes_file.write(amplitudes_to_write.tostring())
            templates_file.write(templates_to_write.tostring())

            if ignore_artefacts:
                arte_spiketimes_file.write(artefacts_peaktimes.astype(numpy.uint32).tostring())
                arte_electrodes_file.write(artefacts_elecs.tostring())
                arte_amplitudes_file.write(artefacts_amps.tostring())

            if mse_error:
                curve = numpy.zeros((len_chunk, n_e), dtype=numpy.float32)
                for spike, temp_id, amplitude in zip(result['spiketimes'], result['templates'], result['amplitudes']):
                    spike = spike - t_offset - padding[0]
                    if is_sparse:
                        tmp1 = templates[temp_id].toarray().reshape(n_e, n_t)
                        tmp2 = templates[temp_id + n_tm].toarray().reshape(n_e, n_t)
                    else:
                        tmp1 = templates[temp_id].reshape(n_e, n_t)
                        tmp2 = templates[temp_id + n_tm].reshape(n_e, n_t)

                    curve[spike - template_shift:spike + template_shift + 1, :] += (amplitude[0] * tmp1 + amplitude[1] * tmp2).T

                for spike, temp_id, amplitude in zip(mse_fit['spiketimes'], mse_fit['templates'], mse_fit['amplitudes']):
                    spike = spike - t_offset + padding[0]
                    if is_sparse:
                        tmp1 = templates[temp_id].toarray().reshape(n_e, n_t)
                        tmp2 = templates[temp_id + n_tm].toarray().reshape(n_e, n_t)
                    else:
                        tmp1 = templates[temp_id].reshape(n_e, n_t)
                        tmp2 = templates[temp_id + n_tm].reshape(n_e, n_t)
                    try:
                        curve[int(spike) - template_shift:int(spike) + template_shift + 1, :] += (amplitude[0] * tmp1 + amplitude[1] * tmp2).T
                    except Exception:
                        pass
                mse = numpy.linalg.norm((curve - c_local_chunk)[-padding[0]:-padding[1]])
                nb_points = len(curve) - (padding[1] - padding[0])
                mse_ratio = mse/(numpy.sqrt(nb_points)*stds_norm)
                mse_to_write = numpy.array([g_offset, mse_ratio], dtype=numpy.float32)
                mse_file.write(mse_to_write.tostring())

            if collect_all:

                for temp, spike in zip(templates_to_write, spikes_to_write - g_offset):
                    c_all_times[c_min_times[spike]:c_max_times[spike], neighbors[temp]] = False

                gspikes = numpy.where(numpy.sum(c_all_times, 1) > 0)[0]
                c_all_times = numpy.take(c_all_times, gspikes, axis=0)
                c_local_chunk = numpy.take(c_local_chunk, gspikes, axis=0) * c_all_times                

                if sign_peaks == 'negative':
                    bestlecs = numpy.argmin(c_local_chunk, 1)
                    if matched_filter:
                        threshs = -matched_thresholds_neg[bestlecs]
                    else:
                        threshs = -thresholds[bestlecs]
                    idx = numpy.where(numpy.min(c_local_chunk, 1) < threshs)[0]
                elif sign_peaks == 'positive':
                    bestlecs = numpy.argmax(c_local_chunk, 1)
                    if matched_filter:
                        threshs = matched_thresholds_pos[bestlecs]
                    else:
                        threshs = thresholds[bestlecs]
                    idx = numpy.where(numpy.max(c_local_chunk, 1) > threshs)[0]
                elif sign_peaks == 'both':
                    c_local_chunk = numpy.abs(c_local_chunk)
                    bestlecs = numpy.argmax(c_local_chunk, 1)
                    if matched_filter:
                        threshs = numpy.minimum(matched_thresholds_neg[bestlecs], matched_thresholds_pos[bestlecs])
                    else:
                        threshs = thresholds[bestlecs]
                    idx = numpy.where(numpy.max(c_local_chunk, 1) > threshs)[0]
                else:
                    raise ValueError("Unexpected value %s" % sign_peaks)

                gspikes = numpy.take(gspikes, idx)
                bestlecs = numpy.take(bestlecs, idx)
                gspikes_to_write = numpy.array(gspikes + g_offset, dtype=numpy.uint32)
                gtemplates_to_write = numpy.array(bestlecs, dtype=numpy.uint32)

                garbage_times_file.write(gspikes_to_write.tostring())
                garbage_temp_file.write(gtemplates_to_write.tostring())

            if debug:
                # Write debug data to debug files.
                for field_label, field_dtype, field_file in [
                    ('chunk_nbs', numpy.uint32, chunk_nbs_debug_file),
                    ('iteration_nbs', numpy.uint32, iteration_nbs_debug_file),
                    ('peak_nbs', numpy.uint32, peak_nbs_debug_file),
                    ('peak_local_time_steps', numpy.uint32, peak_local_time_steps_debug_file),
                    ('peak_time_steps', numpy.uint32, peak_time_steps_debug_file),
                    ('peak_scalar_products', numpy.float32, peak_scalar_products_debug_file),
                    ('peak_solved_flags', numpy.float32, peak_solved_flags_debug_file),
                    ('template_nbs', numpy.uint32, template_nbs_debug_file),
                    ('success_flags', numpy.bool, success_flags_debug_file),
                ]:
                    field_to_write = numpy.array(result_debug[field_label], dtype=field_dtype)
                    field_file.write(field_to_write.tostring())

    sys.stderr.flush()

    spiketimes_file.flush()
    os.fsync(spiketimes_file.fileno())
    spiketimes_file.close()

    amplitudes_file.flush()
    os.fsync(amplitudes_file.fileno())
    amplitudes_file.close()

    templates_file.flush()
    os.fsync(templates_file.fileno())
    templates_file.close()

    if collect_all:

        garbage_temp_file.flush()
        os.fsync(garbage_temp_file.fileno())
        garbage_temp_file.close()
        
        garbage_times_file.flush()
        os.fsync(garbage_times_file.fileno())
        garbage_times_file.close()

    if mse_error:
        mse_file.flush()
        os.fsync(mse_file.fileno())
        mse_file.close()

    if ignore_artefacts:
        arte_spiketimes_file.flush()
        os.fsync(arte_spiketimes_file.fileno())
        arte_spiketimes_file.close()

        arte_electrodes_file.flush()
        os.fsync(arte_electrodes_file.fileno())
        arte_electrodes_file.close()

        arte_amplitudes_file.flush()
        os.fsync(arte_amplitudes_file.fileno())
        arte_amplitudes_file.close()

    if debug:
        # Close debug files.
        for field_file in [
            chunk_nbs_debug_file,
            iteration_nbs_debug_file,
            peak_nbs_debug_file,
            peak_local_time_steps_debug_file,
            peak_time_steps_debug_file,
            peak_scalar_products_debug_file,
            peak_solved_flags_debug_file,
            template_nbs_debug_file,
            success_flags_debug_file,
        ]:
            field_file.flush()
            os.fsync(field_file.fileno())
            field_file.close()

    comm.Barrier()

    if SHARED_MEMORY:
        for memory in mpi_memory_1 + mpi_memory_2:
            memory.Free()
        if ignore_dead_times:
            mpi_memory_3.Free()

    if comm.rank == 0:
        io.collect_data(comm.size, params, erase=True)

        if ignore_artefacts:
            io.collect_artefacts(comm.size, params, erase=True)

    data_file.close()