Beispiel #1
0
def confirm_parameter_dict(params, prompt, shell=False):
    '''Shows user a dictionary and asks them to confirm that the values are
    correct. If not they have an option to edit the dict.

    Parameters
    ----------
    params: dict
        values in dict can be int, float, str, bool, list, dict or None
    prompt: str
        prompt to show user
    shell : bool (optional)
        True to use command line interface
        False (default) for GUI

    Returns
    -------
    dict
       lists are returned as lists of str so other types m ust be cast manually
       by  user
    '''
    prompt = ('----------\n%s\n----------\n%s\nAre these parameters good?' %
              (prompt, dp.print_dict(params)))
    q = ask_user(prompt, choices=['Yes', 'Edit', 'Cancel'], shell=shell)
    if q == 2:
        return None
    elif q == 0:
        return params
    else:
        new_params = fill_dict(params, 'Enter new values:', shell=shell)
        return new_params
Beispiel #2
0
def read_rec_info(file_dir, shell=False):
    '''Reads the info.rhd file to get relevant parameters.
    Parameters
    ----------
    file_dir : str, path to recording directory

    Returns
    -------
    dict, necessary analysis info from info.rhd
        fields: amplifier_sampling_rate, dig_in_sampling_rate, notch_filter,
                ports (list, corresponds to channels), channels (list)

    Throws
    ------
    FileNotFoundError : if info.rhd is not in file_dir
    '''
    info_file = os.path.join(file_dir, 'info.rhd')
    if not os.path.isfile(info_file):
        raise FileNotFoundError('info.rhd file not found in %s' % file_dir)
    out = {}
    print('Reading info.rhd file...')
    info = load_intan_rhd_format.read_data(info_file)

    freq_params = info['frequency_parameters']
    notch_freq = freq_params['notch_filter_frequency']
    amp_fs = freq_params['amplifier_sample_rate']
    dig_in_fs = freq_params['board_dig_in_sample_rate']
    out = {
        'amplifier_sampling_rate': amp_fs,
        'dig_in_sampling_rate': dig_in_fs,
        'notch_filter': notch_freq
    }

    amp_ch = info['amplifier_channels']
    ports = [x['port_prefix'] for x in amp_ch]
    channels = [x['native_order'] for x in amp_ch]

    out['ports'] = ports
    out['channels'] = channels
    out['num_channels'] = len(channels)

    if info.get('board_dig_in_channels'):
        dig_in = info['board_dig_in_channels']
        din = [x['native_order'] for x in dig_in]
        out['dig_in'] = din

    if info.get('board_dig_out_channels'):
        dig_out = info['board_dig_out_channels']
        dout = [x['native_order'] for x in dig_out]
        out['dig_out'] = dout

    out['file_type'] = get_recording_filetype(file_dir, shell)

    print('\nRecording Info\n--------------\n')
    print(dp.print_dict(out))
    return out
Beispiel #3
0
    def make_unit_arrays(self, shell=False):
        params = self.spike_array_params
        query = ('\n----------\nParameters for Spike Array Creation'
                 '\n----------\ntimes in ms\n%s\nWould you like to'
                 ' continue with these parameters?') % dp.print_dict(params)
        q_idx = userIO.ask_user(query, choices=('Continue', 'Abort', 'Edit'),
                                shell=shell)
        if q_idx == 1:
            return
        elif q_idx == 2:
            params = self.edit_spike_array_parameters(shell=shell)
            if params is None:
                return

        print('Generating unit arrays with parameters:\n----------')
        print(dp.print_dict(params, tabs=1))
        ss.make_spike_arrays(self.h5_file, params)
        self.process_status['make_unit_arrays'] = True
        self.save()
def write_params(file_name, params):
    '''Writes parameters into a file for use by blech_process.py

    Parameters
    ----------
    file_name : str, path to .params file to write params in
    params : dict, dictionary of parameters with keys:
                   clustering_params, data_params,
                   bandpass_params, spike_snapshot
    '''
    if not file_name.endswith('.params'):
        file_name += '.params'
    print('File: ' + file_name)
    dp.print_dict(params)
    with open(file_name, 'w') as f:
        for c in clust_param_order:
            print(params['clustering_params'][c], file=f)
        for c in data_param_order:
            print(params['data_params'][c], file=f)
        for c in band_param_order:
            print(params['bandpass_params'][c], file=f)
        for c in spike_snap_order:
            print(params['spike_snapshot'][c], file=f)
        print(params['sampling_rate'], file=f)
def palatability_identity_calculations(rec_dir, pal_ranks=None,
                                       unit_type=None, params=None,
                                       shell=False):
    dat = dataset.load_dataset(rec_dir)
    dim = dat.dig_in_mapping
    if pal_ranks is None:
        dim = get_palatability_ranks(dim, shell=shell)
    elif 'palatability_rank' in dim.columns:
        pass
    else:
        dim['palatability_rank'] = dim['name'].map(pal_ranks)

    dim = dim.dropna(subset=['palatability_rank'])
    dim = dim.reset_index(drop=True)
    num_tastes = len(dim)
    taste_names = dim.name.to_list()

    trial_list = dat.dig_in_trials.copy()
    trial_list = trial_list[[True if x in taste_names else False
                             for x in trial_list.name]]
    num_trials = trial_list.groupby('channel').count()['name'].unique()
    if len(num_trials) > 1:
        raise ValueError('Unequal number of trials for tastes to used')
    else:
        num_trials = num_trials[0]

    dim['num_trials'] = num_trials

    # Get which units to use
    unit_table = h5io.get_unit_table(rec_dir)
    unit_types = ['Single', 'Multi', 'All', 'Custom']
    if unit_type is None:
        q = userIO.ask_user('Which units do you want to use for taste '
                            'discrimination and  palatability analysis?',
                            choices=unit_types,
                            shell=shell)
        unit_type = unit_types[q]

    if unit_type == 'Single':
        chosen_units = unit_table.loc[unit_table['single_unit'],
                                      'unit_num'].to_list()
    elif unit_type == 'Multi':
        chosen_units = unit_table.loc[unit_table['single_unit'] == False,
                                      'unit_num'].to_list()
    elif unit_type == 'All':
        chosen_units = unit_table['unit_num'].to_list()
    else:
        selection = userIO.select_from_list('Select units to use:',
                                            unit_table['unit_num'],
                                            'Select Units',
                                            multi_select=True)
        chosen_units = list(map(int, selection))

    num_units = len(chosen_units)
    unit_table = unit_table.loc[chosen_units]

    # Enter Parameters
    if params is None or params.keys() != default_pal_id_params.keys():
        params = {'window_size': 250, 'window_step': 25,
                  'num_comparison_bins': 5, 'comparison_bin_size': 250,
                  'discrim_p': 0.01, 'pal_deduce_start_time': 700,
                  'pal_deduce_end_time': 1200}
        params = userIO.confirm_parameter_dict(params,
                                               ('Palatability/Identity '
                                                'Calculation Parameters'
                                                '\nTimes in ms'), shell=shell)

    win_size = params['window_size']
    win_step = params['window_step']
    print('Running palatability/identity calculations with parameters:\n%s' %
          dp.print_dict(params))

    with tables.open_file(dat.h5_file, 'r+') as hf5:
        trains_dig_in = hf5.list_nodes('/spike_trains')
        time = trains_dig_in[0].array_time[:]
        bin_times = np.arange(time[0], time[-1] - win_size + win_step,
                             win_step)
        num_bins = len(bin_times)

        palatability = np.empty((num_bins, num_units, num_tastes*num_trials),
                                dtype=int)
        identity = np.empty((num_bins, num_units, num_tastes*num_trials),
                            dtype=int)
        unscaled_response = np.empty((num_bins, num_units, num_tastes*num_trials),
                                     dtype=np.dtype('float64'))
        response  = np.empty((num_bins, num_units, num_tastes*num_trials),
                             dtype=np.dtype('float64'))
        laser = np.empty((num_bins, num_units, num_tastes*num_trials, 2),
                         dtype=float)

        # Fill arrays with data
        print('Filling data arrays...')
        onesies = np.ones((num_bins, num_units, num_trials))
        for i, row in dim.iterrows():
            idx = range(num_trials*i, num_trials*(i+1))
            palatability[:, :, idx] = row.palatability_rank * onesies
            identity[:, :, idx] = row.dig_in * onesies
            for j, u in enumerate(chosen_units):
                for k,t in enumerate(bin_times):
                    t_idx = np.where((time >= t) & (time <= t+win_size))[0]
                    unscaled_response[k, j, idx] = \
                            np.mean(trains_dig_in[i].spike_array[:, u, t_idx],
                                    axis=1)
                    try:
                        lasers[k, j, idx] = \
                            np.vstack((trains_dig_in[i].laser_durations[:],
                                       trains_dig_in[i].laser_onset_lag[:]))
                    except:
                        laser[k, j, idx] = np.zeros((num_trials, 2))

        # Scaling was not done, so:
        response = unscaled_response.copy()

        # Make ancillary_analysis node and put in arrays
        if '/ancillary_analysis' in hf5:
            hf5.remove_node('/ancillary_analysis', recursive=True)

        hf5.create_group('/', 'ancillary_analysis')
        hf5.create_array('/ancillary_analysis', 'palatability', palatability)
        hf5.create_array('/ancillary_analysis', 'identity', identity)
        hf5.create_array('/ancillary_analysis', 'laser', laser)
        hf5.create_array('/ancillary_analysis', 'scaled_neural_response',
                         response)
        hf5.create_array('/ancillary_analysis', 'window_params',
                         np.array([win_size, win_step]))
        hf5.create_array('/ancillary_analysis', 'bin_times', bin_times)
        hf5.create_array('/ancillary_analysis', 'unscaled_neural_response',
                         unscaled_response)

        # for backwards compatibility
        hf5.create_array('/ancillary_analysis', 'params',
                        np.array([win_size, win_step]))
        hf5.create_array('/ancillary_analysis', 'pre_stim', np.array(time[0]))
        hf5.flush()

        # Get unique laser (duration, lag) combinations
        print('Organizing trial data...')
        unique_lasers = np.vstack(list({tuple(row) for row in laser[0, 0, :, :]}))
        unique_lasers = unique_lasers[unique_lasers[:, 1].argsort(), :]
        num_conditions = unique_lasers.shape[0]
        trials = []
        for row in unique_lasers:
            tmp_trials = [j for j in range(num_trials * num_tastes)
                          if np.array_equal(laser[0, 0, j, :], row)]
            trials.append(tmp_trials)

        trials_per_condition = [len(x) for x in trials]
        if not all(x == trials_per_condition[0] for x in trials_per_condition):
            raise ValueError('Different number of trials for each laser condition')

        trials_per_condition = int(trials_per_condition[0] / num_tastes)  #assumes same number of trials per taste per condition
        print('Detected:\n    %i tastes\n    %i laser conditions\n'
              '    %i trials per condition per taste' %
              (num_tastes, num_conditions, trials_per_condition))
        trials = np.array(trials)

        # Store laser conditions and indices of trials per condition in trial x
        # taste space
        hf5.create_array('/ancillary_analysis', 'trials', trials)
        hf5.create_array('/ancillary_analysis', 'laser_combination_d_l',
                         unique_lasers)
        hf5.flush()

        # Taste Similarity Calculation
        neural_response_laser = np.empty((num_conditions, num_bins,
                                          num_tastes, num_units,
                                          trials_per_condition),
                                         dtype=np.dtype('float64'))
        taste_cosine_similarity = np.empty((num_conditions, num_bins,
                                            num_tastes, num_tastes),
                                           dtype=np.dtype('float64'))
        taste_euclidean_distance = np.empty((num_conditions, num_bins,
                                             num_tastes, num_tastes),
                                            dtype=np.dtype('float64'))

        # Re-format neural responses from bin x unit x (trial*taste) to
        # laser_condition x bin x taste x unit x trial
        print('Reformatting data arrays...')
        for i, trial in enumerate(trials):
            for j, _ in enumerate(bin_times):
                for k, _ in dim.iterrows():
                    idx = np.where((trial >= num_trials*k) &
                                   (trial < num_trials*(k+1)))[0]
                    neural_response_laser[i, j, k, :, :] = \
                            response[j, :, trial[idx]].T

        # Compute taste cosine similarity and euclidean distances
        print('Computing taste cosine similarity and euclidean distances...')
        for i, _ in enumerate(trials):
            for j, _ in enumerate(bin_times):
                for k, _ in dim.iterrows():
                    for l, _ in dim.iterrows():
                        taste_cosine_similarity[i, j, k, l] = \
                                np.mean(cosine_similarity(
                                    neural_response_laser[i, j, k, :, :].T,
                                    neural_response_laser[i, j, l, :, :].T))
                        taste_euclidean_distance[i, j, k, l] = \
                                np.mean(cdist(
                                    neural_response_laser[i, j, k, :, :].T,
                                    neural_response_laser[i, j, l, :, :].T,
                                    metric='euclidean'))

        hf5.create_array('/ancillary_analysis', 'taste_cosine_similarity',
                         taste_cosine_similarity)
        hf5.create_array('/ancillary_analysis', 'taste_euclidean_distance',
                         taste_euclidean_distance)
        hf5.flush()

        # Taste Responsiveness calculations
        bin_params = [params['num_comparison_bins'],
                      params['comparison_bin_size']]
        discrim_p = params['discrim_p']

        responsive_neurons = []
        discriminating_neurons = []
        taste_responsiveness = np.zeros((bin_params[0], num_units, 2))
        new_bin_times = np.arange(0, np.prod(bin_params), bin_params[1])
        baseline = np.where(bin_times < 0)[0]
        print('Computing taste responsiveness and taste discrimination...')
        for i, t in enumerate(new_bin_times):
            places = np.where((bin_times >= t) &
                              (bin_times <= t+bin_params[1]))[0]
            for j, u in enumerate(chosen_units):
                # Check taste responsiveness
                f, p = f_oneway(np.mean(response[places, j, :], axis=0),
                                np.mean(response[baseline, j, :], axis=0))
                if np.isnan(f):
                    f = 0.0
                    p = 1.0

                if p <= discrim_p and u not in responsive_neurons:
                    responsive_neurons.append(u)
                    taste_responsiveness[i, j, 0] = 1

                # Check taste discrimination
                taste_idx = [np.arange(num_trials*k, num_trials*(k+1))
                             for k in range(num_tastes)]
                taste_responses = [np.mean(response[places, j, :][:, k], axis=0)
                                   for k in taste_idx]
                f, p = f_oneway(*taste_responses)
                if np.isnan(f):
                    f = 0.0
                    p = 1.0

                if p <= discrim_p and u not in discriminating_neurons:
                    discriminating_neurons.append(u)

        responsive_neurons = np.sort(responsive_neurons)
        discriminating_neurons = np.sort(discriminating_neurons)

        # Write taste responsive and taste discriminating units to text file
        save_file = os.path.join(rec_dir, 'discriminative_responsive_neurons.txt')
        with open(save_file, 'w') as f:
            print('Taste discriminative neurons', file=f)
            for u in discriminating_neurons:
                print(u, file=f)

            print('Taste responsive neurons', file=f)
            for u in responsive_neurons:
                print(u, file=f)

        hf5.create_array('/ancillary_analysis', 'taste_disciminating_neurons',
                         discriminating_neurons)
        hf5.create_array('/ancillary_analysis', 'taste_responsive_neurons',
                         responsive_neurons)
        hf5.create_array('/ancillary_analysis', 'taste_responsiveness',
                         taste_responsiveness)
        hf5.flush()

        # Get time course of taste discrimibility
        print('Getting taste discrimination time course...')
        p_discrim = np.empty((num_conditions, num_bins, num_tastes, num_tastes,
                              num_units), dtype=np.dtype('float64'))
        for i in range(num_conditions):
            for j, t in enumerate(bin_times):
                for k in range(num_tastes):
                    for l in range(num_tastes):
                        for m in range(num_units):
                            _, p = ttest_ind(neural_response_laser[i, j, k, m, :],
                                             neural_response_laser[i, j, l, m, :],
                                             equal_var = False)
                            if np.isnan(p):
                                p = 1.0

                            p_discrim[i, j, k, l, m] = p

        hf5.create_array('/ancillary_analysis', 'p_discriminability',
                          p_discrim)
        hf5.flush()

        # Palatability Rank Order calculation (if > 2 tastes)
        t_start = params['pal_deduce_start_time']
        t_end = params['pal_deduce_end_time']
        if num_tastes > 2:
            print('Deducing palatability rank order...')
            palatability_rank_order_deduction(rec_dir, neural_response_laser,
                                              unique_lasers,
                                              bin_times, [t_start, t_end])

        # Palatability calculation
        r_spearman = np.zeros((num_conditions, num_bins, num_units))
        p_spearman = np.ones((num_conditions, num_bins, num_units))
        r_pearson = np.zeros((num_conditions, num_bins, num_units))
        p_pearson = np.ones((num_conditions, num_bins, num_units))
        f_identity = np.ones((num_conditions, num_bins, num_units))
        p_identity = np.ones((num_conditions, num_bins, num_units))
        lda_palatability = np.zeros((num_conditions, num_bins))
        lda_identity = np.zeros((num_conditions, num_bins))
        r_isotonic = np.zeros((num_conditions, num_bins, num_units))
        id_pal_regress = np.zeros((num_conditions, num_bins, num_units, 2))
        pairwise_identity = np.zeros((num_conditions, num_bins, num_tastes, num_tastes))
        print('Computing palatability metrics...')

        for i, t in enumerate(trials):
            for j in range(num_bins):
                for k in range(num_units):
                    ranks = rankdata(response[j, k, t])
                    r_spearman[i, j, k], p_spearman[i, j, k] = \
                            spearmanr(ranks, palatability[j, k, t])
                    r_pearson[i, j, k], p_pearson[i, j, k] = \
                            pearsonr(response[j, k, t], palatability[j, k, t])
                    if np.isnan(r_spearman[i, j, k]):
                        r_spearman[i, j, k] = 0.0
                        p_spearman[i, j, k] = 1.0

                    if np.isnan(r_pearson[i, j, k]):
                        r_pearson[i, j, k] = 0.0
                        p_pearson[i, j, k] = 1.0

                    # Isotonic regression of firing against palatability
                    model = IsotonicRegression(increasing = 'auto')
                    model.fit(palatability[j, k, t], response[j, k, t])
                    r_isotonic[i, j, k] = model.score(palatability[j, k, t],
                                                      response[j, k, t])

                    # Multiple Regression of firing rate against palatability and identity
                    # Regress palatability on identity
                    tmp_id = identity[j, k, t].reshape(-1, 1)
                    tmp_pal = palatability[j, k, t].reshape(-1, 1)
                    tmp_resp = response[j, k, t].reshape(-1, 1)
                    model_pi = LinearRegression()
                    model_pi.fit(tmp_id, tmp_pal)
                    pi_residuals = tmp_pal - model_pi.predict(tmp_id)

                    # Regress identity on palatability
                    model_ip = LinearRegression()
                    model_ip.fit(tmp_pal, tmp_id)
                    ip_residuals = tmp_id - model_ip.predict(tmp_pal)

                    # Regress firing on identity
                    model_fi = LinearRegression()
                    model_fi.fit(tmp_id, tmp_resp)
                    fi_residuals = tmp_resp - model_fi.predict(tmp_id)

                    # Regress firing on palatability
                    model_fp = LinearRegression()
                    model_fp.fit(tmp_pal, tmp_resp)
                    fp_residuals = tmp_resp - model_fp.predict(tmp_pal)

                    # Get partial correlation coefficient of response with identity
                    idp_reg0, p = pearsonr(fp_residuals, ip_residuals)
                    if np.isnan(idp_reg0):
                        idp_reg0 = 0.0

                    idp_reg1, p = pearsonr(fi_residuals, pi_residuals)
                    if np.isnan(idp_reg1):
                        idp_reg1 = 0.0

                    id_pal_regress[i, j, k, 0] = idp_reg0
                    id_pal_regress[i, j, k, 1] = idp_reg1

                    # Identity Calculation
                    samples = []
                    for _, row in dim.iterrows():
                        taste = row.dig_in
                        samples.append([trial for trial in t
                                        if identity[j, k, trial] == taste])

                    tmp_resp = [response[j, k, sample] for sample in samples]
                    f_identity[i, j, k], p_identity[i, j, k] = f_oneway(*tmp_resp)
                    if np.isnan(f_identity[i, j, k]):
                        f_identity[i, j, k] = 0.0
                        p_identity[i, j, k] = 1.0


                # Linear Discriminant analysis for palatability
                X = response[j, :, t]
                Y = palatability[j, 0, t]
                test_results = []
                c_validator = LeavePOut(1)
                for train, test in c_validator.split(X, Y):
                    model = LDA()
                    model.fit(X[train, :], Y[train])
                    tmp = np.mean(model.predict(X[test]) == Y[test])
                    test_results.append(tmp)

                lda_palatability[i, j] = np.mean(test_results)

                # Linear Discriminant analysis for identity
                Y = identity[j, 0, t]
                test_results = []
                c_validator = LeavePOut(1)
                for train, test in c_validator.split(X, Y):
                    model = LDA()
                    model.fit(X[train, :], Y[train])
                    tmp = np.mean(model.predict(X[test]) == Y[test])
                    test_results.append(tmp)

                lda_identity[i, j] = np.mean(test_results)

                # Pairwise Identity Calculation
                for _, r1 in dim.iterrows():
                    for _, r2 in dim.iterrows():
                        t1 = r1.dig_in
                        t2 = r2.dig_in
                        tmp_trials = np.where((identity[j, 0, :] == t1) |
                                              (identity[j, 0, :] == t2))[0]
                        idx = [trial for trial in t if trial in tmp_trials]
                        X = response[j, :, idx]
                        Y = identity[j, 0, idx]
                        test_results = []
                        c_validator = StratifiedShuffleSplit(n_splits=10,
                                                             test_size=0.25,
                                                             random_state=0)
                        for train, test in c_validator.split(X, Y):
                            model = GaussianNB()
                            model.fit(X[train, :], Y[train])
                            tmp_score = model.score(X[test, :], Y[test])
                            test_results.append(tmp_score)

                        pairwise_identity[i, j, t1, t2] = np.mean(test_results)

        hf5.create_array('/ancillary_analysis', 'r_pearson', r_pearson)
        hf5.create_array('/ancillary_analysis', 'r_spearman', r_spearman)
        hf5.create_array('/ancillary_analysis', 'p_pearson', p_pearson)
        hf5.create_array('/ancillary_analysis', 'p_spearman', p_spearman)
        hf5.create_array('/ancillary_analysis', 'lda_palatability', lda_palatability)
        hf5.create_array('/ancillary_analysis', 'lda_identity', lda_identity)
        hf5.create_array('/ancillary_analysis', 'r_isotonic', r_isotonic)
        hf5.create_array('/ancillary_analysis', 'id_pal_regress', id_pal_regress)
        hf5.create_array('/ancillary_analysis', 'f_identity', f_identity)
        hf5.create_array('/ancillary_analysis', 'p_identity', p_identity)
        hf5.create_array('/ancillary_analysis', 'pairwise_NB_identity', pairwise_identity)
        hf5.flush()
Beispiel #6
0
    def blech_clust_run(self, data_quality=None, accept_params=False,
                        shell=False):
        '''
        Write clustering parameters to file and
        Run blech_process on each electrode using GNU parallel

        Parameters
        ----------
        data_quality : {'clean', 'noisy', None (default)}
            set if you want to change the data quality parameters for cutoff
            and spike detection before running clustering. These parameters are
            automatically set as "clean" during initial parameter setup
        accept_params : bool, False (default)
            set to True in order to skip popup confirmation of parameters when
            running
        '''
        if data_quality:
            tmp = deepcopy(dio.params.data_params.get(data_quality))
            if tmp:
                self.clust_params['data_params'] = tmp
                self.clust_params['data_quality'] = data_quality
            else:
                raise ValueError('%s is not a valid data_quality preset. Must '
                                 'be "clean" or "noisy" or None.')

        # Check if they are OK with the parameters that will be used
        if not accept_params:
            if not shell:
                q = eg.ynbox(dp.print_dict(self.clust_params)
                             + '\n Are these parameters OK?',
                             'Check Extraction and Clustering Parameters')
            else:
                q = input(dp.print_dict(self.clust_params)
                          + '\n Are these paramters OK? (y/n):  ')
                if q == 'y':
                    q = True
                else:
                    False

            if not q:
                return

        print('\nRunning Blech Clust\n-------------------')
        print('Parameters\n%s' % dp.print_dict(self.clust_params))

        # Write parameters into .params file
        self.param_file = os.path.join(self.data_dir, self.data_name+'.params')
        dio.params.write_params(self.param_file, self.clust_params)

        # Create folders for saving things within recording dir
        data_dir = self.data_dir
        directories = ['spike_waveforms', 'spike_times',
                       'clustering_results',
                       'Plots', 'memory_monitor_clustering']
        for d in directories:
            tmp_dir = os.path.join(data_dir, d)
            if os.path.exists(tmp_dir):
                shutil.rmtree(tmp_dir)

            os.mkdir(tmp_dir)

        # Set file for clusting log
        self.clustering_log = os.path.join(data_dir, 'results.log')
        if os.path.exists(self.clustering_log):
            os.remove(self.clustering_log)

        process_path = os.path.realpath(__file__)
        process_path = os.path.join(os.path.dirname(process_path),
                                    'blech_process.py')
        em = self.electrode_mapping
        if 'dead' in em.columns:
            electrodes = em.Electrode[em['dead'] == False].tolist()
        else:
            electrodes = em.Electrode.tolist()

        my_env = os.environ
        my_env['OMP_NUM_THREADS'] = '1'  # possibly not necesary
        cpu_count = int(multiprocessing.cpu_count())-1
        process_call = ['parallel', '-k', '-j', str(cpu_count), '--noswap',
                        '--load', '100%', '--progress', '--memfree', '4G',
                        '--retry-failed', '--joblog', self.clustering_log,
                        'python', process_path, '{1}', self.data_dir, ':::']
        process_call.extend([str(x) for x in electrodes])
        subprocess.call(process_call, env=my_env)
        self.process_status['blech_clust_run'] = True
        self.save()
        print('Clustering Complete\n------------------')
Beispiel #7
0
    def initParams(self, data_quality='clean', emg_port=None,
                   emg_channels=None, shell=False, dig_in_names=None,
                   dig_out_names=None,
                   spike_array_params=None,
                   psth_params=None,
                   confirm_all=False):
        '''
        Initializes basic default analysis parameters that can be customized
        before running processing methods
        Can provide data_quality as 'clean' or 'noisy' to preset some
        parameters that are useful for the different types. Best practice is to
        run as clean (default) and to re-run as noisy if you notice that a lot
        of electrodes are cutoff early
        '''

        # Get parameters from info.rhd
        file_dir = self.data_dir
        rec_info = dio.rawIO.read_rec_info(file_dir, shell)
        ports = rec_info.pop('ports')
        channels = rec_info.pop('channels')
        sampling_rate = rec_info['amplifier_sampling_rate']
        self.rec_info = rec_info
        self.sampling_rate = sampling_rate

        # Get default parameters for blech_clust
        clustering_params = deepcopy(dio.params.clustering_params)
        data_params = deepcopy(dio.params.data_params[data_quality])
        bandpass_params = deepcopy(dio.params.bandpass_params)
        spike_snapshot = deepcopy(dio.params.spike_snapshot)
        if spike_array_params is None:
            spike_array_params = deepcopy(dio.params.spike_array_params)
        if psth_params is None:
            psth_params = deepcopy(dio.params.psth_params)

        # Ask for emg port & channels
        if emg_port is None and not shell:
            q = eg.ynbox('Do you have an EMG?', 'EMG')
            if q:
                emg_port = userIO.select_from_list('Select EMG Port:',
                                                   ports, 'EMG Port',
                                                   shell=shell)
                emg_channels = userIO.select_from_list(
                    'Select EMG Channels:',
                    [y for x, y in
                     zip(ports, channels)
                     if x == emg_port],
                    title='EMG Channels',
                    multi_select=True, shell=shell)

        elif emg_port is None and shell:
            print('\nNo EMG port given.\n')

        electrode_mapping, emg_mapping = dio.params.flatten_channels(
            ports,
            channels,
            emg_port=emg_port,
            emg_channels=emg_channels)
        self.electrode_mapping = electrode_mapping
        self.emg_mapping = emg_mapping

        # Get digital input names and spike array parameters
        if rec_info.get('dig_in'):
            if dig_in_names is None:
                dig_in_names = dict.fromkeys(['dig_in_%i' % x
                                              for x in rec_info['dig_in']])
                name_filler = userIO.dictIO(dig_in_names, shell=shell)
                dig_in_names = name_filler.fill_dict('Enter names for '
                                                     'digital inputs:')
                if dig_in_names is None or \
                   any([x is None for x in dig_in_names.values()]):
                    raise ValueError('Must name all dig_ins')

                dig_in_names = list(dig_in_names.values())

            if spike_array_params['laser_channels'] is None:
                laser_dict = dict.fromkeys(dig_in_names, False)
                laser_filler = userIO.dictIO(laser_dict, shell=shell)
                laser_dict = laser_filler.fill_dict('Select any lasers:')
                if laser_dict is None:
                    laser_channels = []
                else:
                    laser_channels = [i for i, v
                                      in zip(rec_info['dig_in'],
                                             laser_dict.values()) if v]

                spike_array_params['laser_channels'] = laser_channels
            else:
                laser_dict = dict.fromkeys(dig_in_names, False)
                for lc in spike_array_params['laser_channels']:
                    laser_dict[dig_in_names[lc]] = True


            if spike_array_params['dig_ins_to_use'] is None:
                di = [x for x in rec_info['dig_in']
                      if x not in laser_channels]
                dn = [dig_in_names[x] for x in di]
                spike_dig_dict = dict.fromkeys(dn, True)
                filler = userIO.dictIO(spike_dig_dict, shell=shell)
                spike_dig_dict = filler.fill_dict('Select digital inputs '
                                                  'to use for making spike'
                                                  ' arrays:')
                if spike_dig_dict is None:
                    spike_dig_ins = []
                else:
                    spike_dig_ins = [x for x, y in
                                     zip(di, spike_dig_dict.values())
                                     if y]

                spike_array_params['dig_ins_to_use'] = spike_dig_ins


            dim = pd.DataFrame([(x, y) for x, y in zip(rec_info['dig_in'],
                                                       dig_in_names)],
                               columns=['dig_in', 'name'])
            dim['laser'] = dim['name'].apply(lambda x: laser_dict.get(x))
            self.dig_in_mapping = dim.copy()

        # Get digital output names
        if rec_info.get('dig_out'):
            if dig_out_names is None:
                dig_out_names = dict.fromkeys(['dig_out_%i' % x
                                              for x in rec_info['dig_out']])
                name_filler = userIO.dictIO(dig_out_names, shell=shell)
                dig_out_names = name_filler.fill_dict('Enter names for '
                                                      'digital outputs:')
                if dig_out_names is None or \
                   any([x is None for x in dig_out_names.values()]):
                    raise ValueError('Must name all dig_outs')

                dig_out_names = list(dig_out_names.values())

            self.dig_out_mapping = pd.DataFrame([(x, y) for x, y in
                                                 zip(rec_info['dig_out'],
                                                     dig_out_names)],
                                                columns=['dig_out', 'name'])

        # Store clustering parameters
        self.clust_params = {'file_dir': file_dir,
                             'data_quality': data_quality,
                             'sampling_rate': sampling_rate,
                             'clustering_params': clustering_params,
                             'data_params': data_params,
                             'bandpass_params': bandpass_params,
                             'spike_snapshot': spike_snapshot}

        # Store and confirm spike array parameters
        spike_array_params['sampling_rate'] = sampling_rate
        self.spike_array_params = spike_array_params
        self.psth_params = psth_params
        if not confirm_all:
            prompt = ('\n----------\nSpike Array Parameters\n----------\n'
                      + dp.print_dict(spike_array_params) +
                      '\nAre these parameters good?')
            q_idx = userIO.ask_user(prompt, ('Yes', 'Edit'), shell=shell)
            if q_idx == 1:
                self.edit_spike_array_parameters(shell=shell)

            # Edit and store psth parameters
            prompt = ('\n----------\nPSTH Parameters\n----------\n'
                      + dp.print_dict(psth_params) +
                      '\nAre these parameters good?')
            q_idx = userIO.ask_user(prompt, ('Yes', 'Edit'), shell=shell)
            if q_idx == 1:
                self.edit_psth_parameters(shell=shell)

        self.save()
Beispiel #8
0
    def __str__(self):
        '''Put all information about dataset in string format

        Returns
        -------
        str : representation of dataset object
        '''
        out = [self.data_name]
        out.append('Data directory:  '+self.data_dir)
        out.append('Object creation date: '
                   + self.dataset_creation_date.strftime('%m/%d/%y'))
        out.append('Dataset Save File: ' + self.save_file)

        if hasattr(self, 'raw_h5_file'):
            out.append('Deleted Raw h5 file: '+self.raw_h5_file)
            out.append('h5 File: '+self.h5_file)
            out.append('')

        out.append('--------------------')
        out.append('Processing Status')
        out.append('--------------------')
        out.append(dp.print_dict(self.process_status))
        out.append('')

        if not hasattr(self, 'rec_info'):
            return '\n'.join(out)

        info = self.rec_info

        out.append('--------------------')
        out.append('Recording Info')
        out.append('--------------------')
        out.append(dp.print_dict(self.rec_info))
        out.append('')

        out.append('--------------------')
        out.append('Electrodes')
        out.append('--------------------')
        out.append(dp.print_dataframe(self.electrode_mapping))
        out.append('')

        if hasattr(self, 'car_electrodes'):
            out.append('--------------------')
            out.append('CAR Groups')
            out.append('--------------------')
            headers = ['Group %i' % x for x in range(len(self.car_electrodes))]
            out.append(dp.print_list_table(self.car_electrodes, headers))
            out.append('')

        if not self.emg_mapping.empty:
            out.append('--------------------')
            out.append('EMG')
            out.append('--------------------')
            out.append(dp.print_dataframe(self.emg_mapping))
            out.append('')

        if info.get('dig_in'):
            out.append('--------------------')
            out.append('Digital Input')
            out.append('--------------------')
            out.append(dp.print_dataframe(self.dig_in_mapping))
            out.append('')

        if info.get('dig_out'):
            out.append('--------------------')
            out.append('Digital Output')
            out.append('--------------------')
            out.append(dp.print_dataframe(self.dig_out_mapping))
            out.append('')

        out.append('--------------------')
        out.append('Clustering Parameters')
        out.append('--------------------')
        out.append(dp.print_dict(self.clust_params))
        out.append('')

        out.append('--------------------')
        out.append('Spike Array Parameters')
        out.append('--------------------')
        out.append(dp.print_dict(self.spike_array_params))
        out.append('')

        out.append('--------------------')
        out.append('PSTH Parameters')
        out.append('--------------------')
        out.append(dp.print_dict(self.psth_params))
        out.append('')

        return '\n'.join(out)
Beispiel #9
0
def split_cluster(cluster, fs, params=None, shell=True):
    '''Use GMM to re-cluster a single cluster into a number of sub clusters

    Parameters
    ----------
    cluster : dict, cluster metrics and data
    params : dict (optional), parameters for reclustering
        with keys: 'Number of Clusters', 'Max Number of Iterations',
                    'Convergence Criterion', 'GMM random restarts'
    shell : bool , optional, whether to use shell or GUI input (defualt=True)

    Returns
    -------
    clusters : list of dicts
        resultant clusters from split
    '''
    if params is None:
        clustering_params = {
            'Number of Clusters': 2,
            'Max Number of Iterations': 1000,
            'Convergence Criterion': 0.00001,
            'GMM random restarts': 10
        }
        params_filler = userIO.dictIO(clustering_params, shell=shell)
        params = params_filler.fill_dict()

    if params is None:
        return None

    n_clusters = int(params['Number of Clusters'])
    n_iter = int(params['Max Number of Iterations'])
    thresh = float(params['Convergence Criterion'])
    n_restarts = int(params['GMM random restarts'])

    g = GaussianMixture(n_components=n_clusters,
                        covariance_type='full',
                        tol=thresh,
                        max_iter=n_iter,
                        n_init=n_restarts)
    g.fit(cluster['data'])

    out_clusters = []
    if g.converged_:
        spike_waveforms = cluster['spike_waveforms']
        spike_times = cluster['spike_times']
        data = cluster['data']
        predictions = g.predict(data)
        for c in range(n_clusters):
            clust_idx = np.where(predictions == c)[0]
            tmp_clust = deepcopy(cluster)
            clust_id = str(cluster['cluster_id']) + \
                bytes([b'A'[0]+c]).decode('utf-8')
            clust_name = 'E%iS%i_cluster%s' % \
                (cluster['electrode'], cluster['solution'], clust_id)
            clust_waveforms = spike_waveforms[clust_idx]
            clust_times = spike_times[clust_idx]
            clust_data = data[clust_idx, :]
            clust_log = cluster['manipulations'] + \
                '\nSplit %s with parameters: ' \
                '\n%s\nCluster %i from split results. Named %s' \
                % (cluster['Cluster Name'], dp.print_dict(params),
                   c, clust_name)
            tmp_clust['Cluster Name'] = clust_name
            tmp_clust['cluster_id'] = clust_id
            tmp_clust['data'] = clust_data
            tmp_clust['spike_times'] = clust_times
            tmp_clust['spike_waveforms'] = clust_waveforms

            tmp_isi, tmp_violations1, tmp_violations2 = \
                get_ISI_and_violations(clust_times, fs)
            tmp_clust['ISI'] = tmp_isi
            tmp_clust['1ms_violations'] = tmp_violations1
            tmp_clust['2ms_violations'] = tmp_violations2
            out_clusters.append(deepcopy(tmp_clust))

    return out_clusters
Beispiel #10
0
def label_single_unit(hf5_file,
                      cluster,
                      fs,
                      sorting_log=None,
                      metrics_dir=None):
    '''Adds a sorted unit to the hdf5 store
    adds unit info to unit_descriptor table and spike times and waveforms to
    sorted_units
    and saves metrics for unit into sorted_units folder

    Parameters
    ----------
    hf5_file : str, full path to .h5 file
    electrode_num : int, electrode number
    spike_times : numpy.array
        1D array of spike times corresponding to this unit
    spike_waveforms : numpy.array,
        array containing waveforms of spikes for this unit with each row
        corresponding to rows in spike_times
    single_unit : {0, 1},
        0 (default) is for multi-unit activity, 1 to indicate single unit
    reg_spiking : {0, 1},
        0 (default) is if multi-unit or not regular-spiking pyramidal cell
    fast_spiking : {0, 1},
        0 (default) is if multi-unit or not fast-spiking interneuron
    '''
    if sorting_log is None:
        sorting_log = hf5_file.replace('.h5', '_sorting.log')
    if metrics_dir is None:
        file_dir = os.path.dirname(hf5_file)
        metrics_dir = os.path.join(file_dir, 'sorted_unit_metrics')

    unit_name = get_next_unit_name(hf5_file)
    metrics_dir = os.path.join(metrics_dir, unit_name)
    if not os.path.exists(metrics_dir):
        os.mkdir(metrics_dir)

    with open(sorting_log, 'a+') as log:
        print('%s sorted on %s' %
              (unit_name, dt.datetime.today().strftime('%m/%d/%y %H: %M')),
              file=log)
        print('Cluster info: \n----------', file=log)
        print_clust = deepcopy(cluster)
        # Get rid of data arrays in output clister
        for k, v in cluster.items():
            if isinstance(v, np.ndarray):
                print_clust.pop(k)
        print(dp.print_dict(print_clust), file=log)
        print('Saving metrics to %s' % metrics_dir, file=log)
        print('--------------', file=log)

    with tables.open_file(hf5_file, 'r+') as hf5:
        table = hf5.root.unit_descriptor
        unit_descrip = table.row
        unit_descrip['electrode_number'] = int(cluster['electrode'])
        unit_descrip['single_unit'] = int(cluster['single_unit'])
        unit_descrip['regular_spiking'] = int(cluster['regular_spiking'])
        unit_descrip['fast_spiking'] = int(cluster['fast_spiking'])

        hf5.create_group('/sorted_units', unit_name, title=unit_name)
        waveforms = hf5.create_array('/sorted_units/%s' % unit_name,
                                     'waveforms', cluster['spike_waveforms'])
        times = hf5.create_array('/sorted_units/%s' % unit_name, 'times',
                                 cluster['spike_times'])
        unit_descrip.append()
        table.flush()
        hf5.flush()

    # Save metrics for sorted unit
    energy = cluster['data'][:, 0]
    amplitudes = cluster['data'][:, 1]
    pca_slices = cluster['data'][:, 2:]

    np.save(os.path.join(metrics_dir, 'spike_times.npy'),
            cluster['spike_times'])
    np.save(os.path.join(metrics_dir, 'spike_waveforms.npy'),
            cluster['spike_waveforms'])
    np.save(os.path.join(metrics_dir, 'energy.npy'), energy)
    np.save(os.path.join(metrics_dir, 'amplitudes.npy'), amplitudes)
    np.save(os.path.join(metrics_dir, 'pca_slices.npy'), pca_slices)
    clust_info_file = os.path.join(metrics_dir, 'cluster.info')
    with open(clust_info_file, 'a+') as log:
        print('%s sorted on %s' %
              (unit_name, dt.datetime.today().strftime('%m/%d/%y %H: %M')),
              file=log)
        print('Cluster info: \n----------', file=log)
        print(dp.print_dict(print_clust), file=log)
        print('Saved metrics to %s' % metrics_dir, file=log)
        print('--------------', file=log)

    print('Added %s to hdf5 store as %s' %
          (cluster['Cluster Name'], unit_name))
    print('Saved metrics to %s' % metrics_dir)