def make_catalogue_figure(): dataio = DataIO(dirname=dirname) catalogue = dataio.load_catalogue(chan_grp=0) clusters = catalogue['clusters'] geometry = dataio.get_geometry(chan_grp=0) fig, ax = plt.subplots() ax.set_title('Catalogue have 4 templates') for i in range(clusters.size): color = clusters[i]['color'] color = int32_to_rgba(color, mode='float') waveforms = catalogue['centers0'][i:i + 1] plot_waveforms_with_geometry(waveforms, channels, geometry, ax=ax, ratioY=3, deltaX=50, margin=50, color=color, linewidth=3, alpha=1, show_amplitude=True, ratio_mad=8) fig.savefig('../img/peeler_templates_for_animation.png')
def apply_peeler(): dataio = DataIO(dirname=dirname) catalogue = dataio.load_catalogue(chan_grp=0) peeler = Peeler(dataio) peeler.change_params(catalogue=catalogue, chunksize=1024) peeler.run(progressbar=True)
def preprocess_array(array_idx, output_dir, total_duration): dataio = DataIO(dirname=output_dir, ch_grp=array_idx) fullchain_kargs = { 'duration': total_duration, 'preprocessor': { 'highpass_freq': 250., 'lowpass_freq': 3000., 'smooth_size': 0, 'common_ref_removal': True, 'chunksize': 32768, 'lostfront_chunksize': 0, 'signalpreprocessor_engine': 'numpy', } } cc = CatalogueConstructor(dataio=dataio, chan_grp=array_idx) p = {} p.update(fullchain_kargs['preprocessor']) cc.set_preprocessor_params(**p) # TODO offer noise esatimation duration somewhere noise_duration = min( 10., fullchain_kargs['duration'], dataio.get_segment_length(seg_num=0) / dataio.sample_rate * .99) # ~ print('noise_duration', noise_duration) t1 = time.perf_counter() cc.estimate_signals_noise(seg_num=0, duration=noise_duration) t2 = time.perf_counter() print('estimate_signals_noise', t2 - t1) t1 = time.perf_counter() cc.run_signalprocessor(duration=fullchain_kargs['duration'], detect_peak=False) t2 = time.perf_counter() print('run_signalprocessor', t2 - t1)
def test_good_events(): dataio = DataIO(dirname = 'datatest') sigs = dataio.get_signals(seg_num=0) peakdetector = PeakDetector(sigs) peak_pos = peakdetector.detect_peaks(threshold=-4, peak_sign = '-', n_span = 2) #~ peak_pos = peak_pos[:100] #~ print(peak_pos) waveforms = extract_peak_waveforms(sigs, peak_pos, peak_pos, -30,50) keep = good_events(waveforms,upper_thr=5.,lower_thr=-5.) #~ print(keep) goods_wf = waveforms[keep] bads_wf = waveforms[~keep] fig, ax = pyplot.subplots() #~ goods_wf.transpose().plot(ax =ax, color = 'g', lw = .3) bads_wf.transpose().plot(ax =ax,color = 'r', lw = .3) med = waveforms.median(axis=0) mad = np.median(np.abs(waveforms-med),axis=0)*1.4826 limit1 = med+5*mad limit2 = med-5*mad med.plot(ax = ax, color = 'm', lw = 2) limit1.plot(ax = ax, color = 'm') limit2.plot(ax = ax, color = 'm')
def test_clustering(): dataio = DataIO(dirname = 'datatest') sigs = dataio.get_signals(seg_num=0) #peak peakdetector = PeakDetector(sigs) peak_pos = peakdetector.detect_peaks(threshold=-4, peak_sign = '-', n_span = 2) #waveforms waveformextractor = WaveformExtractor(peakdetector, n_left=-30, n_right=50) limit_left, limit_right = waveformextractor.find_good_limits(mad_threshold = 1.1) short_wf = waveformextractor.get_ajusted_waveforms() print(short_wf.shape) #clustering clustering = Clustering(short_wf) #PCA features = clustering.project(method = 'pca', n_components = 5) clustering.plot_projection(plot_density = True) #Kmean labels = clustering.find_clusters(7) clustering.plot_projection(plot_density = True) #ùake catalogue catalogue = clustering.construct_catalogue() clustering.plot_derivatives() clustering.plot_catalogue() clustering.merge_cluster(1,2) clustering.split_cluster(1, 2)
def test_rectify_signals(): dataio = DataIO(dirname = 'datatest') sigs = dataio.get_signals(seg_num=0) retified_sigs = rectify_signals(normalize_signals(sigs), threshold = -4) fig, ax = pyplot.subplots() retified_sigs[3.14:3.22].plot(ax = ax) ax.set_ylim(-20, 10)
def open_PeelerWindow(dirname, chan_grp): dataio = DataIO(dirname=dirname) initial_catalogue = dataio.load_catalogue(chan_grp=chan_grp) app = pg.mkQApp() win = PeelerWindow(dataio=dataio, catalogue=initial_catalogue) win.show() app.exec_()
def test_peakdetector(): dataio = DataIO(dirname = 'datatest') sigs = dataio.get_signals(seg_num=0) peakdetector = PeakDetector(sigs, seg_num=0) peakdetector.detect_peaks(threshold=-4, peak_sign = '-', n_span = 2) print(peakdetector.peak_pos.size) peakdetector.detect_peaks(threshold=-5, peak_sign = '-', n_span = 5) print(peakdetector.peak_pos.size)
def test_DataIO_probes(): # initialze dataio if os.path.exists('test_DataIO'): shutil.rmtree('test_DataIO') dataio = DataIO(dirname='test_DataIO') print(dataio) localdir, filenames, params = download_dataset(name='olfactory_bulb') dataio.set_data_source(type='RawData', filenames=filenames, **params) probe_filename = 'A4x8-5mm-100-400-413-A32.prb' dataio.download_probe(probe_filename) dataio.download_probe('A4x8-5mm-100-400-413-A32') #~ print(dataio.channel_groups) #~ print(dataio.channels) #~ print(dataio.info['probe_filename']) assert dataio.nb_channel(0) == 8 assert probe_filename == dataio.info['probe_filename'] dataio = DataIO(dirname='test_DataIO') print(dataio)
def test_peeler(): dataio = DataIO(dirname = 'datatest') #~ dataio = DataIO(dirname = 'datatest_neo') sigs = dataio.get_signals(seg_num=0) #peak peakdetector = PeakDetector(sigs) peak_pos = peakdetector.detect_peaks(threshold=-4, peak_sign = '-', n_span = 5) #waveforms waveformextractor = WaveformExtractor(peakdetector, n_left=-30, n_right=50) limit_left, limit_right = waveformextractor.find_good_limits(mad_threshold = 1.1) #~ print(limit_left, limit_right) short_wf = waveformextractor.get_ajusted_waveforms() #~ print(short_wf.shape) #clustering clustering = Clustering(short_wf) features = clustering.project(method = 'pca', n_components = 4) clustering.find_clusters(8, order_clusters = True) catalogue = clustering.construct_catalogue() #~ clustering.plot_catalogue(sameax = False) #~ clustering.plot_catalogue(sameax = True) #~ clustering.merge_cluster(1, 2) catalogue = clustering.construct_catalogue() clustering.plot_catalogue(sameax = False) #~ clustering.plot_catalogue(sameax = True) #peeler signals = peakdetector.normed_sigs peeler = Peeler(signals, catalogue, limit_left, limit_right, threshold=-5., peak_sign = '-', n_span = 5) prediction0, residuals0 = peeler.peel() prediction1, residuals1 = peeler.peel() spiketrains = peeler.get_spiketrains() print(spiketrains) fig, axs = pyplot.subplots(nrows = 6, sharex = True)#, sharey = True) axs[0].plot(signals) axs[1].plot(prediction0) axs[2].plot(residuals0) axs[3].plot(prediction1) axs[4].plot(residuals1) for i in range(5): axs[i].set_ylim(-25, 10) peeler.plot_spiketrains(ax = axs[5])
def test_DataIO(): # initialze dataio if os.path.exists('test_DataIO'): shutil.rmtree('test_DataIO') dataio = DataIO(dirname='test_DataIO') print(dataio) localdir, filenames, params = download_dataset(name='olfactory_bulb') dataio.set_data_source(type='RawData', filenames=filenames, **params) #~ dataio.set_channels(range(4)) dataio.set_manual_channel_group(range(14)) for seg_num in range(dataio.nb_segment): for i_stop, sigs_chunk in dataio.iter_over_chunk(seg_num=seg_num, chunksize=1024): assert sigs_chunk.shape[0] == 1024 assert sigs_chunk.shape[1] == 14 #~ print(seg_num, i_stop, sigs_chunk.shape) #reopen existing dataio = DataIO(dirname='test_DataIO') print(dataio) #~ exit() for seg_num in range(dataio.nb_segment): #~ print('seg_num', seg_num) for i_stop, sigs_chunk in dataio.iter_over_chunk(seg_num=seg_num, chunksize=1024): assert sigs_chunk.shape[0] == 1024 assert sigs_chunk.shape[1] == 14
def test_extract_noise_waveforms(): dataio = DataIO(dirname = 'datatest') sigs = dataio.get_signals(seg_num=0) peakdetector = PeakDetector(sigs) peak_pos = peakdetector.detect_peaks(threshold=-4, peak_sign = '-', n_span = 2) waveforms = extract_noise_waveforms(sigs, peak_pos, -30,50, size = 500) print(waveforms.shape) fig, ax = pyplot.subplots() waveforms.median(axis=0).plot(ax =ax)
def run_peeler(dirname, chan_grp): dataio = DataIO(dirname=dirname, ch_grp=chan_grp) initial_catalogue = dataio.load_catalogue(chan_grp=chan_grp) peeler = Peeler(dataio) peeler.change_params(catalogue=initial_catalogue, chunksize=32768, use_sparse_template=False, sparse_threshold_mad=1.5, use_opencl_with_sparse=False) t1 = time.perf_counter() peeler.run() t2 = time.perf_counter() print('peeler.run', t2 - t1)
def test_filter(): dataio = DataIO(dirname = 'datatest_neo') sigs = dataio.get_signals(seg_num=0, signal_type = 'unfiltered') filter = SignalFilter(sigs, highpass_freq = 300.) filterred_sigs = filter.get_filtered_data() filter2 = SignalFilter(sigs, highpass_freq = 300., box_smooth = 3) filterred_sigs2 = filter2.get_filtered_data() fig, axs = pyplot.subplots(nrows=3, sharex = True) sigs[0:.5].plot(ax=axs[0]) filterred_sigs[0:.5].plot(ax=axs[1]) filterred_sigs2[0:.5].plot(ax=axs[2])
def run_merge_clusters(subject, recording_date, ch_grp, threshold): data_dir = os.path.join(cfg['intan_data_dir'], subject, recording_date) print(data_dir) if os.path.exists(data_dir): # Compute total duration (want to use all data for clustering) data_file_names = [] for x in os.listdir(data_dir): if os.path.splitext(x)[1] == '.raw': data_file_names.append(os.path.join(data_dir, x)) data_file_times = [] for idx, evt_file in enumerate(data_file_names): fname = os.path.split(evt_file)[-1] fparts = fname.split('_') filedate = datetime.strptime( '%s %s' % (fparts[4], fparts[5].split('.')[0]), '%y%m%d %H%M%S') data_file_times.append(filedate) data_file_names = [ x for _, x in sorted(zip(data_file_times, data_file_names)) ] if os.path.exists(data_dir) and len(data_file_names) > 0: output_dir = os.path.join(cfg['single_unit_spike_sorting_dir'], subject, recording_date) ## Setup DataIO dataio = DataIO(dirname=output_dir) #dataio.set_data_source(type='RawData', filenames=data_file_names, dtype='float32', sample_rate=30000,total_channel=192) merge_clusters(dataio, ch_grp, output_dir, threshold)
def preprocess_signals_and_peaks(dirname): dataio = DataIO(dirname=dirname) catalogueconstructor = CatalogueConstructor(dataio=dataio) catalogueconstructor.set_preprocessor_params(chunksize=1024, # signal preprocessor highpass_freq=250, lowpass_freq=5000, smooth_size=0, common_ref_removal=False, lostfront_chunksize=0, signalpreprocessor_engine='opencl', # peak detector peakdetector_engine='opencl', peak_sign='-', relative_threshold=5, peak_span=0.0002 ) t1 = time.perf_counter() catalogueconstructor.estimate_signals_noise(seg_num=0, duration=10) t2 = time.perf_counter() print('estimate_signals_noise', t2 - t1) print(catalogueconstructor.signals_medians) print(catalogueconstructor.signals_mads) t1 = time.perf_counter() catalogueconstructor.run_signalprocessor(duration=300) t2 = time.perf_counter() print('run_signalprocessor', t2 - t1) print(catalogueconstructor)
def run_peeler(dirname): dataio = DataIO(dirname=dirname) initial_catalogue = dataio.load_catalogue(chan_grp=0) peeler = Peeler(dataio) peeler.change_params(catalogue=initial_catalogue) t1 = time.perf_counter() peeler.run() t2 = time.perf_counter() print('peeler.run', t2 - t1) print() for seg_num in range(dataio.nb_segment): spikes = dataio.get_spikes(seg_num) print('seg_num', seg_num, 'nb_spikes', spikes.size)
def test_dataio_with_neo(): if os.path.exists('datatest_neo/data.h5'): os.remove('datatest_neo/data.h5') dataio = DataIO(dirname = 'datatest_neo') import neo import quantities as pq filenames = ['Tem06c06.IOT', 'Tem06c07.IOT', 'Tem06c08.IOT', ] for filename in filenames: blocks = neo.RawBinarySignalIO(filename).read(sampling_rate = 10.*pq.kHz, t_start = 0. *pq.S, unit = pq.V, nbchannel = 16, bytesoffset = 0, dtype = 'int16', rangemin = -10, rangemax = 10) channel_indexes = np.arange(14) dataio.append_signals_from_neo(blocks, channel_indexes = channel_indexes, signal_type = 'unfiltered') print(dataio.summary(level=1))
def test_extract_peak_waveforms(): dataio = DataIO(dirname = 'datatest') sigs = dataio.get_signals(seg_num=0) peakdetector = PeakDetector(sigs) peak_pos = peakdetector.detect_peaks(threshold=-4, peak_sign = '-', n_span = 2) peak_index = sigs.index[peak_pos] waveforms = extract_peak_waveforms(sigs, peak_pos,peak_index, -30,50) print(waveforms.shape) fig, ax = pyplot.subplots() waveforms.median(axis=0).plot(ax =ax) normed_waveforms = extract_peak_waveforms(peakdetector.normed_sigs, peak_pos, peak_index, -15,50) fig, ax = pyplot.subplots() normed_waveforms.median(axis=0).plot(ax =ax)
def open_cataloguewindow(dirname): dataio = DataIO(dirname=dirname) catalogueconstructor = CatalogueConstructor(dataio=dataio) app = pg.mkQApp() win = CatalogueWindow(catalogueconstructor) win.show() app.exec_()
def test_waveform_extractor(): dataio = DataIO(dirname = 'datatest') sigs = dataio.get_signals(seg_num=0) peakdetector = PeakDetector(sigs) peakdetector.detect_peaks(threshold=-4, peak_sign = '-', n_span = 5) waveformextractor = WaveformExtractor(peakdetector, n_left=-30, n_right=50) limit_left, limit_right = waveformextractor.find_good_limits(mad_threshold = 1.1) print(limit_left, limit_right) long_wf = waveformextractor.long_waveforms short_wf = waveformextractor.get_ajusted_waveforms() assert long_wf.shape[1]>short_wf.shape[1] waveformextractor.plot_good_limit()
def test_peeler(): dataio = DataIO(dirname = 'datatest') sigs = dataio.get_signals(seg_num=0) #peak peakdetector = PeakDetector(sigs) peak_pos = peakdetector.detect_peaks(threshold=-4, peak_sign = '-', n_span = 5) #waveforms waveformextractor = WaveformExtractor(peakdetector, n_left=-30, n_right=50) limit_left, limit_right = waveformextractor.find_good_limits(mad_threshold = 1.1) #~ print(limit_left, limit_right) short_wf = waveformextractor.get_ajusted_waveforms(margin=2) #~ print(short_wf.shape) #clustering clustering = Clustering(short_wf) features = clustering.project(method = 'pca', n_components = 5) clustering.find_clusters(7) catalogue = clustering.construct_catalogue() clustering.plot_catalogue() #peeler signals = peakdetector.normed_sigs peeler = Peeler(signals, catalogue, limit_left, limit_right, threshold=-4, peak_sign = '-', n_span = 5) prediction0, residuals0 = peeler.peel() prediction1, residuals1 = peeler.peel() fig, axs = pyplot.subplots(nrows = 6, sharex = True)#, sharey = True) axs[0].plot(signals) axs[1].plot(prediction0) axs[2].plot(residuals0) axs[3].plot(prediction1) axs[4].plot(residuals1) colors = sns.color_palette('husl', len(catalogue)) spiketrains = peeler.get_spiketrains() i = 0 for k , pos in spiketrains.items(): axs[5].plot(pos, np.ones(pos.size)*k, ls = 'None', marker = '|', markeredgecolor = colors[i], markersize = 10, markeredgewidth = 2) i += 1 axs[5].set_ylim(0, len(catalogue))
def test_dataio(): if os.path.exists('datatest/data.h5'): os.remove('datatest/data.h5') dataio = DataIO(dirname = 'datatest') #~ print(data) #data from locust sigs_by_trials, sampling_rate, ch_names = download_locust(trial_names = ['trial_01', 'trial_02', 'trial_03']) for seg_num in range(3): sigs = sigs_by_trials[seg_num] dataio.append_signals_from_numpy(sigs, seg_num = seg_num,t_start = 0.+5*seg_num, sampling_rate = sampling_rate, signal_type = 'filtered', channels = ch_names) #~ print(data) #~ print(data.segments) #~ print(data.store) print(dataio.summary(level=0)) print(dataio.summary(level=1))
def plot_interpolation(): dataio = DataIO(dirname = 'datatest') sigs = dataio.get_signals(seg_num=0) #peak peakdetector = PeakDetector(sigs) peak_pos = peakdetector.detect_peaks(threshold=-4, peak_sign = '-', n_span = 5) #waveforms waveformextractor = WaveformExtractor(peakdetector, n_left=-30, n_right=50) limit_left, limit_right = waveformextractor.find_good_limits(mad_threshold = 1.1) #~ print(limit_left, limit_right) short_wf = waveformextractor.get_ajusted_waveforms(margin=2) #~ print(short_wf.shape) #clustering clustering = Clustering(short_wf) features = clustering.project(method = 'pca', n_components = 5) clustering.find_clusters(7) catalogue = clustering.construct_catalogue() k = list(catalogue.keys())[1] w0 = catalogue[k]['center'] w1 = catalogue[k]['centerD'] w2 = catalogue[k]['centerDD'] fig, ax = pyplot.subplots() t = np.arange(w0.size) colors = sns.color_palette('husl', 12) all = [] jitters = np.arange(-.5,.5,.1) for i, jitter in enumerate(jitters): pred = w0 + jitter*w1 + jitter**2/2.*w2 all.append(pred) ax.plot(t+jitter, pred, marker = 'o', label = str(jitter), color = colors[i], linestyle = 'None') ax.plot(t, w0, marker = '*', markersize = 4, label = 'w0', lw = 1, color = 'k') all = np.array(all) interpolated = all.transpose().flatten() t2 = np.arange(interpolated.size)/all.shape[0] + jitters[0] ax.plot(t2, interpolated, label = 'interp', lw = 1, color = 'm')
def test_dataio_catalogue(): if os.path.exists('test_DataIO'): shutil.rmtree('test_DataIO') dataio = DataIO(dirname='test_DataIO') catalogue = {} catalogue['chan_grp'] = 0 catalogue['centers0'] = np.ones((300, 12, 50)) catalogue['n_left'] = -15 catalogue['params_signalpreprocessor'] = {'highpass_freq' : 300.} dataio.save_catalogue(catalogue, name='test') c2 = dataio.load_catalogue(name='test', chan_grp=0) print(c2) assert c2['n_left'] == -15 assert np.all(c2['centers0']==1) assert catalogue['params_signalpreprocessor']['highpass_freq'] == 300.
def test_detect_peak_method_span(): dataio = DataIO(dirname = 'datatest') sigs = dataio.get_signals(seg_num=0) normed_sigs = normalize_signals(sigs) retified_sigs = rectify_signals(normed_sigs, threshold = -4) peaks_pos = detect_peak_method_span(retified_sigs, peak_sign='-', n_span = 5) peaks_index = sigs.index[peaks_pos] fig, ax = pyplot.subplots() chunk = retified_sigs[3.14:3.22] chunk.plot(ax = ax) peaks_value = retified_sigs.loc[peaks_index] peaks_value[3.14:3.22].plot(marker = 'o', linestyle = 'None', ax = ax, color = 'k') ax.set_ylim(-20, 10) fig, ax = pyplot.subplots() chunk = normed_sigs[3.14:3.22] chunk.plot(ax = ax) peaks_value = normed_sigs.loc[peaks_index] peaks_value[3.14:3.22].plot(marker = 'o', linestyle = 'None', ax = ax, color = 'k') ax.set_ylim(-20, 10)
def test_find_good_limits(): dataio = DataIO(dirname = 'datatest') sigs = dataio.get_signals(seg_num=0) peakdetector = PeakDetector(sigs) peak_pos = peakdetector.detect_peaks(threshold=-4, peak_sign = '-', n_span = 5) peak_index = sigs.index[peak_pos] normed_waveforms = extract_peak_waveforms(peakdetector.normed_sigs, peak_pos, peak_index, -25,50) normed_med = normed_waveforms.median(axis=0) normed_mad = np.median(np.abs(normed_waveforms-normed_med),axis=0)*1.4826 normed_mad = normed_mad.reshape(4,-1) fig, ax = pyplot.subplots() ax.plot(normed_mad.transpose()) ax.axhline(1.1) l1, l2 = find_good_limits(normed_mad) print(l1,l2) ax.axvline(l1) ax.axvline(l2)
def make_pca_collision_figure(): dataio = DataIO(dirname=dirname) cc = CatalogueConstructor(dataio=dataio) clusters = cc.clusters #~ plot_features_scatter_2d(cc, labels=None, nb_max=500) #~ plot_features_scatter_2d fig, ax = plt.subplots() ax.set_title('Collision problem') ax.set_aspect('equal') features = cc.some_features labels = cc.all_peaks[cc.some_peaks_index]['cluster_label'] for k in [0, 1, 2, 3]: color = clusters[clusters['cluster_label'] == k]['color'][0] color = int32_to_rgba(color, mode='float') keep = labels == k feat = features[keep] print(np.unique(labels)) ax.plot(feat[:, 0], feat[:, 1], ls='None', marker='o', color=color, markersize=3, alpha=.5) ax.set_xlim(-40, 40) ax.set_ylim(-40, 40) ax.set_xlabel('pca0') ax.set_ylabel('pca1') ax.annotate('Collision', xy=(17.6, -16.4), xytext=(30, -30), arrowprops=dict(facecolor='black', shrink=0.05)) #~ fig.savefig('../img/collision_proble_pca.png')
def clean_and_save_catalogue(dirname): dataio = DataIO(dirname=dirname) catalogueconstructor = CatalogueConstructor(dataio=dataio) catalogueconstructor.trash_small_cluster(n=5) # order cluster by waveforms rms catalogueconstructor.order_clusters(by='waveforms_rms') # put label 0 to trash mask = catalogueconstructor.all_peaks['cluster_label'] == 0 catalogueconstructor.all_peaks['cluster_label'][mask] = -1 catalogueconstructor.on_new_cluster() # save the catalogue catalogueconstructor.make_catalogue_for_peeler()
def extract_waveforms_pca_cluster(dirname): dataio = DataIO(dirname=dirname) catalogueconstructor = CatalogueConstructor(dataio=dataio) print(catalogueconstructor) t1 = time.perf_counter() # ~ catalogueconstructor.extract_some_waveforms(n_left=-35, n_right=150, nb_max=10000, align_waveform=True, subsample_ratio=20) catalogueconstructor.extract_some_waveforms(n_left=-20, n_right=30, nb_max=20000, align_waveform=False) t2 = time.perf_counter() print('extract_some_waveforms', t2 - t1) # ~ print(catalogueconstructor.some_waveforms.shape) print(catalogueconstructor) # ~ t1 = time.perf_counter() # ~ n_left, n_right = catalogueconstructor.find_good_limits(mad_threshold = 1.1,) # ~ t2 = time.perf_counter() # ~ print('n_left', n_left, 'n_right', n_right) # ~ print(catalogueconstructor.some_waveforms.shape) print(catalogueconstructor) # ~ print(catalogueconstructor.all_peaks) # ~ exit() t1 = time.perf_counter() catalogueconstructor.clean_waveforms(alien_value_threshold=100.) t2 = time.perf_counter() print('clean_waveforms', t2 - t1) # extract_some_noise t1 = time.perf_counter() catalogueconstructor.extract_some_noise(nb_snippet=300) t2 = time.perf_counter() print('extract_some_noise', t2 - t1) t1 = time.perf_counter() catalogueconstructor.project(method='pca_by_channel', n_components_by_channel=3) # ~ catalogueconstructor.project(method='tsne', n_components=2, perplexity=40., init='pca') t2 = time.perf_counter() print('project', t2 - t1) print(catalogueconstructor) t1 = time.perf_counter() catalogueconstructor.find_clusters(method='gmm', n_clusters=3*192) t2 = time.perf_counter() print('find_clusters', t2 - t1) print(catalogueconstructor)
def make_catalogue(): if os.path.exists(dirname): shutil.rmtree(dirname) dataio = DataIO(dirname=dirname) localdir, filenames, params = download_dataset(name='olfactory_bulb') dataio.set_data_source(type='RawData', filenames=filenames, **params) dataio.add_one_channel_group(channels=channels) cc = CatalogueConstructor(dataio=dataio) params = { 'duration': 300., 'preprocessor': { 'highpass_freq': 300., 'chunksize': 1024, 'lostfront_chunksize': 100, }, 'peak_detector': { 'peak_sign': '-', 'relative_threshold': 7., 'peak_span': 0.0005, #~ 'peak_span' : 0.000, }, 'extract_waveforms': { 'n_left': -25, 'n_right': 40, 'nb_max': 10000, }, 'clean_waveforms': { 'alien_value_threshold': 60., }, 'noise_snippet': { 'nb_snippet': 300, }, 'feature_method': 'global_pca', 'feature_kargs': { 'n_components': 20 }, 'cluster_method': 'kmeans', 'cluster_kargs': { 'n_clusters': 5 }, 'clean_cluster': False, 'clean_cluster_kargs': {}, } apply_all_catalogue_steps(cc, params, verbose=True) cc.order_clusters(by='waveforms_rms') cc.move_cluster_to_trash(4) cc.make_catalogue_for_peeler()
def preprocess_data(subject, recording_date, data_files): output_dir = os.path.join(cfg['single_unit_spike_sorting_dir'], subject, recording_date, 'preprocess') if os.path.exists(output_dir): # remove is already exists shutil.rmtree(output_dir) ## Setup DataIO dataio = DataIO(dirname=output_dir) dataio.set_data_source(type='Intan', filenames=[x['fname'] for x in data_files]) # Setup channel groups arrays_recorded = [] grp_idx = 0 for array_idx in range(len(cfg['arrays'])): first_chan = '' if array_idx == 0: first_chan = 'A-000' elif array_idx == 1: first_chan = 'A-032' elif array_idx == 2: first_chan = 'B-000' elif array_idx == 3: first_chan = 'B-032' elif array_idx == 4: first_chan = 'C-000' elif array_idx == 5: first_chan = 'C-032' found = False for i in range(len(dataio.datasource.sig_channels)): if dataio.datasource.sig_channels[i][0] == first_chan: found = True break chan_range = [] if found: chan_range = range(grp_idx * cfg['n_channels_per_array'], (grp_idx + 1) * cfg['n_channels_per_array']) grp_idx = grp_idx + 1 arrays_recorded.append(array_idx) dataio.add_one_channel_group(channels=chan_range, chan_grp=array_idx) print(dataio) total_duration = np.sum([x['duration'] for x in data_files]) for array_idx in arrays_recorded: print(array_idx) preprocess_array(array_idx, output_dir, total_duration)
def getSortedTimes(dirName, chanGroup): dataio = DataIO(dirname=dirName) dataio.load_catalogue(chan_grp=chanGroup) catalogueconstructor = CatalogueConstructor(dataio=dataio) sample_rate = dataio.sample_rate # Just initialize sample rate, will set later unitTimes = np.empty(dataio.nb_segment, dtype=object) for j in range(dataio.nb_segment): idd = {} times = {} try: # List of all cluster labels cluster_ids = np.array([i for i in catalogueconstructor.cluster_labels]) # List of all detected peaks by cluster ID clusters = np.array([i[1] for i in dataio.get_spikes(j)]) spike_times = np.array([i[0] for i in dataio.get_spikes(j)]) except: cluster_ids = np.array([]) clusters = np.array([]) spike_times = np.array([]) for i in cluster_ids: idd[i] = np.argwhere(clusters == i) for i in cluster_ids: times[i] = spike_times[idd[i]]/sample_rate mx = np.max([times[i].size for i in times.keys()]) for i in times.keys(): times[i].resize(mx + 1, 1) timesArray = np.array([times[i] for i in times.keys()]) timesArray = np.roll(timesArray, 1) timesArray[:, 0, :] = np.array(list(times.keys())).reshape(timesArray.shape[0], 1) timesArray = np.transpose(timesArray) unitTimes[j] = timesArray[0] return unitTimes
def export_spikes(dirname, array_idx, chan_grp): print('Exporting ch %d' % chan_grp) data = { 'array': [], 'electrode': [], 'cell': [], 'segment': [], 'time': [] } array = cfg['arrays'][array_idx] dataio = DataIO(dirname=dirname, ch_grp=chan_grp) catalogue = dataio.load_catalogue(chan_grp=chan_grp) dataio._open_processed_data(ch_grp=chan_grp) clusters = catalogue['clusters'] for seg_num in range(dataio.nb_segment): spikes = dataio.get_spikes(seg_num=seg_num, chan_grp=chan_grp) spike_labels = spikes['cluster_label'].copy() for l in clusters: mask = spike_labels == l['cluster_label'] spike_labels[mask] = l['cell_label'] spike_indexes = spikes['index'] for (index, label) in zip(spike_indexes, spike_labels): if label >= 0: data['array'].append(array) data['electrode'].append(chan_grp) data['cell'].append(label) data['segment'].append(seg_num) data['time'].append(index) dataio.flush_processed_signals(seg_num=seg_num, chan_grp=chan_grp) df = pd.DataFrame( data, columns=['array', 'electrode', 'cell', 'segment', 'time']) df.to_csv(os.path.join(dirname, '%s_%d_spikes.csv' % (array, chan_grp)), index=False)
def initialize_catalogueconstructor(dirname, filenames): # create a DataIO if os.path.exists(dirname): # remove is already exists shutil.rmtree(dirname) dataio = DataIO(dirname=dirname) # The dataset contains 4 channels : we use them all #dataio.set_channel_groups({'channels':{'channels':[0, 1, 2, 3]}}) # feed DataIO dataio.set_data_source(type='Intan', filenames=filenames, channel_indexes=list(range(192))) #dataio.set_probe_file('/home/bonaiuto/Projects/tool_learning/recordings/rhd2000/betta/default.prb') dataio.add_one_channel_group(channels=range(192), chan_grp=0) print(dataio)
def generate_spike_sorting_report(subject, recording_date): data_dir = os.path.join(cfg['single_unit_spike_sorting_dir'], subject, recording_date) if os.path.exists(data_dir): channel_results = [] for array_idx in range(len(cfg['arrays'])): array = cfg['arrays'][array_idx] print(array) array_data_dir = os.path.join(data_dir, 'array_%d' % array_idx) if os.path.exists(array_data_dir): export_path = os.path.join(array_data_dir, 'figures') if not os.path.exists(export_path): os.makedirs(export_path) for chan_grp in range(cfg['n_channels_per_array']): print(chan_grp) dataio = DataIO(array_data_dir, ch_grp=chan_grp) dataio.datasource.bit_to_microVolt = 0.195 catalogueconstructor = CatalogueConstructor( dataio=dataio, chan_grp=chan_grp) catalogueconstructor.refresh_colors() catalogue = dataio.load_catalogue(chan_grp=chan_grp) channel_result = { 'array': array, 'channel': chan_grp, 'init_waveforms': '', 'clean_waveforms': '', 'noise': '', 'init_clusters': '', 'merge_clusters': [], 'final_clusters': [], 'all_clusters': '' } clusters = catalogue['clusters'] cluster_labels = clusters['cluster_label'] cell_labels = clusters['cell_label'] channel_result['init_waveforms'] = os.path.join( 'array_%d' % array_idx, 'figures', 'chan_%d_init_waveforms.png' % chan_grp) channel_result['clean_waveforms'] = os.path.join( 'array_%d' % array_idx, 'figures', 'chan_%d_clean_waveforms.png' % chan_grp) channel_result['noise'] = os.path.join( 'array_%d' % array_idx, 'figures', 'chan_%d_noise.png' % chan_grp) channel_result['init_clusters'] = os.path.join( 'array_%d' % array_idx, 'figures', 'chan_%d_init_clusters.png' % chan_grp) merge_files = glob.glob( os.path.join(export_path, 'chan_%d_merge_*.png' % chan_grp)) for merge_file in merge_files: [path, file] = os.path.split(merge_file) channel_result['merge_clusters'].append( os.path.join('array_%d' % array_idx, 'figures', file)) for cluster_label in cluster_labels: fig = plot_cluster_summary(dataio, catalogue, chan_grp, cluster_label) fname = 'chan_%d_cluster_%d.png' % (chan_grp, cluster_label) fig.savefig(os.path.join(export_path, fname)) fig.clf() plt.close() channel_result['final_clusters'].append( os.path.join('array_%d' % array_idx, 'figures', fname)) fig = plot_clusters_summary(dataio, catalogueconstructor, chan_grp) fname = 'chan_%d_clusters.png' % chan_grp fig.savefig(os.path.join(export_path, fname)) fig.clf() plt.close() channel_result['all_clusters'] = os.path.join( 'array_%d' % array_idx, 'figures', fname) channel_results.append(channel_result) env = Environment(loader=FileSystemLoader(cfg['template_dir'])) template = env.get_template('spike_sorting_results_template.html') template_output = template.render(subject=subject, recording_date=recording_date, channel_results=channel_results) out_filename = os.path.join(data_dir, 'spike_sorting_report.html') with open(out_filename, 'w') as fh: fh.write(template_output) copyfile(os.path.join(cfg['template_dir'], 'style.css'), os.path.join(data_dir, 'style.css'))
def test_derivative_signals(): dataio = DataIO(dirname = 'datatest') sigs = dataio.get_signals(seg_num=0) deriv_sigs = derivative_signals(sigs) deriv_sigs[3.14:3.22].plot()
def test_normalize_signals(): dataio = DataIO(dirname = 'datatest') sigs = dataio.get_signals(seg_num=0) normed_sigs = normalize_signals(sigs) normed_sigs[3.14:3.22].plot()
def test_DataIO(): # initialze dataio if os.path.exists('test_DataIO'): shutil.rmtree('test_DataIO') dataio = DataIO(dirname='test_DataIO') print(dataio) localdir, filenames, params = download_dataset(name='olfactory_bulb') dataio.set_data_source(type='RawData', filenames=filenames, **params) #with geometry channels = list(range(14)) channel_groups = {0:{'channels':range(14), 'geometry' : { c: [0, i] for i, c in enumerate(channels) }}} dataio.set_channel_groups(channel_groups) #with no geometry channel_groups = {0:{'channels':range(4)}} dataio.set_channel_groups(channel_groups) # add one group dataio.add_one_channel_group(channels=range(4,8), chan_grp=5) channel_groups = {0:{'channels':range(14)}} dataio.set_channel_groups(channel_groups) for seg_num in range(dataio.nb_segment): for i_stop, sigs_chunk in dataio.iter_over_chunk(seg_num=seg_num, chunksize=1024): assert sigs_chunk.shape[0] == 1024 assert sigs_chunk.shape[1] == 14 #~ print(seg_num, i_stop, sigs_chunk.shape) #reopen existing dataio = DataIO(dirname='test_DataIO') print(dataio) #~ exit() for seg_num in range(dataio.nb_segment): #~ print('seg_num', seg_num) for i_stop, sigs_chunk in dataio.iter_over_chunk(seg_num=seg_num, chunksize=1024): assert sigs_chunk.shape[0] == 1024 assert sigs_chunk.shape[1] == 14
def test_DataIO_probes(): # initialze dataio if os.path.exists('test_DataIO'): shutil.rmtree('test_DataIO') dataio = DataIO(dirname='test_DataIO') print(dataio) localdir, filenames, params = download_dataset(name='olfactory_bulb') dataio.set_data_source(type='RawData', filenames=filenames, **params) probe_filename = 'neuronexus/A4x8-5mm-100-400-413-A32.prb' dataio.download_probe(probe_filename) dataio.download_probe('neuronexus/A4x8-5mm-100-400-413-A32') #~ print(dataio.channel_groups) #~ print(dataio.channels) #~ print(dataio.info['probe_filename']) assert dataio.nb_channel(0) == 8 assert probe_filename.split('/')[-1] == dataio.info['probe_filename'] dataio = DataIO(dirname='test_DataIO') print(dataio)
def make_animation(): """ Good example between 1.272 1.302 because collision """ dataio = DataIO(dirname=dirname) catalogue = dataio.load_catalogue(chan_grp=0) clusters = catalogue['clusters'] sr = dataio.sample_rate # also a good one a 11.356 - 11.366 t1, t2 = 1.272, 1.295 i1, i2 = int(t1 * sr), int(t2 * sr) spikes = dataio.get_spikes() spike_times = spikes['index'] / sr keep = (spike_times >= t1) & (spike_times <= t2) spikes = spikes[keep] print(spikes) sigs = dataio.get_signals_chunk(i_start=i1, i_stop=i2, signal_type='processed') sigs = sigs.copy() times = np.arange(sigs.shape[0]) / dataio.sample_rate def plot_spread_sigs(sigs, ax, ratioY=0.02, **kargs): #spread signals sigs2 = sigs * ratioY sigs2 += np.arange(0, len(channels))[np.newaxis, :] ax.plot(times, sigs2, **kargs) ax.set_ylim(-0.5, len(channels) - .5) ax.set_xticks([]) ax.set_yticks([]) residuals = sigs.copy() local_spikes = spikes.copy() local_spikes['index'] -= i1 #~ fig, ax = plt.subplots() #~ plot_spread_sigs(sigs, ax, color='k') num_fig = 0 fig_pred, ax_predictions = plt.subplots() ax_predictions.set_title('All detected templates from catalogue') fig, ax = plt.subplots() plot_spread_sigs(residuals, ax, color='k', lw=2) ax.set_title('Initial filtered signals with spikes') fig.savefig('../img/peeler_animation_sigs.png') fig.savefig('png/fig{}.png'.format(num_fig)) num_fig += 1 for i in range(local_spikes.size): label = local_spikes['cluster_label'][i] color = clusters[clusters['cluster_label'] == label]['color'][0] color = int32_to_rgba(color, mode='float') pred = make_prediction_signals(local_spikes[i:i + 1], 'float32', (i2 - i1, len(channels)), catalogue) fig, ax = plt.subplots() plot_spread_sigs(residuals, ax, color='k', lw=2) plot_spread_sigs(pred, ax, color=color, lw=1.5) ax.set_title('Dected spike label {}'.format(label)) fig.savefig('png/fig{}.png'.format(num_fig)) num_fig += 1 residuals -= pred plot_spread_sigs(pred, ax_predictions, color=color, lw=1.5) fig, ax = plt.subplots() plot_spread_sigs(residuals, ax, color='k', lw=2) plot_spread_sigs(pred, ax, color=color, lw=1, ls='--') ax.set_title('New residual after substraction') fig.savefig('png/fig{}.png'.format(num_fig)) num_fig += 1 fig_pred.savefig('png/fig{}.png'.format(num_fig)) num_fig += 1
def export_spikes(dirname): dataio = DataIO(dirname=dirname) dataio.export_spikes(dirname,formats='csv') dataio.export_spikes(dirname, formats='mat')
def compute_array_catalogue(array_idx, preprocess_dir, subject, recording_date, data_files, cluster_merge_threshold): # If data exists for this array if os.path.exists( os.path.join(preprocess_dir, 'channel_group_%d' % array_idx, 'catalogue_constructor')): output_dir = os.path.join(cfg['single_unit_spike_sorting_dir'], subject, recording_date, 'array_%d' % array_idx) if os.path.exists(output_dir): # remove is already exists shutil.rmtree(output_dir) # Compute total duration (want to use all data for clustering) data_file_names = [] for seg in range(len(data_files)): data_file_names.append( os.path.join(preprocess_dir, 'channel_group_%d' % array_idx, 'segment_%d' % seg, 'processed_signals.raw')) dataio = DataIO(dirname=output_dir) dataio.set_data_source(type='RawData', filenames=data_file_names, dtype='float32', sample_rate=cfg['intan_srate'], total_channel=cfg['n_channels_per_array']) dataio.datasource.bit_to_microVolt = 0.195 for ch_grp in range(cfg['n_channels_per_array']): dataio.add_one_channel_group(channels=[ch_grp], chan_grp=ch_grp) total_duration = np.sum([x['duration'] for x in data_files]) figure_out_dir = os.path.join(output_dir, 'figures') os.mkdir(figure_out_dir) for ch_grp in range(cfg['n_channels_per_array']): print(ch_grp) cc = CatalogueConstructor(dataio=DataIO(dirname=output_dir, ch_grp=ch_grp), chan_grp=ch_grp) fullchain_kargs = { 'duration': total_duration, 'preprocessor': { 'highpass_freq': None, 'lowpass_freq': None, 'smooth_size': 0, 'common_ref_removal': False, 'chunksize': 32768, 'lostfront_chunksize': 0, 'signalpreprocessor_engine': 'numpy', }, 'peak_detector': { 'peakdetector_engine': 'numpy', 'peak_sign': '-', 'relative_threshold': 2., 'peak_span': 0.0002, }, 'noise_snippet': { 'nb_snippet': 300, }, 'extract_waveforms': { 'n_left': -20, 'n_right': 30, 'mode': 'all', 'nb_max': 2000000, 'align_waveform': False, }, 'clean_waveforms': { 'alien_value_threshold': 100., }, } feat_method = 'pca_by_channel' feat_kargs = {'n_components_by_channel': 5} clust_method = 'sawchaincut' clust_kargs = { 'max_loop': 1000, 'nb_min': 20, 'break_nb_remain': 30, 'kde_bandwith': 0.01, 'auto_merge_threshold': 2., 'print_debug': False # 'max_loop': 1000, # 'nb_min': 20, # 'break_nb_remain': 30, # 'kde_bandwith': 0.01, # 'auto_merge_threshold': cluster_merge_threshold, # 'print_debug': False } p = {} p.update(fullchain_kargs['preprocessor']) p.update(fullchain_kargs['peak_detector']) cc.set_preprocessor_params(**p) noise_duration = min( 10., fullchain_kargs['duration'], dataio.get_segment_length(seg_num=0) / dataio.sample_rate * .99) # ~ print('noise_duration', noise_duration) t1 = time.perf_counter() cc.estimate_signals_noise(seg_num=0, duration=noise_duration) t2 = time.perf_counter() print('estimate_signals_noise', t2 - t1) t1 = time.perf_counter() cc.run_signalprocessor(duration=fullchain_kargs['duration']) t2 = time.perf_counter() print('run_signalprocessor', t2 - t1) t1 = time.perf_counter() cc.extract_some_waveforms(**fullchain_kargs['extract_waveforms']) t2 = time.perf_counter() print('extract_some_waveforms', t2 - t1) fname = 'chan_%d_init_waveforms.png' % ch_grp fig = plot_waveforms(np.squeeze(cc.some_waveforms).T) fig.savefig(os.path.join(figure_out_dir, fname)) fig.clf() plt.close() t1 = time.perf_counter() # ~ duration = d['duration'] if d['limit_duration'] else None # ~ d['clean_waveforms'] cc.clean_waveforms(**fullchain_kargs['clean_waveforms']) t2 = time.perf_counter() print('clean_waveforms', t2 - t1) fname = 'chan_%d_clean_waveforms.png' % ch_grp fig = plot_waveforms(np.squeeze(cc.some_waveforms).T) fig.savefig(os.path.join(figure_out_dir, fname)) fig.clf() plt.close() # ~ t1 = time.perf_counter() # ~ n_left, n_right = cc.find_good_limits(mad_threshold = 1.1,) # ~ t2 = time.perf_counter() # ~ print('find_good_limits', t2-t1) t1 = time.perf_counter() cc.extract_some_noise(**fullchain_kargs['noise_snippet']) t2 = time.perf_counter() print('extract_some_noise', t2 - t1) # Plot noise fname = 'chan_%d_noise.png' % ch_grp fig = plot_noise(cc) fig.savefig(os.path.join(figure_out_dir, fname)) fig.clf() plt.close() t1 = time.perf_counter() cc.extract_some_features(method=feat_method, **feat_kargs) t2 = time.perf_counter() print('project', t2 - t1) t1 = time.perf_counter() cc.find_clusters(method=clust_method, **clust_kargs) t2 = time.perf_counter() print('find_clusters', t2 - t1) # Remove empty clusters cc.trash_small_cluster(n=0) if cc.centroids_median is None: cc.compute_all_centroid() # order cluster by waveforms rms cc.order_clusters(by='waveforms_rms') fname = 'chan_%d_init_clusters.png' % ch_grp cluster_labels = cc.clusters['cluster_label'] fig = plot_cluster_waveforms(cc, cluster_labels) fig.savefig(os.path.join(figure_out_dir, fname)) fig.clf() plt.close() # save the catalogue cc.make_catalogue_for_peeler() gc.collect()
def run_compare_catalogues(subject, date, similarity_threshold=0.7): bin_min = 0 bin_max = 100 bin_size = 1. bins = np.arange(bin_min, bin_max, bin_size) if os.path.exists( os.path.join(cfg['single_unit_spike_sorting_dir'], subject, date)): # Get previous days to compare to sorted_dates, sorted_files = read_previous_sorts(subject) new_date = datetime.strptime(date, '%d.%m.%y') # Results for report channel_results = [] for array_idx in range(len(cfg['arrays'])): new_output_dir = os.path.join(cfg['single_unit_spike_sorting_dir'], subject, date, 'array_%d' % array_idx) if os.path.exists(new_output_dir): # Create directory for plots plot_output_dir = os.path.join(new_output_dir, 'figures', 'catalogue_comparison') if not os.path.exists(plot_output_dir): os.mkdir(plot_output_dir) for ch_grp in range(cfg['n_channels_per_array']): channel_result = { 'array': cfg['arrays'][array_idx], 'channel': ch_grp, 'merged': [], 'unmerged': [], 'final': '' } # load catalogue for this channel new_dataio = DataIO(dirname=new_output_dir, ch_grp=ch_grp) catalogueconstructor = CatalogueConstructor( dataio=new_dataio, chan_grp=ch_grp) # Waveform time range time_range = range( catalogueconstructor.info['waveform_extractor_params'] ['n_left'], catalogueconstructor.info['waveform_extractor_params'] ['n_right']) # refresh if catalogueconstructor.centroids_median is None: catalogueconstructor.compute_all_centroid() catalogueconstructor.refresh_colors() # cell labels and cluster waveforms for this day nn_idx = np.where( catalogueconstructor.clusters['cell_label'] > -1)[0] # If there are any clusters for this day if len(nn_idx) > 0: new_cluster_labels, new_cell_labels, new_wfs, new_wfs_stds, new_isis = get_cluster_info( catalogueconstructor, new_dataio, ch_grp, bins, nn_idx) # Load cell labels and waveforms for all previous days all_old_cluster_labels = [] all_old_cell_labels = [] all_old_wfs = np.zeros((0, 50)) all_old_wfs_stds = np.zeros((0, 50)) all_old_isis = np.zeros((0, len(bins) - 1)) recent_sorted_files = copy.copy(sorted_files) recent_sorted_dates = copy.copy(sorted_dates) if len(recent_sorted_files) > 20: recent_sorted_files = recent_sorted_files[-20:] recent_sorted_dates = recent_sorted_dates[-20:] for old_date, old_file in zip(recent_sorted_dates, recent_sorted_files): if old_date < new_date: print('loading %s' % old_date) old_output_dir = os.path.join( cfg['single_unit_spike_sorting_dir'], subject, old_file, 'array_%d' % array_idx) if os.path.exists(old_output_dir): old_dataio = DataIO( dirname=old_output_dir, ch_grp=ch_grp, reload_data_source=False) old_catalogueconstructor = CatalogueConstructor( dataio=old_dataio, chan_grp=ch_grp, load_persistent_arrays=False) old_catalogueconstructor.arrays.load_if_exists( 'clusters') old_catalogueconstructor.arrays.load_if_exists( 'centroids_median') old_catalogueconstructor.arrays.load_if_exists( 'centroids_std') if old_catalogueconstructor.centroids_median is None: old_catalogueconstructor.compute_all_centroid( ) old_cell_labels = old_catalogueconstructor.clusters[ 'cell_label'] to_include = np.where( np.bitwise_and( np.isin( old_cell_labels, all_old_cell_labels) == False, old_cell_labels > -1))[0] if len(to_include): old_cluster_labels, old_cell_labels, old_wfs, old_wfs_stds, old_isis = get_cluster_info( old_catalogueconstructor, old_dataio, ch_grp, bins, to_include) all_old_wfs = np.concatenate( (all_old_wfs, copy.deepcopy(old_wfs))) all_old_wfs_stds = np.concatenate( (all_old_wfs_stds, copy.deepcopy(old_wfs_stds))) all_old_cell_labels.extend( copy.deepcopy(old_cell_labels)) all_old_cluster_labels.extend( copy.deepcopy(old_cluster_labels)) all_old_isis = np.concatenate( (all_old_isis, copy.deepcopy(old_isis))) #old_dataio.arrays.detach_array('clusters', mmap_close=True) #old_dataio.arrays.detach_array('centroids_median', mmap_close=True) #old_dataio.arrays.detach_array('centroids_std', mmap_close=True) if len(all_old_cell_labels): # Compute cluster similarity wfs = np.concatenate((new_wfs, all_old_wfs)) cluster_similarity = metrics.cosine_similarity_with_max( wfs) new_old_cluster_similarity = cluster_similarity[ 0:new_wfs.shape[0], new_wfs.shape[0]:] # Plot cluster similarity fname = 'chan_%d_similarity.png' % ch_grp plot_new_old_cluster_similarity( all_old_cluster_labels, all_old_cell_labels, new_cluster_labels, new_cell_labels, new_old_cluster_similarity, plot_output_dir, fname) channel_result['similarity'] = os.path.join( 'array_%d' % array_idx, 'figures', 'catalogue_comparison', fname) # Go through each cluster in current day for new_cluster_idx in range(new_wfs.shape[0]): # Find most similar cluster from previous days most_similar = np.argmax( new_old_cluster_similarity[ new_cluster_idx, :]) # Merge if similarity greater than threshold similarity = new_old_cluster_similarity[ new_cluster_idx, most_similar] if similarity >= similarity_threshold: print( 'relabeling unit %d-%d as unit %d-%d' % (ch_grp, new_cell_labels[new_cluster_idx], ch_grp, all_old_cell_labels[most_similar])) fname = 'chan_%d_merge_%d-%d.png' % ( ch_grp, new_cell_labels[new_cluster_idx], all_old_cell_labels[most_similar]) plot_cluster_merge( all_old_cell_labels[most_similar], all_old_cluster_labels[most_similar], all_old_isis[most_similar, :], all_old_wfs[most_similar, :], all_old_wfs_stds[most_similar, :], new_cell_labels[new_cluster_idx], new_cluster_labels[new_cluster_idx], new_isis[new_cluster_idx, :], new_wfs[new_cluster_idx, :], new_wfs_stds[new_cluster_idx, :], similarity, time_range, bins, plot_output_dir, fname) channel_result['merged'].append( os.path.join('array_%d' % array_idx, 'figures', 'catalogue_comparison', fname)) new_cell_labels[ new_cluster_idx] = all_old_cell_labels[ most_similar] # Otherwise, add new cluster else: new_label = np.max(all_old_cell_labels) + 1 print('adding new unit %d-%d' % (ch_grp, new_label)) all_old_cell_labels.append(new_label) new_cell_labels[ new_cluster_idx] = new_label fname = 'chan_%d_nonmerge_%d.png' % ( ch_grp, new_cell_labels[new_cluster_idx]) plot_cluster_new( all_old_cluster_labels, all_old_cell_labels, all_old_isis, all_old_wfs, all_old_wfs_stds, new_cluster_labels[new_cluster_idx], new_cell_labels[new_cluster_idx], new_isis[new_cluster_idx, :], new_wfs[new_cluster_idx, :], new_wfs_stds[new_cluster_idx, :], new_old_cluster_similarity[ new_cluster_idx:new_cluster_idx + 1, :], time_range, bins, plot_output_dir, fname) channel_result['unmerged'].append( os.path.join('array_%d' % array_idx, 'figures', 'catalogue_comparison', fname)) catalogueconstructor.clusters['cell_label'][ nn_idx] = new_cell_labels # Merge clusters with same labels cluster_removed = True while cluster_removed: for cell_label in catalogueconstructor.clusters[ 'cell_label']: cluster_idx = np.where( catalogueconstructor.clusters['cell_label'] == cell_label)[0] if len(cluster_idx) > 1: cluster_label1 = catalogueconstructor.cluster_labels[ cluster_idx[0]] cluster_label2 = catalogueconstructor.cluster_labels[ cluster_idx[1]] print('auto_merge', cluster_label2, 'with', cluster_label1) mask = catalogueconstructor.all_peaks[ 'cluster_label'] == cluster_label2 catalogueconstructor.all_peaks[ 'cluster_label'][mask] = cluster_label1 catalogueconstructor.remove_one_cluster( cluster_label2) break else: cluster_removed = False catalogueconstructor.make_catalogue_for_peeler() fig = plot_clusters_summary(new_dataio, catalogueconstructor, ch_grp) fname = 'chan_%d_final_clusters.png' % ch_grp fig.savefig(os.path.join(plot_output_dir, fname)) fig.clf() plt.close() channel_result['final'] = os.path.join( 'array_%d' % array_idx, 'figures', 'catalogue_comparison', fname) channel_results.append(channel_result) env = Environment(loader=FileSystemLoader(cfg['template_dir'])) template = env.get_template('spike_sorting_merge_template.html') template_output = template.render(subject=subject, recording_date=date, channel_results=channel_results) out_filename = os.path.join(cfg['single_unit_spike_sorting_dir'], subject, date, 'spike_sorting_merge_report.html') with open(out_filename, 'w') as fh: fh.write(template_output) copyfile( os.path.join(cfg['template_dir'], 'style.css'), os.path.join(cfg['single_unit_spike_sorting_dir'], subject, date, 'style.css'))