示例#1
0
def main(params, nb_cpu, nb_gpu, use_gpu):
    # Part 1: Whitening
    numpy.random.seed(420)
    # params = detect_memory(params)
    _ = init_logging(params.logfile)
    logger = logging.getLogger('circus.whitening')
    #################################################################
    data_file = params.data_file
    N_e = params.getint('data', 'N_e')
    hdf5_compress = params.getboolean('data', 'hdf5_compress')
    N_total = params.nb_channels
    N_t = params.getint('detection', 'N_t')
    dist_peaks = params.getint('detection', 'dist_peaks')
    template_shift = params.getint('detection', 'template_shift')
    file_out_suff = params.get('data', 'file_out_suff')
    spike_thresh = params.getfloat('detection', 'spike_thresh')
    spike_width = params.getfloat('detection', 'spike_width')
    matched_filter = params.getboolean('detection', 'matched-filter')
    matched_thresh = params.getfloat('detection', 'matched_thresh')
    fudge = params.getfloat('whitening', 'fudge')
    sign_peaks = params.get('detection', 'peaks')
    do_temporal_whitening = params.getboolean('whitening', 'temporal')
    do_spatial_whitening = params.getboolean('whitening', 'spatial')
    ignore_spikes = params.getboolean('whitening', 'ignore_spikes')
    chunk_size = detect_memory(params, whitening=True)
    plot_path = os.path.join(params.get('data', 'file_out_suff'), 'plots')
    nodes, edges = get_nodes_and_edges(params)
    safety_time = params.getint('whitening', 'safety_time')
    safety_space = params.getboolean('whitening', 'safety_space')
    sort_waveforms = params.getboolean('whitening', 'sort_waveforms')
    nb_temp_white = min(max(20, comm.size), N_e)
    max_silence_1 = int(20 * params.rate // comm.size)
    max_silence_2 = 5000
    inv_nodes = numpy.zeros(N_total, dtype=numpy.int32)
    inv_nodes[nodes] = numpy.arange(len(nodes))
    jitter_range = params.getint('detection', 'jitter_range')
    template_shift_2 = template_shift + jitter_range
    use_hanning = params.getboolean('detection', 'hanning')
    rejection_threshold = params.getfloat('detection', 'rejection_threshold')
    noise_window = params.getint('detection', 'noise_time')
    data_file.open()
    #################################################################

    if use_hanning:
        hanning_filter = numpy.hanning(N_t)

    if comm.rank == 0:
        print_and_log(
            ["Analyzing data to get whitening matrices and thresholds..."],
            'default', logger)

    nodes_indices = {}
    for elec in numpy.arange(N_e):
        nodes_indices[elec] = inv_nodes[edges[nodes[elec]]]

    if use_gpu:
        import cudamat as cmt
        # # Need to properly handle multi GPU per MPI nodes?
        if nb_gpu > nb_cpu:
            gpu_id = int(comm.rank // nb_cpu)
        else:
            gpu_id = 0
        cmt.cuda_set_device(gpu_id)
        cmt.init()
        cmt.cuda_sync_threads()

    nb_chunks, last_chunk_len = data_file.analyze(chunk_size)

    if nb_chunks < comm.size:

        res = io.data_stats(params, show=False)
        chunk_size = int(res * params.rate // comm.size)
        if comm.rank == 0:
            print_and_log(
                ["Too much cores, automatically resizing the data chunks"],
                'debug', logger)

        nb_chunks, last_chunk_len = data_file.analyze(chunk_size)

    # I guess this is more relevant, to take signals from all over the recordings.
    if nb_chunks > comm.size:
        all_chunks = numpy.random.permutation(
            numpy.arange(nb_chunks - 1, dtype=numpy.int32))
    else:
        all_chunks = numpy.random.permutation(
            numpy.arange(nb_chunks, dtype=numpy.int32))

    all_electrodes = numpy.random.permutation(N_e)

    numpy.random.seed(comm.rank)

    for gidx in [all_chunks[comm.rank]]:

        # print "Node", comm.rank, "is analyzing chunk", gidx,  "/", nb_chunks, " ..."
        local_chunk, t_offset = data_file.get_data(gidx,
                                                   chunk_size,
                                                   nodes=nodes)
        local_shape = len(local_chunk)

        # print "Node", comm.rank, "computes the median absolute deviations in a random chunk"
        thresholds = numpy.zeros(N_e, dtype=numpy.float32)
        for i in range(N_e):
            u = numpy.median(local_chunk[:, i], 0)
            thresholds[i] = numpy.median(numpy.abs(local_chunk[:, i] - u), 0)
        gdata = gather_array(thresholds, comm)
        if comm.rank == 0:
            gdata = gdata.reshape((comm.size, N_e))
            thresholds = numpy.mean(gdata, 0)
            bfile = h5py.File(file_out_suff + '.basis.hdf5',
                              'w',
                              libver='earliest')
            io.write_datasets(bfile, ['thresholds'],
                              {'thresholds': thresholds},
                              compression=hdf5_compress)
            bfile.close()
        comm.Barrier()
        thresholds = io.load_data(params, 'thresholds')

        local_borders = (template_shift, local_shape - template_shift)
        found_peaktimes = []

        if ignore_spikes:
            # Extracting the peaks.
            local_peaktimes = [np.empty(0, dtype=numpy.uint32)]
            for i in range(N_e):
                peaktimes = scipy.signal.find_peaks(numpy.abs(local_chunk[:,
                                                                          i]),
                                                    height=thresholds[i],
                                                    width=spike_width,
                                                    wlen=N_t)[0]
                peaktimes = peaktimes.astype(numpy.uint32)

                # print "Removing the useless borders..."
                idx = (peaktimes >= local_borders[0]) & (peaktimes <
                                                         local_borders[1])
                peaktimes = numpy.compress(idx, peaktimes)

                found_peaktimes.append(peaktimes)
        else:
            for i in range(N_e):
                found_peaktimes.append(numpy.zeros(0, dtype=numpy.uint32))

        all_peaktimes = numpy.concatenate(found_peaktimes)
        local_peaktimes = numpy.unique(all_peaktimes)

        if len(local_peaktimes) > 0:

            diff_times = local_peaktimes[-1] - local_peaktimes[0]
            all_times = numpy.zeros((N_e, diff_times + 1), dtype=numpy.bool)
            padded_peaks = (local_peaktimes - local_peaktimes[0]).astype(
                numpy.int32)
            min_times = numpy.maximum(padded_peaks - safety_time, 0)
            max_times = numpy.minimum(padded_peaks + safety_time + 1,
                                      diff_times + 1)

            test_extremas = numpy.zeros((N_e, diff_times + 1),
                                        dtype=numpy.bool)
            for i in range(N_e):
                test_extremas[i,
                              found_peaktimes[i] - local_peaktimes[0]] = True

            argmax_peak = numpy.random.permutation(
                numpy.arange(len(local_peaktimes)))
            all_idx = numpy.take(local_peaktimes, argmax_peak)

            # print "Selection of the peaks with spatio-temporal masks..."
            for idx, peak in zip(argmax_peak, all_idx):

                all_elecs = numpy.where(test_extremas[:, peak -
                                                      local_peaktimes[0]])[0]
                data = local_chunk[peak, all_elecs]
                elec = all_elecs[numpy.argmax(numpy.abs(data))]
                indices = nodes_indices[elec]
                if safety_space:
                    all_times[indices, min_times[idx]:max_times[idx]] = True
                else:
                    all_times[elec, min_times[idx]:max_times[idx]] = True
        else:
            all_times = numpy.zeros((N_e, len(local_chunk)), dtype=numpy.bool)

    if do_temporal_whitening:

        local_res_temp = []

        for elec in all_electrodes[numpy.arange(comm.rank, nb_temp_white,
                                                comm.size)]:
            res = numpy.zeros((0, N_t), dtype=numpy.float32)
            scount = 0
            indices = nodes_indices[elec]
            all_times_elec = numpy.any(numpy.take(all_times, indices, axis=0),
                                       0)
            esubset = numpy.where(all_times_elec == False)[0]
            bound = len(esubset) - N_t
            while (scount < bound) and (len(res) < max_silence_2):
                myslice = esubset[scount:scount + N_t]
                if numpy.all((myslice - esubset[scount]) == numpy.arange(N_t)):
                    scount += N_t
                    res = numpy.vstack((res, local_chunk[myslice, elec]))
                else:
                    scount += 1
            if len(res) > 5:
                local_res_temp += [numpy.cov(res.T)]

        nb_elecs = numpy.array([len(local_res_temp)], dtype=numpy.float32)
        local_res_temp = numpy.array(local_res_temp, dtype=numpy.float32)
        if len(local_res_temp) == 0:
            local_res_temp = numpy.zeros(0, dtype=numpy.float32)
        else:
            local_res_temp = numpy.sum(local_res_temp, 0)
        all_res_temp = gather_array(local_res_temp.ravel(), comm, 0, 1)
        all_elecs = gather_array(nb_elecs, comm, 0, 1)

    if do_spatial_whitening:

        local_res_spac = numpy.zeros((N_e, N_e), dtype=numpy.float32)
        local_silences = []

        for elec in numpy.arange(comm.rank, N_e, comm.size):
            indices = nodes_indices[elec]
            all_times_elec = numpy.any(numpy.take(all_times, indices, axis=0),
                                       0)
            esubset = numpy.where(all_times_elec == False)[0]
            local_data = local_chunk[esubset][:, indices]
            local_whitening = get_whitening_matrix(
                local_data, fudge=fudge).astype(numpy.float32)
            pos = numpy.where(elec == indices)[0]
            local_res_spac[elec, indices] = local_whitening[pos]
            local_silences += [len(esubset)]

        all_res_spac = gather_array(local_res_spac.ravel(), comm, 0, 1)
        all_silences = gather_array(
            numpy.array(local_silences, dtype=numpy.int32), comm, 0, 1,
            'uint32')

    if comm.rank == 0:

        to_write = {}

        if do_temporal_whitening:
            try:
                nb_silences = numpy.sum(all_elecs > 0)
                all_res_temp = all_res_temp.reshape((nb_silences, N_t**2))
            except Exception:
                print_and_log([
                    "No silent periods detected: something wrong with the parameters?"
                ], 'error', logger)
            all_res_temp = numpy.sum(all_res_temp, 0)
            all_res_temp = all_res_temp.reshape(
                (N_t, N_t)) / numpy.sum(all_elecs)
            temporal_whitening = get_whitening_matrix(
                all_res_temp.astype(numpy.double),
                fudge=1e-3)[template_shift].astype(numpy.float32)
            temporal_whitening /= temporal_whitening.sum()
            to_write['temporal'] = temporal_whitening
            have_nans = numpy.sum(numpy.isnan(temporal_whitening))

            if have_nans > 0:
                temporal_whitening = numpy.zeros(N_t, dtype=numpy.float32)
                temporal_whitening[N_t // 2] = 1
                to_write['temporal'] = temporal_whitening
                print_and_log(
                    ["Disabling temporal whitening because of NaNs found"],
                    'info', logger)

        if do_spatial_whitening:
            all_res_spac = all_res_spac.reshape(comm.size, N_e, N_e)
            spatial_whitening = numpy.sum(all_res_spac, 0)
            to_write['spatial'] = spatial_whitening

            if ignore_spikes:
                print_and_log([
                    "Found %gs without spikes to compute the whitening matrix..."
                    % (numpy.mean(all_silences) / params.rate)
                ], 'default', logger)
            else:
                print_and_log([
                    "Found %gs to compute the whitening matrix..." %
                    (numpy.mean(all_silences) / params.rate)
                ], 'default', logger)

            have_nans = numpy.sum(numpy.isnan(spatial_whitening))

            if have_nans > 0:
                spatial_whitening = numpy.eye(spatial_whitening.shape[0],
                                              dtype=numpy.float32)
                to_write['spatial'] = spatial_whitening
                print_and_log(
                    ["Disabling spatial whitening because of NaNs found"],
                    'info', logger)

        bfile = h5py.File(file_out_suff + '.basis.hdf5',
                          'r+',
                          libver='earliest')
        io.write_datasets(bfile,
                          list(to_write.keys()),
                          to_write,
                          compression=hdf5_compress)
        bfile.close()

    comm.Barrier()

    if do_spatial_whitening or do_temporal_whitening:

        if comm.rank == 0:
            print_and_log(
                ["Because of whitening, need to recompute the thresholds..."],
                'default', logger)

        if do_spatial_whitening:
            spatial_whitening = io.load_data(params, 'spatial_whitening')
            if use_gpu:
                spatial_whitening = cmt.CUDAMatrix(spatial_whitening,
                                                   copy_on_host=False)
        if do_temporal_whitening:
            temporal_whitening = io.load_data(params, 'temporal_whitening')

        for gidx in [all_chunks[comm.rank]]:
            local_chunk, t_offset = data_file.get_data(gidx,
                                                       chunk_size,
                                                       nodes=nodes)
            local_shape = len(local_chunk)

            if do_spatial_whitening:
                if use_gpu:
                    local_chunk = cmt.CUDAMatrix(local_chunk,
                                                 copy_on_host=False)
                    local_chunk = local_chunk.dot(spatial_whitening).asarray()
                else:
                    local_chunk = numpy.dot(local_chunk, spatial_whitening)
            if do_temporal_whitening:
                local_chunk = scipy.ndimage.filters.convolve1d(
                    local_chunk, temporal_whitening, axis=0, mode='constant')

            thresholds = numpy.zeros(N_e, dtype=numpy.float32)
            for i in range(N_e):
                u = numpy.median(local_chunk[:, i], 0)
                thresholds[i] = numpy.median(numpy.abs(local_chunk[:, i] - u),
                                             0)
            gdata = gather_array(thresholds, comm)
            if comm.rank == 0:
                gdata = gdata.reshape((comm.size, N_e))
                thresholds = numpy.mean(gdata, 0)
                bfile = h5py.File(file_out_suff + '.basis.hdf5',
                                  'r+',
                                  libver='earliest')
                bfile.pop('thresholds')
                io.write_datasets(bfile, ['thresholds'],
                                  {'thresholds': thresholds},
                                  compression=hdf5_compress)
                bfile.close()
            comm.Barrier()

    # if comm.rank == 0:
    #     if not os.path.exists(plot_path):
    #         os.makedirs(plot_path)
    #     N_elec = min(int(numpy.sqrt(data_file.N_e)), 5)
    #     plot.view_fit(filename, t_start=0, t_stop=1, fit_on=False, square=True,
    #                   n_elec=N_elec, save=[plot_path, 'electrodes'])

    # Part 2: Basis
    numpy.random.seed(422)

    SHARED_MEMORY = get_shared_memory_flag(params)
    #################################################################
    file_out = params.get('data', 'file_out')
    alignment = params.getboolean('detection', 'alignment')
    over_factor = params.getint('detection', 'oversampling_factor')
    nb_jitter = params.getint('detection', 'nb_jitter')
    spike_thresh = params.getfloat('detection', 'spike_thresh')
    nodes, edges = get_nodes_and_edges(params)
    _, positions = get_nodes_and_positions(params)
    do_temporal_whitening = params.getboolean('whitening', 'temporal')
    do_spatial_whitening = params.getboolean('whitening', 'spatial')
    use_barycenter = params.getboolean('detection', 'use_barycenter')
    if matched_filter:
        chunk_size = detect_memory(params, whitening=True)
    else:
        chunk_size = detect_memory(params)
    safety_time = params.getint('whitening', 'safety_time')
    max_elts_elec = params.getint('whitening', 'max_elts')
    output_dim = params.getfloat('whitening', 'output_dim')
    inv_nodes = numpy.zeros(N_total, dtype=numpy.int32)
    inv_nodes[nodes] = numpy.arange(len(nodes))
    smoothing_factor = params.getfloat('detection', 'smoothing_factor')
    if sign_peaks == 'both':
        max_elts_elec *= 2
    nb_elts = int(
        params.getfloat('whitening', 'nb_elts') * N_e * max_elts_elec)

    weird_thresh = params.get('detection', 'weird_thresh')
    if weird_thresh != '':
        ignore_artefacts = True
        weird_thresh = io.load_data(params, 'weird-thresholds')
    else:
        ignore_artefacts = False

    ignore_dead_times = params.getboolean('triggers', 'ignore_times')
    if ignore_dead_times:
        if SHARED_MEMORY:
            all_dead_times, mpi_memory_3 = get_dead_times(params)
        else:
            all_dead_times = get_dead_times(params)
    data_file.open()
    #################################################################

    if comm.rank == 0:
        print_and_log(["Searching spikes to construct the PCA basis..."],
                      'default', logger)

    nb_chunks, last_chunk_len = data_file.analyze(chunk_size)

    if nb_chunks < comm.size:

        res = io.data_stats(params, show=False)
        chunk_size = int(res * params.rate // comm.size)
        if comm.rank == 0:
            print_and_log(
                ["Too much cores, automatically resizing the data chunks"],
                'debug', logger)

        nb_chunks, last_chunk_len = data_file.analyze(chunk_size)

    groups = {}
    for i in range(N_e):
        groups[i] = 0

    # I guess this is more relevant, to take signals from all over the recordings
    all_chunks = numpy.random.permutation(
        numpy.arange(nb_chunks, dtype=numpy.int32))
    max_elts_elec //= comm.size
    nb_elts //= comm.size

    elt_count_pos = 0
    elt_count_neg = 0

    if sign_peaks in ['positive', 'both']:
        times_pos = numpy.zeros(nb_elts, dtype=numpy.int32)
        electrodes_pos = numpy.zeros(nb_elts, dtype=numpy.int32)
        extremum_pos = numpy.zeros(nb_elts, dtype=numpy.float32)
        elts_pos = numpy.zeros((N_t, nb_elts), dtype=numpy.float32)
    if sign_peaks in ['negative', 'both']:
        times_neg = numpy.zeros(nb_elts, dtype=numpy.int32)
        electrodes_neg = numpy.zeros(nb_elts, dtype=numpy.int32)
        extremum_neg = numpy.zeros(nb_elts, dtype=numpy.float32)
        elts_neg = numpy.zeros((N_t, nb_elts), dtype=numpy.float32)

    thresholds = io.load_data(params, 'thresholds')
    mads = io.load_data(params, 'mads')
    stds = io.load_data(params, 'stds')

    if alignment:
        cdata = numpy.linspace(-jitter_range, +jitter_range, nb_jitter)
        xdata = numpy.arange(-template_shift_2, template_shift_2 + 1)
        xoff = len(cdata) / 2.0
        snippet_duration = template_shift_2
        m_size = 2 * template_shift_2 + 1
        align_factor = m_size
        local_factors = align_factor * ((smoothing_factor * mads)**2)
    else:
        snippet_duration = template_shift
        xdata = numpy.arange(-template_shift, template_shift + 1)

    if rejection_threshold > 0:
        reject_noise = True
        noise_levels = stds * (2 * noise_window + 1)
    else:
        reject_noise = False

    to_explore = all_chunks[comm.rank::comm.size]

    upper_bounds = max_elts_elec

    if comm.rank == 0:
        to_explore = get_tqdm_progressbar(params, to_explore)

    for gcount, gidx in enumerate(to_explore):

        if (elt_count_pos + elt_count_neg) < nb_elts:
            # print "Node", comm.rank, "is analyzing chunk", gidx, "/", nb_chunks, " ..."
            local_chunk, t_offset = data_file.get_data(gidx,
                                                       chunk_size,
                                                       nodes=nodes)
            local_shape = len(local_chunk)

            if do_spatial_whitening:
                if use_gpu:
                    local_chunk = cmt.CUDAMatrix(local_chunk,
                                                 copy_on_host=False)
                    local_chunk = local_chunk.dot(spatial_whitening).asarray()
                else:
                    local_chunk = numpy.dot(local_chunk, spatial_whitening)
            if do_temporal_whitening:
                local_chunk = scipy.ndimage.filters.convolve1d(
                    local_chunk, temporal_whitening, axis=0, mode='constant')

            local_borders = (snippet_duration, local_shape - snippet_duration)

            if ignore_dead_times:
                dead_indices = numpy.searchsorted(
                    all_dead_times, [t_offset, t_offset + local_shape])

            # Extracting the peaks.
            all_peaktimes = [numpy.empty(0, dtype=numpy.uint32)]

            found_peaktimes = []
            found_peak_amplitudes = []
            for i in range(N_e):
                height = thresholds[i]
                if sign_peaks == 'negative':
                    peaktimes = scipy.signal.find_peaks(-local_chunk[:, i],
                                                        height=height,
                                                        distance=dist_peaks)[0]
                elif sign_peaks == 'positive':
                    peaktimes = scipy.signal.find_peaks(local_chunk[:, i],
                                                        height=height,
                                                        distance=dist_peaks)[0]
                elif sign_peaks == 'both':
                    peaktimes = scipy.signal.find_peaks(numpy.abs(
                        local_chunk[:, i]),
                                                        height=height,
                                                        distance=dist_peaks)[0]
                else:
                    peaktimes = numpy.empty(0, dtype=numpy.uint32)

                if ignore_artefacts:
                    artetimes = scipy.signal.find_peaks(
                        numpy.abs(local_chunk[:,
                                              i]), height=weird_thresh[i])[0]
                    to_keep = numpy.logical_not(
                        numpy.in1d(peaktimes, artetimes))
                    peaktimes = peaktimes[to_keep]

                idx = (peaktimes >= local_borders[0]) & (peaktimes <
                                                         local_borders[1])
                peaktimes = peaktimes[idx]

                if ignore_dead_times:
                    if dead_indices[0] != dead_indices[1]:
                        is_included = numpy.in1d(
                            peaktimes + t_offset,
                            all_dead_times[dead_indices[0]:dead_indices[1]])
                        peaktimes = peaktimes[~is_included]

                peaktimes = peaktimes.astype(numpy.uint32)
                found_peaktimes.append(peaktimes)

                peak_amplitudes = local_chunk[peaktimes, i]
                found_peak_amplitudes.append(peak_amplitudes)

            all_peaktimes = numpy.concatenate(
                found_peaktimes)  # i.e. concatenate once for efficiency
            all_peak_amplitudes = numpy.concatenate(found_peak_amplitudes)
            local_peaktimes, local_indices = numpy.unique(all_peaktimes,
                                                          return_inverse=True)

            if len(local_peaktimes) > 0:

                diff_times = (local_peaktimes[-1] - local_peaktimes[0]) + 1
                all_times = numpy.zeros((N_e, diff_times), dtype=numpy.bool)

                padded_peaks = (local_peaktimes - local_peaktimes[0]).astype(
                    numpy.int32)
                min_times = numpy.maximum(padded_peaks - safety_time, 0)
                max_times = numpy.minimum(padded_peaks + safety_time + 1,
                                          diff_times + 1)
                test_extremas = numpy.zeros((N_e, diff_times + 1),
                                            dtype=numpy.bool)
                for i in range(N_e):
                    test_extremas[i, found_peaktimes[i] -
                                  local_peaktimes[0]] = True

                # Consider the peaks by decreasing extremum.
                if sort_waveforms:
                    order = numpy.argsort(-np.abs(all_peak_amplitudes))
                    all_idx = numpy.take(all_peaktimes, order)
                    argmax_peak = local_indices[order]
                else:
                    n_times = len(all_peaktimes)
                    shuffling = numpy.random.permutation(numpy.arange(n_times))
                    all_idx = numpy.take(all_peaktimes, shuffling)
                    argmax_peak = local_indices[shuffling]

                # print "Selection of the peaks with spatio-temporal masks..."
                for midx, peak in zip(argmax_peak, all_idx):
                    if (elt_count_neg + elt_count_pos) == nb_elts:
                        break

                    all_elecs = numpy.where(
                        test_extremas[:, peak - local_peaktimes[0]])[0]
                    data = local_chunk[peak, all_elecs]

                    #target_area = test_extremas[:, min_times[midx]:max_times[midx]].sum(1)
                    #all_elecs = numpy.where(target_area)[0]
                    #data = local_chunk[peak, all_elecs]

                    if sign_peaks == 'negative':
                        if N_e > 1:
                            if use_barycenter:
                                weighed_position = data[:, numpy.
                                                        newaxis] * positions[
                                                            all_elecs]
                                barycenter = weighed_position.sum(
                                    0) / data.sum()
                                elec = numpy.argmin(
                                    numpy.linalg.norm(barycenter -
                                                      positions[all_elecs],
                                                      axis=1))
                            else:
                                elec = numpy.argmin(data)
                        else:
                            elec = 0
                        negative_peak = True
                    elif sign_peaks == 'positive':
                        if N_e > 1:
                            if use_barycenter:
                                weighed_position = data[:, numpy.
                                                        newaxis] * positions[
                                                            all_elecs]
                                barycenter = weighed_position.sum(
                                    0) / data.sum()
                                elec = numpy.argmax(
                                    numpy.linalg.norm(barycenter -
                                                      positions[all_elecs],
                                                      axis=1))
                            else:
                                elec = numpy.argmax(data)
                        else:
                            elec = 0
                        negative_peak = False
                    elif sign_peaks == 'both':
                        if N_e == 1:
                            if data < 0:
                                negative_peak = True
                            elif data > 0:
                                negative_peak = False
                            elec = 0
                        else:
                            if numpy.abs(numpy.max(data)) > numpy.abs(
                                    numpy.min(data)):
                                elec = numpy.argmax(data)
                                negative_peak = False
                            else:
                                elec = numpy.argmin(data)
                                negative_peak = True

                    elec = all_elecs[elec]

                    if groups[elec] < upper_bounds:

                        indices = nodes_indices[elec]
                        myslice = all_times[indices,
                                            min_times[midx]:max_times[midx]]

                        if not myslice.any():

                            sub_mat = local_chunk[peak -
                                                  snippet_duration:peak +
                                                  snippet_duration + 1, elec]

                            if reject_noise:
                                slice_window = sub_mat[
                                    snippet_duration -
                                    noise_window:snippet_duration +
                                    noise_window + 1]
                                value = numpy.linalg.norm(
                                    slice_window) / noise_levels[elec]
                                is_noise = value < rejection_threshold
                            else:
                                is_noise = False

                            if not is_noise:

                                extrema = sub_mat[snippet_duration]

                                if alignment:
                                    smoothed = True
                                    try:
                                        f = scipy.interpolate.UnivariateSpline(
                                            xdata,
                                            sub_mat,
                                            s=local_factors[elec],
                                            k=3)
                                    except Exception:
                                        smoothed = False
                                        f = scipy.interpolate.UnivariateSpline(
                                            xdata, sub_mat, k=3, s=0)

                                    if negative_peak:
                                        rmin = (numpy.argmin(f(cdata)) -
                                                xoff) / over_factor
                                    else:
                                        rmin = (numpy.argmax(f(cdata)) -
                                                xoff) / over_factor
                                    ddata = numpy.linspace(
                                        rmin - template_shift,
                                        rmin + template_shift, N_t)

                                    if smoothed:
                                        f = scipy.interpolate.UnivariateSpline(
                                            xdata,
                                            sub_mat,
                                            s=local_factors[elec],
                                            k=3)
                                    else:
                                        f = scipy.interpolate.UnivariateSpline(
                                            xdata, sub_mat, s=0, k=3)

                                    sub_mat = f(ddata).astype(numpy.float32)

                                if negative_peak:
                                    times_neg[elt_count_neg] = peak + t_offset
                                    electrodes_neg[elt_count_neg] = elec
                                    extremum_neg[elt_count_neg] = extrema
                                    elts_neg[:, elt_count_neg] = sub_mat
                                    elt_count_neg += 1
                                else:
                                    times_pos[elt_count_pos] = peak + t_offset
                                    electrodes_pos[elt_count_pos] = elec
                                    extremum_pos[elt_count_pos] = extrema
                                    elts_pos[:, elt_count_pos] = sub_mat
                                    elt_count_pos += 1

                                groups[elec] += 1
                                all_times[
                                    indices,
                                    min_times[midx]:max_times[midx]] = True
                                test_extremas[elec, peak -
                                              local_peaktimes[0]] = False

    sys.stderr.flush()

    print_and_log([
        "Node %d has collected %d waveforms" %
        (comm.rank, elt_count_pos + elt_count_neg)
    ], 'debug', logger)

    if sign_peaks in ['negative', 'both']:
        times_neg = gather_array(times_neg[:elt_count_neg],
                                 comm,
                                 0,
                                 1,
                                 dtype='int32')
        electrodes_neg = gather_array(electrodes_neg[:elt_count_neg],
                                      comm,
                                      0,
                                      1,
                                      dtype='int32')
        extremum_neg = gather_array(extremum_neg[:elt_count_neg], comm, 0, 1)
        gdata_neg = gather_array(elts_neg[:, :elt_count_neg].T, comm, 0, 1)
    if sign_peaks in ['positive', 'both']:
        times_pos = gather_array(times_pos[:elt_count_pos],
                                 comm,
                                 0,
                                 1,
                                 dtype='int32')
        electrodes_pos = gather_array(electrodes_pos[:elt_count_pos],
                                      comm,
                                      0,
                                      1,
                                      dtype='int32')
        extremum_pos = gather_array(extremum_pos[:elt_count_pos], comm, 0, 1)
        gdata_pos = gather_array(elts_pos[:, :elt_count_pos].T, comm, 0, 1)

    nb_waveforms = 0

    if comm.rank == 0:
        # DO PCA on elts and store the basis obtained.

        if sign_peaks in ['negative', 'both']:
            nb_waveforms += gdata_neg.shape[0]
        if sign_peaks in ['positive', 'both']:
            nb_waveforms += gdata_pos.shape[0]

    nb_waveforms = all_gather_array(
        numpy.array([nb_waveforms], dtype=numpy.float32), comm, 0)[0]

    if comm.rank == 0:
        print_and_log([
            "Found %d waveforms over %d requested" %
            (nb_waveforms, int(nb_elts * comm.size))
        ], 'default', logger)

        if nb_waveforms == 0:
            print_and_log(
                ['No waveforms found! Are the data properly loaded??'],
                'error', logger)

    if nb_waveforms == 0:
        sys.exit(0)

    if comm.rank == 0:
        res = {}
        pca = None
        pca_pos = None
        pca_neg = None
        warning_n_t = False
        if sign_peaks in ['negative', 'both']:
            res['times'] = times_neg
            res['electrodes'] = electrodes_neg
            res['extremum'] = extremum_neg
            if len(gdata_neg) > 0:
                pca = PCA(output_dim)
                if use_hanning:
                    pca.fit(gdata_neg * hanning_filter)
                else:
                    pca.fit(gdata_neg)
                res['proj'] = pca.components_.T.astype(numpy.float32)
                pca_neg = numpy.sum(pca.explained_variance_ratio_)
            else:
                res['proj'] = numpy.identity(int(output_dim),
                                             dtype=numpy.float32)
            res['rec'] = res['proj'].T
            res['waveform'] = numpy.median(gdata_neg, 0)
            # dispersion = numpy.std(gdata_neg, 0) / numpy.median(stds)
            # ratio = numpy.sum(dispersion > 1.1) / float(len(dispersion))
            # if ratio < 0.25:
            #     print_and_log(["Time window N_t in [detection] seems too large!"], 'info', logger)
            #     warning_n_t = True
            # elif ratio == 1:
            #     print_and_log(["Time window N_t in [detection] seems too small!"], 'info', logger)
            #     warning_n_t = True
            idx = numpy.random.permutation(numpy.arange(
                gdata_neg.shape[0]))[:2500]
            res['waveforms'] = gdata_neg[idx, :]
        if sign_peaks in ['positive', 'both']:
            res['times_pos'] = times_pos
            res['electrodes_pos'] = electrodes_pos
            res['extremum_pos'] = extremum_pos
            if len(gdata_pos) > 0:
                pca = PCA(output_dim)
                if use_hanning:
                    pca.fit(gdata_pos * hanning_filter)
                else:
                    pca.fit(gdata_pos)
                res['proj_pos'] = pca.components_.T.astype(numpy.float32)
                pca_pos = numpy.sum(pca.explained_variance_ratio_)
            else:
                res['proj_pos'] = numpy.identity(int(output_dim),
                                                 dtype=numpy.float32)
            res['rec_pos'] = res['proj_pos'].T
            res['waveform_pos'] = numpy.median(gdata_pos, 0)
            # dispersion = numpy.std(gdata_pos, 0) / numpy.median(stds)
            # ratio = numpy.sum(dispersion > 1.1) / float(len(dispersion))
            # if ratio < 0.25 and not warning_n_t:
            #     print_and_log(["Time window N_t in [detection] seems too large!"], 'info', logger)
            # elif ratio == 1 and not warning_n_t:
            #     print_and_log(["Time window N_t in [detection] seems too small!"], 'info', logger)
            idx = numpy.random.permutation(numpy.arange(
                gdata_pos.shape[0]))[:2500]
            res['waveforms_pos'] = gdata_pos[idx, :]

        bfile = h5py.File(file_out_suff + '.basis.hdf5',
                          'r+',
                          libver='earliest')
        io.write_datasets(bfile,
                          list(res.keys()),
                          res,
                          compression=hdf5_compress)
        if sign_peaks == 'positive':
            print_and_log([
                "A basis with %s dimensions has been built" %
                res['proj_pos'].shape[1]
            ], 'info', logger)
        elif sign_peaks == 'negative':
            print_and_log([
                "A basis with %s dimensions has been built" %
                res['proj'].shape[1]
            ], 'info', logger)
        elif sign_peaks == 'both':
            print_and_log([
                "Two basis with %s dimensions has been built" %
                res['proj'].shape[1]
            ], 'debug', logger)
        if pca_pos is not None:
            print_and_log([
                "The percentage of variance explained is %s for positive spikes"
                % pca_pos
            ], 'debug', logger)
        if pca_neg is not None:
            print_and_log([
                "The percentage of variance explained is %s for negative spikes"
                % pca_neg
            ], 'debug', logger)

        bfile.close()

    comm.Barrier()

    if matched_filter:

        if comm.rank == 0:
            print_and_log([
                "Because of matched filters, need to recompute the thresholds..."
            ], 'default', logger)

        if do_spatial_whitening:
            spatial_whitening = io.load_data(params, 'spatial_whitening')
            if use_gpu:
                spatial_whitening = cmt.CUDAMatrix(spatial_whitening,
                                                   copy_on_host=False)
        if do_temporal_whitening:
            temporal_whitening = io.load_data(params, 'temporal_whitening')

        if sign_peaks in ['negative', 'both']:
            waveform_neg = io.load_data(params, 'waveform')[::-1]
            waveform_neg /= (numpy.abs(numpy.sum(waveform_neg)) *
                             len(waveform_neg))
        if sign_peaks in ['positive', 'both']:
            waveform_pos = io.load_data(params, 'waveform-pos')[::-1]
            waveform_pos /= (numpy.abs(numpy.sum(waveform_pos)) *
                             len(waveform_pos))

        for gidx in [all_chunks[comm.rank]]:
            local_chunk, t_offset = data_file.get_data(gidx,
                                                       chunk_size,
                                                       nodes=nodes)
            local_shape = len(local_chunk)

            if do_spatial_whitening:
                if use_gpu:
                    local_chunk = cmt.CUDAMatrix(local_chunk,
                                                 copy_on_host=False)
                    local_chunk = local_chunk.dot(spatial_whitening).asarray()
                else:
                    local_chunk = numpy.dot(local_chunk, spatial_whitening)
            if do_temporal_whitening:
                local_chunk = scipy.ndimage.filters.convolve1d(
                    local_chunk, temporal_whitening, axis=0, mode='constant')

            local_chunk /= thresholds

            if sign_peaks in ['negative', 'both']:
                tmp_chunk = scipy.ndimage.filters.convolve1d(local_chunk,
                                                             waveform_neg,
                                                             axis=0,
                                                             mode='constant')
                thresholds = numpy.zeros(N_e, dtype=numpy.float32)
                for i in range(N_e):
                    u = numpy.median(tmp_chunk[:, i], 0)
                    thresholds[i] = numpy.median(
                        numpy.abs(tmp_chunk[:, i] - u), 0)
                gdata = gather_array(thresholds, comm)
                if comm.rank == 0:
                    gdata = gdata.reshape((comm.size, N_e))
                    thresholds = numpy.mean(gdata, 0)
                    bfile = h5py.File(file_out_suff + '.basis.hdf5',
                                      'r+',
                                      libver='earliest')
                    io.write_datasets(bfile, ['matched_thresholds'],
                                      {'matched_thresholds': thresholds},
                                      compression=hdf5_compress)
                    bfile.close()
                comm.Barrier()

            if sign_peaks in ['positive', 'both']:
                tmp_chunk = scipy.ndimage.filters.convolve1d(local_chunk,
                                                             waveform_pos,
                                                             axis=0,
                                                             mode='constant')
                thresholds = numpy.zeros(N_e, dtype=numpy.float32)
                for i in range(N_e):
                    u = numpy.median(tmp_chunk[:, i], 0)
                    thresholds[i] = numpy.median(
                        numpy.abs(tmp_chunk[:, i] - u), 0)
                gdata = gather_array(thresholds, comm)
                if comm.rank == 0:
                    gdata = gdata.reshape((comm.size, N_e))
                    thresholds = numpy.mean(gdata, 0)
                    bfile = h5py.File(file_out_suff + '.basis.hdf5',
                                      'r+',
                                      libver='earliest')
                    io.write_datasets(bfile, ['matched_thresholds_pos'],
                                      {'matched_thresholds_pos': thresholds},
                                      compression=hdf5_compress)
                    bfile.close()
                comm.Barrier()

    data_file.close()

    if SHARED_MEMORY and ignore_dead_times:
        mpi_memory_3.Free()
示例#2
0
def main(argv=None):

    if argv is None:
        argv = sys.argv[1:]

    parallel_hdf5 = h5py.get_config().mpi
    user_path = pjoin(os.path.expanduser('~'), 'spyking-circus')
    tasks_list = None

    if not os.path.exists(user_path):
        os.makedirs(user_path)

    try:
        import cudamat as cmt
        cmt.init()
        HAVE_CUDA = True
    except Exception:
        HAVE_CUDA = False

    all_steps = [
        'whitening', 'clustering', 'fitting', 'gathering', 'extracting',
        'filtering', 'converting', 'deconverting', 'benchmarking',
        'merging', 'validating', 'thresholding'
    ]

    config_file = os.path.abspath(pkg_resources.resource_filename('circus', 'config.params'))

    header = get_colored_header()
    header += Fore.GREEN + 'Local CPUs    : ' + Fore.CYAN + str(psutil.cpu_count()) + '\n'
    # header += Fore.GREEN + 'GPU detected  : ' + Fore.CYAN + str(HAVE_CUDA) + '\n'
    header += Fore.GREEN + 'Parallel HDF5 : ' + Fore.CYAN + str(parallel_hdf5) + '\n'

    do_upgrade = ''
    if not SHARED_MEMORY:
        do_upgrade = Fore.WHITE + '   [please consider upgrading MPI]'

    header += Fore.GREEN + 'Shared memory : ' + Fore.CYAN + str(SHARED_MEMORY) + do_upgrade + '\n'
    header += '\n'
    header += Fore.GREEN + "##################################################################"
    header += Fore.RESET

    method_help = '''by default, all steps are performed,
but a subset x,y can be done. Steps are:
 - filtering
 - whitening
 - clustering
 - fitting
 - merging [with or without a GUI for meta merging]
 - (extra) converting [export results to phy format]
 - (extra) thresholding [to get MUA activity only]
 - (extra) deconverting [import results from phy format]
 - (extra) gathering [force collection of results]
 - (extra) extracting [get templates from spike times]
 - (extra) benchmarking [with -o and -t]
 - (extra) validating [to compare performance with GT neurons]'''

    parser = argparse.ArgumentParser(description=header,
                                     formatter_class=argparse.RawTextHelpFormatter)
    parser.add_argument('datafile', help='data file (or a list of commands if batch mode)')
    parser.add_argument('-i', '--info', help='list the file formats supported by SpyKING CIRCUS', action='store_true')
    parser.add_argument('-m', '--method',
                        default='filtering,whitening,clustering,fitting,merging',
                        help=method_help)
    parser.add_argument('-c', '--cpu', type=int, default=max(1, int(psutil.cpu_count()/2)), help='number of CPU')
    # parser.add_argument('-g', '--gpu', type=int, default=0, help='number of GPU')
    parser.add_argument('-H', '--hostfile', help='hostfile for MPI',
                        default=pjoin(user_path, 'circus.hosts'))
    parser.add_argument('-b', '--batch', help='datafile is a list of commands to launch, in a batch mode',
                        action='store_true')
    parser.add_argument('-p', '--preview', help='GUI to display the first second filtered with thresholds',
                        action='store_true')
    parser.add_argument('-r', '--result', help='GUI to display the results on top of raw data',
                        action='store_true')
    parser.add_argument('-s', '--second', type=int, default=0, help='If preview mode, begining of the preview [in s]')
    parser.add_argument('-e', '--extension', help='extension to consider for merging, converting and deconverting',
                        default='None')
    parser.add_argument('-o', '--output', help='output file [for generation of synthetic benchmarks]')
    parser.add_argument('-t', '--type', help='benchmark type',
                        choices=['fitting', 'clustering', 'synchrony'])

    if len(argv) == 0:
        parser.print_help()
        sys.exit(0)

    args = parser.parse_args(argv)

    steps = args.method.split(',')
    for step in steps:
        if step not in all_steps:
            print_error(['The method "%s" is not recognized' % step])
            sys.exit(0)

    # To save some typing later
    nb_gpu = 0
    (nb_cpu, hostfile, batch, preview, result, extension, output, benchmark, info, second) = \
        (args.cpu, args.hostfile, args.batch, args.preview, args.result, args.extension, args.output, args.type, args.info, args.second)
    filename = os.path.abspath(args.datafile)
    real_file = filename

    f_next, extens = os.path.splitext(filename)

    if info:
        if args.datafile.lower() in __supported_data_files__:
            filename = 'tmp'
            if len(__supported_data_files__[args.datafile.lower()].extension) > 0:
                filename += __supported_data_files__[args.datafile.lower()].extension[0]

            __supported_data_files__[args.datafile.lower()](filename, {}, is_empty=True)._display_requirements_()
        else:
            print_and_log([
                '',
                'To get info on any particular file format, do:',
                '>> spyking-circus file_format -i',
                ''
            ], 'default')
            print_and_log(list_all_file_format())
        sys.exit(0)

    if extens == '.params':
        print_error(['You should launch the code on the data file!'])
        sys.exit(0)

    file_params = f_next + '.params'
    if not os.path.exists(file_params) and not batch:
        print(Fore.RED + 'The parameter file %s is not present!' % file_params)
        create_params = query_yes_no(Fore.WHITE + "Do you want SpyKING CIRCUS to create a parameter file?")

        if create_params:
            print(Fore.WHITE + "Creating %s" % file_params)
            print(Fore.WHITE + "Fill it properly before launching the code! (see documentation)")
            print_info(['Keep in mind that filtering is performed on site, so please',
                        'be sure to keep a copy of your data elsewhere'])
            shutil.copyfile(config_file, file_params)
        sys.exit(0)
    elif batch:
        tasks_list = filename

    if not batch:
        file_params = f_next + '.params'

        if not os.path.exists(file_params):
            print_and_log(["%s does not exist" % file_params], 'error')
            sys.exit(0)

        import ConfigParser as configparser
        parser = configparser.ConfigParser()
        myfile = open(file_params, 'r')
        lines = myfile.readlines()
        myfile.close()
        myfile = open(file_params, 'w')
        for l in lines:
            myfile.write(l.replace('\t', ''))
        myfile.close()

        parser.read(file_params)

        for section in CircusParser.__all_sections__:
            if parser.has_section(section):
                for (key, value) in parser.items(section):
                    parser.set(section, key, value.split('#')[0].rstrip())
            else:
                parser.add_section(section)

        try:
            use_output_dir = parser.get('data', 'output_dir') != ''
        except Exception:
            use_output_dir = False

        if use_output_dir:
            path = os.path.abspath(os.path.expanduser(parser.get('data', 'output_dir')))
            file_out = os.path.join(path, os.path.basename(f_next))
            if not os.path.exists(file_out):
                os.makedirs(file_out)
        else:
            file_out = f_next


        logfile = file_out + '.log'
        if os.path.exists(logfile):
            os.remove(logfile)

        logger = init_logging(logfile)
        params = CircusParser(filename)
        data_file = params.get_data_file(source=True, has_been_created=False)
        overwrite = params.getboolean('data', 'overwrite')
        file_format = params.get('data', 'file_format')
        if overwrite:
            support_parallel_write = data_file.parallel_write
            is_writable = data_file.is_writable
        else:
            support_parallel_write = __supported_data_files__['raw_binary'].parallel_write
            is_writable = __supported_data_files__['raw_binary'].is_writable

    if preview:
        print_and_log(['Preview mode, showing only seconds [%d-%d] of the recording' % (second, second+1)], 'info', logger)
        tmp_path_loc = os.path.join(os.path.abspath(params.get('data', 'file_out')), 'tmp')

        if not os.path.exists(tmp_path_loc):
            os.makedirs(tmp_path_loc)

        filename = os.path.join(tmp_path_loc, 'preview.dat')
        f_next, extens = os.path.splitext(filename)
        preview_params = f_next + '.params'
        shutil.copyfile(file_params, preview_params)
        steps = ['filtering', 'whitening']

        chunk_size = int(params.rate)

        data_file.open()
        nb_chunks, _ = data_file.analyze(chunk_size)

        if nb_chunks <= (second + 1):
            print_and_log(['Recording is too short to display seconds [%d-%d]' % (second, second+1)])
            sys.exit(0)
        local_chunk = data_file.get_snippet(int(second*params.rate), int(1.2*chunk_size))
        description = data_file.get_description()
        data_file.close()

        new_params = CircusParser(filename, create_folders=False)

        new_params.write('data', 'chunk_size', '1')
        new_params.write('data', 'file_format', 'raw_binary')
        new_params.write('data', 'data_dtype', 'float32')
        new_params.write('data', 'data_offset', '0')
        new_params.write('data', 'dtype_offset', '0')
        new_params.write('data', 'stream_mode', 'None')
        new_params.write('data', 'overwrite', 'True')
        new_params.write('triggers', 'ignore_times', 'False')
        new_params.write('data', 'sampling_rate', str(params.rate))
        new_params.write('whitening', 'safety_time', '0')
        new_params.write('clustering', 'safety_time', '0')
        new_params.write('whitening', 'chunk_size', '1')
        new_params.write('data', 'preview_path', params.file_params)
        new_params.write('data', 'output_dir', '')

        description['data_dtype'] = 'float32'
        description['dtype_offset'] = 0
        description['data_offset'] = 0
        description['gain'] = 1.
        new_params = CircusParser(filename)
        data_file_out = new_params.get_data_file(is_empty=True, params=description)

        support_parallel_write = data_file_out.parallel_write
        is_writable = data_file_out.is_writable

        data_file_out.allocate(shape=local_chunk.shape, data_dtype=numpy.float32)
        data_file_out.open('r+')
        data_file_out.set_data(0, local_chunk)
        data_file_out.close()

    if tasks_list is not None:
        with open(tasks_list, 'r') as f:
            for line in f:
                if len(line) > 0:
                    subprocess.check_call(['spyking-circus'] + line.replace('\n', '').split(" "))
    else:

        print_and_log(['Config file: %s' % (f_next + '.params')], 'debug', logger)
        print_and_log(['Data file  : %s' % filename], 'debug', logger)

        print(get_colored_header())
        print(Fore.GREEN + "File          : " + Fore.CYAN + real_file)
        if preview:
            print(Fore.GREEN + "Steps         : " + Fore.CYAN + "preview mode")
        elif result:
            print(Fore.GREEN + "Steps         : " + Fore.CYAN + "result mode")
        else:
            print(Fore.GREEN + "Steps         : " + Fore.CYAN + ", ".join(steps))
        # print Fore.GREEN + "GPU detected  : ", Fore.CYAN + str(HAVE_CUDA)
        print(Fore.GREEN + "Number of CPU : " + Fore.CYAN + str(nb_cpu) + "/" + str(psutil.cpu_count()))
        # if HAVE_CUDA:
        #     print Fore.GREEN + "Number of GPU : ", Fore.CYAN + str(nb_gpu)
        print(Fore.GREEN + "Parallel HDF5 : " + Fore.CYAN + str(parallel_hdf5))

        do_upgrade = ''
        use_shared_memory = get_shared_memory_flag(params)
        if not SHARED_MEMORY:
            do_upgrade = Fore.WHITE + '   [please consider upgrading MPI]'

        print(Fore.GREEN + "Shared memory : " + Fore.CYAN + str(use_shared_memory) + do_upgrade)
        print(Fore.GREEN + "Hostfile      : " + Fore.CYAN + hostfile)
        print("")
        print(Fore.GREEN + "##################################################################")
        print("")
        print(Fore.RESET)

        # Launch the subtasks
        subtasks = [('filtering', 'mpirun'),
                    ('whitening', 'mpirun'),
                    ('clustering', 'mpirun'),
                    ('fitting', 'mpirun'),
                    ('extracting', 'mpirun'),
                    ('gathering', 'python'),
                    ('converting', 'mpirun'),
                    ('deconverting', 'mpirun'),
                    ('benchmarking', 'mpirun'),
                    ('merging', 'mpirun'),
                    ('validating', 'mpirun'),
                    ('thresholding', 'mpirun')]

        # if HAVE_CUDA and nb_gpu > 0:
        #     use_gpu = 'True'
        # else:
        use_gpu = 'False'

        time = data_stats(params) / 60.0

        if preview:
            params = new_params

        if nb_cpu < psutil.cpu_count():
            if use_gpu != 'True' and not result:
                print_and_log(['Using only %d out of %d local CPUs available (-c to change)' % (nb_cpu, psutil.cpu_count())], 'info', logger)

        if params.getboolean('detection', 'matched-filter') and not params.getboolean('clustering', 'smart_search'):
            print_and_log(['Smart Search should be activated for matched filtering'], 'info', logger)

        if time > 30 and not params.getboolean('clustering', 'smart_search'):
            print_and_log(['Smart Search should be activated for long recordings'], 'info', logger)

        n_edges = get_averaged_n_edges(params)
        if n_edges > 100 and not params.getboolean('clustering', 'compress'):
            print_and_log(['Template compression is highly recommended based on parameters'], 'info', logger)

        if not result:
            for subtask, command in subtasks:
                if subtask in steps:
                    if command == 'python':
                        # Directly call the launcher
                        try:
                            circus.launch(subtask, filename, nb_cpu, nb_gpu, use_gpu)
                        except:
                            print_and_log(['Step "%s" failed!' % subtask], 'error', logger)
                            sys.exit(0)
                    elif command == 'mpirun':
                        # Use mpirun to make the call
                        mpi_args = gather_mpi_arguments(hostfile, params)
                        one_cpu = False

                        if subtask in ['filtering', 'benchmarking'] and not is_writable:
                            if not preview and overwrite:
                                print_and_log(['The file format %s is read only!' % file_format,
                                               'You should set overwite to False, to create a copy of the data.',
                                               'However, note that if you have streams, informations on times',
                                               'will be discarded'], 'info', logger)
                                sys.exit(0)

                        if subtask in ['filtering'] and not support_parallel_write and (args.cpu > 1):
                            print_and_log(['No parallel writes for %s: only 1 node used for %s' %(file_format, subtask)], 'info', logger)
                            nb_tasks = str(1)
                            one_cpu = True

                        else:
                            if subtask != 'fitting':
                                nb_tasks = str(args.cpu)
                            else:
                                # if use_gpu == 'True':
                                #     nb_tasks = str(args.gpu)
                                # else:
                                nb_tasks = str(args.cpu)

                        if subtask == 'benchmarking':
                            if (output is None) or (benchmark is None):
                                print_and_log(["To generate synthetic datasets, you must provide output and type"], 'error', logger)
                                sys.exit(0)
                            mpi_args += [
                                '-np', nb_tasks, 'spyking-circus-subtask',
                                subtask, filename, str(nb_cpu), str(nb_gpu),
                                use_gpu, output, benchmark
                            ]
                        elif subtask in ['merging', 'converting']:
                            mpi_args += [
                                '-np', nb_tasks, 'spyking-circus-subtask',
                                subtask, filename, str(nb_cpu), str(nb_gpu),
                                use_gpu, extension
                            ]
                        elif subtask in ['deconverting']:
                            nb_tasks = str(1)
                            nb_cpu = 1
                            mpi_args += [
                                '-np', nb_tasks, 'spyking-circus-subtask', subtask,
                                filename, str(nb_cpu), str(nb_gpu), use_gpu,
                                extension
                            ]
                        else:
                            mpi_args += [
                                '-np', nb_tasks, 'spyking-circus-subtask',
                                subtask, filename, str(nb_cpu), str(nb_gpu),
                                use_gpu, str(one_cpu)
                            ]

                        print_and_log(['Launching task %s' % subtask], 'debug', logger)
                        print_and_log(['Command: %s' % str(mpi_args)], 'debug', logger)

                        try:
                            subprocess.check_call(mpi_args)
                        except subprocess.CalledProcessError as e:
                            print_and_log(['Step "%s" failed for reason %s!' % (subtask, e)], 'error', logger)
                            sys.exit(0)

    if preview or result:
        from circus.shared import gui
        import pylab
        try:
            from PyQt5.QtWidgets import QApplication
        except ImportError:
            from matplotlib.backends import qt_compat
            use_pyside = qt_compat.QT_API == qt_compat.QT_API_PYSIDE
            if use_pyside:
                from PySide.QtGui import QApplication
            else:
                from PyQt4.QtGui import QApplication
        app = QApplication([])
        try:
            pylab.style.use('ggplot')
        except Exception:
            pass

        if preview:
            print_and_log(['Launching the preview GUI...'], 'debug', logger)
            mygui = gui.PreviewGUI(new_params)
            shutil.rmtree(tmp_path_loc)
        elif result:
            data_file = params.get_data_file()
            print_and_log(['Launching the result GUI...'], 'debug', logger)
            mygui = gui.PreviewGUI(params, show_fit=True)
        sys.exit(app.exec_())
def main(params, nb_cpu, nb_gpu, use_gpu):

    #################################################################
    # params = detect_memory(params)
    _ = init_logging(params.logfile)
    SHARED_MEMORY = get_shared_memory_flag(params)
    logger = logging.getLogger('circus.fitting')
    data_file = params.data_file
    N_e = params.getint('data', 'N_e')
    N_total = params.nb_channels
    N_t = params.getint('detection', 'N_t')
    template_shift = params.getint('detection', 'template_shift')
    file_out = params.get('data', 'file_out')
    file_out_suff = params.get('data', 'file_out_suff')
    sign_peaks = params.get('detection', 'peaks')
    dist_peaks = params.getint('detection', 'dist_peaks')
    matched_filter = params.getboolean('detection', 'matched-filter')
    spike_thresh = params.getfloat('detection', 'spike_thresh')
    spike_width = params.getfloat('detection', 'spike_width')
    do_temporal_whitening = params.getboolean('whitening', 'temporal')
    do_spatial_whitening = params.getboolean('whitening', 'spatial')
    chunk_size = detect_memory(params)
    gpu_only = params.getboolean('fitting', 'gpu_only')
    nodes, edges = get_nodes_and_edges(params)
    tmp_limits = params.get('fitting',
                            'amp_limits').replace('(',
                                                  '').replace(')',
                                                              '').split(',')
    tmp_limits = map(float, tmp_limits)
    amp_auto = params.getboolean('fitting', 'amp_auto')
    nb_chances = params.getint('fitting', 'nb_chances')
    max_chunk = params.getfloat('fitting', 'max_chunk')
    noise_thr = params.getfloat('clustering', 'noise_thr')
    collect_all = params.getboolean('fitting', 'collect_all')
    ignore_dead_times = params.getboolean('triggers', 'ignore_times')
    inv_nodes = numpy.zeros(N_total, dtype=numpy.int32)
    inv_nodes[nodes] = numpy.arange(len(nodes))
    data_file.open()
    #################################################################

    if use_gpu:
        import cudamat as cmt
        # # Need to properly handle multi GPU per MPI nodes?
        if nb_gpu > nb_cpu:
            gpu_id = int(comm.rank // nb_cpu)
        else:
            gpu_id = 0
        cmt.cuda_set_device(gpu_id)
        cmt.init()
        cmt.cuda_sync_threads()

    if matched_filter:
        if sign_peaks in ['negative', 'both']:
            waveform_neg = io.load_data(params, 'waveform')[::-1]
            waveform_neg /= (numpy.abs(numpy.sum(waveform_neg)) *
                             len(waveform_neg))
            matched_tresholds_neg = io.load_data(params, 'matched-thresholds')
        if sign_peaks in ['positive', 'both']:
            waveform_pos = io.load_data(params, 'waveform-pos')[::-1]
            waveform_pos /= (numpy.abs(numpy.sum(waveform_pos)) *
                             len(waveform_pos))
            matched_tresholds_pos = io.load_data(params,
                                                 'matched-thresholds-pos')

    if ignore_dead_times:
        all_dead_times = get_dead_times(params)

    thresholds = io.load_data(params, 'thresholds')

    comm.Barrier()

    if comm.rank == 0:
        print_and_log(["Extracting MUA activity..."], 'default', logger)
        purge(file_out_suff, '.data')

    if do_spatial_whitening:
        spatial_whitening = io.load_data(params, 'spatial_whitening')
    else:
        spatial_whitening = None  # default assignment (PyCharm code inspection)
    if do_temporal_whitening:
        temporal_whitening = io.load_data(params, 'temporal_whitening')
    else:
        temporal_whitening = None  # default assignment (PyCharm code inspection)

    nb_chunks, last_chunk_len = data_file.analyze(chunk_size)
    processed_chunks = int(min(nb_chunks, max_chunk))

    comm.Barrier()
    spiketimes_file = open(file_out_suff + '.mua-%d.data' % comm.rank, 'wb')
    comm.Barrier()
    electrodes_file = open(file_out_suff + '.elec-%d.data' % comm.rank, 'wb')
    comm.Barrier()
    amplitudes_file = open(file_out_suff + '.amp-%d.data' % comm.rank, 'wb')
    comm.Barrier()

    if use_gpu and do_spatial_whitening:
        spatial_whitening = cmt.CUDAMatrix(spatial_whitening,
                                           copy_on_host=False)

    to_explore = range(comm.rank, processed_chunks, comm.size)

    if comm.rank == 0:
        to_explore = get_tqdm_progressbar(to_explore)

    for gcount, gidx in enumerate(to_explore):
        # print "Node", comm.rank, "is analyzing chunk", gidx, "/", nb_chunks, " ..."
        # # We need to deal with the borders by taking chunks of size [0, chunck_size + template_shift].

        is_first = data_file.is_first_chunk(gidx, nb_chunks)
        is_last = data_file.is_last_chunk(gidx, nb_chunks)

        if is_last:
            padding = (-dist_peaks, 0)
        elif is_first:
            padding = (0, dist_peaks)
        else:
            padding = (-dist_peaks, dist_peaks)

        result = {'spiketimes': [], 'amplitudes': [], 'templates': []}

        local_chunk, t_offset = data_file.get_data(gidx,
                                                   chunk_size,
                                                   padding,
                                                   nodes=nodes)
        len_chunk = len(local_chunk)

        if do_spatial_whitening:
            if use_gpu:
                local_chunk = cmt.CUDAMatrix(local_chunk, copy_on_host=False)
                local_chunk = local_chunk.dot(spatial_whitening).asarray()
            else:
                local_chunk = numpy.dot(local_chunk, spatial_whitening)
        if do_temporal_whitening:
            local_chunk = scipy.ndimage.filters.convolve1d(local_chunk,
                                                           temporal_whitening,
                                                           axis=0,
                                                           mode='constant')

        # print "Extracting the peaks..."

        local_peaktimes = [numpy.zeros(0, dtype=numpy.uint32)]
        local_elecs = [numpy.zeros(0, dtype=numpy.uint32)]
        local_amps = [numpy.zeros(0, dtype=numpy.float32)]

        if matched_filter:
            if sign_peaks in ['positive', 'both']:
                filter_chunk = scipy.ndimage.filters.convolve1d(
                    local_chunk, waveform_pos, axis=0, mode='constant')
                for i in range(N_e):
                    peaktimes = scipy.signal.find_peaks(
                        filter_chunk[:, i],
                        height=matched_tresholds_pos[i],
                        width=spike_width,
                        distance=dist_peaks,
                        wlen=N_t)[0]
                    local_peaktimes.append(peaktimes)
                    local_elecs.append(
                        i * numpy.ones(len(peaktimes), dtype='uint32'))
                    local_amps.append(filter_chunk[peaktimes, i])
            if sign_peaks in ['negative', 'both']:
                filter_chunk = scipy.ndimage.filters.convolve1d(
                    local_chunk, waveform_neg, axis=0, mode='constant')
                for i in range(N_e):
                    peaktimes = scipy.signal.find_peaks(
                        filter_chunk[:, i],
                        height=matched_tresholds_neg[i],
                        width=spike_width,
                        distance=dist_peaks,
                        wlen=N_t)[0]
                    local_peaktimes.append(peaktimes)
                    local_elecs.append(
                        i * numpy.ones(len(peaktimes), dtype='uint32'))
                    local_amps.append(filter_chunk[peaktimes, i])
        else:
            for i in range(N_e):
                if sign_peaks == 'negative':
                    peaktimes = scipy.signal.find_peaks(-local_chunk[:, i],
                                                        height=thresholds[i],
                                                        width=spike_width,
                                                        distance=dist_peaks,
                                                        wlen=N_t)[0]
                elif sign_peaks == 'positive':
                    peaktimes = scipy.signal.find_peaks(local_chunk[:, i],
                                                        height=thresholds[i],
                                                        width=spike_width,
                                                        distance=dist_peaks,
                                                        wlen=N_t)[0]
                elif sign_peaks == 'both':
                    peaktimes = scipy.signal.find_peaks(numpy.abs(
                        local_chunk[:, i]),
                                                        height=thresholds[i],
                                                        width=spike_width,
                                                        distance=dist_peaks,
                                                        wlen=N_t)[0]
                local_peaktimes.append(peaktimes)
                local_elecs.append(i *
                                   numpy.ones(len(peaktimes), dtype='uint32'))
                local_amps.append(local_chunk[peaktimes, i])

        local_peaktimes = numpy.concatenate(local_peaktimes)
        local_elecs = numpy.concatenate(local_elecs)
        local_amps = numpy.concatenate(local_amps)

        g_offset = t_offset + padding[0]

        if ignore_dead_times:
            dead_indices = numpy.searchsorted(
                all_dead_times, [t_offset, t_offset + chunk_size])
            if dead_indices[0] != dead_indices[1]:
                is_included = numpy.in1d(
                    local_peaktimes + g_offset,
                    all_dead_times[dead_indices[0]:dead_indices[1]])
                local_peaktimes = local_peaktimes[~is_included]
                local_elecs = local_elecs[~is_included]
                local_amps = local_amps[~is_included]

        # print "Removing the useless borders..."
        local_borders = (dist_peaks, len_chunk - dist_peaks)
        idx = (local_peaktimes >= local_borders[0]) & (local_peaktimes <
                                                       local_borders[1])
        local_peaktimes = numpy.compress(idx, local_peaktimes) + g_offset
        local_elecs = numpy.compress(idx, local_elecs)
        local_amps = numpy.compress(idx, local_amps)

        spiketimes_file.write(local_peaktimes.astype(numpy.uint32).tostring())
        electrodes_file.write(local_elecs.tostring())
        amplitudes_file.write(local_amps.tostring())

    sys.stderr.flush()

    spiketimes_file.flush()
    os.fsync(spiketimes_file.fileno())
    spiketimes_file.close()

    electrodes_file.flush()
    os.fsync(electrodes_file.fileno())
    electrodes_file.close()

    amplitudes_file.flush()
    os.fsync(amplitudes_file.fileno())
    amplitudes_file.close()

    comm.Barrier()

    if comm.rank == 0:
        io.collect_mua(comm.size, params, erase=True)

    data_file.close()
示例#4
0
def extract_extra_spikes_(params):
    """Detect spikes from the extracellular traces"""
    
    data_file = params.data_file
    data_file.open()
    dist_peaks     = params.getint('detection', 'dist_peaks')
    spike_thresh   = params.getfloat('detection', 'spike_thresh')
    template_shift = params.getint('detection', 'template_shift')
    alignment      = params.getboolean('detection', 'alignment')
    do_temporal_whitening = params.getboolean('whitening', 'temporal')
    do_spatial_whitening  = params.getboolean('whitening', 'spatial')
    safety_time  = params.getint('whitening', 'safety_time')
    safety_space = params.getboolean('clustering', 'safety_space')
    chunk_size   = params.getint('data', 'chunk_size')
    # chunk_size = params.getint('whitening', 'chunk_size')
    N_total        = params.nb_channels
    file_out_suff  = params.get('data', 'file_out_suff')
    
    if do_spatial_whitening:
        spatial_whitening  = io.load_data(params, 'spatial_whitening')
    if do_temporal_whitening:
        temporal_whitening = io.load_data(params, 'temporal_whitening')
    
    #mpi_file = MPI.File()
    #mpi_input = mpi_file.Open(comm, data_filename, MPI.MODE_RDONLY)
    nb_chunks, last_chunk_len = data_file.analyze(chunk_size)
    nodes, _ = get_nodes_and_edges(params)
    N_elec   = params.getint('data', 'N_e')
        
    extra_medians, extra_mads = extract_extra_thresholds(params)
    
    if comm.rank == 0:
        # Save medians and median absolute deviations to BEER file.
        path = "{}.beer.hdf5".format(file_out_suff)
        beer_file = h5py.File(path, 'a', libver='latest')
        ## Save medians.
        extra_medians_key = "extra_medians"
        if extra_medians_key in beer_file.keys():
            beer_file.pop(extra_medians_key)
        beer_file.create_dataset(extra_medians_key, data=extra_medians)
        ## Save median absolute deviations.
        extra_mads_key = "extra_mads"
        if extra_mads_key in beer_file.keys():
            beer_file.pop(extra_mads_key)
        beer_file.create_dataset(extra_mads_key, data=extra_mads)
        beer_file.close()
    
    def extract_chunk_spikes(gidx, extra_thresh, valley=True):
        """Detect spikes from a chunk of the extracellular traces"""
        
        loc_chunk, t_offset = data_file.get_data(gidx, chunk_size, nodes=nodes)
        loc_shape = len(loc_chunk)
        
        # Whiten signal.
        if do_spatial_whitening:
            loc_chunk = numpy.dot(loc_chunk, spatial_whitening)
        if do_temporal_whitening:
            loc_chunk = scipy.ndimage.filters.convolve1d(loc_chunk, temporal_whitening,
                                                         axis=0, mode='constant')
        
        ##### TODO: uncomment or remove temporary zone
        # # For each electrode, center traces by removing the medians.
        # extra_medians = numpy.median(loc_chunk, axis=0)
        # loc_chunk = loc_chunk - extra_medians
        ##### end temporary zone
        
        # Preallocation for results.
        peak_times = N_elec * [None]
        peak_channels = N_elec * [None]
        # For each electrode.
        for e in xrange(N_elec):
            # Extract the peaks of the current chunk.
            threshold = extra_thresh * extra_mads[e]
            peak_times[e] = algo.detect_peaks(loc_chunk[:, e], threshold, valley=valley, mpd=dist_peaks)
            peak_channels[e] = e * numpy.ones(peak_times[e].size, dtype='int')
            
            peak_values = loc_chunk[peak_times[e], e]
            if valley:
                peak_indices = numpy.where(-10.0 * threshold <= peak_values)[0]
            else:
                peak_indices = numpy.where(peak_values <= +10.0 * threshold)[0]
            peak_times[e] = peak_times[e][peak_indices]
            peak_channels[e] = peak_channels[e][peak_indices]
        
        peak_times = numpy.concatenate(peak_times)
        peak_channels = numpy.concatenate(peak_channels)
        # Remove the useless borders.
        if alignment:
            loc_borders = (2 * template_shift, loc_shape - 2 * template_shift)
        else:
            loc_borders = (template_shift, loc_shape - template_shift)
        peak_flags = (loc_borders[0] <= peak_times) & (peak_times < loc_borders[1])
        peak_times = numpy.compress(peak_flags, peak_times)
        peak_channels = numpy.compress(peak_flags, peak_channels)
        # Filter unique peak times.
        loc_peak_times = numpy.unique(peak_times)
        ##### TODO: remove debug zone
        # if gidx < 1:
        #     numpy.save("tmp/loc_peak_times_{}_{}_.npy".format(gidx, int(extra_thresh)), loc_peak_times)
        ##### end debug zone
        n_times = len(loc_peak_times)
        loc_peak_flags = numpy.zeros(n_times, dtype='bool')
        loc_peak_elecs = numpy.zeros(n_times, dtype='int')
        loc_peak_values = numpy.zeros(n_times, dtype='float')
        if 0 < len(loc_peak_times):
            diff_times = loc_peak_times[-1] - loc_peak_times[0]
            all_times = numpy.zeros((N_elec, diff_times + 1), dtype='bool')
            min_times = numpy.maximum(loc_peak_times - loc_peak_times[0] - safety_time, 0)
            max_times = numpy.minimum(loc_peak_times - loc_peak_times[0] + safety_time + 1, diff_times)
            # Shuffle peaks.
            ##### TODO: clean temporary zone
            # argmax_peak = numpy.random.permutation(numpy.arange(n_times))
            if valley:
                for i, loc_peak_time in enumerate(loc_peak_times):
                    loc_peak_values[i] = numpy.amin(loc_chunk[loc_peak_time, :])
                argmax_peak = numpy.argsort(loc_peak_values)
            else:
                for i, loc_peak_time in enumerate(loc_peak_times):
                    loc_peak_values[i] = numpy.amax(loc_chunk[loc_peak_time, :])
                argmax_peak = numpy.argsort(loc_peak_values)
                argmes_peak = argmax_peak[::-1]
            ##### end temporary zone
            all_indices = loc_peak_times[argmax_peak]
            # Select peaks with spatio-temporal masks.
            for peak_index, peak_time in zip(argmax_peak, all_indices):
                # Select electrode showing lowest amplitude.
                if valley:
                    elec = numpy.argmin(loc_chunk[peak_time, :])
                else:
                    elec = numpy.argmax(loc_chunk[peak_time, :])
                _, neighs = get_neighbors(params, chan=elec)
                if safety_space:
                    mslice = all_times[neighs, min_times[peak_index]:max_times[peak_index]]
                else:
                    mslice = all_times[elec, min_times[peak_index]:max_times[peak_index]]
                is_local_min = (elec in peak_channels[peak_times == peak_time])
                if is_local_min and not mslice.any():
                    loc_peak_flags[peak_index] = True
                    loc_peak_elecs[peak_index] = elec
                    if valley:
                        loc_peak_values[peak_index] = - loc_chunk[peak_time, elec]
                    else:
                        loc_peak_values[peak_index] = loc_chunk[peak_time, elec]
                    if safety_space:
                        all_times[neighs, min_times[peak_index]:max_times[peak_index]] = True
                        # all_times[elec, min_times[peak_index]:max_times[peak_index]] = True
                    else:
                        all_times[elec, min_times[peak_index]:max_times[peak_index]] = True
        loc_peak_times = numpy.compress(loc_peak_flags, loc_peak_times)
        loc_peak_elecs = numpy.compress(loc_peak_flags, loc_peak_elecs)
        loc_peak_values = numpy.compress(loc_peak_flags, loc_peak_values)

        ##### TODO: remove debug zone
        # if gidx < 1:
        #     numpy.save("tmp/loc_peak_times_{}_{}.npy".format(gidx, int(extra_thresh)), loc_peak_times)
        #     numpy.save("tmp/loc_peak_elecs_{}_{}.npy".format(gidx, int(extra_thresh)), loc_peak_elecs)
        #     numpy.save("tmp/loc_peak_values_{}_{}.npy".format(gidx, int(extra_thresh)), loc_peak_values)
        #     numpy.save("tmp/loc_chunk_{}_{}.npy".format(gidx, int(extra_thresh)), loc_chunk)
        ##### end debug zone
        
        return loc_peak_times + t_offset, loc_peak_elecs, loc_peak_values
    
    # Distribute chunks over CPUs.
    all_chunks = numpy.arange(nb_chunks)
    loc_all_chunks = all_chunks[comm.rank::comm.size]
    loc_nb_chunks = len(loc_all_chunks)
    
    if comm.rank == 0:
        print_and_log(["Collecting extracellular spikes..."], level='default', logger=logger)
    
    to_explore = xrange(comm.rank, nb_chunks, comm.size)

    if comm.rank == 0:
        to_explore = get_tqdm_progressbar(to_explore)
    
    extra_valley = True
    
    ##### TODO: remove test zone (i.e. plots of extracellular spike times).
    # plot_extracted_extra_spikes(loc_all_chunks, data_len, mpi_input, data_dtype,
    #                             chunk_len, chunk_size, N_total, nodes,
    #                             extra_medians, extra_mads, k, params, safety_space,
    #                             safety_time)
    # sys.exit(1)
    ##### end test zone
    
    # Preallocation for results.
    times = len(loc_all_chunks) * [None]
    channels = len(loc_all_chunks) * [None]
    values = len(loc_all_chunks) * [None]
    
    data_file.open()
    # For each chunk attributed to the current CPU.
    for (count, gidx) in enumerate(to_explore):
        gidx = all_chunks[gidx]
        time, channel, value = extract_chunk_spikes(gidx, spike_thresh, valley=extra_valley)
        times[count] = time
        channels[count] = channel
        values[count] = value
    
    # Concatenate times, channels and values.
    times = numpy.hstack(times)
    channels = numpy.hstack(channels)
    values = numpy.hstack(values)
        
    data_file.close()
    comm.Barrier()
    
    # Gather times, channels and values.
    times    = gather_array(times.astype(numpy.int64), comm, 0, dtype='int64')
    channels = gather_array(channels.astype(numpy.int64), comm, 0, dtype='int64')
    values   = gather_array(values.astype(numpy.float64), comm, 0, dtype='float64')
    
    if comm.rank == 0:
        # Sort times, channels and values according to time.
        idx = numpy.argsort(times)
        times = times[idx]
        channels = channels[idx]
        values = values[idx]
    
        msg = [
            "Total number of extracellular spikes extracted: {}".format(channels.size),
        ] 
        msg2 = [
            "Number of extracellular spikes extracted on channel {}: {}".format(i, channels[channels == i].size) for i in numpy.unique(channels)
        ]
        print_and_log(msg, level='info', logger=logger)
        print_and_log(msg2, level='debug', logger=logger)
    
        path = "{}.beer.hdf5".format(file_out_suff)
        beer_file = h5py.File(path, 'a', libver='latest')
        group_name = "extra_spiketimes"
        if group_name in beer_file.keys():
            beer_file.pop(group_name)
        beer_file.create_group(group_name)
        for i in numpy.arange(0, N_elec):
            mask = (channels == i)
            triggers = times[mask]
            beer_file.create_dataset("{}/elec_{}".format(group_name, i), data=triggers)
        group_name = "extra_spike_values"
        if group_name in beer_file.keys():
            beer_file.pop(group_name)
        beer_file.create_group(group_name)
        for i in numpy.arange(0, N_elec):
            mask = (channels == i)
            data = values[mask]
            beer_file.create_dataset("{}/elec_{}".format(group_name, i), data=data)
        beer_file.close()

    comm.Barrier()
    
    return
示例#5
0
def ellipsoid_general_to_standard(coefs, verbose=False, logger=None):
    """
    Convert an ellipsoid in general form:
        a_{0}
        + a_{1} x1 + ... + a_{m} xm
        + a_{1, 1} * x1 * x1 + ... + a_{1, m} * x1 * xm
        + ...
        + a_{m, m} xm * xm
        = 0
    To standard form (TODO: check validity):
        (x1 - x1') * phi1(t_{1, 2}, ..., t_{m-1, m})
        + ...
        + (xm - xm') * phim(t_{1, 2}, ..., t_{m-1, m})
    The ellipse has center [x1', ..., xm']^T, semi-axes b1, ... and bm, and
    the angle to the semi-major axis is t.
    """
    # Convert to float.
    coefs = coefs.astype('float')
    K = coefs.size
    # Retrieve the number of dimension (i.e. N).
    # (i.e. solve: 1 + N + (N + 1) * N / 2 = K)
    N = int(- 1.5 + numpy.sqrt(1.5 ** 2.0 - 4.0 * 0.5 * (1.0 - float(K))))
    if verbose:
        msg = [
            "# K",
            "%s" %(K,),
            "# N",
            "%s" %(N,),
        ]
        print_and_log(msg, level='default', logger=logger)
    # Retrieve the matrix representation.
    A = numpy.zeros((N, N))
    k = 0
    for i in xrange(0, N):
        A[i, i] = coefs[1 + N + k]
        k = k + 1
        for j in xrange(i + 1, N):
            A[i, j] = coefs[1 + N + k] / 2.0
            A[j, i] = coefs[1 + N + k] / 2.0
            k = k + 1
    b = coefs[1:1+N]
    c = coefs[0]
    # Compute the center of the ellipsoid.
    center = - 0.5 * numpy.dot(numpy.linalg.inv(A), b)
    
    ##### TODO: remove test zone
    if verbose:
        msg = [
            "# Test of symmetry",
            "%s" %(numpy.all(A == A.T),),
        ]
        print_and_log(msg, level='default', logger=logger)
    ##### end test zone
    
    # Each eigenvector of A lies along one of the axes.
    evals, evecs = numpy.linalg.eigh(A)
    
    ##### TODO: remove print zone.
    if verbose:
        msg = [
            "# Semi-axes computation",
            "## det(A)",
            "%s" %(numpy.linalg.det(A),),
            "## evals",
            "%s" %(evals,),
        ]
        print_and_log(msg, level='default', logger=logger)
    ##### end print zone.
    
    # Semi-axes from reduced canonical equation.
    ##### TODO: remove test zone.
    # eaxis = numpy.sqrt(- c / evals)
    eaxis = numpy.sqrt(numpy.abs(-c / evals))
    ##### end test zone
    return center, eaxis, evecs
示例#6
0
    def compute_artefacts(data_file):

        chunk_size     = params.getint('data', 'chunk_size')
        trig_in_ms     = params.getboolean('triggers', 'trig_in_ms')
        artefacts      = numpy.loadtxt(params.get('triggers', 'trig_file'))
        windows        = numpy.loadtxt(params.get('triggers', 'trig_windows'))
        make_plots     = params.get('triggers', 'make_plots')
        plot_path      = os.path.join(params.get('data', 'file_out_suff'), 'plots')

        if len(windows.shape) == 1:
            windows = windows.reshape(1, 2)

        if len(artefacts.shape) == 1:
            artefacts = artefacts.reshape(1, 2)

        if trig_in_ms:
            if comm.rank == 0:
                print_and_log(['Artefact times are read in ms'], 'debug', logger)
            artefacts[:, 1] *= numpy.int64(data_file.sampling_rate*1e-3)
            windows[:, 1]   *= numpy.int64(data_file.sampling_rate*1e-3)
        else:
            if comm.rank == 0:
                print_and_log(['Artefact times are read in timesteps'], 'debug', logger)

        artefacts        = artefacts.astype(numpy.int64)
        windows          = windows.astype(numpy.int64)

        nb_stimuli       = len(numpy.unique(artefacts[:, 0]))
        mytest           = nb_stimuli == len(windows)

        if not mytest:
            if comm.rank == 0:
                print_and_log(['Error in the trigger files'], 'error', logger)
            sys.exit(0)

        all_labels   = artefacts[:, 0]
        all_times    = artefacts[:, 1]
        mask         = (all_times >= 0) & (all_times + numpy.max(windows[:,1]) < data_file.t_stop)
        all_times    = numpy.compress(mask, all_times)
        all_labels   = numpy.compress(mask, all_labels)

        local_labels = numpy.unique(all_labels)[comm.rank::comm.size]

        if comm.rank == 0:
            to_write = ["Computing averaged artefacts from %d stimuli" %(nb_stimuli)]
            print_and_log(to_write, 'default', logger)
            if not os.path.exists(plot_path):
                os.makedirs(plot_path)
            local_labels = get_tqdm_progressbar(local_labels)

        comm.Barrier()
        # First we need to get the average artefacts
        art_dict = {}
        for count, artefact in enumerate(local_labels):
            indices  = numpy.where(all_labels == artefact)[0].astype(numpy.uint32)
            tmp      = numpy.where(windows[:, 0] == artefact)[0]
            tau      = numpy.int64(windows[tmp, 1])
            pspikes  = all_times[indices]
            times    = numpy.sort(numpy.random.permutation(pspikes)[:500])
            if len(numpy.where(numpy.diff(times) < tau)[0]) > 0:
                if comm.rank == 0:
                    print_and_log(['Stimulation times for artefact %d are too close!' %artefact], 'error', logger)
                sys.exit(0)

            art_dict[artefact] = get_artefact(params, times, tau, nodes)
            if make_plots not in ['None', '']:
                save     = [plot_path, '%d.%s' %(artefact, make_plots)]
                plot.view_artefact(art_dict[artefact], save=save)

        sys.stderr.flush()
        return art_dict
示例#7
0
def main(params, nb_cpu, nb_gpu, use_gpu):

    logger         = init_logging(params.logfile)
    logger         = logging.getLogger('circus.filtering')
    #################################################################
    do_filter      = params.getboolean('filtering', 'filter')
    filter_done    = check_if_done(params, 'filter_done', logger)
    artefacts_done = check_if_done(params, 'artefacts_done', logger)
    median_done    = check_if_done(params, 'median_done', logger)
    ground_done    = check_if_done(params, 'ground_done', logger)
    clean_artefact = params.getboolean('triggers', 'clean_artefact')
    remove_median  = params.getboolean('filtering', 'remove_median')
    common_ground  = params.getint('filtering', 'common_ground')
    remove_ground  = common_ground >= 0
    nodes, edges   = get_nodes_and_edges(params)
    #################################################################


    def filter_file(data_file_in, data_file_out, do_filtering, do_remove_median, do_remove_ground):

        try:
            cut_off    = params.getfloat('filtering', 'cut_off')
            cut_off    = [cut_off, 0.95*(params.rate/2.)]
        except Exception:
            cut_off        = params.get('filtering', 'cut_off')
            cut_off        = cut_off.split(',')
            try:
                cut_off[0] = float(cut_off[0])
            except Exception:
                if comm.rank == 0:
                    print_and_log(['First value of cut off must be a valid number'], 'error', logger)
                sys.exit(0)

            cut_off[1] = cut_off[1].replace(' ', '')
            if cut_off[1] == 'auto':
                cut_off[1] = 0.95*(params.rate/2.)
            else:
                try:
                    cut_off[1] = float(cut_off[1])
                except Exception:
                    if comm.rank == 0:
                        print_and_log(['Second value of cut off must either auto, or a valid a number'], 'error', logger)
                    sys.exit(0)

        chunk_size    = params.getint('data', 'chunk_size')
        nb_chunks, _  = data_file_in.analyze(chunk_size)

        b, a          = signal.butter(3, np.array(cut_off)/(params.rate/2.), 'pass')
        all_chunks    = numpy.arange(nb_chunks, dtype=numpy.int64)
        to_process    = all_chunks[comm.rank::comm.size]
        loc_nb_chunks = len(to_process)
        N_total       = params.nb_channels
        process_all_channels = numpy.all(nodes == numpy.arange(N_total))

        if comm.rank == 0:
            to_write = []
            if do_filtering:
                to_write += ["Filtering the signal with a Butterworth filter in (%g, %g) Hz" %(cut_off[0],cut_off[1])]
            if do_remove_median:
                to_write += ["Median over all channels is subtracted to each channels"]
            if do_remove_ground:
                to_write += ["Channel %s is used as a reference channel" %common_ground]

            print_and_log(to_write, 'default', logger)

        to_explore = xrange(comm.rank, nb_chunks, comm.size)

        if comm.rank == 0:
            to_explore = get_tqdm_progressbar(to_explore)

        for count, gidx in enumerate(to_explore):

            local_chunk, t_offset =  data_file_in.get_data(gidx, chunk_size)

            if do_filtering:
                for i in nodes:    
                    try:
                        local_chunk[:, i] = signal.filtfilt(b, a, local_chunk[:, i])
                    except Exception:
                        pass
                    local_chunk[:, i] -= numpy.median(local_chunk[:, i]) 

            if do_remove_median:
                if not process_all_channels:
                    global_median = numpy.median(numpy.take(local_chunk, nodes, axis=1), 1)
                else:
                    global_median = numpy.median(local_chunk, 1)

                for i in nodes:
                    local_chunk[:, i] -= global_median

            if common_ground > -1:
                for i in nodes:
                    local_chunk[:, i] -= local_chunk[:, common_ground]

            if data_file_in != data_file_out and data_file_in.is_first_chunk(gidx, nb_chunks):
                if data_file_in.is_stream:
                    g_offset = t_offset - numpy.sum(data_file_in._times[:data_file_in._get_streams_index_by_time(t_offset)+1])
                else:
                    g_offset = t_offset - data_file_in.t_start
            else:
                g_offset = t_offset

            data_file_out.set_data(g_offset, local_chunk)

        sys.stderr.flush()
        comm.Barrier()


    def compute_artefacts(data_file):

        chunk_size     = params.getint('data', 'chunk_size')
        trig_in_ms     = params.getboolean('triggers', 'trig_in_ms')
        artefacts      = numpy.loadtxt(params.get('triggers', 'trig_file'))
        windows        = numpy.loadtxt(params.get('triggers', 'trig_windows'))
        make_plots     = params.get('triggers', 'make_plots')
        plot_path      = os.path.join(params.get('data', 'file_out_suff'), 'plots')

        if len(windows.shape) == 1:
            windows = windows.reshape(1, 2)

        if len(artefacts.shape) == 1:
            artefacts = artefacts.reshape(1, 2)

        if trig_in_ms:
            if comm.rank == 0:
                print_and_log(['Artefact times are read in ms'], 'debug', logger)
            artefacts[:, 1] *= numpy.int64(data_file.sampling_rate*1e-3)
            windows[:, 1]   *= numpy.int64(data_file.sampling_rate*1e-3)
        else:
            if comm.rank == 0:
                print_and_log(['Artefact times are read in timesteps'], 'debug', logger)

        artefacts        = artefacts.astype(numpy.int64)
        windows          = windows.astype(numpy.int64)

        nb_stimuli       = len(numpy.unique(artefacts[:, 0]))
        mytest           = nb_stimuli == len(windows)

        if not mytest:
            if comm.rank == 0:
                print_and_log(['Error in the trigger files'], 'error', logger)
            sys.exit(0)

        all_labels   = artefacts[:, 0]
        all_times    = artefacts[:, 1]
        mask         = (all_times >= 0) & (all_times + numpy.max(windows[:,1]) < data_file.t_stop)
        all_times    = numpy.compress(mask, all_times)
        all_labels   = numpy.compress(mask, all_labels)

        local_labels = numpy.unique(all_labels)[comm.rank::comm.size]

        if comm.rank == 0:
            to_write = ["Computing averaged artefacts from %d stimuli" %(nb_stimuli)]
            print_and_log(to_write, 'default', logger)
            if not os.path.exists(plot_path):
                os.makedirs(plot_path)
            local_labels = get_tqdm_progressbar(local_labels)

        comm.Barrier()
        # First we need to get the average artefacts
        art_dict = {}
        for count, artefact in enumerate(local_labels):
            indices  = numpy.where(all_labels == artefact)[0].astype(numpy.uint32)
            tmp      = numpy.where(windows[:, 0] == artefact)[0]
            tau      = numpy.int64(windows[tmp, 1])
            pspikes  = all_times[indices]
            times    = numpy.sort(numpy.random.permutation(pspikes)[:500])
            if len(numpy.where(numpy.diff(times) < tau)[0]) > 0:
                if comm.rank == 0:
                    print_and_log(['Stimulation times for artefact %d are too close!' %artefact], 'error', logger)
                sys.exit(0)

            art_dict[artefact] = get_artefact(params, times, tau, nodes)
            if make_plots not in ['None', '']:
                save     = [plot_path, '%d.%s' %(artefact, make_plots)]
                plot.view_artefact(art_dict[artefact], save=save)

        sys.stderr.flush()
        return art_dict


    def remove_artefacts(data_file, art_dict):

        chunk_size     = params.getint('data', 'chunk_size')
        trig_in_ms     = params.getboolean('triggers', 'trig_in_ms')
        artefacts      = numpy.loadtxt(params.get('triggers', 'trig_file'))
        windows        = numpy.loadtxt(params.get('triggers', 'trig_windows'))
        make_plots     = params.get('triggers', 'make_plots')
        plot_path      = os.path.join(params.get('data', 'file_out_suff'), 'plots')

        if len(windows.shape) == 1:
            windows = windows.reshape(1, 2)

        if len(artefacts.shape) == 1:
            artefacts = artefacts.reshape(1, 2)

        if trig_in_ms:
            if comm.rank == 0:
                print_and_log(['Artefact times are read in ms'], 'debug', logger)
            artefacts[:, 1] *= numpy.int64(data_file.sampling_rate*1e-3)
            windows[:, 1]   *= numpy.int64(data_file.sampling_rate*1e-3)
        else:
            if comm.rank == 0:
                print_and_log(['Artefact times are read in timesteps'], 'debug', logger)

        artefacts        = artefacts.astype(numpy.int64)
        windows          = windows.astype(numpy.int64)
        nb_stimuli       = len(numpy.unique(artefacts[:, 0]))
        mytest           = nb_stimuli == len(windows)

        if not mytest:
            if comm.rank == 0:
                print_and_log(['Error in the trigger files'], 'error', logger)
            sys.exit(0)

        all_labels   = artefacts[:, 0]
        all_times    = artefacts[:, 1]
        local_labels = numpy.unique(all_labels)[comm.rank::comm.size]

        mask       = numpy.in1d(all_labels, local_labels)
        all_times  = numpy.compress(mask, all_times)
        all_labels = numpy.compress(mask, all_labels)

        mask       = (all_times >= 0) & (all_times < data_file.t_stop)
        all_times  = numpy.compress(mask, all_times)
        all_labels = numpy.compress(mask, all_labels)

        if comm.rank == 0:
            to_write = ["Removing artefacts from %d stimuli" %(nb_stimuli)]
            print_and_log(to_write, 'default', logger)
            all_times = get_tqdm_progressbar(all_times)

        comm.Barrier()

        for count, time in enumerate(all_times):

            label = all_labels[count]
            tmp   = numpy.where(windows[:, 0] == label)[0][0]
            tau   = numpy.int64(windows[tmp, 1])

            if (data_file.t_stop - time) < tau:
                tau   = max_offset - time

            local_chunk   = data_file.get_snippet(time, tau)
            for idx, i in enumerate(nodes):
                local_chunk[:, i] -= art_dict[label][idx, :tau]
            data_file.set_data(time, local_chunk)

        comm.Barrier()
        sys.stderr.flush()


    if comm.rank == 0:
        print_and_log(['Initializing the filtering step...'], 'debug', logger)

    if params.getboolean('data', 'overwrite'):
        if comm.rank == 0:
            print_and_log(['Reading the input file...'], 'debug', logger)

        data_file_in  = params.get_data_file()
        data_file_out = data_file_in
    else:
        if comm.rank == 0:
            print_and_log(['Overwrite is set to False, so creating a new datafile...'], 'debug', logger)

        if comm.rank == 0:
            print_and_log(['Reading the input file...'], 'debug', logger)

        if os.path.exists(params.get('data', 'data_file_no_overwrite')):
            has_been_created = True
        else:
            has_been_created = False

        if not has_been_created and (filter_done or median_done or artefacts_done):
            if comm.rank == 0:
                print_and_log(['The filtering is done but file not present. See no_edits section'], 'error', logger)
            sys.exit(0)

        if not has_been_created:
            data_file_in = params.get_data_file(source=True, has_been_created=has_been_created)
        else:
            data_file_in = params.get_data_file(source=False, has_been_created=has_been_created)

        if comm.rank == 0:
            print_and_log(['Reading the output file and allocating ressources...'], 'debug', logger)

        description                 = data_file_in.get_description()
        description['data_dtype']   = 'float32'
        description['dtype_offset'] = 0
        description['data_offset']  = 0

        data_file_out = params.get_data_file(is_empty=not has_been_created, params=description)

        if comm.rank == 0:
            print_and_log(['Allocating space for filtered files...'], 'debug', logger)

        if not has_been_created:
            data_file_out.allocate(shape=data_file_in.shape)

        comm.Barrier()

    if clean_artefact:
        if not (os.path.exists(params.get('triggers', 'trig_file')) and os.path.exists(params.get('triggers', 'trig_windows'))):
            if comm.rank == 0:
                print_and_log(['trig_file or trig_windows file can not be found'], 'error', logger)
            sys.exit(0)

    to_write = []

    if do_filter and filter_done:
        do_filter = False
        to_write += ["Filtering has already been done"]
    if remove_median and median_done:
        remove_median = False
        to_write += ["Median over all channels has already been removed"]
    if remove_ground and ground_done:
        remove_ground = False
        to_write += ["Common ground %s has alread been subtracted" %common_ground]

    if comm.rank == 0 and len(to_write) > 0:
        print_and_log(to_write, 'info', logger)

    if params.getboolean('data', 'overwrite'):
        data_file_in.open(mode='r+')
    else:
        data_file_in.open()
        data_file_out.open(mode='r+')

    if do_filter or remove_median or remove_ground:
        if comm.rank == 0:
            if do_filter:
                params.write('noedits', 'filter_done', 'Started')
            if remove_median:
                params.write('noedits', 'median_done', 'Started')
            if remove_ground:
                params.write('noedits', 'ground_done', 'Started')
        filter_file(data_file_in, data_file_out, do_filter, remove_median, remove_ground)

    if comm.rank == 0:
        if do_filter:
            params.write('noedits', 'filter_done', 'True')
        if remove_median:
            params.write('noedits', 'median_done', 'True')
        if remove_ground:
            params.write('noedits', 'ground_done', 'True')

    if clean_artefact and artefacts_done:
        clean_artefact = False
        if comm.rank == 0:
            print_and_log(['Artefacts have already been removed'], 'debug', logger)

    if clean_artefact:
        art_dict   = compute_artefacts(data_file_in)
        if comm.rank == 0:
            params.write('noedits', 'artefacts_done', 'Started')
        remove_artefacts(data_file_out, art_dict)

    if comm.rank == 0:
        if clean_artefact:
            params.write('noedits', 'artefacts_done', 'True')

    data_file_in.close()
    if not params.getboolean('data', 'overwrite'):
        data_file_out.close()

    comm.Barrier()
示例#8
0
def slice_templates(params,
                    to_remove=[],
                    to_merge=[],
                    extension='',
                    input_extension=''):
    """Slice templates in HDF5 file.

    Arguments:
        params
        to_remove: list (optional)
            An array of template indices to remove.
            The default value is [].
        to_merge: list | numpy.ndarray (optional)
            An array of pair of template indices to merge
            (i.e. shape = (nb_merges, 2)).
            The default value is [].
        extension: string (optional)
            The extension to use as output.
            The default value is ''.
        input_extension: string (optional)
            The extension to use as input.
            The default value is ''.
    """

    file_out_suff = params.get('data', 'file_out_suff')

    data_file = params.data_file
    N_e = params.getint('data', 'N_e')
    N_total = params.nb_channels
    hdf5_compress = params.getboolean('data', 'hdf5_compress')
    N_t = params.getint('detection', 'N_t')
    template_shift = params.getint('detection', 'template_shift')

    if comm.rank == 0:
        print_and_log(['Node 0 is slicing templates'], 'debug', logger)
        old_templates = load_data(params,
                                  'templates',
                                  extension=input_extension)
        old_limits = load_data(params, 'limits', extension=input_extension)
        _, N_tm = old_templates.shape
        norm_templates = load_data(params,
                                   'norm-templates',
                                   extension=input_extension)

        # Determine the template indices to delete.
        to_delete = list(to_remove)  # i.e. copy
        if to_merge != []:
            for count in xrange(len(to_merge)):
                remove = to_merge[count][1]
                to_delete += [remove]

        # Determine the indices to keep.
        all_templates = set(numpy.arange(N_tm // 2))
        to_keep = numpy.array(list(all_templates.difference(to_delete)))

        positions = numpy.arange(len(to_keep))

        # Initialize new HDF5 file for templates.
        local_keep = to_keep[positions]
        templates = scipy.sparse.lil_matrix((N_e * N_t, 2 * len(to_keep)),
                                            dtype=numpy.float32)
        hfilename = file_out_suff + '.templates{}.hdf5'.format('-new')
        hfile = h5py.File(hfilename, 'w', libver='earliest')
        norms = hfile.create_dataset('norms',
                                     shape=(2 * len(to_keep), ),
                                     dtype=numpy.float32,
                                     chunks=True)
        limits = hfile.create_dataset('limits',
                                      shape=(len(to_keep), 2),
                                      dtype=numpy.float32,
                                      chunks=True)
        # For each index to keep.
        for count, keep in zip(positions, local_keep):
            # Copy template.
            templates[:, count] = old_templates[:, keep]
            templates[:,
                      count + len(to_keep)] = old_templates[:,
                                                            keep + N_tm // 2]
            # Copy norm.
            norms[count] = norm_templates[keep]
            norms[count + len(to_keep)] = norm_templates[keep + N_tm // 2]
            # Copy limits.
            if to_merge == []:
                new_limits = old_limits[keep]
            else:
                subset = numpy.where(to_merge[:, 0] == keep)[0]
                if len(subset) > 0:
                    # Index to keep is involved in merge(s) and limits need to
                    # be updated.
                    idx = numpy.unique(to_merge[subset].flatten())
                    ratios = norm_templates[idx] / norm_templates[keep]
                    new_limits = [
                        numpy.min(ratios * old_limits[idx][:, 0]),
                        numpy.max(ratios * old_limits[idx][:, 1])
                    ]
                else:
                    new_limits = old_limits[keep]
            limits[count] = new_limits

        # Copy templates to file.
        templates = templates.tocoo()
        if hdf5_compress:
            hfile.create_dataset('temp_x',
                                 data=templates.row,
                                 compression='gzip')
            hfile.create_dataset('temp_y',
                                 data=templates.col,
                                 compression='gzip')
            hfile.create_dataset('temp_data',
                                 data=templates.data,
                                 compression='gzip')
        else:
            hfile.create_dataset('temp_x', data=templates.row)
            hfile.create_dataset('temp_y', data=templates.col)
            hfile.create_dataset('temp_data', data=templates.data)
        hfile.create_dataset('temp_shape',
                             data=numpy.array([N_e, N_t, 2 * len(to_keep)],
                                              dtype=numpy.int32))
        hfile.close()

        # Rename output filename.
        temporary_path = hfilename
        output_path = file_out_suff + '.templates{}.hdf5'.format(extension)
        if os.path.exists(output_path):
            os.remove(output_path)
        shutil.move(temporary_path, output_path)
    else:
        to_keep = numpy.array([])

    return to_keep
示例#9
0
def slice_clusters(params,
                   result,
                   to_remove=[],
                   to_merge=[],
                   extension='',
                   input_extension='',
                   light=False,
                   method='safe'):
    """Slice clusters in HDF5 templates.

    Arguments:
        params
        to_remove: list (optional)
        to_merge: list | numpy.ndarray (optional)
        extension: string (optional)
            The default value is ''.
        input_extension: string (optional)
            The default value is ''.
        light: boolean (optional)
    """

    file_out_suff = params.get('data', 'file_out_suff')
    data_file = params.data_file
    N_e = params.getint('data', 'N_e')
    N_total = params.nb_channels
    hdf5_compress = params.getboolean('data', 'hdf5_compress')
    N_t = params.getint('detection', 'N_t')
    template_shift = params.getint('detection', 'template_shift')

    if comm.rank == 0:

        print_and_log(['Node 0 is slicing clusters'], 'debug', logger)
        old_templates = load_data(params,
                                  'templates',
                                  extension=input_extension)
        _, N_tm = old_templates.shape

        # Determine the template indices to delete.
        to_delete = list(to_remove)
        if to_merge != []:
            for count in xrange(len(to_merge)):
                remove = to_merge[count][1]
                to_delete += [remove]

        # Determine the indices to keep.
        all_templates = set(numpy.arange(N_tm // 2))
        to_keep = numpy.array(list(all_templates.difference(to_delete)))

        all_elements = [[] for i in xrange(N_e)]
        for target in numpy.unique(to_delete):
            elec = result['electrodes'][target]
            nic = target - numpy.where(result['electrodes'] == elec)[0][0]
            mask = result['clusters_' + str(elec)] > -1
            tmp = numpy.unique(result['clusters_' + str(elec)][mask])
            all_elements[elec] += list(
                numpy.where(result['clusters_' + str(elec)] == tmp[nic])[0])

        myfilename = file_out_suff + '.clusters{}.hdf5'.format(input_extension)
        myfile = h5py.File(myfilename, 'r', libver='earliest')

        for elec in xrange(N_e):
            if not light:
                result['data_' + str(elec)] = numpy.delete(result['data_' +
                                                                  str(elec)],
                                                           all_elements[elec],
                                                           axis=0)
                result['clusters_' + str(elec)] = numpy.delete(
                    result['clusters_' + str(elec)], all_elements[elec])
                result['times_' + str(elec)] = numpy.delete(
                    result['times_' + str(elec)], all_elements[elec])
                result['peaks_' + str(elec)] = numpy.delete(
                    result['peaks_' + str(elec)], all_elements[elec])
            else:
                result['clusters_' + str(elec)] = numpy.delete(
                    result['clusters_' + str(elec)], all_elements[elec])
                data = myfile.get('data_' + str(elec))[:]
                result['data_' + str(elec)] = numpy.delete(data,
                                                           all_elements[elec],
                                                           axis=0)
                data = myfile.get('times_' + str(elec))[:]
                result['times_' + str(elec)] = numpy.delete(
                    data, all_elements[elec])
                data = myfile.get('peaks_' + str(elec))[:]
                result['peaks_' + str(elec)] = numpy.delete(
                    data, all_elements[elec])

        myfile.close()
        if method == 'safe':
            result['electrodes'] = numpy.delete(result['electrodes'],
                                                numpy.unique(to_delete))
        elif method == 'new':
            result['electrodes'] = result['electrodes'][to_keep]
        else:
            raise ValueError("Unexpected method value: {}".format(method))

        cfilename = file_out_suff + '.clusters{}.hdf5'.format('-new')
        cfile = h5py.File(cfilename, 'w', libver='earliest')
        to_write = ['data_', 'clusters_', 'times_', 'peaks_']
        for ielec in xrange(N_e):
            write_datasets(cfile,
                           to_write,
                           result,
                           ielec,
                           compression=hdf5_compress)
        write_datasets(cfile, ['electrodes'], result)
        cfile.close()

        # Rename output file.
        temporary_path = cfilename
        output_path = file_out_suff + '.clusters{}.hdf5'.format(extension)
        if os.path.exists(output_path):
            os.remove(output_path)
        shutil.move(temporary_path, output_path)

    return
示例#10
0
def main(params, nb_cpu, nb_gpu, use_gpu):
    # Part 1: Whitening
    numpy.random.seed(420)
    #params         = detect_memory(params)
    logger = init_logging(params.logfile)
    logger = logging.getLogger('circus.whitening')
    #################################################################
    data_file = params.data_file
    data_file.open()
    N_e = params.getint('data', 'N_e')
    hdf5_compress = params.getboolean('data', 'hdf5_compress')
    N_total = params.nb_channels
    N_t = params.getint('detection', 'N_t')
    dist_peaks = params.getint('detection', 'dist_peaks')
    template_shift = params.getint('detection', 'template_shift')
    file_out_suff = params.get('data', 'file_out_suff')
    spike_thresh = params.getfloat('detection', 'spike_thresh')
    matched_filter = params.getboolean('detection', 'matched-filter')
    matched_thresh = params.getfloat('detection', 'matched_thresh')
    sign_peaks = params.get('detection', 'peaks')
    do_temporal_whitening = params.getboolean('whitening', 'temporal')
    do_spatial_whitening = params.getboolean('whitening', 'spatial')
    chunk_size = params.getint('whitening', 'chunk_size')
    plot_path = os.path.join(params.get('data', 'file_out_suff'), 'plots')
    nodes, edges = get_nodes_and_edges(params)
    safety_time = params.getint('whitening', 'safety_time')
    safety_space = params.getboolean('whitening', 'safety_space')
    nb_temp_white = min(max(20, comm.size), N_e)
    max_silence_1 = int(20 * params.rate // comm.size)
    max_silence_2 = 5000
    inv_nodes = numpy.zeros(N_total, dtype=numpy.int32)
    inv_nodes[nodes] = numpy.argsort(nodes)
    template_shift_2 = 2 * template_shift
    #################################################################

    if comm.rank == 0:
        print_and_log(
            ["Analyzing data to get whitening matrices and thresholds..."],
            'default', logger)

    if use_gpu:
        import cudamat as cmt
        ## Need to properly handle multi GPU per MPI nodes?
        if nb_gpu > nb_cpu:
            gpu_id = int(comm.rank // nb_cpu)
        else:
            gpu_id = 0
        cmt.cuda_set_device(gpu_id)
        cmt.init()
        cmt.cuda_sync_threads()

    nb_chunks, last_chunk_len = data_file.analyze(chunk_size)

    if nb_chunks < comm.size:

        res = io.data_stats(params, show=False)
        chunk_size = int(res * params.rate // comm.size)
        if comm.rank == 0:
            print_and_log(
                ["Too much cores, automatically resizing the data chunks"],
                'debug', logger)

        nb_chunks, last_chunk_len = data_file.analyze(chunk_size)

    # I guess this is more relevant, to take signals from all over the recordings
    all_chunks = numpy.random.permutation(
        numpy.arange(nb_chunks, dtype=numpy.int32))
    all_electrodes = numpy.random.permutation(N_e)

    for gidx in [all_chunks[comm.rank]]:

        #print "Node", comm.rank, "is analyzing chunk", gidx,  "/", nb_chunks, " ..."
        local_chunk, t_offset = data_file.get_data(gidx,
                                                   chunk_size,
                                                   nodes=nodes)
        local_shape = len(local_chunk)

        #print "Node", comm.rank, "computes the median absolute deviations in a random chunk"
        thresholds = numpy.zeros(N_e, dtype=numpy.float32)
        for i in xrange(N_e):
            u = numpy.median(local_chunk[:, i], 0)
            thresholds[i] = numpy.median(numpy.abs(local_chunk[:, i] - u), 0)
        gdata = gather_array(thresholds, comm)
        if comm.rank == 0:
            gdata = gdata.reshape((comm.size, N_e))
            thresholds = numpy.mean(gdata, 0)
            bfile = h5py.File(file_out_suff + '.basis.hdf5',
                              'w',
                              libver='earliest')
            io.write_datasets(bfile, ['thresholds'],
                              {'thresholds': thresholds},
                              compression=hdf5_compress)
            bfile.close()
        comm.Barrier()
        thresholds = io.load_data(params, 'thresholds')

        #print "Extracting the peaks..."
        local_peaktimes = numpy.zeros(0, dtype=numpy.uint32)
        for i in xrange(N_e):
            peaktimes = algo.detect_peaks(numpy.abs(local_chunk[:, i]),
                                          thresholds[i],
                                          valley=False,
                                          mpd=dist_peaks)
            local_peaktimes = numpy.concatenate((local_peaktimes, peaktimes))

        local_peaktimes = numpy.unique(local_peaktimes)

        #print "Removing the useless borders..."
        local_borders = (template_shift, local_shape - template_shift)
        idx = (local_peaktimes >= local_borders[0]) & (local_peaktimes <
                                                       local_borders[1])
        local_peaktimes = numpy.compress(idx, local_peaktimes)

        if len(local_peaktimes) > 0:

            diff_times = local_peaktimes[-1] - local_peaktimes[0]
            all_times = numpy.zeros((N_e, diff_times + 1), dtype=numpy.bool)
            min_times = numpy.maximum(
                local_peaktimes - local_peaktimes[0] - safety_time, 0)
            max_times = numpy.minimum(
                local_peaktimes - local_peaktimes[0] + safety_time + 1,
                diff_times)
            argmax_peak = numpy.random.permutation(
                numpy.arange(len(local_peaktimes)))
            all_idx = numpy.take(local_peaktimes, argmax_peak)

            #print "Selection of the peaks with spatio-temporal masks..."
            for idx, peak in zip(argmax_peak, all_idx):
                elec = numpy.argmax(numpy.abs(local_chunk[peak]))
                indices = numpy.take(inv_nodes, edges[nodes[elec]])
                if safety_space:
                    all_times[indices, min_times[idx]:max_times[idx]] = True
                else:
                    all_times[elec, min_times[idx]:max_times[idx]] = True
        else:
            all_times = numpy.zeros((N_e, len(local_chunk)), dtype=numpy.bool)

    if do_temporal_whitening:

        local_res_temp = []

        for elec in all_electrodes[numpy.arange(comm.rank, nb_temp_white,
                                                comm.size)]:
            res = numpy.zeros((0, N_t), dtype=numpy.float32)
            scount = 0
            indices = numpy.take(inv_nodes, edges[nodes[elec]])
            all_times_elec = numpy.any(numpy.take(all_times, indices, axis=0),
                                       0)
            esubset = numpy.where(all_times_elec == False)[0]
            bound = len(esubset) - N_t
            while (scount < bound) and (len(res) < max_silence_2):
                myslice = esubset[scount:scount + N_t]
                if numpy.all((myslice - esubset[scount]) == numpy.arange(N_t)):
                    scount += N_t
                    res = numpy.vstack((res, local_chunk[myslice, elec]))
                else:
                    scount += 1
            if len(res) > 5:
                local_res_temp += [numpy.cov(res.T)]

        nb_elecs = numpy.array([len(local_res_temp)], dtype=numpy.float32)
        local_res_temp = numpy.array(local_res_temp, dtype=numpy.float32)
        if len(local_res_temp) == 0:
            local_res_temp = numpy.zeros(0, dtype=numpy.float32)
        else:
            local_res_temp = numpy.sum(local_res_temp, 0)
        all_res_temp = gather_array(local_res_temp.ravel(), comm, 0, 1)
        all_elecs = gather_array(nb_elecs, comm, 0, 1)

    if do_spatial_whitening:

        local_res_spac = numpy.zeros((N_e, N_e), dtype=numpy.float32)
        local_silences = []

        for elec in numpy.arange(comm.rank, N_e, comm.size):
            indices = numpy.take(inv_nodes, edges[nodes[elec]])
            all_times_elec = numpy.any(numpy.take(all_times, indices, axis=0),
                                       0)
            esubset = numpy.where(all_times_elec == False)[0]
            local_data = local_chunk[esubset][:, indices]
            local_whitening = get_whitening_matrix(local_data).astype(
                numpy.float32)
            pos = numpy.where(elec == indices)[0]
            local_res_spac[elec, indices] = local_whitening[pos]
            local_silences += [len(esubset)]

        all_res_spac = gather_array(local_res_spac.ravel(), comm, 0, 1)
        all_silences = gather_array(
            numpy.array(local_silences, dtype=numpy.int32), comm, 0, 1,
            'uint32')

    if comm.rank == 0:

        to_write = {}

        if do_temporal_whitening:
            try:
                nb_silences = numpy.sum(all_elecs > 0)
                all_res_temp = all_res_temp.reshape((nb_silences, N_t**2))
            except Exception:
                print_and_log([
                    "No silent periods detected: something wrong with the parameters?"
                ], 'error', logger)
            all_res_temp = numpy.sum(all_res_temp, 0)
            all_res_temp = all_res_temp.reshape(
                (N_t, N_t)) / numpy.sum(all_elecs)
            temporal_whitening = get_whitening_matrix(
                all_res_temp.astype(numpy.double),
                fudge=1e-3)[template_shift].astype(numpy.float32)
            temporal_whitening /= temporal_whitening.sum()
            to_write['temporal'] = temporal_whitening
            have_nans = numpy.sum(numpy.isnan(temporal_whitening))

            if have_nans > 0:
                temporal_whitening = numpy.zeros(N_t, dtype=numpy.float32)
                temporal_whitening[N_t // 2] = 1
                to_write['temporal'] = temporal_whitening
                print_and_log(
                    ["Disabling temporal whitening because of NaNs found"],
                    'info', logger)

        if do_spatial_whitening:
            all_res_spac = all_res_spac.reshape(comm.size, N_e, N_e)
            spatial_whitening = numpy.sum(all_res_spac, 0)
            to_write['spatial'] = spatial_whitening

            print_and_log([
                "Found %gs without spikes for whitening matrices..." %
                (numpy.mean(all_silences) / params.rate)
            ], 'default', logger)

            have_nans = numpy.sum(numpy.isnan(spatial_whitening))

            if have_nans > 0:
                spatial_whitening = numpy.eye(spatial_whitening.shape[0],
                                              dtype=numpy.float32)
                to_write['spatial'] = spatial_whitening
                print_and_log(
                    ["Disabling spatial whitening because of NaNs found"],
                    'info', logger)

        bfile = h5py.File(file_out_suff + '.basis.hdf5',
                          'r+',
                          libver='earliest')
        io.write_datasets(bfile,
                          to_write.keys(),
                          to_write,
                          compression=hdf5_compress)
        bfile.close()

    comm.Barrier()

    if do_spatial_whitening or do_temporal_whitening:

        if comm.rank == 0:
            print_and_log(
                ["Because of whitening, need to recompute the thresholds..."],
                'default', logger)

        if do_spatial_whitening:
            spatial_whitening = io.load_data(params, 'spatial_whitening')
            if use_gpu:
                spatial_whitening = cmt.CUDAMatrix(spatial_whitening,
                                                   copy_on_host=False)
        if do_temporal_whitening:
            temporal_whitening = io.load_data(params, 'temporal_whitening')

        for gidx in [all_chunks[comm.rank]]:
            local_chunk, t_offset = data_file.get_data(gidx,
                                                       chunk_size,
                                                       nodes=nodes)
            local_shape = len(local_chunk)

            if do_spatial_whitening:
                if use_gpu:
                    local_chunk = cmt.CUDAMatrix(local_chunk,
                                                 copy_on_host=False)
                    local_chunk = local_chunk.dot(spatial_whitening).asarray()
                else:
                    local_chunk = numpy.dot(local_chunk, spatial_whitening)
            if do_temporal_whitening:
                local_chunk = scipy.ndimage.filters.convolve1d(
                    local_chunk, temporal_whitening, axis=0, mode='constant')

            thresholds = numpy.zeros(N_e, dtype=numpy.float32)
            for i in xrange(N_e):
                u = numpy.median(local_chunk[:, i], 0)
                thresholds[i] = numpy.median(numpy.abs(local_chunk[:, i] - u),
                                             0)
            gdata = gather_array(thresholds, comm)
            if comm.rank == 0:
                gdata = gdata.reshape((comm.size, N_e))
                thresholds = numpy.mean(gdata, 0)
                bfile = h5py.File(file_out_suff + '.basis.hdf5',
                                  'r+',
                                  libver='earliest')
                bfile.pop('thresholds')
                io.write_datasets(bfile, ['thresholds'],
                                  {'thresholds': thresholds},
                                  compression=hdf5_compress)
                bfile.close()
            comm.Barrier()

    #if comm.rank == 0:
    #if not os.path.exists(plot_path):
    #    os.makedirs(plot_path)
    #N_elec = min(int(numpy.sqrt(data_file.N_e)), 5)
    #plot.view_fit(filename, t_start=0, t_stop=1, fit_on=False, square=True,
    #              n_elec=N_elec, save=[plot_path, 'electrodes'])

    # Part 2: Basis
    numpy.random.seed(422)

    #################################################################
    file_out = params.get('data', 'file_out')
    alignment = params.getboolean('detection', 'alignment')
    isolation = params.getboolean('detection', 'isolation')
    over_factor = float(params.getint('detection', 'oversampling_factor'))
    spike_thresh = params.getfloat('detection', 'spike_thresh')
    nodes, edges = get_nodes_and_edges(params)
    do_temporal_whitening = params.getboolean('whitening', 'temporal')
    do_spatial_whitening = params.getboolean('whitening', 'spatial')
    chunk_size = params.getint('data', 'chunk_size')
    safety_time = params.getint('whitening', 'safety_time')
    max_elts_elec = params.getint('whitening', 'max_elts')
    output_dim = params.getfloat('whitening', 'output_dim')
    inv_nodes = numpy.zeros(N_total, dtype=numpy.int32)
    inv_nodes[nodes] = numpy.argsort(nodes)
    if sign_peaks == 'both':
        max_elts_elec *= 2
    nb_elts = int(
        params.getfloat('whitening', 'nb_elts') * N_e * max_elts_elec)

    ignore_dead_times = params.getboolean('triggers', 'ignore_times')
    if ignore_dead_times:
        all_dead_times = get_dead_times(params)
    #################################################################

    if comm.rank == 0:
        print_and_log(["Searching spikes to construct the PCA basis..."],
                      'default', logger)

    nb_chunks, last_chunk_len = data_file.analyze(chunk_size)

    if nb_chunks < comm.size:

        res = io.data_stats(params, show=False)
        chunk_size = int(res * params.rate // comm.size)
        if comm.rank == 0:
            print_and_log(
                ["Too much cores, automatically resizing the data chunks"],
                'debug', logger)

        nb_chunks, last_chunk_len = data_file.analyze(chunk_size)

    groups = {}
    for i in xrange(N_e):
        groups[i] = 0

    # I guess this is more relevant, to take signals from all over the recordings
    all_chunks = numpy.random.permutation(
        numpy.arange(nb_chunks, dtype=numpy.int32))
    max_elts_elec //= comm.size
    nb_elts //= comm.size

    elt_count_pos = 0
    elt_count_neg = 0

    if sign_peaks in ['positive', 'both']:
        elts_pos = numpy.zeros((N_t, nb_elts), dtype=numpy.float32)
    if sign_peaks in ['negative', 'both']:
        elts_neg = numpy.zeros((N_t, nb_elts), dtype=numpy.float32)

    chunks_to_load = all_chunks[comm.rank::comm.size]

    thresholds = io.load_data(params, 'thresholds')
    mads = io.load_data(params, 'mads')

    if alignment:
        cdata = numpy.linspace(-template_shift, template_shift,
                               int(over_factor * N_t))
        xdata = numpy.arange(-template_shift_2, template_shift_2 + 1)
        xoff = len(cdata) / 2.

    if isolation:
        yoff = numpy.array(range(0, N_t // 4) + range(3 * N_t // 4, N_t))

    to_explore = xrange(comm.rank, nb_chunks, comm.size)

    if comm.rank == 0:
        to_explore = get_tqdm_progressbar(to_explore)

    for gcount, gidx in enumerate(to_explore):

        gidx = all_chunks[gidx]

        if ((elt_count_pos + elt_count_neg) < nb_elts):
            #print "Node", comm.rank, "is analyzing chunk", gidx, "/", nb_chunks, " ..."
            local_chunk, t_offset = data_file.get_data(gidx,
                                                       chunk_size,
                                                       nodes=nodes)
            local_shape = len(local_chunk)

            if do_spatial_whitening:
                if use_gpu:
                    local_chunk = cmt.CUDAMatrix(local_chunk,
                                                 copy_on_host=False)
                    local_chunk = local_chunk.dot(spatial_whitening).asarray()
                else:
                    local_chunk = numpy.dot(local_chunk, spatial_whitening)
            if do_temporal_whitening:
                local_chunk = scipy.ndimage.filters.convolve1d(
                    local_chunk, temporal_whitening, axis=0, mode='constant')

            #print "Extracting the peaks..."
            all_peaktimes = numpy.zeros(0, dtype=numpy.uint32)
            all_extremas = numpy.zeros(0, dtype=numpy.uint32)

            for i in xrange(N_e):

                if sign_peaks == 'negative':
                    peaktimes = algo.detect_peaks(local_chunk[:, i],
                                                  thresholds[i],
                                                  valley=True,
                                                  mpd=dist_peaks)
                elif sign_peaks == 'positive':
                    peaktimes = algo.detect_peaks(local_chunk[:, i],
                                                  thresholds[i],
                                                  valley=False,
                                                  mpd=dist_peaks)
                elif sign_peaks == 'both':
                    peaktimes = algo.detect_peaks(numpy.abs(local_chunk[:, i]),
                                                  thresholds[i],
                                                  valley=False,
                                                  mpd=dist_peaks)
                all_peaktimes = numpy.concatenate((all_peaktimes, peaktimes))
                all_extremas = numpy.concatenate(
                    (all_extremas,
                     i * numpy.ones(len(peaktimes), dtype=numpy.uint32)))

            #print "Removing the useless borders..."
            if alignment:
                local_borders = (template_shift_2,
                                 local_shape - template_shift_2)
            else:
                local_borders = (template_shift, local_shape - template_shift)
            idx = (all_peaktimes >= local_borders[0]) & (all_peaktimes <
                                                         local_borders[1])
            all_peaktimes = numpy.compress(idx, all_peaktimes)
            all_extremas = numpy.compress(idx, all_extremas)

            local_peaktimes = numpy.unique(all_peaktimes)

            if ignore_dead_times:
                indices = numpy.searchsorted(
                    all_dead_times, [t_offset, t_offset + local_shape])
                if indices[0] != indices[1]:
                    local_peaktimes = numpy.array(
                        list(
                            set(local_peaktimes + t_offset).difference(
                                all_dead_times[indices[0]:indices[1]])),
                        dtype=numpy.uint32) - t_offset
                    local_peaktimes = numpy.sort(local_peaktimes)

            if len(local_peaktimes) > 0:

                diff_times = local_peaktimes[-1] - local_peaktimes[0]
                all_times = numpy.zeros((N_e, diff_times + 1),
                                        dtype=numpy.bool)
                min_times = numpy.maximum(
                    local_peaktimes - local_peaktimes[0] - safety_time, 0)
                max_times = numpy.minimum(
                    local_peaktimes - local_peaktimes[0] + safety_time + 1,
                    diff_times)

                n_times = len(local_peaktimes)
                argmax_peak = numpy.random.permutation(numpy.arange(n_times))
                all_idx = numpy.take(local_peaktimes, argmax_peak)

                #print "Selection of the peaks with spatio-temporal masks..."
                for midx, peak in zip(argmax_peak, all_idx):
                    if (elt_count_neg + elt_count_pos) == nb_elts:
                        break

                    if sign_peaks == 'negative':
                        elec = numpy.argmin(local_chunk[peak])
                        negative_peak = True
                    elif sign_peaks == 'positive':
                        elec = numpy.argmax(local_chunk[peak])
                        negative_peak = False
                    elif sign_peaks == 'both':
                        if N_e == 1:
                            if local_chunk[peak] < 0:
                                negative_peak = True
                            elif local_chunk[peak] > 0:
                                negative_peak = False
                            elec = 0
                        else:
                            if numpy.abs(numpy.max(
                                    local_chunk[peak])) > numpy.abs(
                                        numpy.min(local_chunk[peak])):
                                elec = numpy.argmax(local_chunk[peak])
                                negative_peak = False
                            else:
                                elec = numpy.argmin(local_chunk[peak])
                                negative_peak = True

                    indices = numpy.take(inv_nodes, edges[nodes[elec]])
                    myslice = all_times[indices,
                                        min_times[midx]:max_times[midx]]
                    is_local_extrema = elec in all_extremas[all_peaktimes ==
                                                            peak]
                    if is_local_extrema and not myslice.any():
                        upper_bounds = max_elts_elec

                        if groups[elec] < upper_bounds:

                            if not alignment:
                                sub_mat = local_chunk[peak -
                                                      template_shift:peak +
                                                      template_shift + 1, elec]

                            elif alignment:
                                ydata = local_chunk[peak -
                                                    template_shift_2:peak +
                                                    template_shift_2 + 1, elec]
                                #try:
                                #   f = scipy.interpolate.UnivariateSpline(xdata, ydata, s=xdata.size * mads[elec]**2, k=3)
                                #except Exception:
                                f = scipy.interpolate.UnivariateSpline(xdata,
                                                                       ydata,
                                                                       s=0,
                                                                       k=3)
                                if negative_peak:
                                    rmin = (numpy.argmin(f(cdata)) -
                                            xoff) / over_factor
                                else:
                                    rmin = (numpy.argmax(f(cdata)) -
                                            xoff) / over_factor
                                ddata = numpy.linspace(rmin - template_shift,
                                                       rmin + template_shift,
                                                       N_t)

                                sub_mat = f(ddata).astype(numpy.float32)

                            if isolation:
                                to_accept = numpy.all(
                                    numpy.max(numpy.abs(sub_mat[yoff])) <=
                                    thresholds[elec])
                            else:
                                to_accept = True

                            if to_accept:
                                if negative_peak:
                                    elts_neg[:, elt_count_neg] = sub_mat
                                else:
                                    elts_pos[:, elt_count_pos] = sub_mat

                                if negative_peak:
                                    elt_count_neg += 1
                                else:
                                    elt_count_pos += 1

                        groups[elec] += 1
                        all_times[indices,
                                  min_times[midx]:max_times[midx]] = True

    sys.stderr.flush()

    if isolation:
        print_and_log([
            "Node %d has collected %d isolated waveforms" %
            (comm.rank, elt_count_pos + elt_count_neg)
        ], 'debug', logger)
    else:
        print_and_log([
            "Node %d has collected %d waveforms" %
            (comm.rank, elt_count_pos + elt_count_neg)
        ], 'debug', logger)

    if sign_peaks in ['negative', 'both']:
        gdata_neg = gather_array(elts_neg[:, :elt_count_neg].T, comm, 0, 1)
    if sign_peaks in ['positive', 'both']:
        gdata_pos = gather_array(elts_pos[:, :elt_count_pos].T, comm, 0, 1)

    if comm.rank == 0:
        #DO PCA on elts and store the basis obtained.

        nb_waveforms = 0
        if sign_peaks in ['negative', 'both']:
            nb_waveforms += gdata_neg.shape[0]
        if sign_peaks in ['positive', 'both']:
            nb_waveforms += gdata_pos.shape[0]

        if isolation:
            print_and_log([
                "Found %d isolated waveforms over %d requested" %
                (nb_waveforms, int(nb_elts * comm.size))
            ], 'default', logger)
        else:
            print_and_log([
                "Found %d waveforms over %d requested" %
                (nb_waveforms, int(nb_elts * comm.size))
            ], 'default', logger)

        if nb_waveforms == 0:
            print_and_log(
                ['No waveforms found! Are the data properly loaded??'],
                'error', logger)

        res = {}
        if sign_peaks in ['negative', 'both']:
            if len(gdata_neg) > 0:
                pca = PCA(output_dim)
                pca.fit(gdata_neg)
                res['proj'] = pca.components_.T.astype(numpy.float32)
            else:
                res['proj'] = numpy.identity(int(output_dim),
                                             dtype=numpy.float32)
            res['rec'] = res['proj'].T
            res['waveform'] = numpy.median(gdata_neg, 0)
            idx = numpy.random.permutation(numpy.arange(
                gdata_neg.shape[0]))[:1000]
            res['waveforms'] = gdata_neg[idx, :]
        if sign_peaks in ['positive', 'both']:
            if len(gdata_pos) > 0:
                pca = PCA(output_dim)
                pca.fit(gdata_pos)
                res['proj_pos'] = pca.components_.T.astype(numpy.float32)
            else:
                res['proj_pos'] = numpy.identity(int(output_dim),
                                                 dtype=numpy.float32)
            res['rec_pos'] = res['proj_pos'].T
            res['waveform_pos'] = numpy.median(gdata_pos, 0)
            idx = numpy.random.permutation(numpy.arange(
                gdata_pos.shape[0]))[:1000]
            res['waveforms_pos'] = gdata_pos[idx, :]

        bfile = h5py.File(file_out_suff + '.basis.hdf5',
                          'r+',
                          libver='earliest')
        io.write_datasets(bfile, res.keys(), res, compression=hdf5_compress)
        if sign_peaks == 'positive':
            print_and_log([
                "A basis with %s dimensions has been built" %
                res['proj_pos'].shape[1]
            ], 'info', logger)
        elif sign_peaks == 'negative':
            print_and_log([
                "A basis with %s dimensions has been built" %
                res['proj'].shape[1]
            ], 'info', logger)
        elif sign_peaks == 'both':
            print_and_log([
                "Two basis with %s dimensions has been built" %
                res['proj'].shape[1]
            ], 'info', logger)

        bfile.close()

    comm.Barrier()

    if matched_filter:

        if comm.rank == 0:
            print_and_log([
                "Because of matched filters, need to recompute the thresholds..."
            ], 'default', logger)

        if do_spatial_whitening:
            spatial_whitening = io.load_data(params, 'spatial_whitening')
            if use_gpu:
                spatial_whitening = cmt.CUDAMatrix(spatial_whitening,
                                                   copy_on_host=False)
        if do_temporal_whitening:
            temporal_whitening = io.load_data(params, 'temporal_whitening')

        if sign_peaks in ['negative', 'both']:
            waveform_neg = io.load_data(params, 'waveform')
            waveform_neg /= (numpy.abs(numpy.sum(waveform_neg)) *
                             len(waveform_neg))
        if sign_peaks in ['positive', 'both']:
            waveform_pos = io.load_data(params, 'waveform-pos')
            waveform_pos /= (numpy.abs(numpy.sum(waveform_pos)) *
                             len(waveform_pos))

        for gidx in [all_chunks[comm.rank]]:
            local_chunk, t_offset = data_file.get_data(gidx,
                                                       chunk_size,
                                                       nodes=nodes)
            local_shape = len(local_chunk)

            if do_spatial_whitening:
                if use_gpu:
                    local_chunk = cmt.CUDAMatrix(local_chunk,
                                                 copy_on_host=False)
                    local_chunk = local_chunk.dot(spatial_whitening).asarray()
                else:
                    local_chunk = numpy.dot(local_chunk, spatial_whitening)
            if do_temporal_whitening:
                local_chunk = scipy.ndimage.filters.convolve1d(
                    local_chunk, temporal_whitening, axis=0, mode='constant')

            if sign_peaks in ['negative', 'both']:
                tmp_chunk = scipy.ndimage.filters.convolve1d(local_chunk,
                                                             waveform_neg,
                                                             axis=0,
                                                             mode='constant')
                thresholds = numpy.zeros(N_e, dtype=numpy.float32)
                for i in xrange(N_e):
                    u = numpy.median(tmp_chunk[:, i], 0)
                    thresholds[i] = numpy.median(
                        numpy.abs(tmp_chunk[:, i] - u), 0)
                gdata = gather_array(thresholds, comm)
                if comm.rank == 0:
                    gdata = gdata.reshape((comm.size, N_e))
                    thresholds = numpy.mean(gdata, 0)
                    bfile = h5py.File(file_out_suff + '.basis.hdf5',
                                      'r+',
                                      libver='earliest')
                    io.write_datasets(bfile, ['matched_thresholds'],
                                      {'matched_thresholds': thresholds},
                                      compression=hdf5_compress)
                    bfile.close()
                comm.Barrier()

            if sign_peaks in ['positive', 'both']:
                tmp_chunk = scipy.ndimage.filters.convolve1d(local_chunk,
                                                             waveform_pos,
                                                             axis=0,
                                                             mode='constant')
                thresholds = numpy.zeros(N_e, dtype=numpy.float32)
                for i in xrange(N_e):
                    u = numpy.median(tmp_chunk[:, i], 0)
                    thresholds[i] = numpy.median(
                        numpy.abs(tmp_chunk[:, i] - u), 0)
                gdata = gather_array(thresholds, comm)
                if comm.rank == 0:
                    gdata = gdata.reshape((comm.size, N_e))
                    thresholds = numpy.mean(gdata, 0)
                    bfile = h5py.File(file_out_suff + '.basis.hdf5',
                                      'r+',
                                      libver='earliest')
                    io.write_datasets(bfile, ['matched_thresholds_pos'],
                                      {'matched_thresholds_pos': thresholds},
                                      compression=hdf5_compress)
                    bfile.close()
                comm.Barrier()

    data_file.close()
示例#11
0
def main(params, nb_cpu, nb_gpu, use_gpu, extension):

    logger         = init_logging(params.logfile)
    logger         = logging.getLogger('circus.converting')
    data_file      = params.data_file
    file_out_suff  = params.get('data', 'file_out_suff')
    probe          = params.probe
    output_path    = params.get('data', 'file_out_suff') + extension + '.GUI'
    N_e            = params.getint('data', 'N_e')
    N_t            = params.getint('detection', 'N_t')
    erase_all      = params.getboolean('converting', 'erase_all')
    export_pcs     = params.get('converting', 'export_pcs')
    export_all     = params.getboolean('converting', 'export_all')
    if export_all and not params.getboolean('fitting', 'collect_all'):
        if comm.rank == 0:
            print_and_log(['Export unfitted spikes only if [fitting] collect_all is True'], 'error', logger)
        sys.exit(1)

    def generate_mapping(probe):
        p         = {}
        positions = []
        nodes     = []
        for key in probe['channel_groups'].keys():
            p.update(probe['channel_groups'][key]['geometry'])
            nodes     +=  probe['channel_groups'][key]['channels']
            positions += [p[channel] for channel in probe['channel_groups'][key]['channels']]
        idx       = numpy.argsort(nodes)
        positions = numpy.array(positions)[idx]
        return positions

    def get_max_loc_channel(params):
        nodes, edges    = get_nodes_and_edges(params)
        max_loc_channel = 0
        for key in edges.keys():
            if len(edges[key]) > max_loc_channel:
                max_loc_channel = len(edges[key])
        return max_loc_channel

    def write_results(path, params, extension):
        result     = io.get_results(params, extension)
        spikes     = numpy.zeros(0, dtype=numpy.uint64)
        clusters   = numpy.zeros(0, dtype=numpy.uint32)
        amplitudes = numpy.zeros(0, dtype=numpy.double)
        N_tm       = len(result['spiketimes'])
        for key in result['spiketimes'].keys():
            temp_id    = int(key.split('_')[-1])
            data       = result['spiketimes'].pop(key).astype(numpy.uint64)
            spikes     = numpy.concatenate((spikes, data))
            data       = result['amplitudes'].pop(key).astype(numpy.double)
            amplitudes = numpy.concatenate((amplitudes, data[:, 0]))
            clusters   = numpy.concatenate((clusters, temp_id*numpy.ones(len(data), dtype=numpy.uint32)))
        
        if export_all:
            print_and_log(["Last %d templates are unfitted spikes on all electrodes" %N_e], 'info', logger)
            garbage = io.load_data(params, 'garbage', extension)
            for key in garbage['gspikes'].keys():
                elec_id    = int(key.split('_')[-1])
                data       = garbage['gspikes'].pop(key).astype(numpy.uint64)
                spikes     = numpy.concatenate((spikes, data))
                amplitudes = numpy.concatenate((amplitudes, numpy.zeros(len(data))))
                clusters   = numpy.concatenate((clusters, (elec_id + N_tm)*numpy.ones(len(data), dtype=numpy.uint32)))                

        idx = numpy.argsort(spikes)
        numpy.save(os.path.join(output_path, 'spike_templates'), clusters[idx])
        numpy.save(os.path.join(output_path, 'spike_times'), spikes[idx])
        numpy.save(os.path.join(output_path, 'amplitudes'), amplitudes[idx])
        return

    def write_templates(path, params, extension):

        max_loc_channel = get_max_loc_channel(params)
        templates       = io.load_data(params, 'templates', extension)
        N_tm            = templates.shape[1]//2
        if export_all:
            to_write    = numpy.zeros((N_tm + N_e, N_t, N_e), dtype=numpy.float32)
            mapping     = numpy.zeros((N_tm + N_e, max_loc_channel), dtype=numpy.int32)            
        else:
            to_write    = numpy.zeros((N_tm, N_t, N_e), dtype=numpy.float32)
            mapping     = numpy.zeros((N_tm, max_loc_channel), dtype=numpy.int32)

        for t in xrange(N_tm):
            tmp  = templates[:, t].toarray().reshape(N_e, N_t).T
            x, y = tmp.nonzero()
            to_write[t, x, y]                = tmp[x, y] 
            nb_loc                           = len(numpy.unique(y))
            mapping[t, numpy.arange(nb_loc)] = numpy.unique(y)

        if export_all:
            for t in xrange(N_tm, N_tm + N_e):
                mapping[t, 0] = N_e

        numpy.save(os.path.join(output_path, 'templates'), to_write.astype(numpy.single))
        numpy.save(os.path.join(output_path, 'templates_ind'), mapping.astype(numpy.double))

        if SPARSE_TEMPLATES:

            n_channels_max = 0
            for t in xrange(N_tm):
                data = numpy.sum(numpy.sum(templates[:, t].toarray().reshape(N_e, N_t), 1) != 0) 
                if data > n_channels_max:
                    n_channels_max = data
            
            to_write_sparse    = numpy.zeros((N_tm, N_t, n_channels_max), dtype=numpy.float32)
            mapping_sparse     = numpy.zeros((N_tm, n_channels_max), dtype=numpy.int32)
            for t in xrange(N_tm):
                tmp                              = templates[:, t].toarray().reshape(N_e, N_t).T
                x, y                             = tmp.nonzero()
                nb_loc                           = len(numpy.unique(y))
                all_positions                    = numpy.zeros(len(y), dtype=numpy.int32)
                all_positions[numpy.unique(y)]   = numpy.arange(nb_loc, dtype=numpy.int32)
                pos                              = all_positions[y]
                to_write_sparse[t, x, pos]       = tmp[x, y] 
                mapping_sparse[t, numpy.arange(nb_loc)] = numpy.unique(y)


            numpy.save(os.path.join(output_path, 'sparse_templates'), to_write_sparse.astype(numpy.single))
            numpy.save(os.path.join(output_path, 'sparse_templates_channels'), mapping_sparse.astype(numpy.uint32))



        return N_tm

    def write_pcs(path, params, extension, mode=0):

        spikes          = numpy.load(os.path.join(output_path, 'spike_times.npy'))
        labels          = numpy.load(os.path.join(output_path, 'spike_templates.npy'))
        max_loc_channel = get_max_loc_channel(params)
        nb_features     = params.getint('whitening', 'output_dim')
        nodes, edges    = get_nodes_and_edges(params)
        N_total         = params.getint('data', 'N_total')
        templates       = io.load_data(params, 'templates', extension)
        N_tm            = templates.shape[1]//2
        if export_all:
            nb_templates = N_tm + N_e
        else:
            nb_templates = N_tm

        pc_features_ind = numpy.zeros((nb_templates, max_loc_channel), dtype=numpy.int32)            
        clusters        = io.load_data(params, 'clusters', extension)
        best_elec       = clusters['electrodes']
        if export_all:
            best_elec = numpy.concatenate((best_elec, numpy.arange(N_e)))
        inv_nodes        = numpy.zeros(N_total, dtype=numpy.int32)
        inv_nodes[nodes] = numpy.argsort(nodes)

        for count, elec in enumerate(best_elec):
            nb_loc                = len(edges[nodes[elec]])
            pc_features_ind[count, numpy.arange(nb_loc)] = inv_nodes[edges[nodes[elec]]]

        basis_proj, basis_rec = io.load_data(params, 'basis')

        to_process = numpy.arange(comm.rank, nb_templates, comm.size)

        all_offsets = numpy.zeros(nb_templates, dtype=numpy.int32)
        for target in xrange(nb_templates):
            if mode == 0:
                all_offsets[target] = len(numpy.where(labels == target)[0])
            elif mode == 1:
                all_offsets[target] = min(500, len(numpy.where(labels == target)[0]))

        all_paddings = numpy.concatenate(([0] , numpy.cumsum(all_offsets)))
        total_pcs   = numpy.sum(all_offsets)

        pc_file     = os.path.join(output_path, 'pc_features.npy')
        pc_file_ids = os.path.join(output_path, 'pc_feature_spike_ids.npy')

        from numpy.lib.format import open_memmap

        if comm.rank == 0:
            pc_features = open_memmap(pc_file, shape=(total_pcs, nb_features, max_loc_channel), dtype=numpy.float32, mode='w+')
            if mode == 1:
                pc_ids = open_memmap(pc_file_ids, shape=(total_pcs, ), dtype=numpy.int32, mode='w+')

        comm.Barrier()
        pc_features = open_memmap(pc_file, mode='r+')
        if mode == 1:
            pc_ids = open_memmap(pc_file_ids, mode='r+')

        if comm.rank == 0:
          pbar    = get_progressbar(len(to_process))

        all_idx = numpy.zeros(0, dtype=numpy.int32)
        for gcount, target in enumerate(to_process):

            count    = all_paddings[target]
            
            if mode == 1:
                idx  = numpy.random.permutation(numpy.where(labels == target)[0])[:500]
                pc_ids[count:count+len(idx)] = idx
            elif mode == 0:
                idx  = numpy.where(labels == target)[0]

            elec     = best_elec[target]
            indices  = inv_nodes[edges[nodes[elec]]]
            labels_i = target*numpy.ones(len(idx))
            times_i  = numpy.take(spikes, idx).astype(numpy.int64)
            sub_data = io.get_stas(params, times_i, labels_i, elec, neighs=indices, nodes=nodes, auto_align=False)
            
            pcs      = numpy.dot(sub_data, basis_proj)
            pcs      = numpy.swapaxes(pcs, 1,2)
            if mode == 0:
                pc_features[idx, :, :len(indices)] = pcs                    
            elif mode == 1:
                pc_features[count:count+len(idx), :, :len(indices)] = pcs

            if comm.rank == 0:
              pbar.update(gcount)

        if comm.rank == 0:
          pbar.finish()

        comm.Barrier()

        if comm.rank == 0:
            numpy.save(os.path.join(output_path, 'pc_feature_ind'), pc_features_ind.astype(numpy.uint32)) #n_templates, n_loc_chan

    do_export = True
    if comm.rank == 0:
        if os.path.exists(output_path):
            if not erase_all:
                do_export = query_yes_no(Fore.WHITE + "Export already made! Do you want to erase everything?", default=None)

            if do_export:
                if os.path.exists(os.path.abspath('.phy')):
                    shutil.rmtree(os.path.abspath('.phy'))
                shutil.rmtree(output_path)
        if do_export == True:
            comm.bcast(numpy.array([1], dtype=numpy.int32), root=0)
        elif do_export == False:
            comm.bcast(numpy.array([0], dtype=numpy.int32), root=0)
    else:
        do_export = bool(comm.bcast(numpy.array([0], dtype=numpy.int32), root=0))
    
    comm.Barrier()

    if do_export:

        if comm.rank == 0:
            os.makedirs(output_path)
            print_and_log(["Exporting data for the phy GUI with %d CPUs..." %nb_cpu], 'info', logger)
        
            if params.getboolean('whitening', 'spatial'):
                whitening_mat = io.load_data(params, 'spatial_whitening').astype(numpy.double)
                numpy.save(os.path.join(output_path, 'whitening_mat'), whitening_mat)
                numpy.save(os.path.join(output_path, 'whitening_mat_inv'), numpy.linalg.inv(whitening_mat))
            else:
                numpy.save(os.path.join(output_path, 'whitening_mat'), numpy.eye(N_e))

            numpy.save(os.path.join(output_path, 'channel_positions'), generate_mapping(probe).astype(numpy.double))
            nodes, edges   = get_nodes_and_edges(params)
            numpy.save(os.path.join(output_path, 'channel_map'), nodes.astype(numpy.int32))

            write_results(output_path, params, extension)    
            N_tm = write_templates(output_path, params, extension)
            similarities = h5py.File(file_out_suff + '.templates%s.hdf5' %extension, 'r+', libver='latest').get('maxoverlap')
            norm = N_e*N_t

            if export_all:
                to_write = numpy.zeros((N_tm + N_e, N_tm + N_e), dtype=numpy.single)
                to_write[:N_tm, :N_tm] = (similarities[:N_tm, :N_tm]/norm).astype(numpy.single)
            else:
                to_write = (similarities[:N_tm, :N_tm]/norm).astype(numpy.single)
            numpy.save(os.path.join(output_path, 'similar_templates'), to_write)
        
        comm.Barrier()

        make_pcs = 2
        if comm.rank == 0:

            if export_pcs == 'prompt':
                key = ''
                while key not in ['a', 's', 'n']:
                    print(Fore.WHITE + "Do you want SpyKING CIRCUS to export PCs? (a)ll / (s)ome / (n)o")
                    key = raw_input('')
            else:
                key = export_pcs

            if key == 'a':
                make_pcs = 0
                comm.bcast(numpy.array([0], dtype=numpy.int32), root=0)
            elif key == 's':
                make_pcs = 1
                comm.bcast(numpy.array([1], dtype=numpy.int32), root=0)
            elif key == 'n':
                comm.bcast(numpy.array([2], dtype=numpy.int32), root=0)
                if os.path.exists(os.path.join(output_path, 'pc_features.npy')):
                    os.remove(os.path.join(output_path, 'pc_features.npy'))
                if os.path.exists(os.path.join(output_path, 'pc_feature_ind.npy')):
                    os.remove(os.path.join(output_path, 'pc_feature_ind.npy'))
        else:
            make_pcs = comm.bcast(numpy.array([0], dtype=numpy.int32), root=0)
            make_pcs = make_pcs[0]

        comm.Barrier()
        if make_pcs < 2:
            write_pcs(output_path, params, extension, make_pcs)
示例#12
0
def main(argv=None):
    
    if argv is None:
        argv = sys.argv[1:]

    header = get_colored_header()
    header += '''Utility to group files within several folders into a single
virtual folder, such that they can be processed together with the
multi-files mode. 
If you want to also process .dead or .trig files in order to later 
on concatenate artefacts, please use the -d or -t options
    '''

    parser = argparse.ArgumentParser(description=header,
                                     formatter_class=argparse.RawTextHelpFormatter)
    parser.add_argument('folders', help='text file with the list of folders to consider')
    parser.add_argument('extension', help='file extension to consider within folders')
    
    parser.add_argument('-o', '--output', help='name of the output folder [default is output]', default='output')
    parser.add_argument('-d', '--dead', help='Search for all .dead files', action='store_true')
    parser.add_argument('-t', '--trig', help='Search for all .trig files', action='store_true')

    if len(argv) == 0:
        parser.print_help()
        sys.exit()

    args = parser.parse_args(argv)


    folders_file = os.path.abspath(args.folders)
    output      = os.path.abspath(args.output)
    extension   = args.extension

    filename, ext = os.path.splitext(os.path.basename(folders_file))

    logger = init_logging(filename + '.log')
    logger = logging.getLogger(__name__)

    if not os.path.exists(folders_file):
        print_and_log(['The folder file %s does not exists!' %folders_file], 'error', logger)
        sys.exit(0)

    try:
        folders = []
        myfile = open(folders_file, 'r')
        lines  = myfile.readlines()
        myfile.close()
        for l in lines:
            folders += [os.path.abspath(l.strip())]
    except Exception:
        print_and_log(['Check the syntax of the folder file'], 'error', logger)
        sys.exit(0)        

    do_folders = True

    if os.path.exists(output):
        do_folders = query_yes_no(Fore.WHITE + "Folder %s already exists! Do you want to erase everything?" %output, default=None)
        if not do_folders:
            sys.exit(0)
        else:
            shutil.rmtree(output)

    os.makedirs(output)

    for count, folder in enumerate(folders):
        files = os.listdir(folder)
        for file in files:
            _, ext = os.path.splitext(file)
            ext = ext.strip('.')
            if (ext.lower() == extension.lower()) or (args.dead and ext.lower() == 'dead') or (args.trig and ext.lower()== 'trig'):
                original_file = os.path.join(folder, file)
                linked_file = os.path.join(output, 'sc_{c}_{f}'.format(c=count, f=os.path.basename(original_file)))
                if not os.path.exists(linked_file):
                    os.symlink(original_file, linked_file)
                else:
                    os.symlink(original_file, linked_file)
示例#13
0
def main(params, nb_cpu, nb_gpu, use_gpu):
    numpy.random.seed(426236)
    #params         = detect_memory(params)
    parallel_hdf5 = get_parallel_hdf5_flag(params)
    logger = init_logging(params.logfile)
    logger = logging.getLogger('circus.extracting')
    #################################################################
    data_file = params.data_file
    N_e = params.getint('data', 'N_e')
    N_t = params.getint('detecton', 'N_t')
    N_total = params.nb_channels
    template_shift = params.getint('detection', 'template_shift')
    chunk_size = params.getint('data', 'chunk_size')
    file_out = params.get('data', 'file_out')
    file_out_suff = params.get('data', 'file_out_suff')
    do_temporal_whitening = params.getboolean('whitening', 'temporal')
    do_spatial_whitening = params.getboolean('whitening', 'spatial')
    nodes, edges = get_nodes_and_edges(params)
    safety_time = params.getint('extracting', 'safety_time')
    max_elts_temp = params.getint('extracting', 'max_elts')
    output_dim = params.getfloat('extracting', 'output_dim')
    noise_thr = params.getfloat('extracting', 'noise_thr')
    hdf5_compress = params.getboolean('data', 'hdf5_compress')
    blosc_compress = params.getboolean('data', 'blosc_compress')
    tmp_limits = params.get('fitting',
                            'amp_limits').replace('(',
                                                  '').replace(')',
                                                              '').split(',')
    amp_limits = map(float, tmp_limits)
    elt_count = 0
    inv_nodes = numpy.zeros(N_total, dtype=numpy.int32)
    inv_nodes[nodes] = numpy.argsort(nodes)
    #################################################################

    if comm.rank == 0:
        print_and_log(["Extracting templates from already found clusters..."],
                      'default', logger)

    thresholds = io.load_data(params, 'thresholds')
    basis_proj, basis_rec = io.load_data(params, 'basis')
    clusters, spiketimes, N_clusters = io.load_data(params, 'spike-cluster')
    inv_clusters = numpy.zeros(clusters.max() + 1, dtype=numpy.int32)
    inv_clusters[numpy.unique(clusters)] = numpy.argsort(
        numpy.unique(clusters))

    if use_gpu:
        import cudamat as cmt
        ## Need to properly handle multi GPU per MPI nodes?
        if nb_gpu > nb_cpu:
            gpu_id = int(comm.rank // nb_cpu)
        else:
            gpu_id = 0
        cmt.cuda_set_device(gpu_id)
        cmt.init()
        cmt.cuda_sync_threads()

    if do_spatial_whitening:
        spatial_whitening = io.load_data(params, 'spatial_whitening')
    if do_temporal_whitening:
        temporal_whitening = io.load_data(params, 'temporal_whitening')

    if use_gpu and do_spatial_whitening:
        spatial_whitening = cmt.CUDAMatrix(spatial_whitening,
                                           copy_on_host=False)

    result = {}
    for i in xrange(N_clusters):
        result['data_tmp_' + str(i)] = numpy.zeros(
            (0, N_e * basis_proj.shape[1]), dtype=numpy.float32)
        result['times_' + str(i)] = numpy.zeros(0, dtype=numpy.int32)

    nb_chunks, last_chunk_len = data_file.analyze(chunk_size)

    # I guess this is more relevant, to take signals from all over the recordings
    all_chunks = numpy.random.permutation(numpy.arange(nb_chunks))

    nb_templates = numpy.sum(
        comm.rank == numpy.mod(numpy.arange(N_clusters), comm.size))
    nb_elts = max_elts_temp * nb_templates

    to_explore = all_chunks

    if comm.rank == 0:
        to_explore = get_tqdm_progressbar(to_explore)

    for gidx in all_chunks:

        if (elt_count < nb_elts):
            #print "Node", comm.rank, "is analyzing chunk", gidx, "/", nb_chunks, " ..."
            local_chunk, t_offset = data_file.get_data(gidx,
                                                       chunk_size,
                                                       nodes=nodes)
            local_shape = len(local_chunk)

            if do_spatial_whitening:
                if use_gpu:
                    local_chunk = cmt.CUDAMatrix(local_chunk,
                                                 copy_on_host=False)
                    local_chunk = local_chunk.dot(spatial_whitening).asarray()
                else:
                    local_chunk = numpy.dot(local_chunk, spatial_whitening)
            if do_temporal_whitening:
                local_chunk = scipy.ndimage.filters.convolve1d(
                    local_chunk, temporal_whitening, axis=0, mode='constant')

            #print "Extracting the peaks..."
            idx = numpy.where((spiketimes >= gidx * chunk_size)
                              & (spiketimes < (gidx + 1) * chunk_size))[0]
            local_offset = t_offset
            local_peaktimes = spiketimes[idx] - local_offset

            #print "Removing the useless borders..."
            local_borders = (template_shift, chunk_size - template_shift)
            idx = (local_peaktimes >= local_borders[0]) & (local_peaktimes <
                                                           local_borders[1])
            local_peaktimes = local_peaktimes[idx]
            local_clusters = inv_clusters[clusters[idx]]

            if len(local_peaktimes) > 0:
                all_times = numpy.zeros(
                    (N_e, local_peaktimes[-1] - local_peaktimes[0] + 1),
                    dtype=numpy.bool)
                min_times = numpy.maximum(
                    local_peaktimes - local_peaktimes[0] - safety_time, 0)
                max_times = numpy.minimum(
                    local_peaktimes - local_peaktimes[0] + safety_time + 1,
                    local_peaktimes[-1] - local_peaktimes[0])

                n_times = len(local_peaktimes)
                argmax_peak = numpy.random.permutation(numpy.arange(n_times))
                clusters_id = local_clusters[argmax_peak]
                local_peaktimes = local_peaktimes[argmax_peak]

                #print "Selection of the peaks with spatio-temporal masks..."
                for idx in xrange(len(local_peaktimes)):

                    if elt_count == nb_elts:
                        break

                    temp = clusters_id[idx]

                    if numpy.mod(temp, comm.size) == comm.rank:

                        elec = numpy.argmin(local_chunk[local_peaktimes[idx]])
                        indices = inv_nodes[edges[nodes[elec]]]
                        myslice = all_times[indices,
                                            min_times[idx]:max_times[idx]]
                        peak = local_peaktimes[idx]
                        if not myslice.any():
                            if (len(result['data_tmp_' + str(temp)]) <
                                    max_elts_temp):
                                elt_count += 1
                                sub_mat = local_chunk[peak -
                                                      template_shift:peak +
                                                      template_shift + 1, :]
                                sub_mat = numpy.dot(basis_rec, sub_mat)
                                nx, ny = sub_mat.shape
                                sub_mat = sub_mat.reshape((1, nx * ny))
                                result['data_tmp_' + str(temp)] = numpy.vstack(
                                    (result['data_tmp_' + str(temp)], sub_mat))
                                to_add = numpy.array([peak + local_offset],
                                                     dtype=numpy.int32)
                                result['times_' +
                                       str(temp)] = numpy.concatenate(
                                           (result['times_' + str(temp)],
                                            to_add))
                            all_times[indices,
                                      min_times[idx]:max_times[idx]] = True

    total_nb_elts = 0
    for temp in xrange(N_clusters):
        total_nb_elts += len(result['data_tmp_' + str(temp)])

    gdata = gather_array(numpy.array([total_nb_elts], dtype=numpy.float32),
                         comm, 0)
    if comm.rank == 0:
        print_and_log([
            "Found %d spikes over %d requested" %
            (int(numpy.sum(gdata)), int(nb_elts))
        ], 'default', logger)

    #print "Spikes extracted in", time.time() - t_start, "s"

    comm.Barrier()

    local_nb_clusters = 0
    for temp in xrange(comm.rank, N_clusters, comm.size):
        if len(result['data_tmp_' + str(temp)]) > 0:
            local_nb_clusters += 1

    #print total_nb_clusters, "found in", time.time() - t_start, "s"
    gdata3 = gather_array(
        numpy.array([local_nb_clusters], dtype=numpy.float32), comm, 0)

    comm.Barrier()
    if comm.rank == 0:
        print_and_log(["Extracting the templates..."], 'default', logger)

    total_nb_clusters = int(
        comm.bcast(numpy.array([int(numpy.sum(gdata3))], dtype=numpy.int32),
                   root=0)[0])
    offsets = numpy.zeros(comm.size, dtype=numpy.int32)
    for i in xrange(comm.size - 1):
        offsets[i + 1] = comm.bcast(numpy.array([local_nb_clusters],
                                                dtype=numpy.int32),
                                    root=i)

    if parallel_hdf5:
        node_pad = numpy.sum(offsets[:comm.rank + 1])
        hfile = h5py.File(file_out_suff + '.templates.hdf5',
                          'w',
                          driver='mpio',
                          comm=comm,
                          libver='earliest')
        norms = hfile.create_dataset('norms',
                                     shape=(2 * total_nb_clusters, ),
                                     dtype=numpy.float32,
                                     chunks=True)
        electrodes = hfile.create_dataset('electrodes',
                                          shape=(total_nb_clusters, ),
                                          dtype=numpy.int32,
                                          chunks=True)
        amps_lims = hfile.create_dataset('limits',
                                         shape=(total_nb_clusters, 2),
                                         dtype=numpy.float32,
                                         chunks=True)
        g_count = node_pad
        g_offset = total_nb_clusters
    else:
        node_pad = 0
        hfile = h5py.File(file_out_suff + '.templates-%d.hdf5' % comm.rank,
                          'w',
                          libver='earliest')
        electrodes = hfile.create_dataset('electrodes',
                                          shape=(local_nb_clusters, ),
                                          dtype=numpy.int32,
                                          chunks=True)
        norms = hfile.create_dataset('norms',
                                     shape=(2 * local_nb_clusters, ),
                                     dtype=numpy.float32,
                                     chunks=True)
        amps_lims = hfile.create_dataset('limits',
                                         shape=(local_nb_clusters, 2),
                                         dtype=numpy.float32,
                                         chunks=True)
        g_count = 0
        g_offset = local_nb_clusters

    cfile = h5py.File(file_out_suff + '.clusters-%d.hdf5' % comm.rank,
                      'w',
                      libver='earliest')
    count_templates = node_pad

    temp_x = numpy.zeros(0, dtype=numpy.int32)
    temp_y = numpy.zeros(0, dtype=numpy.int32)
    temp_data = numpy.zeros(0, dtype=numpy.float32)

    to_explore = xrange(comm.rank, N_clusters, comm.size)

    if comm.rank == 0:
        to_explore = get_tqdm_progressbar(to_explore)

    for temp in to_explore:
        n_data = len(result['data_tmp_' + str(temp)])
        if n_data > 0:
            data = result['data_tmp_' + str(temp)].reshape(
                n_data, basis_proj.shape[1], N_e)
            first_component = numpy.median(data, axis=0)
            tmp_templates = numpy.dot(first_component.T, basis_rec)
            electrodes[g_count] = indices[tmpidx[0][0]]
            indices = inv_nodes[edges[nodes[electrodes[-1]]]]
            templates = numpy.zeros((N_e, N_t), dtype=numpy.float32)
            if shift > 0:
                templates[indices, shift:] = tmp_templates[:, :-shift]
            elif shift < 0:
                templates[indices, :shift] = tmp_templates[:, -shift:]
            else:
                templates[indices, :] = tmp_templates

            templates = templates.flatten()
            dx = templates.nonzero()[0].astype(numpy.int32)

            temp_x = numpy.concatenate((temp_x, dx))
            temp_y = numpy.concatenate(
                (temp_y,
                 count_templates * numpy.ones(len(dx), dtype=numpy.int32)))
            temp_data = numpy.concatenate((temp_data, templates[dx]))

            norms[g_count] = numpy.sqrt(
                numpy.sum(templates.flatten()**2) / (N_e * N_t))

            x, y, z = data.shape
            data_flat = data.reshape(x, y * z)
            first_flat = first_component.reshape(y * z, 1)
            amplitudes = numpy.dot(data_flat, first_flat)
            amplitudes /= numpy.sum(first_flat**2)
            for i in xrange(x):
                data_flat[i, :] -= amplitudes[i] * first_flat[:, 0]

            variations = 10 * numpy.median(
                numpy.abs(amplitudes - numpy.median(amplitudes)))
            physical_limit = noise_thr * (
                -thresholds[indices[tmpidx[0][0]]]) / tmp_templates.min()
            amp_min = max(physical_limit,
                          numpy.median(amplitudes) - variations)
            amp_max = min(amp_limits[1], numpy.median(amplitudes) + variations)
            amps_lims[g_count] = [amp_min, amp_max]

            if len(data_flat) > 1:
                pca = PCA(1)
                res_pca = pca.fit_transform(data_flat.astype(numpy.double))
                second_component = pca.components_.T.astype(
                    numpy.float32).reshape(y, z)
            else:
                second_component = data_flat.reshape(y, z) / numpy.sum(
                    data_flat**2)

            tmp_templates = numpy.dot(second_component.T, basis_rec)
            offset = total_nb_clusters + count_templates
            sub_templates = numpy.zeros((N_e, N_t), dtype=numpy.float32)
            if shift > 0:
                sub_templates[indices, shift:] = tmp_templates[:, :-shift]
            elif shift < 0:
                sub_templates[indices, :shift] = tmp_templates[:, -shift:]
            else:
                sub_templates[indices, :] = tmp_templates

            sub_templates = sub_templates.flatten()
            dx = sub_templates.nonzero()[0].astype(numpy.int32)

            temp_x = numpy.concatenate((temp_x, dx))
            temp_y = numpy.concatenate(
                (temp_y, offset * numpy.ones(len(dx), dtype=numpy.int32)))
            temp_data = numpy.concatenate((temp_data, sub_templates[dx]))

            norms[g_count + g_offset] = numpy.sqrt(
                numpy.sum(sub_templates.flatten()**2) / (N_e * N_t))

            count_templates += 1
            g_count += 1

        io.write_datasets(cfile,
                          to_write,
                          result,
                          ielec,
                          compress=hdf5_compress)

    #At the end we should have a templates variable to store.
    cfile.close()
    del result, templates, amps_lims
    comm.Barrier()

    #We need to gather the sparse arrays
    temp_x = gather_array(temp_x, comm, dtype='int32', compress=blosc_compress)
    temp_y = gather_array(temp_y, comm, dtype='int32', compress=blosc_compress)
    temp_data = gather_array(temp_data, comm, compress=blosc_compress)

    if parallel_hdf5:
        if comm.rank == 0:
            rs = [
                h5py.File(file_out_suff + '.clusters-%d.hdf5' % i,
                          'r',
                          libver='earliest') for i in xrange(comm.size)
            ]
            cfile = h5py.File(file_out_suff + '.clusters.hdf5',
                              'w',
                              libver='earliest')
            io.write_datasets(cfile, ['electrodes'],
                              {'electrodes': electrodes[:]},
                              compress=hdf5_compress)
            for i in xrange(comm.size):
                for j in range(i, N_e, comm.size):
                    io.write_datasets(cfile,
                                      to_write,
                                      rs[i],
                                      j,
                                      compress=hdf5_compress)
                rs[i].close()
                os.remove(file_out_suff + '.clusters-%d.hdf5' % i)
            cfile.close()
        hfile.close()
    else:
        hfile.close()
        if comm.rank == 0:
            ts = [
                h5py.File(file_out_suff + '.templates-%d.hdf5' % i,
                          'r',
                          libver='earliest') for i in xrange(comm.size)
            ]
            rs = [
                h5py.File(file_out_suff + '.clusters-%d.hdf5' % i,
                          'r',
                          libver='earliest') for i in xrange(comm.size)
            ]
            result = {}
            hfile = h5py.File(file_out_suff + '.templates.hdf5',
                              'w',
                              libver='earliest')
            cfile = h5py.File(file_out_suff + '.clusters.hdf5',
                              'w',
                              libver='earliest')
            electrodes = hfile.create_dataset('electrodes',
                                              shape=(total_nb_clusters, ),
                                              dtype=numpy.int32,
                                              chunks=True)
            norms = hfile.create_dataset('norms',
                                         shape=(2 * total_nb_clusters, ),
                                         dtype=numpy.float32,
                                         chunks=True)
            amplitudes = hfile.create_dataset('limits',
                                              shape=(total_nb_clusters, 2),
                                              dtype=numpy.float32,
                                              chunks=True)
            count = 0
            for i in xrange(comm.size):
                loc_temp = ts[i].get('templates')
                middle = loc_temp.shape[2] // 2
                norms[count:count + middle] = loc_norms[:middle]
                norms[n_clusters + count:n_clusters + count +
                      middle] = loc_norms[middle:]
                electrodes[count:count + middle] = ts[i].get('electrodes')
                amplitudes[count:count + middle] = ts[i].get('limits')
                count += middle
                for j in range(i, N_e, comm.size):
                    io.write_datasets(cfile,
                                      to_write,
                                      rs[i],
                                      j,
                                      compress=hdf5_compress)
                ts[i].close()
                rs[i].close()
                os.remove(file_out_suff + '.templates-%d.hdf5' % i)
                os.remove(file_out_suff + '.clusters-%d.hdf5' % i)
            io.write_datasets(cfile, ['electrodes'],
                              {'electrodes': electrodes[:]},
                              compress=hdf5_compress)
            hfile.close()
            cfile.close()

    if comm.rank == 0:
        hfile = h5py.File(file_out_suff + '.templates.hdf5',
                          'r+',
                          libver='earliest')
        hfile.create_dataset('temp_x', data=temp_x)
        hfile.create_dataset('temp_y', data=temp_y)
        hfile.create_dataset('temp_data', data=temp_data)
        hfile.create_dataset('temp_shape',
                             data=numpy.array(
                                 [N_e, N_t, 2 * total_nb_clusters],
                                 dtype=numpy.int32))
        hfile.close()

    comm.Barrier()

    if comm.rank == 0:
        print_and_log(["Merging similar templates..."], 'default', logger)

    merged1 = algo.merging_cc(params, parallel_hdf5)

    comm.Barrier()
    if remove_mixture:
        if comm.rank == 0:
            print_and_log(["Removing mixtures..."], 'default', logger)
        merged2 = algo.delete_mixtures(params, parallel_hdf5)
    else:
        merged2 = [0, 0]

    if comm.rank == 0:
        print_and_log([
            "Number of global merges    : %d" % merged1[1],
            "Number of mixtures removed : %d" % merged2[1]
        ], 'info', logger)

    comm.Barrier()
    io.get_overlaps(params, erase=True, parallel_hdf5=parallel_hdf5)

    data_file.close()
示例#14
0
def main(params, nb_cpu, nb_gpu, use_gpu):

    #################################################################
    logger         = init_logging(params.logfile)
    logger         = logging.getLogger('circus.fitting')
    data_file      = params.data_file
    data_file.open()
    N_e            = params.getint('data', 'N_e')
    N_total        = params.nb_channels
    N_t            = params.getint('detection', 'N_t')
    template_shift = params.getint('detection', 'template_shift')
    file_out       = params.get('data', 'file_out')
    file_out_suff  = params.get('data', 'file_out_suff')
    sign_peaks     = params.get('detection', 'peaks')
    matched_filter = params.getboolean('detection', 'matched-filter')
    spike_thresh   = params.getfloat('detection', 'spike_thresh')
    do_temporal_whitening = params.getboolean('whitening', 'temporal')
    do_spatial_whitening  = params.getboolean('whitening', 'spatial')
    chunk_size     = params.getint('fitting', 'chunk_size')
    gpu_only       = params.getboolean('fitting', 'gpu_only')
    nodes, edges   = get_nodes_and_edges(params)
    tmp_limits     = params.get('fitting', 'amp_limits').replace('(', '').replace(')', '').split(',')
    tmp_limits     = map(float, tmp_limits)
    amp_auto       = params.getboolean('fitting', 'amp_auto')
    space_explo    = params.getfloat('fitting', 'space_explo')
    nb_chances     = params.getint('fitting', 'nb_chances')
    max_chunk      = params.getfloat('fitting', 'max_chunk')
    noise_thr      = params.getfloat('clustering', 'noise_thr')
    collect_all    = params.getboolean('fitting', 'collect_all')
    ignore_dead_times = params.getboolean('triggers', 'ignore_times')
    inv_nodes         = numpy.zeros(N_total, dtype=numpy.int32)
    inv_nodes[nodes]  = numpy.argsort(nodes)
    #################################################################

    if use_gpu:
        import cudamat as cmt
        ## Need to properly handle multi GPU per MPI nodes?
        if nb_gpu > nb_cpu:
            gpu_id = int(comm.rank//nb_cpu)
        else:
            gpu_id = 0
        cmt.cuda_set_device(gpu_id)
        cmt.init()
        cmt.cuda_sync_threads()

    if SHARED_MEMORY:
        templates  = io.load_data_memshared(params, 'templates', normalize=True, transpose=True)
        N_tm, x    = templates.shape
    else:
        templates  = io.load_data(params, 'templates')
        x, N_tm    = templates.shape

    temp_2_shift   = 2*template_shift
    full_gpu       = use_gpu and gpu_only
    n_tm           = N_tm//2
    n_scalar       = N_e*N_t
    last_spikes    = numpy.zeros((n_tm, 1), dtype=numpy.int32)
    temp_window    = numpy.arange(-template_shift, template_shift+1)

    if not amp_auto:
        amp_limits       = numpy.zeros((n_tm, 2))
        amp_limits[:, 0] = tmp_limits[0]
        amp_limits[:, 1] = tmp_limits[1]
    else:
        amp_limits       = io.load_data(params, 'limits')

    norm_templates = io.load_data(params, 'norm-templates')
    
    if not SHARED_MEMORY:
        for idx in xrange(templates.shape[1]):
            myslice = numpy.arange(templates.indptr[idx], templates.indptr[idx+1])
            templates.data[myslice] /= norm_templates[idx]
        templates = templates.T

    if matched_filter:
        if sign_peaks in ['negative', 'both']:
            waveform_neg  = io.load_data(params, 'waveform')
            waveform_neg /= (numpy.abs(numpy.sum(waveform_neg))* len(waveform_neg))
            matched_tresholds_neg = io.load_data(params, 'matched-thresholds')
        if sign_peaks in ['positive', 'both']:
            waveform_pos  = io.load_data(params, 'waveform-pos')
            waveform_pos /= (numpy.abs(numpy.sum(waveform_pos))* len(waveform_pos))
            matched_tresholds_pos = io.load_data(params, 'matched-thresholds-pos')

    if ignore_dead_times:
        dead_times = numpy.loadtxt(params.get('triggers', 'dead_file'))
        if len(dead_times.shape) == 1:
            dead_times = dead_times.reshape(1, 2)
        dead_in_ms = params.getboolean('triggers', 'dead_in_ms')
        if dead_in_ms:
            dead_times *= numpy.int64(data_file.sampling_rate*1e-3)
        dead_times = dead_times.astype(numpy.int64)
        all_dead_times = []
        for i in xrange(len(dead_times)):
            all_dead_times += range(dead_times[i, 0], dead_times[i, 1])

    thresholds = io.load_data(params, 'thresholds')


    if collect_all:
        neighbors = {}
        for i in xrange(n_tm):
            tmp  = templates[i, :].toarray().reshape(N_e, N_t) * norm_templates[i]
            neighbors[i] = numpy.where(numpy.sum(tmp, 1) != 0)[0]

    if use_gpu:
        templates = cmt.SparseCUDAMatrix(templates, copy_on_host=False)

    info_string   = ''

    
    if comm.rank == 0:
        if use_gpu:
            info_string = "using %d GPUs" %(comm.size)
        else:
            info_string = "using %d CPUs" %(comm.size)

    comm.Barrier()

    c_overlap  = io.get_overlaps(params, nb_cpu=nb_cpu, nb_gpu=nb_gpu, use_gpu=use_gpu)
    over_shape = c_overlap.get('over_shape')[:]
    N_over     = int(numpy.sqrt(over_shape[0]))
    S_over     = over_shape[1]
    ## If the number of overlaps is different from templates, we need to recompute them
    if N_over != N_tm:
        if comm.rank == 0:
            print_and_log(['Templates have been modified, recomputing the overlaps...'], 'default', logger)
        c_overlap  = io.get_overlaps(params, erase=True, nb_cpu=nb_cpu, nb_gpu=nb_gpu, use_gpu=use_gpu)
        over_shape = c_overlap.get('over_shape')[:]
        N_over     = int(numpy.sqrt(over_shape[0]))
        S_over     = over_shape[1]

    if SHARED_MEMORY:
        c_overs    = io.load_data_memshared(params, 'overlaps', nb_cpu=nb_cpu, nb_gpu=nb_gpu, use_gpu=use_gpu)
    else:
        c_overlap  = io.get_overlaps(params, nb_cpu=nb_cpu, nb_gpu=nb_gpu, use_gpu=use_gpu)
        over_x     = c_overlap.get('over_x')[:]
        over_y     = c_overlap.get('over_y')[:]
        over_data  = c_overlap.get('over_data')[:]
        over_shape = c_overlap.get('over_shape')[:]
        c_overlap.close()

        # To be faster, we rearrange the overlaps into a dictionnary. This has a cost: twice the memory usage for 
        # a short period of time
        c_overs   = {}
        overlaps  = scipy.sparse.csr_matrix((over_data, (over_x, over_y)), shape=(over_shape[0], over_shape[1]))
        del over_x, over_y, over_data
        
        for i in xrange(N_over):
            c_overs[i] = overlaps[i*N_over:(i+1)*N_over]
        del overlaps

    comm.Barrier()

    if comm.rank == 0:
        print_and_log(["Here comes the SpyKING CIRCUS %s and %d templates..." %(info_string, n_tm)], 'default', logger)
        purge(file_out_suff, '.data')

    if do_spatial_whitening:
        spatial_whitening  = io.load_data(params, 'spatial_whitening')
    if do_temporal_whitening:
        temporal_whitening = io.load_data(params, 'temporal_whitening')

    if full_gpu:
        try:
            # If memory on the GPU is large enough, we load the overlaps onto it
            for i in xrange(N_over):
                c_overs[i] = cmt.SparseCUDAMatrix(c_overs[i], copy_on_host=False)
        except Exception:
            if comm.rank == 0:
                print_and_log(["Not enough memory on GPUs: GPUs are used for projection only"], 'info', logger)
            for i in xrange(N_over):
                if c_overs.has_key(i):
                    del c_overs[i]
            full_gpu = False

    nb_chunks, last_chunk_len = data_file.analyze(chunk_size)
    processed_chunks          = int(min(nb_chunks, max_chunk))

    comm.Barrier()
    spiketimes_file = open(file_out_suff + '.spiketimes-%d.data' %comm.rank, 'wb')
    comm.Barrier()
    amplitudes_file = open(file_out_suff + '.amplitudes-%d.data' %comm.rank, 'wb')
    comm.Barrier()
    templates_file  = open(file_out_suff + '.templates-%d.data' %comm.rank, 'wb')
    comm.Barrier()

    if collect_all:
        garbage_times_file = open(file_out_suff + '.gspiketimes-%d.data' %comm.rank, 'wb')
        comm.Barrier()
        garbage_temp_file  = open(file_out_suff + '.gtemplates-%d.data' %comm.rank, 'wb')
        comm.Barrier()


    if use_gpu and do_spatial_whitening:
        spatial_whitening = cmt.CUDAMatrix(spatial_whitening, copy_on_host=False)

    last_chunk_size = 0

    to_explore = xrange(comm.rank, processed_chunks, comm.size)

    if comm.rank == 0:
        to_explore = get_tqdm_progressbar(to_explore)

    for gcount, gidx in enumerate(to_explore):
        #print "Node", comm.rank, "is analyzing chunk", gidx, "/", nb_chunks, " ..."
        ## We need to deal with the borders by taking chunks of size [0, chunck_size+template_shift]

        is_first = data_file.is_first_chunk(gidx, nb_chunks)
        is_last  = data_file.is_last_chunk(gidx, nb_chunks)

        if is_last:
            padding = (-2*template_shift, 0)
        elif is_first:
            padding = (0, 2*template_shift)
        else:
            padding = (-2*template_shift, 2*template_shift)

        result       = {'spiketimes' : [], 'amplitudes' : [], 'templates' : []}

        local_chunk, t_offset = data_file.get_data(gidx, chunk_size, padding, nodes=nodes)           
        len_chunk             = len(local_chunk)

        if do_spatial_whitening:
            if use_gpu:
                local_chunk = cmt.CUDAMatrix(local_chunk, copy_on_host=False)
                local_chunk = local_chunk.dot(spatial_whitening).asarray()
            else:
                local_chunk = numpy.dot(local_chunk, spatial_whitening)
        if do_temporal_whitening:
            local_chunk = scipy.ndimage.filters.convolve1d(local_chunk, temporal_whitening, axis=0, mode='constant')

        #print "Extracting the peaks..."

        if collect_all:
            all_found_spikes = {}
            for i in xrange(N_e):
                all_found_spikes[i] = []

        local_peaktimes = numpy.zeros(0, dtype=numpy.int32)

        if matched_filter:
            if sign_peaks in ['positive', 'both']:
                filter_chunk = scipy.ndimage.filters.convolve1d(local_chunk, waveform_pos, axis=0, mode='constant')
                for i in xrange(N_e):
                    peaktimes = algo.detect_peaks(filter_chunk[:, i], matched_tresholds_pos[i])
                    local_peaktimes = numpy.concatenate((local_peaktimes, peaktimes))
                    if collect_all:
                        all_found_spikes[i] += peaktimes.tolist()
            if sign_peaks in ['negative', 'both']:
                filter_chunk = scipy.ndimage.filters.convolve1d(local_chunk, waveform_neg, axis=0, mode='constant')
                for i in xrange(N_e):
                    peaktimes = algo.detect_peaks(filter_chunk[:, i], matched_tresholds_neg[i])
                    local_peaktimes = numpy.concatenate((local_peaktimes, peaktimes))
                    if collect_all:
                        all_found_spikes[i] += peaktimes.tolist()
        else:
            for i in xrange(N_e):
                if sign_peaks == 'negative':
                    peaktimes = algo.detect_peaks(local_chunk[:, i], thresholds[i], valley=True)
                elif sign_peaks == 'positive':
                    peaktimes = algo.detect_peaks(local_chunk[:, i], thresholds[i], valley=False)
                elif sign_peaks == 'both':
                    peaktimes = algo.detect_peaks(numpy.abs(local_chunk[:, i]), thresholds[i], valley=False)                    
                local_peaktimes = numpy.concatenate((local_peaktimes, peaktimes)) 
                if collect_all:
                    all_found_spikes[i] += peaktimes.tolist()


            
        local_peaktimes = numpy.unique(local_peaktimes)

        if ignore_dead_times:
            local_peaktimes = numpy.array(list(set(local_peaktimes + t_offset).difference(all_dead_times)), dtype=numpy.int32) - t_offset
            local_peaktimes = numpy.sort(local_peaktimes)

        #print "Removing the useless borders..."
        local_borders   = (template_shift, len_chunk - template_shift)
        idx             = (local_peaktimes >= local_borders[0]) & (local_peaktimes < local_borders[1])
        local_peaktimes = numpy.compress(idx, local_peaktimes)

        if collect_all:
            for i in xrange(N_e):
                all_found_spikes[i] = numpy.array(all_found_spikes[i], dtype=numpy.int32)

                if ignore_dead_times:
                    all_found_spikes[i] = numpy.array(list(set(all_found_spikes[i] + t_offset).difference(all_dead_times)), dtype=numpy.int32) - t_offset
                    all_found_spikes[i] = numpy.sort(all_found_spikes[i])

                idx                 = (all_found_spikes[i] >= local_borders[0]) & (all_found_spikes[i] < local_borders[1])
                all_found_spikes[i] = numpy.compress(idx, all_found_spikes[i])

        n_t             = len(local_peaktimes)
        all_indices     = numpy.arange(n_t)
                            

        if full_gpu:
        #   all_indices = cmt.CUDAMatrix(all_indices)
            tmp_gpu = cmt.CUDAMatrix(local_peaktimes.reshape((1, n_t)), copy_on_host=False)


        if n_t > 0:
            #print "Computing the b (should full_gpu by putting all chunks on GPU if possible?)..."     

            if collect_all:
                c_local_chunk = local_chunk.copy()

            local_chunk = local_chunk.T.ravel()
            sub_mat     = numpy.zeros((N_e*(2*template_shift+1), n_t), dtype=numpy.float32)

            if len_chunk != last_chunk_size:
                slice_indices = numpy.zeros(0, dtype=numpy.int32)
                for idx in xrange(N_e):
                    slice_indices = numpy.concatenate((slice_indices, len_chunk*idx + temp_window))
                last_chunk_size = len_chunk

            for count, idx in enumerate(local_peaktimes):
                sub_mat[:, count] = numpy.take(local_chunk, slice_indices + idx)

            del local_chunk

            if use_gpu: 
                sub_mat = cmt.CUDAMatrix(sub_mat, copy_on_host=False)
                b       = cmt.sparse_dot(templates, sub_mat)
            else:
                b       = templates.dot(sub_mat)                

            del sub_mat

            local_offset = padding[0] + t_offset
            local_bounds = (temp_2_shift, len_chunk - temp_2_shift)
            all_spikes   = local_peaktimes + local_offset

            # Because for GPU, slicing by columns is more efficient, we need to transpose b
            #b           = b.transpose()
            if use_gpu and not full_gpu:
                b = b.asarray()

            failure     = numpy.zeros(n_t, dtype=numpy.int32)

            if full_gpu:
                mask     = numpy.zeros((2*n_tm, n_t), dtype=numpy.float32)
                mask[:n_tm, :] = 1
                data     = cmt.empty(mask.shape)
                patch_gpu= b.shape[1] == 1
            else:
                mask     = numpy.ones((n_tm, n_t), dtype=numpy.float32)
                sub_b    = b[:n_tm, :]

            min_time     = local_peaktimes.min()
            max_time     = local_peaktimes.max()
            local_len    = max_time - min_time + 1
            min_times    = numpy.maximum(local_peaktimes - min_time - temp_2_shift, 0)
            max_times    = numpy.minimum(local_peaktimes - min_time + temp_2_shift + 1, max_time - min_time)
            max_n_t      = int(space_explo*(max_time-min_time+1)//(2*temp_2_shift + 1))

            if collect_all:
                c_all_times = numpy.zeros((len_chunk, N_e), dtype=numpy.bool)
                c_min_times = numpy.maximum(numpy.arange(len_chunk) - template_shift, 0)
                c_max_times = numpy.minimum(numpy.arange(len_chunk) + template_shift + 1, len_chunk)
                for i in xrange(N_e):
                    c_all_times[all_found_spikes[i], i] = True
                    
            while (numpy.mean(failure) < nb_chances):

                if full_gpu:
                    gpu_mask    = cmt.CUDAMatrix(mask, copy_on_host=False)
                    b.mult(gpu_mask, data)
                    tmp_mat     = data.max(0)
                    argmax_bi   = numpy.argsort(tmp_mat.asarray()[0, :])[::-1]
                    del tmp_mat
                else:
                    data        = sub_b * mask
                    argmax_bi   = numpy.argsort(numpy.max(data, 0))[::-1]

                while (len(argmax_bi) > 0):

                    subset          = []
                    indices         = []
                    all_times       = numpy.zeros(local_len, dtype=numpy.bool)

                    for count, idx in enumerate(argmax_bi):
                        myslice = all_times[min_times[idx]:max_times[idx]]
                        if not myslice.any():
                            subset  += [idx]
                            indices += [count]
                            all_times[min_times[idx]:max_times[idx]] = True
                        if len(subset) > max_n_t:
                            break

                    subset    = numpy.array(subset, dtype=numpy.int32)
                    argmax_bi = numpy.delete(argmax_bi, indices)

                    if full_gpu:
                        b_array = b.asarray()
                        sub_b   = b_array[:n_tm, :]

                    inds_t, inds_temp = subset, numpy.argmax(numpy.take(sub_b, subset, axis=1), 0)

                    if full_gpu:
                        best_amp  = sub_b[inds_temp, inds_t]/n_scalar
                        best_amp2 = b_array[inds_temp + n_tm, inds_t]/n_scalar
                    else:
                        
                        best_amp  = sub_b[inds_temp, inds_t]/n_scalar
                        best_amp2 = b[inds_temp + n_tm, inds_t]/n_scalar

                    mask[inds_temp, inds_t] = 0

                    best_amp_n   = best_amp/numpy.take(norm_templates, inds_temp)
                    best_amp2_n  = best_amp2/numpy.take(norm_templates, inds_temp + n_tm)

                    all_idx      = ((best_amp_n >= amp_limits[inds_temp, 0]) & (best_amp_n <= amp_limits[inds_temp, 1]))
                    to_keep      = numpy.where(all_idx == True)[0]
                    to_reject    = numpy.where(all_idx == False)[0]
                    ts           = numpy.take(local_peaktimes, inds_t[to_keep])
                    good         = (ts >= local_bounds[0]) & (ts < local_bounds[1])

                    # We reduce to only the good times that will be kept
                    #to_keep      = to_keep[good]
                    #ts           = ts[good]
                    
                    if len(ts) > 0:
                        if full_gpu:
                            tmp  = cmt.CUDAMatrix(numpy.ones((len(ts), 1)), copy_on_host=False)
                            tmp3 = cmt.CUDAMatrix(-ts.reshape((len(ts), 1)), copy_on_host=False)
                            tmp  = tmp.dot(tmp_gpu)
                            tmp.add_col_vec(tmp3)
                            condition = cmt.empty(tmp.shape)
                            cmt.abs(tmp, condition).less_than(temp_2_shift + 1)
                            condition = condition.asarray().astype(numpy.bool)
                            tmp       = tmp.asarray().astype(numpy.int32)
                        else:
                            tmp      = numpy.dot(numpy.ones((len(ts), 1), dtype=numpy.int32), local_peaktimes.reshape((1, n_t)))
                            tmp     -= ts.reshape((len(ts), 1))
                            condition = numpy.abs(tmp) <= temp_2_shift

                        for count, keep in enumerate(to_keep):
                            
                            idx_b    = numpy.compress(condition[count, :], all_indices)
                            ytmp     = tmp[count, condition[count, :]] + temp_2_shift
                            
                            indices  = numpy.zeros((S_over, len(ytmp)), dtype=numpy.float32)
                            indices[ytmp, numpy.arange(len(ytmp))] = 1

                            if full_gpu: 
                                indices  = cmt.CUDAMatrix(indices, copy_on_host=False)
                                if patch_gpu:
                                    b_lines  = b.get_col_slice(0, b.shape[0])
                                else:
                                    b_lines  = b.get_col_slice(idx_b[0], idx_b[-1]+1)

                                tmp1 = cmt.sparse_dot(c_overs[inds_temp[keep]], indices, mult=-best_amp[keep])
                                tmp2 = cmt.sparse_dot(c_overs[inds_temp[keep] + n_tm], indices, mult=-best_amp2[keep])
                                b_lines.add(tmp1.add(tmp2))
                                del tmp1, tmp2
                            else:
                                tmp1   = c_overs[inds_temp[keep]].multiply(-best_amp[keep]).dot(indices)
                                tmp2   = c_overs[inds_temp[keep] + n_tm].multiply(-best_amp2[keep]).dot(indices)
                                b[:, idx_b] += tmp1 + tmp2

                            if good[count]:

                                t_spike               = ts[count] + local_offset
                                result['spiketimes'] += [t_spike]
                                result['amplitudes'] += [(best_amp_n[keep], best_amp2_n[keep])]
                                result['templates']  += [inds_temp[keep]]

                    myslice           = numpy.take(inds_t, to_reject)
                    failure[myslice] += 1
                    sub_idx           = (numpy.take(failure, myslice) >= nb_chances)
                    
                    mask[:, numpy.compress(sub_idx, myslice)] = 0


            spikes_to_write     = numpy.array(result['spiketimes'], dtype=numpy.uint32)
            amplitudes_to_write = numpy.array(result['amplitudes'], dtype=numpy.float32)
            templates_to_write  = numpy.array(result['templates'], dtype=numpy.int32)

            spiketimes_file.write(spikes_to_write.tostring())
            amplitudes_file.write(amplitudes_to_write.tostring())
            templates_file.write(templates_to_write.tostring())

            if collect_all:

                for temp, spike in zip(templates_to_write, spikes_to_write - local_offset):
                    c_all_times[c_min_times[spike]:c_max_times[spike], neighbors[temp]] = False

                gspikes       = numpy.where(numpy.sum(c_all_times, 1) > 0)[0]
                c_all_times   = numpy.take(c_all_times, gspikes, axis=0)
                c_local_chunk = numpy.take(c_local_chunk, gspikes, axis=0) * c_all_times                

                if sign_peaks == 'negative':
                    bestlecs = numpy.argmin(c_local_chunk, 1)
                    if matched_filter:
                        threshs = -matched_tresholds_neg[bestlecs]
                    else:
                        threshs = -thresholds[bestlecs]
                    idx      = numpy.where(numpy.min(c_local_chunk, 1) < threshs)[0]
                elif sign_peaks == 'positive':
                    bestlecs = numpy.argmax(c_local_chunk, 1)
                    if matched_filter:
                        threshs = matched_tresholds_pos[bestlecs]
                    else:
                        threshs = thresholds[bestlecs]
                    idx      = numpy.where(numpy.max(c_local_chunk, 1) > threshs)[0]
                elif sign_peaks == 'both':
                    c_local_chunk = numpy.abs(c_local_chunk)
                    bestlecs = numpy.argmax(c_local_chunk, 1)
                    if matched_filter:
                        threshs = numpy.minimum(matched_tresholds_neg[bestlecs], matched_tresholds_pos[bestlecs])
                    else:
                        threshs = thresholds[bestlecs]
                    idx      = numpy.where(numpy.max(c_local_chunk, 1) > threshs)[0]
                
                gspikes  = numpy.take(gspikes, idx)
                bestlecs = numpy.take(bestlecs, idx)
                gspikes_to_write     = numpy.array(gspikes + local_offset, dtype=numpy.uint32)
                gtemplates_to_write  = numpy.array(bestlecs, dtype=numpy.int32)

                garbage_times_file.write(gspikes_to_write.tostring())
                garbage_temp_file.write(gtemplates_to_write.tostring())
            

            if full_gpu:
                del gpu_mask, b, data

    spiketimes_file.flush()
    os.fsync(spiketimes_file.fileno())
    spiketimes_file.close()

    amplitudes_file.flush()
    os.fsync(amplitudes_file.fileno())
    amplitudes_file.close()

    templates_file.flush()
    os.fsync(templates_file.fileno())
    templates_file.close()

    if collect_all:

        garbage_temp_file.flush()
        os.fsync(garbage_temp_file.fileno())
        garbage_temp_file.close()
        
        garbage_times_file.flush()
        os.fsync(garbage_times_file.fileno())
        garbage_times_file.close()


    comm.Barrier()
    
    if comm.rank == 0:
        io.collect_data(comm.size, params, erase=True)

    data_file.close()
示例#15
0
    def write_results(path, params, extension):
        result = io.get_results(params, extension)
        spikes = [numpy.zeros(0, dtype=numpy.uint64)]
        clusters = [numpy.zeros(0, dtype=numpy.uint32)]
        amplitudes = [numpy.zeros(0, dtype=numpy.double)]
        N_tm = len(result['spiketimes'])

        has_purity = test_if_purity(params, extension)
        rpvs = []

        if prelabelling:
            labels = []
            norms = io.load_data(params, 'norm-templates', extension)
            norms = norms[:len(norms) // 2]
            if has_purity:
                purity = io.load_data(params, 'purity', extension)

        for key in result['spiketimes'].keys():
            temp_id = int(key.split('_')[-1])
            myspikes = result['spiketimes'].pop(key).astype(numpy.uint64)
            spikes.append(myspikes)
            myamplitudes = result['amplitudes'].pop(key).astype(numpy.double)
            amplitudes.append(myamplitudes[:, 0])
            clusters.append(temp_id *
                            numpy.ones(len(myamplitudes), dtype=numpy.uint32))
            rpv = get_rpv(myspikes, params.data_file.sampling_rate)
            rpvs += [[temp_id, rpv]]
            if prelabelling:
                if has_purity:
                    if rpv <= rpv_threshold:
                        if purity[temp_id] > 0.75:
                            labels += [[temp_id, 'good']]
                    else:
                        if purity[temp_id] > 0.75:
                            labels += [[temp_id, 'mua']]
                        else:
                            labels += [[temp_id, 'noise']]
                else:
                    median_amp = numpy.median(myamplitudes[:, 0])
                    std_amp = numpy.std(myamplitudes[:, 0])
                    if rpv <= rpv_threshold and numpy.abs(median_amp -
                                                          1) < 0.25:
                        labels += [[temp_id, 'good']]
                    else:
                        if median_amp < 0.5:
                            labels += [[temp_id, 'mua']]
                        elif norms[temp_id] < 0.1:
                            labels += [[temp_id, 'noise']]

        if export_all:
            print_and_log([
                "Last %d templates are unfitted spikes on all electrodes" % N_e
            ], 'info', logger)
            garbage = io.load_data(params, 'garbage', extension)
            for key in garbage['gspikes'].keys():
                elec_id = int(key.split('_')[-1])
                data = garbage['gspikes'].pop(key).astype(numpy.uint64)
                spikes.append(data)
                amplitudes.append(numpy.ones(len(data)))
                clusters.append((elec_id + N_tm) *
                                numpy.ones(len(data), dtype=numpy.uint32))

        if prelabelling:
            f = open(os.path.join(output_path, 'cluster_group.tsv'), 'w')
            f.write('cluster_id\tgroup\n')
            for l in labels:
                f.write('%s\t%s\n' % (l[0], l[1]))
            f.close()

        # f = open(os.path.join(output_path, 'cluster_rpv.tsv'), 'w')
        # f.write('cluster_id\trpv\n')
        # for l in rpvs:
        #     f.write('%s\t%s\n' % (l[0], l[1]))
        # f.close()

        spikes = numpy.concatenate(spikes).astype(numpy.uint64)
        amplitudes = numpy.concatenate(amplitudes).astype(numpy.double)
        clusters = numpy.concatenate(clusters).astype(numpy.uint32)

        idx = numpy.argsort(spikes)
        numpy.save(os.path.join(output_path, 'spike_templates'), clusters[idx])
        numpy.save(os.path.join(output_path, 'spike_times'), spikes[idx])
        numpy.save(os.path.join(output_path, 'amplitudes'), amplitudes[idx])
        return
示例#16
0
    def set_streams(self, stream_mode):

        # We assume that all names are in the forms XXXX_channel.ncs
        if stream_mode == 'multi-files':
            dirname = os.path.abspath(os.path.dirname(self.file_name))
            fname = os.path.basename(self.file_name)
            fn, ext = os.path.splitext(fname)
            tmp_all_files = os.listdir(dirname)
            tmp_all_files = filter_per_extension(tmp_all_files, ext)
            tmp_all_files.sort(key=natural_keys)
            all_files = filter_name_duplicates(tmp_all_files,
                                               self.params['ncs_pattern'])

            sources = []
            to_write = []
            global_time = 0
            params = self.get_description()

            for fname in all_files:
                params['ncs_pattern'] = '_'.join(fname.split('_')[:-1])
                new_data = type(self)(os.path.join(os.path.abspath(dirname),
                                                   fname), params)
                new_data._t_start = global_time
                global_time += new_data.duration
                sources += [new_data]
                to_write += [
                    "We found the datafile %s with t_start %s and duration %s"
                    % (new_data.file_name, new_data.t_start, new_data.duration)
                ]

            print_and_log(to_write, 'debug', logger)
            return sources

        elif stream_mode == 'multi-folders':
            dirname = os.path.abspath(os.path.dirname(self.file_name))
            upper_dir = os.path.dirname(dirname)
            fname = os.path.basename(self.file_name)

            all_directories = os.listdir(upper_dir)
            all_files = []

            for local_dir in all_directories:
                local_dir = os.path.join(upper_dir, local_dir)
                if os.path.isdir(local_dir):
                    all_local_files = os.listdir(local_dir)
                    for local_file in all_local_files:
                        ncs_file = os.path.join(upper_dir, local_dir,
                                                local_file)
                        is_valid = len(
                            re.findall(
                                ".*_%s_1.ncs" % self.params['ncs_pattern'],
                                ncs_file)) > 0
                        if is_valid and ncs_file not in all_files:
                            all_files += [ncs_file]

            all_files.sort(key=natural_keys)

            sources = []
            to_write = []
            global_time = 0
            params = self.get_description()

            for fname in all_files:
                params['ncs_pattern'] = self.params['ncs_pattern']
                new_data = type(self)(os.path.join(os.path.abspath(dirname),
                                                   fname), params)
                new_data._t_start = global_time
                global_time += new_data.duration
                sources += [new_data]
                to_write += [
                    'We found the datafile %s with t_start %s and duration %s'
                    % (new_data.file_name, new_data.t_start, new_data.duration)
                ]

            print_and_log(to_write, 'debug', logger)
            return sources

        elif stream_mode == 'mapping-file':

            if self.params['mapping_file'] != '':
                all_files = parse_ncs_mapping(self.params['mapping_file'])
            else:
                all_files = []

            sources = []
            to_write = []
            global_time = 0
            params = self.get_description()

            for count, fname in enumerate(all_files):
                dirname = os.path.abspath(os.path.dirname(fname[0]))
                params['idx_mapping'] = count
                new_data = type(self)(fname[0], params)
                new_data._t_start = global_time
                global_time += new_data.duration
                sources += [new_data]
                to_write += [
                    'We found the datafile %s with t_start %s and duration %s'
                    % (new_data.file_name, new_data.t_start, new_data.duration)
                ]

            print_and_log(to_write, 'debug', logger)
            return sources
示例#17
0
def main(argv=None):

    if argv is None:
        argv = sys.argv[1:]

    header = get_colored_header()
    header += '''Utility to launch the phy GUI and visualize the results. 
[data must be first converted with the converting mode]
    '''
    parser = argparse.ArgumentParser(
        description=header, formatter_class=argparse.RawTextHelpFormatter)
    parser.add_argument('datafile', help='data file')
    parser.add_argument('-e',
                        '--extension',
                        help='extension to consider for visualization',
                        default='')

    if len(argv) == 0:
        parser.print_help()
        sys.exit()

    args = parser.parse_args(argv)
    filename = os.path.abspath(args.datafile)
    extension = args.extension
    params = CircusParser(filename)
    if os.path.exists(params.logfile):
        os.remove(params.logfile)
    _ = init_logging(params.logfile)
    logger = logging.getLogger(__name__)

    if extension != '':
        extension = '-' + extension

    try:
        import traitlets
    except ImportError:
        print_and_log(
            ['The package traitlets required by phy is not installed'],
            'error', logger)
        sys.exit(1)

    try:
        import click
    except ImportError:
        print_and_log(['The package click required by phy is not installed'],
                      'error', logger)
        sys.exit(1)

    try:
        import joblib
    except ImportError:
        print_and_log(['The package joblib required by phy is not installed'],
                      'error', logger)
        sys.exit(1)

    if HAVE_PHYCONTRIB:
        mytest = StrictVersion(
            phycontrib.__version__) >= StrictVersion("1.0.12")
        if not mytest:
            print_and_log(
                ['You need to update phy-contrib to the latest git version'],
                'error', logger)
            sys.exit(1)

        print_and_log([
            'phy-contrib is deprecated, you should upgrade to phy 2.0 and phylib'
        ], 'info', logger)

    if HAVE_PHYLIB:
        try:
            import colorcet
        except ImportError:
            print_and_log(
                ['The package colorcet required by phy is not installed'],
                'error', logger)
            sys.exit(1)

        try:
            import qtconsole
        except ImportError:
            print_and_log(
                ['The package qtconsole required by phy is not installed'],
                'error', logger)
            sys.exit(1)

    data_file = params.get_data_file()
    data_dtype = data_file.data_dtype
    if 'data_offset' in data_file.params:
        data_offset = data_file.data_offset
    else:
        data_offset = 0

    file_format = data_file.description
    file_out_suff = params.get('data', 'file_out_suff')

    if file_format not in supported_by_phy:
        print_and_log([
            "File format %s is not supported by phy. TraceView disabled" %
            file_format
        ], 'info', logger)

    if numpy.iterable(data_file.gain):
        print_and_log(
            ['Multiple gains are not supported, using a default value of 1'],
            'info', logger)
        gain = 1
    else:
        if data_file.gain != 1:
            print_and_log(
                ["Gain is not supported by phy. Expecting a scaling mismatch"],
                'info', logger)
            gain = data_file.gain

    probe = params.probe
    output_path = params.get('data', 'file_out_suff') + extension + '.GUI'

    if not os.path.exists(output_path):
        print_and_log(
            ['Data should be first exported with the converting method!'],
            'error', logger)
    else:

        print_and_log(["Launching the phy GUI..."], 'info', logger)

        gui_params = {}
        if file_format in supported_by_phy:
            if not params.getboolean('data', 'overwrite'):
                gui_params['dat_path'] = r"%s" % params.get(
                    'data', 'data_file_no_overwrite')
            else:
                if params.get('data', 'stream_mode') == 'multi-files':
                    data_file = params.get_data_file(source=True,
                                                     has_been_created=False)
                    gui_params['dat_path'] = [
                        r"%s" % f for f in data_file.get_file_names()
                    ]
                else:
                    gui_params['dat_path'] = r"%s" % params.get(
                        'data', 'data_file')
        else:
            gui_params['dat_path'] = 'giverandomname.dat'

        data_file.close()
        gui_params['n_channels_dat'] = params.nb_channels
        gui_params['n_features_per_channel'] = 5
        gui_params['dtype'] = data_dtype
        gui_params['offset'] = data_offset
        gui_params['sample_rate'] = params.rate
        gui_params['dir_path'] = output_path
        gui_params['hp_filtered'] = True

        os.chdir(output_path)
        create_app()
        controller = TemplateController(**gui_params)
        gui = controller.create_gui()

        gui.show()
        run_app()
        gui.close()
        del gui
示例#18
0
def main(argv=None):

    if argv is None:
        argv = sys.argv[1:]

    header = get_colored_header()
    header += '''Utility to concatenate artefacts/dead times before using 
stream mode. Code will look for .dead and .trig files, and 
concatenate them automatically taking care of file offsets
    '''
    parser = argparse.ArgumentParser(
        description=header, formatter_class=argparse.RawTextHelpFormatter)
    parser.add_argument('datafile', help='data file')
    # parser.add_argument('-w', '--window', help='text file with artefact window files',
    #                     default=None)

    if len(argv) == 0:
        parser.print_help()
        sys.exit()

    args = parser.parse_args(argv)
    # if args.window is None:
    #     window_file = None
    # else:
    #     window_file = os.path.abspath(args.window)

    filename = os.path.abspath(args.datafile)
    params = CircusParser(filename)
    dead_in_ms = params.getboolean('triggers', 'dead_in_ms')
    trig_in_ms = params.getboolean('triggers', 'trig_in_ms')

    if os.path.exists(params.logfile):
        os.remove(params.logfile)

    _ = init_logging(params.logfile)
    logger = logging.getLogger(__name__)

    if params.get('data', 'stream_mode') == 'multi-files':
        data_file = params.get_data_file(source=True, has_been_created=False)
        all_times_dead = numpy.zeros((0, 2), dtype=numpy.int64)
        all_times_trig = numpy.zeros((0, 2), dtype=numpy.int64)

        for f in data_file._sources:
            name, ext = os.path.splitext(f.file_name)
            dead_file = f.file_name.replace(ext, '.dead')
            trig_file = f.file_name.replace(ext, '.trig')

            if os.path.exists(dead_file):
                print_and_log(['Found file %s' % dead_file], 'default', logger)
                times = get_dead_times(dead_file, data_file.sampling_rate,
                                       dead_in_ms)
                if times.max() > f.duration or times.min() < 0:
                    print_and_log([
                        'Dead zones larger than duration for file %s' %
                        f.file_name, '-> Clipping automatically'
                    ], 'error', logger)
                    times = numpy.minimum(times, f.duration)
                    times = numpy.maximum(times, 0)
                times += f.t_start
                all_times_dead = numpy.vstack((all_times_dead, times))

            if os.path.exists(trig_file):
                print_and_log(['Found file %s' % trig_file], 'default', logger)

                times = get_trig_times(trig_file, data_file.sampling_rate,
                                       trig_in_ms)
                if times[:, 1].max() > f.duration or times[:, 1].min() < 0:
                    print_and_log([
                        'Triggers larger than duration for file %s' %
                        f.file_name
                    ], 'error', logger)
                    sys.exit(0)
                times[:, 1] += f.t_start
                all_times_trig = numpy.vstack((all_times_trig, times))

        if len(all_times_dead) > 0:
            output_file = os.path.join(os.path.dirname(filename),
                                       'dead_zones.txt')
            print_and_log(['Saving global artefact file in %s' % output_file],
                          'default', logger)
            if dead_in_ms:
                all_times_dead = all_times_dead.astype(
                    numpy.float32) / data_file.sampling_rate
            numpy.savetxt(output_file, all_times_dead)

        if len(all_times_trig) > 0:
            output_file = os.path.join(os.path.dirname(filename),
                                       'triggers.txt')
            print_and_log(['Saving global artefact file in %s' % output_file],
                          'default', logger)
            if trig_in_ms:
                all_times_trig = all_times_trig.astype(
                    numpy.float32) / data_file.sampling_rate
            numpy.savetxt(output_file, all_times_trig)

    elif params.get('data', 'stream_mode') == 'single-file':
        print_and_log(['Not implemented'], 'error', logger)
        sys.exit(0)
    else:
        print_and_log(
            ['You should select a valid stream_mode such as multi-files'],
            'error', logger)
        sys.exit(0)
示例#19
0
    def remove_artefacts(data_file, art_dict):

        chunk_size     = params.getint('data', 'chunk_size')
        trig_in_ms     = params.getboolean('triggers', 'trig_in_ms')
        artefacts      = numpy.loadtxt(params.get('triggers', 'trig_file'))
        windows        = numpy.loadtxt(params.get('triggers', 'trig_windows'))
        make_plots     = params.get('triggers', 'make_plots')
        plot_path      = os.path.join(params.get('data', 'file_out_suff'), 'plots')

        if len(windows.shape) == 1:
            windows = windows.reshape(1, 2)

        if len(artefacts.shape) == 1:
            artefacts = artefacts.reshape(1, 2)

        if trig_in_ms:
            if comm.rank == 0:
                print_and_log(['Artefact times are read in ms'], 'debug', logger)
            artefacts[:, 1] *= numpy.int64(data_file.sampling_rate*1e-3)
            windows[:, 1]   *= numpy.int64(data_file.sampling_rate*1e-3)
        else:
            if comm.rank == 0:
                print_and_log(['Artefact times are read in timesteps'], 'debug', logger)

        artefacts        = artefacts.astype(numpy.int64)
        windows          = windows.astype(numpy.int64)
        nb_stimuli       = len(numpy.unique(artefacts[:, 0]))
        mytest           = nb_stimuli == len(windows)

        if not mytest:
            if comm.rank == 0:
                print_and_log(['Error in the trigger files'], 'error', logger)
            sys.exit(0)

        all_labels   = artefacts[:, 0]
        all_times    = artefacts[:, 1]
        local_labels = numpy.unique(all_labels)[comm.rank::comm.size]

        mask       = numpy.in1d(all_labels, local_labels)
        all_times  = numpy.compress(mask, all_times)
        all_labels = numpy.compress(mask, all_labels)

        mask       = (all_times >= 0) & (all_times < data_file.t_stop)
        all_times  = numpy.compress(mask, all_times)
        all_labels = numpy.compress(mask, all_labels)

        if comm.rank == 0:
            to_write = ["Removing artefacts from %d stimuli" %(nb_stimuli)]
            print_and_log(to_write, 'default', logger)
            all_times = get_tqdm_progressbar(all_times)

        comm.Barrier()

        for count, time in enumerate(all_times):

            label = all_labels[count]
            tmp   = numpy.where(windows[:, 0] == label)[0][0]
            tau   = numpy.int64(windows[tmp, 1])

            if (data_file.t_stop - time) < tau:
                tau   = max_offset - time

            local_chunk   = data_file.get_snippet(time, tau)
            for idx, i in enumerate(nodes):
                local_chunk[:, i] -= art_dict[label][idx, :tau]
            data_file.set_data(time, local_chunk)

        comm.Barrier()
        sys.stderr.flush()
示例#20
0
 def _check_filename(self, file_name):
     if not os.path.exists(file_name):
         if self.is_master:
             print_and_log(["The file %s can not be found!" % file_name],
                           'error', logger)
         sys.exit(1)
示例#21
0
    def filter_file(data_file_in, data_file_out, do_filtering, do_remove_median, do_remove_ground):

        try:
            cut_off    = params.getfloat('filtering', 'cut_off')
            cut_off    = [cut_off, 0.95*(params.rate/2.)]
        except Exception:
            cut_off        = params.get('filtering', 'cut_off')
            cut_off        = cut_off.split(',')
            try:
                cut_off[0] = float(cut_off[0])
            except Exception:
                if comm.rank == 0:
                    print_and_log(['First value of cut off must be a valid number'], 'error', logger)
                sys.exit(0)

            cut_off[1] = cut_off[1].replace(' ', '')
            if cut_off[1] == 'auto':
                cut_off[1] = 0.95*(params.rate/2.)
            else:
                try:
                    cut_off[1] = float(cut_off[1])
                except Exception:
                    if comm.rank == 0:
                        print_and_log(['Second value of cut off must either auto, or a valid a number'], 'error', logger)
                    sys.exit(0)

        chunk_size    = params.getint('data', 'chunk_size')
        nb_chunks, _  = data_file_in.analyze(chunk_size)

        b, a          = signal.butter(3, np.array(cut_off)/(params.rate/2.), 'pass')
        all_chunks    = numpy.arange(nb_chunks, dtype=numpy.int64)
        to_process    = all_chunks[comm.rank::comm.size]
        loc_nb_chunks = len(to_process)
        N_total       = params.nb_channels
        process_all_channels = numpy.all(nodes == numpy.arange(N_total))

        if comm.rank == 0:
            to_write = []
            if do_filtering:
                to_write += ["Filtering the signal with a Butterworth filter in (%g, %g) Hz" %(cut_off[0],cut_off[1])]
            if do_remove_median:
                to_write += ["Median over all channels is subtracted to each channels"]
            if do_remove_ground:
                to_write += ["Channel %s is used as a reference channel" %common_ground]

            print_and_log(to_write, 'default', logger)

        to_explore = xrange(comm.rank, nb_chunks, comm.size)

        if comm.rank == 0:
            to_explore = get_tqdm_progressbar(to_explore)

        for count, gidx in enumerate(to_explore):

            local_chunk, t_offset =  data_file_in.get_data(gidx, chunk_size)

            if do_filtering:
                for i in nodes:    
                    try:
                        local_chunk[:, i] = signal.filtfilt(b, a, local_chunk[:, i])
                    except Exception:
                        pass
                    local_chunk[:, i] -= numpy.median(local_chunk[:, i]) 

            if do_remove_median:
                if not process_all_channels:
                    global_median = numpy.median(numpy.take(local_chunk, nodes, axis=1), 1)
                else:
                    global_median = numpy.median(local_chunk, 1)

                for i in nodes:
                    local_chunk[:, i] -= global_median

            if common_ground > -1:
                for i in nodes:
                    local_chunk[:, i] -= local_chunk[:, common_ground]

            if data_file_in != data_file_out and data_file_in.is_first_chunk(gidx, nb_chunks):
                if data_file_in.is_stream:
                    g_offset = t_offset - numpy.sum(data_file_in._times[:data_file_in._get_streams_index_by_time(t_offset)+1])
                else:
                    g_offset = t_offset - data_file_in.t_start
            else:
                g_offset = t_offset

            data_file_out.set_data(g_offset, local_chunk)

        sys.stderr.flush()
        comm.Barrier()
示例#22
0
    def __init__(self, file_name, params, is_empty=False, stream_mode=None):
        '''
        The constructor that will create the DataFile object. Note that by default, values are read from the header
        of the file. If not found in the header, they are read from the parameter file. If no values are found, the
        code will trigger an error

        What you need to specify at a generic level (for a given file format)
            - parallel_write  : can the file be safely written in parallel ?
            - is_writable     : if the file can be written
            - is_streamable   : if the file format can support streaming data
            - required_fields : what parameter must be specified for the file format, along with the type
            - default_values  : parameters that may have default values if not provided

        What you need to specify at a low level (maybe by getting specific values with _read_from_header)
            - _shape          : the size of the data, should be a tuple (duration in time bins, nb_channels)
            - _t_start        : the time (in time steps) of the recording (0 by default)
        '''

        self.params = {}
        self.params.update(self._params)

        if not is_empty:
            self._check_filename(file_name)

        if stream_mode is not None:
            self.is_stream = True
            if not stream_mode in self.is_streamable:
                if self.is_master:
                    print_and_log([
                        "The file format %s does not support stream mode %s" %
                        (self.description, stream_mode)
                    ], 'error', logger)
                sys.exit(1)
            if is_empty:
                if self.is_master:
                    print_and_log(
                        ["A datafile can not have streams and be empty!"],
                        'error', logger)
                sys.exit(1)
        else:
            self.is_stream = False

        self.file_name = file_name
        self.is_empty = is_empty
        self.stream_mode = stream_mode

        f_next, extension = os.path.splitext(self.file_name)

        self._check_extension(extension)
        self._fill_from_params(params)

        if not self.is_empty:
            #try:
            self._fill_from_header(self._read_from_header())
            #except Exception as ex:
            #    print_and_log(["There is an error in the _read_from_header method of the wrapper\n" + str(ex)], 'error', logger)
        else:
            self._shape = (0, 0)

        if self._shape is None:
            if self.is_master:
                print_and_log([
                    "Shape of the data is not defined. Are you sure of the wrapper?"
                ], 'error', logger)
            sys.exit(1)

        self.params['dtype_offset'] = get_offset(self.data_dtype,
                                                 self.dtype_offset)

        if self.stream_mode:
            self._sources = self.set_streams(self.stream_mode)
            self._times = []
            for source in self._sources:
                self._times += [source.t_start]
            print_and_log([
                'The file is composed of %d streams' % len(self._sources),
                'Times are between %d and %d' %
                (self._sources[0].t_start, self._sources[-1].t_stop)
            ], 'debug', logger)
示例#23
0
def extract_juxta_spikes_(params):
    '''Detect spikes from the extracellular traces'''
    
    file_out_suff  = params.get('data', 'file_out_suff')
    sampling_rate  = params.getint('data', 'sampling_rate')
    dist_peaks     = params.getint('detection', 'dist_peaks')
    template_shift = params.getint('detection', 'template_shift')
    juxta_dtype    = params.get('validating', 'juxta_dtype')
    juxta_thresh   = params.getfloat('validating', 'juxta_thresh')
    juxta_valley   = params.getboolean('validating', 'juxta_valley')
    juxta_spikes   = params.get('validating', 'juxta_spikes')

    juxta_filename = "{}.juxta.dat".format(file_out_suff)
    beer_path = "{}.beer.hdf5".format(file_out_suff)


    if juxta_spikes == '':
            
        # Read juxtacellular trace.
        juxta_data = numpy.fromfile(juxta_filename, dtype=juxta_dtype)
        #juxta_data = juxta_data.astype(numpy.float32)
        # juxta_data = juxta_data - dtype_offset
        juxta_data = numpy.ascontiguousarray(juxta_data)
        
        # Filter juxtacellular trace.
        juxta_data  = highpass(juxta_data, sampling_rate=sampling_rate)
        juxta_data -= numpy.median(juxta_data)

        # Compute median and median absolute deviation.
        juxta_median = numpy.median(juxta_data)
        juxta_ad     = numpy.abs(juxta_data - juxta_median)
        juxta_mad    = numpy.median(juxta_ad, axis=0)
        
        # Save medians and median absolute deviations to BEER file.
        beer_file = h5py.File(beer_path, 'a', libver='latest')
        if "juxta_median" in beer_file.keys():
            beer_file.pop("juxta_median")
        beer_file.create_dataset("juxta_median", data=juxta_median)
        if "juxta_mad" in beer_file.keys():
            beer_file.pop("juxta_mad")
        beer_file.create_dataset("juxta_mad", data=juxta_mad)
        beer_file.close()

        if comm.rank == 0:
            print_and_log(["Extract juxtacellular spikes"], level='debug', logger=logger)
        
        # Detect juxta spike times.
        threshold = juxta_thresh * juxta_mad
        juxta_spike_times = algo.detect_peaks(juxta_data, threshold, valley=juxta_valley, mpd=dist_peaks)

        # Remove juxta spike times in the borders.
        juxta_spike_times = juxta_spike_times[template_shift <= juxta_spike_times]
        juxta_spike_times = juxta_spike_times[juxta_spike_times < juxta_data.size - template_shift]

    else:
        juxta_spike_times = numpy.load(juxta_spikes)
    
    # Save juxta spike times to BEER file.
    beer_file = h5py.File(beer_path, 'a', libver='latest')
    group_name = "juxta_spiketimes"
    if group_name in beer_file.keys():
        beer_file.pop(group_name)
    beer_file.create_group(group_name)
    key = "{}/elec_0".format(group_name)
    beer_file.create_dataset(key, data=juxta_spike_times)
    beer_file.close()
    
    

    # juxta_spike_values = numpy.zeros_like(juxta_spike_times, dtype='float')
    # for i, t in enumerate(juxta_spike_times):
    #     if juxta_valley:
    #         juxta_spike_values[i] = - juxta_data[t]
    #     else:
    #         juxta_spike_values[i] = + juxta_data[t]
    
    if juxta_spikes == '':

        # Find juxta spike values of juxta spike times.
        juxta_spike_values = juxta_data[juxta_spike_times]
        if juxta_valley:
            juxta_spike_values *= -1

        # Save juxta spike values to BEER file.
        beer_file = h5py.File(beer_path, 'a', libver='latest')
        group_name = "juxta_spike_values"
        if group_name in beer_file.keys():
            beer_file.pop(group_name)
        beer_file.create_group(group_name)
        key = "{}/elec_0".format(group_name)
        beer_file.create_dataset(key, data=juxta_spike_values)
        beer_file.close()

    return
示例#24
0
def main(params, nb_cpu, nb_gpu, use_gpu):

    #################################################################
    #params         = detect_memory(params)
    logger = init_logging(params.logfile)
    SHARED_MEMORY = get_shared_memory_flag(params)
    logger = logging.getLogger('circus.fitting')
    data_file = params.data_file
    N_e = params.getint('data', 'N_e')
    N_total = params.nb_channels
    N_t = params.getint('detection', 'N_t')
    template_shift = params.getint('detection', 'template_shift')
    file_out = params.get('data', 'file_out')
    file_out_suff = params.get('data', 'file_out_suff')
    sign_peaks = params.get('detection', 'peaks')
    matched_filter = params.getboolean('detection', 'matched-filter')
    spike_thresh = params.getfloat('detection', 'spike_thresh')
    spike_width = params.getfloat('detection', 'spike_width')
    dist_peaks = params.getint('detection', 'dist_peaks')
    do_temporal_whitening = params.getboolean('whitening', 'temporal')
    do_spatial_whitening = params.getboolean('whitening', 'spatial')
    chunk_size = detect_memory(params, fitting=True)
    gpu_only = params.getboolean('fitting', 'gpu_only')
    nodes, edges = get_nodes_and_edges(params)
    tmp_limits = params.get('fitting',
                            'amp_limits').replace('(',
                                                  '').replace(')',
                                                              '').split(',')
    tmp_limits = map(float, tmp_limits)
    amp_auto = params.getboolean('fitting', 'amp_auto')
    nb_chances = params.getint('fitting', 'nb_chances')
    max_chunk = params.getfloat('fitting', 'max_chunk')
    noise_thr = params.getfloat('clustering', 'noise_thr')
    collect_all = params.getboolean('fitting', 'collect_all')
    debug = params.getboolean('fitting', 'debug')
    ignore_dead_times = params.getboolean('triggers', 'ignore_times')
    inv_nodes = numpy.zeros(N_total, dtype=numpy.int32)
    inv_nodes[nodes] = numpy.arange(len(nodes))
    data_file.open()
    #################################################################

    if use_gpu:
        import cudamat as cmt
        ## Need to properly handle multi GPU per MPI nodes?
        if nb_gpu > nb_cpu:
            gpu_id = int(comm.rank // nb_cpu)
        else:
            gpu_id = 0
        cmt.cuda_set_device(gpu_id)
        cmt.init()
        cmt.cuda_sync_threads()

    if SHARED_MEMORY:
        templates = io.load_data_memshared(params,
                                           'templates',
                                           normalize=True,
                                           transpose=True)
        N_tm, x = templates.shape
    else:
        templates = io.load_data(params, 'templates')
        x, N_tm = templates.shape

    temp_2_shift = 2 * template_shift
    temp_3_shift = 3 * template_shift
    full_gpu = use_gpu and gpu_only
    n_tm = N_tm // 2
    n_scalar = N_e * N_t

    temp_window = numpy.arange(-template_shift, template_shift + 1)
    size_window = N_e * (2 * template_shift + 1)

    if not amp_auto:
        amp_limits = numpy.zeros((n_tm, 2))
        amp_limits[:, 0] = tmp_limits[0]
        amp_limits[:, 1] = tmp_limits[1]
    else:
        amp_limits = io.load_data(params, 'limits')

    norm_templates = io.load_data(params, 'norm-templates')

    if not SHARED_MEMORY:
        for idx in xrange(templates.shape[1]):
            myslice = numpy.arange(templates.indptr[idx],
                                   templates.indptr[idx + 1])
            templates.data[myslice] /= norm_templates[idx]
        templates = templates.T

    if matched_filter:
        if sign_peaks in ['negative', 'both']:
            waveform_neg = io.load_data(params, 'waveform')[::-1]
            waveform_neg /= (numpy.abs(numpy.sum(waveform_neg)) *
                             len(waveform_neg))
            matched_tresholds_neg = io.load_data(params, 'matched-thresholds')
        if sign_peaks in ['positive', 'both']:
            waveform_pos = io.load_data(params, 'waveform-pos')[::-1]
            waveform_pos /= (numpy.abs(numpy.sum(waveform_pos)) *
                             len(waveform_pos))
            matched_tresholds_pos = io.load_data(params,
                                                 'matched-thresholds-pos')

    if ignore_dead_times:
        all_dead_times = get_dead_times(params)

    thresholds = io.load_data(params, 'thresholds')

    if collect_all:
        neighbors = {}
        for i in xrange(n_tm):
            tmp = templates[i, :].toarray().reshape(N_e,
                                                    N_t) * norm_templates[i]
            neighbors[i] = numpy.where(numpy.sum(tmp, 1) != 0)[0]

    if use_gpu:
        templates = cmt.SparseCUDAMatrix(templates, copy_on_host=False)

    info_string = ''

    if comm.rank == 0:
        if use_gpu:
            info_string = "using %d GPUs" % (comm.size)
        else:
            info_string = "using %d CPUs" % (comm.size)

    comm.Barrier()

    c_overlap = io.get_overlaps(params,
                                nb_cpu=nb_cpu,
                                nb_gpu=nb_gpu,
                                use_gpu=use_gpu)
    over_shape = c_overlap.get('over_shape')[:]
    N_over = int(numpy.sqrt(over_shape[0]))
    S_over = over_shape[1]
    ## If the number of overlaps is different from templates, we need to recompute them
    if N_over != N_tm:
        if comm.rank == 0:
            print_and_log(
                ['Templates have been modified, recomputing the overlaps...'],
                'default', logger)
        c_overlap = io.get_overlaps(params,
                                    erase=True,
                                    nb_cpu=nb_cpu,
                                    nb_gpu=nb_gpu,
                                    use_gpu=use_gpu)
        over_shape = c_overlap.get('over_shape')[:]
        N_over = int(numpy.sqrt(over_shape[0]))
        S_over = over_shape[1]

    if SHARED_MEMORY:
        c_overs = io.load_data_memshared(params, 'overlaps')
    else:
        c_overs = io.load_data(params, 'overlaps')

    comm.Barrier()

    if n_tm == 0:
        if comm.rank == 0:
            print_and_log(["No templates present. Redo clustering?"],
                          'default', logger)

        sys.exit(0)

    if comm.rank == 0:
        print_and_log([
            "Here comes the SpyKING CIRCUS %s and %d templates..." %
            (info_string, n_tm)
        ], 'default', logger)
        purge(file_out_suff, '.data')

    if do_spatial_whitening:
        spatial_whitening = io.load_data(params, 'spatial_whitening')
    if do_temporal_whitening:
        temporal_whitening = io.load_data(params, 'temporal_whitening')

    if full_gpu:
        try:
            # If memory on the GPU is large enough, we load the overlaps onto it
            for i in xrange(N_over):
                c_overs[i] = cmt.SparseCUDAMatrix(c_overs[i],
                                                  copy_on_host=False)
        except Exception:
            if comm.rank == 0:
                print_and_log([
                    "Not enough memory on GPUs: GPUs are used for projection only"
                ], 'info', logger)
            for i in xrange(N_over):
                if c_overs.has_key(i):
                    del c_overs[i]
            full_gpu = False

    nb_chunks, last_chunk_len = data_file.analyze(chunk_size)
    processed_chunks = int(min(nb_chunks, max_chunk))

    comm.Barrier()
    spiketimes_file = open(file_out_suff + '.spiketimes-%d.data' % comm.rank,
                           'wb')
    comm.Barrier()
    amplitudes_file = open(file_out_suff + '.amplitudes-%d.data' % comm.rank,
                           'wb')
    comm.Barrier()
    templates_file = open(file_out_suff + '.templates-%d.data' % comm.rank,
                          'wb')
    comm.Barrier()

    if collect_all:
        garbage_times_file = open(
            file_out_suff + '.gspiketimes-%d.data' % comm.rank, 'wb')
        comm.Barrier()
        garbage_temp_file = open(
            file_out_suff + '.gtemplates-%d.data' % comm.rank, 'wb')
        comm.Barrier()

    if debug:
        # Open debug files.
        chunk_nbs_debug_file = open(file_out_suff +
                                    '.chunk_nbs_debug_%d.data' % comm.rank,
                                    mode='wb')
        comm.Barrier()
        iteration_nbs_debug_file = open(
            file_out_suff + '.iteration_nbs_debug_%d.data' % comm.rank,
            mode='wb')
        comm.Barrier()
        peak_nbs_debug_file = open(file_out_suff +
                                   '.peak_nbs_debug_%d.data' % comm.rank,
                                   mode='wb')
        comm.Barrier()
        peak_local_time_steps_debug_file = open(
            file_out_suff + '.peak_local_time_steps_debug_%d.data' % comm.rank,
            mode='wb')
        comm.Barrier()
        peak_time_steps_debug_file = open(
            file_out_suff + '.peak_time_steps_debug_%d.data' % comm.rank,
            mode='wb')
        comm.Barrier()
        peak_scalar_products_debug_file = open(
            file_out_suff + '.peak_scalar_products_debug_%d.data' % comm.rank,
            mode='wb')
        comm.Barrier()
        peak_solved_flags_debug_file = open(
            file_out_suff + '.peak_solved_flags_debug_%d.data' % comm.rank,
            mode='wb')
        comm.Barrier()
        template_nbs_debug_file = open(
            file_out_suff + '.template_nbs_debug_%d.data' % comm.rank,
            mode='wb')
        comm.Barrier()
        success_flags_debug_file = open(
            file_out_suff + '.success_flags_debug_%d.data' % comm.rank,
            mode='wb')
        comm.Barrier()
    else:
        chunk_nbs_debug_file = None
        iteration_nbs_debug_file = None
        peak_nbs_debug_file = None
        peak_local_time_steps_debug_file = None
        peak_time_steps_debug_file = None
        peak_scalar_products_debug_file = None
        peak_solved_flags_debug_file = None
        template_nbs_debug_file = None
        success_flags_debug_file = None

    if use_gpu and do_spatial_whitening:
        spatial_whitening = cmt.CUDAMatrix(spatial_whitening,
                                           copy_on_host=False)

    last_chunk_size = 0

    to_explore = xrange(comm.rank, processed_chunks, comm.size)

    if comm.rank == 0:
        to_explore = get_tqdm_progressbar(to_explore)

    for gcount, gidx in enumerate(to_explore):
        #print "Node", comm.rank, "is analyzing chunk", gidx, "/", nb_chunks, " ..."
        ## We need to deal with the borders by taking chunks of size [0, chunck_size+template_shift]

        is_first = data_file.is_first_chunk(gidx, nb_chunks)
        is_last = data_file.is_last_chunk(gidx, nb_chunks)

        if is_last:
            padding = (-temp_3_shift, 0)
        elif is_first:
            padding = (0, temp_3_shift)
        else:
            padding = (-temp_3_shift, temp_3_shift)

        result = {
            'spiketimes': [],
            'amplitudes': [],
            'templates': [],
        }
        result_debug = {
            'chunk_nbs': [],
            'iteration_nbs': [],
            'peak_nbs': [],
            'peak_local_time_steps': [],
            'peak_time_steps': [],
            'peak_scalar_products': [],
            'peak_solved_flags': [],
            'template_nbs': [],
            'success_flags': [],
        }

        local_chunk, t_offset = data_file.get_data(gidx,
                                                   chunk_size,
                                                   padding,
                                                   nodes=nodes)
        len_chunk = len(local_chunk)

        if do_spatial_whitening:
            if use_gpu:
                local_chunk = cmt.CUDAMatrix(local_chunk, copy_on_host=False)
                local_chunk = local_chunk.dot(spatial_whitening).asarray()
            else:
                local_chunk = numpy.dot(local_chunk, spatial_whitening)
        if do_temporal_whitening:
            local_chunk = scipy.ndimage.filters.convolve1d(local_chunk,
                                                           temporal_whitening,
                                                           axis=0,
                                                           mode='constant')

        #print "Extracting the peaks..."

        if collect_all:
            all_found_spikes = {}
            for i in xrange(N_e):
                all_found_spikes[i] = []

        local_peaktimes = numpy.zeros(0, dtype=numpy.uint32)

        if matched_filter:
            if sign_peaks in ['positive', 'both']:
                filter_chunk = scipy.ndimage.filters.convolve1d(
                    local_chunk, waveform_pos, axis=0, mode='constant')
                for i in xrange(N_e):
                    peaktimes = scipy.signal.find_peaks(
                        filter_chunk[:, i], height=matched_tresholds_pos[i])[0]
                    local_peaktimes = numpy.concatenate(
                        (local_peaktimes, peaktimes))
                    if collect_all:
                        all_found_spikes[i] += peaktimes.tolist()
            if sign_peaks in ['negative', 'both']:
                filter_chunk = scipy.ndimage.filters.convolve1d(
                    local_chunk, waveform_neg, axis=0, mode='constant')
                for i in xrange(N_e):
                    peaktimes = scipy.signal.find_peaks(
                        filter_chunk[:, i], height=matched_tresholds_neg[i])[0]
                    local_peaktimes = numpy.concatenate(
                        (local_peaktimes, peaktimes))
                    if collect_all:
                        all_found_spikes[i] += peaktimes.tolist()
        else:
            for i in xrange(N_e):
                if sign_peaks == 'negative':
                    peaktimes = scipy.signal.find_peaks(
                        -local_chunk[:, i], height=thresholds[i])[0]
                elif sign_peaks == 'positive':
                    peaktimes = scipy.signal.find_peaks(
                        local_chunk[:, i], height=thresholds[i])[0]
                elif sign_peaks == 'both':
                    peaktimes = scipy.signal.find_peaks(
                        numpy.abs(local_chunk[:, i]), height=thresholds[i])[0]
                local_peaktimes = numpy.concatenate(
                    (local_peaktimes, peaktimes))
                if collect_all:
                    all_found_spikes[i] += peaktimes.tolist()

        local_peaktimes = numpy.unique(local_peaktimes)

        g_offset = t_offset + padding[0]

        if ignore_dead_times:
            dead_indices = numpy.searchsorted(
                all_dead_times, [t_offset, t_offset + chunk_size])
            if dead_indices[0] != dead_indices[1]:
                is_included = numpy.in1d(
                    local_peaktimes + g_offset,
                    all_dead_times[dead_indices[0]:dead_indices[1]])
                local_peaktimes = local_peaktimes[~is_included]
                local_peaktimes = numpy.sort(local_peaktimes)

        #print "Removing the useless borders..."
        local_borders = (template_shift, len_chunk - template_shift)
        idx = (local_peaktimes >= local_borders[0]) & (local_peaktimes <
                                                       local_borders[1])
        local_peaktimes = numpy.compress(idx, local_peaktimes)

        if collect_all:
            for i in xrange(N_e):
                all_found_spikes[i] = numpy.array(all_found_spikes[i],
                                                  dtype=numpy.uint32)

                if ignore_dead_times:
                    if dead_indices[0] != dead_indices[1]:
                        is_included = numpy.in1d(
                            all_found_spikes[i] + g_offset,
                            all_dead_times[dead_indices[0]:dead_indices[1]])
                        all_found_spikes[i] = all_found_spikes[i][~is_included]
                        all_found_spikes[i] = numpy.sort(all_found_spikes[i])

                idx = (all_found_spikes[i] >= local_borders[0]) & (
                    all_found_spikes[i] < local_borders[1])
                all_found_spikes[i] = numpy.compress(idx, all_found_spikes[i])

        n_t = len(local_peaktimes)

        if full_gpu:
            #   all_indices = cmt.CUDAMatrix(all_indices)
            tmp_gpu = cmt.CUDAMatrix(local_peaktimes.reshape((1, n_t)),
                                     copy_on_host=False)

        if n_t > 0:
            #print "Computing the b (should full_gpu by putting all chunks on GPU if possible?)..."

            if collect_all:
                c_local_chunk = local_chunk.copy()

            local_chunk = local_chunk.T.ravel()
            sub_mat = numpy.zeros((size_window, n_t), dtype=numpy.float32)

            if len_chunk != last_chunk_size:
                slice_indices = numpy.zeros(0, dtype=numpy.int32)
                for idx in xrange(N_e):
                    slice_indices = numpy.concatenate(
                        (slice_indices, len_chunk * idx + temp_window))
                last_chunk_size = len_chunk

            for count, idx in enumerate(local_peaktimes):
                sub_mat[:, count] = numpy.take(local_chunk,
                                               slice_indices + idx)

            #snippet_norm = numpy.sum(sub_mat**2, 0)/n_scalar
            #sub_mat /= snippet_norm

            del local_chunk

            if use_gpu:
                sub_mat = cmt.CUDAMatrix(sub_mat, copy_on_host=False)
                b = cmt.sparse_dot(templates, sub_mat)
            else:
                b = templates.dot(sub_mat)

            del sub_mat

            local_restriction = (t_offset, t_offset + chunk_size)
            all_spikes = local_peaktimes + g_offset

            # Because for GPU, slicing by columns is more efficient, we need to transpose b
            #b           = b.transpose()
            if use_gpu and not full_gpu:
                b = b.asarray()

            failure = numpy.zeros(n_t, dtype=numpy.int32)

            if full_gpu:
                mask = numpy.zeros((2 * n_tm, n_t), dtype=numpy.float32)
                mask[:n_tm, :] = 1
                data = cmt.empty(mask.shape)
                patch_gpu = b.shape[1] == 1
            else:
                mask = numpy.ones((n_tm, n_t), dtype=numpy.bool)
                patch_gpu = None

            if collect_all:
                c_all_times = numpy.zeros((len_chunk, N_e), dtype=numpy.bool)
                c_min_times = numpy.maximum(
                    numpy.arange(len_chunk) - template_shift, 0)
                c_max_times = numpy.minimum(
                    numpy.arange(len_chunk) + template_shift + 1, len_chunk)
                for i in xrange(N_e):
                    c_all_times[all_found_spikes[i], i] = True

            iteration_nb = 0
            while numpy.mean(failure) < nb_chances:

                # Is there a way to update sub_b * mask at the same time?
                data = b[:n_tm, :] * mask
                best_template_index, peak_index = numpy.unravel_index(
                    data.argmax(), data.shape)
                best_template2_index = best_template_index + n_tm

                if full_gpu:
                    b_array = b.asarray()
                else:
                    b_array = None

                data = b[:n_tm, :] * mask
                peak_scalar_product = data.max()
                best_template_index, peak_index = numpy.unravel_index(
                    data.argmax(), data.shape)
                best_template2_index = best_template_index + n_tm

                if full_gpu:
                    best_amp = b_array[best_template_index,
                                       peak_index] / n_scalar
                    best_amp2 = b_array[best_template2_index,
                                        peak_index] / n_scalar
                else:
                    best_amp = b[best_template_index, peak_index] / n_scalar
                    best_amp2 = b[best_template2_index, peak_index] / n_scalar

                best_amp_n = best_amp / norm_templates[best_template_index]
                best_amp2_n = best_amp2 / norm_templates[best_template2_index]

                # Verify amplitude constraint.
                a_min, a_max = amp_limits[best_template_index, :]

                if (a_min <= best_amp_n) & (best_amp_n <= a_max):
                    # Keep the matching.
                    peak_time_step = local_peaktimes[peak_index]

                    peak_data = (local_peaktimes - peak_time_step).astype(
                        np.int32)
                    is_neighbor = np.where(
                        np.abs(peak_data) <= temp_2_shift)[0]
                    idx_neighbor = peak_data[is_neighbor] + temp_2_shift
                    nb_neighbors = len(is_neighbor)
                    indices = np.zeros((S_over, nb_neighbors), dtype=np.int32)
                    indices[idx_neighbor, np.arange(nb_neighbors)] = 1

                    if full_gpu:
                        indices = cmt.CUDAMatrix(indices, copy_on_host=False)
                        if patch_gpu:
                            b_lines = b.get_col_slice(0, b.shape[0])
                        else:
                            b_lines = b.get_col_slice(is_neighbor[0],
                                                      is_neighbor[-1] + 1)
                        tmp1 = cmt.sparse_dot(c_overs[best_template_index],
                                              indices,
                                              mult=-best_amp)
                        tmp2 = cmt.sparse_dot(c_overs[best_template2_index],
                                              indices,
                                              mult=-best_amp2)
                        b_lines.add(tmp1.add(tmp2))
                        del tmp1, tmp2
                    else:
                        tmp1 = c_overs[best_template_index].multiply(-best_amp)
                        tmp2 = c_overs[best_template2_index].multiply(
                            -best_amp2)
                        b[:, is_neighbor] += (tmp1 + tmp2).dot(indices)

                    # Add matching to the result.
                    t_spike = all_spikes[peak_index]

                    if (t_spike >= local_restriction[0]) and (
                            t_spike < local_restriction[1]):
                        result['spiketimes'] += [t_spike]
                        result['amplitudes'] += [(best_amp_n, best_amp2_n)]
                        result['templates'] += [best_template_index]
                    # Mark current matching as tried.
                    mask[best_template_index, peak_index] = False
                    # Save debug data.
                    if debug:
                        result_debug['chunk_nbs'] += [gidx]
                        result_debug['iteration_nbs'] += [iteration_nb]
                        result_debug['peak_nbs'] += [peak_index]
                        result_debug['peak_local_time_steps'] += [
                            local_peaktimes[peak_index]
                        ]
                        result_debug['peak_time_steps'] += [
                            all_spikes[peak_index]
                        ]
                        result_debug['peak_scalar_products'] += [
                            peak_scalar_product
                        ]
                        result_debug['peak_solved_flags'] += [
                            mask[best_template_index, peak_index]
                        ]
                        result_debug['template_nbs'] += [best_template_index]
                        result_debug['success_flags'] += [True]
                else:
                    # Reject the matching.
                    # Update failure counter of the peak.
                    failure[peak_index] += 1
                    # If the maximal number of failures is reached then mark peak as solved (i.e. not fitted).
                    if failure[peak_index] == nb_chances:
                        # Mark all the matching associated to the current peak as tried.
                        mask[:, peak_index] = False
                    else:
                        # Mark current matching as tried.
                        mask[best_template_index, peak_index] = False
                    # Save debug data.
                    if debug:
                        result_debug['chunk_nbs'] += [gidx]
                        result_debug['iteration_nbs'] += [iteration_nb]
                        result_debug['peak_nbs'] += [peak_index]
                        result_debug['peak_local_time_steps'] += [
                            local_peaktimes[peak_index]
                        ]
                        result_debug['peak_time_steps'] += [
                            all_spikes[peak_index]
                        ]
                        result_debug['peak_scalar_products'] += [
                            peak_scalar_product
                        ]
                        result_debug['peak_solved_flags'] += [
                            mask[best_template_index, peak_index]
                        ]
                        result_debug['template_nbs'] += [best_template_index]
                        result_debug['success_flags'] += [False]

                iteration_nb += 1

            spikes_to_write = numpy.array(result['spiketimes'],
                                          dtype=numpy.uint32)
            amplitudes_to_write = numpy.array(result['amplitudes'],
                                              dtype=numpy.float32)
            templates_to_write = numpy.array(result['templates'],
                                             dtype=numpy.uint32)

            spiketimes_file.write(spikes_to_write.tostring())
            amplitudes_file.write(amplitudes_to_write.tostring())
            templates_file.write(templates_to_write.tostring())

            if collect_all:

                for temp, spike in zip(templates_to_write,
                                       spikes_to_write - g_offset):
                    c_all_times[c_min_times[spike]:c_max_times[spike],
                                neighbors[temp]] = False

                gspikes = numpy.where(numpy.sum(c_all_times, 1) > 0)[0]
                c_all_times = numpy.take(c_all_times, gspikes, axis=0)
                c_local_chunk = numpy.take(c_local_chunk, gspikes,
                                           axis=0) * c_all_times

                if sign_peaks == 'negative':
                    bestlecs = numpy.argmin(c_local_chunk, 1)
                    if matched_filter:
                        threshs = -matched_tresholds_neg[bestlecs]
                    else:
                        threshs = -thresholds[bestlecs]
                    idx = numpy.where(numpy.min(c_local_chunk, 1) < threshs)[0]
                elif sign_peaks == 'positive':
                    bestlecs = numpy.argmax(c_local_chunk, 1)
                    if matched_filter:
                        threshs = matched_tresholds_pos[bestlecs]
                    else:
                        threshs = thresholds[bestlecs]
                    idx = numpy.where(numpy.max(c_local_chunk, 1) > threshs)[0]
                elif sign_peaks == 'both':
                    c_local_chunk = numpy.abs(c_local_chunk)
                    bestlecs = numpy.argmax(c_local_chunk, 1)
                    if matched_filter:
                        threshs = numpy.minimum(
                            matched_tresholds_neg[bestlecs],
                            matched_tresholds_pos[bestlecs])
                    else:
                        threshs = thresholds[bestlecs]
                    idx = numpy.where(numpy.max(c_local_chunk, 1) > threshs)[0]

                gspikes = numpy.take(gspikes, idx)
                bestlecs = numpy.take(bestlecs, idx)
                gspikes_to_write = numpy.array(gspikes + g_offset,
                                               dtype=numpy.uint32)
                gtemplates_to_write = numpy.array(bestlecs, dtype=numpy.uint32)

                garbage_times_file.write(gspikes_to_write.tostring())
                garbage_temp_file.write(gtemplates_to_write.tostring())

            if debug:
                # Write debug data to debug files.
                for field_label, field_dtype, field_file in [
                    ('chunk_nbs', numpy.uint32, chunk_nbs_debug_file),
                    ('iteration_nbs', numpy.uint32, iteration_nbs_debug_file),
                    ('peak_nbs', numpy.uint32, peak_nbs_debug_file),
                    ('peak_local_time_steps', numpy.uint32,
                     peak_local_time_steps_debug_file),
                    ('peak_time_steps', numpy.uint32,
                     peak_time_steps_debug_file),
                    ('peak_scalar_products', numpy.float32,
                     peak_scalar_products_debug_file),
                    ('peak_solved_flags', numpy.float32,
                     peak_solved_flags_debug_file),
                    ('template_nbs', numpy.uint32, template_nbs_debug_file),
                    ('success_flags', numpy.bool, success_flags_debug_file),
                ]:
                    field_to_write = numpy.array(result_debug[field_label],
                                                 dtype=field_dtype)
                    field_file.write(field_to_write.tostring())

            if full_gpu:
                del b, data

    sys.stderr.flush()

    spiketimes_file.flush()
    os.fsync(spiketimes_file.fileno())
    spiketimes_file.close()

    amplitudes_file.flush()
    os.fsync(amplitudes_file.fileno())
    amplitudes_file.close()

    templates_file.flush()
    os.fsync(templates_file.fileno())
    templates_file.close()

    if collect_all:

        garbage_temp_file.flush()
        os.fsync(garbage_temp_file.fileno())
        garbage_temp_file.close()

        garbage_times_file.flush()
        os.fsync(garbage_times_file.fileno())
        garbage_times_file.close()

    if debug:
        # Close debug files.
        for field_file in [
                chunk_nbs_debug_file,
                iteration_nbs_debug_file,
                peak_nbs_debug_file,
                peak_local_time_steps_debug_file,
                peak_time_steps_debug_file,
                peak_scalar_products_debug_file,
                peak_solved_flags_debug_file,
                template_nbs_debug_file,
                success_flags_debug_file,
        ]:
            field_file.flush()
            os.fsync(field_file.fileno())
            field_file.close()

    comm.Barrier()

    if comm.rank == 0:
        io.collect_data(comm.size, params, erase=True)

    data_file.close()
示例#25
0
def extract_extra_thresholds(params):
    """Compute the mean and the standard deviation for each extracellular channel"""
    
    data_file      = params.data_file
    data_file.open()

    chunk_size = params.getint('data', 'chunk_size')
    do_temporal_whitening = params.getboolean('whitening', 'temporal')
    do_spatial_whitening  = params.getboolean('whitening', 'spatial')
    N_total = params.nb_channels
    
    if do_spatial_whitening:
        spatial_whitening  = io.load_data(params, 'spatial_whitening')
    if do_temporal_whitening:
        temporal_whitening = io.load_data(params, 'temporal_whitening')
    
    #mpi_file = MPI.File()
    #mpi_input = mpi_file.Open(comm, data_filename, MPI.MODE_RDONLY)
    nb_chunks, last_chunk_len = data_file.analyze(chunk_size)
    nodes, _ = get_nodes_and_edges(params)
    N_elec = nodes.size
    
    def weighted_mean(weights, values):
        """Compute a weighted mean for the given values"""
        norm_weights = [float(weight) / float(sum(weights)) for weight in weights]
        weighted_values = [norm_weight * value for (norm_weight, value) in zip(norm_weights, values)]
        weighted_mean = sum(weighted_values)
        return weighted_mean
    
    def extract_median(chunk_size, gidx):
        """Extract the medians from a chunk of extracellular traces"""
        loc_chunk, _ = data_file.get_data(gidx, chunk_size, nodes=nodes)
        # Whiten signal.
        if do_spatial_whitening:
            loc_chunk = numpy.dot(loc_chunk, spatial_whitening)
        if do_temporal_whitening:
            loc_chunk = scipy.ndimage.filters.convolve1d(loc_chunk, temporal_whitening, axis=0, mode='constant')
        median = numpy.median(loc_chunk, axis=0)
        return median
    
    def extract_median_absolute_deviation(chunk_size, gidx, median):
        """Extract the median absolute deviations from a chunk of extracellular traces"""
        loc_chunk, _ = data_file.get_data(gidx, chunk_size, nodes=nodes)
        # Whiten signal.
        if do_spatial_whitening:
            loc_chunk = numpy.dot(loc_chunk, spatial_whitening)
        if do_temporal_whitening:
            loc_chunk = scipy.ndimage.filters.convolve1d(loc_chunk, temporal_whitening, axis=0, mode='constant')
        mad = numpy.median(numpy.abs(loc_chunk - median), axis=0)
        return mad
    
    # Distribute chunks over the CPUs.
    all_chunks = numpy.arange(nb_chunks)
    loc_all_chunks = all_chunks[comm.rank::comm.size]
    loc_nb_chunks = len(loc_all_chunks)
    
    loc_nbs_chunks = comm.gather(loc_nb_chunks, root=0)
    
    if comm.rank == 0:
        print_and_log(["Computing extracellular medians..."],
                         level='default', logger=logger)
    
    to_explore = xrange(comm.rank, nb_chunks, comm.size)

    if comm.rank == 0:
        to_explore = get_tqdm_progressbar(to_explore)
    
    medians = numpy.zeros((N_elec, loc_nb_chunks), dtype=numpy.float32)
    
    # For each chunk attributed to the current CPU.
    for count, gidx in enumerate(to_explore):
        gidx = all_chunks[gidx]
        medians[:, count] = extract_median(chunk_size, gidx)
    median = numpy.mean(medians, axis=1)
    
    comm.Barrier()
    
    medians = comm.gather(median, root=0)
    
    if comm.rank == 0:
        median = weighted_mean(loc_nbs_chunks, medians)
        
    # Broadcast medians to each CPU.
    median = comm.bcast(median, root=0)
    
    comm.Barrier()
    
    if comm.rank == 0:
        print_and_log(["Computing extracellular thresholds..."],
                         level='default', logger=logger)
    
    to_explore = xrange(comm.rank, nb_chunks, comm.size)

    if comm.rank == 0:
        to_explore = get_tqdm_progressbar(to_explore)
    
    mads = numpy.zeros((N_elec, loc_nb_chunks), dtype=numpy.float32)
    
    # For each chunk attributed to the current CPU.
    for count, gidx in enumerate(to_explore):
        gidx = all_chunks[gidx]
        mads[:, count] = extract_median_absolute_deviation(chunk_size, gidx, median)
    mad = numpy.mean(mads, axis=1)
    
    comm.Barrier()
    
    mads = comm.gather(mad, root=0)
    
    if comm.rank == 0:
        mad = weighted_mean(loc_nbs_chunks, mads)
        
    # Broadcast median absolute deviation to each CPU.
    mad = comm.bcast(mad, root=0)
    
    comm.Barrier()
    data_file.close()
    
    return median, mad
示例#26
0
def main(argv=None):

    if argv is None:
        argv = sys.argv[1:]

    header = get_colored_header()
    header += '''Utility to launch the MATLAB GUI and visualize the results
    '''
    parser = argparse.ArgumentParser(description=header,
                                     formatter_class=argparse.RawTextHelpFormatter)
    parser.add_argument('datafile', help='data file')
    parser.add_argument('-e', '--extension', help='extension to consider for visualization',
                        default='')

    if len(argv) == 0:
        parser.print_help()
        sys.exit()

    args = parser.parse_args(argv)

    filename = os.path.abspath(args.datafile)
    extension = args.extension
    params = CircusParser(filename)
    if os.path.exists(params.logfile):
        os.remove(params.logfile)
    _ = init_logging(params.logfile)
    logger = logging.getLogger(__name__)
    data_file = params.get_data_file()
    data_dtype = data_file.data_dtype
    gain = data_file.gain
    t_start = data_file.t_start
    file_format = data_file.description

    if file_format not in supported_by_matlab:
        print_and_log(["File format %s is not supported by MATLAB. Waveforms disabled" % file_format], 'info', logger)

    if numpy.iterable(gain):
        print_and_log(['Multiple gains are not supported, using a default value of 1'], 'info', logger)
        gain = 1

    file_out_suff  = params.get('data', 'file_out_suff')
    if 'data_offset' in data_file.params:
        data_offset = data_file.data_offset
    else:
        data_offset = 0
    probe = params.probe
    if extension != '':
        extension = '-' + extension

    def generate_matlab_mapping(probe):
        p = {}
        positions = []
        nodes = []
        for key in list(probe['channel_groups'].keys()):
            p.update(probe['channel_groups'][key]['geometry'])
            nodes += probe['channel_groups'][key]['channels']
            positions += [p[channel] for channel in probe['channel_groups'][key]['channels']]
        idx = numpy.argsort(nodes)
        positions = numpy.array(positions)[idx]

        t = tempfile.NamedTemporaryFile().name + '.hdf5'
        cfile = h5py.File(t, 'w')
        to_write = {
            'positions': positions / 10.0,
            'permutation': numpy.sort(nodes),
            'nb_total': numpy.array([probe['total_nb_channels']])
        }
        write_datasets(cfile, list(to_write.keys()), to_write) 
        cfile.close()
        return t

    mapping = generate_matlab_mapping(probe)

    if not params.getboolean('data', 'overwrite'):
        filename = params.get('data', 'data_file_no_overwrite')
    else:
        filename = params.get('data', 'data_file')

    #apply_patch_for_similarities(params, extension)

    gui_file = pkg_resources.resource_filename('circus', os.path.join('matlab_GUI', 'SortingGUI.m'))
    # Change to the directory of the matlab file
    os.chdir(os.path.abspath(os.path.dirname(gui_file)))

    # Use quotation marks for string arguments
    if file_format not in supported_by_matlab:
        gui_params = [params.rate, os.path.abspath(file_out_suff), '%s.mat' % extension, mapping, 2, t_start]
        is_string = [False, True, True, True, False]

    else:

        gui_params = [params.rate, os.path.abspath(file_out_suff), '%s.mat' % extension, mapping, 2, t_start, data_dtype, data_offset, gain, filename]
        is_string = [False, True, True, True, False, False, True, False, False, True]

    arguments = ', '.join([
        "'%s'" % arg if s else "%s" % arg
        for arg, s in zip(gui_params, is_string)
    ])
    matlab_command = 'SortingGUI(%s)' % arguments

    print_and_log(["Launching the MATLAB GUI..."], 'info', logger)
    print_and_log([matlab_command], 'debug', logger)

    if params.getboolean('fitting', 'collect_all'):
        print_and_log(['You can not view the unfitted spikes with the MATLAB GUI',
                       'Please consider using phy if you really would like to see them'], 'info', logger)

    try:
        sys.exit(subprocess.call(['matlab', '-nodesktop', '-nosplash', '-r', matlab_command]))
    except Exception:
        if which('matlab') is not None:
            print_and_log(["Something wrong with MATLAB. Try circus-gui-python instead?"], 'error', logger)
        else:
            print_and_log(["MATLAB can not be found in the path. Please add it to the env variables"], 'error', logger)  
        sys.exit(1)
示例#27
0
def main(argv=None):

    if argv is None:
        argv = sys.argv[1:]

    header = get_colored_header()
    parser = argparse.ArgumentParser(
        description=header, formatter_class=argparse.RawTextHelpFormatter)
    parser.add_argument('datafile', help='data file')
    parser.add_argument('-e',
                        '--extension',
                        help='extension to consider for visualization',
                        default='')

    if len(argv) == 0:
        parser.print_help()
        sys.exit()

    args = parser.parse_args(argv)

    filename = os.path.abspath(args.datafile)
    extension = args.extension
    params = CircusParser(filename)
    if os.path.exists(params.logfile):
        os.remove(params.logfile)
    logger = init_logging(params.logfile)
    logger = logging.getLogger(__name__)

    mytest = StrictVersion(phycontrib.__version__) >= StrictVersion("1.0.12")
    if not mytest:
        print_and_log(
            ['You need to update phy-contrib to the latest git version'],
            'error', logger)
        sys.exit(1)

    if not test_patch_for_similarities(params, extension):
        print_and_log(
            ['You should re-export the data because of a fix in 0.6'], 'error',
            logger)
        continue_anyway = query_yes_no(
            Fore.WHITE + "Continue anyway (results may not be fully correct)?",
            default=None)
        if not continue_anyway:
            sys.exit(1)

    data_file = params.get_data_file()
    data_dtype = data_file.data_dtype
    if data_file.params.has_key('data_offset'):
        data_offset = data_file.data_offset
    else:
        data_offset = 0

    file_format = data_file.description
    file_out_suff = params.get('data', 'file_out_suff')

    if file_format not in supported_by_phy:
        print_and_log([
            "File format %s is not supported by phy. TraceView disabled" %
            file_format
        ], 'info', logger)

    if numpy.iterable(data_file.gain):
        print_and_log(
            ['Multiple gains are not supported, using a default value of 1'],
            'info', logger)
        gain = 1
    else:
        if data_file.gain != 1:
            print_and_log([
                "Gain of %g is not supported by phy. Expecting a scaling mismatch"
                % data_file.gain
            ], 'info', logger)
            gain = data_file.gain

    probe = params.probe
    if extension != '':
        extension = '-' + extension
    output_path = params.get('data', 'file_out_suff') + extension + '.GUI'

    if not os.path.exists(output_path):
        print_and_log(
            ['Data should be first exported with the converting method!'],
            'error', logger)
    else:

        print_and_log(["Launching the phy GUI..."], 'info', logger)

        gui_params = {}
        if file_format in supported_by_phy:
            if not params.getboolean('data', 'overwrite'):
                gui_params['dat_path'] = params.get('data',
                                                    'data_file_no_overwrite')
            else:
                if params.get('data', 'stream_mode') == 'multi-files':
                    data_file = params.get_data_file(source=True,
                                                     has_been_created=False)
                    gui_params['dat_path'] = ' '.join(
                        data_file.get_file_names())
                else:
                    gui_params['dat_path'] = params.get('data', 'data_file')
        else:
            gui_params['dat_path'] = 'giverandomname.dat'
        gui_params['n_channels_dat'] = params.nb_channels
        gui_params['n_features_per_channel'] = 5
        gui_params['dtype'] = data_dtype
        gui_params['offset'] = data_offset
        gui_params['sample_rate'] = params.rate
        gui_params['dir_path'] = output_path
        gui_params['hp_filtered'] = True

        f = open(os.path.join(output_path, 'params.py'), 'w')
        for key, value in gui_params.items():
            if key in ['dir_path', 'dat_path', 'dtype']:
                f.write('%s = "%s"\n' % (key, value))
            else:
                f.write("%s = %s\n" % (key, value))
        f.close()
        os.chdir(output_path)
        create_app()
        controller = TemplateController(**gui_params)
        gui = controller.create_gui()

        gui.show()
        run_app()
        gui.close()
        del gui
示例#28
0
def main(params, nb_cpu, nb_gpu, use_gpu, extension):

    _ = init_logging(params.logfile)
    logger = logging.getLogger('circus.converting')
    data_file = params.data_file
    file_out_suff = params.get('data', 'file_out_suff')
    probe = params.probe
    output_path = params.get('data', 'file_out_suff') + extension + '.GUI'
    N_e = params.getint('data', 'N_e')
    prelabelling = params.getboolean('converting', 'prelabelling')
    N_t = params.getint('detection', 'N_t')
    erase_all = params.getboolean('converting', 'erase_all')
    export_pcs = params.get('converting', 'export_pcs')
    export_all = params.getboolean('converting', 'export_all')
    sparse_export = params.getboolean('converting', 'sparse_export')
    rpv_threshold = params.getfloat('converting', 'rpv_threshold')
    if export_all and not params.getboolean('fitting', 'collect_all'):
        if comm.rank == 0:
            print_and_log([
                'Export unfitted spikes only if [fitting] collect_all is True'
            ], 'error', logger)
        sys.exit(0)

    def generate_mapping(probe):
        p = {}
        positions = []
        nodes = []
        shanks = []
        for key in probe['channel_groups'].keys():
            p.update(probe['channel_groups'][key]['geometry'])
            nodes += probe['channel_groups'][key]['channels']
            positions += [
                p[channel]
                for channel in probe['channel_groups'][key]['channels']
            ]
            shanks += [key] * len(probe['channel_groups'][key]['channels'])
        positions = numpy.array(positions)
        shanks = numpy.array(shanks)
        return positions, shanks

    def get_max_loc_channel(params, extension):
        if test_if_support(params, extension):
            supports = io.load_data(params, 'supports', extension)
            max_loc_channel = numpy.sum(supports, 1).max()
        else:
            nodes, edges = get_nodes_and_edges(params)
            max_loc_channel = 0
            for key in edges.keys():
                if len(edges[key]) > max_loc_channel:
                    max_loc_channel = len(edges[key])
        return max_loc_channel

    def write_results(path, params, extension):
        result = io.get_results(params, extension)
        spikes = [numpy.zeros(0, dtype=numpy.uint64)]
        clusters = [numpy.zeros(0, dtype=numpy.uint32)]
        amplitudes = [numpy.zeros(0, dtype=numpy.double)]
        N_tm = len(result['spiketimes'])

        has_purity = test_if_purity(params, extension)
        rpvs = []

        if prelabelling:
            labels = []
            norms = io.load_data(params, 'norm-templates', extension)
            norms = norms[:len(norms) // 2]
            if has_purity:
                purity = io.load_data(params, 'purity', extension)

        for key in result['spiketimes'].keys():
            temp_id = int(key.split('_')[-1])
            myspikes = result['spiketimes'].pop(key).astype(numpy.uint64)
            spikes.append(myspikes)
            myamplitudes = result['amplitudes'].pop(key).astype(numpy.double)
            amplitudes.append(myamplitudes[:, 0])
            clusters.append(temp_id *
                            numpy.ones(len(myamplitudes), dtype=numpy.uint32))
            rpv = get_rpv(myspikes, params.data_file.sampling_rate)
            rpvs += [[temp_id, rpv]]
            if prelabelling:
                if has_purity:
                    if rpv <= rpv_threshold:
                        if purity[temp_id] > 0.75:
                            labels += [[temp_id, 'good']]
                    else:
                        if purity[temp_id] > 0.75:
                            labels += [[temp_id, 'mua']]
                        else:
                            labels += [[temp_id, 'noise']]
                else:
                    median_amp = numpy.median(myamplitudes[:, 0])
                    std_amp = numpy.std(myamplitudes[:, 0])
                    if rpv <= rpv_threshold and numpy.abs(median_amp -
                                                          1) < 0.25:
                        labels += [[temp_id, 'good']]
                    else:
                        if median_amp < 0.5:
                            labels += [[temp_id, 'mua']]
                        elif norms[temp_id] < 0.1:
                            labels += [[temp_id, 'noise']]

        if export_all:
            print_and_log([
                "Last %d templates are unfitted spikes on all electrodes" % N_e
            ], 'info', logger)
            garbage = io.load_data(params, 'garbage', extension)
            for key in garbage['gspikes'].keys():
                elec_id = int(key.split('_')[-1])
                data = garbage['gspikes'].pop(key).astype(numpy.uint64)
                spikes.append(data)
                amplitudes.append(numpy.ones(len(data)))
                clusters.append((elec_id + N_tm) *
                                numpy.ones(len(data), dtype=numpy.uint32))

        if prelabelling:
            f = open(os.path.join(output_path, 'cluster_group.tsv'), 'w')
            f.write('cluster_id\tgroup\n')
            for l in labels:
                f.write('%s\t%s\n' % (l[0], l[1]))
            f.close()

        # f = open(os.path.join(output_path, 'cluster_rpv.tsv'), 'w')
        # f.write('cluster_id\trpv\n')
        # for l in rpvs:
        #     f.write('%s\t%s\n' % (l[0], l[1]))
        # f.close()

        spikes = numpy.concatenate(spikes).astype(numpy.uint64)
        amplitudes = numpy.concatenate(amplitudes).astype(numpy.double)
        clusters = numpy.concatenate(clusters).astype(numpy.uint32)

        idx = numpy.argsort(spikes)
        numpy.save(os.path.join(output_path, 'spike_templates'), clusters[idx])
        numpy.save(os.path.join(output_path, 'spike_times'), spikes[idx])
        numpy.save(os.path.join(output_path, 'amplitudes'), amplitudes[idx])
        return

    def write_templates(path, params, extension):

        max_loc_channel = get_max_loc_channel(params, extension)
        templates = io.load_data(params, 'templates', extension)
        N_tm = templates.shape[1] // 2
        nodes, edges = get_nodes_and_edges(params)

        if sparse_export:
            n_channels_max = 0
            for t in range(N_tm):
                data = numpy.sum(
                    numpy.sum(templates[:, t].toarray().reshape(N_e, N_t), 1)
                    != 0)
                if data > n_channels_max:
                    n_channels_max = data
        else:
            n_channels_max = N_e

        if export_all:
            to_write_sparse = numpy.zeros((N_tm + N_e, N_t, n_channels_max),
                                          dtype=numpy.float32)
            mapping_sparse = -1 * numpy.ones(
                (N_tm + N_e, n_channels_max), dtype=numpy.int32)
        else:
            to_write_sparse = numpy.zeros((N_tm, N_t, n_channels_max),
                                          dtype=numpy.float32)
            mapping_sparse = -1 * numpy.ones(
                (N_tm, n_channels_max), dtype=numpy.int32)

        has_purity = test_if_purity(params, extension)
        if has_purity:
            purity = io.load_data(params, 'purity', extension)
            f = open(os.path.join(output_path, 'cluster_purity.tsv'), 'w')
            f.write('cluster_id\tpurity\n')
            for i in range(N_tm):
                f.write('%d\t%g\n' % (i, purity[i]))
            f.close()

        for t in range(N_tm):
            tmp = templates[:, t].toarray().reshape(N_e, N_t).T
            x, y = tmp.nonzero()
            nb_loc = len(numpy.unique(y))

            if sparse_export:
                all_positions = numpy.zeros(y.max() + 1, dtype=numpy.int32)
                all_positions[numpy.unique(y)] = numpy.arange(
                    nb_loc, dtype=numpy.int32)
                pos = all_positions[y]
                to_write_sparse[t, x, pos] = tmp[x, y]
                mapping_sparse[t, numpy.arange(nb_loc)] = numpy.unique(y)
            else:
                pos = y
                to_write_sparse[t, x, pos] = tmp[x, y]

        if export_all:
            garbage = io.load_data(params, 'garbage', extension)
            for t in range(N_tm, N_tm + N_e):
                elec = t - N_tm
                spikes = garbage['gspikes'].pop('elec_%d' % elec).astype(
                    numpy.uint64)
                spikes = numpy.random.permutation(spikes)[:100]
                mapping_sparse[t, 0] = t - N_tm
                waveform = io.get_stas(params,
                                       times_i=spikes,
                                       labels_i=np.ones(len(spikes)),
                                       src=elec,
                                       neighs=[elec],
                                       nodes=nodes,
                                       mean_mode=True)

                nb_loc = 1

                if sparse_export:
                    to_write_sparse[t, :, 0] = waveform
                else:
                    to_write_sparse[t, :, elec] = waveform

        numpy.save(os.path.join(output_path, 'templates'), to_write_sparse)

        if sparse_export:
            numpy.save(os.path.join(output_path, 'template_ind'),
                       mapping_sparse)

        return N_tm

    def write_pcs(path, params, extension, N_tm, mode=0):

        spikes = numpy.load(os.path.join(output_path, 'spike_times.npy'))
        labels = numpy.load(os.path.join(output_path, 'spike_templates.npy'))
        max_loc_channel = get_max_loc_channel(params, extension)
        nb_features = params.getint('whitening', 'output_dim')
        sign_peaks = params.get('detection', 'peaks')
        nodes, edges = get_nodes_and_edges(params)
        N_total = params.getint('data', 'N_total')
        has_support = test_if_support(params, extension)
        if has_support:
            supports = io.load_data(params, 'supports', extension)
        else:
            inv_nodes = numpy.zeros(N_total, dtype=numpy.int32)
            inv_nodes[nodes] = numpy.arange(len(nodes))

        if export_all:
            nb_templates = N_tm + N_e
        else:
            nb_templates = N_tm

        pc_features_ind = numpy.zeros((nb_templates, max_loc_channel),
                                      dtype=numpy.int32)
        best_elec = io.load_data(params, 'electrodes', extension)
        if export_all:
            best_elec = numpy.concatenate((best_elec, numpy.arange(N_e)))

        if has_support:
            for count, support in enumerate(supports):
                nb_loc = numpy.sum(support)
                pc_features_ind[count, numpy.arange(nb_loc)] = numpy.where(
                    support == True)[0]
        else:
            for count, elec in enumerate(best_elec):
                nb_loc = len(edges[nodes[elec]])
                pc_features_ind[count, numpy.arange(nb_loc)] = inv_nodes[edges[
                    nodes[elec]]]

        if sign_peaks in ['negative', 'both']:
            basis_proj, basis_rec = io.load_data(params, 'basis')
        elif sign_peaks in ['positive']:
            basis_proj, basis_rec = io.load_data(params, 'basis-pos')

        to_process = numpy.arange(comm.rank, nb_templates, comm.size)

        all_offsets = numpy.zeros(nb_templates, dtype=numpy.int32)
        for target in range(nb_templates):
            if mode == 0:
                all_offsets[target] = len(numpy.where(labels == target)[0])
            elif mode == 1:
                all_offsets[target] = min(
                    500, len(numpy.where(labels == target)[0]))

        all_paddings = numpy.concatenate(([0], numpy.cumsum(all_offsets)))
        total_pcs = numpy.sum(all_offsets)

        pc_file = os.path.join(output_path, 'pc_features.npy')
        pc_file_ids = os.path.join(output_path, 'pc_feature_spike_ids.npy')

        from numpy.lib.format import open_memmap

        if comm.rank == 0:
            pc_features = open_memmap(pc_file,
                                      shape=(total_pcs, nb_features,
                                             max_loc_channel),
                                      dtype=numpy.float32,
                                      mode='w+')
            if mode == 1:
                pc_ids = open_memmap(pc_file_ids,
                                     shape=(total_pcs, ),
                                     dtype=numpy.int32,
                                     mode='w+')

        comm.Barrier()
        pc_features = open_memmap(pc_file, mode='r+')
        if mode == 1:
            pc_ids = open_memmap(pc_file_ids, mode='r+')

        to_explore = range(comm.rank, nb_templates, comm.size)

        if comm.rank == 0:
            to_explore = get_tqdm_progressbar(params, to_explore)

        all_idx = numpy.zeros(0, dtype=numpy.int32)
        for gcount, target in enumerate(to_explore):

            count = all_paddings[target]

            if mode == 1:
                idx = numpy.random.permutation(
                    numpy.where(labels == target)[0])[:500]
                pc_ids[count:count + len(idx)] = idx
            elif mode == 0:
                idx = numpy.where(labels == target)[0]

            elec = best_elec[target]

            if has_support:
                indices = numpy.where(supports[target])[0]
            else:
                indices = inv_nodes[edges[nodes[elec]]]
            labels_i = target * numpy.ones(len(idx))
            times_i = numpy.take(spikes, idx).astype(numpy.int64)
            sub_data = io.get_stas(params,
                                   times_i,
                                   labels_i,
                                   elec,
                                   neighs=indices,
                                   nodes=nodes,
                                   auto_align=False)

            pcs = numpy.dot(sub_data, basis_proj)
            pcs = numpy.swapaxes(pcs, 1, 2)
            if mode == 0:
                pc_features[idx, :, :len(indices)] = pcs
            elif mode == 1:
                pc_features[count:count + len(idx), :, :len(indices)] = pcs

        comm.Barrier()

        if comm.rank == 0:
            numpy.save(os.path.join(output_path, 'pc_feature_ind'),
                       pc_features_ind.astype(
                           numpy.uint32))  # n_templates, n_loc_chan

    do_export = True
    if comm.rank == 0:
        if os.path.exists(output_path):
            if not erase_all:
                do_export = query_yes_no(
                    Fore.WHITE +
                    "Export already made! Do you want to erase everything?",
                    default=None)

            if do_export:
                if os.path.exists(os.path.abspath('.phy')):
                    shutil.rmtree(os.path.abspath('.phy'))
                shutil.rmtree(output_path)
        if do_export:
            comm.bcast(numpy.array([1], dtype=numpy.int32), root=0)
        else:
            comm.bcast(numpy.array([0], dtype=numpy.int32), root=0)
    else:
        do_export = bool(
            comm.bcast(numpy.array([0], dtype=numpy.int32), root=0))

    comm.Barrier()

    if do_export:

        apply_patch_for_similarities(params, extension)

        if comm.rank == 0:
            os.makedirs(output_path)
            print_and_log(
                ["Exporting data for the phy GUI with %d CPUs..." % nb_cpu],
                'info', logger)

            if params.getboolean('whitening', 'spatial'):
                whitening_mat = io.load_data(
                    params, 'spatial_whitening').astype(numpy.double)
                numpy.save(os.path.join(output_path, 'whitening_mat'),
                           whitening_mat)
                numpy.save(os.path.join(output_path, 'whitening_mat_inv'),
                           numpy.linalg.inv(whitening_mat))
            else:
                numpy.save(os.path.join(output_path, 'whitening_mat'),
                           numpy.eye(N_e))

            positions, shanks = generate_mapping(probe)
            numpy.save(os.path.join(output_path, 'channel_positions'),
                       positions.astype(numpy.double))
            numpy.save(os.path.join(output_path, 'channel_shanks'),
                       shanks.astype(numpy.double))
            nodes, edges = get_nodes_and_edges(params)
            numpy.save(os.path.join(output_path, 'channel_map'),
                       nodes.astype(numpy.int32))

            write_results(output_path, params, extension)

            N_tm = write_templates(output_path, params, extension)

            template_file = h5py.File(file_out_suff +
                                      '.templates%s.hdf5' % extension,
                                      'r',
                                      libver='earliest')
            similarities = template_file.get('maxoverlap')[:]
            template_file.close()
            norm = N_e * N_t

            if export_all:
                to_write = numpy.zeros((N_tm + N_e, N_tm + N_e),
                                       dtype=numpy.single)
                to_write[:N_tm, :N_tm] = (similarities[:N_tm, :N_tm] /
                                          norm).astype(numpy.single)
            else:
                to_write = (similarities[:N_tm, :N_tm] / norm).astype(
                    numpy.single)
            numpy.save(os.path.join(output_path, 'similar_templates'),
                       to_write)

            comm.bcast(numpy.array([N_tm], dtype=numpy.int32), root=0)

        else:
            N_tm = int(comm.bcast(numpy.array([0], dtype=numpy.int32), root=0))

        comm.Barrier()

        make_pcs = 2
        if comm.rank == 0:

            if export_pcs == 'prompt':
                key = ''
                while key not in ['a', 's', 'n']:
                    print(
                        Fore.WHITE +
                        "Do you want SpyKING CIRCUS to export PCs? (a)ll / (s)ome / (n)o"
                    )
                    key = raw_input('')
            else:
                key = export_pcs

            if key == 'a':
                make_pcs = 0
                comm.bcast(numpy.array([0], dtype=numpy.int32), root=0)
            elif key == 's':
                make_pcs = 1
                comm.bcast(numpy.array([1], dtype=numpy.int32), root=0)
            elif key == 'n':
                comm.bcast(numpy.array([2], dtype=numpy.int32), root=0)
                if os.path.exists(os.path.join(output_path,
                                               'pc_features.npy')):
                    os.remove(os.path.join(output_path, 'pc_features.npy'))
                if os.path.exists(
                        os.path.join(output_path, 'pc_feature_ind.npy')):
                    os.remove(os.path.join(output_path, 'pc_feature_ind.npy'))
        else:
            make_pcs = comm.bcast(numpy.array([0], dtype=numpy.int32), root=0)
            make_pcs = make_pcs[0]

        comm.Barrier()
        if make_pcs < 2:
            write_pcs(output_path, params, extension, N_tm, make_pcs)

        supported_by_phy = ['raw_binary', 'mcs_raw_binary', 'mda']
        file_format = data_file.description
        gui_params = {}

        if file_format in supported_by_phy:
            if not params.getboolean('data', 'overwrite'):
                gui_params['dat_path'] = r"%s" % params.get(
                    'data', 'data_file_no_overwrite')
            else:
                if params.get('data', 'stream_mode') == 'multi-files':
                    data_file = params.get_data_file(source=True,
                                                     has_been_created=False)
                    gui_params['dat_path'] = "["
                    for f in data_file.get_file_names():
                        gui_params['dat_path'] += 'r"%s", ' % f
                    gui_params['dat_path'] += "]"
                else:
                    gui_params['dat_path'] = 'r"%s"' % params.get(
                        'data', 'data_file')
        else:
            gui_params['dat_path'] = 'giverandomname.dat'
        gui_params['n_channels_dat'] = params.nb_channels
        gui_params['n_features_per_channel'] = 5
        gui_params['dtype'] = data_file.data_dtype
        if 'data_offset' in data_file.params.keys():
            gui_params['offset'] = data_file.data_offset
        gui_params['sample_rate'] = params.rate
        gui_params['dir_path'] = output_path
        gui_params['hp_filtered'] = True

        f = open(os.path.join(output_path, 'params.py'), 'w')
        for key, value in gui_params.items():
            if key in ['dir_path', 'dtype']:
                f.write('%s = r"%s"\n' % (key, value))
            else:
                f.write("%s = %s\n" % (key, value))
        f.close()
示例#29
0
    def get_data_file(self,
                      is_empty=False,
                      params=None,
                      source=False,
                      has_been_created=True):
        """
        Gets the datafile as described in the param files.

        Parameters
        ----------
        is_empty : bool

        params : dict

        source : bool

        has_been_created : bool
            if the data file was 

        Returns
        -------
        dict   
            A dictionary with the parameters of created data file.

        """

        if params is None:
            params = {}

        for key, value in self.parser._sections['data'].items():
            if key not in params:
                params[key] = value

        data_file = params.pop('data_file')
        stream_mode = self.get('data', 'stream_mode').lower()

        if stream_mode in ['none']:
            stream_mode = None

        if not self.getboolean('data', 'overwrite'):
            # If we do not want to overwrite, we first read the original data file
            # Then, if we do not want to obtain it as a source file, we switch the
            # format to raw_binary and the output file name

            if not source:

                # First we read the original data file, that should not be empty
                print_and_log(
                    ['Reading first the real data file to get the parameters'],
                    'debug', logger)
                tmp = self._create_data_file(data_file, False, params,
                                             stream_mode)

                # Then we change the dataa_file name
                data_file = self.get('data', 'data_file_no_overwrite')

                if comm.rank == 0:
                    print_and_log([
                        'Forcing the exported data file to be of type raw_binary'
                    ], 'debug', logger)

                # And we force the results to be of type float32, without streams
                params['file_format'] = 'raw_binary'
                params['data_dtype'] = 'float32'
                params['dtype_offset'] = 0
                params['data_offset'] = 0
                params['sampling_rate'] = self.rate
                params['nb_channels'] = self.nb_channels
                params['gain'] = self.gain
                stream_mode = None
                data_file, extension = os.path.splitext(data_file)
                data_file += ".dat"

            else:
                if has_been_created:
                    data_file = self.get('data', 'data_file_no_overwrite')
                    if not os.path.exists(data_file):
                        if comm.rank == 0:
                            lines = [
                                'The overwrite option is only valid if the filtering step is launched before!'
                            ]
                            print_and_log(lines, 'error', logger)
                        sys.exit(0)
                else:
                    if comm.rank == 0:
                        print_and_log([
                            'The copy file has not yet been created! Returns normal file'
                        ], 'debug', logger)

        return self._create_data_file(data_file, is_empty, params, stream_mode)
示例#30
0
def slice_clusters(params,
                   result,
                   to_remove=[],
                   to_merge=[],
                   extension='',
                   light=False):

    import h5py, shutil
    file_out_suff = params.get('data', 'file_out_suff')
    data_file = params.data_file
    N_e = params.getint('data', 'N_e')
    N_total = params.nb_channels
    N_t = params.getint('detection', 'N_t')
    template_shift = params.getint('detection', 'template_shift')

    if comm.rank == 0:

        print_and_log(['Node 0 is slicing clusters'], 'debug', logger)

        if to_merge != []:
            for count in xrange(len(to_merge)):
                remove = to_merge[count][1]
                to_remove += [remove]

        all_elements = [[] for i in xrange(N_e)]
        for target in numpy.unique(to_remove):
            elec = result['electrodes'][target]
            nic = target - numpy.where(result['electrodes'] == elec)[0][0]
            mask = result['clusters_' + str(elec)] > -1
            tmp = numpy.unique(result['clusters_' + str(elec)][mask])
            all_elements[elec] += list(
                numpy.where(result['clusters_' + str(elec)] == tmp[nic])[0])

        for elec in xrange(N_e):
            if not light:
                result['data_' + str(elec)] = numpy.delete(result['data_' +
                                                                  str(elec)],
                                                           all_elements[elec],
                                                           axis=0)
                result['clusters_' + str(elec)] = numpy.delete(
                    result['clusters_' + str(elec)], all_elements[elec])
                result['times_' + str(elec)] = numpy.delete(
                    result['times_' + str(elec)], all_elements[elec])
                result['peaks_' + str(elec)] = numpy.delete(
                    result['peaks_' + str(elec)], all_elements[elec])
            else:

                result['clusters_' + str(elec)] = numpy.delete(
                    result['clusters_' + str(elec)], all_elements[elec])
                myfile = h5py.File(file_out_suff + '.clusters.hdf5',
                                   'r',
                                   libver='latest')
                data = myfile.get('data_' + str(elec))[:]
                result['data_' + str(elec)] = numpy.delete(data,
                                                           all_elements[elec],
                                                           axis=0)
                data = myfile.get('times_' + str(elec))[:]
                result['times_' + str(elec)] = numpy.delete(
                    data, all_elements[elec])
                data = myfile.get('peaks_' + str(elec))[:]
                result['peaks_' + str(elec)] = numpy.delete(
                    data, all_elements[elec])
                myfile.close()

        result['electrodes'] = numpy.delete(result['electrodes'],
                                            numpy.unique(to_remove))

        cfile = h5py.File(file_out_suff + '.clusters-new.hdf5',
                          'w',
                          libver='latest')
        to_write = ['data_', 'clusters_', 'times_', 'peaks_']
        for ielec in xrange(N_e):
            write_datasets(cfile, to_write, result, ielec)

        write_datasets(cfile, ['electrodes'], result)
        cfile.close()
        if os.path.exists(file_out_suff + '.clusters%s.hdf5' % extension):
            os.remove(file_out_suff + '.clusters%s.hdf5' % extension)
        shutil.move(file_out_suff + '.clusters-new.hdf5',
                    file_out_suff + '.clusters%s.hdf5' % extension)

    comm.Barrier()