def __init__( self, spike_times: np.array, pos_times: np.array, spk_clusters: np.array, x: np.array, y: np.array, tracker_params={}, ): """ Parameters ---------- spike_times - 1d np.array pos_times - 1d np.array spk_clusters - 1d np.array x and y - 1d np.array tracker_params - dict - from the PosTracker as created in OEKiloPhy.Settings.parse NB All timestamps should be given in sub-millisecond accurate seconds and pos_xy in cms """ self.spike_times = spike_times self.pos_times = pos_times self.spk_clusters = spk_clusters """ There can be more spikes than pos samples in terms of sampling as the open-ephys buffer probably needs to finish writing and the camera has already stopped, so cut of any cluster indices and spike times that exceed the length of the pos indices """ idx_to_keep = self.spike_times < self.pos_times[-1] self.spike_times = self.spike_times[idx_to_keep] self.spk_clusters = self.spk_clusters[idx_to_keep] self._pos_sample_rate = 30 self._spk_sample_rate = 3e4 self._pos_samples_for_spike = None self._min_runlength = 0.4 # in seconds self.posCalcs = PosCalcsGeneric(x, y, 230, cm=True, jumpmax=100, tracker_params=tracker_params) self.spikeCalcs = SpikeCalcsGeneric(spike_times) self.spikeCalcs.spk_clusters = spk_clusters self.posCalcs.postprocesspos(tracker_params) xy = self.posCalcs.xy hdir = self.posCalcs.dir self.posCalcs.calcSpeed(xy) self._xy = xy self._hdir = hdir self._speed = self.posCalcs.speed # TEMPORARY FOR POWER SPECTRUM STUFF self.smthKernelWidth = 2 self.smthKernelSigma = 0.1875 self.sn2Width = 2 self.thetaRange = [7, 11] self.xmax = 11
def exportPos(self, ppm=300, jumpmax=100, as_text=False): # # Step 1) Deal with the position data first: # # Grab the settings of the pos tracker and do some post-processing on the position # data (discard jumpy data, do some smoothing etc) # settings = OESettings.Settings(os.path.join(self.dirname, 'settings.xml')) self.settings.parsePos() posProcessor = PosCalcsGeneric(self.OE_data.xy[:,0], self.OE_data.xy[:,1], ppm, True, jumpmax) print("Post-processing position data...") xy, _ = posProcessor.postprocesspos(self.settings.tracker_params) xy = xy.T if as_text is True: print("Beginning export of position data to text format...") pos_file_name = self.axona_root_name + ".txt" np.savetxt(pos_file_name, self.OE_data.xy, fmt='%1.u') print("Completed export of position data") return # Do the upsampling of both xy and the timestamps print("Beginning export of position data to Axona format...") axona_pos_file_name = self.axona_root_name + ".pos" axona_pos_data = self.convertPosData(xy, self.OE_data.xyTS) # make sure pos data length is same as duration * num_samples axona_pos_data = axona_pos_data[0:int(self.last_pos_ts - self.first_pos_ts)*50] # Create an empty header for the pos data pos_header = self.AxonaData.getEmptyHeader("pos") for key in pos_header.keys(): if 'min_x' in key: pos_header[key] = str(self.settings.tracker_params['LeftBorder']) if 'min_y' in key: pos_header[key] = str(self.settings.tracker_params['TopBorder']) if 'max_x' in key: pos_header[key] = str(self.settings.tracker_params['RightBorder']) if 'max_y' in key: pos_header[key] = str(self.settings.tracker_params['BottomBorder']) pos_header['duration'] = str(int(self.last_pos_ts - self.first_pos_ts)) # Rest of this stuff probably won't change so should be defaulted in the loaded file # (see axonaIO.py) pos_header['num_colours'] = '4' pos_header['sw_version'] = '1.2.2.1' pos_header['timebase'] = '50 hz' pos_header['sample_rate'] = '50.0 hz' pos_header['pos_format'] = 't,x1,y1,x2,y2,numpix1,numpix2' pos_header['bytes_per_coord'] = '2' pos_header['EEG_samples_per_position'] = '5' pos_header['bytes_per_timestamp'] = '4' pos_header['pixels_per_metre'] = str(ppm) pos_header['num_pos_samples'] = str(len(axona_pos_data)) pos_header['bearing_colour_1'] = '210' pos_header['bearing_colour_2'] = '30' pos_header['bearing_colour_3'] = '0' pos_header['bearing_colour_4'] = '0' pos_header['pixels_per_metre'] = str(ppm) self.writePos2AxonaFormat(pos_header, axona_pos_data) print("Exported position data to Axona format")
def loadPos(self, *args, **kwargs): # Only sub-class that doesn't use this is OpenEphysNWB # which needs updating # TODO: Update / overhaul OpenEphysNWB # Load the start time from the sync_messages file recording_start_time = 0 if self.sync_message_file is not None: with open(self.sync_message_file, "r") as f: sync_strs = f.read() sync_lines = sync_strs.split("\n") for line in sync_lines: if "subProcessor: 0" in line: idx = line.find("start time: ") start_val = line[idx + len("start time: "):-1] tmp = start_val.split("@") recording_start_time = float(tmp[0]) if self.path2PosData is not None: pos_data_type = getattr(self, "pos_data_type", "PosTracker") if pos_data_type == "PosTracker": print("Loading PosTracker data...") pos_data = np.load( os.path.join(self.path2PosData, "data_array.npy")) if pos_data_type == "TrackingPlugin": print("Loading Tracking Plugin data...") pos_data = loadTrackingPluginData( os.path.join(self.path2PosData, "data_array.npy")) pos_ts = np.load(os.path.join(self.path2PosData, "timestamps.npy")) pos_ts = np.ravel(pos_ts) pos_timebase = getattr(self, "pos_timebase", 3e4) sample_rate = np.floor(1 / np.mean(np.diff(pos_ts) / pos_timebase)) self.xyTS = pos_ts - recording_start_time pos_timebase = getattr(self, "pos_timebase", 3e4) self.xyTS = self.xyTS / pos_timebase # convert to seconds if self.sync_message_file is not None: recording_start_time = self.xyTS[0] self.pos_sample_rate = sample_rate self.orig_x = pos_data[:, 0] self.orig_y = pos_data[:, 1] P = PosCalcsGeneric( pos_data[:, 0], pos_data[:, 1], cm=True, ppm=self.ppm, jumpmax=self.jumpmax, ) P.postprocesspos({"SampleRate": sample_rate}) setattr(self, "PosCalcs", P) self.xy = P.xy self.dir = P.dir self.speed = P.speed else: warnings.warn("Could not find the pos data. \ Make sure there is a pos_data folder with data_array.npy \ and timestamps.npy in") self.recording_start_time = recording_start_time
def exportPos(self, ppm=300, jumpmax=100, as_text=False): # # Step 1) Deal with the position data first: # # Grab the settings of the pos tracker and do some post-processing # on the position # data (discard jumpy data, do some smoothing etc) self.settings.parse() posProcessor = PosCalcsGeneric(self.OE_data.xy[:, 0], self.OE_data.xy[:, 1], ppm, True, jumpmax) print("Post-processing position data...") self.settings.tracker_params["AxonaBadValue"] = 1023 posProcessor.postprocesspos(self.settings.tracker_params) xy = posProcessor.xy.T if as_text is True: print("Beginning export of position data to text format...") pos_file_name = self.axona_root_name + ".txt" np.savetxt(pos_file_name, self.OE_data.xy, fmt="%1.u") print("Completed export of position data") return # Do the upsampling of both xy and the timestamps print("Beginning export of position data to Axona format...") axona_pos_data = self.convertPosData(xy, self.OE_data.xyTS) # make sure pos data length is same as duration * num_samples axona_pos_data = axona_pos_data[0:int(self.last_pos_ts - self.first_pos_ts) * 50] # Create an empty header for the pos data from ephysiopy.dacq2py.axona_headers import PosHeader pos_header = PosHeader() pos_header.pos["min_x"] = str( self.settings.tracker_params["LeftBorder"]) pos_header.pos[".min_y"] = str( self.settings.tracker_params["TopBorder"]) pos_header.pos[".max_x"] = str( self.settings.tracker_params["RightBorder"]) pos_header.pos[".max_y"] = str( self.settings.tracker_params["BottomBorder"]) pos_header.common["duration"] = str( int(self.last_pos_ts - self.first_pos_ts)) pos_header.pos["pixels_per_metre"] = str(ppm) pos_header.pos["num_pos_samples"] = str(len(axona_pos_data)) pos_header.pos["pixels_per_metre"] = str(ppm) self.writePos2AxonaFormat(pos_header, axona_pos_data) print("Exported position data to Axona format")
def basic_PosCalcs(basic_xy): ''' Returns a PosCalcsGeneric instance initialised with some random walk xy data ''' x = basic_xy[0] y = basic_xy[1] ppm = 300 # pixels per metre value return PosCalcsGeneric(x, y, ppm)
def plotMapsOneAtATime(self, plot_type='map', **kwargs): """ Parameters ---------- plot_type : str or list The kind of plot to produce. Valid strings include: * 'map' - just ratemap plotted * 'path' - just spikes on path * 'both' - both of the above * 'all' - both spikes on path, ratemap & SAC plotted kwargs : * 'ppm' - Integer denoting pixels per metre where lower values = more bins in ratemap / SAC * 'clusters' - int or list of ints describing which clusters to plot * 'save_grid_summary_location' - bool; if True the dictionary returned from gridcell.SAC.getMeasures is saved for each cluster """ if self.kilodata is None: self.loadKilo() if ( 'ppm' in kwargs.keys() ): ppm = kwargs['ppm'] else: ppm = 400 from ephysiopy.common.ephys_generic import PosCalcsGeneric, MapCalcsGeneric if self.xy is None: self.__loaddata__(**kwargs) posProcessor = PosCalcsGeneric(self.xy[:,0], self.xy[:,1], ppm, jumpmax=self.jumpmax) import os self.__loadSettings__() xy, hdir = posProcessor.postprocesspos(self.settings.tracker_params) self.hdir = hdir spk_times = (self.kilodata.spk_times.T / 3e4) + self.recording_start_time mapiter = MapCalcsGeneric(xy, np.squeeze(hdir), posProcessor.speed, self.xyTS, spk_times, plot_type, **kwargs) if 'clusters' in kwargs: if type(kwargs['clusters']) == int: mapiter.good_clusters = np.intersect1d([kwargs['clusters']], self.kilodata.good_clusters) else: mapiter.good_clusters = np.intersect1d(kwargs['clusters'], self.kilodata.good_clusters) else: mapiter.good_clusters = self.kilodata.good_clusters mapiter.spk_clusters = self.kilodata.spk_clusters self.mapiter = mapiter [ print("") for cluster in mapiter ]
def prepareMaps(self, **kwargs): """Initialises a MapCalcsGeneric object by providing it with positional and spiking data. I don't like the name of this method but it is useful to be able to separate out the preparation of the MapCalcsGeneric object as there are two major uses; actually plotting the maps and/ or extracting data from them without plotting """ if self.kilodata is None: self.loadKilo() if ( 'ppm' in kwargs.keys() ): ppm = kwargs['ppm'] else: ppm = 400 from ephysiopy.common.ephys_generic import PosCalcsGeneric, MapCalcsGeneric if self.xy is None: self.__loaddata__(**kwargs) posProcessor = PosCalcsGeneric(self.xy[:,0], self.xy[:,1], ppm, jumpmax=self.jumpmax) import os self.__loadSettings__() xy, hdir = posProcessor.postprocesspos(self.settings.tracker_params) self.hdir = hdir spk_times = (self.kilodata.spk_times.T / 3e4) + self.recording_start_time if 'plot_type' in kwargs: plot_type = kwargs['plot_type'] else: plot_type = 'map' mapiter = MapCalcsGeneric(xy, np.squeeze(hdir), posProcessor.speed, self.xyTS, spk_times, plot_type, **kwargs) if 'cluster' in kwargs: if type(kwargs['cluster']) == int: mapiter.good_clusters = np.intersect1d([kwargs['cluster']], self.kilodata.good_clusters) else: mapiter.good_clusters = np.intersect1d(kwargs['cluster'], self.kilodata.good_clusters) else: mapiter.good_clusters = self.kilodata.good_clusters mapiter.spk_clusters = self.kilodata.spk_clusters self.mapiter = mapiter return mapiter
def load(self, *args, **kwargs): """ Minially, there should be at least a .set file Other files (.eeg, .pos, .stm, .1, .2 etc) are essentially optional """ if self.settings is not None: print("Loaded .set file") # Give ppm a default value from the set file... self.ppm = int(self.settings["tracker_pixels_per_metre"]) # ...with the option to over-ride if "ppm" in kwargs: self.ppm = kwargs["ppm"] # ------------------------------------ # ------------- Pos data ------------- # ------------------------------------ if self.xy is None: try: AxonaPos = axonaIO.Pos(self.common_name) P = PosCalcsGeneric( AxonaPos.led_pos[:, 0], AxonaPos.led_pos[:, 1], cm=True, ppm=self.ppm, ) P.postprocesspos(tracker_params={"AxonaBadValue": 1023}) self.xy = P.xy self.xyTS = AxonaPos.ts - AxonaPos.ts[0] self.dir = P.dir self.speed = P.speed self.pos_sample_rate = AxonaPos.getHeaderVal( AxonaPos.header, "sample_rate" ) print("Loaded .pos file") except IOError: print("Couldn't load the pos data")
def plotPos(self, jumpmax=None, show=True, **kwargs): """ Plots x vs y position for the current trial Parameters ---------- jumpmax : int The max amount the LED is allowed to instantaneously move show : bool Whether to plot the pos into a figure window or not (default True) Returns ---------- xy : array_like positional data following post-processing """ if jumpmax is None: jumpmax = self.jumpmax import matplotlib.pylab as plt from ephysiopy.common.ephys_generic import PosCalcsGeneric self.__loadSettings__() if self.xy is None: self.__loaddata__(**kwargs) posProcessor = PosCalcsGeneric(self.xy[:,0], self.xy[:,1], ppm=300, cm=True, jumpmax=jumpmax) xy, hdir = posProcessor.postprocesspos(self.settings.tracker_params) self.hdir = hdir if 'saveas' in kwargs: saveas = kwargs['saveas'] plt.plot(xy[0], xy[1]) plt.gca().invert_yaxis() plt.savefig(saveas) if show: plt.plot(xy[0], xy[1]) plt.gca().invert_yaxis() ax = plt.gca() return ax, xy return xy
def load_pos_data(self, pname: Path, ppm: int = 300, jumpmax: int = 100) -> None: if self.PosCalcs is None: try: AxonaPos = Pos(self.pname) P = PosCalcsGeneric( AxonaPos.led_pos[0, :], AxonaPos.led_pos[1, :], cm=True, ppm=self.ppm, ) P.xyTS = Pos.ts P.sample_rate = AxonaPos.getHeaderVal(AxonaPos.header, "sample_rate") P.postprocesspos() print("Loaded pos data") self.PosCalcs = P except IOError: print("Couldn't load the pos data")
class CosineDirectionalTuning(object): """ Produces output to do with Welday et al (2011) like analysis of rhythmic firing a la oscialltory interference model """ def __init__( self, spike_times: np.array, pos_times: np.array, spk_clusters: np.array, x: np.array, y: np.array, tracker_params={}, ): """ Parameters ---------- spike_times - 1d np.array pos_times - 1d np.array spk_clusters - 1d np.array x and y - 1d np.array tracker_params - dict - from the PosTracker as created in OEKiloPhy.Settings.parse NB All timestamps should be given in sub-millisecond accurate seconds and pos_xy in cms """ self.spike_times = spike_times self.pos_times = pos_times self.spk_clusters = spk_clusters """ There can be more spikes than pos samples in terms of sampling as the open-ephys buffer probably needs to finish writing and the camera has already stopped, so cut of any cluster indices and spike times that exceed the length of the pos indices """ idx_to_keep = self.spike_times < self.pos_times[-1] self.spike_times = self.spike_times[idx_to_keep] self.spk_clusters = self.spk_clusters[idx_to_keep] self._pos_sample_rate = 30 self._spk_sample_rate = 3e4 self._pos_samples_for_spike = None self._min_runlength = 0.4 # in seconds self.posCalcs = PosCalcsGeneric(x, y, 230, cm=True, jumpmax=100, tracker_params=tracker_params) self.spikeCalcs = SpikeCalcsGeneric(spike_times) self.spikeCalcs.spk_clusters = spk_clusters self.posCalcs.postprocesspos(tracker_params) xy = self.posCalcs.xy hdir = self.posCalcs.dir self.posCalcs.calcSpeed(xy) self._xy = xy self._hdir = hdir self._speed = self.posCalcs.speed # TEMPORARY FOR POWER SPECTRUM STUFF self.smthKernelWidth = 2 self.smthKernelSigma = 0.1875 self.sn2Width = 2 self.thetaRange = [7, 11] self.xmax = 11 @property def spk_sample_rate(self): return self._spk_sample_rate @spk_sample_rate.setter def spk_sample_rate(self, value): self._spk_sample_rate = value @property def pos_sample_rate(self): return self._pos_sample_rate @pos_sample_rate.setter def pos_sample_rate(self, value): self._pos_sample_rate = value @property def min_runlength(self): return self._min_runlength @min_runlength.setter def min_runlength(self, value): self._min_runlength = value @property def xy(self): return self._xy @xy.setter def xy(self, value): self._xy = value @property def hdir(self): return self._hdir @hdir.setter def hdir(self, value): self._hdir = value @property def speed(self): return self._speed @speed.setter def speed(self, value): self._speed = value @property def pos_samples_for_spike(self): return self._pos_samples_for_spike @pos_samples_for_spike.setter def pos_samples_for_spike(self, value): self._pos_samples_for_spike = value def _rolling_window(self, a: np.array, window: int): """ Totally nabbed from SO: https://stackoverflow.com/questions/6811183/rolling-window-for-1d-arrays-in-numpy """ shape = a.shape[:-1] + (a.shape[-1] - window + 1, window) strides = a.strides + (a.strides[-1], ) return np.lib.stride_tricks.as_strided(a, shape=shape, strides=strides) def getPosIndices(self): self.pos_samples_for_spike = np.floor(self.spike_times * self.pos_sample_rate).astype(int) def getClusterPosIndices(self, cluster: int) -> np.array: if self.pos_samples_for_spike is None: self.getPosIndices() cluster_pos_indices = self.pos_samples_for_spike[self.spk_clusters == cluster] cluster_pos_indices[cluster_pos_indices >= len(self.pos_times)] = ( len(self.pos_times) - 1) return cluster_pos_indices def getClusterSpikeTimes(self, cluster: int): ts = self.spike_times[self.spk_clusters == cluster] if self.pos_samples_for_spike is None: self.getPosIndices() return ts def getDirectionalBinPerPosition(self, binwidth: int): """ Direction is in degrees as that what is created by me in some of the other bits of this package. Parameters ---------- binwidth : int - binsizethe bin width in degrees Outputs ------- A digitization of which directional bin each position sample belongs to """ bins = np.arange(0, 360, binwidth) return np.digitize(self.hdir, bins) def getDirectionalBinForCluster(self, cluster: int): b = self.getDirectionalBinPerPosition(45) cluster_pos = self.getClusterPosIndices(cluster) # idx_to_keep = cluster_pos < len(self.pos_times) # cluster_pos = cluster_pos[idx_to_keep] return b[cluster_pos] def getRunsOfMinLength(self): """ Identifies runs of at least self.min_runlength seconds long, which at 30Hz pos sampling rate equals 12 samples, and returns the start and end indices at which the run was occurred and the directional bin that run belongs to Returns ------- np.array - the start and end indices into position samples of the run and the directional bin to which it belongs """ b = self.getDirectionalBinPerPosition(45) # nabbed from SO from itertools import groupby grouped_runs = [(k, sum(1 for i in g)) for k, g in groupby(b)] grouped_runs = np.array(grouped_runs) run_start_indices = np.cumsum(grouped_runs[:, 1]) - grouped_runs[:, 1] min_len_in_samples = int(self.pos_sample_rate * self.min_runlength) min_len_runs_mask = grouped_runs[:, 1] >= min_len_in_samples ret = np.array([ run_start_indices[min_len_runs_mask], grouped_runs[min_len_runs_mask, 1] ]).T # ret contains run length as last column ret = np.insert(ret, 1, np.sum(ret, 1), 1) ret = np.insert(ret, 2, grouped_runs[min_len_runs_mask, 0], 1) return ret[:, 0:3] def speedFilterRuns(self, runs: np.array, minspeed=5.0): """ Given the runs identified in getRunsOfMinLength, filter for speed and return runs that meet the min speed criteria The function goes over the runs with a moving window of length equal to self.min_runlength in samples and sees if any of those segments meets the speed criteria and splits them out into separate runs if true NB For now this means the same spikes might get included in the autocorrelation procedure later as the moving window will use overlapping periods - can be modified later Parameters ---------- runs - 3 x nRuns np.array generated from getRunsOfMinLength minspeed - float - min running speed in cm/s for an epoch (minimum epoch length defined previously in getRunsOfMinLength as minlength, usually 0.4s) Returns ------- 3 x nRuns np.array - A modified version of the "runs" input variable """ minlength_in_samples = int(self.pos_sample_rate * self.min_runlength) run_list = runs.tolist() all_speed = np.array(self.speed) for start_idx, end_idx, dir_bin in run_list: this_runs_speed = all_speed[start_idx:end_idx] this_runs_runs = self._rolling_window(this_runs_speed, minlength_in_samples) run_mask = np.all(this_runs_runs > minspeed, 1) if np.any(run_mask): print("got one") """ def testing(self, cluster: int): ts = self.getClusterSpikeTimes(cluster) pos_idx = self.getClusterPosIndices(cluster) dir_bins = self.getDirectionalBinPerPosition(45) cluster_dir_bins = dir_bins[pos_idx.astype(int)] from scipy.signal import periodogram, boxcar, filtfilt acorrs = [] max_freqs = [] max_idx = [] isis = [] acorr_range = np.array([-500, 500]) for i in range(1, 9): this_bin_indices = cluster_dir_bins == i this_ts = ts[this_bin_indices] # in seconds still so * 1000 for ms y = self.spikeCalcs.xcorr(this_ts*1000, Trange=acorr_range) isis.append(y) corr, acorr_bins = np.histogram( y[y != 0], bins=501, range=acorr_range) freqs, power = periodogram(corr, fs=200, return_onesided=True) # Smooth the power over +/- 1Hz b = boxcar(3) h = filtfilt(b, 3, power) # Square the amplitude first sqd_amp = h ** 2 # Then find the mean power in the +/-1Hz band either side of that theta_band_max_idx = np.nonzero( sqd_amp == np.max( sqd_amp[np.logical_and(freqs > 6, freqs < 11)]))[0][0] max_freq = freqs[theta_band_max_idx] acorrs.append(corr) max_freqs.append(max_freq) max_idx.append(theta_band_max_idx) return isis, acorrs, max_freqs, max_idx, acorr_bins def plotXCorrsByDirection(self, cluster: int): acorr_range = np.array([-500, 500]) # plot_range = np.array([-400,400]) nbins = 501 isis, acorrs, max_freqs, max_idx, acorr_bins = self.testing(cluster) bin_labels = np.arange(0, 360, 45) fig, axs = plt.subplots(8) pts = [] for i, a in enumerate(isis): axs[i].hist( a[a != 0], bins=nbins, range=acorr_range, color='k', histtype='stepfilled') # find the max of the first positive peak corr, _ = np.histogram(a[a != 0], bins=nbins, range=acorr_range) axs[i].set_xlim(acorr_range) axs[i].set_ylabel(str(bin_labels[i])) axs[i].set_yticklabels('') if i < 7: axs[i].set_xticklabels('') axs[i].spines['right'].set_visible(False) axs[i].spines['top'].set_visible(False) axs[i].spines['left'].set_visible(False) plt.show() return pts """ def intrinsic_freq_autoCorr( self, spkTimes=None, posMask=None, maxFreq=25, acBinSize=0.002, acWindow=0.5, plot=True, **kwargs, ): """ This is taken and adapted from ephysiopy.common.eegcalcs.EEGCalcs Parameters ---------- spkTimes - np.array of times in seconds of the cells firing posMask - boolean array corresponding to the length of spkTimes where True is stuff to keep maxFreq - the maximum frequency to do the power spectrum out to acBinSize - the bin size of the autocorrelogram in seconds acWindow - the range of the autocorr in seconds NB Make sure all times are in seconds """ acBinsPerPos = 1.0 / self.pos_sample_rate / acBinSize acWindowSizeBins = np.round(acWindow / acBinSize) binCentres = np.arange(0.5, len(posMask) * acBinsPerPos) * acBinSize spkTrHist, _ = np.histogram(spkTimes, bins=binCentres) # split the single histogram into individual chunks splitIdx = np.nonzero(np.diff(posMask.astype(int)))[0] + 1 splitMask = np.split(posMask, splitIdx) splitSpkHist = np.split(spkTrHist, (splitIdx * acBinsPerPos).astype(int)) histChunks = [] for i in range(len(splitSpkHist)): if np.all(splitMask[i]): if np.sum(splitSpkHist[i]) > 2: if len(splitSpkHist[i]) > int(acWindowSizeBins) * 2: histChunks.append(splitSpkHist[i]) autoCorrGrid = np.zeros((int(acWindowSizeBins) + 1, len(histChunks))) chunkLens = [] from scipy import signal print(f"num chunks = {len(histChunks)}") for i in range(len(histChunks)): lenThisChunk = len(histChunks[i]) chunkLens.append(lenThisChunk) tmp = np.zeros(lenThisChunk * 2) tmp[lenThisChunk // 2:lenThisChunk // 2 + lenThisChunk] = histChunks[i] tmp2 = signal.fftconvolve(tmp, histChunks[i][::-1], mode="valid") # the autocorrelation autoCorrGrid[:, i] = (tmp2[lenThisChunk // 2:lenThisChunk // 2 + int(acWindowSizeBins) + 1] / acBinsPerPos) totalLen = np.sum(chunkLens) autoCorrSum = np.nansum(autoCorrGrid, 1) / totalLen meanNormdAc = autoCorrSum[1::] - np.nanmean(autoCorrSum[1::]) # return meanNormdAc out = self.power_spectrum( eeg=meanNormdAc, binWidthSecs=acBinSize, maxFreq=maxFreq, pad2pow=16, **kwargs, ) out.update({"meanNormdAc": meanNormdAc}) if plot: fig = plt.gcf() ax = fig.gca() xlim = ax.get_xlim() ylim = ax.get_ylim() ax.imshow( autoCorrGrid, extent=[ maxFreq * 0.6, maxFreq, np.max(out["Power"]) * 0.6, ax.get_ylim()[1], ], ) ax.set_ylim(ylim) ax.set_xlim(xlim) return out def power_spectrum( self, eeg, plot=True, binWidthSecs=None, maxFreq=25, pad2pow=None, ymax=None, **kwargs, ): """ Method used by eeg_power_spectra and intrinsic_freq_autoCorr Signal in must be mean normalised already """ # Get raw power spectrum nqLim = 1 / binWidthSecs / 2.0 origLen = len(eeg) # if pad2pow is None: # fftLen = int(np.power(2, self._nextpow2(origLen))) # else: fftLen = int(np.power(2, pad2pow)) fftHalfLen = int(fftLen / float(2) + 1) fftRes = np.fft.fft(eeg, fftLen) # get power density from fft and discard second half of spectrum _power = np.power(np.abs(fftRes), 2) / origLen power = np.delete(_power, np.s_[fftHalfLen::]) power[1:-2] = power[1:-2] * 2 # calculate freqs and crop spectrum to requested range freqs = nqLim * np.linspace(0, 1, fftHalfLen) freqs = freqs[freqs <= maxFreq].T power = power[0:len(freqs)] # smooth spectrum using gaussian kernel binsPerHz = (fftHalfLen - 1) / nqLim kernelLen = np.round(self.smthKernelWidth * binsPerHz) kernelSig = self.smthKernelSigma * binsPerHz from scipy import signal k = signal.gaussian(kernelLen, kernelSig) / (kernelLen / 2 / 2) power_sm = signal.fftconvolve(power, k[::-1], mode="same") # calculate some metrics # find max in theta band spectrumMaskBand = np.logical_and(freqs > self.thetaRange[0], freqs < self.thetaRange[1]) bandMaxPower = np.max(power_sm[spectrumMaskBand]) maxBinInBand = np.argmax(power_sm[spectrumMaskBand]) bandFreqs = freqs[spectrumMaskBand] freqAtBandMaxPower = bandFreqs[maxBinInBand] # self.maxBinInBand = maxBinInBand # self.freqAtBandMaxPower = freqAtBandMaxPower # self.bandMaxPower = bandMaxPower # find power in small window around peak and divide by power in rest # of spectrum to get snr spectrumMaskPeak = np.logical_and( freqs > freqAtBandMaxPower - self.sn2Width / 2, freqs < freqAtBandMaxPower + self.sn2Width / 2, ) s2n = np.nanmean(power_sm[spectrumMaskPeak]) / np.nanmean( power_sm[~spectrumMaskPeak]) self.freqs = freqs self.power_sm = power_sm self.spectrumMaskPeak = spectrumMaskPeak if plot: fig = plt.figure() ax = fig.add_subplot(111) if ymax is None: ymax = np.min([2 * np.max(power), np.max(power_sm)]) if ymax == 0: ymax = 1 ax.plot(freqs, power, c=[0.9, 0.9, 0.9]) # ax.hold(True) ax.plot(freqs, power_sm, "k", lw=2) ax.axvline(self.thetaRange[0], c="b", ls="--") ax.axvline(self.thetaRange[1], c="b", ls="--") _, stemlines, _ = ax.stem([freqAtBandMaxPower], [bandMaxPower], linefmt="r") # plt.setp(stemlines, 'linewidth', 2) ax.fill_between( freqs, 0, power_sm, where=spectrumMaskPeak, color="r", alpha=0.25, zorder=25, ) # ax.set_ylim(0, ymax) # ax.set_xlim(0, self.xmax) ax.set_xlabel("Frequency (Hz)") ax.set_ylabel("Power density (W/Hz)") out_dict = { "maxFreq": freqAtBandMaxPower, "Power": power_sm, "Freqs": freqs, "s2n": s2n, "Power_raw": power, "k": k, "kernelLen": kernelLen, "kernelSig": kernelSig, "binsPerHz": binsPerHz, "kernelLen": kernelLen, } return out_dict
def __init__( self, lfp_sig: np.array, lfp_fs: int, xy: np.array, spike_ts: np.array, pos_ts: np.array, pp_config: dict = phase_precession_config, ): [setattr(self, k, pp_config[k]) for k in pp_config.keys()] self._pos_ts = pos_ts # Create a dict to hold the stats values stats_dict = { "values": None, "pha": None, "slope": None, "intercept": None, "cor": None, "p": None, "cor_boot": None, "p_shuffled": None, "ci": None, "reg": None, } # Create a dict of regressors to hold stat values # for each regressor from collections import defaultdict self.regressors = {} self.regressors = defaultdict(lambda: stats_dict.copy(), self.regressors) regressor_keys = [ "spk_numWithinRun", "pos_exptdRate_cum", "pos_instFR", "pos_timeInRun", "pos_d_cum", "pos_d_meanDir", "pos_d_currentdir", "spk_thetaBatchLabelInRun", ] [self.regressors[k] for k in regressor_keys] # each of the regressors in regressor_keys is a key with a value # of stats_dict self.k = 1000 self.alpha = 0.05 self.hyp = 0 self.conf = True # Process the EEG data a bit... self.eeg = lfp_sig L = LFPOscillations(lfp_sig, lfp_fs) filt_sig, phase, _, _ = L.getFreqPhase(lfp_sig, [6, 12], 2) self.filteredEEG = filt_sig self.phase = phase self.phaseAdj = None # ... and the position data P = PosCalcsGeneric( xy[0, :], xy[1, :], ppm=self.ppm, cm=True, ) P.postprocesspos(tracker_params={"AxonaBadValue": 1023}) # ... do the ratemap creation here once R = RateMap(P.xy, P.dir, P.speed) R.cmsPerBin = self.cms_per_bin R.smooth_sz = self.field_smoothing_kernel_len R.ppm = self.ppm spk_times_in_pos_samples = self.getSpikePosIndices(spike_ts) spk_weights = np.bincount(spk_times_in_pos_samples, minlength=len(self.pos_ts)) self.spk_times_in_pos_samples = spk_times_in_pos_samples self.spk_weights = spk_weights self.RateMap = R # this will be used a fair bit below self.spike_ts = spike_ts