Exemple #1
0
 def get_max_loc_channel(params):
     nodes, edges    = get_nodes_and_edges(params)
     max_loc_channel = 0
     for key in edges.keys():
         if len(edges[key]) > max_loc_channel:
             max_loc_channel = len(edges[key])
     return max_loc_channel
def plot_probe(ax, params, channel_ids=None):

    nb_channels = params.getint('data', 'N_e')

    probe = params.probe
    nodes, edges = get_nodes_and_edges(params)

    positions = []
    for i in list(probe['channel_groups'][1]['geometry'].keys()):
        positions.append(probe['channel_groups'][1]['geometry'][i])
    positions = np.array(positions)
    dx = np.median(np.diff(np.unique(
        positions[:, 0])))  # horizontal inter-electrode distance
    dy = np.median(np.diff(np.unique(
        positions[:, 1])))  # vertical inter-electrode distance

    if channel_ids is None:
        channel_ids = np.arange(0, nb_channels)

    patches = []
    kwargs = dict(
        radius=(min(dx, dy) / 2.0),
        color='tab:gray',
        alpha=0.5,
    )
    for channel_id in channel_ids:
        xy = positions[nodes[channel_id]]
        patch = mpatches.Circle(xy, **kwargs)
        patches.append(patch)
    collection = mcollections.PatchCollection(patches, match_original=True)

    ax.set_aspect('equal')
    ax.add_collection(collection)

    return
Exemple #3
0
    def write_templates(path, params, extension):

        max_loc_channel = get_max_loc_channel(params)
        templates       = io.load_data(params, 'templates', extension)
        N_tm            = templates.shape[1]//2
        nodes, edges    = get_nodes_and_edges(params)

        if sparse_export:
            n_channels_max = 0
            for t in xrange(N_tm):
                data = numpy.sum(numpy.sum(templates[:, t].toarray().reshape(N_e, N_t), 1) != 0) 
                if data > n_channels_max:
                    n_channels_max = data
        else:
            n_channels_max = N_e

        if export_all:
            to_write_sparse = numpy.zeros((N_tm + N_e, N_t, n_channels_max), dtype=numpy.float32)
            mapping_sparse  = -1 * numpy.ones((N_tm + N_e, n_channels_max), dtype=numpy.int32)            
        else:
            to_write_sparse = numpy.zeros((N_tm, N_t, n_channels_max), dtype=numpy.float32)
            mapping_sparse  = -1 * numpy.ones((N_tm, n_channels_max), dtype=numpy.int32)
            
        for t in xrange(N_tm):
            tmp                              = templates[:, t].toarray().reshape(N_e, N_t).T
            x, y                             = tmp.nonzero()
            nb_loc                           = len(numpy.unique(y))
                
            if sparse_export:
                all_positions                    = numpy.zeros(y.max()+1, dtype=numpy.int32)
                all_positions[numpy.unique(y)]   = numpy.arange(nb_loc, dtype=numpy.int32)
                pos                              = all_positions[y]
                to_write_sparse[t, x, pos]       = tmp[x, y] 
                mapping_sparse[t, numpy.arange(nb_loc)] = numpy.unique(y)
            else:
                pos                              = y
                to_write_sparse[t, x, pos]       = tmp[x, y] 

        if export_all:
            garbage = io.load_data(params, 'garbage', extension)
            for t in xrange(N_tm, N_tm + N_e):
                elec = t - N_tm
                spikes = garbage['gspikes'].pop('elec_%d' %elec).astype(numpy.uint64)
                spikes = numpy.random.permutation(spikes)[:100]
                mapping_sparse[t, 0] = t - N_tm
                waveform = io.get_stas(params, times_i=spikes, labels_i=np.ones(len(spikes)), src=elec, neighs=[elec], nodes=nodes, mean_mode=True)
                
                nb_loc = 1

                if sparse_export:
                    to_write_sparse[t, :, 0] = waveform
                else:
                    to_write_sparse[t, :, elec] = waveform

        numpy.save(os.path.join(output_path, 'templates'), to_write_sparse)

        if sparse_export:
            numpy.save(os.path.join(output_path, 'template_ind'), mapping_sparse)

        return N_tm
Exemple #4
0
 def get_max_loc_channel(params, extension):
     if test_if_support(params, extension):
         supports = io.load_data(params, 'supports', extension)
         max_loc_channel = numpy.sum(supports, 1).max()
     else:
         nodes, edges = get_nodes_and_edges(params)
         max_loc_channel = 0
         for key in list(edges.keys()):
             if len(edges[key]) > max_loc_channel:
                 max_loc_channel = len(edges[key])
     return max_loc_channel
def get_neighbors(params, chan=None):
    N_total = params.getint('data', 'N_total')
    nodes, edges = get_nodes_and_edges(params, validating=True)
    inv_nodes = numpy.zeros(N_total, dtype=numpy.int32)
    inv_nodes[nodes] = numpy.arange(len(nodes))
    if chan is None:
        # Select all the channels.
        chans = inv_nodes[nodes]
    else:
        # Select only the neighboring channels of the best channel.
        chans = inv_nodes[edges[nodes[chan]]]
    return nodes, chans
def load_snippets(time_step_ids, params):

    nb_channels = params.getint('data', 'N_e')
    nb_time_steps = params.getint('detection', 'N_t')
    do_spatial_whitening = params.getboolean('whitening', 'spatial')
    do_temporal_whitening = params.getboolean('whitening', 'temporal')
    spatial_whitening = load_data(
        params, 'spatial_whitening') if do_spatial_whitening else None
    temporal_whitening = load_data(
        params, 'temporal_whitening') if do_temporal_whitening else None

    data_file = params.get_data_file()
    chunk_size = nb_time_steps
    nodes, edges = get_nodes_and_edges(params)

    data_file.open()

    _ = data_file.analyze(chunk_size)  # i.e. count chunks in sources

    snippets = []
    for time_step_id in time_step_ids:
        t_start = time_step_id - int(nb_time_steps - 1) // 2
        idx = data_file.get_idx(t_start, chunk_size)
        padding = (0, nb_time_steps - 1)
        data, t_offset = data_file.get_data(idx,
                                            chunk_size,
                                            padding=padding,
                                            nodes=nodes)
        data = data[(t_start - t_offset) %
                    chunk_size:(t_start - t_offset) % chunk_size +
                    nb_time_steps, :]
        if do_spatial_whitening:
            data = np.dot(data, spatial_whitening)
        if do_temporal_whitening:
            data = sp.ndimage.filters.convolve1d(data,
                                                 temporal_whitening,
                                                 axis=0,
                                                 mode='constant')
        snippets.append(data)
    nb_snippets = len(snippets)
    snippets = np.array(snippets)
    snippets = np.reshape(snippets, (nb_snippets, nb_time_steps, nb_channels))

    data_file.close()

    return snippets
def plot_template(ax,
                  template,
                  params,
                  color='black',
                  vmin=None,
                  vmax=None,
                  label=None,
                  limits='auto'):

    nb_channels = params.getint('data', 'N_e')
    nb_time_steps = params.getint('detection', 'N_t')

    probe = params.probe
    nodes, edges = get_nodes_and_edges(params)

    positions = []
    for i in list(probe['channel_groups'][1]['geometry'].keys()):
        positions.append(probe['channel_groups'][1]['geometry'][i])
    positions = np.array(positions)
    vmin = np.abs(np.min(template)) if vmin is None else vmin
    vmax = np.abs(np.max(template)) if vmax is None else vmax
    dx = np.median(np.diff(np.unique(
        positions[:, 0])))  # horizontal inter-electrode distance
    dy = np.median(np.diff(np.unique(
        positions[:, 1])))  # vertical inter-electrode distance
    x_scaling = 0.8 * dx / 1.0
    y_scaling = 0.8 * dy / np.abs(vmax - vmin)

    ax.set_aspect('equal')
    for channel_id in range(0, nb_channels):
        if np.any(template[:, channel_id] != 0.0):
            x_c, y_c = positions[nodes[channel_id]]
            x = x_scaling * np.linspace(-0.5, +0.5, num=nb_time_steps) + x_c
            y = y_scaling * template[:, channel_id] + y_c
            plot_kwargs = {
                'color': color,
                'label': label,
            }
            ax.plot(x, y, **plot_kwargs)
            label = None  # i.e. label first plot only
    set_limits(ax, limits, positions)

    return
def plot_snippets(ax,
                  snippets,
                  params,
                  color='black',
                  vmin=None,
                  vmax=None,
                  limits='auto'):

    nb_channels = params.getint('data', 'N_e')
    nb_time_steps = params.getint('detection', 'N_t')

    probe = params.probe
    nodes, edges = get_nodes_and_edges(params)

    positions = []
    for i in list(probe['channel_groups'][1]['geometry'].keys()):
        positions.append(probe['channel_groups'][1]['geometry'][i])
    positions = np.array(positions)
    vmin = min(np.min(snippets), 0.0) if vmin is None else vmin
    vmax = max(np.max(snippets), 0.0) if vmax is None else vmax
    dx = np.median(np.diff(np.unique(
        positions[:, 0])))  # horizontal inter-electrode distance
    dy = np.median(np.diff(np.unique(
        positions[:, 1])))  # vertical inter-electrode distance
    x_scaling = 0.8 * dx / 1.0
    y_scaling = 0.8 * dy / np.abs(vmax - vmin)

    ax.set_aspect('equal')
    for channel_id in range(0, nb_channels):
        x_c, y_c = positions[nodes[channel_id]]
        x = x_scaling * np.linspace(-0.5, +0.5, num=nb_time_steps) + x_c
        for snippet in snippets:
            y = y_scaling * snippet[:, channel_id] + y_c
            ax.plot(x, y, color=color)
    set_limits(ax, limits, positions)

    return
Exemple #9
0
def main(params, nb_cpu, nb_gpu, use_gpu):
    # Part 1: Whitening
    numpy.random.seed(420)
    #params         = detect_memory(params)
    logger = init_logging(params.logfile)
    logger = logging.getLogger('circus.whitening')
    #################################################################
    data_file = params.data_file
    N_e = params.getint('data', 'N_e')
    hdf5_compress = params.getboolean('data', 'hdf5_compress')
    N_total = params.nb_channels
    N_t = params.getint('detection', 'N_t')
    dist_peaks = params.getint('detection', 'dist_peaks')
    template_shift = params.getint('detection', 'template_shift')
    file_out_suff = params.get('data', 'file_out_suff')
    spike_thresh = params.getfloat('detection', 'spike_thresh')
    spike_width = params.getfloat('detection', 'spike_width')
    matched_filter = params.getboolean('detection', 'matched-filter')
    matched_thresh = params.getfloat('detection', 'matched_thresh')
    sign_peaks = params.get('detection', 'peaks')
    do_temporal_whitening = params.getboolean('whitening', 'temporal')
    do_spatial_whitening = params.getboolean('whitening', 'spatial')
    chunk_size = detect_memory(params, whitening=True)
    plot_path = os.path.join(params.get('data', 'file_out_suff'), 'plots')
    nodes, edges = get_nodes_and_edges(params)
    safety_time = params.getint('whitening', 'safety_time')
    safety_space = params.getboolean('whitening', 'safety_space')
    nb_temp_white = min(max(20, comm.size), N_e)
    max_silence_1 = int(20 * params.rate // comm.size)
    max_silence_2 = 5000
    inv_nodes = numpy.zeros(N_total, dtype=numpy.int32)
    inv_nodes[nodes] = numpy.arange(len(nodes))
    jitter_range = params.getint('detection', 'jitter_range')
    template_shift_2 = template_shift + jitter_range
    use_hanning = params.getboolean('detection', 'hanning')
    data_file.open()
    #################################################################

    if use_hanning:
        hanning_filter = numpy.hanning(N_t)

    if comm.rank == 0:
        print_and_log(
            ["Analyzing data to get whitening matrices and thresholds..."],
            'default', logger)

    if use_gpu:
        import cudamat as cmt
        ## Need to properly handle multi GPU per MPI nodes?
        if nb_gpu > nb_cpu:
            gpu_id = int(comm.rank // nb_cpu)
        else:
            gpu_id = 0
        cmt.cuda_set_device(gpu_id)
        cmt.init()
        cmt.cuda_sync_threads()

    nb_chunks, last_chunk_len = data_file.analyze(chunk_size)

    if nb_chunks < comm.size:

        res = io.data_stats(params, show=False)
        chunk_size = int(res * params.rate // comm.size)
        if comm.rank == 0:
            print_and_log(
                ["Too much cores, automatically resizing the data chunks"],
                'debug', logger)

        nb_chunks, last_chunk_len = data_file.analyze(chunk_size)

    # I guess this is more relevant, to take signals from all over the recordings
    all_chunks = numpy.random.permutation(
        numpy.arange(nb_chunks, dtype=numpy.int32))
    all_electrodes = numpy.random.permutation(N_e)

    for gidx in [all_chunks[comm.rank]]:

        #print "Node", comm.rank, "is analyzing chunk", gidx,  "/", nb_chunks, " ..."
        local_chunk, t_offset = data_file.get_data(gidx,
                                                   chunk_size,
                                                   nodes=nodes)
        local_shape = len(local_chunk)

        #print "Node", comm.rank, "computes the median absolute deviations in a random chunk"
        thresholds = numpy.zeros(N_e, dtype=numpy.float32)
        for i in xrange(N_e):
            u = numpy.median(local_chunk[:, i], 0)
            thresholds[i] = numpy.median(numpy.abs(local_chunk[:, i] - u), 0)
        gdata = gather_array(thresholds, comm)
        if comm.rank == 0:
            gdata = gdata.reshape((comm.size, N_e))
            thresholds = numpy.mean(gdata, 0)
            bfile = h5py.File(file_out_suff + '.basis.hdf5',
                              'w',
                              libver='earliest')
            io.write_datasets(bfile, ['thresholds'],
                              {'thresholds': thresholds},
                              compression=hdf5_compress)
            bfile.close()
        comm.Barrier()
        thresholds = io.load_data(params, 'thresholds')

        #print "Extracting the peaks..."
        local_peaktimes = numpy.zeros(0, dtype=numpy.uint32)
        for i in xrange(N_e):
            peaktimes = scipy.signal.find_peaks(numpy.abs(local_chunk[:, i]),
                                                height=thresholds[i],
                                                width=spike_width,
                                                wlen=N_t)[0]
            local_peaktimes = numpy.concatenate((local_peaktimes, peaktimes))

        local_peaktimes = numpy.unique(local_peaktimes)

        #print "Removing the useless borders..."
        local_borders = (template_shift, local_shape - template_shift)
        idx = (local_peaktimes >= local_borders[0]) & (local_peaktimes <
                                                       local_borders[1])
        local_peaktimes = numpy.compress(idx, local_peaktimes)

        if len(local_peaktimes) > 0:

            diff_times = local_peaktimes[-1] - local_peaktimes[0]
            all_times = numpy.zeros((N_e, diff_times + 1), dtype=numpy.bool)
            min_times = numpy.maximum(
                local_peaktimes - local_peaktimes[0] - safety_time, 0)
            max_times = numpy.minimum(
                local_peaktimes - local_peaktimes[0] + safety_time + 1,
                diff_times)
            argmax_peak = numpy.random.permutation(
                numpy.arange(len(local_peaktimes)))
            all_idx = numpy.take(local_peaktimes, argmax_peak)

            #print "Selection of the peaks with spatio-temporal masks..."
            for idx, peak in zip(argmax_peak, all_idx):
                elec = numpy.argmax(numpy.abs(local_chunk[peak]))
                indices = numpy.take(inv_nodes, edges[nodes[elec]])
                if safety_space:
                    all_times[indices, min_times[idx]:max_times[idx]] = True
                else:
                    all_times[elec, min_times[idx]:max_times[idx]] = True
        else:
            all_times = numpy.zeros((N_e, len(local_chunk)), dtype=numpy.bool)

    if do_temporal_whitening:

        local_res_temp = []

        for elec in all_electrodes[numpy.arange(comm.rank, nb_temp_white,
                                                comm.size)]:
            res = numpy.zeros((0, N_t), dtype=numpy.float32)
            scount = 0
            indices = numpy.take(inv_nodes, edges[nodes[elec]])
            all_times_elec = numpy.any(numpy.take(all_times, indices, axis=0),
                                       0)
            esubset = numpy.where(all_times_elec == False)[0]
            bound = len(esubset) - N_t
            while (scount < bound) and (len(res) < max_silence_2):
                myslice = esubset[scount:scount + N_t]
                if numpy.all((myslice - esubset[scount]) == numpy.arange(N_t)):
                    scount += N_t
                    res = numpy.vstack((res, local_chunk[myslice, elec]))
                else:
                    scount += 1
            if len(res) > 5:
                local_res_temp += [numpy.cov(res.T)]

        nb_elecs = numpy.array([len(local_res_temp)], dtype=numpy.float32)
        local_res_temp = numpy.array(local_res_temp, dtype=numpy.float32)
        if len(local_res_temp) == 0:
            local_res_temp = numpy.zeros(0, dtype=numpy.float32)
        else:
            local_res_temp = numpy.sum(local_res_temp, 0)
        all_res_temp = gather_array(local_res_temp.ravel(), comm, 0, 1)
        all_elecs = gather_array(nb_elecs, comm, 0, 1)

    if do_spatial_whitening:

        local_res_spac = numpy.zeros((N_e, N_e), dtype=numpy.float32)
        local_silences = []

        for elec in numpy.arange(comm.rank, N_e, comm.size):
            indices = numpy.take(inv_nodes, edges[nodes[elec]])
            all_times_elec = numpy.any(numpy.take(all_times, indices, axis=0),
                                       0)
            esubset = numpy.where(all_times_elec == False)[0]
            local_data = local_chunk[esubset][:, indices]
            local_whitening = get_whitening_matrix(local_data).astype(
                numpy.float32)
            pos = numpy.where(elec == indices)[0]
            local_res_spac[elec, indices] = local_whitening[pos]
            local_silences += [len(esubset)]

        all_res_spac = gather_array(local_res_spac.ravel(), comm, 0, 1)
        all_silences = gather_array(
            numpy.array(local_silences, dtype=numpy.int32), comm, 0, 1,
            'uint32')

    if comm.rank == 0:

        to_write = {}

        if do_temporal_whitening:
            try:
                nb_silences = numpy.sum(all_elecs > 0)
                all_res_temp = all_res_temp.reshape((nb_silences, N_t**2))
            except Exception:
                print_and_log([
                    "No silent periods detected: something wrong with the parameters?"
                ], 'error', logger)
            all_res_temp = numpy.sum(all_res_temp, 0)
            all_res_temp = all_res_temp.reshape(
                (N_t, N_t)) / numpy.sum(all_elecs)
            temporal_whitening = get_whitening_matrix(
                all_res_temp.astype(numpy.double),
                fudge=1e-3)[template_shift].astype(numpy.float32)
            temporal_whitening /= temporal_whitening.sum()
            to_write['temporal'] = temporal_whitening
            have_nans = numpy.sum(numpy.isnan(temporal_whitening))

            if have_nans > 0:
                temporal_whitening = numpy.zeros(N_t, dtype=numpy.float32)
                temporal_whitening[N_t // 2] = 1
                to_write['temporal'] = temporal_whitening
                print_and_log(
                    ["Disabling temporal whitening because of NaNs found"],
                    'info', logger)

        if do_spatial_whitening:
            all_res_spac = all_res_spac.reshape(comm.size, N_e, N_e)
            spatial_whitening = numpy.sum(all_res_spac, 0)
            to_write['spatial'] = spatial_whitening

            print_and_log([
                "Found %gs without spikes for whitening matrices..." %
                (numpy.mean(all_silences) / params.rate)
            ], 'default', logger)

            have_nans = numpy.sum(numpy.isnan(spatial_whitening))

            if have_nans > 0:
                spatial_whitening = numpy.eye(spatial_whitening.shape[0],
                                              dtype=numpy.float32)
                to_write['spatial'] = spatial_whitening
                print_and_log(
                    ["Disabling spatial whitening because of NaNs found"],
                    'info', logger)

        bfile = h5py.File(file_out_suff + '.basis.hdf5',
                          'r+',
                          libver='earliest')
        io.write_datasets(bfile,
                          to_write.keys(),
                          to_write,
                          compression=hdf5_compress)
        bfile.close()

    comm.Barrier()

    if do_spatial_whitening or do_temporal_whitening:

        if comm.rank == 0:
            print_and_log(
                ["Because of whitening, need to recompute the thresholds..."],
                'default', logger)

        if do_spatial_whitening:
            spatial_whitening = io.load_data(params, 'spatial_whitening')
            if use_gpu:
                spatial_whitening = cmt.CUDAMatrix(spatial_whitening,
                                                   copy_on_host=False)
        if do_temporal_whitening:
            temporal_whitening = io.load_data(params, 'temporal_whitening')

        for gidx in [all_chunks[comm.rank]]:
            local_chunk, t_offset = data_file.get_data(gidx,
                                                       chunk_size,
                                                       nodes=nodes)
            local_shape = len(local_chunk)

            if do_spatial_whitening:
                if use_gpu:
                    local_chunk = cmt.CUDAMatrix(local_chunk,
                                                 copy_on_host=False)
                    local_chunk = local_chunk.dot(spatial_whitening).asarray()
                else:
                    local_chunk = numpy.dot(local_chunk, spatial_whitening)
            if do_temporal_whitening:
                local_chunk = scipy.ndimage.filters.convolve1d(
                    local_chunk, temporal_whitening, axis=0, mode='constant')

            thresholds = numpy.zeros(N_e, dtype=numpy.float32)
            for i in xrange(N_e):
                u = numpy.median(local_chunk[:, i], 0)
                thresholds[i] = numpy.median(numpy.abs(local_chunk[:, i] - u),
                                             0)
            gdata = gather_array(thresholds, comm)
            if comm.rank == 0:
                gdata = gdata.reshape((comm.size, N_e))
                thresholds = numpy.mean(gdata, 0)
                bfile = h5py.File(file_out_suff + '.basis.hdf5',
                                  'r+',
                                  libver='earliest')
                bfile.pop('thresholds')
                io.write_datasets(bfile, ['thresholds'],
                                  {'thresholds': thresholds},
                                  compression=hdf5_compress)
                bfile.close()
            comm.Barrier()

    #if comm.rank == 0:
    #if not os.path.exists(plot_path):
    #    os.makedirs(plot_path)
    #N_elec = min(int(numpy.sqrt(data_file.N_e)), 5)
    #plot.view_fit(filename, t_start=0, t_stop=1, fit_on=False, square=True,
    #              n_elec=N_elec, save=[plot_path, 'electrodes'])

    # Part 2: Basis
    numpy.random.seed(422)

    #################################################################
    file_out = params.get('data', 'file_out')
    alignment = params.getboolean('detection', 'alignment')
    smoothing = params.getboolean('detection', 'smoothing')
    isolation = params.getboolean('detection', 'isolation')
    over_factor = params.getint('detection', 'oversampling_factor')
    spike_thresh = params.getfloat('detection', 'spike_thresh')
    nodes, edges = get_nodes_and_edges(params)
    do_temporal_whitening = params.getboolean('whitening', 'temporal')
    do_spatial_whitening = params.getboolean('whitening', 'spatial')
    if matched_filter:
        chunk_size = detect_memory(params, whitening=True)
    else:
        chunk_size = detect_memory(params)
    safety_time = params.getint('whitening', 'safety_time')
    max_elts_elec = params.getint('whitening', 'max_elts')
    output_dim = params.getfloat('whitening', 'output_dim')
    inv_nodes = numpy.zeros(N_total, dtype=numpy.int32)
    inv_nodes[nodes] = numpy.arange(len(nodes))
    smoothing_factor = params.getfloat(
        'detection', 'smoothing_factor') * (1. / spike_thresh)**2
    if sign_peaks == 'both':
        max_elts_elec *= 2
    nb_elts = int(
        params.getfloat('whitening', 'nb_elts') * N_e * max_elts_elec)

    ignore_dead_times = params.getboolean('triggers', 'ignore_times')
    if ignore_dead_times:
        all_dead_times = get_dead_times(params)
    data_file.open()
    #################################################################

    if comm.rank == 0:
        print_and_log(["Searching spikes to construct the PCA basis..."],
                      'default', logger)

    nb_chunks, last_chunk_len = data_file.analyze(chunk_size)

    if nb_chunks < comm.size:

        res = io.data_stats(params, show=False)
        chunk_size = int(res * params.rate // comm.size)
        if comm.rank == 0:
            print_and_log(
                ["Too much cores, automatically resizing the data chunks"],
                'debug', logger)

        nb_chunks, last_chunk_len = data_file.analyze(chunk_size)

    groups = {}
    for i in xrange(N_e):
        groups[i] = 0

    # I guess this is more relevant, to take signals from all over the recordings
    all_chunks = numpy.random.permutation(
        numpy.arange(nb_chunks, dtype=numpy.int32))
    max_elts_elec //= comm.size
    nb_elts //= comm.size

    elt_count_pos = 0
    elt_count_neg = 0

    if sign_peaks in ['positive', 'both']:
        elts_pos = numpy.zeros((N_t, nb_elts), dtype=numpy.float32)
    if sign_peaks in ['negative', 'both']:
        elts_neg = numpy.zeros((N_t, nb_elts), dtype=numpy.float32)

    chunks_to_load = all_chunks[comm.rank::comm.size]

    thresholds = io.load_data(params, 'thresholds')
    mads = io.load_data(params, 'mads')

    if alignment:
        cdata = numpy.linspace(-jitter_range, jitter_range,
                               int(over_factor * 2 * jitter_range))
        xdata = numpy.arange(-template_shift_2, template_shift_2 + 1)
        xoff = len(cdata) / 2.

    if isolation:
        yoff = numpy.array(range(0, N_t // 4) + range(3 * N_t // 4, N_t))

    to_explore = xrange(comm.rank, nb_chunks, comm.size)

    if comm.rank == 0:
        to_explore = get_tqdm_progressbar(to_explore)

    for gcount, gidx in enumerate(to_explore):

        gidx = all_chunks[gidx]

        if ((elt_count_pos + elt_count_neg) < nb_elts):
            #print "Node", comm.rank, "is analyzing chunk", gidx, "/", nb_chunks, " ..."
            local_chunk, t_offset = data_file.get_data(gidx,
                                                       chunk_size,
                                                       nodes=nodes)
            local_shape = len(local_chunk)

            if do_spatial_whitening:
                if use_gpu:
                    local_chunk = cmt.CUDAMatrix(local_chunk,
                                                 copy_on_host=False)
                    local_chunk = local_chunk.dot(spatial_whitening).asarray()
                else:
                    local_chunk = numpy.dot(local_chunk, spatial_whitening)
            if do_temporal_whitening:
                local_chunk = scipy.ndimage.filters.convolve1d(
                    local_chunk, temporal_whitening, axis=0, mode='constant')

            #print "Extracting the peaks..."
            all_peaktimes = numpy.zeros(0, dtype=numpy.uint32)
            all_extremas = numpy.zeros(0, dtype=numpy.uint32)

            for i in xrange(N_e):

                if sign_peaks == 'negative':
                    peaktimes = scipy.signal.find_peaks(-local_chunk[:, i],
                                                        height=thresholds[i],
                                                        width=spike_width,
                                                        distance=dist_peaks,
                                                        wlen=N_t)[0]
                elif sign_peaks == 'positive':
                    peaktimes = scipy.signal.find_peaks(local_chunk[:, i],
                                                        height=thresholds[i],
                                                        width=spike_width,
                                                        wlen=N_t)[0]
                elif sign_peaks == 'both':
                    peaktimes = scipy.signal.find_peaks(numpy.abs(
                        local_chunk[:, i]),
                                                        height=thresholds[i],
                                                        width=spike_width,
                                                        wlen=N_t)[0]
                all_peaktimes = numpy.concatenate((all_peaktimes, peaktimes))
                all_extremas = numpy.concatenate(
                    (all_extremas,
                     i * numpy.ones(len(peaktimes), dtype=numpy.uint32)))

            #print "Removing the useless borders..."
            if alignment:
                local_borders = (template_shift_2,
                                 local_shape - template_shift_2)
            else:
                local_borders = (template_shift, local_shape - template_shift)
            idx = (all_peaktimes >= local_borders[0]) & (all_peaktimes <
                                                         local_borders[1])
            all_peaktimes = numpy.compress(idx, all_peaktimes)
            all_extremas = numpy.compress(idx, all_extremas)

            local_peaktimes = numpy.unique(all_peaktimes)

            if ignore_dead_times:
                dead_indices = numpy.searchsorted(
                    all_dead_times, [t_offset, t_offset + local_shape])
                if dead_indices[0] != dead_indices[1]:
                    is_included = numpy.in1d(
                        local_peaktimes + t_offset,
                        all_dead_times[dead_indices[0]:dead_indices[1]])
                    local_peaktimes = local_peaktimes[~is_included]
                    local_peaktimes = numpy.sort(local_peaktimes)

            if len(local_peaktimes) > 0:

                diff_times = local_peaktimes[-1] - local_peaktimes[0]
                all_times = numpy.zeros((N_e, diff_times + 1),
                                        dtype=numpy.bool)
                min_times = numpy.maximum(
                    local_peaktimes - local_peaktimes[0] - safety_time, 0)
                max_times = numpy.minimum(
                    local_peaktimes - local_peaktimes[0] + safety_time + 1,
                    diff_times)

                n_times = len(local_peaktimes)
                argmax_peak = numpy.random.permutation(numpy.arange(n_times))
                all_idx = numpy.take(local_peaktimes, argmax_peak)

                #print "Selection of the peaks with spatio-temporal masks..."
                for midx, peak in zip(argmax_peak, all_idx):
                    if (elt_count_neg + elt_count_pos) == nb_elts:
                        break

                    if sign_peaks == 'negative':
                        elec = numpy.argmin(local_chunk[peak])
                        negative_peak = True
                    elif sign_peaks == 'positive':
                        elec = numpy.argmax(local_chunk[peak])
                        negative_peak = False
                    elif sign_peaks == 'both':
                        if N_e == 1:
                            if local_chunk[peak] < 0:
                                negative_peak = True
                            elif local_chunk[peak] > 0:
                                negative_peak = False
                            elec = 0
                        else:
                            if numpy.abs(numpy.max(
                                    local_chunk[peak])) > numpy.abs(
                                        numpy.min(local_chunk[peak])):
                                elec = numpy.argmax(local_chunk[peak])
                                negative_peak = False
                            else:
                                elec = numpy.argmin(local_chunk[peak])
                                negative_peak = True

                    indices = numpy.take(inv_nodes, edges[nodes[elec]])
                    myslice = all_times[indices,
                                        min_times[midx]:max_times[midx]]
                    is_local_extrema = elec in all_extremas[all_peaktimes ==
                                                            peak]
                    if is_local_extrema and not myslice.any():
                        upper_bounds = max_elts_elec

                        if groups[elec] < upper_bounds:

                            if not alignment:
                                sub_mat = local_chunk[peak -
                                                      template_shift:peak +
                                                      template_shift + 1, elec]

                            elif alignment:
                                ydata = local_chunk[peak -
                                                    template_shift_2:peak +
                                                    template_shift_2 + 1, elec]

                                if smoothing:
                                    factor = smoothing_factor * xdata.size
                                    f = scipy.interpolate.UnivariateSpline(
                                        xdata, ydata, s=factor, k=3)
                                else:
                                    f = scipy.interpolate.UnivariateSpline(
                                        xdata, ydata, k=3, s=0)
                                if negative_peak:
                                    rmin = (numpy.argmin(f(cdata)) -
                                            xoff) / over_factor
                                else:
                                    rmin = (numpy.argmax(f(cdata)) -
                                            xoff) / over_factor
                                ddata = numpy.linspace(rmin - template_shift,
                                                       rmin + template_shift,
                                                       N_t)
                                sub_mat = f(ddata).astype(numpy.float32)

                            if alignment:
                                if negative_peak:
                                    if numpy.min(sub_mat) >= -thresholds[elec]:
                                        to_accept = False
                                else:
                                    if numpy.max(sub_mat) <= thresholds[elec]:
                                        to_accept = False

                            if isolation:
                                to_accept = numpy.all(
                                    numpy.max(numpy.abs(sub_mat[yoff])) <=
                                    thresholds[elec])
                            else:
                                to_accept = True

                            if to_accept:
                                if negative_peak:
                                    elts_neg[:, elt_count_neg] = sub_mat
                                else:
                                    elts_pos[:, elt_count_pos] = sub_mat

                                if negative_peak:
                                    elt_count_neg += 1
                                else:
                                    elt_count_pos += 1

                        groups[elec] += 1
                        all_times[indices,
                                  min_times[midx]:max_times[midx]] = True

    sys.stderr.flush()

    if isolation:
        print_and_log([
            "Node %d has collected %d isolated waveforms" %
            (comm.rank, elt_count_pos + elt_count_neg)
        ], 'debug', logger)
    else:
        print_and_log([
            "Node %d has collected %d waveforms" %
            (comm.rank, elt_count_pos + elt_count_neg)
        ], 'debug', logger)

    if sign_peaks in ['negative', 'both']:
        gdata_neg = gather_array(elts_neg[:, :elt_count_neg].T, comm, 0, 1)
    if sign_peaks in ['positive', 'both']:
        gdata_pos = gather_array(elts_pos[:, :elt_count_pos].T, comm, 0, 1)

    nb_waveforms = 0

    if comm.rank == 0:
        #DO PCA on elts and store the basis obtained.

        if sign_peaks in ['negative', 'both']:
            nb_waveforms += gdata_neg.shape[0]
        if sign_peaks in ['positive', 'both']:
            nb_waveforms += gdata_pos.shape[0]

    nb_waveforms = all_gather_array(
        numpy.array([nb_waveforms], dtype=numpy.float32), comm, 0)[0]

    if comm.rank == 0:
        if isolation:
            print_and_log([
                "Found %d isolated waveforms over %d requested" %
                (nb_waveforms, int(nb_elts * comm.size))
            ], 'default', logger)
        else:
            print_and_log([
                "Found %d waveforms over %d requested" %
                (nb_waveforms, int(nb_elts * comm.size))
            ], 'default', logger)

        if nb_waveforms == 0:
            print_and_log(
                ['No waveforms found! Are the data properly loaded??'],
                'error', logger)

    if nb_waveforms == 0:
        sys.exit(0)

    if comm.rank == 0:
        res = {}
        pca = None
        pca_pos = None
        pca_neg = None
        if sign_peaks in ['negative', 'both']:
            if len(gdata_neg) > 0:
                pca = PCA(output_dim)
                if use_hanning:
                    pca.fit(gdata_neg * hanning_filter)
                else:
                    pca.fit(gdata_neg)
                res['proj'] = pca.components_.T.astype(numpy.float32)
                pca_neg = numpy.sum(pca.explained_variance_ratio_)
            else:
                res['proj'] = numpy.identity(int(output_dim),
                                             dtype=numpy.float32)
            res['rec'] = res['proj'].T
            res['waveform'] = numpy.median(gdata_neg, 0)
            idx = numpy.random.permutation(numpy.arange(
                gdata_neg.shape[0]))[:1000]
            res['waveforms'] = gdata_neg[idx, :]
        if sign_peaks in ['positive', 'both']:
            if len(gdata_pos) > 0:
                pca = PCA(output_dim)
                if use_hanning:
                    pca.fit(gdata_pos * hanning_filter)
                else:
                    pca.fit(gdata_pos)
                res['proj_pos'] = pca.components_.T.astype(numpy.float32)
                pca_pos = numpy.sum(pca.explained_variance_ratio_)
            else:
                res['proj_pos'] = numpy.identity(int(output_dim),
                                                 dtype=numpy.float32)
            res['rec_pos'] = res['proj_pos'].T
            res['waveform_pos'] = numpy.median(gdata_pos, 0)
            idx = numpy.random.permutation(numpy.arange(
                gdata_pos.shape[0]))[:1000]
            res['waveforms_pos'] = gdata_pos[idx, :]

        bfile = h5py.File(file_out_suff + '.basis.hdf5',
                          'r+',
                          libver='earliest')
        io.write_datasets(bfile, res.keys(), res, compression=hdf5_compress)
        if sign_peaks == 'positive':
            print_and_log([
                "A basis with %s dimensions has been built" %
                res['proj_pos'].shape[1]
            ], 'info', logger)
        elif sign_peaks == 'negative':
            print_and_log([
                "A basis with %s dimensions has been built" %
                res['proj'].shape[1]
            ], 'info', logger)
        elif sign_peaks == 'both':
            print_and_log([
                "Two basis with %s dimensions has been built" %
                res['proj'].shape[1]
            ], 'debug', logger)
        if pca_pos is not None:
            print_and_log([
                "The percentage of variance explained is %s for positive spikes"
                % pca_pos
            ], 'debug', logger)
        if pca_neg is not None:
            print_and_log([
                "The percentage of variance explained is %s for negative spikes"
                % pca_neg
            ], 'debug', logger)

        bfile.close()

    comm.Barrier()

    if matched_filter:

        if comm.rank == 0:
            print_and_log([
                "Because of matched filters, need to recompute the thresholds..."
            ], 'default', logger)

        if do_spatial_whitening:
            spatial_whitening = io.load_data(params, 'spatial_whitening')
            if use_gpu:
                spatial_whitening = cmt.CUDAMatrix(spatial_whitening,
                                                   copy_on_host=False)
        if do_temporal_whitening:
            temporal_whitening = io.load_data(params, 'temporal_whitening')

        if sign_peaks in ['negative', 'both']:
            waveform_neg = io.load_data(params, 'waveform')[::-1]
            waveform_neg /= (numpy.abs(numpy.sum(waveform_neg)) *
                             len(waveform_neg))
        if sign_peaks in ['positive', 'both']:
            waveform_pos = io.load_data(params, 'waveform-pos')[::-1]
            waveform_pos /= (numpy.abs(numpy.sum(waveform_pos)) *
                             len(waveform_pos))

        for gidx in [all_chunks[comm.rank]]:
            local_chunk, t_offset = data_file.get_data(gidx,
                                                       chunk_size,
                                                       nodes=nodes)
            local_shape = len(local_chunk)

            if do_spatial_whitening:
                if use_gpu:
                    local_chunk = cmt.CUDAMatrix(local_chunk,
                                                 copy_on_host=False)
                    local_chunk = local_chunk.dot(spatial_whitening).asarray()
                else:
                    local_chunk = numpy.dot(local_chunk, spatial_whitening)
            if do_temporal_whitening:
                local_chunk = scipy.ndimage.filters.convolve1d(
                    local_chunk, temporal_whitening, axis=0, mode='constant')

            local_chunk /= thresholds

            if sign_peaks in ['negative', 'both']:
                tmp_chunk = scipy.ndimage.filters.convolve1d(local_chunk,
                                                             waveform_neg,
                                                             axis=0,
                                                             mode='constant')
                thresholds = numpy.zeros(N_e, dtype=numpy.float32)
                for i in xrange(N_e):
                    u = numpy.median(tmp_chunk[:, i], 0)
                    thresholds[i] = numpy.median(
                        numpy.abs(tmp_chunk[:, i] - u), 0)
                gdata = gather_array(thresholds, comm)
                if comm.rank == 0:
                    gdata = gdata.reshape((comm.size, N_e))
                    thresholds = numpy.mean(gdata, 0)
                    bfile = h5py.File(file_out_suff + '.basis.hdf5',
                                      'r+',
                                      libver='earliest')
                    io.write_datasets(bfile, ['matched_thresholds'],
                                      {'matched_thresholds': thresholds},
                                      compression=hdf5_compress)
                    bfile.close()
                comm.Barrier()

            if sign_peaks in ['positive', 'both']:
                tmp_chunk = scipy.ndimage.filters.convolve1d(local_chunk,
                                                             waveform_pos,
                                                             axis=0,
                                                             mode='constant')
                thresholds = numpy.zeros(N_e, dtype=numpy.float32)
                for i in xrange(N_e):
                    u = numpy.median(tmp_chunk[:, i], 0)
                    thresholds[i] = numpy.median(
                        numpy.abs(tmp_chunk[:, i] - u), 0)
                gdata = gather_array(thresholds, comm)
                if comm.rank == 0:
                    gdata = gdata.reshape((comm.size, N_e))
                    thresholds = numpy.mean(gdata, 0)
                    bfile = h5py.File(file_out_suff + '.basis.hdf5',
                                      'r+',
                                      libver='earliest')
                    io.write_datasets(bfile, ['matched_thresholds_pos'],
                                      {'matched_thresholds_pos': thresholds},
                                      compression=hdf5_compress)
                    bfile.close()
                comm.Barrier()

    data_file.close()
Exemple #10
0
    def write_pcs(path, params, extension, N_tm, mode=0):

        spikes          = numpy.load(os.path.join(output_path, 'spike_times.npy'))
        labels          = numpy.load(os.path.join(output_path, 'spike_templates.npy'))
        max_loc_channel = get_max_loc_channel(params)
        nb_features     = params.getint('whitening', 'output_dim')
        sign_peaks      = params.get('detection', 'peaks')
        nodes, edges    = get_nodes_and_edges(params)
        N_total         = params.getint('data', 'N_total')

        if export_all:
            nb_templates = N_tm + N_e
        else:
            nb_templates = N_tm

        pc_features_ind = numpy.zeros((nb_templates, max_loc_channel), dtype=numpy.int32)            
        clusters        = io.load_data(params, 'clusters', extension)
        best_elec       = clusters['electrodes']
        if export_all:
            best_elec = numpy.concatenate((best_elec, numpy.arange(N_e)))
        inv_nodes        = numpy.zeros(N_total, dtype=numpy.int32)
        inv_nodes[nodes] = numpy.argsort(nodes)

        for count, elec in enumerate(best_elec):
            nb_loc                = len(edges[nodes[elec]])
            pc_features_ind[count, numpy.arange(nb_loc)] = inv_nodes[edges[nodes[elec]]]

        if sign_peaks in ['negative', 'both']:
            basis_proj, basis_rec = io.load_data(params, 'basis')
        elif sign_peaks in ['positive']:
            basis_proj, basis_rec = io.load_data(params, 'basis-pos')

        to_process = numpy.arange(comm.rank, nb_templates, comm.size)

        all_offsets = numpy.zeros(nb_templates, dtype=numpy.int32)
        for target in xrange(nb_templates):
            if mode == 0:
                all_offsets[target] = len(numpy.where(labels == target)[0])
            elif mode == 1:
                all_offsets[target] = min(500, len(numpy.where(labels == target)[0]))

        all_paddings = numpy.concatenate(([0] , numpy.cumsum(all_offsets)))
        total_pcs   = numpy.sum(all_offsets)

        pc_file     = os.path.join(output_path, 'pc_features.npy')
        pc_file_ids = os.path.join(output_path, 'pc_feature_spike_ids.npy')

        from numpy.lib.format import open_memmap

        if comm.rank == 0:
            pc_features = open_memmap(pc_file, shape=(total_pcs, nb_features, max_loc_channel), dtype=numpy.float32, mode='w+')
            if mode == 1:
                pc_ids = open_memmap(pc_file_ids, shape=(total_pcs, ), dtype=numpy.int32, mode='w+')

        comm.Barrier()
        pc_features = open_memmap(pc_file, mode='r+')
        if mode == 1:
            pc_ids = open_memmap(pc_file_ids, mode='r+')

        to_explore = xrange(comm.rank, nb_templates, comm.size)

        if comm.rank == 0:
          to_explore = get_tqdm_progressbar(to_explore)

        all_idx = numpy.zeros(0, dtype=numpy.int32)
        for gcount, target in enumerate(to_explore):

            count    = all_paddings[target]
            
            if mode == 1:
                idx  = numpy.random.permutation(numpy.where(labels == target)[0])[:500]
                pc_ids[count:count+len(idx)] = idx
            elif mode == 0:
                idx  = numpy.where(labels == target)[0]

            elec     = best_elec[target]
            indices  = inv_nodes[edges[nodes[elec]]]
            labels_i = target*numpy.ones(len(idx))
            times_i  = numpy.take(spikes, idx).astype(numpy.int64)
            sub_data = io.get_stas(params, times_i, labels_i, elec, neighs=indices, nodes=nodes, auto_align=False)
            
            pcs      = numpy.dot(sub_data, basis_proj)
            pcs      = numpy.swapaxes(pcs, 1,2)
            if mode == 0:
                pc_features[idx, :, :len(indices)] = pcs                    
            elif mode == 1:
                pc_features[count:count+len(idx), :, :len(indices)] = pcs

        comm.Barrier()

        if comm.rank == 0:
            numpy.save(os.path.join(output_path, 'pc_feature_ind'), pc_features_ind.astype(numpy.uint32)) #n_templates, n_loc_chan
Exemple #11
0
def main(params, nb_cpu, nb_gpu, use_gpu, extension):

    logger         = init_logging(params.logfile)
    logger         = logging.getLogger('circus.converting')
    data_file      = params.data_file
    file_out_suff  = params.get('data', 'file_out_suff')
    probe          = params.probe
    output_path    = params.get('data', 'file_out_suff') + extension + '.GUI'
    N_e            = params.getint('data', 'N_e')
    N_t            = params.getint('detection', 'N_t')
    erase_all      = params.getboolean('converting', 'erase_all')
    export_pcs     = params.get('converting', 'export_pcs')
    export_all     = params.getboolean('converting', 'export_all')
    sparse_export  = params.getboolean('converting', 'sparse_export')
    if export_all and not params.getboolean('fitting', 'collect_all'):
        if comm.rank == 0:
            print_and_log(['Export unfitted spikes only if [fitting] collect_all is True'], 'error', logger)
        sys.exit(0)

    def generate_mapping(probe):
        p         = {}
        positions = []
        nodes     = []
        for key in probe['channel_groups'].keys():
            p.update(probe['channel_groups'][key]['geometry'])
            nodes     +=  probe['channel_groups'][key]['channels']
            positions += [p[channel] for channel in probe['channel_groups'][key]['channels']]
        idx       = numpy.argsort(nodes)
        positions = numpy.array(positions)[idx]
        return positions

    def get_max_loc_channel(params):
        nodes, edges    = get_nodes_and_edges(params)
        max_loc_channel = 0
        for key in edges.keys():
            if len(edges[key]) > max_loc_channel:
                max_loc_channel = len(edges[key])
        return max_loc_channel

    def write_results(path, params, extension):
        result     = io.get_results(params, extension)
        spikes     = numpy.zeros(0, dtype=numpy.uint64)
        clusters   = numpy.zeros(0, dtype=numpy.uint32)
        amplitudes = numpy.zeros(0, dtype=numpy.double)
        N_tm       = len(result['spiketimes'])
        for key in result['spiketimes'].keys():
            temp_id    = int(key.split('_')[-1])
            data       = result['spiketimes'].pop(key).astype(numpy.uint64)
            spikes     = numpy.concatenate((spikes, data))
            data       = result['amplitudes'].pop(key).astype(numpy.double)
            amplitudes = numpy.concatenate((amplitudes, data[:, 0]))
            clusters   = numpy.concatenate((clusters, temp_id*numpy.ones(len(data), dtype=numpy.uint32)))
        
        if export_all:
            print_and_log(["Last %d templates are unfitted spikes on all electrodes" %N_e], 'info', logger)
            garbage = io.load_data(params, 'garbage', extension)
            for key in garbage['gspikes'].keys():
                elec_id    = int(key.split('_')[-1])
                data       = garbage['gspikes'].pop(key).astype(numpy.uint64)
                spikes     = numpy.concatenate((spikes, data))
                amplitudes = numpy.concatenate((amplitudes, numpy.zeros(len(data))))
                clusters   = numpy.concatenate((clusters, (elec_id + N_tm)*numpy.ones(len(data), dtype=numpy.uint32)))                

        idx = numpy.argsort(spikes)
        numpy.save(os.path.join(output_path, 'spike_templates'), clusters[idx])
        numpy.save(os.path.join(output_path, 'spike_times'), spikes[idx])
        numpy.save(os.path.join(output_path, 'amplitudes'), amplitudes[idx])
        return

    def write_templates(path, params, extension):

        max_loc_channel = get_max_loc_channel(params)
        templates       = io.load_data(params, 'templates', extension)
        N_tm            = templates.shape[1]//2

        if sparse_export:
            n_channels_max = 0
            for t in xrange(N_tm):
                data = numpy.sum(numpy.sum(templates[:, t].toarray().reshape(N_e, N_t), 1) != 0) 
                if data > n_channels_max:
                    n_channels_max = data
        else:
            n_channels_max = N_e

        if export_all:
            to_write_sparse = numpy.zeros((N_tm + N_e, N_t, n_channels_max), dtype=numpy.float32)
            mapping_sparse  = -1 * numpy.ones((N_tm + N_e, n_channels_max), dtype=numpy.int32)            
        else:
            to_write_sparse = numpy.zeros((N_tm, N_t, n_channels_max), dtype=numpy.float32)
            mapping_sparse  = -1 * numpy.ones((N_tm, n_channels_max), dtype=numpy.int32)
            
        for t in xrange(N_tm):
            tmp                              = templates[:, t].toarray().reshape(N_e, N_t).T
            x, y                             = tmp.nonzero()
            nb_loc                           = len(numpy.unique(y))
                
            if sparse_export:
                all_positions                    = numpy.zeros(y.max()+1, dtype=numpy.int32)
                all_positions[numpy.unique(y)]   = numpy.arange(nb_loc, dtype=numpy.int32)
                pos                              = all_positions[y]
                to_write_sparse[t, x, pos]       = tmp[x, y] 
                mapping_sparse[t, numpy.arange(nb_loc)] = numpy.unique(y)
            else:
                pos                              = y
                to_write_sparse[t, x, pos]       = tmp[x, y] 

        if export_all:
            for t in xrange(N_tm, N_tm + N_e):
                mapping_sparse[t, 0] = t - N_tm

        numpy.save(os.path.join(output_path, 'templates'), to_write_sparse)

        if sparse_export:
            numpy.save(os.path.join(output_path, 'template_ind'), mapping_sparse)

        return N_tm

    def write_pcs(path, params, extension, N_tm, mode=0):

        spikes          = numpy.load(os.path.join(output_path, 'spike_times.npy'))
        labels          = numpy.load(os.path.join(output_path, 'spike_templates.npy'))
        max_loc_channel = get_max_loc_channel(params)
        nb_features     = params.getint('whitening', 'output_dim')
        sign_peaks      = params.get('detection', 'peaks')
        nodes, edges    = get_nodes_and_edges(params)
        N_total         = params.getint('data', 'N_total')

        if export_all:
            nb_templates = N_tm + N_e
        else:
            nb_templates = N_tm

        pc_features_ind = numpy.zeros((nb_templates, max_loc_channel), dtype=numpy.int32)            
        clusters        = io.load_data(params, 'clusters', extension)
        best_elec       = clusters['electrodes']
        if export_all:
            best_elec = numpy.concatenate((best_elec, numpy.arange(N_e)))
        inv_nodes        = numpy.zeros(N_total, dtype=numpy.int32)
        inv_nodes[nodes] = numpy.argsort(nodes)

        for count, elec in enumerate(best_elec):
            nb_loc                = len(edges[nodes[elec]])
            pc_features_ind[count, numpy.arange(nb_loc)] = inv_nodes[edges[nodes[elec]]]

        if sign_peaks in ['negative', 'both']:
            basis_proj, basis_rec = io.load_data(params, 'basis')
        elif sign_peaks in ['positive']:
            basis_proj, basis_rec = io.load_data(params, 'basis-pos')

        to_process = numpy.arange(comm.rank, nb_templates, comm.size)

        all_offsets = numpy.zeros(nb_templates, dtype=numpy.int32)
        for target in xrange(nb_templates):
            if mode == 0:
                all_offsets[target] = len(numpy.where(labels == target)[0])
            elif mode == 1:
                all_offsets[target] = min(500, len(numpy.where(labels == target)[0]))

        all_paddings = numpy.concatenate(([0] , numpy.cumsum(all_offsets)))
        total_pcs   = numpy.sum(all_offsets)

        pc_file     = os.path.join(output_path, 'pc_features.npy')
        pc_file_ids = os.path.join(output_path, 'pc_feature_spike_ids.npy')

        from numpy.lib.format import open_memmap

        if comm.rank == 0:
            pc_features = open_memmap(pc_file, shape=(total_pcs, nb_features, max_loc_channel), dtype=numpy.float32, mode='w+')
            if mode == 1:
                pc_ids = open_memmap(pc_file_ids, shape=(total_pcs, ), dtype=numpy.int32, mode='w+')

        comm.Barrier()
        pc_features = open_memmap(pc_file, mode='r+')
        if mode == 1:
            pc_ids = open_memmap(pc_file_ids, mode='r+')

        to_explore = xrange(comm.rank, nb_templates, comm.size)

        if comm.rank == 0:
          to_explore = get_tqdm_progressbar(to_explore)

        all_idx = numpy.zeros(0, dtype=numpy.int32)
        for gcount, target in enumerate(to_explore):

            count    = all_paddings[target]
            
            if mode == 1:
                idx  = numpy.random.permutation(numpy.where(labels == target)[0])[:500]
                pc_ids[count:count+len(idx)] = idx
            elif mode == 0:
                idx  = numpy.where(labels == target)[0]

            elec     = best_elec[target]
            indices  = inv_nodes[edges[nodes[elec]]]
            labels_i = target*numpy.ones(len(idx))
            times_i  = numpy.take(spikes, idx).astype(numpy.int64)
            sub_data = io.get_stas(params, times_i, labels_i, elec, neighs=indices, nodes=nodes, auto_align=False)
            
            pcs      = numpy.dot(sub_data, basis_proj)
            pcs      = numpy.swapaxes(pcs, 1,2)
            if mode == 0:
                pc_features[idx, :, :len(indices)] = pcs                    
            elif mode == 1:
                pc_features[count:count+len(idx), :, :len(indices)] = pcs

        comm.Barrier()

        if comm.rank == 0:
            numpy.save(os.path.join(output_path, 'pc_feature_ind'), pc_features_ind.astype(numpy.uint32)) #n_templates, n_loc_chan

    do_export = True
    if comm.rank == 0:
        if os.path.exists(output_path):
            if not erase_all:
                do_export = query_yes_no(Fore.WHITE + "Export already made! Do you want to erase everything?", default=None)

            if do_export:
                if os.path.exists(os.path.abspath('.phy')):
                    shutil.rmtree(os.path.abspath('.phy'))
                shutil.rmtree(output_path)
        if do_export == True:
            comm.bcast(numpy.array([1], dtype=numpy.int32), root=0)
        elif do_export == False:
            comm.bcast(numpy.array([0], dtype=numpy.int32), root=0)
    else:
        do_export = bool(comm.bcast(numpy.array([0], dtype=numpy.int32), root=0))
    
    comm.Barrier()

    if do_export:

        if comm.rank == 0:
            os.makedirs(output_path)
            print_and_log(["Exporting data for the phy GUI with %d CPUs..." %nb_cpu], 'info', logger)
        
            if params.getboolean('whitening', 'spatial'):
                whitening_mat = io.load_data(params, 'spatial_whitening').astype(numpy.double)
                numpy.save(os.path.join(output_path, 'whitening_mat'), whitening_mat)
                numpy.save(os.path.join(output_path, 'whitening_mat_inv'), numpy.linalg.inv(whitening_mat))
            else:
                numpy.save(os.path.join(output_path, 'whitening_mat'), numpy.eye(N_e))

            numpy.save(os.path.join(output_path, 'channel_positions'), generate_mapping(probe).astype(numpy.double))
            nodes, edges   = get_nodes_and_edges(params)
            numpy.save(os.path.join(output_path, 'channel_map'), nodes.astype(numpy.int32))

            write_results(output_path, params, extension) 
 
            N_tm = write_templates(output_path, params, extension)

            apply_patch_for_similarities(params, extension)

            template_file = h5py.File(file_out_suff + '.templates%s.hdf5' %extension, 'r', libver='earliest')
            similarities  = template_file.get('maxoverlap')[:]
            template_file.close()
            norm = N_e*N_t

            if export_all:
                to_write = numpy.zeros((N_tm + N_e, N_tm + N_e), dtype=numpy.single)
                to_write[:N_tm, :N_tm] = (similarities[:N_tm, :N_tm]/norm).astype(numpy.single)
            else:
                to_write = (similarities[:N_tm, :N_tm]/norm).astype(numpy.single)
            numpy.save(os.path.join(output_path, 'similar_templates'), to_write)
        
            comm.bcast(numpy.array([N_tm], dtype=numpy.int32), root=0)

        else:
            N_tm = int(comm.bcast(numpy.array([0], dtype=numpy.int32), root=0))

        comm.Barrier()

        make_pcs = 2
        if comm.rank == 0:

            if export_pcs == 'prompt':
                key = ''
                while key not in ['a', 's', 'n']:
                    print(Fore.WHITE + "Do you want SpyKING CIRCUS to export PCs? (a)ll / (s)ome / (n)o")
                    key = raw_input('')
            else:
                key = export_pcs

            if key == 'a':
                make_pcs = 0
                comm.bcast(numpy.array([0], dtype=numpy.int32), root=0)
            elif key == 's':
                make_pcs = 1
                comm.bcast(numpy.array([1], dtype=numpy.int32), root=0)
            elif key == 'n':
                comm.bcast(numpy.array([2], dtype=numpy.int32), root=0)
                if os.path.exists(os.path.join(output_path, 'pc_features.npy')):
                    os.remove(os.path.join(output_path, 'pc_features.npy'))
                if os.path.exists(os.path.join(output_path, 'pc_feature_ind.npy')):
                    os.remove(os.path.join(output_path, 'pc_feature_ind.npy'))
        else:
            make_pcs = comm.bcast(numpy.array([0], dtype=numpy.int32), root=0)
            make_pcs = make_pcs[0]

        comm.Barrier()
        if make_pcs < 2:
            write_pcs(output_path, params, extension, N_tm, make_pcs)
Exemple #12
0
def main(params, nb_cpu, nb_gpu, use_gpu):

    logger = init_logging(params.logfile)
    logger = logging.getLogger('circus.filtering')
    #################################################################
    do_filter = params.getboolean('filtering', 'filter')
    filter_done = params.getboolean('noedits', 'filter_done')
    artefacts_done = params.getboolean('noedits', 'artefacts_done')
    median_done = params.getboolean('noedits', 'median_done')
    clean_artefact = params.getboolean('triggers', 'clean_artefact')
    remove_median = params.getboolean('filtering', 'remove_median')
    nodes, edges = get_nodes_and_edges(params)

    #################################################################

    def filter_file(data_file_in, data_file_out, do_filtering,
                    do_remove_median):

        try:
            cut_off = params.getfloat('filtering', 'cut_off')
            cut_off = [cut_off, 0.95 * (params.rate / 2.)]
        except Exception:
            cut_off = params.get('filtering', 'cut_off')
            cut_off = cut_off.split(',')
            try:
                cut_off[0] = float(cut_off[0])
            except Exception:
                if comm.rank == 0:
                    print_and_log(
                        ['First value of cut off must be a valid number'],
                        'error', logger)
                sys.exit(0)

            cut_off[1] = cut_off[1].replace(' ', '')
            if cut_off[1] == 'auto':
                cut_off[1] = 0.95 * (params.rate / 2.)
            else:
                try:
                    cut_off[1] = float(cut_off[1])
                except Exception:
                    if comm.rank == 0:
                        print_and_log([
                            'Second value of cut off must either auto, or a valid a number'
                        ], 'error', logger)
                    sys.exit(0)

        chunk_size = params.getint('data', 'chunk_size')
        nb_chunks, _ = data_file_in.analyze(chunk_size)

        b, a = signal.butter(3, np.array(cut_off) / (params.rate / 2.), 'pass')
        all_chunks = numpy.arange(nb_chunks, dtype=numpy.int64)
        to_process = all_chunks[comm.rank::comm.size]
        loc_nb_chunks = len(to_process)
        N_total = params.nb_channels

        if comm.rank == 0:
            to_write = []
            if do_filtering:
                to_write += [
                    "Filtering the signal with a Butterworth filter in (%g, %g) Hz"
                    % (cut_off[0], cut_off[1])
                ]
            if do_remove_median:
                to_write += [
                    "Median over all channels is substracted to each channels"
                ]

            print_and_log(to_write, 'default', logger)

        to_explore = xrange(comm.rank, nb_chunks, comm.size)

        if comm.rank == 0:
            to_explore = get_tqdm_progressbar(to_explore)

        for count, gidx in enumerate(to_explore):

            local_chunk, t_offset = data_file_in.get_data(gidx, chunk_size)

            if do_filtering:
                for i in nodes:
                    try:
                        local_chunk[:, i] = signal.filtfilt(
                            b, a, local_chunk[:, i])
                    except Exception:
                        pass
                local_chunk[:, i] -= numpy.median(local_chunk[:, i])

            if do_remove_median:
                if not numpy.all(nodes == numpy.arange(N_total)):
                    global_median = numpy.median(
                        numpy.take(local_chunk, nodes, axis=1), 1)
                else:
                    global_median = numpy.median(local_chunk, 1)
                for i in nodes:
                    local_chunk[:, i] -= global_median

            if data_file_in != data_file_out and data_file_in.is_first_chunk(
                    gidx, nb_chunks):
                if data_file_in.is_stream:
                    g_offset = t_offset - numpy.sum(
                        data_file_in.
                        _times[:data_file_in.
                               _get_streams_index_by_time(t_offset) + 1])
                else:
                    g_offset = t_offset - data_file_in.t_start
            else:
                g_offset = t_offset

            data_file_out.set_data(g_offset, local_chunk)

        comm.Barrier()

    def compute_artefacts(data_file):

        chunk_size = params.getint('data', 'chunk_size')
        trig_in_ms = params.getboolean('triggers', 'trig_in_ms')
        artefacts = numpy.loadtxt(params.get('triggers', 'trig_file'))
        windows = numpy.loadtxt(params.get('triggers', 'trig_windows'))
        make_plots = params.get('triggers', 'make_plots')
        plot_path = os.path.join(params.get('data', 'data_file_noext'),
                                 'plots')

        if len(windows.shape) == 1:
            windows = windows.reshape(1, 2)

        if len(artefacts.shape) == 1:
            artefacts = artefacts.reshape(1, 2)

        if trig_in_ms:
            if comm.rank == 0:
                print_and_log(['Artefact times are read in ms'], 'debug',
                              logger)
            artefacts[:, 1] *= numpy.int64(data_file.sampling_rate * 1e-3)
            windows[:, 1] *= numpy.int64(data_file.sampling_rate * 1e-3)
        else:
            if comm.rank == 0:
                print_and_log(['Artefact times are read in timesteps'],
                              'debug', logger)
            artefacts = artefacts.astype(numpy.int64)
            windows = windows.astype(numpy.int64)

        nb_stimuli = len(numpy.unique(artefacts[:, 0]))
        mytest = nb_stimuli == len(windows)

        if not mytest:
            if comm.rank == 0:
                print_and_log(['Error in the trigger files'], 'error', logger)
            sys.exit(0)

        all_labels = artefacts[:, 0].astype(numpy.int32)
        all_times = artefacts[:, 1].astype(numpy.int32)

        mask = (all_times >= 0) & (all_times + numpy.max(windows[:, 1]) <
                                   data_file.t_stop)
        all_times = numpy.compress(mask, all_times)
        all_labels = numpy.compress(mask, all_labels)

        local_labels = numpy.unique(all_labels)[comm.rank::comm.size]

        if comm.rank == 0:
            to_write = [
                "Computing averaged artefacts from %d stimuli" % (nb_stimuli)
            ]
            print_and_log(to_write, 'default', logger)
            if not os.path.exists(plot_path):
                os.makedirs(plot_path)
            local_labels = get_tqdm_progressbar(local_labels)

        comm.Barrier()
        # First we need to get the average artefacts
        art_dict = {}
        for count, artefact in enumerate(local_labels):
            indices = numpy.where(all_labels == artefact)[0].astype(
                numpy.int32)
            tmp = numpy.where(windows[:, 0] == artefact)[0]
            tau = numpy.int64(windows[tmp, 1])
            pspikes = all_times[indices]
            times = numpy.sort(numpy.random.permutation(pspikes)[:500])
            if len(numpy.where(numpy.diff(times) < tau)[0]) > 0:
                if comm.rank == 0:
                    print_and_log([
                        'Stimulation times for artefact %d are too close!' %
                        artefact
                    ], 'error', logger)
                sys.exit(0)

            art_dict[artefact] = get_artefact(params, times, tau, nodes)
            if make_plots not in ['None', '']:
                save = [plot_path, '%d.%s' % (artefact, make_plots)]
                plot.view_artefact(art_dict[artefact], save=save)

        return art_dict

    def remove_artefacts(data_file, art_dict):

        chunk_size = params.getint('data', 'chunk_size')
        trig_in_ms = params.getboolean('triggers', 'trig_in_ms')
        artefacts = numpy.loadtxt(params.get('triggers',
                                             'trig_file')).astype(numpy.int64)
        windows = numpy.loadtxt(params.get('triggers',
                                           'trig_windows')).astype(numpy.int64)
        make_plots = params.get('triggers', 'make_plots')
        plot_path = os.path.join(params.get('data', 'data_file_noext'),
                                 'plots')

        if len(windows.shape) == 1:
            windows = windows.reshape(1, 2)

        if len(artefacts.shape) == 1:
            artefacts = artefacts.reshape(1, 2)

        if trig_in_ms:
            if comm.rank == 0:
                print_and_log(['Artefact times are read in ms'], 'debug',
                              logger)
            artefacts[:, 1] *= numpy.int64(data_file.sampling_rate * 1e-3)
            windows[:, 1] *= numpy.int64(data_file.sampling_rate * 1e-3)
        else:
            if comm.rank == 0:
                print_and_log(['Artefact times are read in timesteps'],
                              'debug', logger)
            artefacts = artefacts.astype(numpy.int64)
            windows = windows.astype(numpy.int64)

        nb_stimuli = len(numpy.unique(artefacts[:, 0]))
        mytest = nb_stimuli == len(windows)

        if not mytest:
            if comm.rank == 0:
                print_and_log(['Error in the trigger files'], 'error', logger)
            sys.exit(0)

        all_labels = artefacts[:, 0].astype(numpy.int32)
        all_times = artefacts[:, 1].astype(numpy.int32)
        local_labels = numpy.unique(all_labels)[comm.rank::comm.size]

        mask = numpy.in1d(all_labels, local_labels)
        all_times = numpy.compress(mask, all_times)
        all_labels = numpy.compress(mask, all_labels)

        mask = (all_times >= 0) & (all_times < data_file.t_stop)
        all_times = numpy.compress(mask, all_times)
        all_labels = numpy.compress(mask, all_labels)

        if comm.rank == 0:
            to_write = ["Removing artefacts from %d stimuli" % (nb_stimuli)]
            print_and_log(to_write, 'default', logger)
            all_times = get_tqdm_progressbar(all_times)

        comm.Barrier()

        for count, time in enumerate(all_times):

            label = all_labels[count]
            tmp = numpy.where(windows[:, 0] == label)[0][0]
            tau = numpy.int64(windows[tmp, 1])

            if (data_file.t_stop - time) < tau:
                tau = max_offset - time

            local_chunk = data_file.get_snippet(time, tau)

            for idx, i in enumerate(nodes):
                local_chunk[:, i] -= art_dict[label][idx, :tau]
            data_file.set_data(time, local_chunk)

        comm.Barrier()

    if comm.rank == 0:
        print_and_log(['Initializing the filtering step...'], 'debug', logger)

    if params.getboolean('data', 'overwrite'):
        if comm.rank == 0:
            print_and_log(['Reading the input file...'], 'debug', logger)

        data_file_in = params.get_data_file()
        data_file_out = data_file_in
    else:
        if comm.rank == 0:
            print_and_log(
                ['Overwrite is set to False, so creating a new datafile...'],
                'debug', logger)

        if comm.rank == 0:
            print_and_log(['Reading the input file...'], 'debug', logger)

        if os.path.exists(params.get('data', 'data_file_no_overwrite')):
            has_been_created = True
        else:
            has_been_created = False

        if not has_been_created and (filter_done or median_done
                                     or artefacts_done):
            if comm.rank == 0:
                print_and_log([
                    'The filtering is done but file not present. See no_edits section'
                ], 'error', logger)
            sys.exit(0)

        if not has_been_created:
            data_file_in = params.get_data_file(
                source=True, has_been_created=has_been_created)
        else:
            data_file_in = params.get_data_file(
                source=False, has_been_created=has_been_created)

        if comm.rank == 0:
            print_and_log(
                ['Reading the output file and allocating ressources...'],
                'debug', logger)

        description = data_file_in.get_description()
        description['data_dtype'] = 'float32'
        description['dtype_offset'] = 0
        description['data_offset'] = 0

        data_file_out = params.get_data_file(is_empty=not has_been_created,
                                             params=description)

        if not has_been_created:
            data_file_out.allocate(shape=data_file_in.shape)

        comm.Barrier()

    if clean_artefact:
        if not (os.path.exists(params.get('triggers', 'trig_file'))
                and os.path.exists(params.get('triggers', 'trig_windows'))):
            if comm.rank == 0:
                print_and_log(
                    ['trig_file or trig_windows file can not be found'],
                    'error', logger)
            sys.exit(0)

    to_write = []

    if do_filter and filter_done:
        do_filter = False
        to_write += ["Filtering has already been done"]
    if remove_median and median_done:
        remove_median = False
        to_write += [
            "Median over all channels has already been substracted to each channels"
        ]

    if comm.rank == 0 and len(to_write) > 0:
        print_and_log(to_write, 'info', logger)

    if params.getboolean('data', 'overwrite'):
        data_file_in.open(mode='r+')
    else:
        data_file_in.open()
        data_file_out.open(mode='r+')

    if do_filter or remove_median:
        filter_file(data_file_in, data_file_out, do_filter, remove_median)

    if comm.rank == 0:
        if do_filter:
            params.write('noedits', 'filter_done', 'True')
        if remove_median:
            params.write('noedits', 'median_done', 'True')

    if clean_artefact and artefacts_done:
        clean_artefact = False
        if comm.rank == 0:
            print_and_log(['Artefacts have already been removed'], 'debug',
                          logger)

    if clean_artefact:
        art_dict = compute_artefacts(data_file_in)
        remove_artefacts(data_file_out, art_dict)

    if comm.rank == 0:
        if clean_artefact:
            params.write('noedits', 'artefacts_done', 'True')

    data_file_in.close()
    if not params.getboolean('data', 'overwrite'):
        data_file_out.close()

    comm.Barrier()
def main(params, nb_cpu, nb_gpu, use_gpu, file_name, benchmark, sim_same_elec):
    """
    Useful tool to create synthetic datasets for benchmarking.
    
    Arguments
    ---------
    benchmark : {'fitting', 'clustering', 'synchrony', 'pca-validation', 'smart-search', 'drifts'}
        
    """
    if sim_same_elec is None:
        sim_same_elec = 0.8

    logger         = init_logging(params.logfile)
    logger         = logging.getLogger('circus.benchmarking')

    numpy.random.seed(265)
    file_name      = os.path.abspath(file_name)
    data_path      = os.path.dirname(file_name)
    data_suff, ext = os.path.splitext(os.path.basename(file_name))
    file_out, ext  = os.path.splitext(file_name)

    if ext == '':
        ext = '.dat'
        file_name += ext
    
    if ext != '.dat':
        if comm.rank == 0:
            print_and_log(['Benchmarking produces raw files: select a .dat extension'], 'error', logger)
        sys.exit(0)

    if benchmark not in ['fitting', 'clustering', 'synchrony', 'smart-search', 'drifts']:
        if comm.rank == 0:
            print_and_log(['Benchmark need to be in [fitting, clustering, synchrony, smart-search, drifts]'], 'error', logger)
        sys.exit(0)

    # The extension `.p` or `.pkl` or `.pickle` seems more appropriate than `.pic`.
    # see: http://stackoverflow.com/questions/4530111/python-saving-objects-and-using-pickle-extension-of-filename
    # see: https://wiki.python.org/moin/UsingPickle
    def write_benchmark(filename, benchmark, cells, rates, amplitudes, sampling, probe, trends=None):
        """Save benchmark parameters in a file to remember them."""
        import cPickle
        to_write = {'benchmark' : benchmark}
        to_write['cells']      = cells
        to_write['rates']      = rates
        to_write['probe']      = probe
        to_write['amplitudes'] = amplitudes
        to_write['sampling']   = sampling
        if benchmark == 'drifts':
            to_write['drifts'] = trends
        cPickle.dump(to_write, open(filename + '.pic', 'w'))

    # Retrieve some key parameters.
    templates = io.load_data(params, 'templates')
    N_tm = templates.shape[1] // 2
    trends          = None

    # Normalize some variables.
    if benchmark == 'fitting':
        nb_insert       = 25
        n_cells         = numpy.random.random_integers(0, N_tm - 1, nb_insert)
        rate            = nb_insert * [10]
        amplitude       = numpy.linspace(0.5, 5, nb_insert)
    if benchmark == 'clustering':
        n_point         = 5
        n_cells         = numpy.random.random_integers(0, N_tm - 1, n_point ** 2)
        x, y            = numpy.mgrid[0:n_point, 0:n_point]
        rate            = numpy.linspace(0.5, 20, n_point)[x.flatten()]
        amplitude       = numpy.linspace(0.5, 5, n_point)[y.flatten()]
    if benchmark == 'synchrony':
        nb_insert       = 5
        corrcoef        = 0.2
        n_cells         = nb_insert * [numpy.random.random_integers(0, N_tm - 1, 1)[0]]
        rate            = 10. / corrcoef
        amplitude       = 2
    if benchmark == 'pca-validation':
        nb_insert       = 10
        n_cells         = numpy.random.random_integers(0, N_tm - 1, nb_insert)
        rate_min        = 0.5
        rate_max        = 20.0
        rate            = rate_min + (rate_max - rate_min) * numpy.random.random_sample(nb_insert)
        amplitude_min   = 0.5
        amplitude_max   = 5.0
        amplitude       = amplitude_min + (amplitude_max - amplitude_min) * numpy.random.random_sample(nb_insert)
    if benchmark == 'smart-search':
        nb_insert       = 10
        n_cells         = nb_insert*[numpy.random.random_integers(0, templates.shape[1]//2-1, 1)[0]]
        rate            = 1 + 5*numpy.arange(nb_insert)
        amplitude       = 2
    if benchmark == 'drifts':
        n_point         = 5
        n_cells         = numpy.random.random_integers(0, templates.shape[1]//2-1, n_point**2)
        x, y            = numpy.mgrid[0:n_point,0:n_point]
        rate            = 5*numpy.ones(n_point)[x.flatten()]
        amplitude       = numpy.linspace(0.5, 5, n_point)[y.flatten()]
        trends          = numpy.random.randn(n_point**2)

    # Delete the output directory tree if this output directory exists.
    if comm.rank == 0:
        if os.path.exists(file_out):
            shutil.rmtree(file_out)

    # Check and normalize some variables.
    if n_cells is None:
        n_cells    = 1
        cells      = [numpy.random.permutation(numpy.arange(n_cells))[0]]
    elif not numpy.iterable(n_cells):
        cells      = [n_cells]
        n_cells    = 1
    else:
        cells      = n_cells
        n_cells    = len(cells)

    if numpy.iterable(rate):
        assert len(rate) == len(cells), "Should have the same number of rates and cells"
    else:
        rate = [rate] * len(cells)

    if numpy.iterable(amplitude):
        assert len(amplitude) == len(cells), "Should have the same number of amplitudes and cells"
    else:
        amplitude = [amplitude] * len(cells)

    # Retrieve some additional key parameters.
    #params           = detect_memory(params)
    data_file        = params.get_data_file(source=True)
    N_e              = params.getint('data', 'N_e')
    N_total          = params.nb_channels
    hdf5_compress    = params.getboolean('data', 'hdf5_compress')
    nodes, edges     = get_nodes_and_edges(params)
    N_t              = params.getint('detection', 'N_t')
    inv_nodes        = numpy.zeros(N_total, dtype=numpy.int32)
    inv_nodes[nodes] = numpy.argsort(nodes)
    do_temporal_whitening = params.getboolean('whitening', 'temporal')
    do_spatial_whitening  = params.getboolean('whitening', 'spatial')
    N_tm_init             = templates.shape[1]//2
    thresholds            = io.load_data(params, 'thresholds')
    limits                = io.load_data(params, 'limits')
    best_elecs            = io.load_data(params, 'electrodes')
    norms                 = io.load_data(params, 'norm-templates')

    # Create output directory if it does not exist.
    if comm.rank == 0:
        if not os.path.exists(file_out):
            os.makedirs(file_out)

    # Save benchmark parameters in a file to remember them.
    if comm.rank == 0:
        write_benchmark(file_out, benchmark, cells, rate, amplitude,
                        params.rate, params.get('data', 'mapping'), trends)

    # Synchronize all the threads/processes.
    comm.Barrier()

    if do_spatial_whitening:
        spatial_whitening  = io.load_data(params, 'spatial_whitening')
    if do_temporal_whitening:
        temporal_whitening = io.load_data(params, 'temporal_whitening')

    # Retrieve some additional key parameters.
    chunk_size     = params.getint('data', 'chunk_size')
    scalings       = []
    
    params.set('data', 'data_file', file_name)

    data_file_out = params.get_data_file(is_empty=True)
    data_file_out.allocate(shape=data_file.shape)

    # Synchronize all the threads/processes.
    comm.Barrier()

    # For each wanted synthesized cell insert a generated template in the set of
    # existing template.
    for gcount, cell_id in enumerate(cells):
        best_elec   = best_elecs[cell_id]
        indices     = inv_nodes[edges[nodes[best_elec]]]
        count       = 0
        new_indices = []
        all_elecs   = numpy.random.permutation(numpy.arange(N_e))
        reference   = templates[:, cell_id].toarray().reshape(N_e, N_t)
        # Initialize the similarity (i.e. default value).
        similarity = 1.0
        # Find the first eligible template for the wanted synthesized cell.
        while len(new_indices) != len(indices) or (similarity > sim_same_elec): 
            similarity  = 0
            if count == len(all_elecs):
                if comm.rank == 0:
                    print_and_log(["No electrode to move template %d (max similarity is %g)" %(cell_id, similarity)], 'error', logger)
                sys.exit(0)
            else:
                # Get the next shuffled electrode.
                n_elec = all_elecs[count]

                if benchmark not in ['synchrony', 'smart-search']:
                    # Process if the shuffled electrode and the nearest electrode
                    # to the synthesized cell are not identical.
                    local_test = n_elec != best_elec
                else:
                    # Process if the shuffled electrode and the nearest electrode
                    # to the synthesized cell are identical.
                    local_test = n_elec == best_elec

                if local_test:
                    # Shuffle the neighboring electrodes whithout modifying
                    # the nearest electrode to the synthesized cell.
                    new_indices = inv_nodes[edges[nodes[n_elec]]]
                    idx = numpy.where(new_indices != best_elec)[0]
                    new_indices[idx] = numpy.random.permutation(new_indices[idx])

                    if len(new_indices) == len(indices):
                        # Shuffle the templates on the neighboring electrodes.
                        new_temp = numpy.zeros(reference.shape,
                                               dtype=numpy.float32)
                        new_temp[new_indices, :] = reference[indices, :]
                        # Compute the scaling factor which normalize the
                        # shuffled template.
                        gmin = new_temp.min()
                        data = numpy.where(new_temp == gmin)
                        scaling = -thresholds[data[0][0]]/gmin
                        for i in xrange(templates.shape[1]//2):
                            match = templates[:, i].toarray().reshape(N_e, N_t)
                            d = numpy.corrcoef(match.flatten(),
                                               scaling * new_temp.flatten())[0, 1]
                            if d > similarity:
                                similarity = d
                else:
                    new_indices = []
            # Go to the next shuffled electrode.
            count += 1

        #if comm.rank == 0:
        #    print "Template", cell_id, "is shuffled from electrode", best_elec, "to", n_elec, "(max similarity is %g)" %similarity

        N_tm           = templates.shape[1]//2
        to_insert      = numpy.zeros(reference.shape, dtype=numpy.float32)
        to_insert[new_indices] = scaling*amplitude[gcount]*templates[:, cell_id].toarray().reshape(N_e, N_t)[indices]
        to_insert2     = numpy.zeros(reference.shape, dtype=numpy.float32)
        to_insert2[new_indices] = scaling*amplitude[gcount]*templates[:, cell_id + N_tm].toarray().reshape(N_e, N_t)[indices]

        ## Insert the selected template.
        
        # Retrieve the number of existing templates in the dataset.
        N_tm           = templates.shape[1]//2

        # Generate the template of the synthesized cell from the selected
        # template, the target amplitude and the rescaling (i.e. threshold of
        # the target electrode).
        to_insert = numpy.zeros(reference.shape, dtype=numpy.float32)
        to_insert[new_indices] = scaling * amplitude[gcount] * templates[:, cell_id].toarray().reshape(N_e, N_t)[indices]
        to_insert = to_insert.flatten()
        to_insert2 = numpy.zeros(reference.shape, dtype=numpy.float32)
        to_insert2[new_indices] = scaling * amplitude[gcount] * templates[:, cell_id + N_tm].toarray().reshape(N_e, N_t)[indices]
        to_insert2 = to_insert2.flatten()

        # Compute the norm of the generated template.
        mynorm     = numpy.sqrt(numpy.sum(to_insert ** 2) / (N_e * N_t))
        mynorm2    = numpy.sqrt(numpy.sum(to_insert2 ** 2) / (N_e * N_t))

        # Insert the limits of the generated template.
        limits     = numpy.vstack((limits, limits[cell_id]))
        # Insert the best electrode of the generated template.
        best_elecs = numpy.concatenate((best_elecs, [n_elec]))

        # Insert the norm of the generated template (i.e. central component and
        # orthogonal component).
        norms      = numpy.insert(norms, N_tm, mynorm)
        norms      = numpy.insert(norms, 2 * N_tm + 1, mynorm2)
        # Insert the scaling of the generated template.
        scalings  += [scaling]

        # Retrieve the data about the existing templates.
        templates = templates.tocoo()
        xdata     = templates.row
        ydata     = templates.col
        zdata     = templates.data

        # Shift by one the orthogonal components of the existing templates.
        idx       = numpy.where(ydata >= N_tm)[0]
        ydata[idx] += 1

        # Insert the central component of the selected template.
        dx    = to_insert.nonzero()[0].astype(numpy.int32)
        xdata = numpy.concatenate((xdata, dx))
        ydata = numpy.concatenate((ydata, N_tm * numpy.ones(len(dx), dtype=numpy.int32)))
        zdata = numpy.concatenate((zdata, to_insert[dx]))

        # Insert the orthogonal component of the selected template.
        dx    = to_insert2.nonzero()[0].astype(numpy.int32)
        xdata = numpy.concatenate((xdata, dx))
        ydata = numpy.concatenate((ydata, (2 * N_tm + 1) * numpy.ones(len(dx), dtype=numpy.int32)))
        zdata = numpy.concatenate((zdata, to_insert2[dx]))

        # Recontruct the matrix of templates.
        templates = scipy.sparse.csc_matrix((zdata, (xdata, ydata)), shape=(N_e * N_t, 2 * (N_tm + 1)))

    # Remove all the expired data.
    if benchmark == 'pca-validation':
        # Remove all the expired data.
        N_tm_init = 0
        N_tm = templates.shape[1] / 2

        limits = limits[N_tm - nb_insert:, :]
        best_elecs = best_elecs[N_tm - nb_insert:]
        norms = numpy.concatenate((norms[N_tm-nb_insert:N_tm], norms[2*N_tm-nb_insert:2*N_tm]))
        scalings = scalings
        
        templates = templates.tocoo()
        xdata = templates.row
        ydata = templates.col
        zdata = templates.data
        
        idx_cen = numpy.logical_and(N_tm - nb_insert <= ydata, ydata < N_tm)
        idx_cen = numpy.where(idx_cen)[0]
        idx_ort = numpy.logical_and(2 * N_tm - nb_insert <= ydata, ydata < 2 * N_tm)
        idx_ort = numpy.where(idx_ort)[0]
        ydata[idx_cen] = ydata[idx_cen] - (N_tm - nb_insert)
        ydata[idx_ort] = ydata[idx_ort] - 2 * (N_tm - nb_insert)
        idx = numpy.concatenate((idx_cen, idx_ort))
        xdata = xdata[idx]
        ydata = ydata[idx]
        zdata = zdata[idx]
        templates = scipy.sparse.csc_matrix((zdata, (xdata, ydata)), shape=(N_e * N_t, 2 * nb_insert))
        
    # Retrieve the information about the organisation of the chunks of data.
    nb_chunks, last_chunk_len = data_file.analyze(chunk_size)

    # Display informations about the generated benchmark.
    if comm.rank == 0:
        print_and_log(["Generating benchmark data [%s] with %d cells" %(benchmark, n_cells)], 'info', logger)
        purge(file_out, '.data')


    template_shift = params.getint('detection', 'template_shift')
    all_chunks     = numpy.arange(nb_chunks)
    to_process     = all_chunks[numpy.arange(comm.rank, nb_chunks, comm.size)]
    loc_nb_chunks  = len(to_process)
    numpy.random.seed(comm.rank)

    to_explore = xrange(comm.rank, nb_chunks, comm.size)

    # Initialize the progress bar about the generation of the benchmark.
    if comm.rank == 0:
        to_explore = get_tqdm_progressbar(to_explore)

    # Open the file for collective I/O.
    #g = myfile.Open(comm, file_name, MPI.MODE_RDWR)
    #g.Set_view(data_offset, data_mpi, data_mpi)
    data_file_out.open(mode='r+')

    # Open the thread/process' files to collect the results.
    spiketimes_filename = os.path.join(file_out, data_suff + '.spiketimes-%d.data' %comm.rank)
    spiketimes_file = open(spiketimes_filename, 'wb')
    amplitude_filename = os.path.join(file_out, data_suff + '.amplitudes-%d.data' %comm.rank)
    amplitudes_file = open(amplitude_filename, 'wb')
    templates_filename = os.path.join(file_out, data_suff + '.templates-%d.data' %comm.rank)
    templates_file = open(templates_filename, 'wb')
    real_amps_filename = os.path.join(file_out, data_suff + '.real_amps-%d.data' %comm.rank)
    real_amps_file = open(real_amps_filename, 'wb')
    voltages_filename = os.path.join(file_out, data_suff + '.voltages-%d.data' %comm.rank)
    voltages_file = open(voltages_filename, 'wb')

    # For each chunk of data associate to the current thread/process generate
    # the new chunk of data (i.e. with considering the added synthesized cells).
    for count, gidx in enumerate(to_explore):

        #if (last_chunk_len > 0) and (gidx == (nb_chunks - 1)):
        #    chunk_len  = last_chunk_len
        #    chunk_size = last_chunk_len // N_total

        result         = {'spiketimes' : [], 'amplitudes' : [], 
                          'templates' : [], 'real_amps' : [],
                          'voltages' : []}
        offset         = gidx * chunk_size
        local_chunk, t_offset = data_file.get_data(gidx, chunk_size, nodes=nodes)

        if benchmark == 'pca-validation':
            # Clear the current data chunk.
            local_chunk = numpy.zeros(local_chunk.shape, dtype=local_chunk.dtype)

        # Handle whitening if necessary.
        if do_spatial_whitening:
            local_chunk = numpy.dot(local_chunk, spatial_whitening)
        if do_temporal_whitening:
            local_chunk = scipy.ndimage.filters.convolve1d(local_chunk,
                                                           temporal_whitening,
                                                           axis=0,
                                                           mode='constant')

        if benchmark is 'synchrony':
            # Generate some spike indices (i.e. times) at the given rate for
            # 'synchrony' mode. Each synthesized cell will use a subset of this
            # spike times.
            mips = numpy.random.rand(chunk_size) < rate[0] / float(params.rate)

        # For each synthesized cell generate its spike indices (i.e.times) and
        # add them to the dataset.
        for idx in xrange(len(cells)):
            if benchmark is 'synchrony':
                # Choose a subset of the spike indices generated before. The
                # size of this subset is parameterized by the target correlation
                # coefficients.
                sidx       = numpy.where(mips == True)[0]
                spikes     = numpy.zeros(chunk_size, dtype=numpy.bool)
                spikes[sidx[numpy.random.rand(len(sidx)) < corrcoef]] = True
            else:
                # Generate some spike indices at the given rate.
                spikes     = numpy.random.rand(chunk_size) < rate[idx] / float(params.rate)
            if benchmark == 'drifts':
                amplitudes = numpy.ones(len(spikes)) + trends[idx]*((spikes + offset)/(5*60*float(params.rate)))
            else:
                amplitudes = numpy.ones(len(spikes))
            # Padding with `False` to avoid the insertion of partial spikes at
            # the edges of the signal.
            spikes[:N_t]   = False
            spikes[-N_t:]  = False
            # Find the indices of the spike samples.
            spikes         = numpy.where(spikes == True)[0]
            n_template     = N_tm_init + idx
            loc_template   = templates[:, n_template].toarray().reshape(N_e, N_t)
            first_flat     = loc_template.T.flatten()
            norm_flat      = numpy.sum(first_flat ** 2)
            # For each index (i.e. spike sample location) add the spike to the
            # chunk of data.
            refractory     = int(5 * 1e-3 * params.rate)         
            t_last         = - refractory
            for scount, spike in enumerate(spikes):
                if (spike - t_last) > refractory:
                    local_chunk[spike-template_shift:spike+template_shift+1, :] += amplitudes[scount]*loc_template.T
                    amp        = numpy.dot(local_chunk[spike-template_shift:spike+template_shift+1, :].flatten(), first_flat)
                    amp       /= norm_flat
                    result['real_amps']  += [amp]
                    result['spiketimes'] += [spike + offset]
                    result['amplitudes'] += [(amplitudes[scount], 0)]
                    result['templates']  += [n_template]
                    result['voltages']   += [local_chunk[spike, best_elecs[idx]]]
                    t_last                = spike

        # Write the results into the thread/process' files.
        spikes_to_write     = numpy.array(result['spiketimes'], dtype=numpy.uint32)
        amplitudes_to_write = numpy.array(result['amplitudes'], dtype=numpy.float32)
        templates_to_write  = numpy.array(result['templates'], dtype=numpy.int32)
        real_amps_to_write  = numpy.array(result['real_amps'], dtype=numpy.float32)
        voltages_to_write   = numpy.array(result['voltages'], dtype=numpy.float32)

        spiketimes_file.write(spikes_to_write.tostring())   
        amplitudes_file.write(amplitudes_to_write.tostring())
        templates_file.write(templates_to_write.tostring())
        real_amps_file.write(real_amps_to_write.tostring())
        voltages_file.write(voltages_to_write.tostring())

        #print count, 'spikes inserted...'
        #new_chunk    = numpy.zeros((chunk_size, N_total), dtype=numpy.float32)
        #new_chunk[:, nodes] = local_chunk

        # Overwrite the new chunk of data using explicit offset. 
        #new_chunk   = new_chunk.flatten()
        #g.Write_at(gidx * chunk_len, new_chunk)
        data_file_out.set_data(offset, local_chunk)

        # Update the progress bar about the generation of the benchmark.
        
    # Close the thread/process' files.
    spiketimes_file.flush()
    os.fsync(spiketimes_file.fileno())
    spiketimes_file.close()

    amplitudes_file.flush()
    os.fsync(amplitudes_file.fileno())
    amplitudes_file.close()

    templates_file.flush()
    os.fsync(templates_file.fileno())
    templates_file.close()

    real_amps_file.flush()
    os.fsync(real_amps_file.fileno())
    real_amps_file.close()

    voltages_file.flush()
    os.fsync(voltages_file.fileno())
    voltages_file.close()


    # Close the file for collective I/O.
    data_file_out.close()
    data_file.close()

    
    # Synchronize all the threads/processes.
    comm.Barrier()

    
    ## Eventually, perform all the administrative tasks.
    ## (i.e. files and folders management).

    file_params = file_out + '.params'

    if comm.rank == 0:
        # Create `injected` directory if it does not exist
        result_path = os.path.join(file_out, 'injected') 
        if not os.path.exists(result_path):
            os.makedirs(result_path)

        # Copy initial configuration file from `<dataset1>.params` to `<dataset2>.params`.
        shutil.copy2(params.get('data', 'data_file_noext') + '.params', file_params)
        new_params = CircusParser(file_name)
        # Copy initial basis file from `<dataset1>/<dataset1>.basis.hdf5` to
        # `<dataset2>/injected/<dataset2>.basis.hdf5.
        shutil.copy2(params.get('data', 'file_out') + '.basis.hdf5',
                     os.path.join(result_path, data_suff + '.basis.hdf5'))


        # Save templates into `<dataset>/<dataset>.templates.hdf5`.
        mydata = h5py.File(os.path.join(file_out, data_suff + '.templates.hdf5'), 'w')
        templates = templates.tocoo()
        if hdf5_compress:
            mydata.create_dataset('temp_x', data=templates.row, compression='gzip')
            mydata.create_dataset('temp_y', data=templates.col, compression='gzip')
            mydata.create_dataset('temp_data', data=templates.data, compression='gzip')
        else:
            mydata.create_dataset('temp_x', data=templates.row)
            mydata.create_dataset('temp_y', data=templates.col)
            mydata.create_dataset('temp_data', data=templates.data)
        mydata.create_dataset('temp_shape', data=numpy.array([N_e, N_t, templates.shape[1]],
                                                             dtype=numpy.int32))
        mydata.create_dataset('limits', data=limits)
        mydata.create_dataset('norms', data=norms)
        mydata.close()

        # Save electrodes into `<dataset>/<dataset>.clusters.hdf5`.
        mydata = h5py.File(os.path.join(file_out, data_suff + '.clusters.hdf5'), 'w')
        mydata.create_dataset('electrodes', data=best_elecs)
        mydata.close()

    comm.Barrier()
    if comm.rank == 0:
        # Gather data from all threads/processes.
        f_next, extension = os.path.splitext(file_name)
        file_out_bis = os.path.join(f_next, os.path.basename(f_next))
        #new_params.set('data', 'file_out', file_out_bis) # Output file without suffix
        #new_params.set('data', 'file_out_suff', file_out_bis  + params.get('data', 'suffix'))
    
        new_params.get_data_file()
        io.collect_data(comm.size, new_params, erase=True, with_real_amps=True, with_voltages=True, benchmark=True)
        # Change some flags in the configuration file.
        new_params.write('whitening', 'temporal', 'False') # Disable temporal filtering
        new_params.write('whitening', 'spatial', 'False') # Disable spatial filtering
        new_params.write('data', 'data_dtype', 'float32') # Set type of the data to float32
        new_params.write('data', 'dtype_offset', 'auto') # Set padding for data to auto
        # Move results from `<dataset>/<dataset>.result.hdf5` to
        # `<dataset>/injected/<dataset>.result.hdf5`.
        
        shutil.move(os.path.join(file_out, data_suff + '.result.hdf5'), os.path.join(result_path, data_suff + '.result.hdf5'))
                
        # Save scalings into `<dataset>/injected/<dataset>.scalings.npy`.
        numpy.save(os.path.join(result_path, data_suff + '.scalings'), scalings)

        file_name_noext, ext = os.path.splitext(file_name)

        # Copy basis from `<dataset>/injected/<dataset>.basis.hdf5` to
        # `<dataset>/<dataset>.basis.hdf5`.
        shutil.copy2(os.path.join(result_path, data_suff + '.basis.hdf5'),
                     os.path.join(file_out, data_suff + '.basis.hdf5'))

        if benchmark not in ['fitting', 'synchrony']:
            # Copy templates from `<dataset>/<dataset>.templates.hdf5` to
            # `<dataset>/injected/<dataset>.templates.hdf5`
            shutil.move(os.path.join(file_out, data_suff + '.templates.hdf5'),
                        os.path.join(result_path, data_suff + '.templates.hdf5'))
def main(params, nb_cpu, nb_gpu, use_gpu):

    #################################################################
    # params = detect_memory(params)
    _ = init_logging(params.logfile)
    SHARED_MEMORY = get_shared_memory_flag(params)
    logger = logging.getLogger('circus.fitting')
    data_file = params.data_file
    n_e = params.getint('data', 'N_e')
    n_total = params.nb_channels
    n_t = params.getint('detection', 'N_t')
    template_shift = params.getint('detection', 'template_shift')
    # file_out = params.get('data', 'file_out')
    file_out_suff = params.get('data', 'file_out_suff')
    sign_peaks = params.get('detection', 'peaks')
    matched_filter = params.getboolean('detection', 'matched-filter')
    # spike_thresh = params.getfloat('detection', 'spike_thresh')
    ratio_thresh = params.getfloat('fitting', 'ratio_thresh')
    two_components = params.getboolean('fitting', 'two_components')
    # spike_width = params.getfloat('detection', 'spike_width')
    # dist_peaks = params.getint('detection', 'dist_peaks')
    do_temporal_whitening = params.getboolean('whitening', 'temporal')
    do_spatial_whitening = params.getboolean('whitening', 'spatial')
    templates_normalization = params.getboolean('clustering', 'templates_normalization')  # TODO test, switch, test!
    chunk_size = detect_memory(params, fitting=True)
    gpu_only = params.getboolean('fitting', 'gpu_only')
    nodes, edges = get_nodes_and_edges(params)
    tmp_limits = params.get('fitting', 'amp_limits').replace('(', '').replace(')', '').split(',')
    tmp_limits = [float(v) for v in tmp_limits]
    amp_auto = params.getboolean('fitting', 'amp_auto')
    auto_nb_chances = params.getboolean('fitting', 'auto_nb_chances')
    if auto_nb_chances:
        nb_chances = io.load_data(params, 'nb_chances')
        max_nb_chances = params.getint('fitting', 'max_nb_chances')
        percent_nb_chances = params.getfloat('fitting', 'percent_nb_chances')
        total_nb_chances = max(1, numpy.nanpercentile(nb_chances, percent_nb_chances))
        total_nb_chances = min(total_nb_chances, max_nb_chances)
        if comm.rank == 0:
            print_and_log(['nb_chances set automatically to %g' %total_nb_chances], 'debug', logger)
    else:
        total_nb_chances = params.getfloat('fitting', 'nb_chances')
    max_chunk = params.getfloat('fitting', 'max_chunk')
    # noise_thr = params.getfloat('clustering', 'noise_thr')
    collect_all = params.getboolean('fitting', 'collect_all')
    min_second_component = params.getfloat('fitting', 'min_second_component')
    debug = params.getboolean('fitting', 'debug')
    ignore_dead_times = params.getboolean('triggers', 'ignore_times')
    inv_nodes = numpy.zeros(n_total, dtype=numpy.int32)
    inv_nodes[nodes] = numpy.arange(len(nodes))
    data_file.open()
    #################################################################

    if use_gpu:
        import cudamat as cmt
        # # Need to properly handle multi GPU per MPI nodes?
        if nb_gpu > nb_cpu:
            gpu_id = int(comm.rank // nb_cpu)
        else:
            gpu_id = 0
        cmt.cuda_set_device(gpu_id)
        cmt.init()
        cmt.cuda_sync_threads()

    if SHARED_MEMORY:
        templates, _ = io.load_data_memshared(params, 'templates', normalize=templates_normalization, transpose=True)
        N_tm, x = templates.shape
    else:
        templates = io.load_data(params, 'templates')
        x, N_tm = templates.shape

    temp_2_shift = 2 * template_shift
    temp_3_shift = 3 * template_shift
    full_gpu = use_gpu and gpu_only
    n_tm = N_tm // 2
    n_scalar = n_e * n_t

    temp_window = numpy.arange(-template_shift, template_shift + 1)
    size_window = n_e * (2 * template_shift + 1)

    if not amp_auto:
        amp_limits = numpy.zeros((n_tm, 2))
        amp_limits[:, 0] = tmp_limits[0]
        amp_limits[:, 1] = tmp_limits[1]
    else:
        amp_limits = io.load_data(params, 'limits')

    norm_templates = io.load_data(params, 'norm-templates')
    if not templates_normalization:
        norm_templates_2 = (norm_templates ** 2.0) * n_scalar

    if not SHARED_MEMORY:
        # Normalize templates (if necessary).
        if templates_normalization:
            for idx in range(templates.shape[1]):
                myslice = numpy.arange(templates.indptr[idx], templates.indptr[idx+1])
                templates.data[myslice] /= norm_templates[idx]
        # Transpose templates.
        templates = templates.T

    waveform_neg = numpy.empty(0)  # default assignment (for PyCharm code inspection)
    matched_thresholds_neg = None  # default assignment (for PyCharm code inspection)
    waveform_pos = numpy.empty(0)  # default assignment (for PyCharm code inspection)
    matched_thresholds_pos = None  # default assignment (for PyCharm code inspection)
    if matched_filter:
        if sign_peaks in ['negative', 'both']:
            waveform_neg = io.load_data(params, 'waveform')[::-1]
            waveform_neg /= (numpy.abs(numpy.sum(waveform_neg)) * len(waveform_neg))
            matched_thresholds_neg = io.load_data(params, 'matched-thresholds')
        if sign_peaks in ['positive', 'both']:
            waveform_pos = io.load_data(params, 'waveform-pos')[::-1]
            waveform_pos /= (numpy.abs(numpy.sum(waveform_pos)) * len(waveform_pos))
            matched_thresholds_pos = io.load_data(params, 'matched-thresholds-pos')

    if ignore_dead_times:
        all_dead_times = get_dead_times(params)
    else:
        all_dead_times = None  # default assignment (for PyCharm code inspection)

    thresholds = io.get_accurate_thresholds(params, ratio_thresh)

    neighbors = {}
    if collect_all:
        for i in range(0, n_tm):
            tmp = templates[i, :].toarray().reshape(n_e, n_t)
            if templates_normalization:
                tmp = tmp * norm_templates[i]
            neighbors[i] = numpy.where(numpy.sum(tmp, axis=1) != 0.0)[0]

    if use_gpu:
        templates = cmt.SparseCUDAMatrix(templates, copy_on_host=False)

    info_string = ''

    if comm.rank == 0:
        if use_gpu:
            info_string = "using %d GPUs" % comm.size
        else:
            info_string = "using %d CPUs" % comm.size

    comm.Barrier()

    c_overlap = io.get_overlaps(params, nb_cpu=nb_cpu, nb_gpu=nb_gpu, use_gpu=use_gpu)
    over_shape = c_overlap.get('over_shape')[:]
    n_over = int(numpy.sqrt(over_shape[0]))
    s_over = over_shape[1]
    # # If the number of overlaps is different from templates, we need to recompute them.
    if n_over != N_tm:
        if comm.rank == 0:
            print_and_log(['Templates have been modified, recomputing the overlaps...'], 'default', logger)
        c_overlap = io.get_overlaps(params, erase=True, nb_cpu=nb_cpu, nb_gpu=nb_gpu, use_gpu=use_gpu)
        over_shape = c_overlap.get('over_shape')[:]
        n_over = int(numpy.sqrt(over_shape[0]))
        s_over = over_shape[1]

    if SHARED_MEMORY:
        c_overs, _ = io.load_data_memshared(params, 'overlaps')
    else:
        c_overs = io.load_data(params, 'overlaps')

    comm.Barrier()

    if n_tm == 0:
        if comm.rank == 0:
            print_and_log(["No templates present. Redo clustering?"], 'default', logger)

        sys.exit(0)

    if comm.rank == 0:
        print_and_log(["Here comes the SpyKING CIRCUS %s and %d templates..." % (info_string, n_tm)], 'default', logger)
        purge(file_out_suff, '.data')

    if do_spatial_whitening:
        spatial_whitening = io.load_data(params, 'spatial_whitening')
    else:
        spatial_whitening = None  # default assignment (for PyCharm code inspection)
    if do_temporal_whitening:
        temporal_whitening = io.load_data(params, 'temporal_whitening')
    else:
        temporal_whitening = None  # default assignment (for PyCharm code inspection)

    if full_gpu:
        try:
            # If memory on the GPU is large enough, we load the overlaps onto it
            for i in range(n_over):
                c_overs[i] = cmt.SparseCUDAMatrix(c_overs[i], copy_on_host=False)
        except Exception:
            if comm.rank == 0:
                print_and_log(["Not enough memory on GPUs: GPUs are used for projection only"], 'info', logger)
            for i in range(n_over):
                if i in c_overs:
                    del c_overs[i]
            full_gpu = False

    nb_chunks, last_chunk_len = data_file.analyze(chunk_size)
    processed_chunks = int(min(nb_chunks, max_chunk))

    comm.Barrier()
    spiketimes_file = open(file_out_suff + '.spiketimes-%d.data' % comm.rank, 'wb')
    comm.Barrier()
    amplitudes_file = open(file_out_suff + '.amplitudes-%d.data' % comm.rank, 'wb')
    comm.Barrier()
    templates_file = open(file_out_suff + '.templates-%d.data' % comm.rank, 'wb')
    comm.Barrier()

    if collect_all:
        garbage_times_file = open(file_out_suff + '.gspiketimes-%d.data' % comm.rank, 'wb')
        comm.Barrier()
        garbage_temp_file = open(file_out_suff + '.gtemplates-%d.data' % comm.rank, 'wb')
        comm.Barrier()
    else:
        garbage_times_file = None  # default assignment (for PyCharm code inspection)
        garbage_temp_file = None  # default assignment (for PyCharm code inspection)

    if debug:
        # Open debug files.
        chunk_nbs_debug_file = open(file_out_suff + '.chunk_nbs_debug_%d.data' % comm.rank, mode='wb')
        comm.Barrier()
        iteration_nbs_debug_file = open(file_out_suff + '.iteration_nbs_debug_%d.data' % comm.rank, mode='wb')
        comm.Barrier()
        peak_nbs_debug_file = open(file_out_suff + '.peak_nbs_debug_%d.data' % comm.rank, mode='wb')
        comm.Barrier()
        peak_local_time_steps_debug_file = open(
            file_out_suff + '.peak_local_time_steps_debug_%d.data' % comm.rank, mode='wb'
        )
        comm.Barrier()
        peak_time_steps_debug_file = open(file_out_suff + '.peak_time_steps_debug_%d.data' % comm.rank, mode='wb')
        comm.Barrier()
        peak_scalar_products_debug_file = open(
            file_out_suff + '.peak_scalar_products_debug_%d.data' % comm.rank, mode='wb'
        )
        comm.Barrier()
        peak_solved_flags_debug_file = open(file_out_suff + '.peak_solved_flags_debug_%d.data' % comm.rank, mode='wb')
        comm.Barrier()
        template_nbs_debug_file = open(file_out_suff + '.template_nbs_debug_%d.data' % comm.rank, mode='wb')
        comm.Barrier()
        success_flags_debug_file = open(file_out_suff + '.success_flags_debug_%d.data' % comm.rank, mode='wb')
        comm.Barrier()
    else:
        chunk_nbs_debug_file = None  # default assignment (for PyCharm code inspection)
        iteration_nbs_debug_file = None  # default assignment (for PyCharm code inspection)
        peak_nbs_debug_file = None  # default assignment (for PyCharm code inspection)
        peak_local_time_steps_debug_file = None  # default assignment (for PyCharm code inspection)
        peak_time_steps_debug_file = None  # default assignment (for PyCharm code inspection)
        peak_scalar_products_debug_file = None  # default assignment (for PyCharm code inspection)
        peak_solved_flags_debug_file = None  # default assignment (for PyCharm code inspection)
        template_nbs_debug_file = None  # default assignment (for PyCharm code inspection)
        success_flags_debug_file = None  # default assignment (for PyCharm code inspection)

    if use_gpu and do_spatial_whitening:
        spatial_whitening = cmt.CUDAMatrix(spatial_whitening, copy_on_host=False)

    last_chunk_size = 0
    slice_indices = numpy.zeros(0, dtype=numpy.int32)

    to_explore = range(comm.rank, processed_chunks, comm.size)

    if comm.rank == 0:
        to_explore = get_tqdm_progressbar(params, to_explore)

    for gcount, gidx in enumerate(to_explore):
        # print "Node", comm.rank, "is analyzing chunk", gidx, "/", nb_chunks, " ..."
        # # We need to deal with the borders by taking chunks of size [0, chunck_size + template_shift].

        is_first = data_file.is_first_chunk(gidx, nb_chunks)
        is_last = data_file.is_last_chunk(gidx, nb_chunks)

        if not (is_first and is_last):
            if is_last:
                padding = (-temp_3_shift, 0)
            elif is_first:
                padding = (0, temp_3_shift)
            else:
                padding = (-temp_3_shift, temp_3_shift)
        else:
            padding = (0, 0)

        result = {
            'spiketimes': [],
            'amplitudes': [],
            'templates': [],
        }
        result_debug = {
            'chunk_nbs': [],
            'iteration_nbs': [],
            'peak_nbs': [],
            'peak_local_time_steps': [],
            'peak_time_steps': [],
            'peak_scalar_products': [],
            'peak_solved_flags': [],
            'template_nbs': [],
            'success_flags': [],
        }

        local_chunk, t_offset = data_file.get_data(gidx, chunk_size, padding, nodes=nodes)           
        len_chunk = len(local_chunk)

        if do_spatial_whitening:
            if use_gpu:
                local_chunk = cmt.CUDAMatrix(local_chunk, copy_on_host=False)
                local_chunk = local_chunk.dot(spatial_whitening).asarray()
            else:
                local_chunk = numpy.dot(local_chunk, spatial_whitening)
        if do_temporal_whitening:
            local_chunk = scipy.ndimage.filters.convolve1d(local_chunk, temporal_whitening, axis=0, mode='constant')

        # Extracting peaks.

        all_found_spikes = {}
        if collect_all:
            for i in range(n_e):
                all_found_spikes[i] = []

        local_peaktimes = [numpy.empty(0, dtype=numpy.uint32)]

        if matched_filter:
            if sign_peaks in ['positive', 'both']:
                filter_chunk = scipy.ndimage.filters.convolve1d(local_chunk, waveform_pos, axis=0, mode='constant')
                for i in range(n_e):
                    peaktimes = scipy.signal.find_peaks(filter_chunk[:, i], height=matched_thresholds_pos[i])[0]
                    local_peaktimes.append(peaktimes)
                    if collect_all:
                        all_found_spikes[i] += peaktimes.tolist()
            if sign_peaks in ['negative', 'both']:
                filter_chunk = scipy.ndimage.filters.convolve1d(local_chunk, waveform_neg, axis=0, mode='constant')
                for i in range(n_e):
                    peaktimes = scipy.signal.find_peaks(filter_chunk[:, i], height=matched_thresholds_neg[i])[0]
                    local_peaktimes.append(peaktimes)
                    if collect_all:
                        all_found_spikes[i] += peaktimes.tolist()
            local_peaktimes = numpy.concatenate(local_peaktimes)
        else:
            for i in range(n_e):
                if sign_peaks == 'negative':
                    peaktimes = scipy.signal.find_peaks(-local_chunk[:, i], height=thresholds[i])[0]
                elif sign_peaks == 'positive':
                    peaktimes = scipy.signal.find_peaks(local_chunk[:, i], height=thresholds[i])[0]
                elif sign_peaks == 'both':
                    peaktimes = scipy.signal.find_peaks(numpy.abs(local_chunk[:, i]), height=thresholds[i])[0]
                else:
                    raise ValueError("Unexpected value %s" % sign_peaks)
                local_peaktimes.append(peaktimes)
                if collect_all:
                    all_found_spikes[i] += peaktimes.tolist()
            local_peaktimes = numpy.concatenate(local_peaktimes)

        local_peaktimes = numpy.unique(local_peaktimes)

        g_offset = t_offset + padding[0]

        if ignore_dead_times:
            dead_indices = numpy.searchsorted(all_dead_times, [t_offset, t_offset + chunk_size])
            if dead_indices[0] != dead_indices[1]:
                is_included = numpy.in1d(local_peaktimes + g_offset, all_dead_times[dead_indices[0]:dead_indices[1]])
                local_peaktimes = local_peaktimes[~is_included]
                local_peaktimes = numpy.sort(local_peaktimes)
        else:
            dead_indices = None  # default assignment (for PyCharm code inspection)

        # print "Removing the useless borders..."
        local_borders = (template_shift, len_chunk - template_shift)
        idx = (local_peaktimes >= local_borders[0]) & (local_peaktimes < local_borders[1])
        local_peaktimes = numpy.compress(idx, local_peaktimes)

        if collect_all:
            for i in range(n_e):
                all_found_spikes[i] = numpy.array(all_found_spikes[i], dtype=numpy.uint32)

                if ignore_dead_times:
                    if dead_indices[0] != dead_indices[1]:
                        is_included = numpy.in1d(
                            all_found_spikes[i] + g_offset, all_dead_times[dead_indices[0]:dead_indices[1]]
                        )
                        all_found_spikes[i] = all_found_spikes[i][~is_included]
                        all_found_spikes[i] = numpy.sort(all_found_spikes[i])

                idx = (all_found_spikes[i] >= local_borders[0]) & (all_found_spikes[i] < local_borders[1])
                all_found_spikes[i] = numpy.compress(idx, all_found_spikes[i])

        nb_local_peak_times = len(local_peaktimes)

        if full_gpu:
            # all_indices = cmt.CUDAMatrix(all_indices)
            # tmp_gpu = cmt.CUDAMatrix(local_peaktimes.reshape((1, nb_local_peak_times)), copy_on_host=False)
            _ = cmt.CUDAMatrix(local_peaktimes.reshape((1, nb_local_peak_times)), copy_on_host=False)

        if nb_local_peak_times > 0:
            # print "Computing the b (should full_gpu by putting all chunks on GPU if possible?)..."

            if collect_all:
                c_local_chunk = local_chunk.copy()
            else:
                c_local_chunk = None  # default assignment (for PyCharm code inspection)

            sub_mat = local_chunk[local_peaktimes[:, None] + temp_window]
            sub_mat = sub_mat.transpose(2, 1, 0).reshape(size_window, nb_local_peak_times)

            del local_chunk

            if use_gpu:
                sub_mat = cmt.CUDAMatrix(sub_mat, copy_on_host=False)
                b = cmt.sparse_dot(templates, sub_mat)
            else:
                b = templates.dot(sub_mat)

            del sub_mat

            local_restriction = (t_offset, t_offset + chunk_size)
            all_spikes = local_peaktimes + g_offset

            # Because for GPU, slicing by columns is more efficient, we need to transpose b
            # b = b.transpose()
            if use_gpu and not full_gpu:
                b = b.asarray()           

            failure = numpy.zeros(nb_local_peak_times, dtype=numpy.int32)

            if full_gpu:
                mask = numpy.zeros((2 * n_tm, nb_local_peak_times), dtype=numpy.float32)
                mask[:n_tm, :] = 1
                # data = cmt.empty(mask.shape)
                _ = cmt.empty(mask.shape)
                patch_gpu = b.shape[1] == 1
            else:
                patch_gpu = None

            if collect_all:
                c_all_times = numpy.zeros((len_chunk, n_e), dtype=numpy.bool)
                c_min_times = numpy.maximum(numpy.arange(len_chunk) - template_shift, 0)
                c_max_times = numpy.minimum(numpy.arange(len_chunk) + template_shift + 1, len_chunk)
                for i in range(n_e):
                    c_all_times[all_found_spikes[i], i] = True
            else:
                c_all_times = None  # default assignment (for PyCharm code inspection)
                c_min_times = None  # default assignment (for PyCharm code inspection)
                c_max_times = None  # default assignment (for PyCharm code inspection)

            iteration_nb = 0
            local_max = 0
            numerous_argmax = False
            nb_argmax = n_tm
            best_indices = numpy.zeros(0, dtype=numpy.int32)

            data = b[:n_tm, :]
            flatten_data = data.ravel()

            while numpy.mean(failure) < total_nb_chances:

                # Is there a way to update sub_b * mask at the same time?
                if full_gpu:
                    b_array = b.asarray()
                else:
                    b_array = None

                if numerous_argmax:
                    if len(best_indices) == 0:
                        best_indices = largest_indices(flatten_data, nb_argmax)
                    best_template_index, peak_index = numpy.unravel_index(best_indices[0], data.shape)
                else:
                    best_template_index, peak_index = numpy.unravel_index(data.argmax(), data.shape)

                peak_scalar_product = data[best_template_index, peak_index]
                best_template2_index = best_template_index + n_tm

                if templates_normalization:
                    if full_gpu:
                        best_amp = b_array[best_template_index, peak_index] / n_scalar
                        best_amp2 = b_array[best_template2_index, peak_index] / n_scalar
                    else:
                        best_amp = b[best_template_index, peak_index] / n_scalar
                        if two_components:
                            best_amp2 = b[best_template2_index, peak_index] / n_scalar
                        else:
                            best_amp2 = 0.0
                    best_amp_n = best_amp / norm_templates[best_template_index]
                    best_amp2_n = best_amp2 / norm_templates[best_template2_index]
                else:
                    if full_gpu:
                        best_amp = b_array[best_template_index, peak_index]
                        best_amp = best_amp / norm_templates_2[best_template_index]
                        # TODO is `best_amp` value correct?
                        best_amp2 = b_array[best_template2_index, peak_index]
                        best_amp2 = best_amp2 / norm_templates_2[best_template2_index]
                        # TODO is `best_amp2` value correct?
                    else:
                        best_amp = b[best_template_index, peak_index]
                        best_amp = best_amp / norm_templates_2[best_template_index]
                        # TODO is `best_amp` value correct?
                        if two_components:
                            best_amp2 = b[best_template2_index, peak_index]
                            best_amp2 = best_amp2 / norm_templates_2[best_template2_index]
                            # TODO is `best_amp2` value correct?
                        else:
                            best_amp2 = 0.0

                    best_amp_n = best_amp
                    best_amp2_n = best_amp2

                # Verify amplitude constraint.
                a_min, a_max = amp_limits[best_template_index, :]

                if (a_min <= best_amp_n) & (best_amp_n <= a_max):
                    # Keep the matching.
                    peak_time_step = local_peaktimes[peak_index]

                    peak_data = (local_peaktimes - peak_time_step).astype(np.int32)
                    is_neighbor = np.where(np.abs(peak_data) <= temp_2_shift)[0]
                    idx_neighbor = peak_data[is_neighbor] + temp_2_shift
                    nb_neighbors = len(is_neighbor)
                    indices = np.zeros((s_over, nb_neighbors), dtype=np.int32)
                    indices[idx_neighbor, np.arange(nb_neighbors)] = 1

                    if full_gpu:
                        indices = cmt.CUDAMatrix(indices, copy_on_host=False)
                        if patch_gpu:
                            b_lines = b.get_col_slice(0, b.shape[0])
                        else:
                            b_lines = b.get_col_slice(is_neighbor[0], is_neighbor[-1]+1)
                        tmp1 = cmt.sparse_dot(c_overs[best_template_index], indices, mult=-best_amp)
                        tmp2 = cmt.sparse_dot(c_overs[best_template2_index], indices, mult=-best_amp2)
                        b_lines.add(tmp1.add(tmp2))
                        del tmp1, tmp2
                    else:
                        tmp1 = c_overs[best_template_index].multiply(-best_amp)
                        if numpy.abs(best_amp2) > min_second_component:
                            tmp1 += c_overs[best_template2_index].multiply(-best_amp2)
                        b[:, is_neighbor] += tmp1.dot(indices)

                    numerous_argmax = False

                    # Add matching to the result.
                    t_spike = all_spikes[peak_index]

                    if (t_spike >= local_restriction[0]) and (t_spike < local_restriction[1]):
                        result['spiketimes'] += [t_spike]
                        result['amplitudes'] += [(best_amp_n, best_amp2_n)]
                        result['templates'] += [best_template_index]
                    # Mark current matching as tried.
                    b[best_template_index, peak_index] = -numpy.inf
                    # Save debug data.
                    if debug:
                        result_debug['chunk_nbs'] += [gidx]
                        result_debug['iteration_nbs'] += [iteration_nb]
                        result_debug['peak_nbs'] += [peak_index]
                        result_debug['peak_local_time_steps'] += [local_peaktimes[peak_index]]
                        result_debug['peak_time_steps'] += [all_spikes[peak_index]]
                        result_debug['peak_scalar_products'] += [peak_scalar_product]
                        result_debug['peak_solved_flags'] += [b[best_template_index, peak_index]]
                        result_debug['template_nbs'] += [best_template_index]
                        result_debug['success_flags'] += [True]
                else:
                    # Reject the matching.
                    numerous_argmax = True
                    # Update failure counter of the peak.
                    failure[peak_index] += 1
                    # If the maximal number of failures is reached then mark peak as solved (i.e. not fitted).
                    if failure[peak_index] >= total_nb_chances:
                        # Mark all the matching associated to the current peak as tried.
                        b[:, peak_index] = -numpy.inf
                        index = numpy.arange(n_tm) * nb_local_peak_times + peak_index
                    else:
                        # Mark current matching as tried.
                        b[best_template_index, peak_index] = -numpy.inf
                        index = best_template_index * nb_local_peak_times + peak_index

                    if numerous_argmax:
                        best_indices = best_indices[~numpy.in1d(best_indices, index)]

                    # Save debug data.
                    if debug:
                        result_debug['chunk_nbs'] += [gidx]
                        result_debug['iteration_nbs'] += [iteration_nb]
                        result_debug['peak_nbs'] += [peak_index]
                        result_debug['peak_local_time_steps'] += [local_peaktimes[peak_index]]
                        result_debug['peak_time_steps'] += [all_spikes[peak_index]]
                        result_debug['peak_scalar_products'] += [peak_scalar_product]
                        result_debug['peak_solved_flags'] += [b[best_template_index, peak_index]]
                        result_debug['template_nbs'] += [best_template_index]
                        result_debug['success_flags'] += [False]

                iteration_nb += 1

            spikes_to_write = numpy.array(result['spiketimes'], dtype=numpy.uint32)
            amplitudes_to_write = numpy.array(result['amplitudes'], dtype=numpy.float32)
            templates_to_write = numpy.array(result['templates'], dtype=numpy.uint32)

            spiketimes_file.write(spikes_to_write.tostring())
            amplitudes_file.write(amplitudes_to_write.tostring())
            templates_file.write(templates_to_write.tostring())

            if collect_all:

                for temp, spike in zip(templates_to_write, spikes_to_write - g_offset):
                    c_all_times[c_min_times[spike]:c_max_times[spike], neighbors[temp]] = False

                gspikes = numpy.where(numpy.sum(c_all_times, 1) > 0)[0]
                c_all_times = numpy.take(c_all_times, gspikes, axis=0)
                c_local_chunk = numpy.take(c_local_chunk, gspikes, axis=0) * c_all_times                

                if sign_peaks == 'negative':
                    bestlecs = numpy.argmin(c_local_chunk, 1)
                    if matched_filter:
                        threshs = -matched_thresholds_neg[bestlecs]
                    else:
                        threshs = -thresholds[bestlecs]
                    idx = numpy.where(numpy.min(c_local_chunk, 1) < threshs)[0]
                elif sign_peaks == 'positive':
                    bestlecs = numpy.argmax(c_local_chunk, 1)
                    if matched_filter:
                        threshs = matched_thresholds_pos[bestlecs]
                    else:
                        threshs = thresholds[bestlecs]
                    idx = numpy.where(numpy.max(c_local_chunk, 1) > threshs)[0]
                elif sign_peaks == 'both':
                    c_local_chunk = numpy.abs(c_local_chunk)
                    bestlecs = numpy.argmax(c_local_chunk, 1)
                    if matched_filter:
                        threshs = numpy.minimum(matched_thresholds_neg[bestlecs], matched_thresholds_pos[bestlecs])
                    else:
                        threshs = thresholds[bestlecs]
                    idx = numpy.where(numpy.max(c_local_chunk, 1) > threshs)[0]
                else:
                    raise ValueError("Unexpected value %s" % sign_peaks)

                gspikes = numpy.take(gspikes, idx)
                bestlecs = numpy.take(bestlecs, idx)
                gspikes_to_write = numpy.array(gspikes + g_offset, dtype=numpy.uint32)
                gtemplates_to_write = numpy.array(bestlecs, dtype=numpy.uint32)

                garbage_times_file.write(gspikes_to_write.tostring())
                garbage_temp_file.write(gtemplates_to_write.tostring())

            if debug:
                # Write debug data to debug files.
                for field_label, field_dtype, field_file in [
                    ('chunk_nbs', numpy.uint32, chunk_nbs_debug_file),
                    ('iteration_nbs', numpy.uint32, iteration_nbs_debug_file),
                    ('peak_nbs', numpy.uint32, peak_nbs_debug_file),
                    ('peak_local_time_steps', numpy.uint32, peak_local_time_steps_debug_file),
                    ('peak_time_steps', numpy.uint32, peak_time_steps_debug_file),
                    ('peak_scalar_products', numpy.float32, peak_scalar_products_debug_file),
                    ('peak_solved_flags', numpy.float32, peak_solved_flags_debug_file),
                    ('template_nbs', numpy.uint32, template_nbs_debug_file),
                    ('success_flags', numpy.bool, success_flags_debug_file),
                ]:
                    field_to_write = numpy.array(result_debug[field_label], dtype=field_dtype)
                    field_file.write(field_to_write.tostring())

            if full_gpu:
                del b, data

    sys.stderr.flush()

    spiketimes_file.flush()
    os.fsync(spiketimes_file.fileno())
    spiketimes_file.close()

    amplitudes_file.flush()
    os.fsync(amplitudes_file.fileno())
    amplitudes_file.close()

    templates_file.flush()
    os.fsync(templates_file.fileno())
    templates_file.close()

    if collect_all:

        garbage_temp_file.flush()
        os.fsync(garbage_temp_file.fileno())
        garbage_temp_file.close()
        
        garbage_times_file.flush()
        os.fsync(garbage_times_file.fileno())
        garbage_times_file.close()

    if debug:
        # Close debug files.
        for field_file in [
            chunk_nbs_debug_file,
            iteration_nbs_debug_file,
            peak_nbs_debug_file,
            peak_local_time_steps_debug_file,
            peak_time_steps_debug_file,
            peak_scalar_products_debug_file,
            peak_solved_flags_debug_file,
            template_nbs_debug_file,
            success_flags_debug_file,
        ]:
            field_file.flush()
            os.fsync(field_file.fileno())
            field_file.close()

    comm.Barrier()

    if comm.rank == 0:
        io.collect_data(comm.size, params, erase=True)

    data_file.close()
def main(params, nb_cpu, nb_gpu, use_gpu):

    #################################################################
    # params = detect_memory(params)
    _ = init_logging(params.logfile)
    SHARED_MEMORY = get_shared_memory_flag(params)
    logger = logging.getLogger('circus.fitting')
    data_file = params.data_file
    n_e = params.getint('data', 'N_e')
    n_total = params.nb_channels
    n_t = params.getint('detection', 'N_t')
    template_shift = params.getint('detection', 'template_shift')
    # file_out = params.get('data', 'file_out')
    file_out_suff = params.get('data', 'file_out_suff')
    sign_peaks = params.get('detection', 'peaks')
    matched_filter = params.getboolean('detection', 'matched-filter')
    # spike_thresh = params.getfloat('detection', 'spike_thresh')
    ratio_thresh = params.getfloat('fitting', 'ratio_thresh')
    two_components = params.getboolean('fitting', 'two_components')
    sparse_threshold = params.getfloat('fitting', 'sparse_thresh')
    # spike_width = params.getfloat('detection', 'spike_width')
    # dist_peaks = params.getint('detection', 'dist_peaks')
    do_temporal_whitening = params.getboolean('whitening', 'temporal')
    do_spatial_whitening = params.getboolean('whitening', 'spatial')
    templates_normalization = params.getboolean('clustering', 'templates_normalization')  # TODO test, switch, test!
    chunk_size = detect_memory(params, fitting=True)
    gpu_only = params.getboolean('fitting', 'gpu_only')
    nodes, edges = get_nodes_and_edges(params)
    tmp_limits = params.get('fitting', 'amp_limits').replace('(', '').replace(')', '').split(',')
    tmp_limits = [float(v) for v in tmp_limits]
    amp_auto = params.getboolean('fitting', 'amp_auto')
    max_chunk = params.getfloat('fitting', 'max_chunk')
    # noise_thr = params.getfloat('clustering', 'noise_thr')
    collect_all = params.getboolean('fitting', 'collect_all')
    min_second_component = params.getfloat('fitting', 'min_second_component')
    debug = params.getboolean('fitting', 'debug')
    ignore_dead_times = params.getboolean('triggers', 'ignore_times')
    inv_nodes = numpy.zeros(n_total, dtype=numpy.int32)
    inv_nodes[nodes] = numpy.arange(len(nodes))
    data_file.open()
    supports = io.load_data(params, 'supports')
    low_channels_thr = params.getint('detection', 'low_channels_thr')
    median_channels = numpy.median(numpy.sum(supports, 1))
    fixed_amplitudes = params.getboolean('clustering', 'fixed_amplitudes')

    weird_thresh = params.get('detection', 'weird_thresh')
    if weird_thresh != '':
        ignore_artefacts = True
        weird_thresh = io.load_data(params, 'weird-thresholds')
    else:
        ignore_artefacts = False

    if not fixed_amplitudes:
        nb_amp_bins = params.getint('clustering', 'nb_amp_bins')
        splits = np.linspace(0, params.data_file.duration, nb_amp_bins)
        interpolated_times = np.zeros(len(splits) - 1, dtype=numpy.float32)
        for count in range(0, len(splits) - 1):
            interpolated_times[count] = (splits[count] + splits[count + 1])/2
        interpolated_times = numpy.concatenate(([0], interpolated_times, [params.data_file.duration]))
        nb_amp_times = len(splits) + 1

    mse_error = params.getboolean('fitting', 'mse_error')
    if mse_error:
        stds = io.load_data(params, 'stds')
        stds_norm = numpy.linalg.norm(stds)
    # if median_channels < low_channels_thr:
    #     normalization = False
    #     if comm.rank == 0:
    #         print_and_log(['Templates defined on few channels (%g), turning off normalization' %median_channels], 'debug', logger)

    #################################################################

    if SHARED_MEMORY:
        templates, mpi_memory_1 = io.load_data_memshared(params, 'templates', normalize=templates_normalization, transpose=True, sparse_threshold=sparse_threshold)
        N_tm, x = templates.shape
        is_sparse = not isinstance(templates, numpy.ndarray)
    else:
        templates = io.load_data(params, 'templates')
        x, N_tm = templates.shape
        if N_tm > 0:
            sparsity = templates.nnz / (x * N_tm)
            is_sparse = sparsity < sparse_threshold
        else:
            is_sparse = True
        if not is_sparse:
            if comm.rank == 0:
                print_and_log(['Templates sparsity is low (%g): densified to speedup the algorithm' %sparsity], 'debug', logger)
            templates = templates.toarray()

    temp_2_shift = 2 * template_shift
    temp_3_shift = 3 * template_shift
    n_tm = N_tm // 2
    n_scalar = n_e * n_t

    temp_window = numpy.arange(-template_shift, template_shift + 1)
    size_window = n_e * (2 * template_shift + 1)

    if not amp_auto:
        amp_limits = numpy.zeros((n_tm, 2))
        amp_limits[:, 0] = tmp_limits[0]
        amp_limits[:, 1] = tmp_limits[1]
    else:
        amp_limits = io.load_data(params, 'limits')

    norm_templates = io.load_data(params, 'norm-templates')
    sub_norm_templates = n_scalar * norm_templates[:n_tm]
    if not templates_normalization:
        norm_templates_2 = (norm_templates ** 2.0) * n_scalar
        sub_norm_templates_2 = norm_templates_2[:n_tm]

    if not SHARED_MEMORY:
        # Normalize templates (if necessary).
        if templates_normalization:
            if is_sparse:
                for idx in range(templates.shape[1]):
                    myslice = numpy.arange(templates.indptr[idx], templates.indptr[idx+1])
                    templates.data[myslice] /= norm_templates[idx]
            else:
                for idx in range(templates.shape[1]):
                    templates[:, idx] /= norm_templates[idx]
        # Transpose templates.
        templates = templates.T

    maxoverlap = io.load_data(params, 'maxoverlap')/n_scalar
    similar = np.where(maxoverlap > 0.5)

    idx = similar[0] < similar[1]
    similar = similar[0][idx], similar[1][idx]
    nb_mixtures = len(similar[0])

    waveform_neg = numpy.empty(0)  # default assignment (for PyCharm code inspection)
    matched_thresholds_neg = None  # default assignment (for PyCharm code inspection)
    waveform_pos = numpy.empty(0)  # default assignment (for PyCharm code inspection)
    matched_thresholds_pos = None  # default assignment (for PyCharm code inspection)
    if matched_filter:
        if sign_peaks in ['negative', 'both']:
            waveform_neg = io.load_data(params, 'waveform')[::-1]
            waveform_neg /= (numpy.abs(numpy.sum(waveform_neg)) * len(waveform_neg))
            matched_thresholds_neg = io.load_data(params, 'matched-thresholds')
        if sign_peaks in ['positive', 'both']:
            waveform_pos = io.load_data(params, 'waveform-pos')[::-1]
            waveform_pos /= (numpy.abs(numpy.sum(waveform_pos)) * len(waveform_pos))
            matched_thresholds_pos = io.load_data(params, 'matched-thresholds-pos')

    if ignore_dead_times:
        if SHARED_MEMORY:
            all_dead_times, mpi_memory_3 = get_dead_times(params)
        else:
            all_dead_times = get_dead_times(params)
    else:
        all_dead_times = None  # default assignment (for PyCharm code inspection)

    thresholds = io.get_accurate_thresholds(params, ratio_thresh)

    neighbors = {}
    if collect_all:
        is_sparse = not isinstance(templates, numpy.ndarray)
        for i in range(0, n_tm):
            if is_sparse:
                tmp = templates[i, :].toarray().reshape(n_e, n_t)
            else:
                tmp = templates[i].reshape(n_e, n_t)
            if templates_normalization:
                tmp = tmp * norm_templates[i]
            neighbors[i] = numpy.where(numpy.sum(tmp, axis=1) != 0.0)[0]

    #N_tm, x = templates.shape
    #sparsity_factor = templates.nnz / (N_tm * x)
    #if sparsity_factor > sparse_threshold:
    #    if comm.rank == 0:
    #        print_and_log(['Templates are not sparse enough, we densify them for'], 'default', logger)
    #    templates = templates.toarray()

    info_string = ''

    if comm.rank == 0:
        info_string = "using %d CPUs" % comm.size

    comm.Barrier()

    c_overlap = io.get_overlaps(params, nb_cpu=nb_cpu, nb_gpu=nb_gpu, use_gpu=use_gpu)
    over_shape = c_overlap.get('over_shape')[:]
    n_over = int(numpy.sqrt(over_shape[0]))
    s_over = over_shape[1]
    s_center = s_over // 2
    # # If the number of overlaps is different from templates, we need to recompute them.
    if n_over != N_tm:
        if comm.rank == 0:
            print_and_log(['Templates have been modified, recomputing the overlaps...'], 'default', logger)
        c_overlap = io.get_overlaps(params, erase=True, nb_cpu=nb_cpu, nb_gpu=nb_gpu, use_gpu=use_gpu)
        over_shape = c_overlap.get('over_shape')[:]
        n_over = int(numpy.sqrt(over_shape[0]))
        s_over = over_shape[1]

    if SHARED_MEMORY:
        c_overs, mpi_memory_2 = io.load_data_memshared(params, 'overlaps')
    else:
        c_overs = io.load_data(params, 'overlaps')

    comm.Barrier()

    if n_tm == 0:
        if comm.rank == 0:
            print_and_log(["No templates present. Redo clustering?"], 'default', logger)

        sys.exit(0)

    if comm.rank == 0:
        print_and_log(["Here comes the SpyKING CIRCUS %s and %d templates..." % (info_string, n_tm)], 'default', logger)
        purge(file_out_suff, '.data')

    if do_spatial_whitening:
        spatial_whitening = io.load_data(params, 'spatial_whitening')
    else:
        spatial_whitening = None  # default assignment (for PyCharm code inspection)
    if do_temporal_whitening:
        temporal_whitening = io.load_data(params, 'temporal_whitening')
    else:
        temporal_whitening = None  # default assignment (for PyCharm code inspection)

    nb_chunks, last_chunk_len = data_file.analyze(chunk_size)
    processed_chunks = int(min(nb_chunks, max_chunk))

    comm.Barrier()
    spiketimes_file = open(file_out_suff + '.spiketimes-%d.data' % comm.rank, 'wb')
    comm.Barrier()
    amplitudes_file = open(file_out_suff + '.amplitudes-%d.data' % comm.rank, 'wb')
    comm.Barrier()
    templates_file = open(file_out_suff + '.templates-%d.data' % comm.rank, 'wb')
    comm.Barrier()

    if ignore_artefacts:
        comm.Barrier()
        arte_spiketimes_file = open(file_out_suff + '.times-%d.sata' % comm.rank, 'wb')
        comm.Barrier()
        arte_electrodes_file = open(file_out_suff + '.elec-%d.sata' % comm.rank, 'wb')
        comm.Barrier()
        arte_amplitudes_file = open(file_out_suff + '.amp-%d.sata' % comm.rank, 'wb')
        comm.Barrier()

    if mse_error:
        mse_file = open(file_out_suff + '.mses-%d.data' % comm.rank, 'wb')
        comm.Barrier()

    if collect_all:
        garbage_times_file = open(file_out_suff + '.gspiketimes-%d.data' % comm.rank, 'wb')
        comm.Barrier()
        garbage_temp_file = open(file_out_suff + '.gtemplates-%d.data' % comm.rank, 'wb')
        comm.Barrier()
    else:
        garbage_times_file = None  # default assignment (for PyCharm code inspection)
        garbage_temp_file = None  # default assignment (for PyCharm code inspection)

    if debug:
        # Open debug files.
        chunk_nbs_debug_file = open(file_out_suff + '.chunk_nbs_debug_%d.data' % comm.rank, mode='wb')
        comm.Barrier()
        iteration_nbs_debug_file = open(file_out_suff + '.iteration_nbs_debug_%d.data' % comm.rank, mode='wb')
        comm.Barrier()
        peak_nbs_debug_file = open(file_out_suff + '.peak_nbs_debug_%d.data' % comm.rank, mode='wb')
        comm.Barrier()
        peak_local_time_steps_debug_file = open(
            file_out_suff + '.peak_local_time_steps_debug_%d.data' % comm.rank, mode='wb'
        )
        comm.Barrier()
        peak_time_steps_debug_file = open(file_out_suff + '.peak_time_steps_debug_%d.data' % comm.rank, mode='wb')
        comm.Barrier()
        peak_scalar_products_debug_file = open(
            file_out_suff + '.peak_scalar_products_debug_%d.data' % comm.rank, mode='wb'
        )
        comm.Barrier()
        peak_solved_flags_debug_file = open(file_out_suff + '.peak_solved_flags_debug_%d.data' % comm.rank, mode='wb')
        comm.Barrier()
        template_nbs_debug_file = open(file_out_suff + '.template_nbs_debug_%d.data' % comm.rank, mode='wb')
        comm.Barrier()
        success_flags_debug_file = open(file_out_suff + '.success_flags_debug_%d.data' % comm.rank, mode='wb')
        comm.Barrier()
    else:
        chunk_nbs_debug_file = None  # default assignment (for PyCharm code inspection)
        iteration_nbs_debug_file = None  # default assignment (for PyCharm code inspection)
        peak_nbs_debug_file = None  # default assignment (for PyCharm code inspection)
        peak_local_time_steps_debug_file = None  # default assignment (for PyCharm code inspection)
        peak_time_steps_debug_file = None  # default assignment (for PyCharm code inspection)
        peak_scalar_products_debug_file = None  # default assignment (for PyCharm code inspection)
        peak_solved_flags_debug_file = None  # default assignment (for PyCharm code inspection)
        template_nbs_debug_file = None  # default assignment (for PyCharm code inspection)
        success_flags_debug_file = None  # default assignment (for PyCharm code inspection)

    last_chunk_size = 0
    slice_indices = numpy.zeros(0, dtype=numpy.int32)

    to_explore = list(range(comm.rank, processed_chunks, comm.size))

    if comm.rank == 0:
        to_explore = get_tqdm_progressbar(params, to_explore)

    if fixed_amplitudes:
        min_scalar_products = amp_limits[:,0][:, numpy.newaxis]
        max_scalar_products = amp_limits[:,1][:, numpy.newaxis]

        if templates_normalization:
            min_sps = min_scalar_products * sub_norm_templates[:, numpy.newaxis]
            max_sps = max_scalar_products * sub_norm_templates[:, numpy.newaxis]
        else:
            min_sps = min_scalar_products * sub_norm_templates_2[:, numpy.newaxis]
            max_sps = max_scalar_products * sub_norm_templates_2[:, numpy.newaxis]

    for gcount, gidx in enumerate(to_explore):
        # print "Node", comm.rank, "is analyzing chunk", gidx, "/", nb_chunks, " ..."
        # # We need to deal with the borders by taking chunks of size [0, chunck_size + template_shift].

        is_first = data_file.is_first_chunk(gidx, nb_chunks)
        is_last = data_file.is_last_chunk(gidx, nb_chunks)

        if not (is_first and is_last):
            if is_last:
                padding = (-temp_3_shift, 0)
            elif is_first:
                padding = (0, temp_3_shift)
            else:
                padding = (-temp_3_shift, temp_3_shift)
        else:
            padding = (0, 0)

        result = {
            'spiketimes': [],
            'amplitudes': [],
            'templates': [],
        }

        if mse_error:
            mse_fit = {
            'spiketimes': [],
            'amplitudes': [],
            'templates': [],
        }

        result_debug = {
            'chunk_nbs': [],
            'iteration_nbs': [],
            'peak_nbs': [],
            'peak_local_time_steps': [],
            'peak_time_steps': [],
            'peak_scalar_products': [],
            'peak_solved_flags': [],
            'template_nbs': [],
            'success_flags': [],
        }

        local_chunk, t_offset = data_file.get_data(gidx, chunk_size, padding, nodes=nodes)           
        len_chunk = len(local_chunk)
        if is_last:
            my_chunk_size = last_chunk_size
        else:
            my_chunk_size = chunk_size

        if do_spatial_whitening:
            local_chunk = numpy.dot(local_chunk, spatial_whitening)
        if do_temporal_whitening:
            local_chunk = scipy.ndimage.filters.convolve1d(local_chunk, temporal_whitening, axis=0, mode='constant')

        # Extracting peaks.

        all_found_spikes = {}
        if collect_all:
            for i in range(n_e):
                all_found_spikes[i] = []

        local_peaktimes = [numpy.empty(0, dtype=numpy.uint32)]

        if ignore_artefacts:
            artefacts_peaktimes = [numpy.zeros(0, dtype=numpy.uint32)]
            artefacts_elecs = [numpy.zeros(0, dtype=numpy.uint32)]
            artefacts_amps = [numpy.zeros(0, dtype=numpy.float32)]    

        if matched_filter:
            if sign_peaks in ['positive', 'both']:
                filter_chunk = scipy.ndimage.filters.convolve1d(local_chunk, waveform_pos, axis=0, mode='constant')
                for i in range(n_e):
                    peaktimes = scipy.signal.find_peaks(filter_chunk[:, i], height=matched_thresholds_pos[i])[0]

                    if ignore_artefacts:
                        artetimes = scipy.signal.find_peaks(numpy.abs(filter_chunk[:, i]), height=weird_thresh[i])[0]
                        to_keep = numpy.logical_not(numpy.in1d(peaktimes, artetimes))
                        peaktimes = peaktimes[to_keep]
                        artefacts_peaktimes.append(artetimes)
                        artefacts_elecs.append(i*numpy.ones(len(artetimes), dtype='uint32'))
                        artefacts_amps.append(local_chunk[artetimes, i])

                    local_peaktimes.append(peaktimes)
                    if collect_all:
                        all_found_spikes[i] += peaktimes.tolist()
            if sign_peaks in ['negative', 'both']:
                filter_chunk = scipy.ndimage.filters.convolve1d(local_chunk, waveform_neg, axis=0, mode='constant')
                for i in range(n_e):
                    peaktimes = scipy.signal.find_peaks(filter_chunk[:, i], height=matched_thresholds_neg[i])[0]

                    if ignore_artefacts:
                        artetimes = scipy.signal.find_peaks(numpy.abs(filter_chunk[:, i]), height=weird_thresh[i])[0]
                        to_keep = numpy.logical_not(numpy.in1d(peaktimes, artetimes))
                        peaktimes = peaktimes[to_keep]
                        artefacts_peaktimes.append(artetimes)
                        artefacts_elecs.append(i*numpy.ones(len(artetimes), dtype='uint32'))
                        artefacts_amps.append(local_chunk[artetimes, i])

                    local_peaktimes.append(peaktimes)
                    if collect_all:
                        all_found_spikes[i] += peaktimes.tolist()
            local_peaktimes = numpy.concatenate(local_peaktimes)

            if ignore_artefacts:
                artefacts_peaktimes = numpy.concatenate(artefacts_peaktimes)
                artefacts_elecs = numpy.concatenate(artefacts_elecs)
                artefacts_amps = numpy.concatenate(artefacts_amps)
        else:
            for i in range(n_e):
                if sign_peaks == 'negative':
                    peaktimes = scipy.signal.find_peaks(-local_chunk[:, i], height=thresholds[i])[0]
                elif sign_peaks == 'positive':
                    peaktimes = scipy.signal.find_peaks(local_chunk[:, i], height=thresholds[i])[0]
                elif sign_peaks == 'both':
                    peaktimes = scipy.signal.find_peaks(numpy.abs(local_chunk[:, i]), height=thresholds[i])[0]
                else:
                    raise ValueError("Unexpected value %s" % sign_peaks)

                if ignore_artefacts:
                    artetimes = scipy.signal.find_peaks(numpy.abs(local_chunk[:, i]), height=weird_thresh[i])[0]
                    to_keep = numpy.logical_not(numpy.in1d(peaktimes, artetimes))
                    peaktimes = peaktimes[to_keep]
                    artefacts_peaktimes.append(artetimes)
                    artefacts_elecs.append(i*numpy.ones(len(artetimes), dtype='uint32'))
                    artefacts_amps.append(local_chunk[artetimes, i])

                local_peaktimes.append(peaktimes)
                if collect_all:
                    all_found_spikes[i] += peaktimes.tolist()
            local_peaktimes = numpy.concatenate(local_peaktimes)

            if ignore_artefacts:
                artefacts_peaktimes = numpy.concatenate(artefacts_peaktimes)
                artefacts_elecs = numpy.concatenate(artefacts_elecs)
                artefacts_amps = numpy.concatenate(artefacts_amps)

        local_peaktimes = numpy.unique(local_peaktimes)
        g_offset = t_offset + padding[0]

        if ignore_dead_times:
            dead_indices = numpy.searchsorted(all_dead_times, [t_offset, t_offset + my_chunk_size])
            if dead_indices[0] != dead_indices[1]:
                is_included = numpy.in1d(local_peaktimes + g_offset, all_dead_times[dead_indices[0]:dead_indices[1]])
                local_peaktimes = local_peaktimes[~is_included]

                if ignore_artefacts:
                    is_included = numpy.in1d(artefacts_peaktimes + g_offset, all_dead_times[dead_indices[0]:dead_indices[1]])
                    artefacts_peaktimes = artefacts_peaktimes[~is_included]
                    artefacts_elecs = artefacts_elecs[~is_included]
                    artefacts_amps = artefacts_amps[~is_included]  

                local_peaktimes = numpy.sort(local_peaktimes)
        else:
            dead_indices = None  # default assignment (for PyCharm code inspection)

        # print "Removing the useless borders..."
        local_borders = (template_shift, len_chunk - template_shift)
        idx = (local_peaktimes >= local_borders[0]) & (local_peaktimes < local_borders[1])
        local_peaktimes = numpy.compress(idx, local_peaktimes)

        if ignore_artefacts:
            artefacts_peaktimes = artefacts_peaktimes + g_offset
            idx = (artefacts_peaktimes >= t_offset) & (artefacts_peaktimes < t_offset + my_chunk_size)
            artefacts_peaktimes = numpy.compress(idx, artefacts_peaktimes)
            artefacts_elecs = numpy.compress(idx, artefacts_elecs)
            artefacts_amps = numpy.compress(idx, artefacts_amps)

        if collect_all:
            for i in range(n_e):
                all_found_spikes[i] = numpy.array(all_found_spikes[i], dtype=numpy.uint32)

                if ignore_dead_times:
                    if dead_indices[0] != dead_indices[1]:
                        is_included = numpy.in1d(
                            all_found_spikes[i] + g_offset, all_dead_times[dead_indices[0]:dead_indices[1]]
                        )
                        all_found_spikes[i] = all_found_spikes[i][~is_included]
                        all_found_spikes[i] = numpy.sort(all_found_spikes[i])

                idx = (all_found_spikes[i] >= local_borders[0]) & (all_found_spikes[i] < local_borders[1])
                all_found_spikes[i] = numpy.compress(idx, all_found_spikes[i])

        nb_local_peak_times = len(local_peaktimes)

        if nb_local_peak_times > 0:
            # print "Computing the b (should full_gpu by putting all chunks on GPU if possible?)..."

            if collect_all or mse_error:
                c_local_chunk = local_chunk.copy()
            else:
                c_local_chunk = None  # default assignment (for PyCharm code inspection)

            sub_mat = local_chunk[local_peaktimes[:, None] + temp_window]
            sub_mat = sub_mat.transpose(2, 1, 0).reshape(size_window, nb_local_peak_times)

            del local_chunk

            b = templates.dot(sub_mat)
    
            del sub_mat

            local_restriction = (t_offset, t_offset + my_chunk_size)
            all_spikes = local_peaktimes + g_offset

            if collect_all:
                c_all_times = numpy.zeros((len_chunk, n_e), dtype=numpy.bool)
                c_min_times = numpy.maximum(numpy.arange(len_chunk) - template_shift, 0)
                c_max_times = numpy.minimum(numpy.arange(len_chunk) + template_shift + 1, len_chunk)
                for i in range(n_e):
                    c_all_times[all_found_spikes[i], i] = True
            else:
                c_all_times = None  # default assignment (for PyCharm code inspection)
                c_min_times = None  # default assignment (for PyCharm code inspection)
                c_max_times = None  # default assignment (for PyCharm code inspection)

            iteration_nb = 0
            data = b[:n_tm, :]

            if not fixed_amplitudes:
                amp_index = numpy.searchsorted(splits, local_restriction[0], 'right')
                scaling = 1/(splits[amp_index] - splits[amp_index - 1])
                min_scalar_products = amp_limits[:, amp_index, 0] + (amp_limits[:, amp_index, 0] - amp_limits[:, amp_index+1, 0])*scaling
                max_scalar_products = amp_limits[:, amp_index, 1] + (amp_limits[:, amp_index, 1] - amp_limits[:, amp_index+1, 0])*scaling
                    
                min_scalar_products = min_scalar_products[:, numpy.newaxis]
                max_scalar_products = max_scalar_products[:, numpy.newaxis]

                if templates_normalization:
                    min_sps = min_scalar_products * sub_norm_templates[:, numpy.newaxis]
                    max_sps = max_scalar_products * sub_norm_templates[:, numpy.newaxis]
                else:
                    min_sps = min_scalar_products * sub_norm_templates_2[:, numpy.newaxis]
                    max_sps = max_scalar_products * sub_norm_templates_2[:, numpy.newaxis]

            while True:

                is_valid = (data > min_sps)*(data < max_sps)
                valid_indices = numpy.where(is_valid)

                if len(valid_indices[0]) == 0:
                    break

                best_amplitude_idx = data[is_valid].argmax()

                best_template_index, peak_index = valid_indices[0][best_amplitude_idx], valid_indices[1][best_amplitude_idx]
                peak_scalar_product = data[is_valid][best_amplitude_idx]

                best_template2_index = best_template_index + n_tm
                if templates_normalization:
                    best_amp = b[best_template_index, peak_index] / n_scalar
                    best_amp_n = best_amp / norm_templates[best_template_index]
                    if two_components:
                        best_amp2 = b[best_template2_index, peak_index] / n_scalar
                        best_amp2_n = best_amp2 /  norm_templates[best_template2_index]
                    else:
                        best_amp2 = 0
                        best_amp2_n = 0
                else:
                    best_amp = b[best_template_index, peak_index] / norm_templates_2[best_template_index]
                    best_amp_n = best_amp
                    if two_components:     
                        best_amp2 = b[best_template2_index, peak_index] / norm_templates_2[best_template2_index]
                        best_amp2_n = best_amp2
                    else:
                        best_amp2 = 0
                        best_amp2_n = 0

                peak_time_step = local_peaktimes[peak_index]

                peak_data = (local_peaktimes - peak_time_step).astype(np.int32)
                is_neighbor = np.abs(peak_data) <= temp_2_shift
                idx_neighbor = peak_data[is_neighbor] + temp_2_shift

                tmp1 = c_overs[best_template_index].multiply(-best_amp)
                if numpy.abs(best_amp2_n) > min_second_component:
                    tmp1 += c_overs[best_template2_index].multiply(-best_amp2)

                to_add = tmp1.toarray()[:, idx_neighbor]
                b[:, is_neighbor] += to_add

                b[best_template_index, peak_index] = -numpy.inf

                # Add matching to the result.
                t_spike = all_spikes[peak_index]

                if (t_spike >= local_restriction[0]) and (t_spike < local_restriction[1]):
                    result['spiketimes'] += [t_spike]
                    result['amplitudes'] += [(best_amp_n, best_amp2_n)]
                    result['templates'] += [best_template_index]
                elif mse_error:
                    mse_fit['spiketimes'] += [t_spike]
                    mse_fit['amplitudes'] += [(best_amp_n, best_amp2_n)]
                    mse_fit['templates'] += [best_template_index]

                # Save debug data.
                if debug:
                    result_debug['chunk_nbs'] += [gidx]
                    result_debug['iteration_nbs'] += [iteration_nb]
                    result_debug['peak_nbs'] += [peak_index]
                    result_debug['peak_local_time_steps'] += [local_peaktimes[peak_index]]
                    result_debug['peak_time_steps'] += [all_spikes[peak_index]]
                    result_debug['peak_scalar_products'] += [peak_scalar_product]
                    result_debug['peak_solved_flags'] += [b[best_template_index, peak_index]]
                    result_debug['template_nbs'] += [best_template_index]
                    result_debug['success_flags'] += [True]

                iteration_nb += 1

            spikes_to_write = numpy.array(result['spiketimes'], dtype=numpy.uint32)
            amplitudes_to_write = numpy.array(result['amplitudes'], dtype=numpy.float32)
            templates_to_write = numpy.array(result['templates'], dtype=numpy.uint32)

            spiketimes_file.write(spikes_to_write.tostring())
            amplitudes_file.write(amplitudes_to_write.tostring())
            templates_file.write(templates_to_write.tostring())

            if ignore_artefacts:
                arte_spiketimes_file.write(artefacts_peaktimes.astype(numpy.uint32).tostring())
                arte_electrodes_file.write(artefacts_elecs.tostring())
                arte_amplitudes_file.write(artefacts_amps.tostring())

            if mse_error:
                curve = numpy.zeros((len_chunk, n_e), dtype=numpy.float32)
                for spike, temp_id, amplitude in zip(result['spiketimes'], result['templates'], result['amplitudes']):
                    spike = spike - t_offset - padding[0]
                    if is_sparse:
                        tmp1 = templates[temp_id].toarray().reshape(n_e, n_t)
                        tmp2 = templates[temp_id + n_tm].toarray().reshape(n_e, n_t)
                    else:
                        tmp1 = templates[temp_id].reshape(n_e, n_t)
                        tmp2 = templates[temp_id + n_tm].reshape(n_e, n_t)

                    curve[spike - template_shift:spike + template_shift + 1, :] += (amplitude[0] * tmp1 + amplitude[1] * tmp2).T

                for spike, temp_id, amplitude in zip(mse_fit['spiketimes'], mse_fit['templates'], mse_fit['amplitudes']):
                    spike = spike - t_offset + padding[0]
                    if is_sparse:
                        tmp1 = templates[temp_id].toarray().reshape(n_e, n_t)
                        tmp2 = templates[temp_id + n_tm].toarray().reshape(n_e, n_t)
                    else:
                        tmp1 = templates[temp_id].reshape(n_e, n_t)
                        tmp2 = templates[temp_id + n_tm].reshape(n_e, n_t)
                    try:
                        curve[int(spike) - template_shift:int(spike) + template_shift + 1, :] += (amplitude[0] * tmp1 + amplitude[1] * tmp2).T
                    except Exception:
                        pass
                mse = numpy.linalg.norm((curve - c_local_chunk)[-padding[0]:-padding[1]])
                nb_points = len(curve) - (padding[1] - padding[0])
                mse_ratio = mse/(numpy.sqrt(nb_points)*stds_norm)
                mse_to_write = numpy.array([g_offset, mse_ratio], dtype=numpy.float32)
                mse_file.write(mse_to_write.tostring())

            if collect_all:

                for temp, spike in zip(templates_to_write, spikes_to_write - g_offset):
                    c_all_times[c_min_times[spike]:c_max_times[spike], neighbors[temp]] = False

                gspikes = numpy.where(numpy.sum(c_all_times, 1) > 0)[0]
                c_all_times = numpy.take(c_all_times, gspikes, axis=0)
                c_local_chunk = numpy.take(c_local_chunk, gspikes, axis=0) * c_all_times                

                if sign_peaks == 'negative':
                    bestlecs = numpy.argmin(c_local_chunk, 1)
                    if matched_filter:
                        threshs = -matched_thresholds_neg[bestlecs]
                    else:
                        threshs = -thresholds[bestlecs]
                    idx = numpy.where(numpy.min(c_local_chunk, 1) < threshs)[0]
                elif sign_peaks == 'positive':
                    bestlecs = numpy.argmax(c_local_chunk, 1)
                    if matched_filter:
                        threshs = matched_thresholds_pos[bestlecs]
                    else:
                        threshs = thresholds[bestlecs]
                    idx = numpy.where(numpy.max(c_local_chunk, 1) > threshs)[0]
                elif sign_peaks == 'both':
                    c_local_chunk = numpy.abs(c_local_chunk)
                    bestlecs = numpy.argmax(c_local_chunk, 1)
                    if matched_filter:
                        threshs = numpy.minimum(matched_thresholds_neg[bestlecs], matched_thresholds_pos[bestlecs])
                    else:
                        threshs = thresholds[bestlecs]
                    idx = numpy.where(numpy.max(c_local_chunk, 1) > threshs)[0]
                else:
                    raise ValueError("Unexpected value %s" % sign_peaks)

                gspikes = numpy.take(gspikes, idx)
                bestlecs = numpy.take(bestlecs, idx)
                gspikes_to_write = numpy.array(gspikes + g_offset, dtype=numpy.uint32)
                gtemplates_to_write = numpy.array(bestlecs, dtype=numpy.uint32)

                garbage_times_file.write(gspikes_to_write.tostring())
                garbage_temp_file.write(gtemplates_to_write.tostring())

            if debug:
                # Write debug data to debug files.
                for field_label, field_dtype, field_file in [
                    ('chunk_nbs', numpy.uint32, chunk_nbs_debug_file),
                    ('iteration_nbs', numpy.uint32, iteration_nbs_debug_file),
                    ('peak_nbs', numpy.uint32, peak_nbs_debug_file),
                    ('peak_local_time_steps', numpy.uint32, peak_local_time_steps_debug_file),
                    ('peak_time_steps', numpy.uint32, peak_time_steps_debug_file),
                    ('peak_scalar_products', numpy.float32, peak_scalar_products_debug_file),
                    ('peak_solved_flags', numpy.float32, peak_solved_flags_debug_file),
                    ('template_nbs', numpy.uint32, template_nbs_debug_file),
                    ('success_flags', numpy.bool, success_flags_debug_file),
                ]:
                    field_to_write = numpy.array(result_debug[field_label], dtype=field_dtype)
                    field_file.write(field_to_write.tostring())

    sys.stderr.flush()

    spiketimes_file.flush()
    os.fsync(spiketimes_file.fileno())
    spiketimes_file.close()

    amplitudes_file.flush()
    os.fsync(amplitudes_file.fileno())
    amplitudes_file.close()

    templates_file.flush()
    os.fsync(templates_file.fileno())
    templates_file.close()

    if collect_all:

        garbage_temp_file.flush()
        os.fsync(garbage_temp_file.fileno())
        garbage_temp_file.close()
        
        garbage_times_file.flush()
        os.fsync(garbage_times_file.fileno())
        garbage_times_file.close()

    if mse_error:
        mse_file.flush()
        os.fsync(mse_file.fileno())
        mse_file.close()

    if ignore_artefacts:
        arte_spiketimes_file.flush()
        os.fsync(arte_spiketimes_file.fileno())
        arte_spiketimes_file.close()

        arte_electrodes_file.flush()
        os.fsync(arte_electrodes_file.fileno())
        arte_electrodes_file.close()

        arte_amplitudes_file.flush()
        os.fsync(arte_amplitudes_file.fileno())
        arte_amplitudes_file.close()

    if debug:
        # Close debug files.
        for field_file in [
            chunk_nbs_debug_file,
            iteration_nbs_debug_file,
            peak_nbs_debug_file,
            peak_local_time_steps_debug_file,
            peak_time_steps_debug_file,
            peak_scalar_products_debug_file,
            peak_solved_flags_debug_file,
            template_nbs_debug_file,
            success_flags_debug_file,
        ]:
            field_file.flush()
            os.fsync(field_file.fileno())
            field_file.close()

    comm.Barrier()

    if SHARED_MEMORY:
        for memory in mpi_memory_1 + mpi_memory_2:
            memory.Free()
        if ignore_dead_times:
            mpi_memory_3.Free()

    if comm.rank == 0:
        io.collect_data(comm.size, params, erase=True)

        if ignore_artefacts:
            io.collect_artefacts(comm.size, params, erase=True)

    data_file.close()
Exemple #16
0
def main(params, nb_cpu, nb_gpu, use_gpu):

    logger = init_logging(params.logfile)
    logger = logging.getLogger('circus.filtering')
    #################################################################
    do_filter = params.getboolean('filtering', 'filter')
    filter_done = check_if_done(params, 'filter_done', logger)
    artefacts_done = check_if_done(params, 'artefacts_done', logger)
    median_done = check_if_done(params, 'median_done', logger)
    ground_done = check_if_done(params, 'ground_done', logger)
    file_out_suff = params.get('data', 'file_out_suff')
    clean_artefact = params.getboolean('triggers', 'clean_artefact')
    remove_median = params.getboolean('filtering', 'remove_median')
    sat_value = params.get('filtering', 'sat_value')
    sat_threshold = params.get('filtering', 'sat_threshold')
    if sat_value != '':
        flag_saturation = True
        sat_value = float(sat_value)
    else:
        flag_saturation = False

    common_ground = params.common_ground
    remove_ground = len(common_ground) > 0
    nodes, edges = get_nodes_and_edges(params)
    N_total = params.nb_channels
    inv_nodes = numpy.zeros(N_total, dtype=numpy.int32)
    inv_nodes[nodes] = numpy.arange(len(nodes))
    #################################################################

    def filter_file(data_file_in, data_file_out, do_filtering, do_remove_median, do_remove_ground):
        """
        Performs a high-pass and low-pass Butterworth filter on the data file.

        Parameters
        ----------
        
        data_file_in : 

        data_file_out : 

        do_filtering : bool

        do_remove_median : bool
 
        do_remove_median : bool
        """

        try:
            cut_off = params.getfloat('filtering', 'cut_off', check=False)
            cut_off = [cut_off, 0.95 * (params.rate / 2.0)]  # Nyquist
        except Exception:
            cut_off = params.get('filtering', 'cut_off', check=False)
            cut_off = cut_off.split(',')
            try:
                cut_off[0] = float(cut_off[0])
            except Exception:
                if comm.rank == 0:
                    print_and_log(['First value of cut off must be a valid number'], 'error', logger)
                sys.exit(0)

            cut_off[1] = cut_off[1].replace(' ', '')
            if cut_off[1] == 'auto':
                cut_off[1] = 0.95 * (params.rate / 2.0)
            else:
                try:
                    cut_off[1] = float(cut_off[1])
                except Exception:
                    if comm.rank == 0:
                        print_and_log(['Second value of cut off must either auto, or a valid a number'], 'error', logger)
                    sys.exit(0)

        chunk_size = detect_memory(params, filtering=True)
        butter_order = params.getint('filtering', 'butter_order')
        nb_chunks, _ = data_file_in.analyze(chunk_size)

        b, a = signal.butter(butter_order, np.array(cut_off)/(params.rate/2.), 'pass')
        all_chunks = numpy.arange(nb_chunks, dtype=numpy.int64)
        to_process = all_chunks[comm.rank::comm.size]
        loc_nb_chunks = len(to_process)
        N_total = params.nb_channels
        nb_shanks = len(params.probe['channel_groups'])



        if nb_shanks > 1:
            shank_channels = {}
            for i in list(params.probe['channel_groups'].keys()):
                shank_channels[i] = numpy.array(params.probe['channel_groups'][i]['channels'], dtype=numpy.int32)
        else:
            channel_group = list(params.probe['channel_groups'].keys())[0]

        process_all_channels = numpy.all(nodes == numpy.arange(N_total))
        duration = int(0.1*params.rate)

        if comm.rank == 0:
            to_write = []
            if do_filtering:
                to_write += ["Filtering with a Butterworth filter (order %d) in [%g, %g] Hz" % (butter_order, cut_off[0], cut_off[1])]
            if do_remove_median:
                to_write += ["Median over all channels is subtracted to each channels"]
            if do_remove_ground:
                to_write += ["Channels %s are used as reference channels in respective shanks" % common_ground]

            print_and_log(to_write, 'default', logger)

        to_explore = list(range(comm.rank, nb_chunks, comm.size))

        if comm.rank == 0:
            to_explore = get_tqdm_progressbar(params, to_explore)

        if data_file_in == data_file_out:
            data_file_in.open(mode='r+')
        else:
            data_file_in.open(mode='r')
            data_file_out.open(mode='r+')

        if flag_saturation:
            comm.Barrier()
            saturation_times = open(file_out_suff + '.times-%d.data' % comm.rank, 'wb')
            saturation_channels = open(file_out_suff + '.channels-%d.data' % comm.rank, 'wb')
            saturation_values = open(file_out_suff + '.values-%d.data' % comm.rank, 'wb')

            if data_file_in.data_dtype in ['float32', numpy.float32, 'float64', numpy.float64]:
                max_value = numpy.finfo(data_file_in.data_dtype).max
            else:
                max_value = numpy.iinfo(data_file_in.data_dtype).max - data_file_in.dtype_offset

            saturation = sat_value * max_value

        for count, gidx in enumerate(to_explore):

            is_first = data_file_in.is_first_chunk(gidx, nb_chunks)
            is_last = data_file_in.is_last_chunk(gidx, nb_chunks)

            if not (is_first and is_last):
                if is_first:
                    padding = (0, duration)
                elif is_last:
                    padding = (-duration, 0)
                else:
                    padding = (-duration, duration)
            else:
                padding = (0, 0)

            local_chunk, t_offset =  data_file_in.get_data(gidx, chunk_size, padding)

            len_chunk = len(local_chunk)

            if flag_saturation:

                if sat_threshold == 'negative':
                    indices = numpy.where(local_chunk <= saturation * data_file_in.gain)
                elif sat_threshold == 'positive':
                    indices = numpy.where(local_chunk >= saturation * data_file_in.gain)
                elif sat_threshold == 'both':
                    indices = numpy.where(numpy.abs(local_chunk) >= saturation * data_file_in.gain)

                if not process_all_channels:
                    to_keep = numpy.in1d(indices[1], nodes)
                    channels = inv_nodes[indices[1][to_keep]]
                    times = indices[0][to_keep]
                else:
                    channels = indices[1]
                    times = indices[0]

                to_keep = (times >= padding[0]) & (times < (len_chunk-numpy.abs(padding[1])))
                sub_times = times[to_keep]
                sub_channels = channels[to_keep]

                saturation_times.write((sub_times + t_offset).astype(numpy.uint32).tostring())
                saturation_channels.write(sub_channels.astype(numpy.uint32).tostring())
                saturation_values.write(local_chunk[sub_times, sub_channels].tostring())

            if do_filtering:
                if not process_all_channels:
                    local_chunk[:, nodes] = signal.filtfilt(b, a, local_chunk[:, nodes], axis=0)
                    local_chunk[:, nodes] -= numpy.median(local_chunk[:, nodes], 0)
                else:
                    local_chunk = signal.filtfilt(b, a, local_chunk, axis=0)
                    local_chunk -= numpy.median(local_chunk, 0)

            if flag_saturation:
                local_chunk[times, channels] = 0

            local_chunk = local_chunk[numpy.abs(padding[0]):len_chunk-numpy.abs(padding[1])]

            if do_remove_median:
                if nb_shanks == 1:
                    if not process_all_channels:
                        global_median = numpy.median(numpy.take(local_chunk, nodes, axis=1), 1)
                    else:
                        global_median = numpy.median(local_chunk, 1)
                    local_chunk -= global_median[:, numpy.newaxis]
                else:
                    for i in list(params.probe['channel_groups'].keys()):
                        global_median = numpy.median(numpy.take(local_chunk, shank_channels[i], axis=1), 1)
                        local_chunk[:, shank_channels[i]] -= global_median[:, numpy.newaxis]

            if do_remove_ground:
                if nb_shanks == 1:
                    ground = local_chunk[:, common_ground[channel_group]]
                    local_chunk -= ground[:, numpy.newaxis]
                else:
                    for i in list(params.probe['channel_groups'].keys()):
                        ground = local_chunk[:, common_ground[i]]
                        local_chunk[:, shank_channels[i]] -= ground[:, numpy.newaxis]

            if data_file_in != data_file_out:
                if data_file_in.is_stream:
                    g_offset = t_offset
                else:
                    g_offset = t_offset - data_file_in.t_start
            else:
                g_offset = t_offset


            data_file_out.set_data(g_offset, local_chunk)

        sys.stderr.flush()

        comm.Barrier()

        if flag_saturation:
            saturation_times.flush()
            os.fsync(saturation_times.fileno())
            saturation_times.close()

            saturation_values.flush()
            os.fsync(saturation_values.fileno())
            saturation_values.close()
            
            saturation_channels.flush()
            os.fsync(saturation_channels.fileno())
            saturation_channels.close()

            # We need to gather the results into a single file

            if comm.rank == 0:
                collect_saturation(comm.size, params, erase=True)

        sys.stderr.flush()
        comm.Barrier()

    def compute_artefacts(data_file):
        """
        Compute artefact locations based on the [triggers] section of the params file.

        Parameters
        ----------
        data_file :

        Return
        ------
        dict
            A dictionary with the location of the artefacts
        """

        trig_in_ms = params.getboolean('triggers', 'trig_in_ms')
        artefacts = numpy.loadtxt(params.get('triggers', 'trig_file'), comments=['#', '//'])
        windows = numpy.loadtxt(params.get('triggers', 'trig_windows'), comments=['#', '//'])
        make_plots = params.get('triggers', 'make_plots')
        plot_path = os.path.join(params.get('data', 'file_out_suff'), 'plots')

        if len(windows.shape) == 1:
            windows = windows.reshape(1, 2)

        if len(artefacts.shape) == 1:
            artefacts = artefacts.reshape(1, 2)

        if trig_in_ms:
            if comm.rank == 0:
                print_and_log(['Artefact times are read in ms'], 'debug', logger)
            artefacts[:, 1] *= numpy.int64(data_file.sampling_rate*1e-3)
            windows[:, 1] *= numpy.int64(data_file.sampling_rate*1e-3)
        else:
            if comm.rank == 0:
                print_and_log(['Artefact times are read in timesteps'], 'debug', logger)

        artefacts = artefacts.astype(numpy.int64)
        windows = windows.astype(numpy.int64)

        nb_stimuli = len(numpy.unique(artefacts[:, 0]))
        mytest = numpy.all(numpy.in1d(numpy.unique(artefacts[:, 0]), numpy.unique(windows[:, 0])))

        if not mytest:
            if comm.rank == 0:
                print_and_log(['Error in the trigger file: not all artefacts are defined'], 'error', logger)
            sys.exit(0)

        all_labels = artefacts[:, 0]
        all_times = artefacts[:, 1]
        mask = (all_times >= 0) & (all_times + numpy.max(windows[:, 1]) < data_file.t_stop)
        all_times = numpy.compress(mask, all_times)
        all_labels = numpy.compress(mask, all_labels)

        local_labels = numpy.unique(all_labels)[comm.rank::comm.size]

        if comm.rank == 0:
            to_write = ["Computing averaged artefacts from %d stimuli" % nb_stimuli]
            print_and_log(to_write, 'default', logger)
            if not os.path.exists(plot_path):
                os.makedirs(plot_path)
            local_labels = get_tqdm_progressbar(params, local_labels)

        comm.Barrier()
        # First we need to get the average artefacts
        art_dict = {}
        for count, artefact in enumerate(local_labels):
            indices = numpy.where(all_labels == artefact)[0].astype(numpy.uint32)
            tmp = numpy.where(windows[:, 0] == artefact)[0]
            tau = numpy.int64(windows[tmp, 1])
            pspikes = all_times[indices]
            times = numpy.sort(numpy.random.permutation(pspikes)[:500])
            if len(numpy.where(numpy.diff(times) < tau)[0]) > 0:
                if comm.rank == 0:
                    print_and_log(['Stimulation times for artefact %d are too close!' % artefact], 'error', logger)
                sys.exit(0)

            art_dict[artefact] = get_artefact(params, times, tau, nodes)
            if make_plots not in ['None', '']:
                save = [plot_path, '%d.%s' % (artefact, make_plots)]
                plot.view_artefact(art_dict[artefact], save=save)

        sys.stderr.flush()
        return art_dict

    def remove_artefacts(data_file, art_dict):
        """
        Remove artefact times based on the [triggers] section of the params file.

        Parameters
        ----------

        data_file :

        art_dict : dict
            a dictionary with the artefact times.
        """

        trig_in_ms = params.getboolean('triggers', 'trig_in_ms')
        artefacts = numpy.loadtxt(params.get('triggers', 'trig_file'), comments=['#', '//'])
        windows = numpy.loadtxt(params.get('triggers', 'trig_windows'), comments=['#', '//'])
        make_plots = params.get('triggers', 'make_plots')
        plot_path = os.path.join(params.get('data', 'file_out_suff'), 'plots')

        if len(windows.shape) == 1:
            windows = windows.reshape(1, 2)

        if len(artefacts.shape) == 1:
            artefacts = artefacts.reshape(1, 2)

        if trig_in_ms:
            if comm.rank == 0:
                print_and_log(['Artefact times are read in ms'], 'debug', logger)
            artefacts[:, 1] *= numpy.int64(data_file.sampling_rate * 1e-3)
            windows[:, 1] *= numpy.int64(data_file.sampling_rate * 1e-3)
        else:
            if comm.rank == 0:
                print_and_log(['Artefact times are read in timesteps'], 'debug', logger)

        artefacts = artefacts.astype(numpy.int64)
        windows = windows.astype(numpy.int64)
        nb_stimuli = len(numpy.unique(artefacts[:, 0]))
        mytest = numpy.all(numpy.in1d(numpy.unique(artefacts[:, 0]), numpy.unique(windows[:, 0])))

        if not mytest:
            if comm.rank == 0:
                print_and_log(['Error in the trigger files: not all artefacts are defined'], 'error', logger)
            sys.exit(0)

        all_labels = artefacts[:, 0]
        all_times = artefacts[:, 1]
        local_labels = numpy.unique(all_labels)[comm.rank::comm.size]

        mask = numpy.in1d(all_labels, local_labels)
        all_times = numpy.compress(mask, all_times)
        all_labels = numpy.compress(mask, all_labels)

        mask = (all_times >= 0) & (all_times < data_file.t_stop)
        all_times = numpy.compress(mask, all_times)
        all_labels = numpy.compress(mask, all_labels)

        if comm.rank == 0:
            to_write = ["Removing artefacts from %d stimuli" % nb_stimuli]
            print_and_log(to_write, 'default', logger)
            all_times = get_tqdm_progressbar(params, all_times)

        comm.Barrier()

        for count, time in enumerate(all_times):

            label = all_labels[count]
            tmp = numpy.where(windows[:, 0] == label)[0][0]
            tau = numpy.int64(windows[tmp, 1])

            if (data_file.t_stop - time) < tau:
                tau = max_offset - time

            local_chunk = data_file.get_snippet(time, tau)
            for idx, i in enumerate(nodes):
                local_chunk[:, i] -= art_dict[label][idx, :tau]
            data_file.set_data(time, local_chunk)

        comm.Barrier()
        sys.stderr.flush()

    if comm.rank == 0:
        print_and_log(['Initializing the filtering step...'], 'debug', logger)

    if params.getboolean('data', 'overwrite'):
        if comm.rank == 0:
            print_and_log(['Reading the input file...'], 'debug', logger)

        data_file_in = params.get_data_file()
        data_file_out = data_file_in
    else:
        if comm.rank == 0:
            print_and_log(['Overwrite is set to False, so creating a new datafile...'], 'debug', logger)

        if comm.rank == 0:
            print_and_log(['Reading the input file...'], 'debug', logger)

        if os.path.exists(params.get('data', 'data_file_no_overwrite')):
            has_been_created = True
        else:
            has_been_created = False

        if not has_been_created and (filter_done or median_done or artefacts_done):
            if comm.rank == 0:
                print_and_log(['The filtering is done but file not present. See no_edits section'], 'error', logger)
            sys.exit(0)

        if not has_been_created:
            data_file_in = params.get_data_file(source=True, has_been_created=has_been_created)
        else:
            data_file_in = params.get_data_file(source=False, has_been_created=has_been_created)

        if comm.rank == 0:
            print_and_log(['Reading the output file and allocating ressources...'], 'debug', logger)

        description = data_file_in.get_description()
        description['data_dtype'] = 'float32'
        description['dtype_offset'] = 0
        description['data_offset'] = 0

        comm.Barrier()
        data_file_out = params.get_data_file(is_empty=not has_been_created, params=description)

        if comm.rank == 0:
            print_and_log(['Allocating space for filtered files...'], 'debug', logger)

        if not has_been_created:
            data_file_out.allocate(shape=data_file_in.shape)

        comm.Barrier()

    if clean_artefact:
        if not (os.path.exists(params.get('triggers', 'trig_file')) and os.path.exists(params.get('triggers', 'trig_windows'))):
            if comm.rank == 0:
                print_and_log(['trig_file or trig_windows file can not be found'], 'error', logger)
            sys.exit(0)

    to_write = []

    if do_filter and filter_done:
        do_filter = False
        to_write += ["Filtering has already been done"]
    if remove_median and median_done:
        remove_median = False
        to_write += ["Median over all channels has already been removed"]
    if remove_ground and ground_done:
        remove_ground = False
        to_write += ["Common ground %s has alread been subtracted" % common_ground]

    if comm.rank == 0 and len(to_write) > 0:
        print_and_log(to_write, 'info', logger)

    if params.getboolean('data', 'overwrite'):
        data_file_in.open(mode='r+')
    else:
        data_file_in.open(mode='r')
        data_file_out.open(mode='r+')

    if do_filter or remove_median or remove_ground:
        if comm.rank == 0:
            if do_filter:
                params.write('noedits', 'filter_done', 'Started')
            if remove_median:
                params.write('noedits', 'median_done', 'Started')
            if remove_ground:
                params.write('noedits', 'ground_done', 'Started')
        filter_file(data_file_in, data_file_out, do_filter, remove_median, remove_ground)

    if comm.rank == 0:
        if do_filter:
            params.write('noedits', 'filter_done', 'True')
        if remove_median:
            params.write('noedits', 'median_done', 'True')
        if remove_ground:
            params.write('noedits', 'ground_done', 'True')

    if clean_artefact and artefacts_done:
        clean_artefact = False
        if comm.rank == 0:
            print_and_log(['Artefacts have already been removed'], 'debug', logger)

    if clean_artefact:
        art_dict = compute_artefacts(data_file_in)
        if comm.rank == 0:
            params.write('noedits', 'artefacts_done', 'Started')
        remove_artefacts(data_file_out, art_dict)

    if comm.rank == 0:
        if clean_artefact:
            params.write('noedits', 'artefacts_done', 'True')

    data_file_in.close()
    if not params.getboolean('data', 'overwrite'):
        data_file_out.close()

    comm.Barrier()
def extract_extra_spikes_(params):
    """Detect spikes from the extracellular traces"""

    data_file = params.data_file
    data_file.open()
    dist_peaks = params.getint('detection', 'dist_peaks')
    spike_thresh = params.getfloat('detection', 'spike_thresh')
    template_shift = params.getint('detection', 'template_shift')
    alignment = params.getboolean('detection', 'alignment')
    do_temporal_whitening = params.getboolean('whitening', 'temporal')
    do_spatial_whitening = params.getboolean('whitening', 'spatial')
    safety_time = params.getint('whitening', 'safety_time')
    safety_space = params.getboolean('clustering', 'safety_space')
    chunk_size = params.getint('data', 'chunk_size')
    jitter_range = params.getint('detection', 'jitter_range')
    template_shift_2 = template_shift + jitter_range
    # chunk_size = params.getint('whitening', 'chunk_size')
    N_total = params.nb_channels
    file_out_suff = params.get('data', 'file_out_suff')

    if do_spatial_whitening:
        spatial_whitening = io.load_data(params, 'spatial_whitening')
    if do_temporal_whitening:
        temporal_whitening = io.load_data(params, 'temporal_whitening')

    # mpi_file = MPI.File()
    # mpi_input = mpi_file.Open(comm, data_filename, MPI.MODE_RDONLY)
    nb_chunks, last_chunk_len = data_file.analyze(chunk_size)
    nodes, _ = get_nodes_and_edges(params)
    N_elec = params.getint('data', 'N_e')

    extra_medians, extra_mads = extract_extra_thresholds(params)

    if comm.rank == 0:
        # Save medians and median absolute deviations to BEER file.
        path = "{}.beer.hdf5".format(file_out_suff)
        beer_file = h5py.File(path, 'a', libver='earliest')
        # # Save medians.
        extra_medians_key = "extra_medians"
        if extra_medians_key in list(beer_file.keys()):
            beer_file.pop(extra_medians_key)
        beer_file.create_dataset(extra_medians_key, data=extra_medians)
        # # Save median absolute deviations.
        extra_mads_key = "extra_mads"
        if extra_mads_key in list(beer_file.keys()):
            beer_file.pop(extra_mads_key)
        beer_file.create_dataset(extra_mads_key, data=extra_mads)
        beer_file.close()

    def extract_chunk_spikes(gidx, extra_thresh, valley=True):
        """Detect spikes from a chunk of the extracellular traces"""

        loc_chunk, t_offset = data_file.get_data(gidx, chunk_size, nodes=nodes)
        loc_shape = len(loc_chunk)

        # Whiten signal.
        if do_spatial_whitening:
            loc_chunk = numpy.dot(loc_chunk, spatial_whitening)
        if do_temporal_whitening:
            loc_chunk = scipy.ndimage.filters.convolve1d(loc_chunk,
                                                         temporal_whitening,
                                                         axis=0,
                                                         mode='constant')

        # TODO uncomment or remove temporary zone.
        # # For each electrode, center traces by removing the medians.
        # extra_medians = numpy.median(loc_chunk, axis=0)
        # loc_chunk = loc_chunk - extra_medians
        # TODO end temporary zone

        # Pre-allocation for results.
        peak_times = N_elec * [None]
        peak_channels = N_elec * [None]
        # For each electrode.
        for e in range(N_elec):
            # Extract the peaks of the current chunk.
            threshold = extra_thresh * extra_mads[e]
            if valley is True:
                peak_times[e] = scipy.signal.find_peaks(-local_chunk[:, e],
                                                        height=threshold,
                                                        distance=dist_peaks)[0]
            else:
                peak_times[e] = scipy.signal.find_peaks(local_chunk[:, e],
                                                        height=threshold,
                                                        distance=dist_peaks)[0]

            peak_channels[e] = e * numpy.ones(peak_times[e].size,
                                              dtype='int64')

            peak_values = loc_chunk[peak_times[e], e]
            if valley:
                peak_indices = numpy.where(-10.0 * threshold <= peak_values)[0]
            else:
                peak_indices = numpy.where(peak_values <= +10.0 * threshold)[0]
            peak_times[e] = peak_times[e][peak_indices]
            peak_channels[e] = peak_channels[e][peak_indices]

        peak_times = numpy.concatenate(peak_times)
        peak_channels = numpy.concatenate(peak_channels)
        # Remove the useless borders.
        if alignment:
            loc_borders = (2 * template_shift, loc_shape - 2 * template_shift)
        else:
            loc_borders = (template_shift, loc_shape - template_shift)
        peak_flags = (loc_borders[0] <= peak_times) & (peak_times <
                                                       loc_borders[1])
        peak_times = numpy.compress(peak_flags, peak_times)
        peak_channels = numpy.compress(peak_flags, peak_channels)
        # Filter unique peak times.
        loc_peak_times = numpy.unique(peak_times)
        # TODO remove debug zone.
        # if gidx < 1:
        #     numpy.save("tmp/loc_peak_times_{}_{}_.npy".format(gidx, int(extra_thresh)), loc_peak_times)
        # TODO end debug zone.
        n_times = len(loc_peak_times)
        loc_peak_flags = numpy.zeros(n_times, dtype='bool')
        loc_peak_elecs = numpy.zeros(n_times, dtype='int64')
        loc_peak_values = numpy.zeros(n_times, dtype='float32')

        if 0 < len(loc_peak_times):
            diff_times = loc_peak_times[-1] - loc_peak_times[0]
            all_times = numpy.zeros((N_elec, diff_times + 1), dtype='bool')
            min_times = numpy.maximum(
                loc_peak_times - loc_peak_times[0] - safety_time, 0)
            max_times = numpy.minimum(
                loc_peak_times - loc_peak_times[0] + safety_time + 1,
                diff_times)
            # Shuffle peaks.
            # TODO clean temporary zone.
            # argmax_peak = numpy.random.permutation(numpy.arange(n_times))
            if valley:
                for i, loc_peak_time in enumerate(loc_peak_times):
                    loc_peak_values[i] = numpy.amin(
                        loc_chunk[loc_peak_time, :])
                argmax_peak = numpy.argsort(loc_peak_values)
            else:
                for i, loc_peak_time in enumerate(loc_peak_times):
                    loc_peak_values[i] = numpy.amax(
                        loc_chunk[loc_peak_time, :])
                argmax_peak = numpy.argsort(loc_peak_values)
                argmes_peak = argmax_peak[::-1]
            # TODO end temporary zone.
            all_indices = loc_peak_times[argmax_peak]
            # Select peaks with spatio-temporal masks.
            for peak_index, peak_time in zip(argmax_peak, all_indices):
                # Select electrode showing lowest amplitude.
                if valley:
                    elec = numpy.argmin(loc_chunk[peak_time, :])
                else:
                    elec = numpy.argmax(loc_chunk[peak_time, :])
                _, neighs = get_neighbors(params, chan=elec)
                if safety_space:
                    mslice = all_times[
                        neighs, min_times[peak_index]:max_times[peak_index]]
                else:
                    mslice = all_times[
                        elec, min_times[peak_index]:max_times[peak_index]]
                is_local_min = (elec in peak_channels[peak_times == peak_time])
                if is_local_min and not mslice.any():
                    loc_peak_flags[peak_index] = True
                    loc_peak_elecs[peak_index] = elec
                    if valley:
                        loc_peak_values[peak_index] = -loc_chunk[peak_time,
                                                                 elec]
                    else:
                        loc_peak_values[peak_index] = loc_chunk[peak_time,
                                                                elec]
                    if safety_space:
                        all_times[
                            neighs,
                            min_times[peak_index]:max_times[peak_index]] = True
                        # all_times[elec, min_times[peak_index]:max_times[peak_index]] = True
                    else:
                        all_times[
                            elec,
                            min_times[peak_index]:max_times[peak_index]] = True
        loc_peak_times = numpy.compress(loc_peak_flags, loc_peak_times)
        loc_peak_elecs = numpy.compress(loc_peak_flags, loc_peak_elecs)
        loc_peak_values = numpy.compress(loc_peak_flags, loc_peak_values)

        # TODO remove debug zone.
        # if gidx < 1:
        #     numpy.save("tmp/loc_peak_times_{}_{}.npy".format(gidx, int(extra_thresh)), loc_peak_times)
        #     numpy.save("tmp/loc_peak_elecs_{}_{}.npy".format(gidx, int(extra_thresh)), loc_peak_elecs)
        #     numpy.save("tmp/loc_peak_values_{}_{}.npy".format(gidx, int(extra_thresh)), loc_peak_values)
        #     numpy.save("tmp/loc_chunk_{}_{}.npy".format(gidx, int(extra_thresh)), loc_chunk)
        # TODO end debug zone.

        return loc_peak_times + t_offset, loc_peak_elecs, loc_peak_values

    comm.Barrier()

    # Distribute chunks over CPUs.
    all_chunks = numpy.arange(nb_chunks)
    loc_all_chunks = all_chunks[comm.rank::comm.size]
    loc_nb_chunks = len(loc_all_chunks)

    if comm.rank == 0:
        print_and_log(["Collecting extracellular spikes..."],
                      level='default',
                      logger=logger)

    to_explore = list(range(comm.rank, nb_chunks, comm.size))

    if comm.rank == 0:
        to_explore = get_tqdm_progressbar(params, to_explore)

    extra_valley = True

    # TODO remove test zone (i.e. plots of extracellular spike times).
    # plot_extracted_extra_spikes(loc_all_chunks, data_len, mpi_input, data_dtype,
    #                             chunk_len, chunk_size, N_total, nodes,
    #                             extra_medians, extra_mads, k, params, safety_space,
    #                             safety_time)
    # sys.exit(1)
    # TODO end test zone.

    # Pre-allocation for results.
    times = len(loc_all_chunks) * [None]
    channels = len(loc_all_chunks) * [None]
    values = len(loc_all_chunks) * [None]

    data_file.open()
    # For each chunk attributed to the current CPU.
    for (count, gidx) in enumerate(to_explore):
        gidx = all_chunks[gidx]
        time, channel, value = extract_chunk_spikes(gidx,
                                                    spike_thresh,
                                                    valley=extra_valley)
        times[count] = time
        channels[count] = channel
        values[count] = value

    sys.stderr.flush()
    # Concatenate times, channels and values.
    times = numpy.hstack(times)
    channels = numpy.hstack(channels)
    values = numpy.hstack(values)

    data_file.close()
    comm.Barrier()

    # Gather times, channels and values.
    times = gather_array(times.astype(numpy.int64), comm, 0, dtype='int64')
    channels = gather_array(channels.astype(numpy.int64),
                            comm,
                            0,
                            dtype='int64')
    values = gather_array(values.astype(numpy.float32),
                          comm,
                          0,
                          dtype='float32')

    if comm.rank == 0:
        # Sort times, channels and values according to time.
        idx = numpy.argsort(times)
        times = times[idx]
        channels = channels[idx]
        values = values[idx]

        msg = [
            "Total number of extracellular spikes extracted: {}".format(
                channels.size),
        ]
        msg2 = [
            "Number of extracellular spikes extracted on channel {}: {}".
            format(i, channels[channels == i].size)
            for i in numpy.arange(N_elec)
        ]
        print_and_log(msg, level='info', logger=logger)
        print_and_log(msg2, level='debug', logger=logger)

        path = "{}.beer.hdf5".format(file_out_suff)
        beer_file = h5py.File(path, 'a', libver='earliest')
        group_name = "extra_spiketimes"
        if group_name in list(beer_file.keys()):
            beer_file.pop(group_name)
        beer_file.create_group(group_name)
        for i in numpy.arange(0, N_elec):
            mask = (channels == i)
            triggers = times[mask]
            beer_file.create_dataset("{}/elec_{}".format(group_name, i),
                                     data=triggers)
        group_name = "extra_spike_values"
        if group_name in list(beer_file.keys()):
            beer_file.pop(group_name)
        beer_file.create_group(group_name)
        for i in numpy.arange(0, N_elec):
            mask = (channels == i)
            data = values[mask]
            beer_file.create_dataset("{}/elec_{}".format(group_name, i),
                                     data=data)
        beer_file.close()

    comm.Barrier()

    return
Exemple #18
0
def main(params, nb_cpu, nb_gpu, use_gpu):
    numpy.random.seed(426236)
    #params         = detect_memory(params)
    parallel_hdf5 = get_parallel_hdf5_flag(params)
    logger = init_logging(params.logfile)
    logger = logging.getLogger('circus.extracting')
    #################################################################
    data_file = params.data_file
    N_e = params.getint('data', 'N_e')
    N_t = params.getint('detecton', 'N_t')
    N_total = params.nb_channels
    template_shift = params.getint('detection', 'template_shift')
    chunk_size = params.getint('data', 'chunk_size')
    file_out = params.get('data', 'file_out')
    file_out_suff = params.get('data', 'file_out_suff')
    do_temporal_whitening = params.getboolean('whitening', 'temporal')
    do_spatial_whitening = params.getboolean('whitening', 'spatial')
    nodes, edges = get_nodes_and_edges(params)
    safety_time = params.getint('extracting', 'safety_time')
    max_elts_temp = params.getint('extracting', 'max_elts')
    output_dim = params.getfloat('extracting', 'output_dim')
    noise_thr = params.getfloat('extracting', 'noise_thr')
    hdf5_compress = params.getboolean('data', 'hdf5_compress')
    blosc_compress = params.getboolean('data', 'blosc_compress')
    tmp_limits = params.get('fitting',
                            'amp_limits').replace('(',
                                                  '').replace(')',
                                                              '').split(',')
    amp_limits = map(float, tmp_limits)
    elt_count = 0
    inv_nodes = numpy.zeros(N_total, dtype=numpy.int32)
    inv_nodes[nodes] = numpy.argsort(nodes)
    #################################################################

    if comm.rank == 0:
        print_and_log(["Extracting templates from already found clusters..."],
                      'default', logger)

    thresholds = io.load_data(params, 'thresholds')
    basis_proj, basis_rec = io.load_data(params, 'basis')
    clusters, spiketimes, N_clusters = io.load_data(params, 'spike-cluster')
    inv_clusters = numpy.zeros(clusters.max() + 1, dtype=numpy.int32)
    inv_clusters[numpy.unique(clusters)] = numpy.argsort(
        numpy.unique(clusters))

    if use_gpu:
        import cudamat as cmt
        ## Need to properly handle multi GPU per MPI nodes?
        if nb_gpu > nb_cpu:
            gpu_id = int(comm.rank // nb_cpu)
        else:
            gpu_id = 0
        cmt.cuda_set_device(gpu_id)
        cmt.init()
        cmt.cuda_sync_threads()

    if do_spatial_whitening:
        spatial_whitening = io.load_data(params, 'spatial_whitening')
    if do_temporal_whitening:
        temporal_whitening = io.load_data(params, 'temporal_whitening')

    if use_gpu and do_spatial_whitening:
        spatial_whitening = cmt.CUDAMatrix(spatial_whitening,
                                           copy_on_host=False)

    result = {}
    for i in xrange(N_clusters):
        result['data_tmp_' + str(i)] = numpy.zeros(
            (0, N_e * basis_proj.shape[1]), dtype=numpy.float32)
        result['times_' + str(i)] = numpy.zeros(0, dtype=numpy.int32)

    nb_chunks, last_chunk_len = data_file.analyze(chunk_size)

    # I guess this is more relevant, to take signals from all over the recordings
    all_chunks = numpy.random.permutation(numpy.arange(nb_chunks))

    nb_templates = numpy.sum(
        comm.rank == numpy.mod(numpy.arange(N_clusters), comm.size))
    nb_elts = max_elts_temp * nb_templates

    to_explore = all_chunks

    if comm.rank == 0:
        to_explore = get_tqdm_progressbar(to_explore)

    for gidx in all_chunks:

        if (elt_count < nb_elts):
            #print "Node", comm.rank, "is analyzing chunk", gidx, "/", nb_chunks, " ..."
            local_chunk, t_offset = data_file.get_data(gidx,
                                                       chunk_size,
                                                       nodes=nodes)
            local_shape = len(local_chunk)

            if do_spatial_whitening:
                if use_gpu:
                    local_chunk = cmt.CUDAMatrix(local_chunk,
                                                 copy_on_host=False)
                    local_chunk = local_chunk.dot(spatial_whitening).asarray()
                else:
                    local_chunk = numpy.dot(local_chunk, spatial_whitening)
            if do_temporal_whitening:
                local_chunk = scipy.ndimage.filters.convolve1d(
                    local_chunk, temporal_whitening, axis=0, mode='constant')

            #print "Extracting the peaks..."
            idx = numpy.where((spiketimes >= gidx * chunk_size)
                              & (spiketimes < (gidx + 1) * chunk_size))[0]
            local_offset = t_offset
            local_peaktimes = spiketimes[idx] - local_offset

            #print "Removing the useless borders..."
            local_borders = (template_shift, chunk_size - template_shift)
            idx = (local_peaktimes >= local_borders[0]) & (local_peaktimes <
                                                           local_borders[1])
            local_peaktimes = local_peaktimes[idx]
            local_clusters = inv_clusters[clusters[idx]]

            if len(local_peaktimes) > 0:
                all_times = numpy.zeros(
                    (N_e, local_peaktimes[-1] - local_peaktimes[0] + 1),
                    dtype=numpy.bool)
                min_times = numpy.maximum(
                    local_peaktimes - local_peaktimes[0] - safety_time, 0)
                max_times = numpy.minimum(
                    local_peaktimes - local_peaktimes[0] + safety_time + 1,
                    local_peaktimes[-1] - local_peaktimes[0])

                n_times = len(local_peaktimes)
                argmax_peak = numpy.random.permutation(numpy.arange(n_times))
                clusters_id = local_clusters[argmax_peak]
                local_peaktimes = local_peaktimes[argmax_peak]

                #print "Selection of the peaks with spatio-temporal masks..."
                for idx in xrange(len(local_peaktimes)):

                    if elt_count == nb_elts:
                        break

                    temp = clusters_id[idx]

                    if numpy.mod(temp, comm.size) == comm.rank:

                        elec = numpy.argmin(local_chunk[local_peaktimes[idx]])
                        indices = inv_nodes[edges[nodes[elec]]]
                        myslice = all_times[indices,
                                            min_times[idx]:max_times[idx]]
                        peak = local_peaktimes[idx]
                        if not myslice.any():
                            if (len(result['data_tmp_' + str(temp)]) <
                                    max_elts_temp):
                                elt_count += 1
                                sub_mat = local_chunk[peak -
                                                      template_shift:peak +
                                                      template_shift + 1, :]
                                sub_mat = numpy.dot(basis_rec, sub_mat)
                                nx, ny = sub_mat.shape
                                sub_mat = sub_mat.reshape((1, nx * ny))
                                result['data_tmp_' + str(temp)] = numpy.vstack(
                                    (result['data_tmp_' + str(temp)], sub_mat))
                                to_add = numpy.array([peak + local_offset],
                                                     dtype=numpy.int32)
                                result['times_' +
                                       str(temp)] = numpy.concatenate(
                                           (result['times_' + str(temp)],
                                            to_add))
                            all_times[indices,
                                      min_times[idx]:max_times[idx]] = True

    total_nb_elts = 0
    for temp in xrange(N_clusters):
        total_nb_elts += len(result['data_tmp_' + str(temp)])

    gdata = gather_array(numpy.array([total_nb_elts], dtype=numpy.float32),
                         comm, 0)
    if comm.rank == 0:
        print_and_log([
            "Found %d spikes over %d requested" %
            (int(numpy.sum(gdata)), int(nb_elts))
        ], 'default', logger)

    #print "Spikes extracted in", time.time() - t_start, "s"

    comm.Barrier()

    local_nb_clusters = 0
    for temp in xrange(comm.rank, N_clusters, comm.size):
        if len(result['data_tmp_' + str(temp)]) > 0:
            local_nb_clusters += 1

    #print total_nb_clusters, "found in", time.time() - t_start, "s"
    gdata3 = gather_array(
        numpy.array([local_nb_clusters], dtype=numpy.float32), comm, 0)

    comm.Barrier()
    if comm.rank == 0:
        print_and_log(["Extracting the templates..."], 'default', logger)

    total_nb_clusters = int(
        comm.bcast(numpy.array([int(numpy.sum(gdata3))], dtype=numpy.int32),
                   root=0)[0])
    offsets = numpy.zeros(comm.size, dtype=numpy.int32)
    for i in xrange(comm.size - 1):
        offsets[i + 1] = comm.bcast(numpy.array([local_nb_clusters],
                                                dtype=numpy.int32),
                                    root=i)

    if parallel_hdf5:
        node_pad = numpy.sum(offsets[:comm.rank + 1])
        hfile = h5py.File(file_out_suff + '.templates.hdf5',
                          'w',
                          driver='mpio',
                          comm=comm,
                          libver='earliest')
        norms = hfile.create_dataset('norms',
                                     shape=(2 * total_nb_clusters, ),
                                     dtype=numpy.float32,
                                     chunks=True)
        electrodes = hfile.create_dataset('electrodes',
                                          shape=(total_nb_clusters, ),
                                          dtype=numpy.int32,
                                          chunks=True)
        amps_lims = hfile.create_dataset('limits',
                                         shape=(total_nb_clusters, 2),
                                         dtype=numpy.float32,
                                         chunks=True)
        g_count = node_pad
        g_offset = total_nb_clusters
    else:
        node_pad = 0
        hfile = h5py.File(file_out_suff + '.templates-%d.hdf5' % comm.rank,
                          'w',
                          libver='earliest')
        electrodes = hfile.create_dataset('electrodes',
                                          shape=(local_nb_clusters, ),
                                          dtype=numpy.int32,
                                          chunks=True)
        norms = hfile.create_dataset('norms',
                                     shape=(2 * local_nb_clusters, ),
                                     dtype=numpy.float32,
                                     chunks=True)
        amps_lims = hfile.create_dataset('limits',
                                         shape=(local_nb_clusters, 2),
                                         dtype=numpy.float32,
                                         chunks=True)
        g_count = 0
        g_offset = local_nb_clusters

    cfile = h5py.File(file_out_suff + '.clusters-%d.hdf5' % comm.rank,
                      'w',
                      libver='earliest')
    count_templates = node_pad

    temp_x = numpy.zeros(0, dtype=numpy.int32)
    temp_y = numpy.zeros(0, dtype=numpy.int32)
    temp_data = numpy.zeros(0, dtype=numpy.float32)

    to_explore = xrange(comm.rank, N_clusters, comm.size)

    if comm.rank == 0:
        to_explore = get_tqdm_progressbar(to_explore)

    for temp in to_explore:
        n_data = len(result['data_tmp_' + str(temp)])
        if n_data > 0:
            data = result['data_tmp_' + str(temp)].reshape(
                n_data, basis_proj.shape[1], N_e)
            first_component = numpy.median(data, axis=0)
            tmp_templates = numpy.dot(first_component.T, basis_rec)
            electrodes[g_count] = indices[tmpidx[0][0]]
            indices = inv_nodes[edges[nodes[electrodes[-1]]]]
            templates = numpy.zeros((N_e, N_t), dtype=numpy.float32)
            if shift > 0:
                templates[indices, shift:] = tmp_templates[:, :-shift]
            elif shift < 0:
                templates[indices, :shift] = tmp_templates[:, -shift:]
            else:
                templates[indices, :] = tmp_templates

            templates = templates.flatten()
            dx = templates.nonzero()[0].astype(numpy.int32)

            temp_x = numpy.concatenate((temp_x, dx))
            temp_y = numpy.concatenate(
                (temp_y,
                 count_templates * numpy.ones(len(dx), dtype=numpy.int32)))
            temp_data = numpy.concatenate((temp_data, templates[dx]))

            norms[g_count] = numpy.sqrt(
                numpy.sum(templates.flatten()**2) / (N_e * N_t))

            x, y, z = data.shape
            data_flat = data.reshape(x, y * z)
            first_flat = first_component.reshape(y * z, 1)
            amplitudes = numpy.dot(data_flat, first_flat)
            amplitudes /= numpy.sum(first_flat**2)
            for i in xrange(x):
                data_flat[i, :] -= amplitudes[i] * first_flat[:, 0]

            variations = 10 * numpy.median(
                numpy.abs(amplitudes - numpy.median(amplitudes)))
            physical_limit = noise_thr * (
                -thresholds[indices[tmpidx[0][0]]]) / tmp_templates.min()
            amp_min = max(physical_limit,
                          numpy.median(amplitudes) - variations)
            amp_max = min(amp_limits[1], numpy.median(amplitudes) + variations)
            amps_lims[g_count] = [amp_min, amp_max]

            if len(data_flat) > 1:
                pca = PCA(1)
                res_pca = pca.fit_transform(data_flat.astype(numpy.double))
                second_component = pca.components_.T.astype(
                    numpy.float32).reshape(y, z)
            else:
                second_component = data_flat.reshape(y, z) / numpy.sum(
                    data_flat**2)

            tmp_templates = numpy.dot(second_component.T, basis_rec)
            offset = total_nb_clusters + count_templates
            sub_templates = numpy.zeros((N_e, N_t), dtype=numpy.float32)
            if shift > 0:
                sub_templates[indices, shift:] = tmp_templates[:, :-shift]
            elif shift < 0:
                sub_templates[indices, :shift] = tmp_templates[:, -shift:]
            else:
                sub_templates[indices, :] = tmp_templates

            sub_templates = sub_templates.flatten()
            dx = sub_templates.nonzero()[0].astype(numpy.int32)

            temp_x = numpy.concatenate((temp_x, dx))
            temp_y = numpy.concatenate(
                (temp_y, offset * numpy.ones(len(dx), dtype=numpy.int32)))
            temp_data = numpy.concatenate((temp_data, sub_templates[dx]))

            norms[g_count + g_offset] = numpy.sqrt(
                numpy.sum(sub_templates.flatten()**2) / (N_e * N_t))

            count_templates += 1
            g_count += 1

        io.write_datasets(cfile,
                          to_write,
                          result,
                          ielec,
                          compress=hdf5_compress)

    #At the end we should have a templates variable to store.
    cfile.close()
    del result, templates, amps_lims
    comm.Barrier()

    #We need to gather the sparse arrays
    temp_x = gather_array(temp_x, comm, dtype='int32', compress=blosc_compress)
    temp_y = gather_array(temp_y, comm, dtype='int32', compress=blosc_compress)
    temp_data = gather_array(temp_data, comm, compress=blosc_compress)

    if parallel_hdf5:
        if comm.rank == 0:
            rs = [
                h5py.File(file_out_suff + '.clusters-%d.hdf5' % i,
                          'r',
                          libver='earliest') for i in xrange(comm.size)
            ]
            cfile = h5py.File(file_out_suff + '.clusters.hdf5',
                              'w',
                              libver='earliest')
            io.write_datasets(cfile, ['electrodes'],
                              {'electrodes': electrodes[:]},
                              compress=hdf5_compress)
            for i in xrange(comm.size):
                for j in range(i, N_e, comm.size):
                    io.write_datasets(cfile,
                                      to_write,
                                      rs[i],
                                      j,
                                      compress=hdf5_compress)
                rs[i].close()
                os.remove(file_out_suff + '.clusters-%d.hdf5' % i)
            cfile.close()
        hfile.close()
    else:
        hfile.close()
        if comm.rank == 0:
            ts = [
                h5py.File(file_out_suff + '.templates-%d.hdf5' % i,
                          'r',
                          libver='earliest') for i in xrange(comm.size)
            ]
            rs = [
                h5py.File(file_out_suff + '.clusters-%d.hdf5' % i,
                          'r',
                          libver='earliest') for i in xrange(comm.size)
            ]
            result = {}
            hfile = h5py.File(file_out_suff + '.templates.hdf5',
                              'w',
                              libver='earliest')
            cfile = h5py.File(file_out_suff + '.clusters.hdf5',
                              'w',
                              libver='earliest')
            electrodes = hfile.create_dataset('electrodes',
                                              shape=(total_nb_clusters, ),
                                              dtype=numpy.int32,
                                              chunks=True)
            norms = hfile.create_dataset('norms',
                                         shape=(2 * total_nb_clusters, ),
                                         dtype=numpy.float32,
                                         chunks=True)
            amplitudes = hfile.create_dataset('limits',
                                              shape=(total_nb_clusters, 2),
                                              dtype=numpy.float32,
                                              chunks=True)
            count = 0
            for i in xrange(comm.size):
                loc_temp = ts[i].get('templates')
                middle = loc_temp.shape[2] // 2
                norms[count:count + middle] = loc_norms[:middle]
                norms[n_clusters + count:n_clusters + count +
                      middle] = loc_norms[middle:]
                electrodes[count:count + middle] = ts[i].get('electrodes')
                amplitudes[count:count + middle] = ts[i].get('limits')
                count += middle
                for j in range(i, N_e, comm.size):
                    io.write_datasets(cfile,
                                      to_write,
                                      rs[i],
                                      j,
                                      compress=hdf5_compress)
                ts[i].close()
                rs[i].close()
                os.remove(file_out_suff + '.templates-%d.hdf5' % i)
                os.remove(file_out_suff + '.clusters-%d.hdf5' % i)
            io.write_datasets(cfile, ['electrodes'],
                              {'electrodes': electrodes[:]},
                              compress=hdf5_compress)
            hfile.close()
            cfile.close()

    if comm.rank == 0:
        hfile = h5py.File(file_out_suff + '.templates.hdf5',
                          'r+',
                          libver='earliest')
        hfile.create_dataset('temp_x', data=temp_x)
        hfile.create_dataset('temp_y', data=temp_y)
        hfile.create_dataset('temp_data', data=temp_data)
        hfile.create_dataset('temp_shape',
                             data=numpy.array(
                                 [N_e, N_t, 2 * total_nb_clusters],
                                 dtype=numpy.int32))
        hfile.close()

    comm.Barrier()

    if comm.rank == 0:
        print_and_log(["Merging similar templates..."], 'default', logger)

    merged1 = algo.merging_cc(params, parallel_hdf5)

    comm.Barrier()
    if remove_mixture:
        if comm.rank == 0:
            print_and_log(["Removing mixtures..."], 'default', logger)
        merged2 = algo.delete_mixtures(params, parallel_hdf5)
    else:
        merged2 = [0, 0]

    if comm.rank == 0:
        print_and_log([
            "Number of global merges    : %d" % merged1[1],
            "Number of mixtures removed : %d" % merged2[1]
        ], 'info', logger)

    comm.Barrier()
    io.get_overlaps(params, erase=True, parallel_hdf5=parallel_hdf5)

    data_file.close()
Exemple #19
0
def main(params, nb_cpu, nb_gpu, use_gpu):
    # Part 1: Whitening
    numpy.random.seed(420)
    # params = detect_memory(params)
    _ = init_logging(params.logfile)
    logger = logging.getLogger('circus.whitening')
    #################################################################
    data_file = params.data_file
    N_e = params.getint('data', 'N_e')
    hdf5_compress = params.getboolean('data', 'hdf5_compress')
    N_total = params.nb_channels
    N_t = params.getint('detection', 'N_t')
    dist_peaks = params.getint('detection', 'dist_peaks')
    template_shift = params.getint('detection', 'template_shift')
    file_out_suff = params.get('data', 'file_out_suff')
    spike_thresh = params.getfloat('detection', 'spike_thresh')
    spike_width = params.getfloat('detection', 'spike_width')
    matched_filter = params.getboolean('detection', 'matched-filter')
    matched_thresh = params.getfloat('detection', 'matched_thresh')
    fudge = params.getfloat('whitening', 'fudge')
    sign_peaks = params.get('detection', 'peaks')
    do_temporal_whitening = params.getboolean('whitening', 'temporal')
    do_spatial_whitening = params.getboolean('whitening', 'spatial')
    ignore_spikes = params.getboolean('whitening', 'ignore_spikes')
    chunk_size = detect_memory(params, whitening=True)
    plot_path = os.path.join(params.get('data', 'file_out_suff'), 'plots')
    nodes, edges = get_nodes_and_edges(params)
    safety_time = params.getint('whitening', 'safety_time')
    safety_space = params.getboolean('whitening', 'safety_space')
    sort_waveforms = params.getboolean('whitening', 'sort_waveforms')
    nb_temp_white = min(max(20, comm.size), N_e)
    max_silence_1 = int(20 * params.rate // comm.size)
    max_silence_2 = 5000
    inv_nodes = numpy.zeros(N_total, dtype=numpy.int32)
    inv_nodes[nodes] = numpy.arange(len(nodes))
    jitter_range = params.getint('detection', 'jitter_range')
    template_shift_2 = template_shift + jitter_range
    use_hanning = params.getboolean('detection', 'hanning')
    rejection_threshold = params.getfloat('detection', 'rejection_threshold')
    noise_window = params.getint('detection', 'noise_time')
    data_file.open()
    #################################################################

    if use_hanning:
        hanning_filter = numpy.hanning(N_t)

    if comm.rank == 0:
        print_and_log(
            ["Analyzing data to get whitening matrices and thresholds..."],
            'default', logger)

    nodes_indices = {}
    for elec in numpy.arange(N_e):
        nodes_indices[elec] = inv_nodes[edges[nodes[elec]]]

    if use_gpu:
        import cudamat as cmt
        # # Need to properly handle multi GPU per MPI nodes?
        if nb_gpu > nb_cpu:
            gpu_id = int(comm.rank // nb_cpu)
        else:
            gpu_id = 0
        cmt.cuda_set_device(gpu_id)
        cmt.init()
        cmt.cuda_sync_threads()

    nb_chunks, last_chunk_len = data_file.analyze(chunk_size)

    if nb_chunks < comm.size:

        res = io.data_stats(params, show=False)
        chunk_size = int(res * params.rate // comm.size)
        if comm.rank == 0:
            print_and_log(
                ["Too much cores, automatically resizing the data chunks"],
                'debug', logger)

        nb_chunks, last_chunk_len = data_file.analyze(chunk_size)

    # I guess this is more relevant, to take signals from all over the recordings.
    if nb_chunks > comm.size:
        all_chunks = numpy.random.permutation(
            numpy.arange(nb_chunks - 1, dtype=numpy.int32))
    else:
        all_chunks = numpy.random.permutation(
            numpy.arange(nb_chunks, dtype=numpy.int32))

    all_electrodes = numpy.random.permutation(N_e)

    numpy.random.seed(comm.rank)

    for gidx in [all_chunks[comm.rank]]:

        # print "Node", comm.rank, "is analyzing chunk", gidx,  "/", nb_chunks, " ..."
        local_chunk, t_offset = data_file.get_data(gidx,
                                                   chunk_size,
                                                   nodes=nodes)
        local_shape = len(local_chunk)

        # print "Node", comm.rank, "computes the median absolute deviations in a random chunk"
        thresholds = numpy.zeros(N_e, dtype=numpy.float32)
        for i in range(N_e):
            u = numpy.median(local_chunk[:, i], 0)
            thresholds[i] = numpy.median(numpy.abs(local_chunk[:, i] - u), 0)
        gdata = gather_array(thresholds, comm)
        if comm.rank == 0:
            gdata = gdata.reshape((comm.size, N_e))
            thresholds = numpy.mean(gdata, 0)
            bfile = h5py.File(file_out_suff + '.basis.hdf5',
                              'w',
                              libver='earliest')
            io.write_datasets(bfile, ['thresholds'],
                              {'thresholds': thresholds},
                              compression=hdf5_compress)
            bfile.close()
        comm.Barrier()
        thresholds = io.load_data(params, 'thresholds')

        local_borders = (template_shift, local_shape - template_shift)
        found_peaktimes = []

        if ignore_spikes:
            # Extracting the peaks.
            local_peaktimes = [np.empty(0, dtype=numpy.uint32)]
            for i in range(N_e):
                peaktimes = scipy.signal.find_peaks(numpy.abs(local_chunk[:,
                                                                          i]),
                                                    height=thresholds[i],
                                                    width=spike_width,
                                                    wlen=N_t)[0]
                peaktimes = peaktimes.astype(numpy.uint32)

                # print "Removing the useless borders..."
                idx = (peaktimes >= local_borders[0]) & (peaktimes <
                                                         local_borders[1])
                peaktimes = numpy.compress(idx, peaktimes)

                found_peaktimes.append(peaktimes)
        else:
            for i in range(N_e):
                found_peaktimes.append(numpy.zeros(0, dtype=numpy.uint32))

        all_peaktimes = numpy.concatenate(found_peaktimes)
        local_peaktimes = numpy.unique(all_peaktimes)

        if len(local_peaktimes) > 0:

            diff_times = local_peaktimes[-1] - local_peaktimes[0]
            all_times = numpy.zeros((N_e, diff_times + 1), dtype=numpy.bool)
            padded_peaks = (local_peaktimes - local_peaktimes[0]).astype(
                numpy.int32)
            min_times = numpy.maximum(padded_peaks - safety_time, 0)
            max_times = numpy.minimum(padded_peaks + safety_time + 1,
                                      diff_times + 1)

            test_extremas = numpy.zeros((N_e, diff_times + 1),
                                        dtype=numpy.bool)
            for i in range(N_e):
                test_extremas[i,
                              found_peaktimes[i] - local_peaktimes[0]] = True

            argmax_peak = numpy.random.permutation(
                numpy.arange(len(local_peaktimes)))
            all_idx = numpy.take(local_peaktimes, argmax_peak)

            # print "Selection of the peaks with spatio-temporal masks..."
            for idx, peak in zip(argmax_peak, all_idx):

                all_elecs = numpy.where(test_extremas[:, peak -
                                                      local_peaktimes[0]])[0]
                data = local_chunk[peak, all_elecs]
                elec = all_elecs[numpy.argmax(numpy.abs(data))]
                indices = nodes_indices[elec]
                if safety_space:
                    all_times[indices, min_times[idx]:max_times[idx]] = True
                else:
                    all_times[elec, min_times[idx]:max_times[idx]] = True
        else:
            all_times = numpy.zeros((N_e, len(local_chunk)), dtype=numpy.bool)

    if do_temporal_whitening:

        local_res_temp = []

        for elec in all_electrodes[numpy.arange(comm.rank, nb_temp_white,
                                                comm.size)]:
            res = numpy.zeros((0, N_t), dtype=numpy.float32)
            scount = 0
            indices = nodes_indices[elec]
            all_times_elec = numpy.any(numpy.take(all_times, indices, axis=0),
                                       0)
            esubset = numpy.where(all_times_elec == False)[0]
            bound = len(esubset) - N_t
            while (scount < bound) and (len(res) < max_silence_2):
                myslice = esubset[scount:scount + N_t]
                if numpy.all((myslice - esubset[scount]) == numpy.arange(N_t)):
                    scount += N_t
                    res = numpy.vstack((res, local_chunk[myslice, elec]))
                else:
                    scount += 1
            if len(res) > 5:
                local_res_temp += [numpy.cov(res.T)]

        nb_elecs = numpy.array([len(local_res_temp)], dtype=numpy.float32)
        local_res_temp = numpy.array(local_res_temp, dtype=numpy.float32)
        if len(local_res_temp) == 0:
            local_res_temp = numpy.zeros(0, dtype=numpy.float32)
        else:
            local_res_temp = numpy.sum(local_res_temp, 0)
        all_res_temp = gather_array(local_res_temp.ravel(), comm, 0, 1)
        all_elecs = gather_array(nb_elecs, comm, 0, 1)

    if do_spatial_whitening:

        local_res_spac = numpy.zeros((N_e, N_e), dtype=numpy.float32)
        local_silences = []

        for elec in numpy.arange(comm.rank, N_e, comm.size):
            indices = nodes_indices[elec]
            all_times_elec = numpy.any(numpy.take(all_times, indices, axis=0),
                                       0)
            esubset = numpy.where(all_times_elec == False)[0]
            local_data = local_chunk[esubset][:, indices]
            local_whitening = get_whitening_matrix(
                local_data, fudge=fudge).astype(numpy.float32)
            pos = numpy.where(elec == indices)[0]
            local_res_spac[elec, indices] = local_whitening[pos]
            local_silences += [len(esubset)]

        all_res_spac = gather_array(local_res_spac.ravel(), comm, 0, 1)
        all_silences = gather_array(
            numpy.array(local_silences, dtype=numpy.int32), comm, 0, 1,
            'uint32')

    if comm.rank == 0:

        to_write = {}

        if do_temporal_whitening:
            try:
                nb_silences = numpy.sum(all_elecs > 0)
                all_res_temp = all_res_temp.reshape((nb_silences, N_t**2))
            except Exception:
                print_and_log([
                    "No silent periods detected: something wrong with the parameters?"
                ], 'error', logger)
            all_res_temp = numpy.sum(all_res_temp, 0)
            all_res_temp = all_res_temp.reshape(
                (N_t, N_t)) / numpy.sum(all_elecs)
            temporal_whitening = get_whitening_matrix(
                all_res_temp.astype(numpy.double),
                fudge=1e-3)[template_shift].astype(numpy.float32)
            temporal_whitening /= temporal_whitening.sum()
            to_write['temporal'] = temporal_whitening
            have_nans = numpy.sum(numpy.isnan(temporal_whitening))

            if have_nans > 0:
                temporal_whitening = numpy.zeros(N_t, dtype=numpy.float32)
                temporal_whitening[N_t // 2] = 1
                to_write['temporal'] = temporal_whitening
                print_and_log(
                    ["Disabling temporal whitening because of NaNs found"],
                    'info', logger)

        if do_spatial_whitening:
            all_res_spac = all_res_spac.reshape(comm.size, N_e, N_e)
            spatial_whitening = numpy.sum(all_res_spac, 0)
            to_write['spatial'] = spatial_whitening

            if ignore_spikes:
                print_and_log([
                    "Found %gs without spikes to compute the whitening matrix..."
                    % (numpy.mean(all_silences) / params.rate)
                ], 'default', logger)
            else:
                print_and_log([
                    "Found %gs to compute the whitening matrix..." %
                    (numpy.mean(all_silences) / params.rate)
                ], 'default', logger)

            have_nans = numpy.sum(numpy.isnan(spatial_whitening))

            if have_nans > 0:
                spatial_whitening = numpy.eye(spatial_whitening.shape[0],
                                              dtype=numpy.float32)
                to_write['spatial'] = spatial_whitening
                print_and_log(
                    ["Disabling spatial whitening because of NaNs found"],
                    'info', logger)

        bfile = h5py.File(file_out_suff + '.basis.hdf5',
                          'r+',
                          libver='earliest')
        io.write_datasets(bfile,
                          list(to_write.keys()),
                          to_write,
                          compression=hdf5_compress)
        bfile.close()

    comm.Barrier()

    if do_spatial_whitening or do_temporal_whitening:

        if comm.rank == 0:
            print_and_log(
                ["Because of whitening, need to recompute the thresholds..."],
                'default', logger)

        if do_spatial_whitening:
            spatial_whitening = io.load_data(params, 'spatial_whitening')
            if use_gpu:
                spatial_whitening = cmt.CUDAMatrix(spatial_whitening,
                                                   copy_on_host=False)
        if do_temporal_whitening:
            temporal_whitening = io.load_data(params, 'temporal_whitening')

        for gidx in [all_chunks[comm.rank]]:
            local_chunk, t_offset = data_file.get_data(gidx,
                                                       chunk_size,
                                                       nodes=nodes)
            local_shape = len(local_chunk)

            if do_spatial_whitening:
                if use_gpu:
                    local_chunk = cmt.CUDAMatrix(local_chunk,
                                                 copy_on_host=False)
                    local_chunk = local_chunk.dot(spatial_whitening).asarray()
                else:
                    local_chunk = numpy.dot(local_chunk, spatial_whitening)
            if do_temporal_whitening:
                local_chunk = scipy.ndimage.filters.convolve1d(
                    local_chunk, temporal_whitening, axis=0, mode='constant')

            thresholds = numpy.zeros(N_e, dtype=numpy.float32)
            for i in range(N_e):
                u = numpy.median(local_chunk[:, i], 0)
                thresholds[i] = numpy.median(numpy.abs(local_chunk[:, i] - u),
                                             0)
            gdata = gather_array(thresholds, comm)
            if comm.rank == 0:
                gdata = gdata.reshape((comm.size, N_e))
                thresholds = numpy.mean(gdata, 0)
                bfile = h5py.File(file_out_suff + '.basis.hdf5',
                                  'r+',
                                  libver='earliest')
                bfile.pop('thresholds')
                io.write_datasets(bfile, ['thresholds'],
                                  {'thresholds': thresholds},
                                  compression=hdf5_compress)
                bfile.close()
            comm.Barrier()

    # if comm.rank == 0:
    #     if not os.path.exists(plot_path):
    #         os.makedirs(plot_path)
    #     N_elec = min(int(numpy.sqrt(data_file.N_e)), 5)
    #     plot.view_fit(filename, t_start=0, t_stop=1, fit_on=False, square=True,
    #                   n_elec=N_elec, save=[plot_path, 'electrodes'])

    # Part 2: Basis
    numpy.random.seed(422)

    SHARED_MEMORY = get_shared_memory_flag(params)
    #################################################################
    file_out = params.get('data', 'file_out')
    alignment = params.getboolean('detection', 'alignment')
    over_factor = params.getint('detection', 'oversampling_factor')
    nb_jitter = params.getint('detection', 'nb_jitter')
    spike_thresh = params.getfloat('detection', 'spike_thresh')
    nodes, edges = get_nodes_and_edges(params)
    _, positions = get_nodes_and_positions(params)
    do_temporal_whitening = params.getboolean('whitening', 'temporal')
    do_spatial_whitening = params.getboolean('whitening', 'spatial')
    use_barycenter = params.getboolean('detection', 'use_barycenter')
    if matched_filter:
        chunk_size = detect_memory(params, whitening=True)
    else:
        chunk_size = detect_memory(params)
    safety_time = params.getint('whitening', 'safety_time')
    max_elts_elec = params.getint('whitening', 'max_elts')
    output_dim = params.getfloat('whitening', 'output_dim')
    inv_nodes = numpy.zeros(N_total, dtype=numpy.int32)
    inv_nodes[nodes] = numpy.arange(len(nodes))
    smoothing_factor = params.getfloat('detection', 'smoothing_factor')
    if sign_peaks == 'both':
        max_elts_elec *= 2
    nb_elts = int(
        params.getfloat('whitening', 'nb_elts') * N_e * max_elts_elec)

    weird_thresh = params.get('detection', 'weird_thresh')
    if weird_thresh != '':
        ignore_artefacts = True
        weird_thresh = io.load_data(params, 'weird-thresholds')
    else:
        ignore_artefacts = False

    ignore_dead_times = params.getboolean('triggers', 'ignore_times')
    if ignore_dead_times:
        if SHARED_MEMORY:
            all_dead_times, mpi_memory_3 = get_dead_times(params)
        else:
            all_dead_times = get_dead_times(params)
    data_file.open()
    #################################################################

    if comm.rank == 0:
        print_and_log(["Searching spikes to construct the PCA basis..."],
                      'default', logger)

    nb_chunks, last_chunk_len = data_file.analyze(chunk_size)

    if nb_chunks < comm.size:

        res = io.data_stats(params, show=False)
        chunk_size = int(res * params.rate // comm.size)
        if comm.rank == 0:
            print_and_log(
                ["Too much cores, automatically resizing the data chunks"],
                'debug', logger)

        nb_chunks, last_chunk_len = data_file.analyze(chunk_size)

    groups = {}
    for i in range(N_e):
        groups[i] = 0

    # I guess this is more relevant, to take signals from all over the recordings
    all_chunks = numpy.random.permutation(
        numpy.arange(nb_chunks, dtype=numpy.int32))
    max_elts_elec //= comm.size
    nb_elts //= comm.size

    elt_count_pos = 0
    elt_count_neg = 0

    if sign_peaks in ['positive', 'both']:
        times_pos = numpy.zeros(nb_elts, dtype=numpy.int32)
        electrodes_pos = numpy.zeros(nb_elts, dtype=numpy.int32)
        extremum_pos = numpy.zeros(nb_elts, dtype=numpy.float32)
        elts_pos = numpy.zeros((N_t, nb_elts), dtype=numpy.float32)
    if sign_peaks in ['negative', 'both']:
        times_neg = numpy.zeros(nb_elts, dtype=numpy.int32)
        electrodes_neg = numpy.zeros(nb_elts, dtype=numpy.int32)
        extremum_neg = numpy.zeros(nb_elts, dtype=numpy.float32)
        elts_neg = numpy.zeros((N_t, nb_elts), dtype=numpy.float32)

    thresholds = io.load_data(params, 'thresholds')
    mads = io.load_data(params, 'mads')
    stds = io.load_data(params, 'stds')

    if alignment:
        cdata = numpy.linspace(-jitter_range, +jitter_range, nb_jitter)
        xdata = numpy.arange(-template_shift_2, template_shift_2 + 1)
        xoff = len(cdata) / 2.0
        snippet_duration = template_shift_2
        m_size = 2 * template_shift_2 + 1
        align_factor = m_size
        local_factors = align_factor * ((smoothing_factor * mads)**2)
    else:
        snippet_duration = template_shift
        xdata = numpy.arange(-template_shift, template_shift + 1)

    if rejection_threshold > 0:
        reject_noise = True
        noise_levels = stds * (2 * noise_window + 1)
    else:
        reject_noise = False

    to_explore = all_chunks[comm.rank::comm.size]

    upper_bounds = max_elts_elec

    if comm.rank == 0:
        to_explore = get_tqdm_progressbar(params, to_explore)

    for gcount, gidx in enumerate(to_explore):

        if (elt_count_pos + elt_count_neg) < nb_elts:
            # print "Node", comm.rank, "is analyzing chunk", gidx, "/", nb_chunks, " ..."
            local_chunk, t_offset = data_file.get_data(gidx,
                                                       chunk_size,
                                                       nodes=nodes)
            local_shape = len(local_chunk)

            if do_spatial_whitening:
                if use_gpu:
                    local_chunk = cmt.CUDAMatrix(local_chunk,
                                                 copy_on_host=False)
                    local_chunk = local_chunk.dot(spatial_whitening).asarray()
                else:
                    local_chunk = numpy.dot(local_chunk, spatial_whitening)
            if do_temporal_whitening:
                local_chunk = scipy.ndimage.filters.convolve1d(
                    local_chunk, temporal_whitening, axis=0, mode='constant')

            local_borders = (snippet_duration, local_shape - snippet_duration)

            if ignore_dead_times:
                dead_indices = numpy.searchsorted(
                    all_dead_times, [t_offset, t_offset + local_shape])

            # Extracting the peaks.
            all_peaktimes = [numpy.empty(0, dtype=numpy.uint32)]

            found_peaktimes = []
            found_peak_amplitudes = []
            for i in range(N_e):
                height = thresholds[i]
                if sign_peaks == 'negative':
                    peaktimes = scipy.signal.find_peaks(-local_chunk[:, i],
                                                        height=height,
                                                        distance=dist_peaks)[0]
                elif sign_peaks == 'positive':
                    peaktimes = scipy.signal.find_peaks(local_chunk[:, i],
                                                        height=height,
                                                        distance=dist_peaks)[0]
                elif sign_peaks == 'both':
                    peaktimes = scipy.signal.find_peaks(numpy.abs(
                        local_chunk[:, i]),
                                                        height=height,
                                                        distance=dist_peaks)[0]
                else:
                    peaktimes = numpy.empty(0, dtype=numpy.uint32)

                if ignore_artefacts:
                    artetimes = scipy.signal.find_peaks(
                        numpy.abs(local_chunk[:,
                                              i]), height=weird_thresh[i])[0]
                    to_keep = numpy.logical_not(
                        numpy.in1d(peaktimes, artetimes))
                    peaktimes = peaktimes[to_keep]

                idx = (peaktimes >= local_borders[0]) & (peaktimes <
                                                         local_borders[1])
                peaktimes = peaktimes[idx]

                if ignore_dead_times:
                    if dead_indices[0] != dead_indices[1]:
                        is_included = numpy.in1d(
                            peaktimes + t_offset,
                            all_dead_times[dead_indices[0]:dead_indices[1]])
                        peaktimes = peaktimes[~is_included]

                peaktimes = peaktimes.astype(numpy.uint32)
                found_peaktimes.append(peaktimes)

                peak_amplitudes = local_chunk[peaktimes, i]
                found_peak_amplitudes.append(peak_amplitudes)

            all_peaktimes = numpy.concatenate(
                found_peaktimes)  # i.e. concatenate once for efficiency
            all_peak_amplitudes = numpy.concatenate(found_peak_amplitudes)
            local_peaktimes, local_indices = numpy.unique(all_peaktimes,
                                                          return_inverse=True)

            if len(local_peaktimes) > 0:

                diff_times = (local_peaktimes[-1] - local_peaktimes[0]) + 1
                all_times = numpy.zeros((N_e, diff_times), dtype=numpy.bool)

                padded_peaks = (local_peaktimes - local_peaktimes[0]).astype(
                    numpy.int32)
                min_times = numpy.maximum(padded_peaks - safety_time, 0)
                max_times = numpy.minimum(padded_peaks + safety_time + 1,
                                          diff_times + 1)
                test_extremas = numpy.zeros((N_e, diff_times + 1),
                                            dtype=numpy.bool)
                for i in range(N_e):
                    test_extremas[i, found_peaktimes[i] -
                                  local_peaktimes[0]] = True

                # Consider the peaks by decreasing extremum.
                if sort_waveforms:
                    order = numpy.argsort(-np.abs(all_peak_amplitudes))
                    all_idx = numpy.take(all_peaktimes, order)
                    argmax_peak = local_indices[order]
                else:
                    n_times = len(all_peaktimes)
                    shuffling = numpy.random.permutation(numpy.arange(n_times))
                    all_idx = numpy.take(all_peaktimes, shuffling)
                    argmax_peak = local_indices[shuffling]

                # print "Selection of the peaks with spatio-temporal masks..."
                for midx, peak in zip(argmax_peak, all_idx):
                    if (elt_count_neg + elt_count_pos) == nb_elts:
                        break

                    all_elecs = numpy.where(
                        test_extremas[:, peak - local_peaktimes[0]])[0]
                    data = local_chunk[peak, all_elecs]

                    #target_area = test_extremas[:, min_times[midx]:max_times[midx]].sum(1)
                    #all_elecs = numpy.where(target_area)[0]
                    #data = local_chunk[peak, all_elecs]

                    if sign_peaks == 'negative':
                        if N_e > 1:
                            if use_barycenter:
                                weighed_position = data[:, numpy.
                                                        newaxis] * positions[
                                                            all_elecs]
                                barycenter = weighed_position.sum(
                                    0) / data.sum()
                                elec = numpy.argmin(
                                    numpy.linalg.norm(barycenter -
                                                      positions[all_elecs],
                                                      axis=1))
                            else:
                                elec = numpy.argmin(data)
                        else:
                            elec = 0
                        negative_peak = True
                    elif sign_peaks == 'positive':
                        if N_e > 1:
                            if use_barycenter:
                                weighed_position = data[:, numpy.
                                                        newaxis] * positions[
                                                            all_elecs]
                                barycenter = weighed_position.sum(
                                    0) / data.sum()
                                elec = numpy.argmax(
                                    numpy.linalg.norm(barycenter -
                                                      positions[all_elecs],
                                                      axis=1))
                            else:
                                elec = numpy.argmax(data)
                        else:
                            elec = 0
                        negative_peak = False
                    elif sign_peaks == 'both':
                        if N_e == 1:
                            if data < 0:
                                negative_peak = True
                            elif data > 0:
                                negative_peak = False
                            elec = 0
                        else:
                            if numpy.abs(numpy.max(data)) > numpy.abs(
                                    numpy.min(data)):
                                elec = numpy.argmax(data)
                                negative_peak = False
                            else:
                                elec = numpy.argmin(data)
                                negative_peak = True

                    elec = all_elecs[elec]

                    if groups[elec] < upper_bounds:

                        indices = nodes_indices[elec]
                        myslice = all_times[indices,
                                            min_times[midx]:max_times[midx]]

                        if not myslice.any():

                            sub_mat = local_chunk[peak -
                                                  snippet_duration:peak +
                                                  snippet_duration + 1, elec]

                            if reject_noise:
                                slice_window = sub_mat[
                                    snippet_duration -
                                    noise_window:snippet_duration +
                                    noise_window + 1]
                                value = numpy.linalg.norm(
                                    slice_window) / noise_levels[elec]
                                is_noise = value < rejection_threshold
                            else:
                                is_noise = False

                            if not is_noise:

                                extrema = sub_mat[snippet_duration]

                                if alignment:
                                    smoothed = True
                                    try:
                                        f = scipy.interpolate.UnivariateSpline(
                                            xdata,
                                            sub_mat,
                                            s=local_factors[elec],
                                            k=3)
                                    except Exception:
                                        smoothed = False
                                        f = scipy.interpolate.UnivariateSpline(
                                            xdata, sub_mat, k=3, s=0)

                                    if negative_peak:
                                        rmin = (numpy.argmin(f(cdata)) -
                                                xoff) / over_factor
                                    else:
                                        rmin = (numpy.argmax(f(cdata)) -
                                                xoff) / over_factor
                                    ddata = numpy.linspace(
                                        rmin - template_shift,
                                        rmin + template_shift, N_t)

                                    if smoothed:
                                        f = scipy.interpolate.UnivariateSpline(
                                            xdata,
                                            sub_mat,
                                            s=local_factors[elec],
                                            k=3)
                                    else:
                                        f = scipy.interpolate.UnivariateSpline(
                                            xdata, sub_mat, s=0, k=3)

                                    sub_mat = f(ddata).astype(numpy.float32)

                                if negative_peak:
                                    times_neg[elt_count_neg] = peak + t_offset
                                    electrodes_neg[elt_count_neg] = elec
                                    extremum_neg[elt_count_neg] = extrema
                                    elts_neg[:, elt_count_neg] = sub_mat
                                    elt_count_neg += 1
                                else:
                                    times_pos[elt_count_pos] = peak + t_offset
                                    electrodes_pos[elt_count_pos] = elec
                                    extremum_pos[elt_count_pos] = extrema
                                    elts_pos[:, elt_count_pos] = sub_mat
                                    elt_count_pos += 1

                                groups[elec] += 1
                                all_times[
                                    indices,
                                    min_times[midx]:max_times[midx]] = True
                                test_extremas[elec, peak -
                                              local_peaktimes[0]] = False

    sys.stderr.flush()

    print_and_log([
        "Node %d has collected %d waveforms" %
        (comm.rank, elt_count_pos + elt_count_neg)
    ], 'debug', logger)

    if sign_peaks in ['negative', 'both']:
        times_neg = gather_array(times_neg[:elt_count_neg],
                                 comm,
                                 0,
                                 1,
                                 dtype='int32')
        electrodes_neg = gather_array(electrodes_neg[:elt_count_neg],
                                      comm,
                                      0,
                                      1,
                                      dtype='int32')
        extremum_neg = gather_array(extremum_neg[:elt_count_neg], comm, 0, 1)
        gdata_neg = gather_array(elts_neg[:, :elt_count_neg].T, comm, 0, 1)
    if sign_peaks in ['positive', 'both']:
        times_pos = gather_array(times_pos[:elt_count_pos],
                                 comm,
                                 0,
                                 1,
                                 dtype='int32')
        electrodes_pos = gather_array(electrodes_pos[:elt_count_pos],
                                      comm,
                                      0,
                                      1,
                                      dtype='int32')
        extremum_pos = gather_array(extremum_pos[:elt_count_pos], comm, 0, 1)
        gdata_pos = gather_array(elts_pos[:, :elt_count_pos].T, comm, 0, 1)

    nb_waveforms = 0

    if comm.rank == 0:
        # DO PCA on elts and store the basis obtained.

        if sign_peaks in ['negative', 'both']:
            nb_waveforms += gdata_neg.shape[0]
        if sign_peaks in ['positive', 'both']:
            nb_waveforms += gdata_pos.shape[0]

    nb_waveforms = all_gather_array(
        numpy.array([nb_waveforms], dtype=numpy.float32), comm, 0)[0]

    if comm.rank == 0:
        print_and_log([
            "Found %d waveforms over %d requested" %
            (nb_waveforms, int(nb_elts * comm.size))
        ], 'default', logger)

        if nb_waveforms == 0:
            print_and_log(
                ['No waveforms found! Are the data properly loaded??'],
                'error', logger)

    if nb_waveforms == 0:
        sys.exit(0)

    if comm.rank == 0:
        res = {}
        pca = None
        pca_pos = None
        pca_neg = None
        warning_n_t = False
        if sign_peaks in ['negative', 'both']:
            res['times'] = times_neg
            res['electrodes'] = electrodes_neg
            res['extremum'] = extremum_neg
            if len(gdata_neg) > 0:
                pca = PCA(output_dim)
                if use_hanning:
                    pca.fit(gdata_neg * hanning_filter)
                else:
                    pca.fit(gdata_neg)
                res['proj'] = pca.components_.T.astype(numpy.float32)
                pca_neg = numpy.sum(pca.explained_variance_ratio_)
            else:
                res['proj'] = numpy.identity(int(output_dim),
                                             dtype=numpy.float32)
            res['rec'] = res['proj'].T
            res['waveform'] = numpy.median(gdata_neg, 0)
            # dispersion = numpy.std(gdata_neg, 0) / numpy.median(stds)
            # ratio = numpy.sum(dispersion > 1.1) / float(len(dispersion))
            # if ratio < 0.25:
            #     print_and_log(["Time window N_t in [detection] seems too large!"], 'info', logger)
            #     warning_n_t = True
            # elif ratio == 1:
            #     print_and_log(["Time window N_t in [detection] seems too small!"], 'info', logger)
            #     warning_n_t = True
            idx = numpy.random.permutation(numpy.arange(
                gdata_neg.shape[0]))[:2500]
            res['waveforms'] = gdata_neg[idx, :]
        if sign_peaks in ['positive', 'both']:
            res['times_pos'] = times_pos
            res['electrodes_pos'] = electrodes_pos
            res['extremum_pos'] = extremum_pos
            if len(gdata_pos) > 0:
                pca = PCA(output_dim)
                if use_hanning:
                    pca.fit(gdata_pos * hanning_filter)
                else:
                    pca.fit(gdata_pos)
                res['proj_pos'] = pca.components_.T.astype(numpy.float32)
                pca_pos = numpy.sum(pca.explained_variance_ratio_)
            else:
                res['proj_pos'] = numpy.identity(int(output_dim),
                                                 dtype=numpy.float32)
            res['rec_pos'] = res['proj_pos'].T
            res['waveform_pos'] = numpy.median(gdata_pos, 0)
            # dispersion = numpy.std(gdata_pos, 0) / numpy.median(stds)
            # ratio = numpy.sum(dispersion > 1.1) / float(len(dispersion))
            # if ratio < 0.25 and not warning_n_t:
            #     print_and_log(["Time window N_t in [detection] seems too large!"], 'info', logger)
            # elif ratio == 1 and not warning_n_t:
            #     print_and_log(["Time window N_t in [detection] seems too small!"], 'info', logger)
            idx = numpy.random.permutation(numpy.arange(
                gdata_pos.shape[0]))[:2500]
            res['waveforms_pos'] = gdata_pos[idx, :]

        bfile = h5py.File(file_out_suff + '.basis.hdf5',
                          'r+',
                          libver='earliest')
        io.write_datasets(bfile,
                          list(res.keys()),
                          res,
                          compression=hdf5_compress)
        if sign_peaks == 'positive':
            print_and_log([
                "A basis with %s dimensions has been built" %
                res['proj_pos'].shape[1]
            ], 'info', logger)
        elif sign_peaks == 'negative':
            print_and_log([
                "A basis with %s dimensions has been built" %
                res['proj'].shape[1]
            ], 'info', logger)
        elif sign_peaks == 'both':
            print_and_log([
                "Two basis with %s dimensions has been built" %
                res['proj'].shape[1]
            ], 'debug', logger)
        if pca_pos is not None:
            print_and_log([
                "The percentage of variance explained is %s for positive spikes"
                % pca_pos
            ], 'debug', logger)
        if pca_neg is not None:
            print_and_log([
                "The percentage of variance explained is %s for negative spikes"
                % pca_neg
            ], 'debug', logger)

        bfile.close()

    comm.Barrier()

    if matched_filter:

        if comm.rank == 0:
            print_and_log([
                "Because of matched filters, need to recompute the thresholds..."
            ], 'default', logger)

        if do_spatial_whitening:
            spatial_whitening = io.load_data(params, 'spatial_whitening')
            if use_gpu:
                spatial_whitening = cmt.CUDAMatrix(spatial_whitening,
                                                   copy_on_host=False)
        if do_temporal_whitening:
            temporal_whitening = io.load_data(params, 'temporal_whitening')

        if sign_peaks in ['negative', 'both']:
            waveform_neg = io.load_data(params, 'waveform')[::-1]
            waveform_neg /= (numpy.abs(numpy.sum(waveform_neg)) *
                             len(waveform_neg))
        if sign_peaks in ['positive', 'both']:
            waveform_pos = io.load_data(params, 'waveform-pos')[::-1]
            waveform_pos /= (numpy.abs(numpy.sum(waveform_pos)) *
                             len(waveform_pos))

        for gidx in [all_chunks[comm.rank]]:
            local_chunk, t_offset = data_file.get_data(gidx,
                                                       chunk_size,
                                                       nodes=nodes)
            local_shape = len(local_chunk)

            if do_spatial_whitening:
                if use_gpu:
                    local_chunk = cmt.CUDAMatrix(local_chunk,
                                                 copy_on_host=False)
                    local_chunk = local_chunk.dot(spatial_whitening).asarray()
                else:
                    local_chunk = numpy.dot(local_chunk, spatial_whitening)
            if do_temporal_whitening:
                local_chunk = scipy.ndimage.filters.convolve1d(
                    local_chunk, temporal_whitening, axis=0, mode='constant')

            local_chunk /= thresholds

            if sign_peaks in ['negative', 'both']:
                tmp_chunk = scipy.ndimage.filters.convolve1d(local_chunk,
                                                             waveform_neg,
                                                             axis=0,
                                                             mode='constant')
                thresholds = numpy.zeros(N_e, dtype=numpy.float32)
                for i in range(N_e):
                    u = numpy.median(tmp_chunk[:, i], 0)
                    thresholds[i] = numpy.median(
                        numpy.abs(tmp_chunk[:, i] - u), 0)
                gdata = gather_array(thresholds, comm)
                if comm.rank == 0:
                    gdata = gdata.reshape((comm.size, N_e))
                    thresholds = numpy.mean(gdata, 0)
                    bfile = h5py.File(file_out_suff + '.basis.hdf5',
                                      'r+',
                                      libver='earliest')
                    io.write_datasets(bfile, ['matched_thresholds'],
                                      {'matched_thresholds': thresholds},
                                      compression=hdf5_compress)
                    bfile.close()
                comm.Barrier()

            if sign_peaks in ['positive', 'both']:
                tmp_chunk = scipy.ndimage.filters.convolve1d(local_chunk,
                                                             waveform_pos,
                                                             axis=0,
                                                             mode='constant')
                thresholds = numpy.zeros(N_e, dtype=numpy.float32)
                for i in range(N_e):
                    u = numpy.median(tmp_chunk[:, i], 0)
                    thresholds[i] = numpy.median(
                        numpy.abs(tmp_chunk[:, i] - u), 0)
                gdata = gather_array(thresholds, comm)
                if comm.rank == 0:
                    gdata = gdata.reshape((comm.size, N_e))
                    thresholds = numpy.mean(gdata, 0)
                    bfile = h5py.File(file_out_suff + '.basis.hdf5',
                                      'r+',
                                      libver='earliest')
                    io.write_datasets(bfile, ['matched_thresholds_pos'],
                                      {'matched_thresholds_pos': thresholds},
                                      compression=hdf5_compress)
                    bfile.close()
                comm.Barrier()

    data_file.close()

    if SHARED_MEMORY and ignore_dead_times:
        mpi_memory_3.Free()
Exemple #20
0
def main(params, nb_cpu, nb_gpu, use_gpu):

    logger = init_logging(params.logfile)
    logger = logging.getLogger('circus.filtering')
    #################################################################
    do_filter = params.getboolean('filtering', 'filter')
    filter_done = params.getboolean('noedits', 'filter_done')
    artefacts_done = params.getboolean('noedits', 'artefacts_done')
    median_done = params.getboolean('noedits', 'median_done')
    clean_artefact = params.getboolean('triggers', 'clean_artefact')
    remove_median = params.getboolean('filtering', 'remove_median')
    nodes, edges = get_nodes_and_edges(params)

    #################################################################

    def filter_file(data_file_in, data_file_out, do_filtering,
                    do_remove_median):

        try:
            cut_off = params.getfloat('filtering', 'cut_off')
            cut_off = [cut_off, 0.95 * (params.rate / 2.)]
        except Exception:
            cut_off = params.get('filtering', 'cut_off')
            cut_off = cut_off.split(',')
            try:
                cut_off[0] = float(cut_off[0])
            except Exception:
                if comm.rank == 0:
                    print_and_log(
                        ['First value of cut off must be a valid number'],
                        'error', logger)
                sys.exit(1)

            cut_off[1] = cut_off[1].replace(' ', '')
            if cut_off[1] == 'auto':
                cut_off[1] = 0.95 * (params.rate / 2.)
            else:
                try:
                    cut_off[1] = float(cut_off[1])
                except Exception:
                    if comm.rank == 0:
                        print_and_log([
                            'Second value of cut off must either auto, or a valid a number'
                        ], 'error', logger)
                    sys.exit(1)

        chunk_size = params.getint('data', 'chunk_size')
        nb_chunks, _ = data_file_in.analyze(chunk_size)
        do_butter = False  # mmyros
        do_lowess = False  # otherwise wavelet if both lowess and butter are false
        if do_butter:
            b, a = signal.butter(3,
                                 np.array(cut_off) / (params.rate / 2.),
                                 'pass')
        all_chunks = numpy.arange(nb_chunks, dtype=numpy.int64)
        to_process = all_chunks[comm.rank::comm.size]
        loc_nb_chunks = len(to_process)
        N_total = params.nb_channels

        if comm.rank == 0:
            to_write = []
            if do_filtering:
                to_write += [
                    "Filtering the signal with a Butterworth filter in (%g, %g) Hz"
                    % (cut_off[0], cut_off[1])
                ]
            if do_remove_median:
                to_write += [
                    "Median over all channels is substracted to each channels"
                ]

            print_and_log(to_write, 'default', logger)

        to_explore = xrange(comm.rank, nb_chunks, comm.size)

        if comm.rank == 0:
            to_explore = get_tqdm_progressbar(to_explore)

        for count, gidx in enumerate(to_explore):

            local_chunk, t_offset = data_file_in.get_data(gidx, chunk_size)

            if do_filtering:
                for i in nodes:
                    #try:
                    if do_butter:
                        local_chunk[:, i] = signal.filtfilt(
                            b, a, local_chunk[:, i])
                    elif do_lowess:
                        local_chunk[:, i] = lowess(local_chunk[:, i])
                    else:
                        local_chunk[:, i] = WMLDR(local_chunk[:, i])

                local_chunk[:, i] -= numpy.median(local_chunk[:, i])

            if do_remove_median:
                if not numpy.all(nodes == numpy.arange(N_total)):
                    global_median = numpy.median(
                        numpy.take(local_chunk, nodes, axis=1), 1)
                else:
                    global_median = numpy.median(local_chunk, 1)
                for i in nodes:
                    local_chunk[:, i] -= global_median

            if data_file_in != data_file_out and data_file_in.is_first_chunk(
                    gidx, nb_chunks):
                if data_file_in.is_stream:
                    g_offset = t_offset - numpy.sum(
                        data_file_in.
                        _times[:data_file_in.
                               _get_streams_index_by_time(t_offset) + 1])
                else:
                    g_offset = t_offset - data_file_in.t_start
            else:
                g_offset = t_offset

            data_file_out.set_data(g_offset, local_chunk)

        comm.Barrier()

    def lowess(data, frac=0.01):
        ''' Local regression (Lowess) filter for spike extraction
        dependency: pip install cylowess
        '''
        #import statsmodels.api as sm
        import cylowess
        data = np.atleast_2d(data)
        nt = data.shape[1]
        # filter data in place, iterate over channels in rows:
        nchans = len(data)
        for chani in range(nchans):
            # lowess using vanilla statsmodels
            #fit=sm.nonparametric.lowess(data[chani],range(len(data[chani])))[:,1]
            # cython implementation
            delta = 67.61  #deltafrac*len(data[chani])
            fit = cylowess.lowess(np.asarray(data[chani], dtype='float'),
                                  np.asarray(range(len(data[chani])),
                                             dtype='float'),
                                  frac=frac,
                                  it=0,
                                  delta=delta)[:, 1]
            data[chani] = data[chani] - fit
        return data

    def WMLDR(data, wname="db4", maxlevel=6, mode='sym'):
        """ Function by Martin Spacek from https://github.com/spyke/spyke

        Perform wavelet multi-level decomposition and reconstruction (WMLDR) on multichannel
        data. See Wiltschko2008. Default to Daubechies(4) wavelet. Modifies data in-place, at
        least for now. The effective cutoff frequency is:

        fc = (sampfreq / 2) / 2**maxlevel                     (Wiltschko2008)

        For sampfreq of 25 kHz and maxlevel of 6, the effective cutoff frequency is 195 Hz.
        For sampfreq of 30 kHz and maxlevel of 6, the effective cutoff frequency is 234 Hz.

        TODO: for now, this only returns highpass data. In the future, this probably should
        return both low and highpass data (and not modify it in-place). The Discussion in
        Wiltschko2008 suggests that this approach cannot be used to extract the LFP, but
        I don't see why you can't simply subtract the highpass data from the raw data to get the
        lowpass data.

        Signal extension modes (from PyWavelets docs):

        PyWavelets provides several methods of signal extrapolation that can be used to minimize
        edge effects. PyWavelet's default is 'sym':

        zpd - zero-padding - signal is extended by adding zero samples:
        ... 0  0 | x1 x2 ... xn | 0  0 ...

        cpd - constant-padding - border values are replicated:
        ... x1 x1 | x1 x2 ... xn | xn xn ...

        sym - symmetric-padding - signal is extended by mirroring samples:
        ... x2 x1 | x1 x2 ... xn | xn xn-1 ...

        ppd - periodic-padding - signal is treated as a periodic one:
        ... xn-1 xn | x1 x2 ... xn | x1 x2 ...

        sp1 - smooth-padding - signal is extended according to the first derivatives calculated on
        the edges (straight line)

        DWT performed for these extension modes is slightly redundant, but ensures perfect
        reconstruction. To receive the smallest possible number of coefficients, computations can
        be performed with the periodization mode:

        per - periodization - is like periodic-padding but gives the smallest possible number of
        decomposition coefficients. IDWT must be performed with the same mode.
        """
        import pywt

        data = np.atleast_2d(data)
        nt = data.shape[1]
        # reconstructed signals always seem to have an even number of points. If the number of
        # input data points is odd, trim last data point from reconstructed signal:
        isodd = nt % 2
        # filter data in place, iterate over channels in rows:
        nchans = len(data)
        for chani in range(nchans):
            # decompose the signal:
            cs = pywt.wavedec(data[chani], wname, mode=mode, level=maxlevel)
            # destroy the appropriate approximation coefficients to get highpass data:
            cs[0] = None
            # reconstruct the signal:
            recsignal = pywt.waverec(cs, wname, mode=mode)
            ntrec = len(recsignal)
            data[chani] = recsignal[:ntrec - isodd]

        return data

    def compute_artefacts(data_file):

        chunk_size = params.getint('data', 'chunk_size')
        trig_in_ms = params.getboolean('triggers', 'trig_in_ms')
        artefacts = numpy.loadtxt(params.get('triggers', 'trig_file'))
        windows = numpy.loadtxt(params.get('triggers', 'trig_windows'))
        make_plots = params.get('triggers', 'make_plots')
        plot_path = os.path.join(params.get('data', 'data_file_noext'),
                                 'plots')

        if len(windows.shape) == 1:
            windows = windows.reshape(1, 2)

        if len(artefacts.shape) == 1:
            artefacts = artefacts.reshape(1, 2)

        if trig_in_ms:
            if comm.rank == 0:
                print_and_log(['Artefact times are read in ms'], 'debug',
                              logger)
            artefacts[:, 1] *= numpy.int64(data_file.sampling_rate * 1e-3)
            windows[:, 1] *= numpy.int64(data_file.sampling_rate * 1e-3)
        else:
            if comm.rank == 0:
                print_and_log(['Artefact times are read in timesteps'],
                              'debug', logger)
            artefacts = artefacts.astype(numpy.int64)
            windows = windows.astype(numpy.int64)

        nb_stimuli = len(numpy.unique(artefacts[:, 0]))
        mytest = nb_stimuli == len(windows)

        if not mytest:
            if comm.rank == 0:
                print_and_log(['Error in the trigger files'], 'error', logger)
            sys.exit(1)

        all_labels = artefacts[:, 0].astype(numpy.int32)
        all_times = artefacts[:, 1].astype(numpy.int32)

        mask = (all_times >= 0) & (all_times + numpy.max(windows[:, 1]) <
                                   data_file.t_stop)
        all_times = numpy.compress(mask, all_times)
        all_labels = numpy.compress(mask, all_labels)

        local_labels = numpy.unique(all_labels)[comm.rank::comm.size]

        if comm.rank == 0:
            to_write = [
                "Computing averaged artefacts from %d stimuli" % (nb_stimuli)
            ]
            print_and_log(to_write, 'default', logger)
            if not os.path.exists(plot_path):
                os.makedirs(plot_path)
            local_labels = get_tqdm_progressbar(local_labels)

        comm.Barrier()
        # First we need to get the average artefacts
        art_dict = {}
        for count, artefact in enumerate(local_labels):
            indices = numpy.where(all_labels == artefact)[0].astype(
                numpy.int32)
            tmp = numpy.where(windows[:, 0] == artefact)[0]
            tau = numpy.int64(windows[tmp, 1])
            pspikes = all_times[indices]
            times = numpy.sort(numpy.random.permutation(pspikes)[:500])
            if len(numpy.where(numpy.diff(times) < tau)[0]) > 0:
                if comm.rank == 0:
                    print_and_log([
                        'Stimulation times for artefact %d are too close!' %
                        artefact
                    ], 'error', logger)
                sys.exit(1)

            art_dict[artefact] = get_artefact(params, times, tau, nodes)
            if make_plots not in ['None', '']:
                save = [plot_path, '%d.%s' % (artefact, make_plots)]
                plot.view_artefact(art_dict[artefact], save=save)

        return art_dict

    def remove_artefacts(data_file, art_dict):

        chunk_size = params.getint('data', 'chunk_size')
        trig_in_ms = params.getboolean('triggers', 'trig_in_ms')
        artefacts = numpy.loadtxt(params.get('triggers',
                                             'trig_file')).astype(numpy.int64)
        windows = numpy.loadtxt(params.get('triggers',
                                           'trig_windows')).astype(numpy.int64)
        make_plots = params.get('triggers', 'make_plots')
        plot_path = os.path.join(params.get('data', 'data_file_noext'),
                                 'plots')

        if len(windows.shape) == 1:
            windows = windows.reshape(1, 2)

        if len(artefacts.shape) == 1:
            artefacts = artefacts.reshape(1, 2)

        if trig_in_ms:
            if comm.rank == 0:
                print_and_log(['Artefact times are read in ms'], 'debug',
                              logger)
            artefacts[:, 1] *= numpy.int64(data_file.sampling_rate * 1e-3)
            windows[:, 1] *= numpy.int64(data_file.sampling_rate * 1e-3)
        else:
            if comm.rank == 0:
                print_and_log(['Artefact times are read in timesteps'],
                              'debug', logger)
            artefacts = artefacts.astype(numpy.int64)
            windows = windows.astype(numpy.int64)

        nb_stimuli = len(numpy.unique(artefacts[:, 0]))
        mytest = nb_stimuli == len(windows)

        if not mytest:
            if comm.rank == 0:
                print_and_log(['Error in the trigger files'], 'error', logger)
            sys.exit(1)

        all_labels = artefacts[:, 0].astype(numpy.int32)
        all_times = artefacts[:, 1].astype(numpy.int32)
        local_labels = numpy.unique(all_labels)[comm.rank::comm.size]

        mask = numpy.in1d(all_labels, local_labels)
        all_times = numpy.compress(mask, all_times)
        all_labels = numpy.compress(mask, all_labels)

        mask = (all_times >= 0) & (all_times < data_file.t_stop)
        all_times = numpy.compress(mask, all_times)
        all_labels = numpy.compress(mask, all_labels)

        if comm.rank == 0:
            to_write = ["Removing artefacts from %d stimuli" % (nb_stimuli)]
            print_and_log(to_write, 'default', logger)
            all_times = get_tqdm_progressbar(all_times)

        comm.Barrier()

        for count, time in enumerate(all_times):

            label = all_labels[count]
            tmp = numpy.where(windows[:, 0] == label)[0][0]
            tau = numpy.int64(windows[tmp, 1])

            if (data_file.t_stop - time) < tau:
                tau = max_offset - time

            local_chunk = data_file.get_snippet(time, tau)

            for idx, i in enumerate(nodes):
                local_chunk[:, i] -= art_dict[label][idx, :tau]
            data_file.set_data(time, local_chunk)

        comm.Barrier()

    if comm.rank == 0:
        print_and_log(['Initializing the filtering step...'], 'debug', logger)

    if params.getboolean('data', 'overwrite'):
        if comm.rank == 0:
            print_and_log(['Reading the input file...'], 'debug', logger)

        data_file_in = params.get_data_file()
        data_file_out = data_file_in
    else:
        if comm.rank == 0:
            print_and_log(
                ['Overwrite is set to False, so creating a new datafile...'],
                'debug', logger)

        if comm.rank == 0:
            print_and_log(['Reading the input file...'], 'debug', logger)

        data_file_in = params.get_data_file(source=True,
                                            has_been_created=False)

        if comm.rank == 0:
            print_and_log(
                ['Reading the output file and allocating ressources...'],
                'debug', logger)

        description = data_file_in.get_description()
        description['data_dtype'] = 'float32'
        description['dtype_offset'] = 0
        description['data_offset'] = 0

        data_file_out = params.get_data_file(is_empty=True, params=description)

        data_file_out.allocate(shape=data_file_in.shape)
        #<<<<<<< HEAD
        data_file_in._params = tmp_params
        if data_file_in.is_stream:
            for source in data_file_in._sources:
                source._params = tmp_params
#=======

#>>>>>>> upstream/master

    if clean_artefact:
        if not (os.path.exists(params.get('triggers', 'trig_file'))
                and os.path.exists(params.get('triggers', 'trig_windows'))):
            if comm.rank == 0:
                print_and_log(
                    ['trig_file or trig_windows file can not be found'],
                    'error', logger)
            sys.exit(1)

    to_write = []

    if do_filter and filter_done:
        do_filter = False
        to_write += ["Filtering has already been done"]
    if remove_median and median_done:
        remove_median = False
        to_write += [
            "Median over all channels has already been substracted to each channels"
        ]

    if comm.rank == 0:
        print_and_log(to_write, 'debug', logger)

    if params.getboolean('data', 'overwrite'):
        data_file_in.open(mode='r+')
    else:
        data_file_in.open()
        data_file_out.open(mode='r+')

    if do_filter or remove_median:
        filter_file(data_file_in, data_file_out, do_filter, remove_median)

    if comm.rank == 0:
        if do_filter:
            params.write('noedits', 'filter_done', 'True')
        if remove_median:
            params.write('noedits', 'median_done', 'True')

    if clean_artefact and artefacts_done:
        clean_artefact = False
        if comm.rank == 0:
            print_and_log(['Artefacts have already been removed'], 'debug',
                          logger)

    if clean_artefact:
        art_dict = compute_artefacts(data_file_in)
        remove_artefacts(data_file_out, art_dict)

    if comm.rank == 0:
        if clean_artefact:
            params.write('noedits', 'artefacts_done', 'True')

    data_file_in.close()
    if not params.getboolean('data', 'overwrite'):
        data_file_out.close()

    comm.Barrier()
def extract_extra_thresholds(params):
    """Compute the mean and the standard deviation for each extracellular channel"""

    data_file = params.data_file
    data_file.open()

    chunk_size = params.getint('data', 'chunk_size')
    do_temporal_whitening = params.getboolean('whitening', 'temporal')
    do_spatial_whitening = params.getboolean('whitening', 'spatial')
    N_total = params.nb_channels

    if do_spatial_whitening:
        spatial_whitening = io.load_data(params, 'spatial_whitening')
    if do_temporal_whitening:
        temporal_whitening = io.load_data(params, 'temporal_whitening')

    # mpi_file = MPI.File()
    # mpi_input = mpi_file.Open(comm, data_filename, MPI.MODE_RDONLY)
    nb_chunks, last_chunk_len = data_file.analyze(chunk_size)
    nodes, _ = get_nodes_and_edges(params)
    N_elec = nodes.size

    def weighted_mean(weights, values):
        """Compute a weighted mean for the given values"""
        norm_weights = [
            float(weight) / float(sum(weights)) for weight in weights
        ]
        weighted_values = [
            norm_weight * value
            for (norm_weight, value) in zip(norm_weights, values)
        ]
        weighted_mean = sum(weighted_values)
        return weighted_mean

    def extract_median(chunk_size, gidx):
        """Extract the medians from a chunk of extracellular traces"""
        loc_chunk, _ = data_file.get_data(gidx, chunk_size, nodes=nodes)
        # Whiten signal.
        if do_spatial_whitening:
            loc_chunk = numpy.dot(loc_chunk, spatial_whitening)
        if do_temporal_whitening:
            loc_chunk = scipy.ndimage.filters.convolve1d(loc_chunk,
                                                         temporal_whitening,
                                                         axis=0,
                                                         mode='constant')
        median = numpy.median(loc_chunk, axis=0)
        return median

    def extract_median_absolute_deviation(chunk_size, gidx, median):
        """Extract the median absolute deviations from a chunk of extracellular traces"""
        loc_chunk, _ = data_file.get_data(gidx, chunk_size, nodes=nodes)
        # Whiten signal.
        if do_spatial_whitening:
            loc_chunk = numpy.dot(loc_chunk, spatial_whitening)
        if do_temporal_whitening:
            loc_chunk = scipy.ndimage.filters.convolve1d(loc_chunk,
                                                         temporal_whitening,
                                                         axis=0,
                                                         mode='constant')
        mad = numpy.median(numpy.abs(loc_chunk - median), axis=0)
        return mad

    # Distribute chunks over the CPUs.
    all_chunks = numpy.arange(nb_chunks)
    loc_all_chunks = all_chunks[comm.rank::comm.size]
    loc_nb_chunks = len(loc_all_chunks)

    loc_nbs_chunks = comm.gather(loc_nb_chunks, root=0)

    if comm.rank == 0:
        print_and_log(["Computing extracellular medians..."],
                      level='default',
                      logger=logger)

    to_explore = list(range(comm.rank, nb_chunks, comm.size))

    if comm.rank == 0:
        to_explore = get_tqdm_progressbar(params, to_explore)

    medians = numpy.zeros((N_elec, loc_nb_chunks), dtype=numpy.float32)

    # For each chunk attributed to the current CPU.
    for count, gidx in enumerate(to_explore):
        gidx = all_chunks[gidx]
        medians[:, count] = extract_median(chunk_size, gidx)

    sys.stderr.flush()
    median = numpy.mean(medians, axis=1)

    comm.Barrier()

    medians = comm.gather(median, root=0)

    if comm.rank == 0:
        median = weighted_mean(loc_nbs_chunks, medians)

    # Broadcast medians to each CPU.
    median = comm.bcast(median, root=0)

    comm.Barrier()

    if comm.rank == 0:
        print_and_log(["Computing extracellular thresholds..."],
                      level='default',
                      logger=logger)

    to_explore = list(range(comm.rank, nb_chunks, comm.size))

    if comm.rank == 0:
        to_explore = get_tqdm_progressbar(params, to_explore)

    mads = numpy.zeros((N_elec, loc_nb_chunks), dtype=numpy.float32)

    # For each chunk attributed to the current CPU.
    for count, gidx in enumerate(to_explore):
        gidx = all_chunks[gidx]
        mads[:, count] = extract_median_absolute_deviation(
            chunk_size, gidx, median)
    sys.stderr.flush()
    mad = numpy.mean(mads, axis=1)

    comm.Barrier()

    mads = comm.gather(mad, root=0)

    if comm.rank == 0:
        mad = weighted_mean(loc_nbs_chunks, mads)

    # Broadcast median absolute deviation to each CPU.
    mad = comm.bcast(mad, root=0)

    comm.Barrier()
    data_file.close()

    return median, mad
def main(params, nb_cpu, nb_gpu, use_gpu):

    #################################################################
    # params = detect_memory(params)
    _ = init_logging(params.logfile)
    SHARED_MEMORY = get_shared_memory_flag(params)
    logger = logging.getLogger('circus.fitting')
    data_file = params.data_file
    N_e = params.getint('data', 'N_e')
    N_total = params.nb_channels
    N_t = params.getint('detection', 'N_t')
    template_shift = params.getint('detection', 'template_shift')
    file_out = params.get('data', 'file_out')
    file_out_suff = params.get('data', 'file_out_suff')
    sign_peaks = params.get('detection', 'peaks')
    dist_peaks = params.getint('detection', 'dist_peaks')
    matched_filter = params.getboolean('detection', 'matched-filter')
    spike_thresh = params.getfloat('detection', 'spike_thresh')
    spike_width = params.getfloat('detection', 'spike_width')
    do_temporal_whitening = params.getboolean('whitening', 'temporal')
    do_spatial_whitening = params.getboolean('whitening', 'spatial')
    chunk_size = detect_memory(params)
    gpu_only = params.getboolean('fitting', 'gpu_only')
    nodes, edges = get_nodes_and_edges(params)
    tmp_limits = params.get('fitting',
                            'amp_limits').replace('(',
                                                  '').replace(')',
                                                              '').split(',')
    tmp_limits = map(float, tmp_limits)
    amp_auto = params.getboolean('fitting', 'amp_auto')
    nb_chances = params.getint('fitting', 'nb_chances')
    max_chunk = params.getfloat('fitting', 'max_chunk')
    noise_thr = params.getfloat('clustering', 'noise_thr')
    collect_all = params.getboolean('fitting', 'collect_all')
    ignore_dead_times = params.getboolean('triggers', 'ignore_times')
    inv_nodes = numpy.zeros(N_total, dtype=numpy.int32)
    inv_nodes[nodes] = numpy.arange(len(nodes))
    data_file.open()
    #################################################################

    if use_gpu:
        import cudamat as cmt
        # # Need to properly handle multi GPU per MPI nodes?
        if nb_gpu > nb_cpu:
            gpu_id = int(comm.rank // nb_cpu)
        else:
            gpu_id = 0
        cmt.cuda_set_device(gpu_id)
        cmt.init()
        cmt.cuda_sync_threads()

    if matched_filter:
        if sign_peaks in ['negative', 'both']:
            waveform_neg = io.load_data(params, 'waveform')[::-1]
            waveform_neg /= (numpy.abs(numpy.sum(waveform_neg)) *
                             len(waveform_neg))
            matched_tresholds_neg = io.load_data(params, 'matched-thresholds')
        if sign_peaks in ['positive', 'both']:
            waveform_pos = io.load_data(params, 'waveform-pos')[::-1]
            waveform_pos /= (numpy.abs(numpy.sum(waveform_pos)) *
                             len(waveform_pos))
            matched_tresholds_pos = io.load_data(params,
                                                 'matched-thresholds-pos')

    if ignore_dead_times:
        all_dead_times = get_dead_times(params)

    thresholds = io.load_data(params, 'thresholds')

    comm.Barrier()

    if comm.rank == 0:
        print_and_log(["Extracting MUA activity..."], 'default', logger)
        purge(file_out_suff, '.data')

    if do_spatial_whitening:
        spatial_whitening = io.load_data(params, 'spatial_whitening')
    else:
        spatial_whitening = None  # default assignment (PyCharm code inspection)
    if do_temporal_whitening:
        temporal_whitening = io.load_data(params, 'temporal_whitening')
    else:
        temporal_whitening = None  # default assignment (PyCharm code inspection)

    nb_chunks, last_chunk_len = data_file.analyze(chunk_size)
    processed_chunks = int(min(nb_chunks, max_chunk))

    comm.Barrier()
    spiketimes_file = open(file_out_suff + '.mua-%d.data' % comm.rank, 'wb')
    comm.Barrier()
    electrodes_file = open(file_out_suff + '.elec-%d.data' % comm.rank, 'wb')
    comm.Barrier()
    amplitudes_file = open(file_out_suff + '.amp-%d.data' % comm.rank, 'wb')
    comm.Barrier()

    if use_gpu and do_spatial_whitening:
        spatial_whitening = cmt.CUDAMatrix(spatial_whitening,
                                           copy_on_host=False)

    to_explore = range(comm.rank, processed_chunks, comm.size)

    if comm.rank == 0:
        to_explore = get_tqdm_progressbar(to_explore)

    for gcount, gidx in enumerate(to_explore):
        # print "Node", comm.rank, "is analyzing chunk", gidx, "/", nb_chunks, " ..."
        # # We need to deal with the borders by taking chunks of size [0, chunck_size + template_shift].

        is_first = data_file.is_first_chunk(gidx, nb_chunks)
        is_last = data_file.is_last_chunk(gidx, nb_chunks)

        if is_last:
            padding = (-dist_peaks, 0)
        elif is_first:
            padding = (0, dist_peaks)
        else:
            padding = (-dist_peaks, dist_peaks)

        result = {'spiketimes': [], 'amplitudes': [], 'templates': []}

        local_chunk, t_offset = data_file.get_data(gidx,
                                                   chunk_size,
                                                   padding,
                                                   nodes=nodes)
        len_chunk = len(local_chunk)

        if do_spatial_whitening:
            if use_gpu:
                local_chunk = cmt.CUDAMatrix(local_chunk, copy_on_host=False)
                local_chunk = local_chunk.dot(spatial_whitening).asarray()
            else:
                local_chunk = numpy.dot(local_chunk, spatial_whitening)
        if do_temporal_whitening:
            local_chunk = scipy.ndimage.filters.convolve1d(local_chunk,
                                                           temporal_whitening,
                                                           axis=0,
                                                           mode='constant')

        # print "Extracting the peaks..."

        local_peaktimes = [numpy.zeros(0, dtype=numpy.uint32)]
        local_elecs = [numpy.zeros(0, dtype=numpy.uint32)]
        local_amps = [numpy.zeros(0, dtype=numpy.float32)]

        if matched_filter:
            if sign_peaks in ['positive', 'both']:
                filter_chunk = scipy.ndimage.filters.convolve1d(
                    local_chunk, waveform_pos, axis=0, mode='constant')
                for i in range(N_e):
                    peaktimes = scipy.signal.find_peaks(
                        filter_chunk[:, i],
                        height=matched_tresholds_pos[i],
                        width=spike_width,
                        distance=dist_peaks,
                        wlen=N_t)[0]
                    local_peaktimes.append(peaktimes)
                    local_elecs.append(
                        i * numpy.ones(len(peaktimes), dtype='uint32'))
                    local_amps.append(filter_chunk[peaktimes, i])
            if sign_peaks in ['negative', 'both']:
                filter_chunk = scipy.ndimage.filters.convolve1d(
                    local_chunk, waveform_neg, axis=0, mode='constant')
                for i in range(N_e):
                    peaktimes = scipy.signal.find_peaks(
                        filter_chunk[:, i],
                        height=matched_tresholds_neg[i],
                        width=spike_width,
                        distance=dist_peaks,
                        wlen=N_t)[0]
                    local_peaktimes.append(peaktimes)
                    local_elecs.append(
                        i * numpy.ones(len(peaktimes), dtype='uint32'))
                    local_amps.append(filter_chunk[peaktimes, i])
        else:
            for i in range(N_e):
                if sign_peaks == 'negative':
                    peaktimes = scipy.signal.find_peaks(-local_chunk[:, i],
                                                        height=thresholds[i],
                                                        width=spike_width,
                                                        distance=dist_peaks,
                                                        wlen=N_t)[0]
                elif sign_peaks == 'positive':
                    peaktimes = scipy.signal.find_peaks(local_chunk[:, i],
                                                        height=thresholds[i],
                                                        width=spike_width,
                                                        distance=dist_peaks,
                                                        wlen=N_t)[0]
                elif sign_peaks == 'both':
                    peaktimes = scipy.signal.find_peaks(numpy.abs(
                        local_chunk[:, i]),
                                                        height=thresholds[i],
                                                        width=spike_width,
                                                        distance=dist_peaks,
                                                        wlen=N_t)[0]
                local_peaktimes.append(peaktimes)
                local_elecs.append(i *
                                   numpy.ones(len(peaktimes), dtype='uint32'))
                local_amps.append(local_chunk[peaktimes, i])

        local_peaktimes = numpy.concatenate(local_peaktimes)
        local_elecs = numpy.concatenate(local_elecs)
        local_amps = numpy.concatenate(local_amps)

        g_offset = t_offset + padding[0]

        if ignore_dead_times:
            dead_indices = numpy.searchsorted(
                all_dead_times, [t_offset, t_offset + chunk_size])
            if dead_indices[0] != dead_indices[1]:
                is_included = numpy.in1d(
                    local_peaktimes + g_offset,
                    all_dead_times[dead_indices[0]:dead_indices[1]])
                local_peaktimes = local_peaktimes[~is_included]
                local_elecs = local_elecs[~is_included]
                local_amps = local_amps[~is_included]

        # print "Removing the useless borders..."
        local_borders = (dist_peaks, len_chunk - dist_peaks)
        idx = (local_peaktimes >= local_borders[0]) & (local_peaktimes <
                                                       local_borders[1])
        local_peaktimes = numpy.compress(idx, local_peaktimes) + g_offset
        local_elecs = numpy.compress(idx, local_elecs)
        local_amps = numpy.compress(idx, local_amps)

        spiketimes_file.write(local_peaktimes.astype(numpy.uint32).tostring())
        electrodes_file.write(local_elecs.tostring())
        amplitudes_file.write(local_amps.tostring())

    sys.stderr.flush()

    spiketimes_file.flush()
    os.fsync(spiketimes_file.fileno())
    spiketimes_file.close()

    electrodes_file.flush()
    os.fsync(electrodes_file.fileno())
    electrodes_file.close()

    amplitudes_file.flush()
    os.fsync(amplitudes_file.fileno())
    amplitudes_file.close()

    comm.Barrier()

    if comm.rank == 0:
        io.collect_mua(comm.size, params, erase=True)

    data_file.close()
Exemple #23
0
def delete_mixtures(params, nb_cpu, nb_gpu, use_gpu):

    templates      = load_data(params, 'templates')

    data_file      = params.data_file
    N_e            = params.getint('data', 'N_e')
    N_total        = params.nb_channels
    N_t            = params.getint('detection', 'N_t')
    template_shift = params.getint('detection', 'template_shift')
    cc_merge       = params.getfloat('clustering', 'cc_merge')
    x,        N_tm = templates.shape
    nb_temp        = N_tm//2
    merged         = [nb_temp, 0]
    mixtures       = []
    to_remove      = []

    overlap  = get_overlaps(params, extension='-mixtures', erase=True, normalize=False, maxoverlap=False, verbose=False, half=True, use_gpu=use_gpu, nb_cpu=nb_cpu, nb_gpu=nb_gpu)
    filename = params.get('data', 'file_out_suff') + '.overlap-mixtures.hdf5'
    result   = []

    norm_templates   = load_data(params, 'norm-templates')
    templates        = load_data(params, 'templates')
    result           = load_data(params, 'clusters')
    best_elec        = load_data(params, 'electrodes')
    limits           = load_data(params, 'limits')
    nodes, edges     = get_nodes_and_edges(params)
    inv_nodes        = numpy.zeros(N_total, dtype=numpy.int32)
    inv_nodes[nodes] = numpy.argsort(nodes)

    distances = numpy.zeros((nb_temp, nb_temp), dtype=numpy.float32)

    over_x     = overlap.get('over_x')[:]
    over_y     = overlap.get('over_y')[:]
    over_data  = overlap.get('over_data')[:]
    over_shape = overlap.get('over_shape')[:]
    overlap.close()

    overlap    = scipy.sparse.csr_matrix((over_data, (over_x, over_y)), shape=over_shape)

    for i in xrange(nb_temp-1):
        distances[i, i+1:] = numpy.argmax(overlap[i*nb_temp+i+1:(i+1)*nb_temp].toarray(), 1)
        distances[i+1:, i] = distances[i, i+1:]

    all_temp  = numpy.arange(comm.rank, nb_temp, comm.size)
    overlap_0 = overlap[:, N_t].toarray().reshape(nb_temp, nb_temp)


    sorted_temp    = numpy.argsort(norm_templates[:nb_temp])[::-1][comm.rank::comm.size]
    M              = numpy.zeros((2, 2), dtype=numpy.float32)
    V              = numpy.zeros((2, 1), dtype=numpy.float32)

    to_explore = xrange(comm.rank, len(sorted_temp), comm.size)
    if comm.rank == 0:
        to_explore = get_tqdm_progressbar(to_explore)


    for count, k in enumerate(to_explore):

        k             = sorted_temp[k]
        electrodes    = numpy.take(inv_nodes, edges[nodes[best_elec[k]]])
        overlap_k     = overlap[k*nb_temp:(k+1)*nb_temp].tolil()
        is_in_area    = numpy.in1d(best_elec, electrodes)
        all_idx       = numpy.arange(len(best_elec))[is_in_area]
        been_found    = False

        for i in all_idx:
            if not been_found:
                overlap_i = overlap[i*nb_temp:(i+1)*nb_temp].tolil()
                M[0, 0]   = overlap_0[i, i]
                V[0, 0]   = overlap_k[i, distances[k, i]]
                for j in all_idx[i+1:]:
                    M[1, 1]  = overlap_0[j, j]
                    M[1, 0]  = overlap_i[j, distances[k, i] - distances[k, j]]
                    M[0, 1]  = M[1, 0]
                    V[1, 0]  = overlap_k[j, distances[k, j]]
                    try:
                        [a1, a2] = numpy.dot(scipy.linalg.inv(M), V)
                    except Exception:
                        [a1, a2] = [0, 0]
                    a1_lim   = limits[i]
                    a2_lim   = limits[j]
                    is_a1    = (a1_lim[0] <= a1) and (a1 <= a1_lim[1])
                    is_a2    = (a2_lim[0] <= a2) and (a2 <= a2_lim[1])
                    if is_a1 and is_a2:
                        new_template = (a1*templates[:, i].toarray() + a2*templates[:, j].toarray()).ravel()
                        similarity   = numpy.corrcoef(templates[:, k].toarray().ravel(), new_template)[0, 1]
                        if similarity > cc_merge:
                            if k not in mixtures:
                                mixtures  += [k]
                                been_found = True
                                break
                                #print "Template", k, 'is sum of (%d, %g) and (%d,%g)' %(i, a1, j, a2)

    #print mixtures
    to_remove = numpy.unique(numpy.array(mixtures, dtype=numpy.int32))
    to_remove = all_gather_array(to_remove, comm, 0, dtype='int32')

    if len(to_remove) > 0:
        slice_templates(params, to_remove)
        slice_clusters(params, result, to_remove=to_remove)

    comm.Barrier()

    if comm.rank == 0:
        os.remove(filename)

    return [nb_temp, len(to_remove)]
Exemple #24
0
def main(params, nb_cpu, nb_gpu, use_gpu):
    # Part 1: Whitening
    numpy.random.seed(420)

    logger = init_logging(params.logfile)
    logger = logging.getLogger('circus.whitening')
    #################################################################
    data_file = params.data_file
    data_file.open()
    N_e = params.getint('data', 'N_e')
    N_total = params.nb_channels
    N_t = params.getint('detection', 'N_t')
    dist_peaks = params.getint('detection', 'dist_peaks')
    template_shift = params.getint('detection', 'template_shift')
    file_out_suff = params.get('data', 'file_out_suff')
    file_out = params.get('data', 'file_out')
    spike_thresh = params.getfloat('detection', 'spike_thresh')
    matched_filter = params.getboolean('detection', 'matched-filter')
    matched_thresh = params.getfloat('detection', 'matched_thresh')
    sign_peaks = params.get('detection', 'peaks')
    do_temporal_whitening = params.getboolean('whitening', 'temporal')
    do_spatial_whitening = params.getboolean('whitening', 'spatial')
    chunk_size = params.getint('whitening', 'chunk_size')
    plot_path = os.path.join(params.get('data', 'data_file_noext'), 'plots')
    nodes, edges = get_nodes_and_edges(params)
    safety_time = params.getint('whitening', 'safety_time')
    nb_temp_white = min(max(20, comm.size), N_e)
    max_silence_1 = int(20 * params.rate // comm.size)
    max_silence_2 = 5000
    inv_nodes = numpy.zeros(N_total, dtype=numpy.int32)
    inv_nodes[nodes] = numpy.argsort(nodes)
    #################################################################

    if comm.rank == 0:
        print_and_log(
            ["Analyzing data to get whitening matrices and thresholds..."],
            'default', logger)

    if use_gpu:
        import cudamat as cmt
        ## Need to properly handle multi GPU per MPI nodes?
        if nb_gpu > nb_cpu:
            gpu_id = int(comm.rank // nb_cpu)
        else:
            gpu_id = 0
        cmt.cuda_set_device(gpu_id)
        cmt.init()
        cmt.cuda_sync_threads()

    nb_chunks, last_chunk_len = data_file.analyze(chunk_size)

    if nb_chunks < comm.size:

        res = io.data_stats(params, show=False)
        chunk_size = int(res * params.rate // comm.size)
        if comm.rank == 0:
            print_and_log(
                ["Too much cores, automatically resizing the data chunks"],
                'debug', logger)

        nb_chunks, last_chunk_len = data_file.analyze(chunk_size)

    # I guess this is more relevant, to take signals from all over the recordings
    all_chunks = numpy.random.permutation(
        numpy.arange(nb_chunks, dtype=numpy.int32))
    all_electrodes = numpy.random.permutation(N_e)

    for gidx in [all_chunks[comm.rank]]:

        #print "Node", comm.rank, "is analyzing chunk", gidx,  "/", nb_chunks, " ..."
        local_chunk, t_offset = data_file.get_data(gidx,
                                                   chunk_size,
                                                   nodes=nodes)
        local_shape = len(local_chunk)

        #print "Node", comm.rank, "computes the median absolute deviations in a random chunk"
        thresholds = numpy.zeros(N_e, dtype=numpy.float32)
        for i in xrange(N_e):
            u = numpy.median(local_chunk[:, i], 0)
            thresholds[i] = numpy.median(numpy.abs(local_chunk[:, i] - u), 0)
        gdata = gather_array(thresholds, comm)
        if comm.rank == 0:
            gdata = gdata.reshape((comm.size, N_e))
            thresholds = numpy.mean(gdata, 0)
            bfile = h5py.File(file_out_suff + '.basis.hdf5',
                              'w',
                              libver='latest')
            io.write_datasets(bfile, ['thresholds'],
                              {'thresholds': thresholds})
            bfile.close()
        comm.Barrier()
        thresholds = io.load_data(params, 'thresholds')

        #print "Extracting the peaks..."
        local_peaktimes = numpy.zeros(0, dtype=numpy.int32)
        for i in xrange(N_e):
            peaktimes = algo.detect_peaks(numpy.abs(local_chunk[:, i]),
                                          thresholds[i],
                                          valley=False,
                                          mpd=dist_peaks)
            local_peaktimes = numpy.concatenate((local_peaktimes, peaktimes))

        local_peaktimes = numpy.unique(local_peaktimes)

        #print "Removing the useless borders..."
        local_borders = (template_shift, local_shape - template_shift)
        idx = (local_peaktimes >= local_borders[0]) & (local_peaktimes <
                                                       local_borders[1])
        local_peaktimes = numpy.compress(idx, local_peaktimes)

        if len(local_peaktimes) > 0:

            diff_times = local_peaktimes[-1] - local_peaktimes[0]
            all_times = numpy.zeros((N_e, diff_times + 1), dtype=numpy.bool)
            min_times = numpy.maximum(
                local_peaktimes - local_peaktimes[0] - safety_time, 0)
            max_times = numpy.minimum(
                local_peaktimes - local_peaktimes[0] + safety_time + 1,
                diff_times)
            argmax_peak = numpy.random.permutation(
                numpy.arange(len(local_peaktimes)))
            all_idx = numpy.take(local_peaktimes, argmax_peak)

            #print "Selection of the peaks with spatio-temporal masks..."
            for idx, peak in zip(argmax_peak, all_idx):
                elec = numpy.argmax(numpy.abs(local_chunk[peak]))
                indices = numpy.take(inv_nodes, edges[nodes[elec]])
                all_times[indices, min_times[idx]:max_times[idx]] = True
        else:
            all_times = numpy.zeros((N_e, len(local_chunk)), dtype=numpy.bool)

    all_times_Ne = numpy.any(all_times, 0)
    subset = numpy.where(all_times_Ne == False)[0]
    all_silences = []

    if do_spatial_whitening:
        local_silences = numpy.take(local_chunk, subset,
                                    axis=0)[:max_silence_1]
        all_silences = gather_array(local_silences, comm, 0, 1)

    local_res = []

    if do_temporal_whitening:

        for elec in all_electrodes[numpy.arange(comm.rank, nb_temp_white,
                                                comm.size)]:
            res = numpy.zeros((0, N_t), dtype=numpy.float32)
            scount = 0
            indices = numpy.take(inv_nodes, edges[nodes[elec]])
            all_times_elec = numpy.any(numpy.take(all_times, indices, axis=0),
                                       0)
            esubset = numpy.where(all_times_elec == False)[0]
            bound = len(esubset) - N_t
            while (scount < bound) and (len(res) < max_silence_2):
                myslice = esubset[scount:scount + N_t]
                if numpy.all((myslice - esubset[scount]) == numpy.arange(N_t)):
                    scount += N_t
                    res = numpy.vstack((res, local_chunk[myslice, elec]))
                else:
                    scount += 1
            if len(res) > 5:
                local_res += [numpy.cov(res.T)]

        nb_elecs = numpy.array([len(local_res)], dtype=numpy.float32)
        local_res = numpy.array(local_res, dtype=numpy.float32)
        if len(local_res) == 0:
            local_res = numpy.zeros(0, dtype=numpy.float32)
        else:
            local_res = numpy.sum(local_res, 0)
        all_res = gather_array(local_res.ravel(), comm, 0, 1)
        all_elecs = gather_array(nb_elecs, comm, 0, 1)

    if comm.rank == 0:

        to_write = {}

        if do_temporal_whitening:
            try:
                nb_silences = numpy.sum(all_elecs > 0)
                all_res = all_res.reshape((nb_silences, N_t**2))
            except Exception:
                print_and_log([
                    "No silent periods detected: something wrong with the parameters?"
                ], 'error', logger)
            all_res = numpy.sum(all_res, 0)
            all_res = all_res.reshape((N_t, N_t)) / numpy.sum(all_elecs)
            temporal_whitening = get_whitening_matrix(
                all_res.astype(numpy.double),
                fudge=1e-3)[template_shift].astype(numpy.float32)
            temporal_whitening /= temporal_whitening.sum()
            to_write['temporal'] = temporal_whitening
            have_nans = numpy.sum(numpy.isnan(temporal_whitening))

            if have_nans > 0:
                temporal_whitening = numpy.zeros(N_t, dtype=numpy.float32)
                temporal_whitening[N_t // 2] = 1
                to_write['temporal'] = temporal_whitening
                print_and_log(
                    ["Disabling temporal whitening because of NaNs found"],
                    'info', logger)

        if do_spatial_whitening:
            if len(all_silences) / params.rate == 0:
                print_and_log([
                    "No silent periods detected: something wrong with the parameters?"
                ], 'error', logger)
            spatial_whitening = get_whitening_matrix(
                all_silences.astype(numpy.double)).astype(numpy.float32)
            to_write['spatial'] = spatial_whitening
            print_and_log([
                "Found %gs without spikes for whitening matrices..." %
                (len(all_silences) / params.rate)
            ], 'default', logger)

            have_nans = numpy.sum(numpy.isnan(spatial_whitening))

            if have_nans > 0:
                spatial_whitening = numpy.eye(spatial_whitening.shape[0],
                                              dtype=numpy.float32)
                to_write['spatial'] = spatial_whitening
                print_and_log(
                    ["Disabling spatial whitening because of NaNs found"],
                    'info', logger)

        bfile = h5py.File(file_out_suff + '.basis.hdf5', 'r+', libver='latest')
        io.write_datasets(bfile, to_write.keys(), to_write)
        bfile.close()

    del all_silences
    comm.Barrier()

    if do_spatial_whitening or do_temporal_whitening:

        if comm.rank == 0:
            print_and_log(
                ["Because of whitening, need to recompute the thresholds..."],
                'default', logger)

        if do_spatial_whitening:
            spatial_whitening = io.load_data(params, 'spatial_whitening')
            if use_gpu:
                spatial_whitening = cmt.CUDAMatrix(spatial_whitening,
                                                   copy_on_host=False)
        if do_temporal_whitening:
            temporal_whitening = io.load_data(params, 'temporal_whitening')

        for gidx in [all_chunks[comm.rank]]:
            local_chunk, t_offset = data_file.get_data(gidx,
                                                       chunk_size,
                                                       nodes=nodes)
            local_shape = len(local_chunk)

            if do_spatial_whitening:
                if use_gpu:
                    local_chunk = cmt.CUDAMatrix(local_chunk,
                                                 copy_on_host=False)
                    local_chunk = local_chunk.dot(spatial_whitening).asarray()
                else:
                    local_chunk = numpy.dot(local_chunk, spatial_whitening)
            if do_temporal_whitening:
                local_chunk = scipy.ndimage.filters.convolve1d(
                    local_chunk, temporal_whitening, axis=0, mode='constant')

            thresholds = numpy.zeros(N_e, dtype=numpy.float32)
            for i in xrange(N_e):
                u = numpy.median(local_chunk[:, i], 0)
                thresholds[i] = numpy.median(numpy.abs(local_chunk[:, i] - u),
                                             0)
            gdata = gather_array(thresholds, comm)
            if comm.rank == 0:
                gdata = gdata.reshape((comm.size, N_e))
                thresholds = numpy.mean(gdata, 0)
                bfile = h5py.File(file_out_suff + '.basis.hdf5',
                                  'r+',
                                  libver='latest')
                bfile.pop('thresholds')
                io.write_datasets(bfile, ['thresholds'],
                                  {'thresholds': thresholds})
                bfile.close()
            comm.Barrier()

    #if comm.rank == 0:
    #if not os.path.exists(plot_path):
    #    os.makedirs(plot_path)
    #N_elec = min(int(numpy.sqrt(data_file.N_e)), 5)
    #plot.view_fit(filename, t_start=0, t_stop=1, fit_on=False, square=True,
    #              n_elec=N_elec, save=[plot_path, 'electrodes'])

    # Part 2: Basis
    numpy.random.seed(422)

    #################################################################
    file_out = params.get('data', 'file_out')
    alignment = params.getboolean('detection', 'alignment')
    spike_thresh = params.getfloat('detection', 'spike_thresh')
    nodes, edges = get_nodes_and_edges(params)
    do_temporal_whitening = params.getboolean('whitening', 'temporal')
    do_spatial_whitening = params.getboolean('whitening', 'spatial')
    chunk_size = params.getint('data', 'chunk_size')
    safety_time = params.getint('whitening', 'safety_time')
    max_elts_elec = params.getint('whitening', 'max_elts')
    nb_elts = int(
        params.getfloat('whitening', 'nb_elts') * N_e * max_elts_elec)
    output_dim = params.getfloat('whitening', 'output_dim')
    inv_nodes = numpy.zeros(N_total, dtype=numpy.int32)
    inv_nodes[nodes] = numpy.argsort(nodes)
    if sign_peaks == 'both':
        max_elts_elec *= 2
    #################################################################

    if comm.rank == 0:
        print_and_log(["Searching spikes to construct the PCA basis..."],
                      'default', logger)

    nb_chunks, last_chunk_len = data_file.analyze(chunk_size)

    if nb_chunks < comm.size:

        res = io.data_stats(params, show=False)
        chunk_size = int(res * params.rate // comm.size)
        if comm.rank == 0:
            print_and_log(
                ["Too much cores, automatically resizing the data chunks"],
                'debug', logger)

        nb_chunks, last_chunk_len = data_file.analyze(chunk_size)

    groups = {}
    for i in xrange(N_e):
        groups[i] = 0

    # I guess this is more relevant, to take signals from all over the recordings
    all_chunks = numpy.random.permutation(
        numpy.arange(nb_chunks, dtype=numpy.int32))
    max_elts_elec //= comm.size
    nb_elts //= comm.size

    elt_count_pos = 0
    elt_count_neg = 0

    if sign_peaks in ['positive', 'both']:
        elts_pos = numpy.zeros((N_t, nb_elts), dtype=numpy.float32)
    if sign_peaks in ['negative', 'both']:
        elts_neg = numpy.zeros((N_t, nb_elts), dtype=numpy.float32)

    chunks_to_load = all_chunks[comm.rank::comm.size]

    thresholds = io.load_data(params, 'thresholds')

    if comm.rank == 0:
        pbar = get_progressbar(nb_elts)

    if alignment:
        cdata = numpy.linspace(-template_shift, template_shift, 5 * N_t)
        xdata = numpy.arange(-2 * template_shift, 2 * template_shift + 1)

    for gcount, gidx in enumerate(chunks_to_load):

        if ((elt_count_pos + elt_count_neg) < nb_elts):
            #print "Node", comm.rank, "is analyzing chunk", gidx, "/", nb_chunks, " ..."
            local_chunk, t_offset = data_file.get_data(gidx,
                                                       chunk_size,
                                                       nodes=nodes)
            local_shape = len(local_chunk)

            if do_spatial_whitening:
                if use_gpu:
                    local_chunk = cmt.CUDAMatrix(local_chunk,
                                                 copy_on_host=False)
                    local_chunk = local_chunk.dot(spatial_whitening).asarray()
                else:
                    local_chunk = numpy.dot(local_chunk, spatial_whitening)
            if do_temporal_whitening:
                local_chunk = scipy.ndimage.filters.convolve1d(
                    local_chunk, temporal_whitening, axis=0, mode='constant')

            #print "Extracting the peaks..."
            all_peaktimes = numpy.zeros(0, dtype=numpy.int32)
            all_extremas = numpy.zeros(0, dtype=numpy.int32)

            for i in xrange(N_e):

                if sign_peaks == 'negative':
                    peaktimes = algo.detect_peaks(local_chunk[:, i],
                                                  thresholds[i],
                                                  valley=True,
                                                  mpd=dist_peaks)
                elif sign_peaks == 'positive':
                    peaktimes = algo.detect_peaks(local_chunk[:, i],
                                                  thresholds[i],
                                                  valley=False,
                                                  mpd=dist_peaks)
                elif sign_peaks == 'both':
                    peaktimes = algo.detect_peaks(numpy.abs(local_chunk[:, i]),
                                                  thresholds[i],
                                                  valley=False,
                                                  mpd=dist_peaks)
                all_peaktimes = numpy.concatenate((all_peaktimes, peaktimes))
                all_extremas = numpy.concatenate(
                    (all_extremas,
                     i * numpy.ones(len(peaktimes), dtype=numpy.int32)))

            #print "Removing the useless borders..."
            if alignment:
                local_borders = (2 * template_shift,
                                 local_shape - 2 * template_shift)
            else:
                local_borders = (template_shift, local_shape - template_shift)
            idx = (all_peaktimes >= local_borders[0]) & (all_peaktimes <
                                                         local_borders[1])
            all_peaktimes = numpy.compress(idx, all_peaktimes)
            all_extremas = numpy.compress(idx, all_extremas)

            local_peaktimes = numpy.unique(all_peaktimes)

            if len(local_peaktimes) > 0:

                diff_times = local_peaktimes[-1] - local_peaktimes[0]
                all_times = numpy.zeros((N_e, diff_times + 1),
                                        dtype=numpy.bool)
                min_times = numpy.maximum(
                    local_peaktimes - local_peaktimes[0] - safety_time, 0)
                max_times = numpy.minimum(
                    local_peaktimes - local_peaktimes[0] + safety_time + 1,
                    diff_times)

                n_times = len(local_peaktimes)
                argmax_peak = numpy.random.permutation(numpy.arange(n_times))
                all_idx = numpy.take(local_peaktimes, argmax_peak)

                #print "Selection of the peaks with spatio-temporal masks..."
                for midx, peak in zip(argmax_peak, all_idx):
                    if (elt_count_neg + elt_count_pos) == nb_elts:
                        break

                    if sign_peaks == 'negative':
                        elec = numpy.argmin(local_chunk[peak])
                        negative_peak = True
                    elif sign_peaks == 'positive':
                        elec = numpy.argmax(local_chunk[peak])
                        negative_peak = False
                    elif sign_peaks == 'both':
                        if numpy.abs(numpy.max(local_chunk[peak])) > numpy.abs(
                                numpy.min(local_chunk[peak])):
                            elec = numpy.argmax(local_chunk[peak])
                            negative_peak = False
                        else:
                            elec = numpy.argmin(local_chunk[peak])
                            negative_peak = True

                    indices = numpy.take(inv_nodes, edges[nodes[elec]])
                    myslice = all_times[indices,
                                        min_times[midx]:max_times[midx]]
                    is_local_extrema = elec in all_extremas[all_peaktimes ==
                                                            peak]
                    if is_local_extrema and not myslice.any():
                        upper_bounds = max_elts_elec

                        if groups[elec] < upper_bounds:

                            if negative_peak:
                                elts_neg[:, elt_count_neg] = local_chunk[
                                    peak - template_shift:peak +
                                    template_shift + 1, elec]
                            else:
                                elts_pos[:, elt_count_pos] = local_chunk[
                                    peak - template_shift:peak +
                                    template_shift + 1, elec]
                            if alignment:
                                ydata = local_chunk[peak -
                                                    2 * template_shift:peak +
                                                    2 * template_shift + 1,
                                                    elec]
                                f = scipy.interpolate.UnivariateSpline(xdata,
                                                                       ydata,
                                                                       s=0)
                                if negative_peak:
                                    rmin = (numpy.argmin(f(cdata)) -
                                            len(cdata) / 2.) / 5.
                                else:
                                    rmin = (numpy.argmax(f(cdata)) -
                                            len(cdata) / 2.) / 5.
                                ddata = numpy.linspace(rmin - template_shift,
                                                       rmin + template_shift,
                                                       N_t)

                                if negative_peak:
                                    elts_neg[:,
                                             elt_count_neg] = f(ddata).astype(
                                                 numpy.float32)
                                else:
                                    elts_pos[:,
                                             elt_count_pos] = f(ddata).astype(
                                                 numpy.float32)

                            if negative_peak:
                                elt_count_neg += 1
                            else:
                                elt_count_pos += 1

                        groups[elec] += 1
                        all_times[indices,
                                  min_times[midx]:max_times[midx]] = True

            if comm.rank == 0:
                pbar.update(elt_count_pos + elt_count_neg)

        if comm.rank == 0:
            if (elt_count_pos + elt_count_neg <
                (gcount + 1) * max_elts_elec // len(chunks_to_load)):
                pbar.update(
                    (gcount + 1) * max_elts_elec // len(chunks_to_load))

    if comm.rank == 0:
        pbar.finish()

    print_and_log([
        "Node %d has collected %d waveforms" %
        (comm.rank, elt_count_pos + elt_count_neg)
    ], 'debug', logger)

    if sign_peaks in ['negative', 'both']:
        gdata_neg = gather_array(elts_neg[:, :elt_count_neg].T, comm, 0, 1)
    if sign_peaks in ['positive', 'both']:
        gdata_pos = gather_array(elts_pos[:, :elt_count_pos].T, comm, 0, 1)

    if comm.rank == 0:
        #DO PCA on elts and store the basis obtained.

        nb_waveforms = 0
        if sign_peaks in ['negative', 'both']:
            nb_waveforms += gdata_neg.shape[0]
        if sign_peaks in ['positive', 'both']:
            nb_waveforms += gdata_pos.shape[0]

        print_and_log([
            "Found %d waveforms over %d requested" %
            (nb_waveforms, int(nb_elts * comm.size))
        ], 'default', logger)
        pca = PCA(output_dim, copy=False)
        res = {}
        if sign_peaks in ['negative', 'both']:
            if len(gdata_neg) > 0:
                res_pca = pca.fit_transform(gdata_neg.astype(
                    numpy.double)).astype(numpy.float32)
                res['proj'] = pca.components_.T.astype(numpy.float32)
            else:
                res['proj'] = numpy.identity(N_t, dtype=numpy.float32)
            res['rec'] = res['proj'].T
            res['waveform'] = numpy.median(gdata_neg, 0)
            idx = numpy.random.permutation(numpy.arange(
                gdata_neg.shape[0]))[:1000]
            res['waveforms'] = gdata_neg[idx, :]
        if sign_peaks in ['positive', 'both']:
            if len(gdata_pos) > 0:
                res_pca = pca.fit_transform(gdata_pos.astype(
                    numpy.double)).astype(numpy.float32)
                res['proj_pos'] = pca.components_.T.astype(numpy.float32)
            else:
                res['proj_pos'] = numpy.identity(N_t, dtype=numpy.float32)
            res['rec_pos'] = res['proj_pos'].T
            res['waveform_pos'] = numpy.median(gdata_pos, 0)
            idx = numpy.random.permutation(numpy.arange(
                gdata_pos.shape[0]))[:1000]
            res['waveforms_pos'] = gdata_pos[idx, :]

        bfile = h5py.File(file_out_suff + '.basis.hdf5', 'r+', libver='latest')
        io.write_datasets(bfile, res.keys(), res)
        if sign_peaks == 'positive':
            print_and_log([
                "A basis with %s dimensions has been built" %
                res['proj_pos'].shape[1]
            ], 'info', logger)
        elif sign_peaks == 'negative':
            print_and_log([
                "A basis with %s dimensions has been built" %
                res['proj'].shape[1]
            ], 'info', logger)
        elif sign_peaks == 'both':
            print_and_log([
                "Two basis with %s dimensions has been built" %
                res['proj'].shape[1]
            ], 'info', logger)

        bfile.close()

    comm.Barrier()

    if matched_filter:

        if comm.rank == 0:
            print_and_log([
                "Because of matched filters, need to recompute the thresholds..."
            ], 'default', logger)

        if do_spatial_whitening:
            spatial_whitening = io.load_data(params, 'spatial_whitening')
            if use_gpu:
                spatial_whitening = cmt.CUDAMatrix(spatial_whitening,
                                                   copy_on_host=False)
        if do_temporal_whitening:
            temporal_whitening = io.load_data(params, 'temporal_whitening')

        if sign_peaks in ['negative', 'both']:
            waveform_neg = io.load_data(params, 'waveform')
            waveform_neg /= (numpy.abs(numpy.sum(waveform_neg)) *
                             len(waveform_neg))
        if sign_peaks in ['positive', 'both']:
            waveform_pos = io.load_data(params, 'waveform-pos')
            waveform_pos /= (numpy.abs(numpy.sum(waveform_pos)) *
                             len(waveform_pos))

        for gidx in [all_chunks[comm.rank]]:
            local_chunk, t_offset = data_file.get_data(gidx,
                                                       chunk_size,
                                                       nodes=nodes)
            local_shape = len(local_chunk)

            if do_spatial_whitening:
                if use_gpu:
                    local_chunk = cmt.CUDAMatrix(local_chunk,
                                                 copy_on_host=False)
                    local_chunk = local_chunk.dot(spatial_whitening).asarray()
                else:
                    local_chunk = numpy.dot(local_chunk, spatial_whitening)
            if do_temporal_whitening:
                local_chunk = scipy.ndimage.filters.convolve1d(
                    local_chunk, temporal_whitening, axis=0, mode='constant')

            if sign_peaks in ['negative', 'both']:
                tmp_chunk = scipy.ndimage.filters.convolve1d(local_chunk,
                                                             waveform_neg,
                                                             axis=0,
                                                             mode='constant')
                thresholds = numpy.zeros(N_e, dtype=numpy.float32)
                for i in xrange(N_e):
                    u = numpy.median(tmp_chunk[:, i], 0)
                    thresholds[i] = numpy.median(
                        numpy.abs(tmp_chunk[:, i] - u), 0)
                gdata = gather_array(thresholds, comm)
                if comm.rank == 0:
                    gdata = gdata.reshape((comm.size, N_e))
                    thresholds = numpy.mean(gdata, 0)
                    bfile = h5py.File(file_out + '.basis.hdf5',
                                      'r+',
                                      libver='latest')
                    io.write_datasets(bfile, ['matched_thresholds'],
                                      {'matched_thresholds': thresholds})
                    bfile.close()
                comm.Barrier()

            if sign_peaks in ['positive', 'both']:
                tmp_chunk = scipy.ndimage.filters.convolve1d(local_chunk,
                                                             waveform_pos,
                                                             axis=0,
                                                             mode='constant')
                thresholds = numpy.zeros(N_e, dtype=numpy.float32)
                for i in xrange(N_e):
                    u = numpy.median(tmp_chunk[:, i], 0)
                    thresholds[i] = numpy.median(
                        numpy.abs(tmp_chunk[:, i] - u), 0)
                gdata = gather_array(thresholds, comm)
                if comm.rank == 0:
                    gdata = gdata.reshape((comm.size, N_e))
                    thresholds = numpy.mean(gdata, 0)
                    bfile = h5py.File(file_out + '.basis.hdf5',
                                      'r+',
                                      libver='latest')
                    io.write_datasets(bfile, ['matched_thresholds_pos'],
                                      {'matched_thresholds_pos': thresholds})
                    bfile.close()
                comm.Barrier()

    data_file.close()
Exemple #25
0
def main(params, nb_cpu, nb_gpu, use_gpu, extension):

    _ = init_logging(params.logfile)
    logger = logging.getLogger('circus.converting')
    data_file = params.data_file
    file_out_suff = params.get('data', 'file_out_suff')
    probe = params.probe
    output_path = params.get('data', 'file_out_suff') + extension + '.GUI'
    N_e = params.getint('data', 'N_e')
    prelabelling = params.getboolean('converting', 'prelabelling')
    N_t = params.getint('detection', 'N_t')
    erase_all = params.getboolean('converting', 'erase_all')
    export_pcs = params.get('converting', 'export_pcs')
    export_all = params.getboolean('converting', 'export_all')
    sparse_export = params.getboolean('converting', 'sparse_export')
    rpv_threshold = params.getfloat('converting', 'rpv_threshold')
    if export_all and not params.getboolean('fitting', 'collect_all'):
        if comm.rank == 0:
            print_and_log([
                'Export unfitted spikes only if [fitting] collect_all is True'
            ], 'error', logger)
        sys.exit(0)

    def generate_mapping(probe):
        p = {}
        positions = []
        nodes = []
        shanks = []
        for key in list(probe['channel_groups'].keys()):
            p.update(probe['channel_groups'][key]['geometry'])
            nodes += probe['channel_groups'][key]['channels']
            positions += [
                p[channel]
                for channel in probe['channel_groups'][key]['channels']
            ]
            shanks += [key] * len(probe['channel_groups'][key]['channels'])
        positions = numpy.array(positions)
        shanks = numpy.array(shanks)
        return positions, shanks

    def get_max_loc_channel(params, extension):
        if test_if_support(params, extension):
            supports = io.load_data(params, 'supports', extension)
            max_loc_channel = numpy.sum(supports, 1).max()
        else:
            nodes, edges = get_nodes_and_edges(params)
            max_loc_channel = 0
            for key in list(edges.keys()):
                if len(edges[key]) > max_loc_channel:
                    max_loc_channel = len(edges[key])
        return max_loc_channel

    def write_results(path, params, extension):
        result = io.get_results(params, extension)
        spikes = [numpy.zeros(0, dtype=numpy.uint64)]
        clusters = [numpy.zeros(0, dtype=numpy.uint32)]
        amplitudes = [numpy.zeros(0, dtype=numpy.double)]
        N_tm = len(result['spiketimes'])

        has_purity = test_if_purity(params, extension)
        rpvs = []

        if prelabelling:
            labels = []
            norms = io.load_data(params, 'norm-templates', extension)
            norms = norms[:len(norms) // 2]
            if has_purity:
                purity = io.load_data(params, 'purity', extension)

        for key in list(result['spiketimes'].keys()):
            temp_id = int(key.split('_')[-1])
            myspikes = result['spiketimes'].pop(key).astype(numpy.uint64)
            spikes.append(myspikes)
            myamplitudes = result['amplitudes'].pop(key).astype(numpy.double)
            amplitudes.append(myamplitudes[:, 0])
            clusters.append(temp_id *
                            numpy.ones(len(myamplitudes), dtype=numpy.uint32))
            rpv = get_rpv(myspikes, params.data_file.sampling_rate)
            rpvs += [[temp_id, rpv]]
            if prelabelling:
                if has_purity:
                    if rpv <= rpv_threshold:
                        if purity[temp_id] > 0.75:
                            labels += [[temp_id, 'good']]
                    else:
                        if purity[temp_id] > 0.75:
                            labels += [[temp_id, 'mua']]
                        else:
                            labels += [[temp_id, 'noise']]
                else:
                    median_amp = numpy.median(myamplitudes[:, 0])
                    std_amp = numpy.std(myamplitudes[:, 0])
                    if rpv <= rpv_threshold and numpy.abs(median_amp -
                                                          1) < 0.25:
                        labels += [[temp_id, 'good']]
                    else:
                        if median_amp < 0.5:
                            labels += [[temp_id, 'mua']]
                        elif norms[temp_id] < 0.1:
                            labels += [[temp_id, 'noise']]

        if export_all:
            print_and_log([
                "Last %d templates are unfitted spikes on all electrodes" % N_e
            ], 'info', logger)
            garbage = io.load_data(params, 'garbage', extension)
            for key in list(garbage['gspikes'].keys()):
                elec_id = int(key.split('_')[-1])
                data = garbage['gspikes'].pop(key).astype(numpy.uint64)
                spikes.append(data)
                amplitudes.append(numpy.ones(len(data)))
                clusters.append((elec_id + N_tm) *
                                numpy.ones(len(data), dtype=numpy.uint32))

        if prelabelling:
            f = open(os.path.join(output_path, 'cluster_group.tsv'), 'w')
            f.write('cluster_id\tgroup\n')
            for l in labels:
                f.write('%s\t%s\n' % (l[0], l[1]))
            f.close()

        # f = open(os.path.join(output_path, 'cluster_rpv.tsv'), 'w')
        # f.write('cluster_id\trpv\n')
        # for l in rpvs:
        #     f.write('%s\t%s\n' % (l[0], l[1]))
        # f.close()

        spikes = numpy.concatenate(spikes).astype(numpy.uint64)
        amplitudes = numpy.concatenate(amplitudes).astype(numpy.double)
        clusters = numpy.concatenate(clusters).astype(numpy.uint32)

        idx = numpy.argsort(spikes)
        numpy.save(os.path.join(output_path, 'spike_templates'), clusters[idx])
        numpy.save(os.path.join(output_path, 'spike_times'), spikes[idx])
        numpy.save(os.path.join(output_path, 'amplitudes'), amplitudes[idx])
        return

    def write_templates(path, params, extension):

        max_loc_channel = get_max_loc_channel(params, extension)
        templates = io.load_data(params, 'templates', extension)
        N_tm = templates.shape[1] // 2
        nodes, edges = get_nodes_and_edges(params)

        if sparse_export:
            n_channels_max = 0
            for t in range(N_tm):
                data = numpy.sum(
                    numpy.sum(templates[:, t].toarray().reshape(N_e, N_t), 1)
                    != 0)
                if data > n_channels_max:
                    n_channels_max = data
        else:
            n_channels_max = N_e

        if export_all:
            to_write_sparse = numpy.zeros((N_tm + N_e, N_t, n_channels_max),
                                          dtype=numpy.float32)
            mapping_sparse = -1 * numpy.ones(
                (N_tm + N_e, n_channels_max), dtype=numpy.int32)
        else:
            to_write_sparse = numpy.zeros((N_tm, N_t, n_channels_max),
                                          dtype=numpy.float32)
            mapping_sparse = -1 * numpy.ones(
                (N_tm, n_channels_max), dtype=numpy.int32)

        has_purity = test_if_purity(params, extension)
        if has_purity:
            purity = io.load_data(params, 'purity', extension)
            f = open(os.path.join(output_path, 'cluster_purity.tsv'), 'w')
            f.write('cluster_id\tpurity\n')
            for i in range(N_tm):
                f.write('%d\t%g\n' % (i, purity[i]))
            f.close()

        for t in range(N_tm):
            tmp = templates[:, t].toarray().reshape(N_e, N_t).T
            x, y = tmp.nonzero()
            nb_loc = len(numpy.unique(y))

            if sparse_export:
                all_positions = numpy.zeros(y.max() + 1, dtype=numpy.int32)
                all_positions[numpy.unique(y)] = numpy.arange(
                    nb_loc, dtype=numpy.int32)
                pos = all_positions[y]
                to_write_sparse[t, x, pos] = tmp[x, y]
                mapping_sparse[t, numpy.arange(nb_loc)] = numpy.unique(y)
            else:
                pos = y
                to_write_sparse[t, x, pos] = tmp[x, y]

        if export_all:
            garbage = io.load_data(params, 'garbage', extension)
            for t in range(N_tm, N_tm + N_e):
                elec = t - N_tm
                spikes = garbage['gspikes'].pop('elec_%d' % elec).astype(
                    numpy.int64)
                spikes = numpy.random.permutation(spikes)[:100]
                mapping_sparse[t, 0] = t - N_tm
                waveform = io.get_stas(params,
                                       times_i=spikes,
                                       labels_i=np.ones(len(spikes)),
                                       src=elec,
                                       neighs=[elec],
                                       nodes=nodes,
                                       mean_mode=True)

                nb_loc = 1

                if sparse_export:
                    to_write_sparse[t, :, 0] = waveform
                else:
                    to_write_sparse[t, :, elec] = waveform

        numpy.save(os.path.join(output_path, 'templates'), to_write_sparse)

        if sparse_export:
            numpy.save(os.path.join(output_path, 'template_ind'),
                       mapping_sparse)

        return N_tm

    def write_pcs(path, params, extension, N_tm, mode=0):

        spikes = numpy.load(os.path.join(output_path, 'spike_times.npy'))
        labels = numpy.load(os.path.join(output_path, 'spike_templates.npy'))
        max_loc_channel = get_max_loc_channel(params, extension)
        nb_features = params.getint('whitening', 'output_dim')
        sign_peaks = params.get('detection', 'peaks')
        nodes, edges = get_nodes_and_edges(params)
        N_total = params.getint('data', 'N_total')
        has_support = test_if_support(params, extension)
        if has_support:
            supports = io.load_data(params, 'supports', extension)
        else:
            inv_nodes = numpy.zeros(N_total, dtype=numpy.int32)
            inv_nodes[nodes] = numpy.arange(len(nodes))

        if export_all:
            nb_templates = N_tm + N_e
        else:
            nb_templates = N_tm

        pc_features_ind = numpy.zeros((nb_templates, max_loc_channel),
                                      dtype=numpy.int32)
        best_elec = io.load_data(params, 'electrodes', extension)
        if export_all:
            best_elec = numpy.concatenate((best_elec, numpy.arange(N_e)))

        if has_support:
            for count, support in enumerate(supports):
                nb_loc = numpy.sum(support)
                pc_features_ind[count, numpy.arange(nb_loc)] = numpy.where(
                    support == True)[0]
        else:
            for count, elec in enumerate(best_elec):
                nb_loc = len(edges[nodes[elec]])
                pc_features_ind[count, numpy.arange(nb_loc)] = inv_nodes[edges[
                    nodes[elec]]]

        if sign_peaks in ['negative', 'both']:
            basis_proj, basis_rec = io.load_data(params, 'basis')
        elif sign_peaks in ['positive']:
            basis_proj, basis_rec = io.load_data(params, 'basis-pos')

        to_process = numpy.arange(comm.rank, nb_templates, comm.size)

        all_offsets = numpy.zeros(nb_templates, dtype=numpy.int32)
        for target in range(nb_templates):
            if mode == 0:
                all_offsets[target] = len(numpy.where(labels == target)[0])
            elif mode == 1:
                all_offsets[target] = min(
                    500, len(numpy.where(labels == target)[0]))

        all_paddings = numpy.concatenate(([0], numpy.cumsum(all_offsets)))
        total_pcs = numpy.sum(all_offsets)

        pc_file = os.path.join(output_path, 'pc_features.npy')
        pc_file_ids = os.path.join(output_path, 'pc_feature_spike_ids.npy')

        from numpy.lib.format import open_memmap

        if comm.rank == 0:
            pc_features = open_memmap(pc_file,
                                      shape=(total_pcs, nb_features,
                                             max_loc_channel),
                                      dtype=numpy.float32,
                                      mode='w+')
            if mode == 1:
                pc_ids = open_memmap(pc_file_ids,
                                     shape=(total_pcs, ),
                                     dtype=numpy.int32,
                                     mode='w+')

        comm.Barrier()
        pc_features = open_memmap(pc_file, mode='r+')
        if mode == 1:
            pc_ids = open_memmap(pc_file_ids, mode='r+')

        to_explore = list(range(comm.rank, nb_templates, comm.size))

        if comm.rank == 0:
            to_explore = get_tqdm_progressbar(params, to_explore)

        all_idx = numpy.zeros(0, dtype=numpy.int32)
        for gcount, target in enumerate(to_explore):

            count = all_paddings[target]

            if mode == 1:
                idx = numpy.random.permutation(
                    numpy.where(labels == target)[0])[:500]
                pc_ids[count:count + len(idx)] = idx
            elif mode == 0:
                idx = numpy.where(labels == target)[0]

            elec = best_elec[target]

            if has_support:
                if target >= len(supports):
                    indices = [target - N_tm]
                else:
                    indices = numpy.where(supports[target])[0]
            else:
                indices = inv_nodes[edges[nodes[elec]]]
            labels_i = target * numpy.ones(len(idx))
            times_i = numpy.take(spikes, idx).astype(numpy.int64)
            sub_data = io.get_stas(params,
                                   times_i,
                                   labels_i,
                                   elec,
                                   neighs=indices,
                                   nodes=nodes,
                                   auto_align=False)

            pcs = numpy.dot(sub_data, basis_proj)
            pcs = numpy.swapaxes(pcs, 1, 2)
            if mode == 0:
                pc_features[idx, :, :len(indices)] = pcs
            elif mode == 1:
                pc_features[count:count + len(idx), :, :len(indices)] = pcs

        comm.Barrier()

        if comm.rank == 0:
            numpy.save(os.path.join(output_path, 'pc_feature_ind'),
                       pc_features_ind.astype(
                           numpy.uint32))  # n_templates, n_loc_chan

    do_export = True
    if comm.rank == 0:
        if os.path.exists(output_path):
            if not erase_all:
                do_export = query_yes_no(
                    Fore.WHITE +
                    "Export already made! Do you want to erase everything?",
                    default=None)

            if do_export:
                if os.path.exists(os.path.abspath('.phy')):
                    shutil.rmtree(os.path.abspath('.phy'))
                shutil.rmtree(output_path)
        if do_export:
            comm.bcast(numpy.array([1], dtype=numpy.int32), root=0)
        else:
            comm.bcast(numpy.array([0], dtype=numpy.int32), root=0)
    else:
        do_export = bool(
            comm.bcast(numpy.array([0], dtype=numpy.int32), root=0))

    comm.Barrier()

    if do_export:

        #apply_patch_for_similarities(params, extension)

        if comm.rank == 0:
            os.makedirs(output_path)
            print_and_log(
                ["Exporting data for the phy GUI with %d CPUs..." % nb_cpu],
                'info', logger)

            if params.getboolean('whitening', 'spatial'):
                whitening_mat = io.load_data(
                    params, 'spatial_whitening').astype(numpy.double)
                numpy.save(os.path.join(output_path, 'whitening_mat'),
                           whitening_mat)
                numpy.save(os.path.join(output_path, 'whitening_mat_inv'),
                           numpy.linalg.inv(whitening_mat))
            else:
                numpy.save(os.path.join(output_path, 'whitening_mat'),
                           numpy.eye(N_e))

            positions, shanks = generate_mapping(probe)
            numpy.save(os.path.join(output_path, 'channel_positions'),
                       positions.astype(numpy.double))
            numpy.save(os.path.join(output_path, 'channel_shanks'),
                       shanks.astype(numpy.double))
            nodes, edges = get_nodes_and_edges(params)
            numpy.save(os.path.join(output_path, 'channel_map'),
                       nodes.astype(numpy.int32))

            write_results(output_path, params, extension)

            N_tm = write_templates(output_path, params, extension)

            template_file = h5py.File(file_out_suff +
                                      '.templates%s.hdf5' % extension,
                                      'r',
                                      libver='earliest')
            similarities = template_file.get('maxoverlap')[:]
            template_file.close()
            norm = N_e * N_t

            if export_all:
                to_write = numpy.zeros((N_tm + N_e, N_tm + N_e),
                                       dtype=numpy.single)
                to_write[:N_tm, :N_tm] = (similarities[:N_tm, :N_tm] /
                                          norm).astype(numpy.single)
            else:
                to_write = (similarities[:N_tm, :N_tm] / norm).astype(
                    numpy.single)
            numpy.save(os.path.join(output_path, 'similar_templates'),
                       to_write)

            comm.bcast(numpy.array([N_tm], dtype=numpy.int32), root=0)

        else:
            N_tm = int(comm.bcast(numpy.array([0], dtype=numpy.int32), root=0))

        comm.Barrier()

        make_pcs = 2
        if comm.rank == 0:

            if export_pcs == 'prompt':
                key = ''
                while key not in ['a', 's', 'n']:
                    print((
                        Fore.WHITE +
                        "Do you want SpyKING CIRCUS to export PCs? (a)ll / (s)ome / (n)o"
                    ))
                    key = input('')
            else:
                key = export_pcs

            if key == 'a':
                make_pcs = 0
                comm.bcast(numpy.array([0], dtype=numpy.int32), root=0)
            elif key == 's':
                make_pcs = 1
                comm.bcast(numpy.array([1], dtype=numpy.int32), root=0)
            elif key == 'n':
                comm.bcast(numpy.array([2], dtype=numpy.int32), root=0)
                if os.path.exists(os.path.join(output_path,
                                               'pc_features.npy')):
                    os.remove(os.path.join(output_path, 'pc_features.npy'))
                if os.path.exists(
                        os.path.join(output_path, 'pc_feature_ind.npy')):
                    os.remove(os.path.join(output_path, 'pc_feature_ind.npy'))
        else:
            make_pcs = comm.bcast(numpy.array([0], dtype=numpy.int32), root=0)
            make_pcs = make_pcs[0]

        comm.Barrier()
        if make_pcs < 2:
            write_pcs(output_path, params, extension, N_tm, make_pcs)

        supported_by_phy = ['raw_binary', 'mcs_raw_binary', 'mda']
        file_format = data_file.description
        gui_params = {}

        if file_format in supported_by_phy:
            if not params.getboolean('data', 'overwrite'):
                gui_params['dat_path'] = r"%s" % params.get(
                    'data', 'data_file_no_overwrite')
            else:
                if params.get('data', 'stream_mode') == 'multi-files':
                    data_file = params.get_data_file(source=True,
                                                     has_been_created=False)
                    gui_params['dat_path'] = "["
                    for f in data_file.get_file_names():
                        gui_params['dat_path'] += 'r"%s", ' % f
                    gui_params['dat_path'] += "]"
                else:
                    gui_params['dat_path'] = 'r"%s"' % params.get(
                        'data', 'data_file')
        else:
            gui_params['dat_path'] = 'giverandomname.dat'
        gui_params['n_channels_dat'] = params.nb_channels
        gui_params['n_features_per_channel'] = 5
        gui_params['dtype'] = data_file.data_dtype
        if 'data_offset' in list(data_file.params.keys()):
            gui_params['offset'] = data_file.data_offset
        gui_params['sample_rate'] = params.rate
        gui_params['dir_path'] = output_path
        gui_params['hp_filtered'] = True

        f = open(os.path.join(output_path, 'params.py'), 'w')
        for key, value in list(gui_params.items()):
            if key in ['dir_path', 'dtype']:
                f.write('%s = r"%s"\n' % (key, value))
            else:
                f.write("%s = %s\n" % (key, value))
        f.close()
Exemple #26
0
def delete_mixtures(params, nb_cpu, nb_gpu, use_gpu):

    data_file = params.data_file
    N_e = params.getint('data', 'N_e')
    N_total = params.nb_channels
    N_t = params.getint('detection', 'N_t')
    template_shift = params.getint('detection', 'template_shift')
    cc_merge = params.getfloat('clustering', 'cc_mixtures')
    mixtures = []
    to_remove = []

    filename = params.get('data', 'file_out_suff') + '.overlap-mixtures.hdf5'
    norm_templates = load_data(params, 'norm-templates')
    best_elec = load_data(params, 'electrodes')
    limits = load_data(params, 'limits')
    nodes, edges = get_nodes_and_edges(params)
    inv_nodes = numpy.zeros(N_total, dtype=numpy.int32)
    inv_nodes[nodes] = numpy.argsort(nodes)

    overlap = get_overlaps(params,
                           extension='-mixtures',
                           erase=True,
                           normalize=False,
                           maxoverlap=False,
                           verbose=False,
                           half=True,
                           use_gpu=use_gpu,
                           nb_cpu=nb_cpu,
                           nb_gpu=nb_gpu)
    overlap.close()

    SHARED_MEMORY = get_shared_memory_flag(params)

    if SHARED_MEMORY:
        c_overs = load_data_memshared(params,
                                      'overlaps',
                                      extension='-mixtures',
                                      use_gpu=use_gpu,
                                      nb_cpu=nb_cpu,
                                      nb_gpu=nb_gpu)
    else:
        c_overs = load_data(params, 'overlaps', extension='-mixtures')

    if SHARED_MEMORY:
        templates = load_data_memshared(params, 'templates', normalize=False)
    else:
        templates = load_data(params, 'templates')

    x, N_tm = templates.shape
    nb_temp = int(N_tm // 2)
    merged = [nb_temp, 0]

    overlap_0 = numpy.zeros(nb_temp, dtype=numpy.float32)
    distances = numpy.zeros((nb_temp, nb_temp), dtype=numpy.int32)

    for i in xrange(nb_temp - 1):
        data = c_overs[i].toarray()
        distances[i, i + 1:] = numpy.argmax(data[i + 1:, :], 1)
        distances[i + 1:, i] = distances[i, i + 1:]
        overlap_0[i] = data[i, N_t]

    all_temp = numpy.arange(comm.rank, nb_temp, comm.size)
    sorted_temp = numpy.argsort(
        norm_templates[:nb_temp])[::-1][comm.rank::comm.size]
    M = numpy.zeros((2, 2), dtype=numpy.float32)
    V = numpy.zeros((2, 1), dtype=numpy.float32)

    to_explore = xrange(comm.rank, len(sorted_temp), comm.size)
    if comm.rank == 0:
        to_explore = get_tqdm_progressbar(to_explore)

    for count, k in enumerate(to_explore):

        k = sorted_temp[k]
        electrodes = numpy.take(inv_nodes, edges[nodes[best_elec[k]]])
        overlap_k = c_overs[k]
        is_in_area = numpy.in1d(best_elec, electrodes)
        all_idx = numpy.arange(len(best_elec))[is_in_area]
        been_found = False
        t_k = None

        for i in all_idx:
            t_i = None
            if not been_found:
                overlap_i = c_overs[i]
                M[0, 0] = overlap_0[i]
                V[0, 0] = overlap_k[i, distances[k, i]]
                for j in all_idx[i + 1:]:
                    t_j = None
                    M[1, 1] = overlap_0[j]
                    M[1, 0] = overlap_i[j, distances[k, i] - distances[k, j]]
                    M[0, 1] = M[1, 0]
                    V[1, 0] = overlap_k[j, distances[k, j]]
                    try:
                        [a1, a2] = numpy.dot(scipy.linalg.inv(M), V)
                    except Exception:
                        [a1, a2] = [0, 0]
                    a1_lim = limits[i]
                    a2_lim = limits[j]
                    is_a1 = (a1_lim[0] <= a1) and (a1 <= a1_lim[1])
                    is_a2 = (a2_lim[0] <= a2) and (a2 <= a2_lim[1])
                    if is_a1 and is_a2:
                        if t_k is None:
                            t_k = templates[:, k].toarray().ravel()
                        if t_i is None:
                            t_i = templates[:, i].toarray().ravel()
                        if t_j is None:
                            t_j = templates[:, j].toarray().ravel()
                        new_template = (a1 * t_i + a2 * t_j)
                        similarity = numpy.corrcoef(t_k, new_template)[0, 1]
                        local_overlap = numpy.corrcoef(t_i, t_j)[0, 1]
                        if similarity > cc_merge and local_overlap < cc_merge:
                            if k not in mixtures:
                                mixtures += [k]
                                been_found = True
                                #print "Template", k, 'is sum of (%d, %g) and (%d,%g)' %(i, a1, j, a2)
                                break
    sys.stderr.flush()
    #print mixtures
    to_remove = numpy.unique(numpy.array(mixtures, dtype=numpy.int32))
    to_remove = all_gather_array(to_remove, comm, 0, dtype='int32')

    if len(to_remove) > 0 and comm.rank == 0:
        result = load_data(params, 'clusters')
        slice_templates(params, to_remove)
        slice_clusters(params, result, to_remove=to_remove)

    comm.Barrier()

    del c_overs

    if comm.rank == 0:
        os.remove(filename)

    return [nb_temp, len(to_remove)]
electrode = clusters_data['electrodes'][args.template_id]
local_cluster = clusters_data['local_clusters'][args.template_id]
assert electrode.shape == local_cluster.shape, (electrode.shape,
                                                local_cluster.shape)
times = clusters_data['times'][electrode]
clusters = clusters_data['clusters'][electrode]
assert times.shape == clusters.shape, (times.shape, clusters.shape)
selection = (clusters == local_cluster)
times = times[selection]
clusters = clusters[selection]
if times.size > nb_snippets:
    indices = np.random.choice(times.size, size=nb_snippets)
    indices = np.sort(indices)
    times = times[indices]
    clusters = clusters[indices]
nodes, _ = get_nodes_and_edges(params)
inv_nodes = np.zeros(nb_electrodes, dtype=np.int)
inv_nodes[nodes] = np.arange(len(nodes))
indices = inv_nodes[nodes]
snippets = get_stas(params, times, clusters, electrode, indices, nodes=nodes)
snippets = np.transpose(snippets, axes=(0, 2, 1))
assert snippets.shape == (nb_snippets, nb_time_steps, nb_channels), \
    (snippets.shape, nb_snippets, nb_time_steps, nb_channels)

# Load template.
template = load_template(args.template_id, params, extension='')
assert template.shape == (nb_time_steps, nb_channels)

# Compute the scalar products.
snippets_ = np.reshape(snippets, (nb_snippets, nb_channels * nb_time_steps))
template_ = np.reshape(template, (nb_channels * nb_time_steps, 1))