def create_gui(self, name=None, subtitle=None, config_dir=None, add_default_views=True, **kwargs): """Create a manual clustering GUI.""" config_dir = config_dir or self.config_dir gui = GUI(name=name or self.gui_name, subtitle=subtitle, config_dir=config_dir, **kwargs) gui.controller = self # Attach the ManualClustering component to the GUI. self.manual_clustering.attach(gui) # Add views. if add_default_views: self.add_correlogram_view(gui) if self.all_features is not None: self.add_feature_view(gui) if self.all_waveforms is not None: self.add_waveform_view(gui) if self.all_traces is not None: self.add_trace_view(gui) self.emit('create_gui', gui) return gui
def create_gui(self, name=None, subtitle=None, config_dir=None, add_default_views=True, **kwargs): """Create a manual clustering GUI.""" config_dir = config_dir or self.config_dir gui = GUI(name=name or self.gui_name, subtitle=subtitle, config_dir=config_dir, **kwargs) gui.controller = self # Attach the ManualClustering component to the GUI. self.manual_clustering.attach(gui) # Add views. if add_default_views: self.add_correlogram_view(gui) if self.all_features is not None: self.add_feature_view(gui) if self.all_waveforms is not None: self.add_waveform_view(gui) if self.all_traces is not None: self.add_trace_view(gui) self.emit('create_gui', gui) return gui
def gui(tempdir, qtbot): gui = GUI(position=(200, 100), size=(500, 500), config_dir=tempdir) gui.show() qtbot.waitForWindowShown(gui) yield gui qtbot.wait(5) gui.close() del gui qtbot.wait(5)
def create_trace_gui(obj, **kwargs): """Create the Trace GUI. Parameters ---------- obj : str or Path Path to the raw data file. sample_rate : float The data sampling rate, in Hz. n_channels_dat : int The number of columns in the raw data file. dtype : str The NumPy data type of the raw binary file. offset : int The header offset in bytes. """ gui_name = 'TraceGUI' # Support passing a params.py file. if str(obj).endswith('.py'): params = get_template_params(str(obj)) return create_trace_gui(next(iter(params.pop('dat_path'))), **params) kwargs = { k: v for k, v in kwargs.items() if k in ('sample_rate', 'n_channels_dat', 'dtype', 'offset') } traces = get_ephys_reader(obj, **kwargs) create_app() gui = GUI(name=gui_name, subtitle=obj.resolve(), enable_threading=False) gui.set_default_actions() def _get_traces(interval): return Bunch(data=select_traces( traces, interval, sample_rate=traces.sample_rate)) # TODO: load channel information view = TraceView( traces=_get_traces, n_channels=traces.n_channels, sample_rate=traces.sample_rate, duration=traces.duration, enable_threading=False, ) view.attach(gui) return gui
def gui(tempdir, qtbot): # NOTE: mock patch show box exec_ _supervisor._show_box = lambda _: _ gui = GUI(position=(200, 100), size=(500, 500), config_dir=tempdir) gui.show() qtbot.waitForWindowShown(gui) yield gui qtbot.wait(5) gui.close() del gui qtbot.wait(5)
def create_gui(self): """Create the spike sorting GUI.""" gui = GUI(name=self.gui_name, subtitle=self.dat_path.resolve(), enable_threading=False) gui.has_save_action = False gui.set_default_actions() self.create_actions(gui) self.create_params_widget(gui) self.create_ipython_view(gui) self.create_trace_view(gui) self.create_probe_view(gui) return gui
def create_gui(self, **kwargs): gui = GUI(name=self.gui_name, subtitle=self.model.dat_path, config_dir=self.config_dir, **kwargs) self.supervisor.attach(gui) self.add_waveform_view(gui) if self.model.traces is not None: self.add_trace_view(gui) if self.model.features is not None: self.add_feature_view(gui) if self.model.template_features is not None: self.add_template_feature_view(gui) self.add_correlogram_view(gui) if self.model.amplitudes is not None: self.add_amplitude_view(gui) self.add_probe_view(gui) # Save the memcache when closing the GUI. @gui.connect_ def on_close(): self.context.save_memcache() self.emit('gui_ready', gui) return gui
def gui(tempdir, qtbot): gui = GUI(position=(200, 200), size=(800, 600), config_dir=tempdir) gui.set_default_actions() gui.show() qtbot.wait(1) #qtbot.addWidget(gui) #qtbot.waitForWindowShown(gui) yield gui qtbot.wait(1) gui.close() qtbot.wait(1)
def test_scatter_view(qtbot, tempdir): n = 1000 v = ScatterView(coords=lambda c: Bunch(x=np.random.randn(n), y=np.random.randn(n), data_bounds=None, ) ) gui = GUI(config_dir=tempdir) gui.show() v.attach(gui) qtbot.addWidget(gui) v.on_select([]) v.on_select([0]) v.on_select([0, 2, 3]) v.on_select([0, 2]) # qtbot.stop() gui.close()
def gui(tempdir, qtbot): # NOTE: mock patch show box exec_ _supervisor._show_box = lambda _: _ gui = GUI(position=(200, 100), size=(500, 500), config_dir=tempdir) gui.set_default_actions() gui.show() qtbot.waitForWindowShown(gui) yield gui qtbot.wait(5) gui.close() del gui qtbot.wait(5)
def test_probe_view(qtbot, tempdir): n = 50 positions = staggered_positions(n) best_channels = lambda cluster_id: range(1, 9, 2) v = ProbeView(positions=positions, best_channels=best_channels, ) gui = GUI(config_dir=tempdir) gui.show() v.attach(gui) qtbot.addWidget(gui) v.on_select([]) v.on_select([0]) v.on_select([0, 2, 3]) v.on_select([0, 2]) # qtbot.stop() gui.close()
def test_plot_mpl_1(qtbot): gui = GUI() c = PlotCanvasMpl() c.clear() c.attach(gui) c.show() qtbot.waitForWindowShown(c.canvas) if os.environ.get('PHY_TEST_STOP', None): qtbot.stop() c.close()
def test_correlogram_view(qtbot, tempdir): ns = 50 def get_correlograms(cluster_ids, bin_size, window_size): return artificial_correlograms(len(cluster_ids), ns) v = CorrelogramView(correlograms=get_correlograms, sample_rate=100., ) gui = GUI(config_dir=tempdir) gui.show() v.attach(gui) qtbot.addWidget(gui) v.on_select([]) v.on_select([0]) v.on_select([0, 2, 3]) v.on_select([0, 2]) v.toggle_normalization() v.set_bin(1) v.set_window(100) # qtbot.stop() gui.close()
def gui(tempdir, qtbot): # NOTE: mock patch show box exec_ gui_component._show_box = lambda _: _ gui = GUI(position=(200, 100), size=(500, 500), config_dir=tempdir) gui.show() qtbot.waitForWindowShown(gui) yield gui qtbot.wait(5) gui.close() del gui qtbot.wait(5)
def test_scatter_view(qtbot, tempdir): n = 1000 v = ScatterView(coords=lambda c: Bunch( x=np.random.randn(n), y=np.random.randn(n), data_bounds=None, )) gui = GUI(config_dir=tempdir) gui.show() v.attach(gui) qtbot.addWidget(gui) v.on_select([]) v.on_select([0]) v.on_select([0, 2, 3]) v.on_select([0, 2]) # qtbot.stop() gui.close()
def test_probe_view(qtbot, tempdir): n = 50 positions = staggered_positions(n) best_channels = lambda cluster_id: range(1, 9, 2) v = ProbeView(positions=positions, best_channels=best_channels, ) gui = GUI(config_dir=tempdir) gui.show() v.attach(gui) qtbot.addWidget(gui) v.on_select([]) v.on_select([0]) v.on_select([0, 2, 3]) v.on_select([0, 2]) # qtbot.stop() gui.close()
def test_waveform_view(qtbot, tempdir): nc = 5 def get_waveforms(cluster_id): return Bunch( data=artificial_waveforms(10, 20, nc), channel_ids=np.arange(nc), channel_positions=staggered_positions(nc), ) v = WaveformView(waveforms=get_waveforms, ) gui = GUI(config_dir=tempdir) gui.show() v.attach(gui) qtbot.addWidget(gui) v.on_select([]) v.on_select([0]) v.on_select([0, 2, 3]) v.on_select([0, 2]) v.toggle_waveform_overlap() v.toggle_waveform_overlap() v.toggle_show_labels() v.toggle_show_labels() # Box scaling. bs = v.boxed.box_size v.increase() v.decrease() ac(v.boxed.box_size, bs) bs = v.boxed.box_size v.widen() v.narrow() ac(v.boxed.box_size, bs) # Probe scaling. bp = v.boxed.box_pos v.extend_horizontally() v.shrink_horizontally() ac(v.boxed.box_pos, bp) bp = v.boxed.box_pos v.extend_vertically() v.shrink_vertically() ac(v.boxed.box_pos, bp) a, b = v.probe_scaling v.probe_scaling = (a, b * 2) ac(v.probe_scaling, (a, b * 2)) a, b = v.box_scaling v.box_scaling = (a * 2, b) ac(v.box_scaling, (a * 2, b)) # Simulate channel selection. _clicked = [] @v.gui.connect_ def on_channel_click(channel_id=None, button=None, key=None): _clicked.append((channel_id, button, key)) v.events.key_press(key=keys.Key('2')) v.events.mouse_press(pos=(0., 0.), button=1) v.events.key_release(key=keys.Key('2')) assert _clicked == [(0, 1, 2)] # qtbot.stop() gui.close()
def test_trace_view(tempdir, qtbot): nc = 5 ns = 9 sr = 1000. ch = list(range(nc)) duration = 1. st = np.linspace(0.1, .9, ns) sc = artificial_spike_clusters(ns, nc) traces = 10 * artificial_traces(int(round(duration * sr)), nc) cs = ColorSelector() m = Bunch(spike_times=st, spike_clusters=sc, sample_rate=sr) s = Bunch(cluster_meta={}, selected=[0]) sw = _iter_spike_waveforms(interval=[0., 1.], traces_interval=traces, model=m, supervisor=s, n_samples_waveforms=ns, get_best_channels=lambda cluster_id: ch, color_selector=cs, ) assert len(list(sw)) def get_traces(interval): out = Bunch(data=select_traces(traces, interval, sample_rate=sr), color=(.75,) * 4, ) a, b = st.searchsorted(interval) out.waveforms = [] k = 20 for i in range(a, b): t = st[i] c = sc[i] s = int(round(t * sr)) d = Bunch(data=traces[s - k:s + k, :], start_time=t - k / sr, color=cs.get(c), channel_ids=np.arange(5), spike_id=i, spike_cluster=c, ) out.waveforms.append(d) return out v = TraceView(traces=get_traces, n_channels=nc, sample_rate=sr, duration=duration, channel_vertical_order=np.arange(nc)[::-1], ) gui = GUI(config_dir=tempdir) gui.show() v.attach(gui) qtbot.addWidget(gui) # qtbot.waitForWindowShown(gui) v.on_select([]) v.on_select([0]) v.on_select([0, 2, 3]) v.on_select([0, 2]) # ac(v.stacked.box_size, (1., .08181), atol=1e-3) v.set_interval((.375, .625)) assert v.time == .5 v.go_to(.25) assert v.time == .25 v.go_to(-.5) assert v.time == .125 v.go_left() assert v.time == .125 v.go_right() assert v.time == .175 # Change interval size. v.interval = (.25, .75) ac(v.interval, (.25, .75)) v.widen() ac(v.interval, (.125, .875)) v.narrow() ac(v.interval, (.25, .75)) # Widen the max interval. v.set_interval((0, duration)) v.widen() v.toggle_show_labels() # v.toggle_show_labels() v.go_right() assert v.do_show_labels # Change channel scaling. bs = v.stacked.box_size v.increase() v.decrease() ac(v.stacked.box_size, bs, atol=1e-3) v.origin = 'upper' assert v.origin == 'upper' # Simulate spike selection. _clicked = [] @v.gui.connect_ def on_spike_click(channel_id=None, spike_id=None, cluster_id=None): _clicked.append((channel_id, spike_id, cluster_id)) v.events.key_press(key=keys.Key('Control')) v.events.mouse_press(pos=(400., 200.), button=1, modifiers=(keys.CONTROL,)) v.events.key_release(key=keys.Key('Control')) assert _clicked == [(1, 4, 1)] # qtbot.stop() gui.close()
def test_trace_view(tempdir, qtbot): nc = 5 ns = 9 sr = 1000. ch = list(range(nc)) duration = 1. st = np.linspace(0.1, .9, ns) sc = artificial_spike_clusters(ns, nc) traces = 10 * artificial_traces(int(round(duration * sr)), nc) cs = ColorSelector() m = Bunch(spike_times=st, spike_clusters=sc, sample_rate=sr) s = Bunch(cluster_meta={}, selected=[0]) sw = _iter_spike_waveforms( interval=[0., 1.], traces_interval=traces, model=m, supervisor=s, n_samples_waveforms=ns, get_best_channels=lambda cluster_id: ch, color_selector=cs, ) assert len(list(sw)) def get_traces(interval): out = Bunch( data=select_traces(traces, interval, sample_rate=sr), color=(.75, ) * 4, ) a, b = st.searchsorted(interval) out.waveforms = [] k = 20 for i in range(a, b): t = st[i] c = sc[i] s = int(round(t * sr)) d = Bunch( data=traces[s - k:s + k, :], start_time=t - k / sr, color=cs.get(c), channel_ids=np.arange(5), spike_id=i, spike_cluster=c, ) out.waveforms.append(d) return out v = TraceView( traces=get_traces, n_channels=nc, sample_rate=sr, duration=duration, channel_vertical_order=np.arange(nc)[::-1], ) gui = GUI(config_dir=tempdir) gui.show() v.attach(gui) qtbot.addWidget(gui) # qtbot.waitForWindowShown(gui) v.on_select([]) v.on_select([0]) v.on_select([0, 2, 3]) v.on_select([0, 2]) # ac(v.stacked.box_size, (1., .08181), atol=1e-3) v.set_interval((.375, .625)) assert v.time == .5 v.go_to(.25) assert v.time == .25 v.go_to(-.5) assert v.time == .125 v.go_left() assert v.time == .125 v.go_right() assert v.time == .175 # Change interval size. v.interval = (.25, .75) ac(v.interval, (.25, .75)) v.widen() ac(v.interval, (.125, .875)) v.narrow() ac(v.interval, (.25, .75)) # Widen the max interval. v.set_interval((0, duration)) v.widen() v.toggle_show_labels() # v.toggle_show_labels() v.go_right() assert v.do_show_labels # Change channel scaling. bs = v.stacked.box_size v.increase() v.decrease() ac(v.stacked.box_size, bs, atol=1e-3) v.origin = 'upper' assert v.origin == 'upper' # Simulate spike selection. _clicked = [] @v.gui.connect_ def on_spike_click(channel_id=None, spike_id=None, cluster_id=None): _clicked.append((channel_id, spike_id, cluster_id)) v.events.key_press(key=keys.Key('Control')) v.events.mouse_press(pos=(400., 200.), button=1, modifiers=(keys.CONTROL, )) v.events.key_release(key=keys.Key('Control')) assert _clicked == [(1, 4, 1)] # qtbot.stop() gui.close()
def create_trace_gui(dat_path, **kwargs): """Create the Trace GUI. Parameters ---------- dat_path : str or Path Path to the raw data file. sample_rate : float The data sampling rate, in Hz. n_channels_dat : int The number of columns in the raw data file. dtype : str The NumPy data type of the raw binary file. """ gui_name = 'TraceGUI' dat_path = Path(dat_path) # Support passing a params.py file. if dat_path.suffix == '.py': params = get_template_params(str(dat_path)) return create_trace_gui(next(iter(params.pop('dat_path'))), **params) if dat_path.suffix == '.cbin': # pragma: no cover data = load_raw_data(path=dat_path) sample_rate = data.sample_rate n_channels_dat = data.shape[1] else: sample_rate = float(kwargs['sample_rate']) assert sample_rate > 0. n_channels_dat = int(kwargs['n_channels_dat']) dtype = np.dtype(kwargs['dtype']) offset = int(kwargs['offset'] or 0) order = kwargs.get('order', None) # Memmap the raw data file. data = load_raw_data( path=dat_path, n_channels_dat=n_channels_dat, dtype=dtype, offset=offset, order=order, ) duration = data.shape[0] / sample_rate create_app() gui = GUI(name=gui_name, subtitle=dat_path.resolve(), enable_threading=False) gui.set_default_actions() def _get_traces(interval): return Bunch( data=select_traces(data, interval, sample_rate=sample_rate)) # TODO: load channel information view = TraceView( traces=_get_traces, n_channels=n_channels_dat, sample_rate=sample_rate, duration=duration, enable_threading=False, ) view.attach(gui) return gui
def test_feature_view(qtbot, tempdir, n_channels): nc = n_channels ns = 500 features = artificial_features(ns, nc, 4) spike_clusters = artificial_spike_clusters(ns, 4) spike_times = np.linspace(0., 1., ns) spc = _spikes_per_cluster(spike_clusters) def get_spike_ids(cluster_id): return (spc[cluster_id] if cluster_id is not None else np.arange(ns)) def get_features(cluster_id=None, channel_ids=None, spike_ids=None, load_all=None): if load_all: spike_ids = spc[cluster_id] else: spike_ids = get_spike_ids(cluster_id) return Bunch( data=features[spike_ids], spike_ids=spike_ids, masks=np.random.rand(ns, nc), channel_ids=(channel_ids if channel_ids is not None else np.arange(nc)[::-1]), ) def get_time(cluster_id=None, load_all=None): return Bunch( data=spike_times[get_spike_ids(cluster_id)], lim=(0., 1.), ) v = FeatureView( features=get_features, attributes={'time': get_time}, ) v.set_state(GUIState(scaling=None)) gui = GUI(config_dir=tempdir) gui.show() v.attach(gui) qtbot.addWidget(gui) v.on_select([]) v.on_select([0]) v.on_select([0, 2, 3]) v.on_select([0, 2]) gui.emit('select', [0, 2]) qtbot.wait(10) v.increase() v.decrease() v.on_channel_click(channel_id=3, button=1, key=2) v.clear_channels() v.toggle_automatic_channel_selection() # Split without selection. spike_ids = v.on_request_split() assert len(spike_ids) == 0 # Draw a lasso. def _click(x, y): qtbot.mouseClick(v.native, Qt.LeftButton, pos=QPoint(x, y), modifier=Qt.ControlModifier) _click(10, 10) _click(10, 100) _click(100, 100) _click(100, 10) # Split lassoed points. spike_ids = v.on_request_split() assert len(spike_ids) > 0 # qtbot.stop() gui.close()
def test_feature_view(qtbot, tempdir, n_channels): nc = n_channels ns = 500 features = artificial_features(ns, nc, 4) spike_clusters = artificial_spike_clusters(ns, 4) spike_times = np.linspace(0., 1., ns) spc = _spikes_per_cluster(spike_clusters) def get_spike_ids(cluster_id): return (spc[cluster_id] if cluster_id is not None else np.arange(ns)) def get_features(cluster_id=None, channel_ids=None, spike_ids=None, load_all=None): if load_all: spike_ids = spc[cluster_id] else: spike_ids = get_spike_ids(cluster_id) return Bunch(data=features[spike_ids], spike_ids=spike_ids, masks=np.random.rand(ns, nc), channel_ids=(channel_ids if channel_ids is not None else np.arange(nc)[::-1]), ) def get_time(cluster_id=None, load_all=None): return Bunch(data=spike_times[get_spike_ids(cluster_id)], lim=(0., 1.), ) v = FeatureView(features=get_features, attributes={'time': get_time}, ) v.set_state(GUIState(scaling=None)) gui = GUI(config_dir=tempdir) gui.show() v.attach(gui) qtbot.addWidget(gui) v.on_select([]) v.on_select([0]) v.on_select([0, 2, 3]) v.on_select([0, 2]) gui.emit('select', [0, 2]) qtbot.wait(10) v.increase() v.decrease() v.on_channel_click(channel_id=3, button=1, key=2) v.clear_channels() v.toggle_automatic_channel_selection() # Split without selection. spike_ids = v.on_request_split() assert len(spike_ids) == 0 # Draw a lasso. def _click(x, y): qtbot.mouseClick(v.native, Qt.LeftButton, pos=QPoint(x, y), modifier=Qt.ControlModifier) _click(10, 10) _click(10, 100) _click(100, 100) _click(100, 10) # Split lassoed points. spike_ids = v.on_request_split() assert len(spike_ids) > 0 # qtbot.stop() gui.close()
@author: chongxi lai """ #%% from MUA import * from Vis import ObjectWidget, view_scatter_3d from phy.plot import View from phy.plot.interact import Grid from hdbscan import HDBSCAN import phy from phy.gui import GUI import seaborn as sns mua = MUA(filename='S:/pcie.bin', nCh=32, fs=25000, numbytes=4) spk = mua.tospk() fet = spk.tofet('pca') #%% gui = GUI(position=(0, 0), size=(600, 400), name='GUI') props = ObjectWidget() gui.add_view(props,position='left', name='params') scatter_view = view_scatter_3d() scatter_view.unfreeze() gui.add_view(scatter_view) spk_view = View('grid') gui.add_view(spk_view)
def MergeRuns(controller=controller, plugin=plugin): if True: path2 = QtGui.QFileDialog.getExistingDirectory( None, "Select the results folder for the sort to be merged", op.dirname( op.dirname(controller.model.dir_path) ), #two folders up from the current phy's path QtGui.QFileDialog.ShowDirsOnly) else: path2 = '/home/luke/KiloSort_tmp/BOL005c_9_96clusts/results' params_path = op.join(path2, 'params.py') params = _read_python(params_path) params['dtype'] = np.dtype(params['dtype']) params['path'] = path2 if op.realpath(params['dat_path']) != params['dat_path']: params['dat_path'] = op.join(path2, params['dat_path']) print('Loading {}'.format(path2)) controller2 = TemplateController(**params) #controller2.gui_name = 'TemplateGUI2' gui2 = controller2.create_gui() gui2.show() # @gui2.connect_ # def on_select(clusters,controller2=controller2): # controller.supervisor.select(clusters) #create mean_waveforms for each controller (run) print('computing mean waveforms for master run...') controller.mean_waveforms = create_mean_waveforms( controller, max_waveforms_per_cluster=100) print('computing mean waveforms for slave run...') controller2.mean_waveforms = create_mean_waveforms( controller2, max_waveforms_per_cluster=100) groups = { c: controller.supervisor.cluster_meta.get('group', c) or 'unsorted' for c in controller.supervisor.clustering.cluster_ids } groups2 = { c: controller2.supervisor.cluster_meta.get('group', c) or 'unsorted' for c in controller2.supervisor.clustering.cluster_ids } su_inds = np.nonzero([ controller.supervisor.cluster_meta.get('group', c) == 'good' for c in controller.supervisor.clustering.cluster_ids ])[0] mu_inds = np.nonzero([ controller.supervisor.cluster_meta.get('group', c) == 'mua' for c in controller.supervisor.clustering.cluster_ids ])[0] su_best_channels = np.array([ controller.get_best_channel(c) for c in controller.supervisor.clustering.cluster_ids[su_inds] ]) mu_best_channels = np.array([ controller.get_best_channel(c) for c in controller.supervisor.clustering.cluster_ids[mu_inds] ]) su_order = np.argsort(su_best_channels, kind='mergesort') mu_order = np.argsort(mu_best_channels, kind='mergesort') m_inds = np.concatenate((su_inds[su_order], mu_inds[mu_order])) filename = op.join(controller.model.dir_path, 'cluster_names.ts') if not op.exists(filename): best_channels = np.concatenate( (su_best_channels[su_order], mu_best_channels[mu_order])) unit_type = np.concatenate( (np.ones(len(su_order)), 2 * np.ones(len(mu_order)))) unit_number = np.zeros(len(best_channels)) for chan in np.unique(best_channels): matched_clusts = best_channels == chan unit_number[matched_clusts] = np.arange( sum(matched_clusts)) + 1 else: print('{} exists, loading'.format(filename)) unit_types, channels, unit_numbers = load_metadata( filename, controller.supervisor.clustering.cluster_ids) best_channels = channels[m_inds] unit_number = unit_numbers[m_inds] unit_type = unit_types[m_inds] unit_type_current = np.concatenate( (np.ones(len(su_order)), 2 * np.ones(len(mu_order)))) if ~np.all(unit_type == unit_type_current): raise RuntimeError( 'For the master phy, the unit types saved in "cluster_names.ts"' 'do not match those save in "cluster_groups.tsv" This likely means work was done on ' 'this phy after merging with a previous master. Not sure how to deal with this!' ) #re-sort to make unit numbers in order # assuming unit_type is already sorted (which it should be...) nsu = np.sum(unit_type == 1) su_order = np.lexsort( (unit_number[:nsu], best_channels[:nsu])) mu_order = np.lexsort( (unit_number[nsu:], best_channels[nsu:])) m_inds[:nsu] = m_inds[su_order] m_inds[nsu:] = m_inds[mu_order + nsu] best_channels = channels[m_inds] unit_number = unit_numbers[m_inds] unit_type = unit_types[m_inds] dists = calc_dists(controller, controller2, m_inds) so = np.argsort(dists, 0, kind='mergesort') matchi = so[0, :] #best match index to master for each slave sortrows = np.argsort( matchi, kind='mergesort') #sort index for best match def handle_item_clicked(item, controller=controller, controller2=controller2, plugin=plugin): row = np.array( [cell.row() for cell in table.selectedIndexes()]) column = np.array( [cell.column() for cell in table.selectedIndexes()]) print("Row {} and Column {} was clicked".format( row, column)) print("M {} S {} ".format( controller.supervisor.clustering.cluster_ids[ plugin.m_inds[column]], controller2.supervisor.clustering.cluster_ids[ plugin.sortrows[row]])) column = column[~np.in1d(plugin.m_inds[column], (-1, -2))] if len(column) == 0: pass #controller.supervisor.select(None) # make a deselect function and call it here if feeling fancy else: controller.supervisor.select( controller.supervisor.clustering.cluster_ids[ plugin.m_inds[column]].tolist()) controller2.supervisor.select( controller2.supervisor.clustering.cluster_ids[ plugin.sortrows[row]].tolist()) #print("Row %d and Column %d was clicked" % (row, column)) def create_table(controller, controller2, plugin): plugin.table.setRowCount(len(plugin.matchi)) plugin.table.setColumnCount(len(plugin.m_inds)) # set data dists_txt = np.round(plugin.dists / plugin.dists.max() * 100) normal = plt.Normalize( plugin.dists[plugin.dists != -1].min() - 1, plugin.dists.max() + 1) colors = plt.cm.viridis_r(normal(plugin.dists)) * 255 for col in range(len(plugin.m_inds)): for row in range(len(plugin.matchi)): if plugin.dists[col, plugin.sortrows[row]] < 0: item = QtGui.QTableWidgetItem('N/A') item.setBackground(QtGui.QColor(127, 127, 127)) else: item = QtGui.QTableWidgetItem('{:.0f}'.format( dists_txt[col, plugin.sortrows[row]])) item.setBackground( QtGui.QColor( colors[col, plugin.sortrows[row], 0], colors[col, plugin.sortrows[row], 1], colors[col, plugin.sortrows[row], 2])) if plugin.matchi[plugin.sortrows[row]] == col: item.setForeground(QtGui.QColor(255, 0, 0)) #item.setFlags(Qt.ItemIsEditable) plugin.table.setItem(row, col, item) #plugin.table.item(row,col).setForeground(QtGui.QColor(0,255,0)) for col in range(plugin.dists.shape[0]): if plugin.m_inds[col] == -1: cluster_num = 'None' elif plugin.m_inds[col] == -2: cluster_num = 'Noise' else: cluster_num = controller.supervisor.clustering.cluster_ids[ plugin.m_inds[col]] plugin.table.setHorizontalHeaderItem( col, QtGui.QTableWidgetItem('{}\n{:.0f}-{:.0f}'.format( cluster_num, plugin.best_channels[col], plugin.unit_number[col]))) #plugin.table.setHorizontalHeaderItem(row, QtGui.QTableWidgetItem('{:.0f}'.format(controller.supervisor.clustering.cluster_ids[plugin.m_inds[row]]))) for col in range(plugin.dists.shape[1]): c_id = controller2.supervisor.clustering.cluster_ids[ plugin.sortrows[col]] #height=controller2.supervisor.cluster_view._columns['height']['func'](c_id) snr = controller2.supervisor.cluster_view._columns[ 'snr']['func'](c_id) plugin.table.setVerticalHeaderItem( col, QtGui.QTableWidgetItem('{:.0f}-{:.1f}'.format( c_id, snr))) plugin.table.setEditTriggers( QtGui.QAbstractItemView.NoEditTriggers) plugin.table.resizeColumnsToContents() plugin.table.itemClicked.connect(handle_item_clicked) # self.fig = plt.figure() # ax = self.fig.add_axes([0.15, 0.02, 0.83, 0.975]) # normal = plt.Normalize(dists.min()-1, dists.max()+1) # dists_txt=np.round(dists/dists.max()*100) # self.table=ax.table(cellText=dists_txt, rowLabels=controller.supervisor.clustering.cluster_ids[m_inds], colLabels=controller2.supervisor.clustering.cluster_ids, # colWidths = [0.03]*dists.shape[1], loc='center', # cellColours=plt.cm.hot(normal(dists))) # self.fig.show() #a = QApplication(sys.argv) tablegui = GUI(position=(400, 200), size=(400, 300)) table = QtGui.QTableWidget() table.setWindowTitle("Merge Table") #table.resize(600, 400) plugin.matchi = matchi plugin.sortrows = sortrows plugin.best_channels = best_channels plugin.unit_number = unit_number plugin.unit_type = unit_type plugin.m_inds = m_inds plugin.tablegui = tablegui #need to keep a reference otherwide the gui is deleted by garbage collection, leading to a segfault! plugin.table = table plugin.dists = dists create_table(controller, controller2, plugin) tablegui.add_view(table) actions = Actions(tablegui, name='Merge', menu='Merge') @actions.add(menu='Merge', name='Set master for selected slave', shortcut='enter') def set_master(plugin=plugin, controller=controller, controller2=controller2): row = np.array([ cell.row() for cell in plugin.table.selectedIndexes() ]) column = np.array([ cell.column() for cell in plugin.table.selectedIndexes() ]) if len(row) == 1: print("Row {} and Column {} is selected".format( row, column)) plugin.table.item(row[0], column[0]).setForeground( QtGui.QColor(255, 0, 0)) #plugin.table.item(0, 0).setForeground(QtGui.QColor(0,255,0)) plugin.table.item( row[0], plugin.matchi[plugin.sortrows[ row[0]]]).setForeground(QtGui.QColor(0, 0, 0)) plugin.matchi[plugin.sortrows[row[0]]] = column[0] plugin.table.show() plugin.tablegui.show() else: st = 'Only one cell can be selected when setting master' print(st) plugin.tablegui.status_message = st @actions.add(menu='Merge', name='Merge selected slaves', shortcut='m') def merge_slaves_by_selection(plugin=plugin, controller=controller, controller2=controller2): row = np.array([ cell.row() for cell in plugin.table.selectedIndexes() ]) column = np.array([ cell.column() for cell in plugin.table.selectedIndexes() ]) row = np.unique(row) column = np.unique(column) if len(column) == 1 and len(row) > 1: merge_slaves(plugin, controller, controller2, row) create_table(controller, controller2, plugin) plugin.tablegui.show() else: if len(column) > 1: st = 'Only one master can be selected when merging slaves' elif len(row) == 1: st = 'At least two slaves must be selected to merge slaves' else: st = 'Unknown slave merging problem' print(st) plugin.tablegui.status_message = st def merge_slaves_by_array(plugin, controller, controller2, merge_matchis): for merge_matchi in merge_matchis: row = np.where( plugin.matchi[plugin.sortrows] == merge_matchi)[0] merge_slaves(plugin, controller, controller2, row) create_table(controller, controller2, plugin) plugin.tablegui.show() def merge_slaves(plugin, controller, controller2, row): controller2.supervisor.merge( controller2.supervisor.clustering.cluster_ids[ plugin.sortrows[row]].tolist()) assign_matchi = plugin.matchi[plugin.sortrows[row[0]]] plugin.matchi = np.delete(plugin.matchi, plugin.sortrows[row], axis=0) plugin.matchi = np.append(plugin.matchi, assign_matchi) controller2.mean_waveforms = np.delete( controller2.mean_waveforms, plugin.sortrows[row], axis=2) new_mean_waveforms = create_mean_waveforms( controller2, max_waveforms_per_cluster=100, cluster_ids=controller2.supervisor.clustering. cluster_ids[-1]) controller2.mean_waveforms = np.append( controller2.mean_waveforms, new_mean_waveforms, axis=2) plugin.dists = np.delete(plugin.dists, plugin.sortrows[row], axis=1) plugin.dists = np.append(plugin.dists, calc_dists( controller, controller2, plugin.m_inds, s_inds=plugin.dists.shape[1]), axis=1) plugin.sortrows = np.argsort(plugin.matchi, kind='mergesort') @actions.add(menu='Merge', name='Move low-snr clusters to noise', shortcut='n') def move_low_snr_to_noise(plugin=plugin, controller=controller, controller2=controller2): cluster_ids = controller2.supervisor.clustering.cluster_ids snrs = np.zeros(cluster_ids.shape) for i in range(len(cluster_ids)): snrs[i] = controller2.supervisor.cluster_view._columns[ 'snr']['func'](cluster_ids[i]) thresh = 0.2 # for amplitude thresh = 0.5 # for snr noise_clusts = cluster_ids[snrs < thresh] n_ind = [] for clu in noise_clusts: this_ind = np.where( controller2.supervisor.clustering.cluster_ids[ plugin.sortrows] == clu)[0][0] n_ind.append(this_ind) ind = plugin.m_inds.shape[0] plugin.matchi[plugin.sortrows[n_ind]] = ind plugin.m_inds = np.insert(plugin.m_inds, ind, -2) plugin.best_channels = np.insert(plugin.best_channels, ind, 999) plugin.unit_number = np.insert(plugin.unit_number, ind, 0) plugin.unit_type = np.insert(plugin.unit_type, ind, 3) plugin.dists = np.insert(plugin.dists, ind, -1, axis=0) plugin.sortrows = np.argsort(plugin.matchi, kind='mergesort') create_table(controller, controller2, plugin) tablegui.show() st = 'Cluster ids {} moved to noise'.format(noise_clusts) print(st) plugin.tablegui.status_message = st @actions.add(menu='Merge', name='Add new unit label', shortcut='a') def add_unit(plugin=plugin, controller=controller, controller2=controller2): chan, ok = QtGui.QInputDialog.getText( None, 'Adding new unit label:', ' Channel:') if not ok: return try: chan = int(chan) except: plugin.tablegui.status_message = 'Error inputting channel' return nums = plugin.unit_number[plugin.best_channels == int( chan)] if len(nums) == 0: next_unit_num = 1 else: next_unit_num = int(nums.max()) + 1 dlg = QtGui.QInputDialog(None) dlg.setInputMode(QtGui.QInputDialog.TextInput) dlg.setTextValue('{}'.format(next_unit_num)) dlg.setLabelText("Unit number:") dlg.resize(300, 300) #dlg.mainLayout = QtGui.QVBoxLayout() #dlg.setLayout(dlg.mainLayout) b1 = QtGui.QRadioButton("SU", dlg) b2 = QtGui.QRadioButton("MU", dlg) b1.move(100, 0) b2.move(150, 0) b1.setChecked(True) ok = dlg.exec_() unit_number = dlg.textValue() if not ok: return try: unit_number = int(unit_number) except: plugin.tablegui.status_message = 'Error inputting unit number' return if b1.isChecked(): unit_type = 1 elif b2.isChecked(): unit_type = 2 else: plugin.tablegui.status_message( 'Error getting unit type, must have checked either SU or MU' ) return below_inds = np.logical_and( plugin.unit_type == unit_type, plugin.best_channels <= chan).nonzero()[0] if below_inds.shape[0] == 0: below_inds = plugin.unit_type == unit_type below_inds = below_inds.nonzero()[0] ind = below_inds[-1] + 1 plugin.m_inds = np.insert(plugin.m_inds, ind, -1) plugin.matchi[plugin.matchi >= ind] += 1 plugin.best_channels = np.insert(plugin.best_channels, ind, chan) plugin.unit_number = np.insert(plugin.unit_number, ind, unit_number) plugin.unit_type = np.insert(plugin.unit_type, ind, unit_type) plugin.dists = np.insert(plugin.dists, ind, -1, axis=0) create_table(controller, controller2, plugin) tablegui.show() @actions.add(menu='Merge', name='Save cluster associations', alias='sca') def save_cluster_associations(plugin=plugin, controller=controller, controller2=controller2): un_matchi, counts = np.unique(plugin.matchi, return_index=False, return_inverse=False, return_counts=True) rmi = np.where(plugin.unit_type[un_matchi] == 3)[0] if (len(rmi) > 0): un_matchi = np.delete(un_matchi, rmi) counts = np.delete(counts, rmi) if np.any(counts > 1): msgBox = QtGui.QMessageBox() msgBox.setText( 'There are {} master clusters that are about ' 'to be assigned multiple slave clusters. If this slave will ' 'be used as a master for an addional merge, in most cases ' 'slave clusters that share the same master match should be ' 'merged.'.format(np.sum(counts > 1))) msgBox.setInformativeText( 'Do you want to automatically do these merges?') msgBox.setStandardButtons(QtGui.QMessageBox.Yes | QtGui.QMessageBox.No | QtGui.QMessageBox.Cancel) msgBox.setDefaultButton(QtGui.QMessageBox.Yes) ret = msgBox.exec_() if ret == QtGui.QMessageBox.Yes: merge_slaves_by_array(plugin, controller, controller2, un_matchi[counts > 1]) msgBox = QtGui.QMessageBox() msgBox.setText( 'Merges done. Not saving yet! Check to see that everything is okay and then save.' ) msgBox.exec_() return elif ret == QtGui.QMessageBox.No: pass else: return # assign labels to slave phy's clusters based on master phy's labels good_clusts = controller2.supervisor.clustering.cluster_ids[ plugin.sortrows[plugin.unit_type[plugin.matchi[ plugin.sortrows]] == 1]].tolist() controller2.supervisor.move('good', good_clusts) mua_clusts = controller2.supervisor.clustering.cluster_ids[ plugin.sortrows[plugin.unit_type[plugin.matchi[ plugin.sortrows]] == 2]].tolist() controller2.supervisor.move('mua', mua_clusts) mua_clusts = controller2.supervisor.clustering.cluster_ids[ plugin.sortrows[plugin.unit_type[plugin.matchi[ plugin.sortrows]] == 3]].tolist() controller2.supervisor.move('noise', mua_clusts) #save both master and slave controller.supervisor.save() controller2.supervisor.save() #save associations create_tsv( op.join(controller.model.dir_path, 'cluster_names.ts'), controller.supervisor.clustering.cluster_ids[ plugin.m_inds[~np.in1d(plugin.m_inds, (-1, -2))]], plugin.unit_type[~np.in1d(plugin.m_inds, (-1, -2))], plugin.best_channels[~np.in1d(plugin.m_inds, (-1, -2))], plugin.unit_number[~np.in1d(plugin.m_inds, (-1, -2))]) create_tsv( op.join(controller2.model.dir_path, 'cluster_names.ts'), controller2.supervisor.clustering.cluster_ids[ plugin.sortrows], plugin.unit_type[plugin.matchi[plugin.sortrows]], plugin.best_channels[plugin.matchi[plugin.sortrows]], plugin.unit_number[plugin.matchi[plugin.sortrows]]) with open( op.join(controller.model.dir_path, 'Merged_Files.txt'), 'a') as text_file: text_file.write('{} on {}\n'.format( controller2.model.dir_path, time.strftime('%c'))) plugin.tablegui.status_message = 'Saved clusted associations' print('Saved clusted associations') def create_tsv(filename, cluster_id, unit_type, channel, unit_number): if sys.version_info[0] < 3: file = open(filename, 'wb') else: file = open(filename, 'w', newline='') with file as f: writer = csv.writer(f, delimiter='\t') writer.writerow( ['cluster_id', 'unit_type', 'chan', 'unit_number']) for i in range(len(cluster_id)): writer.writerow([ cluster_id[i], unit_type[i], channel[i], unit_number[i] ]) tablegui.show()
ch = 26 min_cluster_size = 5 leaf_size = 10 hdbcluster = HDBSCAN(min_cluster_size=min_cluster_size, leaf_size=leaf_size, gen_min_span_tree=True, algorithm='boruvka_kdtree') clu = hdbcluster.fit_predict(fet[ch]) print 'get clusters', np.unique(clu) # from phy.gui import GUI, create_app, run_app create_app() gui = GUI(position=(400, 200), size=(600, 400)) scatter_view = view_scatter_3d() scatter_view.attach(gui) scatter_view.set_data(fet[ch], clu) nclu = len(np.unique(clu)) view = View(layout='grid', shape=(3, nclu)) gui.add_view(view) palette = sns.color_palette() view.clear() for chNo in range(3): for clu_id in np.unique(clu): color = palette[clu_id] if clu_id>=0 else np.array([1,1,1])
if 1 in event.buttons and modifiers is not (): p1 = event.press_event.pos p2 = event.last_event.pos if modifiers[0].name == 'Shift': self.cross.ref_enable(p2) elif self.cross.cross_state: if event.press_event is None: self.cross.moveto(event.pos) self.cross.ref_disable() if __name__ == '__main__': from phy.gui import GUI, create_app, run_app create_app() gui = GUI(position=(0, 0), size=(600, 400), name='GUI') ############################################## ### Test scatter_view from sklearn.preprocessing import normalize n = 1000000 fet = np.random.randn(n,3) fet = normalize(fet,axis=1) print fet.shape clu = np.random.randint(3,size=(n,1)) scatter_view = view_scatter_3d() scatter_view.attach(gui) scatter_view.set_data(fet, clu) ############################################################################################# from Binload import Binload ### Set Parameters ### filename = 'S:/pcie.bin'
def test_waveform_view(qtbot, tempdir): nc = 5 def get_waveforms(cluster_id): return Bunch( data=artificial_waveforms(10, 20, nc), channel_ids=np.arange(nc), channel_positions=staggered_positions(nc) ) v = WaveformView(waveforms=get_waveforms) gui = GUI(config_dir=tempdir) gui.show() v.attach(gui) qtbot.addWidget(gui) v.on_select([]) v.on_select([0]) v.on_select([0, 2, 3]) v.on_select([0, 2]) v.toggle_waveform_overlap() v.toggle_waveform_overlap() v.toggle_show_labels() v.toggle_show_labels() # Box scaling. bs = v.boxed.box_size v.increase() v.decrease() ac(v.boxed.box_size, bs) bs = v.boxed.box_size v.widen() v.narrow() ac(v.boxed.box_size, bs) # Probe scaling. bp = v.boxed.box_pos v.extend_horizontally() v.shrink_horizontally() ac(v.boxed.box_pos, bp) bp = v.boxed.box_pos v.extend_vertically() v.shrink_vertically() ac(v.boxed.box_pos, bp) a, b = v.probe_scaling v.probe_scaling = (a, b * 2) ac(v.probe_scaling, (a, b * 2)) a, b = v.box_scaling v.box_scaling = (a * 2, b) ac(v.box_scaling, (a * 2, b)) # Simulate channel selection. _clicked = [] @v.gui.connect_ def on_channel_click(channel_id=None, button=None, key=None): _clicked.append((channel_id, button, key)) v.events.key_press(key=keys.Key("2")) v.events.mouse_press(pos=(0.0, 0.0), button=1) v.events.key_release(key=keys.Key("2")) assert _clicked == [(0, 1, 2)] # qtbot.stop() gui.close()