def main(anim): rec_dirs, anim_dir = get_rec_dirs(anim) for fd in rec_dirs: pre_process(fd) clustering(anim_dir) ## Sort spike manually 1 electrode and 1 recording at a time for fd in rec_dirs: dat = blechpy.load_dataset(fd) # Sort each electrode # root, ssg = dat.sort_spikes(electrode) dat.cleanup_lowSpiking_units(min_spikes=100) dat.units_similarity(shell=True) ## Check units similarity and delete necessary units # Then for fd in rec_dirs: dat = blechpy.load_dataset(fd) dat.make_unit_plots() dat.make_unit_arrays() dat.make_psth_arrays() exp.detect_held_units() exp.save()
def clustering(exp): # TODO Actually I now cluster 1 recording at a time if isinstance(exp, str): fd = exp exp = blechpy.load_experiment(fd) if exp is None: exp = blechpy.experiment(fd) dat = blechpy.load_dataset(exp.recording_dirs[0]) clustering_params = dat.clustering_params.copy() clustering_params['clustering_params']['Max Number of Clusters'] = 15 exp.cluster_spikes(custom_params=clustering_params, umap=True) for fd in exp.recording_dirs: dat = blechpy.load_dataset(fd) dat.cleanup_clustering()
def __init__(self, dat, n_states, save_dir=None, params=None): if isinstance(dat, str): dat = blechpy.load_dataset(dat) if dat is None: raise FileNotFoundError('No dataset.p file found given directory') if save_dir is None: save_dir = os.path.join(dat.root_dir, '%s_analysis' % dat.data_name) self._dataset = dat self.root_dir = dat.root_dir self.save_dir = save_dir self.n_states = n_states if not os.path.isdir(save_dir): os.makedirs(save_dir) plot_dir = os.path.join(save_dir, '%i_states' % n_states, 'plots') if not os.path.isdir(plot_dir): os.makedirs(plot_dir) self._plot_dir = plot_dir self._files = {'hmm_data': os.path.join(save_dir, 'hmm_data.hdf5'), 'params' : os.path.join(save_dir, 'hmm_params.json')} file_check = self._file_check() self.update_params(params) dig_in_map = dat.dig_in_mapping.query('spike_array == True and exclude == False') self._dig_ins = dig_in_map.set_index('name') self._fitted_models = {} self._setup_hdf5()
def set_electrode_areas(proj, el_in_gc={}): exp_info = proj._exp_info for i, row in exp_info.iterrows(): name = row['exp_name'] if name not in el_in_gc.keys(): continue exp = blechpy.load_experiment(row['exp_dir']) ingc = el_in_gc[name] if ingc is 'right': el = np.arange(8, 24) elif ingc is 'left': el = np.concatenate([np.arange(0, 8), np.arange(24, 32)]) elif ingc is 'none': el = np.arange(0, 32) else: el = None for rec in exp.recording_dirs: dat = load_dataset(rec) print('Fixing %s...' % dat.data_name) em = dat.electrode_mapping em['area'] = 'GC' if el is not None: em.loc[em['Channel'].isin(el), 'area'] = 'STR' h5io.write_electrode_map_to_h5(dat.h5_file, em) dat.save() return
def init_dat(dat): if isinstance(dat, str): fd = dat dat = blechpy.load_dataset(fd) if dat is None: data_name = os.path.basename(fd).split('_') _ = data_name.pop(-1) _ = data_name.pop(-1) data_name = '_'.join(data_name) dat = blechpy.dataset(file_dir=fd, data_name=data_name, shell=True) dat.save() rec_dir = dat.root_dir rec_type = os.path.basename(rec_dir).split('_')[1] status = dat.process_status if not status['initialize parameters']: params = init_params.copy() if rec_type == '4taste': dig_in_names = ['Water', 'Quinine', 'NaCl', 'Citric Acid'] else: dig_in_names= ['Saccharin'] params['dig_in_names'] = dig_in_names dat.initParams(**params)
def run_hmms(rec_dirs, constraint_func=None): base_params = {'unit_type': 'single', 'dt': 0.001, 'max_iter': 200, 'n_repeats': 50, 'time_start': -250, 'time_end': 2000, 'n_states': 3, 'area': 'GC', 'hmm_class': 'PoissonHMM', 'threshold':1e-10, 'notes': 'sequential - low thresh'} params = [{'n_states': 2}, {'n_states': 3}, {'time_start': -200, 'n_states': 4}] for rec_dir in rec_dirs: units = phmm.query_units(rec_dir, 'single', area='GC') if len(units) < 2: continue handler = phmm.HmmHandler(rec_dir) dat = load_dataset(rec_dir) for i, row in dat.dig_in_mapping.iterrows(): if row['laser']: continue name = row['name'] ch = row['channel'] on_trials, off_trials = get_laser_trials(rec_dir, ch) for new_params in params: p = base_params.copy() p.update(new_params) p['taste'] = name p['channel'] ch p['notes'] += ' - all_trials' handler.add_params(p) p = base_params.copy() p.update(new_params) p['taste'] = name p['channel'] ch p['trial_nums'] = on_trials p['notes'] += ' - on_trials' handler.add_params(p) p = base_params.copy() p.update(new_params) p['taste'] = name p['channel'] ch p['trial_nums'] = off_trials p['notes'] += ' - off_trials' handler.add_params(p) dataname = os.path.basename(rec_dir) print('Fitting %s' % os.path.basename(rec_dir)) if type(constraint) == 'function': print('Fitting Constraint: %s' % constraint.__name__) else: print('Fitting Constraint: %s' % str(constraint)) handler.run(constraint_func=constraint)
def post_process(dat): if isinstance(dat, str): fd = dat dat = blechpy.load_dataset(fd) if dat is None: raise FileNotFoundError('Dataset for %s not found.' % fd) dat.cleanup_lowSpiking_units(min_spikes=100) dat.units_similarity(shell=True) dat.make_unit_plots() dat.make_unit_arrays() dat.make_psth_arrays()
def __init__(self, dat, params=None, save_dir=None): '''Takes a blechpy dataset object and fits HMMs for each tastant Parameters ---------- dat: blechpy.dataset params: dict or list of dicts each dict must have fields: time_window: list of int, time window to cut around stimuli in ms convergence_thresh: float max_iter: int n_repeats: int unit_type: str, {'single', 'pyramidal', 'interneuron', 'all'} bin_size: time bin for spike array when fitting in seconds n_states: predicted number of states to fit ''' if isinstance(params, dict): params = [params] if isinstance(dat, str): dat = blechpy.load_dataset(dat) if dat is None: raise FileNotFoundError('No dataset.p file found given directory') if save_dir is None: save_dir = os.path.join(dat.root_dir, '%s_analysis' % dat.data_name) self._dataset = dat self.root_dir = dat.root_dir self.save_dir = save_dir self.h5_file = os.path.join(save_dir, '%s_HMM_Analysis.hdf5' % dat.data_name) dim = dat.dig_in_mapping.query('exclude==False') tastes = dim['name'].tolist() if params is None: # Load params and fitted models self.load_data() else: self.init_params(params) self.params = params if not os.path.isdir(save_dir): os.makedirs(save_dir) self.plot_dir = os.path.join(save_dir, 'HMM_Plots') if not os.path.isdir(self.plot_dir): os.makedirs(self.plot_dir) self._setup_hdf5()
def fix_palatability(proj, pal_map=None): '''Goes through all datasets in project and fixes palatability rankings ''' if pal_map is None: pal_map = PAL_MAP exp_dirs = proj._exp_info.exp_dir.to_list() for exp_dir in tqdm(exp_dirs): exp = blechpy.load_experiment(exp_dir) for rd in exp.recording_dirs: dat = load_dataset(rd) dat.dig_in_mapping[ 'palatability_rank'] = dat.dig_in_mapping.name.map(pal_map) h5io.write_digital_map_to_h5(dat.h5_file, dat.dig_in_mapping, 'in') dat.save()
def get_all_units(proj): # Columns: # - exp_name, exp_group, rec_name, rec_group, rec_dir, unit_num, # - electrode, area, single, unit_type all_units = pd.DataFrame(columns=[ 'exp_name', 'exp_group', 'rec_name', 'rec_group', 'rec_dir', 'unit_name', 'unit_num', 'electrode', 'area', 'single_unit', 'regular_spiking', 'fast_spiking' ]) for i, row in proj._exp_info.iterrows(): exp_name = row['exp_name'] exp_group = row['exp_group'] exp_dir = row['exp_dir'] exp = blechpy.load_experiment(exp_dir) for rec_name, rec_dir in exp.rec_labels.items(): if 'preCTA' in rec_name: rec_group = 'preCTA' elif 'postCTA' in rec_name: rec_group = 'postCTA' elif 'Train' in rec_name: rec_group = 'ctaTrain' elif 'Test' in rec_name: rec_group = 'ctaTest' else: # TODO: Make more elegant, ask for input raise ValueError('Rec %s does not fit into a group' % rec_name) dat = load_dataset(rec_dir) units = dat.get_unit_table().copy() units['exp_name'] = exp_name units['exp_group'] = exp_group units['rec_name'] = rec_name units['rec_group'] = rec_group units['rec_dir'] = rec_dir em = dat.electrode_mapping.copy().set_index('Electrode') units['area'] = units['electrode'].map(em['area']) units = units[all_units.columns] all_units = all_units.append(units).reset_index(drop=True) return all_units
def pre_process(dat, dead_ch=[]): if isinstance(dat, str): fd = dat dat = blechpy.load_dataset(fd) if dat is None: raise ValueError('No Dataset found') status = dat.process_status if not status['extract_data']: dat.extract_data() if not status['create_trial_list']: dat.create_trial_list() if not status['mark_dead_channels']: dat.mark_dead_channels(dead_ch) if not status['common_average_reference']: dat.common_average_reference() if not status['spike_detection']: dat.detect_spikes()
def get_pca_data(rec, units, bin_size, step=None, t_start=None, t_end=None, baseline_win=None): '''Get spike data, turns it into binned firing rate traces and then organizes them into a format for PCA (trials*time X units) Parameters ---------- Returns ------- Raises ------ ''' if step is None: step = bin_size st, sa = h5io.get_spike_data(rec, units) if t_start is None: t_start = st[0] if t_end is None: t_end = st[-1] spikes = [] labels = [] dim = load_dataset(rec).dig_in_mapping.set_index('channel') if isinstance(sa, dict): for k, v in sa.items(): ch = int(k.split('_')[-1]) tst = dim.loc[ch, 'name'] l = [(tst, i) for i in range(v.shape[0])] if len(v.shape) == 2: tmp = np.expand_dims(v, 1) else: tmp = v labels.extend(l) spikes.append(tmp) else: if len(sa.shape) == 2: tmp = np.exapnd_dims(sa, 1) else: tmp = sa spikes.append(tmp) tst = dim.loc[0, 'name'] l = [(tst, i) for i in range(sa.shape[0])] labels.extend(l) spikes = np.vstack(spikes).astype('float64') b_idx = np.where(st < 0)[0] baseline_fr = np.sum(spikes[:, :, b_idx], axis=-1) / (len(b_idx) / 1000) baseline_fr = np.mean(baseline_fr, axis=0) t_idx = np.where((st >= t_start) & (st <= t_end))[0] spikes = spikes[:, :, t_idx] st = st[t_idx] fr_lbls = [] fr_arr = [] fr_time = None for trial_i, (trial, lbl) in enumerate(zip(spikes, labels)): fr_t, fr = sas.get_binned_firing_rate(st, trial, bin_size, step) # fr is units x time fr = fr.T # now its time x units fr = fr - baseline_fr # subtract baseline firing rate l = [(*lbl, t) for t in fr_t] fr_arr.append(fr) fr_lbls.extend(l) if fr_time is None: fr_time = fr_t elif not np.array_equal(fr_t, fr_time): raise ValueError('Time Vectors dont match') # So now fr_lbls is [taste, trial, time] fr_out = np.vstack(fr_arr) fr_lbls = np.array(fr_lbls) return fr_time, fr_out, fr_lbls
def apply_pca_analysis(df, params): '''df is held_units dataframe grouped by exp_name, exp_group, time_group only contains units held over preCTA or postCTA, no units held from ctaTrain to ctaTest Parameters ---------- Returns ------- Raises ------ ''' bin_size = params['pca']['win_size'] bin_step = params['pca']['step_size'] time_start = params['pca']['time_win'][0] time_end = params['pca']['time_win'][1] smoothing = params['pca']['smoothing_win'] n_cells = len(df) rd1 = df['rec1'].unique() rd2 = df['rec2'].unique() if len(rd1) > 1 or len(rd2) > 1: raise ValueError('Too many recording directories') rd1 = rd1[0] rd2 = rd2[0] units1 = list(df['unit1'].unique()) units2 = list(df['unit2'].unique()) dim1 = load_dataset(rd1).dig_in_mapping.set_index('channel') dim2 = load_dataset(rd2).dig_in_mapping.set_index('channel') if n_cells < 2: # No point if only 1 unit exp_name = os.path.basename(rd1).split('_') print('%s - %s: Not enough units for PCA analysis' % (exp_name[0], exp_name[-3])) return time, sa = h5io.get_spike_data(rd1, units1) fr_t, fr, fr_lbls = get_pca_data(rd1, units1, bin_size, step=bin_step, t_start=time_start, t_end=time_end) rates = fr labels = fr_lbls time = fr_t # Again with rec2 fr_t, fr, fr_lbls = get_pca_data(rd2, units2, bin_size, step=bin_step, t_start=time_start, t_end=time_end) rates = np.vstack([rates, fr]) labels = np.vstack([labels, fr_lbls]) # So now rates is tastes*trial*times X units # Do PCA on all data, put in (trials*time)xcells 2D matrix # pca = MDS(n_components=2) pca = PCA(n_components=2) pc_values = pca.fit_transform(rates) mds = MDS(n_components=2) md_values = mds.fit_transform(rates) out_df = pd.DataFrame(labels, columns=['taste', 'trial', 'time']) out_df['n_cells'] = n_cells out_df[['PC1', 'PC2']] = pd.DataFrame(pc_values) out_df[['MDS1', 'MDS2']] = pd.DataFrame(md_values) # Compute the MDS distance metric using the full dimensional solution # For each point computes distance to mean Quinine / distance to mean NaCl mds = MDS(n_components=rates.shape[1]) mds_values = mds.fit_transform(rates) n_idx = np.where(labels[:, 0] == 'NaCl')[0] q_idx = np.where(labels[:, 0] == 'Quinine')[0] q_mean = np.mean(mds_values[q_idx, :], axis=0) n_mean = np.mean(mds_values[n_idx, :], axis=0) dist_metric = [ euclidean(x, q_mean) / euclidean(x, n_mean) for x in mds_values ] assert len( dist_metric) == rates.shape[0], 'computed distances over wrong axis' out_df['dQ_v_dN_fullMDS'] = pd.DataFrame(dist_metric) # Do it again with raw rates q_mean = np.mean(rates[q_idx, :], axis=0) n_mean = np.mean(rates[n_idx, :], axis=0) raw_metric = [euclidean(x, q_mean) / euclidean(x, n_mean) for x in rates] assert len( raw_metric) == rates.shape[0], 'computed distances over wrong axis' out_df['dQ_v_dN_rawRates'] = pd.DataFrame(raw_metric) return out_df
def exp_name(rd): dat = load_dataset(rd) en = dat.data_name.split('_')[0] if en == 'RN5b': en = 'RN5' return en