예제 #1
0
    def __init__(
        self,
        spike_times: np.array,
        pos_times: np.array,
        spk_clusters: np.array,
        x: np.array,
        y: np.array,
        tracker_params={},
    ):
        """
        Parameters
        ----------
        spike_times - 1d np.array
        pos_times - 1d np.array
        spk_clusters - 1d np.array
        x and y - 1d np.array
        tracker_params - dict - from the PosTracker as created in
            OEKiloPhy.Settings.parse

        NB All timestamps should be given in sub-millisecond accurate
             seconds and pos_xy in cms
        """
        self.spike_times = spike_times
        self.pos_times = pos_times
        self.spk_clusters = spk_clusters
        """
        There can be more spikes than pos samples in terms of sampling as the
        open-ephys buffer probably needs to finish writing and the camera has
        already stopped, so cut of any cluster indices and spike times
        that exceed the length of the pos indices
        """
        idx_to_keep = self.spike_times < self.pos_times[-1]
        self.spike_times = self.spike_times[idx_to_keep]
        self.spk_clusters = self.spk_clusters[idx_to_keep]
        self._pos_sample_rate = 30
        self._spk_sample_rate = 3e4
        self._pos_samples_for_spike = None
        self._min_runlength = 0.4  # in seconds
        self.posCalcs = PosCalcsGeneric(x,
                                        y,
                                        230,
                                        cm=True,
                                        jumpmax=100,
                                        tracker_params=tracker_params)
        self.spikeCalcs = SpikeCalcsGeneric(spike_times)
        self.spikeCalcs.spk_clusters = spk_clusters
        self.posCalcs.postprocesspos(tracker_params)
        xy = self.posCalcs.xy
        hdir = self.posCalcs.dir
        self.posCalcs.calcSpeed(xy)
        self._xy = xy
        self._hdir = hdir
        self._speed = self.posCalcs.speed
        # TEMPORARY FOR POWER SPECTRUM STUFF
        self.smthKernelWidth = 2
        self.smthKernelSigma = 0.1875
        self.sn2Width = 2
        self.thetaRange = [7, 11]
        self.xmax = 11
예제 #2
0
	def exportPos(self, ppm=300, jumpmax=100, as_text=False):
		#
		# Step 1) Deal with the position data first:
		#
		# Grab the settings of the pos tracker and do some post-processing on the position
		# data (discard jumpy data, do some smoothing etc)
		# settings = OESettings.Settings(os.path.join(self.dirname, 'settings.xml'))
		self.settings.parsePos()
		posProcessor = PosCalcsGeneric(self.OE_data.xy[:,0], self.OE_data.xy[:,1], ppm, True, jumpmax)
		print("Post-processing position data...")
		xy, _ = posProcessor.postprocesspos(self.settings.tracker_params)
		xy = xy.T
		if as_text is True:
			print("Beginning export of position data to text format...")
			pos_file_name = self.axona_root_name + ".txt"
			np.savetxt(pos_file_name, self.OE_data.xy, fmt='%1.u')
			print("Completed export of position data")
			return
		# Do the upsampling of both xy and the timestamps
		print("Beginning export of position data to Axona format...")
		axona_pos_file_name = self.axona_root_name + ".pos"
		axona_pos_data = self.convertPosData(xy, self.OE_data.xyTS)
		# make sure pos data length is same as duration * num_samples
		axona_pos_data = axona_pos_data[0:int(self.last_pos_ts - self.first_pos_ts)*50]
		# Create an empty header for the pos data
		pos_header = self.AxonaData.getEmptyHeader("pos")
		for key in pos_header.keys():
			if 'min_x' in key:
				pos_header[key] = str(self.settings.tracker_params['LeftBorder'])
			if 'min_y' in key:
				pos_header[key] = str(self.settings.tracker_params['TopBorder'])
			if 'max_x' in key:
				pos_header[key] = str(self.settings.tracker_params['RightBorder'])
			if 'max_y' in key:
				pos_header[key] = str(self.settings.tracker_params['BottomBorder'])
		pos_header['duration'] = str(int(self.last_pos_ts - self.first_pos_ts))
		# Rest of this stuff probably won't change so should be defaulted in the loaded file
		# (see axonaIO.py)
		pos_header['num_colours'] = '4'
		pos_header['sw_version'] = '1.2.2.1'
		pos_header['timebase'] = '50 hz'
		pos_header['sample_rate'] = '50.0 hz'
		pos_header['pos_format'] = 't,x1,y1,x2,y2,numpix1,numpix2'
		pos_header['bytes_per_coord'] = '2'
		pos_header['EEG_samples_per_position'] = '5'
		pos_header['bytes_per_timestamp'] = '4'
		pos_header['pixels_per_metre'] = str(ppm)
		pos_header['num_pos_samples'] = str(len(axona_pos_data))
		pos_header['bearing_colour_1'] = '210'
		pos_header['bearing_colour_2'] = '30'
		pos_header['bearing_colour_3'] = '0'
		pos_header['bearing_colour_4'] = '0'
		pos_header['pixels_per_metre'] = str(ppm)

		self.writePos2AxonaFormat(pos_header, axona_pos_data)
		print("Exported position data to Axona format")
예제 #3
0
    def loadPos(self, *args, **kwargs):
        # Only sub-class that doesn't use this is OpenEphysNWB
        # which needs updating
        # TODO: Update / overhaul OpenEphysNWB
        # Load the start time from the sync_messages file
        recording_start_time = 0
        if self.sync_message_file is not None:
            with open(self.sync_message_file, "r") as f:
                sync_strs = f.read()
            sync_lines = sync_strs.split("\n")
            for line in sync_lines:
                if "subProcessor: 0" in line:
                    idx = line.find("start time: ")
                    start_val = line[idx + len("start time: "):-1]
                    tmp = start_val.split("@")
                    recording_start_time = float(tmp[0])
        if self.path2PosData is not None:
            pos_data_type = getattr(self, "pos_data_type", "PosTracker")
            if pos_data_type == "PosTracker":
                print("Loading PosTracker data...")
                pos_data = np.load(
                    os.path.join(self.path2PosData, "data_array.npy"))
            if pos_data_type == "TrackingPlugin":
                print("Loading Tracking Plugin data...")
                pos_data = loadTrackingPluginData(
                    os.path.join(self.path2PosData, "data_array.npy"))
            pos_ts = np.load(os.path.join(self.path2PosData, "timestamps.npy"))
            pos_ts = np.ravel(pos_ts)
            pos_timebase = getattr(self, "pos_timebase", 3e4)
            sample_rate = np.floor(1 / np.mean(np.diff(pos_ts) / pos_timebase))
            self.xyTS = pos_ts - recording_start_time
            pos_timebase = getattr(self, "pos_timebase", 3e4)
            self.xyTS = self.xyTS / pos_timebase  # convert to seconds
            if self.sync_message_file is not None:
                recording_start_time = self.xyTS[0]
            self.pos_sample_rate = sample_rate
            self.orig_x = pos_data[:, 0]
            self.orig_y = pos_data[:, 1]

            P = PosCalcsGeneric(
                pos_data[:, 0],
                pos_data[:, 1],
                cm=True,
                ppm=self.ppm,
                jumpmax=self.jumpmax,
            )
            P.postprocesspos({"SampleRate": sample_rate})
            setattr(self, "PosCalcs", P)
            self.xy = P.xy
            self.dir = P.dir
            self.speed = P.speed
        else:
            warnings.warn("Could not find the pos data. \
                Make sure there is a pos_data folder with data_array.npy \
                and timestamps.npy in")
        self.recording_start_time = recording_start_time
예제 #4
0
    def exportPos(self, ppm=300, jumpmax=100, as_text=False):
        #
        # Step 1) Deal with the position data first:
        #
        # Grab the settings of the pos tracker and do some post-processing
        # on the position
        # data (discard jumpy data, do some smoothing etc)
        self.settings.parse()
        posProcessor = PosCalcsGeneric(self.OE_data.xy[:, 0],
                                       self.OE_data.xy[:,
                                                       1], ppm, True, jumpmax)
        print("Post-processing position data...")
        self.settings.tracker_params["AxonaBadValue"] = 1023
        posProcessor.postprocesspos(self.settings.tracker_params)
        xy = posProcessor.xy.T
        if as_text is True:
            print("Beginning export of position data to text format...")
            pos_file_name = self.axona_root_name + ".txt"
            np.savetxt(pos_file_name, self.OE_data.xy, fmt="%1.u")
            print("Completed export of position data")
            return
        # Do the upsampling of both xy and the timestamps
        print("Beginning export of position data to Axona format...")
        axona_pos_data = self.convertPosData(xy, self.OE_data.xyTS)
        # make sure pos data length is same as duration * num_samples
        axona_pos_data = axona_pos_data[0:int(self.last_pos_ts -
                                              self.first_pos_ts) * 50]
        # Create an empty header for the pos data
        from ephysiopy.dacq2py.axona_headers import PosHeader

        pos_header = PosHeader()
        pos_header.pos["min_x"] = str(
            self.settings.tracker_params["LeftBorder"])
        pos_header.pos[".min_y"] = str(
            self.settings.tracker_params["TopBorder"])
        pos_header.pos[".max_x"] = str(
            self.settings.tracker_params["RightBorder"])
        pos_header.pos[".max_y"] = str(
            self.settings.tracker_params["BottomBorder"])
        pos_header.common["duration"] = str(
            int(self.last_pos_ts - self.first_pos_ts))
        pos_header.pos["pixels_per_metre"] = str(ppm)
        pos_header.pos["num_pos_samples"] = str(len(axona_pos_data))
        pos_header.pos["pixels_per_metre"] = str(ppm)

        self.writePos2AxonaFormat(pos_header, axona_pos_data)
        print("Exported position data to Axona format")
예제 #5
0
def basic_PosCalcs(basic_xy):
    '''
    Returns a PosCalcsGeneric instance initialised with some random
    walk xy data
    '''
    x = basic_xy[0]
    y = basic_xy[1]
    ppm = 300  # pixels per metre value
    return PosCalcsGeneric(x, y, ppm)
예제 #6
0
	def plotMapsOneAtATime(self, plot_type='map', **kwargs):
		"""
		Parameters
		----------
		plot_type : str or list
			The kind of plot to produce.  Valid strings include:
			* 'map' - just ratemap plotted
			* 'path' - just spikes on path
			* 'both' - both of the above
			* 'all' - both spikes on path, ratemap & SAC plotted
		kwargs :
		* 'ppm' - Integer denoting pixels per metre where lower values = more bins in ratemap / SAC
		* 'clusters' - int or list of ints describing which clusters to plot
		* 'save_grid_summary_location' - bool; if True the dictionary returned from gridcell.SAC.getMeasures is saved for each cluster
		"""

		if self.kilodata is None:
			self.loadKilo()
		if ( 'ppm' in kwargs.keys() ):
			ppm = kwargs['ppm']
		else:
			ppm = 400
		from ephysiopy.common.ephys_generic import PosCalcsGeneric, MapCalcsGeneric
		if self.xy is None:
			self.__loaddata__(**kwargs)
		posProcessor = PosCalcsGeneric(self.xy[:,0], self.xy[:,1], ppm, jumpmax=self.jumpmax)
		import os
		self.__loadSettings__()
		xy, hdir = posProcessor.postprocesspos(self.settings.tracker_params)
		self.hdir = hdir
		spk_times = (self.kilodata.spk_times.T / 3e4) + self.recording_start_time
		mapiter = MapCalcsGeneric(xy, np.squeeze(hdir), posProcessor.speed, self.xyTS, spk_times, plot_type, **kwargs)
		if 'clusters' in kwargs:
			if type(kwargs['clusters']) == int:
				mapiter.good_clusters = np.intersect1d([kwargs['clusters']], self.kilodata.good_clusters)

			else:
				mapiter.good_clusters = np.intersect1d(kwargs['clusters'], self.kilodata.good_clusters)
		else:
			mapiter.good_clusters = self.kilodata.good_clusters
		mapiter.spk_clusters = self.kilodata.spk_clusters
		self.mapiter = mapiter
		[ print("") for cluster in mapiter ]
예제 #7
0
	def prepareMaps(self, **kwargs):
		"""Initialises a MapCalcsGeneric object by providing it with positional and
		spiking data.

		I don't like the name of this method but it is useful to be able to separate
		out the preparation of the MapCalcsGeneric object as there are two major uses;
		actually plotting the maps and/ or extracting data from them without plotting
		"""
		if self.kilodata is None:
			self.loadKilo()
		if ( 'ppm' in kwargs.keys() ):
			ppm = kwargs['ppm']
		else:
			ppm = 400
		from ephysiopy.common.ephys_generic import PosCalcsGeneric, MapCalcsGeneric
		if self.xy is None:
			self.__loaddata__(**kwargs)
		posProcessor = PosCalcsGeneric(self.xy[:,0], self.xy[:,1], ppm, jumpmax=self.jumpmax)
		import os
		self.__loadSettings__()
		xy, hdir = posProcessor.postprocesspos(self.settings.tracker_params)
		self.hdir = hdir
		spk_times = (self.kilodata.spk_times.T / 3e4) + self.recording_start_time
		if 'plot_type' in kwargs:
			plot_type = kwargs['plot_type']
		else:
			plot_type = 'map'
		mapiter = MapCalcsGeneric(xy, np.squeeze(hdir), posProcessor.speed, self.xyTS, spk_times, plot_type, **kwargs)
		if 'cluster' in kwargs:
			if type(kwargs['cluster']) == int:
				mapiter.good_clusters = np.intersect1d([kwargs['cluster']], self.kilodata.good_clusters)

			else:
				mapiter.good_clusters = np.intersect1d(kwargs['cluster'], self.kilodata.good_clusters)
		else:
			mapiter.good_clusters = self.kilodata.good_clusters
		mapiter.spk_clusters = self.kilodata.spk_clusters
		self.mapiter = mapiter
		return mapiter
예제 #8
0
    def load(self, *args, **kwargs):
        """
        Minially, there should be at least a .set file
        Other files (.eeg, .pos, .stm, .1, .2 etc) are essentially optional

        """
        if self.settings is not None:
            print("Loaded .set file")
        # Give ppm a default value from the set file...
        self.ppm = int(self.settings["tracker_pixels_per_metre"])
        # ...with the option to over-ride
        if "ppm" in kwargs:
            self.ppm = kwargs["ppm"]

        # ------------------------------------
        # ------------- Pos data -------------
        # ------------------------------------
        if self.xy is None:
            try:
                AxonaPos = axonaIO.Pos(self.common_name)
                P = PosCalcsGeneric(
                    AxonaPos.led_pos[:, 0],
                    AxonaPos.led_pos[:, 1],
                    cm=True,
                    ppm=self.ppm,
                )
                P.postprocesspos(tracker_params={"AxonaBadValue": 1023})
                self.xy = P.xy
                self.xyTS = AxonaPos.ts - AxonaPos.ts[0]
                self.dir = P.dir
                self.speed = P.speed
                self.pos_sample_rate = AxonaPos.getHeaderVal(
                    AxonaPos.header, "sample_rate"
                )

                print("Loaded .pos file")
            except IOError:
                print("Couldn't load the pos data")
예제 #9
0
	def plotPos(self, jumpmax=None, show=True, **kwargs):
		"""
		Plots x vs y position for the current trial

		Parameters
		----------
		jumpmax : int
			The max amount the LED is allowed to instantaneously move
		show : bool
			Whether to plot the pos into a figure window or not (default True)

		Returns
		----------
		xy : array_like
			positional data following post-processing
		"""
		if jumpmax is None:
			jumpmax = self.jumpmax
		import matplotlib.pylab as plt
		from ephysiopy.common.ephys_generic import PosCalcsGeneric

		self.__loadSettings__()
		if self.xy is None:
			self.__loaddata__(**kwargs)
		posProcessor = PosCalcsGeneric(self.xy[:,0], self.xy[:,1], ppm=300, cm=True, jumpmax=jumpmax)
		xy, hdir = posProcessor.postprocesspos(self.settings.tracker_params)
		self.hdir = hdir
		if 'saveas' in kwargs:
			saveas = kwargs['saveas']
			plt.plot(xy[0], xy[1])
			plt.gca().invert_yaxis()
			plt.savefig(saveas)
		if show:
			plt.plot(xy[0], xy[1])
			plt.gca().invert_yaxis()
			ax = plt.gca()
			return ax, xy
		return xy
예제 #10
0
 def load_pos_data(self,
                   pname: Path,
                   ppm: int = 300,
                   jumpmax: int = 100) -> None:
     if self.PosCalcs is None:
         try:
             AxonaPos = Pos(self.pname)
             P = PosCalcsGeneric(
                 AxonaPos.led_pos[0, :],
                 AxonaPos.led_pos[1, :],
                 cm=True,
                 ppm=self.ppm,
             )
             P.xyTS = Pos.ts
             P.sample_rate = AxonaPos.getHeaderVal(AxonaPos.header,
                                                   "sample_rate")
             P.postprocesspos()
             print("Loaded pos data")
             self.PosCalcs = P
         except IOError:
             print("Couldn't load the pos data")
예제 #11
0
class CosineDirectionalTuning(object):
    """
    Produces output to do with Welday et al (2011) like analysis
    of rhythmic firing a la oscialltory interference model
    """
    def __init__(
        self,
        spike_times: np.array,
        pos_times: np.array,
        spk_clusters: np.array,
        x: np.array,
        y: np.array,
        tracker_params={},
    ):
        """
        Parameters
        ----------
        spike_times - 1d np.array
        pos_times - 1d np.array
        spk_clusters - 1d np.array
        x and y - 1d np.array
        tracker_params - dict - from the PosTracker as created in
            OEKiloPhy.Settings.parse

        NB All timestamps should be given in sub-millisecond accurate
             seconds and pos_xy in cms
        """
        self.spike_times = spike_times
        self.pos_times = pos_times
        self.spk_clusters = spk_clusters
        """
        There can be more spikes than pos samples in terms of sampling as the
        open-ephys buffer probably needs to finish writing and the camera has
        already stopped, so cut of any cluster indices and spike times
        that exceed the length of the pos indices
        """
        idx_to_keep = self.spike_times < self.pos_times[-1]
        self.spike_times = self.spike_times[idx_to_keep]
        self.spk_clusters = self.spk_clusters[idx_to_keep]
        self._pos_sample_rate = 30
        self._spk_sample_rate = 3e4
        self._pos_samples_for_spike = None
        self._min_runlength = 0.4  # in seconds
        self.posCalcs = PosCalcsGeneric(x,
                                        y,
                                        230,
                                        cm=True,
                                        jumpmax=100,
                                        tracker_params=tracker_params)
        self.spikeCalcs = SpikeCalcsGeneric(spike_times)
        self.spikeCalcs.spk_clusters = spk_clusters
        self.posCalcs.postprocesspos(tracker_params)
        xy = self.posCalcs.xy
        hdir = self.posCalcs.dir
        self.posCalcs.calcSpeed(xy)
        self._xy = xy
        self._hdir = hdir
        self._speed = self.posCalcs.speed
        # TEMPORARY FOR POWER SPECTRUM STUFF
        self.smthKernelWidth = 2
        self.smthKernelSigma = 0.1875
        self.sn2Width = 2
        self.thetaRange = [7, 11]
        self.xmax = 11

    @property
    def spk_sample_rate(self):
        return self._spk_sample_rate

    @spk_sample_rate.setter
    def spk_sample_rate(self, value):
        self._spk_sample_rate = value

    @property
    def pos_sample_rate(self):
        return self._pos_sample_rate

    @pos_sample_rate.setter
    def pos_sample_rate(self, value):
        self._pos_sample_rate = value

    @property
    def min_runlength(self):
        return self._min_runlength

    @min_runlength.setter
    def min_runlength(self, value):
        self._min_runlength = value

    @property
    def xy(self):
        return self._xy

    @xy.setter
    def xy(self, value):
        self._xy = value

    @property
    def hdir(self):
        return self._hdir

    @hdir.setter
    def hdir(self, value):
        self._hdir = value

    @property
    def speed(self):
        return self._speed

    @speed.setter
    def speed(self, value):
        self._speed = value

    @property
    def pos_samples_for_spike(self):
        return self._pos_samples_for_spike

    @pos_samples_for_spike.setter
    def pos_samples_for_spike(self, value):
        self._pos_samples_for_spike = value

    def _rolling_window(self, a: np.array, window: int):
        """
        Totally nabbed from SO:
        https://stackoverflow.com/questions/6811183/rolling-window-for-1d-arrays-in-numpy
        """
        shape = a.shape[:-1] + (a.shape[-1] - window + 1, window)
        strides = a.strides + (a.strides[-1], )
        return np.lib.stride_tricks.as_strided(a, shape=shape, strides=strides)

    def getPosIndices(self):
        self.pos_samples_for_spike = np.floor(self.spike_times *
                                              self.pos_sample_rate).astype(int)

    def getClusterPosIndices(self, cluster: int) -> np.array:
        if self.pos_samples_for_spike is None:
            self.getPosIndices()
        cluster_pos_indices = self.pos_samples_for_spike[self.spk_clusters ==
                                                         cluster]
        cluster_pos_indices[cluster_pos_indices >= len(self.pos_times)] = (
            len(self.pos_times) - 1)
        return cluster_pos_indices

    def getClusterSpikeTimes(self, cluster: int):
        ts = self.spike_times[self.spk_clusters == cluster]
        if self.pos_samples_for_spike is None:
            self.getPosIndices()
        return ts

    def getDirectionalBinPerPosition(self, binwidth: int):
        """
        Direction is in degrees as that what is created by me in some of the
        other bits of this package.

        Parameters
        ----------
        binwidth : int - binsizethe bin width in degrees

        Outputs
        -------
        A digitization of which directional bin each position sample belongs to
        """

        bins = np.arange(0, 360, binwidth)
        return np.digitize(self.hdir, bins)

    def getDirectionalBinForCluster(self, cluster: int):
        b = self.getDirectionalBinPerPosition(45)
        cluster_pos = self.getClusterPosIndices(cluster)
        # idx_to_keep = cluster_pos < len(self.pos_times)
        # cluster_pos = cluster_pos[idx_to_keep]
        return b[cluster_pos]

    def getRunsOfMinLength(self):
        """
        Identifies runs of at least self.min_runlength seconds long,
        which at 30Hz pos sampling rate equals 12 samples, and
        returns the start and end indices at which
        the run was occurred and the directional bin that run belongs to

        Returns
        -------
        np.array - the start and end indices into position samples of the run
                          and the directional bin to which it belongs
        """

        b = self.getDirectionalBinPerPosition(45)
        # nabbed from SO
        from itertools import groupby

        grouped_runs = [(k, sum(1 for i in g)) for k, g in groupby(b)]
        grouped_runs = np.array(grouped_runs)
        run_start_indices = np.cumsum(grouped_runs[:, 1]) - grouped_runs[:, 1]
        min_len_in_samples = int(self.pos_sample_rate * self.min_runlength)
        min_len_runs_mask = grouped_runs[:, 1] >= min_len_in_samples
        ret = np.array([
            run_start_indices[min_len_runs_mask],
            grouped_runs[min_len_runs_mask, 1]
        ]).T
        # ret contains run length as last column
        ret = np.insert(ret, 1, np.sum(ret, 1), 1)
        ret = np.insert(ret, 2, grouped_runs[min_len_runs_mask, 0], 1)
        return ret[:, 0:3]

    def speedFilterRuns(self, runs: np.array, minspeed=5.0):
        """
        Given the runs identified in getRunsOfMinLength, filter for speed
        and return runs that meet the min speed criteria

        The function goes over the runs with a moving window of length equal
        to self.min_runlength in samples and sees if any of those segments
        meets the speed criteria and splits them out into separate runs if true

        NB For now this means the same spikes might get included in the
        autocorrelation procedure later as the
        moving window will use overlapping periods - can be modified later

        Parameters
        ----------
        runs - 3 x nRuns np.array generated from getRunsOfMinLength
        minspeed - float - min running speed in cm/s for an epoch (minimum
                                        epoch length defined previously
                            in getRunsOfMinLength as minlength, usually 0.4s)

        Returns
        -------
        3 x nRuns np.array - A modified version of the "runs" input variable
        """
        minlength_in_samples = int(self.pos_sample_rate * self.min_runlength)
        run_list = runs.tolist()
        all_speed = np.array(self.speed)
        for start_idx, end_idx, dir_bin in run_list:
            this_runs_speed = all_speed[start_idx:end_idx]
            this_runs_runs = self._rolling_window(this_runs_speed,
                                                  minlength_in_samples)
            run_mask = np.all(this_runs_runs > minspeed, 1)
            if np.any(run_mask):
                print("got one")

    """
    def testing(self, cluster: int):
        ts = self.getClusterSpikeTimes(cluster)
        pos_idx = self.getClusterPosIndices(cluster)

        dir_bins = self.getDirectionalBinPerPosition(45)
        cluster_dir_bins = dir_bins[pos_idx.astype(int)]

        from scipy.signal import periodogram, boxcar, filtfilt

        acorrs = []
        max_freqs = []
        max_idx = []
        isis = []

        acorr_range = np.array([-500, 500])
        for i in range(1, 9):
            this_bin_indices = cluster_dir_bins == i
            this_ts = ts[this_bin_indices]  # in seconds still so * 1000 for ms
            y = self.spikeCalcs.xcorr(this_ts*1000, Trange=acorr_range)
            isis.append(y)
            corr, acorr_bins = np.histogram(
                y[y != 0], bins=501, range=acorr_range)
            freqs, power = periodogram(corr, fs=200, return_onesided=True)
            # Smooth the power over +/- 1Hz
            b = boxcar(3)
            h = filtfilt(b, 3, power)
            # Square the amplitude first
            sqd_amp = h ** 2
            # Then find the mean power in the +/-1Hz band either side of that
            theta_band_max_idx = np.nonzero(
                sqd_amp == np.max(
                    sqd_amp[np.logical_and(freqs > 6, freqs < 11)]))[0][0]
            max_freq = freqs[theta_band_max_idx]
            acorrs.append(corr)
            max_freqs.append(max_freq)
            max_idx.append(theta_band_max_idx)
        return isis, acorrs, max_freqs, max_idx, acorr_bins

    def plotXCorrsByDirection(self, cluster: int):
        acorr_range = np.array([-500, 500])
        # plot_range = np.array([-400,400])
        nbins = 501
        isis, acorrs, max_freqs, max_idx, acorr_bins = self.testing(cluster)
        bin_labels = np.arange(0, 360, 45)
        fig, axs = plt.subplots(8)
        pts = []
        for i, a in enumerate(isis):
            axs[i].hist(
                a[a != 0], bins=nbins, range=acorr_range,
                color='k', histtype='stepfilled')
            # find the max of the first positive peak
            corr, _ = np.histogram(a[a != 0], bins=nbins, range=acorr_range)
            axs[i].set_xlim(acorr_range)
            axs[i].set_ylabel(str(bin_labels[i]))
            axs[i].set_yticklabels('')
            if i < 7:
                axs[i].set_xticklabels('')
            axs[i].spines['right'].set_visible(False)
            axs[i].spines['top'].set_visible(False)
            axs[i].spines['left'].set_visible(False)
        plt.show()
        return pts
    """

    def intrinsic_freq_autoCorr(
        self,
        spkTimes=None,
        posMask=None,
        maxFreq=25,
        acBinSize=0.002,
        acWindow=0.5,
        plot=True,
        **kwargs,
    ):
        """
        This is taken and adapted from ephysiopy.common.eegcalcs.EEGCalcs

        Parameters
        ----------
        spkTimes - np.array of times in seconds of the cells firing
        posMask - boolean array corresponding to the length of spkTimes
                            where True is stuff to keep
        maxFreq - the maximum frequency to do the power spectrum out to
        acBinSize - the bin size of the autocorrelogram in seconds
        acWindow - the range of the autocorr in seconds

        NB Make sure all times are in seconds
        """
        acBinsPerPos = 1.0 / self.pos_sample_rate / acBinSize
        acWindowSizeBins = np.round(acWindow / acBinSize)
        binCentres = np.arange(0.5, len(posMask) * acBinsPerPos) * acBinSize
        spkTrHist, _ = np.histogram(spkTimes, bins=binCentres)

        # split the single histogram into individual chunks
        splitIdx = np.nonzero(np.diff(posMask.astype(int)))[0] + 1
        splitMask = np.split(posMask, splitIdx)
        splitSpkHist = np.split(spkTrHist,
                                (splitIdx * acBinsPerPos).astype(int))
        histChunks = []
        for i in range(len(splitSpkHist)):
            if np.all(splitMask[i]):
                if np.sum(splitSpkHist[i]) > 2:
                    if len(splitSpkHist[i]) > int(acWindowSizeBins) * 2:
                        histChunks.append(splitSpkHist[i])
        autoCorrGrid = np.zeros((int(acWindowSizeBins) + 1, len(histChunks)))
        chunkLens = []
        from scipy import signal

        print(f"num chunks = {len(histChunks)}")
        for i in range(len(histChunks)):
            lenThisChunk = len(histChunks[i])
            chunkLens.append(lenThisChunk)
            tmp = np.zeros(lenThisChunk * 2)
            tmp[lenThisChunk // 2:lenThisChunk // 2 +
                lenThisChunk] = histChunks[i]
            tmp2 = signal.fftconvolve(tmp, histChunks[i][::-1],
                                      mode="valid")  # the autocorrelation
            autoCorrGrid[:,
                         i] = (tmp2[lenThisChunk // 2:lenThisChunk // 2 +
                                    int(acWindowSizeBins) + 1] / acBinsPerPos)

        totalLen = np.sum(chunkLens)
        autoCorrSum = np.nansum(autoCorrGrid, 1) / totalLen
        meanNormdAc = autoCorrSum[1::] - np.nanmean(autoCorrSum[1::])
        # return meanNormdAc
        out = self.power_spectrum(
            eeg=meanNormdAc,
            binWidthSecs=acBinSize,
            maxFreq=maxFreq,
            pad2pow=16,
            **kwargs,
        )
        out.update({"meanNormdAc": meanNormdAc})
        if plot:
            fig = plt.gcf()
            ax = fig.gca()
            xlim = ax.get_xlim()
            ylim = ax.get_ylim()
            ax.imshow(
                autoCorrGrid,
                extent=[
                    maxFreq * 0.6,
                    maxFreq,
                    np.max(out["Power"]) * 0.6,
                    ax.get_ylim()[1],
                ],
            )
            ax.set_ylim(ylim)
            ax.set_xlim(xlim)
        return out

    def power_spectrum(
        self,
        eeg,
        plot=True,
        binWidthSecs=None,
        maxFreq=25,
        pad2pow=None,
        ymax=None,
        **kwargs,
    ):
        """
        Method used by eeg_power_spectra and intrinsic_freq_autoCorr
        Signal in must be mean normalised already
        """

        # Get raw power spectrum
        nqLim = 1 / binWidthSecs / 2.0
        origLen = len(eeg)
        # if pad2pow is None:
        # 	fftLen = int(np.power(2, self._nextpow2(origLen)))
        # else:
        fftLen = int(np.power(2, pad2pow))
        fftHalfLen = int(fftLen / float(2) + 1)

        fftRes = np.fft.fft(eeg, fftLen)
        # get power density from fft and discard second half of spectrum
        _power = np.power(np.abs(fftRes), 2) / origLen
        power = np.delete(_power, np.s_[fftHalfLen::])
        power[1:-2] = power[1:-2] * 2

        # calculate freqs and crop spectrum to requested range
        freqs = nqLim * np.linspace(0, 1, fftHalfLen)
        freqs = freqs[freqs <= maxFreq].T
        power = power[0:len(freqs)]

        # smooth spectrum using gaussian kernel
        binsPerHz = (fftHalfLen - 1) / nqLim
        kernelLen = np.round(self.smthKernelWidth * binsPerHz)
        kernelSig = self.smthKernelSigma * binsPerHz
        from scipy import signal

        k = signal.gaussian(kernelLen, kernelSig) / (kernelLen / 2 / 2)
        power_sm = signal.fftconvolve(power, k[::-1], mode="same")

        # calculate some metrics
        # find max in theta band
        spectrumMaskBand = np.logical_and(freqs > self.thetaRange[0],
                                          freqs < self.thetaRange[1])
        bandMaxPower = np.max(power_sm[spectrumMaskBand])
        maxBinInBand = np.argmax(power_sm[spectrumMaskBand])
        bandFreqs = freqs[spectrumMaskBand]
        freqAtBandMaxPower = bandFreqs[maxBinInBand]
        # self.maxBinInBand = maxBinInBand
        # self.freqAtBandMaxPower = freqAtBandMaxPower
        # self.bandMaxPower = bandMaxPower

        # find power in small window around peak and divide by power in rest
        # of spectrum to get snr
        spectrumMaskPeak = np.logical_and(
            freqs > freqAtBandMaxPower - self.sn2Width / 2,
            freqs < freqAtBandMaxPower + self.sn2Width / 2,
        )
        s2n = np.nanmean(power_sm[spectrumMaskPeak]) / np.nanmean(
            power_sm[~spectrumMaskPeak])
        self.freqs = freqs
        self.power_sm = power_sm
        self.spectrumMaskPeak = spectrumMaskPeak
        if plot:
            fig = plt.figure()
            ax = fig.add_subplot(111)
            if ymax is None:
                ymax = np.min([2 * np.max(power), np.max(power_sm)])
                if ymax == 0:
                    ymax = 1
            ax.plot(freqs, power, c=[0.9, 0.9, 0.9])
            # ax.hold(True)
            ax.plot(freqs, power_sm, "k", lw=2)
            ax.axvline(self.thetaRange[0], c="b", ls="--")
            ax.axvline(self.thetaRange[1], c="b", ls="--")
            _, stemlines, _ = ax.stem([freqAtBandMaxPower], [bandMaxPower],
                                      linefmt="r")
            # plt.setp(stemlines, 'linewidth', 2)
            ax.fill_between(
                freqs,
                0,
                power_sm,
                where=spectrumMaskPeak,
                color="r",
                alpha=0.25,
                zorder=25,
            )
            # ax.set_ylim(0, ymax)
            # ax.set_xlim(0, self.xmax)
            ax.set_xlabel("Frequency (Hz)")
            ax.set_ylabel("Power density (W/Hz)")
        out_dict = {
            "maxFreq": freqAtBandMaxPower,
            "Power": power_sm,
            "Freqs": freqs,
            "s2n": s2n,
            "Power_raw": power,
            "k": k,
            "kernelLen": kernelLen,
            "kernelSig": kernelSig,
            "binsPerHz": binsPerHz,
            "kernelLen": kernelLen,
        }
        return out_dict
예제 #12
0
    def __init__(
        self,
        lfp_sig: np.array,
        lfp_fs: int,
        xy: np.array,
        spike_ts: np.array,
        pos_ts: np.array,
        pp_config: dict = phase_precession_config,
    ):

        [setattr(self, k, pp_config[k]) for k in pp_config.keys()]

        self._pos_ts = pos_ts

        # Create a dict to hold the stats values
        stats_dict = {
            "values": None,
            "pha": None,
            "slope": None,
            "intercept": None,
            "cor": None,
            "p": None,
            "cor_boot": None,
            "p_shuffled": None,
            "ci": None,
            "reg": None,
        }
        # Create a dict of regressors to hold stat values
        # for each regressor
        from collections import defaultdict

        self.regressors = {}
        self.regressors = defaultdict(lambda: stats_dict.copy(),
                                      self.regressors)
        regressor_keys = [
            "spk_numWithinRun",
            "pos_exptdRate_cum",
            "pos_instFR",
            "pos_timeInRun",
            "pos_d_cum",
            "pos_d_meanDir",
            "pos_d_currentdir",
            "spk_thetaBatchLabelInRun",
        ]
        [self.regressors[k] for k in regressor_keys]
        # each of the regressors in regressor_keys is a key with a value
        # of stats_dict

        self.k = 1000
        self.alpha = 0.05
        self.hyp = 0
        self.conf = True

        # Process the EEG data a bit...
        self.eeg = lfp_sig
        L = LFPOscillations(lfp_sig, lfp_fs)
        filt_sig, phase, _, _ = L.getFreqPhase(lfp_sig, [6, 12], 2)
        self.filteredEEG = filt_sig
        self.phase = phase
        self.phaseAdj = None

        # ... and the position data
        P = PosCalcsGeneric(
            xy[0, :],
            xy[1, :],
            ppm=self.ppm,
            cm=True,
        )
        P.postprocesspos(tracker_params={"AxonaBadValue": 1023})
        # ... do the ratemap creation here once
        R = RateMap(P.xy, P.dir, P.speed)
        R.cmsPerBin = self.cms_per_bin
        R.smooth_sz = self.field_smoothing_kernel_len
        R.ppm = self.ppm
        spk_times_in_pos_samples = self.getSpikePosIndices(spike_ts)
        spk_weights = np.bincount(spk_times_in_pos_samples,
                                  minlength=len(self.pos_ts))
        self.spk_times_in_pos_samples = spk_times_in_pos_samples
        self.spk_weights = spk_weights
        self.RateMap = R  # this will be used a fair bit below

        self.spike_ts = spike_ts