def __init__(self, path=None, file_schema=None, timefreq_path=None, **file_schema_kwargs): """Create a new object to estimate a STRF from a dataset. There are many computation steps which must be done in order. Here is a full pipeline illustrating its use: # Get the files expt_path = expt_path_l[0] expt = STRF.base.Experiment(expt_path) expt.file_schema.timefreq_path = timefreq_dir expt.file_schema.populate() # Load the timefreq and concatenate expt.read_all_timefreq() expt.compute_full_stimulus_matrix() # Load responses and bin expt.read_all_responses() expt.compute_binned_responses() # Grab the stimulus and responses fsm = expt.compute_full_stimulus_matrix() frm = expt.compute_full_response_matrix() """ # Location of data self.path = path if file_schema is None: self.file_schema = STRFlabFileSchema(self.path, **file_schema_kwargs) # Hack to make it load the timefreq files self.file_schema.timefreq_path = timefreq_path self.file_schema.populate() # How to read timefreq files self.timefreq_file_reader = io.read_timefreq_from_matfile
class Experiment: """Object encapsulating STRF estimation for a specific dataset""" def __init__(self, path=None, file_schema=None, timefreq_path=None, **file_schema_kwargs): """Create a new object to estimate a STRF from a dataset. There are many computation steps which must be done in order. Here is a full pipeline illustrating its use: # Get the files expt_path = expt_path_l[0] expt = STRF.base.Experiment(expt_path) expt.file_schema.timefreq_path = timefreq_dir expt.file_schema.populate() # Load the timefreq and concatenate expt.read_all_timefreq() expt.compute_full_stimulus_matrix() # Load responses and bin expt.read_all_responses() expt.compute_binned_responses() # Grab the stimulus and responses fsm = expt.compute_full_stimulus_matrix() frm = expt.compute_full_response_matrix() """ # Location of data self.path = path if file_schema is None: self.file_schema = STRFlabFileSchema(self.path, **file_schema_kwargs) # Hack to make it load the timefreq files self.file_schema.timefreq_path = timefreq_path self.file_schema.populate() # How to read timefreq files self.timefreq_file_reader = io.read_timefreq_from_matfile def read_timefreq(self, label): filename = self.file_schema.timefreq_filename[label] return self.timefreq_file_reader(filename) def read_all_timefreq(self, store_intermediates=True): """Read timefreq from disk. Store and return. Reads all timefreq from self.file_schema. Each consists of Pxx, freqs, t. If the freqs is the same for all, then stores in self.freqs. Otherwise, self.freqs is None. Same for t. Returns: List of Pxx, list of freqs, list of t """ # Load all Pxx_l, freqs_l, t_l = zip(*[self.read_timefreq(label) for label in self.file_schema.timefreq_file_labels]) # Optionally store if store_intermediates: self.timefreq_list = Pxx_l self.freqs_l = freqs_l self.t_l = t_l # Test for freqs consistency self.freqs = None if allclose_2d(freqs_l): self.freqs = np.mean(freqs_l, axis=0) # Test for t consistency self.t = None if allclose_2d(t_l): self.t = np.mean(t_l, axis=0) return Pxx_l, freqs_l, t_l def read_response(self, label): folded = io.read_single_stimulus(self.file_schema.spike_path, label) return folded def read_all_responses(self): """Reads all response files and stores in self.response_l""" # Read in all spikes, recentering dfolded = io.read_directory(self.file_schema.spike_path, subtract_off_center=True) # Order by label response_l = [] for label in self.file_schema.spike_file_labels: response_l.append(dfolded[label]) self.response_l = response_l return response_l def compute_binned_responses(self, dilation_before_binning=.99663): """Bins the stored responses in the same way as the stimuli. The bins are inferred from the binwidth of the timefreq, as stored in self.t_l, independently for each stimulus. Optionally, a dilation is applied to these bins to convert them into the neural timebase. Finally, the values in self.response_l are binned and stored in self.binned_response_l I also store self.trials_l to identify how many repetitions of each timepoint occurred. """ self.binned_response_l = [] self.trials_l = [] # Iterate over stimuli for folded, t_stim in zip(self.response_l, self.t_l): # Get bins from t_stim by recovering original edges t_stim_width = np.mean(np.diff(t_stim)) edges = np.linspace(0, len(t_stim) * t_stim_width, len(t_stim) + 1) # Optionally apply a kkpandas dilation # Spike times are always shorter than behavior times edges = edges * dilation_before_binning # Bin each, using the same number of bins as in t binned = kkpandas.Binned.from_folded(folded, bins=edges) # Save the results self.binned_response_l.append(binned.rate.values.flatten()) self.trials_l.append(binned.trials.values.flatten()) def compute_concatenated_stimuli(self): """Returns concatenated spectrograms as (N_freqs, N_timepoints). This is really only for visualization, not computation, because it doesn't include the delays. """ return np.concatenate(self.timefreq_list, axis=1) def compute_concatenated_responses(self): """Returns a 1d array of concatenated binned responses""" return np.concatenate(self.binned_response_l) def compute_full_stimulus_matrix(self, n_delays=3, timefreq_list=None, blanking_value=-np.inf): """Given a list of spectrograms, returns the full stimulus matrix. See concatenate_and_reshape_timefreq for the implementation details. This function actually returns a transposed version, more suitable for fitting. The shape is (N_timepoints, N_freqs * N_delays), ie, (N_constraints, N_inputs) """ # Determine what list to operate on if timefreq_list is None: timefreq_list = self.timefreq_list if timefreq_list is None: timefreq_list = self.read_all_timefreq()[0] if timefreq_list is None: raise ValueError("cannot determine timefreq lists") # Concatenate and reshape self.full_stimulus_matrix = concatenate_and_reshape_timefreq( timefreq_list, n_delays=n_delays, blanking_value=blanking_value).T # Write out return self.full_stimulus_matrix def compute_full_response_matrix(self): """Returns a response matrix, suitable for fitting. Returned array has shape (N_timepoints, 1) """ self.full_response_matrix = \ self.compute_concatenated_responses()[:, None] return self.full_response_matrix