예제 #1
0
def partition_responses(rmap_bins=36,
                        table='cell_information',
                        quality='None',
                        smooth=True,
                        just_bins=False):
    from scanr.cluster import PrincipalCellCriteria, get_min_quality_criterion, AND
    from scanr.data import get_node

    pbins = np.array(
        [1.1, 0.5, 0.1, 0.05, 0.03, 0.02, 0.01, 0.005, 0.002, 0.0])
    ibins = np.linspace(0, 6, 13)
    if just_bins:
        return pbins, ibins
    N_pbins, N_ibins = len(pbins) - 1, len(ibins) - 1

    R = [[[] for j in xrange(N_ibins)] for i in xrange(N_pbins)]
    cell_table = get_node('/physiology', table)

    cell_criteria = AND(PrincipalCellCriteria,
                        get_min_quality_criterion(quality))

    for cell in cell_table.where('(area=="CA3")|(area=="CA1")'):

        session = SessionData.get((cell['rat'], cell['day'], cell['session']),
                                  load_clusters=False)
        cluster = session.cluster_data(cell['tc'])

        if not (cell['N_running'] > 30 and cell_criteria.filter(cluster)):
            continue

        pix = (cell['p_value'] <= pbins).nonzero()[0]
        iix = (cell['I'] >= ibins).nonzero()[0]

        if not len(pix) or not (0 <= pix[-1] < N_pbins):
            continue
        if not len(iix) or not (0 <= iix[-1] < N_ibins):
            continue

        pix = pix[-1]
        iix = iix[-1]

        R[pix][iix].append(
            session.get_cluster_ratemap(cluster,
                                        bins=rmap_bins,
                                        smoothing=smooth,
                                        blur_width=360. / rmap_bins,
                                        exclude_off_track=True,
                                        exclude=session.scan_and_pause_list))

        print '...added %s to p-value bin %.4f, info bin %.2f...' % (
            cell['tc'], pbins[pix], ibins[iix])

    for i, row in enumerate(R):
        for j, rmap in enumerate(row):
            R[i][j] = r = np.asarray(rmap)
            if not len(rmap):
                continue
            R[i][j] = r[np.argsort(np.argmax(r, axis=1))]

    return R
예제 #2
0
    def collect_data(self,
                     session,
                     min_quality='fair',
                     alpha_color=False,
                     tetrodes=None,
                     plot_track=True):
        """Generate per-lap plots for each cluster of spikes on a trajectory
        """
        if type(session) is tuple and len(session) == 3:
            from scanr.session import SessionData
            session = SessionData(rds=session)

        traj = session.trajectory
        ts, x, y = traj.ts, traj.x, traj.y

        self.out("Quality filter: at least %s" % min_quality)
        criteria = AND(PrincipalCellCriteria,
                       get_min_quality_criterion(min_quality))

        if tetrodes is not None:
            criteria = AND(criteria,
                           get_tetrode_restriction_criterion(tetrodes))

        for tc, data, lap, ax in self.get_plot(session, criteria):

            start, end = lap
            t, xy = time_slice_sample(ts, np.c_[x, y], start=start, end=end)
            lap_x, lap_y = xy.T

            t, xy = time_slice_sample(data.spikes,
                                      np.c_[data.x, data.y],
                                      start=start,
                                      end=end)
            spike_x, spike_y = xy.T

            ax.plot(lap_x, lap_y, **TRAJ_FMT)

            if len(spike_x):
                if alpha_color:
                    alpha = xy_to_deg_vec(spike_x, spike_y)
                    ax.scatter(spike_x,
                               spike_y,
                               c=alpha,
                               vmin=0,
                               vmax=360,
                               **SPIKE_FMT)
                else:
                    ax.scatter(spike_x, spike_y, **SPIKE_FMT)

            if plot_track:
                plot_track_underlay(ax, lw=0.5, ls='dotted')

            ax.axis('equal')
            ax.set_axis_off()
예제 #3
0
    def run(self,
            test='place',
            place_field='pass',
            min_quality='fair',
            **kwds):
        """Compute I_pos and I_spike across all criterion place cells in CA3/CA1

        Keyword arguments:
        place_field -- 'pass', 'fail', or 'all' to restrict responses based on place
            field criterion test results
        test -- 'place', 'skaggs', or 'olypher' to use either the full place field test or
            one of the component tests for the cell filtering
        min_quality -- isolation quality threshold for filtering cells

        Remaining keywords are passed to TetrodeSelect.

        Returns (I_pos, I_spike) tuple of arrays for selected cell clusters.
        """
        self.out = CPrint(prefix='ScatterInfo')
        area_query = '(area=="CA3")|(area=="CA1")'

        # Metadata for the plot title
        self.place_field = place_field
        self.test = test
        self.quality = min_quality
        if place_field == 'all':
            self.test = 'place'

        if test == 'place':
            SpatialTest = SpatialInformationCriteria
        elif test == 'skaggs':
            SpatialTest = SkaggsCriteria
        elif test == 'olypher':
            SpatialTest = OlypherCriteria
        else:
            raise ValueError, 'bad test value: %s' % test

        MinQuality = get_min_quality_criterion(min_quality)
        CellCriteria = AND(PrincipalCellCriteria, SpikeCountCriteria,
                           MinQuality)
        if place_field == 'pass':
            CellCriteria = AND(CellCriteria, SpatialTest)
        elif place_field == 'fail':
            CellCriteria = AND(CellCriteria, NOT(SpatialTest))
        elif place_field != 'all':
            raise ValueError, 'bad place_field value: %s' % place_field

        I = []
        for dataset in TetrodeSelect.datasets(area_query):
            rat, day = dataset
            Criteria = AND(
                CellCriteria,
                TetrodeSelect.criterion(dataset, area_query, **kwds))

            for maze in get_maze_list(*dataset):
                data = SessionData.get((rat, day, maze))

                for tc in data.get_clusters(request=Criteria):
                    cluster = data.cluster_data(tc)
                    I.append((cluster.I_pos, cluster.I_spike))

        self.I = I = np.array(I).T
        self.out('%d cell-sessions counted.' % I.shape[1])
        return I[0], I[1]
예제 #4
0
    def collect_data(self,
                     min_quality='fair',
                     exclude_table_names=('novel_fields',
                                          'potentiation_0.5_tol_0.2')):
        """Excluding place fields from the tables listed in exclude_table_names
        under /physiology, categorize whether scans within the field fire or
        not based on the normalized distance through the field where the
        scan occurs. This should illustrate any field-based scan firing bias
        that may explain the observed bias in the scan-field cross-correlation
        data in scanr.ana.predictive.TrackAngleCorrelation.
        """
        area_query = '(area=="CA1")|(area=="CA3")'

        # Check modulation event table
        self.results['exclude_table_names'] = exclude_table_names
        exclude_tables = map(lambda t: get_node('/physiology', t),
                             exclude_table_names)
        sessions_table = get_node('/metadata', 'sessions')
        scan_table = get_node('/behavior', 'scans')

        # Place-field tables and iterator
        data_file = self.open_data_file()
        scan_spike_table = data_file.createTable(
            '/', 'field_scans', ScanDescr, title='In-field Scan Spiking Data')
        row = scan_spike_table.row

        # Quality criterion
        Quality = get_min_quality_criterion(min_quality)

        self.out('Gathering place field scanning data...')
        for dataset in TetrodeSelect.datasets(area_query,
                                              allow_ambiguous=True):
            rat, day = dataset

            Tetrodes = TetrodeSelect.criterion(dataset,
                                               area_query,
                                               allow_ambiguous=True)
            Criteria = AND(Quality, Tetrodes, PlaceCellCriteria)

            for session in sessions_table.where('(rat==%d)&(day==%d)' %
                                                dataset):
                rds = rat, day, session['session']

                # Set cluster criteria and load session data
                session_data = SessionData.get(rds)
                session_data.cluster_criteria = Criteria

                # Get timing of scan start, max, and end
                scan_timing = session_data.T_(
                    np.array([
                        (rec['start'], rec['max'], rec['end'])
                        for rec in scan_table.where(session_data.session_query)
                    ]))
                scan_magnitude = np.array([
                    rec['magnitude']
                    for rec in scan_table.where(session_data.session_query)
                ])

                if not scan_timing.size:
                    continue

                self.out.printf('Scanning: ', color='lightgray')
                for tc in session_data.get_clusters():

                    # Check for any events for this cell, skip if found
                    skip_unstable = False
                    for table in exclude_tables:
                        found = table.getWhereList(session_data.session_query +
                                                   '&(tc=="%s")' % tc)
                        if len(found):
                            skip_unstable = True
                            break
                    if skip_unstable:
                        self.out.printf(u'\u25a0', color='red')
                        continue

                    # Get pooled ratemap and discard weak place fields
                    ratemap_kwds = dict(bins=RATEMAP_BINS,
                                        blur_width=360 / RATEMAP_BINS)
                    ratemap_kwds.update(session_data.running_filter())
                    R_pooled = session_data.get_cluster_ratemap(
                        tc, **ratemap_kwds)
                    if R_pooled.max() < MIN_FIELD_RATE:
                        self.out.printf(u'\u25a1', color='red')
                        continue

                    # Mark pooled field and discard small place fields
                    field = mark_max_field(R_pooled, floor=0.1, kill_on=2)
                    start, end = field_extent(field)
                    wrapped = start > end
                    field_size = wrapped and (360 - start + end) or (end -
                                                                     start)
                    if field_size < MIN_FIELD_SIZE:
                        self.out.printf(u'\u25a1', color='red')
                        continue

                    # Output indication that we are processing a place field
                    self.out.printf(u'\u25a1', color='green')

                    # Cut laps opposite COM, get spike trains, spike angles
                    cut_laps_opposite_field(session_data, tc, R=R_pooled)
                    cdata = session_data.cluster_data(tc)
                    run_ix = session_data.filter_tracking_data(
                        cdata.spikes,
                        cdata.x,
                        cdata.y,
                        boolean_index=True,
                        **session_data.running_filter())
                    t_all_spikes = session_data.T_(cdata.spikes)
                    t_run_spikes = t_all_spikes[run_ix]
                    alpha_run_spikes = xy_to_deg_vec(cdata.x[run_ix],
                                                     cdata.y[run_ix])
                    in_field = (wrapped and np.logical_or
                                or np.logical_and)(alpha_run_spikes >= start,
                                                   alpha_run_spikes <= end)

                    for i in xrange(1, session_data.N_laps - 1):

                        # this loop skips first and last laps to avoid problems with finding
                        # complete traversals on incomplete laps

                        lap_interval = [
                            session_data.T_(session_data.laps[i:i + 2])
                        ]

                        # Find traversal spikes on this lap, ignore if smaller than threshold
                        in_lap = select_from(t_run_spikes, lap_interval)
                        in_traversal = np.logical_and(in_lap, in_field)
                        if in_traversal.sum() < MIN_TRAVERSAL_SPIKES:
                            continue
                        alpha_traversal_spikes = alpha_run_spikes[in_traversal]
                        start_traversal, end_traversal = alpha_traversal_spikes[
                            -1], alpha_traversal_spikes[0]
                        wrapped_traversal = start_traversal > end_traversal
                        if wrapped_traversal:
                            traversal_size = 360 - start_traversal + end_traversal
                        else:
                            traversal_size = end_traversal - start_traversal
                        if traversal_size < MIN_TRAVERSAL_SIZE:
                            continue

                        strength = in_traversal.sum() / t_run_spikes[
                            in_traversal].ptp()  # rough firing rate

                        # Indices of scans on this lap meeting the minimum magnitude threshold
                        lap_scan_ix = np.logical_and(
                            select_from(scan_timing[:, 0], lap_interval),
                            scan_magnitude >= MIN_SCAN_MAGNITUDE).nonzero()[0]

                        for scan_ix in lap_scan_ix:
                            scan = session_data.F_('alpha_unwrapped')(
                                scan_timing[scan_ix, 0]) % 360

                            # Compute field traversal-normalized scan locations for wrapped and
                            # not-wrapped linear fields; skip non-field scans
                            if wrapped_traversal:
                                if scan >= start_traversal:
                                    norm_dist_traversal = (
                                        scan -
                                        start_traversal) / traversal_size
                                elif scan <= end_traversal:
                                    norm_dist_traversal = (
                                        360 - start_traversal +
                                        scan) / traversal_size
                                else:
                                    continue
                            else:
                                if start_traversal <= scan <= end_traversal:
                                    norm_dist_traversal = (
                                        scan -
                                        start_traversal) / traversal_size
                                else:
                                    continue

                            # ...and for the pooled field
                            if wrapped:
                                if scan >= start:
                                    norm_dist_field = (scan -
                                                       start) / field_size
                                elif scan <= end:
                                    norm_dist_field = (360 - start +
                                                       scan) / field_size
                            else:
                                norm_dist_field = (scan - start) / field_size

                            # Convert to running direction (CCW -> CW)
                            norm_dist_field = 1 - norm_dist_field
                            norm_dist_traversal = 1 - norm_dist_traversal

                            # Count the number of scan spikes
                            N_out_spikes = select_from(
                                t_all_spikes,
                                [scan_timing[scan_ix, :2]]).sum()
                            N_in_spikes = select_from(
                                t_all_spikes,
                                [scan_timing[scan_ix, 1:]]).sum()
                            N_spikes = select_from(
                                t_all_spikes, [scan_timing[scan_ix,
                                                           (0, 2)]]).sum()

                            # Add row to field-scan table
                            row['rat'] = rat
                            row['day'] = day
                            row['session'] = session['session']
                            row['tc'] = tc
                            row['scan'] = scan_ix + 1
                            row['field_distance'] = norm_dist_field
                            row['traversal_distance'] = norm_dist_traversal
                            row['strength'] = strength
                            row['field_size'] = field_size
                            row['traversal_size'] = traversal_size
                            row['out_spikes'] = N_out_spikes
                            row['in_spikes'] = N_in_spikes
                            row['spikes'] = N_spikes
                            row.append()

                self.out.printf('\n')
            scan_spike_table.flush()
        self.out('All done!')
예제 #5
0
    def collect_data(self,
                     test='place',
                     place_field='pass',
                     min_quality='fair',
                     allow_ambiguous=True):
        """Tally place fields across areas

        Keyword arguments similar to info_scores.InfoScoreData. Remaining
        keywords are passed to TetrodeSelect.
        """
        # Metadata for determining valid fields
        self.results['test'] = test
        self.results['place_field'] = place_field
        self.results['min_quality'] = min_quality
        self.results['allow_ambiguous'] = allow_ambiguous
        if place_field == 'all':
            self.test = 'place'

        # Construct place cell selection criteria based on keyword arguments
        if test == 'place':
            SpatialTest = SpatialInformationCriteria
        elif test == 'skaggs':
            SpatialTest = SkaggsCriteria
        elif test == 'olypher':
            SpatialTest = OlypherCriteria
        else:
            raise ValueError, 'bad test value: %s' % test
        MinQuality = get_min_quality_criterion(min_quality)
        CellCriteria = AND(PrincipalCellCriteria, SpikeCountCriteria,
                           MinQuality)
        if place_field == 'pass':
            CellCriteria = AND(CellCriteria, SpatialTest)
        elif place_field == 'fail':
            CellCriteria = AND(CellCriteria, NOT(SpatialTest))
        elif place_field != 'all':
            raise ValueError, 'bad place_field value: %s' % place_field

        # Walk the tree and count place fields
        N = {}
        N_cells = {}
        N_sessions = {}
        sessions = set()
        tetrodes = get_node('/metadata', 'tetrodes')
        for area in AREAS.keys():
            for subdiv in (['all'] + AREAS[area]):
                self.out('Walking datasets for %s %s...' % (area, subdiv))
                key = '%s_%s' % (area, subdiv)
                N[key] = 0
                N_cells[key] = 0
                N_sessions[key] = 0

                area_query = 'area=="%s"' % area
                if subdiv != 'all':
                    area_query = '(%s)&(subdiv=="%s")' % (area_query, subdiv)

                for dataset in TetrodeSelect.datasets(
                        area_query, allow_ambiguous=allow_ambiguous):
                    Criteria = AND(
                        CellCriteria,
                        TetrodeSelect.criterion(
                            dataset,
                            area_query,
                            allow_ambiguous=allow_ambiguous))
                    dataset_cells = set()

                    for maze in get_maze_list(*dataset):
                        rds = dataset + (maze, )
                        data = SessionData.get(rds)
                        sessions.add(rds)
                        place_cell_clusters = data.get_clusters(
                            request=Criteria)
                        N[key] += len(place_cell_clusters)
                        dataset_cells.update(place_cell_clusters)
                        N_sessions[key] += 1

                    N_cells[key] += len(dataset_cells)

        self.out.timestamp = False
        self.results['N'] = N
        self.out('Total number of sessions = %d' % len(sessions))
        for key in sorted(N.keys()):
            self.out('N_cells[%s] = %d cells' % (key, N_cells[key]))
            self.out('N_sessions[%s] = %d sessions' % (key, N_sessions[key]))
            self.out('N_cell_sessions[%s] = %d cell-sessions' % (key, N[key]))

        # Good-bye
        self.out('All done!')
예제 #6
0
    def collect_data(self, min_quality='fair', shuffle_samples=200,
        shuffle_retry=20):
        """Collate and summarize statistics of scan vs. non-scan activity for
        MEC and LEC clusters

        Keyword arguments:
        min_quality -- if quality_filter is False, then this is the threshold for
            cluster isolation quality used to choose cells
        shuffle_samples -- number of randomized samples for empirical p-values
        """
        self.results['quality'] = min_quality
        Quality = get_min_quality_criterion(min_quality)
        self.out('Cluster quality: at least %s'%min_quality)

        scan_table = get_node('/behavior', 'scans')
        scan_firing_data = dict(MEC=[], LEC=[])

        LEC_datasets = self._get_datasets("area=='LEC'")
        MEC_datasets = self._get_datasets("area=='MEC'")
        area_list = ['LEC']*len(LEC_datasets) + ['MEC']*len(MEC_datasets)
        dataset_list = LEC_datasets+MEC_datasets

        # Set up spreadsheet output
        fn = os.path.join(self.datadir, 'entorhinal_scan_firing.csv')
        cols = [('dataset', 's'), ('maze', 'd'), ('type', 's'),
                ('parameter', 'd'), ('area', 's'), ('cluster', 's'),
                ('N_scans', 'd'), ('scan_rate_mean', 'f'), ('scan_rate_sd', 'f'),
                ('scan_rate_overall', 'f'), ('nonscan_rate_overall', 'f'),
                ('scan_nonscan_ratio', 'f') ]
        spreadsheet = DataSpreadsheet(fn, cols)
        record = spreadsheet.get_record()
        self.out('Record string: ' + spreadsheet.get_record_string())

        for area, dataset in zip(area_list, dataset_list):
            dataset_str = 'rat%03d-%02d'%dataset
            rat, day = dataset
            self.out('Analyzing %s for area %s.'%(dataset_str, area))
            area_query = "area=='%s'"%area

            record['dataset'] = dataset_str
            record['area'] = area

            for maze in data.session_list(rat, day):
                rds = rat, day, maze
                record['maze'] = maze

                Tetrodes = get_tetrode_restriction_criterion(
                    self._get_valid_tetrodes(dataset, area_query))
                Criteria = AND(PrincipalCellCriteria, Quality, Tetrodes)

                session_data = session.SessionData(rds=rds,
                    cluster_criteria=Criteria)
                total_time = session_data.duration
                record['type'] = session_data.data_group._v_attrs['type']
                record['parameter'] = session_data.data_group._v_attrs['parameter']

                scan_list = [tuple(scan['tlim']) for scan in
                    scan_table.where(session_data.session_query)]
                scan_list.sort()
                if len(scan_list) == 0:
                    continue
                record['N_scans'] = len(scan_list)

                for tc in session_data.get_clusters():
                    record['cluster'] = tc
                    cl_data = session_data.clusts[tc]
                    ts_spikes = cl_data.spikes

                    # Cell-session statistics
                    total_spikes = ts_spikes.size
                    session_firing_rate = total_spikes / total_time

                    # Initialize per-scan accumulators
                    scan_counts = []
                    scan_durations = []
                    scan_rates = []
                    scan_pvals = []
                    scan_norm = []

                    for start, end in scan_list:
                        scan_spikes = time.time_slice_sample(ts_spikes,
                            start=start, end=end)

                        scan_counts.append(scan_spikes.size)
                        this_scan_duration = time.elapsed(start, end)
                        scan_durations.append(this_scan_duration)
                        firing_rate = scan_spikes.size / this_scan_duration
                        scan_rates.append(firing_rate)

                        # Randomized firing-distributions for one-sided p-values
                        delta_ts = end - start
                        shuffle = np.empty((shuffle_samples,), 'd')
                        for i in xrange(shuffle_samples):
                            c = 0
                            while c < shuffle_retry:
                                # Get random time segment of same length in session
                                rand_start = long(session_data.start +
                                    plt.randint(session_data.end -
                                        session_data.start - delta_ts))
                                rand_end = long(rand_start + delta_ts)

                                # Only accept if not colliding with another scan...
                                hit = False
                                for s, e in scan_list:
                                    if ((rand_start <= s <= rand_end) or
                                        (rand_start <= e <= rand_end) or
                                        (s < rand_start and e > rand_start)):
                                        hit = True
                                        break
                                if hit:
                                    c += 1
                                else:
                                    break
                            rand_spikes = time.time_slice_sample(ts_spikes,
                                start=rand_start, end=rand_end)
                            shuffle[i] = rand_spikes.size / this_scan_duration
                        p_val = (1+(shuffle > firing_rate).sum()) / float(1+shuffle_samples)
                        scan_pvals.append(p_val)
                        scan_norm.append(firing_rate / session_firing_rate)

                    # Overall scan firing rate
                    overall_scan_rate = np.sum(scan_counts) / np.sum(scan_durations)

                    # Finish spreadsheet entry for this cluster
                    record['scan_rate_mean'] = np.mean(scan_rates)
                    record['scan_rate_sd'] = np.std(scan_rates)
                    record['scan_rate_overall'] = overall_scan_rate
                    record['nonscan_rate_overall'] = \
                        (total_spikes-np.sum(scan_counts)) / (total_time-np.sum(scan_durations))
                    record['scan_nonscan_ratio'] = 0.0
                    if record['nonscan_rate_overall']:
                        record['scan_nonscan_ratio'] = overall_scan_rate / record['nonscan_rate_overall']
                    spreadsheet.write_record(record)

                    # Create the final record for this cell-session
                    scan_row = (  np.mean(scan_rates),
                                  np.median(scan_rates),
                                  overall_scan_rate,
                                  np.mean(scan_norm),
                                  np.median(scan_norm),
                                  record['scan_nonscan_ratio'],
                                  np.mean(scan_pvals),
                                  np.median(scan_pvals)   )

                    # Store the record in an area-specific list
                    scan_firing_data[area].append(scan_row)

        # Save data as numpy record arrays
        firing_data = [ ('mean_rate', float), ('median_rate', float), ('overall_rate', float),
                        ('mean_norm', float), ('median_norm', float), ('overall_norm', float),
                        ('mean_pval', float), ('median_pval', float)  ]
        self.results['LEC'] = np.rec.fromrecords(scan_firing_data['LEC'],
            dtype=firing_data)
        self.results['MEC'] = np.rec.fromrecords(scan_firing_data['MEC'],
            dtype=firing_data)

        spreadsheet.close()
        self.out('All done!')
예제 #7
0
    def collect_data(self, tetrode_query="area=='CA1'", scan_time="max",
        scan_type=None, min_quality="fair", shuffle=False, **kwds):
        """Perform scan-firing xcorrs on selected tetrodes

        Arguments:
        tetrode_query -- query string against the /metadata/tetrodes table
        scan_time -- time point of scan to lock onto (must be "start", "end",
            "mid", "max")
        min_quality -- minimum cluster isolation quality (string name or
            ClusterQuality value) for inclusion of data
        shuffle -- whether to randomly shuffle scan times

        Keywords are passed to the spike.xcorr function.
        """
        assert scan_time in ("start", "end", "mid", "max"), \
            "bad scan_time: %s"%scan_time

        ttable = get_node('/metadata', 'tetrodes')
        stable = get_node('/behavior', 'scans')

        units = ('use_millis' in kwds and kwds['use_millis']) and 'ms' or 's'
        overall_psth = None
        rat_psth = {}

        Quality = get_min_quality_criterion(min_quality)
        self.out("Quality filter: at least %s"%min_quality)

        for ratday, ax in self.get_plot(self._get_datasets(tetrode_query)):
            valid_tetrodes = self._get_valid_tetrodes(ratday, tetrode_query)
            Tetrodes = get_tetrode_restriction_criterion(valid_tetrodes)
            rat, day = ratday
            psth = {}

            ClusterFilter = AND(Tetrodes, Quality, PrincipalCellCriteria)

            for maze in session_list(rat, day, exclude_timing_issue=True):
                data = session.SessionData(rds=(rat, day, maze))
                clusters = data.get_clusters(ClusterFilter)

                if shuffle:
                    scan_times = data.random_timestamp_array(
                        size=len(stable.getWhereList(data.session_query)))
                else:
                    query = data.session_query
                    if scan_type is not None:
                        query += "&(type==\'%s\')"%scan_type.upper()
                    scan_times = [s[scan_time] for s in stable.where(query)]
                t_scan = data.to_time(scan_times)
                for cl in clusters:
                    t_spikes = data.to_time(data.get_spike_train(cl))
                    C, lags = xcorr(t_scan, t_spikes, **kwds)
                    if cl not in psth:
                        psth[cl] = np.zeros_like(C)
                    psth[cl] += C

            if not psth:
                ax.set_axis_off()
                continue

            # Initialize aggregators for overall psth and per-rat psth
            if overall_psth is None:
                overall_psth = np.zeros_like(C)
            if rat not in rat_psth:
                rat_psth[rat] = np.zeros_like(C)

            drawn = False
            fmt = dict(lw=0.5, c='b')
            max_corr = max([max(corr) for corr in psth.values()])
            for cl in psth:
                overall_psth += psth[cl]
                rat_psth[rat] += psth[cl]
                psth[cl] /= max_corr
                plot_correlogram((psth[cl], lags), is_corr=True, ax=ax,
                    plot_type="lines", fmt=fmt, zero_line=not drawn)
                drawn = True
            ax.set_yticks([])
            if self.lastrow:
                ax.set_xlabel('Spike Lag (%s)'%units)
            else:
                ax.set_xticks([])

        # For plotting overall and per-rat cross-correlograms
        self.results['lags'] = lags
        self.results['C_overall'] = overall_psth
        self.results['C_rat'] = rat_psth
        self.results['query'] = tetrode_query
        self.results['scan_time'] = scan_time
        self.results['shuffled'] = shuffle
        self.results['units'] = units
        self.results['scan_type'] = scan_type

        self.out('All done!')
예제 #8
0
    def collect_data(self, min_quality='fair', exclude_bad=False, ROI=None,
        exclude_zero_scans=False, only_m1=False, use_ec_table=False):
        """Collect normalized firing-rate information across region, scan type,
        and scan phase, for all valid clusters

        Arguments:
        min_quality -- minimum cluster isolation quality (string name or
            ClusterQuality value) for inclusion of data
        """
        self.results['quality'] = min_quality
        Quality = get_min_quality_criterion(min_quality)
        self.out('Cluster quality: at least %s'%min_quality)

        tetrode_table = get_node('/metadata', 'tetrodes')
        scan_table = get_node('/behavior', 'scans')

        # Only entorhinal if using the EC clusters table
        if use_ec_table:
            ec_table = get_node('/metadata', 'ec_clusters')
            ROI_areas = ['LEC', 'MEC']
        elif ROI:
            ROI_areas = ROI
        else:
            ROI_areas = spike.PrimaryAreas

        # Initialize main cluster data accumulator
        def new_arrays():
            return dict(INT=dict(   inbound=[], inbound_rate=[],
                                    inbound_rate_norm=[], inbound_diff=[],
                                    outbound=[], outbound_rate=[],
                                    outbound_rate_norm=[], outbound_diff=[]),
                        EXT=dict(   inbound=[], inbound_rate=[],
                                    inbound_rate_norm=[], inbound_diff=[],
                                    outbound=[], outbound_rate=[],
                                    outbound_rate_norm=[], outbound_diff=[]))
        cluster_data = { area: new_arrays() for area in ROI_areas }
        self.results['cluster_data'] = cluster_data

        scan_rate_data = { area: [] for area in ROI_areas }
        self.results['scan_rate_data'] = scan_rate_data

        scan_nonscan_diff = { area: [] for area in ROI_areas }
        self.results['scan_nonscan_diff'] = scan_nonscan_diff

        type_by_phase = [[x, y] for x in ('INT', 'EXT')
            for y in ('inbound', 'outbound')]

        cell_counts = { area: 0 for area in ROI_areas }
        self.results['cell_counts'] = cell_counts

        Criteria = AND(PrincipalCellCriteria, Quality)

        spreadsheet = DataSpreadsheet(
            os.path.join(self.datadir, 'scan_firing_rates.csv'),
            [   ('rat', 'd'), ('day', 'd'), ('area', 's'), ('cell', 's'),
                ('rate', 'f'), ('scan_rate', 'f'), ('nonscan_rate', 'f'),
                ('scan_nonscan_diff', 'f'), ('INT_outbound_diff', 'f'),
                ('INT_inbound_diff', 'f'), ('EXT_outbound_diff', 'f'),
                ('EXT_inbound_diff', 'f') ])
        self.out('Record string: %s'%spreadsheet.get_record_string())
        record = spreadsheet.get_record()

        def count_spikes(ts, start, end):
            return len(time.time_slice_sample(ts, start=start, end=end))

        for dataset in meta.walk_days():
            rat, day = dataset
            record['rat'], record['day'] = rat, day
            self.out('Analyzing dataset rat%03d-%02d.'%dataset)

            dataset_duration = 0.0
            dataset_spikes = {}

            N = dict(   INT=dict(inbound={}, outbound={}),
                        EXT=dict(inbound={}, outbound={})   )

            duration = dict(   INT=dict(inbound={}, outbound={}),
                               EXT=dict(inbound={}, outbound={})   )

            collated = []

            for maze in meta.get_maze_list(rat, day):
                if only_m1 and maze != 1:
                    continue

                rds = rat, day, maze
                grp = get_group(rds=rds)
                if exclude_bad:
                    attrs = grp._v_attrs
                    if attrs['HD_missing'] or attrs['timing_jumps']:
                        self.out('...bad dataset, skipping...')
                        continue

                if use_ec_table:
                    request = ['t%dc%d'%(rec['tt'], rec['cluster']) for
                        rec in ec_table.where(
                            '(rat==%d)&(day==%d)&(session==%d)'%rds)]
                    if request:
                        self.out('Found EC clusters: %s'%str(request)[1:-1])
                    else:
                        self.out('...no EC table clusters, skipping...')
                        continue
                else:
                    request = Criteria

                session_data = session.SessionData(rds=rds)
                dataset_duration += session_data.duration
                scan_list = \
                    [(scan['type'], scan['start'], scan['max'], scan['end'])
                        for scan in scan_table.where(session_data.session_query)
                        if scan['type'] != 'AMB']

                # Collate the spike counts, accumulating cells across sessions
                for tc in session_data.get_clusters(request):
                    cl_data = session_data.clusts[tc]
                    area = spike.get_tetrode_area(rat, day, cl_data.tt)
                    if area not in ROI_areas:
                        continue

                    # Initialize accumulators if first occurence of this cell
                    if (tc, area) not in collated:
                        collated.append((tc, area))
                        dataset_spikes[tc] = 0
                        for t, p in type_by_phase:
                            N[t][p][tc] = 0
                            duration[t][p][tc] = 0.0

                    t_spikes = cl_data.spikes
                    dataset_spikes[tc] += t_spikes.size

                    for scan_type, start, scan_max, end in scan_list:
                        if exclude_zero_scans and (count_spikes(t_spikes, start,
                            end) == 0):
                            continue

                        N[scan_type]['outbound'][tc] += \
                            count_spikes(t_spikes, start, scan_max)
                        N[scan_type]['inbound'][tc] += \
                            count_spikes(t_spikes, scan_max, end)
                        duration[scan_type]['outbound'][tc] += \
                            time.elapsed(start, scan_max)
                        duration[scan_type]['inbound'][tc] += \
                            time.elapsed(scan_max, end)

            self.out('Computing firing rates for %d cells...'%len(collated))
            for tc, area in collated:
                N_total = float(sum([N[t][p][tc] for t, p in type_by_phase]))
                duration_total = sum([duration[t][p][tc] for t, p in type_by_phase])
                if not duration_total:
                    continue

                record['cell'] = tc
                record['area'] = area

                scan_rate = N_total / duration_total
                scan_rate_data[area].append(scan_rate)
                record['scan_rate'] = scan_rate

                overall_rate = dataset_spikes[tc] / dataset_duration
                overall_nonscan_rate = \
                    (dataset_spikes[tc] - N_total) / (dataset_duration - duration_total)
                record['rate'] = overall_rate
                record['nonscan_rate'] = overall_nonscan_rate

                scan_nonscan_diff[area].append((scan_rate - overall_nonscan_rate) /
                    (scan_rate + overall_nonscan_rate))
                record['scan_nonscan_diff'] = scan_nonscan_diff[area][-1]

                cell_counts[area] += 1
                c_dict = cluster_data[area]
                for t, p in type_by_phase:
                    key = '%s_%s_diff'%(t, p)
                    record[key] = -99
                    if not (N_total and duration[t][p][tc]):
                        continue
                    c_dict[t][p].append(N[t][p][tc]/N_total)
                    this_scan_rate = N[t][p][tc]/duration[t][p][tc]
                    c_dict[t][p+'_rate'].append(this_scan_rate)
                    if this_scan_rate + scan_rate != 0:
                        c_dict[t][p+'_rate_norm'].append(
                            (this_scan_rate-scan_rate) /
                                (this_scan_rate+scan_rate))
                    if this_scan_rate + overall_nonscan_rate != 0:
                        c_dict[t][p+'_diff'].append(
                            (this_scan_rate-overall_nonscan_rate) /
                                (this_scan_rate+overall_nonscan_rate))
                        record[key] = c_dict[t][p+'_diff'][-1]
                spreadsheet.write_record(record)

        # Convert data to arrays
        for area in ROI_areas:
            self.out('Total cell count for %s: %d cells'%(area,
                cell_counts[area]))
            scan_rate_data[area] = np.array(scan_rate_data[area])
            scan_nonscan_diff[area] = np.array(scan_nonscan_diff[area])
            for scan_type in 'INT', 'EXT':
                c_dict = cluster_data[area][scan_type]
                for k in c_dict:
                    c_dict[k] = np.array(c_dict[k])

        spreadsheet.close()
        self.out("All done!")
예제 #9
0
    def collect_data(self,
                     min_quality='fair',
                     table_name='potentiation',
                     bins=Config['placefield']['default_bins'],
                     min_rate=Config['placefield']['min_peak_rate']):
        """Merge event information from a /physiology/<event_table> into a new
        table with rows for every place field in the database, allowing for
        proper bootstrapping of the fractional prevalence of potentiation
        events.
        """
        area_query = '(area=="CA1")|(area=="CA3")'
        bin_width = 360.0 / bins

        # Check modulation event table
        self.results['table_name'] = table_name
        sessions_table = get_node('/metadata', 'sessions')
        mod_table = get_node('/physiology', table_name)
        tetrode_table = get_node('/metadata', 'tetrodes')
        self.out('Using %s with %d rows.' %
                 (mod_table._v_pathname, mod_table.nrows))

        # Place-field tables and iterator
        data_file = self.open_data_file()
        field_table = data_file.createTable('/',
                                            'place_fields',
                                            FieldDescr,
                                            title='Place Field Event Data')
        row = field_table.row

        # Quality criterion
        Quality = get_min_quality_criterion(min_quality)

        self.out('Gathering place field data...')
        for dataset in TetrodeSelect.datasets(area_query,
                                              allow_ambiguous=True):
            rat, day = dataset

            Tetrodes = TetrodeSelect.criterion(dataset,
                                               area_query,
                                               allow_ambiguous=True)
            Criteria = AND(Quality, Tetrodes, PlaceCellCriteria)

            for session in sessions_table.where('(rat==%d)&(day==%d)' %
                                                dataset):
                rds = rat, day, session['session']

                # Set cluster criteria and load session data
                session_data = SessionData.get(rds)
                session_data.cluster_criteria = Criteria

                self.out.printf('Scanning: ', color='lightgray')
                for tc in session_data.get_clusters():

                    tt, cl = parse_cell_name(tc)
                    area = get_unique_row(
                        tetrode_table, '(rat==%d)&(day==%d)&(tt==%d)' %
                        (rat, day, tt))['area']

                    # Create firing-rate map to find fields for accurate field count
                    filter_kwds = dict(
                        velocity_filter=True,
                        exclude_off_track=True,
                        exclude=session_data.extended_scan_and_pause_list)
                    ratemap_kwds = dict(bins=bins, blur_width=bin_width)
                    ratemap_kwds.update(filter_kwds)

                    R_full = session_data.get_cluster_ratemap(
                        tc, **ratemap_kwds)
                    if R_full.max() < min_rate:
                        continue

                    for f, field in enumerate(mark_all_fields(R_full)):

                        # Check for any events for this field
                        field_num = f + 1
                        events = mod_table.getWhereList(
                            session_data.session_query +
                            '&(tc=="%s")&(fieldnum==%d)' % (tc, field_num))

                        if len(events):
                            self.out.printf(u'\u25a1', color='green')
                        else:
                            self.out.printf(u'\u25a1', color='red')

                        # Add row to place-field table
                        row['rat'] = rat
                        row['day'] = day
                        row['session'] = session['session']
                        row['tc'] = tc
                        row['num'] = field_num
                        row['type'] = session['type']
                        row['area'] = area
                        row['expt_type'] = session['expt_type']
                        row['events'] = len(events)
                        row.append()

                self.out.printf('\n')
            field_table.flush()
        self.out('All done!')
예제 #10
0
    def create_scan_cell_table(self, scan_phase='scan'):
        """For every scan–cell pair, compute the relative index of cell firing that
        occurred during the scan and previous cell firing on the track
        """
        scan_table_description = {
            'id': tb.UInt32Col(pos=1),
            'scan_id': tb.UInt16Col(pos=2),
            'rat': tb.UInt16Col(pos=3),
            'day': tb.UInt16Col(pos=4),
            'session': tb.UInt16Col(pos=5),
            'session_start_angle': tb.FloatCol(pos=6),
            'session_end_angle': tb.FloatCol(pos=7),
            'tc': tb.StringCol(itemsize=8, pos=8),
            'type': tb.StringCol(itemsize=4, pos=9),
            'expt_type': tb.StringCol(itemsize=4, pos=10),
            'area': tb.StringCol(itemsize=4, pos=11),
            'subdiv': tb.StringCol(itemsize=4, pos=12),
            'duration': tb.FloatCol(pos=13),
            'magnitude': tb.FloatCol(pos=14),
            'angle': tb.FloatCol(pos=15)
        }

        def add_scan_index_column_descriptors(descr):
            pos = 16
            for name in ScanIndex.AllNames:
                descr[name] = tb.FloatCol(pos=pos)
                pos += 1

        add_scan_index_column_descriptors(scan_table_description)

        data_file = self.get_data_file(mode='a')
        scan_cell_table = create_table(data_file,
                                       '/',
                                       'scan_cell_info',
                                       scan_table_description,
                                       title='Metadata for Scan-Cell Pairs')
        scan_cell_table._v_attrs['scan_phase'] = scan_phase
        row = scan_cell_table.row
        row_id = 0

        scans_table = get_node('/behavior', 'scans')
        sessions_table = get_node('/metadata', 'sessions')
        tetrodes_table = get_node('/metadata', 'tetrodes')

        cornu_ammonis_query = '(area=="CA1")|(area=="CA3")'
        hippocampal_datasets = unique_datasets('/metadata',
                                               'tetrodes',
                                               condn=cornu_ammonis_query)

        quality_place_cells = AND(get_min_quality_criterion(self.min_quality),
                                  PlaceCellCriteria)

        index = ScanIndex(scan_phase=scan_phase)

        for dataset in hippocampal_datasets:
            dataset_query = '(rat==%d)&(day==%d)' % dataset

            hippocampal_tetrodes = unique_values(
                tetrodes_table,
                column='tt',
                condn='(%s)&(%s)' % (dataset_query, cornu_ammonis_query))
            cluster_criteria = AND(
                quality_place_cells,
                get_tetrode_restriction_criterion(hippocampal_tetrodes))

            for maze in get_maze_list(*dataset):
                rds = dataset + (maze, )
                session = SessionData(rds=rds)
                place_cells = session.get_clusters(cluster_criteria)
                session_start_angle = np.median(
                    session.trajectory.alpha_unwrapped[:5])
                session_end_angle = np.median(
                    session.trajectory.alpha_unwrapped[-5:])

                self.out('Computing scan index for %s...' %
                         session.data_group._v_pathname)

                for scan in scans_table.where(session.session_query):
                    self.out.printf('|', color='cyan')

                    for cell in place_cells:
                        cluster = session.cluster_data(cell)

                        tt, cl = parse_cell_name(cluster.name)
                        tetrode = get_unique_row(
                            tetrodes_table, '(rat==%d)&(day==%d)&(tt==%d)' %
                            (rds[0], rds[1], tt))

                        row['id'] = row_id
                        row['scan_id'] = scan['id']
                        row['rat'], row['day'], row['session'] = rds
                        row['session_start_angle'] = session_start_angle
                        row['session_end_angle'] = session_end_angle
                        row['tc'] = cluster.name
                        row['type'] = session.attrs['type']
                        row['expt_type'] = get_unique_row(
                            sessions_table, session.session_query)['expt_type']
                        row['area'] = tetrode['area']
                        row['subdiv'] = tetrode['area'] + tetrode['subdiv'][:1]
                        row['angle'] = session.F_('alpha_unwrapped')(
                            session.T_(scan['start']))
                        row['duration'] = scan['duration']
                        row['magnitude'] = scan['magnitude']

                        for index_name in ScanIndex.AllNames:
                            row[index_name] = index.compute(
                                index_name, session, scan, cluster)

                        self.out.printf('.', color='green')
                        row_id += 1
                        row.append()

                        if row_id % 100 == 0:
                            scan_cell_table.flush()

                self.out.printf('\n')

        scan_cell_table.flush()
        self.out('Finished creating %s.' % scan_cell_table._v_pathname)