Ejemplo n.º 1
0
    def edit_clustering_parameters(self, shell=False):
        '''Allows user interface for editing clustering parameters

        Parameters
        ----------
        shell : bool (optional)
            True if you want command-line interface, False for GUI (default)
        '''
        param_filler = userIO.dictIO(self.clust_params, shell=shell)
        tmp = param_filler.fill_dict()
        if tmp:
            self.clust_params = tmp
Ejemplo n.º 2
0
    def edit_psth_parameters(self, shell=False):
        '''Allows user interface for editing psth parameters

        Parameters
        ----------
        shell : bool (optional)
            True if you want command-line interface, False for GUI (default)
        '''
        param_filler = userIO.dictIO(self.psth_params, shell=shell)
        tmp = param_filler.fill_dict('Edit params for making PSTHs\n'
                                     'All times are in ms')
        if tmp:
            self.psth_params = tmp
    def __init__(self, exp_dir=None, shell=False):
        '''Setup for analysis across recording sessions

        Parameters
        ----------
        exp_dir : str (optional)
            path to directory containing all recording directories
            if None (default) is passed then a popup to choose file
            will come up
        shell : bool (optional)
            True to use command-line interface for user input
            False (default) for GUI
        '''
        if exp_dir is None:
            exp_dir = eg.diropenbox('Select Experiment Directory',
                                    'Experiment Directory')
            if exp_dir is None or exp_dir == '':
                return

        fd = [os.path.join(exp_dir, x) for x in os.listdir(exp_dir)]
        file_dirs = [x for x in fd if os.path.isdir(x)]
        order_dict = dict.fromkeys(file_dirs, 0)
        tmp = userIO.dictIO(order_dict, shell=shell)
        order_dict = tmp.fill_dict(prompt=('Set order of recordings (1-%i)\n'
                                           'Leave blank to delete directory'
                                           ' from list'))
        if order_dict is None:
            return

        file_dirs = [k for k, v in order_dict.items()
                     if v is not None and v != 0]
        file_dirs = sorted(file_dirs, key=order_dict.get)
        file_dirs = [os.path.join(exp_dir, x) for x in file_dirs]
        file_dirs = [x[:-1] if x.endswith('/') else x
                     for x in file_dirs]
        self.recording_dirs = file_dirs
        self.experiment_dir = exp_dir
        self.shell = shell

        dat = dataset.load_dataset(file_dirs[0])
        em = dat.electrode_mapping.copy()
        ingc = userIO.select_from_list('Select all eletrodes confirmed in GC',
                                       em['Electrode'],
                                       multi_select=True, shell=shell)
        ingc = list(map(int, ingc))
        em['Area'] = np.where(em['Electrode'].isin(ingc), 'GC', 'Other')
        self.electrode_mapping = em
        self.save_file = os.path.join(exp_dir, '%s_experiment.p'
                                      % os.path.basename(exp_dir))
Ejemplo n.º 4
0
 def edit_spike_array_parameters(self, shell=False):
     params = self.spike_array_params
     param_filler = userIO.dictIO(params, shell=shell)
     new_params = param_filler.fill_dict(prompt=('Input desired parameters'
                                                 ' (Times are in ms'))
     if new_params is None:
         return None
     else:
         new_params['dig_ins_to_use'] = [int(x) for x in
                                         new_params['dig_ins_to_use']
                                         if x != '']
         new_params['laser_channels'] = [int(x) for x in
                                         new_params['laser_channels']
                                         if x != '']
         self.spike_array_params = new_params
         return new_params
def get_palatability_ranks(dig_in_mapping, shell=True):
    '''Queries user for palatability rankings for digital inputs (tastants) and
    adds a column to dig_in_mapping DataFrame

    Parameters
    ----------
    dig_in_mapping: pandas.DataFrame,
        DataFrame with at least columns 'dig_in' and 'name', for mapping
        digital input channel number to a str name
    '''
    dim = dig_in_mapping.copy()
    tmp = dict.fromkeys(dim['name'], 0)
    filler = userIO.dictIO(tmp, shell=shell)
    tmp = filler.fill_dict('Rank Palatability\n1 for the lowest\n'
                           'Leave blank to exclude from palatability analysis')
    dim['palatability_rank'] = dim['name'].map(tmp)
    return dim
    def _order_dirs(self, shell=None):
        '''set order of redcording directories
        '''
        if shell is None:
            shell = self.shell
        self.recording_dirs = [x[:-1] if x.endswith('/') else x
                               for x in self.recording_dirs]
        top_dirs = {os.path.basename(x): os.path.dirname(x)
                    for x in self.recording_dirs}
        file_dirs = list(top_dirs.keys())
        order_dict = dict.fromkeys(file_dirs, 0)
        tmp = userIO.dictIO(order_dict, shell=shell)
        order_dict = tmp.fill_dict(prompt=('Set order of recordings (1-%i)\n'
                                           'Leave blank to delete directory'
                                           ' from list'))
        if order_dict is None:
            return

        file_dirs = [k for k, v in order_dict.items()
                     if v is not None and v != 0]
        file_dirs = sorted(file_dirs, key=order_dict.get)
        file_dirs = [os.path.join(top_dirs.get(x), x) for x in file_dirs]
        self.recording_dirs = file_dirs
Ejemplo n.º 7
0
def get_cell_types(cluster_names, shell=True):
    '''Queries user to identify cluster as multiunit vs single-unit, regular vs
    fast spiking

    Parameters
    ----------
    shell : bool (optional),
        True if command-line interface desired, False for GUI (default)

    Returns
    -------
    dict
        with keys 'Single unit', 'Regular spiking', 'Fast spiking' and values
        are 0 or 1 for each  key
    '''
    query = {
        'Single Unit': False,
        'Regular Spiking': False,
        'Fast Spiking': False
    }
    new_query = {}
    for name in cluster_names:
        new_query[name] = query.copy()

    query = userIO.dictIO(new_query, shell=shell)
    ans = query.fill_dict()
    if ans is None:
        return None
    out = []
    for name in cluster_names:
        c = {}
        c['single_unit'] = int(ans[name]['Single Unit'])
        c['regular_spiking'] = int(ans[name]['Regular Spiking'])
        c['fast_spiking'] = int(ans[name]['Fast Spiking'])
        out.append(c.copy())
    return out
Ejemplo n.º 8
0
    def initParams(self, data_quality='clean', emg_port=None,
                   emg_channels=None, shell=False, dig_in_names=None,
                   dig_out_names=None,
                   spike_array_params=None,
                   psth_params=None,
                   confirm_all=False):
        '''
        Initializes basic default analysis parameters that can be customized
        before running processing methods
        Can provide data_quality as 'clean' or 'noisy' to preset some
        parameters that are useful for the different types. Best practice is to
        run as clean (default) and to re-run as noisy if you notice that a lot
        of electrodes are cutoff early
        '''

        # Get parameters from info.rhd
        file_dir = self.data_dir
        rec_info = dio.rawIO.read_rec_info(file_dir, shell)
        ports = rec_info.pop('ports')
        channels = rec_info.pop('channels')
        sampling_rate = rec_info['amplifier_sampling_rate']
        self.rec_info = rec_info
        self.sampling_rate = sampling_rate

        # Get default parameters for blech_clust
        clustering_params = deepcopy(dio.params.clustering_params)
        data_params = deepcopy(dio.params.data_params[data_quality])
        bandpass_params = deepcopy(dio.params.bandpass_params)
        spike_snapshot = deepcopy(dio.params.spike_snapshot)
        if spike_array_params is None:
            spike_array_params = deepcopy(dio.params.spike_array_params)
        if psth_params is None:
            psth_params = deepcopy(dio.params.psth_params)

        # Ask for emg port & channels
        if emg_port is None and not shell:
            q = eg.ynbox('Do you have an EMG?', 'EMG')
            if q:
                emg_port = userIO.select_from_list('Select EMG Port:',
                                                   ports, 'EMG Port',
                                                   shell=shell)
                emg_channels = userIO.select_from_list(
                    'Select EMG Channels:',
                    [y for x, y in
                     zip(ports, channels)
                     if x == emg_port],
                    title='EMG Channels',
                    multi_select=True, shell=shell)

        elif emg_port is None and shell:
            print('\nNo EMG port given.\n')

        electrode_mapping, emg_mapping = dio.params.flatten_channels(
            ports,
            channels,
            emg_port=emg_port,
            emg_channels=emg_channels)
        self.electrode_mapping = electrode_mapping
        self.emg_mapping = emg_mapping

        # Get digital input names and spike array parameters
        if rec_info.get('dig_in'):
            if dig_in_names is None:
                dig_in_names = dict.fromkeys(['dig_in_%i' % x
                                              for x in rec_info['dig_in']])
                name_filler = userIO.dictIO(dig_in_names, shell=shell)
                dig_in_names = name_filler.fill_dict('Enter names for '
                                                     'digital inputs:')
                if dig_in_names is None or \
                   any([x is None for x in dig_in_names.values()]):
                    raise ValueError('Must name all dig_ins')

                dig_in_names = list(dig_in_names.values())

            if spike_array_params['laser_channels'] is None:
                laser_dict = dict.fromkeys(dig_in_names, False)
                laser_filler = userIO.dictIO(laser_dict, shell=shell)
                laser_dict = laser_filler.fill_dict('Select any lasers:')
                if laser_dict is None:
                    laser_channels = []
                else:
                    laser_channels = [i for i, v
                                      in zip(rec_info['dig_in'],
                                             laser_dict.values()) if v]

                spike_array_params['laser_channels'] = laser_channels
            else:
                laser_dict = dict.fromkeys(dig_in_names, False)
                for lc in spike_array_params['laser_channels']:
                    laser_dict[dig_in_names[lc]] = True


            if spike_array_params['dig_ins_to_use'] is None:
                di = [x for x in rec_info['dig_in']
                      if x not in laser_channels]
                dn = [dig_in_names[x] for x in di]
                spike_dig_dict = dict.fromkeys(dn, True)
                filler = userIO.dictIO(spike_dig_dict, shell=shell)
                spike_dig_dict = filler.fill_dict('Select digital inputs '
                                                  'to use for making spike'
                                                  ' arrays:')
                if spike_dig_dict is None:
                    spike_dig_ins = []
                else:
                    spike_dig_ins = [x for x, y in
                                     zip(di, spike_dig_dict.values())
                                     if y]

                spike_array_params['dig_ins_to_use'] = spike_dig_ins


            dim = pd.DataFrame([(x, y) for x, y in zip(rec_info['dig_in'],
                                                       dig_in_names)],
                               columns=['dig_in', 'name'])
            dim['laser'] = dim['name'].apply(lambda x: laser_dict.get(x))
            self.dig_in_mapping = dim.copy()

        # Get digital output names
        if rec_info.get('dig_out'):
            if dig_out_names is None:
                dig_out_names = dict.fromkeys(['dig_out_%i' % x
                                              for x in rec_info['dig_out']])
                name_filler = userIO.dictIO(dig_out_names, shell=shell)
                dig_out_names = name_filler.fill_dict('Enter names for '
                                                      'digital outputs:')
                if dig_out_names is None or \
                   any([x is None for x in dig_out_names.values()]):
                    raise ValueError('Must name all dig_outs')

                dig_out_names = list(dig_out_names.values())

            self.dig_out_mapping = pd.DataFrame([(x, y) for x, y in
                                                 zip(rec_info['dig_out'],
                                                     dig_out_names)],
                                                columns=['dig_out', 'name'])

        # Store clustering parameters
        self.clust_params = {'file_dir': file_dir,
                             'data_quality': data_quality,
                             'sampling_rate': sampling_rate,
                             'clustering_params': clustering_params,
                             'data_params': data_params,
                             'bandpass_params': bandpass_params,
                             'spike_snapshot': spike_snapshot}

        # Store and confirm spike array parameters
        spike_array_params['sampling_rate'] = sampling_rate
        self.spike_array_params = spike_array_params
        self.psth_params = psth_params
        if not confirm_all:
            prompt = ('\n----------\nSpike Array Parameters\n----------\n'
                      + dp.print_dict(spike_array_params) +
                      '\nAre these parameters good?')
            q_idx = userIO.ask_user(prompt, ('Yes', 'Edit'), shell=shell)
            if q_idx == 1:
                self.edit_spike_array_parameters(shell=shell)

            # Edit and store psth parameters
            prompt = ('\n----------\nPSTH Parameters\n----------\n'
                      + dp.print_dict(psth_params) +
                      '\nAre these parameters good?')
            q_idx = userIO.ask_user(prompt, ('Yes', 'Edit'), shell=shell)
            if q_idx == 1:
                self.edit_psth_parameters(shell=shell)

        self.save()
Ejemplo n.º 9
0
def edit_clusters(clusters, fs, shell=False):
    '''Handles editing of a cluster group until a user has a single cluster
    they are satisfied with

    Parameters
    ----------
    clusters : list of dict
        list of dictionaries each defining clusters of spikes
    fs : float
        sampling rate in Hz
    shell : bool
        set True if command-line interface is desires, False for GUI (default)

    Returns
    -------
    dict
        dict representing the resulting cluster from manipulations, None if
        aborted
    '''
    clusters = deepcopy(clusters)
    quit_flag = False
    while not quit_flag:
        if len(clusters) == 1:
            # One cluster, ask if they want to keep or split
            idx = userIO.ask_user('What would you like to do with %s?' %
                                  clusters[0]['Cluster Name'],
                                  choices=[
                                      'Split', 'Waveform View', 'Raster View',
                                      'Keep', 'Abort'
                                  ],
                                  shell=shell)
            if idx == 1:
                fig = plot_cluster(clusters[0])
                plt.show()
                continue
            elif idx == 2:
                plot_raster(clusters)
                plt.show()
                continue
            elif idx == 3:
                return clusters[0]
            elif idx == 4:
                return None

            old_clust = clusters[0]
            clusters = split_cluster(clusters[0], fs, shell=shell)
            if clusters is None or clusters == []:
                return None
            figs = []
            for i, c in enumerate(clusters):
                tmp_fig = plot_cluster(c, i)
                figs.append(tmp_fig)

            f2, ax2 = plot_pca_view(clusters)
            plt.show()
            query = {'Clusters to keep (indices)': []}
            query = userIO.dictIO(query, shell=shell)
            ans = query.fill_dict()
            if ans['Clusters to keep (indices)'][0] == '':
                print('Reset to before split')
                clusters = [old_clust]
                continue

            ans = [int(x) for x in ans['Clusters to keep (indices)']]
            new_clusters = [clusters[x] for x in ans]
            del clusters
            clusters = new_clusters
            del new_clusters
        else:
            idx = userIO.ask_user(
                ('You have %i clusters. '
                 'What would you like to do?') % len(clusters),
                choices=[
                    'Merge', 'Waveform View', 'Raster View', 'Keep', 'Abort'
                ],
                shell=shell)
            if idx == 0:
                # Automatically merge multiple clusters
                cluster = merge_clusters(clusters, fs)
                print('%i clusters merged into %s' %
                      (len(clusters), cluster['Cluster Name']))
                clusters = [cluster]
            elif idx == 1:
                for c in clusters:
                    plot_cluster(c)

                plt.show()
                continue
            elif idx == 2:
                plot_raster(clusters)
                plt.show()
                continue
            elif idx == 3:
                return clusters
            elif idx == 4:
                return None
Ejemplo n.º 10
0
def split_cluster(cluster, fs, params=None, shell=True):
    '''Use GMM to re-cluster a single cluster into a number of sub clusters

    Parameters
    ----------
    cluster : dict, cluster metrics and data
    params : dict (optional), parameters for reclustering
        with keys: 'Number of Clusters', 'Max Number of Iterations',
                    'Convergence Criterion', 'GMM random restarts'
    shell : bool , optional, whether to use shell or GUI input (defualt=True)

    Returns
    -------
    clusters : list of dicts
        resultant clusters from split
    '''
    if params is None:
        clustering_params = {
            'Number of Clusters': 2,
            'Max Number of Iterations': 1000,
            'Convergence Criterion': 0.00001,
            'GMM random restarts': 10
        }
        params_filler = userIO.dictIO(clustering_params, shell=shell)
        params = params_filler.fill_dict()

    if params is None:
        return None

    n_clusters = int(params['Number of Clusters'])
    n_iter = int(params['Max Number of Iterations'])
    thresh = float(params['Convergence Criterion'])
    n_restarts = int(params['GMM random restarts'])

    g = GaussianMixture(n_components=n_clusters,
                        covariance_type='full',
                        tol=thresh,
                        max_iter=n_iter,
                        n_init=n_restarts)
    g.fit(cluster['data'])

    out_clusters = []
    if g.converged_:
        spike_waveforms = cluster['spike_waveforms']
        spike_times = cluster['spike_times']
        data = cluster['data']
        predictions = g.predict(data)
        for c in range(n_clusters):
            clust_idx = np.where(predictions == c)[0]
            tmp_clust = deepcopy(cluster)
            clust_id = str(cluster['cluster_id']) + \
                bytes([b'A'[0]+c]).decode('utf-8')
            clust_name = 'E%iS%i_cluster%s' % \
                (cluster['electrode'], cluster['solution'], clust_id)
            clust_waveforms = spike_waveforms[clust_idx]
            clust_times = spike_times[clust_idx]
            clust_data = data[clust_idx, :]
            clust_log = cluster['manipulations'] + \
                '\nSplit %s with parameters: ' \
                '\n%s\nCluster %i from split results. Named %s' \
                % (cluster['Cluster Name'], dp.print_dict(params),
                   c, clust_name)
            tmp_clust['Cluster Name'] = clust_name
            tmp_clust['cluster_id'] = clust_id
            tmp_clust['data'] = clust_data
            tmp_clust['spike_times'] = clust_times
            tmp_clust['spike_waveforms'] = clust_waveforms

            tmp_isi, tmp_violations1, tmp_violations2 = \
                get_ISI_and_violations(clust_times, fs)
            tmp_clust['ISI'] = tmp_isi
            tmp_clust['1ms_violations'] = tmp_violations1
            tmp_clust['2ms_violations'] = tmp_violations2
            out_clusters.append(deepcopy(tmp_clust))

    return out_clusters
Ejemplo n.º 11
0
def sort_units(file_dir, fs, shell=False):
    '''Allows user to sort clustered units

    Parameters
    ----------
    file_dir : str, path to recording directory
    fs : float, sampling rate in Hz
    shell : bool
        True for command-line interface, False for GUI (default)
    '''
    hf5_name = h5io.get_h5_filename(file_dir)
    hf5_file = os.path.join(file_dir, hf5_name)
    sorting_log = hf5_file.replace('.h5', '_sorting.log')
    metrics_dir = os.path.join(file_dir, 'sorted_unit_metrics')
    if not os.path.exists(metrics_dir):
        os.mkdir(metrics_dir)
    quit_flag = False

    # Start loop to label a cluster
    clust_info = {
        'Electrode': 0,
        'Clusters in solution': 7,
        'Cluster Numbers': [],
        'Edit Clusters': False
    }
    clust_query = userIO.dictIO(clust_info, shell=shell)
    print(('Beginning spike sorting for: \n\t%s\n'
           'Sorting Log written to: \n\t%s') % (hf5_file, sorting_log))
    print(('To select multiple clusters, simply type in the numbers of each\n'
           'cluster comma-separated, however, you MUST check Edit Clusters in'
           'order\nto split clusters, merges will be assumed'
           ' from multiple cluster selection.'))
    print(('\nIf using GUI input, hit cancel at any point to abort sorting'
           ' of current cluster.\nIf using shell interface,'
           ' type abort at any time.\n'))
    while not quit_flag:
        clust_info = clust_query.fill_dict()
        if clust_info is None:
            quit_flag = True
            break
        clusters = []
        for c in clust_info['Cluster Numbers']:
            tmp = get_cluster_data(file_dir, clust_info['Electrode'],
                                   clust_info['Clusters in solution'], int(c),
                                   fs)
            clusters.append(tmp)
        if len(clusters) == 0:
            quit_flag = True
            break
        if clust_info['Edit Clusters']:
            clusters = edit_clusters(clusters, fs, shell)
            if isinstance(clusters, dict):
                clusters = [clusters]
            if clusters is None:
                quit_flag = True
                break

        cell_types = get_cell_types([x['Cluster Name'] for x in clusters],
                                    shell)
        if cell_types is None:
            quit_flag = True
            break
        else:
            for x, y in zip(clusters, cell_types):
                x.update(y)
        for unit in clusters:
            label_single_unit(hf5_file, unit, fs, sorting_log, metrics_dir)

        q = input('Do you wish to continue sorting units? (y/n) >> ')
        if q == 'n':
            quit_flag = True