Exemplo n.º 1
0
def main(params, nb_cpu, nb_gpu, use_gpu):

    #################################################################
    #params         = detect_memory(params)
    logger = init_logging(params.logfile)
    SHARED_MEMORY = get_shared_memory_flag(params)
    logger = logging.getLogger('circus.fitting')
    data_file = params.data_file
    data_file.open()
    N_e = params.getint('data', 'N_e')
    N_total = params.nb_channels
    N_t = params.getint('detection', 'N_t')
    template_shift = params.getint('detection', 'template_shift')
    file_out = params.get('data', 'file_out')
    file_out_suff = params.get('data', 'file_out_suff')
    sign_peaks = params.get('detection', 'peaks')
    matched_filter = params.getboolean('detection', 'matched-filter')
    spike_thresh = params.getfloat('detection', 'spike_thresh')
    do_temporal_whitening = params.getboolean('whitening', 'temporal')
    do_spatial_whitening = params.getboolean('whitening', 'spatial')
    chunk_size = params.getint('fitting', 'chunk_size')
    gpu_only = params.getboolean('fitting', 'gpu_only')
    nodes, edges = get_nodes_and_edges(params)
    tmp_limits = params.get('fitting',
                            'amp_limits').replace('(',
                                                  '').replace(')',
                                                              '').split(',')
    tmp_limits = map(float, tmp_limits)
    amp_auto = params.getboolean('fitting', 'amp_auto')
    nb_chances = params.getint('fitting', 'nb_chances')
    max_chunk = params.getfloat('fitting', 'max_chunk')
    noise_thr = params.getfloat('clustering', 'noise_thr')
    collect_all = params.getboolean('fitting', 'collect_all')
    ignore_dead_times = params.getboolean('triggers', 'ignore_times')
    inv_nodes = numpy.zeros(N_total, dtype=numpy.int32)
    inv_nodes[nodes] = numpy.argsort(nodes)
    #################################################################

    if use_gpu:
        import cudamat as cmt
        ## Need to properly handle multi GPU per MPI nodes?
        if nb_gpu > nb_cpu:
            gpu_id = int(comm.rank // nb_cpu)
        else:
            gpu_id = 0
        cmt.cuda_set_device(gpu_id)
        cmt.init()
        cmt.cuda_sync_threads()

    if SHARED_MEMORY:
        templates = io.load_data_memshared(params,
                                           'templates',
                                           normalize=True,
                                           transpose=True)
        N_tm, x = templates.shape
    else:
        templates = io.load_data(params, 'templates')
        x, N_tm = templates.shape

    temp_2_shift = 2 * template_shift
    full_gpu = use_gpu and gpu_only
    n_tm = N_tm // 2
    n_scalar = N_e * N_t

    temp_window = numpy.arange(-template_shift, template_shift + 1)
    size_window = N_e * (2 * template_shift + 1)

    if not amp_auto:
        amp_limits = numpy.zeros((n_tm, 2))
        amp_limits[:, 0] = tmp_limits[0]
        amp_limits[:, 1] = tmp_limits[1]
    else:
        amp_limits = io.load_data(params, 'limits')

    norm_templates = io.load_data(params, 'norm-templates')

    if not SHARED_MEMORY:
        for idx in xrange(templates.shape[1]):
            myslice = numpy.arange(templates.indptr[idx],
                                   templates.indptr[idx + 1])
            templates.data[myslice] /= norm_templates[idx]
        templates = templates.T

    if matched_filter:
        if sign_peaks in ['negative', 'both']:
            waveform_neg = io.load_data(params, 'waveform')
            waveform_neg /= (numpy.abs(numpy.sum(waveform_neg)) *
                             len(waveform_neg))
            matched_tresholds_neg = io.load_data(params, 'matched-thresholds')
        if sign_peaks in ['positive', 'both']:
            waveform_pos = io.load_data(params, 'waveform-pos')
            waveform_pos /= (numpy.abs(numpy.sum(waveform_pos)) *
                             len(waveform_pos))
            matched_tresholds_pos = io.load_data(params,
                                                 'matched-thresholds-pos')

    if ignore_dead_times:
        all_dead_times = get_dead_times(params)

    thresholds = io.load_data(params, 'thresholds')

    if collect_all:
        neighbors = {}
        for i in xrange(n_tm):
            tmp = templates[i, :].toarray().reshape(N_e,
                                                    N_t) * norm_templates[i]
            neighbors[i] = numpy.where(numpy.sum(tmp, 1) != 0)[0]

    if use_gpu:
        templates = cmt.SparseCUDAMatrix(templates, copy_on_host=False)

    info_string = ''

    if comm.rank == 0:
        if use_gpu:
            info_string = "using %d GPUs" % (comm.size)
        else:
            info_string = "using %d CPUs" % (comm.size)

    comm.Barrier()

    c_overlap = io.get_overlaps(params,
                                nb_cpu=nb_cpu,
                                nb_gpu=nb_gpu,
                                use_gpu=use_gpu)
    over_shape = c_overlap.get('over_shape')[:]
    N_over = int(numpy.sqrt(over_shape[0]))
    S_over = over_shape[1]
    ## If the number of overlaps is different from templates, we need to recompute them
    if N_over != N_tm:
        if comm.rank == 0:
            print_and_log(
                ['Templates have been modified, recomputing the overlaps...'],
                'default', logger)
        c_overlap = io.get_overlaps(params,
                                    erase=True,
                                    nb_cpu=nb_cpu,
                                    nb_gpu=nb_gpu,
                                    use_gpu=use_gpu)
        over_shape = c_overlap.get('over_shape')[:]
        N_over = int(numpy.sqrt(over_shape[0]))
        S_over = over_shape[1]

    if SHARED_MEMORY:
        c_overs = io.load_data_memshared(params, 'overlaps')
    else:
        c_overs = io.load_data(params, 'overlaps')

    comm.Barrier()

    if n_tm == 0:
        if comm.rank == 0:
            print_and_log(["No templates present. Redo clustering?"],
                          'default', logger)

        sys.exit(0)

    if comm.rank == 0:
        print_and_log([
            "Here comes the SpyKING CIRCUS %s and %d templates..." %
            (info_string, n_tm)
        ], 'default', logger)
        purge(file_out_suff, '.data')

    if do_spatial_whitening:
        spatial_whitening = io.load_data(params, 'spatial_whitening')
    if do_temporal_whitening:
        temporal_whitening = io.load_data(params, 'temporal_whitening')

    if full_gpu:
        try:
            # If memory on the GPU is large enough, we load the overlaps onto it
            for i in xrange(N_over):
                c_overs[i] = cmt.SparseCUDAMatrix(c_overs[i],
                                                  copy_on_host=False)
        except Exception:
            if comm.rank == 0:
                print_and_log([
                    "Not enough memory on GPUs: GPUs are used for projection only"
                ], 'info', logger)
            for i in xrange(N_over):
                if c_overs.has_key(i):
                    del c_overs[i]
            full_gpu = False

    nb_chunks, last_chunk_len = data_file.analyze(chunk_size)
    processed_chunks = int(min(nb_chunks, max_chunk))

    comm.Barrier()
    spiketimes_file = open(file_out_suff + '.spiketimes-%d.data' % comm.rank,
                           'wb')
    comm.Barrier()
    amplitudes_file = open(file_out_suff + '.amplitudes-%d.data' % comm.rank,
                           'wb')
    comm.Barrier()
    templates_file = open(file_out_suff + '.templates-%d.data' % comm.rank,
                          'wb')
    comm.Barrier()

    if collect_all:
        garbage_times_file = open(
            file_out_suff + '.gspiketimes-%d.data' % comm.rank, 'wb')
        comm.Barrier()
        garbage_temp_file = open(
            file_out_suff + '.gtemplates-%d.data' % comm.rank, 'wb')
        comm.Barrier()

    if use_gpu and do_spatial_whitening:
        spatial_whitening = cmt.CUDAMatrix(spatial_whitening,
                                           copy_on_host=False)

    last_chunk_size = 0

    to_explore = xrange(comm.rank, processed_chunks, comm.size)

    if comm.rank == 0:
        to_explore = get_tqdm_progressbar(to_explore)

    for gcount, gidx in enumerate(to_explore):
        #print "Node", comm.rank, "is analyzing chunk", gidx, "/", nb_chunks, " ..."
        ## We need to deal with the borders by taking chunks of size [0, chunck_size+template_shift]

        is_first = data_file.is_first_chunk(gidx, nb_chunks)
        is_last = data_file.is_last_chunk(gidx, nb_chunks)

        if is_last:
            padding = (-temp_2_shift, 0)
        elif is_first:
            padding = (0, temp_2_shift)
        else:
            padding = (-temp_2_shift, temp_2_shift)

        result = {'spiketimes': [], 'amplitudes': [], 'templates': []}

        local_chunk, t_offset = data_file.get_data(gidx,
                                                   chunk_size,
                                                   padding,
                                                   nodes=nodes)
        len_chunk = len(local_chunk)

        if do_spatial_whitening:
            if use_gpu:
                local_chunk = cmt.CUDAMatrix(local_chunk, copy_on_host=False)
                local_chunk = local_chunk.dot(spatial_whitening).asarray()
            else:
                local_chunk = numpy.dot(local_chunk, spatial_whitening)
        if do_temporal_whitening:
            local_chunk = scipy.ndimage.filters.convolve1d(local_chunk,
                                                           temporal_whitening,
                                                           axis=0,
                                                           mode='constant')

        #print "Extracting the peaks..."

        if collect_all:
            all_found_spikes = {}
            for i in xrange(N_e):
                all_found_spikes[i] = []

        local_peaktimes = numpy.zeros(0, dtype=numpy.uint32)

        if matched_filter:
            if sign_peaks in ['positive', 'both']:
                filter_chunk = scipy.ndimage.filters.convolve1d(
                    local_chunk, waveform_pos, axis=0, mode='constant')
                for i in xrange(N_e):
                    peaktimes = algo.detect_peaks(filter_chunk[:, i],
                                                  matched_tresholds_pos[i])
                    local_peaktimes = numpy.concatenate(
                        (local_peaktimes, peaktimes))
                    if collect_all:
                        all_found_spikes[i] += peaktimes.tolist()
            if sign_peaks in ['negative', 'both']:
                filter_chunk = scipy.ndimage.filters.convolve1d(
                    local_chunk, waveform_neg, axis=0, mode='constant')
                for i in xrange(N_e):
                    peaktimes = algo.detect_peaks(filter_chunk[:, i],
                                                  matched_tresholds_neg[i])
                    local_peaktimes = numpy.concatenate(
                        (local_peaktimes, peaktimes))
                    if collect_all:
                        all_found_spikes[i] += peaktimes.tolist()
        else:
            for i in xrange(N_e):
                if sign_peaks == 'negative':
                    peaktimes = algo.detect_peaks(local_chunk[:, i],
                                                  thresholds[i],
                                                  valley=True)
                elif sign_peaks == 'positive':
                    peaktimes = algo.detect_peaks(local_chunk[:, i],
                                                  thresholds[i],
                                                  valley=False)
                elif sign_peaks == 'both':
                    peaktimes = algo.detect_peaks(numpy.abs(local_chunk[:, i]),
                                                  thresholds[i],
                                                  valley=False)
                local_peaktimes = numpy.concatenate(
                    (local_peaktimes, peaktimes))
                if collect_all:
                    all_found_spikes[i] += peaktimes.tolist()

        local_peaktimes = numpy.unique(local_peaktimes)

        g_offset = t_offset + padding[0]

        if ignore_dead_times:
            dead_indices = numpy.searchsorted(
                all_dead_times, [t_offset, t_offset + chunk_size])
            if dead_indices[0] != dead_indices[1]:
                local_peaktimes = numpy.array(list(
                    set(local_peaktimes + g_offset).difference(
                        all_dead_times[dead_indices[0]:dead_indices[1]])),
                                              dtype=numpy.uint32) - g_offset
                local_peaktimes = numpy.sort(local_peaktimes)

        #print "Removing the useless borders..."
        local_borders = (template_shift, len_chunk - template_shift)
        idx = (local_peaktimes >= local_borders[0]) & (local_peaktimes <
                                                       local_borders[1])
        local_peaktimes = numpy.compress(idx, local_peaktimes)

        if collect_all:
            for i in xrange(N_e):
                all_found_spikes[i] = numpy.array(all_found_spikes[i],
                                                  dtype=numpy.uint32)

                if ignore_dead_times:
                    if dead_indices[0] != dead_indices[1]:
                        all_found_spikes[i] = numpy.array(
                            list(
                                set(all_found_spikes[i] + g_offset).difference(
                                    all_dead_times[
                                        dead_indices[0]:dead_indices[1]])),
                            dtype=numpy.uint32) - g_offset
                        all_found_spikes[i] = numpy.sort(all_found_spikes[i])

                idx = (all_found_spikes[i] >= local_borders[0]) & (
                    all_found_spikes[i] < local_borders[1])
                all_found_spikes[i] = numpy.compress(idx, all_found_spikes[i])

        n_t = len(local_peaktimes)

        if full_gpu:
            #   all_indices = cmt.CUDAMatrix(all_indices)
            tmp_gpu = cmt.CUDAMatrix(local_peaktimes.reshape((1, n_t)),
                                     copy_on_host=False)

        if n_t > 0:
            #print "Computing the b (should full_gpu by putting all chunks on GPU if possible?)..."

            if collect_all:
                c_local_chunk = local_chunk.copy()

            local_chunk = local_chunk.T.ravel()
            sub_mat = numpy.zeros((size_window, n_t), dtype=numpy.float32)

            if len_chunk != last_chunk_size:
                slice_indices = numpy.zeros(0, dtype=numpy.int32)
                for idx in xrange(N_e):
                    slice_indices = numpy.concatenate(
                        (slice_indices, len_chunk * idx + temp_window))
                last_chunk_size = len_chunk

            for count, idx in enumerate(local_peaktimes):
                sub_mat[:, count] = numpy.take(local_chunk,
                                               slice_indices + idx)

            del local_chunk

            if use_gpu:
                sub_mat = cmt.CUDAMatrix(sub_mat, copy_on_host=False)
                b = cmt.sparse_dot(templates, sub_mat)
            else:
                b = templates.dot(sub_mat)

            del sub_mat

            local_restriction = (t_offset, t_offset + chunk_size)
            all_spikes = local_peaktimes + g_offset

            # Because for GPU, slicing by columns is more efficient, we need to transpose b
            #b           = b.transpose()
            if use_gpu and not full_gpu:
                b = b.asarray()

            failure = numpy.zeros(n_t, dtype=numpy.int32)

            if full_gpu:
                mask = numpy.zeros((2 * n_tm, n_t), dtype=numpy.float32)
                mask[:n_tm, :] = 1
                data = cmt.empty(mask.shape)
                patch_gpu = b.shape[1] == 1
            else:
                mask = numpy.ones((n_tm, n_t), dtype=numpy.float32)
                sub_b = b[:n_tm, :]

            if collect_all:
                c_all_times = numpy.zeros((len_chunk, N_e), dtype=numpy.bool)
                c_min_times = numpy.maximum(
                    numpy.arange(len_chunk) - template_shift, 0)
                c_max_times = numpy.minimum(
                    numpy.arange(len_chunk) + template_shift + 1, len_chunk)
                for i in xrange(N_e):
                    c_all_times[all_found_spikes[i], i] = True

            while (numpy.mean(failure) < nb_chances):

                if full_gpu:
                    gpu_mask = cmt.CUDAMatrix(mask, copy_on_host=False)
                    b.mult(gpu_mask, data)
                    tmp_mat = data.max(0)
                    argmax_bi = numpy.argsort(tmp_mat.asarray()[0, :])[::-1]
                    del tmp_mat
                else:
                    data = sub_b * mask
                    argmax_bi = numpy.argsort(numpy.max(data, 0))[::-1]

                for peak_index in argmax_bi:

                    if full_gpu:
                        b_array = b.asarray()
                        sub_b = b_array[:n_tm, :]

                    peak_scalar_products = np.take(sub_b, peak_index, axis=1)
                    best_template_index = np.argmax(peak_scalar_products,
                                                    axis=0)
                    best_template2_index = best_template_index + n_tm

                    if full_gpu:
                        best_amp = sub_b[best_template_index,
                                         peak_index] / n_scalar
                        best_amp2 = b_array[best_template_index,
                                            peak_index] / n_scalar
                    else:
                        best_amp = sub_b[best_template_index,
                                         peak_index] / n_scalar
                        best_amp2 = b[best_template2_index,
                                      peak_index] / n_scalar

                    best_amp_n = best_amp / norm_templates[best_template_index]
                    best_amp2_n = best_amp2 / norm_templates[
                        best_template2_index]

                    # Verify amplitude constraint.
                    a_min = amp_limits[best_template_index, 0]
                    a_max = amp_limits[best_template_index, 1]

                    if (a_min <= best_amp_n) & (best_amp_n <= a_max):
                        # Keep the matching.
                        peak_time_step = local_peaktimes[peak_index]

                        data = (local_peaktimes - peak_time_step).astype(
                            np.int32)
                        is_neighbor = np.where(np.abs(data) <= temp_2_shift)[0]
                        idx_neighbor = data[is_neighbor] + temp_2_shift
                        nb_neighbors = len(is_neighbor)
                        indices = np.zeros((S_over, nb_neighbors),
                                           dtype=np.int32)
                        indices[idx_neighbor, np.arange(nb_neighbors)] = 1

                        if full_gpu:
                            indices = cmt.CUDAMatrix(indices,
                                                     copy_on_host=False)
                            if patch_gpu:
                                b_lines = b.get_col_slice(0, b.shape[0])
                            else:
                                b_lines = b.get_col_slice(
                                    is_neighbor[0], is_neighbor[-1] + 1)

                            tmp1 = cmt.sparse_dot(c_overs[best_template_index],
                                                  indices,
                                                  mult=-best_amp[keep])
                            tmp2 = cmt.sparse_dot(
                                c_overs[best_template2_index],
                                indices,
                                mult=-best_amp2[keep])
                            b_lines.add(tmp1.add(tmp2))
                            del tmp1, tmp2
                        else:
                            tmp1 = c_overs[best_template_index].multiply(
                                -best_amp)
                            tmp2 = c_overs[best_template2_index].multiply(
                                -best_amp2)
                            b[:, is_neighbor] += (tmp1 + tmp2).dot(indices)
                        # Add matching to the result.
                        t_spike = all_spikes[peak_index]
                        if (t_spike >= local_restriction[0]) and (
                                t_spike < local_restriction[1]):
                            #print "Accept spikes", t_spike, local_restriction, type(t_spike), t_spike > local_restriction[0], t_spike < local_restriction[1]
                            result['spiketimes'] += [t_spike]
                            result['amplitudes'] += [(best_amp_n, best_amp2_n)]
                            result['templates'] += [best_template_index]
                        # Mark current matching as tried.
                        mask[best_template_index, peak_index] = 0
                    else:
                        # Reject the matching.
                        # Update failure counter of the peak.
                        failure[peak_index] += 1
                        # If the maximal number of failures is reached then mark peak as solved (i.e. not fitted).
                        if failure[peak_index] == nb_chances:
                            mask[:, peak_index] = 0
                        else:
                            mask[best_template_index, peak_index] = 0

            spikes_to_write = numpy.array(result['spiketimes'],
                                          dtype=numpy.uint32)
            amplitudes_to_write = numpy.array(result['amplitudes'],
                                              dtype=numpy.float32)
            templates_to_write = numpy.array(result['templates'],
                                             dtype=numpy.uint32)

            spiketimes_file.write(spikes_to_write.tostring())
            amplitudes_file.write(amplitudes_to_write.tostring())
            templates_file.write(templates_to_write.tostring())

            if collect_all:

                for temp, spike in zip(templates_to_write,
                                       spikes_to_write - g_offset):
                    c_all_times[c_min_times[spike]:c_max_times[spike],
                                neighbors[temp]] = False

                gspikes = numpy.where(numpy.sum(c_all_times, 1) > 0)[0]
                c_all_times = numpy.take(c_all_times, gspikes, axis=0)
                c_local_chunk = numpy.take(c_local_chunk, gspikes,
                                           axis=0) * c_all_times

                if sign_peaks == 'negative':
                    bestlecs = numpy.argmin(c_local_chunk, 1)
                    if matched_filter:
                        threshs = -matched_tresholds_neg[bestlecs]
                    else:
                        threshs = -thresholds[bestlecs]
                    idx = numpy.where(numpy.min(c_local_chunk, 1) < threshs)[0]
                elif sign_peaks == 'positive':
                    bestlecs = numpy.argmax(c_local_chunk, 1)
                    if matched_filter:
                        threshs = matched_tresholds_pos[bestlecs]
                    else:
                        threshs = thresholds[bestlecs]
                    idx = numpy.where(numpy.max(c_local_chunk, 1) > threshs)[0]
                elif sign_peaks == 'both':
                    c_local_chunk = numpy.abs(c_local_chunk)
                    bestlecs = numpy.argmax(c_local_chunk, 1)
                    if matched_filter:
                        threshs = numpy.minimum(
                            matched_tresholds_neg[bestlecs],
                            matched_tresholds_pos[bestlecs])
                    else:
                        threshs = thresholds[bestlecs]
                    idx = numpy.where(numpy.max(c_local_chunk, 1) > threshs)[0]

                gspikes = numpy.take(gspikes, idx)
                bestlecs = numpy.take(bestlecs, idx)
                gspikes_to_write = numpy.array(gspikes + g_offset,
                                               dtype=numpy.uint32)
                gtemplates_to_write = numpy.array(bestlecs, dtype=numpy.uint32)

                garbage_times_file.write(gspikes_to_write.tostring())
                garbage_temp_file.write(gtemplates_to_write.tostring())

            if full_gpu:
                del gpu_mask, b, data

    sys.stderr.flush()

    spiketimes_file.flush()
    os.fsync(spiketimes_file.fileno())
    spiketimes_file.close()

    amplitudes_file.flush()
    os.fsync(amplitudes_file.fileno())
    amplitudes_file.close()

    templates_file.flush()
    os.fsync(templates_file.fileno())
    templates_file.close()

    if collect_all:

        garbage_temp_file.flush()
        os.fsync(garbage_temp_file.fileno())
        garbage_temp_file.close()

        garbage_times_file.flush()
        os.fsync(garbage_times_file.fileno())
        garbage_times_file.close()

    comm.Barrier()

    if comm.rank == 0:
        io.collect_data(comm.size, params, erase=True)

    data_file.close()
Exemplo n.º 2
0
def main(params, nb_cpu, nb_gpu, use_gpu):
    # Part 1: Whitening
    numpy.random.seed(420)

    logger = init_logging(params.logfile)
    logger = logging.getLogger('circus.whitening')
    #################################################################
    data_file = params.data_file
    data_file.open()
    N_e = params.getint('data', 'N_e')
    N_total = params.nb_channels
    N_t = params.getint('detection', 'N_t')
    dist_peaks = params.getint('detection', 'dist_peaks')
    template_shift = params.getint('detection', 'template_shift')
    file_out_suff = params.get('data', 'file_out_suff')
    file_out = params.get('data', 'file_out')
    spike_thresh = params.getfloat('detection', 'spike_thresh')
    matched_filter = params.getboolean('detection', 'matched-filter')
    matched_thresh = params.getfloat('detection', 'matched_thresh')
    sign_peaks = params.get('detection', 'peaks')
    do_temporal_whitening = params.getboolean('whitening', 'temporal')
    do_spatial_whitening = params.getboolean('whitening', 'spatial')
    chunk_size = params.getint('whitening', 'chunk_size')
    plot_path = os.path.join(params.get('data', 'data_file_noext'), 'plots')
    nodes, edges = get_nodes_and_edges(params)
    safety_time = params.getint('whitening', 'safety_time')
    nb_temp_white = min(max(20, comm.size), N_e)
    max_silence_1 = int(20 * params.rate // comm.size)
    max_silence_2 = 5000
    inv_nodes = numpy.zeros(N_total, dtype=numpy.int32)
    inv_nodes[nodes] = numpy.argsort(nodes)
    #################################################################

    if comm.rank == 0:
        print_and_log(
            ["Analyzing data to get whitening matrices and thresholds..."],
            'default', logger)

    if use_gpu:
        import cudamat as cmt
        ## Need to properly handle multi GPU per MPI nodes?
        if nb_gpu > nb_cpu:
            gpu_id = int(comm.rank // nb_cpu)
        else:
            gpu_id = 0
        cmt.cuda_set_device(gpu_id)
        cmt.init()
        cmt.cuda_sync_threads()

    nb_chunks, last_chunk_len = data_file.analyze(chunk_size)

    if nb_chunks < comm.size:

        res = io.data_stats(params, show=False)
        chunk_size = int(res * params.rate // comm.size)
        if comm.rank == 0:
            print_and_log(
                ["Too much cores, automatically resizing the data chunks"],
                'debug', logger)

        nb_chunks, last_chunk_len = data_file.analyze(chunk_size)

    # I guess this is more relevant, to take signals from all over the recordings
    all_chunks = numpy.random.permutation(
        numpy.arange(nb_chunks, dtype=numpy.int32))
    all_electrodes = numpy.random.permutation(N_e)

    for gidx in [all_chunks[comm.rank]]:

        #print "Node", comm.rank, "is analyzing chunk", gidx,  "/", nb_chunks, " ..."
        local_chunk, t_offset = data_file.get_data(gidx,
                                                   chunk_size,
                                                   nodes=nodes)
        local_shape = len(local_chunk)

        #print "Node", comm.rank, "computes the median absolute deviations in a random chunk"
        thresholds = numpy.zeros(N_e, dtype=numpy.float32)
        for i in xrange(N_e):
            u = numpy.median(local_chunk[:, i], 0)
            thresholds[i] = numpy.median(numpy.abs(local_chunk[:, i] - u), 0)
        gdata = gather_array(thresholds, comm)
        if comm.rank == 0:
            gdata = gdata.reshape((comm.size, N_e))
            thresholds = numpy.mean(gdata, 0)
            bfile = h5py.File(file_out_suff + '.basis.hdf5',
                              'w',
                              libver='latest')
            io.write_datasets(bfile, ['thresholds'],
                              {'thresholds': thresholds})
            bfile.close()
        comm.Barrier()
        thresholds = io.load_data(params, 'thresholds')

        #print "Extracting the peaks..."
        local_peaktimes = numpy.zeros(0, dtype=numpy.int32)
        for i in xrange(N_e):
            peaktimes = algo.detect_peaks(numpy.abs(local_chunk[:, i]),
                                          thresholds[i],
                                          valley=False,
                                          mpd=dist_peaks)
            local_peaktimes = numpy.concatenate((local_peaktimes, peaktimes))

        local_peaktimes = numpy.unique(local_peaktimes)

        #print "Removing the useless borders..."
        local_borders = (template_shift, local_shape - template_shift)
        idx = (local_peaktimes >= local_borders[0]) & (local_peaktimes <
                                                       local_borders[1])
        local_peaktimes = numpy.compress(idx, local_peaktimes)

        if len(local_peaktimes) > 0:

            diff_times = local_peaktimes[-1] - local_peaktimes[0]
            all_times = numpy.zeros((N_e, diff_times + 1), dtype=numpy.bool)
            min_times = numpy.maximum(
                local_peaktimes - local_peaktimes[0] - safety_time, 0)
            max_times = numpy.minimum(
                local_peaktimes - local_peaktimes[0] + safety_time + 1,
                diff_times)
            argmax_peak = numpy.random.permutation(
                numpy.arange(len(local_peaktimes)))
            all_idx = numpy.take(local_peaktimes, argmax_peak)

            #print "Selection of the peaks with spatio-temporal masks..."
            for idx, peak in zip(argmax_peak, all_idx):
                elec = numpy.argmax(numpy.abs(local_chunk[peak]))
                indices = numpy.take(inv_nodes, edges[nodes[elec]])
                all_times[indices, min_times[idx]:max_times[idx]] = True
        else:
            all_times = numpy.zeros((N_e, len(local_chunk)), dtype=numpy.bool)

    all_times_Ne = numpy.any(all_times, 0)
    subset = numpy.where(all_times_Ne == False)[0]
    all_silences = []

    if do_spatial_whitening:
        local_silences = numpy.take(local_chunk, subset,
                                    axis=0)[:max_silence_1]
        all_silences = gather_array(local_silences, comm, 0, 1)

    local_res = []

    if do_temporal_whitening:

        for elec in all_electrodes[numpy.arange(comm.rank, nb_temp_white,
                                                comm.size)]:
            res = numpy.zeros((0, N_t), dtype=numpy.float32)
            scount = 0
            indices = numpy.take(inv_nodes, edges[nodes[elec]])
            all_times_elec = numpy.any(numpy.take(all_times, indices, axis=0),
                                       0)
            esubset = numpy.where(all_times_elec == False)[0]
            bound = len(esubset) - N_t
            while (scount < bound) and (len(res) < max_silence_2):
                myslice = esubset[scount:scount + N_t]
                if numpy.all((myslice - esubset[scount]) == numpy.arange(N_t)):
                    scount += N_t
                    res = numpy.vstack((res, local_chunk[myslice, elec]))
                else:
                    scount += 1
            if len(res) > 5:
                local_res += [numpy.cov(res.T)]

        nb_elecs = numpy.array([len(local_res)], dtype=numpy.float32)
        local_res = numpy.array(local_res, dtype=numpy.float32)
        if len(local_res) == 0:
            local_res = numpy.zeros(0, dtype=numpy.float32)
        else:
            local_res = numpy.sum(local_res, 0)
        all_res = gather_array(local_res.ravel(), comm, 0, 1)
        all_elecs = gather_array(nb_elecs, comm, 0, 1)

    if comm.rank == 0:

        to_write = {}

        if do_temporal_whitening:
            try:
                nb_silences = numpy.sum(all_elecs > 0)
                all_res = all_res.reshape((nb_silences, N_t**2))
            except Exception:
                print_and_log([
                    "No silent periods detected: something wrong with the parameters?"
                ], 'error', logger)
            all_res = numpy.sum(all_res, 0)
            all_res = all_res.reshape((N_t, N_t)) / numpy.sum(all_elecs)
            temporal_whitening = get_whitening_matrix(
                all_res.astype(numpy.double),
                fudge=1e-3)[template_shift].astype(numpy.float32)
            temporal_whitening /= temporal_whitening.sum()
            to_write['temporal'] = temporal_whitening
            have_nans = numpy.sum(numpy.isnan(temporal_whitening))

            if have_nans > 0:
                temporal_whitening = numpy.zeros(N_t, dtype=numpy.float32)
                temporal_whitening[N_t // 2] = 1
                to_write['temporal'] = temporal_whitening
                print_and_log(
                    ["Disabling temporal whitening because of NaNs found"],
                    'info', logger)

        if do_spatial_whitening:
            if len(all_silences) / params.rate == 0:
                print_and_log([
                    "No silent periods detected: something wrong with the parameters?"
                ], 'error', logger)
            spatial_whitening = get_whitening_matrix(
                all_silences.astype(numpy.double)).astype(numpy.float32)
            to_write['spatial'] = spatial_whitening
            print_and_log([
                "Found %gs without spikes for whitening matrices..." %
                (len(all_silences) / params.rate)
            ], 'default', logger)

            have_nans = numpy.sum(numpy.isnan(spatial_whitening))

            if have_nans > 0:
                spatial_whitening = numpy.eye(spatial_whitening.shape[0],
                                              dtype=numpy.float32)
                to_write['spatial'] = spatial_whitening
                print_and_log(
                    ["Disabling spatial whitening because of NaNs found"],
                    'info', logger)

        bfile = h5py.File(file_out_suff + '.basis.hdf5', 'r+', libver='latest')
        io.write_datasets(bfile, to_write.keys(), to_write)
        bfile.close()

    del all_silences
    comm.Barrier()

    if do_spatial_whitening or do_temporal_whitening:

        if comm.rank == 0:
            print_and_log(
                ["Because of whitening, need to recompute the thresholds..."],
                'default', logger)

        if do_spatial_whitening:
            spatial_whitening = io.load_data(params, 'spatial_whitening')
            if use_gpu:
                spatial_whitening = cmt.CUDAMatrix(spatial_whitening,
                                                   copy_on_host=False)
        if do_temporal_whitening:
            temporal_whitening = io.load_data(params, 'temporal_whitening')

        for gidx in [all_chunks[comm.rank]]:
            local_chunk, t_offset = data_file.get_data(gidx,
                                                       chunk_size,
                                                       nodes=nodes)
            local_shape = len(local_chunk)

            if do_spatial_whitening:
                if use_gpu:
                    local_chunk = cmt.CUDAMatrix(local_chunk,
                                                 copy_on_host=False)
                    local_chunk = local_chunk.dot(spatial_whitening).asarray()
                else:
                    local_chunk = numpy.dot(local_chunk, spatial_whitening)
            if do_temporal_whitening:
                local_chunk = scipy.ndimage.filters.convolve1d(
                    local_chunk, temporal_whitening, axis=0, mode='constant')

            thresholds = numpy.zeros(N_e, dtype=numpy.float32)
            for i in xrange(N_e):
                u = numpy.median(local_chunk[:, i], 0)
                thresholds[i] = numpy.median(numpy.abs(local_chunk[:, i] - u),
                                             0)
            gdata = gather_array(thresholds, comm)
            if comm.rank == 0:
                gdata = gdata.reshape((comm.size, N_e))
                thresholds = numpy.mean(gdata, 0)
                bfile = h5py.File(file_out_suff + '.basis.hdf5',
                                  'r+',
                                  libver='latest')
                bfile.pop('thresholds')
                io.write_datasets(bfile, ['thresholds'],
                                  {'thresholds': thresholds})
                bfile.close()
            comm.Barrier()

    #if comm.rank == 0:
    #if not os.path.exists(plot_path):
    #    os.makedirs(plot_path)
    #N_elec = min(int(numpy.sqrt(data_file.N_e)), 5)
    #plot.view_fit(filename, t_start=0, t_stop=1, fit_on=False, square=True,
    #              n_elec=N_elec, save=[plot_path, 'electrodes'])

    # Part 2: Basis
    numpy.random.seed(422)

    #################################################################
    file_out = params.get('data', 'file_out')
    alignment = params.getboolean('detection', 'alignment')
    spike_thresh = params.getfloat('detection', 'spike_thresh')
    nodes, edges = get_nodes_and_edges(params)
    do_temporal_whitening = params.getboolean('whitening', 'temporal')
    do_spatial_whitening = params.getboolean('whitening', 'spatial')
    chunk_size = params.getint('data', 'chunk_size')
    safety_time = params.getint('whitening', 'safety_time')
    max_elts_elec = params.getint('whitening', 'max_elts')
    nb_elts = int(
        params.getfloat('whitening', 'nb_elts') * N_e * max_elts_elec)
    output_dim = params.getfloat('whitening', 'output_dim')
    inv_nodes = numpy.zeros(N_total, dtype=numpy.int32)
    inv_nodes[nodes] = numpy.argsort(nodes)
    if sign_peaks == 'both':
        max_elts_elec *= 2
    #################################################################

    if comm.rank == 0:
        print_and_log(["Searching spikes to construct the PCA basis..."],
                      'default', logger)

    nb_chunks, last_chunk_len = data_file.analyze(chunk_size)

    if nb_chunks < comm.size:

        res = io.data_stats(params, show=False)
        chunk_size = int(res * params.rate // comm.size)
        if comm.rank == 0:
            print_and_log(
                ["Too much cores, automatically resizing the data chunks"],
                'debug', logger)

        nb_chunks, last_chunk_len = data_file.analyze(chunk_size)

    groups = {}
    for i in xrange(N_e):
        groups[i] = 0

    # I guess this is more relevant, to take signals from all over the recordings
    all_chunks = numpy.random.permutation(
        numpy.arange(nb_chunks, dtype=numpy.int32))
    max_elts_elec //= comm.size
    nb_elts //= comm.size

    elt_count_pos = 0
    elt_count_neg = 0

    if sign_peaks in ['positive', 'both']:
        elts_pos = numpy.zeros((N_t, nb_elts), dtype=numpy.float32)
    if sign_peaks in ['negative', 'both']:
        elts_neg = numpy.zeros((N_t, nb_elts), dtype=numpy.float32)

    chunks_to_load = all_chunks[comm.rank::comm.size]

    thresholds = io.load_data(params, 'thresholds')

    if comm.rank == 0:
        pbar = get_progressbar(nb_elts)

    if alignment:
        cdata = numpy.linspace(-template_shift, template_shift, 5 * N_t)
        xdata = numpy.arange(-2 * template_shift, 2 * template_shift + 1)

    for gcount, gidx in enumerate(chunks_to_load):

        if ((elt_count_pos + elt_count_neg) < nb_elts):
            #print "Node", comm.rank, "is analyzing chunk", gidx, "/", nb_chunks, " ..."
            local_chunk, t_offset = data_file.get_data(gidx,
                                                       chunk_size,
                                                       nodes=nodes)
            local_shape = len(local_chunk)

            if do_spatial_whitening:
                if use_gpu:
                    local_chunk = cmt.CUDAMatrix(local_chunk,
                                                 copy_on_host=False)
                    local_chunk = local_chunk.dot(spatial_whitening).asarray()
                else:
                    local_chunk = numpy.dot(local_chunk, spatial_whitening)
            if do_temporal_whitening:
                local_chunk = scipy.ndimage.filters.convolve1d(
                    local_chunk, temporal_whitening, axis=0, mode='constant')

            #print "Extracting the peaks..."
            all_peaktimes = numpy.zeros(0, dtype=numpy.int32)
            all_extremas = numpy.zeros(0, dtype=numpy.int32)

            for i in xrange(N_e):

                if sign_peaks == 'negative':
                    peaktimes = algo.detect_peaks(local_chunk[:, i],
                                                  thresholds[i],
                                                  valley=True,
                                                  mpd=dist_peaks)
                elif sign_peaks == 'positive':
                    peaktimes = algo.detect_peaks(local_chunk[:, i],
                                                  thresholds[i],
                                                  valley=False,
                                                  mpd=dist_peaks)
                elif sign_peaks == 'both':
                    peaktimes = algo.detect_peaks(numpy.abs(local_chunk[:, i]),
                                                  thresholds[i],
                                                  valley=False,
                                                  mpd=dist_peaks)
                all_peaktimes = numpy.concatenate((all_peaktimes, peaktimes))
                all_extremas = numpy.concatenate(
                    (all_extremas,
                     i * numpy.ones(len(peaktimes), dtype=numpy.int32)))

            #print "Removing the useless borders..."
            if alignment:
                local_borders = (2 * template_shift,
                                 local_shape - 2 * template_shift)
            else:
                local_borders = (template_shift, local_shape - template_shift)
            idx = (all_peaktimes >= local_borders[0]) & (all_peaktimes <
                                                         local_borders[1])
            all_peaktimes = numpy.compress(idx, all_peaktimes)
            all_extremas = numpy.compress(idx, all_extremas)

            local_peaktimes = numpy.unique(all_peaktimes)

            if len(local_peaktimes) > 0:

                diff_times = local_peaktimes[-1] - local_peaktimes[0]
                all_times = numpy.zeros((N_e, diff_times + 1),
                                        dtype=numpy.bool)
                min_times = numpy.maximum(
                    local_peaktimes - local_peaktimes[0] - safety_time, 0)
                max_times = numpy.minimum(
                    local_peaktimes - local_peaktimes[0] + safety_time + 1,
                    diff_times)

                n_times = len(local_peaktimes)
                argmax_peak = numpy.random.permutation(numpy.arange(n_times))
                all_idx = numpy.take(local_peaktimes, argmax_peak)

                #print "Selection of the peaks with spatio-temporal masks..."
                for midx, peak in zip(argmax_peak, all_idx):
                    if (elt_count_neg + elt_count_pos) == nb_elts:
                        break

                    if sign_peaks == 'negative':
                        elec = numpy.argmin(local_chunk[peak])
                        negative_peak = True
                    elif sign_peaks == 'positive':
                        elec = numpy.argmax(local_chunk[peak])
                        negative_peak = False
                    elif sign_peaks == 'both':
                        if numpy.abs(numpy.max(local_chunk[peak])) > numpy.abs(
                                numpy.min(local_chunk[peak])):
                            elec = numpy.argmax(local_chunk[peak])
                            negative_peak = False
                        else:
                            elec = numpy.argmin(local_chunk[peak])
                            negative_peak = True

                    indices = numpy.take(inv_nodes, edges[nodes[elec]])
                    myslice = all_times[indices,
                                        min_times[midx]:max_times[midx]]
                    is_local_extrema = elec in all_extremas[all_peaktimes ==
                                                            peak]
                    if is_local_extrema and not myslice.any():
                        upper_bounds = max_elts_elec

                        if groups[elec] < upper_bounds:

                            if negative_peak:
                                elts_neg[:, elt_count_neg] = local_chunk[
                                    peak - template_shift:peak +
                                    template_shift + 1, elec]
                            else:
                                elts_pos[:, elt_count_pos] = local_chunk[
                                    peak - template_shift:peak +
                                    template_shift + 1, elec]
                            if alignment:
                                ydata = local_chunk[peak -
                                                    2 * template_shift:peak +
                                                    2 * template_shift + 1,
                                                    elec]
                                f = scipy.interpolate.UnivariateSpline(xdata,
                                                                       ydata,
                                                                       s=0)
                                if negative_peak:
                                    rmin = (numpy.argmin(f(cdata)) -
                                            len(cdata) / 2.) / 5.
                                else:
                                    rmin = (numpy.argmax(f(cdata)) -
                                            len(cdata) / 2.) / 5.
                                ddata = numpy.linspace(rmin - template_shift,
                                                       rmin + template_shift,
                                                       N_t)

                                if negative_peak:
                                    elts_neg[:,
                                             elt_count_neg] = f(ddata).astype(
                                                 numpy.float32)
                                else:
                                    elts_pos[:,
                                             elt_count_pos] = f(ddata).astype(
                                                 numpy.float32)

                            if negative_peak:
                                elt_count_neg += 1
                            else:
                                elt_count_pos += 1

                        groups[elec] += 1
                        all_times[indices,
                                  min_times[midx]:max_times[midx]] = True

            if comm.rank == 0:
                pbar.update(elt_count_pos + elt_count_neg)

        if comm.rank == 0:
            if (elt_count_pos + elt_count_neg <
                (gcount + 1) * max_elts_elec // len(chunks_to_load)):
                pbar.update(
                    (gcount + 1) * max_elts_elec // len(chunks_to_load))

    if comm.rank == 0:
        pbar.finish()

    print_and_log([
        "Node %d has collected %d waveforms" %
        (comm.rank, elt_count_pos + elt_count_neg)
    ], 'debug', logger)

    if sign_peaks in ['negative', 'both']:
        gdata_neg = gather_array(elts_neg[:, :elt_count_neg].T, comm, 0, 1)
    if sign_peaks in ['positive', 'both']:
        gdata_pos = gather_array(elts_pos[:, :elt_count_pos].T, comm, 0, 1)

    if comm.rank == 0:
        #DO PCA on elts and store the basis obtained.

        nb_waveforms = 0
        if sign_peaks in ['negative', 'both']:
            nb_waveforms += gdata_neg.shape[0]
        if sign_peaks in ['positive', 'both']:
            nb_waveforms += gdata_pos.shape[0]

        print_and_log([
            "Found %d waveforms over %d requested" %
            (nb_waveforms, int(nb_elts * comm.size))
        ], 'default', logger)
        pca = PCA(output_dim, copy=False)
        res = {}
        if sign_peaks in ['negative', 'both']:
            if len(gdata_neg) > 0:
                res_pca = pca.fit_transform(gdata_neg.astype(
                    numpy.double)).astype(numpy.float32)
                res['proj'] = pca.components_.T.astype(numpy.float32)
            else:
                res['proj'] = numpy.identity(N_t, dtype=numpy.float32)
            res['rec'] = res['proj'].T
            res['waveform'] = numpy.median(gdata_neg, 0)
            idx = numpy.random.permutation(numpy.arange(
                gdata_neg.shape[0]))[:1000]
            res['waveforms'] = gdata_neg[idx, :]
        if sign_peaks in ['positive', 'both']:
            if len(gdata_pos) > 0:
                res_pca = pca.fit_transform(gdata_pos.astype(
                    numpy.double)).astype(numpy.float32)
                res['proj_pos'] = pca.components_.T.astype(numpy.float32)
            else:
                res['proj_pos'] = numpy.identity(N_t, dtype=numpy.float32)
            res['rec_pos'] = res['proj_pos'].T
            res['waveform_pos'] = numpy.median(gdata_pos, 0)
            idx = numpy.random.permutation(numpy.arange(
                gdata_pos.shape[0]))[:1000]
            res['waveforms_pos'] = gdata_pos[idx, :]

        bfile = h5py.File(file_out_suff + '.basis.hdf5', 'r+', libver='latest')
        io.write_datasets(bfile, res.keys(), res)
        if sign_peaks == 'positive':
            print_and_log([
                "A basis with %s dimensions has been built" %
                res['proj_pos'].shape[1]
            ], 'info', logger)
        elif sign_peaks == 'negative':
            print_and_log([
                "A basis with %s dimensions has been built" %
                res['proj'].shape[1]
            ], 'info', logger)
        elif sign_peaks == 'both':
            print_and_log([
                "Two basis with %s dimensions has been built" %
                res['proj'].shape[1]
            ], 'info', logger)

        bfile.close()

    comm.Barrier()

    if matched_filter:

        if comm.rank == 0:
            print_and_log([
                "Because of matched filters, need to recompute the thresholds..."
            ], 'default', logger)

        if do_spatial_whitening:
            spatial_whitening = io.load_data(params, 'spatial_whitening')
            if use_gpu:
                spatial_whitening = cmt.CUDAMatrix(spatial_whitening,
                                                   copy_on_host=False)
        if do_temporal_whitening:
            temporal_whitening = io.load_data(params, 'temporal_whitening')

        if sign_peaks in ['negative', 'both']:
            waveform_neg = io.load_data(params, 'waveform')
            waveform_neg /= (numpy.abs(numpy.sum(waveform_neg)) *
                             len(waveform_neg))
        if sign_peaks in ['positive', 'both']:
            waveform_pos = io.load_data(params, 'waveform-pos')
            waveform_pos /= (numpy.abs(numpy.sum(waveform_pos)) *
                             len(waveform_pos))

        for gidx in [all_chunks[comm.rank]]:
            local_chunk, t_offset = data_file.get_data(gidx,
                                                       chunk_size,
                                                       nodes=nodes)
            local_shape = len(local_chunk)

            if do_spatial_whitening:
                if use_gpu:
                    local_chunk = cmt.CUDAMatrix(local_chunk,
                                                 copy_on_host=False)
                    local_chunk = local_chunk.dot(spatial_whitening).asarray()
                else:
                    local_chunk = numpy.dot(local_chunk, spatial_whitening)
            if do_temporal_whitening:
                local_chunk = scipy.ndimage.filters.convolve1d(
                    local_chunk, temporal_whitening, axis=0, mode='constant')

            if sign_peaks in ['negative', 'both']:
                tmp_chunk = scipy.ndimage.filters.convolve1d(local_chunk,
                                                             waveform_neg,
                                                             axis=0,
                                                             mode='constant')
                thresholds = numpy.zeros(N_e, dtype=numpy.float32)
                for i in xrange(N_e):
                    u = numpy.median(tmp_chunk[:, i], 0)
                    thresholds[i] = numpy.median(
                        numpy.abs(tmp_chunk[:, i] - u), 0)
                gdata = gather_array(thresholds, comm)
                if comm.rank == 0:
                    gdata = gdata.reshape((comm.size, N_e))
                    thresholds = numpy.mean(gdata, 0)
                    bfile = h5py.File(file_out + '.basis.hdf5',
                                      'r+',
                                      libver='latest')
                    io.write_datasets(bfile, ['matched_thresholds'],
                                      {'matched_thresholds': thresholds})
                    bfile.close()
                comm.Barrier()

            if sign_peaks in ['positive', 'both']:
                tmp_chunk = scipy.ndimage.filters.convolve1d(local_chunk,
                                                             waveform_pos,
                                                             axis=0,
                                                             mode='constant')
                thresholds = numpy.zeros(N_e, dtype=numpy.float32)
                for i in xrange(N_e):
                    u = numpy.median(tmp_chunk[:, i], 0)
                    thresholds[i] = numpy.median(
                        numpy.abs(tmp_chunk[:, i] - u), 0)
                gdata = gather_array(thresholds, comm)
                if comm.rank == 0:
                    gdata = gdata.reshape((comm.size, N_e))
                    thresholds = numpy.mean(gdata, 0)
                    bfile = h5py.File(file_out + '.basis.hdf5',
                                      'r+',
                                      libver='latest')
                    io.write_datasets(bfile, ['matched_thresholds_pos'],
                                      {'matched_thresholds_pos': thresholds})
                    bfile.close()
                comm.Barrier()

    data_file.close()
Exemplo n.º 3
0
def main(params, nb_cpu, nb_gpu, use_gpu):

    #################################################################
    # params = detect_memory(params)
    _ = init_logging(params.logfile)
    SHARED_MEMORY = get_shared_memory_flag(params)
    logger = logging.getLogger('circus.fitting')
    data_file = params.data_file
    N_e = params.getint('data', 'N_e')
    N_total = params.nb_channels
    N_t = params.getint('detection', 'N_t')
    template_shift = params.getint('detection', 'template_shift')
    file_out = params.get('data', 'file_out')
    file_out_suff = params.get('data', 'file_out_suff')
    sign_peaks = params.get('detection', 'peaks')
    dist_peaks = params.getint('detection', 'dist_peaks')
    matched_filter = params.getboolean('detection', 'matched-filter')
    spike_thresh = params.getfloat('detection', 'spike_thresh')
    spike_width = params.getfloat('detection', 'spike_width')
    do_temporal_whitening = params.getboolean('whitening', 'temporal')
    do_spatial_whitening = params.getboolean('whitening', 'spatial')
    chunk_size = detect_memory(params)
    gpu_only = params.getboolean('fitting', 'gpu_only')
    nodes, edges = get_nodes_and_edges(params)
    tmp_limits = params.get('fitting',
                            'amp_limits').replace('(',
                                                  '').replace(')',
                                                              '').split(',')
    tmp_limits = map(float, tmp_limits)
    amp_auto = params.getboolean('fitting', 'amp_auto')
    nb_chances = params.getint('fitting', 'nb_chances')
    max_chunk = params.getfloat('fitting', 'max_chunk')
    noise_thr = params.getfloat('clustering', 'noise_thr')
    collect_all = params.getboolean('fitting', 'collect_all')
    ignore_dead_times = params.getboolean('triggers', 'ignore_times')
    inv_nodes = numpy.zeros(N_total, dtype=numpy.int32)
    inv_nodes[nodes] = numpy.arange(len(nodes))
    data_file.open()
    #################################################################

    if use_gpu:
        import cudamat as cmt
        # # Need to properly handle multi GPU per MPI nodes?
        if nb_gpu > nb_cpu:
            gpu_id = int(comm.rank // nb_cpu)
        else:
            gpu_id = 0
        cmt.cuda_set_device(gpu_id)
        cmt.init()
        cmt.cuda_sync_threads()

    if matched_filter:
        if sign_peaks in ['negative', 'both']:
            waveform_neg = io.load_data(params, 'waveform')[::-1]
            waveform_neg /= (numpy.abs(numpy.sum(waveform_neg)) *
                             len(waveform_neg))
            matched_tresholds_neg = io.load_data(params, 'matched-thresholds')
        if sign_peaks in ['positive', 'both']:
            waveform_pos = io.load_data(params, 'waveform-pos')[::-1]
            waveform_pos /= (numpy.abs(numpy.sum(waveform_pos)) *
                             len(waveform_pos))
            matched_tresholds_pos = io.load_data(params,
                                                 'matched-thresholds-pos')

    if ignore_dead_times:
        all_dead_times = get_dead_times(params)

    thresholds = io.load_data(params, 'thresholds')

    comm.Barrier()

    if comm.rank == 0:
        print_and_log(["Extracting MUA activity..."], 'default', logger)
        purge(file_out_suff, '.data')

    if do_spatial_whitening:
        spatial_whitening = io.load_data(params, 'spatial_whitening')
    else:
        spatial_whitening = None  # default assignment (PyCharm code inspection)
    if do_temporal_whitening:
        temporal_whitening = io.load_data(params, 'temporal_whitening')
    else:
        temporal_whitening = None  # default assignment (PyCharm code inspection)

    nb_chunks, last_chunk_len = data_file.analyze(chunk_size)
    processed_chunks = int(min(nb_chunks, max_chunk))

    comm.Barrier()
    spiketimes_file = open(file_out_suff + '.mua-%d.data' % comm.rank, 'wb')
    comm.Barrier()
    electrodes_file = open(file_out_suff + '.elec-%d.data' % comm.rank, 'wb')
    comm.Barrier()
    amplitudes_file = open(file_out_suff + '.amp-%d.data' % comm.rank, 'wb')
    comm.Barrier()

    if use_gpu and do_spatial_whitening:
        spatial_whitening = cmt.CUDAMatrix(spatial_whitening,
                                           copy_on_host=False)

    to_explore = range(comm.rank, processed_chunks, comm.size)

    if comm.rank == 0:
        to_explore = get_tqdm_progressbar(to_explore)

    for gcount, gidx in enumerate(to_explore):
        # print "Node", comm.rank, "is analyzing chunk", gidx, "/", nb_chunks, " ..."
        # # We need to deal with the borders by taking chunks of size [0, chunck_size + template_shift].

        is_first = data_file.is_first_chunk(gidx, nb_chunks)
        is_last = data_file.is_last_chunk(gidx, nb_chunks)

        if is_last:
            padding = (-dist_peaks, 0)
        elif is_first:
            padding = (0, dist_peaks)
        else:
            padding = (-dist_peaks, dist_peaks)

        result = {'spiketimes': [], 'amplitudes': [], 'templates': []}

        local_chunk, t_offset = data_file.get_data(gidx,
                                                   chunk_size,
                                                   padding,
                                                   nodes=nodes)
        len_chunk = len(local_chunk)

        if do_spatial_whitening:
            if use_gpu:
                local_chunk = cmt.CUDAMatrix(local_chunk, copy_on_host=False)
                local_chunk = local_chunk.dot(spatial_whitening).asarray()
            else:
                local_chunk = numpy.dot(local_chunk, spatial_whitening)
        if do_temporal_whitening:
            local_chunk = scipy.ndimage.filters.convolve1d(local_chunk,
                                                           temporal_whitening,
                                                           axis=0,
                                                           mode='constant')

        # print "Extracting the peaks..."

        local_peaktimes = [numpy.zeros(0, dtype=numpy.uint32)]
        local_elecs = [numpy.zeros(0, dtype=numpy.uint32)]
        local_amps = [numpy.zeros(0, dtype=numpy.float32)]

        if matched_filter:
            if sign_peaks in ['positive', 'both']:
                filter_chunk = scipy.ndimage.filters.convolve1d(
                    local_chunk, waveform_pos, axis=0, mode='constant')
                for i in range(N_e):
                    peaktimes = scipy.signal.find_peaks(
                        filter_chunk[:, i],
                        height=matched_tresholds_pos[i],
                        width=spike_width,
                        distance=dist_peaks,
                        wlen=N_t)[0]
                    local_peaktimes.append(peaktimes)
                    local_elecs.append(
                        i * numpy.ones(len(peaktimes), dtype='uint32'))
                    local_amps.append(filter_chunk[peaktimes, i])
            if sign_peaks in ['negative', 'both']:
                filter_chunk = scipy.ndimage.filters.convolve1d(
                    local_chunk, waveform_neg, axis=0, mode='constant')
                for i in range(N_e):
                    peaktimes = scipy.signal.find_peaks(
                        filter_chunk[:, i],
                        height=matched_tresholds_neg[i],
                        width=spike_width,
                        distance=dist_peaks,
                        wlen=N_t)[0]
                    local_peaktimes.append(peaktimes)
                    local_elecs.append(
                        i * numpy.ones(len(peaktimes), dtype='uint32'))
                    local_amps.append(filter_chunk[peaktimes, i])
        else:
            for i in range(N_e):
                if sign_peaks == 'negative':
                    peaktimes = scipy.signal.find_peaks(-local_chunk[:, i],
                                                        height=thresholds[i],
                                                        width=spike_width,
                                                        distance=dist_peaks,
                                                        wlen=N_t)[0]
                elif sign_peaks == 'positive':
                    peaktimes = scipy.signal.find_peaks(local_chunk[:, i],
                                                        height=thresholds[i],
                                                        width=spike_width,
                                                        distance=dist_peaks,
                                                        wlen=N_t)[0]
                elif sign_peaks == 'both':
                    peaktimes = scipy.signal.find_peaks(numpy.abs(
                        local_chunk[:, i]),
                                                        height=thresholds[i],
                                                        width=spike_width,
                                                        distance=dist_peaks,
                                                        wlen=N_t)[0]
                local_peaktimes.append(peaktimes)
                local_elecs.append(i *
                                   numpy.ones(len(peaktimes), dtype='uint32'))
                local_amps.append(local_chunk[peaktimes, i])

        local_peaktimes = numpy.concatenate(local_peaktimes)
        local_elecs = numpy.concatenate(local_elecs)
        local_amps = numpy.concatenate(local_amps)

        g_offset = t_offset + padding[0]

        if ignore_dead_times:
            dead_indices = numpy.searchsorted(
                all_dead_times, [t_offset, t_offset + chunk_size])
            if dead_indices[0] != dead_indices[1]:
                is_included = numpy.in1d(
                    local_peaktimes + g_offset,
                    all_dead_times[dead_indices[0]:dead_indices[1]])
                local_peaktimes = local_peaktimes[~is_included]
                local_elecs = local_elecs[~is_included]
                local_amps = local_amps[~is_included]

        # print "Removing the useless borders..."
        local_borders = (dist_peaks, len_chunk - dist_peaks)
        idx = (local_peaktimes >= local_borders[0]) & (local_peaktimes <
                                                       local_borders[1])
        local_peaktimes = numpy.compress(idx, local_peaktimes) + g_offset
        local_elecs = numpy.compress(idx, local_elecs)
        local_amps = numpy.compress(idx, local_amps)

        spiketimes_file.write(local_peaktimes.astype(numpy.uint32).tostring())
        electrodes_file.write(local_elecs.tostring())
        amplitudes_file.write(local_amps.tostring())

    sys.stderr.flush()

    spiketimes_file.flush()
    os.fsync(spiketimes_file.fileno())
    spiketimes_file.close()

    electrodes_file.flush()
    os.fsync(electrodes_file.fileno())
    electrodes_file.close()

    amplitudes_file.flush()
    os.fsync(amplitudes_file.fileno())
    amplitudes_file.close()

    comm.Barrier()

    if comm.rank == 0:
        io.collect_mua(comm.size, params, erase=True)

    data_file.close()
Exemplo n.º 4
0
def main(params, nb_cpu, nb_gpu, use_gpu):

    #################################################################
    # params = detect_memory(params)
    _ = init_logging(params.logfile)
    SHARED_MEMORY = get_shared_memory_flag(params)
    logger = logging.getLogger('circus.fitting')
    data_file = params.data_file
    n_e = params.getint('data', 'N_e')
    n_total = params.nb_channels
    n_t = params.getint('detection', 'N_t')
    template_shift = params.getint('detection', 'template_shift')
    # file_out = params.get('data', 'file_out')
    file_out_suff = params.get('data', 'file_out_suff')
    sign_peaks = params.get('detection', 'peaks')
    matched_filter = params.getboolean('detection', 'matched-filter')
    # spike_thresh = params.getfloat('detection', 'spike_thresh')
    ratio_thresh = params.getfloat('fitting', 'ratio_thresh')
    two_components = params.getboolean('fitting', 'two_components')
    # spike_width = params.getfloat('detection', 'spike_width')
    # dist_peaks = params.getint('detection', 'dist_peaks')
    do_temporal_whitening = params.getboolean('whitening', 'temporal')
    do_spatial_whitening = params.getboolean('whitening', 'spatial')
    templates_normalization = params.getboolean('clustering', 'templates_normalization')  # TODO test, switch, test!
    chunk_size = detect_memory(params, fitting=True)
    gpu_only = params.getboolean('fitting', 'gpu_only')
    nodes, edges = get_nodes_and_edges(params)
    tmp_limits = params.get('fitting', 'amp_limits').replace('(', '').replace(')', '').split(',')
    tmp_limits = [float(v) for v in tmp_limits]
    amp_auto = params.getboolean('fitting', 'amp_auto')
    auto_nb_chances = params.getboolean('fitting', 'auto_nb_chances')
    if auto_nb_chances:
        nb_chances = io.load_data(params, 'nb_chances')
        max_nb_chances = params.getint('fitting', 'max_nb_chances')
        percent_nb_chances = params.getfloat('fitting', 'percent_nb_chances')
        total_nb_chances = max(1, numpy.nanpercentile(nb_chances, percent_nb_chances))
        total_nb_chances = min(total_nb_chances, max_nb_chances)
        if comm.rank == 0:
            print_and_log(['nb_chances set automatically to %g' %total_nb_chances], 'debug', logger)
    else:
        total_nb_chances = params.getfloat('fitting', 'nb_chances')
    max_chunk = params.getfloat('fitting', 'max_chunk')
    # noise_thr = params.getfloat('clustering', 'noise_thr')
    collect_all = params.getboolean('fitting', 'collect_all')
    min_second_component = params.getfloat('fitting', 'min_second_component')
    debug = params.getboolean('fitting', 'debug')
    ignore_dead_times = params.getboolean('triggers', 'ignore_times')
    inv_nodes = numpy.zeros(n_total, dtype=numpy.int32)
    inv_nodes[nodes] = numpy.arange(len(nodes))
    data_file.open()
    #################################################################

    if use_gpu:
        import cudamat as cmt
        # # Need to properly handle multi GPU per MPI nodes?
        if nb_gpu > nb_cpu:
            gpu_id = int(comm.rank // nb_cpu)
        else:
            gpu_id = 0
        cmt.cuda_set_device(gpu_id)
        cmt.init()
        cmt.cuda_sync_threads()

    if SHARED_MEMORY:
        templates, _ = io.load_data_memshared(params, 'templates', normalize=templates_normalization, transpose=True)
        N_tm, x = templates.shape
    else:
        templates = io.load_data(params, 'templates')
        x, N_tm = templates.shape

    temp_2_shift = 2 * template_shift
    temp_3_shift = 3 * template_shift
    full_gpu = use_gpu and gpu_only
    n_tm = N_tm // 2
    n_scalar = n_e * n_t

    temp_window = numpy.arange(-template_shift, template_shift + 1)
    size_window = n_e * (2 * template_shift + 1)

    if not amp_auto:
        amp_limits = numpy.zeros((n_tm, 2))
        amp_limits[:, 0] = tmp_limits[0]
        amp_limits[:, 1] = tmp_limits[1]
    else:
        amp_limits = io.load_data(params, 'limits')

    norm_templates = io.load_data(params, 'norm-templates')
    if not templates_normalization:
        norm_templates_2 = (norm_templates ** 2.0) * n_scalar

    if not SHARED_MEMORY:
        # Normalize templates (if necessary).
        if templates_normalization:
            for idx in range(templates.shape[1]):
                myslice = numpy.arange(templates.indptr[idx], templates.indptr[idx+1])
                templates.data[myslice] /= norm_templates[idx]
        # Transpose templates.
        templates = templates.T

    waveform_neg = numpy.empty(0)  # default assignment (for PyCharm code inspection)
    matched_thresholds_neg = None  # default assignment (for PyCharm code inspection)
    waveform_pos = numpy.empty(0)  # default assignment (for PyCharm code inspection)
    matched_thresholds_pos = None  # default assignment (for PyCharm code inspection)
    if matched_filter:
        if sign_peaks in ['negative', 'both']:
            waveform_neg = io.load_data(params, 'waveform')[::-1]
            waveform_neg /= (numpy.abs(numpy.sum(waveform_neg)) * len(waveform_neg))
            matched_thresholds_neg = io.load_data(params, 'matched-thresholds')
        if sign_peaks in ['positive', 'both']:
            waveform_pos = io.load_data(params, 'waveform-pos')[::-1]
            waveform_pos /= (numpy.abs(numpy.sum(waveform_pos)) * len(waveform_pos))
            matched_thresholds_pos = io.load_data(params, 'matched-thresholds-pos')

    if ignore_dead_times:
        all_dead_times = get_dead_times(params)
    else:
        all_dead_times = None  # default assignment (for PyCharm code inspection)

    thresholds = io.get_accurate_thresholds(params, ratio_thresh)

    neighbors = {}
    if collect_all:
        for i in range(0, n_tm):
            tmp = templates[i, :].toarray().reshape(n_e, n_t)
            if templates_normalization:
                tmp = tmp * norm_templates[i]
            neighbors[i] = numpy.where(numpy.sum(tmp, axis=1) != 0.0)[0]

    if use_gpu:
        templates = cmt.SparseCUDAMatrix(templates, copy_on_host=False)

    info_string = ''

    if comm.rank == 0:
        if use_gpu:
            info_string = "using %d GPUs" % comm.size
        else:
            info_string = "using %d CPUs" % comm.size

    comm.Barrier()

    c_overlap = io.get_overlaps(params, nb_cpu=nb_cpu, nb_gpu=nb_gpu, use_gpu=use_gpu)
    over_shape = c_overlap.get('over_shape')[:]
    n_over = int(numpy.sqrt(over_shape[0]))
    s_over = over_shape[1]
    # # If the number of overlaps is different from templates, we need to recompute them.
    if n_over != N_tm:
        if comm.rank == 0:
            print_and_log(['Templates have been modified, recomputing the overlaps...'], 'default', logger)
        c_overlap = io.get_overlaps(params, erase=True, nb_cpu=nb_cpu, nb_gpu=nb_gpu, use_gpu=use_gpu)
        over_shape = c_overlap.get('over_shape')[:]
        n_over = int(numpy.sqrt(over_shape[0]))
        s_over = over_shape[1]

    if SHARED_MEMORY:
        c_overs, _ = io.load_data_memshared(params, 'overlaps')
    else:
        c_overs = io.load_data(params, 'overlaps')

    comm.Barrier()

    if n_tm == 0:
        if comm.rank == 0:
            print_and_log(["No templates present. Redo clustering?"], 'default', logger)

        sys.exit(0)

    if comm.rank == 0:
        print_and_log(["Here comes the SpyKING CIRCUS %s and %d templates..." % (info_string, n_tm)], 'default', logger)
        purge(file_out_suff, '.data')

    if do_spatial_whitening:
        spatial_whitening = io.load_data(params, 'spatial_whitening')
    else:
        spatial_whitening = None  # default assignment (for PyCharm code inspection)
    if do_temporal_whitening:
        temporal_whitening = io.load_data(params, 'temporal_whitening')
    else:
        temporal_whitening = None  # default assignment (for PyCharm code inspection)

    if full_gpu:
        try:
            # If memory on the GPU is large enough, we load the overlaps onto it
            for i in range(n_over):
                c_overs[i] = cmt.SparseCUDAMatrix(c_overs[i], copy_on_host=False)
        except Exception:
            if comm.rank == 0:
                print_and_log(["Not enough memory on GPUs: GPUs are used for projection only"], 'info', logger)
            for i in range(n_over):
                if i in c_overs:
                    del c_overs[i]
            full_gpu = False

    nb_chunks, last_chunk_len = data_file.analyze(chunk_size)
    processed_chunks = int(min(nb_chunks, max_chunk))

    comm.Barrier()
    spiketimes_file = open(file_out_suff + '.spiketimes-%d.data' % comm.rank, 'wb')
    comm.Barrier()
    amplitudes_file = open(file_out_suff + '.amplitudes-%d.data' % comm.rank, 'wb')
    comm.Barrier()
    templates_file = open(file_out_suff + '.templates-%d.data' % comm.rank, 'wb')
    comm.Barrier()

    if collect_all:
        garbage_times_file = open(file_out_suff + '.gspiketimes-%d.data' % comm.rank, 'wb')
        comm.Barrier()
        garbage_temp_file = open(file_out_suff + '.gtemplates-%d.data' % comm.rank, 'wb')
        comm.Barrier()
    else:
        garbage_times_file = None  # default assignment (for PyCharm code inspection)
        garbage_temp_file = None  # default assignment (for PyCharm code inspection)

    if debug:
        # Open debug files.
        chunk_nbs_debug_file = open(file_out_suff + '.chunk_nbs_debug_%d.data' % comm.rank, mode='wb')
        comm.Barrier()
        iteration_nbs_debug_file = open(file_out_suff + '.iteration_nbs_debug_%d.data' % comm.rank, mode='wb')
        comm.Barrier()
        peak_nbs_debug_file = open(file_out_suff + '.peak_nbs_debug_%d.data' % comm.rank, mode='wb')
        comm.Barrier()
        peak_local_time_steps_debug_file = open(
            file_out_suff + '.peak_local_time_steps_debug_%d.data' % comm.rank, mode='wb'
        )
        comm.Barrier()
        peak_time_steps_debug_file = open(file_out_suff + '.peak_time_steps_debug_%d.data' % comm.rank, mode='wb')
        comm.Barrier()
        peak_scalar_products_debug_file = open(
            file_out_suff + '.peak_scalar_products_debug_%d.data' % comm.rank, mode='wb'
        )
        comm.Barrier()
        peak_solved_flags_debug_file = open(file_out_suff + '.peak_solved_flags_debug_%d.data' % comm.rank, mode='wb')
        comm.Barrier()
        template_nbs_debug_file = open(file_out_suff + '.template_nbs_debug_%d.data' % comm.rank, mode='wb')
        comm.Barrier()
        success_flags_debug_file = open(file_out_suff + '.success_flags_debug_%d.data' % comm.rank, mode='wb')
        comm.Barrier()
    else:
        chunk_nbs_debug_file = None  # default assignment (for PyCharm code inspection)
        iteration_nbs_debug_file = None  # default assignment (for PyCharm code inspection)
        peak_nbs_debug_file = None  # default assignment (for PyCharm code inspection)
        peak_local_time_steps_debug_file = None  # default assignment (for PyCharm code inspection)
        peak_time_steps_debug_file = None  # default assignment (for PyCharm code inspection)
        peak_scalar_products_debug_file = None  # default assignment (for PyCharm code inspection)
        peak_solved_flags_debug_file = None  # default assignment (for PyCharm code inspection)
        template_nbs_debug_file = None  # default assignment (for PyCharm code inspection)
        success_flags_debug_file = None  # default assignment (for PyCharm code inspection)

    if use_gpu and do_spatial_whitening:
        spatial_whitening = cmt.CUDAMatrix(spatial_whitening, copy_on_host=False)

    last_chunk_size = 0
    slice_indices = numpy.zeros(0, dtype=numpy.int32)

    to_explore = range(comm.rank, processed_chunks, comm.size)

    if comm.rank == 0:
        to_explore = get_tqdm_progressbar(params, to_explore)

    for gcount, gidx in enumerate(to_explore):
        # print "Node", comm.rank, "is analyzing chunk", gidx, "/", nb_chunks, " ..."
        # # We need to deal with the borders by taking chunks of size [0, chunck_size + template_shift].

        is_first = data_file.is_first_chunk(gidx, nb_chunks)
        is_last = data_file.is_last_chunk(gidx, nb_chunks)

        if not (is_first and is_last):
            if is_last:
                padding = (-temp_3_shift, 0)
            elif is_first:
                padding = (0, temp_3_shift)
            else:
                padding = (-temp_3_shift, temp_3_shift)
        else:
            padding = (0, 0)

        result = {
            'spiketimes': [],
            'amplitudes': [],
            'templates': [],
        }
        result_debug = {
            'chunk_nbs': [],
            'iteration_nbs': [],
            'peak_nbs': [],
            'peak_local_time_steps': [],
            'peak_time_steps': [],
            'peak_scalar_products': [],
            'peak_solved_flags': [],
            'template_nbs': [],
            'success_flags': [],
        }

        local_chunk, t_offset = data_file.get_data(gidx, chunk_size, padding, nodes=nodes)           
        len_chunk = len(local_chunk)

        if do_spatial_whitening:
            if use_gpu:
                local_chunk = cmt.CUDAMatrix(local_chunk, copy_on_host=False)
                local_chunk = local_chunk.dot(spatial_whitening).asarray()
            else:
                local_chunk = numpy.dot(local_chunk, spatial_whitening)
        if do_temporal_whitening:
            local_chunk = scipy.ndimage.filters.convolve1d(local_chunk, temporal_whitening, axis=0, mode='constant')

        # Extracting peaks.

        all_found_spikes = {}
        if collect_all:
            for i in range(n_e):
                all_found_spikes[i] = []

        local_peaktimes = [numpy.empty(0, dtype=numpy.uint32)]

        if matched_filter:
            if sign_peaks in ['positive', 'both']:
                filter_chunk = scipy.ndimage.filters.convolve1d(local_chunk, waveform_pos, axis=0, mode='constant')
                for i in range(n_e):
                    peaktimes = scipy.signal.find_peaks(filter_chunk[:, i], height=matched_thresholds_pos[i])[0]
                    local_peaktimes.append(peaktimes)
                    if collect_all:
                        all_found_spikes[i] += peaktimes.tolist()
            if sign_peaks in ['negative', 'both']:
                filter_chunk = scipy.ndimage.filters.convolve1d(local_chunk, waveform_neg, axis=0, mode='constant')
                for i in range(n_e):
                    peaktimes = scipy.signal.find_peaks(filter_chunk[:, i], height=matched_thresholds_neg[i])[0]
                    local_peaktimes.append(peaktimes)
                    if collect_all:
                        all_found_spikes[i] += peaktimes.tolist()
            local_peaktimes = numpy.concatenate(local_peaktimes)
        else:
            for i in range(n_e):
                if sign_peaks == 'negative':
                    peaktimes = scipy.signal.find_peaks(-local_chunk[:, i], height=thresholds[i])[0]
                elif sign_peaks == 'positive':
                    peaktimes = scipy.signal.find_peaks(local_chunk[:, i], height=thresholds[i])[0]
                elif sign_peaks == 'both':
                    peaktimes = scipy.signal.find_peaks(numpy.abs(local_chunk[:, i]), height=thresholds[i])[0]
                else:
                    raise ValueError("Unexpected value %s" % sign_peaks)
                local_peaktimes.append(peaktimes)
                if collect_all:
                    all_found_spikes[i] += peaktimes.tolist()
            local_peaktimes = numpy.concatenate(local_peaktimes)

        local_peaktimes = numpy.unique(local_peaktimes)

        g_offset = t_offset + padding[0]

        if ignore_dead_times:
            dead_indices = numpy.searchsorted(all_dead_times, [t_offset, t_offset + chunk_size])
            if dead_indices[0] != dead_indices[1]:
                is_included = numpy.in1d(local_peaktimes + g_offset, all_dead_times[dead_indices[0]:dead_indices[1]])
                local_peaktimes = local_peaktimes[~is_included]
                local_peaktimes = numpy.sort(local_peaktimes)
        else:
            dead_indices = None  # default assignment (for PyCharm code inspection)

        # print "Removing the useless borders..."
        local_borders = (template_shift, len_chunk - template_shift)
        idx = (local_peaktimes >= local_borders[0]) & (local_peaktimes < local_borders[1])
        local_peaktimes = numpy.compress(idx, local_peaktimes)

        if collect_all:
            for i in range(n_e):
                all_found_spikes[i] = numpy.array(all_found_spikes[i], dtype=numpy.uint32)

                if ignore_dead_times:
                    if dead_indices[0] != dead_indices[1]:
                        is_included = numpy.in1d(
                            all_found_spikes[i] + g_offset, all_dead_times[dead_indices[0]:dead_indices[1]]
                        )
                        all_found_spikes[i] = all_found_spikes[i][~is_included]
                        all_found_spikes[i] = numpy.sort(all_found_spikes[i])

                idx = (all_found_spikes[i] >= local_borders[0]) & (all_found_spikes[i] < local_borders[1])
                all_found_spikes[i] = numpy.compress(idx, all_found_spikes[i])

        nb_local_peak_times = len(local_peaktimes)

        if full_gpu:
            # all_indices = cmt.CUDAMatrix(all_indices)
            # tmp_gpu = cmt.CUDAMatrix(local_peaktimes.reshape((1, nb_local_peak_times)), copy_on_host=False)
            _ = cmt.CUDAMatrix(local_peaktimes.reshape((1, nb_local_peak_times)), copy_on_host=False)

        if nb_local_peak_times > 0:
            # print "Computing the b (should full_gpu by putting all chunks on GPU if possible?)..."

            if collect_all:
                c_local_chunk = local_chunk.copy()
            else:
                c_local_chunk = None  # default assignment (for PyCharm code inspection)

            sub_mat = local_chunk[local_peaktimes[:, None] + temp_window]
            sub_mat = sub_mat.transpose(2, 1, 0).reshape(size_window, nb_local_peak_times)

            del local_chunk

            if use_gpu:
                sub_mat = cmt.CUDAMatrix(sub_mat, copy_on_host=False)
                b = cmt.sparse_dot(templates, sub_mat)
            else:
                b = templates.dot(sub_mat)

            del sub_mat

            local_restriction = (t_offset, t_offset + chunk_size)
            all_spikes = local_peaktimes + g_offset

            # Because for GPU, slicing by columns is more efficient, we need to transpose b
            # b = b.transpose()
            if use_gpu and not full_gpu:
                b = b.asarray()           

            failure = numpy.zeros(nb_local_peak_times, dtype=numpy.int32)

            if full_gpu:
                mask = numpy.zeros((2 * n_tm, nb_local_peak_times), dtype=numpy.float32)
                mask[:n_tm, :] = 1
                # data = cmt.empty(mask.shape)
                _ = cmt.empty(mask.shape)
                patch_gpu = b.shape[1] == 1
            else:
                patch_gpu = None

            if collect_all:
                c_all_times = numpy.zeros((len_chunk, n_e), dtype=numpy.bool)
                c_min_times = numpy.maximum(numpy.arange(len_chunk) - template_shift, 0)
                c_max_times = numpy.minimum(numpy.arange(len_chunk) + template_shift + 1, len_chunk)
                for i in range(n_e):
                    c_all_times[all_found_spikes[i], i] = True
            else:
                c_all_times = None  # default assignment (for PyCharm code inspection)
                c_min_times = None  # default assignment (for PyCharm code inspection)
                c_max_times = None  # default assignment (for PyCharm code inspection)

            iteration_nb = 0
            local_max = 0
            numerous_argmax = False
            nb_argmax = n_tm
            best_indices = numpy.zeros(0, dtype=numpy.int32)

            data = b[:n_tm, :]
            flatten_data = data.ravel()

            while numpy.mean(failure) < total_nb_chances:

                # Is there a way to update sub_b * mask at the same time?
                if full_gpu:
                    b_array = b.asarray()
                else:
                    b_array = None

                if numerous_argmax:
                    if len(best_indices) == 0:
                        best_indices = largest_indices(flatten_data, nb_argmax)
                    best_template_index, peak_index = numpy.unravel_index(best_indices[0], data.shape)
                else:
                    best_template_index, peak_index = numpy.unravel_index(data.argmax(), data.shape)

                peak_scalar_product = data[best_template_index, peak_index]
                best_template2_index = best_template_index + n_tm

                if templates_normalization:
                    if full_gpu:
                        best_amp = b_array[best_template_index, peak_index] / n_scalar
                        best_amp2 = b_array[best_template2_index, peak_index] / n_scalar
                    else:
                        best_amp = b[best_template_index, peak_index] / n_scalar
                        if two_components:
                            best_amp2 = b[best_template2_index, peak_index] / n_scalar
                        else:
                            best_amp2 = 0.0
                    best_amp_n = best_amp / norm_templates[best_template_index]
                    best_amp2_n = best_amp2 / norm_templates[best_template2_index]
                else:
                    if full_gpu:
                        best_amp = b_array[best_template_index, peak_index]
                        best_amp = best_amp / norm_templates_2[best_template_index]
                        # TODO is `best_amp` value correct?
                        best_amp2 = b_array[best_template2_index, peak_index]
                        best_amp2 = best_amp2 / norm_templates_2[best_template2_index]
                        # TODO is `best_amp2` value correct?
                    else:
                        best_amp = b[best_template_index, peak_index]
                        best_amp = best_amp / norm_templates_2[best_template_index]
                        # TODO is `best_amp` value correct?
                        if two_components:
                            best_amp2 = b[best_template2_index, peak_index]
                            best_amp2 = best_amp2 / norm_templates_2[best_template2_index]
                            # TODO is `best_amp2` value correct?
                        else:
                            best_amp2 = 0.0

                    best_amp_n = best_amp
                    best_amp2_n = best_amp2

                # Verify amplitude constraint.
                a_min, a_max = amp_limits[best_template_index, :]

                if (a_min <= best_amp_n) & (best_amp_n <= a_max):
                    # Keep the matching.
                    peak_time_step = local_peaktimes[peak_index]

                    peak_data = (local_peaktimes - peak_time_step).astype(np.int32)
                    is_neighbor = np.where(np.abs(peak_data) <= temp_2_shift)[0]
                    idx_neighbor = peak_data[is_neighbor] + temp_2_shift
                    nb_neighbors = len(is_neighbor)
                    indices = np.zeros((s_over, nb_neighbors), dtype=np.int32)
                    indices[idx_neighbor, np.arange(nb_neighbors)] = 1

                    if full_gpu:
                        indices = cmt.CUDAMatrix(indices, copy_on_host=False)
                        if patch_gpu:
                            b_lines = b.get_col_slice(0, b.shape[0])
                        else:
                            b_lines = b.get_col_slice(is_neighbor[0], is_neighbor[-1]+1)
                        tmp1 = cmt.sparse_dot(c_overs[best_template_index], indices, mult=-best_amp)
                        tmp2 = cmt.sparse_dot(c_overs[best_template2_index], indices, mult=-best_amp2)
                        b_lines.add(tmp1.add(tmp2))
                        del tmp1, tmp2
                    else:
                        tmp1 = c_overs[best_template_index].multiply(-best_amp)
                        if numpy.abs(best_amp2) > min_second_component:
                            tmp1 += c_overs[best_template2_index].multiply(-best_amp2)
                        b[:, is_neighbor] += tmp1.dot(indices)

                    numerous_argmax = False

                    # Add matching to the result.
                    t_spike = all_spikes[peak_index]

                    if (t_spike >= local_restriction[0]) and (t_spike < local_restriction[1]):
                        result['spiketimes'] += [t_spike]
                        result['amplitudes'] += [(best_amp_n, best_amp2_n)]
                        result['templates'] += [best_template_index]
                    # Mark current matching as tried.
                    b[best_template_index, peak_index] = -numpy.inf
                    # Save debug data.
                    if debug:
                        result_debug['chunk_nbs'] += [gidx]
                        result_debug['iteration_nbs'] += [iteration_nb]
                        result_debug['peak_nbs'] += [peak_index]
                        result_debug['peak_local_time_steps'] += [local_peaktimes[peak_index]]
                        result_debug['peak_time_steps'] += [all_spikes[peak_index]]
                        result_debug['peak_scalar_products'] += [peak_scalar_product]
                        result_debug['peak_solved_flags'] += [b[best_template_index, peak_index]]
                        result_debug['template_nbs'] += [best_template_index]
                        result_debug['success_flags'] += [True]
                else:
                    # Reject the matching.
                    numerous_argmax = True
                    # Update failure counter of the peak.
                    failure[peak_index] += 1
                    # If the maximal number of failures is reached then mark peak as solved (i.e. not fitted).
                    if failure[peak_index] >= total_nb_chances:
                        # Mark all the matching associated to the current peak as tried.
                        b[:, peak_index] = -numpy.inf
                        index = numpy.arange(n_tm) * nb_local_peak_times + peak_index
                    else:
                        # Mark current matching as tried.
                        b[best_template_index, peak_index] = -numpy.inf
                        index = best_template_index * nb_local_peak_times + peak_index

                    if numerous_argmax:
                        best_indices = best_indices[~numpy.in1d(best_indices, index)]

                    # Save debug data.
                    if debug:
                        result_debug['chunk_nbs'] += [gidx]
                        result_debug['iteration_nbs'] += [iteration_nb]
                        result_debug['peak_nbs'] += [peak_index]
                        result_debug['peak_local_time_steps'] += [local_peaktimes[peak_index]]
                        result_debug['peak_time_steps'] += [all_spikes[peak_index]]
                        result_debug['peak_scalar_products'] += [peak_scalar_product]
                        result_debug['peak_solved_flags'] += [b[best_template_index, peak_index]]
                        result_debug['template_nbs'] += [best_template_index]
                        result_debug['success_flags'] += [False]

                iteration_nb += 1

            spikes_to_write = numpy.array(result['spiketimes'], dtype=numpy.uint32)
            amplitudes_to_write = numpy.array(result['amplitudes'], dtype=numpy.float32)
            templates_to_write = numpy.array(result['templates'], dtype=numpy.uint32)

            spiketimes_file.write(spikes_to_write.tostring())
            amplitudes_file.write(amplitudes_to_write.tostring())
            templates_file.write(templates_to_write.tostring())

            if collect_all:

                for temp, spike in zip(templates_to_write, spikes_to_write - g_offset):
                    c_all_times[c_min_times[spike]:c_max_times[spike], neighbors[temp]] = False

                gspikes = numpy.where(numpy.sum(c_all_times, 1) > 0)[0]
                c_all_times = numpy.take(c_all_times, gspikes, axis=0)
                c_local_chunk = numpy.take(c_local_chunk, gspikes, axis=0) * c_all_times                

                if sign_peaks == 'negative':
                    bestlecs = numpy.argmin(c_local_chunk, 1)
                    if matched_filter:
                        threshs = -matched_thresholds_neg[bestlecs]
                    else:
                        threshs = -thresholds[bestlecs]
                    idx = numpy.where(numpy.min(c_local_chunk, 1) < threshs)[0]
                elif sign_peaks == 'positive':
                    bestlecs = numpy.argmax(c_local_chunk, 1)
                    if matched_filter:
                        threshs = matched_thresholds_pos[bestlecs]
                    else:
                        threshs = thresholds[bestlecs]
                    idx = numpy.where(numpy.max(c_local_chunk, 1) > threshs)[0]
                elif sign_peaks == 'both':
                    c_local_chunk = numpy.abs(c_local_chunk)
                    bestlecs = numpy.argmax(c_local_chunk, 1)
                    if matched_filter:
                        threshs = numpy.minimum(matched_thresholds_neg[bestlecs], matched_thresholds_pos[bestlecs])
                    else:
                        threshs = thresholds[bestlecs]
                    idx = numpy.where(numpy.max(c_local_chunk, 1) > threshs)[0]
                else:
                    raise ValueError("Unexpected value %s" % sign_peaks)

                gspikes = numpy.take(gspikes, idx)
                bestlecs = numpy.take(bestlecs, idx)
                gspikes_to_write = numpy.array(gspikes + g_offset, dtype=numpy.uint32)
                gtemplates_to_write = numpy.array(bestlecs, dtype=numpy.uint32)

                garbage_times_file.write(gspikes_to_write.tostring())
                garbage_temp_file.write(gtemplates_to_write.tostring())

            if debug:
                # Write debug data to debug files.
                for field_label, field_dtype, field_file in [
                    ('chunk_nbs', numpy.uint32, chunk_nbs_debug_file),
                    ('iteration_nbs', numpy.uint32, iteration_nbs_debug_file),
                    ('peak_nbs', numpy.uint32, peak_nbs_debug_file),
                    ('peak_local_time_steps', numpy.uint32, peak_local_time_steps_debug_file),
                    ('peak_time_steps', numpy.uint32, peak_time_steps_debug_file),
                    ('peak_scalar_products', numpy.float32, peak_scalar_products_debug_file),
                    ('peak_solved_flags', numpy.float32, peak_solved_flags_debug_file),
                    ('template_nbs', numpy.uint32, template_nbs_debug_file),
                    ('success_flags', numpy.bool, success_flags_debug_file),
                ]:
                    field_to_write = numpy.array(result_debug[field_label], dtype=field_dtype)
                    field_file.write(field_to_write.tostring())

            if full_gpu:
                del b, data

    sys.stderr.flush()

    spiketimes_file.flush()
    os.fsync(spiketimes_file.fileno())
    spiketimes_file.close()

    amplitudes_file.flush()
    os.fsync(amplitudes_file.fileno())
    amplitudes_file.close()

    templates_file.flush()
    os.fsync(templates_file.fileno())
    templates_file.close()

    if collect_all:

        garbage_temp_file.flush()
        os.fsync(garbage_temp_file.fileno())
        garbage_temp_file.close()
        
        garbage_times_file.flush()
        os.fsync(garbage_times_file.fileno())
        garbage_times_file.close()

    if debug:
        # Close debug files.
        for field_file in [
            chunk_nbs_debug_file,
            iteration_nbs_debug_file,
            peak_nbs_debug_file,
            peak_local_time_steps_debug_file,
            peak_time_steps_debug_file,
            peak_scalar_products_debug_file,
            peak_solved_flags_debug_file,
            template_nbs_debug_file,
            success_flags_debug_file,
        ]:
            field_file.flush()
            os.fsync(field_file.fileno())
            field_file.close()

    comm.Barrier()

    if comm.rank == 0:
        io.collect_data(comm.size, params, erase=True)

    data_file.close()
Exemplo n.º 5
0
def main(filename, params, nb_cpu, nb_gpu, use_gpu):

    try:
        SHARED_MEMORY = True
        MPI.Win.Allocate_shared(1, 1, MPI.INFO_NULL, MPI.COMM_SELF).Free()
    except NotImplementedError:
        SHARED_MEMORY = False

    #################################################################
    sampling_rate = params.getint('data', 'sampling_rate')
    N_e = params.getint('data', 'N_e')
    N_t = params.getint('data', 'N_t')
    N_total = params.getint('data', 'N_total')
    template_shift = params.getint('data', 'template_shift')
    file_out = params.get('data', 'file_out')
    file_out_suff = params.get('data', 'file_out_suff')
    sign_peaks = params.get('detection', 'peaks')
    matched_filter = params.getboolean('detection', 'matched-filter')
    spike_thresh = params.getfloat('detection', 'spike_thresh')
    do_temporal_whitening = params.getboolean('whitening', 'temporal')
    do_spatial_whitening = params.getboolean('whitening', 'spatial')
    chunk_size = int(params.getfloat('fitting', 'chunk') * sampling_rate)
    gpu_only = params.getboolean('fitting', 'gpu_only')
    nodes, edges = io.get_nodes_and_edges(params)
    tmp_limits = params.get('fitting',
                            'amp_limits').replace('(',
                                                  '').replace(')',
                                                              '').split(',')
    tmp_limits = map(float, tmp_limits)
    amp_auto = params.getboolean('fitting', 'amp_auto')
    space_explo = params.getfloat('fitting', 'space_explo')
    nb_chances = params.getint('fitting', 'nb_chances')
    max_chunk = params.getfloat('fitting', 'max_chunk')
    inv_nodes = numpy.zeros(N_total, dtype=numpy.int32)
    inv_nodes[nodes] = numpy.argsort(nodes)
    #################################################################

    if use_gpu:
        import cudamat as cmt
        ## Need to properly handle multi GPU per MPI nodes?
        if nb_gpu > nb_cpu:
            gpu_id = int(comm.rank // nb_cpu)
        else:
            gpu_id = 0
        cmt.cuda_set_device(gpu_id)
        cmt.init()
        cmt.cuda_sync_threads()

    if SHARED_MEMORY:
        templates = io.load_data_memshared(params,
                                           comm,
                                           'templates',
                                           normalize=True,
                                           transpose=True)
        N_tm, x = templates.shape
    else:
        templates = io.load_data(params, 'templates')
        x, N_tm = templates.shape

    N_e = params.getint('data', 'N_e')
    N_t = params.getint('data', 'N_t')
    template_shift = int((N_t - 1) // 2)
    temp_2_shift = 2 * template_shift
    full_gpu = use_gpu and gpu_only
    n_tm = N_tm // 2
    n_scalar = N_e * N_t
    last_spikes = numpy.zeros((n_tm, 1), dtype=numpy.int32)
    temp_window = numpy.arange(-template_shift, template_shift + 1)

    if not amp_auto:
        amp_limits = numpy.zeros((n_tm, 2))
        amp_limits[:, 0] = tmp_limits[0]
        amp_limits[:, 1] = tmp_limits[1]
    else:
        amp_limits = io.load_data(params, 'limits')

    norm_templates = io.load_data(params, 'norm-templates')

    if not SHARED_MEMORY:
        for idx in xrange(templates.shape[1]):
            myslice = numpy.arange(templates.indptr[idx],
                                   templates.indptr[idx + 1])
            templates.data[myslice] /= norm_templates[idx]
        templates = templates.T

    if use_gpu:
        templates = cmt.SparseCUDAMatrix(templates)

    info_string = ''

    if matched_filter:
        if sign_peaks in ['negative', 'both']:
            waveform_neg = io.load_data(params, 'waveform')
            waveform_neg /= (numpy.abs(numpy.sum(waveform_neg)) *
                             len(waveform_neg))
            matched_tresholds_neg = io.load_data(params, 'matched-thresholds')
        if sign_peaks in ['positive', 'both']:
            waveform_pos = io.load_data(params, 'waveform-pos')
            waveform_pos /= (numpy.abs(numpy.sum(waveform_pos)) *
                             len(waveform_pos))
            matched_tresholds_pos = io.load_data(params,
                                                 'matched-thresholds-pos')

    if comm.rank == 0:
        if use_gpu:
            info_string = "using %d GPUs" % (comm.size)
        else:
            info_string = "using %d CPUs" % (comm.size)

    comm.Barrier()

    thresholds = io.load_data(params, 'thresholds')

    if SHARED_MEMORY:
        c_overs = io.load_data_memshared(params,
                                         comm,
                                         'overlaps',
                                         nb_cpu=nb_cpu,
                                         nb_gpu=nb_gpu,
                                         use_gpu=use_gpu)
        c_overlap = io.get_overlaps(comm,
                                    params,
                                    nb_cpu=nb_cpu,
                                    nb_gpu=nb_gpu,
                                    use_gpu=use_gpu)
        over_shape = c_overlap.get('over_shape')[:]
        N_over = int(numpy.sqrt(over_shape[0]))
        S_over = over_shape[1]
    else:
        c_overlap = io.get_overlaps(comm,
                                    params,
                                    nb_cpu=nb_cpu,
                                    nb_gpu=nb_gpu,
                                    use_gpu=use_gpu)
        over_x = c_overlap.get('over_x')[:]
        over_y = c_overlap.get('over_y')[:]
        over_data = c_overlap.get('over_data')[:]
        over_shape = c_overlap.get('over_shape')[:]
        N_over = int(numpy.sqrt(over_shape[0]))
        S_over = over_shape[1]
        c_overlap.close()

        # To be faster, we rearrange the overlaps into a dictionnary. This has a cost: twice the memory usage for
        # a short period of time
        c_overs = {}
        overlaps = scipy.sparse.csr_matrix(
            (over_data, (over_x, over_y)),
            shape=(over_shape[0], over_shape[1]))
        del over_x, over_y, over_data

        for i in xrange(N_over):
            c_overs[i] = overlaps[i * N_over:(i + 1) * N_over]

        del overlaps

    comm.Barrier()

    if comm.rank == 0:
        io.print_and_log([
            "Here comes the SpyKING CIRCUS %s and %d templates..." %
            (info_string, n_tm)
        ], 'default', params)
        io.purge(file_out_suff, '.data')

    if do_spatial_whitening:
        spatial_whitening = io.load_data(params, 'spatial_whitening')
    if do_temporal_whitening:
        temporal_whitening = io.load_data(params, 'temporal_whitening')

    if full_gpu:
        try:
            # If memory on the GPU is large enough, we load the overlaps onto it
            for i in xrange(N_over):
                c_overs[i] = cmt.SparseCUDAMatrix(c_overs[i])
        except Exception:
            if comm.rank == 0:
                io.print_and_log([
                    "Not enough memory on GPUs: GPUs are used for projection only"
                ], 'info', params)
            for i in xrange(N_over):
                if c_overs.has_key(i):
                    del c_overs[i]
            full_gpu = False

    borders, nb_chunks, chunk_len, last_chunk_len = io.analyze_data(
        params, chunk_size)
    nb_chunks = int(min(nb_chunks, max_chunk))

    if comm.rank == 0:
        pbar = get_progressbar(int(nb_chunks // comm.size))

    spiketimes_file = open(file_out_suff + '.spiketimes-%d.data' % comm.rank,
                           'wb')
    amplitudes_file = open(file_out_suff + '.amplitudes-%d.data' % comm.rank,
                           'wb')
    templates_file = open(file_out_suff + '.templates-%d.data' % comm.rank,
                          'wb')

    comm.Barrier()

    if use_gpu and do_spatial_whitening:
        spatial_whitening = cmt.CUDAMatrix(spatial_whitening,
                                           copy_on_host=False)

    last_chunk_size = 0

    for gcount, gidx in enumerate(xrange(comm.rank, nb_chunks, comm.size)):
        #print "Node", comm.rank, "is analyzing chunk", gidx, "/", nb_chunks, " ..."
        ## We need to deal with the borders by taking chunks of size [0, chunck_size+template_shift]
        if gidx == (nb_chunks - 1):
            padding = (-2 * borders, 0)
        elif gidx == 0:
            padding = (0, 2 * borders)
        else:
            padding = (-2 * borders, 2 * borders)

        result = {'spiketimes': [], 'amplitudes': [], 'templates': []}

        local_chunk, local_shape = io.load_chunk(params,
                                                 gidx,
                                                 chunk_len,
                                                 chunk_size,
                                                 padding,
                                                 nodes=nodes)

        if do_spatial_whitening:
            if use_gpu:
                local_chunk = cmt.CUDAMatrix(local_chunk, copy_on_host=False)
                local_chunk = local_chunk.dot(spatial_whitening).asarray()
            else:
                local_chunk = numpy.dot(local_chunk, spatial_whitening)
        if do_temporal_whitening:
            local_chunk = scipy.ndimage.filters.convolve1d(local_chunk,
                                                           temporal_whitening,
                                                           axis=0,
                                                           mode='constant')

        #print "Extracting the peaks..."
        local_peaktimes = numpy.zeros(0, dtype=numpy.int32)

        if matched_filter:
            if sign_peaks in ['positive', 'both']:
                filter_chunk = scipy.ndimage.filters.convolve1d(
                    local_chunk, waveform_pos, axis=0, mode='constant')
                for i in xrange(N_e):
                    peaktimes = algo.detect_peaks(filter_chunk[:, i],
                                                  matched_tresholds_pos[i])
                    local_peaktimes = numpy.concatenate(
                        (local_peaktimes, peaktimes))
            if sign_peaks in ['negative', 'both']:
                filter_chunk = scipy.ndimage.filters.convolve1d(
                    local_chunk, waveform_neg, axis=0, mode='constant')
                for i in xrange(N_e):
                    peaktimes = algo.detect_peaks(filter_chunk[:, i],
                                                  matched_tresholds_neg[i])
                    local_peaktimes = numpy.concatenate(
                        (local_peaktimes, peaktimes))
        else:
            for i in xrange(N_e):
                if sign_peaks == 'negative':
                    peaktimes = algo.detect_peaks(local_chunk[:, i],
                                                  thresholds[i],
                                                  valley=True)
                elif sign_peaks == 'positive':
                    peaktimes = algo.detect_peaks(local_chunk[:, i],
                                                  thresholds[i],
                                                  valley=False)
                elif sign_peaks == 'both':
                    peaktimes = algo.detect_peaks(numpy.abs(local_chunk[:, i]),
                                                  thresholds[i],
                                                  valley=False)
                local_peaktimes = numpy.concatenate(
                    (local_peaktimes, peaktimes))

        local_peaktimes = numpy.unique(local_peaktimes)

        #print "Removing the useless borders..."
        local_borders = (template_shift, local_shape - template_shift)
        idx = (local_peaktimes >= local_borders[0]) & (local_peaktimes <
                                                       local_borders[1])
        local_peaktimes = numpy.compress(idx, local_peaktimes)
        n_t = len(local_peaktimes)
        len_chunk = local_chunk.shape[0]
        all_indices = numpy.arange(n_t)

        if full_gpu:
            #   all_indices = cmt.CUDAMatrix(all_indices)
            tmp_gpu = cmt.CUDAMatrix(local_peaktimes.reshape((1, n_t)),
                                     copy_on_host=False)

        if n_t > 0:
            #print "Computing the b (should full_gpu by putting all chunks on GPU if possible?)..."

            local_chunk = local_chunk.T.ravel()
            sub_mat = numpy.zeros((N_e * (2 * template_shift + 1), n_t),
                                  dtype=numpy.float32)

            if len_chunk != last_chunk_size:
                slice_indices = numpy.zeros(0, dtype=numpy.int32)
                for idx in xrange(N_e):
                    slice_indices = numpy.concatenate(
                        (slice_indices, len_chunk * idx + temp_window))
                last_chunk_size = len_chunk

            for count, idx in enumerate(local_peaktimes):
                sub_mat[:, count] = numpy.take(local_chunk,
                                               slice_indices + idx)

            del local_chunk

            if use_gpu:
                sub_mat = cmt.CUDAMatrix(sub_mat, copy_on_host=False)
                b = cmt.sparse_dot(templates, sub_mat)
            else:
                b = templates.dot(sub_mat)

            del sub_mat

            local_offset = gidx * chunk_size + padding[0] // N_total
            local_bounds = (temp_2_shift, local_shape - temp_2_shift)
            all_spikes = local_peaktimes + local_offset
            penalty = numpy.ones((n_tm, n_t), dtype=numpy.float32)

            # Because for GPU, slicing by columns is more efficient, we need to transpose b
            #b           = b.transpose()
            if use_gpu and not full_gpu:
                b = b.asarray()

            failure = numpy.zeros(n_t, dtype=numpy.int32)

            if full_gpu:
                mask = cmt.CUDAMatrix(penalty, copy_on_host=False)
                data = cmt.empty(mask.shape)
                cm_zeros = cmt.CUDAMatrix(numpy.zeros(mask.shape),
                                          copy_on_host=False)
                patch_gpu = b.shape[1] == 1
            else:
                mask = penalty
                sub_b = b[:n_tm, :]

            min_time = local_peaktimes.min()
            max_time = local_peaktimes.max()
            local_len = max_time - min_time + 1
            min_times = numpy.maximum(
                local_peaktimes - min_time - temp_2_shift, 0)
            max_times = numpy.minimum(
                local_peaktimes - min_time + temp_2_shift + 1,
                max_time - min_time)
            max_n_t = int(space_explo * (max_time - min_time + 1) //
                          (2 * temp_2_shift + 1))

            while (numpy.mean(failure) < nb_chances):

                if full_gpu:
                    sub_b = b.get_row_slice(0, n_tm)
                    sub_b.mult(mask, data)
                    tmp_mat = data.max(0)
                    argmax_bi = numpy.argsort(tmp_mat.asarray()[0, :])[::-1]
                    del tmp_mat, sub_b
                else:
                    data = sub_b * mask
                    argmax_bi = numpy.argsort(numpy.max(data, 0))[::-1]

                while (len(argmax_bi) > 0):

                    subset = []
                    indices = []
                    all_times = numpy.zeros(local_len, dtype=numpy.bool)

                    for count, idx in enumerate(argmax_bi):
                        myslice = all_times[min_times[idx]:max_times[idx]]
                        if not myslice.any():
                            subset += [idx]
                            indices += [count]
                            all_times[min_times[idx]:max_times[idx]] = True
                        if len(subset) > max_n_t:
                            break

                    subset = numpy.array(subset, dtype=numpy.int32)
                    argmax_bi = numpy.delete(argmax_bi, indices)

                    if full_gpu:
                        sub_b = b.get_row_slice(0, n_tm)
                        tmp_mat = sub_b.argmax(0)
                        inds_t, inds_temp = subset, tmp_mat.asarray()[
                            0, :][subset].astype(numpy.int32)
                        del tmp_mat
                    else:
                        inds_t, inds_temp = subset, numpy.argmax(
                            numpy.take(sub_b, subset, axis=1), 0)

                    if full_gpu:
                        best_amp = sub_b.asarray()[inds_temp,
                                                   inds_t] / n_scalar
                        best_amp2 = b.asarray()[inds_temp + n_tm,
                                                inds_t] / n_scalar
                        sub_mask = numpy.ones((sub_b.shape),
                                              dtype=numpy.float32)
                        sub_mask[inds_temp, inds_t] = 0
                        sub_mask = cmt.CUDAMatrix(sub_mask, copy_on_host=False)
                        mask.mult(sub_mask)
                        del sub_mask
                    else:
                        mask[inds_temp, inds_t] = 0
                        best_amp = sub_b[inds_temp, inds_t] / n_scalar
                        best_amp2 = b[inds_temp + n_tm, inds_t] / n_scalar

                    best_amp_n = best_amp / numpy.take(norm_templates,
                                                       inds_temp)
                    best_amp2_n = best_amp2 / numpy.take(
                        norm_templates, inds_temp + n_tm)

                    all_idx = ((best_amp_n >= amp_limits[inds_temp, 0]) &
                               (best_amp_n <= amp_limits[inds_temp, 1]))
                    to_keep = numpy.where(all_idx == True)[0]
                    to_reject = numpy.where(all_idx == False)[0]
                    ts = numpy.take(local_peaktimes, inds_t[to_keep])
                    good = (ts >= local_bounds[0]) & (ts < local_bounds[1])

                    # We reduce to only the good times that will be kept
                    #to_keep      = to_keep[good]
                    #ts           = ts[good]

                    if len(ts) > 0:
                        if full_gpu:
                            tmp = cmt.CUDAMatrix(numpy.ones((len(ts), 1)),
                                                 copy_on_host=False)
                            tmp3 = cmt.CUDAMatrix(-ts.reshape((len(ts), 1)),
                                                  copy_on_host=False)
                            tmp = tmp.dot(tmp_gpu)
                            tmp.add_col_vec(tmp3)
                            condition = cmt.empty(tmp.shape)
                            cmt.abs(tmp, condition).less_than(temp_2_shift + 1)
                            condition = condition.asarray().astype(numpy.bool)
                            tmp = tmp.asarray().astype(numpy.int32)
                        else:
                            tmp = numpy.dot(
                                numpy.ones((len(ts), 1), dtype=numpy.int32),
                                local_peaktimes.reshape((1, n_t)))
                            tmp -= ts.reshape((len(ts), 1))
                            condition = numpy.abs(tmp) <= temp_2_shift

                        for count, keep in enumerate(to_keep):

                            idx_b = numpy.compress(condition[count, :],
                                                   all_indices)
                            ytmp = tmp[count,
                                       condition[count, :]] + temp_2_shift

                            indices = numpy.zeros((S_over, len(ytmp)),
                                                  dtype=numpy.float32)
                            indices[ytmp, numpy.arange(len(ytmp))] = 1

                            if full_gpu:
                                indices = cmt.CUDAMatrix(indices,
                                                         copy_on_host=False)
                                if patch_gpu:
                                    b_lines = b.get_col_slice(0, b.shape[0])
                                else:
                                    b_lines = b.get_col_slice(
                                        idx_b[0], idx_b[-1] + 1)

                                tmp1 = cmt.sparse_dot(c_overs[inds_temp[keep]],
                                                      indices,
                                                      mult=-best_amp[keep])
                                tmp2 = cmt.sparse_dot(c_overs[inds_temp[keep] +
                                                              n_tm],
                                                      indices,
                                                      mult=-best_amp2[keep])
                                b_lines.add(tmp1)
                                b_lines.add(tmp2)
                                del tmp1, tmp2
                            else:
                                tmp1 = c_overs[inds_temp[keep]].multiply(
                                    -best_amp[keep]).dot(indices)
                                tmp2 = c_overs[inds_temp[keep] +
                                               n_tm].multiply(-best_amp2[keep]
                                                              ).dot(indices)
                                b[:, idx_b] += tmp1 + tmp2

                            if good[count]:

                                t_spike = ts[count] + local_offset
                                result['spiketimes'] += [t_spike]
                                result['amplitudes'] += [(best_amp_n[keep],
                                                          best_amp2_n[keep])]
                                result['templates'] += [inds_temp[keep]]

                    myslice = numpy.take(inds_t, to_reject)
                    failure[myslice] += 1
                    sub_idx = (numpy.take(failure, myslice) >= nb_chances)
                    if full_gpu:
                        N = numpy.sum(sub_idx)
                        if N > 0:
                            cu_slice = cmt.CUDAMatrix(numpy.compress(
                                sub_idx, myslice).reshape(1, N),
                                                      copy_on_host=False)
                            mask.set_selected_columns(cu_slice, cm_zeros)
                            del cu_slice
                    else:
                        mask[:, numpy.compress(sub_idx, myslice)] = 0

                    if full_gpu:
                        del sub_b

            spikes_to_write = numpy.array(result['spiketimes'],
                                          dtype=numpy.uint32)
            amplitudes_to_write = numpy.array(result['amplitudes'],
                                              dtype=numpy.float32)
            templates_to_write = numpy.array(result['templates'],
                                             dtype=numpy.int32)

            spiketimes_file.write(spikes_to_write.tostring())
            amplitudes_file.write(amplitudes_to_write.tostring())
            templates_file.write(templates_to_write.tostring())

            if full_gpu:
                del mask, b, cm_zeros, data

        if comm.rank == 0:
            pbar.update(gcount)

    spiketimes_file.flush()
    os.fsync(spiketimes_file.fileno())
    spiketimes_file.close()

    amplitudes_file.flush()
    os.fsync(amplitudes_file.fileno())
    amplitudes_file.close()

    templates_file.flush()
    os.fsync(templates_file.fileno())
    templates_file.close()

    comm.Barrier()

    if comm.rank == 0:
        pbar.finish()

    if comm.rank == 0:
        io.collect_data(comm.size, params, erase=True)
Exemplo n.º 6
0
def main(params, nb_cpu, nb_gpu, use_gpu):
    numpy.random.seed(426236)
    #params         = detect_memory(params)
    parallel_hdf5 = get_parallel_hdf5_flag(params)
    logger = init_logging(params.logfile)
    logger = logging.getLogger('circus.extracting')
    #################################################################
    data_file = params.data_file
    N_e = params.getint('data', 'N_e')
    N_t = params.getint('detecton', 'N_t')
    N_total = params.nb_channels
    template_shift = params.getint('detection', 'template_shift')
    chunk_size = params.getint('data', 'chunk_size')
    file_out = params.get('data', 'file_out')
    file_out_suff = params.get('data', 'file_out_suff')
    do_temporal_whitening = params.getboolean('whitening', 'temporal')
    do_spatial_whitening = params.getboolean('whitening', 'spatial')
    nodes, edges = get_nodes_and_edges(params)
    safety_time = params.getint('extracting', 'safety_time')
    max_elts_temp = params.getint('extracting', 'max_elts')
    output_dim = params.getfloat('extracting', 'output_dim')
    noise_thr = params.getfloat('extracting', 'noise_thr')
    hdf5_compress = params.getboolean('data', 'hdf5_compress')
    blosc_compress = params.getboolean('data', 'blosc_compress')
    tmp_limits = params.get('fitting',
                            'amp_limits').replace('(',
                                                  '').replace(')',
                                                              '').split(',')
    amp_limits = map(float, tmp_limits)
    elt_count = 0
    inv_nodes = numpy.zeros(N_total, dtype=numpy.int32)
    inv_nodes[nodes] = numpy.argsort(nodes)
    #################################################################

    if comm.rank == 0:
        print_and_log(["Extracting templates from already found clusters..."],
                      'default', logger)

    thresholds = io.load_data(params, 'thresholds')
    basis_proj, basis_rec = io.load_data(params, 'basis')
    clusters, spiketimes, N_clusters = io.load_data(params, 'spike-cluster')
    inv_clusters = numpy.zeros(clusters.max() + 1, dtype=numpy.int32)
    inv_clusters[numpy.unique(clusters)] = numpy.argsort(
        numpy.unique(clusters))

    if use_gpu:
        import cudamat as cmt
        ## Need to properly handle multi GPU per MPI nodes?
        if nb_gpu > nb_cpu:
            gpu_id = int(comm.rank // nb_cpu)
        else:
            gpu_id = 0
        cmt.cuda_set_device(gpu_id)
        cmt.init()
        cmt.cuda_sync_threads()

    if do_spatial_whitening:
        spatial_whitening = io.load_data(params, 'spatial_whitening')
    if do_temporal_whitening:
        temporal_whitening = io.load_data(params, 'temporal_whitening')

    if use_gpu and do_spatial_whitening:
        spatial_whitening = cmt.CUDAMatrix(spatial_whitening,
                                           copy_on_host=False)

    result = {}
    for i in xrange(N_clusters):
        result['data_tmp_' + str(i)] = numpy.zeros(
            (0, N_e * basis_proj.shape[1]), dtype=numpy.float32)
        result['times_' + str(i)] = numpy.zeros(0, dtype=numpy.int32)

    nb_chunks, last_chunk_len = data_file.analyze(chunk_size)

    # I guess this is more relevant, to take signals from all over the recordings
    all_chunks = numpy.random.permutation(numpy.arange(nb_chunks))

    nb_templates = numpy.sum(
        comm.rank == numpy.mod(numpy.arange(N_clusters), comm.size))
    nb_elts = max_elts_temp * nb_templates

    to_explore = all_chunks

    if comm.rank == 0:
        to_explore = get_tqdm_progressbar(to_explore)

    for gidx in all_chunks:

        if (elt_count < nb_elts):
            #print "Node", comm.rank, "is analyzing chunk", gidx, "/", nb_chunks, " ..."
            local_chunk, t_offset = data_file.get_data(gidx,
                                                       chunk_size,
                                                       nodes=nodes)
            local_shape = len(local_chunk)

            if do_spatial_whitening:
                if use_gpu:
                    local_chunk = cmt.CUDAMatrix(local_chunk,
                                                 copy_on_host=False)
                    local_chunk = local_chunk.dot(spatial_whitening).asarray()
                else:
                    local_chunk = numpy.dot(local_chunk, spatial_whitening)
            if do_temporal_whitening:
                local_chunk = scipy.ndimage.filters.convolve1d(
                    local_chunk, temporal_whitening, axis=0, mode='constant')

            #print "Extracting the peaks..."
            idx = numpy.where((spiketimes >= gidx * chunk_size)
                              & (spiketimes < (gidx + 1) * chunk_size))[0]
            local_offset = t_offset
            local_peaktimes = spiketimes[idx] - local_offset

            #print "Removing the useless borders..."
            local_borders = (template_shift, chunk_size - template_shift)
            idx = (local_peaktimes >= local_borders[0]) & (local_peaktimes <
                                                           local_borders[1])
            local_peaktimes = local_peaktimes[idx]
            local_clusters = inv_clusters[clusters[idx]]

            if len(local_peaktimes) > 0:
                all_times = numpy.zeros(
                    (N_e, local_peaktimes[-1] - local_peaktimes[0] + 1),
                    dtype=numpy.bool)
                min_times = numpy.maximum(
                    local_peaktimes - local_peaktimes[0] - safety_time, 0)
                max_times = numpy.minimum(
                    local_peaktimes - local_peaktimes[0] + safety_time + 1,
                    local_peaktimes[-1] - local_peaktimes[0])

                n_times = len(local_peaktimes)
                argmax_peak = numpy.random.permutation(numpy.arange(n_times))
                clusters_id = local_clusters[argmax_peak]
                local_peaktimes = local_peaktimes[argmax_peak]

                #print "Selection of the peaks with spatio-temporal masks..."
                for idx in xrange(len(local_peaktimes)):

                    if elt_count == nb_elts:
                        break

                    temp = clusters_id[idx]

                    if numpy.mod(temp, comm.size) == comm.rank:

                        elec = numpy.argmin(local_chunk[local_peaktimes[idx]])
                        indices = inv_nodes[edges[nodes[elec]]]
                        myslice = all_times[indices,
                                            min_times[idx]:max_times[idx]]
                        peak = local_peaktimes[idx]
                        if not myslice.any():
                            if (len(result['data_tmp_' + str(temp)]) <
                                    max_elts_temp):
                                elt_count += 1
                                sub_mat = local_chunk[peak -
                                                      template_shift:peak +
                                                      template_shift + 1, :]
                                sub_mat = numpy.dot(basis_rec, sub_mat)
                                nx, ny = sub_mat.shape
                                sub_mat = sub_mat.reshape((1, nx * ny))
                                result['data_tmp_' + str(temp)] = numpy.vstack(
                                    (result['data_tmp_' + str(temp)], sub_mat))
                                to_add = numpy.array([peak + local_offset],
                                                     dtype=numpy.int32)
                                result['times_' +
                                       str(temp)] = numpy.concatenate(
                                           (result['times_' + str(temp)],
                                            to_add))
                            all_times[indices,
                                      min_times[idx]:max_times[idx]] = True

    total_nb_elts = 0
    for temp in xrange(N_clusters):
        total_nb_elts += len(result['data_tmp_' + str(temp)])

    gdata = gather_array(numpy.array([total_nb_elts], dtype=numpy.float32),
                         comm, 0)
    if comm.rank == 0:
        print_and_log([
            "Found %d spikes over %d requested" %
            (int(numpy.sum(gdata)), int(nb_elts))
        ], 'default', logger)

    #print "Spikes extracted in", time.time() - t_start, "s"

    comm.Barrier()

    local_nb_clusters = 0
    for temp in xrange(comm.rank, N_clusters, comm.size):
        if len(result['data_tmp_' + str(temp)]) > 0:
            local_nb_clusters += 1

    #print total_nb_clusters, "found in", time.time() - t_start, "s"
    gdata3 = gather_array(
        numpy.array([local_nb_clusters], dtype=numpy.float32), comm, 0)

    comm.Barrier()
    if comm.rank == 0:
        print_and_log(["Extracting the templates..."], 'default', logger)

    total_nb_clusters = int(
        comm.bcast(numpy.array([int(numpy.sum(gdata3))], dtype=numpy.int32),
                   root=0)[0])
    offsets = numpy.zeros(comm.size, dtype=numpy.int32)
    for i in xrange(comm.size - 1):
        offsets[i + 1] = comm.bcast(numpy.array([local_nb_clusters],
                                                dtype=numpy.int32),
                                    root=i)

    if parallel_hdf5:
        node_pad = numpy.sum(offsets[:comm.rank + 1])
        hfile = h5py.File(file_out_suff + '.templates.hdf5',
                          'w',
                          driver='mpio',
                          comm=comm,
                          libver='earliest')
        norms = hfile.create_dataset('norms',
                                     shape=(2 * total_nb_clusters, ),
                                     dtype=numpy.float32,
                                     chunks=True)
        electrodes = hfile.create_dataset('electrodes',
                                          shape=(total_nb_clusters, ),
                                          dtype=numpy.int32,
                                          chunks=True)
        amps_lims = hfile.create_dataset('limits',
                                         shape=(total_nb_clusters, 2),
                                         dtype=numpy.float32,
                                         chunks=True)
        g_count = node_pad
        g_offset = total_nb_clusters
    else:
        node_pad = 0
        hfile = h5py.File(file_out_suff + '.templates-%d.hdf5' % comm.rank,
                          'w',
                          libver='earliest')
        electrodes = hfile.create_dataset('electrodes',
                                          shape=(local_nb_clusters, ),
                                          dtype=numpy.int32,
                                          chunks=True)
        norms = hfile.create_dataset('norms',
                                     shape=(2 * local_nb_clusters, ),
                                     dtype=numpy.float32,
                                     chunks=True)
        amps_lims = hfile.create_dataset('limits',
                                         shape=(local_nb_clusters, 2),
                                         dtype=numpy.float32,
                                         chunks=True)
        g_count = 0
        g_offset = local_nb_clusters

    cfile = h5py.File(file_out_suff + '.clusters-%d.hdf5' % comm.rank,
                      'w',
                      libver='earliest')
    count_templates = node_pad

    temp_x = numpy.zeros(0, dtype=numpy.int32)
    temp_y = numpy.zeros(0, dtype=numpy.int32)
    temp_data = numpy.zeros(0, dtype=numpy.float32)

    to_explore = xrange(comm.rank, N_clusters, comm.size)

    if comm.rank == 0:
        to_explore = get_tqdm_progressbar(to_explore)

    for temp in to_explore:
        n_data = len(result['data_tmp_' + str(temp)])
        if n_data > 0:
            data = result['data_tmp_' + str(temp)].reshape(
                n_data, basis_proj.shape[1], N_e)
            first_component = numpy.median(data, axis=0)
            tmp_templates = numpy.dot(first_component.T, basis_rec)
            electrodes[g_count] = indices[tmpidx[0][0]]
            indices = inv_nodes[edges[nodes[electrodes[-1]]]]
            templates = numpy.zeros((N_e, N_t), dtype=numpy.float32)
            if shift > 0:
                templates[indices, shift:] = tmp_templates[:, :-shift]
            elif shift < 0:
                templates[indices, :shift] = tmp_templates[:, -shift:]
            else:
                templates[indices, :] = tmp_templates

            templates = templates.flatten()
            dx = templates.nonzero()[0].astype(numpy.int32)

            temp_x = numpy.concatenate((temp_x, dx))
            temp_y = numpy.concatenate(
                (temp_y,
                 count_templates * numpy.ones(len(dx), dtype=numpy.int32)))
            temp_data = numpy.concatenate((temp_data, templates[dx]))

            norms[g_count] = numpy.sqrt(
                numpy.sum(templates.flatten()**2) / (N_e * N_t))

            x, y, z = data.shape
            data_flat = data.reshape(x, y * z)
            first_flat = first_component.reshape(y * z, 1)
            amplitudes = numpy.dot(data_flat, first_flat)
            amplitudes /= numpy.sum(first_flat**2)
            for i in xrange(x):
                data_flat[i, :] -= amplitudes[i] * first_flat[:, 0]

            variations = 10 * numpy.median(
                numpy.abs(amplitudes - numpy.median(amplitudes)))
            physical_limit = noise_thr * (
                -thresholds[indices[tmpidx[0][0]]]) / tmp_templates.min()
            amp_min = max(physical_limit,
                          numpy.median(amplitudes) - variations)
            amp_max = min(amp_limits[1], numpy.median(amplitudes) + variations)
            amps_lims[g_count] = [amp_min, amp_max]

            if len(data_flat) > 1:
                pca = PCA(1)
                res_pca = pca.fit_transform(data_flat.astype(numpy.double))
                second_component = pca.components_.T.astype(
                    numpy.float32).reshape(y, z)
            else:
                second_component = data_flat.reshape(y, z) / numpy.sum(
                    data_flat**2)

            tmp_templates = numpy.dot(second_component.T, basis_rec)
            offset = total_nb_clusters + count_templates
            sub_templates = numpy.zeros((N_e, N_t), dtype=numpy.float32)
            if shift > 0:
                sub_templates[indices, shift:] = tmp_templates[:, :-shift]
            elif shift < 0:
                sub_templates[indices, :shift] = tmp_templates[:, -shift:]
            else:
                sub_templates[indices, :] = tmp_templates

            sub_templates = sub_templates.flatten()
            dx = sub_templates.nonzero()[0].astype(numpy.int32)

            temp_x = numpy.concatenate((temp_x, dx))
            temp_y = numpy.concatenate(
                (temp_y, offset * numpy.ones(len(dx), dtype=numpy.int32)))
            temp_data = numpy.concatenate((temp_data, sub_templates[dx]))

            norms[g_count + g_offset] = numpy.sqrt(
                numpy.sum(sub_templates.flatten()**2) / (N_e * N_t))

            count_templates += 1
            g_count += 1

        io.write_datasets(cfile,
                          to_write,
                          result,
                          ielec,
                          compress=hdf5_compress)

    #At the end we should have a templates variable to store.
    cfile.close()
    del result, templates, amps_lims
    comm.Barrier()

    #We need to gather the sparse arrays
    temp_x = gather_array(temp_x, comm, dtype='int32', compress=blosc_compress)
    temp_y = gather_array(temp_y, comm, dtype='int32', compress=blosc_compress)
    temp_data = gather_array(temp_data, comm, compress=blosc_compress)

    if parallel_hdf5:
        if comm.rank == 0:
            rs = [
                h5py.File(file_out_suff + '.clusters-%d.hdf5' % i,
                          'r',
                          libver='earliest') for i in xrange(comm.size)
            ]
            cfile = h5py.File(file_out_suff + '.clusters.hdf5',
                              'w',
                              libver='earliest')
            io.write_datasets(cfile, ['electrodes'],
                              {'electrodes': electrodes[:]},
                              compress=hdf5_compress)
            for i in xrange(comm.size):
                for j in range(i, N_e, comm.size):
                    io.write_datasets(cfile,
                                      to_write,
                                      rs[i],
                                      j,
                                      compress=hdf5_compress)
                rs[i].close()
                os.remove(file_out_suff + '.clusters-%d.hdf5' % i)
            cfile.close()
        hfile.close()
    else:
        hfile.close()
        if comm.rank == 0:
            ts = [
                h5py.File(file_out_suff + '.templates-%d.hdf5' % i,
                          'r',
                          libver='earliest') for i in xrange(comm.size)
            ]
            rs = [
                h5py.File(file_out_suff + '.clusters-%d.hdf5' % i,
                          'r',
                          libver='earliest') for i in xrange(comm.size)
            ]
            result = {}
            hfile = h5py.File(file_out_suff + '.templates.hdf5',
                              'w',
                              libver='earliest')
            cfile = h5py.File(file_out_suff + '.clusters.hdf5',
                              'w',
                              libver='earliest')
            electrodes = hfile.create_dataset('electrodes',
                                              shape=(total_nb_clusters, ),
                                              dtype=numpy.int32,
                                              chunks=True)
            norms = hfile.create_dataset('norms',
                                         shape=(2 * total_nb_clusters, ),
                                         dtype=numpy.float32,
                                         chunks=True)
            amplitudes = hfile.create_dataset('limits',
                                              shape=(total_nb_clusters, 2),
                                              dtype=numpy.float32,
                                              chunks=True)
            count = 0
            for i in xrange(comm.size):
                loc_temp = ts[i].get('templates')
                middle = loc_temp.shape[2] // 2
                norms[count:count + middle] = loc_norms[:middle]
                norms[n_clusters + count:n_clusters + count +
                      middle] = loc_norms[middle:]
                electrodes[count:count + middle] = ts[i].get('electrodes')
                amplitudes[count:count + middle] = ts[i].get('limits')
                count += middle
                for j in range(i, N_e, comm.size):
                    io.write_datasets(cfile,
                                      to_write,
                                      rs[i],
                                      j,
                                      compress=hdf5_compress)
                ts[i].close()
                rs[i].close()
                os.remove(file_out_suff + '.templates-%d.hdf5' % i)
                os.remove(file_out_suff + '.clusters-%d.hdf5' % i)
            io.write_datasets(cfile, ['electrodes'],
                              {'electrodes': electrodes[:]},
                              compress=hdf5_compress)
            hfile.close()
            cfile.close()

    if comm.rank == 0:
        hfile = h5py.File(file_out_suff + '.templates.hdf5',
                          'r+',
                          libver='earliest')
        hfile.create_dataset('temp_x', data=temp_x)
        hfile.create_dataset('temp_y', data=temp_y)
        hfile.create_dataset('temp_data', data=temp_data)
        hfile.create_dataset('temp_shape',
                             data=numpy.array(
                                 [N_e, N_t, 2 * total_nb_clusters],
                                 dtype=numpy.int32))
        hfile.close()

    comm.Barrier()

    if comm.rank == 0:
        print_and_log(["Merging similar templates..."], 'default', logger)

    merged1 = algo.merging_cc(params, parallel_hdf5)

    comm.Barrier()
    if remove_mixture:
        if comm.rank == 0:
            print_and_log(["Removing mixtures..."], 'default', logger)
        merged2 = algo.delete_mixtures(params, parallel_hdf5)
    else:
        merged2 = [0, 0]

    if comm.rank == 0:
        print_and_log([
            "Number of global merges    : %d" % merged1[1],
            "Number of mixtures removed : %d" % merged2[1]
        ], 'info', logger)

    comm.Barrier()
    io.get_overlaps(params, erase=True, parallel_hdf5=parallel_hdf5)

    data_file.close()
Exemplo n.º 7
0
def main(params, nb_cpu, nb_gpu, use_gpu):

    #################################################################
    logger         = init_logging(params.logfile)
    logger         = logging.getLogger('circus.fitting')
    data_file      = params.data_file
    data_file.open()
    N_e            = params.getint('data', 'N_e')
    N_total        = params.nb_channels
    N_t            = params.getint('detection', 'N_t')
    template_shift = params.getint('detection', 'template_shift')
    file_out       = params.get('data', 'file_out')
    file_out_suff  = params.get('data', 'file_out_suff')
    sign_peaks     = params.get('detection', 'peaks')
    matched_filter = params.getboolean('detection', 'matched-filter')
    spike_thresh   = params.getfloat('detection', 'spike_thresh')
    do_temporal_whitening = params.getboolean('whitening', 'temporal')
    do_spatial_whitening  = params.getboolean('whitening', 'spatial')
    chunk_size     = params.getint('fitting', 'chunk_size')
    gpu_only       = params.getboolean('fitting', 'gpu_only')
    nodes, edges   = get_nodes_and_edges(params)
    tmp_limits     = params.get('fitting', 'amp_limits').replace('(', '').replace(')', '').split(',')
    tmp_limits     = map(float, tmp_limits)
    amp_auto       = params.getboolean('fitting', 'amp_auto')
    space_explo    = params.getfloat('fitting', 'space_explo')
    nb_chances     = params.getint('fitting', 'nb_chances')
    max_chunk      = params.getfloat('fitting', 'max_chunk')
    noise_thr      = params.getfloat('clustering', 'noise_thr')
    collect_all    = params.getboolean('fitting', 'collect_all')
    ignore_dead_times = params.getboolean('triggers', 'ignore_times')
    inv_nodes         = numpy.zeros(N_total, dtype=numpy.int32)
    inv_nodes[nodes]  = numpy.argsort(nodes)
    #################################################################

    if use_gpu:
        import cudamat as cmt
        ## Need to properly handle multi GPU per MPI nodes?
        if nb_gpu > nb_cpu:
            gpu_id = int(comm.rank//nb_cpu)
        else:
            gpu_id = 0
        cmt.cuda_set_device(gpu_id)
        cmt.init()
        cmt.cuda_sync_threads()

    if SHARED_MEMORY:
        templates  = io.load_data_memshared(params, 'templates', normalize=True, transpose=True)
        N_tm, x    = templates.shape
    else:
        templates  = io.load_data(params, 'templates')
        x, N_tm    = templates.shape

    temp_2_shift   = 2*template_shift
    full_gpu       = use_gpu and gpu_only
    n_tm           = N_tm//2
    n_scalar       = N_e*N_t
    last_spikes    = numpy.zeros((n_tm, 1), dtype=numpy.int32)
    temp_window    = numpy.arange(-template_shift, template_shift+1)

    if not amp_auto:
        amp_limits       = numpy.zeros((n_tm, 2))
        amp_limits[:, 0] = tmp_limits[0]
        amp_limits[:, 1] = tmp_limits[1]
    else:
        amp_limits       = io.load_data(params, 'limits')

    norm_templates = io.load_data(params, 'norm-templates')
    
    if not SHARED_MEMORY:
        for idx in xrange(templates.shape[1]):
            myslice = numpy.arange(templates.indptr[idx], templates.indptr[idx+1])
            templates.data[myslice] /= norm_templates[idx]
        templates = templates.T

    if matched_filter:
        if sign_peaks in ['negative', 'both']:
            waveform_neg  = io.load_data(params, 'waveform')
            waveform_neg /= (numpy.abs(numpy.sum(waveform_neg))* len(waveform_neg))
            matched_tresholds_neg = io.load_data(params, 'matched-thresholds')
        if sign_peaks in ['positive', 'both']:
            waveform_pos  = io.load_data(params, 'waveform-pos')
            waveform_pos /= (numpy.abs(numpy.sum(waveform_pos))* len(waveform_pos))
            matched_tresholds_pos = io.load_data(params, 'matched-thresholds-pos')

    if ignore_dead_times:
        dead_times = numpy.loadtxt(params.get('triggers', 'dead_file'))
        if len(dead_times.shape) == 1:
            dead_times = dead_times.reshape(1, 2)
        dead_in_ms = params.getboolean('triggers', 'dead_in_ms')
        if dead_in_ms:
            dead_times *= numpy.int64(data_file.sampling_rate*1e-3)
        dead_times = dead_times.astype(numpy.int64)
        all_dead_times = []
        for i in xrange(len(dead_times)):
            all_dead_times += range(dead_times[i, 0], dead_times[i, 1])

    thresholds = io.load_data(params, 'thresholds')


    if collect_all:
        neighbors = {}
        for i in xrange(n_tm):
            tmp  = templates[i, :].toarray().reshape(N_e, N_t) * norm_templates[i]
            neighbors[i] = numpy.where(numpy.sum(tmp, 1) != 0)[0]

    if use_gpu:
        templates = cmt.SparseCUDAMatrix(templates, copy_on_host=False)

    info_string   = ''

    
    if comm.rank == 0:
        if use_gpu:
            info_string = "using %d GPUs" %(comm.size)
        else:
            info_string = "using %d CPUs" %(comm.size)

    comm.Barrier()

    c_overlap  = io.get_overlaps(params, nb_cpu=nb_cpu, nb_gpu=nb_gpu, use_gpu=use_gpu)
    over_shape = c_overlap.get('over_shape')[:]
    N_over     = int(numpy.sqrt(over_shape[0]))
    S_over     = over_shape[1]
    ## If the number of overlaps is different from templates, we need to recompute them
    if N_over != N_tm:
        if comm.rank == 0:
            print_and_log(['Templates have been modified, recomputing the overlaps...'], 'default', logger)
        c_overlap  = io.get_overlaps(params, erase=True, nb_cpu=nb_cpu, nb_gpu=nb_gpu, use_gpu=use_gpu)
        over_shape = c_overlap.get('over_shape')[:]
        N_over     = int(numpy.sqrt(over_shape[0]))
        S_over     = over_shape[1]

    if SHARED_MEMORY:
        c_overs    = io.load_data_memshared(params, 'overlaps', nb_cpu=nb_cpu, nb_gpu=nb_gpu, use_gpu=use_gpu)
    else:
        c_overlap  = io.get_overlaps(params, nb_cpu=nb_cpu, nb_gpu=nb_gpu, use_gpu=use_gpu)
        over_x     = c_overlap.get('over_x')[:]
        over_y     = c_overlap.get('over_y')[:]
        over_data  = c_overlap.get('over_data')[:]
        over_shape = c_overlap.get('over_shape')[:]
        c_overlap.close()

        # To be faster, we rearrange the overlaps into a dictionnary. This has a cost: twice the memory usage for 
        # a short period of time
        c_overs   = {}
        overlaps  = scipy.sparse.csr_matrix((over_data, (over_x, over_y)), shape=(over_shape[0], over_shape[1]))
        del over_x, over_y, over_data
        
        for i in xrange(N_over):
            c_overs[i] = overlaps[i*N_over:(i+1)*N_over]
        del overlaps

    comm.Barrier()

    if comm.rank == 0:
        print_and_log(["Here comes the SpyKING CIRCUS %s and %d templates..." %(info_string, n_tm)], 'default', logger)
        purge(file_out_suff, '.data')

    if do_spatial_whitening:
        spatial_whitening  = io.load_data(params, 'spatial_whitening')
    if do_temporal_whitening:
        temporal_whitening = io.load_data(params, 'temporal_whitening')

    if full_gpu:
        try:
            # If memory on the GPU is large enough, we load the overlaps onto it
            for i in xrange(N_over):
                c_overs[i] = cmt.SparseCUDAMatrix(c_overs[i], copy_on_host=False)
        except Exception:
            if comm.rank == 0:
                print_and_log(["Not enough memory on GPUs: GPUs are used for projection only"], 'info', logger)
            for i in xrange(N_over):
                if c_overs.has_key(i):
                    del c_overs[i]
            full_gpu = False

    nb_chunks, last_chunk_len = data_file.analyze(chunk_size)
    processed_chunks          = int(min(nb_chunks, max_chunk))

    comm.Barrier()
    spiketimes_file = open(file_out_suff + '.spiketimes-%d.data' %comm.rank, 'wb')
    comm.Barrier()
    amplitudes_file = open(file_out_suff + '.amplitudes-%d.data' %comm.rank, 'wb')
    comm.Barrier()
    templates_file  = open(file_out_suff + '.templates-%d.data' %comm.rank, 'wb')
    comm.Barrier()

    if collect_all:
        garbage_times_file = open(file_out_suff + '.gspiketimes-%d.data' %comm.rank, 'wb')
        comm.Barrier()
        garbage_temp_file  = open(file_out_suff + '.gtemplates-%d.data' %comm.rank, 'wb')
        comm.Barrier()


    if use_gpu and do_spatial_whitening:
        spatial_whitening = cmt.CUDAMatrix(spatial_whitening, copy_on_host=False)

    last_chunk_size = 0

    to_explore = xrange(comm.rank, processed_chunks, comm.size)

    if comm.rank == 0:
        to_explore = get_tqdm_progressbar(to_explore)

    for gcount, gidx in enumerate(to_explore):
        #print "Node", comm.rank, "is analyzing chunk", gidx, "/", nb_chunks, " ..."
        ## We need to deal with the borders by taking chunks of size [0, chunck_size+template_shift]

        is_first = data_file.is_first_chunk(gidx, nb_chunks)
        is_last  = data_file.is_last_chunk(gidx, nb_chunks)

        if is_last:
            padding = (-2*template_shift, 0)
        elif is_first:
            padding = (0, 2*template_shift)
        else:
            padding = (-2*template_shift, 2*template_shift)

        result       = {'spiketimes' : [], 'amplitudes' : [], 'templates' : []}

        local_chunk, t_offset = data_file.get_data(gidx, chunk_size, padding, nodes=nodes)           
        len_chunk             = len(local_chunk)

        if do_spatial_whitening:
            if use_gpu:
                local_chunk = cmt.CUDAMatrix(local_chunk, copy_on_host=False)
                local_chunk = local_chunk.dot(spatial_whitening).asarray()
            else:
                local_chunk = numpy.dot(local_chunk, spatial_whitening)
        if do_temporal_whitening:
            local_chunk = scipy.ndimage.filters.convolve1d(local_chunk, temporal_whitening, axis=0, mode='constant')

        #print "Extracting the peaks..."

        if collect_all:
            all_found_spikes = {}
            for i in xrange(N_e):
                all_found_spikes[i] = []

        local_peaktimes = numpy.zeros(0, dtype=numpy.int32)

        if matched_filter:
            if sign_peaks in ['positive', 'both']:
                filter_chunk = scipy.ndimage.filters.convolve1d(local_chunk, waveform_pos, axis=0, mode='constant')
                for i in xrange(N_e):
                    peaktimes = algo.detect_peaks(filter_chunk[:, i], matched_tresholds_pos[i])
                    local_peaktimes = numpy.concatenate((local_peaktimes, peaktimes))
                    if collect_all:
                        all_found_spikes[i] += peaktimes.tolist()
            if sign_peaks in ['negative', 'both']:
                filter_chunk = scipy.ndimage.filters.convolve1d(local_chunk, waveform_neg, axis=0, mode='constant')
                for i in xrange(N_e):
                    peaktimes = algo.detect_peaks(filter_chunk[:, i], matched_tresholds_neg[i])
                    local_peaktimes = numpy.concatenate((local_peaktimes, peaktimes))
                    if collect_all:
                        all_found_spikes[i] += peaktimes.tolist()
        else:
            for i in xrange(N_e):
                if sign_peaks == 'negative':
                    peaktimes = algo.detect_peaks(local_chunk[:, i], thresholds[i], valley=True)
                elif sign_peaks == 'positive':
                    peaktimes = algo.detect_peaks(local_chunk[:, i], thresholds[i], valley=False)
                elif sign_peaks == 'both':
                    peaktimes = algo.detect_peaks(numpy.abs(local_chunk[:, i]), thresholds[i], valley=False)                    
                local_peaktimes = numpy.concatenate((local_peaktimes, peaktimes)) 
                if collect_all:
                    all_found_spikes[i] += peaktimes.tolist()


            
        local_peaktimes = numpy.unique(local_peaktimes)

        if ignore_dead_times:
            local_peaktimes = numpy.array(list(set(local_peaktimes + t_offset).difference(all_dead_times)), dtype=numpy.int32) - t_offset
            local_peaktimes = numpy.sort(local_peaktimes)

        #print "Removing the useless borders..."
        local_borders   = (template_shift, len_chunk - template_shift)
        idx             = (local_peaktimes >= local_borders[0]) & (local_peaktimes < local_borders[1])
        local_peaktimes = numpy.compress(idx, local_peaktimes)

        if collect_all:
            for i in xrange(N_e):
                all_found_spikes[i] = numpy.array(all_found_spikes[i], dtype=numpy.int32)

                if ignore_dead_times:
                    all_found_spikes[i] = numpy.array(list(set(all_found_spikes[i] + t_offset).difference(all_dead_times)), dtype=numpy.int32) - t_offset
                    all_found_spikes[i] = numpy.sort(all_found_spikes[i])

                idx                 = (all_found_spikes[i] >= local_borders[0]) & (all_found_spikes[i] < local_borders[1])
                all_found_spikes[i] = numpy.compress(idx, all_found_spikes[i])

        n_t             = len(local_peaktimes)
        all_indices     = numpy.arange(n_t)
                            

        if full_gpu:
        #   all_indices = cmt.CUDAMatrix(all_indices)
            tmp_gpu = cmt.CUDAMatrix(local_peaktimes.reshape((1, n_t)), copy_on_host=False)


        if n_t > 0:
            #print "Computing the b (should full_gpu by putting all chunks on GPU if possible?)..."     

            if collect_all:
                c_local_chunk = local_chunk.copy()

            local_chunk = local_chunk.T.ravel()
            sub_mat     = numpy.zeros((N_e*(2*template_shift+1), n_t), dtype=numpy.float32)

            if len_chunk != last_chunk_size:
                slice_indices = numpy.zeros(0, dtype=numpy.int32)
                for idx in xrange(N_e):
                    slice_indices = numpy.concatenate((slice_indices, len_chunk*idx + temp_window))
                last_chunk_size = len_chunk

            for count, idx in enumerate(local_peaktimes):
                sub_mat[:, count] = numpy.take(local_chunk, slice_indices + idx)

            del local_chunk

            if use_gpu: 
                sub_mat = cmt.CUDAMatrix(sub_mat, copy_on_host=False)
                b       = cmt.sparse_dot(templates, sub_mat)
            else:
                b       = templates.dot(sub_mat)                

            del sub_mat

            local_offset = padding[0] + t_offset
            local_bounds = (temp_2_shift, len_chunk - temp_2_shift)
            all_spikes   = local_peaktimes + local_offset

            # Because for GPU, slicing by columns is more efficient, we need to transpose b
            #b           = b.transpose()
            if use_gpu and not full_gpu:
                b = b.asarray()

            failure     = numpy.zeros(n_t, dtype=numpy.int32)

            if full_gpu:
                mask     = numpy.zeros((2*n_tm, n_t), dtype=numpy.float32)
                mask[:n_tm, :] = 1
                data     = cmt.empty(mask.shape)
                patch_gpu= b.shape[1] == 1
            else:
                mask     = numpy.ones((n_tm, n_t), dtype=numpy.float32)
                sub_b    = b[:n_tm, :]

            min_time     = local_peaktimes.min()
            max_time     = local_peaktimes.max()
            local_len    = max_time - min_time + 1
            min_times    = numpy.maximum(local_peaktimes - min_time - temp_2_shift, 0)
            max_times    = numpy.minimum(local_peaktimes - min_time + temp_2_shift + 1, max_time - min_time)
            max_n_t      = int(space_explo*(max_time-min_time+1)//(2*temp_2_shift + 1))

            if collect_all:
                c_all_times = numpy.zeros((len_chunk, N_e), dtype=numpy.bool)
                c_min_times = numpy.maximum(numpy.arange(len_chunk) - template_shift, 0)
                c_max_times = numpy.minimum(numpy.arange(len_chunk) + template_shift + 1, len_chunk)
                for i in xrange(N_e):
                    c_all_times[all_found_spikes[i], i] = True
                    
            while (numpy.mean(failure) < nb_chances):

                if full_gpu:
                    gpu_mask    = cmt.CUDAMatrix(mask, copy_on_host=False)
                    b.mult(gpu_mask, data)
                    tmp_mat     = data.max(0)
                    argmax_bi   = numpy.argsort(tmp_mat.asarray()[0, :])[::-1]
                    del tmp_mat
                else:
                    data        = sub_b * mask
                    argmax_bi   = numpy.argsort(numpy.max(data, 0))[::-1]

                while (len(argmax_bi) > 0):

                    subset          = []
                    indices         = []
                    all_times       = numpy.zeros(local_len, dtype=numpy.bool)

                    for count, idx in enumerate(argmax_bi):
                        myslice = all_times[min_times[idx]:max_times[idx]]
                        if not myslice.any():
                            subset  += [idx]
                            indices += [count]
                            all_times[min_times[idx]:max_times[idx]] = True
                        if len(subset) > max_n_t:
                            break

                    subset    = numpy.array(subset, dtype=numpy.int32)
                    argmax_bi = numpy.delete(argmax_bi, indices)

                    if full_gpu:
                        b_array = b.asarray()
                        sub_b   = b_array[:n_tm, :]

                    inds_t, inds_temp = subset, numpy.argmax(numpy.take(sub_b, subset, axis=1), 0)

                    if full_gpu:
                        best_amp  = sub_b[inds_temp, inds_t]/n_scalar
                        best_amp2 = b_array[inds_temp + n_tm, inds_t]/n_scalar
                    else:
                        
                        best_amp  = sub_b[inds_temp, inds_t]/n_scalar
                        best_amp2 = b[inds_temp + n_tm, inds_t]/n_scalar

                    mask[inds_temp, inds_t] = 0

                    best_amp_n   = best_amp/numpy.take(norm_templates, inds_temp)
                    best_amp2_n  = best_amp2/numpy.take(norm_templates, inds_temp + n_tm)

                    all_idx      = ((best_amp_n >= amp_limits[inds_temp, 0]) & (best_amp_n <= amp_limits[inds_temp, 1]))
                    to_keep      = numpy.where(all_idx == True)[0]
                    to_reject    = numpy.where(all_idx == False)[0]
                    ts           = numpy.take(local_peaktimes, inds_t[to_keep])
                    good         = (ts >= local_bounds[0]) & (ts < local_bounds[1])

                    # We reduce to only the good times that will be kept
                    #to_keep      = to_keep[good]
                    #ts           = ts[good]
                    
                    if len(ts) > 0:
                        if full_gpu:
                            tmp  = cmt.CUDAMatrix(numpy.ones((len(ts), 1)), copy_on_host=False)
                            tmp3 = cmt.CUDAMatrix(-ts.reshape((len(ts), 1)), copy_on_host=False)
                            tmp  = tmp.dot(tmp_gpu)
                            tmp.add_col_vec(tmp3)
                            condition = cmt.empty(tmp.shape)
                            cmt.abs(tmp, condition).less_than(temp_2_shift + 1)
                            condition = condition.asarray().astype(numpy.bool)
                            tmp       = tmp.asarray().astype(numpy.int32)
                        else:
                            tmp      = numpy.dot(numpy.ones((len(ts), 1), dtype=numpy.int32), local_peaktimes.reshape((1, n_t)))
                            tmp     -= ts.reshape((len(ts), 1))
                            condition = numpy.abs(tmp) <= temp_2_shift

                        for count, keep in enumerate(to_keep):
                            
                            idx_b    = numpy.compress(condition[count, :], all_indices)
                            ytmp     = tmp[count, condition[count, :]] + temp_2_shift
                            
                            indices  = numpy.zeros((S_over, len(ytmp)), dtype=numpy.float32)
                            indices[ytmp, numpy.arange(len(ytmp))] = 1

                            if full_gpu: 
                                indices  = cmt.CUDAMatrix(indices, copy_on_host=False)
                                if patch_gpu:
                                    b_lines  = b.get_col_slice(0, b.shape[0])
                                else:
                                    b_lines  = b.get_col_slice(idx_b[0], idx_b[-1]+1)

                                tmp1 = cmt.sparse_dot(c_overs[inds_temp[keep]], indices, mult=-best_amp[keep])
                                tmp2 = cmt.sparse_dot(c_overs[inds_temp[keep] + n_tm], indices, mult=-best_amp2[keep])
                                b_lines.add(tmp1.add(tmp2))
                                del tmp1, tmp2
                            else:
                                tmp1   = c_overs[inds_temp[keep]].multiply(-best_amp[keep]).dot(indices)
                                tmp2   = c_overs[inds_temp[keep] + n_tm].multiply(-best_amp2[keep]).dot(indices)
                                b[:, idx_b] += tmp1 + tmp2

                            if good[count]:

                                t_spike               = ts[count] + local_offset
                                result['spiketimes'] += [t_spike]
                                result['amplitudes'] += [(best_amp_n[keep], best_amp2_n[keep])]
                                result['templates']  += [inds_temp[keep]]

                    myslice           = numpy.take(inds_t, to_reject)
                    failure[myslice] += 1
                    sub_idx           = (numpy.take(failure, myslice) >= nb_chances)
                    
                    mask[:, numpy.compress(sub_idx, myslice)] = 0


            spikes_to_write     = numpy.array(result['spiketimes'], dtype=numpy.uint32)
            amplitudes_to_write = numpy.array(result['amplitudes'], dtype=numpy.float32)
            templates_to_write  = numpy.array(result['templates'], dtype=numpy.int32)

            spiketimes_file.write(spikes_to_write.tostring())
            amplitudes_file.write(amplitudes_to_write.tostring())
            templates_file.write(templates_to_write.tostring())

            if collect_all:

                for temp, spike in zip(templates_to_write, spikes_to_write - local_offset):
                    c_all_times[c_min_times[spike]:c_max_times[spike], neighbors[temp]] = False

                gspikes       = numpy.where(numpy.sum(c_all_times, 1) > 0)[0]
                c_all_times   = numpy.take(c_all_times, gspikes, axis=0)
                c_local_chunk = numpy.take(c_local_chunk, gspikes, axis=0) * c_all_times                

                if sign_peaks == 'negative':
                    bestlecs = numpy.argmin(c_local_chunk, 1)
                    if matched_filter:
                        threshs = -matched_tresholds_neg[bestlecs]
                    else:
                        threshs = -thresholds[bestlecs]
                    idx      = numpy.where(numpy.min(c_local_chunk, 1) < threshs)[0]
                elif sign_peaks == 'positive':
                    bestlecs = numpy.argmax(c_local_chunk, 1)
                    if matched_filter:
                        threshs = matched_tresholds_pos[bestlecs]
                    else:
                        threshs = thresholds[bestlecs]
                    idx      = numpy.where(numpy.max(c_local_chunk, 1) > threshs)[0]
                elif sign_peaks == 'both':
                    c_local_chunk = numpy.abs(c_local_chunk)
                    bestlecs = numpy.argmax(c_local_chunk, 1)
                    if matched_filter:
                        threshs = numpy.minimum(matched_tresholds_neg[bestlecs], matched_tresholds_pos[bestlecs])
                    else:
                        threshs = thresholds[bestlecs]
                    idx      = numpy.where(numpy.max(c_local_chunk, 1) > threshs)[0]
                
                gspikes  = numpy.take(gspikes, idx)
                bestlecs = numpy.take(bestlecs, idx)
                gspikes_to_write     = numpy.array(gspikes + local_offset, dtype=numpy.uint32)
                gtemplates_to_write  = numpy.array(bestlecs, dtype=numpy.int32)

                garbage_times_file.write(gspikes_to_write.tostring())
                garbage_temp_file.write(gtemplates_to_write.tostring())
            

            if full_gpu:
                del gpu_mask, b, data

    spiketimes_file.flush()
    os.fsync(spiketimes_file.fileno())
    spiketimes_file.close()

    amplitudes_file.flush()
    os.fsync(amplitudes_file.fileno())
    amplitudes_file.close()

    templates_file.flush()
    os.fsync(templates_file.fileno())
    templates_file.close()

    if collect_all:

        garbage_temp_file.flush()
        os.fsync(garbage_temp_file.fileno())
        garbage_temp_file.close()
        
        garbage_times_file.flush()
        os.fsync(garbage_times_file.fileno())
        garbage_times_file.close()


    comm.Barrier()
    
    if comm.rank == 0:
        io.collect_data(comm.size, params, erase=True)

    data_file.close()
Exemplo n.º 8
0
def main(params, nb_cpu, nb_gpu, use_gpu):
    # Part 1: Whitening
    numpy.random.seed(420)
    # params = detect_memory(params)
    _ = init_logging(params.logfile)
    logger = logging.getLogger('circus.whitening')
    #################################################################
    data_file = params.data_file
    N_e = params.getint('data', 'N_e')
    hdf5_compress = params.getboolean('data', 'hdf5_compress')
    N_total = params.nb_channels
    N_t = params.getint('detection', 'N_t')
    dist_peaks = params.getint('detection', 'dist_peaks')
    template_shift = params.getint('detection', 'template_shift')
    file_out_suff = params.get('data', 'file_out_suff')
    spike_thresh = params.getfloat('detection', 'spike_thresh')
    spike_width = params.getfloat('detection', 'spike_width')
    matched_filter = params.getboolean('detection', 'matched-filter')
    matched_thresh = params.getfloat('detection', 'matched_thresh')
    fudge = params.getfloat('whitening', 'fudge')
    sign_peaks = params.get('detection', 'peaks')
    do_temporal_whitening = params.getboolean('whitening', 'temporal')
    do_spatial_whitening = params.getboolean('whitening', 'spatial')
    ignore_spikes = params.getboolean('whitening', 'ignore_spikes')
    chunk_size = detect_memory(params, whitening=True)
    plot_path = os.path.join(params.get('data', 'file_out_suff'), 'plots')
    nodes, edges = get_nodes_and_edges(params)
    safety_time = params.getint('whitening', 'safety_time')
    safety_space = params.getboolean('whitening', 'safety_space')
    sort_waveforms = params.getboolean('whitening', 'sort_waveforms')
    nb_temp_white = min(max(20, comm.size), N_e)
    max_silence_1 = int(20 * params.rate // comm.size)
    max_silence_2 = 5000
    inv_nodes = numpy.zeros(N_total, dtype=numpy.int32)
    inv_nodes[nodes] = numpy.arange(len(nodes))
    jitter_range = params.getint('detection', 'jitter_range')
    template_shift_2 = template_shift + jitter_range
    use_hanning = params.getboolean('detection', 'hanning')
    rejection_threshold = params.getfloat('detection', 'rejection_threshold')
    noise_window = params.getint('detection', 'noise_time')
    data_file.open()
    #################################################################

    if use_hanning:
        hanning_filter = numpy.hanning(N_t)

    if comm.rank == 0:
        print_and_log(
            ["Analyzing data to get whitening matrices and thresholds..."],
            'default', logger)

    nodes_indices = {}
    for elec in numpy.arange(N_e):
        nodes_indices[elec] = inv_nodes[edges[nodes[elec]]]

    if use_gpu:
        import cudamat as cmt
        # # Need to properly handle multi GPU per MPI nodes?
        if nb_gpu > nb_cpu:
            gpu_id = int(comm.rank // nb_cpu)
        else:
            gpu_id = 0
        cmt.cuda_set_device(gpu_id)
        cmt.init()
        cmt.cuda_sync_threads()

    nb_chunks, last_chunk_len = data_file.analyze(chunk_size)

    if nb_chunks < comm.size:

        res = io.data_stats(params, show=False)
        chunk_size = int(res * params.rate // comm.size)
        if comm.rank == 0:
            print_and_log(
                ["Too much cores, automatically resizing the data chunks"],
                'debug', logger)

        nb_chunks, last_chunk_len = data_file.analyze(chunk_size)

    # I guess this is more relevant, to take signals from all over the recordings.
    if nb_chunks > comm.size:
        all_chunks = numpy.random.permutation(
            numpy.arange(nb_chunks - 1, dtype=numpy.int32))
    else:
        all_chunks = numpy.random.permutation(
            numpy.arange(nb_chunks, dtype=numpy.int32))

    all_electrodes = numpy.random.permutation(N_e)

    numpy.random.seed(comm.rank)

    for gidx in [all_chunks[comm.rank]]:

        # print "Node", comm.rank, "is analyzing chunk", gidx,  "/", nb_chunks, " ..."
        local_chunk, t_offset = data_file.get_data(gidx,
                                                   chunk_size,
                                                   nodes=nodes)
        local_shape = len(local_chunk)

        # print "Node", comm.rank, "computes the median absolute deviations in a random chunk"
        thresholds = numpy.zeros(N_e, dtype=numpy.float32)
        for i in range(N_e):
            u = numpy.median(local_chunk[:, i], 0)
            thresholds[i] = numpy.median(numpy.abs(local_chunk[:, i] - u), 0)
        gdata = gather_array(thresholds, comm)
        if comm.rank == 0:
            gdata = gdata.reshape((comm.size, N_e))
            thresholds = numpy.mean(gdata, 0)
            bfile = h5py.File(file_out_suff + '.basis.hdf5',
                              'w',
                              libver='earliest')
            io.write_datasets(bfile, ['thresholds'],
                              {'thresholds': thresholds},
                              compression=hdf5_compress)
            bfile.close()
        comm.Barrier()
        thresholds = io.load_data(params, 'thresholds')

        local_borders = (template_shift, local_shape - template_shift)
        found_peaktimes = []

        if ignore_spikes:
            # Extracting the peaks.
            local_peaktimes = [np.empty(0, dtype=numpy.uint32)]
            for i in range(N_e):
                peaktimes = scipy.signal.find_peaks(numpy.abs(local_chunk[:,
                                                                          i]),
                                                    height=thresholds[i],
                                                    width=spike_width,
                                                    wlen=N_t)[0]
                peaktimes = peaktimes.astype(numpy.uint32)

                # print "Removing the useless borders..."
                idx = (peaktimes >= local_borders[0]) & (peaktimes <
                                                         local_borders[1])
                peaktimes = numpy.compress(idx, peaktimes)

                found_peaktimes.append(peaktimes)
        else:
            for i in range(N_e):
                found_peaktimes.append(numpy.zeros(0, dtype=numpy.uint32))

        all_peaktimes = numpy.concatenate(found_peaktimes)
        local_peaktimes = numpy.unique(all_peaktimes)

        if len(local_peaktimes) > 0:

            diff_times = local_peaktimes[-1] - local_peaktimes[0]
            all_times = numpy.zeros((N_e, diff_times + 1), dtype=numpy.bool)
            padded_peaks = (local_peaktimes - local_peaktimes[0]).astype(
                numpy.int32)
            min_times = numpy.maximum(padded_peaks - safety_time, 0)
            max_times = numpy.minimum(padded_peaks + safety_time + 1,
                                      diff_times + 1)

            test_extremas = numpy.zeros((N_e, diff_times + 1),
                                        dtype=numpy.bool)
            for i in range(N_e):
                test_extremas[i,
                              found_peaktimes[i] - local_peaktimes[0]] = True

            argmax_peak = numpy.random.permutation(
                numpy.arange(len(local_peaktimes)))
            all_idx = numpy.take(local_peaktimes, argmax_peak)

            # print "Selection of the peaks with spatio-temporal masks..."
            for idx, peak in zip(argmax_peak, all_idx):

                all_elecs = numpy.where(test_extremas[:, peak -
                                                      local_peaktimes[0]])[0]
                data = local_chunk[peak, all_elecs]
                elec = all_elecs[numpy.argmax(numpy.abs(data))]
                indices = nodes_indices[elec]
                if safety_space:
                    all_times[indices, min_times[idx]:max_times[idx]] = True
                else:
                    all_times[elec, min_times[idx]:max_times[idx]] = True
        else:
            all_times = numpy.zeros((N_e, len(local_chunk)), dtype=numpy.bool)

    if do_temporal_whitening:

        local_res_temp = []

        for elec in all_electrodes[numpy.arange(comm.rank, nb_temp_white,
                                                comm.size)]:
            res = numpy.zeros((0, N_t), dtype=numpy.float32)
            scount = 0
            indices = nodes_indices[elec]
            all_times_elec = numpy.any(numpy.take(all_times, indices, axis=0),
                                       0)
            esubset = numpy.where(all_times_elec == False)[0]
            bound = len(esubset) - N_t
            while (scount < bound) and (len(res) < max_silence_2):
                myslice = esubset[scount:scount + N_t]
                if numpy.all((myslice - esubset[scount]) == numpy.arange(N_t)):
                    scount += N_t
                    res = numpy.vstack((res, local_chunk[myslice, elec]))
                else:
                    scount += 1
            if len(res) > 5:
                local_res_temp += [numpy.cov(res.T)]

        nb_elecs = numpy.array([len(local_res_temp)], dtype=numpy.float32)
        local_res_temp = numpy.array(local_res_temp, dtype=numpy.float32)
        if len(local_res_temp) == 0:
            local_res_temp = numpy.zeros(0, dtype=numpy.float32)
        else:
            local_res_temp = numpy.sum(local_res_temp, 0)
        all_res_temp = gather_array(local_res_temp.ravel(), comm, 0, 1)
        all_elecs = gather_array(nb_elecs, comm, 0, 1)

    if do_spatial_whitening:

        local_res_spac = numpy.zeros((N_e, N_e), dtype=numpy.float32)
        local_silences = []

        for elec in numpy.arange(comm.rank, N_e, comm.size):
            indices = nodes_indices[elec]
            all_times_elec = numpy.any(numpy.take(all_times, indices, axis=0),
                                       0)
            esubset = numpy.where(all_times_elec == False)[0]
            local_data = local_chunk[esubset][:, indices]
            local_whitening = get_whitening_matrix(
                local_data, fudge=fudge).astype(numpy.float32)
            pos = numpy.where(elec == indices)[0]
            local_res_spac[elec, indices] = local_whitening[pos]
            local_silences += [len(esubset)]

        all_res_spac = gather_array(local_res_spac.ravel(), comm, 0, 1)
        all_silences = gather_array(
            numpy.array(local_silences, dtype=numpy.int32), comm, 0, 1,
            'uint32')

    if comm.rank == 0:

        to_write = {}

        if do_temporal_whitening:
            try:
                nb_silences = numpy.sum(all_elecs > 0)
                all_res_temp = all_res_temp.reshape((nb_silences, N_t**2))
            except Exception:
                print_and_log([
                    "No silent periods detected: something wrong with the parameters?"
                ], 'error', logger)
            all_res_temp = numpy.sum(all_res_temp, 0)
            all_res_temp = all_res_temp.reshape(
                (N_t, N_t)) / numpy.sum(all_elecs)
            temporal_whitening = get_whitening_matrix(
                all_res_temp.astype(numpy.double),
                fudge=1e-3)[template_shift].astype(numpy.float32)
            temporal_whitening /= temporal_whitening.sum()
            to_write['temporal'] = temporal_whitening
            have_nans = numpy.sum(numpy.isnan(temporal_whitening))

            if have_nans > 0:
                temporal_whitening = numpy.zeros(N_t, dtype=numpy.float32)
                temporal_whitening[N_t // 2] = 1
                to_write['temporal'] = temporal_whitening
                print_and_log(
                    ["Disabling temporal whitening because of NaNs found"],
                    'info', logger)

        if do_spatial_whitening:
            all_res_spac = all_res_spac.reshape(comm.size, N_e, N_e)
            spatial_whitening = numpy.sum(all_res_spac, 0)
            to_write['spatial'] = spatial_whitening

            if ignore_spikes:
                print_and_log([
                    "Found %gs without spikes to compute the whitening matrix..."
                    % (numpy.mean(all_silences) / params.rate)
                ], 'default', logger)
            else:
                print_and_log([
                    "Found %gs to compute the whitening matrix..." %
                    (numpy.mean(all_silences) / params.rate)
                ], 'default', logger)

            have_nans = numpy.sum(numpy.isnan(spatial_whitening))

            if have_nans > 0:
                spatial_whitening = numpy.eye(spatial_whitening.shape[0],
                                              dtype=numpy.float32)
                to_write['spatial'] = spatial_whitening
                print_and_log(
                    ["Disabling spatial whitening because of NaNs found"],
                    'info', logger)

        bfile = h5py.File(file_out_suff + '.basis.hdf5',
                          'r+',
                          libver='earliest')
        io.write_datasets(bfile,
                          list(to_write.keys()),
                          to_write,
                          compression=hdf5_compress)
        bfile.close()

    comm.Barrier()

    if do_spatial_whitening or do_temporal_whitening:

        if comm.rank == 0:
            print_and_log(
                ["Because of whitening, need to recompute the thresholds..."],
                'default', logger)

        if do_spatial_whitening:
            spatial_whitening = io.load_data(params, 'spatial_whitening')
            if use_gpu:
                spatial_whitening = cmt.CUDAMatrix(spatial_whitening,
                                                   copy_on_host=False)
        if do_temporal_whitening:
            temporal_whitening = io.load_data(params, 'temporal_whitening')

        for gidx in [all_chunks[comm.rank]]:
            local_chunk, t_offset = data_file.get_data(gidx,
                                                       chunk_size,
                                                       nodes=nodes)
            local_shape = len(local_chunk)

            if do_spatial_whitening:
                if use_gpu:
                    local_chunk = cmt.CUDAMatrix(local_chunk,
                                                 copy_on_host=False)
                    local_chunk = local_chunk.dot(spatial_whitening).asarray()
                else:
                    local_chunk = numpy.dot(local_chunk, spatial_whitening)
            if do_temporal_whitening:
                local_chunk = scipy.ndimage.filters.convolve1d(
                    local_chunk, temporal_whitening, axis=0, mode='constant')

            thresholds = numpy.zeros(N_e, dtype=numpy.float32)
            for i in range(N_e):
                u = numpy.median(local_chunk[:, i], 0)
                thresholds[i] = numpy.median(numpy.abs(local_chunk[:, i] - u),
                                             0)
            gdata = gather_array(thresholds, comm)
            if comm.rank == 0:
                gdata = gdata.reshape((comm.size, N_e))
                thresholds = numpy.mean(gdata, 0)
                bfile = h5py.File(file_out_suff + '.basis.hdf5',
                                  'r+',
                                  libver='earliest')
                bfile.pop('thresholds')
                io.write_datasets(bfile, ['thresholds'],
                                  {'thresholds': thresholds},
                                  compression=hdf5_compress)
                bfile.close()
            comm.Barrier()

    # if comm.rank == 0:
    #     if not os.path.exists(plot_path):
    #         os.makedirs(plot_path)
    #     N_elec = min(int(numpy.sqrt(data_file.N_e)), 5)
    #     plot.view_fit(filename, t_start=0, t_stop=1, fit_on=False, square=True,
    #                   n_elec=N_elec, save=[plot_path, 'electrodes'])

    # Part 2: Basis
    numpy.random.seed(422)

    SHARED_MEMORY = get_shared_memory_flag(params)
    #################################################################
    file_out = params.get('data', 'file_out')
    alignment = params.getboolean('detection', 'alignment')
    over_factor = params.getint('detection', 'oversampling_factor')
    nb_jitter = params.getint('detection', 'nb_jitter')
    spike_thresh = params.getfloat('detection', 'spike_thresh')
    nodes, edges = get_nodes_and_edges(params)
    _, positions = get_nodes_and_positions(params)
    do_temporal_whitening = params.getboolean('whitening', 'temporal')
    do_spatial_whitening = params.getboolean('whitening', 'spatial')
    use_barycenter = params.getboolean('detection', 'use_barycenter')
    if matched_filter:
        chunk_size = detect_memory(params, whitening=True)
    else:
        chunk_size = detect_memory(params)
    safety_time = params.getint('whitening', 'safety_time')
    max_elts_elec = params.getint('whitening', 'max_elts')
    output_dim = params.getfloat('whitening', 'output_dim')
    inv_nodes = numpy.zeros(N_total, dtype=numpy.int32)
    inv_nodes[nodes] = numpy.arange(len(nodes))
    smoothing_factor = params.getfloat('detection', 'smoothing_factor')
    if sign_peaks == 'both':
        max_elts_elec *= 2
    nb_elts = int(
        params.getfloat('whitening', 'nb_elts') * N_e * max_elts_elec)

    weird_thresh = params.get('detection', 'weird_thresh')
    if weird_thresh != '':
        ignore_artefacts = True
        weird_thresh = io.load_data(params, 'weird-thresholds')
    else:
        ignore_artefacts = False

    ignore_dead_times = params.getboolean('triggers', 'ignore_times')
    if ignore_dead_times:
        if SHARED_MEMORY:
            all_dead_times, mpi_memory_3 = get_dead_times(params)
        else:
            all_dead_times = get_dead_times(params)
    data_file.open()
    #################################################################

    if comm.rank == 0:
        print_and_log(["Searching spikes to construct the PCA basis..."],
                      'default', logger)

    nb_chunks, last_chunk_len = data_file.analyze(chunk_size)

    if nb_chunks < comm.size:

        res = io.data_stats(params, show=False)
        chunk_size = int(res * params.rate // comm.size)
        if comm.rank == 0:
            print_and_log(
                ["Too much cores, automatically resizing the data chunks"],
                'debug', logger)

        nb_chunks, last_chunk_len = data_file.analyze(chunk_size)

    groups = {}
    for i in range(N_e):
        groups[i] = 0

    # I guess this is more relevant, to take signals from all over the recordings
    all_chunks = numpy.random.permutation(
        numpy.arange(nb_chunks, dtype=numpy.int32))
    max_elts_elec //= comm.size
    nb_elts //= comm.size

    elt_count_pos = 0
    elt_count_neg = 0

    if sign_peaks in ['positive', 'both']:
        times_pos = numpy.zeros(nb_elts, dtype=numpy.int32)
        electrodes_pos = numpy.zeros(nb_elts, dtype=numpy.int32)
        extremum_pos = numpy.zeros(nb_elts, dtype=numpy.float32)
        elts_pos = numpy.zeros((N_t, nb_elts), dtype=numpy.float32)
    if sign_peaks in ['negative', 'both']:
        times_neg = numpy.zeros(nb_elts, dtype=numpy.int32)
        electrodes_neg = numpy.zeros(nb_elts, dtype=numpy.int32)
        extremum_neg = numpy.zeros(nb_elts, dtype=numpy.float32)
        elts_neg = numpy.zeros((N_t, nb_elts), dtype=numpy.float32)

    thresholds = io.load_data(params, 'thresholds')
    mads = io.load_data(params, 'mads')
    stds = io.load_data(params, 'stds')

    if alignment:
        cdata = numpy.linspace(-jitter_range, +jitter_range, nb_jitter)
        xdata = numpy.arange(-template_shift_2, template_shift_2 + 1)
        xoff = len(cdata) / 2.0
        snippet_duration = template_shift_2
        m_size = 2 * template_shift_2 + 1
        align_factor = m_size
        local_factors = align_factor * ((smoothing_factor * mads)**2)
    else:
        snippet_duration = template_shift
        xdata = numpy.arange(-template_shift, template_shift + 1)

    if rejection_threshold > 0:
        reject_noise = True
        noise_levels = stds * (2 * noise_window + 1)
    else:
        reject_noise = False

    to_explore = all_chunks[comm.rank::comm.size]

    upper_bounds = max_elts_elec

    if comm.rank == 0:
        to_explore = get_tqdm_progressbar(params, to_explore)

    for gcount, gidx in enumerate(to_explore):

        if (elt_count_pos + elt_count_neg) < nb_elts:
            # print "Node", comm.rank, "is analyzing chunk", gidx, "/", nb_chunks, " ..."
            local_chunk, t_offset = data_file.get_data(gidx,
                                                       chunk_size,
                                                       nodes=nodes)
            local_shape = len(local_chunk)

            if do_spatial_whitening:
                if use_gpu:
                    local_chunk = cmt.CUDAMatrix(local_chunk,
                                                 copy_on_host=False)
                    local_chunk = local_chunk.dot(spatial_whitening).asarray()
                else:
                    local_chunk = numpy.dot(local_chunk, spatial_whitening)
            if do_temporal_whitening:
                local_chunk = scipy.ndimage.filters.convolve1d(
                    local_chunk, temporal_whitening, axis=0, mode='constant')

            local_borders = (snippet_duration, local_shape - snippet_duration)

            if ignore_dead_times:
                dead_indices = numpy.searchsorted(
                    all_dead_times, [t_offset, t_offset + local_shape])

            # Extracting the peaks.
            all_peaktimes = [numpy.empty(0, dtype=numpy.uint32)]

            found_peaktimes = []
            found_peak_amplitudes = []
            for i in range(N_e):
                height = thresholds[i]
                if sign_peaks == 'negative':
                    peaktimes = scipy.signal.find_peaks(-local_chunk[:, i],
                                                        height=height,
                                                        distance=dist_peaks)[0]
                elif sign_peaks == 'positive':
                    peaktimes = scipy.signal.find_peaks(local_chunk[:, i],
                                                        height=height,
                                                        distance=dist_peaks)[0]
                elif sign_peaks == 'both':
                    peaktimes = scipy.signal.find_peaks(numpy.abs(
                        local_chunk[:, i]),
                                                        height=height,
                                                        distance=dist_peaks)[0]
                else:
                    peaktimes = numpy.empty(0, dtype=numpy.uint32)

                if ignore_artefacts:
                    artetimes = scipy.signal.find_peaks(
                        numpy.abs(local_chunk[:,
                                              i]), height=weird_thresh[i])[0]
                    to_keep = numpy.logical_not(
                        numpy.in1d(peaktimes, artetimes))
                    peaktimes = peaktimes[to_keep]

                idx = (peaktimes >= local_borders[0]) & (peaktimes <
                                                         local_borders[1])
                peaktimes = peaktimes[idx]

                if ignore_dead_times:
                    if dead_indices[0] != dead_indices[1]:
                        is_included = numpy.in1d(
                            peaktimes + t_offset,
                            all_dead_times[dead_indices[0]:dead_indices[1]])
                        peaktimes = peaktimes[~is_included]

                peaktimes = peaktimes.astype(numpy.uint32)
                found_peaktimes.append(peaktimes)

                peak_amplitudes = local_chunk[peaktimes, i]
                found_peak_amplitudes.append(peak_amplitudes)

            all_peaktimes = numpy.concatenate(
                found_peaktimes)  # i.e. concatenate once for efficiency
            all_peak_amplitudes = numpy.concatenate(found_peak_amplitudes)
            local_peaktimes, local_indices = numpy.unique(all_peaktimes,
                                                          return_inverse=True)

            if len(local_peaktimes) > 0:

                diff_times = (local_peaktimes[-1] - local_peaktimes[0]) + 1
                all_times = numpy.zeros((N_e, diff_times), dtype=numpy.bool)

                padded_peaks = (local_peaktimes - local_peaktimes[0]).astype(
                    numpy.int32)
                min_times = numpy.maximum(padded_peaks - safety_time, 0)
                max_times = numpy.minimum(padded_peaks + safety_time + 1,
                                          diff_times + 1)
                test_extremas = numpy.zeros((N_e, diff_times + 1),
                                            dtype=numpy.bool)
                for i in range(N_e):
                    test_extremas[i, found_peaktimes[i] -
                                  local_peaktimes[0]] = True

                # Consider the peaks by decreasing extremum.
                if sort_waveforms:
                    order = numpy.argsort(-np.abs(all_peak_amplitudes))
                    all_idx = numpy.take(all_peaktimes, order)
                    argmax_peak = local_indices[order]
                else:
                    n_times = len(all_peaktimes)
                    shuffling = numpy.random.permutation(numpy.arange(n_times))
                    all_idx = numpy.take(all_peaktimes, shuffling)
                    argmax_peak = local_indices[shuffling]

                # print "Selection of the peaks with spatio-temporal masks..."
                for midx, peak in zip(argmax_peak, all_idx):
                    if (elt_count_neg + elt_count_pos) == nb_elts:
                        break

                    all_elecs = numpy.where(
                        test_extremas[:, peak - local_peaktimes[0]])[0]
                    data = local_chunk[peak, all_elecs]

                    #target_area = test_extremas[:, min_times[midx]:max_times[midx]].sum(1)
                    #all_elecs = numpy.where(target_area)[0]
                    #data = local_chunk[peak, all_elecs]

                    if sign_peaks == 'negative':
                        if N_e > 1:
                            if use_barycenter:
                                weighed_position = data[:, numpy.
                                                        newaxis] * positions[
                                                            all_elecs]
                                barycenter = weighed_position.sum(
                                    0) / data.sum()
                                elec = numpy.argmin(
                                    numpy.linalg.norm(barycenter -
                                                      positions[all_elecs],
                                                      axis=1))
                            else:
                                elec = numpy.argmin(data)
                        else:
                            elec = 0
                        negative_peak = True
                    elif sign_peaks == 'positive':
                        if N_e > 1:
                            if use_barycenter:
                                weighed_position = data[:, numpy.
                                                        newaxis] * positions[
                                                            all_elecs]
                                barycenter = weighed_position.sum(
                                    0) / data.sum()
                                elec = numpy.argmax(
                                    numpy.linalg.norm(barycenter -
                                                      positions[all_elecs],
                                                      axis=1))
                            else:
                                elec = numpy.argmax(data)
                        else:
                            elec = 0
                        negative_peak = False
                    elif sign_peaks == 'both':
                        if N_e == 1:
                            if data < 0:
                                negative_peak = True
                            elif data > 0:
                                negative_peak = False
                            elec = 0
                        else:
                            if numpy.abs(numpy.max(data)) > numpy.abs(
                                    numpy.min(data)):
                                elec = numpy.argmax(data)
                                negative_peak = False
                            else:
                                elec = numpy.argmin(data)
                                negative_peak = True

                    elec = all_elecs[elec]

                    if groups[elec] < upper_bounds:

                        indices = nodes_indices[elec]
                        myslice = all_times[indices,
                                            min_times[midx]:max_times[midx]]

                        if not myslice.any():

                            sub_mat = local_chunk[peak -
                                                  snippet_duration:peak +
                                                  snippet_duration + 1, elec]

                            if reject_noise:
                                slice_window = sub_mat[
                                    snippet_duration -
                                    noise_window:snippet_duration +
                                    noise_window + 1]
                                value = numpy.linalg.norm(
                                    slice_window) / noise_levels[elec]
                                is_noise = value < rejection_threshold
                            else:
                                is_noise = False

                            if not is_noise:

                                extrema = sub_mat[snippet_duration]

                                if alignment:
                                    smoothed = True
                                    try:
                                        f = scipy.interpolate.UnivariateSpline(
                                            xdata,
                                            sub_mat,
                                            s=local_factors[elec],
                                            k=3)
                                    except Exception:
                                        smoothed = False
                                        f = scipy.interpolate.UnivariateSpline(
                                            xdata, sub_mat, k=3, s=0)

                                    if negative_peak:
                                        rmin = (numpy.argmin(f(cdata)) -
                                                xoff) / over_factor
                                    else:
                                        rmin = (numpy.argmax(f(cdata)) -
                                                xoff) / over_factor
                                    ddata = numpy.linspace(
                                        rmin - template_shift,
                                        rmin + template_shift, N_t)

                                    if smoothed:
                                        f = scipy.interpolate.UnivariateSpline(
                                            xdata,
                                            sub_mat,
                                            s=local_factors[elec],
                                            k=3)
                                    else:
                                        f = scipy.interpolate.UnivariateSpline(
                                            xdata, sub_mat, s=0, k=3)

                                    sub_mat = f(ddata).astype(numpy.float32)

                                if negative_peak:
                                    times_neg[elt_count_neg] = peak + t_offset
                                    electrodes_neg[elt_count_neg] = elec
                                    extremum_neg[elt_count_neg] = extrema
                                    elts_neg[:, elt_count_neg] = sub_mat
                                    elt_count_neg += 1
                                else:
                                    times_pos[elt_count_pos] = peak + t_offset
                                    electrodes_pos[elt_count_pos] = elec
                                    extremum_pos[elt_count_pos] = extrema
                                    elts_pos[:, elt_count_pos] = sub_mat
                                    elt_count_pos += 1

                                groups[elec] += 1
                                all_times[
                                    indices,
                                    min_times[midx]:max_times[midx]] = True
                                test_extremas[elec, peak -
                                              local_peaktimes[0]] = False

    sys.stderr.flush()

    print_and_log([
        "Node %d has collected %d waveforms" %
        (comm.rank, elt_count_pos + elt_count_neg)
    ], 'debug', logger)

    if sign_peaks in ['negative', 'both']:
        times_neg = gather_array(times_neg[:elt_count_neg],
                                 comm,
                                 0,
                                 1,
                                 dtype='int32')
        electrodes_neg = gather_array(electrodes_neg[:elt_count_neg],
                                      comm,
                                      0,
                                      1,
                                      dtype='int32')
        extremum_neg = gather_array(extremum_neg[:elt_count_neg], comm, 0, 1)
        gdata_neg = gather_array(elts_neg[:, :elt_count_neg].T, comm, 0, 1)
    if sign_peaks in ['positive', 'both']:
        times_pos = gather_array(times_pos[:elt_count_pos],
                                 comm,
                                 0,
                                 1,
                                 dtype='int32')
        electrodes_pos = gather_array(electrodes_pos[:elt_count_pos],
                                      comm,
                                      0,
                                      1,
                                      dtype='int32')
        extremum_pos = gather_array(extremum_pos[:elt_count_pos], comm, 0, 1)
        gdata_pos = gather_array(elts_pos[:, :elt_count_pos].T, comm, 0, 1)

    nb_waveforms = 0

    if comm.rank == 0:
        # DO PCA on elts and store the basis obtained.

        if sign_peaks in ['negative', 'both']:
            nb_waveforms += gdata_neg.shape[0]
        if sign_peaks in ['positive', 'both']:
            nb_waveforms += gdata_pos.shape[0]

    nb_waveforms = all_gather_array(
        numpy.array([nb_waveforms], dtype=numpy.float32), comm, 0)[0]

    if comm.rank == 0:
        print_and_log([
            "Found %d waveforms over %d requested" %
            (nb_waveforms, int(nb_elts * comm.size))
        ], 'default', logger)

        if nb_waveforms == 0:
            print_and_log(
                ['No waveforms found! Are the data properly loaded??'],
                'error', logger)

    if nb_waveforms == 0:
        sys.exit(0)

    if comm.rank == 0:
        res = {}
        pca = None
        pca_pos = None
        pca_neg = None
        warning_n_t = False
        if sign_peaks in ['negative', 'both']:
            res['times'] = times_neg
            res['electrodes'] = electrodes_neg
            res['extremum'] = extremum_neg
            if len(gdata_neg) > 0:
                pca = PCA(output_dim)
                if use_hanning:
                    pca.fit(gdata_neg * hanning_filter)
                else:
                    pca.fit(gdata_neg)
                res['proj'] = pca.components_.T.astype(numpy.float32)
                pca_neg = numpy.sum(pca.explained_variance_ratio_)
            else:
                res['proj'] = numpy.identity(int(output_dim),
                                             dtype=numpy.float32)
            res['rec'] = res['proj'].T
            res['waveform'] = numpy.median(gdata_neg, 0)
            # dispersion = numpy.std(gdata_neg, 0) / numpy.median(stds)
            # ratio = numpy.sum(dispersion > 1.1) / float(len(dispersion))
            # if ratio < 0.25:
            #     print_and_log(["Time window N_t in [detection] seems too large!"], 'info', logger)
            #     warning_n_t = True
            # elif ratio == 1:
            #     print_and_log(["Time window N_t in [detection] seems too small!"], 'info', logger)
            #     warning_n_t = True
            idx = numpy.random.permutation(numpy.arange(
                gdata_neg.shape[0]))[:2500]
            res['waveforms'] = gdata_neg[idx, :]
        if sign_peaks in ['positive', 'both']:
            res['times_pos'] = times_pos
            res['electrodes_pos'] = electrodes_pos
            res['extremum_pos'] = extremum_pos
            if len(gdata_pos) > 0:
                pca = PCA(output_dim)
                if use_hanning:
                    pca.fit(gdata_pos * hanning_filter)
                else:
                    pca.fit(gdata_pos)
                res['proj_pos'] = pca.components_.T.astype(numpy.float32)
                pca_pos = numpy.sum(pca.explained_variance_ratio_)
            else:
                res['proj_pos'] = numpy.identity(int(output_dim),
                                                 dtype=numpy.float32)
            res['rec_pos'] = res['proj_pos'].T
            res['waveform_pos'] = numpy.median(gdata_pos, 0)
            # dispersion = numpy.std(gdata_pos, 0) / numpy.median(stds)
            # ratio = numpy.sum(dispersion > 1.1) / float(len(dispersion))
            # if ratio < 0.25 and not warning_n_t:
            #     print_and_log(["Time window N_t in [detection] seems too large!"], 'info', logger)
            # elif ratio == 1 and not warning_n_t:
            #     print_and_log(["Time window N_t in [detection] seems too small!"], 'info', logger)
            idx = numpy.random.permutation(numpy.arange(
                gdata_pos.shape[0]))[:2500]
            res['waveforms_pos'] = gdata_pos[idx, :]

        bfile = h5py.File(file_out_suff + '.basis.hdf5',
                          'r+',
                          libver='earliest')
        io.write_datasets(bfile,
                          list(res.keys()),
                          res,
                          compression=hdf5_compress)
        if sign_peaks == 'positive':
            print_and_log([
                "A basis with %s dimensions has been built" %
                res['proj_pos'].shape[1]
            ], 'info', logger)
        elif sign_peaks == 'negative':
            print_and_log([
                "A basis with %s dimensions has been built" %
                res['proj'].shape[1]
            ], 'info', logger)
        elif sign_peaks == 'both':
            print_and_log([
                "Two basis with %s dimensions has been built" %
                res['proj'].shape[1]
            ], 'debug', logger)
        if pca_pos is not None:
            print_and_log([
                "The percentage of variance explained is %s for positive spikes"
                % pca_pos
            ], 'debug', logger)
        if pca_neg is not None:
            print_and_log([
                "The percentage of variance explained is %s for negative spikes"
                % pca_neg
            ], 'debug', logger)

        bfile.close()

    comm.Barrier()

    if matched_filter:

        if comm.rank == 0:
            print_and_log([
                "Because of matched filters, need to recompute the thresholds..."
            ], 'default', logger)

        if do_spatial_whitening:
            spatial_whitening = io.load_data(params, 'spatial_whitening')
            if use_gpu:
                spatial_whitening = cmt.CUDAMatrix(spatial_whitening,
                                                   copy_on_host=False)
        if do_temporal_whitening:
            temporal_whitening = io.load_data(params, 'temporal_whitening')

        if sign_peaks in ['negative', 'both']:
            waveform_neg = io.load_data(params, 'waveform')[::-1]
            waveform_neg /= (numpy.abs(numpy.sum(waveform_neg)) *
                             len(waveform_neg))
        if sign_peaks in ['positive', 'both']:
            waveform_pos = io.load_data(params, 'waveform-pos')[::-1]
            waveform_pos /= (numpy.abs(numpy.sum(waveform_pos)) *
                             len(waveform_pos))

        for gidx in [all_chunks[comm.rank]]:
            local_chunk, t_offset = data_file.get_data(gidx,
                                                       chunk_size,
                                                       nodes=nodes)
            local_shape = len(local_chunk)

            if do_spatial_whitening:
                if use_gpu:
                    local_chunk = cmt.CUDAMatrix(local_chunk,
                                                 copy_on_host=False)
                    local_chunk = local_chunk.dot(spatial_whitening).asarray()
                else:
                    local_chunk = numpy.dot(local_chunk, spatial_whitening)
            if do_temporal_whitening:
                local_chunk = scipy.ndimage.filters.convolve1d(
                    local_chunk, temporal_whitening, axis=0, mode='constant')

            local_chunk /= thresholds

            if sign_peaks in ['negative', 'both']:
                tmp_chunk = scipy.ndimage.filters.convolve1d(local_chunk,
                                                             waveform_neg,
                                                             axis=0,
                                                             mode='constant')
                thresholds = numpy.zeros(N_e, dtype=numpy.float32)
                for i in range(N_e):
                    u = numpy.median(tmp_chunk[:, i], 0)
                    thresholds[i] = numpy.median(
                        numpy.abs(tmp_chunk[:, i] - u), 0)
                gdata = gather_array(thresholds, comm)
                if comm.rank == 0:
                    gdata = gdata.reshape((comm.size, N_e))
                    thresholds = numpy.mean(gdata, 0)
                    bfile = h5py.File(file_out_suff + '.basis.hdf5',
                                      'r+',
                                      libver='earliest')
                    io.write_datasets(bfile, ['matched_thresholds'],
                                      {'matched_thresholds': thresholds},
                                      compression=hdf5_compress)
                    bfile.close()
                comm.Barrier()

            if sign_peaks in ['positive', 'both']:
                tmp_chunk = scipy.ndimage.filters.convolve1d(local_chunk,
                                                             waveform_pos,
                                                             axis=0,
                                                             mode='constant')
                thresholds = numpy.zeros(N_e, dtype=numpy.float32)
                for i in range(N_e):
                    u = numpy.median(tmp_chunk[:, i], 0)
                    thresholds[i] = numpy.median(
                        numpy.abs(tmp_chunk[:, i] - u), 0)
                gdata = gather_array(thresholds, comm)
                if comm.rank == 0:
                    gdata = gdata.reshape((comm.size, N_e))
                    thresholds = numpy.mean(gdata, 0)
                    bfile = h5py.File(file_out_suff + '.basis.hdf5',
                                      'r+',
                                      libver='earliest')
                    io.write_datasets(bfile, ['matched_thresholds_pos'],
                                      {'matched_thresholds_pos': thresholds},
                                      compression=hdf5_compress)
                    bfile.close()
                comm.Barrier()

    data_file.close()

    if SHARED_MEMORY and ignore_dead_times:
        mpi_memory_3.Free()
Exemplo n.º 9
0
def main(params, nb_cpu, nb_gpu, use_gpu):
    # Part 1: Whitening
    numpy.random.seed(420)
    #params         = detect_memory(params)
    logger = init_logging(params.logfile)
    logger = logging.getLogger('circus.whitening')
    #################################################################
    data_file = params.data_file
    N_e = params.getint('data', 'N_e')
    hdf5_compress = params.getboolean('data', 'hdf5_compress')
    N_total = params.nb_channels
    N_t = params.getint('detection', 'N_t')
    dist_peaks = params.getint('detection', 'dist_peaks')
    template_shift = params.getint('detection', 'template_shift')
    file_out_suff = params.get('data', 'file_out_suff')
    spike_thresh = params.getfloat('detection', 'spike_thresh')
    spike_width = params.getfloat('detection', 'spike_width')
    matched_filter = params.getboolean('detection', 'matched-filter')
    matched_thresh = params.getfloat('detection', 'matched_thresh')
    sign_peaks = params.get('detection', 'peaks')
    do_temporal_whitening = params.getboolean('whitening', 'temporal')
    do_spatial_whitening = params.getboolean('whitening', 'spatial')
    chunk_size = detect_memory(params, whitening=True)
    plot_path = os.path.join(params.get('data', 'file_out_suff'), 'plots')
    nodes, edges = get_nodes_and_edges(params)
    safety_time = params.getint('whitening', 'safety_time')
    safety_space = params.getboolean('whitening', 'safety_space')
    nb_temp_white = min(max(20, comm.size), N_e)
    max_silence_1 = int(20 * params.rate // comm.size)
    max_silence_2 = 5000
    inv_nodes = numpy.zeros(N_total, dtype=numpy.int32)
    inv_nodes[nodes] = numpy.arange(len(nodes))
    jitter_range = params.getint('detection', 'jitter_range')
    template_shift_2 = template_shift + jitter_range
    use_hanning = params.getboolean('detection', 'hanning')
    data_file.open()
    #################################################################

    if use_hanning:
        hanning_filter = numpy.hanning(N_t)

    if comm.rank == 0:
        print_and_log(
            ["Analyzing data to get whitening matrices and thresholds..."],
            'default', logger)

    if use_gpu:
        import cudamat as cmt
        ## Need to properly handle multi GPU per MPI nodes?
        if nb_gpu > nb_cpu:
            gpu_id = int(comm.rank // nb_cpu)
        else:
            gpu_id = 0
        cmt.cuda_set_device(gpu_id)
        cmt.init()
        cmt.cuda_sync_threads()

    nb_chunks, last_chunk_len = data_file.analyze(chunk_size)

    if nb_chunks < comm.size:

        res = io.data_stats(params, show=False)
        chunk_size = int(res * params.rate // comm.size)
        if comm.rank == 0:
            print_and_log(
                ["Too much cores, automatically resizing the data chunks"],
                'debug', logger)

        nb_chunks, last_chunk_len = data_file.analyze(chunk_size)

    # I guess this is more relevant, to take signals from all over the recordings
    all_chunks = numpy.random.permutation(
        numpy.arange(nb_chunks, dtype=numpy.int32))
    all_electrodes = numpy.random.permutation(N_e)

    for gidx in [all_chunks[comm.rank]]:

        #print "Node", comm.rank, "is analyzing chunk", gidx,  "/", nb_chunks, " ..."
        local_chunk, t_offset = data_file.get_data(gidx,
                                                   chunk_size,
                                                   nodes=nodes)
        local_shape = len(local_chunk)

        #print "Node", comm.rank, "computes the median absolute deviations in a random chunk"
        thresholds = numpy.zeros(N_e, dtype=numpy.float32)
        for i in xrange(N_e):
            u = numpy.median(local_chunk[:, i], 0)
            thresholds[i] = numpy.median(numpy.abs(local_chunk[:, i] - u), 0)
        gdata = gather_array(thresholds, comm)
        if comm.rank == 0:
            gdata = gdata.reshape((comm.size, N_e))
            thresholds = numpy.mean(gdata, 0)
            bfile = h5py.File(file_out_suff + '.basis.hdf5',
                              'w',
                              libver='earliest')
            io.write_datasets(bfile, ['thresholds'],
                              {'thresholds': thresholds},
                              compression=hdf5_compress)
            bfile.close()
        comm.Barrier()
        thresholds = io.load_data(params, 'thresholds')

        #print "Extracting the peaks..."
        local_peaktimes = numpy.zeros(0, dtype=numpy.uint32)
        for i in xrange(N_e):
            peaktimes = scipy.signal.find_peaks(numpy.abs(local_chunk[:, i]),
                                                height=thresholds[i],
                                                width=spike_width,
                                                wlen=N_t)[0]
            local_peaktimes = numpy.concatenate((local_peaktimes, peaktimes))

        local_peaktimes = numpy.unique(local_peaktimes)

        #print "Removing the useless borders..."
        local_borders = (template_shift, local_shape - template_shift)
        idx = (local_peaktimes >= local_borders[0]) & (local_peaktimes <
                                                       local_borders[1])
        local_peaktimes = numpy.compress(idx, local_peaktimes)

        if len(local_peaktimes) > 0:

            diff_times = local_peaktimes[-1] - local_peaktimes[0]
            all_times = numpy.zeros((N_e, diff_times + 1), dtype=numpy.bool)
            min_times = numpy.maximum(
                local_peaktimes - local_peaktimes[0] - safety_time, 0)
            max_times = numpy.minimum(
                local_peaktimes - local_peaktimes[0] + safety_time + 1,
                diff_times)
            argmax_peak = numpy.random.permutation(
                numpy.arange(len(local_peaktimes)))
            all_idx = numpy.take(local_peaktimes, argmax_peak)

            #print "Selection of the peaks with spatio-temporal masks..."
            for idx, peak in zip(argmax_peak, all_idx):
                elec = numpy.argmax(numpy.abs(local_chunk[peak]))
                indices = numpy.take(inv_nodes, edges[nodes[elec]])
                if safety_space:
                    all_times[indices, min_times[idx]:max_times[idx]] = True
                else:
                    all_times[elec, min_times[idx]:max_times[idx]] = True
        else:
            all_times = numpy.zeros((N_e, len(local_chunk)), dtype=numpy.bool)

    if do_temporal_whitening:

        local_res_temp = []

        for elec in all_electrodes[numpy.arange(comm.rank, nb_temp_white,
                                                comm.size)]:
            res = numpy.zeros((0, N_t), dtype=numpy.float32)
            scount = 0
            indices = numpy.take(inv_nodes, edges[nodes[elec]])
            all_times_elec = numpy.any(numpy.take(all_times, indices, axis=0),
                                       0)
            esubset = numpy.where(all_times_elec == False)[0]
            bound = len(esubset) - N_t
            while (scount < bound) and (len(res) < max_silence_2):
                myslice = esubset[scount:scount + N_t]
                if numpy.all((myslice - esubset[scount]) == numpy.arange(N_t)):
                    scount += N_t
                    res = numpy.vstack((res, local_chunk[myslice, elec]))
                else:
                    scount += 1
            if len(res) > 5:
                local_res_temp += [numpy.cov(res.T)]

        nb_elecs = numpy.array([len(local_res_temp)], dtype=numpy.float32)
        local_res_temp = numpy.array(local_res_temp, dtype=numpy.float32)
        if len(local_res_temp) == 0:
            local_res_temp = numpy.zeros(0, dtype=numpy.float32)
        else:
            local_res_temp = numpy.sum(local_res_temp, 0)
        all_res_temp = gather_array(local_res_temp.ravel(), comm, 0, 1)
        all_elecs = gather_array(nb_elecs, comm, 0, 1)

    if do_spatial_whitening:

        local_res_spac = numpy.zeros((N_e, N_e), dtype=numpy.float32)
        local_silences = []

        for elec in numpy.arange(comm.rank, N_e, comm.size):
            indices = numpy.take(inv_nodes, edges[nodes[elec]])
            all_times_elec = numpy.any(numpy.take(all_times, indices, axis=0),
                                       0)
            esubset = numpy.where(all_times_elec == False)[0]
            local_data = local_chunk[esubset][:, indices]
            local_whitening = get_whitening_matrix(local_data).astype(
                numpy.float32)
            pos = numpy.where(elec == indices)[0]
            local_res_spac[elec, indices] = local_whitening[pos]
            local_silences += [len(esubset)]

        all_res_spac = gather_array(local_res_spac.ravel(), comm, 0, 1)
        all_silences = gather_array(
            numpy.array(local_silences, dtype=numpy.int32), comm, 0, 1,
            'uint32')

    if comm.rank == 0:

        to_write = {}

        if do_temporal_whitening:
            try:
                nb_silences = numpy.sum(all_elecs > 0)
                all_res_temp = all_res_temp.reshape((nb_silences, N_t**2))
            except Exception:
                print_and_log([
                    "No silent periods detected: something wrong with the parameters?"
                ], 'error', logger)
            all_res_temp = numpy.sum(all_res_temp, 0)
            all_res_temp = all_res_temp.reshape(
                (N_t, N_t)) / numpy.sum(all_elecs)
            temporal_whitening = get_whitening_matrix(
                all_res_temp.astype(numpy.double),
                fudge=1e-3)[template_shift].astype(numpy.float32)
            temporal_whitening /= temporal_whitening.sum()
            to_write['temporal'] = temporal_whitening
            have_nans = numpy.sum(numpy.isnan(temporal_whitening))

            if have_nans > 0:
                temporal_whitening = numpy.zeros(N_t, dtype=numpy.float32)
                temporal_whitening[N_t // 2] = 1
                to_write['temporal'] = temporal_whitening
                print_and_log(
                    ["Disabling temporal whitening because of NaNs found"],
                    'info', logger)

        if do_spatial_whitening:
            all_res_spac = all_res_spac.reshape(comm.size, N_e, N_e)
            spatial_whitening = numpy.sum(all_res_spac, 0)
            to_write['spatial'] = spatial_whitening

            print_and_log([
                "Found %gs without spikes for whitening matrices..." %
                (numpy.mean(all_silences) / params.rate)
            ], 'default', logger)

            have_nans = numpy.sum(numpy.isnan(spatial_whitening))

            if have_nans > 0:
                spatial_whitening = numpy.eye(spatial_whitening.shape[0],
                                              dtype=numpy.float32)
                to_write['spatial'] = spatial_whitening
                print_and_log(
                    ["Disabling spatial whitening because of NaNs found"],
                    'info', logger)

        bfile = h5py.File(file_out_suff + '.basis.hdf5',
                          'r+',
                          libver='earliest')
        io.write_datasets(bfile,
                          to_write.keys(),
                          to_write,
                          compression=hdf5_compress)
        bfile.close()

    comm.Barrier()

    if do_spatial_whitening or do_temporal_whitening:

        if comm.rank == 0:
            print_and_log(
                ["Because of whitening, need to recompute the thresholds..."],
                'default', logger)

        if do_spatial_whitening:
            spatial_whitening = io.load_data(params, 'spatial_whitening')
            if use_gpu:
                spatial_whitening = cmt.CUDAMatrix(spatial_whitening,
                                                   copy_on_host=False)
        if do_temporal_whitening:
            temporal_whitening = io.load_data(params, 'temporal_whitening')

        for gidx in [all_chunks[comm.rank]]:
            local_chunk, t_offset = data_file.get_data(gidx,
                                                       chunk_size,
                                                       nodes=nodes)
            local_shape = len(local_chunk)

            if do_spatial_whitening:
                if use_gpu:
                    local_chunk = cmt.CUDAMatrix(local_chunk,
                                                 copy_on_host=False)
                    local_chunk = local_chunk.dot(spatial_whitening).asarray()
                else:
                    local_chunk = numpy.dot(local_chunk, spatial_whitening)
            if do_temporal_whitening:
                local_chunk = scipy.ndimage.filters.convolve1d(
                    local_chunk, temporal_whitening, axis=0, mode='constant')

            thresholds = numpy.zeros(N_e, dtype=numpy.float32)
            for i in xrange(N_e):
                u = numpy.median(local_chunk[:, i], 0)
                thresholds[i] = numpy.median(numpy.abs(local_chunk[:, i] - u),
                                             0)
            gdata = gather_array(thresholds, comm)
            if comm.rank == 0:
                gdata = gdata.reshape((comm.size, N_e))
                thresholds = numpy.mean(gdata, 0)
                bfile = h5py.File(file_out_suff + '.basis.hdf5',
                                  'r+',
                                  libver='earliest')
                bfile.pop('thresholds')
                io.write_datasets(bfile, ['thresholds'],
                                  {'thresholds': thresholds},
                                  compression=hdf5_compress)
                bfile.close()
            comm.Barrier()

    #if comm.rank == 0:
    #if not os.path.exists(plot_path):
    #    os.makedirs(plot_path)
    #N_elec = min(int(numpy.sqrt(data_file.N_e)), 5)
    #plot.view_fit(filename, t_start=0, t_stop=1, fit_on=False, square=True,
    #              n_elec=N_elec, save=[plot_path, 'electrodes'])

    # Part 2: Basis
    numpy.random.seed(422)

    #################################################################
    file_out = params.get('data', 'file_out')
    alignment = params.getboolean('detection', 'alignment')
    smoothing = params.getboolean('detection', 'smoothing')
    isolation = params.getboolean('detection', 'isolation')
    over_factor = params.getint('detection', 'oversampling_factor')
    spike_thresh = params.getfloat('detection', 'spike_thresh')
    nodes, edges = get_nodes_and_edges(params)
    do_temporal_whitening = params.getboolean('whitening', 'temporal')
    do_spatial_whitening = params.getboolean('whitening', 'spatial')
    if matched_filter:
        chunk_size = detect_memory(params, whitening=True)
    else:
        chunk_size = detect_memory(params)
    safety_time = params.getint('whitening', 'safety_time')
    max_elts_elec = params.getint('whitening', 'max_elts')
    output_dim = params.getfloat('whitening', 'output_dim')
    inv_nodes = numpy.zeros(N_total, dtype=numpy.int32)
    inv_nodes[nodes] = numpy.arange(len(nodes))
    smoothing_factor = params.getfloat(
        'detection', 'smoothing_factor') * (1. / spike_thresh)**2
    if sign_peaks == 'both':
        max_elts_elec *= 2
    nb_elts = int(
        params.getfloat('whitening', 'nb_elts') * N_e * max_elts_elec)

    ignore_dead_times = params.getboolean('triggers', 'ignore_times')
    if ignore_dead_times:
        all_dead_times = get_dead_times(params)
    data_file.open()
    #################################################################

    if comm.rank == 0:
        print_and_log(["Searching spikes to construct the PCA basis..."],
                      'default', logger)

    nb_chunks, last_chunk_len = data_file.analyze(chunk_size)

    if nb_chunks < comm.size:

        res = io.data_stats(params, show=False)
        chunk_size = int(res * params.rate // comm.size)
        if comm.rank == 0:
            print_and_log(
                ["Too much cores, automatically resizing the data chunks"],
                'debug', logger)

        nb_chunks, last_chunk_len = data_file.analyze(chunk_size)

    groups = {}
    for i in xrange(N_e):
        groups[i] = 0

    # I guess this is more relevant, to take signals from all over the recordings
    all_chunks = numpy.random.permutation(
        numpy.arange(nb_chunks, dtype=numpy.int32))
    max_elts_elec //= comm.size
    nb_elts //= comm.size

    elt_count_pos = 0
    elt_count_neg = 0

    if sign_peaks in ['positive', 'both']:
        elts_pos = numpy.zeros((N_t, nb_elts), dtype=numpy.float32)
    if sign_peaks in ['negative', 'both']:
        elts_neg = numpy.zeros((N_t, nb_elts), dtype=numpy.float32)

    chunks_to_load = all_chunks[comm.rank::comm.size]

    thresholds = io.load_data(params, 'thresholds')
    mads = io.load_data(params, 'mads')

    if alignment:
        cdata = numpy.linspace(-jitter_range, jitter_range,
                               int(over_factor * 2 * jitter_range))
        xdata = numpy.arange(-template_shift_2, template_shift_2 + 1)
        xoff = len(cdata) / 2.

    if isolation:
        yoff = numpy.array(range(0, N_t // 4) + range(3 * N_t // 4, N_t))

    to_explore = xrange(comm.rank, nb_chunks, comm.size)

    if comm.rank == 0:
        to_explore = get_tqdm_progressbar(to_explore)

    for gcount, gidx in enumerate(to_explore):

        gidx = all_chunks[gidx]

        if ((elt_count_pos + elt_count_neg) < nb_elts):
            #print "Node", comm.rank, "is analyzing chunk", gidx, "/", nb_chunks, " ..."
            local_chunk, t_offset = data_file.get_data(gidx,
                                                       chunk_size,
                                                       nodes=nodes)
            local_shape = len(local_chunk)

            if do_spatial_whitening:
                if use_gpu:
                    local_chunk = cmt.CUDAMatrix(local_chunk,
                                                 copy_on_host=False)
                    local_chunk = local_chunk.dot(spatial_whitening).asarray()
                else:
                    local_chunk = numpy.dot(local_chunk, spatial_whitening)
            if do_temporal_whitening:
                local_chunk = scipy.ndimage.filters.convolve1d(
                    local_chunk, temporal_whitening, axis=0, mode='constant')

            #print "Extracting the peaks..."
            all_peaktimes = numpy.zeros(0, dtype=numpy.uint32)
            all_extremas = numpy.zeros(0, dtype=numpy.uint32)

            for i in xrange(N_e):

                if sign_peaks == 'negative':
                    peaktimes = scipy.signal.find_peaks(-local_chunk[:, i],
                                                        height=thresholds[i],
                                                        width=spike_width,
                                                        distance=dist_peaks,
                                                        wlen=N_t)[0]
                elif sign_peaks == 'positive':
                    peaktimes = scipy.signal.find_peaks(local_chunk[:, i],
                                                        height=thresholds[i],
                                                        width=spike_width,
                                                        wlen=N_t)[0]
                elif sign_peaks == 'both':
                    peaktimes = scipy.signal.find_peaks(numpy.abs(
                        local_chunk[:, i]),
                                                        height=thresholds[i],
                                                        width=spike_width,
                                                        wlen=N_t)[0]
                all_peaktimes = numpy.concatenate((all_peaktimes, peaktimes))
                all_extremas = numpy.concatenate(
                    (all_extremas,
                     i * numpy.ones(len(peaktimes), dtype=numpy.uint32)))

            #print "Removing the useless borders..."
            if alignment:
                local_borders = (template_shift_2,
                                 local_shape - template_shift_2)
            else:
                local_borders = (template_shift, local_shape - template_shift)
            idx = (all_peaktimes >= local_borders[0]) & (all_peaktimes <
                                                         local_borders[1])
            all_peaktimes = numpy.compress(idx, all_peaktimes)
            all_extremas = numpy.compress(idx, all_extremas)

            local_peaktimes = numpy.unique(all_peaktimes)

            if ignore_dead_times:
                dead_indices = numpy.searchsorted(
                    all_dead_times, [t_offset, t_offset + local_shape])
                if dead_indices[0] != dead_indices[1]:
                    is_included = numpy.in1d(
                        local_peaktimes + t_offset,
                        all_dead_times[dead_indices[0]:dead_indices[1]])
                    local_peaktimes = local_peaktimes[~is_included]
                    local_peaktimes = numpy.sort(local_peaktimes)

            if len(local_peaktimes) > 0:

                diff_times = local_peaktimes[-1] - local_peaktimes[0]
                all_times = numpy.zeros((N_e, diff_times + 1),
                                        dtype=numpy.bool)
                min_times = numpy.maximum(
                    local_peaktimes - local_peaktimes[0] - safety_time, 0)
                max_times = numpy.minimum(
                    local_peaktimes - local_peaktimes[0] + safety_time + 1,
                    diff_times)

                n_times = len(local_peaktimes)
                argmax_peak = numpy.random.permutation(numpy.arange(n_times))
                all_idx = numpy.take(local_peaktimes, argmax_peak)

                #print "Selection of the peaks with spatio-temporal masks..."
                for midx, peak in zip(argmax_peak, all_idx):
                    if (elt_count_neg + elt_count_pos) == nb_elts:
                        break

                    if sign_peaks == 'negative':
                        elec = numpy.argmin(local_chunk[peak])
                        negative_peak = True
                    elif sign_peaks == 'positive':
                        elec = numpy.argmax(local_chunk[peak])
                        negative_peak = False
                    elif sign_peaks == 'both':
                        if N_e == 1:
                            if local_chunk[peak] < 0:
                                negative_peak = True
                            elif local_chunk[peak] > 0:
                                negative_peak = False
                            elec = 0
                        else:
                            if numpy.abs(numpy.max(
                                    local_chunk[peak])) > numpy.abs(
                                        numpy.min(local_chunk[peak])):
                                elec = numpy.argmax(local_chunk[peak])
                                negative_peak = False
                            else:
                                elec = numpy.argmin(local_chunk[peak])
                                negative_peak = True

                    indices = numpy.take(inv_nodes, edges[nodes[elec]])
                    myslice = all_times[indices,
                                        min_times[midx]:max_times[midx]]
                    is_local_extrema = elec in all_extremas[all_peaktimes ==
                                                            peak]
                    if is_local_extrema and not myslice.any():
                        upper_bounds = max_elts_elec

                        if groups[elec] < upper_bounds:

                            if not alignment:
                                sub_mat = local_chunk[peak -
                                                      template_shift:peak +
                                                      template_shift + 1, elec]

                            elif alignment:
                                ydata = local_chunk[peak -
                                                    template_shift_2:peak +
                                                    template_shift_2 + 1, elec]

                                if smoothing:
                                    factor = smoothing_factor * xdata.size
                                    f = scipy.interpolate.UnivariateSpline(
                                        xdata, ydata, s=factor, k=3)
                                else:
                                    f = scipy.interpolate.UnivariateSpline(
                                        xdata, ydata, k=3, s=0)
                                if negative_peak:
                                    rmin = (numpy.argmin(f(cdata)) -
                                            xoff) / over_factor
                                else:
                                    rmin = (numpy.argmax(f(cdata)) -
                                            xoff) / over_factor
                                ddata = numpy.linspace(rmin - template_shift,
                                                       rmin + template_shift,
                                                       N_t)
                                sub_mat = f(ddata).astype(numpy.float32)

                            if alignment:
                                if negative_peak:
                                    if numpy.min(sub_mat) >= -thresholds[elec]:
                                        to_accept = False
                                else:
                                    if numpy.max(sub_mat) <= thresholds[elec]:
                                        to_accept = False

                            if isolation:
                                to_accept = numpy.all(
                                    numpy.max(numpy.abs(sub_mat[yoff])) <=
                                    thresholds[elec])
                            else:
                                to_accept = True

                            if to_accept:
                                if negative_peak:
                                    elts_neg[:, elt_count_neg] = sub_mat
                                else:
                                    elts_pos[:, elt_count_pos] = sub_mat

                                if negative_peak:
                                    elt_count_neg += 1
                                else:
                                    elt_count_pos += 1

                        groups[elec] += 1
                        all_times[indices,
                                  min_times[midx]:max_times[midx]] = True

    sys.stderr.flush()

    if isolation:
        print_and_log([
            "Node %d has collected %d isolated waveforms" %
            (comm.rank, elt_count_pos + elt_count_neg)
        ], 'debug', logger)
    else:
        print_and_log([
            "Node %d has collected %d waveforms" %
            (comm.rank, elt_count_pos + elt_count_neg)
        ], 'debug', logger)

    if sign_peaks in ['negative', 'both']:
        gdata_neg = gather_array(elts_neg[:, :elt_count_neg].T, comm, 0, 1)
    if sign_peaks in ['positive', 'both']:
        gdata_pos = gather_array(elts_pos[:, :elt_count_pos].T, comm, 0, 1)

    nb_waveforms = 0

    if comm.rank == 0:
        #DO PCA on elts and store the basis obtained.

        if sign_peaks in ['negative', 'both']:
            nb_waveforms += gdata_neg.shape[0]
        if sign_peaks in ['positive', 'both']:
            nb_waveforms += gdata_pos.shape[0]

    nb_waveforms = all_gather_array(
        numpy.array([nb_waveforms], dtype=numpy.float32), comm, 0)[0]

    if comm.rank == 0:
        if isolation:
            print_and_log([
                "Found %d isolated waveforms over %d requested" %
                (nb_waveforms, int(nb_elts * comm.size))
            ], 'default', logger)
        else:
            print_and_log([
                "Found %d waveforms over %d requested" %
                (nb_waveforms, int(nb_elts * comm.size))
            ], 'default', logger)

        if nb_waveforms == 0:
            print_and_log(
                ['No waveforms found! Are the data properly loaded??'],
                'error', logger)

    if nb_waveforms == 0:
        sys.exit(0)

    if comm.rank == 0:
        res = {}
        pca = None
        pca_pos = None
        pca_neg = None
        if sign_peaks in ['negative', 'both']:
            if len(gdata_neg) > 0:
                pca = PCA(output_dim)
                if use_hanning:
                    pca.fit(gdata_neg * hanning_filter)
                else:
                    pca.fit(gdata_neg)
                res['proj'] = pca.components_.T.astype(numpy.float32)
                pca_neg = numpy.sum(pca.explained_variance_ratio_)
            else:
                res['proj'] = numpy.identity(int(output_dim),
                                             dtype=numpy.float32)
            res['rec'] = res['proj'].T
            res['waveform'] = numpy.median(gdata_neg, 0)
            idx = numpy.random.permutation(numpy.arange(
                gdata_neg.shape[0]))[:1000]
            res['waveforms'] = gdata_neg[idx, :]
        if sign_peaks in ['positive', 'both']:
            if len(gdata_pos) > 0:
                pca = PCA(output_dim)
                if use_hanning:
                    pca.fit(gdata_pos * hanning_filter)
                else:
                    pca.fit(gdata_pos)
                res['proj_pos'] = pca.components_.T.astype(numpy.float32)
                pca_pos = numpy.sum(pca.explained_variance_ratio_)
            else:
                res['proj_pos'] = numpy.identity(int(output_dim),
                                                 dtype=numpy.float32)
            res['rec_pos'] = res['proj_pos'].T
            res['waveform_pos'] = numpy.median(gdata_pos, 0)
            idx = numpy.random.permutation(numpy.arange(
                gdata_pos.shape[0]))[:1000]
            res['waveforms_pos'] = gdata_pos[idx, :]

        bfile = h5py.File(file_out_suff + '.basis.hdf5',
                          'r+',
                          libver='earliest')
        io.write_datasets(bfile, res.keys(), res, compression=hdf5_compress)
        if sign_peaks == 'positive':
            print_and_log([
                "A basis with %s dimensions has been built" %
                res['proj_pos'].shape[1]
            ], 'info', logger)
        elif sign_peaks == 'negative':
            print_and_log([
                "A basis with %s dimensions has been built" %
                res['proj'].shape[1]
            ], 'info', logger)
        elif sign_peaks == 'both':
            print_and_log([
                "Two basis with %s dimensions has been built" %
                res['proj'].shape[1]
            ], 'debug', logger)
        if pca_pos is not None:
            print_and_log([
                "The percentage of variance explained is %s for positive spikes"
                % pca_pos
            ], 'debug', logger)
        if pca_neg is not None:
            print_and_log([
                "The percentage of variance explained is %s for negative spikes"
                % pca_neg
            ], 'debug', logger)

        bfile.close()

    comm.Barrier()

    if matched_filter:

        if comm.rank == 0:
            print_and_log([
                "Because of matched filters, need to recompute the thresholds..."
            ], 'default', logger)

        if do_spatial_whitening:
            spatial_whitening = io.load_data(params, 'spatial_whitening')
            if use_gpu:
                spatial_whitening = cmt.CUDAMatrix(spatial_whitening,
                                                   copy_on_host=False)
        if do_temporal_whitening:
            temporal_whitening = io.load_data(params, 'temporal_whitening')

        if sign_peaks in ['negative', 'both']:
            waveform_neg = io.load_data(params, 'waveform')[::-1]
            waveform_neg /= (numpy.abs(numpy.sum(waveform_neg)) *
                             len(waveform_neg))
        if sign_peaks in ['positive', 'both']:
            waveform_pos = io.load_data(params, 'waveform-pos')[::-1]
            waveform_pos /= (numpy.abs(numpy.sum(waveform_pos)) *
                             len(waveform_pos))

        for gidx in [all_chunks[comm.rank]]:
            local_chunk, t_offset = data_file.get_data(gidx,
                                                       chunk_size,
                                                       nodes=nodes)
            local_shape = len(local_chunk)

            if do_spatial_whitening:
                if use_gpu:
                    local_chunk = cmt.CUDAMatrix(local_chunk,
                                                 copy_on_host=False)
                    local_chunk = local_chunk.dot(spatial_whitening).asarray()
                else:
                    local_chunk = numpy.dot(local_chunk, spatial_whitening)
            if do_temporal_whitening:
                local_chunk = scipy.ndimage.filters.convolve1d(
                    local_chunk, temporal_whitening, axis=0, mode='constant')

            local_chunk /= thresholds

            if sign_peaks in ['negative', 'both']:
                tmp_chunk = scipy.ndimage.filters.convolve1d(local_chunk,
                                                             waveform_neg,
                                                             axis=0,
                                                             mode='constant')
                thresholds = numpy.zeros(N_e, dtype=numpy.float32)
                for i in xrange(N_e):
                    u = numpy.median(tmp_chunk[:, i], 0)
                    thresholds[i] = numpy.median(
                        numpy.abs(tmp_chunk[:, i] - u), 0)
                gdata = gather_array(thresholds, comm)
                if comm.rank == 0:
                    gdata = gdata.reshape((comm.size, N_e))
                    thresholds = numpy.mean(gdata, 0)
                    bfile = h5py.File(file_out_suff + '.basis.hdf5',
                                      'r+',
                                      libver='earliest')
                    io.write_datasets(bfile, ['matched_thresholds'],
                                      {'matched_thresholds': thresholds},
                                      compression=hdf5_compress)
                    bfile.close()
                comm.Barrier()

            if sign_peaks in ['positive', 'both']:
                tmp_chunk = scipy.ndimage.filters.convolve1d(local_chunk,
                                                             waveform_pos,
                                                             axis=0,
                                                             mode='constant')
                thresholds = numpy.zeros(N_e, dtype=numpy.float32)
                for i in xrange(N_e):
                    u = numpy.median(tmp_chunk[:, i], 0)
                    thresholds[i] = numpy.median(
                        numpy.abs(tmp_chunk[:, i] - u), 0)
                gdata = gather_array(thresholds, comm)
                if comm.rank == 0:
                    gdata = gdata.reshape((comm.size, N_e))
                    thresholds = numpy.mean(gdata, 0)
                    bfile = h5py.File(file_out_suff + '.basis.hdf5',
                                      'r+',
                                      libver='earliest')
                    io.write_datasets(bfile, ['matched_thresholds_pos'],
                                      {'matched_thresholds_pos': thresholds},
                                      compression=hdf5_compress)
                    bfile.close()
                comm.Barrier()

    data_file.close()