def edit_clustering_parameters(self, shell=False): '''Allows user interface for editing clustering parameters Parameters ---------- shell : bool (optional) True if you want command-line interface, False for GUI (default) ''' param_filler = userIO.dictIO(self.clust_params, shell=shell) tmp = param_filler.fill_dict() if tmp: self.clust_params = tmp
def edit_psth_parameters(self, shell=False): '''Allows user interface for editing psth parameters Parameters ---------- shell : bool (optional) True if you want command-line interface, False for GUI (default) ''' param_filler = userIO.dictIO(self.psth_params, shell=shell) tmp = param_filler.fill_dict('Edit params for making PSTHs\n' 'All times are in ms') if tmp: self.psth_params = tmp
def __init__(self, exp_dir=None, shell=False): '''Setup for analysis across recording sessions Parameters ---------- exp_dir : str (optional) path to directory containing all recording directories if None (default) is passed then a popup to choose file will come up shell : bool (optional) True to use command-line interface for user input False (default) for GUI ''' if exp_dir is None: exp_dir = eg.diropenbox('Select Experiment Directory', 'Experiment Directory') if exp_dir is None or exp_dir == '': return fd = [os.path.join(exp_dir, x) for x in os.listdir(exp_dir)] file_dirs = [x for x in fd if os.path.isdir(x)] order_dict = dict.fromkeys(file_dirs, 0) tmp = userIO.dictIO(order_dict, shell=shell) order_dict = tmp.fill_dict(prompt=('Set order of recordings (1-%i)\n' 'Leave blank to delete directory' ' from list')) if order_dict is None: return file_dirs = [k for k, v in order_dict.items() if v is not None and v != 0] file_dirs = sorted(file_dirs, key=order_dict.get) file_dirs = [os.path.join(exp_dir, x) for x in file_dirs] file_dirs = [x[:-1] if x.endswith('/') else x for x in file_dirs] self.recording_dirs = file_dirs self.experiment_dir = exp_dir self.shell = shell dat = dataset.load_dataset(file_dirs[0]) em = dat.electrode_mapping.copy() ingc = userIO.select_from_list('Select all eletrodes confirmed in GC', em['Electrode'], multi_select=True, shell=shell) ingc = list(map(int, ingc)) em['Area'] = np.where(em['Electrode'].isin(ingc), 'GC', 'Other') self.electrode_mapping = em self.save_file = os.path.join(exp_dir, '%s_experiment.p' % os.path.basename(exp_dir))
def edit_spike_array_parameters(self, shell=False): params = self.spike_array_params param_filler = userIO.dictIO(params, shell=shell) new_params = param_filler.fill_dict(prompt=('Input desired parameters' ' (Times are in ms')) if new_params is None: return None else: new_params['dig_ins_to_use'] = [int(x) for x in new_params['dig_ins_to_use'] if x != ''] new_params['laser_channels'] = [int(x) for x in new_params['laser_channels'] if x != ''] self.spike_array_params = new_params return new_params
def get_palatability_ranks(dig_in_mapping, shell=True): '''Queries user for palatability rankings for digital inputs (tastants) and adds a column to dig_in_mapping DataFrame Parameters ---------- dig_in_mapping: pandas.DataFrame, DataFrame with at least columns 'dig_in' and 'name', for mapping digital input channel number to a str name ''' dim = dig_in_mapping.copy() tmp = dict.fromkeys(dim['name'], 0) filler = userIO.dictIO(tmp, shell=shell) tmp = filler.fill_dict('Rank Palatability\n1 for the lowest\n' 'Leave blank to exclude from palatability analysis') dim['palatability_rank'] = dim['name'].map(tmp) return dim
def _order_dirs(self, shell=None): '''set order of redcording directories ''' if shell is None: shell = self.shell self.recording_dirs = [x[:-1] if x.endswith('/') else x for x in self.recording_dirs] top_dirs = {os.path.basename(x): os.path.dirname(x) for x in self.recording_dirs} file_dirs = list(top_dirs.keys()) order_dict = dict.fromkeys(file_dirs, 0) tmp = userIO.dictIO(order_dict, shell=shell) order_dict = tmp.fill_dict(prompt=('Set order of recordings (1-%i)\n' 'Leave blank to delete directory' ' from list')) if order_dict is None: return file_dirs = [k for k, v in order_dict.items() if v is not None and v != 0] file_dirs = sorted(file_dirs, key=order_dict.get) file_dirs = [os.path.join(top_dirs.get(x), x) for x in file_dirs] self.recording_dirs = file_dirs
def get_cell_types(cluster_names, shell=True): '''Queries user to identify cluster as multiunit vs single-unit, regular vs fast spiking Parameters ---------- shell : bool (optional), True if command-line interface desired, False for GUI (default) Returns ------- dict with keys 'Single unit', 'Regular spiking', 'Fast spiking' and values are 0 or 1 for each key ''' query = { 'Single Unit': False, 'Regular Spiking': False, 'Fast Spiking': False } new_query = {} for name in cluster_names: new_query[name] = query.copy() query = userIO.dictIO(new_query, shell=shell) ans = query.fill_dict() if ans is None: return None out = [] for name in cluster_names: c = {} c['single_unit'] = int(ans[name]['Single Unit']) c['regular_spiking'] = int(ans[name]['Regular Spiking']) c['fast_spiking'] = int(ans[name]['Fast Spiking']) out.append(c.copy()) return out
def initParams(self, data_quality='clean', emg_port=None, emg_channels=None, shell=False, dig_in_names=None, dig_out_names=None, spike_array_params=None, psth_params=None, confirm_all=False): ''' Initializes basic default analysis parameters that can be customized before running processing methods Can provide data_quality as 'clean' or 'noisy' to preset some parameters that are useful for the different types. Best practice is to run as clean (default) and to re-run as noisy if you notice that a lot of electrodes are cutoff early ''' # Get parameters from info.rhd file_dir = self.data_dir rec_info = dio.rawIO.read_rec_info(file_dir, shell) ports = rec_info.pop('ports') channels = rec_info.pop('channels') sampling_rate = rec_info['amplifier_sampling_rate'] self.rec_info = rec_info self.sampling_rate = sampling_rate # Get default parameters for blech_clust clustering_params = deepcopy(dio.params.clustering_params) data_params = deepcopy(dio.params.data_params[data_quality]) bandpass_params = deepcopy(dio.params.bandpass_params) spike_snapshot = deepcopy(dio.params.spike_snapshot) if spike_array_params is None: spike_array_params = deepcopy(dio.params.spike_array_params) if psth_params is None: psth_params = deepcopy(dio.params.psth_params) # Ask for emg port & channels if emg_port is None and not shell: q = eg.ynbox('Do you have an EMG?', 'EMG') if q: emg_port = userIO.select_from_list('Select EMG Port:', ports, 'EMG Port', shell=shell) emg_channels = userIO.select_from_list( 'Select EMG Channels:', [y for x, y in zip(ports, channels) if x == emg_port], title='EMG Channels', multi_select=True, shell=shell) elif emg_port is None and shell: print('\nNo EMG port given.\n') electrode_mapping, emg_mapping = dio.params.flatten_channels( ports, channels, emg_port=emg_port, emg_channels=emg_channels) self.electrode_mapping = electrode_mapping self.emg_mapping = emg_mapping # Get digital input names and spike array parameters if rec_info.get('dig_in'): if dig_in_names is None: dig_in_names = dict.fromkeys(['dig_in_%i' % x for x in rec_info['dig_in']]) name_filler = userIO.dictIO(dig_in_names, shell=shell) dig_in_names = name_filler.fill_dict('Enter names for ' 'digital inputs:') if dig_in_names is None or \ any([x is None for x in dig_in_names.values()]): raise ValueError('Must name all dig_ins') dig_in_names = list(dig_in_names.values()) if spike_array_params['laser_channels'] is None: laser_dict = dict.fromkeys(dig_in_names, False) laser_filler = userIO.dictIO(laser_dict, shell=shell) laser_dict = laser_filler.fill_dict('Select any lasers:') if laser_dict is None: laser_channels = [] else: laser_channels = [i for i, v in zip(rec_info['dig_in'], laser_dict.values()) if v] spike_array_params['laser_channels'] = laser_channels else: laser_dict = dict.fromkeys(dig_in_names, False) for lc in spike_array_params['laser_channels']: laser_dict[dig_in_names[lc]] = True if spike_array_params['dig_ins_to_use'] is None: di = [x for x in rec_info['dig_in'] if x not in laser_channels] dn = [dig_in_names[x] for x in di] spike_dig_dict = dict.fromkeys(dn, True) filler = userIO.dictIO(spike_dig_dict, shell=shell) spike_dig_dict = filler.fill_dict('Select digital inputs ' 'to use for making spike' ' arrays:') if spike_dig_dict is None: spike_dig_ins = [] else: spike_dig_ins = [x for x, y in zip(di, spike_dig_dict.values()) if y] spike_array_params['dig_ins_to_use'] = spike_dig_ins dim = pd.DataFrame([(x, y) for x, y in zip(rec_info['dig_in'], dig_in_names)], columns=['dig_in', 'name']) dim['laser'] = dim['name'].apply(lambda x: laser_dict.get(x)) self.dig_in_mapping = dim.copy() # Get digital output names if rec_info.get('dig_out'): if dig_out_names is None: dig_out_names = dict.fromkeys(['dig_out_%i' % x for x in rec_info['dig_out']]) name_filler = userIO.dictIO(dig_out_names, shell=shell) dig_out_names = name_filler.fill_dict('Enter names for ' 'digital outputs:') if dig_out_names is None or \ any([x is None for x in dig_out_names.values()]): raise ValueError('Must name all dig_outs') dig_out_names = list(dig_out_names.values()) self.dig_out_mapping = pd.DataFrame([(x, y) for x, y in zip(rec_info['dig_out'], dig_out_names)], columns=['dig_out', 'name']) # Store clustering parameters self.clust_params = {'file_dir': file_dir, 'data_quality': data_quality, 'sampling_rate': sampling_rate, 'clustering_params': clustering_params, 'data_params': data_params, 'bandpass_params': bandpass_params, 'spike_snapshot': spike_snapshot} # Store and confirm spike array parameters spike_array_params['sampling_rate'] = sampling_rate self.spike_array_params = spike_array_params self.psth_params = psth_params if not confirm_all: prompt = ('\n----------\nSpike Array Parameters\n----------\n' + dp.print_dict(spike_array_params) + '\nAre these parameters good?') q_idx = userIO.ask_user(prompt, ('Yes', 'Edit'), shell=shell) if q_idx == 1: self.edit_spike_array_parameters(shell=shell) # Edit and store psth parameters prompt = ('\n----------\nPSTH Parameters\n----------\n' + dp.print_dict(psth_params) + '\nAre these parameters good?') q_idx = userIO.ask_user(prompt, ('Yes', 'Edit'), shell=shell) if q_idx == 1: self.edit_psth_parameters(shell=shell) self.save()
def edit_clusters(clusters, fs, shell=False): '''Handles editing of a cluster group until a user has a single cluster they are satisfied with Parameters ---------- clusters : list of dict list of dictionaries each defining clusters of spikes fs : float sampling rate in Hz shell : bool set True if command-line interface is desires, False for GUI (default) Returns ------- dict dict representing the resulting cluster from manipulations, None if aborted ''' clusters = deepcopy(clusters) quit_flag = False while not quit_flag: if len(clusters) == 1: # One cluster, ask if they want to keep or split idx = userIO.ask_user('What would you like to do with %s?' % clusters[0]['Cluster Name'], choices=[ 'Split', 'Waveform View', 'Raster View', 'Keep', 'Abort' ], shell=shell) if idx == 1: fig = plot_cluster(clusters[0]) plt.show() continue elif idx == 2: plot_raster(clusters) plt.show() continue elif idx == 3: return clusters[0] elif idx == 4: return None old_clust = clusters[0] clusters = split_cluster(clusters[0], fs, shell=shell) if clusters is None or clusters == []: return None figs = [] for i, c in enumerate(clusters): tmp_fig = plot_cluster(c, i) figs.append(tmp_fig) f2, ax2 = plot_pca_view(clusters) plt.show() query = {'Clusters to keep (indices)': []} query = userIO.dictIO(query, shell=shell) ans = query.fill_dict() if ans['Clusters to keep (indices)'][0] == '': print('Reset to before split') clusters = [old_clust] continue ans = [int(x) for x in ans['Clusters to keep (indices)']] new_clusters = [clusters[x] for x in ans] del clusters clusters = new_clusters del new_clusters else: idx = userIO.ask_user( ('You have %i clusters. ' 'What would you like to do?') % len(clusters), choices=[ 'Merge', 'Waveform View', 'Raster View', 'Keep', 'Abort' ], shell=shell) if idx == 0: # Automatically merge multiple clusters cluster = merge_clusters(clusters, fs) print('%i clusters merged into %s' % (len(clusters), cluster['Cluster Name'])) clusters = [cluster] elif idx == 1: for c in clusters: plot_cluster(c) plt.show() continue elif idx == 2: plot_raster(clusters) plt.show() continue elif idx == 3: return clusters elif idx == 4: return None
def split_cluster(cluster, fs, params=None, shell=True): '''Use GMM to re-cluster a single cluster into a number of sub clusters Parameters ---------- cluster : dict, cluster metrics and data params : dict (optional), parameters for reclustering with keys: 'Number of Clusters', 'Max Number of Iterations', 'Convergence Criterion', 'GMM random restarts' shell : bool , optional, whether to use shell or GUI input (defualt=True) Returns ------- clusters : list of dicts resultant clusters from split ''' if params is None: clustering_params = { 'Number of Clusters': 2, 'Max Number of Iterations': 1000, 'Convergence Criterion': 0.00001, 'GMM random restarts': 10 } params_filler = userIO.dictIO(clustering_params, shell=shell) params = params_filler.fill_dict() if params is None: return None n_clusters = int(params['Number of Clusters']) n_iter = int(params['Max Number of Iterations']) thresh = float(params['Convergence Criterion']) n_restarts = int(params['GMM random restarts']) g = GaussianMixture(n_components=n_clusters, covariance_type='full', tol=thresh, max_iter=n_iter, n_init=n_restarts) g.fit(cluster['data']) out_clusters = [] if g.converged_: spike_waveforms = cluster['spike_waveforms'] spike_times = cluster['spike_times'] data = cluster['data'] predictions = g.predict(data) for c in range(n_clusters): clust_idx = np.where(predictions == c)[0] tmp_clust = deepcopy(cluster) clust_id = str(cluster['cluster_id']) + \ bytes([b'A'[0]+c]).decode('utf-8') clust_name = 'E%iS%i_cluster%s' % \ (cluster['electrode'], cluster['solution'], clust_id) clust_waveforms = spike_waveforms[clust_idx] clust_times = spike_times[clust_idx] clust_data = data[clust_idx, :] clust_log = cluster['manipulations'] + \ '\nSplit %s with parameters: ' \ '\n%s\nCluster %i from split results. Named %s' \ % (cluster['Cluster Name'], dp.print_dict(params), c, clust_name) tmp_clust['Cluster Name'] = clust_name tmp_clust['cluster_id'] = clust_id tmp_clust['data'] = clust_data tmp_clust['spike_times'] = clust_times tmp_clust['spike_waveforms'] = clust_waveforms tmp_isi, tmp_violations1, tmp_violations2 = \ get_ISI_and_violations(clust_times, fs) tmp_clust['ISI'] = tmp_isi tmp_clust['1ms_violations'] = tmp_violations1 tmp_clust['2ms_violations'] = tmp_violations2 out_clusters.append(deepcopy(tmp_clust)) return out_clusters
def sort_units(file_dir, fs, shell=False): '''Allows user to sort clustered units Parameters ---------- file_dir : str, path to recording directory fs : float, sampling rate in Hz shell : bool True for command-line interface, False for GUI (default) ''' hf5_name = h5io.get_h5_filename(file_dir) hf5_file = os.path.join(file_dir, hf5_name) sorting_log = hf5_file.replace('.h5', '_sorting.log') metrics_dir = os.path.join(file_dir, 'sorted_unit_metrics') if not os.path.exists(metrics_dir): os.mkdir(metrics_dir) quit_flag = False # Start loop to label a cluster clust_info = { 'Electrode': 0, 'Clusters in solution': 7, 'Cluster Numbers': [], 'Edit Clusters': False } clust_query = userIO.dictIO(clust_info, shell=shell) print(('Beginning spike sorting for: \n\t%s\n' 'Sorting Log written to: \n\t%s') % (hf5_file, sorting_log)) print(('To select multiple clusters, simply type in the numbers of each\n' 'cluster comma-separated, however, you MUST check Edit Clusters in' 'order\nto split clusters, merges will be assumed' ' from multiple cluster selection.')) print(('\nIf using GUI input, hit cancel at any point to abort sorting' ' of current cluster.\nIf using shell interface,' ' type abort at any time.\n')) while not quit_flag: clust_info = clust_query.fill_dict() if clust_info is None: quit_flag = True break clusters = [] for c in clust_info['Cluster Numbers']: tmp = get_cluster_data(file_dir, clust_info['Electrode'], clust_info['Clusters in solution'], int(c), fs) clusters.append(tmp) if len(clusters) == 0: quit_flag = True break if clust_info['Edit Clusters']: clusters = edit_clusters(clusters, fs, shell) if isinstance(clusters, dict): clusters = [clusters] if clusters is None: quit_flag = True break cell_types = get_cell_types([x['Cluster Name'] for x in clusters], shell) if cell_types is None: quit_flag = True break else: for x, y in zip(clusters, cell_types): x.update(y) for unit in clusters: label_single_unit(hf5_file, unit, fs, sorting_log, metrics_dir) q = input('Do you wish to continue sorting units? (y/n) >> ') if q == 'n': quit_flag = True