def partition_responses(rmap_bins=36,
                        table='cell_information',
                        quality='None',
                        smooth=True,
                        just_bins=False):
    from scanr.cluster import PrincipalCellCriteria, get_min_quality_criterion, AND
    from scanr.data import get_node

    pbins = np.array(
        [1.1, 0.5, 0.1, 0.05, 0.03, 0.02, 0.01, 0.005, 0.002, 0.0])
    ibins = np.linspace(0, 6, 13)
    if just_bins:
        return pbins, ibins
    N_pbins, N_ibins = len(pbins) - 1, len(ibins) - 1

    R = [[[] for j in xrange(N_ibins)] for i in xrange(N_pbins)]
    cell_table = get_node('/physiology', table)

    cell_criteria = AND(PrincipalCellCriteria,
                        get_min_quality_criterion(quality))

    for cell in cell_table.where('(area=="CA3")|(area=="CA1")'):

        session = SessionData.get((cell['rat'], cell['day'], cell['session']),
                                  load_clusters=False)
        cluster = session.cluster_data(cell['tc'])

        if not (cell['N_running'] > 30 and cell_criteria.filter(cluster)):
            continue

        pix = (cell['p_value'] <= pbins).nonzero()[0]
        iix = (cell['I'] >= ibins).nonzero()[0]

        if not len(pix) or not (0 <= pix[-1] < N_pbins):
            continue
        if not len(iix) or not (0 <= iix[-1] < N_ibins):
            continue

        pix = pix[-1]
        iix = iix[-1]

        R[pix][iix].append(
            session.get_cluster_ratemap(cluster,
                                        bins=rmap_bins,
                                        smoothing=smooth,
                                        blur_width=360. / rmap_bins,
                                        exclude_off_track=True,
                                        exclude=session.scan_and_pause_list))

        print '...added %s to p-value bin %.4f, info bin %.2f...' % (
            cell['tc'], pbins[pix], ibins[iix])

    for i, row in enumerate(R):
        for j, rmap in enumerate(row):
            R[i][j] = r = np.asarray(rmap)
            if not len(rmap):
                continue
            R[i][j] = r[np.argsort(np.argmax(r, axis=1))]

    return R
示例#2
0
    def collect_data(self,
                     session,
                     min_quality='fair',
                     alpha_color=False,
                     tetrodes=None,
                     plot_track=True):
        """Generate per-lap plots for each cluster of spikes on a trajectory
        """
        if type(session) is tuple and len(session) == 3:
            from scanr.session import SessionData
            session = SessionData(rds=session)

        traj = session.trajectory
        ts, x, y = traj.ts, traj.x, traj.y

        self.out("Quality filter: at least %s" % min_quality)
        criteria = AND(PrincipalCellCriteria,
                       get_min_quality_criterion(min_quality))

        if tetrodes is not None:
            criteria = AND(criteria,
                           get_tetrode_restriction_criterion(tetrodes))

        for tc, data, lap, ax in self.get_plot(session, criteria):

            start, end = lap
            t, xy = time_slice_sample(ts, np.c_[x, y], start=start, end=end)
            lap_x, lap_y = xy.T

            t, xy = time_slice_sample(data.spikes,
                                      np.c_[data.x, data.y],
                                      start=start,
                                      end=end)
            spike_x, spike_y = xy.T

            ax.plot(lap_x, lap_y, **TRAJ_FMT)

            if len(spike_x):
                if alpha_color:
                    alpha = xy_to_deg_vec(spike_x, spike_y)
                    ax.scatter(spike_x,
                               spike_y,
                               c=alpha,
                               vmin=0,
                               vmax=360,
                               **SPIKE_FMT)
                else:
                    ax.scatter(spike_x, spike_y, **SPIKE_FMT)

            if plot_track:
                plot_track_underlay(ax, lw=0.5, ls='dotted')

            ax.axis('equal')
            ax.set_axis_off()
示例#3
0
    def run(self,
            test='place',
            place_field='pass',
            min_quality='fair',
            **kwds):
        """Compute I_pos and I_spike across all criterion place cells in CA3/CA1

        Keyword arguments:
        place_field -- 'pass', 'fail', or 'all' to restrict responses based on place
            field criterion test results
        test -- 'place', 'skaggs', or 'olypher' to use either the full place field test or
            one of the component tests for the cell filtering
        min_quality -- isolation quality threshold for filtering cells

        Remaining keywords are passed to TetrodeSelect.

        Returns (I_pos, I_spike) tuple of arrays for selected cell clusters.
        """
        self.out = CPrint(prefix='ScatterInfo')
        area_query = '(area=="CA3")|(area=="CA1")'

        # Metadata for the plot title
        self.place_field = place_field
        self.test = test
        self.quality = min_quality
        if place_field == 'all':
            self.test = 'place'

        if test == 'place':
            SpatialTest = SpatialInformationCriteria
        elif test == 'skaggs':
            SpatialTest = SkaggsCriteria
        elif test == 'olypher':
            SpatialTest = OlypherCriteria
        else:
            raise ValueError, 'bad test value: %s' % test

        MinQuality = get_min_quality_criterion(min_quality)
        CellCriteria = AND(PrincipalCellCriteria, SpikeCountCriteria,
                           MinQuality)
        if place_field == 'pass':
            CellCriteria = AND(CellCriteria, SpatialTest)
        elif place_field == 'fail':
            CellCriteria = AND(CellCriteria, NOT(SpatialTest))
        elif place_field != 'all':
            raise ValueError, 'bad place_field value: %s' % place_field

        I = []
        for dataset in TetrodeSelect.datasets(area_query):
            rat, day = dataset
            Criteria = AND(
                CellCriteria,
                TetrodeSelect.criterion(dataset, area_query, **kwds))

            for maze in get_maze_list(*dataset):
                data = SessionData.get((rat, day, maze))

                for tc in data.get_clusters(request=Criteria):
                    cluster = data.cluster_data(tc)
                    I.append((cluster.I_pos, cluster.I_spike))

        self.I = I = np.array(I).T
        self.out('%d cell-sessions counted.' % I.shape[1])
        return I[0], I[1]
示例#4
0
    def collect_data(self,
                     test='place',
                     place_field='pass',
                     min_quality='fair',
                     allow_ambiguous=True):
        """Tally place fields across areas

        Keyword arguments similar to info_scores.InfoScoreData. Remaining
        keywords are passed to TetrodeSelect.
        """
        # Metadata for determining valid fields
        self.results['test'] = test
        self.results['place_field'] = place_field
        self.results['min_quality'] = min_quality
        self.results['allow_ambiguous'] = allow_ambiguous
        if place_field == 'all':
            self.test = 'place'

        # Construct place cell selection criteria based on keyword arguments
        if test == 'place':
            SpatialTest = SpatialInformationCriteria
        elif test == 'skaggs':
            SpatialTest = SkaggsCriteria
        elif test == 'olypher':
            SpatialTest = OlypherCriteria
        else:
            raise ValueError, 'bad test value: %s' % test
        MinQuality = get_min_quality_criterion(min_quality)
        CellCriteria = AND(PrincipalCellCriteria, SpikeCountCriteria,
                           MinQuality)
        if place_field == 'pass':
            CellCriteria = AND(CellCriteria, SpatialTest)
        elif place_field == 'fail':
            CellCriteria = AND(CellCriteria, NOT(SpatialTest))
        elif place_field != 'all':
            raise ValueError, 'bad place_field value: %s' % place_field

        # Walk the tree and count place fields
        N = {}
        N_cells = {}
        N_sessions = {}
        sessions = set()
        tetrodes = get_node('/metadata', 'tetrodes')
        for area in AREAS.keys():
            for subdiv in (['all'] + AREAS[area]):
                self.out('Walking datasets for %s %s...' % (area, subdiv))
                key = '%s_%s' % (area, subdiv)
                N[key] = 0
                N_cells[key] = 0
                N_sessions[key] = 0

                area_query = 'area=="%s"' % area
                if subdiv != 'all':
                    area_query = '(%s)&(subdiv=="%s")' % (area_query, subdiv)

                for dataset in TetrodeSelect.datasets(
                        area_query, allow_ambiguous=allow_ambiguous):
                    Criteria = AND(
                        CellCriteria,
                        TetrodeSelect.criterion(
                            dataset,
                            area_query,
                            allow_ambiguous=allow_ambiguous))
                    dataset_cells = set()

                    for maze in get_maze_list(*dataset):
                        rds = dataset + (maze, )
                        data = SessionData.get(rds)
                        sessions.add(rds)
                        place_cell_clusters = data.get_clusters(
                            request=Criteria)
                        N[key] += len(place_cell_clusters)
                        dataset_cells.update(place_cell_clusters)
                        N_sessions[key] += 1

                    N_cells[key] += len(dataset_cells)

        self.out.timestamp = False
        self.results['N'] = N
        self.out('Total number of sessions = %d' % len(sessions))
        for key in sorted(N.keys()):
            self.out('N_cells[%s] = %d cells' % (key, N_cells[key]))
            self.out('N_sessions[%s] = %d sessions' % (key, N_sessions[key]))
            self.out('N_cell_sessions[%s] = %d cell-sessions' % (key, N[key]))

        # Good-bye
        self.out('All done!')
示例#5
0
    def collect_data(self, min_quality='fair', shuffle_samples=200,
        shuffle_retry=20):
        """Collate and summarize statistics of scan vs. non-scan activity for
        MEC and LEC clusters

        Keyword arguments:
        min_quality -- if quality_filter is False, then this is the threshold for
            cluster isolation quality used to choose cells
        shuffle_samples -- number of randomized samples for empirical p-values
        """
        self.results['quality'] = min_quality
        Quality = get_min_quality_criterion(min_quality)
        self.out('Cluster quality: at least %s'%min_quality)

        scan_table = get_node('/behavior', 'scans')
        scan_firing_data = dict(MEC=[], LEC=[])

        LEC_datasets = self._get_datasets("area=='LEC'")
        MEC_datasets = self._get_datasets("area=='MEC'")
        area_list = ['LEC']*len(LEC_datasets) + ['MEC']*len(MEC_datasets)
        dataset_list = LEC_datasets+MEC_datasets

        # Set up spreadsheet output
        fn = os.path.join(self.datadir, 'entorhinal_scan_firing.csv')
        cols = [('dataset', 's'), ('maze', 'd'), ('type', 's'),
                ('parameter', 'd'), ('area', 's'), ('cluster', 's'),
                ('N_scans', 'd'), ('scan_rate_mean', 'f'), ('scan_rate_sd', 'f'),
                ('scan_rate_overall', 'f'), ('nonscan_rate_overall', 'f'),
                ('scan_nonscan_ratio', 'f') ]
        spreadsheet = DataSpreadsheet(fn, cols)
        record = spreadsheet.get_record()
        self.out('Record string: ' + spreadsheet.get_record_string())

        for area, dataset in zip(area_list, dataset_list):
            dataset_str = 'rat%03d-%02d'%dataset
            rat, day = dataset
            self.out('Analyzing %s for area %s.'%(dataset_str, area))
            area_query = "area=='%s'"%area

            record['dataset'] = dataset_str
            record['area'] = area

            for maze in data.session_list(rat, day):
                rds = rat, day, maze
                record['maze'] = maze

                Tetrodes = get_tetrode_restriction_criterion(
                    self._get_valid_tetrodes(dataset, area_query))
                Criteria = AND(PrincipalCellCriteria, Quality, Tetrodes)

                session_data = session.SessionData(rds=rds,
                    cluster_criteria=Criteria)
                total_time = session_data.duration
                record['type'] = session_data.data_group._v_attrs['type']
                record['parameter'] = session_data.data_group._v_attrs['parameter']

                scan_list = [tuple(scan['tlim']) for scan in
                    scan_table.where(session_data.session_query)]
                scan_list.sort()
                if len(scan_list) == 0:
                    continue
                record['N_scans'] = len(scan_list)

                for tc in session_data.get_clusters():
                    record['cluster'] = tc
                    cl_data = session_data.clusts[tc]
                    ts_spikes = cl_data.spikes

                    # Cell-session statistics
                    total_spikes = ts_spikes.size
                    session_firing_rate = total_spikes / total_time

                    # Initialize per-scan accumulators
                    scan_counts = []
                    scan_durations = []
                    scan_rates = []
                    scan_pvals = []
                    scan_norm = []

                    for start, end in scan_list:
                        scan_spikes = time.time_slice_sample(ts_spikes,
                            start=start, end=end)

                        scan_counts.append(scan_spikes.size)
                        this_scan_duration = time.elapsed(start, end)
                        scan_durations.append(this_scan_duration)
                        firing_rate = scan_spikes.size / this_scan_duration
                        scan_rates.append(firing_rate)

                        # Randomized firing-distributions for one-sided p-values
                        delta_ts = end - start
                        shuffle = np.empty((shuffle_samples,), 'd')
                        for i in xrange(shuffle_samples):
                            c = 0
                            while c < shuffle_retry:
                                # Get random time segment of same length in session
                                rand_start = long(session_data.start +
                                    plt.randint(session_data.end -
                                        session_data.start - delta_ts))
                                rand_end = long(rand_start + delta_ts)

                                # Only accept if not colliding with another scan...
                                hit = False
                                for s, e in scan_list:
                                    if ((rand_start <= s <= rand_end) or
                                        (rand_start <= e <= rand_end) or
                                        (s < rand_start and e > rand_start)):
                                        hit = True
                                        break
                                if hit:
                                    c += 1
                                else:
                                    break
                            rand_spikes = time.time_slice_sample(ts_spikes,
                                start=rand_start, end=rand_end)
                            shuffle[i] = rand_spikes.size / this_scan_duration
                        p_val = (1+(shuffle > firing_rate).sum()) / float(1+shuffle_samples)
                        scan_pvals.append(p_val)
                        scan_norm.append(firing_rate / session_firing_rate)

                    # Overall scan firing rate
                    overall_scan_rate = np.sum(scan_counts) / np.sum(scan_durations)

                    # Finish spreadsheet entry for this cluster
                    record['scan_rate_mean'] = np.mean(scan_rates)
                    record['scan_rate_sd'] = np.std(scan_rates)
                    record['scan_rate_overall'] = overall_scan_rate
                    record['nonscan_rate_overall'] = \
                        (total_spikes-np.sum(scan_counts)) / (total_time-np.sum(scan_durations))
                    record['scan_nonscan_ratio'] = 0.0
                    if record['nonscan_rate_overall']:
                        record['scan_nonscan_ratio'] = overall_scan_rate / record['nonscan_rate_overall']
                    spreadsheet.write_record(record)

                    # Create the final record for this cell-session
                    scan_row = (  np.mean(scan_rates),
                                  np.median(scan_rates),
                                  overall_scan_rate,
                                  np.mean(scan_norm),
                                  np.median(scan_norm),
                                  record['scan_nonscan_ratio'],
                                  np.mean(scan_pvals),
                                  np.median(scan_pvals)   )

                    # Store the record in an area-specific list
                    scan_firing_data[area].append(scan_row)

        # Save data as numpy record arrays
        firing_data = [ ('mean_rate', float), ('median_rate', float), ('overall_rate', float),
                        ('mean_norm', float), ('median_norm', float), ('overall_norm', float),
                        ('mean_pval', float), ('median_pval', float)  ]
        self.results['LEC'] = np.rec.fromrecords(scan_firing_data['LEC'],
            dtype=firing_data)
        self.results['MEC'] = np.rec.fromrecords(scan_firing_data['MEC'],
            dtype=firing_data)

        spreadsheet.close()
        self.out('All done!')
示例#6
0
    def collect_data(self, tetrode_query="area=='CA1'", scan_time="max",
        scan_type=None, min_quality="fair", shuffle=False, **kwds):
        """Perform scan-firing xcorrs on selected tetrodes

        Arguments:
        tetrode_query -- query string against the /metadata/tetrodes table
        scan_time -- time point of scan to lock onto (must be "start", "end",
            "mid", "max")
        min_quality -- minimum cluster isolation quality (string name or
            ClusterQuality value) for inclusion of data
        shuffle -- whether to randomly shuffle scan times

        Keywords are passed to the spike.xcorr function.
        """
        assert scan_time in ("start", "end", "mid", "max"), \
            "bad scan_time: %s"%scan_time

        ttable = get_node('/metadata', 'tetrodes')
        stable = get_node('/behavior', 'scans')

        units = ('use_millis' in kwds and kwds['use_millis']) and 'ms' or 's'
        overall_psth = None
        rat_psth = {}

        Quality = get_min_quality_criterion(min_quality)
        self.out("Quality filter: at least %s"%min_quality)

        for ratday, ax in self.get_plot(self._get_datasets(tetrode_query)):
            valid_tetrodes = self._get_valid_tetrodes(ratday, tetrode_query)
            Tetrodes = get_tetrode_restriction_criterion(valid_tetrodes)
            rat, day = ratday
            psth = {}

            ClusterFilter = AND(Tetrodes, Quality, PrincipalCellCriteria)

            for maze in session_list(rat, day, exclude_timing_issue=True):
                data = session.SessionData(rds=(rat, day, maze))
                clusters = data.get_clusters(ClusterFilter)

                if shuffle:
                    scan_times = data.random_timestamp_array(
                        size=len(stable.getWhereList(data.session_query)))
                else:
                    query = data.session_query
                    if scan_type is not None:
                        query += "&(type==\'%s\')"%scan_type.upper()
                    scan_times = [s[scan_time] for s in stable.where(query)]
                t_scan = data.to_time(scan_times)
                for cl in clusters:
                    t_spikes = data.to_time(data.get_spike_train(cl))
                    C, lags = xcorr(t_scan, t_spikes, **kwds)
                    if cl not in psth:
                        psth[cl] = np.zeros_like(C)
                    psth[cl] += C

            if not psth:
                ax.set_axis_off()
                continue

            # Initialize aggregators for overall psth and per-rat psth
            if overall_psth is None:
                overall_psth = np.zeros_like(C)
            if rat not in rat_psth:
                rat_psth[rat] = np.zeros_like(C)

            drawn = False
            fmt = dict(lw=0.5, c='b')
            max_corr = max([max(corr) for corr in psth.values()])
            for cl in psth:
                overall_psth += psth[cl]
                rat_psth[rat] += psth[cl]
                psth[cl] /= max_corr
                plot_correlogram((psth[cl], lags), is_corr=True, ax=ax,
                    plot_type="lines", fmt=fmt, zero_line=not drawn)
                drawn = True
            ax.set_yticks([])
            if self.lastrow:
                ax.set_xlabel('Spike Lag (%s)'%units)
            else:
                ax.set_xticks([])

        # For plotting overall and per-rat cross-correlograms
        self.results['lags'] = lags
        self.results['C_overall'] = overall_psth
        self.results['C_rat'] = rat_psth
        self.results['query'] = tetrode_query
        self.results['scan_time'] = scan_time
        self.results['shuffled'] = shuffle
        self.results['units'] = units
        self.results['scan_type'] = scan_type

        self.out('All done!')
示例#7
0
    def collect_data(self, area_query='(area=="CA3")|(area=="CA1")'):
        """Collect firing rate data across scans
        """
        datasets = TetrodeSelect.datasets(area_query)
        tetrode_table = get_node('/metadata', 'tetrodes')
        scan_table = get_node('/behavior', 'scans')

        epochs = (  'rate', 'running_rate', 'pause_rate', 'scan_rate',
                    'interior_rate', 'exterior_rate',
                    'outbound_rate', 'inbound_rate',
                    'ext_out_rate', 'ext_in_rate',
                    'int_out_rate', 'int_in_rate' )

        spreadsheet = DataSpreadsheet(
            os.path.join(self.datadir, 'scan_firing_rates.csv'),
            [   ('dataset', 's'), ('rat', 'd'), ('day', 'd'),
                ('area', 's'), ('area_sub', 's'), ('cell', 's') ] +
            map(lambda n: (n, 'f'), epochs))
        self.out('Record string: %s'%spreadsheet.get_record_string())
        record = spreadsheet.get_record()

        # Index labels for the scan data
        PRE, START, MAX, END = 0, 1, 2, 3

        for dataset in datasets:

            rat, day = dataset
            dataset_str = 'rat%03d-%02d'%dataset
            self.out('Calculating scan firing rates for %s...'%dataset_str)

            # Set dataset info
            record['dataset'] = dataset_str
            record['rat'] = rat
            record['day'] = day

            # Cell accumulators
            collated_cells = []
            N = {}
            T = {}

            def increment(tc, which, count, duration):
                N[tc][which] += count
                T[tc][which] += duration

            for maze in get_maze_list(rat, day):
                rds = rat, day, maze
                data = SessionData.get(rds)
                traj = data.trajectory

                def occupancy(traj_occupied):
                    return \
                        data.duration * (np.sum(traj_occupied) / float(traj.N))

                Criteria = AND(PlaceCellCriteria,
                    TetrodeSelect.criterion(dataset, area_query))

                for tc in data.get_clusters(Criteria):
                    cluster = data.cluster_data(tc)

                    if tc not in collated_cells:
                        collated_cells.append(tc)
                        N[tc] = { k: 0 for k in epochs }
                        T[tc] = { k: 0.0 for k in epochs }

                    spikes = cluster.spikes

                    increment(tc, 'rate', cluster.N, data.duration)
                    increment(tc, 'running_rate',
                        data.velocity_filter(spikes).sum(),
                        occupancy(data.velocity_filter(traj.ts)))

                    increment(tc, 'scan_rate',
                        np.sum(select_from(spikes, data.scan_list)),
                        occupancy(select_from(traj.ts, data.scan_list)))

                    increment(tc, 'pause_rate',
                        np.sum(select_from(spikes, data.pause_list)),
                        occupancy(select_from(traj.ts, data.pause_list)))

                    ext_scan_list = np.array(
                        [(rec['prepause'], rec['start'], rec['max'], rec['end'])
                            for rec in scan_table.where(data.session_query +
                                '&(type=="%s")'%EXTERIOR)])

                    int_scan_list = np.array(
                        [(rec['prepause'], rec['start'], rec['max'], rec['end'])
                            for rec in scan_table.where(data.session_query +
                                '&(type=="%s")'%INTERIOR)])

                    both_scan_list = np.array(
                        [(rec['prepause'], rec['start'], rec['max'], rec['end'])
                            for rec in scan_table.where(data.session_query +
                                '&(type!="%s")'%AMBIG)])

                    if ext_scan_list.shape[0]:
                        increment(tc, 'exterior_rate',
                            np.sum(select_from(spikes, ext_scan_list[:,(START,END)])),
                            occupancy(select_from(traj.ts, ext_scan_list[:,(START,END)])))

                        increment(tc, 'ext_out_rate',
                            np.sum(select_from(spikes, ext_scan_list[:,(START,MAX)])),
                            occupancy(select_from(traj.ts, ext_scan_list[:,(START,MAX)])))

                        increment(tc, 'ext_in_rate',
                            np.sum(select_from(spikes, ext_scan_list[:,(MAX,END)])),
                            occupancy(select_from(traj.ts, ext_scan_list[:,(MAX,END)])))

                    if int_scan_list.shape[0]:
                        increment(tc, 'interior_rate',
                            np.sum(select_from(spikes, int_scan_list[:,(START,END)])),
                            occupancy(select_from(traj.ts, int_scan_list[:,(START,END)])))

                        increment(tc, 'int_out_rate',
                            np.sum(select_from(spikes, int_scan_list[:,(START,MAX)])),
                            occupancy(select_from(traj.ts, int_scan_list[:,(START,MAX)])))

                        increment(tc, 'int_in_rate',
                            np.sum(select_from(spikes, int_scan_list[:,(MAX,END)])),
                            occupancy(select_from(traj.ts, int_scan_list[:,(MAX,END)])))

                    if both_scan_list.shape[0]:
                        increment(tc, 'outbound_rate',
                            np.sum(select_from(spikes, both_scan_list[:,(START,MAX)])),
                            occupancy(select_from(traj.ts, both_scan_list[:,(START,MAX)])))

                        increment(tc, 'inbound_rate',
                            np.sum(select_from(spikes, both_scan_list[:,(MAX,END)])),
                            occupancy(select_from(traj.ts, both_scan_list[:,(MAX,END)])))

            def firing_rate(tc, k):
                if T[tc][k]:
                    return N[tc][k] / T[tc][k]
                return 0.0

            self.out('Writing out spreadsheet records...')
            for tc in collated_cells:
                self.out.printf('.')

                tt, cl = parse_cell_name(tc)
                tetrode = get_unique_row(tetrode_table,
                    '(rat==%d)&(day==%d)&(tt==%d)'%(rat, day, tt))

                record['area'] = tetrode['area']
                record['area_sub'] = tetrode['subdiv']
                record['cell'] = tc

                record.update({ k: firing_rate(tc, k) for k in epochs })

                spreadsheet.write_record(record)
            self.out.printf('\n')

        # Finish up
        spreadsheet.close()
        self.out('All done!')
示例#8
0
    def collect_data(self, min_quality='fair', exclude_bad=False, ROI=None,
        exclude_zero_scans=False, only_m1=False, use_ec_table=False):
        """Collect normalized firing-rate information across region, scan type,
        and scan phase, for all valid clusters

        Arguments:
        min_quality -- minimum cluster isolation quality (string name or
            ClusterQuality value) for inclusion of data
        """
        self.results['quality'] = min_quality
        Quality = get_min_quality_criterion(min_quality)
        self.out('Cluster quality: at least %s'%min_quality)

        tetrode_table = get_node('/metadata', 'tetrodes')
        scan_table = get_node('/behavior', 'scans')

        # Only entorhinal if using the EC clusters table
        if use_ec_table:
            ec_table = get_node('/metadata', 'ec_clusters')
            ROI_areas = ['LEC', 'MEC']
        elif ROI:
            ROI_areas = ROI
        else:
            ROI_areas = spike.PrimaryAreas

        # Initialize main cluster data accumulator
        def new_arrays():
            return dict(INT=dict(   inbound=[], inbound_rate=[],
                                    inbound_rate_norm=[], inbound_diff=[],
                                    outbound=[], outbound_rate=[],
                                    outbound_rate_norm=[], outbound_diff=[]),
                        EXT=dict(   inbound=[], inbound_rate=[],
                                    inbound_rate_norm=[], inbound_diff=[],
                                    outbound=[], outbound_rate=[],
                                    outbound_rate_norm=[], outbound_diff=[]))
        cluster_data = { area: new_arrays() for area in ROI_areas }
        self.results['cluster_data'] = cluster_data

        scan_rate_data = { area: [] for area in ROI_areas }
        self.results['scan_rate_data'] = scan_rate_data

        scan_nonscan_diff = { area: [] for area in ROI_areas }
        self.results['scan_nonscan_diff'] = scan_nonscan_diff

        type_by_phase = [[x, y] for x in ('INT', 'EXT')
            for y in ('inbound', 'outbound')]

        cell_counts = { area: 0 for area in ROI_areas }
        self.results['cell_counts'] = cell_counts

        Criteria = AND(PrincipalCellCriteria, Quality)

        spreadsheet = DataSpreadsheet(
            os.path.join(self.datadir, 'scan_firing_rates.csv'),
            [   ('rat', 'd'), ('day', 'd'), ('area', 's'), ('cell', 's'),
                ('rate', 'f'), ('scan_rate', 'f'), ('nonscan_rate', 'f'),
                ('scan_nonscan_diff', 'f'), ('INT_outbound_diff', 'f'),
                ('INT_inbound_diff', 'f'), ('EXT_outbound_diff', 'f'),
                ('EXT_inbound_diff', 'f') ])
        self.out('Record string: %s'%spreadsheet.get_record_string())
        record = spreadsheet.get_record()

        def count_spikes(ts, start, end):
            return len(time.time_slice_sample(ts, start=start, end=end))

        for dataset in meta.walk_days():
            rat, day = dataset
            record['rat'], record['day'] = rat, day
            self.out('Analyzing dataset rat%03d-%02d.'%dataset)

            dataset_duration = 0.0
            dataset_spikes = {}

            N = dict(   INT=dict(inbound={}, outbound={}),
                        EXT=dict(inbound={}, outbound={})   )

            duration = dict(   INT=dict(inbound={}, outbound={}),
                               EXT=dict(inbound={}, outbound={})   )

            collated = []

            for maze in meta.get_maze_list(rat, day):
                if only_m1 and maze != 1:
                    continue

                rds = rat, day, maze
                grp = get_group(rds=rds)
                if exclude_bad:
                    attrs = grp._v_attrs
                    if attrs['HD_missing'] or attrs['timing_jumps']:
                        self.out('...bad dataset, skipping...')
                        continue

                if use_ec_table:
                    request = ['t%dc%d'%(rec['tt'], rec['cluster']) for
                        rec in ec_table.where(
                            '(rat==%d)&(day==%d)&(session==%d)'%rds)]
                    if request:
                        self.out('Found EC clusters: %s'%str(request)[1:-1])
                    else:
                        self.out('...no EC table clusters, skipping...')
                        continue
                else:
                    request = Criteria

                session_data = session.SessionData(rds=rds)
                dataset_duration += session_data.duration
                scan_list = \
                    [(scan['type'], scan['start'], scan['max'], scan['end'])
                        for scan in scan_table.where(session_data.session_query)
                        if scan['type'] != 'AMB']

                # Collate the spike counts, accumulating cells across sessions
                for tc in session_data.get_clusters(request):
                    cl_data = session_data.clusts[tc]
                    area = spike.get_tetrode_area(rat, day, cl_data.tt)
                    if area not in ROI_areas:
                        continue

                    # Initialize accumulators if first occurence of this cell
                    if (tc, area) not in collated:
                        collated.append((tc, area))
                        dataset_spikes[tc] = 0
                        for t, p in type_by_phase:
                            N[t][p][tc] = 0
                            duration[t][p][tc] = 0.0

                    t_spikes = cl_data.spikes
                    dataset_spikes[tc] += t_spikes.size

                    for scan_type, start, scan_max, end in scan_list:
                        if exclude_zero_scans and (count_spikes(t_spikes, start,
                            end) == 0):
                            continue

                        N[scan_type]['outbound'][tc] += \
                            count_spikes(t_spikes, start, scan_max)
                        N[scan_type]['inbound'][tc] += \
                            count_spikes(t_spikes, scan_max, end)
                        duration[scan_type]['outbound'][tc] += \
                            time.elapsed(start, scan_max)
                        duration[scan_type]['inbound'][tc] += \
                            time.elapsed(scan_max, end)

            self.out('Computing firing rates for %d cells...'%len(collated))
            for tc, area in collated:
                N_total = float(sum([N[t][p][tc] for t, p in type_by_phase]))
                duration_total = sum([duration[t][p][tc] for t, p in type_by_phase])
                if not duration_total:
                    continue

                record['cell'] = tc
                record['area'] = area

                scan_rate = N_total / duration_total
                scan_rate_data[area].append(scan_rate)
                record['scan_rate'] = scan_rate

                overall_rate = dataset_spikes[tc] / dataset_duration
                overall_nonscan_rate = \
                    (dataset_spikes[tc] - N_total) / (dataset_duration - duration_total)
                record['rate'] = overall_rate
                record['nonscan_rate'] = overall_nonscan_rate

                scan_nonscan_diff[area].append((scan_rate - overall_nonscan_rate) /
                    (scan_rate + overall_nonscan_rate))
                record['scan_nonscan_diff'] = scan_nonscan_diff[area][-1]

                cell_counts[area] += 1
                c_dict = cluster_data[area]
                for t, p in type_by_phase:
                    key = '%s_%s_diff'%(t, p)
                    record[key] = -99
                    if not (N_total and duration[t][p][tc]):
                        continue
                    c_dict[t][p].append(N[t][p][tc]/N_total)
                    this_scan_rate = N[t][p][tc]/duration[t][p][tc]
                    c_dict[t][p+'_rate'].append(this_scan_rate)
                    if this_scan_rate + scan_rate != 0:
                        c_dict[t][p+'_rate_norm'].append(
                            (this_scan_rate-scan_rate) /
                                (this_scan_rate+scan_rate))
                    if this_scan_rate + overall_nonscan_rate != 0:
                        c_dict[t][p+'_diff'].append(
                            (this_scan_rate-overall_nonscan_rate) /
                                (this_scan_rate+overall_nonscan_rate))
                        record[key] = c_dict[t][p+'_diff'][-1]
                spreadsheet.write_record(record)

        # Convert data to arrays
        for area in ROI_areas:
            self.out('Total cell count for %s: %d cells'%(area,
                cell_counts[area]))
            scan_rate_data[area] = np.array(scan_rate_data[area])
            scan_nonscan_diff[area] = np.array(scan_nonscan_diff[area])
            for scan_type in 'INT', 'EXT':
                c_dict = cluster_data[area][scan_type]
                for k in c_dict:
                    c_dict[k] = np.array(c_dict[k])

        spreadsheet.close()
        self.out("All done!")
示例#9
0
    def create_scan_cell_table(self, scan_phase='scan'):
        """For every scan–cell pair, compute the relative index of cell firing that
        occurred during the scan and previous cell firing on the track
        """
        scan_table_description = {
            'id': tb.UInt32Col(pos=1),
            'scan_id': tb.UInt16Col(pos=2),
            'rat': tb.UInt16Col(pos=3),
            'day': tb.UInt16Col(pos=4),
            'session': tb.UInt16Col(pos=5),
            'session_start_angle': tb.FloatCol(pos=6),
            'session_end_angle': tb.FloatCol(pos=7),
            'tc': tb.StringCol(itemsize=8, pos=8),
            'type': tb.StringCol(itemsize=4, pos=9),
            'expt_type': tb.StringCol(itemsize=4, pos=10),
            'area': tb.StringCol(itemsize=4, pos=11),
            'subdiv': tb.StringCol(itemsize=4, pos=12),
            'duration': tb.FloatCol(pos=13),
            'magnitude': tb.FloatCol(pos=14),
            'angle': tb.FloatCol(pos=15)
        }

        def add_scan_index_column_descriptors(descr):
            pos = 16
            for name in ScanIndex.AllNames:
                descr[name] = tb.FloatCol(pos=pos)
                pos += 1

        add_scan_index_column_descriptors(scan_table_description)

        data_file = self.get_data_file(mode='a')
        scan_cell_table = create_table(data_file,
                                       '/',
                                       'scan_cell_info',
                                       scan_table_description,
                                       title='Metadata for Scan-Cell Pairs')
        scan_cell_table._v_attrs['scan_phase'] = scan_phase
        row = scan_cell_table.row
        row_id = 0

        scans_table = get_node('/behavior', 'scans')
        sessions_table = get_node('/metadata', 'sessions')
        tetrodes_table = get_node('/metadata', 'tetrodes')

        cornu_ammonis_query = '(area=="CA1")|(area=="CA3")'
        hippocampal_datasets = unique_datasets('/metadata',
                                               'tetrodes',
                                               condn=cornu_ammonis_query)

        quality_place_cells = AND(get_min_quality_criterion(self.min_quality),
                                  PlaceCellCriteria)

        index = ScanIndex(scan_phase=scan_phase)

        for dataset in hippocampal_datasets:
            dataset_query = '(rat==%d)&(day==%d)' % dataset

            hippocampal_tetrodes = unique_values(
                tetrodes_table,
                column='tt',
                condn='(%s)&(%s)' % (dataset_query, cornu_ammonis_query))
            cluster_criteria = AND(
                quality_place_cells,
                get_tetrode_restriction_criterion(hippocampal_tetrodes))

            for maze in get_maze_list(*dataset):
                rds = dataset + (maze, )
                session = SessionData(rds=rds)
                place_cells = session.get_clusters(cluster_criteria)
                session_start_angle = np.median(
                    session.trajectory.alpha_unwrapped[:5])
                session_end_angle = np.median(
                    session.trajectory.alpha_unwrapped[-5:])

                self.out('Computing scan index for %s...' %
                         session.data_group._v_pathname)

                for scan in scans_table.where(session.session_query):
                    self.out.printf('|', color='cyan')

                    for cell in place_cells:
                        cluster = session.cluster_data(cell)

                        tt, cl = parse_cell_name(cluster.name)
                        tetrode = get_unique_row(
                            tetrodes_table, '(rat==%d)&(day==%d)&(tt==%d)' %
                            (rds[0], rds[1], tt))

                        row['id'] = row_id
                        row['scan_id'] = scan['id']
                        row['rat'], row['day'], row['session'] = rds
                        row['session_start_angle'] = session_start_angle
                        row['session_end_angle'] = session_end_angle
                        row['tc'] = cluster.name
                        row['type'] = session.attrs['type']
                        row['expt_type'] = get_unique_row(
                            sessions_table, session.session_query)['expt_type']
                        row['area'] = tetrode['area']
                        row['subdiv'] = tetrode['area'] + tetrode['subdiv'][:1]
                        row['angle'] = session.F_('alpha_unwrapped')(
                            session.T_(scan['start']))
                        row['duration'] = scan['duration']
                        row['magnitude'] = scan['magnitude']

                        for index_name in ScanIndex.AllNames:
                            row[index_name] = index.compute(
                                index_name, session, scan, cluster)

                        self.out.printf('.', color='green')
                        row_id += 1
                        row.append()

                        if row_id % 100 == 0:
                            scan_cell_table.flush()

                self.out.printf('\n')

        scan_cell_table.flush()
        self.out('Finished creating %s.' % scan_cell_table._v_pathname)