def _wheel_move_during_closed_loop(re_ts, re_pos, data, wheel_gain=None, tol=1, **_): """ Check that the wheel moves by approximately 35 degrees during the closed-loop period on trials where a feedback (error sound or valve) is delivered. Metric: M = abs(w_resp - w_t0) - threshold_displacement, where w_resp = position at response time, w_t0 = position at go cue time, threshold_displacement = displacement required to move 35 visual degrees Criterion: displacement < tol visual degree Units: degrees angle of wheel turn :param re_ts: extarcted wheel timestamps in seconds :param re_pos: extracted wheel positions in radians :param data: a dict with the keys (goCueTrigger_times, response_times, feedback_times, position, choice, intervals) :param wheel_gain: the 'STIM_GAIN' task setting :param tol: the criterion in visual degrees """ if wheel_gain is None: _log.warning("No wheel_gain input in function call, returning None") return None, None # Get tuple of wheel times and positions over each trial's closed-loop period traces = traces_by_trial(re_ts, re_pos, start=data["goCueTrigger_times"], end=data["response_times"]) metric = np.zeros_like(data["feedback_times"]) # For each trial find the absolute displacement for i, trial in enumerate(traces): t, pos = trial if pos.size != 0: # Find the position of the preceding sample and subtract it idx = np.abs(re_ts - t[0]).argmin() - 1 origin = re_pos[idx] metric[i] = np.abs(pos - origin).max() # Load wheel_gain and thresholds for each trial wheel_gain = np.array([wheel_gain] * len(data["position"])) thresh = data["position"] # abs displacement, s, in mm required to move 35 visual degrees s_mm = np.abs(thresh / wheel_gain) # don't care about direction criterion = cm_to_rad( s_mm * 1e-1) # convert abs displacement to radians (wheel pos is in rad) metric = metric - criterion # difference should be close to 0 rad_per_deg = cm_to_rad(1 / wheel_gain * 1e-1) passed = (np.abs(metric) < rad_per_deg * tol).astype( float) # less than 1 visual degree off metric[data["choice"] == 0] = passed[data["choice"] == 0] = np.nan # except no-go trials assert data["intervals"].shape[0] == len(metric) == len(passed) return metric, passed
def load_wheel_move_during_closed_loop(trial_data, wheel_data, wheel_gain): """ Wheel should move a sufficient amount during the closed-loop period Variable name: wheel_move_during_closed_loop Metric: abs(w_resp - w_t0) - threshold_displacement, where w_resp = position at response time, w_t0 = position at go cue time, threshold_displacement = displacement required to move 35 visual degrees Criterion: displacement < 1 visual degree for 99% of non-NoGo trials """ if wheel_gain is None: log.warning("No wheel_gain input in function call, retruning None") return None # Get tuple of wheel times and positions over each trial's closed-loop period traces = traces_by_trial( wheel_data["re_ts"], wheel_data["re_pos"], start=trial_data["goCueTrigger_times"], end=trial_data["response_times"], ) metric = np.zeros_like(trial_data["feedback_times"]) # For each trial find the absolute displacement for i, trial in enumerate(traces): t, pos = trial if pos.size == 0: metric[i] = np.nan else: # Find the position of the preceding sample and subtract it origin = wheel_data["re_pos"][wheel_data["re_ts"] <= t[0]][-1] metric[i] = np.abs(pos - origin).max() # Load wheel_gain and thresholds for each trial wheel_gain = np.array([wheel_gain] * len(trial_data["position"])) thresh = trial_data["position"] # abs displacement, s, in mm required to move 35 visual degrees s_mm = np.abs(thresh / wheel_gain) # don't care about direction criterion = cm_to_rad( s_mm * 1e-1) # convert abs displacement to radians (wheel pos is in rad) metric = metric - criterion # difference should be close to 0 rad_per_deg = cm_to_rad(1 / wheel_gain * 1e-1) passed = (np.abs(metric) < rad_per_deg).astype( np.float) # less than 1 visual degree off metric[trial_data["choice"] == 0] = np.nan # except no-go trials passed[trial_data["choice"] == 0] = np.nan # except no-go trials assert len(trial_data["intervals_0"]) == len(metric) == len(passed) return metric, passed
def extract_onsets_for_trial(self): """ Extracts the movement onsets and offsets for the current trial :return: tuple of onsets, offsets on, interpolated timestamps, interpolated positions, and position units """ wheel = self._session_data['wheel'] trials = self._session_data['trials'] trial_idx = self.trial_num - 1 # Trials num starts at 1 # Check the values and units of wheel position res = np.array([wh.ENC_RES, wh.ENC_RES / 2, wh.ENC_RES / 4]) # min change in rad and cm for each decoding type # [rad_X4, rad_X2, rad_X1, cm_X4, cm_X2, cm_X1] min_change = np.concatenate( [2 * np.pi / res, wh.WHEEL_DIAMETER * np.pi / res]) pos_diff = np.median(np.abs(np.ediff1d(wheel['position']))) # find min change closest to min pos_diff idx = np.argmin(np.abs(min_change - pos_diff)) if idx < len(res): # Assume values are in radians units = 'rad' encoding = idx else: units = 'cm' encoding = idx - len(res) thresholds = wh.samples_to_cm(np.array([8, 1.5]), resolution=res[encoding]) if units == 'rad': thresholds = wh.cm_to_rad(thresholds) kwargs = { 'pos_thresh': thresholds[0], 'pos_thresh_onset': thresholds[1] } # kwargs = {'make_plots': True, **kwargs} # Uncomment for plot # Interpolate and get onsets pos, t = wh.interpolate_position(wheel['timestamps'], wheel['position'], freq=1000) # Get the positions and times between our trial start and the next trial start if self.quick_load or not self.trial_num: try: # End of previous trial to beginning of next t_mask = np.logical_and( t >= trials['intervals'][trial_idx - 1, 1], t <= trials['intervals'][trial_idx + 1, 0]) except IndexError: # We're on the last trial # End of previous trial to end of current t_mask = np.logical_and( t >= trials['intervals'][trial_idx - 1, 1], t <= trials['intervals'][trial_idx, 1]) else: t_mask = np.ones_like(t, dtype=bool) wheel_ts = t[t_mask] wheel_pos = pos[t_mask] on, off, *_ = wh.movements(wheel_ts, wheel_pos, freq=1000, **kwargs) return on, off, wheel_ts, wheel_pos, units
def test_extract_wheel_moves(self): test_data = self.test_data[1] # Wrangle data into expected form re_ts = test_data[0][0] re_pos = test_data[0][1] logger = logging.getLogger('ibllib') with self.assertLogs(logger, level='INFO') as cm: wheel_moves = extract_wheel_moves(re_ts, re_pos) self.assertEqual( ['INFO:ibllib:Wheel in cm units using X2 encoding'], cm.output) n = 56 # expected number of movements self.assertTupleEqual( wheel_moves['intervals'].shape, (n, 2), 'failed to return the correct number of intervals') self.assertEqual(wheel_moves['peakAmplitude'].size, n) self.assertEqual(wheel_moves['peakVelocity_times'].size, n) # Check the first 3 intervals ints = np.array([[24.78462599, 25.22562599], [29.58762599, 31.15062599], [31.64262599, 31.81662599]]) actual = wheel_moves['intervals'][:3, ] self.assertIsNone(np.testing.assert_allclose(actual, ints), 'unexpected intervals') # Check amplitudes actual = wheel_moves['peakAmplitude'][-3:] expected = [0.50255486, -1.70103154, 1.00740789] self.assertIsNone(np.testing.assert_allclose(actual, expected), 'unexpected amplitudes') # Check peak velocities actual = wheel_moves['peakVelocity_times'][-3:] expected = [175.13662599, 176.65762599, 178.57262599] self.assertIsNone(np.testing.assert_allclose(actual, expected), 'peak times') # Test extraction in rad re_pos = wh.cm_to_rad(re_pos) with self.assertLogs(logger, level='INFO') as cm: wheel_moves = ephys_fpga.extract_wheel_moves(re_ts, re_pos) self.assertEqual( ['INFO:ibllib:Wheel in rad units using X2 encoding'], cm.output) # Check the first 3 intervals. As position thresholds are adjusted by units and # encoding, we should expect the intervals to be identical to above actual = wheel_moves['intervals'][:3, ] self.assertIsNone(np.testing.assert_allclose(actual, ints), 'unexpected intervals')
def extract_wheel_moves(re_ts, re_pos, display=False): """ Extract wheel positions and times from sync fronts dictionary :param re_ts: numpy array of rotary encoder timestamps :param re_pos: numpy array of rotary encoder positions :param display: bool: show the wheel position and velocity for full session with detected movements highlighted :return: wheel_moves dictionary """ if len(re_ts.shape) == 1: assert re_ts.size == re_pos.size, 'wheel data dimension mismatch' else: _logger.debug('2D wheel timestamps') if len(re_pos.shape) > 1: # Ensure 1D array of positions re_pos = re_pos.flatten() # Linearly interpolate the times x = np.arange(re_pos.size) re_ts = np.interp(x, re_ts[:, 0], re_ts[:, 1]) units, res, enc = infer_wheel_units(re_pos) _logger.info('Wheel in %s units using %s encoding', units, enc) # The below assertion is violated by Bpod wheel data # assert np.allclose(pos_diff, min_change, rtol=1e-05), 'wheel position skips' # Convert the pos threshold defaults from samples to correct unit thresholds = wh.samples_to_cm(np.array([8, 1.5]), resolution=res) if units == 'rad': thresholds = wh.cm_to_rad(thresholds) kwargs = { 'pos_thresh': thresholds[0], 'pos_thresh_onset': thresholds[1], 'make_plots': display } # Interpolate and get onsets pos, t = wh.interpolate_position(re_ts, re_pos, freq=1000) on, off, amp, peak_vel = wh.movements(t, pos, freq=1000, **kwargs) assert on.size == off.size, 'onset/offset number mismatch' assert np.all(np.diff(on) > 0) and np.all( np.diff(off) > 0), 'onsets/offsets not strictly increasing' assert np.all((off - on) > 0), 'not all offsets occur after onset' # Put into dict wheel_moves = { 'intervals': np.c_[on, off], 'peakAmplitude': amp, 'peakVelocity_times': peak_vel } return wheel_moves
def load_fake_wheel_data(trial_data, wheel_gain=4): # Load a wheel fragment: a numpy array of the form [timestamps, positions], for a wheel # movement during one trial. Wheel is X1 bpod RE in radians. wh_path = Path(__file__).parent.joinpath('..', 'fixtures', 'qc').resolve() wheel_frag = np.load(wh_path.joinpath('wheel.npy')) resolution = np.mean(np.abs(np.diff( wheel_frag[:, 1]))) # pos diff between samples # abs displacement, s, in mm required to move 35 visual degrees POS_THRESH = 35 s_mm = np.abs(POS_THRESH / wheel_gain) # don't care about direction # convert abs displacement to radians (wheel pos is in rad) pos_thresh = cm_to_rad(s_mm * 1e-1) # index of threshold cross pos_thresh_idx = np.argmax(np.abs(wheel_frag[:, 1]) > pos_thresh) def qt_wheel_fill(start, end, t_step=0.001, p_step=None): if p_step is None: p_step = 2 * np.pi / 1024 t = np.arange(start, end, t_step) p = np.random.randint(-1, 2, len(t)) t = t[p != 0] p = p[p != 0].cumsum() * p_step return t, p wheel_data = [] # List generated of wheel data fragments movement_times = [] # List of generated first movement times def add_frag(t, p): """Add wheel data fragments to list, adjusting positions to be within one sample of one another""" last_samp = getattr(add_frag, 'last_samp', (0, 0)) p += last_samp[1] if np.abs(p[0] - last_samp[1]) == 0: p += resolution wheel_data.append((t, p)) add_frag.last_samp = (t[-1], p[-1]) for i in np.arange(len(trial_data['choice'])): # Iterate over trials generating wheel samples for the necessary periods # trial start to stim on; should be below quiescence threshold stimOn_trig = trial_data['stimOnTrigger_times'][i] trial_start = trial_data['intervals'][i, 0] t, p = qt_wheel_fill(trial_start, stimOn_trig, .5, resolution) if len(t) > 0: # Possible for no movement during quiescence add_frag(t, p) # stim on to trial end trial_end = trial_data['intervals'][i, 1] if trial_data['choice'][i] == 0: # Add random wheel movements for duration of trial goCue = trial_data['goCue_times'][i] t, p = qt_wheel_fill(goCue, trial_end, .1, resolution) add_frag(t, p) movement_times.append(t[0]) else: # Align wheel fragment with response time response_time = trial_data['response_times'][i] t = wheel_frag[:, 0] + response_time - wheel_frag[pos_thresh_idx, 0] p = np.abs(wheel_frag[:, 1]) * trial_data['choice'][i] assert t[0] > add_frag.last_samp[0] movement_times.append(t[1]) add_frag(t, p) # Fill in random movements between end of response and trial end t, p = qt_wheel_fill(t[-1] + 0.01, trial_end, p_step=resolution) add_frag(t, p) # Stitch wheel fragments and assert no skips wheel_data = np.concatenate(list(map(np.column_stack, wheel_data))) assert np.all( np.diff(wheel_data[:, 0]) > 0), "timestamps don't strictly increase" np.testing.assert_allclose(np.abs(np.diff(wheel_data[:, 1])), resolution) assert len(movement_times) == trial_data['intervals'].shape[0] return { 'wheel_timestamps': wheel_data[:, 0], 'wheel_position': wheel_data[:, 1], 'firstMovement_times': np.array(movement_times) }
def make(self, key): # Load the wheel for this session move_key = key.copy() one = ONE() eid, ver = (acquisition.Session & key).fetch1('session_uuid', 'task_protocol') logger.info('WheelMoves for session %s, %s', str(eid), ver) try: # Should be able to remove this wheel = one.load_object(str(eid), 'wheel') all_loaded = \ all([isinstance(wheel[lab], np.ndarray) for lab in wheel]) and \ all(k in wheel for k in ('timestamps', 'position')) assert all_loaded, 'wheel data missing' alf.io.check_dimensions(wheel) if len(wheel['timestamps'].shape) == 1: assert wheel['timestamps'].size == wheel[ 'position'].size, 'wheel data dimension mismatch' assert np.all( np.diff(wheel['timestamps']) > 0 ), 'wheel timestamps not monotonically increasing' else: logger.debug('2D timestamps') # Check the values and units of wheel position res = np.array([wh.ENC_RES, wh.ENC_RES / 2, wh.ENC_RES / 4]) min_change_rad = 2 * np.pi / res min_change_cm = wh.WHEEL_DIAMETER * np.pi / res pos_diff = np.abs(np.ediff1d(wheel['position'])) if pos_diff.min() < min_change_cm.min(): # Assume values are in radians units = 'rad' encoding = np.argmin(np.abs(min_change_rad - pos_diff.min())) min_change = min_change_rad[encoding] else: units = 'cm' encoding = np.argmin(np.abs(min_change_cm - pos_diff.min())) min_change = min_change_cm[encoding] enc_names = {0: '4X', 1: '2X', 2: '1X'} logger.info('Wheel in %s units using %s encoding', units, enc_names[int(encoding)]) if '_iblrig_tasks_ephys' in ver: assert np.allclose(pos_diff, min_change, rtol=1e-05), 'wheel position skips' except ValueError: logger.exception('Inconsistent wheel data') raise except AssertionError as ex: logger.exception(str(ex)) raise except Exception as ex: logger.exception(str(ex)) raise try: # Convert the pos threshold defaults from samples to correct unit thresholds = wh.samples_to_cm(np.array([8, 1.5]), resolution=res[encoding]) if units == 'rad': thresholds = wh.cm_to_rad(thresholds) kwargs = { 'pos_thresh': thresholds[0], 'pos_thresh_onset': thresholds[1] } # kwargs = {'make_plots': True, **kwargs} # Interpolate and get onsets pos, t = wh.interpolate_position(wheel['timestamps'], wheel['position'], freq=1000) on, off, amp, peak_vel = wh.movements(t, pos, freq=1000, **kwargs) assert on.size == off.size, 'onset/offset number mismatch' assert np.all(np.diff(on) > 0) and np.all(np.diff( off) > 0), 'onsets/offsets not monotonically increasing' assert np.all((off - on) > 0), 'not all offsets occur after onset' except ValueError: logger.exception('Failed to find movements') raise except AssertionError as ex: logger.exception('Wheel integrity check failed: ' + str(ex)) raise key['n_movements'] = on.size # total number of movements within the session key['total_displacement'] = float(np.diff( pos[[0, -1]])) # total displacement of the wheel during session key['total_distance'] = float(np.abs( np.diff(pos)).sum()) # total movement of the wheel if units is 'cm': # convert to radians key['total_displacement'] = wh.cm_to_rad(key['total_displacement']) key['total_distance'] = wh.cm_to_rad(key['total_distance']) amp = wh.cm_to_rad(amp) self.insert1(key) keys = ('move_id', 'movement_onset', 'movement_offset', 'max_velocity', 'movement_amplitude') moves = [ dict(zip(keys, (i, on[i], off[i], amp[i], peak_vel[i]))) for i in np.arange(on.size) ] [x.update(move_key) for x in moves] self.Move.insert(moves)
def make(self, key, one=None): # Load the wheel for this session move_key = key.copy() change_key = move_key.copy() one = one or ONE() eid, ver = (acquisition.Session & key).fetch1('session_uuid', 'task_protocol') logger.info('WheelMoves for session %s, %s', str(eid), ver) try: # Should be able to remove this wheel = one.load_object(str(eid), 'wheel') all_loaded = \ all([isinstance(wheel[lab], np.ndarray) for lab in wheel]) and \ all(k in wheel for k in ('timestamps', 'position')) assert all_loaded, 'wheel data missing' # If times and timestamps present, drop times if {'times', 'timestamps'}.issubset(wheel): wheel.pop('times') wheel_moves = extract_wheel_moves(wheel.timestamps, wheel.position) except ValueError: logger.exception('Failed to find movements') raise except AssertionError as ex: logger.exception(str(ex)) raise except Exception as ex: logger.exception(str(ex)) raise # Build list of table entries keys = ('move_id', 'movement_onset', 'movement_offset', 'max_velocity', 'movement_amplitude') on_off, amp, vel_t = wheel_moves.values() # Unpack into short vars moves = [dict(zip(keys, (i, on, off, vel_t[i], amp[i])), **move_key) for i, (on, off) in enumerate(on_off)] # Calculate direction changes Fs = 1000 re_ts, re_pos = wheel.timestamps, wheel.position if len(re_ts.shape) != 1: logger.info('2D wheel timestamps') if len(re_pos.shape) > 1: # Ensure 1D array of positions re_pos = re_pos.flatten() # Linearly interpolate the times x = np.arange(re_pos.size) re_ts = np.interp(x, re_ts[:, 0], re_ts[:, 1]) pos, ts = wh.interpolate_position(re_pos, re_ts, freq=Fs) vel, _ = wh.velocity_smoothed(pos, Fs) change_mask = np.insert(np.diff(np.sign(vel)) != 0, 0, 0) changes = [] for i, (on, off) in enumerate(on_off.reshape(-1, 2)): mask = np.logical_and(ts > on, ts < off) ind = np.logical_and(mask, change_mask) changes.extend( dict(change_key, move_id=i, change_id=j, change_time=t) for j, t in enumerate(ts[ind]) ) # Get the units of the position data units, *_ = infer_wheel_units(wheel.position) key['n_movements'] = wheel_moves['intervals'].shape[0] # total number of movements within the session key['total_displacement'] = float(np.diff(wheel.position[[0, -1]])) # total displacement of the wheel during session key['total_distance'] = float(np.abs(np.diff(wheel.position)).sum()) # total movement of the wheel key['n_direction_changes'] = sum(change_mask) # total number of direction changes if units == 'cm': # convert to radians key['total_displacement'] = wh.cm_to_rad(key['total_displacement']) key['total_distance'] = wh.cm_to_rad(key['total_distance']) wheel_moves['peakAmplitude'] = wh.cm_to_rad(wheel_moves['peakAmplitude']) # Insert the keys in order self.insert1(key) self.Move.insert(moves) self.DirectionChange.insert(changes)
def extract_wheel_moves(re_ts, re_pos, display=False): """ Extract wheel positions and times from sync fronts dictionary :param re_ts: numpy array of rotary encoder timestamps :param re_pos: numpy array of rotary encoder positions :param display: bool: show the wheel position and velocity for full session with detected movements highlighted :return: wheel_moves dictionary """ if len(re_ts.shape) == 1: assert re_ts.size == re_pos.size, 'wheel data dimension mismatch' assert np.all(np.diff(re_ts) > 0 ), 'wheel timestamps not monotonically increasing' else: _logger.debug('2D wheel timestamps') # Check the values and units of wheel position res = np.array([wh.ENC_RES, wh.ENC_RES / 2, wh.ENC_RES / 4]) # min change in rad and cm for each decoding type # [rad_X4, rad_X2, rad_X1, cm_X4, cm_X2, cm_X1] min_change = np.concatenate( [2 * np.pi / res, wh.WHEEL_DIAMETER * np.pi / res]) pos_diff = np.abs(np.ediff1d(re_pos)).min() # find min change closest to min pos_diff idx = np.argmin(np.abs(min_change - pos_diff)) if idx < len(res): # Assume values are in radians units = 'rad' encoding = idx else: units = 'cm' encoding = idx - len(res) enc_names = {0: 'X4', 1: 'X2', 2: 'X1'} _logger.info('Wheel in %s units using %s encoding', units, enc_names[int(encoding)]) # The below assertion is violated by Bpod wheel data # assert np.allclose(pos_diff, min_change, rtol=1e-05), 'wheel position skips' # Convert the pos threshold defaults from samples to correct unit thresholds = wh.samples_to_cm(np.array([8, 1.5]), resolution=res[encoding]) if units == 'rad': thresholds = wh.cm_to_rad(thresholds) kwargs = { 'pos_thresh': thresholds[0], 'pos_thresh_onset': thresholds[1], 'make_plots': display } # Interpolate and get onsets pos, t = wh.interpolate_position(re_ts, re_pos, freq=1000) on, off, amp, peak_vel = wh.movements(t, pos, freq=1000, **kwargs) assert on.size == off.size, 'onset/offset number mismatch' assert np.all(np.diff(on) > 0) and np.all( np.diff(off) > 0), 'onsets/offsets not monotonically increasing' assert np.all((off - on) > 0), 'not all offsets occur after onset' # Put into dict wheel_moves = { 'intervals': np.c_[on, off], 'peakAmplitude': amp, 'peakVelocity_times': peak_vel } return wheel_moves