Ejemplo n.º 1
0
    def get_spike_waveforms(self, units=None):
        from phylib.io.model import load_model
        from phylib.utils.color import selected_cluster_color

        waveforms = []

        if units is None:
            units = self.select_units()

        for rec_num, recording in enumerate(self.files):
            paramspy = self.processed / f'sorted_{rec_num}' / 'params.py'
            if not paramspy.exists():
                raise PixelsError(f"{self.name}: params.py not found")
            model = load_model(paramspy)
            rec_forms = {}

            for unit in units[rec_num]:
                # get the waveforms from only the best channel
                spike_ids = model.get_cluster_spikes(unit)
                best_chan = model.get_cluster_channels(unit)[0]
                u_waveforms = model.get_waveforms(spike_ids, [best_chan])
                if u_waveforms is None:
                    raise PixelsError(
                        f"{self.name}: unit {unit} - waveforms not read")
                rec_forms[unit] = pd.DataFrame(np.squeeze(u_waveforms).T)
            waveforms.append(pd.concat(rec_forms, axis=1))

        df = pd.concat(waveforms,
                       axis=1,
                       keys=range(len(self.files)),
                       names=['rec_num', 'unit', 'spike'])
        # convert indexes to ms
        rate = 1000 / int(self.spike_meta[rec_num]['imSampRate'])
        df.index = df.index * rate
        return df
Ejemplo n.º 2
0
    def process_motion_tracking(self, config, create_labelled_video=False):
        """
        Run DeepLabCut motion tracking on behavioural videos.
        """
        # bloated so imported when needed
        import deeplabcut  # pylint: disable=import-error

        self.extract_videos()

        config = Path(config).expanduser()
        if not config.exists():
            raise PixelsError(f"Config at {config} not found.")

        for recording in self.files:
            if 'camera_data' in recording:
                video = self.interim / recording['camera_data'].with_suffix(
                    '.avi')
                if not video.exists():
                    raise PixelsError(
                        f"Path {video} should exist but doesn't... discuss.")

                deeplabcut.analyze_videos(config, [video])
                deeplabcut.plot_trajectories(config, [video])
                if create_labelled_video:
                    deeplabcut.create_labeled_video(config, [video])
Ejemplo n.º 3
0
    def __init__(self, mouse_ids, behaviour, data_dir, meta_dir=None):
        if not isinstance(mouse_ids, (list, tuple, set)):
            mouse_ids = [mouse_ids]

        self.behaviour = behaviour
        self.mouse_ids = mouse_ids

        self.data_dir = Path(data_dir).expanduser()
        if not self.data_dir.exists():
            raise PixelsError(f"Directory not found: {data_dir}")

        if meta_dir:
            self.meta_dir = Path(meta_dir).expanduser()
            if not self.meta_dir.exists():
                raise PixelsError(f"Directory not found: {meta_dir}")
        else:
            self.meta_dir = None

        self.raw = self.data_dir / 'raw'
        self.processed = self.data_dir / 'processed'
        self.interim = self.data_dir / 'interim'

        self.sessions = []

        for session in ioutils.get_sessions(mouse_ids, self.data_dir,
                                            self.meta_dir):
            self.sessions.append(
                behaviour(
                    session['name'],
                    metadata=session['metadata'],
                    data_dir=session['data_dir'],
                ))
Ejemplo n.º 4
0
    def process_motion_index(self):
        """
        Extract motion indexes from videos using already drawn ROIs.
        """

        ses_rois = {}

        # First collect all ROIs to catch errors early
        for i, recording in enumerate(self.files):
            if 'camera_data' in recording:
                roi_file = self.processed / f"motion_index_ROIs_{i}.pickle"
                if not roi_file.exists():
                    raise PixelsError(self.name +
                                      ": ROIs not drawn for motion index.")

                # Also check videos are available
                video = self.interim / recording['camera_data'].with_suffix(
                    '.avi')
                if not video.exists():
                    raise PixelsError(
                        self.name + ": AVI video not found in interim folder.")

                with roi_file.open('rb') as fd:
                    ses_rois[i] = pickle.load(fd)

        # Then do the extraction
        for rec_num, recording in enumerate(self.files):
            if 'camera_data' in recording:

                # Get MIs
                raise NotImplementedError
                video = self.interim / recording['camera_data'].with_suffix(
                    '.avi')
                rec_rois = ses_rois[rec_num]
                rec_mi = signal.motion_index(video.as_posix(), rec_rois,
                                             self.sample_rate)

                # TODO: Use timestamps for real alignment
                # Get initial timestamp of behavioural data
                #behavioural_data = ioutils.read_tdms(self.find_file(recording['behaviour']))
                #for key in behavioural_data.keys():
                #    if key.startswith("/'t0'/"):
                #        t0 = behavioural_data[key][0]
                #        break
                #metadata = ioutils.read_tdms(self.find_file(recording['camera_meta']))
                #timestamps = ioutils.tdms_parse_timestamps(metadata)

                np.save(self.processed / f'motion_index_{i}.npy', rec_mi)
Ejemplo n.º 5
0
    def _get_spike_times(self):
        """
        Returns the sorted spike times.
        """
        saved = self._spike_times_data
        if saved[0] is None:
            for rec_num, recording in enumerate(self.files):
                times = self.processed / f'sorted_{rec_num}' / 'spike_times.npy'
                clust = self.processed / f'sorted_{rec_num}' / 'spike_clusters.npy'

                try:
                    times = np.load(times)
                    clust = np.load(clust)
                except FileNotFoundError:
                    msg = ": Can't load spike times that haven't been extracted!"
                    raise PixelsError(self.name + msg)

                times = np.squeeze(times)
                clust = np.squeeze(clust)
                by_clust = {}

                for c in np.unique(clust):
                    by_clust[c] = pd.Series(
                        times[clust == c]).drop_duplicates()
                saved[rec_num] = pd.concat(by_clust, axis=1, names=['unit'])
        return saved
Ejemplo n.º 6
0
    def _get_processed_data(self, attr, key):
        """
        Used by the following get_X methods to load processed data.

        Parameters
        ----------
        attr : str
            The self attribute that stores the data.

        key : str
            The key for the files in each recording of self.files that contain this
            data.

        """
        saved = getattr(self, attr)
        if saved[0] is None:
            for rec_num, recording in enumerate(self.files):
                if key in recording:
                    file_path = self.processed / recording[key]
                    if file_path.exists():
                        if file_path.suffix == '.npy':
                            saved[rec_num] = np.load(file_path)
                        elif file_path.suffix == '.h5':
                            saved[rec_num] = ioutils.read_hdf5(file_path)
                    else:
                        msg = f"Could not find {attr[1:]} for recording {rec_num}."
                        msg += f"\nFile should be at: {file_path}"
                        raise PixelsError(msg)
        return saved
Ejemplo n.º 7
0
 def get_probe_depth(self):
     """
     Load probe depth in um from file if it has been recorded.
     """
     depth_file = self.processed / 'depth.txt'
     if not depth_file.exists():
         msg = f": Can't load probe depth: please add it in um to processed/{self.name}/depth.txt"
         raise PixelsError(msg)
     with depth_file.open() as fd:
         return [float(line) for line in fd.readlines()]
Ejemplo n.º 8
0
    def get_cluster_info(self):
        cluster_info = []

        for rec_num, recording in enumerate(self.files):
            info_file = self.processed / f'sorted_{rec_num}' / 'cluster_info.tsv'
            try:
                info = pd.read_csv(info_file, sep='\t')
            except FileNotFoundError:
                msg = ": Can't load cluster info. Did you sort this session yet?"
                raise PixelsError(self.name + msg)

            cluster_info.append(info)

        return cluster_info
Ejemplo n.º 9
0
    def draw_motion_index_rois(self, num_rois=1):
        """
        Draw motion index ROIs using EasyROI. If ROIs already exist, skip.

        Parameters
        ----------
        num_rois : int
            The number of ROIs to draw interactively. Default: 1

        """
        # Only needed for this method
        import cv2
        import EasyROI

        roi_helper = EasyROI.EasyROI(verbose=False)

        for i, recording in enumerate(self.files):
            if 'camera_data' in recording:
                roi_file = self.processed / f"motion_index_ROIs_{i}.pickle"
                if roi_file.exists():
                    continue

                # Load frame from video
                video = self.interim / recording['camera_data'].with_suffix(
                    '.avi')
                if not video.exists():
                    raise PixelsError(
                        self.name +
                        ": AVI video not found, run `extract_videos`")

                duration = ioutils.get_video_dimensions(video.as_posix())[2]
                frame = ioutils.load_video_frame(video.as_posix(),
                                                 duration // 4)

                # Interactively draw ROI
                roi = roi_helper.draw_polygon(frame, num_rois)
                cv2.destroyAllWindows()  # Needed otherwise EasyROI errors

                # Save a copy of the frame with ROIs to PNG file
                png = self.processed / f'motion_index_ROIs_{i}.png'
                copy = EasyROI.visualize_polygon(frame, roi, color=(255, 0, 0))
                plt.imsave(png, copy, cmap='gray')

                # Save ROI to file
                with roi_file.open('wb') as fd:
                    pickle.dump(roi['roi'], fd)
Ejemplo n.º 10
0
    def sort_spikes(self):
        """
        Run kilosort spike sorting on raw spike data.
        """
        for rec_num, recording in enumerate(self.files):
            print(
                f">>>>> Spike sorting recording {rec_num + 1} of {len(self.files)}"
            )

            output = self.processed / f'sorted_{rec_num}'
            data_file = self.find_file(recording['spike_data'])
            try:
                recording = se.SpikeGLXRecordingExtractor(file_path=data_file)
            except ValueError as e:
                raise PixelsError(
                    f"Did the raw data get fully copied to interim? Full error: {e}"
                )

            print(f"> Running kilosort")
            ss.run_kilosort3(recording=recording, output_folder=output)
Ejemplo n.º 11
0
def save_ndarray_as_video(video, path, frame_rate):
    """
    Save a numpy.ndarray as video file.

    Parameters
    ----------
    video : numpy.ndarray
        Video data to save to file. It's dimensions should be (duration, height, width)
        and data should be of uint8 type. The file extension determines the resultant
        file type.

    path : string / pathlib.Path object
        File to which the video will be saved.

    frame_rate : int
        The frame rate of the output video.

    """
    _, height, width = video.shape
    path = Path(path)

    process = (ffmpeg.input(
        'pipe:',
        format='rawvideo',
        pix_fmt='rgb24',
        s=f'{width}x{height}',
        r=frame_rate).output(
            path.as_posix(),
            pix_fmt='yuv420p',
            r=frame_rate,
            crf=0,
            vcodec='libx264').overwrite_output().run_async(pipe_stdin=True))

    for frame in video:
        process.stdin.write(
            np.stack([frame, frame, frame], axis=2).astype(np.uint8).tobytes())

    process.stdin.close()
    process.wait()
    if not path.exists():
        raise PixelsError(f"Video creation failed: {path}")
Ejemplo n.º 12
0
    def align_trials(self,
                     label,
                     event,
                     data='spike_times',
                     raw=False,
                     duration=1,
                     sigma=None,
                     units=None):
        """
        Get trials aligned to an event. This finds all instances of label in the action
        labels - these are the start times of the trials. Then this finds the first
        instance of event on or after these start times of each trial. Then it cuts out
        a period around each of these events covering all units, rearranges this data
        into a MultiIndex DataFrame and returns it.

        Parameters
        ----------
        label : int
            An action label value to specify which trial types are desired.

        event : int
            An event type value to specify which event to align the trials to.

        data : str, optional
            The data type to align.

        raw : bool, optional
            Whether to get raw, unprocessed data instead of processed and downsampled
            data. Defaults to False.

        duration : int/float, optional
            The length of time in seconds desired in the output. Default is 1 second.

        sigma : int, optional
            Time in milliseconds of sigma of gaussian kernel to use when aligning firing
            rates. Default is 50 ms.

        units : list of lists of ints, optional
            The output from self.select_units, used to only apply this method to a
            selection of units.

        """
        data = data.lower()

        data_options = [
            'behavioural',  # Channels from behaviour TDMS file
            'spike',  # Raw/downsampled channels from probe (AP)
            'spike_times',  # List of spike times per unit
            'spike_rate',  # Spike rate signals from convolved spike times
            'lfp',  # Raw/downsampled channels from probe (LFP)
            'motion_index',  # Motion indexes per ROI from the video
        ]
        if data not in data_options:
            raise PixelsError(
                f"align_trials: 'data' should be one of: {data_options}")

        if data in ("spike_times", "spike_rate"):
            print(f"Aligning {data} to trials.")
            # we let a dedicated function handle aligning spike times
            return self._get_aligned_spike_times(label,
                                                 event,
                                                 duration,
                                                 rate=data == "spike_rate",
                                                 sigma=sigma,
                                                 units=units)

        action_labels = self.get_action_labels()

        if raw:
            print(f"Aligning raw {data} data to trials.")
            getter = getattr(self, f"get_{data}_data_raw", None)
            if not getter:
                raise PixelsError(
                    f"align_trials: {data} doesn't have a 'raw' option.")
            values, sample_rate = getter()

        else:
            print(f"Aligning {data} data to trials.")
            values = getattr(self, f"get_{data}_data")()
            sample_rate = self.sample_rate

        if not values or values[0] is None:
            raise PixelsError(f"align_trials: Could not get {data} data.")

        rec_trials = []
        # The logic here is that the action labels will always have a sample rate of
        # self.sample_rate, whereas our data here may differ. 'duration' is used to scan
        # the action labels, so always give it 5 seconds to scan, then 'half' is used to
        # index data.
        scan_duration = self.sample_rate * 10
        half = (sample_rate * duration) // 2

        for rec_num in range(len(self.files)):
            if values[rec_num] is None:
                # This means that each recording is using the same piece of data for
                # this data type, e.g. all recordings using motion indexes from a single
                # video
                break

            trials = []
            actions = action_labels[rec_num][:, 0]
            events = action_labels[rec_num][:, 1]
            trial_starts = np.where(np.bitwise_and(actions, label))[0]

            for start in trial_starts:
                centre = np.where(
                    np.bitwise_and(events[start:start + scan_duration],
                                   event))[0]
                if len(centre) == 0:
                    raise PixelsError('Action labels probably miscalculated')
                centre = start + centre[0]
                centre = int(centre * sample_rate / self.sample_rate)
                trial = values[rec_num][centre - half + 1:centre + half + 1]

                if isinstance(trial, np.ndarray):
                    trial = pd.DataFrame(trial)
                trials.append(trial.reset_index(drop=True))

            rec_trials.append(
                pd.concat(trials,
                          axis=1,
                          keys=range(len(trials)),
                          names=['trial', 'unit']))

        if not trials:
            raise PixelsError(
                "Seems the action-event combo you asked for doesn't occur")

        ses_trials = pd.concat(rec_trials,
                               axis=1,
                               copy=False,
                               keys=range(len(trials)),
                               names=["rec_num", "trial", "unit"])
        ses_trials = ses_trials.sort_index(level=1, axis=1)
        ses_trials = ses_trials.reorder_levels(["rec_num", "unit", "trial"],
                                               axis=1)

        points = ses_trials.shape[0]
        start = (-duration / 2) + (duration / points)
        timepoints = np.linspace(start, duration / 2, points)
        ses_trials['time'] = pd.Series(timepoints, index=ses_trials.index)
        ses_trials = ses_trials.set_index('time')
        return ses_trials
Ejemplo n.º 13
0
    def _get_aligned_spike_times(self,
                                 label,
                                 event,
                                 duration,
                                 rate=False,
                                 sigma=None,
                                 units=None):
        """
        Returns spike times for each unit within a given time window around an event.
        align_trials delegates to this function, and should be used for getting aligned
        data in scripts.
        """
        action_labels = self.get_action_labels()
        spikes = self._get_spike_times()

        if units is None:
            units = self.select_units()

        if rate:
            # pad ends with 1 second extra to remove edge effects from convolution
            duration += 2

        scan_duration = self.sample_rate * 5
        half = int((self.sample_rate * duration) / 2)
        trials = []

        for rec_num in range(len(self.files)):
            actions = action_labels[rec_num][:, 0]
            events = action_labels[rec_num][:, 1]
            trial_starts = np.where(np.bitwise_and(actions, label))[0]

            rec_spikes = spikes[rec_num]
            rec_spikes = rec_spikes[units[rec_num]]
            rec_trials = []

            # Convert to ms (self.sample_rate)
            f = int(self.spike_meta[rec_num]['imSampRate']) / self.sample_rate
            rec_spikes = rec_spikes / f

            # Account for lag, in case the ephys recording was started before the
            # behaviour
            lag_start, _ = self._lag[rec_num]
            if lag_start < 0:
                rec_spikes = rec_spikes + lag_start

            for i, start in enumerate(trial_starts):
                centre = np.where(
                    np.bitwise_and(events[start:start + scan_duration],
                                   event))[0]
                if len(centre) == 0:
                    raise PixelsError('Action labels probably miscalculated')
                centre = start + centre[0]

                trial = rec_spikes[centre - half < rec_spikes]
                trial = trial[trial <= centre + half]
                trial = trial - centre
                tdf = []

                for unit in trial:
                    u_times = trial[unit].values
                    u_times = u_times[~np.isnan(u_times)]
                    u_times = np.unique(
                        u_times)  # remove double-counted spikes
                    udf = pd.DataFrame({int(unit): u_times})
                    tdf.append(udf)

                if tdf:
                    tdfc = pd.concat(tdf, axis=1)
                    if rate:
                        tdfc = signal.convolve(tdfc, duration * 1000, sigma)
                    rec_trials.append(tdfc)

            rec_df = pd.concat(rec_trials, axis=1, keys=range(len(rec_trials)))
            trials.append(rec_df)

        trials = pd.concat(trials,
                           axis=1,
                           keys=range(len(trials)),
                           names=["rec_num", "trial", "unit"])
        trials = trials.reorder_levels(["rec_num", "unit", "trial"], axis=1)
        trials = trials.sort_index(level=0, axis=1)

        if rate:
            # Set index to seconds and remove the padding 1 sec at each end
            points = trials.shape[0]
            start = (-duration / 2) + (duration / points)
            timepoints = np.linspace(start, duration / 2, points)
            trials['time'] = pd.Series(timepoints, index=trials.index)
            trials = trials.set_index('time')
            trials = trials.iloc[self.sample_rate:-self.sample_rate]

        return trials
Ejemplo n.º 14
0
def get_sessions(mouse_ids, data_dir, meta_dir):
    """
    Get a list of recording sessions for the specified mice, excluding those whose
    metadata contain '"exclude" = True'.

    Parameters
    ----------
    mouse_ids : list of strs
        List of mouse IDs.

    data_dir : str
        The path to the folder containing data for all sessions. This is searched for
        available sessions.

    meta_dir : str or None
        If not None, the path to the folder containing training metadata JSON files. If
        None, no metadata is collected.

    Returns
    -------
    list of dicts : Dictionaries containing the values that can be used to create new
        Behaviour subclass instances.

    """
    if not isinstance(mouse_ids, (list, tuple, set)):
        mouse_ids = [mouse_ids]
    sessions = []
    raw_dir = data_dir / 'raw'

    for mouse in mouse_ids:
        mouse_sessions = list(raw_dir.glob(f'*{mouse}*'))

        if not mouse_sessions:
            print(f'Found no sessions for: {mouse}')
            continue

        if not meta_dir:
            # Do not collect metadata
            for session in mouse_sessions:
                sessions.append(
                    dict(
                        name=session.stem,
                        metadata=None,
                        data_dir=data_dir,
                    ))
            continue

        meta_file = meta_dir / (mouse + '.json')
        with meta_file.open() as fd:
            mouse_meta = json.load(fd)
        session_dates = [
            datetime.datetime.strptime(s.stem[0:6], '%y%m%d')
            for s in mouse_sessions
        ]

        if len(session_dates) != len(set(session_dates)):
            raise PixelsError(f"{mouse}: Data folder dates must be unique.")

        s = 0
        for i, session in enumerate(mouse_meta):
            try:
                meta_date = datetime.datetime.strptime(session['date'],
                                                       '%Y-%m-%d')
            except TypeError:
                raise PixelsError(
                    f"{mouse} session #{i}: 'date' not found in JSON.")

            for index, ses_date in enumerate(session_dates):
                if ses_date == meta_date and not session.get('exclude', False):
                    s += 1
                    sessions.append(
                        dict(
                            name=mouse_sessions[index].stem,
                            metadata=session,
                            data_dir=data_dir,
                        ))

        if s == 0:
            print(
                f'No session dates match between folders and metadata for: {mouse}'
            )

    return sessions
Ejemplo n.º 15
0
def get_data_files(data_dir, session_name):
    """
    Get the file names of raw data for a session.

    Parameters
    ----------
    data_dir : str
        The directory containing the data.

    session_name : str
        The name of the session for which to get file names.

    Returns
    -------
    A list of dicts, where each dict corresponds to one recording. The dict will contain
    these keys to identify data files:

        - spike_data
        - spike_meta
        - lfp_data
        - lfp_meta
        - behaviour
        - camera_data
        - camera_meta

    """
    if session_name != data_dir.stem:
        data_dir = list(data_dir.glob(f'{session_name}*'))[0]
    files = []

    spike_data = sorted(
        glob.glob(f'{data_dir}/{session_name}_g[0-9]_t0.imec[0-9].ap.bin*'))
    spike_meta = sorted(
        glob.glob(f'{data_dir}/{session_name}_g[0-9]_t0.imec[0-9].ap.meta*'))
    lfp_data = sorted(
        glob.glob(f'{data_dir}/{session_name}_g[0-9]_t0.imec[0-9].lf.bin*'))
    lfp_meta = sorted(
        glob.glob(f'{data_dir}/{session_name}_g[0-9]_t0.imec[0-9].lf.meta*'))
    behaviour = sorted(glob.glob(f'{data_dir}/[0-9a-zA-Z_-]*([0-9]).tdms*'))

    camera = sorted(glob.glob(f'{data_dir}/[0-9a-zA-Z_-]*([0-9])-*.tdms*'))
    camera_data = []
    camera_meta = []
    for match in camera:
        if 'meta' in match:
            camera_meta.append(match)
        else:
            camera_data.append(match)

    if not spike_data:
        raise PixelsError(f"{session_name}: could not find raw AP data file.")
    if not spike_meta:
        raise PixelsError(
            f"{session_name}: could not find raw AP metadata file.")
    if not lfp_data:
        raise PixelsError(f"{session_name}: could not find raw LFP data file.")
    if not lfp_meta:
        raise PixelsError(
            f"{session_name}: could not find raw LFP metadata file.")

    for num, spike_recording in enumerate(spike_data):
        recording = {}
        recording['spike_data'] = original_name(spike_recording)
        recording['spike_meta'] = original_name(spike_meta[num])
        recording['lfp_data'] = original_name(lfp_data[num])
        recording['lfp_meta'] = original_name(lfp_meta[num])
        if behaviour:
            if len(behaviour) == len(spike_data):
                recording['behaviour'] = original_name(behaviour[num])
            else:
                recording['behaviour'] = original_name(behaviour[0])
            recording['behaviour_processed'] = recording[
                'behaviour'].with_name(recording['behaviour'].stem +
                                       '_processed.h5')
        else:
            recording['behaviour'] = None
            recording['behaviour_processed'] = None
        if len(camera_data) > num:
            recording['camera_data'] = original_name(camera_data[num])
            recording['camera_meta'] = original_name(camera_meta[num])
            recording['motion_index'] = Path(f'motion_index_{num}.npy')
        recording['action_labels'] = Path(f'action_labels_{num}.npy')
        recording['spike_processed'] = recording['spike_data'].with_name(
            recording['spike_data'].stem + '_processed.h5')
        recording['spike_rate_processed'] = Path(f'spike_rate_{num}.h5')
        recording['lfp_processed'] = recording['lfp_data'].with_name(
            recording['lfp_data'].stem + '_processed.h5')
        files.append(recording)

    return files