Exemplo n.º 1
0
class DlcQC(base.QC):
    """A class for computing camera QC metrics"""

    bbox = {
        'body': {
            'xrange': range(201, 500),
            'yrange': range(81, 330)
        },
        'left': {
            'xrange': range(301, 700),
            'yrange': range(181, 470)
        },
        'right': {
            'xrange': range(301, 600),
            'yrange': range(110, 275)
        },
    }

    dstypes = {
        'left': ['_ibl_leftCamera.dlc.*', '_ibl_leftCamera.times.*', '_ibl_leftCamera.features.*'],
        'right': ['_ibl_rightCamera.dlc.*', '_ibl_rightCamera.times.*', '_ibl_rightCamera.features.*'],
        'body': ['_ibl_bodyCamera.dlc.*', '_ibl_bodyCamera.times.*'],
    }

    def __init__(self, session_path_or_eid, side, **kwargs):
        """
        :param session_path_or_eid: A session eid or path
        :param side: The camera to run QC on
        :param log: A logging.Logger instance, if None the 'ibllib' logger is used
        :param one: An ONE instance for fetching and setting the QC on Alyx
        """
        # Make sure the type of camera is chosen
        self.side = side
        # When an eid is provided, we will download the required data by default (if necessary)
        download_data = not is_session_path(session_path_or_eid)
        self.download_data = kwargs.pop('download_data', download_data)
        super().__init__(session_path_or_eid, **kwargs)
        self.data = Bunch()

        # QC outcomes map
        self.metrics = None

    def load_data(self, download_data: bool = None) -> None:
        """Extract the data from data files
        Extracts all the required task data from the data files.

        Data keys:
            - camera_times (float array): camera frame timestamps extracted from frame headers
            - dlc_coords (dict): keys are the points traced by dlc, items are x-y coordinates of
                                 these points over time, those with likelihood <0.9 set to NaN

        :param download_data: if True, any missing raw data is downloaded via ONE.
        """
        if download_data is not None:
            self.download_data = download_data
        if self.one and not self.one.offline:
            self._ensure_required_data()

        alf_path = self.session_path / 'alf'

        # Load times
        self.data['camera_times'] = alfio.load_object(alf_path, f'{self.side}Camera')['times']
        # Load dlc traces
        dlc_df = alfio.load_object(alf_path, f'{self.side}Camera', namespace='ibl')['dlc']
        targets = np.unique(['_'.join(col.split('_')[:-1]) for col in dlc_df.columns])
        # Set values to nan if likelihood is too low
        dlc_coords = {}
        for t in targets:
            idx = dlc_df.loc[dlc_df[f'{t}_likelihood'] < 0.9].index
            dlc_df.loc[idx, [f'{t}_x', f'{t}_y']] = np.nan
            dlc_coords[t] = np.array((dlc_df[f'{t}_x'], dlc_df[f'{t}_y']))
        self.data['dlc_coords'] = dlc_coords

        # load pupil diameters
        if self.side in ['left', 'right']:
            features = alfio.load_object(alf_path, f'{self.side}Camera', namespace='ibl')['features']
            self.data['pupilDiameter_raw'] = features['pupilDiameter_raw']
            self.data['pupilDiameter_smooth'] = features['pupilDiameter_smooth']

    def _ensure_required_data(self):
        """
        Ensures the datasets required for QC are local.  If the download_data attribute is True,
        any missing data are downloaded.  If all the data are not present locally at the end of
        it an exception is raised.
        :return:
        """
        for ds in self.dstypes[self.side]:
            # Check if data available locally
            if not next(self.session_path.rglob(ds), None):
                # If download is allowed, try to download
                if self.download_data is True:
                    assert self.one is not None, 'ONE required to download data'
                    try:
                        self.one.load_dataset(self.eid, ds, download_only=True)
                    except ALFObjectNotFound:
                        raise AssertionError(f'Dataset {ds} not found locally and failed to download')
                else:
                    raise AssertionError(f'Dataset {ds} not found locally and download_data is False')

    def run(self, update: bool = False, **kwargs) -> (str, dict):
        """
        Run DLC QC checks and return outcome
        :param update: if True, updates the session QC fields on Alyx
        :param download_data: if True, downloads any missing data if required
        :returns: overall outcome as a str, a dict of checks and their outcomes
        """
        _log.info(f'Running DLC QC for {self.side} camera, session {self.eid}')
        namespace = f'dlc{self.side.capitalize()}'
        if all(x is None for x in self.data.values()):
            self.load_data(**kwargs)

        def is_metric(x):
            return isfunction(x) and x.__name__.startswith('check_')

        checks = getmembers(DlcQC, is_metric)
        self.metrics = {f'_{namespace}_' + k[6:]: fn(self) for k, fn in checks}

        values = [x if isinstance(x, str) else x[0] for x in self.metrics.values()]
        code = max(base.CRITERIA[x] for x in values)
        outcome = next(k for k, v in base.CRITERIA.items() if v == code)

        if update:
            extended = {
                k: None if v is None or v == 'NOT_SET'
                else base.CRITERIA[v] < 3 if isinstance(v, str)
                else (base.CRITERIA[v[0]] < 3, *v[1:])  # Convert first value to bool if array
                for k, v in self.metrics.items()
            }
            self.update_extended_qc(extended)
            self.update(outcome, namespace)
        return outcome, self.metrics

    def check_time_trace_length_match(self):
        '''
        Check that the length of the DLC traces is the same length as the video.
        '''
        dlc_coords = self.data['dlc_coords']
        times = self.data['camera_times']
        for target in dlc_coords.keys():
            if times.shape[0] != dlc_coords[target].shape[1]:
                _log.warning(f'{self.side}Camera length of camera.times does not match '
                             f'length of camera.dlc {target}')
                return 'FAIL'
        return 'PASS'

    def check_trace_all_nan(self):
        '''
        Check that none of the dlc traces, except for the 'tube' traces, are all NaN.
        '''
        dlc_coords = self.data['dlc_coords']
        for target in dlc_coords.keys():
            if 'tube' not in target:
                if all(np.isnan(dlc_coords[target][0])) or all(np.isnan(dlc_coords[target][1])):
                    _log.warning(f'{self.side}Camera dlc trace {target} all NaN')
                    return 'FAIL'
        return 'PASS'

    def check_mean_in_bbox(self):
        '''
        Empirical bounding boxes around average dlc points, averaged across time and points;
        sessions with points out of this box were often faulty in terms of raw videos
        '''

        dlc_coords = self.data['dlc_coords']
        with warnings.catch_warnings():
            warnings.simplefilter("ignore", category=RuntimeWarning)
            x_mean = np.nanmean([np.nanmean(dlc_coords[k][0]) for k in dlc_coords.keys()])
            y_mean = np.nanmean([np.nanmean(dlc_coords[k][1]) for k in dlc_coords.keys()])

        xrange = self.bbox[self.side]['xrange']
        yrange = self.bbox[self.side]['yrange']
        if int(x_mean) not in xrange or int(y_mean) not in yrange:
            return 'FAIL'
        else:
            return 'PASS'

    def check_pupil_blocked(self):
        '''
        Check if pupil diameter is nan for more than 60 % of the frames
        (might be blocked by a whisker)
        Check if standard deviation is above a threshold, found for bad sessions
        '''

        if self.side == 'body':
            return 'NOT_SET'

        if np.mean(np.isnan(self.data['pupilDiameter_raw'])) > 0.9:
            _log.warning(f'{self.eid}, {self.side}Camera, pupil diameter too often NaN')
            return 'FAIL'

        thr = 5 if self.side == 'right' else 10
        if np.nanstd(self.data['pupilDiameter_raw']) > thr:
            _log.warning(f'{self.eid}, {self.side}Camera, pupil diameter too unstable')
            return 'FAIL'

        return 'PASS'

    def check_lick_detection(self):
        '''
        Check if both of the two tongue edge points are less than 10 % NaN, indicating that
        wrong points are detected (spout edge, mouth edge)
        '''

        if self.side == 'body':
            return 'NOT_SET'
        dlc_coords = self.data['dlc_coords']
        nan_l = np.mean(np.isnan(dlc_coords['tongue_end_l'][0]))
        nan_r = np.mean(np.isnan(dlc_coords['tongue_end_r'][0]))
        if (nan_l < 0.1) and (nan_r < 0.1):
            return 'FAIL'
        return 'PASS'

    def check_pupil_diameter_snr(self):
        if self.side == 'body':
            return 'NOT_SET'
        thresh = 5 if self.side == 'right' else 10
        if 'pupilDiameter_raw' not in self.data.keys() or 'pupilDiameter_smooth' not in self.data.keys():
            return 'NOT_SET'
        # compute signal to noise ratio between raw and smooth dia
        good_idxs = np.where(~np.isnan(self.data['pupilDiameter_smooth']) & ~np.isnan(self.data['pupilDiameter_raw']))[0]
        snr = (np.var(self.data['pupilDiameter_smooth'][good_idxs]) /
               (np.var(self.data['pupilDiameter_smooth'][good_idxs] - self.data['pupilDiameter_raw'][good_idxs])))
        if snr < thresh:
            return 'FAIL', float(round(snr, 3))
        return 'PASS', float(round(snr, 3))
Exemplo n.º 2
0
class MotionAlignment:
    roi = {
        'left': ((800, 1020), (233, 1096)),
        'right': ((426, 510), (104, 545)),
        'body': ((402, 481), (31, 103))
    }

    def __init__(self,
                 eid=None,
                 one=None,
                 log=logging.getLogger('ibllib'),
                 **kwargs):
        self.one = one or ONE()
        self.eid = eid
        self.session_path = kwargs.pop('session_path',
                                       None) or self.one.eid2path(eid)
        self.ref = self.one.dict2ref(self.one.path2ref(self.session_path))
        self.log = log
        self.trials = self.wheel = self.camera_times = None
        raw_cam_path = self.session_path.joinpath('raw_video_data')
        camera_path = list(raw_cam_path.glob('_iblrig_*Camera.raw.*'))
        self.video_paths = {vidio.label_from_path(x): x for x in camera_path}
        self.data = Bunch()
        self.alignment = Bunch()

    def align_all_trials(self, side='all'):
        """Align all wheel motion for all trials"""
        if self.trials is None:
            self.load_data()
        if side == 'all':
            side = self.video_paths.keys()
        if not isinstance(side, str):
            # Try to iterate over sides
            [self.align_all_trials(s) for s in side]
        if side not in self.video_paths:
            raise ValueError(f'{side} camera video file not found')
        # Align each trial sequentially
        for i in np.arange(self.trials['intervals'].shape[0]):
            self.align_motion(i, display=False)

    @staticmethod
    def set_roi(video_path):
        """Manually set the ROIs for a given set of videos
        TODO Improve docstring
        TODO A method for setting ROIs by label
        """
        frame = vidio.get_video_frame(str(video_path), 0)

        def line_select_callback(eclick, erelease):
            """
            Callback for line selection.

            *eclick* and *erelease* are the press and release events.
            """
            x1, y1 = eclick.xdata, eclick.ydata
            x2, y2 = erelease.xdata, erelease.ydata
            print("(%3.2f, %3.2f) --> (%3.2f, %3.2f)" % (x1, y1, x2, y2))
            return np.array([[x1, x2], [y1, y2]])

        plt.imshow(frame)
        roi = RectangleSelector(
            plt.gca(),
            line_select_callback,
            drawtype='box',
            useblit=True,
            button=[1, 3],  # don't use middle button
            minspanx=5,
            minspany=5,
            spancoords='pixels',
            interactive=True)
        plt.show()
        ((x1, x2, *_), (y1, *_, y2)) = roi.corners
        col = np.arange(round(x1), round(x2), dtype=int)
        row = np.arange(round(y1), round(y2), dtype=int)
        return col, row

    def load_data(self, download=False):
        """
        Load wheel, trial and camera timestamp data
        :return: wheel, trials
        """
        if download:
            self.data.wheel = self.one.load_object(self.eid, 'wheel')
            self.data.trials = self.one.load_object(self.eid, 'trials')
            cam = self.one.load(self.eid, ['camera.times'], dclass_output=True)
            self.data.camera_times = {
                vidio.label_from_path(url): ts
                for ts, url in zip(cam.data, cam.url)
            }
        else:
            alf_path = self.session_path / 'alf'
            self.data.wheel = alfio.load_object(alf_path,
                                                'wheel',
                                                short_keys=True)
            self.data.trials = alfio.load_object(alf_path, 'trials')
            self.data.camera_times = {
                vidio.label_from_path(x): alfio.load_file_content(x)
                for x in alf_path.glob('*Camera.times*')
            }
        assert all(x is not None for x in self.data.values())

    def _set_eid_or_path(self, session_path_or_eid):
        """Parse a given eID or session path
        If a session UUID is given, resolves and stores the local path and vice versa
        :param session_path_or_eid: A session eid or path
        :return:
        """
        self.eid = None
        if is_uuid_string(str(session_path_or_eid)):
            self.eid = session_path_or_eid
            # Try to set session_path if data is found locally
            self.session_path = self.one.eid2path(self.eid)
        elif is_session_path(session_path_or_eid):
            self.session_path = Path(session_path_or_eid)
            if self.one is not None:
                self.eid = self.one.path2eid(self.session_path)
                if not self.eid:
                    self.log.warning(
                        'Failed to determine eID from session path')
        else:
            self.log.error(
                'Cannot run alignment: an experiment uuid or session path is required'
            )
            raise ValueError("'session' must be a valid session path or uuid")

    def align_motion(self,
                     period=(-np.inf, np.inf),
                     side='left',
                     sd_thresh=10,
                     display=False):
        # Get data samples within period
        wheel = self.data['wheel']
        self.alignment.label = side
        self.alignment.to_mask = lambda ts: np.logical_and(
            ts >= period[0], ts <= period[1])
        camera_times = self.data['camera_times'][side]
        cam_mask = self.alignment.to_mask(camera_times)
        frame_numbers, = np.where(cam_mask)

        if frame_numbers.size == 0:
            raise ValueError('No frames during given period')

        # Motion Energy
        camera_path = self.video_paths[side]
        roi = (*[slice(*r) for r in self.roi[side]], 0)
        try:
            # TODO Add function arg to make grayscale
            self.alignment.frames = \
                vidio.get_video_frames_preload(camera_path, frame_numbers, mask=roi)
            assert self.alignment.frames.size != 0
        except AssertionError:
            self.log.error('Failed to open video')
            return None, None, None
        self.alignment.df, stDev = video.motion_energy(self.alignment.frames,
                                                       2)
        self.alignment.period = period  # For plotting

        # Calculate rotary encoder velocity trace
        x = camera_times[cam_mask]
        Fs = 1000
        pos, t = wh.interpolate_position(wheel.timestamps,
                                         wheel.position,
                                         freq=Fs)
        v, _ = wh.velocity_smoothed(pos, Fs)
        interp_mask = self.alignment.to_mask(t)
        # Convert to normalized speed
        xs = np.unique([find_nearest(t[interp_mask], ts) for ts in x])
        vs = np.abs(v[interp_mask][xs])
        vs = (vs - np.min(vs)) / (np.max(vs) - np.min(vs))

        # FIXME This can be used as a goodness of fit measure
        USE_CV2 = False
        if USE_CV2:
            # convert from numpy format to openCV format
            dfCV = np.float32(self.alignment.df.reshape((-1, 1)))
            reCV = np.float32(vs.reshape((-1, 1)))

            # perform cross correlation
            resultCv = cv2.matchTemplate(dfCV, reCV, cv2.TM_CCORR_NORMED)

            # convert result back to numpy array
            xcorr = np.asarray(resultCv)
        else:
            xcorr = signal.correlate(self.alignment.df, vs)

        # Cross correlate wheel speed trace with the motion energy
        CORRECTION = 2
        self.alignment.c = max(xcorr)
        self.alignment.xcorr = np.argmax(xcorr)
        self.alignment.dt_i = self.alignment.xcorr - xs.size + CORRECTION
        self.log.info(
            f'{side} camera, adjusted by {self.alignment.dt_i} frames')

        if display:
            # Plot the motion energy
            fig, ax = plt.subplots(2, 1, sharex='all')
            y = np.pad(self.alignment.df, 1, 'edge')
            ax[0].plot(x, y, '-x', label='wheel motion energy')
            thresh = stDev > sd_thresh
            ax[0].vlines(x[np.array(
                np.pad(thresh, 1, 'constant', constant_values=False))],
                         0,
                         1,
                         linewidth=0.5,
                         linestyle=':',
                         label=f'>{sd_thresh} s.d. diff')
            ax[1].plot(t[interp_mask], np.abs(v[interp_mask]))

            # Plot other stuff
            dt = np.diff(camera_times[[0, np.abs(self.alignment.dt_i)]])
            fps = 1 / np.diff(camera_times).mean()
            ax[0].plot(t[interp_mask][xs] - dt,
                       vs,
                       'r-x',
                       label='velocity (shifted)')
            ax[0].set_title('normalized motion energy, %s camera, %.0f fps' %
                            (side, fps))
            ax[0].set_ylabel('rate of change (a.u.)')
            ax[0].legend()
            ax[1].set_ylabel('wheel speed (rad / s)')
            ax[1].set_xlabel('Time (s)')

            title = f'{self.ref}, from {period[0]:.1f}s - {period[1]:.1f}s'
            fig.suptitle(title, fontsize=16)
            fig.set_size_inches(19.2, 9.89)

        return self.alignment.dt_i, self.alignment.c, self.alignment.df

    def plot_alignment(self, energy=True, save=False):
        if not self.alignment:
            self.log.error('No alignment data, run `align_motion` first')
            return
        # Change backend based on save flag
        backend = matplotlib.get_backend().lower()
        if (save and backend != 'agg') or (not save and backend == 'agg'):
            new_backend = 'Agg' if save else 'Qt5Agg'
            self.log.warning('Switching backend from %s to %s', backend,
                             new_backend)
            matplotlib.use(new_backend)
        from matplotlib import pyplot as plt

        # Main animated plots
        fig, axes = plt.subplots(nrows=2)
        title = f'{self.ref}'  # ', from {period[0]:.1f}s - {period[1]:.1f}s'
        fig.suptitle(title, fontsize=16)

        wheel = self.data['wheel']
        wheel_mask = self.alignment['to_mask'](wheel.timestamps)
        ts = self.data['camera_times'][self.alignment['label']]
        frame_numbers, = np.where(self.alignment['to_mask'](ts))
        if energy:
            self.alignment['frames'] = video.frame_diffs(
                self.alignment['frames'], 2)
            frame_numbers = frame_numbers[1:-1]
        data = {'frame_ids': frame_numbers}

        def init_plot():
            """
            Plot the wheel data for the current trial
            :return: None
            """
            data['im'] = axes[0].imshow(self.alignment['frames'][0])
            axes[0].axis('off')
            axes[0].set_title(f'adjusted by {self.alignment["dt_i"]} frames')

            # Plot the wheel position
            ax = axes[1]
            ax.clear()
            ax.plot(wheel.timestamps[wheel_mask], wheel.position[wheel_mask],
                    '-x')

            ts_0 = frame_numbers[0]
            data['idx_0'] = ts_0 - self.alignment['dt_i']
            ts_0 = ts[ts_0 + self.alignment['dt_i']]
            data['ln'] = ax.axvline(x=ts_0, color='k')
            ax.set_xlim([ts_0 - (3 / 2), ts_0 + (3 / 2)])
            data['frame_num'] = 0
            mkr = find_nearest(wheel.timestamps[wheel_mask], ts_0)

            data['marker'], = ax.plot(wheel.timestamps[wheel_mask][mkr],
                                      wheel.position[wheel_mask][mkr], 'r-x')
            ax.set_ylabel('Wheel position (rad))')
            ax.set_xlabel('Time (s))')
            return

        def animate(i):
            """
            Callback for figure animation.  Sets image data for current frame and moves pointer
            along axis
            :param i: unused; the current time step of the calling method
            :return: None
            """
            if i < 0:
                data['frame_num'] -= 1
                if data['frame_num'] < 0:
                    data['frame_num'] = len(self.alignment['frames']) - 1
            else:
                data['frame_num'] += 1
                if data['frame_num'] >= len(self.alignment['frames']):
                    data['frame_num'] = 0
            i = data[
                'frame_num']  # NB: This is index for current trial's frame list

            frame = self.alignment['frames'][i]
            t_x = ts[data['idx_0'] + i]
            data['ln'].set_xdata([t_x, t_x])
            axes[1].set_xlim([t_x - (3 / 2), t_x + (3 / 2)])
            data['im'].set_data(frame)

            mkr = find_nearest(wheel.timestamps[wheel_mask], t_x)
            data['marker'].set_data(wheel.timestamps[wheel_mask][mkr],
                                    wheel.position[wheel_mask][mkr])

            return data['im'], data['ln'], data['marker']

        anim = animation.FuncAnimation(fig,
                                       animate,
                                       init_func=init_plot,
                                       frames=(range(len(self.alignment.df))
                                               if save else cycle(range(60))),
                                       interval=20,
                                       blit=False,
                                       repeat=not save,
                                       cache_frame_data=False)
        anim.running = False

        def process_key(event):
            """
            Callback for key presses.
            :param event: a figure key_press_event
            :return: None
            """
            if event.key.isspace():
                if anim.running:
                    anim.event_source.stop()
                else:
                    anim.event_source.start()
                anim.running = ~anim.running
            elif event.key == 'right':
                if anim.running:
                    anim.event_source.stop()
                    anim.running = False
                animate(1)
                fig.canvas.draw()
            elif event.key == 'left':
                if anim.running:
                    anim.event_source.stop()
                    anim.running = False
                animate(-1)
                fig.canvas.draw()

        fig.canvas.mpl_connect('key_press_event', process_key)

        # init_plot()
        # while True:
        #     animate(0)
        if save:
            filename = '%s_%c.mp4' % (self.ref, self.alignment['label'][0])
            if isinstance(save, (str, Path)):
                filename = Path(save).joinpath(filename)
            self.log.info(f'Saving to {filename}')
            # Set up formatting for the movie files
            Writer = animation.writers['ffmpeg']
            writer = Writer(fps=24,
                            metadata=dict(artist='Miles Wells'),
                            bitrate=1800)
            anim.save(str(filename), writer=writer)
        else:
            plt.show()