Esempio n. 1
0
 def load_behavior_timing(self, key):
     log.info('Loading behavior frametimes')
     # -- find number of recording depths
     pipe = (fuse.Activity() & key).fetch('pipe')
     assert len(
         np.unique(pipe)) == 1, 'Selection is from different pipelines'
     pipe = dj.create_virtual_module(pipe[0], 'pipeline_' + pipe[0])
     k = dict(key)
     k.pop('field', None)
     ndepth = len(dj.U('z') & (pipe.ScanInfo.Field() & k))
     return (stimulus.BehaviorSync()
             & key).fetch1('frame_times').squeeze()[0::ndepth]
Esempio n. 2
0
    def make(self, key):
        log.info(80 * '-')
        log.info('Processing key ' + pformat(dict(key)))

        # get original frame
        frame = self.load_frame(key)

        # preprocess the frame
        frame = process_frame(key, frame)

        # --- generate response sampling points and sample movie frames relative to it
        self.insert1(dict(key, frame=frame))
Esempio n. 3
0
def process_frame(preproc_key, frame):
    """
    Helper function that preprocesses a frame
    """
    import cv2
    imgsize = (Preprocessing() & preproc_key).fetch1(
        'col', 'row')  # target size of movie frames
    log.info('Downsampling frame')
    if not frame.shape[0] / imgsize[1] == frame.shape[1] / imgsize[0]:
        log.warning('Image size would change aspect ratio.')

    return cv2.resize(frame, imgsize,
                      interpolation=cv2.INTER_AREA).astype(np.float32)
Esempio n. 4
0
    def fetch_data(self, key, key_order=None):
        assert len(self
                   & key) == 1, 'Key must refer to exactly one multi dataset'
        ret = OrderedDict()
        log.info('Fetching data for ' + repr(key))
        for mkey in (self.Member() & key).fetch(
                dj.key,
                order_by=
                'animal_id ASC, session ASC, scan_idx ASC, preproc_id ASC'):
            name = (self.Member() & mkey).fetch1('name')
            include_behavior = bool(Eye().proj() * Treadmill().proj() & mkey)
            data_names = ['images', 'responses'] if not include_behavior \
                else ['images',
                      'behavior',
                      'pupil_center',
                      'responses']
            log.info('Data will be ({})'.format(','.join(data_names)))

            h5filename = InputResponse().get_filename(mkey)
            log.info('Loading dataset {} --> {}'.format(name, h5filename))
            ret[name] = datasets.StaticImageSet(h5filename, *data_names)
        if key_order is not None:
            log.info(
                'Reordering datasets according to given key order {}'.format(
                    ', '.join(key_order)))
            ret = OrderedDict([(k, ret[k]) for k in key_order])
        return ret
Esempio n. 5
0
    def add_transforms(key, datasets, exclude=None):
        warnings.warn('You are using an outdated `add_transform` kept for backward compatibility. Do not use this in new networks.')
        if exclude is not None:
            log.info('Excluding "{}" from normalization'.format(
                '", "'.join(exclude)))
        for k, dataset in datasets.items():
            transforms = []
            if key['normalize']:
                transforms.append(Normalizer(
                    dataset, stats_source=key['stats_source'],
                    buggy=True, normalize_per_image=True, exclude=exclude))
            transforms.append(ToTensor())
            dataset.transforms = transforms

        return datasets
Esempio n. 6
0
    def load_eye_traces(self, key):
        #r, center = (pupil.FittedPupil.Ellipse() & key).fetch('major_r', 'center', order_by='frame_id ASC')
        r, center = (pupil.FittedPupil.Circle() & key).fetch(
            'radius', 'center', order_by='frame_id')
        detectedFrames = ~np.isnan(r)
        xy = np.full((len(r), 2), np.nan)
        xy[detectedFrames, :] = np.vstack(center[detectedFrames])
        xy = np.vstack(map(partial(fill_nans, preserve_gap=3), xy.T))
        if np.any(np.isnan(xy)):
            log.info('Keeping some nans in the pupil location trace')
        pupil_radius = fill_nans(r.squeeze(), preserve_gap=3)
        if np.any(np.isnan(pupil_radius)):
            log.info('Keeping some nans in the pupil radius trace')

        eye_time = (pupil.Eye() & key).fetch1('eye_time').squeeze()
        return pupil_radius, xy, eye_time
Esempio n. 7
0
    def add_transforms(key, datasets, exclude=None):
        if exclude is not None:
            log.info('Excluding "{}" from normalization'.format(
                '", "'.join(exclude)))
        for k, dataset in datasets.items():
            transforms = []

            if key.get('normalize', True):
                transforms.append(Normalizer(
                    dataset,
                    stats_source=key.get('stats_source', 'all'),
                    normalize_per_image=key.get('normalize_per_image', False),
                    exclude=exclude))
            transforms.append(ToTensor())
            dataset.transforms = transforms

        return datasets
Esempio n. 8
0
    def stimulus_onset(flip_times, duration):
        n_ft = np.unique([ft.size for ft in flip_times])
        assert len(n_ft) == 1, 'Found inconsistent number of fliptimes'
        n_ft = int(n_ft)
        log.info('Found {} flip times'.format(n_ft))

        assert n_ft in (2, 3), 'Cannot deal with {} flip times'.format(n_ft)

        stimulus_onset = np.vstack(flip_times)  # columns correspond to  clear flip, onset flip
        ft = stimulus_onset[np.argsort(stimulus_onset[:, 0])]
        if n_ft == 2:
            assert np.median(ft[1:, 0] - ft[:-1, 1]) < duration + 0.05, 'stimulus duration off by more than 50ms'
        else:
            assert np.median(ft[:, 2] - ft[:, 1]) < duration + 0.05, 'stimulus duration off by more than 50ms'
        stimulus_onset = stimulus_onset[:, 1]

        return stimulus_onset
Esempio n. 9
0
    def get_trace_spline(self, key, sampling_period):
        traces, frame_times, trace_keys = self.load_traces_and_frametimes(key)
        log.info('Loaded {} traces'.format(len(traces)))

        log.info('Generating lowpass filters to {}Hz'.format(1 /
                                                             sampling_period))
        h_trace = self.get_filter(sampling_period,
                                  np.median(np.diff(frame_times)),
                                  'hamming',
                                  warning=False)
        # low pass filter
        trace_spline = SplineCurve(
            frame_times,
            [np.convolve(trace, h_trace, mode='same') for trace in traces],
            k=1,
            ext=1)
        return trace_spline, trace_keys, frame_times.min(), frame_times.max()
Esempio n. 10
0
        def load_data(self,
                      key,
                      tier=None,
                      batch_size=1,
                      key_order=None,
                      stimulus_types=None,
                      Sampler=None):
            from .stats import Oracle
            datasets, loaders = super().load_data(
                key,
                tier=tier,
                batch_size=batch_size,
                key_order=key_order,
                stimulus_types=stimulus_types,
                Sampler=Sampler)
            for rok, dataset in datasets.items():
                member_key = (StaticMultiDataset.Member() & key
                              & dict(name=rok)).fetch1(dj.key)

                okey = dict(key, **member_key)
                okey['data_hash'] = okey.pop('oracle_source')
                units, pearson = (Oracle.UnitScores() & okey).fetch(
                    'unit_id', 'pearson')
                assert len(
                    pearson
                ) > 0, 'You forgot to populate oracle for data_hash="{}"'.format(
                    key['oracle_source'])
                assert len(units) == len(
                    dataset.neurons.unit_ids), 'Number of neurons has changed'
                assert np.all(units == dataset.neurons.unit_ids
                              ), 'order of neurons has changed'

                low, high = np.percentile(
                    pearson, [key['percent_low'], key['percent_high']])
                selection = (pearson >= low) & (pearson <= high)
                log.info(
                    'Subsampling to {} neurons above {:.2f} and below {} oracle'
                    .format(selection.sum(), low, high))
                dataset.transforms.insert(-1,
                                          Subsample(np.where(selection)[0]))

                assert np.all(dataset.neurons.unit_ids ==
                              units[selection]), 'Units are inconsistent'
            return datasets, loaders
Esempio n. 11
0
    def load_data(self,
                  key,
                  tier=None,
                  batch_size=1,
                  key_order=None,
                  stimulus_types=None,
                  Sampler=None,
                  **kwargs):
        log.info('Ignoring input arguments: "' + '", "'.join(kwargs.keys()) +
                 '"' + 'when creating datasets')
        exclude = key.pop('exclude').split(',')
        stimulus_types = key.pop('stimulus_type')
        datasets, loaders = super().load_data(
            key,
            tier,
            batch_size,
            key_order,
            exclude_from_normalization=exclude,
            stimulus_types=stimulus_types,
            Sampler=Sampler)

        log.info('Subsampling to layer {} and area(s) "{}"'.format(
            key['layer'],
            key.get('brain_area') or key['brain_areas']))
        for readout_key, dataset in datasets.items():
            layers = dataset.neurons.layer
            areas = dataset.neurons.area

            layer_idx = (layers == key['layer'])
            desired_areas = ([
                key['brain_area'],
            ] if 'brain_area' in key else (common_configs.BrainAreas.BrainArea
                                           & key).fetch('brain_area'))
            area_idx = np.stack([areas == da
                                 for da in desired_areas]).any(axis=0)
            idx = np.where(layer_idx & area_idx)[0]
            if len(idx) == 0:
                log.warning('Empty set of neurons. Deleting this key')
                del datasets[readout_key]
                del loaders[readout_key]
            else:
                dataset.transforms.insert(-1, Subsample(idx))
        return datasets, loaders
Esempio n. 12
0
 def get_constraint(dataset, stimulus_type, tier=None):
     """
     Find subentries of dataset that matches the given `stimulus_type` specification and `tier` specification.
     `stimulus_type` is of the format `stimulus.Frame|~stimulus.Monet|...`. This function returns a boolean array
     suitable to be used for boolean indexing to obtain only entries with data types and tiers matching the
     specified condition.
     """
     constraint = np.zeros(len(dataset.types), dtype=bool)
     for const in map(lambda s: s.strip(), stimulus_type.split('|')):
         if const.startswith('~'):
             log.info('Using all trial but from {}'.format(const[1:]))
             tmp = (dataset.types != const[1:])
         else:
             log.info('Using all trial from {}'.format(const))
             tmp = (dataset.types == const)
         constraint = constraint | tmp
     if tier is not None:
         constraint = constraint & (dataset.tiers == tier)
     return constraint
Esempio n. 13
0
    def load_data(self,
                  key,
                  tier=None,
                  batch_size=1,
                  key_order=None,
                  exclude_from_normalization=None,
                  stimulus_types=None,
                  Sampler=None):
        log.info('Loading {} dataset with tier={}'.format(
            self._stimulus_type, tier))
        datasets = StaticMultiDataset().fetch_data(key, key_order=key_order)
        for k, dat in datasets.items():
            if 'stats_source' in key:
                log.info(
                    'Adding stats_source "{stats_source}" to dataset'.format(
                        **key))
                dat.stats_source = key['stats_source']

        log.info('Using statistics source ' + key['stats_source'])

        datasets = self.add_transforms(key,
                                       datasets,
                                       exclude=exclude_from_normalization)

        loaders = self.get_loaders(datasets, tier, batch_size, stimulus_types,
                                   Sampler)
        return datasets, loaders
Esempio n. 14
0
    def make(self, scan_key):
        log.info('Populating\n' + pformat(scan_key))
        v, treadmill_time = self.load_treadmill_velocity(scan_key)
        frame_times = self.load_frame_times(scan_key)
        behavior_clock = self.load_behavior_timing(scan_key)

        if len(frame_times) - len(behavior_clock) != 0:
            assert abs(len(frame_times) - len(behavior_clock)) < 2, 'Difference bigger than 2 time points'
            l = min(len(frame_times), len(behavior_clock))
            log.warning('Frametimes and stimulus.BehaviorSync differ in length! Shortening it.')
            frame_times = frame_times[:l]
            behavior_clock = behavior_clock[:l]

        fr2beh = NaNSpline(frame_times, behavior_clock, k=1, ext=3)
        duration, offset = map(float, (Preprocessing() & scan_key).fetch1('duration', 'offset'))
        sample_point = offset + duration / 2

        log.info('Downsampling treadmill signal to {}Hz'.format(1 / duration))

        h_tread = self.get_filter(duration, np.nanmedian(np.diff(treadmill_time)), 'hamming', warning=True)
        treadmill_spline = NaNSpline(treadmill_time, np.abs(np.convolve(v, h_tread, mode='same')), k=1, ext=0)

        flip_times = (InputResponse.Input * Frame * stimulus.Trial & scan_key).fetch('flip_times',
                                                                                     order_by='row_id ASC')

        flip_times = [ft.squeeze() for ft in flip_times]

        # If no Frames are present, skip this scan
        if len(flip_times) == 0:
            log.warning('No static frames were present to be processed for {}'.format(scan_key))
            return

        stimulus_onset = InputResponse.stimulus_onset(flip_times, duration)
        tm = treadmill_spline(fr2beh(stimulus_onset + sample_point))
        valid = ~np.isnan(tm)
        if not np.all(valid):
            log.warning('Found {} NaN trials. Setting to -1'.format((~valid).sum()))
            tm[~valid] = -1

        self.insert1(dict(scan_key, treadmill=tm, valid=valid))
Esempio n. 15
0
    def make(self, scan_key):
        self.insert1(scan_key)
        # integration window size for responses
        duration, offset = map(float, (Preprocessing() & scan_key).fetch1('duration', 'offset'))
        sample_point = offset + duration / 2

        log.info('Sampling neural responses at {}s intervals'.format(duration))

        trace_spline, trace_keys, ftmin, ftmax = self.get_trace_spline(scan_key, duration)
        # exclude trials marked in ExcludedTrial
        log.info('Excluding {} trials based on ExcludedTrial'.format(len(ExcludedTrial() & scan_key)))
        flip_times, trial_keys = (Frame * (stimulus.Trial - ExcludedTrial) & scan_key).fetch('flip_times', dj.key,
                                                                           order_by='condition_hash')
        flip_times = [ft.squeeze() for ft in flip_times]

        # If no Frames are present, skip this scan
        if len(flip_times) == 0:
            log.warning('No static frames were present to be processed for {}'.format(scan_key))
            return

        valid = np.array([ft.min() >= ftmin and ft.max() <= ftmax for ft in flip_times], dtype=bool)
        if not np.all(valid):
            log.warning('Dropping {} trials with dropped frames or flips outside the recording interval'.format(
                (~valid).sum()))

        stimulus_onset = self.stimulus_onset(flip_times, duration)
        log.info('Sampling {} responses {}s after stimulus onset'.format(valid.sum(), sample_point))
        R = trace_spline(stimulus_onset[valid] + sample_point, log=True).T

        self.ResponseBlock.insert1(dict(scan_key, responses=R))
        self.ResponseKeys.insert([dict(scan_key, **trace_key, col_id=i) for i, trace_key in enumerate(trace_keys)])
        self.Input.insert([dict(scan_key, **trial_key, row_id=i)
                           for i, trial_key in enumerate(compress(trial_keys, valid))])
Esempio n. 16
0
        def stats(self, condition_hashes, images, responses, tiers):

            key = self.fetch1()
            # check if the method is eligible for condition_hashes requested
            assert ((
                stimulus.Condition
                & "condition_hash in {}".format(tuple(condition_hashes))
            ).fetch("stimulus_type") == "stimulus.Frame").all(
            ), "StatsConfig.NeuroStaticNoBehFrame is only implemented for stimulus.Frame"
            image_classes = (stimulus.Condition * stimulus.Frame
                             & "condition_hash in {}".format(
                                 tuple(condition_hashes))).fetch("image_class")
            assert set(image_classes) <= set(
                FF_CLASSES
            ), "StatsConfig.NeuroStaticNoBehFrame is only implemented for full-field stimulus"

            # reshape inputs
            images = np.stack(images)
            if len(images.shape) == 3:
                log.info("Adding channel dimension")
                images = images[:, None, ...]
            elif len(images.shape) == 4:
                images = images.transpose(0, 3, 1, 2)

            # compute stats
            if key["stats_tier"] in ("train", "validation", "test"):
                ix = tiers == key["stats_tier"]
            elif key["stats_tier"] == "all":
                ix = np.ones_like(tiers, dtype=bool)
            else:
                raise NotImplementedError(
                    "stats_tier must be one of train, validation, test, all")

            response_statistics = self.run_stats_resp(responses, ix, axis=0)
            input_statistics = self.run_stats_input(
                images, ix, per_input=key["stats_per_input"])
            statistics = dict(images=input_statistics,
                              responses=response_statistics)
            return statistics
Esempio n. 17
0
    def get_loaders(self, datasets, tier, batch_size, stimulus_types, Sampler):
        """

        Args:
            datasets: a dictionary of H5ArrayDataSets
            tier: tier of data to be loaded. Can be 'train', 'validation', 'test', or None
            batch_size: size of a batch to be returned by the data loader
            stimulus_types: stimulus type specification like 'stimulus.Frame|~stimulus.Monet'
            Sampler: sampler to be placed on the data loader. If None, defaults to a sampler chosen based on the tier

        Returns:
            A dictionary of DataLoader's, key paired to each dataset
        """

        # if Sampler not given, use a default one specified for each tier
        if Sampler is None:
            Sampler = self.get_sampler_class(tier)

        # if only a single stimulus_types string was given, apply to all datasets
        if not isinstance(stimulus_types, list):
            log.info('Using {} as stimulus type for all datasets'.format(
                stimulus_types))
            stimulus_types = len(datasets) * [stimulus_types]

        log.info('Stimulus sources: "{}"'.format('","'.join(stimulus_types)))

        loaders = OrderedDict()
        constraints = [
            self.get_constraint(dataset, stimulus_type,
                                tier=tier) for dataset, stimulus_type in zip(
                                    datasets.values(), stimulus_types)
        ]

        for (k, dataset), stimulus_type, constraint in zip(
                datasets.items(), stimulus_types, constraints):
            log.info(
                'Selecting trials from {} and tier={} for dataset {}'.format(
                    stimulus_type, tier, k))
            ix = np.where(constraint)[0]
            log.info('Found {} active trials'.format(constraint.sum()))
            if Sampler is BalancedSubsetSampler:
                sampler = Sampler(ix, dataset.types, mode='longest')
            else:
                sampler = Sampler(ix)
            loaders[k] = DataLoader(dataset,
                                    sampler=sampler,
                                    batch_size=batch_size)
            self.log_loader(loaders[k])
        return loaders
Esempio n. 18
0
        def load_data(self,
                      key,
                      tier=None,
                      batch_size=1,
                      Sampler=None,
                      t_first=False,
                      cuda=False):
            from .stats import BootstrapOracleTTest
            assert tier in [None, 'train', 'validation', 'test']
            datasets, loaders = super().load_data(key,
                                                  tier=tier,
                                                  batch_size=batch_size,
                                                  Sampler=Sampler,
                                                  cuda=cuda)
            for rok, dataset in datasets.items():
                member_key = (StaticMultiDataset.Member() & key
                              & dict(name=rok)).fetch1(dj.key)
                all_units, all_pvals = (BootstrapOracleTTest.UnitPValue
                                        & member_key).fetch(
                                            'unit_id', 'unit_p_value')
                assert len(all_pvals) > 0, \
                    'You forgot to populate BootstrapOracleTTest for group_id={}'.format(
                    member_key['group_id'])
                units_mask = np.isin(all_units, dataset.neurons.unit_ids)
                units, pvals = all_units[units_mask], all_pvals[units_mask]
                assert np.all(units == dataset.neurons.unit_ids
                              ), 'order of neurons has changed'
                pval_thresh = np.power(10, float(key['p_val_power']))
                selection = pvals < pval_thresh
                log.info(
                    'Subsampling to {} neurons with BootstrapOracleTTest p-val < {:.0E}'
                    .format(selection.sum(), pval_thresh))
                dataset.transforms.insert(-1,
                                          Subsample(np.where(selection)[0]))

                assert np.all(dataset.neurons.unit_ids ==
                              units[selection]), 'Units are inconsistent'
            return datasets, loaders
Esempio n. 19
0
 def log_loader(loader):
     """
     A helper function that when given an instance of DataLoader, print out a log detailing its configuration
     """
     log.info('Loader sampler is {}'.format(
         loader.sampler.__class__.__name__))
     log.info('Number of samples in the loader will be {}'.format(
         len(loader.sampler)))
     log.info('Number of batches in the loader will be {}'.format(
         int(np.ceil(len(loader.sampler) / loader.batch_size))))
Esempio n. 20
0
    def make(self, scan_key):
        scan_key = {**scan_key, 'tracking_method': 2}
        log.info('Populating '+ pformat(scan_key))
        radius, xy, eye_time = self.load_eye_traces(scan_key)
        frame_times = self.load_frame_times(scan_key)
        behavior_clock = self.load_behavior_timing(scan_key)

        if len(frame_times) - len(behavior_clock) != 0:
            assert abs(len(frame_times) - len(behavior_clock)) < 2, 'Difference bigger than 2 time points'
            l = min(len(frame_times), len(behavior_clock))
            log.info('Frametimes and stimulus.BehaviorSync differ in length! Shortening it.')
            frame_times = frame_times[:l]
            behavior_clock = behavior_clock[:l]

        fr2beh = NaNSpline(frame_times, behavior_clock, k=1, ext=3)

        duration, offset = map(float, (Preprocessing() & scan_key).fetch1('duration', 'offset'))
        sample_point = offset + duration / 2

        log.info('Downsampling eye signal to {}Hz'.format(1 / duration))
        deye = np.nanmedian(np.diff(eye_time))
        h_eye = self.get_filter(duration, deye, 'hamming', warning=True)
        h_deye = self.get_filter(duration, deye, 'dhamming', warning=True)
        pupil_spline = NaNSpline(eye_time,
                                 np.convolve(radius, h_eye, mode='same'), k=1, ext=0)

        dpupil_spline = NaNSpline(eye_time,
                                  np.convolve(radius, h_deye, mode='same'), k=1, ext=0)
        center_spline = SplineCurve(eye_time,
                                    np.vstack([np.convolve(coord, h_eye, mode='same') for coord in xy]),
                                    k=1, ext=0)

        flip_times = (InputResponse.Input * Frame * stimulus.Trial & scan_key).fetch('flip_times',
                                                                                     order_by='row_id ASC')

        flip_times = [ft.squeeze() for ft in flip_times]

        # If no Frames are present, skip this scan
        if len(flip_times) == 0:
            log.warning('No static frames were present to be processed for {}'.format(scan_key))
            return

        stimulus_onset = InputResponse.stimulus_onset(flip_times, duration)
        t = fr2beh(stimulus_onset + sample_point)
        pupil = pupil_spline(t)
        dpupil = dpupil_spline(t)
        center = center_spline(t)
        valid = ~np.isnan(pupil + dpupil + center.sum(axis=0))
        if not np.all(valid):
            log.warning('Found {} NaN trials. Setting to -1'.format((~valid).sum()))
            pupil[~valid] = -1
            dpupil[~valid] = -1
            center[:, ~valid] = -1

        self.insert1(dict(scan_key, pupil=pupil, dpupil=dpupil, center=center, valid=valid))
Esempio n. 21
0
    def fetch_data(self, key, key_order=None):
        ret = OrderedDict()
        log.info("Fetching data for " + repr(key))
        for mkey in (self.Member() * DatasetInputResponse & key).fetch(
                dj.key,
                order_by=
                "animal_id ASC, session ASC, scan_idx ASC, preproc_id ASC"):
            name = (self.Member() & mkey).fetch1("name")
            data_names = DatasetConfig().part_table(mkey).data_names
            log.info("Data will be ({})".format(",".join(data_names)))

            h5filename = DatasetConfig().part_table(mkey).get_filename()
            log.info("Loading dataset {} --> {}".format(name, h5filename))
            ret[name] = datasets.StaticImageSet(h5filename, *data_names)
        if key_order is not None:
            log.info(
                "Reordering datasets according to given key order {}".format(
                    ", ".join(key_order)))
            ret = OrderedDict([(k, ret[k]) for k in key_order])
        return ret
Esempio n. 22
0
    def load_data(self, key, cuda=False, oracle=False, **kwargs):
        data_key = self.data_key(key)
        Data = getattr(self, data_key.pop('data_type'))
        datasets, loaders = Data().load_data(data_key, **kwargs)

        if oracle:
            log.info('Placing oracle data samplers')
            for readout_key, loader in loaders.items():
                ix = loader.sampler.indices
                # types = np.unique(datasets[readout_key].types[ix])
                # if len(types) == 1 and types[0] == 'stimulus.Frame':
                #     condition_hashes = datasets[readout_key].info.frame_image_id
                # elif len(types) == 2 and types[0] in ('stimulus.MonetFrame',  'stimulus.TrippyFrame'):
                #     condition_hashes = datasets[readout_key].condition_hashes
                # elif len(types) == 1 and types[0] == 'stimulus.ColorFrameProjector':
                #     condition_hashes = datasets[readout_key].info.colorframeprojector_image_id
                # else:
                #     raise ValueError('Do not recognize types={}'.format(*types))
                condition_hashes = datasets[readout_key].condition_hashes
                log.info('Replacing ' + loader.sampler.__class__.__name__ +
                         ' with RepeatsBatchSampler')
                Loader = loader.__class__
                loaders[readout_key] = Loader(
                    loader.dataset,
                    batch_sampler=RepeatsBatchSampler(condition_hashes,
                                                      subset_index=ix))

                removed = []
                keep = []
                for tr in datasets[readout_key].transforms:
                    if isinstance(tr, (Subsample, ToTensor)):
                        keep.append(tr)
                    else:
                        removed.append(tr.__class__.__name__)
                datasets[readout_key].transforms = keep
                if len(removed) > 0:
                    log.warning(
                        'Removed the following transforms: "{}"'.format(
                            '", "'.join(removed)))

        log.info('Setting cuda={}'.format(cuda))
        for dat in datasets.values():
            for tr in dat.transforms:
                if isinstance(tr, ToTensor):
                    tr.cuda = cuda

        return datasets, loaders
Esempio n. 23
0
    def make(self, key):
        log.info(80 * '-')
        log.info('Processing ' + pformat(key))
        # count the number of distinct conditions presented for each one of three stimulus types:
        # "stimulus.Frame","stimulus.MonetFrame", "stimulus.TrippyFrame"
        conditions = dj.U('stimulus_type').aggr(stimulus.Condition() & (stimulus.Trial() & key),
                                                count='count(*)') \
                     & 'stimulus_type in ("stimulus.Frame", "stimulus.MonetFrame", "stimulus.TrippyFrame", "stimulus.ColorFrameProjector")'
        for cond in conditions.fetch(as_dict=True):
            # hack for compatibility with previous datasets
            if cond['stimulus_type'] in [
                    'stimulus.Frame', 'stimulus.ColorFrameProjector'
            ]:
                frame_table = (stimulus.Frame
                               if cond['stimulus_type'] == 'stimulus.Frame'
                               else stimulus.ColorFrameProjector)

                # deal with ImageNet frames first
                log.info('Inserting assignment from ImageNetSplit')
                targets = StaticScan * frame_table * ImageNetSplit & (
                    stimulus.Trial & key) & IMAGE_CLASSES
                print('Inserting {} imagenet conditions!'.format(len(targets)))
                self.insert(targets, ignore_extra_fields=True)

                # deal with MEI images, assigning tier test for all images
                assignment = (
                    frame_table &
                    'image_class in ("cnn_mei", "lin_rf", "multi_cnn_mei", "multi_lin_rf")'
                ).proj(tier='"train"')
                self.insert(StaticScan * frame_table * assignment &
                            (stimulus.Trial & key),
                            ignore_extra_fields=True)

                # make sure that all frames were assigned
                remaining = (stimulus.Trial * frame_table & key) - self
                assert len(
                    remaining) == 0, 'There are still unprocessed Frames'
                continue

            log.info('Checking condition {stimulus_type} (n={count})'.format(
                **cond))
            frames = (stimulus.Condition() * StaticScan() & key & cond).aggr(
                stimulus.Trial(), repeats="count(*)", test='count(*) > 4')
            self.check_train_test_split(frames, cond)

            m = len(frames)
            m_test = m_val = len(frames & 'test > 0') or max(m * 0.075, 1)
            log.info('Minimum test and validation set size will be {}'.format(
                m_test))
            log.info('Processing test conditions')

            # insert repeats as test trials
            self.insert((frames & dict(test=1)).proj(tier='"test"'),
                        ignore_extra_fields=True)
            self.fill_up('test', frames, cond, key, m_test)

            log.info('Processing validation conditions')
            self.fill_up('validation', frames, cond, key, m_val)

            log.info('Processing training conditions')
            self.fill_up('train', frames, cond, key, m - m_test - m_val)
Esempio n. 24
0
    def make(self, key):
        # Check new preprocessing
        if key['preproc_id'] < 4:
            raise ValueError(
                'Deprecated preprocessing, use preproc_id > 4 or downgrade '
                'code to access previous preprocessings.')

        # Get all traces for this scan
        log.info('Getting traces...')
        traces, unit_ids, trace_times = get_traces(key)

        # Get trial times for frames in Scan.Frame (excluding bad trials)
        log.info('Getting onset and offset times for each image...')
        trials_rel = stimulus.Trial * Frame - ExcludedTrial & key
        flip_times, trial_ids, cond_hashes = trials_rel.fetch(
            'flip_times',
            'trial_idx',
            'condition_hash',
            order_by='condition_hash',
            squeeze=True)

        # Find start and duration of image frames
        image_onset = np.stack([get_image_onset(ft)
                                for ft in flip_times])  # start of image
        image_duration = float(
            (Preprocessing & key).fetch1('duration')
        )  # np.stack([ft[2] for ft in flip_times]) - image_onset

        # Add a shift to the onset times to account for the time it takes for the image to
        # travel from the retina to V1
        image_onset += float((Preprocessing & key).fetch1('offset'))
        # Wiskott, L. How does our visual system achieve shift and size invariance?. Problems in Systems Neuroscience, 2003.

        # Sample responses (trace by trace) with a rectangular window
        log.info('Sampling responses...')
        image_resps = np.stack([
            trapezoid_integration(tt, t, image_onset, image_onset +
                                  image_duration) / image_duration
            for tt, t in zip(trace_times, traces)
        ],
                               axis=-1)

        # Insert
        log.info('Inserting...')
        self.insert1(key)
        self.ResponseBlock.insert1({
            **key, 'responses':
            image_resps.astype(np.float32)
        })
        self.Input.insert([{
            **key, 'trial_idx': trial_idx,
            'condition_hash': cond_hash,
            'row_id': i
        } for i, (trial_idx,
                  cond_hash) in enumerate(zip(trial_ids, cond_hashes))])
        self.ResponseKeys.insert([{
            **key, 'unit_id':
            unit_id,
            'col_id':
            i,
            'field': (fuse.Activity.Trace & key & {
                'unit_id': unit_id
            }).fetch1('field'),
            'channel': (fuse.Activity.Trace & key & {
                'unit_id': unit_id
            }).fetch1('channel')
        } for i, unit_id in enumerate(unit_ids)])
Esempio n. 25
0
 def compute_data(self, key=None):
     key = self.fetch1() if key is None else key
     dynamic_scan = (DvScanInfo
                     & {
                         **key,
                         "animal_id": key["animal_id"],
                         "session": key["dynamic_session"],
                         "scan_idx": key["dynamic_scan_idx"],
                     }).fetch1("KEY")
     static_scan = (StaticScan()
                    & {
                        **key,
                        "animal_id": key["animal_id"],
                        "session": key["static_session"],
                        "scan_idx": key["static_scan_idx"],
                    }).fetch1("KEY")
     log.info("Fecthing images")
     trial_idx, condition_hashes, images, types = (
         InputConfig().part_table(key).input(static_scan))
     log.info("Fetching responses")
     responses = (DvScanInfo & key).responses(
         trial_idx=trial_idx,
         condition_hashes=condition_hashes,
     )
     dynamic_unit_keys = (DvScanInfo & key).unit_keys()
     log.info("Fecthing tiers")
     tiers = TierConfig().part_table(key).tier(static_scan,
                                               condition_hashes)
     log.info("Fecthing layer information")
     layer = LayerConfig().part_table(key).layer(dynamic_unit_keys)
     log.info("Fecthing area information")
     area = AreaConfig().part_table(key).area(dynamic_unit_keys)
     log.info("Computing stats")
     statistics = (StatsConfig().part_table(key).stats(
         condition_hashes, images, responses, tiers))
     neurons = dict(
         unit_ids=np.array([k["unit_id"] for k in dynamic_unit_keys
                            ]).astype(np.uint16),
         animal_ids=np.array([
             k["animal_id"] for k in dynamic_unit_keys
         ]).astype(np.uint16),
         sessions=np.array([k["session"] for k in dynamic_unit_keys
                            ]).astype(np.uint8),
         scan_idx=np.array([k["scan_idx"] for k in dynamic_unit_keys
                            ]).astype(np.uint8),
         layer=layer.astype("S"),
         area=area.astype("S"),
     )
     return dict(
         images=images,
         responses=responses,
         types=types.astype("S"),
         condition_hashes=condition_hashes.astype("S"),
         trial_idx=trial_idx.astype(np.uint32),
         neurons=neurons,
         tiers=tiers.astype("S"),
         statistics=statistics,
     )
Esempio n. 26
0
    def compute_data(self, key):
        key = dict((self & key).fetch1(dj.key), **key)
        log.info('Computing dataset for\n' + pformat(key, indent=20))

        # meso or reso?
        pipe = (fuse.ScanDone() * StaticScan() & key).fetch1('pipe')
        pipe = dj.create_virtual_module(pipe, 'pipeline_' + pipe)

        # get data relation
        include_behavior = bool(Eye.proj() * Treadmill.proj() & key)

        assert include_behavior, 'Behavior data is missing!'

        # make sure that including areas and layers does not decrease number of neurons
        assert len(pipe.ScanSet.UnitInfo() * experiment.Layer() * anatomy.AreaMembership() * anatomy.LayerMembership() & key) == \
               len(pipe.ScanSet.UnitInfo() & key), "AreaMembership decreases number of neurons"

        responses = (self.ResponseBlock & key).fetch1('responses')
        trials = Frame() * ConditionTier() * self.Input() * stimulus.Condition(
        ).proj('stimulus_type') & key
        hashes, trial_idxs, tiers, types, images = trials.fetch(
            'condition_hash',
            'trial_idx',
            'tier',
            'stimulus_type',
            'frame',
            order_by='row_id')
        images = np.stack(images)
        if len(images.shape) == 3:
            log.info('Adding channel dimension')
            images = images[:, None, ...]
        elif len(images.shape) == 4:
            images = images.transpose(0, 3, 1, 2)
        hashes = hashes.astype(str)
        types = types.astype(str)

        # gamma correction
        if (Preprocessing & key).fetch1('gamma'):
            log.info('Gamma correcting images.')
            from staticnet_analyses import multi_mei

            if len(multi_mei.ClosestCalibration & key) == 0:
                raise ValueError('No ClosestMonitorCalibration for this scan.')
            f, f_inv = (multi_mei.ClosestCalibration & key).get_fs()
            images = f(images)

        # --- extract infomation for each trial
        extra_info = pd.DataFrame({
            'condition_hash': hashes,
            'trial_idx': trial_idxs
        })
        dfs = OrderedDict()

        # add information about each stimulus
        for t in map(lambda x: x.split('.')[1], np.unique(types)):
            stim = getattr(stimulus, t)
            rel = stim() * stimulus.Trial() & key
            df = pd.DataFrame(rel.proj(*rel.heading.non_blobs).fetch())
            dfs[t] = df

        on = [
            'animal_id', 'condition_hash', 'scan_idx', 'session', 'trial_idx'
        ]
        for t, df in dfs.items():
            mapping = {
                c: (t.lower() + '_' + c)
                for c in set(df.columns) - set(on)
            }
            dfs[t] = df.rename(index=str, columns=mapping)
        df = list(dfs.values())[0]
        for d in list(dfs.values())[1:]:
            df = df.merge(d, how='outer', on=on)
        extra_info = extra_info.merge(df, on=['condition_hash', 'trial_idx'
                                              ])  # align rows to existing data
        assert len(extra_info) == len(
            trial_idxs), 'Extra information changes in length'
        assert np.all(
            extra_info['condition_hash'] == hashes), 'Hash order changed'
        assert np.all(
            extra_info['trial_idx'] == trial_idxs), 'Trial idx order changed'
        row_info = {}

        for k in extra_info.columns:
            dt = extra_info[k].dtype
            if isinstance(extra_info[k][0], str):
                row_info[k] = np.array(extra_info[k], dtype='S')
            elif dt == np.dtype('O') or dt == np.dtype('<M8[ns]'):
                row_info[k] = np.array(list(map(repr, extra_info[k])),
                                       dtype='S')
            else:
                row_info[k] = np.array(extra_info[k])

        # extract behavior
        if include_behavior:
            pupil, dpupil, pupil_center, valid_eye = (Eye & key).fetch1(
                'pupil', 'dpupil', 'center', 'valid')
            pupil_center = pupil_center.T
            treadmill, valid_treadmill = (Treadmill & key).fetch1(
                'treadmill', 'valid')
            valid = valid_eye & valid_treadmill
            if np.any(~valid):
                log.warning('Found {} invalid trials. Reducing data.'.format(
                    (~valid).sum()))
                hashes = hashes[valid]
                images = images[valid]
                responses = responses[valid]
                trial_idxs = trial_idxs[valid]
                tiers = tiers[valid]
                types = types[valid]
                pupil = pupil[valid]
                dpupil = dpupil[valid]
                pupil_center = pupil_center[valid]
                treadmill = treadmill[valid]
                for k in row_info:
                    row_info[k] = row_info[k][valid]
            behavior = np.c_[pupil, dpupil, treadmill]

        areas, layers, animal_ids, sessions, scan_idxs, unit_ids = (
            self.ResponseKeys * anatomy.AreaMembership *
            anatomy.LayerMembership & key).fetch('brain_area',
                                                 'layer',
                                                 'animal_id',
                                                 'session',
                                                 'scan_idx',
                                                 'unit_id',
                                                 order_by='col_id ASC')

        assert len(np.unique(unit_ids)) == len(unit_ids), \
            'unit ids are not unique, do you have more than one preprocessing method?'

        neurons = dict(unit_ids=unit_ids.astype(np.uint16),
                       animal_ids=animal_ids.astype(np.uint16),
                       sessions=sessions.astype(np.uint8),
                       scan_idx=scan_idxs.astype(np.uint8),
                       layer=layers.astype('S'),
                       area=areas.astype('S'))

        def run_stats(selector, types, ix, axis=None):
            ret = {}
            for t in np.unique(types):
                if not np.any(ix & (types == t)):
                    continue
                data = selector(ix & (types == t))

                ret[t] = dict(mean=data.mean(axis=axis).astype(np.float32),
                              std=data.std(axis=axis,
                                           ddof=1).astype(np.float32),
                              min=data.min(axis=axis).astype(np.float32),
                              max=data.max(axis=axis).astype(np.float32),
                              median=np.median(data,
                                               axis=axis).astype(np.float32))
            data = selector(ix)
            ret['all'] = dict(mean=data.mean(axis=axis).astype(np.float32),
                              std=data.std(axis=axis,
                                           ddof=1).astype(np.float32),
                              min=data.min(axis=axis).astype(np.float32),
                              max=data.max(axis=axis).astype(np.float32),
                              median=np.median(data,
                                               axis=axis).astype(np.float32))
            return ret

        # --- compute statistics
        log.info('Computing statistics on training dataset')
        response_statistics = run_stats(lambda ix: responses[ix],
                                        types,
                                        tiers == 'train',
                                        axis=0)

        input_statistics = run_stats(lambda ix: images[ix], types,
                                     tiers == 'train')

        statistics = dict(images=input_statistics,
                          responses=response_statistics)

        if include_behavior:
            # ---- include statistics
            behavior_statistics = run_stats(lambda ix: behavior[ix],
                                            types,
                                            tiers == 'train',
                                            axis=0)
            eye_statistics = run_stats(lambda ix: pupil_center[ix],
                                       types,
                                       tiers == 'train',
                                       axis=0)

            statistics['behavior'] = behavior_statistics
            statistics['pupil_center'] = eye_statistics

        retval = dict(images=images,
                      responses=responses,
                      types=types.astype('S'),
                      condition_hashes=hashes.astype('S'),
                      trial_idx=trial_idxs.astype(np.uint32),
                      neurons=neurons,
                      item_info=row_info,
                      tiers=tiers.astype('S'),
                      statistics=statistics)
        if include_behavior:
            retval['behavior'] = behavior
            retval['pupil_center'] = pupil_center

        return retval