예제 #1
0
    def write_results(path, params, extension):
        result = io.get_results(params, extension)
        spikes = [numpy.zeros(0, dtype=numpy.uint64)]
        clusters = [numpy.zeros(0, dtype=numpy.uint32)]
        amplitudes = [numpy.zeros(0, dtype=numpy.double)]
        N_tm = len(result['spiketimes'])

        has_purity = test_if_purity(params, extension)
        rpvs = []

        if prelabelling:
            labels = []
            norms = io.load_data(params, 'norm-templates', extension)
            norms = norms[:len(norms) // 2]
            if has_purity:
                purity = io.load_data(params, 'purity', extension)

        for key in list(result['spiketimes'].keys()):
            temp_id = int(key.split('_')[-1])
            myspikes = result['spiketimes'].pop(key).astype(numpy.uint64)
            spikes.append(myspikes)
            myamplitudes = result['amplitudes'].pop(key).astype(numpy.double)
            amplitudes.append(myamplitudes[:, 0])
            clusters.append(temp_id *
                            numpy.ones(len(myamplitudes), dtype=numpy.uint32))
            rpv = get_rpv(myspikes, params.data_file.sampling_rate)
            rpvs += [[temp_id, rpv]]
            if prelabelling:
                if has_purity:
                    if rpv <= rpv_threshold:
                        if purity[temp_id] > 0.75:
                            labels += [[temp_id, 'good']]
                    else:
                        if purity[temp_id] > 0.75:
                            labels += [[temp_id, 'mua']]
                        else:
                            labels += [[temp_id, 'noise']]
                else:
                    median_amp = numpy.median(myamplitudes[:, 0])
                    std_amp = numpy.std(myamplitudes[:, 0])
                    if rpv <= rpv_threshold and numpy.abs(median_amp -
                                                          1) < 0.25:
                        labels += [[temp_id, 'good']]
                    else:
                        if median_amp < 0.5:
                            labels += [[temp_id, 'mua']]
                        elif norms[temp_id] < 0.1:
                            labels += [[temp_id, 'noise']]

        if export_all:
            print_and_log([
                "Last %d templates are unfitted spikes on all electrodes" % N_e
            ], 'info', logger)
            garbage = io.load_data(params, 'garbage', extension)
            for key in list(garbage['gspikes'].keys()):
                elec_id = int(key.split('_')[-1])
                data = garbage['gspikes'].pop(key).astype(numpy.uint64)
                spikes.append(data)
                amplitudes.append(numpy.ones(len(data)))
                clusters.append((elec_id + N_tm) *
                                numpy.ones(len(data), dtype=numpy.uint32))

        if prelabelling:
            f = open(os.path.join(output_path, 'cluster_group.tsv'), 'w')
            f.write('cluster_id\tgroup\n')
            for l in labels:
                f.write('%s\t%s\n' % (l[0], l[1]))
            f.close()

        # f = open(os.path.join(output_path, 'cluster_rpv.tsv'), 'w')
        # f.write('cluster_id\trpv\n')
        # for l in rpvs:
        #     f.write('%s\t%s\n' % (l[0], l[1]))
        # f.close()

        spikes = numpy.concatenate(spikes).astype(numpy.uint64)
        amplitudes = numpy.concatenate(amplitudes).astype(numpy.double)
        clusters = numpy.concatenate(clusters).astype(numpy.uint32)

        idx = numpy.argsort(spikes)
        numpy.save(os.path.join(output_path, 'spike_templates'), clusters[idx])
        numpy.save(os.path.join(output_path, 'spike_times'), spikes[idx])
        numpy.save(os.path.join(output_path, 'amplitudes'), amplitudes[idx])
        return
예제 #2
0
    def write_templates(path, params, extension):

        max_loc_channel = get_max_loc_channel(params, extension)
        templates = io.load_data(params, 'templates', extension)
        N_tm = templates.shape[1] // 2
        nodes, edges = get_nodes_and_edges(params)

        if sparse_export:
            n_channels_max = 0
            for t in range(N_tm):
                data = numpy.sum(
                    numpy.sum(templates[:, t].toarray().reshape(N_e, N_t), 1)
                    != 0)
                if data > n_channels_max:
                    n_channels_max = data
        else:
            n_channels_max = N_e

        if export_all:
            to_write_sparse = numpy.zeros((N_tm + N_e, N_t, n_channels_max),
                                          dtype=numpy.float32)
            mapping_sparse = -1 * numpy.ones(
                (N_tm + N_e, n_channels_max), dtype=numpy.int32)
        else:
            to_write_sparse = numpy.zeros((N_tm, N_t, n_channels_max),
                                          dtype=numpy.float32)
            mapping_sparse = -1 * numpy.ones(
                (N_tm, n_channels_max), dtype=numpy.int32)

        has_purity = test_if_purity(params, extension)
        if has_purity:
            purity = io.load_data(params, 'purity', extension)
            f = open(os.path.join(output_path, 'cluster_purity.tsv'), 'w')
            f.write('cluster_id\tpurity\n')
            for i in range(N_tm):
                f.write('%d\t%g\n' % (i, purity[i]))
            f.close()

        for t in range(N_tm):
            tmp = templates[:, t].toarray().reshape(N_e, N_t).T
            x, y = tmp.nonzero()
            nb_loc = len(numpy.unique(y))

            if sparse_export:
                all_positions = numpy.zeros(y.max() + 1, dtype=numpy.int32)
                all_positions[numpy.unique(y)] = numpy.arange(
                    nb_loc, dtype=numpy.int32)
                pos = all_positions[y]
                to_write_sparse[t, x, pos] = tmp[x, y]
                mapping_sparse[t, numpy.arange(nb_loc)] = numpy.unique(y)
            else:
                pos = y
                to_write_sparse[t, x, pos] = tmp[x, y]

        if export_all:
            garbage = io.load_data(params, 'garbage', extension)
            for t in range(N_tm, N_tm + N_e):
                elec = t - N_tm
                spikes = garbage['gspikes'].pop('elec_%d' % elec).astype(
                    numpy.int64)
                spikes = numpy.random.permutation(spikes)[:100]
                mapping_sparse[t, 0] = t - N_tm
                waveform = io.get_stas(params,
                                       times_i=spikes,
                                       labels_i=np.ones(len(spikes)),
                                       src=elec,
                                       neighs=[elec],
                                       nodes=nodes,
                                       mean_mode=True)

                nb_loc = 1

                if sparse_export:
                    to_write_sparse[t, :, 0] = waveform
                else:
                    to_write_sparse[t, :, elec] = waveform

        numpy.save(os.path.join(output_path, 'templates'), to_write_sparse)

        if sparse_export:
            numpy.save(os.path.join(output_path, 'template_ind'),
                       mapping_sparse)

        return N_tm