Example #1
0
def main(anim):
    rec_dirs, anim_dir = get_rec_dirs(anim)
    for fd in rec_dirs:
        pre_process(fd)

    clustering(anim_dir)

    ## Sort spike manually 1 electrode and 1 recording at a time

    for fd in rec_dirs:
        dat = blechpy.load_dataset(fd)
        # Sort each electrode
        # root, ssg = dat.sort_spikes(electrode) 
        dat.cleanup_lowSpiking_units(min_spikes=100)
        dat.units_similarity(shell=True)

    ## Check units similarity and delete necessary units
    # Then
    for fd in rec_dirs:
        dat = blechpy.load_dataset(fd)
        dat.make_unit_plots()
        dat.make_unit_arrays()
        dat.make_psth_arrays()

    exp.detect_held_units()
    exp.save()
Example #2
0
def clustering(exp):
    # TODO Actually I now cluster 1 recording at a time
    if isinstance(exp, str):
        fd = exp
        exp = blechpy.load_experiment(fd)
        if exp is None:
            exp = blechpy.experiment(fd)

    dat = blechpy.load_dataset(exp.recording_dirs[0])
    clustering_params = dat.clustering_params.copy()
    clustering_params['clustering_params']['Max Number of Clusters'] = 15
    exp.cluster_spikes(custom_params=clustering_params, umap=True)
    for fd in exp.recording_dirs:
        dat = blechpy.load_dataset(fd)
        dat.cleanup_clustering()
Example #3
0
    def __init__(self, dat, n_states, save_dir=None, params=None):
        if isinstance(dat, str):
            dat = blechpy.load_dataset(dat)
            if dat is None:
                raise FileNotFoundError('No dataset.p file found given directory')

        if save_dir is None:
            save_dir = os.path.join(dat.root_dir,
                                    '%s_analysis' % dat.data_name)

        self._dataset = dat
        self.root_dir = dat.root_dir
        self.save_dir = save_dir
        self.n_states = n_states

        if not os.path.isdir(save_dir):
            os.makedirs(save_dir)

        plot_dir = os.path.join(save_dir, '%i_states' % n_states, 'plots')
        if not os.path.isdir(plot_dir):
            os.makedirs(plot_dir)

        self._plot_dir = plot_dir

        self._files = {'hmm_data': os.path.join(save_dir, 'hmm_data.hdf5'),
                       'params' : os.path.join(save_dir, 'hmm_params.json')}
        file_check = self._file_check()
        self.update_params(params)

        dig_in_map = dat.dig_in_mapping.query('spike_array == True and exclude == False')
        self._dig_ins = dig_in_map.set_index('name')
        self._fitted_models = {}
        self._setup_hdf5()
Example #4
0
def set_electrode_areas(proj, el_in_gc={}):
    exp_info = proj._exp_info
    for i, row in exp_info.iterrows():
        name = row['exp_name']
        if name not in el_in_gc.keys():
            continue

        exp = blechpy.load_experiment(row['exp_dir'])
        ingc = el_in_gc[name]
        if ingc is 'right':
            el = np.arange(8, 24)
        elif ingc is 'left':
            el = np.concatenate([np.arange(0, 8), np.arange(24, 32)])
        elif ingc is 'none':
            el = np.arange(0, 32)
        else:
            el = None

        for rec in exp.recording_dirs:
            dat = load_dataset(rec)
            print('Fixing %s...' % dat.data_name)
            em = dat.electrode_mapping
            em['area'] = 'GC'
            if el is not None:
                em.loc[em['Channel'].isin(el), 'area'] = 'STR'

            h5io.write_electrode_map_to_h5(dat.h5_file, em)
            dat.save()

    return
Example #5
0
def init_dat(dat):
    if isinstance(dat, str):
        fd = dat
        dat = blechpy.load_dataset(fd)
        if dat is None:
            data_name = os.path.basename(fd).split('_')
            _ = data_name.pop(-1)
            _ = data_name.pop(-1)
            data_name = '_'.join(data_name)
            dat = blechpy.dataset(file_dir=fd, data_name=data_name, shell=True)
            dat.save()

    rec_dir = dat.root_dir
    rec_type = os.path.basename(rec_dir).split('_')[1]

    status = dat.process_status
    if not status['initialize parameters']:
        params = init_params.copy()
        if rec_type == '4taste':
            dig_in_names = ['Water', 'Quinine', 'NaCl', 'Citric Acid']
        else:
            dig_in_names= ['Saccharin']

        params['dig_in_names'] = dig_in_names
        dat.initParams(**params)
Example #6
0
def run_hmms(rec_dirs, constraint_func=None):
    base_params = {'unit_type': 'single', 'dt': 0.001,
                   'max_iter': 200, 'n_repeats': 50, 'time_start': -250,
                   'time_end': 2000, 'n_states': 3, 'area': 'GC',
                   'hmm_class': 'PoissonHMM', 'threshold':1e-10,
                   'notes': 'sequential - low thresh'}

    params = [{'n_states': 2}, {'n_states': 3}, {'time_start': -200, 'n_states': 4}]

    for rec_dir in rec_dirs:
        units = phmm.query_units(rec_dir, 'single', area='GC')
        if len(units) < 2:
            continue
        handler = phmm.HmmHandler(rec_dir)
        dat = load_dataset(rec_dir)

        for i, row in dat.dig_in_mapping.iterrows():
            if row['laser']:
                continue

            name = row['name']
            ch = row['channel']
            on_trials, off_trials = get_laser_trials(rec_dir, ch)

            for new_params in params:
                p = base_params.copy()
                p.update(new_params)
                p['taste'] = name
                p['channel'] ch
                p['notes'] += ' - all_trials'
                handler.add_params(p)

                p = base_params.copy()
                p.update(new_params)
                p['taste'] = name
                p['channel'] ch
                p['trial_nums'] = on_trials
                p['notes'] += ' - on_trials'
                handler.add_params(p)

                p = base_params.copy()
                p.update(new_params)
                p['taste'] = name
                p['channel'] ch
                p['trial_nums'] = off_trials
                p['notes'] += ' - off_trials'
                handler.add_params(p)

        dataname = os.path.basename(rec_dir)
        print('Fitting %s' % os.path.basename(rec_dir))
        if type(constraint) == 'function':
            print('Fitting Constraint: %s' % constraint.__name__)
        else:
            print('Fitting Constraint: %s' % str(constraint))

        handler.run(constraint_func=constraint)
Example #7
0
def post_process(dat):
    if isinstance(dat, str):
        fd = dat
        dat = blechpy.load_dataset(fd)
        if dat is None:
            raise FileNotFoundError('Dataset for %s not found.' % fd)

    dat.cleanup_lowSpiking_units(min_spikes=100)
    dat.units_similarity(shell=True)
    dat.make_unit_plots()
    dat.make_unit_arrays()
    dat.make_psth_arrays()
Example #8
0
    def __init__(self, dat, params=None, save_dir=None):
        '''Takes a blechpy dataset object and fits HMMs for each tastant

        Parameters
        ----------
        dat: blechpy.dataset
        params: dict or list of dicts
            each dict must have fields:
                time_window: list of int, time window to cut around stimuli in ms
                convergence_thresh: float
                max_iter: int
                n_repeats: int
                unit_type: str, {'single', 'pyramidal', 'interneuron', 'all'}
                bin_size: time bin for spike array when fitting in seconds
                n_states: predicted number of states to fit
        '''
        if isinstance(params, dict):
            params = [params]

        if isinstance(dat, str):
            dat = blechpy.load_dataset(dat)
            if dat is None:
                raise FileNotFoundError('No dataset.p file found given directory')

        if save_dir is None:
            save_dir = os.path.join(dat.root_dir,
                                    '%s_analysis' % dat.data_name)

        self._dataset = dat
        self.root_dir = dat.root_dir
        self.save_dir = save_dir
        self.h5_file = os.path.join(save_dir, '%s_HMM_Analysis.hdf5' % dat.data_name)
        dim = dat.dig_in_mapping.query('exclude==False')
        tastes = dim['name'].tolist()
        if params is None:
            # Load params and fitted models
            self.load_data()
        else:
            self.init_params(params)

        self.params = params

        if not os.path.isdir(save_dir):
            os.makedirs(save_dir)

        self.plot_dir = os.path.join(save_dir, 'HMM_Plots')
        if not os.path.isdir(self.plot_dir):
            os.makedirs(self.plot_dir)

        self._setup_hdf5()
Example #9
0
def fix_palatability(proj, pal_map=None):
    '''Goes through all datasets in project and fixes palatability rankings
    '''
    if pal_map is None:
        pal_map = PAL_MAP

    exp_dirs = proj._exp_info.exp_dir.to_list()

    for exp_dir in tqdm(exp_dirs):
        exp = blechpy.load_experiment(exp_dir)
        for rd in exp.recording_dirs:
            dat = load_dataset(rd)
            dat.dig_in_mapping[
                'palatability_rank'] = dat.dig_in_mapping.name.map(pal_map)
            h5io.write_digital_map_to_h5(dat.h5_file, dat.dig_in_mapping, 'in')
            dat.save()
Example #10
0
def get_all_units(proj):
    # Columns:
    #   - exp_name, exp_group, rec_name, rec_group, rec_dir, unit_num,
    #   - electrode, area, single, unit_type
    all_units = pd.DataFrame(columns=[
        'exp_name', 'exp_group', 'rec_name', 'rec_group', 'rec_dir',
        'unit_name', 'unit_num', 'electrode', 'area', 'single_unit',
        'regular_spiking', 'fast_spiking'
    ])
    for i, row in proj._exp_info.iterrows():
        exp_name = row['exp_name']
        exp_group = row['exp_group']
        exp_dir = row['exp_dir']
        exp = blechpy.load_experiment(exp_dir)
        for rec_name, rec_dir in exp.rec_labels.items():
            if 'preCTA' in rec_name:
                rec_group = 'preCTA'
            elif 'postCTA' in rec_name:
                rec_group = 'postCTA'
            elif 'Train' in rec_name:
                rec_group = 'ctaTrain'
            elif 'Test' in rec_name:
                rec_group = 'ctaTest'
            else:
                # TODO: Make more elegant, ask for input
                raise ValueError('Rec %s does not fit into a group' % rec_name)

            dat = load_dataset(rec_dir)
            units = dat.get_unit_table().copy()
            units['exp_name'] = exp_name
            units['exp_group'] = exp_group
            units['rec_name'] = rec_name
            units['rec_group'] = rec_group
            units['rec_dir'] = rec_dir

            em = dat.electrode_mapping.copy().set_index('Electrode')
            units['area'] = units['electrode'].map(em['area'])
            units = units[all_units.columns]
            all_units = all_units.append(units).reset_index(drop=True)

    return all_units
Example #11
0
def pre_process(dat, dead_ch=[]):
    if isinstance(dat, str):
        fd = dat
        dat = blechpy.load_dataset(fd)

    if dat is None:
        raise ValueError('No Dataset found')

    status = dat.process_status
    if not status['extract_data']:
        dat.extract_data()

    if not status['create_trial_list']:
        dat.create_trial_list()

    if not status['mark_dead_channels']:
        dat.mark_dead_channels(dead_ch)

    if not status['common_average_reference']:
        dat.common_average_reference()

    if not status['spike_detection']:
        dat.detect_spikes()
Example #12
0
def get_pca_data(rec,
                 units,
                 bin_size,
                 step=None,
                 t_start=None,
                 t_end=None,
                 baseline_win=None):
    '''Get spike data, turns it into binned firing rate traces and then
    organizes them into a format for PCA (trials*time X units)
    
    Parameters
    ----------
    
    
    Returns
    -------
    
    
    Raises
    ------
    
    '''
    if step is None:
        step = bin_size

    st, sa = h5io.get_spike_data(rec, units)
    if t_start is None:
        t_start = st[0]

    if t_end is None:
        t_end = st[-1]

    spikes = []
    labels = []
    dim = load_dataset(rec).dig_in_mapping.set_index('channel')
    if isinstance(sa, dict):
        for k, v in sa.items():
            ch = int(k.split('_')[-1])
            tst = dim.loc[ch, 'name']
            l = [(tst, i) for i in range(v.shape[0])]
            if len(v.shape) == 2:
                tmp = np.expand_dims(v, 1)
            else:
                tmp = v

            labels.extend(l)
            spikes.append(tmp)

    else:
        if len(sa.shape) == 2:
            tmp = np.exapnd_dims(sa, 1)
        else:
            tmp = sa

        spikes.append(tmp)
        tst = dim.loc[0, 'name']
        l = [(tst, i) for i in range(sa.shape[0])]
        labels.extend(l)

    spikes = np.vstack(spikes).astype('float64')
    b_idx = np.where(st < 0)[0]
    baseline_fr = np.sum(spikes[:, :, b_idx], axis=-1) / (len(b_idx) / 1000)
    baseline_fr = np.mean(baseline_fr, axis=0)
    t_idx = np.where((st >= t_start) & (st <= t_end))[0]
    spikes = spikes[:, :, t_idx]
    st = st[t_idx]

    fr_lbls = []
    fr_arr = []
    fr_time = None
    for trial_i, (trial, lbl) in enumerate(zip(spikes, labels)):
        fr_t, fr = sas.get_binned_firing_rate(st, trial, bin_size, step)
        # fr is units x time
        fr = fr.T  # now its time x units
        fr = fr - baseline_fr  # subtract baseline firing rate
        l = [(*lbl, t) for t in fr_t]
        fr_arr.append(fr)
        fr_lbls.extend(l)
        if fr_time is None:
            fr_time = fr_t
        elif not np.array_equal(fr_t, fr_time):
            raise ValueError('Time Vectors dont match')

    # So now fr_lbls is [taste, trial, time]
    fr_out = np.vstack(fr_arr)
    fr_lbls = np.array(fr_lbls)
    return fr_time, fr_out, fr_lbls
Example #13
0
def apply_pca_analysis(df, params):
    '''df is held_units dataframe grouped by exp_name, exp_group, time_group
    only contains units held over preCTA or postCTA, no units held from ctaTrain to ctaTest
    
    Parameters
    ----------
    
    
    Returns
    -------
    
    
    Raises
    ------
    
    '''
    bin_size = params['pca']['win_size']
    bin_step = params['pca']['step_size']
    time_start = params['pca']['time_win'][0]
    time_end = params['pca']['time_win'][1]
    smoothing = params['pca']['smoothing_win']
    n_cells = len(df)

    rd1 = df['rec1'].unique()
    rd2 = df['rec2'].unique()
    if len(rd1) > 1 or len(rd2) > 1:
        raise ValueError('Too many recording directories')

    rd1 = rd1[0]
    rd2 = rd2[0]
    units1 = list(df['unit1'].unique())
    units2 = list(df['unit2'].unique())
    dim1 = load_dataset(rd1).dig_in_mapping.set_index('channel')
    dim2 = load_dataset(rd2).dig_in_mapping.set_index('channel')
    if n_cells < 2:
        # No point if only 1 unit
        exp_name = os.path.basename(rd1).split('_')
        print('%s - %s: Not enough units for PCA analysis' %
              (exp_name[0], exp_name[-3]))
        return

    time, sa = h5io.get_spike_data(rd1, units1)
    fr_t, fr, fr_lbls = get_pca_data(rd1,
                                     units1,
                                     bin_size,
                                     step=bin_step,
                                     t_start=time_start,
                                     t_end=time_end)
    rates = fr
    labels = fr_lbls
    time = fr_t
    # Again with rec2
    fr_t, fr, fr_lbls = get_pca_data(rd2,
                                     units2,
                                     bin_size,
                                     step=bin_step,
                                     t_start=time_start,
                                     t_end=time_end)
    rates = np.vstack([rates, fr])
    labels = np.vstack([labels, fr_lbls])
    # So now rates is tastes*trial*times X units

    # Do PCA on all data, put in (trials*time)xcells 2D matrix
    # pca = MDS(n_components=2)
    pca = PCA(n_components=2)
    pc_values = pca.fit_transform(rates)
    mds = MDS(n_components=2)
    md_values = mds.fit_transform(rates)

    out_df = pd.DataFrame(labels, columns=['taste', 'trial', 'time'])
    out_df['n_cells'] = n_cells
    out_df[['PC1', 'PC2']] = pd.DataFrame(pc_values)
    out_df[['MDS1', 'MDS2']] = pd.DataFrame(md_values)

    # Compute the MDS distance metric using the full dimensional solution
    # For each point computes distance to mean Quinine / distance to mean NaCl
    mds = MDS(n_components=rates.shape[1])
    mds_values = mds.fit_transform(rates)
    n_idx = np.where(labels[:, 0] == 'NaCl')[0]
    q_idx = np.where(labels[:, 0] == 'Quinine')[0]
    q_mean = np.mean(mds_values[q_idx, :], axis=0)
    n_mean = np.mean(mds_values[n_idx, :], axis=0)
    dist_metric = [
        euclidean(x, q_mean) / euclidean(x, n_mean) for x in mds_values
    ]
    assert len(
        dist_metric) == rates.shape[0], 'computed distances over wrong axis'
    out_df['dQ_v_dN_fullMDS'] = pd.DataFrame(dist_metric)

    # Do it again with raw rates
    q_mean = np.mean(rates[q_idx, :], axis=0)
    n_mean = np.mean(rates[n_idx, :], axis=0)
    raw_metric = [euclidean(x, q_mean) / euclidean(x, n_mean) for x in rates]
    assert len(
        raw_metric) == rates.shape[0], 'computed distances over wrong axis'
    out_df['dQ_v_dN_rawRates'] = pd.DataFrame(raw_metric)

    return out_df
Example #14
0
 def exp_name(rd):
     dat = load_dataset(rd)
     en = dat.data_name.split('_')[0]
     if en == 'RN5b':
         en = 'RN5'
     return en