Ejemplo n.º 1
0
 def get_active_wheel_period(wheel,
                             duration_range=(3., 20.),
                             display=False):
     """
     Attempts to find a period of movement where the wheel accelerates and decelerates for
     the wheel motion alignment QC.
     :param wheel: A Bunch of wheel timestamps and position data
     :param duration_range: The candidates must be within min/max duration range
     :param display: If true, plot the selected wheel movement
     :return: 2-element array comprising the start and end times of the active period
     """
     pos, ts = wh.interpolate_position(wheel.timestamps, wheel.position)
     v, acc = wh.velocity_smoothed(pos, 1000)
     on, off, *_ = wh.movements(ts, acc, pos_thresh=.1, make_plots=False)
     edges = np.c_[on, off]
     indices, _ = np.where(
         np.logical_and(
             np.diff(edges) > duration_range[0],
             np.diff(edges) < duration_range[1]))
     if len(indices) == 0:
         _log.warning(
             'No period of wheel movement found for motion alignment.')
         return None
     # Pick movement somewhere in the middle
     i = indices[int(indices.size / 2)]
     if display:
         _, (ax0, ax1) = plt.subplots(2, 1, sharex='all')
         mask = np.logical_and(ts > edges[i][0], ts < edges[i][1])
         ax0.plot(ts[mask], pos[mask])
         ax1.plot(ts[mask], acc[mask])
     return edges[i]
Ejemplo n.º 2
0
    def test_direction_changes(self):
        t, pos = self.test_data[0][0]
        on, off, *_ = self.test_data[0][1]
        vel, _ = wheel.velocity_smoothed(pos, 1000)
        times, indices = wheel.direction_changes(t, vel, np.c_[on, off])

        self.assertTrue(len(times) == len(indices) == 14, 'incorrect number of arrays returned')
        # Check first arrays
        np.testing.assert_allclose(times[0], [21.86593334, 22.12693334, 22.20193334, 22.66093334])
        np.testing.assert_array_equal(indices[0], [21809, 22070, 22145, 22604])
Ejemplo n.º 3
0
# Randomly select the trials to plot
trial_ids = np.random.randint(trial_data['choice'].size, size=n_trials)
fig, axs = plt.subplots(1, n_trials, figsize=(8.5, 2.5))
plt.tight_layout()

# Plot go cue and response times
goCues = trial_data['goCue_times'][trial_ids]
responses = trial_data['response_times'][trial_ids]

# Plot traces between trial intervals
starts = trial_data['intervals'][trial_ids, 0]
ends = trial_data['intervals'][trial_ids, 1]
# Cut up the wheel vectors
Fs = 1000
pos, t = wh.interpolate_position(wheel.timestamps, wheel.position, freq=Fs)
vel, acc = wh.velocity_smoothed(pos, Fs)

traces = wh.traces_by_trial(t, pos, start=starts, end=ends)
zipped = zip(traces, axs, goCues, responses, trial_ids)

for (trace, ax, go, resp, n) in zipped:
    ax.plot(trace[0], trace[1], 'k-')
    ax.axvline(x=go, color='g', label='go cue', linestyle=':')
    ax.axvline(x=resp, color='r', label='threshold', linestyle=':')
    ax.set_title('Trial #%s' % n)

    # Turn off tick labels
    ax.set_yticklabels([])
    ax.set_xticklabels([])

# Add labels to first
Ejemplo n.º 4
0
    def make(self, key, one=None):
        # Load the wheel for this session
        move_key = key.copy()
        change_key = move_key.copy()
        one = one or ONE()
        eid, ver = (acquisition.Session & key).fetch1('session_uuid', 'task_protocol')
        logger.info('WheelMoves for session %s, %s', str(eid), ver)

        try:  # Should be able to remove this
            wheel = one.load_object(str(eid), 'wheel')
            all_loaded = \
                all([isinstance(wheel[lab], np.ndarray) for lab in wheel]) and \
                all(k in wheel for k in ('timestamps', 'position'))
            assert all_loaded, 'wheel data missing'

            # If times and timestamps present, drop times
            if {'times', 'timestamps'}.issubset(wheel):
                wheel.pop('times')
            wheel_moves = extract_wheel_moves(wheel.timestamps, wheel.position)
        except ValueError:
            logger.exception('Failed to find movements')
            raise
        except AssertionError as ex:
            logger.exception(str(ex))
            raise
        except Exception as ex:
            logger.exception(str(ex))
            raise

        # Build list of table entries
        keys = ('move_id', 'movement_onset', 'movement_offset', 'max_velocity', 'movement_amplitude')
        on_off, amp, vel_t = wheel_moves.values()  # Unpack into short vars
        moves = [dict(zip(keys, (i, on, off, vel_t[i], amp[i])), **move_key)
                 for i, (on, off) in enumerate(on_off)]

        # Calculate direction changes
        Fs = 1000
        re_ts, re_pos = wheel.timestamps, wheel.position
        if len(re_ts.shape) != 1:
            logger.info('2D wheel timestamps')
            if len(re_pos.shape) > 1:  # Ensure 1D array of positions
                re_pos = re_pos.flatten()
            # Linearly interpolate the times
            x = np.arange(re_pos.size)
            re_ts = np.interp(x, re_ts[:, 0], re_ts[:, 1])

        pos, ts = wh.interpolate_position(re_pos, re_ts, freq=Fs)
        vel, _ = wh.velocity_smoothed(pos, Fs)
        change_mask = np.insert(np.diff(np.sign(vel)) != 0, 0, 0)

        changes = []
        for i, (on, off) in enumerate(on_off.reshape(-1, 2)):
            mask = np.logical_and(ts > on, ts < off)
            ind = np.logical_and(mask, change_mask)
            changes.extend(
                dict(change_key, move_id=i, change_id=j, change_time=t) for j, t in enumerate(ts[ind])
            )

        # Get the units of the position data
        units, *_ = infer_wheel_units(wheel.position)
        key['n_movements'] = wheel_moves['intervals'].shape[0]  # total number of movements within the session
        key['total_displacement'] = float(np.diff(wheel.position[[0, -1]]))  # total displacement of the wheel during session
        key['total_distance'] = float(np.abs(np.diff(wheel.position)).sum())  # total movement of the wheel
        key['n_direction_changes'] = sum(change_mask)  # total number of direction changes
        if units == 'cm':  # convert to radians
            key['total_displacement'] = wh.cm_to_rad(key['total_displacement'])
            key['total_distance'] = wh.cm_to_rad(key['total_distance'])
            wheel_moves['peakAmplitude'] = wh.cm_to_rad(wheel_moves['peakAmplitude'])

        # Insert the keys in order
        self.insert1(key)
        self.Move.insert(moves)
        self.DirectionChange.insert(changes)
Ejemplo n.º 5
0
    def align_motion(self,
                     period=(-np.inf, np.inf),
                     side='left',
                     sd_thresh=10,
                     display=False):
        # Get data samples within period
        wheel = self.data['wheel']
        self.alignment.label = side
        self.alignment.to_mask = lambda ts: np.logical_and(
            ts >= period[0], ts <= period[1])
        camera_times = self.data['camera_times'][side]
        cam_mask = self.alignment.to_mask(camera_times)
        frame_numbers, = np.where(cam_mask)

        if frame_numbers.size == 0:
            raise ValueError('No frames during given period')

        # Motion Energy
        camera_path = self.video_paths[side]
        roi = (*[slice(*r) for r in self.roi[side]], 0)
        try:
            # TODO Add function arg to make grayscale
            self.alignment.frames = \
                vidio.get_video_frames_preload(camera_path, frame_numbers, mask=roi)
            assert self.alignment.frames.size != 0
        except AssertionError:
            self.log.error('Failed to open video')
            return None, None, None
        self.alignment.df, stDev = video.motion_energy(self.alignment.frames,
                                                       2)
        self.alignment.period = period  # For plotting

        # Calculate rotary encoder velocity trace
        x = camera_times[cam_mask]
        Fs = 1000
        pos, t = wh.interpolate_position(wheel.timestamps,
                                         wheel.position,
                                         freq=Fs)
        v, _ = wh.velocity_smoothed(pos, Fs)
        interp_mask = self.alignment.to_mask(t)
        # Convert to normalized speed
        xs = np.unique([find_nearest(t[interp_mask], ts) for ts in x])
        vs = np.abs(v[interp_mask][xs])
        vs = (vs - np.min(vs)) / (np.max(vs) - np.min(vs))

        # FIXME This can be used as a goodness of fit measure
        USE_CV2 = False
        if USE_CV2:
            # convert from numpy format to openCV format
            dfCV = np.float32(self.alignment.df.reshape((-1, 1)))
            reCV = np.float32(vs.reshape((-1, 1)))

            # perform cross correlation
            resultCv = cv2.matchTemplate(dfCV, reCV, cv2.TM_CCORR_NORMED)

            # convert result back to numpy array
            xcorr = np.asarray(resultCv)
        else:
            xcorr = signal.correlate(self.alignment.df, vs)

        # Cross correlate wheel speed trace with the motion energy
        CORRECTION = 2
        self.alignment.c = max(xcorr)
        self.alignment.xcorr = np.argmax(xcorr)
        self.alignment.dt_i = self.alignment.xcorr - xs.size + CORRECTION
        self.log.info(
            f'{side} camera, adjusted by {self.alignment.dt_i} frames')

        if display:
            # Plot the motion energy
            fig, ax = plt.subplots(2, 1, sharex='all')
            y = np.pad(self.alignment.df, 1, 'edge')
            ax[0].plot(x, y, '-x', label='wheel motion energy')
            thresh = stDev > sd_thresh
            ax[0].vlines(x[np.array(
                np.pad(thresh, 1, 'constant', constant_values=False))],
                         0,
                         1,
                         linewidth=0.5,
                         linestyle=':',
                         label=f'>{sd_thresh} s.d. diff')
            ax[1].plot(t[interp_mask], np.abs(v[interp_mask]))

            # Plot other stuff
            dt = np.diff(camera_times[[0, np.abs(self.alignment.dt_i)]])
            fps = 1 / np.diff(camera_times).mean()
            ax[0].plot(t[interp_mask][xs] - dt,
                       vs,
                       'r-x',
                       label='velocity (shifted)')
            ax[0].set_title('normalized motion energy, %s camera, %.0f fps' %
                            (side, fps))
            ax[0].set_ylabel('rate of change (a.u.)')
            ax[0].legend()
            ax[1].set_ylabel('wheel speed (rad / s)')
            ax[1].set_xlabel('Time (s)')

            title = f'{self.ref}, from {period[0]:.1f}s - {period[1]:.1f}s'
            fig.suptitle(title, fontsize=16)
            fig.set_size_inches(19.2, 9.89)

        return self.alignment.dt_i, self.alignment.c, self.alignment.df