예제 #1
0
 def get_active_wheel_period(wheel,
                             duration_range=(3., 20.),
                             display=False):
     """
     Attempts to find a period of movement where the wheel accelerates and decelerates for
     the wheel motion alignment QC.
     :param wheel: A Bunch of wheel timestamps and position data
     :param duration_range: The candidates must be within min/max duration range
     :param display: If true, plot the selected wheel movement
     :return: 2-element array comprising the start and end times of the active period
     """
     pos, ts = wh.interpolate_position(wheel.timestamps, wheel.position)
     v, acc = wh.velocity_smoothed(pos, 1000)
     on, off, *_ = wh.movements(ts, acc, pos_thresh=.1, make_plots=False)
     edges = np.c_[on, off]
     indices, _ = np.where(
         np.logical_and(
             np.diff(edges) > duration_range[0],
             np.diff(edges) < duration_range[1]))
     if len(indices) == 0:
         _log.warning(
             'No period of wheel movement found for motion alignment.')
         return None
     # Pick movement somewhere in the middle
     i = indices[int(indices.size / 2)]
     if display:
         _, (ax0, ax1) = plt.subplots(2, 1, sharex='all')
         mask = np.logical_and(ts > edges[i][0], ts < edges[i][1])
         ax0.plot(ts[mask], pos[mask])
         ax1.plot(ts[mask], acc[mask])
     return edges[i]
예제 #2
0
    def extract_onsets_for_trial(self):
        """
        Extracts the movement onsets and offsets for the current trial
        :return: tuple of onsets, offsets on, interpolated timestamps, interpolated positions,
        and position units
        """
        wheel = self._session_data['wheel']
        trials = self._session_data['trials']
        trial_idx = self.trial_num - 1  # Trials num starts at 1
        # Check the values and units of wheel position
        res = np.array([wh.ENC_RES, wh.ENC_RES / 2, wh.ENC_RES / 4])
        # min change in rad and cm for each decoding type
        # [rad_X4, rad_X2, rad_X1, cm_X4, cm_X2, cm_X1]
        min_change = np.concatenate(
            [2 * np.pi / res, wh.WHEEL_DIAMETER * np.pi / res])
        pos_diff = np.median(np.abs(np.ediff1d(wheel['position'])))

        # find min change closest to min pos_diff
        idx = np.argmin(np.abs(min_change - pos_diff))
        if idx < len(res):
            # Assume values are in radians
            units = 'rad'
            encoding = idx
        else:
            units = 'cm'
            encoding = idx - len(res)
        thresholds = wh.samples_to_cm(np.array([8, 1.5]),
                                      resolution=res[encoding])
        if units == 'rad':
            thresholds = wh.cm_to_rad(thresholds)
        kwargs = {
            'pos_thresh': thresholds[0],
            'pos_thresh_onset': thresholds[1]
        }
        #  kwargs = {'make_plots': True, **kwargs}  # Uncomment for plot

        # Interpolate and get onsets
        pos, t = wh.interpolate_position(wheel['timestamps'],
                                         wheel['position'],
                                         freq=1000)
        # Get the positions and times between our trial start and the next trial start
        if self.quick_load or not self.trial_num:
            try:
                # End of previous trial to beginning of next
                t_mask = np.logical_and(
                    t >= trials['intervals'][trial_idx - 1, 1],
                    t <= trials['intervals'][trial_idx + 1, 0])
            except IndexError:  # We're on the last trial
                # End of previous trial to end of current
                t_mask = np.logical_and(
                    t >= trials['intervals'][trial_idx - 1, 1],
                    t <= trials['intervals'][trial_idx, 1])
        else:
            t_mask = np.ones_like(t, dtype=bool)
        wheel_ts = t[t_mask]
        wheel_pos = pos[t_mask]
        on, off, *_ = wh.movements(wheel_ts, wheel_pos, freq=1000, **kwargs)
        return on, off, wheel_ts, wheel_pos, units
예제 #3
0
 def test_movements(self):
     # These test data are the same as those used in the MATLAB code
     inputs = self.test_data[0][0]
     expected = self.test_data[0][1]
     on, off, amp, peak_vel = wheel.movements(
         *inputs, freq=1000, pos_thresh=8, pos_thresh_onset=1.5)
     self.assertTrue(np.array_equal(on, expected[0]), msg='Unexpected onsets')
     self.assertTrue(np.array_equal(off, expected[1]), msg='Unexpected offsets')
     self.assertTrue(np.array_equal(amp, expected[2]), msg='Unexpected move amps')
     # Differences due to convolution algorithm
     all_close = np.allclose(peak_vel, expected[3], atol=1.e-2)
     self.assertTrue(all_close, msg='Unexpected peak velocities')
예제 #4
0
 def test_movements_FPGA(self):
     # These test data are the same as those used in the MATLAB code.  Test data are from
     # extracted FPGA wheel data
     pos, t = wheel.interpolate_position(*self.test_data[1][0], freq=1000)
     expected = self.test_data[1][1]
     thresholds = wheel.samples_to_cm(np.array([8, 1.5]))
     on, off, amp, peak_vel = wheel.movements(
         t, pos, freq=1000, pos_thresh=thresholds[0], pos_thresh_onset=thresholds[1])
     self.assertTrue(np.allclose(on, expected[0], atol=1.e-5), msg='Unexpected onsets')
     self.assertTrue(np.allclose(off, expected[1], atol=1.e-5), msg='Unexpected offsets')
     self.assertTrue(np.allclose(amp, expected[2], atol=1.e-5), msg='Unexpected move amps')
     self.assertTrue(np.allclose(peak_vel, expected[3], atol=1.e-2),
                     msg='Unexpected peak velocities')
예제 #5
0
def extract_wheel_moves(re_ts, re_pos, display=False):
    """
    Extract wheel positions and times from sync fronts dictionary
    :param re_ts: numpy array of rotary encoder timestamps
    :param re_pos: numpy array of rotary encoder positions
    :param display: bool: show the wheel position and velocity for full session with detected
    movements highlighted
    :return: wheel_moves dictionary
    """
    if len(re_ts.shape) == 1:
        assert re_ts.size == re_pos.size, 'wheel data dimension mismatch'
    else:
        _logger.debug('2D wheel timestamps')
        if len(re_pos.shape) > 1:  # Ensure 1D array of positions
            re_pos = re_pos.flatten()
        # Linearly interpolate the times
        x = np.arange(re_pos.size)
        re_ts = np.interp(x, re_ts[:, 0], re_ts[:, 1])

    units, res, enc = infer_wheel_units(re_pos)
    _logger.info('Wheel in %s units using %s encoding', units, enc)

    # The below assertion is violated by Bpod wheel data
    #  assert np.allclose(pos_diff, min_change, rtol=1e-05), 'wheel position skips'

    # Convert the pos threshold defaults from samples to correct unit
    thresholds = wh.samples_to_cm(np.array([8, 1.5]), resolution=res)
    if units == 'rad':
        thresholds = wh.cm_to_rad(thresholds)
    kwargs = {
        'pos_thresh': thresholds[0],
        'pos_thresh_onset': thresholds[1],
        'make_plots': display
    }

    # Interpolate and get onsets
    pos, t = wh.interpolate_position(re_ts, re_pos, freq=1000)
    on, off, amp, peak_vel = wh.movements(t, pos, freq=1000, **kwargs)
    assert on.size == off.size, 'onset/offset number mismatch'
    assert np.all(np.diff(on) > 0) and np.all(
        np.diff(off) > 0), 'onsets/offsets not strictly increasing'
    assert np.all((off - on) > 0), 'not all offsets occur after onset'

    # Put into dict
    wheel_moves = {
        'intervals': np.c_[on, off],
        'peakAmplitude': amp,
        'peakVelocity_times': peak_vel
    }
    return wheel_moves
예제 #6
0
    def make(self, key):
        # Load the wheel for this session
        move_key = key.copy()
        one = ONE()
        eid, ver = (acquisition.Session & key).fetch1('session_uuid',
                                                      'task_protocol')
        logger.info('WheelMoves for session %s, %s', str(eid), ver)

        try:  # Should be able to remove this
            wheel = one.load_object(str(eid), 'wheel')
            all_loaded = \
                all([isinstance(wheel[lab], np.ndarray) for lab in wheel]) and \
                all(k in wheel for k in ('timestamps', 'position'))
            assert all_loaded, 'wheel data missing'
            alf.io.check_dimensions(wheel)
            if len(wheel['timestamps'].shape) == 1:
                assert wheel['timestamps'].size == wheel[
                    'position'].size, 'wheel data dimension mismatch'
                assert np.all(
                    np.diff(wheel['timestamps']) > 0
                ), 'wheel timestamps not monotonically increasing'
            else:
                logger.debug('2D timestamps')
            # Check the values and units of wheel position
            res = np.array([wh.ENC_RES, wh.ENC_RES / 2, wh.ENC_RES / 4])
            min_change_rad = 2 * np.pi / res
            min_change_cm = wh.WHEEL_DIAMETER * np.pi / res
            pos_diff = np.abs(np.ediff1d(wheel['position']))
            if pos_diff.min() < min_change_cm.min():
                # Assume values are in radians
                units = 'rad'
                encoding = np.argmin(np.abs(min_change_rad - pos_diff.min()))
                min_change = min_change_rad[encoding]
            else:
                units = 'cm'
                encoding = np.argmin(np.abs(min_change_cm - pos_diff.min()))
                min_change = min_change_cm[encoding]
            enc_names = {0: '4X', 1: '2X', 2: '1X'}
            logger.info('Wheel in %s units using %s encoding', units,
                        enc_names[int(encoding)])
            if '_iblrig_tasks_ephys' in ver:
                assert np.allclose(pos_diff, min_change,
                                   rtol=1e-05), 'wheel position skips'
        except ValueError:
            logger.exception('Inconsistent wheel data')
            raise
        except AssertionError as ex:
            logger.exception(str(ex))
            raise
        except Exception as ex:
            logger.exception(str(ex))
            raise

        try:
            # Convert the pos threshold defaults from samples to correct unit
            thresholds = wh.samples_to_cm(np.array([8, 1.5]),
                                          resolution=res[encoding])
            if units == 'rad':
                thresholds = wh.cm_to_rad(thresholds)
            kwargs = {
                'pos_thresh': thresholds[0],
                'pos_thresh_onset': thresholds[1]
            }
            #  kwargs = {'make_plots': True, **kwargs}
            # Interpolate and get onsets
            pos, t = wh.interpolate_position(wheel['timestamps'],
                                             wheel['position'],
                                             freq=1000)
            on, off, amp, peak_vel = wh.movements(t, pos, freq=1000, **kwargs)
            assert on.size == off.size, 'onset/offset number mismatch'
            assert np.all(np.diff(on) > 0) and np.all(np.diff(
                off) > 0), 'onsets/offsets not monotonically increasing'
            assert np.all((off - on) > 0), 'not all offsets occur after onset'
        except ValueError:
            logger.exception('Failed to find movements')
            raise
        except AssertionError as ex:
            logger.exception('Wheel integrity check failed: ' + str(ex))
            raise

        key['n_movements'] = on.size  # total number of movements within the session
        key['total_displacement'] = float(np.diff(
            pos[[0, -1]]))  # total displacement of the wheel during session
        key['total_distance'] = float(np.abs(
            np.diff(pos)).sum())  # total movement of the wheel
        if units is 'cm':  # convert to radians
            key['total_displacement'] = wh.cm_to_rad(key['total_displacement'])
            key['total_distance'] = wh.cm_to_rad(key['total_distance'])
            amp = wh.cm_to_rad(amp)

        self.insert1(key)

        keys = ('move_id', 'movement_onset', 'movement_offset', 'max_velocity',
                'movement_amplitude')
        moves = [
            dict(zip(keys, (i, on[i], off[i], amp[i], peak_vel[i])))
            for i in np.arange(on.size)
        ]
        [x.update(move_key) for x in moves]

        self.Move.insert(moves)
예제 #7
0
def extract_wheel_moves(re_ts, re_pos, display=False):
    """
    Extract wheel positions and times from sync fronts dictionary
    :param re_ts: numpy array of rotary encoder timestamps
    :param re_pos: numpy array of rotary encoder positions
    :param display: bool: show the wheel position and velocity for full session with detected
    movements highlighted
    :return: wheel_moves dictionary
    """
    if len(re_ts.shape) == 1:
        assert re_ts.size == re_pos.size, 'wheel data dimension mismatch'
        assert np.all(np.diff(re_ts) > 0
                      ), 'wheel timestamps not monotonically increasing'
    else:
        _logger.debug('2D wheel timestamps')

    # Check the values and units of wheel position
    res = np.array([wh.ENC_RES, wh.ENC_RES / 2, wh.ENC_RES / 4])
    # min change in rad and cm for each decoding type
    # [rad_X4, rad_X2, rad_X1, cm_X4, cm_X2, cm_X1]
    min_change = np.concatenate(
        [2 * np.pi / res, wh.WHEEL_DIAMETER * np.pi / res])
    pos_diff = np.abs(np.ediff1d(re_pos)).min()

    # find min change closest to min pos_diff
    idx = np.argmin(np.abs(min_change - pos_diff))
    if idx < len(res):
        # Assume values are in radians
        units = 'rad'
        encoding = idx
    else:
        units = 'cm'
        encoding = idx - len(res)
    enc_names = {0: 'X4', 1: 'X2', 2: 'X1'}
    _logger.info('Wheel in %s units using %s encoding', units,
                 enc_names[int(encoding)])

    # The below assertion is violated by Bpod wheel data
    #  assert np.allclose(pos_diff, min_change, rtol=1e-05), 'wheel position skips'

    # Convert the pos threshold defaults from samples to correct unit
    thresholds = wh.samples_to_cm(np.array([8, 1.5]), resolution=res[encoding])
    if units == 'rad':
        thresholds = wh.cm_to_rad(thresholds)
    kwargs = {
        'pos_thresh': thresholds[0],
        'pos_thresh_onset': thresholds[1],
        'make_plots': display
    }

    # Interpolate and get onsets
    pos, t = wh.interpolate_position(re_ts, re_pos, freq=1000)
    on, off, amp, peak_vel = wh.movements(t, pos, freq=1000, **kwargs)
    assert on.size == off.size, 'onset/offset number mismatch'
    assert np.all(np.diff(on) > 0) and np.all(
        np.diff(off) > 0), 'onsets/offsets not monotonically increasing'
    assert np.all((off - on) > 0), 'not all offsets occur after onset'

    # Put into dict
    wheel_moves = {
        'intervals': np.c_[on, off],
        'peakAmplitude': amp,
        'peakVelocity_times': peak_vel
    }
    return wheel_moves