def get_active_wheel_period(wheel, duration_range=(3., 20.), display=False): """ Attempts to find a period of movement where the wheel accelerates and decelerates for the wheel motion alignment QC. :param wheel: A Bunch of wheel timestamps and position data :param duration_range: The candidates must be within min/max duration range :param display: If true, plot the selected wheel movement :return: 2-element array comprising the start and end times of the active period """ pos, ts = wh.interpolate_position(wheel.timestamps, wheel.position) v, acc = wh.velocity_smoothed(pos, 1000) on, off, *_ = wh.movements(ts, acc, pos_thresh=.1, make_plots=False) edges = np.c_[on, off] indices, _ = np.where( np.logical_and( np.diff(edges) > duration_range[0], np.diff(edges) < duration_range[1])) if len(indices) == 0: _log.warning( 'No period of wheel movement found for motion alignment.') return None # Pick movement somewhere in the middle i = indices[int(indices.size / 2)] if display: _, (ax0, ax1) = plt.subplots(2, 1, sharex='all') mask = np.logical_and(ts > edges[i][0], ts < edges[i][1]) ax0.plot(ts[mask], pos[mask]) ax1.plot(ts[mask], acc[mask]) return edges[i]
def extract_onsets_for_trial(self): """ Extracts the movement onsets and offsets for the current trial :return: tuple of onsets, offsets on, interpolated timestamps, interpolated positions, and position units """ wheel = self._session_data['wheel'] trials = self._session_data['trials'] trial_idx = self.trial_num - 1 # Trials num starts at 1 # Check the values and units of wheel position res = np.array([wh.ENC_RES, wh.ENC_RES / 2, wh.ENC_RES / 4]) # min change in rad and cm for each decoding type # [rad_X4, rad_X2, rad_X1, cm_X4, cm_X2, cm_X1] min_change = np.concatenate( [2 * np.pi / res, wh.WHEEL_DIAMETER * np.pi / res]) pos_diff = np.median(np.abs(np.ediff1d(wheel['position']))) # find min change closest to min pos_diff idx = np.argmin(np.abs(min_change - pos_diff)) if idx < len(res): # Assume values are in radians units = 'rad' encoding = idx else: units = 'cm' encoding = idx - len(res) thresholds = wh.samples_to_cm(np.array([8, 1.5]), resolution=res[encoding]) if units == 'rad': thresholds = wh.cm_to_rad(thresholds) kwargs = { 'pos_thresh': thresholds[0], 'pos_thresh_onset': thresholds[1] } # kwargs = {'make_plots': True, **kwargs} # Uncomment for plot # Interpolate and get onsets pos, t = wh.interpolate_position(wheel['timestamps'], wheel['position'], freq=1000) # Get the positions and times between our trial start and the next trial start if self.quick_load or not self.trial_num: try: # End of previous trial to beginning of next t_mask = np.logical_and( t >= trials['intervals'][trial_idx - 1, 1], t <= trials['intervals'][trial_idx + 1, 0]) except IndexError: # We're on the last trial # End of previous trial to end of current t_mask = np.logical_and( t >= trials['intervals'][trial_idx - 1, 1], t <= trials['intervals'][trial_idx, 1]) else: t_mask = np.ones_like(t, dtype=bool) wheel_ts = t[t_mask] wheel_pos = pos[t_mask] on, off, *_ = wh.movements(wheel_ts, wheel_pos, freq=1000, **kwargs) return on, off, wheel_ts, wheel_pos, units
def test_movements_FPGA(self): # These test data are the same as those used in the MATLAB code. Test data are from # extracted FPGA wheel data pos, t = wheel.interpolate_position(*self.test_data[1][0], freq=1000) expected = self.test_data[1][1] thresholds = wheel.samples_to_cm(np.array([8, 1.5])) on, off, amp, peak_vel = wheel.movements( t, pos, freq=1000, pos_thresh=thresholds[0], pos_thresh_onset=thresholds[1]) self.assertTrue(np.allclose(on, expected[0], atol=1.e-5), msg='Unexpected onsets') self.assertTrue(np.allclose(off, expected[1], atol=1.e-5), msg='Unexpected offsets') self.assertTrue(np.allclose(amp, expected[2], atol=1.e-5), msg='Unexpected move amps') self.assertTrue(np.allclose(peak_vel, expected[3], atol=1.e-2), msg='Unexpected peak velocities')
def plot_wheel_position(wheel_position, wheel_time, trials_df): """ Plots wheel position across trials, color by which side was chosen :param wheel_position: np.array, interpolated wheel position :param wheel_time: np.array, interpolated wheel timestamps :param trials_df: pd.DataFrame, with column 'stimOn_times' (time of stimulus onset times for each trial) :returns: matplotlib.axis """ # Interpolate wheel data wheel_position, wheel_time = bbox_wheel.interpolate_position( wheel_time, wheel_position, freq=1 / T_BIN) # Create a window around the stimulus onset start_window, end_window = plt_window(trials_df['stimOn_times']) # Translating the time window into an index window start_idx = insert_idx(wheel_time, start_window) end_idx = np.array(start_idx + int(WINDOW_LEN / T_BIN), dtype='int64') # Getting the wheel position for each window, normalize to first value of each window trials_df['wheel_position'] = [ wheel_position[start_idx[w]:end_idx[w]] - wheel_position[start_idx[w]] for w in range(len(start_idx)) ] # Plotting times = np.arange(len( trials_df['wheel_position'].iloc[0])) * T_BIN + WINDOW_LAG for side, label, color in zip([-1, 1], ['right', 'left'], ['darkred', '#1f77b4']): side_df = trials_df[trials_df['choice'] == side] for idx in side_df.index: plt.plot(times, side_df.loc[idx, 'wheel_position'], c=color, alpha=0.5, linewidth=0.05) plt.plot(times, side_df['wheel_position'].mean(), c=color, linewidth=2, label=f'{label} turn') plt.axvline(x=0, linestyle='--', c='k', label='stimOn') plt.axhline(y=-0.26, linestyle='--', c='g', label='reward') plt.axhline(y=0.26, linestyle='--', c='g', label='reward') plt.ylim([-0.27, 0.27]) plt.xlabel('time [sec]') plt.ylabel('wheel position diff to first value [rad]') plt.legend(loc='center right') plt.title('Wheel position trial avg\n(and individual trials)') plt.tight_layout() return plt.gca()
def extract_wheel_moves(re_ts, re_pos, display=False): """ Extract wheel positions and times from sync fronts dictionary :param re_ts: numpy array of rotary encoder timestamps :param re_pos: numpy array of rotary encoder positions :param display: bool: show the wheel position and velocity for full session with detected movements highlighted :return: wheel_moves dictionary """ if len(re_ts.shape) == 1: assert re_ts.size == re_pos.size, 'wheel data dimension mismatch' else: _logger.debug('2D wheel timestamps') if len(re_pos.shape) > 1: # Ensure 1D array of positions re_pos = re_pos.flatten() # Linearly interpolate the times x = np.arange(re_pos.size) re_ts = np.interp(x, re_ts[:, 0], re_ts[:, 1]) units, res, enc = infer_wheel_units(re_pos) _logger.info('Wheel in %s units using %s encoding', units, enc) # The below assertion is violated by Bpod wheel data # assert np.allclose(pos_diff, min_change, rtol=1e-05), 'wheel position skips' # Convert the pos threshold defaults from samples to correct unit thresholds = wh.samples_to_cm(np.array([8, 1.5]), resolution=res) if units == 'rad': thresholds = wh.cm_to_rad(thresholds) kwargs = { 'pos_thresh': thresholds[0], 'pos_thresh_onset': thresholds[1], 'make_plots': display } # Interpolate and get onsets pos, t = wh.interpolate_position(re_ts, re_pos, freq=1000) on, off, amp, peak_vel = wh.movements(t, pos, freq=1000, **kwargs) assert on.size == off.size, 'onset/offset number mismatch' assert np.all(np.diff(on) > 0) and np.all( np.diff(off) > 0), 'onsets/offsets not strictly increasing' assert np.all((off - on) > 0), 'not all offsets occur after onset' # Put into dict wheel_moves = { 'intervals': np.c_[on, off], 'peakAmplitude': amp, 'peakVelocity_times': peak_vel } return wheel_moves
def Viewer(eid, video_type, trial_range, save_video=True, eye_zoom=False): ''' eid: session id, e.g. '3663d82b-f197-4e8b-b299-7b803a155b84' video_type: one of 'left', 'right', 'body' trial_range: first and last trial number of range to be shown, e.g. [5,7] save_video: video is displayed and saved in local folder Example usage to view and save labeled video with wheel angle: Viewer('3663d82b-f197-4e8b-b299-7b803a155b84', 'left', [5,7]) 3D example: 'cb2ad999-a6cb-42ff-bf71-1774c57e5308', [5,7] ''' save_vids_here = '/home/mic/' if save_vids_here[-1] != '/': return 'Last character of save_vids_here must be slash' one = ONE() dataset_types = [ 'camera.times', 'wheel.position', 'wheel.timestamps', 'trials.intervals', 'camera.dlc' ] a = one.list(eid, 'dataset-types') assert all([i in a for i in dataset_types ]), 'For this eid, not all data available' D = one.load(eid, dataset_types=dataset_types, dclass_output=True) alf_path = Path(D.local_path[0]).parent.parent / 'alf' # Download a single video video_data = alf_path.parent / 'raw_video_data' download_raw_video(eid, cameras=[video_type]) video_path = list(video_data.rglob('_iblrig_%sCamera.raw.*' % video_type))[0] print(video_path) # that gives cam time stamps and DLC output (change to alf_path eventually) cam = alf.io.load_object(alf_path, '%sCamera' % video_type, namespace='ibl') # just to read in times for newer data (which has DLC results in pqt format # cam = alf.io.load_object(alf_path, '_ibl_%sCamera' % video_type) # set where to read and save video and get video info cap = cv2.VideoCapture(video_path.as_uri()) length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) fps = cap.get(cv2.CAP_PROP_FPS) size = (int(cap.get(3)), int(cap.get(4))) assert length < len(cam['times']), '#frames > #stamps' print(eid, ', ', video_type, ', fsp:', fps, ', #frames:', length, ', #stamps:', len(cam['times']), ', #frames - #stamps = ', length - len(cam['times'])) # pick trial range for which to display stuff trials = alf.io.load_object(alf_path, 'trials', namespace='ibl') num_trials = len(trials['intervals']) if trial_range[-1] > num_trials - 1: print('There are only %s trials' % num_trials) frame_start = find_nearest(cam['times'], [trials['intervals'][trial_range[0]][0]]) frame_stop = find_nearest(cam['times'], [trials['intervals'][trial_range[-1]][1]]) ''' wheel related stuff ''' wheel = alf.io.load_object(alf_path, 'wheel', namespace='ibl') import brainbox.behavior.wheel as wh try: pos, t = wh.interpolate_position(wheel['timestamps'], wheel['position'], freq=1000) except BaseException: pos, t = wh.interpolate_position(wheel['times'], wheel['position'], freq=1000) w_start = find_nearest(t, trials['intervals'][trial_range[0]][0]) w_stop = find_nearest(t, trials['intervals'][trial_range[-1]][1]) # confine to interval pos_int = pos[w_start:w_stop] t_int = t[w_start:w_stop] # alignment of cam stamps and interpolated wheel stamps wheel_pos = [] kk = 0 for wt in cam['times'][frame_start:frame_stop]: wheel_pos.append(pos_int[find_nearest(t_int, wt)]) kk += 1 if kk % 3000 == 0: print('iteration', kk) ''' DLC related stuff ''' Times = cam['times'][frame_start:frame_stop] del cam['times'] # some exception for inconsisitent data formats try: dlc_name = '_ibl_%sCamera.dlc.pqt' % video_type dlc_path = alf_path / dlc_name cam = pd.read_parquet(dlc_path, engine="fastparquet") print('it is pqt') except BaseException: raw_vid_path = alf_path.parent / 'raw_video_data' cam = alf.io.load_object(raw_vid_path, '%sCamera' % video_type, namespace='ibl') points = np.unique(['_'.join(x.split('_')[:-1]) for x in cam.keys()]) if len(points) == 1: cam = cam['dlc'] points = np.unique(['_'.join(x.split('_')[:-1]) for x in cam.keys()]) if video_type != 'body': d = list(points) d.remove('tube_top') d.remove('tube_bottom') points = np.array(d) # Set values to nan if likelyhood is too low # for pqt: .to_numpy() XYs = {} for point in points: x = np.ma.masked_where(cam[point + '_likelihood'] < 0.9, cam[point + '_x']) x = x.filled(np.nan) y = np.ma.masked_where(cam[point + '_likelihood'] < 0.9, cam[point + '_y']) y = y.filled(np.nan) XYs[point] = np.array( [x[frame_start:frame_stop], y[frame_start:frame_stop]]) # Just for 3D testing # return XYs # Zoom at eye if eye_zoom: pivot = np.nanmean(XYs['pupil_top_r'], axis=1) x0 = int(pivot[0]) - 33 x1 = int(pivot[0]) + 33 y0 = int(pivot[1]) - 28 y1 = int(pivot[1]) + 38 size = (66, 66) dot_s = 1 # [px] for painting DLC dots else: x0 = 0 x1 = size[0] y0 = 0 y1 = size[1] if video_type == 'left': dot_s = 10 # [px] for painting DLC dots else: dot_s = 5 if save_video: loc = save_vids_here + '%s_trials_%s_%s_%s.mp4' % ( eid, trial_range[0], trial_range[-1], video_type) out = cv2.VideoWriter(loc, cv2.VideoWriter_fourcc(*'mp4v'), fps, size) # put , 0 if grey scale # writing stuff on frames font = cv2.FONT_HERSHEY_SIMPLEX if video_type == 'left': bottomLeftCornerOfText = (20, 1000) fontScale = 4 else: bottomLeftCornerOfText = (10, 500) fontScale = 2 lineType = 2 # assign a color to each DLC point (now: all points red) cmap = matplotlib.cm.get_cmap('Spectral') CR = np.arange(len(points)) / len(points) block = np.ones((2 * dot_s, 2 * dot_s, 3)) # set start frame cap.set(1, frame_start) k = 0 while (cap.isOpened()): ret, frame = cap.read() gray = frame # print wheel angle fontColor = (255, 255, 255) Angle = round(wheel_pos[k], 2) Time = round(Times[k], 3) cv2.putText(gray, 'Wheel angle: ' + str(Angle), bottomLeftCornerOfText, font, fontScale / 2, fontColor, lineType) a, b = bottomLeftCornerOfText bottomLeftCornerOfText0 = (int(a * 10 + b / 2), b) cv2.putText(gray, ' time: ' + str(Time), bottomLeftCornerOfText0, font, fontScale / 2, fontColor, lineType) # print DLC dots ll = 0 for point in points: # Put point color legend fontColor = (np.array([cmap(CR[ll])]) * 255)[0][:3] a, b = bottomLeftCornerOfText if video_type == 'right': bottomLeftCornerOfText2 = (a, a * 2 * (1 + ll)) else: bottomLeftCornerOfText2 = (b, a * 2 * (1 + ll)) fontScale2 = fontScale / 4 cv2.putText(gray, point, bottomLeftCornerOfText2, font, fontScale2, fontColor, lineType) X0 = XYs[point][0][k] Y0 = XYs[point][1][k] # transform for opencv? X = Y0 Y = X0 if not np.isnan(X) and not np.isnan(Y): col = (np.array([cmap(CR[ll])]) * 255)[0][:3] # col = np.array([0, 0, 255]) # all points red X = X.astype(int) Y = Y.astype(int) gray[X - dot_s:X + dot_s, Y - dot_s:Y + dot_s] = block * col ll += 1 gray = gray[y0:y1, x0:x1] if save_video: out.write(gray) cv2.imshow('frame', gray) cv2.waitKey(1) k += 1 if k == (frame_stop - frame_start) - 1: break if save_video: out.release() cap.release() cv2.destroyAllWindows()
n_trials = 3 # Number of trials to plot # Randomly select the trials to plot trial_ids = np.random.randint(trial_data['choice'].size, size=n_trials) fig, axs = plt.subplots(1, n_trials, figsize=(8.5, 2.5)) plt.tight_layout() # Plot go cue and response times goCues = trial_data['goCue_times'][trial_ids] responses = trial_data['response_times'][trial_ids] # Plot traces between trial intervals starts = trial_data['intervals'][trial_ids, 0] ends = trial_data['intervals'][trial_ids, 1] # Cut up the wheel vectors Fs = 1000 pos, t = wh.interpolate_position(wheel.timestamps, wheel.position, freq=Fs) vel, acc = wh.velocity_smoothed(pos, Fs) traces = wh.traces_by_trial(t, pos, start=starts, end=ends) zipped = zip(traces, axs, goCues, responses, trial_ids) for (trace, ax, go, resp, n) in zipped: ax.plot(trace[0], trace[1], 'k-') ax.axvline(x=go, color='g', label='go cue', linestyle=':') ax.axvline(x=resp, color='r', label='threshold', linestyle=':') ax.set_title('Trial #%s' % n) # Turn off tick labels ax.set_yticklabels([]) ax.set_xticklabels([])
def make(self, key): # Load the wheel for this session move_key = key.copy() one = ONE() eid, ver = (acquisition.Session & key).fetch1('session_uuid', 'task_protocol') logger.info('WheelMoves for session %s, %s', str(eid), ver) try: # Should be able to remove this wheel = one.load_object(str(eid), 'wheel') all_loaded = \ all([isinstance(wheel[lab], np.ndarray) for lab in wheel]) and \ all(k in wheel for k in ('timestamps', 'position')) assert all_loaded, 'wheel data missing' alf.io.check_dimensions(wheel) if len(wheel['timestamps'].shape) == 1: assert wheel['timestamps'].size == wheel[ 'position'].size, 'wheel data dimension mismatch' assert np.all( np.diff(wheel['timestamps']) > 0 ), 'wheel timestamps not monotonically increasing' else: logger.debug('2D timestamps') # Check the values and units of wheel position res = np.array([wh.ENC_RES, wh.ENC_RES / 2, wh.ENC_RES / 4]) min_change_rad = 2 * np.pi / res min_change_cm = wh.WHEEL_DIAMETER * np.pi / res pos_diff = np.abs(np.ediff1d(wheel['position'])) if pos_diff.min() < min_change_cm.min(): # Assume values are in radians units = 'rad' encoding = np.argmin(np.abs(min_change_rad - pos_diff.min())) min_change = min_change_rad[encoding] else: units = 'cm' encoding = np.argmin(np.abs(min_change_cm - pos_diff.min())) min_change = min_change_cm[encoding] enc_names = {0: '4X', 1: '2X', 2: '1X'} logger.info('Wheel in %s units using %s encoding', units, enc_names[int(encoding)]) if '_iblrig_tasks_ephys' in ver: assert np.allclose(pos_diff, min_change, rtol=1e-05), 'wheel position skips' except ValueError: logger.exception('Inconsistent wheel data') raise except AssertionError as ex: logger.exception(str(ex)) raise except Exception as ex: logger.exception(str(ex)) raise try: # Convert the pos threshold defaults from samples to correct unit thresholds = wh.samples_to_cm(np.array([8, 1.5]), resolution=res[encoding]) if units == 'rad': thresholds = wh.cm_to_rad(thresholds) kwargs = { 'pos_thresh': thresholds[0], 'pos_thresh_onset': thresholds[1] } # kwargs = {'make_plots': True, **kwargs} # Interpolate and get onsets pos, t = wh.interpolate_position(wheel['timestamps'], wheel['position'], freq=1000) on, off, amp, peak_vel = wh.movements(t, pos, freq=1000, **kwargs) assert on.size == off.size, 'onset/offset number mismatch' assert np.all(np.diff(on) > 0) and np.all(np.diff( off) > 0), 'onsets/offsets not monotonically increasing' assert np.all((off - on) > 0), 'not all offsets occur after onset' except ValueError: logger.exception('Failed to find movements') raise except AssertionError as ex: logger.exception('Wheel integrity check failed: ' + str(ex)) raise key['n_movements'] = on.size # total number of movements within the session key['total_displacement'] = float(np.diff( pos[[0, -1]])) # total displacement of the wheel during session key['total_distance'] = float(np.abs( np.diff(pos)).sum()) # total movement of the wheel if units is 'cm': # convert to radians key['total_displacement'] = wh.cm_to_rad(key['total_displacement']) key['total_distance'] = wh.cm_to_rad(key['total_distance']) amp = wh.cm_to_rad(amp) self.insert1(key) keys = ('move_id', 'movement_onset', 'movement_offset', 'max_velocity', 'movement_amplitude') moves = [ dict(zip(keys, (i, on[i], off[i], amp[i], peak_vel[i]))) for i in np.arange(on.size) ] [x.update(move_key) for x in moves] self.Move.insert(moves)
def make(self, key, one=None): # Load the wheel for this session move_key = key.copy() change_key = move_key.copy() one = one or ONE() eid, ver = (acquisition.Session & key).fetch1('session_uuid', 'task_protocol') logger.info('WheelMoves for session %s, %s', str(eid), ver) try: # Should be able to remove this wheel = one.load_object(str(eid), 'wheel') all_loaded = \ all([isinstance(wheel[lab], np.ndarray) for lab in wheel]) and \ all(k in wheel for k in ('timestamps', 'position')) assert all_loaded, 'wheel data missing' # If times and timestamps present, drop times if {'times', 'timestamps'}.issubset(wheel): wheel.pop('times') wheel_moves = extract_wheel_moves(wheel.timestamps, wheel.position) except ValueError: logger.exception('Failed to find movements') raise except AssertionError as ex: logger.exception(str(ex)) raise except Exception as ex: logger.exception(str(ex)) raise # Build list of table entries keys = ('move_id', 'movement_onset', 'movement_offset', 'max_velocity', 'movement_amplitude') on_off, amp, vel_t = wheel_moves.values() # Unpack into short vars moves = [dict(zip(keys, (i, on, off, vel_t[i], amp[i])), **move_key) for i, (on, off) in enumerate(on_off)] # Calculate direction changes Fs = 1000 re_ts, re_pos = wheel.timestamps, wheel.position if len(re_ts.shape) != 1: logger.info('2D wheel timestamps') if len(re_pos.shape) > 1: # Ensure 1D array of positions re_pos = re_pos.flatten() # Linearly interpolate the times x = np.arange(re_pos.size) re_ts = np.interp(x, re_ts[:, 0], re_ts[:, 1]) pos, ts = wh.interpolate_position(re_pos, re_ts, freq=Fs) vel, _ = wh.velocity_smoothed(pos, Fs) change_mask = np.insert(np.diff(np.sign(vel)) != 0, 0, 0) changes = [] for i, (on, off) in enumerate(on_off.reshape(-1, 2)): mask = np.logical_and(ts > on, ts < off) ind = np.logical_and(mask, change_mask) changes.extend( dict(change_key, move_id=i, change_id=j, change_time=t) for j, t in enumerate(ts[ind]) ) # Get the units of the position data units, *_ = infer_wheel_units(wheel.position) key['n_movements'] = wheel_moves['intervals'].shape[0] # total number of movements within the session key['total_displacement'] = float(np.diff(wheel.position[[0, -1]])) # total displacement of the wheel during session key['total_distance'] = float(np.abs(np.diff(wheel.position)).sum()) # total movement of the wheel key['n_direction_changes'] = sum(change_mask) # total number of direction changes if units == 'cm': # convert to radians key['total_displacement'] = wh.cm_to_rad(key['total_displacement']) key['total_distance'] = wh.cm_to_rad(key['total_distance']) wheel_moves['peakAmplitude'] = wh.cm_to_rad(wheel_moves['peakAmplitude']) # Insert the keys in order self.insert1(key) self.Move.insert(moves) self.DirectionChange.insert(changes)
def __init__(self, eid=None, trial=None, camera='left', dlc_features=None, quick_load=True, t_win=3, one=None, start=True): """ Plot the wheel trace alongside the video frames. Below is list of key bindings: :key n: plot movements of next trial :key p: plot movements of previous trial :key r: plot movements of a random trial :key t: prompt for a trial number to plot :key l: toggle between legend for wheel and trial events :key space: pause/play frames :key left: move to previous frame :key right: move to next frame :param eid: uuid of experiment session to load :param trial: the trial id to plot :param camera: the camera position to load, options: 'left' (default), 'right', 'body' :param plot_dlc: tuple of dlc features overlay onto frames :param quick_load: when true, move onset detection is performed on individual trials instead of entire session :param t_win: the window in seconds over which to plot the wheel trace :param start: if False, the Viewer must be started by calling the `run` method :return: Viewer object """ self._logger = logging.getLogger('ibllib') self.t_win = t_win # Time window of wheel plot self.one = one or ONE() self.quick_load = quick_load # Input validation if camera not in ['left', 'right', 'body']: raise ValueError( "camera must be one of 'left', 'right', or 'body'") # If None, randomly pick a session to load if not eid: self._logger.info('Finding random session') eids = self.find_sessions(dlc=dlc_features is not None) eid = random.choice(eids) ref = eid2ref(eid, as_dict=False, one=self.one) self._logger.info('using session %s (%s)', eid, ref) elif not is_uuid_string(eid): raise ValueError('f"{eid}" is not a valid session uuid') # Store complete session data: trials, timestamps, etc. ref = eid2ref(eid, one=self.one, parse=False) self._session_data = {'eid': eid, 'ref': ref, 'dlc': None} self._plot_data = { } # Holds data specific to current plot, namely data for single trial # Download the DLC data if required if dlc_features: self._session_data['dlc'] = self.get_dlc(dlc_features, camera=camera) # These are for the dict returned by ONE trial_data = self.get_trial_data('ONE') total_trials = trial_data['intervals'].shape[0] trial = random.randint(0, total_trials) if not trial else trial self._session_data['total_trials'] = total_trials self._session_data['trials'] = trial_data # Check for local first movement times first_moves = self.one.path_from_eid( eid) / 'alf' / '_ibl_trials.firstMovement_times.npy' if first_moves.exists() and 'firstMovement_times' not in trial_data: # Load file if exists locally self._session_data['trials']['firstMovement_times'] = np.load( first_moves) # Download the raw video for left camera only self.video_path, = self.download_raw_video(camera) cam_ts = self.one.load(self._session_data['eid'], ['camera.times'], dclass_output=True) cam_ts, = [ ts for ts, url in zip(cam_ts.data, cam_ts.url) if camera in url ] # _, cam_ts, _ = one.load(eid, ['camera.times']) # leftCamera is in the middle of the list Fs = 1 / np.diff( cam_ts).mean() # Approx. frequency of camera timestamps # Verify video frames and timestamps agree _, fps, count = get_video_frames_preload(self.video_path, []) if count != cam_ts.size: assert count <= cam_ts.size, 'fewer camera timestamps than frames' msg = 'number of timestamps does not match number video file frames: ' self._logger.warning(msg + '%i more timestamps than frames', cam_ts.size - count) assert Fs - fps < 1, 'camera timestamps do not match reported frame rate' self._logger.info("Frame rate = %.0fHz", fps) # cam_ts = cam_ts[-count:] # Remove extraneous timestamps self._session_data['camera_ts'] = cam_ts # Load wheel data self._session_data['wheel'] = self.one.load_object( self._session_data['eid'], 'wheel') if 'firstMovement_times' in self._session_data['trials']: pos, t = wh.interpolate_position( self._session_data['wheel']['timestamps'], self._session_data['wheel']['position'], freq=1000) # Plot the first frame in the upper subplot fig, axes = plt.subplots(nrows=2) fig.canvas.mpl_disconnect( fig.canvas.manager.key_press_handler_id) # Disable defaults fig.canvas.mpl_connect( 'key_press_event', self.process_key) # Connect our own key press fn self._plot_data['figure'] = fig self._plot_data['axes'] = axes self._trial_num = trial self.anim = animation.FuncAnimation(fig, self.animate, init_func=self.init_plot, frames=cycle(range(60)), interval=20, blit=False, repeat=True, cache_frame_data=False) self.anim.running = False self.trial_num = trial # Set trial and prepare plot/frame data if start: self.run()
def plot_wheel_position(eid): ''' illustrate wheel position next to distance plot ''' T_BIN = 0.02 rt = 2 st = -0.5 d = constant_reaction_time(eid, rt, st) one = ONE() wheel = one.load_object(eid, 'wheel') pos, t = wh.interpolate_position(wheel.timestamps, wheel.position, freq=1 / T_BIN) whe_left = [] whe_right = [] for i in d: start_idx = find_nearest(t, d[i][0]) end_idx = start_idx + int(d[i][1] / T_BIN) wheel_pos = pos[start_idx:end_idx] if len(wheel_pos) == 1: print(i, [start_idx, end_idx]) wheel_pos = wheel_pos - wheel_pos[0] if d[i][4] == -1: whe_left.append(wheel_pos) if d[i][4] == 1: whe_right.append(wheel_pos) xs = np.arange(len(whe_left[0])) * T_BIN times = np.concatenate([ -1 * np.array(list(reversed(xs[:int(len(xs) * abs(st / rt))]))), np.array(xs[:int(len(xs) * (1 - abs(st / rt)))]) ]) for i in range(len(whe_left)): plt.plot(times, whe_left[i], c='#1f77b4', alpha=0.5, linewidth=0.05) for i in range(len(whe_right)): plt.plot(times, whe_right[i], c='darkred', alpha=0.5, linewidth=0.05) plt.plot(times, np.mean(whe_left, axis=0), c='#1f77b4', linewidth=2, label='left') plt.plot(times, np.mean(whe_right, axis=0), c='darkred', linewidth=2, label='right') plt.axhline(y=0.26, linestyle='--', c='k') plt.axhline(y=-0.26, linestyle='--', c='k', label='reward boundary') plt.axvline(x=0, linestyle='--', c='g', label='stimOn') axes = plt.gca() #axes.set_xlim([0,rt]) axes.set_ylim([-0.27, 0.27]) plt.xlabel('time [sec]') plt.ylabel('wheel position [rad]') plt.legend(loc='lower right') plt.title('wheel positions colored by choice') plt.tight_layout()
def extract_wheel_moves(re_ts, re_pos, display=False): """ Extract wheel positions and times from sync fronts dictionary :param re_ts: numpy array of rotary encoder timestamps :param re_pos: numpy array of rotary encoder positions :param display: bool: show the wheel position and velocity for full session with detected movements highlighted :return: wheel_moves dictionary """ if len(re_ts.shape) == 1: assert re_ts.size == re_pos.size, 'wheel data dimension mismatch' assert np.all(np.diff(re_ts) > 0 ), 'wheel timestamps not monotonically increasing' else: _logger.debug('2D wheel timestamps') # Check the values and units of wheel position res = np.array([wh.ENC_RES, wh.ENC_RES / 2, wh.ENC_RES / 4]) # min change in rad and cm for each decoding type # [rad_X4, rad_X2, rad_X1, cm_X4, cm_X2, cm_X1] min_change = np.concatenate( [2 * np.pi / res, wh.WHEEL_DIAMETER * np.pi / res]) pos_diff = np.abs(np.ediff1d(re_pos)).min() # find min change closest to min pos_diff idx = np.argmin(np.abs(min_change - pos_diff)) if idx < len(res): # Assume values are in radians units = 'rad' encoding = idx else: units = 'cm' encoding = idx - len(res) enc_names = {0: 'X4', 1: 'X2', 2: 'X1'} _logger.info('Wheel in %s units using %s encoding', units, enc_names[int(encoding)]) # The below assertion is violated by Bpod wheel data # assert np.allclose(pos_diff, min_change, rtol=1e-05), 'wheel position skips' # Convert the pos threshold defaults from samples to correct unit thresholds = wh.samples_to_cm(np.array([8, 1.5]), resolution=res[encoding]) if units == 'rad': thresholds = wh.cm_to_rad(thresholds) kwargs = { 'pos_thresh': thresholds[0], 'pos_thresh_onset': thresholds[1], 'make_plots': display } # Interpolate and get onsets pos, t = wh.interpolate_position(re_ts, re_pos, freq=1000) on, off, amp, peak_vel = wh.movements(t, pos, freq=1000, **kwargs) assert on.size == off.size, 'onset/offset number mismatch' assert np.all(np.diff(on) > 0) and np.all( np.diff(off) > 0), 'onsets/offsets not monotonically increasing' assert np.all((off - on) > 0), 'not all offsets occur after onset' # Put into dict wheel_moves = { 'intervals': np.c_[on, off], 'peakAmplitude': amp, 'peakVelocity_times': peak_vel } return wheel_moves
def align_motion(self, period=(-np.inf, np.inf), side='left', sd_thresh=10, display=False): # Get data samples within period wheel = self.data['wheel'] self.alignment.label = side self.alignment.to_mask = lambda ts: np.logical_and( ts >= period[0], ts <= period[1]) camera_times = self.data['camera_times'][side] cam_mask = self.alignment.to_mask(camera_times) frame_numbers, = np.where(cam_mask) if frame_numbers.size == 0: raise ValueError('No frames during given period') # Motion Energy camera_path = self.video_paths[side] roi = (*[slice(*r) for r in self.roi[side]], 0) try: # TODO Add function arg to make grayscale self.alignment.frames = \ vidio.get_video_frames_preload(camera_path, frame_numbers, mask=roi) assert self.alignment.frames.size != 0 except AssertionError: self.log.error('Failed to open video') return None, None, None self.alignment.df, stDev = video.motion_energy(self.alignment.frames, 2) self.alignment.period = period # For plotting # Calculate rotary encoder velocity trace x = camera_times[cam_mask] Fs = 1000 pos, t = wh.interpolate_position(wheel.timestamps, wheel.position, freq=Fs) v, _ = wh.velocity_smoothed(pos, Fs) interp_mask = self.alignment.to_mask(t) # Convert to normalized speed xs = np.unique([find_nearest(t[interp_mask], ts) for ts in x]) vs = np.abs(v[interp_mask][xs]) vs = (vs - np.min(vs)) / (np.max(vs) - np.min(vs)) # FIXME This can be used as a goodness of fit measure USE_CV2 = False if USE_CV2: # convert from numpy format to openCV format dfCV = np.float32(self.alignment.df.reshape((-1, 1))) reCV = np.float32(vs.reshape((-1, 1))) # perform cross correlation resultCv = cv2.matchTemplate(dfCV, reCV, cv2.TM_CCORR_NORMED) # convert result back to numpy array xcorr = np.asarray(resultCv) else: xcorr = signal.correlate(self.alignment.df, vs) # Cross correlate wheel speed trace with the motion energy CORRECTION = 2 self.alignment.c = max(xcorr) self.alignment.xcorr = np.argmax(xcorr) self.alignment.dt_i = self.alignment.xcorr - xs.size + CORRECTION self.log.info( f'{side} camera, adjusted by {self.alignment.dt_i} frames') if display: # Plot the motion energy fig, ax = plt.subplots(2, 1, sharex='all') y = np.pad(self.alignment.df, 1, 'edge') ax[0].plot(x, y, '-x', label='wheel motion energy') thresh = stDev > sd_thresh ax[0].vlines(x[np.array( np.pad(thresh, 1, 'constant', constant_values=False))], 0, 1, linewidth=0.5, linestyle=':', label=f'>{sd_thresh} s.d. diff') ax[1].plot(t[interp_mask], np.abs(v[interp_mask])) # Plot other stuff dt = np.diff(camera_times[[0, np.abs(self.alignment.dt_i)]]) fps = 1 / np.diff(camera_times).mean() ax[0].plot(t[interp_mask][xs] - dt, vs, 'r-x', label='velocity (shifted)') ax[0].set_title('normalized motion energy, %s camera, %.0f fps' % (side, fps)) ax[0].set_ylabel('rate of change (a.u.)') ax[0].legend() ax[1].set_ylabel('wheel speed (rad / s)') ax[1].set_xlabel('Time (s)') title = f'{self.ref}, from {period[0]:.1f}s - {period[1]:.1f}s' fig.suptitle(title, fontsize=16) fig.set_size_inches(19.2, 9.89) return self.alignment.dt_i, self.alignment.c, self.alignment.df