コード例 #1
0
ファイル: tridesclous.py プロジェクト: timsainb/spikesorters
    def _run(self, recording, output_folder):
        recording = recover_recording(recording)
        tdc_dataio = tdc.DataIO(dirname=str(output_folder))

        params = dict(self.params)
        del params["chunk_mb"], params["n_jobs_bin"]

        clean_catalogue_gui = params.pop('clean_catalogue_gui')
        # make catalogue
        chan_grps = list(tdc_dataio.channel_groups.keys())
        for chan_grp in chan_grps:

            # parameters can change depending the group
            catalogue_nested_params = make_nested_tdc_params(
                tdc_dataio, chan_grp, **params)

            if self.verbose:
                print('catalogue_nested_params')
                pprint(catalogue_nested_params)

            peeler_params = tdc.get_auto_params_for_peelers(
                tdc_dataio, chan_grp)
            if self.verbose:
                print('peeler_params')
                pprint(peeler_params)

            cc = tdc.CatalogueConstructor(dataio=tdc_dataio, chan_grp=chan_grp)
            tdc.apply_all_catalogue_steps(cc,
                                          catalogue_nested_params,
                                          verbose=self.verbose)

            if clean_catalogue_gui:
                import pyqtgraph as pg
                app = pg.mkQApp()
                win = tdc.CatalogueWindow(cc)
                win.show()
                app.exec_()

            if self.verbose:
                print(cc)

            if distutils.version.LooseVersion(tdc.__version__) < '1.6.0':
                print('You should upgrade tridesclous')
                t0 = time.perf_counter()
                cc.make_catalogue_for_peeler()
                if self.verbose:
                    t1 = time.perf_counter()
                    print('make_catalogue_for_peeler', t1 - t0)

            # apply Peeler (template matching)
            initial_catalogue = tdc_dataio.load_catalogue(chan_grp=chan_grp)
            peeler = tdc.Peeler(tdc_dataio)
            peeler.change_params(catalogue=initial_catalogue, **peeler_params)
            t0 = time.perf_counter()
            peeler.run(duration=None, progressbar=False)
            if self.verbose:
                t1 = time.perf_counter()
                print('peeler.tun', t1 - t0)
コード例 #2
0
    def _run(self, recording, output_folder):
        nb_chan = recording.get_num_channels()
        tdc_dataio = tdc.DataIO(dirname=str(output_folder))

        params = dict(self.params)
        clean_catalogue_gui = params.pop('clean_catalogue_gui')
        # make catalogue
        chan_grps = list(tdc_dataio.channel_groups.keys())
        for chan_grp in chan_grps:

            # parameters can change depending the group
            catalogue_nested_params = make_nested_tdc_params(tdc_dataio, chan_grp, **params)

            if self.verbose:
                print('catalogue_nested_params')
                pprint(catalogue_nested_params)
            
            peeler_params = tdc.get_auto_params_for_peelers(tdc_dataio, chan_grp)
            if self.verbose:
                print('peeler_params')
                pprint(peeler_params)

            # check params and OpenCL when many channels
            use_sparse_template = False
            use_opencl_with_sparse = False
            if nb_chan > 64 and not peeler_params['use_sparse_template']:
                print('OpenCL is not available processing will be slow, try install it')

            cc = tdc.CatalogueConstructor(dataio=tdc_dataio, chan_grp=chan_grp)
            tdc.apply_all_catalogue_steps(cc, catalogue_nested_params, verbose=self.verbose)

            if clean_catalogue_gui:
                import pyqtgraph as pg
                app = pg.mkQApp()
                win = tdc.CatalogueWindow(cc)
                win.show()
                app.exec_()

            if self.verbose:
                print(cc)
            
            t0 = time.perf_counter()
            cc.make_catalogue_for_peeler()
            if self.verbose:
                t1 = time.perf_counter()
                print('make_catalogue_for_peeler', t1-t0)

            # apply Peeler (template matching)
            initial_catalogue = tdc_dataio.load_catalogue(chan_grp=chan_grp)
            peeler = tdc.Peeler(tdc_dataio)
            peeler.change_params(catalogue=initial_catalogue, **peeler_params)
            t0 = time.perf_counter()
            peeler.run(duration=None, progressbar=False)
            if self.verbose:
                t1 = time.perf_counter()
                print('peeler.tun', t1-t0)
コード例 #3
0
def find_clusters_and_show():

    #~ catalogueconstructor.find_clusters(method='kmeans', n_clusters=12)
    catalogueconstructor.find_clusters(method='gmm', n_clusters=15)

    app = pg.mkQApp()
    win = tdc.CatalogueWindow(catalogueconstructor)
    win.show()

    app.exec_()
コード例 #4
0
    def open_tridesclous_gui(self, channel_rel, label):
        cc = self.run_spikesorter_on_channel(channel_rel, label)
        if cc is None:
            return
        gui = pg.mkQApp()
        win = tdc.CatalogueWindow(cc)
        self.plot_window.withdraw()
        win.show()
        gui.exec_()
        self.plot_window.deiconify()

        # Remove any plots, because the user may have modified the clusters.
        # The plots will be re-created on demand.
        self.remove_plots(label)
コード例 #5
0
ファイル: tdc.py プロジェクト: zenmar/tridesclous
def main():
    argv = sys.argv[1:]

    parser = argparse.ArgumentParser(description='tridesclous')
    parser.add_argument('command',
                        help='command in [{}]'.format(txt_command_list),
                        default='mainwin',
                        nargs='?')

    parser.add_argument('-d',
                        '--dirname',
                        help='working directory',
                        default=None)
    parser.add_argument('-c',
                        '--chan_grp',
                        type=int,
                        help='channel group index',
                        default=0)
    parser.add_argument('-p',
                        '--parameters',
                        help='JSON parameter file',
                        default=None)

    args = parser.parse_args(argv)
    #~ print(sys.argv)
    #~ print(args)
    #~ print(args.command)

    command = args.command
    if not command in comand_list:
        print('command should be in [{}]'.format(txt_command_list))
        exit()

    dirname = args.dirname
    if dirname is None:
        dirname = os.getcwd()

    #~ print(command)

    if command in ['cataloguewin', 'peelerwin']:
        if not tdc.DataIO.check_initialized(dirname):
            print('{} is not initialized'.format(dirname))
            exit()
        dataio = tdc.DataIO(dirname=dirname)
        print(dataio)

    if command == 'mainwin':
        open_mainwindow()

    elif command == 'makecatalogue':
        pass

    elif command == 'runpeeler':
        pass

    elif command == 'cataloguewin':
        catalogueconstructor = tdc.CatalogueConstructor(dataio=dataio,
                                                        chan_grp=args.chan_grp)
        app = pg.mkQApp()
        win = tdc.CatalogueWindow(catalogueconstructor)
        win.show()
        app.exec_()

    elif command == 'peelerwin':
        initial_catalogue = dataio.load_catalogue(chan_grp=args.chan_grp)
        app = pg.mkQApp()
        win = tdc.PeelerWindow(dataio=dataio, catalogue=initial_catalogue)
        win.show()
        app.exec_()

    elif command == 'init':
        app = pg.mkQApp()
        win = tdc.InitializeDatasetWindow()
        win.show()
        app.exec_()
コード例 #6
0
def test_split_to_find_residual_minimize():

    import scipy.signal
    import diptest

    labels = cc.all_peaks['label'][cc.some_peaks_index]
    #~ keep = labels>=0
    #~ labels = labels[keep]
    #~ features = cc.some_features[keep]
    waveforms = cc.some_waveforms

    n_left = cc.info['params_waveformextractor']['n_left']
    n_right = cc.info['params_waveformextractor']['n_right']
    width = n_right - n_left

    cluster_labels = np.zeros(waveforms.shape[0], dtype='int64')

    bins = np.arange(-30, 0, 0.1)

    def dirty_cut(x, bins):
        labels = np.zeros(x.size, dtype='int64')
        count, bins = np.histogram(x, bins=bins)
        #~ kernel = scipy.signal.get_window(10

        #~ kernel = scipy.signal.gaussian(51, 10)
        kernel = scipy.signal.gaussian(51, 5)
        #~ kernel = scipy.signal.gaussian(31, 10)
        kernel /= np.sum(kernel)

        #~ fig, ax = plt.subplots()
        #~ ax.plot(kernel)
        #~ plt.show()

        #~ count[count==1]=0
        count_smooth = np.convolve(count, kernel, mode='same')

        local_min_indexes, = np.nonzero(
            (count_smooth[1:-1] < count_smooth[:-2])
            & (count_smooth[1:-1] <= count_smooth[2:]))

        if local_min_indexes.size == 0:
            lim = 0
        else:

            n_on_left = []
            for ind in local_min_indexes:
                lim = bins[ind]
                n = np.sum(count[bins[:-1] <= lim])
                n_on_left.append(n)
                #~ print('lim', lim, n)

                #~ if n>30:
                #~ break
                #~ else:
                #~ lim = None
            n_on_left = np.array(n_on_left)
            print('n_on_left', n_on_left, 'local_min_indexes',
                  local_min_indexes, x.size)
            p = np.argmin(np.abs(n_on_left - x.size // 2))
            print('p', p)
            lim = bins[local_min_indexes[p]]

        #~ lim = bins[local_min[0]]
        #~ print(local_min, min(x), lim)

        #~ if x.size==3296:
        #~ fig, ax = plt.subplots()
        #~ ax.plot(bins[:-1], count, color='b')
        #~ ax.plot(bins[:-1], count_smooth, color='g')
        #~ if lim is not None:
        #~ ax.axvline(lim)
        #~ plt.show()

        if lim is None:
            return None, None

        labels[x > lim] = 1

        return labels, lim

    wf = waveforms.swapaxes(1, 2).reshape(waveforms.shape[0], -1)

    k = 0
    dim_visited = []
    for i in range(1000):

        left_over = np.sum(cluster_labels >= k)
        print()
        print('i', i, 'k', k, 'left_over', left_over)

        sel = cluster_labels == k

        if i != 0 and left_over < 30:  # or k==40:
            cluster_labels[sel] = -k
            print('BREAK left_over<30')
            break

        wf_sel = wf[sel]
        n_with_label = wf_sel.shape[0]
        print('n_with_label', n_with_label)

        if wf_sel.shape[0] < 30:
            print('too few')
            cluster_labels[sel] = -k
            k += 1
            dim_visited = []
            continue

        med, mad = compute_median_mad(wf_sel)

        if np.all(mad < 1.6):
            print('mad<1.6')
            k += 1
            dim_visited = []
            continue

        if np.all(wf_sel.shape[0] < 100):
            print('Too small cluster')
            k += 1
            dim_visited = []
            continue

        #~ weight = mad-1
        #~ feat = wf_sel * weight
        #~ feat = wf_sel[:, np.argmax(mad), None]
        #~ print(feat.shape)
        #~ print(feat)

        #~ while True:

        possible_dim, = np.nonzero(med < 0)
        possible_dim = possible_dim[~np.in1d(possible_dim, dim_visited)]
        if len(possible_dim) == 0:
            print('BREAK len(possible_dim)==0')
            #~ dim = None
            break
        dim = possible_dim[np.argmax(mad[possible_dim])]
        print('dim', dim)
        #~ dim_visited.append(dim)
        #~ print('dim', dim, 'dim_visited',dim_visited)
        #~ feat = wf_sel[:, dim]

        #~ dip_values = np.zeros(possible_dim.size)
        #~ for j, dim in enumerate(possible_dim):
        #~ print('j', j)
        #~ dip, f = diptest.dip(wf_sel[:, dim], full_output=True, x_is_sorted=False)
        #~ dip_values[j] = dip

        #~ dip_values[j] = dip

        #~ dim = possible_dim[np.argmin(dip_values)]
        #~ dip = dip_values[np.argmin(dip_values)]

        #~ dip, f = diptest.dip(wf_sel[:, dim], full_output=True, x_is_sorted=False)
        #~ dip, pval = diptest.diptest(wf_sel[:, dim])

        #~ print('dim', dim, 'dip', dip, 'pval', pval)

        #~ fig, axs = plt.subplots(nrows=2, sharex=True)
        #~ axs[0].fill_between(np.arange(med.size), med-mad, med+mad, alpha=.5)
        #~ axs[0].plot(np.arange(med.size), med)
        #~ axs[0].set_title(str(wf_sel.shape[0]))
        #~ axs[1].plot(mad)
        #~ axs[1].axvline(dim, color='r')
        #~ plt.show()

        #~ if dip<0.01:
        #~ break

        feat = wf_sel[:, dim]

        labels, lim = dirty_cut(feat, bins)
        if labels is None:
            channel_index = dim // width
            #~ dim_visited.append(dim)
            dim_visited.extend(
                range(channel_index * width, (channel_index + 1) * width))
            print('loop', dim_visited)
            continue

        print(feat[labels == 0].size, feat[labels == 1].size)

        #~ fig, ax = plt.subplots()
        #~ count, bins = np.histogram(feat, bins=bins)
        #~ count0, bins = np.histogram(feat[labels==0], bins=bins)
        #~ count1, bins = np.histogram(feat[labels==1], bins=bins)
        #~ ax.axvline(lim)
        #~ ax.plot(bins[:-1], count, color='b')
        #~ ax.plot(bins[:-1], count0,  color='r')
        #~ ax.plot(bins[:-1], count1, color='g')
        #~ ax.set_title('dim {} dip {:.4f} pval {:.4f}'.format(dim, dip, pval))
        #~ plt.show()

        #~ if pval>0.1:
        #~ print('BREAK pval>0.05')
        #~ break

        ind, = np.nonzero(sel)

        #~ med0, mad0 = compute_median_mad(feat[labels==0])
        #~ med1, mad1 = compute_median_mad(feat[labels==1])
        #~ if np.abs(med0)>np.abs(med1):
        #~ if mad0<mad1:
        #~ cluster_labels[ind[labels==1]] += 1
        #~ else:
        #~ cluster_labels[ind[labels==0]] += 1

        #~ if dip>0.05 and pval<0.01:

        print('nb1', np.sum(labels == 1), 'nb0', np.sum(labels == 0))

        if np.sum(labels == 0) == 0:
            channel_index = dim // width
            #~ dim_visited.append(dim)
            dim_visited.extend(
                range(channel_index * width, (channel_index + 1) * width))
            print('nb0==0', dim_visited)

            continue

        #~ cluster_labels[cluster_labels>k] += 1#TODO reflechir la dessus!!!
        cluster_labels[ind[labels == 1]] += 1

        if np.sum(labels == 1) == 0:
            k += 1
            dim_visited = []

        #~ med0, mad0 = compute_median_mad(feat[labels==0])

    fig, axs = plt.subplots(nrows=2)
    for k in np.unique(cluster_labels):
        sel = cluster_labels == k
        wf_sel = wf[sel]

        med, mad = compute_median_mad(wf_sel)
        axs[0].plot(med, label=str(k))
        axs[1].plot(mad)

    plt.show()

    cc.all_peaks['label'][cc.some_peaks_index] = cluster_labels

    app = pg.mkQApp()
    win = tdc.CatalogueWindow(catalogueconstructor)
    win.show()

    app.exec_()