Exemple #1
0
 def get_active_wheel_period(wheel,
                             duration_range=(3., 20.),
                             display=False):
     """
     Attempts to find a period of movement where the wheel accelerates and decelerates for
     the wheel motion alignment QC.
     :param wheel: A Bunch of wheel timestamps and position data
     :param duration_range: The candidates must be within min/max duration range
     :param display: If true, plot the selected wheel movement
     :return: 2-element array comprising the start and end times of the active period
     """
     pos, ts = wh.interpolate_position(wheel.timestamps, wheel.position)
     v, acc = wh.velocity_smoothed(pos, 1000)
     on, off, *_ = wh.movements(ts, acc, pos_thresh=.1, make_plots=False)
     edges = np.c_[on, off]
     indices, _ = np.where(
         np.logical_and(
             np.diff(edges) > duration_range[0],
             np.diff(edges) < duration_range[1]))
     if len(indices) == 0:
         _log.warning(
             'No period of wheel movement found for motion alignment.')
         return None
     # Pick movement somewhere in the middle
     i = indices[int(indices.size / 2)]
     if display:
         _, (ax0, ax1) = plt.subplots(2, 1, sharex='all')
         mask = np.logical_and(ts > edges[i][0], ts < edges[i][1])
         ax0.plot(ts[mask], pos[mask])
         ax1.plot(ts[mask], acc[mask])
     return edges[i]
Exemple #2
0
    def extract_onsets_for_trial(self):
        """
        Extracts the movement onsets and offsets for the current trial
        :return: tuple of onsets, offsets on, interpolated timestamps, interpolated positions,
        and position units
        """
        wheel = self._session_data['wheel']
        trials = self._session_data['trials']
        trial_idx = self.trial_num - 1  # Trials num starts at 1
        # Check the values and units of wheel position
        res = np.array([wh.ENC_RES, wh.ENC_RES / 2, wh.ENC_RES / 4])
        # min change in rad and cm for each decoding type
        # [rad_X4, rad_X2, rad_X1, cm_X4, cm_X2, cm_X1]
        min_change = np.concatenate(
            [2 * np.pi / res, wh.WHEEL_DIAMETER * np.pi / res])
        pos_diff = np.median(np.abs(np.ediff1d(wheel['position'])))

        # find min change closest to min pos_diff
        idx = np.argmin(np.abs(min_change - pos_diff))
        if idx < len(res):
            # Assume values are in radians
            units = 'rad'
            encoding = idx
        else:
            units = 'cm'
            encoding = idx - len(res)
        thresholds = wh.samples_to_cm(np.array([8, 1.5]),
                                      resolution=res[encoding])
        if units == 'rad':
            thresholds = wh.cm_to_rad(thresholds)
        kwargs = {
            'pos_thresh': thresholds[0],
            'pos_thresh_onset': thresholds[1]
        }
        #  kwargs = {'make_plots': True, **kwargs}  # Uncomment for plot

        # Interpolate and get onsets
        pos, t = wh.interpolate_position(wheel['timestamps'],
                                         wheel['position'],
                                         freq=1000)
        # Get the positions and times between our trial start and the next trial start
        if self.quick_load or not self.trial_num:
            try:
                # End of previous trial to beginning of next
                t_mask = np.logical_and(
                    t >= trials['intervals'][trial_idx - 1, 1],
                    t <= trials['intervals'][trial_idx + 1, 0])
            except IndexError:  # We're on the last trial
                # End of previous trial to end of current
                t_mask = np.logical_and(
                    t >= trials['intervals'][trial_idx - 1, 1],
                    t <= trials['intervals'][trial_idx, 1])
        else:
            t_mask = np.ones_like(t, dtype=bool)
        wheel_ts = t[t_mask]
        wheel_pos = pos[t_mask]
        on, off, *_ = wh.movements(wheel_ts, wheel_pos, freq=1000, **kwargs)
        return on, off, wheel_ts, wheel_pos, units
Exemple #3
0
 def test_movements_FPGA(self):
     # These test data are the same as those used in the MATLAB code.  Test data are from
     # extracted FPGA wheel data
     pos, t = wheel.interpolate_position(*self.test_data[1][0], freq=1000)
     expected = self.test_data[1][1]
     thresholds = wheel.samples_to_cm(np.array([8, 1.5]))
     on, off, amp, peak_vel = wheel.movements(
         t, pos, freq=1000, pos_thresh=thresholds[0], pos_thresh_onset=thresholds[1])
     self.assertTrue(np.allclose(on, expected[0], atol=1.e-5), msg='Unexpected onsets')
     self.assertTrue(np.allclose(off, expected[1], atol=1.e-5), msg='Unexpected offsets')
     self.assertTrue(np.allclose(amp, expected[2], atol=1.e-5), msg='Unexpected move amps')
     self.assertTrue(np.allclose(peak_vel, expected[3], atol=1.e-2),
                     msg='Unexpected peak velocities')
Exemple #4
0
def plot_wheel_position(wheel_position, wheel_time, trials_df):
    """
    Plots wheel position across trials, color by which side was chosen

    :param wheel_position: np.array, interpolated wheel position
    :param wheel_time: np.array, interpolated wheel timestamps
    :param trials_df: pd.DataFrame, with column 'stimOn_times' (time of stimulus onset times for each trial)
    :returns: matplotlib.axis
    """
    # Interpolate wheel data
    wheel_position, wheel_time = bbox_wheel.interpolate_position(
        wheel_time, wheel_position, freq=1 / T_BIN)
    # Create a window around the stimulus onset
    start_window, end_window = plt_window(trials_df['stimOn_times'])
    # Translating the time window into an index window
    start_idx = insert_idx(wheel_time, start_window)
    end_idx = np.array(start_idx + int(WINDOW_LEN / T_BIN), dtype='int64')
    # Getting the wheel position for each window, normalize to first value of each window
    trials_df['wheel_position'] = [
        wheel_position[start_idx[w]:end_idx[w]] - wheel_position[start_idx[w]]
        for w in range(len(start_idx))
    ]
    # Plotting
    times = np.arange(len(
        trials_df['wheel_position'].iloc[0])) * T_BIN + WINDOW_LAG
    for side, label, color in zip([-1, 1], ['right', 'left'],
                                  ['darkred', '#1f77b4']):
        side_df = trials_df[trials_df['choice'] == side]
        for idx in side_df.index:
            plt.plot(times,
                     side_df.loc[idx, 'wheel_position'],
                     c=color,
                     alpha=0.5,
                     linewidth=0.05)
        plt.plot(times,
                 side_df['wheel_position'].mean(),
                 c=color,
                 linewidth=2,
                 label=f'{label} turn')

    plt.axvline(x=0, linestyle='--', c='k', label='stimOn')
    plt.axhline(y=-0.26, linestyle='--', c='g', label='reward')
    plt.axhline(y=0.26, linestyle='--', c='g', label='reward')
    plt.ylim([-0.27, 0.27])
    plt.xlabel('time [sec]')
    plt.ylabel('wheel position diff to first value [rad]')
    plt.legend(loc='center right')
    plt.title('Wheel position trial avg\n(and individual trials)')
    plt.tight_layout()

    return plt.gca()
Exemple #5
0
def extract_wheel_moves(re_ts, re_pos, display=False):
    """
    Extract wheel positions and times from sync fronts dictionary
    :param re_ts: numpy array of rotary encoder timestamps
    :param re_pos: numpy array of rotary encoder positions
    :param display: bool: show the wheel position and velocity for full session with detected
    movements highlighted
    :return: wheel_moves dictionary
    """
    if len(re_ts.shape) == 1:
        assert re_ts.size == re_pos.size, 'wheel data dimension mismatch'
    else:
        _logger.debug('2D wheel timestamps')
        if len(re_pos.shape) > 1:  # Ensure 1D array of positions
            re_pos = re_pos.flatten()
        # Linearly interpolate the times
        x = np.arange(re_pos.size)
        re_ts = np.interp(x, re_ts[:, 0], re_ts[:, 1])

    units, res, enc = infer_wheel_units(re_pos)
    _logger.info('Wheel in %s units using %s encoding', units, enc)

    # The below assertion is violated by Bpod wheel data
    #  assert np.allclose(pos_diff, min_change, rtol=1e-05), 'wheel position skips'

    # Convert the pos threshold defaults from samples to correct unit
    thresholds = wh.samples_to_cm(np.array([8, 1.5]), resolution=res)
    if units == 'rad':
        thresholds = wh.cm_to_rad(thresholds)
    kwargs = {
        'pos_thresh': thresholds[0],
        'pos_thresh_onset': thresholds[1],
        'make_plots': display
    }

    # Interpolate and get onsets
    pos, t = wh.interpolate_position(re_ts, re_pos, freq=1000)
    on, off, amp, peak_vel = wh.movements(t, pos, freq=1000, **kwargs)
    assert on.size == off.size, 'onset/offset number mismatch'
    assert np.all(np.diff(on) > 0) and np.all(
        np.diff(off) > 0), 'onsets/offsets not strictly increasing'
    assert np.all((off - on) > 0), 'not all offsets occur after onset'

    # Put into dict
    wheel_moves = {
        'intervals': np.c_[on, off],
        'peakAmplitude': amp,
        'peakVelocity_times': peak_vel
    }
    return wheel_moves
Exemple #6
0
def Viewer(eid, video_type, trial_range, save_video=True, eye_zoom=False):
    '''
    eid: session id, e.g. '3663d82b-f197-4e8b-b299-7b803a155b84'
    video_type: one of 'left', 'right', 'body'
    trial_range: first and last trial number of range to be shown, e.g. [5,7]
    save_video: video is displayed and saved in local folder

    Example usage to view and save labeled video with wheel angle:
    Viewer('3663d82b-f197-4e8b-b299-7b803a155b84', 'left', [5,7])
    3D example: 'cb2ad999-a6cb-42ff-bf71-1774c57e5308', [5,7]
    '''

    save_vids_here = '/home/mic/'
    if save_vids_here[-1] != '/':
        return 'Last character of save_vids_here must be slash'

    one = ONE()
    dataset_types = [
        'camera.times', 'wheel.position', 'wheel.timestamps',
        'trials.intervals', 'camera.dlc'
    ]

    a = one.list(eid, 'dataset-types')

    assert all([i in a for i in dataset_types
                ]), 'For this eid, not all data available'

    D = one.load(eid, dataset_types=dataset_types, dclass_output=True)
    alf_path = Path(D.local_path[0]).parent.parent / 'alf'

    # Download a single video
    video_data = alf_path.parent / 'raw_video_data'
    download_raw_video(eid, cameras=[video_type])
    video_path = list(video_data.rglob('_iblrig_%sCamera.raw.*' %
                                       video_type))[0]
    print(video_path)

    # that gives cam time stamps and DLC output (change to alf_path eventually)

    cam = alf.io.load_object(alf_path,
                             '%sCamera' % video_type,
                             namespace='ibl')

    # just to read in times for newer data (which has DLC results in pqt format
    # cam = alf.io.load_object(alf_path, '_ibl_%sCamera' % video_type)

    # set where to read and save video and get video info
    cap = cv2.VideoCapture(video_path.as_uri())
    length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    fps = cap.get(cv2.CAP_PROP_FPS)
    size = (int(cap.get(3)), int(cap.get(4)))

    assert length < len(cam['times']), '#frames > #stamps'
    print(eid, ', ', video_type, ', fsp:', fps, ', #frames:', length,
          ', #stamps:', len(cam['times']), ', #frames - #stamps = ',
          length - len(cam['times']))

    # pick trial range for which to display stuff
    trials = alf.io.load_object(alf_path, 'trials', namespace='ibl')
    num_trials = len(trials['intervals'])
    if trial_range[-1] > num_trials - 1:
        print('There are only %s trials' % num_trials)

    frame_start = find_nearest(cam['times'],
                               [trials['intervals'][trial_range[0]][0]])
    frame_stop = find_nearest(cam['times'],
                              [trials['intervals'][trial_range[-1]][1]])
    '''
    wheel related stuff
    '''

    wheel = alf.io.load_object(alf_path, 'wheel', namespace='ibl')
    import brainbox.behavior.wheel as wh
    try:
        pos, t = wh.interpolate_position(wheel['timestamps'],
                                         wheel['position'],
                                         freq=1000)
    except BaseException:
        pos, t = wh.interpolate_position(wheel['times'],
                                         wheel['position'],
                                         freq=1000)

    w_start = find_nearest(t, trials['intervals'][trial_range[0]][0])
    w_stop = find_nearest(t, trials['intervals'][trial_range[-1]][1])

    # confine to interval
    pos_int = pos[w_start:w_stop]
    t_int = t[w_start:w_stop]

    # alignment of cam stamps and interpolated wheel stamps
    wheel_pos = []
    kk = 0
    for wt in cam['times'][frame_start:frame_stop]:
        wheel_pos.append(pos_int[find_nearest(t_int, wt)])
        kk += 1
        if kk % 3000 == 0:
            print('iteration', kk)
    '''
    DLC related stuff
    '''
    Times = cam['times'][frame_start:frame_stop]
    del cam['times']

    # some exception for inconsisitent data formats
    try:
        dlc_name = '_ibl_%sCamera.dlc.pqt' % video_type
        dlc_path = alf_path / dlc_name
        cam = pd.read_parquet(dlc_path, engine="fastparquet")
        print('it is pqt')
    except BaseException:
        raw_vid_path = alf_path.parent / 'raw_video_data'
        cam = alf.io.load_object(raw_vid_path,
                                 '%sCamera' % video_type,
                                 namespace='ibl')

    points = np.unique(['_'.join(x.split('_')[:-1]) for x in cam.keys()])
    if len(points) == 1:
        cam = cam['dlc']
        points = np.unique(['_'.join(x.split('_')[:-1]) for x in cam.keys()])

    if video_type != 'body':
        d = list(points)
        d.remove('tube_top')
        d.remove('tube_bottom')
        points = np.array(d)

    # Set values to nan if likelyhood is too low # for pqt: .to_numpy()
    XYs = {}
    for point in points:
        x = np.ma.masked_where(cam[point + '_likelihood'] < 0.9,
                               cam[point + '_x'])
        x = x.filled(np.nan)
        y = np.ma.masked_where(cam[point + '_likelihood'] < 0.9,
                               cam[point + '_y'])
        y = y.filled(np.nan)
        XYs[point] = np.array(
            [x[frame_start:frame_stop], y[frame_start:frame_stop]])

    # Just for 3D testing
    # return XYs

    # Zoom at eye
    if eye_zoom:
        pivot = np.nanmean(XYs['pupil_top_r'], axis=1)
        x0 = int(pivot[0]) - 33
        x1 = int(pivot[0]) + 33
        y0 = int(pivot[1]) - 28
        y1 = int(pivot[1]) + 38
        size = (66, 66)
        dot_s = 1  # [px] for painting DLC dots

    else:
        x0 = 0
        x1 = size[0]
        y0 = 0
        y1 = size[1]
        if video_type == 'left':
            dot_s = 10  # [px] for painting DLC dots
        else:
            dot_s = 5

    if save_video:
        loc = save_vids_here + '%s_trials_%s_%s_%s.mp4' % (
            eid, trial_range[0], trial_range[-1], video_type)
        out = cv2.VideoWriter(loc, cv2.VideoWriter_fourcc(*'mp4v'), fps,
                              size)  # put , 0 if grey scale

    # writing stuff on frames
    font = cv2.FONT_HERSHEY_SIMPLEX

    if video_type == 'left':
        bottomLeftCornerOfText = (20, 1000)
        fontScale = 4
    else:
        bottomLeftCornerOfText = (10, 500)
        fontScale = 2

    lineType = 2

    # assign a color to each DLC point (now: all points red)
    cmap = matplotlib.cm.get_cmap('Spectral')
    CR = np.arange(len(points)) / len(points)

    block = np.ones((2 * dot_s, 2 * dot_s, 3))

    # set start frame
    cap.set(1, frame_start)

    k = 0
    while (cap.isOpened()):
        ret, frame = cap.read()
        gray = frame

        # print wheel angle
        fontColor = (255, 255, 255)
        Angle = round(wheel_pos[k], 2)
        Time = round(Times[k], 3)
        cv2.putText(gray, 'Wheel angle: ' + str(Angle), bottomLeftCornerOfText,
                    font, fontScale / 2, fontColor, lineType)

        a, b = bottomLeftCornerOfText
        bottomLeftCornerOfText0 = (int(a * 10 + b / 2), b)
        cv2.putText(gray, '  time: ' + str(Time), bottomLeftCornerOfText0,
                    font, fontScale / 2, fontColor, lineType)

        # print DLC dots
        ll = 0
        for point in points:

            # Put point color legend
            fontColor = (np.array([cmap(CR[ll])]) * 255)[0][:3]
            a, b = bottomLeftCornerOfText
            if video_type == 'right':
                bottomLeftCornerOfText2 = (a, a * 2 * (1 + ll))
            else:
                bottomLeftCornerOfText2 = (b, a * 2 * (1 + ll))
            fontScale2 = fontScale / 4
            cv2.putText(gray, point, bottomLeftCornerOfText2, font, fontScale2,
                        fontColor, lineType)

            X0 = XYs[point][0][k]
            Y0 = XYs[point][1][k]
            # transform for opencv?
            X = Y0
            Y = X0

            if not np.isnan(X) and not np.isnan(Y):
                col = (np.array([cmap(CR[ll])]) * 255)[0][:3]
                # col = np.array([0, 0, 255]) # all points red
                X = X.astype(int)
                Y = Y.astype(int)
                gray[X - dot_s:X + dot_s, Y - dot_s:Y + dot_s] = block * col
            ll += 1

        gray = gray[y0:y1, x0:x1]
        if save_video:
            out.write(gray)
        cv2.imshow('frame', gray)
        cv2.waitKey(1)
        k += 1
        if k == (frame_stop - frame_start) - 1:
            break

    if save_video:
        out.release()
    cap.release()
    cv2.destroyAllWindows()
n_trials = 3  # Number of trials to plot
# Randomly select the trials to plot
trial_ids = np.random.randint(trial_data['choice'].size, size=n_trials)
fig, axs = plt.subplots(1, n_trials, figsize=(8.5, 2.5))
plt.tight_layout()

# Plot go cue and response times
goCues = trial_data['goCue_times'][trial_ids]
responses = trial_data['response_times'][trial_ids]

# Plot traces between trial intervals
starts = trial_data['intervals'][trial_ids, 0]
ends = trial_data['intervals'][trial_ids, 1]
# Cut up the wheel vectors
Fs = 1000
pos, t = wh.interpolate_position(wheel.timestamps, wheel.position, freq=Fs)
vel, acc = wh.velocity_smoothed(pos, Fs)

traces = wh.traces_by_trial(t, pos, start=starts, end=ends)
zipped = zip(traces, axs, goCues, responses, trial_ids)

for (trace, ax, go, resp, n) in zipped:
    ax.plot(trace[0], trace[1], 'k-')
    ax.axvline(x=go, color='g', label='go cue', linestyle=':')
    ax.axvline(x=resp, color='r', label='threshold', linestyle=':')
    ax.set_title('Trial #%s' % n)

    # Turn off tick labels
    ax.set_yticklabels([])
    ax.set_xticklabels([])
Exemple #8
0
    def make(self, key):
        # Load the wheel for this session
        move_key = key.copy()
        one = ONE()
        eid, ver = (acquisition.Session & key).fetch1('session_uuid',
                                                      'task_protocol')
        logger.info('WheelMoves for session %s, %s', str(eid), ver)

        try:  # Should be able to remove this
            wheel = one.load_object(str(eid), 'wheel')
            all_loaded = \
                all([isinstance(wheel[lab], np.ndarray) for lab in wheel]) and \
                all(k in wheel for k in ('timestamps', 'position'))
            assert all_loaded, 'wheel data missing'
            alf.io.check_dimensions(wheel)
            if len(wheel['timestamps'].shape) == 1:
                assert wheel['timestamps'].size == wheel[
                    'position'].size, 'wheel data dimension mismatch'
                assert np.all(
                    np.diff(wheel['timestamps']) > 0
                ), 'wheel timestamps not monotonically increasing'
            else:
                logger.debug('2D timestamps')
            # Check the values and units of wheel position
            res = np.array([wh.ENC_RES, wh.ENC_RES / 2, wh.ENC_RES / 4])
            min_change_rad = 2 * np.pi / res
            min_change_cm = wh.WHEEL_DIAMETER * np.pi / res
            pos_diff = np.abs(np.ediff1d(wheel['position']))
            if pos_diff.min() < min_change_cm.min():
                # Assume values are in radians
                units = 'rad'
                encoding = np.argmin(np.abs(min_change_rad - pos_diff.min()))
                min_change = min_change_rad[encoding]
            else:
                units = 'cm'
                encoding = np.argmin(np.abs(min_change_cm - pos_diff.min()))
                min_change = min_change_cm[encoding]
            enc_names = {0: '4X', 1: '2X', 2: '1X'}
            logger.info('Wheel in %s units using %s encoding', units,
                        enc_names[int(encoding)])
            if '_iblrig_tasks_ephys' in ver:
                assert np.allclose(pos_diff, min_change,
                                   rtol=1e-05), 'wheel position skips'
        except ValueError:
            logger.exception('Inconsistent wheel data')
            raise
        except AssertionError as ex:
            logger.exception(str(ex))
            raise
        except Exception as ex:
            logger.exception(str(ex))
            raise

        try:
            # Convert the pos threshold defaults from samples to correct unit
            thresholds = wh.samples_to_cm(np.array([8, 1.5]),
                                          resolution=res[encoding])
            if units == 'rad':
                thresholds = wh.cm_to_rad(thresholds)
            kwargs = {
                'pos_thresh': thresholds[0],
                'pos_thresh_onset': thresholds[1]
            }
            #  kwargs = {'make_plots': True, **kwargs}
            # Interpolate and get onsets
            pos, t = wh.interpolate_position(wheel['timestamps'],
                                             wheel['position'],
                                             freq=1000)
            on, off, amp, peak_vel = wh.movements(t, pos, freq=1000, **kwargs)
            assert on.size == off.size, 'onset/offset number mismatch'
            assert np.all(np.diff(on) > 0) and np.all(np.diff(
                off) > 0), 'onsets/offsets not monotonically increasing'
            assert np.all((off - on) > 0), 'not all offsets occur after onset'
        except ValueError:
            logger.exception('Failed to find movements')
            raise
        except AssertionError as ex:
            logger.exception('Wheel integrity check failed: ' + str(ex))
            raise

        key['n_movements'] = on.size  # total number of movements within the session
        key['total_displacement'] = float(np.diff(
            pos[[0, -1]]))  # total displacement of the wheel during session
        key['total_distance'] = float(np.abs(
            np.diff(pos)).sum())  # total movement of the wheel
        if units is 'cm':  # convert to radians
            key['total_displacement'] = wh.cm_to_rad(key['total_displacement'])
            key['total_distance'] = wh.cm_to_rad(key['total_distance'])
            amp = wh.cm_to_rad(amp)

        self.insert1(key)

        keys = ('move_id', 'movement_onset', 'movement_offset', 'max_velocity',
                'movement_amplitude')
        moves = [
            dict(zip(keys, (i, on[i], off[i], amp[i], peak_vel[i])))
            for i in np.arange(on.size)
        ]
        [x.update(move_key) for x in moves]

        self.Move.insert(moves)
Exemple #9
0
    def make(self, key, one=None):
        # Load the wheel for this session
        move_key = key.copy()
        change_key = move_key.copy()
        one = one or ONE()
        eid, ver = (acquisition.Session & key).fetch1('session_uuid', 'task_protocol')
        logger.info('WheelMoves for session %s, %s', str(eid), ver)

        try:  # Should be able to remove this
            wheel = one.load_object(str(eid), 'wheel')
            all_loaded = \
                all([isinstance(wheel[lab], np.ndarray) for lab in wheel]) and \
                all(k in wheel for k in ('timestamps', 'position'))
            assert all_loaded, 'wheel data missing'

            # If times and timestamps present, drop times
            if {'times', 'timestamps'}.issubset(wheel):
                wheel.pop('times')
            wheel_moves = extract_wheel_moves(wheel.timestamps, wheel.position)
        except ValueError:
            logger.exception('Failed to find movements')
            raise
        except AssertionError as ex:
            logger.exception(str(ex))
            raise
        except Exception as ex:
            logger.exception(str(ex))
            raise

        # Build list of table entries
        keys = ('move_id', 'movement_onset', 'movement_offset', 'max_velocity', 'movement_amplitude')
        on_off, amp, vel_t = wheel_moves.values()  # Unpack into short vars
        moves = [dict(zip(keys, (i, on, off, vel_t[i], amp[i])), **move_key)
                 for i, (on, off) in enumerate(on_off)]

        # Calculate direction changes
        Fs = 1000
        re_ts, re_pos = wheel.timestamps, wheel.position
        if len(re_ts.shape) != 1:
            logger.info('2D wheel timestamps')
            if len(re_pos.shape) > 1:  # Ensure 1D array of positions
                re_pos = re_pos.flatten()
            # Linearly interpolate the times
            x = np.arange(re_pos.size)
            re_ts = np.interp(x, re_ts[:, 0], re_ts[:, 1])

        pos, ts = wh.interpolate_position(re_pos, re_ts, freq=Fs)
        vel, _ = wh.velocity_smoothed(pos, Fs)
        change_mask = np.insert(np.diff(np.sign(vel)) != 0, 0, 0)

        changes = []
        for i, (on, off) in enumerate(on_off.reshape(-1, 2)):
            mask = np.logical_and(ts > on, ts < off)
            ind = np.logical_and(mask, change_mask)
            changes.extend(
                dict(change_key, move_id=i, change_id=j, change_time=t) for j, t in enumerate(ts[ind])
            )

        # Get the units of the position data
        units, *_ = infer_wheel_units(wheel.position)
        key['n_movements'] = wheel_moves['intervals'].shape[0]  # total number of movements within the session
        key['total_displacement'] = float(np.diff(wheel.position[[0, -1]]))  # total displacement of the wheel during session
        key['total_distance'] = float(np.abs(np.diff(wheel.position)).sum())  # total movement of the wheel
        key['n_direction_changes'] = sum(change_mask)  # total number of direction changes
        if units == 'cm':  # convert to radians
            key['total_displacement'] = wh.cm_to_rad(key['total_displacement'])
            key['total_distance'] = wh.cm_to_rad(key['total_distance'])
            wheel_moves['peakAmplitude'] = wh.cm_to_rad(wheel_moves['peakAmplitude'])

        # Insert the keys in order
        self.insert1(key)
        self.Move.insert(moves)
        self.DirectionChange.insert(changes)
Exemple #10
0
    def __init__(self,
                 eid=None,
                 trial=None,
                 camera='left',
                 dlc_features=None,
                 quick_load=True,
                 t_win=3,
                 one=None,
                 start=True):
        """
        Plot the wheel trace alongside the video frames.  Below is list of key bindings:
        :key n: plot movements of next trial
        :key p: plot movements of previous trial
        :key r: plot movements of a random trial
        :key t: prompt for a trial number to plot
        :key l: toggle between legend for wheel and trial events
        :key space: pause/play frames
        :key left: move to previous frame
        :key right: move to next frame

        :param eid: uuid of experiment session to load
        :param trial: the trial id to plot
        :param camera: the camera position to load, options: 'left' (default), 'right', 'body'
        :param plot_dlc: tuple of dlc features overlay onto frames
        :param quick_load: when true, move onset detection is performed on individual trials
        instead of entire session
        :param t_win: the window in seconds over which to plot the wheel trace
        :param start: if False, the Viewer must be started by calling the `run` method
        :return: Viewer object
        """
        self._logger = logging.getLogger('ibllib')

        self.t_win = t_win  # Time window of wheel plot
        self.one = one or ONE()
        self.quick_load = quick_load

        # Input validation
        if camera not in ['left', 'right', 'body']:
            raise ValueError(
                "camera must be one of 'left', 'right', or 'body'")

        # If None, randomly pick a session to load
        if not eid:
            self._logger.info('Finding random session')
            eids = self.find_sessions(dlc=dlc_features is not None)
            eid = random.choice(eids)
            ref = eid2ref(eid, as_dict=False, one=self.one)
            self._logger.info('using session %s (%s)', eid, ref)
        elif not is_uuid_string(eid):
            raise ValueError('f"{eid}" is not a valid session uuid')

        # Store complete session data: trials, timestamps, etc.
        ref = eid2ref(eid, one=self.one, parse=False)
        self._session_data = {'eid': eid, 'ref': ref, 'dlc': None}
        self._plot_data = {
        }  # Holds data specific to current plot, namely data for single trial

        # Download the DLC data if required
        if dlc_features:
            self._session_data['dlc'] = self.get_dlc(dlc_features,
                                                     camera=camera)

        # These are for the dict returned by ONE
        trial_data = self.get_trial_data('ONE')
        total_trials = trial_data['intervals'].shape[0]
        trial = random.randint(0, total_trials) if not trial else trial
        self._session_data['total_trials'] = total_trials
        self._session_data['trials'] = trial_data

        # Check for local first movement times
        first_moves = self.one.path_from_eid(
            eid) / 'alf' / '_ibl_trials.firstMovement_times.npy'
        if first_moves.exists() and 'firstMovement_times' not in trial_data:
            # Load file if exists locally
            self._session_data['trials']['firstMovement_times'] = np.load(
                first_moves)

        # Download the raw video for left camera only
        self.video_path, = self.download_raw_video(camera)
        cam_ts = self.one.load(self._session_data['eid'], ['camera.times'],
                               dclass_output=True)
        cam_ts, = [
            ts for ts, url in zip(cam_ts.data, cam_ts.url) if camera in url
        ]
        # _, cam_ts, _ = one.load(eid, ['camera.times'])  # leftCamera is in the middle of the list
        Fs = 1 / np.diff(
            cam_ts).mean()  # Approx. frequency of camera timestamps
        # Verify video frames and timestamps agree
        _, fps, count = get_video_frames_preload(self.video_path, [])

        if count != cam_ts.size:
            assert count <= cam_ts.size, 'fewer camera timestamps than frames'
            msg = 'number of timestamps does not match number video file frames: '
            self._logger.warning(msg + '%i more timestamps than frames',
                                 cam_ts.size - count)

        assert Fs - fps < 1, 'camera timestamps do not match reported frame rate'
        self._logger.info("Frame rate = %.0fHz", fps)
        # cam_ts = cam_ts[-count:]  # Remove extraneous timestamps
        self._session_data['camera_ts'] = cam_ts

        # Load wheel data
        self._session_data['wheel'] = self.one.load_object(
            self._session_data['eid'], 'wheel')
        if 'firstMovement_times' in self._session_data['trials']:
            pos, t = wh.interpolate_position(
                self._session_data['wheel']['timestamps'],
                self._session_data['wheel']['position'],
                freq=1000)

        # Plot the first frame in the upper subplot
        fig, axes = plt.subplots(nrows=2)
        fig.canvas.mpl_disconnect(
            fig.canvas.manager.key_press_handler_id)  # Disable defaults
        fig.canvas.mpl_connect(
            'key_press_event',
            self.process_key)  # Connect our own key press fn

        self._plot_data['figure'] = fig
        self._plot_data['axes'] = axes
        self._trial_num = trial

        self.anim = animation.FuncAnimation(fig,
                                            self.animate,
                                            init_func=self.init_plot,
                                            frames=cycle(range(60)),
                                            interval=20,
                                            blit=False,
                                            repeat=True,
                                            cache_frame_data=False)
        self.anim.running = False
        self.trial_num = trial  # Set trial and prepare plot/frame data
        if start:
            self.run()
Exemple #11
0
def plot_wheel_position(eid):
    '''
    illustrate wheel position next to distance plot
    '''
    T_BIN = 0.02
    rt = 2
    st = -0.5

    d = constant_reaction_time(eid, rt, st)

    one = ONE()
    wheel = one.load_object(eid, 'wheel')
    pos, t = wh.interpolate_position(wheel.timestamps,
                                     wheel.position,
                                     freq=1 / T_BIN)
    whe_left = []
    whe_right = []

    for i in d:

        start_idx = find_nearest(t, d[i][0])
        end_idx = start_idx + int(d[i][1] / T_BIN)

        wheel_pos = pos[start_idx:end_idx]
        if len(wheel_pos) == 1:
            print(i, [start_idx, end_idx])

        wheel_pos = wheel_pos - wheel_pos[0]

        if d[i][4] == -1:
            whe_left.append(wheel_pos)
        if d[i][4] == 1:
            whe_right.append(wheel_pos)

    xs = np.arange(len(whe_left[0])) * T_BIN
    times = np.concatenate([
        -1 * np.array(list(reversed(xs[:int(len(xs) * abs(st / rt))]))),
        np.array(xs[:int(len(xs) * (1 - abs(st / rt)))])
    ])

    for i in range(len(whe_left)):
        plt.plot(times, whe_left[i], c='#1f77b4', alpha=0.5, linewidth=0.05)
    for i in range(len(whe_right)):
        plt.plot(times, whe_right[i], c='darkred', alpha=0.5, linewidth=0.05)

    plt.plot(times,
             np.mean(whe_left, axis=0),
             c='#1f77b4',
             linewidth=2,
             label='left')
    plt.plot(times,
             np.mean(whe_right, axis=0),
             c='darkred',
             linewidth=2,
             label='right')

    plt.axhline(y=0.26, linestyle='--', c='k')
    plt.axhline(y=-0.26, linestyle='--', c='k', label='reward boundary')
    plt.axvline(x=0, linestyle='--', c='g', label='stimOn')
    axes = plt.gca()
    #axes.set_xlim([0,rt])
    axes.set_ylim([-0.27, 0.27])
    plt.xlabel('time [sec]')
    plt.ylabel('wheel position [rad]')
    plt.legend(loc='lower right')
    plt.title('wheel positions colored by choice')
    plt.tight_layout()
Exemple #12
0
def extract_wheel_moves(re_ts, re_pos, display=False):
    """
    Extract wheel positions and times from sync fronts dictionary
    :param re_ts: numpy array of rotary encoder timestamps
    :param re_pos: numpy array of rotary encoder positions
    :param display: bool: show the wheel position and velocity for full session with detected
    movements highlighted
    :return: wheel_moves dictionary
    """
    if len(re_ts.shape) == 1:
        assert re_ts.size == re_pos.size, 'wheel data dimension mismatch'
        assert np.all(np.diff(re_ts) > 0
                      ), 'wheel timestamps not monotonically increasing'
    else:
        _logger.debug('2D wheel timestamps')

    # Check the values and units of wheel position
    res = np.array([wh.ENC_RES, wh.ENC_RES / 2, wh.ENC_RES / 4])
    # min change in rad and cm for each decoding type
    # [rad_X4, rad_X2, rad_X1, cm_X4, cm_X2, cm_X1]
    min_change = np.concatenate(
        [2 * np.pi / res, wh.WHEEL_DIAMETER * np.pi / res])
    pos_diff = np.abs(np.ediff1d(re_pos)).min()

    # find min change closest to min pos_diff
    idx = np.argmin(np.abs(min_change - pos_diff))
    if idx < len(res):
        # Assume values are in radians
        units = 'rad'
        encoding = idx
    else:
        units = 'cm'
        encoding = idx - len(res)
    enc_names = {0: 'X4', 1: 'X2', 2: 'X1'}
    _logger.info('Wheel in %s units using %s encoding', units,
                 enc_names[int(encoding)])

    # The below assertion is violated by Bpod wheel data
    #  assert np.allclose(pos_diff, min_change, rtol=1e-05), 'wheel position skips'

    # Convert the pos threshold defaults from samples to correct unit
    thresholds = wh.samples_to_cm(np.array([8, 1.5]), resolution=res[encoding])
    if units == 'rad':
        thresholds = wh.cm_to_rad(thresholds)
    kwargs = {
        'pos_thresh': thresholds[0],
        'pos_thresh_onset': thresholds[1],
        'make_plots': display
    }

    # Interpolate and get onsets
    pos, t = wh.interpolate_position(re_ts, re_pos, freq=1000)
    on, off, amp, peak_vel = wh.movements(t, pos, freq=1000, **kwargs)
    assert on.size == off.size, 'onset/offset number mismatch'
    assert np.all(np.diff(on) > 0) and np.all(
        np.diff(off) > 0), 'onsets/offsets not monotonically increasing'
    assert np.all((off - on) > 0), 'not all offsets occur after onset'

    # Put into dict
    wheel_moves = {
        'intervals': np.c_[on, off],
        'peakAmplitude': amp,
        'peakVelocity_times': peak_vel
    }
    return wheel_moves
Exemple #13
0
    def align_motion(self,
                     period=(-np.inf, np.inf),
                     side='left',
                     sd_thresh=10,
                     display=False):
        # Get data samples within period
        wheel = self.data['wheel']
        self.alignment.label = side
        self.alignment.to_mask = lambda ts: np.logical_and(
            ts >= period[0], ts <= period[1])
        camera_times = self.data['camera_times'][side]
        cam_mask = self.alignment.to_mask(camera_times)
        frame_numbers, = np.where(cam_mask)

        if frame_numbers.size == 0:
            raise ValueError('No frames during given period')

        # Motion Energy
        camera_path = self.video_paths[side]
        roi = (*[slice(*r) for r in self.roi[side]], 0)
        try:
            # TODO Add function arg to make grayscale
            self.alignment.frames = \
                vidio.get_video_frames_preload(camera_path, frame_numbers, mask=roi)
            assert self.alignment.frames.size != 0
        except AssertionError:
            self.log.error('Failed to open video')
            return None, None, None
        self.alignment.df, stDev = video.motion_energy(self.alignment.frames,
                                                       2)
        self.alignment.period = period  # For plotting

        # Calculate rotary encoder velocity trace
        x = camera_times[cam_mask]
        Fs = 1000
        pos, t = wh.interpolate_position(wheel.timestamps,
                                         wheel.position,
                                         freq=Fs)
        v, _ = wh.velocity_smoothed(pos, Fs)
        interp_mask = self.alignment.to_mask(t)
        # Convert to normalized speed
        xs = np.unique([find_nearest(t[interp_mask], ts) for ts in x])
        vs = np.abs(v[interp_mask][xs])
        vs = (vs - np.min(vs)) / (np.max(vs) - np.min(vs))

        # FIXME This can be used as a goodness of fit measure
        USE_CV2 = False
        if USE_CV2:
            # convert from numpy format to openCV format
            dfCV = np.float32(self.alignment.df.reshape((-1, 1)))
            reCV = np.float32(vs.reshape((-1, 1)))

            # perform cross correlation
            resultCv = cv2.matchTemplate(dfCV, reCV, cv2.TM_CCORR_NORMED)

            # convert result back to numpy array
            xcorr = np.asarray(resultCv)
        else:
            xcorr = signal.correlate(self.alignment.df, vs)

        # Cross correlate wheel speed trace with the motion energy
        CORRECTION = 2
        self.alignment.c = max(xcorr)
        self.alignment.xcorr = np.argmax(xcorr)
        self.alignment.dt_i = self.alignment.xcorr - xs.size + CORRECTION
        self.log.info(
            f'{side} camera, adjusted by {self.alignment.dt_i} frames')

        if display:
            # Plot the motion energy
            fig, ax = plt.subplots(2, 1, sharex='all')
            y = np.pad(self.alignment.df, 1, 'edge')
            ax[0].plot(x, y, '-x', label='wheel motion energy')
            thresh = stDev > sd_thresh
            ax[0].vlines(x[np.array(
                np.pad(thresh, 1, 'constant', constant_values=False))],
                         0,
                         1,
                         linewidth=0.5,
                         linestyle=':',
                         label=f'>{sd_thresh} s.d. diff')
            ax[1].plot(t[interp_mask], np.abs(v[interp_mask]))

            # Plot other stuff
            dt = np.diff(camera_times[[0, np.abs(self.alignment.dt_i)]])
            fps = 1 / np.diff(camera_times).mean()
            ax[0].plot(t[interp_mask][xs] - dt,
                       vs,
                       'r-x',
                       label='velocity (shifted)')
            ax[0].set_title('normalized motion energy, %s camera, %.0f fps' %
                            (side, fps))
            ax[0].set_ylabel('rate of change (a.u.)')
            ax[0].legend()
            ax[1].set_ylabel('wheel speed (rad / s)')
            ax[1].set_xlabel('Time (s)')

            title = f'{self.ref}, from {period[0]:.1f}s - {period[1]:.1f}s'
            fig.suptitle(title, fontsize=16)
            fig.set_size_inches(19.2, 9.89)

        return self.alignment.dt_i, self.alignment.c, self.alignment.df